Skip to content

Commit

Permalink
refactor(NetworkStateManager): update NetworkStateManager with new …
Browse files Browse the repository at this point in the history
…configuration (#369)

* chore: update network state mgr with default network config

* chore: lint fix

* chore: update function comment

* chore: lint fix
  • Loading branch information
stanleyyconsensys authored Oct 8, 2024
1 parent c96f75e commit 5b2696f
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 24 deletions.
16 changes: 15 additions & 1 deletion packages/starknet-snap/src/config.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
import { SnapEnv } from './utils/constants';
import type { Network } from './types/snapState';
import {
SnapEnv,
STARKNET_MAINNET_NETWORK,
STARKNET_SEPOLIA_TESTNET_NETWORK,
} from './utils/constants';
import { LogLevel } from './utils/logger';

export type SnapConfig = {
logLevel: string;
snapEnv: SnapEnv;
defaultNetwork: Network;
availableNetworks: Network[];
};

export const Config: SnapConfig = {
// eslint-disable-next-line no-restricted-globals
logLevel: process.env.LOG_LEVEL ?? LogLevel.OFF.valueOf().toString(),
// eslint-disable-next-line no-restricted-globals
snapEnv: (process.env.SNAP_ENV ?? SnapEnv.Prod) as unknown as SnapEnv,

defaultNetwork: STARKNET_MAINNET_NETWORK,

availableNetworks: [
STARKNET_MAINNET_NETWORK,
STARKNET_SEPOLIA_TESTNET_NETWORK,
],
};
49 changes: 36 additions & 13 deletions packages/starknet-snap/src/state/network-state-manager.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { constants } from 'starknet';

import { Config } from '../config';
import type { Network } from '../types/snapState';
import {
STARKNET_MAINNET_NETWORK,
STARKNET_SEPOLIA_TESTNET_NETWORK,
STARKNET_TESTNET_NETWORK,
} from '../utils/constants';
import { mockState } from './__tests__/helper';
import { NetworkStateManager, ChainIdFilter } from './network-state-manager';
Expand All @@ -14,7 +16,7 @@ describe('NetworkStateManager', () => {
it('returns the network', async () => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
await mockState({
networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK],
networks: Config.availableNetworks,
});

const stateManager = new NetworkStateManager();
Expand All @@ -25,15 +27,27 @@ describe('NetworkStateManager', () => {
expect(result).toStrictEqual(STARKNET_SEPOLIA_TESTNET_NETWORK);
});

it('returns null if the network can not be found', async () => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
it('looks up the configuration if the network cant be found in state', async () => {
await mockState({
networks: [STARKNET_MAINNET_NETWORK],
});

const stateManager = new NetworkStateManager();
const result = await stateManager.getNetwork({
chainId,
chainId: STARKNET_SEPOLIA_TESTNET_NETWORK.chainId,
});

expect(result).toStrictEqual(STARKNET_SEPOLIA_TESTNET_NETWORK);
});

it('returns null if the network can not be found', async () => {
await mockState({
networks: Config.availableNetworks,
});

const stateManager = new NetworkStateManager();
const result = await stateManager.getNetwork({
chainId: '0x9999',
});

expect(result).toBeNull();
Expand Down Expand Up @@ -103,7 +117,7 @@ describe('NetworkStateManager', () => {
it('returns the list of network by chainId', async () => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
await mockState({
networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK],
networks: Config.availableNetworks,
});

const stateManager = new NetworkStateManager();
Expand Down Expand Up @@ -163,7 +177,7 @@ describe('NetworkStateManager', () => {
describe('getCurrentNetwork', () => {
it('get the current network', async () => {
await mockState({
networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK],
networks: Config.availableNetworks,
currentNetwork: STARKNET_MAINNET_NETWORK,
});

Expand All @@ -173,15 +187,27 @@ describe('NetworkStateManager', () => {
expect(result).toStrictEqual(STARKNET_MAINNET_NETWORK);
});

it('returns null if the current network is null or undefined', async () => {
it(`returns default network if the current network is null or undefined`, async () => {
await mockState({
networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK],
networks: Config.availableNetworks,
});

const stateManager = new NetworkStateManager();
const result = await stateManager.getCurrentNetwork();

expect(result).toBeNull();
expect(result).toStrictEqual(Config.defaultNetwork);
});

it(`returns default network if the current network is neither mainnet or sepolia testnet`, async () => {
await mockState({
networks: Config.availableNetworks,
currentNetwork: STARKNET_TESTNET_NETWORK,
});

const stateManager = new NetworkStateManager();
const result = await stateManager.getCurrentNetwork();

expect(result).toStrictEqual(Config.defaultNetwork);
});
});

Expand Down Expand Up @@ -213,10 +239,7 @@ describe('NetworkStateManager', () => {
updateTo: Network;
}) => {
const { state } = await mockState({
networks: [
STARKNET_MAINNET_NETWORK,
STARKNET_SEPOLIA_TESTNET_NETWORK,
],
networks: Config.availableNetworks,
currentNetwork,
});

Expand Down
34 changes: 27 additions & 7 deletions packages/starknet-snap/src/state/network-state-manager.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { assert, string } from 'superstruct';

import { Config } from '../config';
import type { Network, SnapState } from '../types/snapState';
import type { IFilter } from './filter';
import { ChainIdFilter as BaseChainIdFilter } from './filter';
Expand Down Expand Up @@ -59,6 +60,9 @@ export class NetworkStateManager extends StateManager<Network> {

/**
* Finds a network based on the given chainId.
* The query will first be looked up in the state. If the result is false, it will then fallback to the available Networks constants.
*
* (Note) Due to the returned network object may not exist in the state, it may failed to execute `updateNetwork` with the returned network object.
*
* @param param - The param object.
* @param param.chainId - The chainId to search for.
Expand All @@ -74,7 +78,12 @@ export class NetworkStateManager extends StateManager<Network> {
state?: SnapState,
): Promise<Network | null> {
const filters: INetworkFilter[] = [new ChainIdFilter([chainId])];
return this.find(filters, state);
// in case the network not found from the state, try to get the network from the available Networks constants
return (
(await this.find(filters, state)) ??
Config.availableNetworks.find((network) => network.chainId === chainId) ??
null
);
}

/**
Expand All @@ -88,10 +97,9 @@ export class NetworkStateManager extends StateManager<Network> {
async updateNetwork(data: Network): Promise<void> {
try {
await this.update(async (state: SnapState) => {
const dataInState = await this.getNetwork(
{
chainId: data.chainId,
},
// Use underlying function `find` to avoid searching network from constants
const dataInState = await this.find(
[new ChainIdFilter([data.chainId])],
state,
);

Expand All @@ -111,8 +119,20 @@ export class NetworkStateManager extends StateManager<Network> {
* @param [state] - The optional SnapState object.
* @returns A Promise that resolves with the current Network object if found, or null if not found.
*/
async getCurrentNetwork(state?: SnapState): Promise<Network | null> {
return (state ?? (await this.get())).currentNetwork ?? null;
async getCurrentNetwork(state?: SnapState): Promise<Network> {
const { currentNetwork } = state ?? (await this.get());

// Make sure the current network is either Sepolia testnet or Mainnet. By default it will be Mainnet.
if (
!currentNetwork ||
!Config.availableNetworks.find(
(network) => network.chainId === currentNetwork.chainId,
)
) {
return Config.defaultNetwork;
}

return currentNetwork;
}

/**
Expand Down
6 changes: 3 additions & 3 deletions packages/starknet-snap/src/utils/snapUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type {
UniversalDetails,
} from 'starknet';

import { Config } from '../config';
import {
FeeToken,
type AddErc20TokenRequestParams,
Expand All @@ -34,7 +35,6 @@ import {
MAXIMUM_TOKEN_SYMBOL_LENGTH,
PRELOADED_NETWORKS,
PRELOADED_TOKENS,
STARKNET_MAINNET_NETWORK,
STARKNET_SEPOLIA_TESTNET_NETWORK,
} from './constants';
import { DeployRequiredError, UpgradeRequiredError } from './exceptions';
Expand Down Expand Up @@ -855,7 +855,7 @@ export function getNetworkFromChainId(
state: SnapState,
targerChainId: string | undefined,
) {
const chainId = targerChainId ?? STARKNET_MAINNET_NETWORK.chainId;
const chainId = targerChainId ?? Config.defaultNetwork.chainId;
const network = getNetwork(state, chainId);
if (network === undefined) {
throw new Error(
Expand Down Expand Up @@ -1117,7 +1117,7 @@ export async function removeAcceptedTransaction(
* @param state
*/
export function getCurrentNetwork(state: SnapState) {
return state.currentNetwork ?? STARKNET_MAINNET_NETWORK;
return state.currentNetwork ?? Config.defaultNetwork;
}

/**
Expand Down

0 comments on commit 5b2696f

Please sign in to comment.