diff --git a/integration/networktest/actions/context.go b/integration/networktest/actions/context.go index c54536b1af..8a231c63a6 100644 --- a/integration/networktest/actions/context.go +++ b/integration/networktest/actions/context.go @@ -16,18 +16,18 @@ var KeyNumberOfTestUsers = ActionKey("numberOfTestUsers") // ActionKey is the type for all test data stored in the context. Go documentation recommends using a typed key rather than string to avoid conflicts. type ActionKey string -func storeTestUser(ctx context.Context, userNumber int, user *userwallet.UserWallet) context.Context { +func storeTestUser(ctx context.Context, userNumber int, user userwallet.User) context.Context { return context.WithValue(ctx, userKey(userNumber), user) } -func FetchTestUser(ctx context.Context, userNumber int) (*userwallet.UserWallet, error) { +func FetchTestUser(ctx context.Context, userNumber int) (userwallet.User, error) { u := ctx.Value(userKey(userNumber)) if u == nil { - return nil, fmt.Errorf("no UserWallet found in context for userNumber=%d", userNumber) + return nil, fmt.Errorf("no userWallet found in context for userNumber=%d", userNumber) } - user, ok := u.(*userwallet.UserWallet) + user, ok := u.(userwallet.User) if !ok { - return nil, fmt.Errorf("user retrieved from context was not of expected type UserWallet for userNumber=%d type=%T", userNumber, u) + return nil, fmt.Errorf("user retrieved from context was not of expected type userWallet for userNumber=%d type=%T", userNumber, u) } return user, nil } diff --git a/integration/networktest/actions/native_fund_actions.go b/integration/networktest/actions/native_fund_actions.go index d356867cac..d01a60d420 100644 --- a/integration/networktest/actions/native_fund_actions.go +++ b/integration/networktest/actions/native_fund_actions.go @@ -18,7 +18,7 @@ type SendNativeFunds struct { GasLimit *big.Int SkipVerify bool - user *userwallet.UserWallet + user userwallet.User txHash *common.Hash } @@ -35,7 +35,7 @@ func (s *SendNativeFunds) Run(ctx context.Context, _ networktest.NetworkConnecto if err != nil { return ctx, err } - txHash, err := user.SendFunds(ctx, target.Address(), s.Amount) + txHash, err := user.SendFunds(ctx, target.Wallet().Address(), s.Amount) if err != nil { return nil, err } diff --git a/integration/networktest/actions/node_actions.go b/integration/networktest/actions/node_actions.go index c6da34fc80..2ff7ed0be9 100644 --- a/integration/networktest/actions/node_actions.go +++ b/integration/networktest/actions/node_actions.go @@ -141,7 +141,7 @@ func (w *waitForValidatorHealthCheckAction) Run(ctx context.Context, network net validator := network.GetValidatorNode(w.validatorIdx) // poll the health check until success or timeout err := retry.Do(func() error { - return networktest.NodeHealthCheck(validator.HostRPCAddress()) + return networktest.NodeHealthCheck(validator.HostRPCWSAddress()) }, retry.NewTimeoutStrategy(w.maxWait, 1*time.Second)) if err != nil { return nil, err @@ -158,7 +158,7 @@ func WaitForSequencerHealthCheck(maxWait time.Duration) networktest.Action { sequencer := network.GetSequencerNode() // poll the health check until success or timeout err := retry.Do(func() error { - return networktest.NodeHealthCheck(sequencer.HostRPCAddress()) + return networktest.NodeHealthCheck(sequencer.HostRPCWSAddress()) }, retry.NewTimeoutStrategy(maxWait, 1*time.Second)) if err != nil { return nil, err diff --git a/integration/networktest/actions/setup_actions.go b/integration/networktest/actions/setup_actions.go index 174a506403..2fa447306a 100644 --- a/integration/networktest/actions/setup_actions.go +++ b/integration/networktest/actions/setup_actions.go @@ -12,7 +12,8 @@ import ( ) type CreateTestUser struct { - UserID int + UserID int + UseGateway bool } func (c *CreateTestUser) String() string { @@ -21,9 +22,22 @@ func (c *CreateTestUser) String() string { func (c *CreateTestUser) Run(ctx context.Context, network networktest.NetworkConnector) (context.Context, error) { logger := testlog.Logger() + wal := datagenerator.RandomWallet(integration.TenChainID) - // traffic sim users are round robin-ed onto the validators for now (todo (@matt) - make that overridable) - user := userwallet.NewUserWallet(wal.PrivateKey(), network.ValidatorRPCAddress(c.UserID%network.NumValidators()), logger) + var user userwallet.User + if c.UseGateway { + gwURL, err := network.GetGatewayURL() + if err != nil { + return ctx, fmt.Errorf("failed to get required gateway URL: %w", err) + } + user, err = userwallet.NewGatewayUser(wal, gwURL, logger) + if err != nil { + return ctx, fmt.Errorf("failed to create gateway user: %w", err) + } + } else { + // traffic sim users are round robin-ed onto the validators for now (todo (@matt) - make that overridable) + user = userwallet.NewUserWallet(wal, network.ValidatorRPCAddress(c.UserID%network.NumValidators()), logger) + } return storeTestUser(ctx, c.UserID, user), nil } @@ -44,7 +58,7 @@ func (a *AllocateFaucetFunds) Run(ctx context.Context, network networktest.Netwo if err != nil { return ctx, err } - return ctx, network.AllocateFaucetFunds(ctx, user.Address()) + return ctx, network.AllocateFaucetFunds(ctx, user.Wallet().Address()) } func (a *AllocateFaucetFunds) Verify(_ context.Context, _ networktest.NetworkConnector) error { @@ -65,23 +79,3 @@ func CreateAndFundTestUsers(numUsers int) *MultiAction { newUserActions = append(newUserActions, SnapshotUserBalances(SnapAfterAllocation)) return Series(newUserActions...) } - -func AuthenticateAllUsers() networktest.Action { - return RunOnlyAction(func(ctx context.Context, network networktest.NetworkConnector) (context.Context, error) { - numUsers, err := FetchNumberOfTestUsers(ctx) - if err != nil { - return nil, fmt.Errorf("expected number of test users to be set on the context") - } - for i := 0; i < numUsers; i++ { - user, err := FetchTestUser(ctx, i) - if err != nil { - return nil, err - } - err = user.ResetClient(ctx) - if err != nil { - return nil, fmt.Errorf("unable to (re)authenticate client %d - %w", i, err) - } - } - return ctx, nil - }) -} diff --git a/integration/networktest/env/dev_network.go b/integration/networktest/env/dev_network.go index 89281d564d..988a00bb3b 100644 --- a/integration/networktest/env/dev_network.go +++ b/integration/networktest/env/dev_network.go @@ -28,12 +28,12 @@ func (d *devNetworkEnv) Prepare() (networktest.NetworkConnector, func(), error) } func awaitNodesAvailable(nc networktest.NetworkConnector) error { - err := awaitHealthStatus(nc.GetSequencerNode().HostRPCAddress(), 60*time.Second) + err := awaitHealthStatus(nc.GetSequencerNode().HostRPCWSAddress(), 60*time.Second) if err != nil { return err } for i := 0; i < nc.NumValidators(); i++ { - err := awaitHealthStatus(nc.GetValidatorNode(i).HostRPCAddress(), 60*time.Second) + err := awaitHealthStatus(nc.GetValidatorNode(i).HostRPCWSAddress(), 60*time.Second) if err != nil { return err } @@ -61,8 +61,12 @@ func awaitHealthStatus(rpcAddress string, timeout time.Duration) error { }, retry.NewTimeoutStrategy(timeout, 200*time.Millisecond)) } -func LocalDevNetwork() networktest.Environment { - return &devNetworkEnv{inMemDevNetwork: devnetwork.DefaultDevNetwork()} +func LocalDevNetwork(opts ...LocalNetworkOption) networktest.Environment { + config := &LocalNetworkConfig{} + for _, opt := range opts { + opt(config) + } + return &devNetworkEnv{inMemDevNetwork: devnetwork.DefaultDevNetwork(config.TenGatewayEnabled)} } // LocalNetworkLiveL1 creates a local network that points to a live running L1. @@ -70,3 +74,15 @@ func LocalDevNetwork() networktest.Environment { func LocalNetworkLiveL1(seqWallet wallet.Wallet, validatorWallets []wallet.Wallet, l1RPCURLs []string) networktest.Environment { return &devNetworkEnv{inMemDevNetwork: devnetwork.LiveL1DevNetwork(seqWallet, validatorWallets, l1RPCURLs)} } + +type LocalNetworkConfig struct { + TenGatewayEnabled bool +} + +type LocalNetworkOption func(*LocalNetworkConfig) + +func WithTenGateway() LocalNetworkOption { + return func(c *LocalNetworkConfig) { + c.TenGatewayEnabled = true + } +} diff --git a/integration/networktest/env/network_setup.go b/integration/networktest/env/network_setup.go index 464fc829ff..507700b461 100644 --- a/integration/networktest/env/network_setup.go +++ b/integration/networktest/env/network_setup.go @@ -11,6 +11,7 @@ func SepoliaTestnet() networktest.Environment { []string{"http://erpc.sepolia-testnet.ten.xyz:80"}, "http://sepolia-testnet-faucet.uksouth.azurecontainer.io/fund/eth", "https://rpc.sepolia.org/", + "https://testnet.ten.xyz", // :81 for websocket ) return &testnetEnv{connector} } @@ -21,6 +22,7 @@ func UATTestnet() networktest.Environment { []string{"http://erpc.uat-testnet.ten.xyz:80"}, "http://uat-testnet-faucet.uksouth.azurecontainer.io/fund/eth", "ws://uat-testnet-eth2network.uksouth.cloudapp.azure.com:9000", + "https://uat-testnet.ten.xyz", ) return &testnetEnv{connector} } @@ -31,6 +33,7 @@ func DevTestnet() networktest.Environment { []string{"http://erpc.dev-testnet.ten.xyz:80"}, "http://dev-testnet-faucet.uksouth.azurecontainer.io/fund/eth", "ws://dev-testnet-eth2network.uksouth.cloudapp.azure.com:9000", + "https://dev-testnet.ten.xyz", ) return &testnetEnv{connector} } @@ -42,6 +45,7 @@ func LongRunningLocalNetwork(l1WSURL string) networktest.Environment { []string{"ws://127.0.0.1:37901"}, genesis.TestnetPrefundedPK, l1WSURL, + "", ) return &testnetEnv{connector} } diff --git a/integration/networktest/env/testnet.go b/integration/networktest/env/testnet.go index fbe5a3e08d..103defc574 100644 --- a/integration/networktest/env/testnet.go +++ b/integration/networktest/env/testnet.go @@ -32,28 +32,32 @@ type testnetConnector struct { validatorRPCAddresses []string faucetHTTPAddress string l1RPCURL string - faucetWallet *userwallet.UserWallet + tenGatewayURL string + faucetWallet userwallet.User } -func NewTestnetConnector(seqRPCAddr string, validatorRPCAddressses []string, faucetHTTPAddress string, l1WSURL string) networktest.NetworkConnector { +func NewTestnetConnector(seqRPCAddr string, validatorRPCAddressses []string, faucetHTTPAddress string, l1WSURL string, tenGatewayURL string) networktest.NetworkConnector { return &testnetConnector{ seqRPCAddress: seqRPCAddr, validatorRPCAddresses: validatorRPCAddressses, faucetHTTPAddress: faucetHTTPAddress, l1RPCURL: l1WSURL, + tenGatewayURL: tenGatewayURL, } } -func NewTestnetConnectorWithFaucetAccount(seqRPCAddr string, validatorRPCAddressses []string, faucetPK string, l1RPCAddress string) networktest.NetworkConnector { +func NewTestnetConnectorWithFaucetAccount(seqRPCAddr string, validatorRPCAddressses []string, faucetPK string, l1RPCAddress string, tenGatewayURL string) networktest.NetworkConnector { ecdsaKey, err := crypto.HexToECDSA(faucetPK) if err != nil { panic(err) } + wal := wallet.NewInMemoryWalletFromPK(big.NewInt(integration.TenChainID), ecdsaKey, testlog.Logger()) return &testnetConnector{ seqRPCAddress: seqRPCAddr, validatorRPCAddresses: validatorRPCAddressses, - faucetWallet: userwallet.NewUserWallet(ecdsaKey, validatorRPCAddressses[0], testlog.Logger(), userwallet.WithChainID(big.NewInt(integration.TenChainID))), + faucetWallet: userwallet.NewUserWallet(wal, validatorRPCAddressses[0], testlog.Logger()), l1RPCURL: l1RPCAddress, + tenGatewayURL: tenGatewayURL, } } @@ -131,3 +135,14 @@ func (t *testnetConnector) AllocateFaucetFundsWithWallet(ctx context.Context, ac func (t *testnetConnector) GetMCOwnerWallet() (wallet.Wallet, error) { return nil, errors.New("testnet connector environments cannot access the MC owner wallet") } + +func (t *testnetConnector) GetGatewayClient() (ethadapter.EthClient, error) { + if t.tenGatewayURL == "" { + return nil, errors.New("gateway client not set for this environment") + } + return ethadapter.NewEthClientFromURL(t.tenGatewayURL, time.Minute, gethcommon.Address{}, testlog.Logger()) +} + +func (t *testnetConnector) GetGatewayURL() (string, error) { + return t.tenGatewayURL, nil +} diff --git a/integration/networktest/interfaces.go b/integration/networktest/interfaces.go index acc27fb7f7..1a74845d30 100644 --- a/integration/networktest/interfaces.go +++ b/integration/networktest/interfaces.go @@ -25,6 +25,7 @@ type NetworkConnector interface { GetValidatorNode(idx int) NodeOperator GetL1Client() (ethadapter.EthClient, error) GetMCOwnerWallet() (wallet.Wallet, error) // wallet that owns the management contract (network admin) + GetGatewayURL() (string, error) } // Action is any step in a test, they will typically be either minimally small steps in the test or they will be containers @@ -62,5 +63,6 @@ type NodeOperator interface { StartHost() error StopHost() error - HostRPCAddress() string + HostRPCHTTPAddress() string + HostRPCWSAddress() string } diff --git a/integration/networktest/tests/gateway/gateway_test.go b/integration/networktest/tests/gateway/gateway_test.go new file mode 100644 index 0000000000..3034291644 --- /dev/null +++ b/integration/networktest/tests/gateway/gateway_test.go @@ -0,0 +1,41 @@ +package gateway + +import ( + "math/big" + "testing" + + "github.com/ten-protocol/go-ten/integration/networktest" + "github.com/ten-protocol/go-ten/integration/networktest/actions" + "github.com/ten-protocol/go-ten/integration/networktest/env" +) + +var _transferAmount = big.NewInt(100_000_000) + +// TestGatewayHappyPath tests ths same functionality as the smoke_test but with the gateway: +// 1. Create two test users +// 2. Allocate funds to the first user +// 3. Send funds from the first user to the second +// 4. Verify the second user has the funds +// 5. Verify the first user has the funds deducted +// To run this test with a local network use the flag to start it with the gateway enabled. +func TestGatewayHappyPath(t *testing.T) { + networktest.TestOnlyRunsInIDE(t) + networktest.Run( + "gateway-happy-path", + t, + env.LocalDevNetwork(env.WithTenGateway()), + actions.Series( + &actions.CreateTestUser{UserID: 0, UseGateway: true}, + &actions.CreateTestUser{UserID: 1, UseGateway: true}, + actions.SetContextValue(actions.KeyNumberOfTestUsers, 2), + + &actions.AllocateFaucetFunds{UserID: 0}, + actions.SnapshotUserBalances(actions.SnapAfterAllocation), // record user balances (we have no guarantee on how much the network faucet allocates) + + &actions.SendNativeFunds{FromUser: 0, ToUser: 1, Amount: _transferAmount}, + + &actions.VerifyBalanceAfterTest{UserID: 1, ExpectedBalance: _transferAmount}, + &actions.VerifyBalanceDiffAfterTest{UserID: 0, Snapshot: actions.SnapAfterAllocation, ExpectedDiff: big.NewInt(0).Neg(_transferAmount)}, + ), + ) +} diff --git a/integration/networktest/tests/helpful/availability_test.go b/integration/networktest/tests/helpful/availability_test.go index a9ff34e9d8..8c03bcdef4 100644 --- a/integration/networktest/tests/helpful/availability_test.go +++ b/integration/networktest/tests/helpful/availability_test.go @@ -13,7 +13,7 @@ import ( "github.com/ten-protocol/go-ten/integration/networktest/env" ) -const _testTimeSpan = 120 * time.Second +const _testTimeSpan = 30 * time.Second // basic test that verifies it can connect the L1 client and L2 client and sees block numbers increasing (useful to sanity check testnet issues etc.) func TestNetworkAvailability(t *testing.T) { @@ -21,7 +21,7 @@ func TestNetworkAvailability(t *testing.T) { networktest.Run( "network-availability", t, - env.DevTestnet(), + env.SepoliaTestnet(), actions.RunOnlyAction(func(ctx context.Context, network networktest.NetworkConnector) (context.Context, error) { client, err := network.GetL1Client() if err != nil { diff --git a/integration/networktest/tests/nodescenario/restart_network_test.go b/integration/networktest/tests/nodescenario/restart_network_test.go index 42a64307a1..4b88854a9c 100644 --- a/integration/networktest/tests/nodescenario/restart_network_test.go +++ b/integration/networktest/tests/nodescenario/restart_network_test.go @@ -56,10 +56,6 @@ func TestRestartNetwork(t *testing.T) { // This needs investigating but it suggests to me that the health check is succeeding prematurely actions.SleepAction(5*time.Second), // allow time for re-sync - // resubmit user viewing keys (all users will have lost their "session") - // todo: get rid of this once the enclave persists viewing keys correctly - actions.AuthenticateAllUsers(), - // another load test, check that the network is still working actions.GenerateUsersRandomisedTransferActionsInParallel(4, 60*time.Second), ), diff --git a/integration/networktest/tests/nodescenario/restart_validator_enclave_test.go b/integration/networktest/tests/nodescenario/restart_validator_enclave_test.go index 1559bd8d7b..6eb85c0506 100644 --- a/integration/networktest/tests/nodescenario/restart_validator_enclave_test.go +++ b/integration/networktest/tests/nodescenario/restart_validator_enclave_test.go @@ -32,10 +32,6 @@ func TestRestartValidatorEnclave(t *testing.T) { // This needs investigating but it suggests to me that the health check is succeeding prematurely actions.SleepAction(5*time.Second), // allow time for re-sync - // resubmit user viewing keys (any users attached to the restarted node will have lost their "session") - // todo (@matt) - get rid of this once the enclave persists viewing keys correctly - actions.AuthenticateAllUsers(), - // another load test (important that at least one of the users will be using the validator with restarted enclave) actions.GenerateUsersRandomisedTransferActionsInParallel(4, 10*time.Second), ), diff --git a/integration/networktest/tests/nodescenario/restart_validator_test.go b/integration/networktest/tests/nodescenario/restart_validator_test.go index 2c6f3b8f2f..bfb35e0602 100644 --- a/integration/networktest/tests/nodescenario/restart_validator_test.go +++ b/integration/networktest/tests/nodescenario/restart_validator_test.go @@ -38,10 +38,6 @@ func TestRestartValidatorNode(t *testing.T) { // This needs investigating but it suggests to me that the health check is succeeding prematurely actions.SleepAction(5*time.Second), // allow time for re-sync - // resubmit user viewing keys (any users attached to the restarted node will have lost their "session") - // todo (@matt) - get rid of this once the enclave persists viewing keys correctly - actions.AuthenticateAllUsers(), - // another load test (important that at least one of the users will be using the validator with restarted enclave) actions.GenerateUsersRandomisedTransferActionsInParallel(4, 10*time.Second), ), diff --git a/integration/networktest/userwallet/authclient.go b/integration/networktest/userwallet/authclient.go new file mode 100644 index 0000000000..7402bb3097 --- /dev/null +++ b/integration/networktest/userwallet/authclient.go @@ -0,0 +1,131 @@ +package userwallet + +import ( + "context" + "errors" + "fmt" + "math/big" + "time" + + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + gethlog "github.com/ethereum/go-ethereum/log" + "github.com/ten-protocol/go-ten/go/common/retry" + "github.com/ten-protocol/go-ten/go/obsclient" + "github.com/ten-protocol/go-ten/go/rpc" + "github.com/ten-protocol/go-ten/go/wallet" +) + +const ( + _maxReceiptWaitTime = 30 * time.Second + _receiptPollInterval = 1 * time.Second // todo (@matt) this should be configured using network timings provided by env +) + +// AuthClientUser is a test user that uses the auth client to talk to directly to a node +// Note: AuthClientUser is **not** thread-safe for a single wallet (creates nonce conflicts etc.) +type AuthClientUser struct { + wal wallet.Wallet + rpcEndpoint string + + client *obsclient.AuthObsClient // lazily initialised and authenticated on first usage + logger gethlog.Logger +} + +func NewUserWallet(wal wallet.Wallet, rpcEndpoint string, logger gethlog.Logger) *AuthClientUser { + return &AuthClientUser{ + wal: wal, + rpcEndpoint: rpcEndpoint, + logger: logger, + } +} + +func (s *AuthClientUser) SendFunds(ctx context.Context, addr gethcommon.Address, value *big.Int) (*gethcommon.Hash, error) { + err := s.EnsureClientSetup(ctx) + if err != nil { + return nil, fmt.Errorf("unable to prepare client to send funds - %w", err) + } + + txData := &types.LegacyTx{ + Value: value, + To: &addr, + } + tx := s.client.EstimateGasAndGasPrice(txData) //nolint: contextcheck + + txHash, err := s.SendTransaction(ctx, tx) + if err != nil { + return nil, fmt.Errorf("unable to send transaction - %w", err) + } + + return txHash, nil +} + +func (s *AuthClientUser) SendTransaction(ctx context.Context, tx types.TxData) (*gethcommon.Hash, error) { + signedTx, err := s.wal.SignTransaction(tx) + if err != nil { + return nil, fmt.Errorf("unable to sign transaction - %w", err) + } + // fmt.Printf("waiting for receipt hash %s\n", signedTx.Hash()) + err = s.client.SendTransaction(ctx, signedTx) + if err != nil { + return nil, fmt.Errorf("unable to send transaction - %w", err) + } + + txHash := signedTx.Hash() + // transaction has been sent, we increment the nonce + s.wal.GetNonceAndIncrement() + return &txHash, nil +} + +func (s *AuthClientUser) AwaitReceipt(ctx context.Context, txHash *gethcommon.Hash) (*types.Receipt, error) { + var receipt *types.Receipt + var err error + err = retry.Do(func() error { + receipt, err = s.client.TransactionReceipt(ctx, *txHash) + if !errors.Is(err, rpc.ErrNilResponse) { + // nil response means not found. Any other error is unexpected, so we stop polling and fail immediately + return retry.FailFast(err) + } + return err + }, retry.NewTimeoutStrategy(_maxReceiptWaitTime, _receiptPollInterval)) + return receipt, err +} + +// EnsureClientSetup creates an authenticated RPC client (with a viewing key generated, signed and registered) when first called +// Also fetches current nonce value. +func (s *AuthClientUser) EnsureClientSetup(ctx context.Context) error { + if s.client != nil { + // client already setup + return nil + } + authClient, err := obsclient.DialWithAuth(s.rpcEndpoint, s.wal, s.logger) + if err != nil { + return err + } + s.client = authClient + + // fetch current nonce for account + nonce, err := authClient.NonceAt(ctx, big.NewInt(-1)) + if err != nil { + return fmt.Errorf("unable to fetch client nonce - %w", err) + } + s.wal.SetNonce(nonce) + + return nil +} + +func (s *AuthClientUser) NativeBalance(ctx context.Context) (*big.Int, error) { + err := s.EnsureClientSetup(ctx) + if err != nil { + return nil, err + } + return s.client.BalanceAt(ctx, nil) +} + +// Init forces VK setup: currently the faucet http server requires a viewing key for a wallet to even *receive* funds :( +func (s *AuthClientUser) Init(ctx context.Context) (*AuthClientUser, error) { + return s, s.EnsureClientSetup(ctx) +} + +func (s *AuthClientUser) Wallet() wallet.Wallet { + return s.wal +} diff --git a/integration/networktest/userwallet/gateway.go b/integration/networktest/userwallet/gateway.go new file mode 100644 index 0000000000..a6ca793191 --- /dev/null +++ b/integration/networktest/userwallet/gateway.go @@ -0,0 +1,115 @@ +package userwallet + +import ( + "context" + "errors" + "fmt" + "math/big" + "time" + + "github.com/ethereum/go-ethereum" + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethclient" + gethlog "github.com/ethereum/go-ethereum/log" + "github.com/ten-protocol/go-ten/go/common/retry" + "github.com/ten-protocol/go-ten/go/rpc" + "github.com/ten-protocol/go-ten/go/wallet" + "github.com/ten-protocol/go-ten/tools/walletextension/lib" +) + +type GatewayUser struct { + wal wallet.Wallet + + gwLib *lib.TGLib // TenGateway utility + client *ethclient.Client + + // state managed by the wallet + nonce uint64 + + logger gethlog.Logger +} + +func NewGatewayUser(wal wallet.Wallet, gatewayURL string, logger gethlog.Logger) (*GatewayUser, error) { + gwLib := lib.NewTenGatewayLibrary(gatewayURL, "") // not providing wsURL for now, add if we need it + + err := gwLib.Join() + if err != nil { + return nil, fmt.Errorf("failed to join TenGateway: %w", err) + } + err = gwLib.RegisterAccount(wal.PrivateKey(), wal.Address()) + if err != nil { + return nil, fmt.Errorf("failed to register account with TenGateway: %w", err) + } + + client, err := ethclient.Dial(gwLib.HTTP()) + if err != nil { + return nil, fmt.Errorf("failed to dial TenGateway HTTP: %w", err) + } + + fmt.Printf("Registered acc with TenGateway: %s (%s)\n", wal.Address(), gwLib.HTTP()) + + return &GatewayUser{ + wal: wal, + gwLib: gwLib, + client: client, + logger: logger, + }, nil +} + +func (g *GatewayUser) SendFunds(ctx context.Context, addr gethcommon.Address, value *big.Int) (*gethcommon.Hash, error) { + txData := &types.LegacyTx{ + Nonce: g.nonce, + Value: value, + To: &addr, + } + gasPrice, err := g.client.SuggestGasPrice(ctx) + if err != nil { + return nil, fmt.Errorf("unable to suggest gas price - %w", err) + } + txData.GasPrice = gasPrice + gasLimit, err := g.client.EstimateGas(ctx, ethereum.CallMsg{ + From: g.wal.Address(), + To: &addr, + }) + if err != nil { + return nil, fmt.Errorf("unable to estimate gas - %w", err) + } + txData.Gas = gasLimit + signedTx, err := g.wal.SignTransaction(txData) + if err != nil { + return nil, fmt.Errorf("unable to sign transaction - %w", err) + } + err = g.client.SendTransaction(ctx, signedTx) + if err != nil { + return nil, fmt.Errorf("unable to send transaction - %w", err) + } + txHash := signedTx.Hash() + // transaction has been sent, we increment the nonce + g.nonce++ + return &txHash, nil +} + +func (g *GatewayUser) AwaitReceipt(ctx context.Context, txHash *gethcommon.Hash) (*types.Receipt, error) { + var receipt *types.Receipt + var err error + err = retry.Do(func() error { + receipt, err = g.client.TransactionReceipt(ctx, *txHash) + if !errors.Is(err, rpc.ErrNilResponse) { + return retry.FailFast(err) + } + return err + }, retry.NewTimeoutStrategy(20*time.Second, 1*time.Second)) + if err != nil { + return nil, fmt.Errorf("unable to get receipt - %w", err) + } + return receipt, nil +} + +func (g *GatewayUser) NativeBalance(ctx context.Context) (*big.Int, error) { + return g.client.BalanceAt(ctx, g.wal.Address(), nil) +} + +func (g *GatewayUser) Wallet() wallet.Wallet { + return g.wal +} diff --git a/integration/networktest/userwallet/user.go b/integration/networktest/userwallet/user.go new file mode 100644 index 0000000000..105e2cd8e6 --- /dev/null +++ b/integration/networktest/userwallet/user.go @@ -0,0 +1,22 @@ +package userwallet + +import ( + "context" + "math/big" + + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ten-protocol/go-ten/go/wallet" +) + +// User - abstraction for networktest users - two implementations initially: +// 1. AuthClientUser - a user that uses the auth client to talk to the network +// 2. GatewayUser - a user that uses the gateway to talk to the network +// +// This abstraction allows us to use the same tests for both types of users +type User interface { + Wallet() wallet.Wallet + SendFunds(ctx context.Context, addr gethcommon.Address, value *big.Int) (*gethcommon.Hash, error) + AwaitReceipt(ctx context.Context, txHash *gethcommon.Hash) (*types.Receipt, error) + NativeBalance(ctx context.Context) (*big.Int, error) +} diff --git a/integration/networktest/userwallet/userwallet.go b/integration/networktest/userwallet/userwallet.go deleted file mode 100644 index 6cbd6c6c9d..0000000000 --- a/integration/networktest/userwallet/userwallet.go +++ /dev/null @@ -1,235 +0,0 @@ -package userwallet - -import ( - "context" - "crypto/ecdsa" - "errors" - "fmt" - "math/big" - "time" - - "github.com/ten-protocol/go-ten/integration/common/testlog" - "github.com/ten-protocol/go-ten/integration/datagenerator" - "github.com/ten-protocol/go-ten/integration/networktest" - - gethcommon "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" - gethlog "github.com/ethereum/go-ethereum/log" - "github.com/ten-protocol/go-ten/go/common/retry" - "github.com/ten-protocol/go-ten/go/obsclient" - "github.com/ten-protocol/go-ten/go/rpc" - "github.com/ten-protocol/go-ten/integration" -) - -const ( - _maxReceiptWaitTime = 30 * time.Second - _receiptPollInterval = 1 * time.Second // todo (@matt) this should be configured using network timings provided by env -) - -// UserWallet implements wallet.Wallet so it can be used with the original Wallet code. -// But it aims to provide a wider range of functionality, akin to the software and hardware wallets that users interact with. -// Note: UserWallet is **not** thread-safe for a single wallet (creates nonce conflicts etc.) -type UserWallet struct { - privateKey *ecdsa.PrivateKey - publicKey *ecdsa.PublicKey - accountAddress gethcommon.Address - chainID *big.Int - rpcEndpoint string - - // state managed by the wallet - nonce uint64 - - client *obsclient.AuthObsClient // lazily initialised and authenticated on first usage - logger gethlog.Logger -} - -// Option modifies a UserWallet. See below for options, in the form `WithXxx(xxx)` that can be chained into constructor -type Option func(wallet *UserWallet) - -// GenerateRandomWallet will generate a random wallet with a UserWallet wrapper, connecting to a random validator node -// Note: will use testlog.Logger() as the logger -func GenerateRandomWallet(network networktest.NetworkConnector) *UserWallet { - wallet := datagenerator.RandomWallet(network.ChainID()) - _, err := obsclient.DialWithAuth(network.SequencerRPCAddress(), wallet, testlog.Logger()) - if err != nil { - panic(err) - } - - rndValidatorIdx := int(datagenerator.RandomUInt64()) % network.NumValidators() - return NewUserWallet(wallet.PrivateKey(), network.ValidatorRPCAddress(rndValidatorIdx), testlog.Logger()) -} - -func NewUserWallet(pk *ecdsa.PrivateKey, rpcEndpoint string, logger gethlog.Logger, opts ...Option) *UserWallet { - publicKeyECDSA, ok := pk.Public().(*ecdsa.PublicKey) - if !ok { - // this shouldn't happen - logger.Crit("error casting public key to ECDSA") - } - wal := &UserWallet{ - privateKey: pk, - publicKey: publicKeyECDSA, - accountAddress: crypto.PubkeyToAddress(*publicKeyECDSA), - chainID: big.NewInt(integration.TenChainID), // default, overridable using `WithChainID(...) opt` - rpcEndpoint: rpcEndpoint, - logger: logger, - } - // apply any optional config to the wallet - for _, opt := range opts { - opt(wal) - } - return wal -} - -func (s *UserWallet) ChainID() *big.Int { - return big.NewInt(integration.TenChainID) -} - -func (s *UserWallet) SendFunds(ctx context.Context, addr gethcommon.Address, value *big.Int) (*gethcommon.Hash, error) { - err := s.EnsureClientSetup(ctx) - if err != nil { - return nil, fmt.Errorf("unable to prepare client to send funds - %w", err) - } - - txData := &types.LegacyTx{ - Nonce: s.nonce, - Value: value, - To: &addr, - } - tx := s.client.EstimateGasAndGasPrice(txData) //nolint: contextcheck - - txHash, err := s.SendTransaction(ctx, tx) - if err != nil { - return nil, fmt.Errorf("unable to send transaction - %w", err) - } - - return txHash, nil -} - -func (s *UserWallet) SendTransaction(ctx context.Context, tx types.TxData) (*gethcommon.Hash, error) { - signedTx, err := s.SignTransaction(tx) - if err != nil { - return nil, fmt.Errorf("unable to sign transaction - %w", err) - } - // fmt.Printf("waiting for receipt hash %s\n", signedTx.Hash()) - err = s.client.SendTransaction(ctx, signedTx) - if err != nil { - return nil, fmt.Errorf("unable to send transaction - %w", err) - } - - txHash := signedTx.Hash() - // transaction has been sent, we increment the nonce - s.nonce++ - return &txHash, nil -} - -func (s *UserWallet) AwaitReceipt(ctx context.Context, txHash *gethcommon.Hash) (*types.Receipt, error) { - var receipt *types.Receipt - var err error - err = retry.Do(func() error { - receipt, err = s.client.TransactionReceipt(ctx, *txHash) - if !errors.Is(err, rpc.ErrNilResponse) { - // nil response means not found. Any other error is unexpected, so we stop polling and fail immediately - return retry.FailFast(err) - } - return err - }, retry.NewTimeoutStrategy(_maxReceiptWaitTime, _receiptPollInterval)) - return receipt, err -} - -func (s *UserWallet) Address() gethcommon.Address { - return s.accountAddress -} - -func (s *UserWallet) SignTransaction(tx types.TxData) (*types.Transaction, error) { - return types.SignNewTx(s.privateKey, types.NewLondonSigner(s.chainID), tx) -} - -func (s *UserWallet) SignTransactionForChainID(tx types.TxData, chainID *big.Int) (*types.Transaction, error) { - return types.SignNewTx(s.privateKey, types.NewLondonSigner(chainID), tx) -} - -func (s *UserWallet) GetNonce() uint64 { - return s.nonce -} - -func (s *UserWallet) PrivateKey() *ecdsa.PrivateKey { - return s.privateKey -} - -func (s *UserWallet) SetNonce(_ uint64) { - panic("UserWallet is designed to manage its own nonce - this method exists to support legacy interface methods") -} - -func (s *UserWallet) GetNonceAndIncrement() uint64 { - panic("UserWallet is designed to manage its own nonce - this method exists to support legacy interface methods") -} - -// EnsureClientSetup creates an authenticated RPC client (with a viewing key generated, signed and registered) when first called -// Also fetches current nonce value. -func (s *UserWallet) EnsureClientSetup(ctx context.Context) error { - if s.client != nil { - // client already setup - return nil - } - authClient, err := obsclient.DialWithAuth(s.rpcEndpoint, s, s.logger) - if err != nil { - return err - } - s.client = authClient - - // fetch current nonce for account - nonce, err := authClient.NonceAt(ctx, nil) - if err != nil { - return fmt.Errorf("unable to fetch client nonce - %w", err) - } - s.nonce = nonce - - return nil -} - -// ResetClient creates an authenticated RPC client (with a viewing key generated, signed and registered) -// Also fetches current nonce value. It closes previous client if it exists. -func (s *UserWallet) ResetClient(ctx context.Context) error { - if s.client != nil { - // client already setup, close it before re-authenticating - s.client.Close() - } - authClient, err := obsclient.DialWithAuth(s.rpcEndpoint, s, s.logger) - if err != nil { - return err - } - s.client = authClient - - // fetch current nonce for account - nonce, err := authClient.NonceAt(ctx, nil) - if err != nil { - return fmt.Errorf("unable to fetch client nonce - %w", err) - } - s.nonce = nonce - - return nil -} - -func (s *UserWallet) NativeBalance(ctx context.Context) (*big.Int, error) { - err := s.EnsureClientSetup(ctx) - if err != nil { - return nil, err - } - return s.client.BalanceAt(ctx, nil) -} - -// Init forces VK setup: currently the faucet http server requires a viewing key for a wallet to even *receive* funds :( -func (s *UserWallet) Init(ctx context.Context) (*UserWallet, error) { - return s, s.EnsureClientSetup(ctx) -} - -// UserWalletOptions can be passed into the constructor to override default values -// e.g. NewUserWallet(pk, rpcAddr, logger, WithChainId(123)) -// NewUserWallet(pk, rpcAddr, logger, WithChainId(123), WithRPCTimeout(20*time.Second)), ) - -func WithChainID(chainID *big.Int) Option { - return func(wallet *UserWallet) { - wallet.chainID = chainID - } -} diff --git a/integration/simulation/devnetwork/config.go b/integration/simulation/devnetwork/config.go index 51a4998a5c..0db3e04632 100644 --- a/integration/simulation/devnetwork/config.go +++ b/integration/simulation/devnetwork/config.go @@ -36,7 +36,7 @@ type ObscuroConfig struct { } // DefaultDevNetwork provides an off-the-shelf default config for a sim network -func DefaultDevNetwork() *InMemDevNetwork { +func DefaultDevNetwork(tenGateway bool) *InMemDevNetwork { numNodes := 4 // Default sim currently uses 4 L1 nodes. Obscuro nodes: 1 seq, 3 validators networkWallets := params.NewSimWallets(0, numNodes, integration.EthereumChainID, integration.TenChainID) l1Config := &L1Config{ @@ -57,7 +57,8 @@ func DefaultDevNetwork() *InMemDevNetwork { L1BlockTime: 15 * time.Second, SequencerID: networkWallets.NodeWallets[0].Address(), }, - faucetLock: sync.Mutex{}, + tenGatewayEnabled: tenGateway, + faucetLock: sync.Mutex{}, } } diff --git a/integration/simulation/devnetwork/dev_network.go b/integration/simulation/devnetwork/dev_network.go index a668bdc5d5..cbf1af8a76 100644 --- a/integration/simulation/devnetwork/dev_network.go +++ b/integration/simulation/devnetwork/dev_network.go @@ -9,6 +9,8 @@ import ( "github.com/ten-protocol/go-ten/integration/common/testlog" "github.com/ten-protocol/go-ten/integration/simulation/network" + gatewaycfg "github.com/ten-protocol/go-ten/tools/walletextension/config" + "github.com/ten-protocol/go-ten/tools/walletextension/container" "github.com/ten-protocol/go-ten/go/ethadapter" @@ -25,6 +27,12 @@ import ( "github.com/ten-protocol/go-ten/integration/simulation/params" ) +const ( + // these ports were picked arbitrarily, if we want plan to use these tests on CI we need to use ports in the constants.go file + _gwHTTPPort = 11180 + _gwWSPort = 11181 +) + var _defaultFaucetAmount = big.NewInt(750_000_000_000_000) // InMemDevNetwork is a local dev network (L1 and L2) - the obscuro nodes are in-memory in a single go process, the L1 nodes are a docker geth network @@ -46,14 +54,24 @@ type InMemDevNetwork struct { // - if it is nil when `Start()` is called then Obscuro contracts will be deployed on the L1 l1SetupData *params.L1SetupData - obscuroConfig ObscuroConfig - obscuroSequencer *InMemNodeOperator - obscuroValidators []*InMemNodeOperator + obscuroConfig ObscuroConfig + obscuroSequencer *InMemNodeOperator + obscuroValidators []*InMemNodeOperator + tenGatewayContainer *container.WalletExtensionContainer + + tenGatewayEnabled bool - faucet *userwallet.UserWallet + faucet userwallet.User faucetLock sync.Mutex } +func (s *InMemDevNetwork) GetGatewayURL() (string, error) { + if !s.tenGatewayEnabled { + return "", fmt.Errorf("ten gateway not enabled") + } + return fmt.Sprintf("http://localhost:%d", _gwHTTPPort), nil +} + func (s *InMemDevNetwork) GetMCOwnerWallet() (wallet.Wallet, error) { return s.networkWallets.MCOwnerWallet, nil } @@ -88,12 +106,12 @@ func (s *InMemDevNetwork) AllocateFaucetFunds(ctx context.Context, account gethc func (s *InMemDevNetwork) SequencerRPCAddress() string { seq := s.GetSequencerNode() - return seq.HostRPCAddress() + return seq.HostRPCWSAddress() } func (s *InMemDevNetwork) ValidatorRPCAddress(idx int) string { val := s.GetValidatorNode(idx) - return val.HostRPCAddress() + return val.HostRPCWSAddress() } // GetL1Client returns the first client we have for our local L1 network @@ -129,10 +147,16 @@ func (s *InMemDevNetwork) Start() { } fmt.Println("Starting obscuro nodes") s.startNodes() + + if s.tenGatewayEnabled { + s.startTenGateway() + } + // sleep to allow the nodes to start + time.Sleep(10 * time.Second) } -func (s *InMemDevNetwork) DeployL1StandardContracts() { - // todo (@matt) - separate out L1 contract deployment from the geth network setup to give better sim control +func (s *InMemDevNetwork) GetGatewayClient() (ethadapter.EthClient, error) { + return nil, fmt.Errorf("not implemented") } func (s *InMemDevNetwork) startNodes() { @@ -159,7 +183,38 @@ func (s *InMemDevNetwork) startNodes() { } }(v) } - s.faucet = userwallet.NewUserWallet(s.networkWallets.L2FaucetWallet.PrivateKey(), s.SequencerRPCAddress(), s.logger) + s.faucet = userwallet.NewUserWallet(s.networkWallets.L2FaucetWallet, s.SequencerRPCAddress(), s.logger) +} + +func (s *InMemDevNetwork) startTenGateway() { + validator := s.GetValidatorNode(0) + validatorHTTP := validator.HostRPCHTTPAddress() + // remove http:// prefix for the gateway config + validatorHTTP = validatorHTTP[len("http://"):] + validatorWS := validator.HostRPCWSAddress() + // remove ws:// prefix for the gateway config + validatorWS = validatorWS[len("ws://"):] + cfg := gatewaycfg.Config{ + WalletExtensionHost: "127.0.0.1", + WalletExtensionPortHTTP: _gwHTTPPort, + WalletExtensionPortWS: _gwWSPort, + NodeRPCHTTPAddress: validatorHTTP, + NodeRPCWebsocketAddress: validatorWS, + LogPath: "sys_out", + VerboseFlag: false, + DBType: "sqlite", + TenChainID: integration.TenChainID, + } + tenGWContainer := container.NewWalletExtensionContainerFromConfig(cfg, s.logger) + go func() { + fmt.Println("Starting Ten Gateway") + err := tenGWContainer.Start() + if err != nil { + s.logger.Error("failed to start ten gateway", "err", err) + panic(err) + } + s.tenGatewayContainer = tenGWContainer + }() } func (s *InMemDevNetwork) CleanUp() { @@ -178,6 +233,14 @@ func (s *InMemDevNetwork) CleanUp() { } }() go s.l1Network.CleanUp() + if s.tenGatewayContainer != nil { + go func() { + err := s.tenGatewayContainer.Stop() + if err != nil { + fmt.Println("failed to stop ten gateway", err.Error()) + } + }() + } s.logger.Info("Waiting for servers to stop.") time.Sleep(3 * time.Second) diff --git a/integration/simulation/devnetwork/node.go b/integration/simulation/devnetwork/node.go index c7ebd1e845..493ea6a415 100644 --- a/integration/simulation/devnetwork/node.go +++ b/integration/simulation/devnetwork/node.go @@ -200,11 +200,16 @@ func (n *InMemNodeOperator) Stop() error { return nil } -func (n *InMemNodeOperator) HostRPCAddress() string { +func (n *InMemNodeOperator) HostRPCWSAddress() string { hostPort := n.config.PortStart + integration.DefaultHostRPCWSOffset + n.operatorIdx return fmt.Sprintf("ws://%s:%d", network.Localhost, hostPort) } +func (n *InMemNodeOperator) HostRPCHTTPAddress() string { + hostPort := n.config.PortStart + integration.DefaultHostRPCHTTPOffset + n.operatorIdx + return fmt.Sprintf("http://%s:%d", network.Localhost, hostPort) +} + func (n *InMemNodeOperator) StopEnclave() error { err := n.enclave.Stop() if err != nil { diff --git a/tools/walletextension/api/utils.go b/tools/walletextension/api/utils.go index f97ffed18a..4ca99ac44d 100644 --- a/tools/walletextension/api/utils.go +++ b/tools/walletextension/api/utils.go @@ -30,9 +30,15 @@ func parseRequest(body []byte) (*common.RPCRequest, error) { // we extract the params into a JSON list var params []interface{} - err = json.Unmarshal(reqJSONMap[common.JSONKeyParams], ¶ms) - if err != nil { - return nil, fmt.Errorf("could not unmarshal params list from JSON-RPC request body: %s ; %w", string(body), err) + // params key is optional in JSON-RPC request + _, exists := reqJSONMap[common.JSONKeyParams] + if exists { + err = json.Unmarshal(reqJSONMap[common.JSONKeyParams], ¶ms) + if err != nil { + return nil, fmt.Errorf("could not unmarshal params list from JSON-RPC request body: %s ; %w", string(body), err) + } + } else { + params = []interface{}{} } return &common.RPCRequest{ diff --git a/tools/walletextension/cache/cache.go b/tools/walletextension/cache/cache.go index e1ebc81223..74be7c3a7b 100644 --- a/tools/walletextension/cache/cache.go +++ b/tools/walletextension/cache/cache.go @@ -16,24 +16,30 @@ const ( shortCacheTTL = 1 * time.Second ) +// Define a struct to hold the cache TTL and auth requirement +type RPCMethodCacheConfig struct { + CacheTTL time.Duration + RequiresAuth bool +} + // CacheableRPCMethods is a map of Ethereum JSON-RPC methods that can be cached and their TTL -var cacheableRPCMethods = map[string]time.Duration{ +var cacheableRPCMethods = map[string]RPCMethodCacheConfig{ // Ethereum JSON-RPC methods that can be cached long time - "eth_getBlockByNumber": longCacheTTL, - "eth_getBlockByHash": longCacheTTL, - "eth_getTransactionByHash": longCacheTTL, - "eth_chainId": longCacheTTL, + "eth_getBlockByNumber": {longCacheTTL, false}, + "eth_getBlockByHash": {longCacheTTL, false}, + "eth_getTransactionByHash": {longCacheTTL, true}, + "eth_chainId": {longCacheTTL, false}, // Ethereum JSON-RPC methods that can be cached short time - "eth_blockNumber": shortCacheTTL, - "eth_getCode": shortCacheTTL, - // "eth_getBalance": shortCacheTTL,// excluded for test: gen_cor_059 - "eth_getTransactionReceipt": shortCacheTTL, - "eth_call": shortCacheTTL, - "eth_gasPrice": shortCacheTTL, - // "eth_getTransactionCount": shortCacheTTL, // excluded for test: gen_cor_009 - "eth_estimateGas": shortCacheTTL, - "eth_feeHistory": shortCacheTTL, + "eth_blockNumber": {shortCacheTTL, false}, + "eth_getCode": {shortCacheTTL, true}, + // "eth_getBalance": {longCacheTTL, true},// excluded for test: gen_cor_059 + "eth_getTransactionReceipt": {shortCacheTTL, true}, + "eth_call": {shortCacheTTL, true}, + "eth_gasPrice": {shortCacheTTL, false}, + // "eth_getTransactionCount": {longCacheTTL, true}, // excluded for test: gen_cor_009 + "eth_estimateGas": {shortCacheTTL, true}, + "eth_feeHistory": {shortCacheTTL, false}, } type Cache interface { @@ -46,31 +52,37 @@ func NewCache(logger log.Logger) (Cache, error) { } // IsCacheable checks if the given RPC request is cacheable and returns the cache key and TTL -func IsCacheable(key *common.RPCRequest) (bool, string, time.Duration) { +func IsCacheable(key *common.RPCRequest, encryptionToken string) (bool, string, time.Duration) { if key == nil || key.Method == "" { return false, "", 0 } // Check if the method is cacheable - ttl, isCacheable := cacheableRPCMethods[key.Method] + methodCacheConfig, isCacheable := cacheableRPCMethods[key.Method] + + // If method does not need to be authenticated, we can don't need to cache it per user + if !methodCacheConfig.RequiresAuth { + encryptionToken = "" + } if isCacheable { - // method is cacheable - select cache key + // method is cacheable - select cache key and ttl switch key.Method { - case "eth_getCode", "eth_getBalance", "eth_getTransactionCount", "eth_estimateGas", "eth_call": + case "eth_getCode", "eth_getBalance", "eth_estimateGas", "eth_call": if len(key.Params) == 1 || len(key.Params) == 2 && (key.Params[1] == "latest" || key.Params[1] == "pending") { - return true, GenerateCacheKey(key.Method, key.Params...), ttl + return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), methodCacheConfig.CacheTTL } // in this case, we have a fixed block number, and we can cache the result for a long time - return true, GenerateCacheKey(key.Method, key.Params...), longCacheTTL + return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), longCacheTTL case "eth_feeHistory": if len(key.Params) == 2 || len(key.Params) == 3 && (key.Params[2] == "latest" || key.Params[2] == "pending") { - return true, GenerateCacheKey(key.Method, key.Params...), ttl + return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), methodCacheConfig.CacheTTL } // in this case, we have a fixed block number, and we can cache the result for a long time - return true, GenerateCacheKey(key.Method, key.Params...), longCacheTTL + return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), longCacheTTL + default: - return true, GenerateCacheKey(key.Method, key.Params...), ttl + return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), methodCacheConfig.CacheTTL } } @@ -78,8 +90,9 @@ func IsCacheable(key *common.RPCRequest) (bool, string, time.Duration) { return false, "", 0 } -// GenerateCacheKey generates a cache key for the given method and parameters -func GenerateCacheKey(method string, params ...interface{}) string { +// GenerateCacheKey generates a cache key for the given method, encryptionToken and parameters +// encryptionToken is used to generate a unique cache key for each user and empty string should be used for public data +func GenerateCacheKey(method string, encryptionToken string, params ...interface{}) string { // Serialize parameters paramBytes, err := json.Marshal(params) if err != nil { @@ -87,7 +100,7 @@ func GenerateCacheKey(method string, params ...interface{}) string { } // Concatenate method name and parameters - rawKey := method + string(paramBytes) + rawKey := method + encryptionToken + string(paramBytes) // Optional: Apply hashing hasher := sha256.New() diff --git a/tools/walletextension/cache/cache_test.go b/tools/walletextension/cache/cache_test.go index d5f714d6d6..30012f9dcf 100644 --- a/tools/walletextension/cache/cache_test.go +++ b/tools/walletextension/cache/cache_test.go @@ -17,11 +17,17 @@ var tests = map[string]func(t *testing.T){ } var cacheTests = map[string]func(cache Cache, t *testing.T){ - "testResultsAreCached": testResultsAreCached, - "testCacheTTL": testCacheTTL, + "testResultsAreCached": testResultsAreCached, + "testCacheTTL": testCacheTTL, + "testCachingAuthenticatedMethods": testCachingAuthenticatedMethods, + "testCachingNonAuthenticatedMethods": testCachingNonAuthenticatedMethods, } -var nonCacheableMethods = []string{"eth_sendrawtransaction", "eth_sendtransaction", "join", "authenticate"} +var ( + nonCacheableMethods = []string{"eth_sendrawtransaction", "eth_sendtransaction", "join", "authenticate"} + encryptionToken = "test" + encryptionToken2 = "not-test" +) func TestGatewayCaching(t *testing.T) { for name, test := range tests { @@ -47,7 +53,7 @@ func TestGatewayCaching(t *testing.T) { func testCacheableMethods(t *testing.T) { for method := range cacheableRPCMethods { key := &common.RPCRequest{Method: method} - isCacheable, _, _ := IsCacheable(key) + isCacheable, _, _ := IsCacheable(key, encryptionToken) if isCacheable != true { t.Errorf("method %s should be cacheable", method) } @@ -58,7 +64,7 @@ func testCacheableMethods(t *testing.T) { func testNonCacheableMethods(t *testing.T) { for _, method := range nonCacheableMethods { key := &common.RPCRequest{Method: method} - isCacheable, _, _ := IsCacheable(key) + isCacheable, _, _ := IsCacheable(key, encryptionToken) if isCacheable == true { t.Errorf("method %s should not be cacheable", method) } @@ -70,13 +76,13 @@ func testMethodsWithLatestOrPendingParameter(t *testing.T) { methods := []string{"eth_getCode", "eth_estimateGas", "eth_call"} for _, method := range methods { key := &common.RPCRequest{Method: method, Params: []interface{}{"0x123", "latest"}} - _, _, ttl := IsCacheable(key) + _, _, ttl := IsCacheable(key, encryptionToken) if ttl != shortCacheTTL { t.Errorf("method %s with latest parameter should have TTL of %s, but %s received", method, shortCacheTTL, ttl) } key = &common.RPCRequest{Method: method, Params: []interface{}{"0x123", "pending"}} - _, _, ttl = IsCacheable(key) + _, _, ttl = IsCacheable(key, encryptionToken) if ttl != shortCacheTTL { t.Errorf("method %s with pending parameter should have TTL of %s, but %s received", method, shortCacheTTL, ttl) } @@ -88,7 +94,7 @@ func testResultsAreCached(cache Cache, t *testing.T) { // prepare a cacheable request and imaginary response req := &common.RPCRequest{Method: "eth_getBlockByNumber", Params: []interface{}{"0x123"}} res := map[string]interface{}{"result": "block"} - isCacheable, key, ttl := IsCacheable(req) + isCacheable, key, ttl := IsCacheable(req, encryptionToken) if !isCacheable { t.Errorf("method %s should be cacheable", req.Method) } @@ -112,7 +118,7 @@ func testResultsAreCached(cache Cache, t *testing.T) { func testCacheTTL(cache Cache, t *testing.T) { req := &common.RPCRequest{Method: "eth_blockNumber", Params: []interface{}{"0x123"}} res := map[string]interface{}{"result": "100"} - isCacheable, key, ttl := IsCacheable(req) + isCacheable, key, ttl := IsCacheable(req, encryptionToken) if !isCacheable { t.Errorf("method %s should be cacheable", req.Method) @@ -145,3 +151,99 @@ func testCacheTTL(cache Cache, t *testing.T) { t.Errorf("value should not be in the cache after TTL") } } + +func testCachingAuthenticatedMethods(cache Cache, t *testing.T) { + // eth_getTransactionByHash + authMethods := []string{ + "eth_getTransactionByHash", + "eth_getCode", + "eth_getTransactionReceipt", + "eth_call", + "eth_estimateGas", + } + for _, method := range authMethods { + req := &common.RPCRequest{Method: method, Params: []interface{}{"0x123"}} + res := map[string]interface{}{"result": "transaction"} + + // store the response in cache for the first user using encryptionToken + isCacheable, key, ttl := IsCacheable(req, encryptionToken) + + if !isCacheable { + t.Errorf("method %s should be cacheable", req.Method) + } + + // set the response in the cache with a TTL + if !cache.Set(key, res, ttl) { + t.Errorf("failed to set value in cache for %s", req) + } + time.Sleep(50 * time.Millisecond) // wait for the cache to be set + + // check if the value is in the cache + value, ok := cache.Get(key) + if !ok { + t.Errorf("failed to get cached value for %s", req) + } + + // for the first error we should have the value in cache + if !reflect.DeepEqual(value, res) { + t.Errorf("expected %v, got %v", res, value) + } + + // now check with the second user asking for the same request, but with a different encryptionToken + _, key2, _ := IsCacheable(req, encryptionToken2) + + _, okSecondUser := cache.Get(key2) + if okSecondUser { + t.Errorf("another user should not see a value the first user cached %s", req) + } + } +} + +func testCachingNonAuthenticatedMethods(cache Cache, t *testing.T) { + // eth_getTransactionByHash + nonAuthMethods := []string{ + "eth_getBlockByNumber", + "eth_getBlockByHash", + "eth_chainId", + "eth_blockNumber", + "eth_gasPrice", + "eth_feeHistory", + } + + for _, method := range nonAuthMethods { + req := &common.RPCRequest{Method: method, Params: []interface{}{"0x123"}} + res := map[string]interface{}{"result": "transaction"} + + // store the response in cache for the first user using encryptionToken + isCacheable, key, ttl := IsCacheable(req, encryptionToken) + + if !isCacheable { + t.Errorf("method %s should be cacheable", req.Method) + } + + // set the response in the cache with a TTL + if !cache.Set(key, res, ttl) { + t.Errorf("failed to set value in cache for %s", req) + } + time.Sleep(50 * time.Millisecond) // wait for the cache to be set + + // check if the value is in the cache + value, ok := cache.Get(key) + if !ok { + t.Errorf("failed to get cached value for %s", req) + } + + // for the first error we should have the value in cache + if !reflect.DeepEqual(value, res) { + t.Errorf("expected %v, got %v", res, value) + } + + // now check with the second user asking for the same request, but with a different encryptionToken + _, key2, _ := IsCacheable(req, encryptionToken2) + + _, okSecondUser := cache.Get(key2) + if !okSecondUser { + t.Errorf("another user should see a value the first user cached %s", req) + } + } +} diff --git a/tools/walletextension/useraccountmanager/user_account_manager.go b/tools/walletextension/useraccountmanager/user_account_manager.go index b450181801..9af05a8779 100644 --- a/tools/walletextension/useraccountmanager/user_account_manager.go +++ b/tools/walletextension/useraccountmanager/user_account_manager.go @@ -21,7 +21,7 @@ type UserAccountManager struct { hostRPCBinAddrHTTP string hostRPCBinAddrWS string logger gethlog.Logger - mu sync.RWMutex + mu sync.Mutex } func NewUserAccountManager(unauthenticatedClient rpc.Client, logger gethlog.Logger, storage storage.Storage, hostRPCBindAddrHTTP string, hostRPCBindAddrWS string) UserAccountManager { @@ -37,8 +37,8 @@ func NewUserAccountManager(unauthenticatedClient rpc.Client, logger gethlog.Logg // AddAndReturnAccountManager adds new UserAccountManager if it doesn't exist and returns it, if UserAccountManager already exists for that user just return it func (m *UserAccountManager) AddAndReturnAccountManager(userID string) *accountmanager.AccountManager { - m.mu.RLock() - defer m.mu.RUnlock() + m.mu.Lock() + defer m.mu.Unlock() existingUserAccountManager, exists := m.userAccountManager[userID] if exists { return existingUserAccountManager diff --git a/tools/walletextension/wallet_extension.go b/tools/walletextension/wallet_extension.go index a130ba7fc6..f8625f6a66 100644 --- a/tools/walletextension/wallet_extension.go +++ b/tools/walletextension/wallet_extension.go @@ -107,20 +107,20 @@ func (w *WalletExtension) ProxyEthRequest(request *common.RPCRequest, conn userc // start measuring time for request requestStartTime := time.Now() - //// Check if the request is in the cache - //isCacheable, key, ttl := cache.IsCacheable(request) - // - //// in case of cache hit return the response from the cache - //if isCacheable { - // if value, ok := w.cache.Get(key); ok { - // requestEndTime := time.Now() - // duration := requestEndTime.Sub(requestStartTime) - // w.fileLogger.Info(fmt.Sprintf("Request method: %s, request params: %s, encryptionToken of sender: %s, response: %s, duration: %d ", request.Method, request.Params, hexUserID, value, duration.Milliseconds())) - // // adjust requestID - // value[common.JSONKeyID] = request.ID - // return value, nil - // } - //} + // Check if the request is in the cache + isCacheable, key, ttl := cache.IsCacheable(request, hexUserID) + + // in case of cache hit return the response from the cache + if isCacheable { + if value, ok := w.cache.Get(key); ok { + requestEndTime := time.Now() + duration := requestEndTime.Sub(requestStartTime) + w.fileLogger.Info(fmt.Sprintf("Request method: %s, request params: %s, encryptionToken of sender: %s, response: %s, duration: %d ", request.Method, request.Params, hexUserID, value, duration.Milliseconds())) + // adjust requestID + value[common.JSONKeyID] = request.ID + return value, nil + } + } // proxyRequest will find the correct client to proxy the request (or try them all if appropriate) var rpcResp interface{} @@ -166,9 +166,9 @@ func (w *WalletExtension) ProxyEthRequest(request *common.RPCRequest, conn userc w.fileLogger.Info(fmt.Sprintf("Request method: %s, request params: %s, encryptionToken of sender: %s, response: %s, duration: %d ", request.Method, request.Params, hexUserID, response, duration.Milliseconds())) // if the request is cacheable, store the response in the cache - //if isCacheable { - // w.cache.Set(key, response, ttl) - //} + if isCacheable { + w.cache.Set(key, response, ttl) + } return response, nil }