diff --git a/packages/passport/sdk/src/Passport.ts b/packages/passport/sdk/src/Passport.ts index 4b1fdb37d8..4c96718489 100644 --- a/packages/passport/sdk/src/Passport.ts +++ b/packages/passport/sdk/src/Passport.ts @@ -20,6 +20,7 @@ import { PassportEventMap, PassportEvents, PassportModuleConfiguration, + User, UserProfile, } from './types'; import { ConfirmationScreen } from './confirmation'; @@ -160,7 +161,7 @@ export class Passport { anonymousId?: string; }): Promise { const { useCachedSession = false } = options || {}; - let user = null; + let user: User | null = null; try { user = await this.authManager.getUser(); } catch (error) { diff --git a/packages/passport/sdk/src/authManager.test.ts b/packages/passport/sdk/src/authManager.test.ts index 3acec6efa5..25646ba81d 100644 --- a/packages/passport/sdk/src/authManager.test.ts +++ b/packages/passport/sdk/src/authManager.test.ts @@ -6,7 +6,7 @@ import { PassportError, PassportErrorType } from './errors/passportError'; import { PassportConfiguration } from './config'; import { mockUser, mockUserImx, mockUserZkEvm } from './test/mocks'; import { isTokenExpired } from './utils/token'; -import { PassportModuleConfiguration } from './types'; +import { isUserZkEvm, PassportModuleConfiguration } from './types'; jest.mock('jwt-decode'); jest.mock('./utils/token'); @@ -416,6 +416,68 @@ describe('AuthManager', () => { }); }); }); + + describe('when the user does not meet the type assertion', () => { + it('should return null', async () => { + mockGetUser.mockReturnValue(mockOidcUser); + (isTokenExpired as jest.Mock).mockReturnValue(false); + + const result = await authManager.getUser(isUserZkEvm); + + expect(result).toBeNull(); + }); + }); + + describe('when the user does meet the type assertion', () => { + it('should return the user', async () => { + mockGetUser.mockReturnValue(mockOidcUser); + (jwt_decode as jest.Mock).mockReturnValue({ + passport: { + zkevm_eth_address: mockUserZkEvm.zkEvm.ethAddress, + zkevm_user_admin_address: mockUserZkEvm.zkEvm.userAdminAddress, + }, + }); + (isTokenExpired as jest.Mock).mockReturnValue(false); + + const result = await authManager.getUser(isUserZkEvm); + + expect(result).toEqual(mockUserZkEvm); + }); + }); + + describe('when the user is refreshing', () => { + it('should return the refreshed used', async () => { + mockSigninSilent.mockReturnValue(mockOidcUser); + + authManager.forceUserRefreshInBackground(); + + const result = await authManager.getUser(); + expect(result).toEqual(mockUser); + + expect(mockSigninSilent).toBeCalledTimes(1); + expect(mockGetUser).toBeCalledTimes(0); + }); + }); + }); + + describe('getUserZkEvm', () => { + it('should throw an error if no user is returned', async () => { + mockGetUser.mockReturnValue(null); + + await expect(() => authManager.getUserZkEvm()).rejects.toThrow( + new Error('Failed to obtain a User with the required ZkEvm attributes'), + ); + }); + }); + + describe('getUserImx', () => { + it('should throw an error if no user is returned', async () => { + mockGetUser.mockReturnValue(null); + + await expect(() => authManager.getUserImx()).rejects.toThrow( + new Error('Failed to obtain a User with the required IMX attributes'), + ); + }); }); describe('getDeviceFlowEndSessionEndpoint', () => { diff --git a/packages/passport/sdk/src/authManager.ts b/packages/passport/sdk/src/authManager.ts index 915c31ab8a..28d1d0707c 100644 --- a/packages/passport/sdk/src/authManager.ts +++ b/packages/passport/sdk/src/authManager.ts @@ -1,4 +1,6 @@ import { + ErrorResponse, + ErrorTimeout, InMemoryWebStorage, User as OidcUser, UserManager, @@ -11,7 +13,7 @@ import * as crypto from 'crypto'; import jwt_decode from 'jwt-decode'; import { getDetail, Detail } from '@imtbl/metrics'; import { isTokenExpired } from './utils/token'; -import { PassportErrorType, withPassportError } from './errors/passportError'; +import { PassportError, PassportErrorType, withPassportError } from './errors/passportError'; import { PassportMetadata, User, @@ -21,6 +23,10 @@ import { DeviceErrorResponse, IdTokenPayload, OidcConfiguration, + UserZkEvm, + isUserZkEvm, + UserImx, + isUserImx, } from './types'; import { PassportConfiguration } from './config'; @@ -85,10 +91,10 @@ function sha256(buffer: string) { export default class AuthManager { private userManager; - private config: PassportConfiguration; - private deviceCredentialsManager: DeviceCredentialsManager; + private readonly config: PassportConfiguration; + private readonly logoutMode: Exclude; /** @@ -179,7 +185,7 @@ export default class AuthManager { } public async getUserOrLogin(): Promise { - let user = null; + let user: User | null = null; try { user = await this.getUser(); } catch (err) { @@ -370,6 +376,13 @@ export default class AuthManager { return this.userManager.signoutSilentCallback(url); } + public forceUserRefreshInBackground() { + this.refreshTokenAndUpdatePromise().catch((error) => { + // eslint-disable-next-line no-console + console.warn('Failed to refresh user token', error); + }); + } + public async forceUserRefresh(): Promise { return this.refreshTokenAndUpdatePromise(); } @@ -391,7 +404,19 @@ export default class AuthManager { } resolve(null); } catch (err) { - reject(err); + let passportErrorType = PassportErrorType.AUTHENTICATION_ERROR; + let errorMessage = 'Failed to refresh token'; + + if (err instanceof ErrorTimeout) { + passportErrorType = PassportErrorType.SILENT_LOGIN_ERROR; + } else if (err instanceof ErrorResponse) { + passportErrorType = PassportErrorType.NOT_LOGGED_IN_ERROR; + errorMessage = `${err.message}: ${err.error_description}`; + } else if (err instanceof Error) { + errorMessage = err.message; + } + + reject(new PassportError(errorMessage, passportErrorType)); } finally { this.refreshingPromise = null; // Reset the promise after completion } @@ -401,23 +426,65 @@ export default class AuthManager { } /** - * Get the user from the cache or refresh the token if it's expired. - * return null if there's no refresh token. + * + * @param typeAssertion {(user: User) => boolean} - Optional. If provided, then the User will be checked against + * the typeAssertion. If the user meets the requirements, then it will be typed as T and returned. If the User + * does NOT meet the type assertion, then execution will continue, and we will attempt to obtain a User that does + * meet the type assertion. + * + * This function will attempt to obtain a User in the following order: + * 1. If the User is currently refreshing, wait for the refresh to complete. + * 2. Attempt to obtain a User from storage that has not expired. + * 3. Attempt to refresh the User if a refresh token is present. + * 4. Return null if no valid User can be obtained. */ - public async getUser(): Promise { - return withPassportError(async () => { - const oidcUser = await this.userManager.getUser(); - if (!oidcUser) return null; + public async getUser( + typeAssertion: (user: User) => user is T = (user: User): user is T => true, + ): Promise { + if (this.refreshingPromise) { + const user = await this.refreshingPromise; + if (user && typeAssertion(user)) { + return user; + } - if (!isTokenExpired(oidcUser)) { - return AuthManager.mapOidcUserToDomainModel(oidcUser); + return null; + } + + const oidcUser = await this.userManager.getUser(); + if (!oidcUser) return null; + + if (!isTokenExpired(oidcUser)) { + const user = AuthManager.mapOidcUserToDomainModel(oidcUser); + if (user && typeAssertion(user)) { + return user; } + } - if (oidcUser.refresh_token) { - return this.refreshTokenAndUpdatePromise(); + if (oidcUser.refresh_token) { + const user = await this.refreshTokenAndUpdatePromise(); + if (user && typeAssertion(user)) { + return user; } + } - return null; - }, PassportErrorType.NOT_LOGGED_IN_ERROR); + return null; + } + + public async getUserZkEvm(): Promise { + const user = await this.getUser(isUserZkEvm); + if (!user) { + throw new Error('Failed to obtain a User with the required ZkEvm attributes'); + } + + return user; + } + + public async getUserImx(): Promise { + const user = await this.getUser(isUserImx); + if (!user) { + throw new Error('Failed to obtain a User with the required IMX attributes'); + } + + return user; } } diff --git a/packages/passport/sdk/src/guardian/index.test.ts b/packages/passport/sdk/src/guardian/index.test.ts index c90a38aded..05c50ba2d6 100644 --- a/packages/passport/sdk/src/guardian/index.test.ts +++ b/packages/passport/sdk/src/guardian/index.test.ts @@ -11,57 +11,53 @@ import { PassportConfiguration } from '../config'; jest.mock('@imtbl/guardian'); jest.mock('../confirmation/confirmation'); -let guardianClient: GuardianClient; - describe('Guardian', () => { afterEach(jest.resetAllMocks); let mockGetTransactionByID: jest.Mock; let mockEvaluateTransaction: jest.Mock; let mockEvaluateMessage : jest.Mock; + let getUserImxMock: jest.Mock; + let getUserZkEvmMock: jest.Mock; const mockConfirmationScreen = new ConfirmationScreen({} as any); + const getGuardianClient = (crossSdkBridgeEnabled: boolean = false) => ( + new GuardianClient({ + confirmationScreen: mockConfirmationScreen, + config: new PassportConfiguration({ + baseConfig: {} as ImmutableConfiguration, + clientId: 'client123', + logoutRedirectUri: 'http://localhost:3000/logout', + redirectUri: 'http://localhost:3000/redirect', + crossSdkBridgeEnabled, + }), + authManager: { + getUserImx: getUserImxMock, + getUserZkEvm: getUserZkEvmMock, + } as unknown as AuthManager, + }) + ); + beforeEach(() => { mockGetTransactionByID = jest.fn(); mockEvaluateTransaction = jest.fn(); mockEvaluateMessage = jest.fn(); + getUserImxMock = jest.fn().mockReturnValue(mockUserImx); + getUserZkEvmMock = jest.fn().mockReturnValue(mockUserZkEvm); (guardian.TransactionsApi as jest.Mock).mockImplementation(() => ({ getTransactionByID: mockGetTransactionByID, evaluateTransaction: mockEvaluateTransaction, })); (guardian.MessagesApi as jest.Mock).mockImplementation(() => ({ evaluateMessage: mockEvaluateMessage })); - - guardianClient = new GuardianClient({ - confirmationScreen: mockConfirmationScreen, - config: new PassportConfiguration({ - baseConfig: {} as ImmutableConfiguration, - clientId: 'client123', - logoutRedirectUri: 'http://localhost:3000/logout', - redirectUri: 'http://localhost:3000/redirect', - }), - authManager: { getUser: jest.fn().mockResolvedValue(mockUserImx) } as unknown as AuthManager, - }); }); describe('evaluateImxTransaction', () => { - beforeAll(() => { - guardianClient = new GuardianClient({ - confirmationScreen: mockConfirmationScreen, - config: new PassportConfiguration({ - baseConfig: {} as ImmutableConfiguration, - clientId: 'client123', - logoutRedirectUri: 'http://localhost:3000/logout', - redirectUri: 'http://localhost:3000/redirect', - }), - authManager: { getUser: jest.fn().mockResolvedValue(mockUserImx) } as unknown as AuthManager, - }); - }); afterEach(jest.clearAllMocks); it('should retry getting transaction details and throw an error when transaction does not exist', async () => { mockGetTransactionByID.mockResolvedValue({ data: { id: '1234' } }); mockEvaluateTransaction.mockResolvedValue({ data: { confirmationRequired: false } }); - await guardianClient.evaluateImxTransaction({ payloadHash: 'hash' }); + await getGuardianClient().evaluateImxTransaction({ payloadHash: 'hash' }); expect(mockConfirmationScreen.requestConfirmation).toBeCalledTimes(0); expect(mockEvaluateTransaction).toBeCalledWith({ @@ -76,7 +72,7 @@ describe('Guardian', () => { mockGetTransactionByID.mockResolvedValue({ data: { id: '1234' } }); mockEvaluateTransaction.mockResolvedValue({ data: { confirmationRequired: false } }); - await guardianClient.evaluateImxTransaction({ payloadHash: 'hash' }); + await getGuardianClient().evaluateImxTransaction({ payloadHash: 'hash' }); expect(mockConfirmationScreen.requestConfirmation).toBeCalledTimes(0); }); @@ -87,7 +83,7 @@ describe('Guardian', () => { .mockResolvedValueOnce({ data: { confirmationRequired: true } }); (mockConfirmationScreen.requestConfirmation as jest.Mock).mockResolvedValueOnce({ confirmed: true }); - await guardianClient.evaluateImxTransaction({ payloadHash: 'hash' }); + await getGuardianClient().evaluateImxTransaction({ payloadHash: 'hash' }); expect(mockConfirmationScreen.requestConfirmation).toHaveBeenCalledWith('hash', mockUserImx.imx.ethAddress, 'starkex'); }); @@ -98,7 +94,7 @@ describe('Guardian', () => { .mockResolvedValueOnce({ data: { confirmationRequired: true } }); (mockConfirmationScreen.requestConfirmation as jest.Mock).mockResolvedValueOnce({ confirmed: false }); - await expect(guardianClient.evaluateImxTransaction({ payloadHash: 'hash' })).rejects.toThrow('Transaction rejected by user'); + await expect(getGuardianClient().evaluateImxTransaction({ payloadHash: 'hash' })).rejects.toThrow('Transaction rejected by user'); }); describe('crossSdkBridgeEnabled', () => { @@ -107,17 +103,7 @@ describe('Guardian', () => { mockEvaluateTransaction .mockResolvedValueOnce({ data: { confirmationRequired: true } }); - guardianClient = new GuardianClient({ - confirmationScreen: mockConfirmationScreen, - config: new PassportConfiguration({ - baseConfig: {} as ImmutableConfiguration, - clientId: 'client123', - logoutRedirectUri: 'http://localhost:3000/logout', - redirectUri: 'http://localhost:3000/redirect', - crossSdkBridgeEnabled: true, - }), - authManager: { getUser: jest.fn().mockResolvedValue(mockUserImx) } as unknown as AuthManager, - }); + const guardianClient = getGuardianClient(true); await expect(guardianClient.evaluateImxTransaction({ payloadHash: 'hash' })) .rejects @@ -128,18 +114,6 @@ describe('Guardian', () => { describe('validateEVMTransaction', () => { afterEach(jest.resetAllMocks); - beforeEach(() => { - guardianClient = new GuardianClient({ - confirmationScreen: mockConfirmationScreen, - config: new PassportConfiguration({ - baseConfig: {} as ImmutableConfiguration, - clientId: 'client123', - logoutRedirectUri: 'http://localhost:3000/logout', - redirectUri: 'http://localhost:3000/redirect', - }), - authManager: { getUser: jest.fn().mockResolvedValue(mockUserZkEvm) } as unknown as AuthManager, - }); - }); it('throws an error if the request data fails to be parsed', async () => { const transactionRequest: TransactionRequest = { to: mockUserZkEvm.zkEvm.ethAddress, @@ -148,7 +122,7 @@ describe('Guardian', () => { }; await expect( - guardianClient.validateEVMTransaction({ + getGuardianClient().validateEVMTransaction({ chainId: 'epi123', nonce: '5', metaTransactions: [ @@ -184,7 +158,7 @@ describe('Guardian', () => { mockEvaluateTransaction.mockResolvedValue({ data: { confirmationRequired: false } }); - await guardianClient.validateEVMTransaction({ + await getGuardianClient().validateEVMTransaction({ chainId: 'epi123', nonce: '5', metaTransactions: [ @@ -228,17 +202,6 @@ describe('Guardian', () => { describe('crossSdkBridgeEnabled', () => { it('throws an error if confirmation is required and the cross sdk bridge flag is enabled', async () => { mockEvaluateTransaction.mockResolvedValue({ data: { confirmationRequired: true } }); - guardianClient = new GuardianClient({ - confirmationScreen: mockConfirmationScreen, - config: new PassportConfiguration({ - baseConfig: {} as ImmutableConfiguration, - clientId: 'client123', - logoutRedirectUri: 'http://localhost:3000/logout', - redirectUri: 'http://localhost:3000/redirect', - crossSdkBridgeEnabled: true, - }), - authManager: { getUser: jest.fn().mockResolvedValue(mockUserZkEvm) } as unknown as AuthManager, - }); const transactionRequest: TransactionRequest = { to: mockUserZkEvm.zkEvm.ethAddress, @@ -247,7 +210,7 @@ describe('Guardian', () => { }; await expect( - guardianClient.validateEVMTransaction({ + getGuardianClient(true).validateEVMTransaction({ chainId: 'epi123', nonce: '5', metaTransactions: [ @@ -279,13 +242,13 @@ describe('Guardian', () => { describe('withConfirmationScreenTask', () => { it('should call the task and close the confirmation screen if the task fails', async () => { const mockTask = jest.fn().mockRejectedValueOnce(new Error('Task failed')); - await expect(guardianClient.withConfirmationScreenTask()(mockTask)()).rejects.toThrow('Task failed'); + await expect(getGuardianClient().withConfirmationScreenTask()(mockTask)()).rejects.toThrow('Task failed'); expect(mockConfirmationScreen.closeWindow).toBeCalledTimes(1); }); it('should call the task and return the result if the task succeeds', async () => { const mockTask = jest.fn().mockResolvedValueOnce('result'); - const wrappedTask = guardianClient.withConfirmationScreenTask()(mockTask); + const wrappedTask = getGuardianClient().withConfirmationScreenTask()(mockTask); await expect(wrappedTask()).resolves.toEqual('result'); @@ -296,13 +259,13 @@ describe('Guardian', () => { it('should call the task and close the confirmation screen if the task fails', async () => { const mockTask = jest.fn().mockRejectedValueOnce(new Error('Task failed')); - await expect(guardianClient.withConfirmationScreen()(mockTask)).rejects.toThrow('Task failed'); + await expect(getGuardianClient().withConfirmationScreen()(mockTask)).rejects.toThrow('Task failed'); expect(mockConfirmationScreen.closeWindow).toBeCalledTimes(1); }); it('should call the task and return the result if the task succeeds', async () => { const mockTask = jest.fn().mockResolvedValueOnce('result'); - const promise = guardianClient.withConfirmationScreen()(mockTask); + const promise = getGuardianClient().withConfirmationScreen()(mockTask); await expect(promise).resolves.toEqual('result'); expect(mockConfirmationScreen.closeWindow).toBeCalledTimes(0); @@ -313,13 +276,13 @@ describe('Guardian', () => { it('should call the task and close the confirmation screen if the task fails', async () => { const mockTask = jest.fn().mockRejectedValueOnce(new Error('Task failed')); - await expect(guardianClient.withDefaultConfirmationScreenTask(mockTask)()).rejects.toThrow('Task failed'); + await expect(getGuardianClient().withDefaultConfirmationScreenTask(mockTask)()).rejects.toThrow('Task failed'); expect(mockConfirmationScreen.closeWindow).toBeCalledTimes(1); }); it('should call the task and return the result if the task succeeds', async () => { const mockTask = jest.fn().mockResolvedValueOnce('result'); - const wrappedTask = guardianClient.withDefaultConfirmationScreenTask(mockTask); + const wrappedTask = getGuardianClient().withDefaultConfirmationScreenTask(mockTask); await expect(wrappedTask()).resolves.toEqual('result'); expect(mockConfirmationScreen.closeWindow).toBeCalledTimes(0); @@ -329,34 +292,23 @@ describe('Guardian', () => { describe('validateMessage', () => { afterEach(jest.resetAllMocks); - beforeEach(() => { - guardianClient = new GuardianClient({ - confirmationScreen: mockConfirmationScreen, - config: new PassportConfiguration({ - baseConfig: {} as ImmutableConfiguration, - clientId: 'client123', - logoutRedirectUri: 'http://localhost:3000/logout', - redirectUri: 'http://localhost:3000/redirect', - }), - authManager: { getUser: jest.fn().mockResolvedValue(mockUserZkEvm) } as unknown as AuthManager, - }); - }); + const mockPayload = { chainID: '0x1234', payload: {} as guardian.EIP712Message, user: mockUserZkEvm }; it('surfaces error message if message evaluation fails', async () => { mockEvaluateMessage.mockRejectedValueOnce(new Error('401: Unauthorized')); - await expect(guardianClient.validateMessage(mockPayload)) + await expect(getGuardianClient().validateMessage(mockPayload)) .rejects.toThrow('Message failed to validate with error: 401: Unauthorized'); }); it('displays confirmation screen if confirmation is required', async () => { mockEvaluateMessage.mockResolvedValueOnce({ data: { confirmationRequired: true, messageId: 'asd123' } }); (mockConfirmationScreen.requestMessageConfirmation as jest.Mock).mockResolvedValueOnce({ confirmed: true }); - await guardianClient.validateMessage(mockPayload); + await getGuardianClient().validateMessage(mockPayload); expect(mockConfirmationScreen.requestMessageConfirmation).toBeCalledTimes(1); }); it('displays rejection error message if user rejects confirmation', async () => { mockEvaluateMessage.mockResolvedValueOnce({ data: { confirmationRequired: true, messageId: 'asd123' } }); (mockConfirmationScreen.requestMessageConfirmation as jest.Mock).mockResolvedValueOnce({ confirmed: false }); - await expect(guardianClient.validateMessage(mockPayload)).rejects.toEqual(new JsonRpcError(RpcErrorCode.TRANSACTION_REJECTED, 'Signature rejected by user')); + await expect(getGuardianClient().validateMessage(mockPayload)).rejects.toEqual(new JsonRpcError(RpcErrorCode.TRANSACTION_REJECTED, 'Signature rejected by user')); }); }); }); diff --git a/packages/passport/sdk/src/guardian/index.ts b/packages/passport/sdk/src/guardian/index.ts index 64188c138f..01bd7c866d 100644 --- a/packages/passport/sdk/src/guardian/index.ts +++ b/packages/passport/sdk/src/guardian/index.ts @@ -7,7 +7,6 @@ import { ConfirmationScreen } from '../confirmation'; import { retryWithDelay } from '../network/retry'; import { JsonRpcError, RpcErrorCode } from '../zkEvm/JsonRpcError'; import { MetaTransaction, TypedDataPayload } from '../zkEvm/types'; -import { UserImx, UserZkEvm } from '../types'; import { PassportConfiguration } from '../config'; export type GuardianClientParams = { @@ -114,7 +113,7 @@ export default class GuardianClient { const finallyFn = () => { this.confirmationScreen.closeWindow(); }; - const user = await this.authManager.getUser() as UserImx; + const user = await this.authManager.getUserImx(); const headers = { Authorization: `Bearer ${user.accessToken}` }; const transactionRes = await retryWithDelay( @@ -161,7 +160,7 @@ export default class GuardianClient { nonce, metaTransactions, }: GuardianEVMValidationParams): Promise { - const user = await this.authManager.getUser() as UserZkEvm; + const user = await this.authManager.getUserZkEvm(); const headers = { Authorization: `Bearer ${user.accessToken}` }; const guardianTransactions = transformGuardianTransactions(metaTransactions); try { @@ -210,7 +209,7 @@ export default class GuardianClient { } if (confirmationRequired && !!transactionId) { - const user = await this.authManager.getUser() as UserZkEvm; + const user = await this.authManager.getUserZkEvm(); const confirmationResult = await this.confirmationScreen.requestConfirmation( transactionId, user.zkEvm.ethAddress, @@ -233,7 +232,7 @@ export default class GuardianClient { { chainID, payload }:GuardianMessageValidationParams, ): Promise { try { - const user = await this.authManager.getUser() as UserZkEvm; + const user = await this.authManager.getUserZkEvm(); if (user === null) { throw new PassportError('evaluateMessage requires a valid ID token or refresh token. Please log in first', PassportErrorType.NOT_LOGGED_IN_ERROR); } @@ -254,7 +253,7 @@ export default class GuardianClient { throw new JsonRpcError(RpcErrorCode.TRANSACTION_REJECTED, transactionRejectedCrossSdkBridgeError); } if (confirmationRequired && !!messageId) { - const user = await this.authManager.getUser() as UserZkEvm; + const user = await this.authManager.getUserZkEvm(); const confirmationResult = await this.confirmationScreen.requestMessageConfirmation( messageId, user.zkEvm.ethAddress, diff --git a/packages/passport/sdk/src/mocks/zkEvm/msw.ts b/packages/passport/sdk/src/mocks/zkEvm/msw.ts index bbe4ea3433..5684b2d3ad 100644 --- a/packages/passport/sdk/src/mocks/zkEvm/msw.ts +++ b/packages/passport/sdk/src/mocks/zkEvm/msw.ts @@ -4,7 +4,7 @@ import { SetupServer, setupServer } from 'msw/node'; import { ChainName } from 'network/chains'; import { RelayerTransactionRequest } from '../../zkEvm/relayerClient'; import { JsonRpcRequestPayload } from '../../zkEvm/types'; -import { chainId, chainIdHex } from '../../test/mocks'; +import { chainId, chainIdHex, mockUserZkEvm } from '../../test/mocks'; export const relayerId = '0x745'; export const transactionHash = '0x867'; @@ -34,7 +34,12 @@ export const mswHandlers = { counterfactualAddress: { success: rest.post( `https://api.sandbox.immutable.com/v2/chains/${chainName}/passport/counterfactual-address`, - (req, res, ctx) => res(ctx.status(201)), + (req, res, ctx) => res( + ctx.status(201), + ctx.json({ + counterfactual_address: mockUserZkEvm.zkEvm.ethAddress, + }), + ), ), internalServerError: rest.post( `https://api.sandbox.immutable.com/v2/chains/${chainName}/passport/counterfactual-address`, diff --git a/packages/passport/sdk/src/starkEx/passportImxProvider.ts b/packages/passport/sdk/src/starkEx/passportImxProvider.ts index 404a93e6be..21249b062c 100644 --- a/packages/passport/sdk/src/starkEx/passportImxProvider.ts +++ b/packages/passport/sdk/src/starkEx/passportImxProvider.ts @@ -26,7 +26,7 @@ import TypedEventEmitter from '../utils/typedEventEmitter'; import AuthManager from '../authManager'; import GuardianClient from '../guardian'; import { - PassportEventMap, PassportEvents, UserImx, User, IMXSigners, + PassportEventMap, PassportEvents, UserImx, User, IMXSigners, isUserImx, } from '../types'; import { PassportError, PassportErrorType } from '../errors/passportError'; import { @@ -112,7 +112,7 @@ export class PassportImxProvider implements IMXProvider { * @see getAuthenticatedUserAndSigners * */ - private async initialiseSigners(): Promise { + private initialiseSigners() { const generateSigners = async (): Promise => { const user = await this.authManager.getUser(); // The user will be present because the factory validates it @@ -160,7 +160,6 @@ export class PassportImxProvider implements IMXProvider { protected async getRegisteredImxUserAndSigners(): Promise { const { user, starkSigner, ethSigner } = await this.getAuthenticatedUserAndSigners(); - const isUserImx = (oidcUser: User | null): oidcUser is UserImx => oidcUser?.imx !== undefined; if (!isUserImx(user)) { throw new PassportError( diff --git a/packages/passport/sdk/src/types.ts b/packages/passport/sdk/src/types.ts index 4991cb9b93..cdb7132aea 100644 --- a/packages/passport/sdk/src/types.ts +++ b/packages/passport/sdk/src/types.ts @@ -80,6 +80,9 @@ type WithRequired = T & { [P in K]-?: T[P] }; export type UserImx = WithRequired; export type UserZkEvm = WithRequired; +export const isUserZkEvm = (user: User): user is UserZkEvm => !!user.zkEvm; +export const isUserImx = (user: User): user is UserImx => !!user.imx; + // Device code auth export type DeviceConnectResponse = { diff --git a/packages/passport/sdk/src/zkEvm/relayerClient.test.ts b/packages/passport/sdk/src/zkEvm/relayerClient.test.ts index 708580e8ed..33f554716d 100644 --- a/packages/passport/sdk/src/zkEvm/relayerClient.test.ts +++ b/packages/passport/sdk/src/zkEvm/relayerClient.test.ts @@ -20,7 +20,9 @@ describe('relayerClient', () => { const relayerClient = new RelayerClient({ config: config as PassportConfiguration, jsonRpcProvider: jsonRpcProvider as JsonRpcProvider, - authManager: { getUser: jest.fn().mockResolvedValue(user as UserZkEvm) } as unknown as AuthManager, + authManager: { + getUserZkEvm: jest.fn().mockResolvedValue(user as UserZkEvm), + } as unknown as AuthManager, }); let originalFetch: any; diff --git a/packages/passport/sdk/src/zkEvm/relayerClient.ts b/packages/passport/sdk/src/zkEvm/relayerClient.ts index 4bd884252b..1b5f95f6ae 100644 --- a/packages/passport/sdk/src/zkEvm/relayerClient.ts +++ b/packages/passport/sdk/src/zkEvm/relayerClient.ts @@ -95,12 +95,12 @@ export class RelayerClient { ...request, }; - const user = await this.authManager.getUser(); + const user = await this.authManager.getUserZkEvm(); const response = await fetch(`${this.config.relayerUrl}/v1/transactions`, { method: 'POST', headers: { - Authorization: `Bearer ${user?.accessToken}`, + Authorization: `Bearer ${user.accessToken}`, 'Content-Type': 'application/json', }, body: JSON.stringify(body), diff --git a/packages/passport/sdk/src/zkEvm/user/registerZkEvmUser.test.ts b/packages/passport/sdk/src/zkEvm/user/registerZkEvmUser.test.ts index 90a2d8d7d4..8702128fd0 100644 --- a/packages/passport/sdk/src/zkEvm/user/registerZkEvmUser.test.ts +++ b/packages/passport/sdk/src/zkEvm/user/registerZkEvmUser.test.ts @@ -5,7 +5,7 @@ import { MultiRollupApiClients } from '@imtbl/generated-clients'; import { ChainId, ChainName } from 'network/chains'; import { registerZkEvmUser } from './registerZkEvmUser'; import AuthManager from '../../authManager'; -import { mockListChains, mockUser, mockUserZkEvm } from '../../test/mocks'; +import { mockListChains, mockUserZkEvm } from '../../test/mocks'; jest.mock('@ethersproject/providers'); jest.mock('@ethersproject/abstract-signer'); @@ -17,7 +17,7 @@ describe('registerZkEvmUser', () => { }; const authManager = { getUser: jest.fn(), - forceUserRefresh: jest.fn(), + forceUserRefreshInBackground: jest.fn(), }; const multiRollupApiClients = { passportApi: { @@ -59,49 +59,14 @@ describe('registerZkEvmUser', () => { }); }); - describe('when getUser fails to return a user', () => { - it('should throw an error', async () => { - multiRollupApiClients.passportApi.createCounterfactualAddressV2.mockResolvedValue({ - status: 201, - }); - - authManager.getUser.mockResolvedValue(null); - - await expect(async () => registerZkEvmUser({ - authManager: authManager as unknown as AuthManager, - ethSigner: ethSignerMock as unknown as Signer, - multiRollupApiClients: multiRollupApiClients as unknown as MultiRollupApiClients, - accessToken, - jsonRpcProvider: jsonRPCProvider as unknown as JsonRpcProvider, - })).rejects.toThrow('Failed to refresh user details'); - }); - }); - - describe('when getUser returns a user that has not registered with zkEvm', () => { - it('should throw an error', async () => { - multiRollupApiClients.passportApi.createCounterfactualAddressV2.mockResolvedValue({ - status: 201, - }); - - authManager.getUser.mockResolvedValue(mockUser); - - await expect(async () => registerZkEvmUser({ - authManager: authManager as unknown as AuthManager, - ethSigner: ethSignerMock as unknown as Signer, - multiRollupApiClients: multiRollupApiClients as unknown as MultiRollupApiClients, - accessToken, - jsonRpcProvider: jsonRPCProvider as unknown as JsonRpcProvider, - })).rejects.toThrow('Failed to refresh user details'); - }); - }); - it('should return a user that has registered with zkEvm', async () => { multiRollupApiClients.passportApi.createCounterfactualAddressV2.mockResolvedValue({ status: 201, + data: { + counterfactual_address: mockUserZkEvm.zkEvm.ethAddress, + }, }); - authManager.forceUserRefresh.mockResolvedValue(mockUserZkEvm); - const result = await registerZkEvmUser({ authManager: authManager as unknown as AuthManager, ethSigner: ethSignerMock as unknown as Signer, @@ -110,7 +75,7 @@ describe('registerZkEvmUser', () => { jsonRpcProvider: jsonRPCProvider as unknown as JsonRpcProvider, }); - expect(result).toEqual(mockUserZkEvm); + expect(result).toEqual(mockUserZkEvm.zkEvm.ethAddress); expect(multiRollupApiClients.passportApi.createCounterfactualAddressV2).toHaveBeenCalledWith({ chainName: ChainName.IMTBL_ZKEVM_TESTNET, createCounterfactualAddressRequest: { @@ -122,6 +87,6 @@ describe('registerZkEvmUser', () => { Authorization: `Bearer ${accessToken}`, }, }); - expect(authManager.forceUserRefresh).toHaveBeenCalledTimes(1); + expect(authManager.forceUserRefreshInBackground).toHaveBeenCalledTimes(1); }); }); diff --git a/packages/passport/sdk/src/zkEvm/user/registerZkEvmUser.ts b/packages/passport/sdk/src/zkEvm/user/registerZkEvmUser.ts index f24d1378ff..565d09ed26 100644 --- a/packages/passport/sdk/src/zkEvm/user/registerZkEvmUser.ts +++ b/packages/passport/sdk/src/zkEvm/user/registerZkEvmUser.ts @@ -3,7 +3,6 @@ import { MultiRollupApiClients } from '@imtbl/generated-clients'; import { signRaw } from '@imtbl/toolkit'; import { getEip155ChainId } from 'zkEvm/walletHelpers'; import { Signer } from '@ethersproject/abstract-signer'; -import { UserZkEvm } from '../../types'; import AuthManager from '../../authManager'; import { JsonRpcError, RpcErrorCode } from '../JsonRpcError'; @@ -23,7 +22,7 @@ export async function registerZkEvmUser({ multiRollupApiClients, accessToken, jsonRpcProvider, -}: RegisterZkEvmUserInput): Promise { +}: RegisterZkEvmUserInput): Promise { const [ethereumAddress, ethereumSignature, network, chainListResponse] = await Promise.all([ ethSigner.getAddress(), signRaw(MESSAGE_TO_SIGN, ethSigner), @@ -41,7 +40,7 @@ export async function registerZkEvmUser({ } try { - await multiRollupApiClients.passportApi.createCounterfactualAddressV2({ + const registrationResponse = await multiRollupApiClients.passportApi.createCounterfactualAddressV2({ chainName, createCounterfactualAddressRequest: { ethereum_address: ethereumAddress, @@ -50,14 +49,11 @@ export async function registerZkEvmUser({ }, { headers: { Authorization: `Bearer ${accessToken}` }, }); + + authManager.forceUserRefreshInBackground(); + + return registrationResponse.data.counterfactual_address; } catch (error) { throw new JsonRpcError(RpcErrorCode.INTERNAL_ERROR, `Failed to create counterfactual address: ${error}`); } - - const user = await authManager.forceUserRefresh(); - if (!user?.zkEvm) { - throw new JsonRpcError(RpcErrorCode.INTERNAL_ERROR, 'Failed to refresh user details'); - } - - return user as UserZkEvm; } diff --git a/packages/passport/sdk/src/zkEvm/zkEvmProvider.ts b/packages/passport/sdk/src/zkEvm/zkEvmProvider.ts index de14d2ec82..314cb69dd0 100644 --- a/packages/passport/sdk/src/zkEvm/zkEvmProvider.ts +++ b/packages/passport/sdk/src/zkEvm/zkEvmProvider.ts @@ -128,7 +128,7 @@ export class ZkEvmProvider implements Provider { * @see #getSigner * */ - async #initialiseEthSigner(user: User) { + #initialiseEthSigner(user: User) { const generateSigner = async (): Promise => { const magicRpcProvider = await this.#magicAdapter.login(user.idToken!); const web3Provider = new Web3Provider(magicRpcProvider); @@ -173,16 +173,13 @@ export class ZkEvmProvider implements Provider { if (!isZkEvmUser(user)) { const ethSigner = await this.#getSigner(); - - const userZkEvm = await registerZkEvmUser({ + this.#zkEvmAddress = await registerZkEvmUser({ ethSigner, authManager: this.#authManager, multiRollupApiClients: this.#multiRollupApiClients, accessToken: user.accessToken, jsonRpcProvider: this.#jsonRpcProvider, }); - - this.#zkEvmAddress = userZkEvm.zkEvm.ethAddress; } else { this.#zkEvmAddress = user.zkEvm.ethAddress; }