Skip to content

Commit

Permalink
fix wallet connect provider chain id validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jhesgodi committed May 30, 2024
1 parent 622dc96 commit 763cc76
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 12 deletions.
1 change: 1 addition & 0 deletions packages/checkout/sdk/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export { IMMUTABLE_API_BASE_URL } from './env';
export {
getPassportProviderDetail,
getMetaMaskProviderDetail,
validateProvider,
} from './provider';

export {
Expand Down
70 changes: 66 additions & 4 deletions packages/checkout/sdk/src/provider/getUnderlyingProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,23 @@ import { WalletAction } from '../types/wallet';
import { CheckoutErrorType } from '../errors';

describe('getUnderlyingChainId', () => {
it('should return the underlying chain id', async () => {
it('should return underlying chain id from property', async () => {
const provider = {
provider: {
request: jest.fn().mockResolvedValue('0xAA36A7'),
chainId: ChainId.SEPOLIA,
request: jest.fn(),
},
} as unknown as Web3Provider;

const chainId = await getUnderlyingChainId(provider);
expect(chainId).toEqual(ChainId.SEPOLIA);
expect(provider.provider.request).not.toBeCalled();
});

it('should return the underlying chain id from rpc call', async () => {
const provider = {
provider: {
request: jest.fn().mockResolvedValue('0xaa36a7'),
},
} as unknown as Web3Provider;

Expand All @@ -20,11 +33,35 @@ describe('getUnderlyingChainId', () => {
});
});

it('should properly parse chain id', async () => {
const intChainId = 13473;
const strChainId = intChainId.toString();
const hexChainId = `0x${intChainId.toString(16)}`;
const getMockProvider = (chainId: unknown) => ({ provider: { chainId } } as unknown as Web3Provider);

// Number
expect(await getUnderlyingChainId(getMockProvider(intChainId))).toEqual(
intChainId,
);

// String to Number
expect(await getUnderlyingChainId(getMockProvider(strChainId))).toEqual(
intChainId,
);

// Hex to Number
expect(await getUnderlyingChainId(getMockProvider(hexChainId))).toEqual(
intChainId,
);
});

it('should throw an error if provider missing from web3provider', async () => {
try {
await getUnderlyingChainId({} as Web3Provider);
} catch (err: any) {
expect(err.message).toEqual('Parsed provider is not a valid Web3Provider');
expect(err.message).toEqual(
'Parsed provider is not a valid Web3Provider',
);
expect(err.type).toEqual(CheckoutErrorType.WEB3_PROVIDER_ERROR);
}
});
Expand All @@ -33,8 +70,33 @@ describe('getUnderlyingChainId', () => {
try {
await getUnderlyingChainId({ provider: {} } as Web3Provider);
} catch (err: any) {
expect(err.message).toEqual('Parsed provider is not a valid Web3Provider');
expect(err.message).toEqual(
'Parsed provider is not a valid Web3Provider',
);
expect(err.type).toEqual(CheckoutErrorType.WEB3_PROVIDER_ERROR);
}
});

it('should throw an error if invalid chain id value from property', async () => {
const provider = {
provider: {
chainId: 'invalid',
request: jest.fn(),
},
} as unknown as Web3Provider;

expect(provider.provider.request).not.toHaveBeenCalled();
expect(getUnderlyingChainId(provider)).rejects.toThrow('Invalid chainId');
});

it('should throw an error if invalid chain id value returned from rpc call ', async () => {
const provider = {
provider: {
request: jest.fn().mockResolvedValue('invalid'),
},
} as unknown as Web3Provider;

expect(getUnderlyingChainId(provider)).rejects.toThrow('Invalid chainId');
expect(provider.provider.request).toHaveBeenCalled();
});
});
43 changes: 39 additions & 4 deletions packages/checkout/sdk/src/provider/getUnderlyingProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,53 @@ import { Web3Provider } from '@ethersproject/providers';
import { CheckoutError, CheckoutErrorType } from '../errors';
import { WalletAction } from '../types';

// this gives us access to the properties of the underlying provider object
export async function getUnderlyingChainId(web3Provider: Web3Provider) {
const parseChainId = (chainId: unknown): number => {
if (typeof chainId === 'number') {
return chainId;
}

if (typeof chainId === 'string' && !Number.isNaN(Number(chainId))) {
return chainId.startsWith('0x') ? parseInt(chainId, 16) : Number(chainId);
}

throw new CheckoutError(
'Invalid chainId',
CheckoutErrorType.WEB3_PROVIDER_ERROR,
);
};

/**
* Get chain id from RPC method
* @param web3Provider
* @returns chainId number
*/
async function requestChainId(web3Provider: Web3Provider): Promise<number> {
if (!web3Provider.provider?.request) {
throw new CheckoutError(
'Parsed provider is not a valid Web3Provider',
CheckoutErrorType.WEB3_PROVIDER_ERROR,
);
}

const chainId = await web3Provider.provider.request({
const chainId: string = await web3Provider.provider.request({
method: WalletAction.GET_CHAINID,
params: [],
});
return parseInt(chainId, 16);

return parseChainId(chainId);
}

/**
* Get the underlying chain id from the provider
* @param web3Provider
* @returns chainId number
*/
export async function getUnderlyingChainId(web3Provider: Web3Provider): Promise<number> {
const chainId = (web3Provider.provider as any)?.chainId;

if (chainId) {
return parseChainId(chainId);
}

return requestChainId(web3Provider);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Box } from '@biom3/react';
import EthereumProvider from '@walletconnect/ethereum-provider';
import {
ChainId,
CheckoutErrorType,
Expand Down Expand Up @@ -200,20 +201,25 @@ export function WalletList(props: WalletListProps) {
[checkout],
);

const connectCallback = async (ethereumProvider) => {
const connectCallback = async (ethereumProvider: EthereumProvider) => {
if (ethereumProvider.connected && ethereumProvider.session) {
const web3Provider = new Web3Provider(ethereumProvider as any);
const web3Provider = new Web3Provider(ethereumProvider);
selectWeb3Provider(web3Provider, 'walletconnect');

const chainId = await web3Provider.getSigner().getChainId();

if (ethereumProvider.chainId !== targetChainId) {
// @ts-ignore allow protected method `switchEthereumChain` to be called
await ethereumProvider.switchEthereumChain(targetChainId);
}

if (chainId !== targetChainId) {
viewDispatch({
payload: {
type: ViewActions.UPDATE_VIEW,
view: { type: ConnectWidgetViews.SWITCH_NETWORK },
},
});
return;
}

viewDispatch({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,19 @@ export function SaleUI() {
[passportInstance]
);
const factory = useMemo(
() => new WidgetsFactory(checkout, { theme: WidgetTheme.DARK }),
() =>
new WidgetsFactory(checkout, {
theme: WidgetTheme.DARK,
walletConnect: {
projectId: "938b553484e344b1e0b4bb80edf8c362",
metadata: {
name: "Checkout Marketplace",
description: "Checkout Marketplace",
url: "http://localhost:3000/marketplace-orchestrator",
icons: [],
},
},
}),
[checkout]
);
const saleWidget = useMemo(
Expand Down

0 comments on commit 763cc76

Please sign in to comment.