diff --git a/pkg/solana/txm/pendingtx.go b/pkg/solana/txm/pendingtx.go index fdaa661e6..1e2c66728 100644 --- a/pkg/solana/txm/pendingtx.go +++ b/pkg/solana/txm/pendingtx.go @@ -30,6 +30,8 @@ type PendingTxContext interface { ListAll() []solana.Signature // ListAllExpiredBroadcastedTxs returns all the txes that are in broadcasted state and have expired for given slot height compared against their lastValidBlockHeight. ListAllExpiredBroadcastedTxs(currHeight uint64) []pendingTx + // ListAllBroadcastedTxs returns all the txes that are in broadcasted state. + ListAllBroadcastedTxs() []pendingTx // Expired returns whether or not confirmation timeout amount of time has passed since creation Expired(sig solana.Signature, confirmationTimeout time.Duration) bool // OnProcessed marks transactions as Processed @@ -215,7 +217,7 @@ func (c *pendingTxContext) ListAll() []solana.Signature { return maps.Keys(c.sigToID) } -// ListAllExpiredBroadcastedTxs returns all the expired broadcasted that are in broadcasted state and have expired for given slot height. +// ListAllExpiredBroadcastedTxs returns all the txes that are in broadcasted state and have expired for given slot height compared against their lastValidBlockHeight. func (c *pendingTxContext) ListAllExpiredBroadcastedTxs(currHeight uint64) []pendingTx { c.lock.RLock() defer c.lock.RUnlock() @@ -228,6 +230,19 @@ func (c *pendingTxContext) ListAllExpiredBroadcastedTxs(currHeight uint64) []pen return broadcastedTxes } +// ListAllBroadcastedTxs returns all the txes that are in broadcasted state. +func (c *pendingTxContext) ListAllBroadcastedTxs() []pendingTx { + c.lock.RLock() + defer c.lock.RUnlock() + broadcastedTxes := make([]pendingTx, 0, len(c.broadcastedProcessedTxs)) // worst case, all of them + for _, tx := range c.broadcastedProcessedTxs { + if tx.state == Broadcasted { + broadcastedTxes = append(broadcastedTxes, tx) + } + } + return broadcastedTxes +} + // Expired returns if the timeout for trying to confirm a signature has been reached func (c *pendingTxContext) Expired(sig solana.Signature, confirmationTimeout time.Duration) bool { c.lock.RLock() @@ -638,3 +653,7 @@ func (c *pendingTxContextWithProm) TrimFinalizedErroredTxs() { func (c *pendingTxContextWithProm) GetTxRebroadcastCount(id string) (int, error) { return c.pendingTx.GetTxRebroadcastCount(id) } + +func (c *pendingTxContextWithProm) ListAllBroadcastedTxs() []pendingTx { + return c.pendingTx.ListAllBroadcastedTxs() +} diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index ea82593f3..4052cfe3e 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -392,6 +392,12 @@ func (txm *Txm) confirm() { break } txm.processConfirmations(ctx, client) + + // In case all txes where confirmed and there's nothing to rebroadcast. + // This check saves making 2 RPC calls (slot height + blockhash) when there's nothing to process. + if len(txm.txs.ListAllBroadcastedTxs()) == 0 { + break + } if txm.cfg.TxExpirationRebroadcast() { txm.rebroadcastExpiredTxs(ctx, client) } diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index 6fb044471..b872e9871 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -1093,94 +1093,351 @@ func addSigAndLimitToTx(t *testing.T, keystore SimpleKeystore, pubkey solana.Pub func TestTxm_ExpirationRebroadcast(t *testing.T) { t.Parallel() - // Set up configurations estimator := "fixed" id := "mocknet-" + estimator + "-" + uuid.NewString() - t.Logf("Starting new iteration: %s", id) - ctx := tests.Context(t) - lggr := logger.Test(t) cfg := config.NewDefault() cfg.Chain.FeeEstimatorMode = &estimator - txExpirationRebroadcast := true - cfg.Chain.TxExpirationRebroadcast = &txExpirationRebroadcast // enable expiration rebroadcast cfg.Chain.TxConfirmTimeout = relayconfig.MustNewDuration(5 * time.Second) cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(10 * time.Second) // Enable retention to keep transactions after finality and be able to check. + lggr := logger.Test(t) + ctx := tests.Context(t) - mc := mocks.NewReaderWriter(t) + // Helper function to set up common test environment + setupTxmTest := func( + txExpirationRebroadcast bool, + latestBlockhashFunc func() (*rpc.GetLatestBlockhashResult, error), + slotHeightFunc func() (uint64, error), + sendTxFunc func() (solana.Signature, error), + statuses map[solana.Signature]func() *rpc.SignatureStatusesResult, + ) (*Txm, *mocks.ReaderWriter, *keyMocks.SimpleKeystore) { + cfg.Chain.TxExpirationRebroadcast = &txExpirationRebroadcast + + mc := mocks.NewReaderWriter(t) + if latestBlockhashFunc != nil { + mc.On("LatestBlockhash", mock.Anything).Return( + func(_ context.Context) (*rpc.GetLatestBlockhashResult, error) { + return latestBlockhashFunc() + }, + ).Maybe() + } + if slotHeightFunc != nil { + mc.On("SlotHeight", mock.Anything).Return( + func(_ context.Context) (uint64, error) { + return slotHeightFunc() + }, + ).Maybe() + } + if sendTxFunc != nil { + mc.On("SendTx", mock.Anything, mock.Anything).Return( + func(_ context.Context, _ *solana.Transaction) (solana.Signature, error) { + return sendTxFunc() + }, + ).Maybe() + } - // First blockhash is set on sender. Second blockhash (the one returned here) is set on txExpirationRebroadcast before rebroadcasting. - // The first one will be invalid as it's initialized in 0 by default. This call will get a valid one greater than slotHeight and go through. - mc.On("LatestBlockhash", mock.Anything).Return(func(_ context.Context) (*rpc.GetLatestBlockhashResult, error) { - return &rpc.GetLatestBlockhashResult{ - Value: &rpc.LatestBlockhashResult{ - LastValidBlockHeight: uint64(2000), - }, - }, nil - }).Maybe() - - // Set up SlotHeight to return a value greater than 0 so the initial LastValidBlockHeight is invalid. - mc.On("SlotHeight", mock.Anything).Return(uint64(1500), nil).Maybe() - mkey := keyMocks.NewSimpleKeystore(t) - mkey.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil) - loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) - txm := NewTxm(id, loader, nil, cfg, mkey, lggr) - require.NoError(t, txm.Start(ctx)) - t.Cleanup(func() { require.NoError(t, txm.Close()) }) - sig1 := randomSignature(t) - mc.On("SendTx", mock.Anything, mock.Anything).Return(sig1, nil).Maybe() - mc.On("SimulateTx", mock.Anything, mock.Anything, mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe() - statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} - mc.On("SignatureStatuses", mock.Anything, mock.AnythingOfType("[]solana.Signature")).Return( - func(_ context.Context, sigs []solana.Signature) (out []*rpc.SignatureStatusesResult) { - for i := range sigs { - get, exists := statuses[sigs[i]] - if !exists { - out = append(out, nil) - continue + mc.On("SimulateTx", mock.Anything, mock.Anything, mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil) + if statuses != nil { + mc.On("SignatureStatuses", mock.Anything, mock.AnythingOfType("[]solana.Signature")).Return( + func(_ context.Context, sigs []solana.Signature) ([]*rpc.SignatureStatusesResult, error) { + var out []*rpc.SignatureStatusesResult + for _, sig := range sigs { + getStatus, exists := statuses[sig] + if !exists { + out = append(out, nil) + } else { + out = append(out, getStatus()) + } + } + return out, nil + }, + ).Maybe() + } + + mkey := keyMocks.NewSimpleKeystore(t) + mkey.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil) + + loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) + txm := NewTxm(id, loader, nil, cfg, mkey, lggr) + require.NoError(t, txm.Start(ctx)) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) + + return txm, mc, mkey + } + + t.Run("WithRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // Mock SlotHeight to return a value greater than 0 + slotHeightFunc := func() (uint64, error) { + return uint64(1500), nil + } + + // Mock LatestBlockhash to return a valid blockhash greater than slotHeight + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + nowTs := time.Now() + sigStatusCallCount := 0 + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // First transaction should be rebroadcasted. + if time.Since(nowTs) < cfg.TxConfirmTimeout()-2*time.Second { + return nil + } else { + // Second transaction should reach finalization. + sigStatusCallCount++ + if sigStatusCallCount == 1 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusProcessed, + } + } else if sigStatusCallCount == 2 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusConfirmed, + } + } else { + wg.Done() + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusFinalized, + } } - out = append(out, get()) } - return out - }, nil, - ) + } - nowTs := time.Now() - sigStatusCallCount := 0 - var wg sync.WaitGroup - wg.Add(1) - statuses[sig1] = func() *rpc.SignatureStatusesResult { - // first transaction should be rebroadcasted. - if time.Since(nowTs) < cfg.TxConfirmTimeout()-2*time.Second { - return nil - } else { - // second transaction should reach finalization. - sigStatusCallCount++ - if sigStatusCallCount == 1 { - return &rpc.SignatureStatusesResult{ - ConfirmationStatus: rpc.ConfirmationStatusProcessed, - } - } else if sigStatusCallCount == 2 { - return &rpc.SignatureStatusesResult{ - ConfirmationStatus: rpc.ConfirmationStatusConfirmed, - } + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, slotHeightFunc, sendTxFunc, statuses) + + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast" + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) // Will create an expired transaction as lastValidBlockHeight is 0 by default. + wg.Wait() + time.Sleep(2 * time.Second) // Sleep to allow for rebroadcasting + + // Check that transaction for txID has been finalized and rebroadcasted + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + rebroadcastCount, err := txm.txs.GetTxRebroadcastCount(txID) + require.NoError(t, err) + require.Equal(t, 1, rebroadcastCount) + }) + + t.Run("WithoutRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := false + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + nowTs := time.Now() + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // Transaction remains unconfirmed and should not be rebroadcasted. + if time.Since(nowTs) < cfg.TxConfirmTimeout() { + return nil } else { wg.Done() - return &rpc.SignatureStatusesResult{ - ConfirmationStatus: rpc.ConfirmationStatusFinalized, + return nil + } + } + // No LatestBlockhash nor slotHeight needed because there's no rebroadcast. + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, nil, nil, sendTxFunc, statuses) + + tx, _ := getTx(t, 5, mkey) + txID := "test-no-rebroadcast" + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) + wg.Wait() + time.Sleep(2 * time.Second) // Sleep to ensure no rebroadcast + + // Check that transaction for txID has not been finalized and has not been rebroadcasted + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + rebroadcastCount, err := txm.txs.GetTxRebroadcastCount(txID) + require.NoError(t, err) + require.Equal(t, 0, rebroadcastCount) + }) + + t.Run("WithMultipleRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + expectedRebroadcastsCount := 3 + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // Mock SlotHeight to return a value greater than 0 + slotHeightFunc := func() (uint64, error) { + return uint64(1500), nil + } + + // Mock LatestBlockhash to return a invalid blockhash first 2 attempts and a valid blockhash third time + // the third one is valid because it is greater than the slotHeight + callCount := 1 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { callCount++ }() + if callCount < expectedRebroadcastsCount { + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(1000), + }, + }, nil + } + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + nowTs := time.Now() + sigStatusCallCount := 0 + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // transaction should be rebroadcasted multiple times. + if time.Since(nowTs) < cfg.TxConfirmTimeout()-2*time.Second { + return nil + } else { + // Second transaction should reach finalization. + sigStatusCallCount++ + if sigStatusCallCount == 1 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusProcessed, + } + } else if sigStatusCallCount == 2 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusConfirmed, + } + } else { + wg.Done() + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusFinalized, + } } } } - } - tx, _ := getTx(t, 0, mkey) - txID := "test" - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) // Will create a expired transaction as lastValidBlockHeight is 0 by default. - wg.Wait() - time.Sleep(2 * time.Second) // Sleep to allow for rebroadcasting - // Check that transaction for txID has been finalized and rebroadcasted - status, err := txm.GetTransactionStatus(ctx, txID) - require.NoError(t, err) - require.Equal(t, types.Finalized, status) - rebroadcastCount, err := txm.txs.GetTxRebroadcastCount(txID) - require.NoError(t, err) - require.Equal(t, 1, rebroadcastCount) + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, slotHeightFunc, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast" + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) // Will create an expired transaction as lastValidBlockHeight is 0 by default. + wg.Wait() + time.Sleep(2 * time.Second) // Sleep to allow for rebroadcasting + + // Check that transaction for txID has been finalized and rebroadcasted + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + rebroadcastCount, err := txm.txs.GetTxRebroadcastCount(txID) + require.NoError(t, err) + require.Equal(t, expectedRebroadcastsCount, rebroadcastCount) + }) + + t.Run("ConfirmedBeforeRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + var wg sync.WaitGroup + wg.Add(1) + count := 0 + statuses[sig1] = func() *rpc.SignatureStatusesResult { + defer func() { count++ }() + + out := &rpc.SignatureStatusesResult{} + if count == 1 { + out.ConfirmationStatus = rpc.ConfirmationStatusConfirmed + return out + } + if count == 2 { + out.ConfirmationStatus = rpc.ConfirmationStatusFinalized + wg.Done() + return out + } + out.ConfirmationStatus = rpc.ConfirmationStatusProcessed + return out + } + + // No LatestBlockhash nor SlotHeight needed + // Our check will detect there are no rebroadcasts to process saving 2 rpc calls and ending loop. + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, nil, nil, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-confirmed-before-rebroadcast" + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) + wg.Wait() + time.Sleep(1 * time.Second) // Allow for processing + + // Check that transaction has been finalized without rebroadcast + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + rebroadcastCount, err := txm.txs.GetTxRebroadcastCount(txID) + require.NoError(t, err) + require.Equal(t, 0, rebroadcastCount) + }) + + t.Run("RebroadcastWithError", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // Mock SlotHeight to return a value greater than 0 + slotHeightFunc := func() (uint64, error) { + return uint64(1500), nil + } + + // Mock LatestBlockhash to return a valid blockhash greater than slotHeight + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // Transaction remains unconfirmed + return nil + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, slotHeightFunc, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast-error" + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) + time.Sleep(2 * time.Second) // Allow for processing + + // TODO: Add check that transaction status is failed due to rebroadcast error when prebroadcast is implemented and we have an error in sendWithRetry + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Pending, status) // TODO: Change to Failed when prebroadcast error is implemented + rebroadcastCount, err := txm.txs.GetTxRebroadcastCount(txID) + require.NoError(t, err) + require.Equal(t, 1, rebroadcastCount) // Attempted to rebroadcast 1 time but encountered error + time.Sleep(2 * time.Second) // Allow for processing + rebroadcastCount, err = txm.txs.GetTxRebroadcastCount(txID) // rebroadcast should still be 1. We should not be rebroadcasting. + require.NoError(t, err) + require.Equal(t, 1, rebroadcastCount) + }) }