diff --git a/pkg/solana/chain.go b/pkg/solana/chain.go index 56f37bc07..c47e1cf1b 100644 --- a/pkg/solana/chain.go +++ b/pkg/solana/chain.go @@ -22,11 +22,12 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/core" - - mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" + "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" "github.com/smartcontractkit/chainlink-solana/pkg/solana/monitor" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" ) @@ -90,7 +91,7 @@ type chain struct { // if multiNode is enabled, the clientCache will not be used multiNode *mn.MultiNode[mn.StringID, *client.MultiNodeClient] - txSender *mn.TransactionSender[*solanago.Transaction, mn.StringID, *client.MultiNodeClient] + txSender *mn.TransactionSender[*solanago.Transaction, *client.SendTxResult, mn.StringID, *client.MultiNodeClient] // tracking node chain id for verification clientCache map[string]*verifiedCachedClient // map URL -> {client, chainId} [mainnet/testnet/devnet/localnet] @@ -230,6 +231,12 @@ func newChain(id string, cfg *config.TOMLConfig, ks loop.Keystore, lggr logger.L clientCache: map[string]*verifiedCachedClient{}, } + var tc internal.Loader[client.ReaderWriter] = utils.NewLazyLoad(func() (client.ReaderWriter, error) { return ch.getClient() }) + var bc internal.Loader[monitor.BalanceClient] = utils.NewLazyLoad(func() (monitor.BalanceClient, error) { return ch.getClient() }) + + // txm will default to sending transactions using a single RPC client if sendTx is nil + var sendTx func(ctx context.Context, tx *solanago.Transaction) (solanago.Signature, error) + if cfg.MultiNode.Enabled() { chainFamily := "solana" @@ -268,18 +275,12 @@ func newChain(id string, cfg *config.TOMLConfig, ks loop.Keystore, lggr logger.L mnCfg.DeathDeclarationDelay(), ) - // TODO: implement error classification; move logic to separate file if large - // TODO: might be useful to reference anza-xyz/agave@master/sdk/src/transaction/error.rs - classifySendError := func(tx *solanago.Transaction, err error) mn.SendTxReturnCode { - return 0 // TODO ClassifySendError(err, clientErrors, logger.Sugared(logger.Nop()), tx, common.Address{}, false) - } - - txSender := mn.NewTransactionSender[*solanago.Transaction, mn.StringID, *client.MultiNodeClient]( + txSender := mn.NewTransactionSender[*solanago.Transaction, *client.SendTxResult, mn.StringID, *client.MultiNodeClient]( lggr, mn.StringID(id), chainFamily, multiNode, - classifySendError, + client.NewSendTxResult, 0, // use the default value provided by the implementation ) @@ -288,13 +289,24 @@ func newChain(id string, cfg *config.TOMLConfig, ks loop.Keystore, lggr logger.L // clientCache will not be used if multinode is enabled ch.clientCache = nil - } - tc := func() (client.ReaderWriter, error) { - return ch.getClient() + // Send tx using MultiNode transaction sender + sendTx = func(ctx context.Context, tx *solanago.Transaction) (solanago.Signature, error) { + result := ch.txSender.SendTransaction(ctx, tx) + if result == nil { + return solanago.Signature{}, errors.New("tx sender returned nil result") + } + if result.Error() != nil { + return solanago.Signature{}, result.Error() + } + return result.Signature(), result.TxError() + } + + tc = internal.NewLoader[client.ReaderWriter](func() (client.ReaderWriter, error) { return ch.multiNode.SelectRPC() }) + bc = internal.NewLoader[monitor.BalanceClient](func() (monitor.BalanceClient, error) { return ch.multiNode.SelectRPC() }) } - ch.txm = txm.NewTxm(ch.id, tc, cfg, ks, lggr) - bc := func() (monitor.BalanceClient, error) { return ch.getClient() } + + ch.txm = txm.NewTxm(ch.id, tc, sendTx, cfg, ks, lggr) ch.balanceMonitor = monitor.NewBalanceMonitor(ch.id, cfg, lggr, ks, bc) return &ch, nil } diff --git a/pkg/solana/chain_test.go b/pkg/solana/chain_test.go index 9f32096d6..b705860c9 100644 --- a/pkg/solana/chain_test.go +++ b/pkg/solana/chain_test.go @@ -11,20 +11,22 @@ import ( "strings" "sync" "testing" + "time" "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/programs/system" "github.com/gagliardetto/solana-go/rpc" "github.com/google/uuid" + "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" - "github.com/smartcontractkit/chainlink-common/pkg/config" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" solcfg "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" @@ -126,61 +128,6 @@ func TestSolanaChain_GetClient(t *testing.T) { assert.NoError(t, err) } -func TestSolanaChain_MultiNode_GetClient(t *testing.T) { - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - out := fmt.Sprintf(TestSolanaGenesisHashTemplate, client.MainnetGenesisHash) // mainnet genesis hash - if !strings.Contains(r.URL.Path, "/mismatch") { - // devnet gensis hash - out = fmt.Sprintf(TestSolanaGenesisHashTemplate, client.DevnetGenesisHash) - } - _, err := w.Write([]byte(out)) - require.NoError(t, err) - })) - defer mockServer.Close() - - ch := solcfg.Chain{} - ch.SetDefaults() - mn := solcfg.MultiNodeConfig{ - MultiNode: solcfg.MultiNode{ - Enabled: ptr(true), - }, - } - mn.SetDefaults() - - cfg := &solcfg.TOMLConfig{ - ChainID: ptr("devnet"), - Chain: ch, - MultiNode: mn, - } - cfg.Nodes = []*solcfg.Node{ - { - Name: ptr("devnet"), - URL: config.MustParseURL(mockServer.URL + "/1"), - }, - { - Name: ptr("devnet"), - URL: config.MustParseURL(mockServer.URL + "/2"), - }, - } - - testChain, err := newChain("devnet", cfg, nil, logger.Test(t)) - require.NoError(t, err) - - err = testChain.Start(tests.Context(t)) - require.NoError(t, err) - defer func() { - closeErr := testChain.Close() - require.NoError(t, closeErr) - }() - - selectedClient, err := testChain.getClient() - assert.NoError(t, err) - - id, err := selectedClient.ChainID(tests.Context(t)) - assert.NoError(t, err) - assert.Equal(t, "devnet", id.String()) -} - func TestSolanaChain_VerifiedClient(t *testing.T) { ctx := tests.Context(t) called := false @@ -371,3 +318,290 @@ func TestChain_Transact(t *testing.T) { require.NoError(t, err) assert.Equal(t, fees.ComputeUnitLimit(500), limit) } + +func TestSolanaChain_MultiNode_GetClient(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := fmt.Sprintf(TestSolanaGenesisHashTemplate, client.MainnetGenesisHash) // mainnet genesis hash + if !strings.Contains(r.URL.Path, "/mismatch") { + // devnet gensis hash + out = fmt.Sprintf(TestSolanaGenesisHashTemplate, client.DevnetGenesisHash) + } + _, err := w.Write([]byte(out)) + require.NoError(t, err) + })) + defer mockServer.Close() + + ch := solcfg.Chain{} + ch.SetDefaults() + mnCfg := solcfg.MultiNodeConfig{ + MultiNode: solcfg.MultiNode{ + Enabled: ptr(true), + }, + } + mnCfg.SetDefaults() + + cfg := &solcfg.TOMLConfig{ + ChainID: ptr("devnet"), + Chain: ch, + MultiNode: mnCfg, + } + cfg.Nodes = []*solcfg.Node{ + { + Name: ptr("devnet"), + URL: config.MustParseURL(mockServer.URL + "/1"), + }, + { + Name: ptr("devnet"), + URL: config.MustParseURL(mockServer.URL + "/2"), + }, + } + + testChain, err := newChain("devnet", cfg, nil, logger.Test(t)) + require.NoError(t, err) + + err = testChain.Start(tests.Context(t)) + require.NoError(t, err) + defer func() { + closeErr := testChain.Close() + require.NoError(t, closeErr) + }() + + selectedClient, err := testChain.getClient() + assert.NoError(t, err) + + id, err := selectedClient.ChainID(tests.Context(t)) + assert.NoError(t, err) + assert.Equal(t, "devnet", id.String()) +} + +func TestChain_MultiNode_TransactionSender(t *testing.T) { + ctx := tests.Context(t) + url := client.SetupLocalSolNode(t) + lgr, _ := logger.TestObserved(t, zapcore.DebugLevel) + + // transaction parameters + sender, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + receiver, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + client.FundTestAccounts(t, solana.PublicKeySlice{sender.PublicKey()}, url) + + // configuration + cfg := solcfg.NewDefault() + cfg.MultiNode.MultiNode.Enabled = ptr(true) + cfg.Nodes = append(cfg.Nodes, + &solcfg.Node{ + Name: ptr("localnet-" + t.Name() + "-primary"), + URL: config.MustParseURL(client.SetupLocalSolNode(t)), + SendOnly: false, + }) + + // mocked keystore + mkey := mocks.NewSimpleKeystore(t) + c, err := newChain("localnet", cfg, mkey, lgr) + require.NoError(t, err) + require.NoError(t, c.Start(ctx)) + defer func() { + require.NoError(t, c.Close()) + }() + + createTx := func(from solana.PrivateKey, to solana.PrivateKey) *solana.Transaction { + cl, err := c.getClient() + require.NoError(t, err) + + hash, hashErr := cl.LatestBlockhash(tests.Context(t)) + assert.NoError(t, hashErr) + + tx, txErr := solana.NewTransaction( + []solana.Instruction{ + system.NewTransferInstruction( + 1, + from.PublicKey(), + to.PublicKey(), + ).Build(), + }, + hash.Value.Blockhash, + solana.TransactionPayer(from.PublicKey()), + ) + assert.NoError(t, txErr) + _, signErr := tx.Sign( + func(key solana.PublicKey) *solana.PrivateKey { + if from.PublicKey().Equals(key) { + return &from + } + return nil + }, + ) + assert.NoError(t, signErr) + return tx + } + + t.Run("successful transaction", func(t *testing.T) { + // Send tx using transaction sender + result := c.txSender.SendTransaction(ctx, createTx(sender, receiver)) + require.NotNil(t, result) + require.NoError(t, result.Error()) + require.Equal(t, mn.Successful, result.Code()) + require.NotEmpty(t, result.Signature()) + }) + + t.Run("unsigned transaction error", func(t *testing.T) { + // create + sign transaction + unsignedTx := func(to solana.PublicKey) *solana.Transaction { + cl, err := c.getClient() + require.NoError(t, err) + + hash, hashErr := cl.LatestBlockhash(tests.Context(t)) + assert.NoError(t, hashErr) + + tx, txErr := solana.NewTransaction( + []solana.Instruction{ + system.NewTransferInstruction( + 1, + sender.PublicKey(), + to, + ).Build(), + }, + hash.Value.Blockhash, + solana.TransactionPayer(sender.PublicKey()), + ) + assert.NoError(t, txErr) + return tx + } + + // Send tx using transaction sender + result := c.txSender.SendTransaction(ctx, unsignedTx(receiver.PublicKey())) + require.NotNil(t, result) + require.NoError(t, result.Error()) + require.Error(t, result.TxError()) + require.Equal(t, mn.Fatal, result.Code()) + require.Empty(t, result.Signature()) + }) + + t.Run("empty transaction", func(t *testing.T) { + result := c.txSender.SendTransaction(ctx, &solana.Transaction{}) + require.NotNil(t, result) + require.NoError(t, result.Error()) + require.Error(t, result.TxError()) + require.Equal(t, mn.Fatal, result.Code()) + require.Empty(t, result.Signature()) + }) +} + +func TestSolanaChain_MultiNode_Txm(t *testing.T) { + cfg := solcfg.NewDefault() + cfg.MultiNode.MultiNode.Enabled = ptr(true) + cfg.Nodes = []*solcfg.Node{ + { + Name: ptr("primary"), + URL: config.MustParseURL(client.SetupLocalSolNode(t)), + }, + } + + // setup keys + key, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + pubKey := key.PublicKey() + + // setup receiver key + privKeyReceiver, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + pubKeyReceiver := privKeyReceiver.PublicKey() + + // mocked keystore + mkey := mocks.NewSimpleKeystore(t) + mkey.On("Sign", mock.Anything, pubKey.String(), mock.Anything).Return(func(_ context.Context, _ string, data []byte) []byte { + sig, _ := key.Sign(data) + return sig[:] + }, nil) + mkey.On("Sign", mock.Anything, pubKeyReceiver.String(), mock.Anything).Return([]byte{}, config.KeyNotFoundError{ID: pubKeyReceiver.String(), KeyType: "Solana"}) + + testChain, err := newChain("localnet", cfg, mkey, logger.Test(t)) + require.NoError(t, err) + + err = testChain.Start(tests.Context(t)) + require.NoError(t, err) + defer func() { + require.NoError(t, testChain.Close()) + }() + + // fund keys + client.FundTestAccounts(t, []solana.PublicKey{pubKey}, cfg.Nodes[0].URL.String()) + + // track initial balance + selectedClient, err := testChain.getClient() + require.NoError(t, err) + receiverBal, err := selectedClient.Balance(tests.Context(t), pubKeyReceiver) + assert.NoError(t, err) + assert.Equal(t, uint64(0), receiverBal) + + createTx := func(signer solana.PublicKey, sender solana.PublicKey, receiver solana.PublicKey, amt uint64) *solana.Transaction { + selectedClient, err = testChain.getClient() + assert.NoError(t, err) + hash, hashErr := selectedClient.LatestBlockhash(tests.Context(t)) + assert.NoError(t, hashErr) + tx, txErr := solana.NewTransaction( + []solana.Instruction{ + system.NewTransferInstruction( + amt, + sender, + receiver, + ).Build(), + }, + hash.Value.Blockhash, + solana.TransactionPayer(signer), + ) + require.NoError(t, txErr) + return tx + } + + // Send funds twice, along with an invalid transaction + require.NoError(t, testChain.txm.Enqueue(tests.Context(t), "test_success", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) + + // Wait for new block hash + currentBh, err := selectedClient.LatestBlockhash(tests.Context(t)) + require.NoError(t, err) + timeout := time.After(time.Minute) + +NewBlockHash: + for { + select { + case <-timeout: + t.Fatal("timed out waiting for new block hash") + default: + newBh, bhErr := selectedClient.LatestBlockhash(tests.Context(t)) + require.NoError(t, bhErr) + if newBh.Value.LastValidBlockHeight > currentBh.Value.LastValidBlockHeight { + break NewBlockHash + } + } + } + + require.NoError(t, testChain.txm.Enqueue(tests.Context(t), "test_success_2", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) + require.Error(t, testChain.txm.Enqueue(tests.Context(t), "test_invalidSigner", createTx(pubKeyReceiver, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) // cannot sign tx before enqueuing + + // wait for all txes to finish + ctx, cancel := context.WithCancel(tests.Context(t)) + t.Cleanup(cancel) + ticker := time.NewTicker(time.Second) + defer ticker.Stop() +loop: + for { + select { + case <-ctx.Done(): + assert.Equal(t, 0, testChain.txm.InflightTxs()) + break loop + case <-ticker.C: + if testChain.txm.InflightTxs() == 0 { + cancel() // exit for loop + } + } + } + + // verify funds were transferred through transaction sender + selectedClient, err = testChain.getClient() + assert.NoError(t, err) + receiverBal, err = selectedClient.Balance(tests.Context(t), pubKeyReceiver) + assert.NoError(t, err) + require.Equal(t, 2*solana.LAMPORTS_PER_SOL, receiverBal) +} diff --git a/pkg/solana/client/classify_errors.go b/pkg/solana/client/classify_errors.go new file mode 100644 index 000000000..ae3402694 --- /dev/null +++ b/pkg/solana/client/classify_errors.go @@ -0,0 +1,75 @@ +package client + +import ( + "regexp" + + "github.com/gagliardetto/solana-go" + + mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" +) + +// Solana error patters +// https://github.com/anza-xyz/agave/blob/master/sdk/src/transaction/error.rs +var ( + ErrAccountInUse = regexp.MustCompile(`Account in use`) + ErrAccountLoadedTwice = regexp.MustCompile(`Account loaded twice`) + ErrAccountNotFound = regexp.MustCompile(`Attempt to debit an account but found no record of a prior credit\.`) + ErrProgramAccountNotFound = regexp.MustCompile(`Attempt to load a program that does not exist`) + ErrInsufficientFundsForFee = regexp.MustCompile(`Insufficient funds for fee`) + ErrInvalidAccountForFee = regexp.MustCompile(`This account may not be used to pay transaction fees`) + ErrAlreadyProcessed = regexp.MustCompile(`This transaction has already been processed`) + ErrBlockhashNotFound = regexp.MustCompile(`Blockhash not found`) + ErrInstructionError = regexp.MustCompile(`Error processing Instruction \d+: .+`) + ErrCallChainTooDeep = regexp.MustCompile(`Loader call chain is too deep`) + ErrMissingSignatureForFee = regexp.MustCompile(`Transaction requires a fee but has no signature present`) + ErrInvalidAccountIndex = regexp.MustCompile(`Transaction contains an invalid account reference`) + ErrSignatureFailure = regexp.MustCompile(`Transaction did not pass signature verification`) + ErrInvalidProgramForExecution = regexp.MustCompile(`This program may not be used for executing instructions`) + ErrSanitizeFailure = regexp.MustCompile(`Transaction failed to sanitize accounts offsets correctly`) + ErrClusterMaintenance = regexp.MustCompile(`Transactions are currently disabled due to cluster maintenance`) + ErrAccountBorrowOutstanding = regexp.MustCompile(`Transaction processing left an account with an outstanding borrowed reference`) + ErrWouldExceedMaxBlockCostLimit = regexp.MustCompile(`Transaction would exceed max Block Cost Limit`) + ErrUnsupportedVersion = regexp.MustCompile(`Transaction version is unsupported`) + ErrInvalidWritableAccount = regexp.MustCompile(`Transaction loads a writable account that cannot be written`) + ErrWouldExceedMaxAccountCostLimit = regexp.MustCompile(`Transaction would exceed max account limit within the block`) + ErrWouldExceedAccountDataBlockLimit = regexp.MustCompile(`Transaction would exceed account data limit within the block`) + ErrTooManyAccountLocks = regexp.MustCompile(`Transaction locked too many accounts`) + ErrAddressLookupTableNotFound = regexp.MustCompile(`Transaction loads an address table account that doesn't exist`) + ErrInvalidAddressLookupTableOwner = regexp.MustCompile(`Transaction loads an address table account with an invalid owner`) + ErrInvalidAddressLookupTableData = regexp.MustCompile(`Transaction loads an address table account with invalid data`) + ErrInvalidAddressLookupTableIndex = regexp.MustCompile(`Transaction address table lookup uses an invalid index`) + ErrInvalidRentPayingAccount = regexp.MustCompile(`Transaction leaves an account with a lower balance than rent-exempt minimum`) + ErrWouldExceedMaxVoteCostLimit = regexp.MustCompile(`Transaction would exceed max Vote Cost Limit`) + ErrWouldExceedAccountDataTotalLimit = regexp.MustCompile(`Transaction would exceed total account data limit`) + ErrDuplicateInstruction = regexp.MustCompile(`Transaction contains a duplicate instruction \(\d+\) that is not allowed`) + ErrInsufficientFundsForRent = regexp.MustCompile(`Transaction results in an account \(\d+\) with insufficient funds for rent`) + ErrMaxLoadedAccountsDataSizeExceeded = regexp.MustCompile(`Transaction exceeded max loaded accounts data size cap`) + ErrInvalidLoadedAccountsDataSizeLimit = regexp.MustCompile(`LoadedAccountsDataSizeLimit set for transaction must be greater than 0\.`) + ErrResanitizationNeeded = regexp.MustCompile(`Sanitized transaction differed before/after feature activation\. Needs to be resanitized\.`) + ErrProgramExecutionTemporarilyRestricted = regexp.MustCompile(`Execution of the program referenced by account at index \d+ is temporarily restricted\.`) + ErrUnbalancedTransaction = regexp.MustCompile(`Sum of account balances before and after transaction do not match`) + ErrProgramCacheHitMaxLimit = regexp.MustCompile(`Program cache hit max limit`) +) + +// errCodes maps regex patterns to their corresponding return code +// errors are considered Retryable by default if not in this map +var errCodes = map[*regexp.Regexp]mn.SendTxReturnCode{ + ErrSanitizeFailure: mn.Fatal, // Transaction formatting is invalid and cannot be processed or retried + ErrAlreadyProcessed: mn.TransactionAlreadyKnown, // Transaction was already processed and thus known by the RPC + ErrInsufficientFundsForFee: mn.InsufficientFunds, // Transaction was rejected due to insufficient funds for gas fees +} + +// ClassifySendError returns the corresponding return code based on the error. +func ClassifySendError(_ *solana.Transaction, err error) mn.SendTxReturnCode { + if err == nil { + return mn.Successful + } + + errMsg := err.Error() + for pattern, code := range errCodes { + if pattern.MatchString(errMsg) { + return code + } + } + return mn.Retryable +} diff --git a/pkg/solana/client/classify_errors_test.go b/pkg/solana/client/classify_errors_test.go new file mode 100644 index 000000000..29d6a3bc0 --- /dev/null +++ b/pkg/solana/client/classify_errors_test.go @@ -0,0 +1,69 @@ +package client + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" +) + +func TestClassifySendError(t *testing.T) { + tests := []struct { + errMsg string + expectedCode mn.SendTxReturnCode + }{ + // Static error cases + {"Account in use", mn.Retryable}, + {"Account loaded twice", mn.Retryable}, + {"Attempt to debit an account but found no record of a prior credit.", mn.Retryable}, + {"Attempt to load a program that does not exist", mn.Retryable}, + {"Insufficient funds for fee", mn.InsufficientFunds}, + {"This account may not be used to pay transaction fees", mn.Retryable}, + {"This transaction has already been processed", mn.TransactionAlreadyKnown}, + {"Blockhash not found", mn.Retryable}, + {"Loader call chain is too deep", mn.Retryable}, + {"Transaction requires a fee but has no signature present", mn.Retryable}, + {"Transaction contains an invalid account reference", mn.Retryable}, + {"Transaction did not pass signature verification", mn.Retryable}, + {"This program may not be used for executing instructions", mn.Retryable}, + {"Transaction failed to sanitize accounts offsets correctly", mn.Fatal}, + {"Transactions are currently disabled due to cluster maintenance", mn.Retryable}, + {"Transaction processing left an account with an outstanding borrowed reference", mn.Retryable}, + {"Transaction would exceed max Block Cost Limit", mn.Retryable}, + {"Transaction version is unsupported", mn.Retryable}, + {"Transaction loads a writable account that cannot be written", mn.Retryable}, + {"Transaction would exceed max account limit within the block", mn.Retryable}, + {"Transaction would exceed account data limit within the block", mn.Retryable}, + {"Transaction locked too many accounts", mn.Retryable}, + {"Address lookup table not found", mn.Retryable}, + {"Attempted to lookup addresses from an account owned by the wrong program", mn.Retryable}, + {"Attempted to lookup addresses from an invalid account", mn.Retryable}, + {"Address table lookup uses an invalid index", mn.Retryable}, + {"Transaction leaves an account with a lower balance than rent-exempt minimum", mn.Retryable}, + {"Transaction would exceed max Vote Cost Limit", mn.Retryable}, + {"Transaction would exceed total account data limit", mn.Retryable}, + {"Transaction contains a duplicate instruction", mn.Retryable}, + {"Transaction exceeded max loaded accounts data size cap", mn.Retryable}, + {"LoadedAccountsDataSizeLimit set for transaction must be greater than 0.", mn.Retryable}, + {"Sanitized transaction differed before/after feature activation. Needs to be resanitized.", mn.Retryable}, + {"Program cache hit max limit", mn.Retryable}, + + // Dynamic error cases + {"Transaction results in an account (123) with insufficient funds for rent", mn.Retryable}, + {"Error processing Instruction 2: Some error details", mn.Retryable}, + {"Execution of the program referenced by account at index 3 is temporarily restricted.", mn.Retryable}, + + // Edge cases + {"Unknown error message", mn.Retryable}, + {"", mn.Retryable}, // Empty message + } + + for _, tt := range tests { + t.Run(tt.errMsg, func(t *testing.T) { + result := ClassifySendError(nil, errors.New(tt.errMsg)) + assert.Equal(t, tt.expectedCode, result, "Expected %v but got %v for error message: %s", tt.expectedCode, result, tt.errMsg) + }) + } +} diff --git a/pkg/solana/client/multinode/multi_node.go b/pkg/solana/client/multinode/multi_node.go index 8b7efc46b..bd97ebc7b 100644 --- a/pkg/solana/client/multinode/multi_node.go +++ b/pkg/solana/client/multinode/multi_node.go @@ -90,11 +90,9 @@ func (c *MultiNode[CHAIN_ID, RPC]) ChainID() CHAIN_ID { return c.chainID } -func (c *MultiNode[CHAIN_ID, RPC]) DoAll(baseCtx context.Context, do func(ctx context.Context, rpc RPC, isSendOnly bool)) error { +func (c *MultiNode[CHAIN_ID, RPC]) DoAll(ctx context.Context, do func(ctx context.Context, rpc RPC, isSendOnly bool)) error { var err error ok := c.IfNotStopped(func() { - ctx, _ := c.chStop.Ctx(baseCtx) - callsCompleted := 0 for _, n := range c.primaryNodes { select { diff --git a/pkg/solana/client/multinode/transaction_sender.go b/pkg/solana/client/multinode/transaction_sender.go index fbd5acca5..bd11a71a5 100644 --- a/pkg/solana/client/multinode/transaction_sender.go +++ b/pkg/solana/client/multinode/transaction_sender.go @@ -24,52 +24,49 @@ var ( }, []string{"network", "chainId", "invariant"}) ) -// TxErrorClassifier - defines interface of a function that transforms raw RPC error into the SendTxReturnCode enum -// (e.g. Successful, Fatal, Retryable, etc.) -type TxErrorClassifier[TX any] func(tx TX, err error) SendTxReturnCode - -type sendTxResult struct { - Err error - ResultCode SendTxReturnCode +type SendTxResult interface { + Code() SendTxReturnCode + TxError() error + Error() error } const sendTxQuorum = 0.7 // SendTxRPCClient - defines interface of an RPC used by TransactionSender to broadcast transaction -type SendTxRPCClient[TX any] interface { +type SendTxRPCClient[TX any, RESULT SendTxResult] interface { // SendTransaction errors returned should include name or other unique identifier of the RPC - SendTransaction(ctx context.Context, tx TX) error + SendTransaction(ctx context.Context, tx TX) RESULT } -func NewTransactionSender[TX any, CHAIN_ID ID, RPC SendTxRPCClient[TX]]( +func NewTransactionSender[TX any, RESULT SendTxResult, CHAIN_ID ID, RPC SendTxRPCClient[TX, RESULT]]( lggr logger.Logger, chainID CHAIN_ID, chainFamily string, multiNode *MultiNode[CHAIN_ID, RPC], - txErrorClassifier TxErrorClassifier[TX], + newResult func(err error) RESULT, sendTxSoftTimeout time.Duration, -) *TransactionSender[TX, CHAIN_ID, RPC] { +) *TransactionSender[TX, RESULT, CHAIN_ID, RPC] { if sendTxSoftTimeout == 0 { sendTxSoftTimeout = QueryTimeout / 2 } - return &TransactionSender[TX, CHAIN_ID, RPC]{ + return &TransactionSender[TX, RESULT, CHAIN_ID, RPC]{ chainID: chainID, chainFamily: chainFamily, lggr: logger.Sugared(lggr).Named("TransactionSender").With("chainID", chainID.String()), multiNode: multiNode, - txErrorClassifier: txErrorClassifier, + newResult: newResult, sendTxSoftTimeout: sendTxSoftTimeout, chStop: make(services.StopChan), } } -type TransactionSender[TX any, CHAIN_ID ID, RPC SendTxRPCClient[TX]] struct { +type TransactionSender[TX any, RESULT SendTxResult, CHAIN_ID ID, RPC SendTxRPCClient[TX, RESULT]] struct { services.StateMachine chainID CHAIN_ID chainFamily string lggr logger.SugaredLogger multiNode *MultiNode[CHAIN_ID, RPC] - txErrorClassifier TxErrorClassifier[TX] + newResult func(err error) RESULT sendTxSoftTimeout time.Duration // defines max waiting time from first response til responses evaluation wg sync.WaitGroup // waits for all reporting goroutines to finish @@ -94,16 +91,26 @@ type TransactionSender[TX any, CHAIN_ID ID, RPC SendTxRPCClient[TX]] struct { // * If there is at least one terminal error - returns terminal error // * If there is both success and terminal error - returns success and reports invariant violation // * Otherwise, returns any (effectively random) of the errors. -func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) SendTransaction(ctx context.Context, tx TX) (SendTxReturnCode, error) { - txResults := make(chan sendTxResult) - txResultsToReport := make(chan sendTxResult) +func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) SendTransaction(ctx context.Context, tx TX) RESULT { + txResults := make(chan RESULT) + txResultsToReport := make(chan RESULT) primaryNodeWg := sync.WaitGroup{} - ctx, cancel := txSender.chStop.Ctx(ctx) - defer cancel() + if txSender.State() != "Started" { + return txSender.newResult(errors.New("TransactionSender not started")) + } + + txSenderCtx, cancel := txSender.chStop.NewCtx() + reportWg := sync.WaitGroup{} + defer func() { + go func() { + reportWg.Wait() + cancel() + }() + }() healthyNodesNum := 0 - err := txSender.multiNode.DoAll(ctx, func(ctx context.Context, rpc RPC, isSendOnly bool) { + err := txSender.multiNode.DoAll(txSenderCtx, func(ctx context.Context, rpc RPC, isSendOnly bool) { if isSendOnly { txSender.wg.Add(1) go func() { @@ -120,17 +127,17 @@ func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) SendTransaction(ctx contex primaryNodeWg.Add(1) go func() { defer primaryNodeWg.Done() - result := txSender.broadcastTxAsync(ctx, rpc, tx) + r := txSender.broadcastTxAsync(ctx, rpc, tx) select { case <-ctx.Done(): return - case txResults <- result: + case txResults <- r: } select { case <-ctx.Done(): return - case txResultsToReport <- result: + case txResultsToReport <- r: } }() }) @@ -145,77 +152,80 @@ func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) SendTransaction(ctx contex }() if err != nil { - return Retryable, err + return txSender.newResult(err) } txSender.wg.Add(1) - go txSender.reportSendTxAnomalies(tx, txResultsToReport) + reportWg.Add(1) + go func() { + defer reportWg.Done() + txSender.reportSendTxAnomalies(tx, txResultsToReport) + }() return txSender.collectTxResults(ctx, tx, healthyNodesNum, txResults) } -func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) broadcastTxAsync(ctx context.Context, rpc RPC, tx TX) sendTxResult { - txErr := rpc.SendTransaction(ctx, tx) - txSender.lggr.Debugw("Node sent transaction", "tx", tx, "err", txErr) - resultCode := txSender.txErrorClassifier(tx, txErr) - if !slices.Contains(sendTxSuccessfulCodes, resultCode) { - txSender.lggr.Warnw("RPC returned error", "tx", tx, "err", txErr) +func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) broadcastTxAsync(ctx context.Context, rpc RPC, tx TX) RESULT { + result := rpc.SendTransaction(ctx, tx) + txSender.lggr.Debugw("Node sent transaction", "tx", tx, "err", result.TxError()) + if !slices.Contains(sendTxSuccessfulCodes, result.Code()) { + txSender.lggr.Warnw("RPC returned error", "tx", tx, "err", result.TxError()) } - return sendTxResult{Err: txErr, ResultCode: resultCode} + return result } -func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) reportSendTxAnomalies(tx TX, txResults <-chan sendTxResult) { +func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) reportSendTxAnomalies(tx TX, txResults <-chan RESULT) { defer txSender.wg.Done() - resultsByCode := sendTxResults{} + resultsByCode := sendTxResults[RESULT]{} // txResults eventually will be closed for txResult := range txResults { - resultsByCode[txResult.ResultCode] = append(resultsByCode[txResult.ResultCode], txResult.Err) + resultsByCode[txResult.Code()] = append(resultsByCode[txResult.Code()], txResult) } - _, _, criticalErr := aggregateTxResults(resultsByCode) + _, criticalErr := aggregateTxResults[RESULT](resultsByCode) if criticalErr != nil { txSender.lggr.Criticalw("observed invariant violation on SendTransaction", "tx", tx, "resultsByCode", resultsByCode, "err", criticalErr) PromMultiNodeInvariantViolations.WithLabelValues(txSender.chainFamily, txSender.chainID.String(), criticalErr.Error()).Inc() } } -type sendTxResults map[SendTxReturnCode][]error +type sendTxResults[RESULT any] map[SendTxReturnCode][]RESULT -func aggregateTxResults(resultsByCode sendTxResults) (returnCode SendTxReturnCode, txResult error, err error) { - severeCode, severeErrors, hasSevereErrors := findFirstIn(resultsByCode, sendTxSevereErrors) - successCode, successResults, hasSuccess := findFirstIn(resultsByCode, sendTxSuccessfulCodes) +func aggregateTxResults[RESULT any](resultsByCode sendTxResults[RESULT]) (result RESULT, criticalErr error) { + severeErrors, hasSevereErrors := findFirstIn(resultsByCode, sendTxSevereErrors) + successResults, hasSuccess := findFirstIn(resultsByCode, sendTxSuccessfulCodes) if hasSuccess { // We assume that primary node would never report false positive txResult for a transaction. // Thus, if such case occurs it's probably due to misconfiguration or a bug and requires manual intervention. if hasSevereErrors { const errMsg = "found contradictions in nodes replies on SendTransaction: got success and severe error" // return success, since at least 1 node has accepted our broadcasted Tx, and thus it can now be included onchain - return successCode, successResults[0], errors.New(errMsg) + return successResults[0], errors.New(errMsg) } // other errors are temporary - we are safe to return success - return successCode, successResults[0], nil + return successResults[0], nil } if hasSevereErrors { - return severeCode, severeErrors[0], nil + return severeErrors[0], nil } // return temporary error - for code, result := range resultsByCode { - return code, result[0], nil + for _, r := range resultsByCode { + return r[0], nil } - err = fmt.Errorf("expected at least one response on SendTransaction") - return Retryable, err, err + criticalErr = fmt.Errorf("expected at least one response on SendTransaction") + return result, criticalErr } -func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) collectTxResults(ctx context.Context, tx TX, healthyNodesNum int, txResults <-chan sendTxResult) (SendTxReturnCode, error) { +func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) collectTxResults(ctx context.Context, tx TX, healthyNodesNum int, txResults <-chan RESULT) RESULT { if healthyNodesNum == 0 { - return Retryable, ErroringNodeError + return txSender.newResult(ErroringNodeError) } requiredResults := int(math.Ceil(float64(healthyNodesNum) * sendTxQuorum)) - errorsByCode := sendTxResults{} + errorsByCode := sendTxResults[RESULT]{} var softTimeoutChan <-chan time.Time var resultsCount int loop: @@ -223,11 +233,11 @@ loop: select { case <-ctx.Done(): txSender.lggr.Debugw("Failed to collect of the results before context was done", "tx", tx, "errorsByCode", errorsByCode) - return Retryable, ctx.Err() - case result := <-txResults: - errorsByCode[result.ResultCode] = append(errorsByCode[result.ResultCode], result.Err) + return txSender.newResult(ctx.Err()) + case r := <-txResults: + errorsByCode[r.Code()] = append(errorsByCode[r.Code()], r) resultsCount++ - if slices.Contains(sendTxSuccessfulCodes, result.ResultCode) || resultsCount >= requiredResults { + if slices.Contains(sendTxSuccessfulCodes, r.Code()) || resultsCount >= requiredResults { break loop } case <-softTimeoutChan: @@ -245,17 +255,17 @@ loop: } // ignore critical error as it's reported in reportSendTxAnomalies - returnCode, result, _ := aggregateTxResults(errorsByCode) - return returnCode, result + result, _ := aggregateTxResults(errorsByCode) + return result } -func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) Start(ctx context.Context) error { +func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) Start(ctx context.Context) error { return txSender.StartOnce("TransactionSender", func() error { return nil }) } -func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) Close() error { +func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) Close() error { return txSender.StopOnce("TransactionSender", func() error { close(txSender.chStop) txSender.wg.Wait() @@ -264,13 +274,12 @@ func (txSender *TransactionSender[TX, CHAIN_ID, RPC]) Close() error { } // findFirstIn - returns the first existing key and value for the slice of keys -func findFirstIn[K comparable, V any](set map[K]V, keys []K) (K, V, bool) { +func findFirstIn[K comparable, V any](set map[K]V, keys []K) (V, bool) { for _, k := range keys { if v, ok := set[k]; ok { - return k, v, true + return v, true } } - var zeroK K var zeroV V - return zeroK, zeroV, false + return zeroV, false } diff --git a/pkg/solana/client/multinode_client.go b/pkg/solana/client/multinode_client.go index 086699cef..0a68b78f6 100644 --- a/pkg/solana/client/multinode_client.go +++ b/pkg/solana/client/multinode_client.go @@ -41,7 +41,7 @@ func (h *Head) IsValid() bool { } var _ mn.RPCClient[mn.StringID, *Head] = (*MultiNodeClient)(nil) -var _ mn.SendTxRPCClient[*solana.Transaction] = (*MultiNodeClient)(nil) +var _ mn.SendTxRPCClient[*solana.Transaction, *SendTxResult] = (*MultiNodeClient)(nil) type MultiNodeClient struct { Client @@ -300,8 +300,43 @@ func (m *MultiNodeClient) GetInterceptedChainInfo() (latest, highestUserObservat return m.latestChainInfo, m.highestUserObservations } -func (m *MultiNodeClient) SendTransaction(ctx context.Context, tx *solana.Transaction) error { - // TODO: Use Transaction Sender - _, err := m.SendTx(ctx, tx) - return err +type SendTxResult struct { + err error + txErr error + code mn.SendTxReturnCode + sig solana.Signature +} + +var _ mn.SendTxResult = (*SendTxResult)(nil) + +func NewSendTxResult(err error) *SendTxResult { + result := &SendTxResult{ + err: err, + txErr: err, + } + result.code = ClassifySendError(nil, err) + return result +} + +func (r *SendTxResult) Error() error { + return r.err +} + +func (r *SendTxResult) TxError() error { + return r.txErr +} + +func (r *SendTxResult) Code() mn.SendTxReturnCode { + return r.code +} + +func (r *SendTxResult) Signature() solana.Signature { + return r.sig +} + +func (m *MultiNodeClient) SendTransaction(ctx context.Context, tx *solana.Transaction) *SendTxResult { + var sendTxResult = &SendTxResult{} + sendTxResult.sig, sendTxResult.txErr = m.SendTx(ctx, tx) + sendTxResult.code = ClassifySendError(tx, sendTxResult.txErr) + return sendTxResult } diff --git a/pkg/solana/fees/block_history.go b/pkg/solana/fees/block_history.go index 5a44e9640..c4eb55b2e 100644 --- a/pkg/solana/fees/block_history.go +++ b/pkg/solana/fees/block_history.go @@ -7,11 +7,11 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" - "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/mathutil" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" ) var _ Estimator = &blockHistoryEstimator{} @@ -23,7 +23,7 @@ type blockHistoryEstimator struct { chStop services.StopChan done sync.WaitGroup - client *utils.LazyLoad[client.ReaderWriter] + client internal.Loader[client.ReaderWriter] cfg config.Config lgr logger.Logger @@ -34,7 +34,7 @@ type blockHistoryEstimator struct { // NewBlockHistoryEstimator creates a new fee estimator that parses historical fees from a fetched block // Note: getRecentPrioritizationFees is not used because it provides the lowest prioritization fee for an included tx in the block // which is not effective enough for increasing the chances of block inclusion -func NewBlockHistoryEstimator(c *utils.LazyLoad[client.ReaderWriter], cfg config.Config, lgr logger.Logger) (*blockHistoryEstimator, error) { +func NewBlockHistoryEstimator(c internal.Loader[client.ReaderWriter], cfg config.Config, lgr logger.Logger) (*blockHistoryEstimator, error) { if cfg.BlockHistorySize() < 1 { return nil, fmt.Errorf("invalid block history depth: %d", cfg.BlockHistorySize()) } diff --git a/pkg/solana/internal/loader.go b/pkg/solana/internal/loader.go new file mode 100644 index 000000000..ba0bc5ee4 --- /dev/null +++ b/pkg/solana/internal/loader.go @@ -0,0 +1,24 @@ +package internal + +type Loader[T any] interface { + Get() (T, error) + Reset() +} + +var _ Loader[any] = (*loader[any])(nil) + +type loader[T any] struct { + getClient func() (T, error) +} + +func (c *loader[T]) Get() (T, error) { + return c.getClient() +} + +func (c *loader[T]) Reset() { /* do nothing */ } + +func NewLoader[T any](getClient func() (T, error)) *loader[T] { + return &loader[T]{ + getClient: getClient, + } +} diff --git a/pkg/solana/internal/loader_test.go b/pkg/solana/internal/loader_test.go new file mode 100644 index 000000000..8d17a27ea --- /dev/null +++ b/pkg/solana/internal/loader_test.go @@ -0,0 +1,33 @@ +package internal + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type testLoader struct { + Loader[any] + callCount int +} + +func (t *testLoader) load() (any, error) { + t.callCount++ + return nil, nil +} + +func newTestLoader() *testLoader { + loader := testLoader{} + loader.Loader = NewLoader[any](loader.load) + return &loader +} + +func TestLoader(t *testing.T) { + t.Run("direct loading", func(t *testing.T) { + loader := newTestLoader() + _, _ = loader.Get() + _, _ = loader.Get() + _, _ = loader.Get() + require.Equal(t, 3, loader.callCount) + }) +} diff --git a/pkg/solana/monitor/balance.go b/pkg/solana/monitor/balance.go index a1ab59d69..10ea487db 100644 --- a/pkg/solana/monitor/balance.go +++ b/pkg/solana/monitor/balance.go @@ -9,6 +9,8 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/utils" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" ) // Config defines the monitor configuration. @@ -26,19 +28,19 @@ type BalanceClient interface { } // NewBalanceMonitor returns a balance monitoring services.Service which reports the SOL balance of all ks keys to prometheus. -func NewBalanceMonitor(chainID string, cfg Config, lggr logger.Logger, ks Keystore, newReader func() (BalanceClient, error)) services.Service { - return newBalanceMonitor(chainID, cfg, lggr, ks, newReader) +func NewBalanceMonitor(chainID string, cfg Config, lggr logger.Logger, ks Keystore, reader internal.Loader[BalanceClient]) services.Service { + return newBalanceMonitor(chainID, cfg, lggr, ks, reader) } -func newBalanceMonitor(chainID string, cfg Config, lggr logger.Logger, ks Keystore, newReader func() (BalanceClient, error)) *balanceMonitor { +func newBalanceMonitor(chainID string, cfg Config, lggr logger.Logger, ks Keystore, reader internal.Loader[BalanceClient]) *balanceMonitor { b := balanceMonitor{ - chainID: chainID, - cfg: cfg, - lggr: logger.Named(lggr, "BalanceMonitor"), - ks: ks, - newReader: newReader, - stop: make(chan struct{}), - done: make(chan struct{}), + chainID: chainID, + cfg: cfg, + lggr: logger.Named(lggr, "BalanceMonitor"), + ks: ks, + reader: reader, + stop: make(chan struct{}), + done: make(chan struct{}), } b.updateFn = b.updateProm return &b @@ -46,14 +48,13 @@ func newBalanceMonitor(chainID string, cfg Config, lggr logger.Logger, ks Keysto type balanceMonitor struct { services.StateMachine - chainID string - cfg Config - lggr logger.Logger - ks Keystore - newReader func() (BalanceClient, error) - updateFn func(acc solana.PublicKey, lamports uint64) // overridable for testing + chainID string + cfg Config + lggr logger.Logger + ks Keystore + updateFn func(acc solana.PublicKey, lamports uint64) // overridable for testing - reader BalanceClient + reader internal.Loader[BalanceClient] stop services.StopChan done chan struct{} @@ -99,18 +100,6 @@ func (b *balanceMonitor) monitor() { } } -// getReader returns the cached solanaClient.Reader, or creates a new one if nil. -func (b *balanceMonitor) getReader() (BalanceClient, error) { - if b.reader == nil { - var err error - b.reader, err = b.newReader() - if err != nil { - return nil, err - } - } - return b.reader, nil -} - func (b *balanceMonitor) updateBalances(ctx context.Context) { ctx, cancel := b.stop.Ctx(ctx) defer cancel() @@ -123,7 +112,7 @@ func (b *balanceMonitor) updateBalances(ctx context.Context) { if len(keys) == 0 { return } - reader, err := b.getReader() + reader, err := b.reader.Get() if err != nil { b.lggr.Errorw("Failed to get client", "err", err) return @@ -151,6 +140,6 @@ func (b *balanceMonitor) updateBalances(ctx context.Context) { } if !gotSomeBals { // Try a new client next time. - b.reader = nil + b.reader.Reset() } } diff --git a/pkg/solana/monitor/balance_test.go b/pkg/solana/monitor/balance_test.go index 9321d6a52..a6cc231c9 100644 --- a/pkg/solana/monitor/balance_test.go +++ b/pkg/solana/monitor/balance_test.go @@ -62,7 +62,10 @@ func TestBalanceMonitor(t *testing.T) { close(done) } } - b.reader = client + + b.reader = internal.NewLoader[BalanceClient](func() (BalanceClient, error) { + return client, nil + }) servicetest.Run(t, b) select { diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index 45eb289fc..7cd09cf5e 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -23,6 +23,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" ) const ( @@ -54,8 +55,11 @@ type Txm struct { cfg config.Config txs PendingTxContext ks SimpleKeystore - client *utils.LazyLoad[client.ReaderWriter] + client internal.Loader[client.ReaderWriter] fee fees.Estimator + // sendTx is an override for sending transactions rather than using a single client + // Enabling MultiNode uses this function to send transactions to all RPCs + sendTx func(ctx context.Context, tx *solanaGo.Transaction) (solanaGo.Signature, error) } type TxConfig struct { @@ -79,7 +83,20 @@ type pendingTx struct { } // NewTxm creates a txm. Uses simulation so should only be used to send txes to trusted contracts i.e. OCR. -func NewTxm(chainID string, tc func() (client.ReaderWriter, error), cfg config.Config, ks SimpleKeystore, lggr logger.Logger) *Txm { +func NewTxm(chainID string, client internal.Loader[client.ReaderWriter], + sendTx func(ctx context.Context, tx *solanaGo.Transaction) (solanaGo.Signature, error), + cfg config.Config, ks SimpleKeystore, lggr logger.Logger) *Txm { + if sendTx == nil { + // default sendTx using a single RPC + sendTx = func(ctx context.Context, tx *solanaGo.Transaction) (solanaGo.Signature, error) { + c, err := client.Get() + if err != nil { + return solanaGo.Signature{}, err + } + return c.SendTx(ctx, tx) + } + } + return &Txm{ lggr: logger.Named(lggr, "Txm"), chSend: make(chan pendingTx, MaxQueueLen), // queue can support 1000 pending txs @@ -88,7 +105,8 @@ func NewTxm(chainID string, tc func() (client.ReaderWriter, error), cfg config.C cfg: cfg, txs: newPendingTxContextWithProm(chainID), ks: ks, - client: utils.NewLazyLoad(tc), + client: client, + sendTx: sendTx, } } @@ -157,12 +175,6 @@ func (txm *Txm) run() { } func (txm *Txm) sendWithRetry(ctx context.Context, baseTx solanaGo.Transaction, txcfg TxConfig) (solanaGo.Transaction, uuid.UUID, solanaGo.Signature, error) { - // fetch client - client, clientErr := txm.client.Get() - if clientErr != nil { - return solanaGo.Transaction{}, uuid.Nil, solanaGo.Signature{}, fmt.Errorf("failed to get client in soltxm.sendWithRetry: %w", clientErr) - } - // get key // fee payer account is index 0 account // https://github.com/gagliardetto/solana-go/blob/main/transaction.go#L252 @@ -222,7 +234,7 @@ func (txm *Txm) sendWithRetry(ctx context.Context, baseTx solanaGo.Transaction, ctx, cancel := context.WithTimeout(ctx, txcfg.Timeout) // send initial tx (do not retry and exit early if fails) - sig, initSendErr := client.SendTx(ctx, &initTx) + sig, initSendErr := txm.sendTx(ctx, &initTx) if initSendErr != nil { cancel() // cancel context when exiting early txm.txs.OnError(sig, TxFailReject) // increment failed metric @@ -293,7 +305,7 @@ func (txm *Txm) sendWithRetry(ctx context.Context, baseTx solanaGo.Transaction, go func(bump bool, count int, retryTx solanaGo.Transaction) { defer wg.Done() - retrySig, retrySendErr := client.SendTx(ctx, &retryTx) + retrySig, retrySendErr := txm.sendTx(ctx, &retryTx) // this could occur if endpoint goes down or if ctx cancelled if retrySendErr != nil { if strings.Contains(retrySendErr.Error(), "context canceled") || strings.Contains(retrySendErr.Error(), "context deadline exceeded") { diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index ef46ce785..802dc93b2 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -27,6 +27,7 @@ import ( relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" ) @@ -111,9 +112,8 @@ func TestTxm(t *testing.T) { mkey := keyMocks.NewSimpleKeystore(t) mkey.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil) - txm := NewTxm(id, func() (client.ReaderWriter, error) { - return mc, nil - }, cfg, mkey, lggr) + loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) + txm := NewTxm(id, loader, nil, cfg, mkey, lggr) require.NoError(t, txm.Start(ctx)) // tracking prom metrics @@ -726,9 +726,8 @@ func TestTxm_Enqueue(t *testing.T) { ) require.NoError(t, err) - txm := NewTxm("enqueue_test", func() (client.ReaderWriter, error) { - return mc, nil - }, cfg, mkey, lggr) + loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) + txm := NewTxm("enqueue_test", loader, nil, cfg, mkey, lggr) require.ErrorContains(t, txm.Enqueue(ctx, "txmUnstarted", &solana.Transaction{}), "not started") require.NoError(t, txm.Start(ctx)) diff --git a/pkg/solana/txm/txm_load_test.go b/pkg/solana/txm/txm_load_test.go index ff7831b02..744610e1f 100644 --- a/pkg/solana/txm/txm_load_test.go +++ b/pkg/solana/txm/txm_load_test.go @@ -22,6 +22,7 @@ import ( relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" ) @@ -70,10 +71,8 @@ func TestTxm_Integration(t *testing.T) { cfg.Chain.FeeEstimatorMode = &estimator client, err := solanaClient.NewClient(url, cfg, 2*time.Second, lggr) require.NoError(t, err) - getClient := func() (solanaClient.ReaderWriter, error) { - return client, nil - } - txm := txm.NewTxm("localnet", getClient, cfg, mkey, lggr) + loader := utils.NewLazyLoad(func() (solanaClient.ReaderWriter, error) { return client, nil }) + txm := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) // track initial balance initBal, err := client.Balance(ctx, pubKey) diff --git a/pkg/solana/txm/txm_race_test.go b/pkg/solana/txm/txm_race_test.go index 480ad35c8..81f2c15f6 100644 --- a/pkg/solana/txm/txm_race_test.go +++ b/pkg/solana/txm/txm_race_test.go @@ -12,6 +12,7 @@ import ( "go.uber.org/zap/zapcore" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" @@ -61,12 +62,11 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { tx := NewTestTx() testRunner := func(t *testing.T, client solanaClient.ReaderWriter) { - getClient := func() (solanaClient.ReaderWriter, error) { - return client, nil - } - // build minimal txm - txm := NewTxm("retry_race", getClient, cfg, ks, lggr) + loader := utils.NewLazyLoad(func() (solanaClient.ReaderWriter, error) { + return client, nil + }) + txm := NewTxm("retry_race", loader, nil, cfg, ks, lggr) txm.fee = fee _, _, _, err := txm.sendWithRetry( diff --git a/pkg/solana/txm/txm_unit_test.go b/pkg/solana/txm/txm_unit_test.go index e1a7aaf1b..bb2108f4e 100644 --- a/pkg/solana/txm/txm_unit_test.go +++ b/pkg/solana/txm/txm_unit_test.go @@ -18,6 +18,7 @@ import ( keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils" bigmath "github.com/smartcontractkit/chainlink-common/pkg/utils/big_math" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" ) @@ -45,10 +46,8 @@ func TestTxm_EstimateComputeUnitLimit(t *testing.T) { cfg := config.NewDefault() client := clientmocks.NewReaderWriter(t) require.NoError(t, err) - getClient := func() (solanaClient.ReaderWriter, error) { - return client, nil - } - txm := solanatxm.NewTxm("localnet", getClient, cfg, mkey, lggr) + loader := utils.NewLazyLoad(func() (solanaClient.ReaderWriter, error) { return client, nil }) + txm := solanatxm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) t.Run("successfully sets estimated compute unit limit", func(t *testing.T) { usedCompute := uint64(100)