Skip to content

Commit

Permalink
rebroadcast with new blockhash + add integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Farber98 committed Dec 23, 2024
1 parent dff36b0 commit dbf4e41
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 61 deletions.
37 changes: 30 additions & 7 deletions pkg/solana/txm/pendingtx.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ type PendingTxContext interface {
// If the highest aggregated state is less than the current state, a reorg has occurred and we need to handle it.
TxHasReorg(id string) bool
// OnReorg resets the transaction state to Broadcasted for the given signature and returns the pendingTx for retrying.
OnReorg(sig solana.Signature, id string) (pendingTx, error)
OnReorg(sig solana.Signature, id string) error
// GetPendingTx returns the pendingTx for the given ID if it exists
GetPendingTx(id string) (pendingTx, error)
}

// finishedTx is used to store info required to track transactions to finality or error
Expand Down Expand Up @@ -588,7 +590,7 @@ func (c *pendingTxContext) GetSignatureInfo(sig solana.Signature) (txInfo, error
return info, nil
}

func (c *pendingTxContext) OnReorg(sig solana.Signature, id string) (pendingTx, error) {
func (c *pendingTxContext) OnReorg(sig solana.Signature, id string) error {
err := c.withReadLock(func() error {
// Check if the transaction is still in a non finalized/errored state
var broadcastedExists, confirmedExists bool
Expand All @@ -601,7 +603,7 @@ func (c *pendingTxContext) OnReorg(sig solana.Signature, id string) (pendingTx,
})
if err != nil {
// If transaction or sig are not found, return
return pendingTx{}, err
return err
}

var pTx pendingTx
Expand Down Expand Up @@ -638,11 +640,10 @@ func (c *pendingTxContext) OnReorg(sig solana.Signature, id string) (pendingTx,
})
if err != nil {
// If transaction was not found
return pendingTx{}, err
return err
}

// Returns the transaction in case we need to rebroadcast and restart the retry/bumping cycle
return pTx, nil
return nil
}

// TxHasReorg determines whether a reorg has occurred for a given tx.
Expand Down Expand Up @@ -716,6 +717,24 @@ func (c *pendingTxContext) UpdateSignatureStatus(sig solana.Signature, status Tx
return nil
}

func (c *pendingTxContext) GetPendingTx(id string) (pendingTx, error) {
c.lock.RLock()
defer c.lock.RUnlock()
var tx, tempTx pendingTx
var broadcastedExists, confirmedExists bool
if tempTx, broadcastedExists = c.broadcastedProcessedTxs[id]; broadcastedExists {
tx = tempTx
}
if tempTx, confirmedExists = c.confirmedTxs[id]; confirmedExists {
tx = tempTx
}

if !broadcastedExists && !confirmedExists {
return pendingTx{}, ErrTransactionNotFound
}
return tx, nil
}

func (c *pendingTxContext) withReadLock(fn func() error) error {
c.lock.RLock()
defer c.lock.RUnlock()
Expand Down Expand Up @@ -847,7 +866,7 @@ func (c *pendingTxContextWithProm) GetSignatureInfo(sig solana.Signature) (txInf
return c.pendingTx.GetSignatureInfo(sig)
}

func (c *pendingTxContextWithProm) OnReorg(sig solana.Signature, id string) (pendingTx, error) {
func (c *pendingTxContextWithProm) OnReorg(sig solana.Signature, id string) error {
return c.pendingTx.OnReorg(sig, id)
}

Expand All @@ -858,3 +877,7 @@ func (c *pendingTxContextWithProm) TxHasReorg(id string) bool {
func (c *pendingTxContextWithProm) UpdateSignatureStatus(sig solana.Signature, status TxState) error {
return c.pendingTx.UpdateSignatureStatus(sig, status)
}

func (c *pendingTxContextWithProm) GetPendingTx(id string) (pendingTx, error) {
return c.pendingTx.GetPendingTx(id)
}
44 changes: 39 additions & 5 deletions pkg/solana/txm/pendingtx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,8 @@ func TestPendingTxContext_OnReorg(t *testing.T) {
require.NoError(t, err)

// Call OnReorg
pTx, err := txs.OnReorg(sig, txID)
require.NoError(t, txs.OnReorg(sig, txID))
pTx, err := txs.GetPendingTx(txID)
require.NoError(t, err)
require.Equal(t, Broadcasted, pTx.state)

Expand All @@ -1487,7 +1488,8 @@ func TestPendingTxContext_OnReorg(t *testing.T) {
require.NoError(t, err)

// Call OnReorg
pTx, err := txs.OnReorg(sig, txID)
require.NoError(t, txs.OnReorg(sig, txID))
pTx, err := txs.GetPendingTx(txID)
require.NoError(t, err)
require.Equal(t, Broadcasted, pTx.state)

Expand All @@ -1512,7 +1514,7 @@ func TestPendingTxContext_OnReorg(t *testing.T) {
require.NoError(t, err)

// Call OnReorg
_, err = txs.OnReorg(sig, txID)
err = txs.OnReorg(sig, txID)
require.Error(t, err)
require.Equal(t, ErrTransactionNotFound, err)
})
Expand All @@ -1524,13 +1526,13 @@ func TestPendingTxContext_OnReorg(t *testing.T) {
require.NoError(t, err)

// Call OnReorg
_, err = txs.OnReorg(sig, txID)
err = txs.OnReorg(sig, txID)
require.Error(t, err)
require.Equal(t, ErrTransactionNotFound, err)
})

t.Run("fail to reset non-existent transaction", func(t *testing.T) {
_, err := txs.OnReorg(randomSignature(t), "non-existent")
err := txs.OnReorg(randomSignature(t), "non-existent")
require.Error(t, err)
require.Equal(t, ErrTransactionNotFound, err)
})
Expand Down Expand Up @@ -1623,3 +1625,35 @@ func TestPendingTxContext_TxHasReorg(t *testing.T) {
require.True(t, hasReorg, "expected reorg when all signatures are < transaction state")
})
}

func TestPendingTxContext_GetPendingTx(t *testing.T) {
t.Parallel()
txs := newPendingTxContext()

t.Run("successfully retrieve broadcasted transaction", func(t *testing.T) {
txID, sig := createTxAndAddSig(t, txs)
_, err := txs.OnProcessed(sig)
require.NoError(t, err)

tx, err := txs.GetPendingTx(txID)
require.NoError(t, err)
require.Equal(t, txID, tx.id)
})

t.Run("successfully retrieve confirmed transaction", func(t *testing.T) {
txID, sig := createTxAndAddSig(t, txs)
_, err := txs.OnProcessed(sig)
require.NoError(t, err)
_, err = txs.OnConfirmed(sig)
require.NoError(t, err)

tx, err := txs.GetPendingTx(txID)
require.NoError(t, err)
require.Equal(t, txID, tx.id)
})

t.Run("fail to retrieve non-existent transaction", func(t *testing.T) {
_, err := txs.GetPendingTx("non-existent-id")
require.ErrorIs(t, err, ErrTransactionNotFound)
})
}
117 changes: 74 additions & 43 deletions pkg/solana/txm/txm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/gagliardetto/solana-go"
solanaGo "github.com/gagliardetto/solana-go"
"github.com/gagliardetto/solana-go/rpc"
"github.com/google/uuid"
Expand Down Expand Up @@ -445,7 +446,7 @@ func (txm *Txm) processConfirmations(ctx context.Context, client client.ReaderWr
if status == nil {
// sig not found could mean invalid tx or not picked up yet, keep polling
// we also need to check if a potential re-org has occurred for this sig and handle it
txm.handleReorg(ctx, sig, status)
txm.handleReorg(ctx, client, sig, status)
txm.handleNotFoundSignatureStatus(sig)
continue
}
Expand All @@ -460,7 +461,7 @@ func (txm *Txm) processConfirmations(ctx context.Context, client client.ReaderWr
case rpc.ConfirmationStatusProcessed:
// if signature is processed, keep polling for confirmed or finalized status
// we also need to check if a potential re-org has occurred for this sig and handle it
txm.handleReorg(ctx, sig, status)
txm.handleReorg(ctx, client, sig, status)
txm.handleProcessedSignatureStatus(sig)
case rpc.ConfirmationStatusConfirmed:
// if signature is confirmed, keep polling for finalized status
Expand Down Expand Up @@ -515,15 +516,15 @@ func (txm *Txm) handleErrorSignatureStatus(sig solanaGo.Signature, status *rpc.S

// handleReorg detects and manages transaction state regressions (re-orgs) for a given signature.
//
// A re-org occurs when the blockchain state of a signature regresses to:
// A re-org occurs when the on-chain state of a signature regresses:
// - Confirmed -> Processed || Not Found
// - Processed -> Not Found
//
// This function determines if the signature’s state regression impacts the overall transaction state and, if so, takes appropriate action:
// - For regressions from "Confirmed", our in memory layer is updated, the tx is rebroadcasted, and the retry/bumping cycle is restarted.
// - For regressions from "Confirmed", our in memory layer is updated and the tx is rebroadcasted with a new hash restarting the retry/bumping cycle.
// - For regressions from "Processed", the existing retry/bumping cycle is still running, so no immediate action is needed. We only update our in-memory state to Broadcasted.
// Future rebroadcasts, will be handled by the TxExpirationRebroadcast logic (if enabled) when the transaction expires.
func (txm *Txm) handleReorg(ctx context.Context, sig solanaGo.Signature, status *rpc.SignatureStatusesResult) {
func (txm *Txm) handleReorg(ctx context.Context, client client.ReaderWriter, sig solanaGo.Signature, status *rpc.SignatureStatusesResult) {
// Retrieve the last known status of the transaction associated with this signature from the in-memory layer.
txInfo, err := txm.txs.GetSignatureInfo(sig)
if err != nil {
Expand Down Expand Up @@ -552,26 +553,40 @@ func (txm *Txm) handleReorg(ctx context.Context, sig solanaGo.Signature, status

txm.lggr.Warnw("re-org detected for transaction", "txID", txInfo.id, "signature", sig, "previousStatus", txInfo.state, "currentStatus", currentTxState)
// update the in-memory state and return the transaction associated with the signature for rebroadcasting and restarting retry/bump cycle if needed
pTx, err := txm.txs.OnReorg(sig, txInfo.id)
err := txm.txs.OnReorg(sig, txInfo.id)
if err != nil {
txm.lggr.Errorw("failed to handle re-org", "signature", sig, "id", pTx.id, "error", err)
txm.lggr.Errorw("failed to handle re-org", "signature", sig, "id", txInfo.id, "error", err)
return
}

// For regressions from "Confirmed", rebroadcast tx and restart retry/bumping cycle.
// For regressions from "Confirmed, we'll need to rebroadcast the tx.
if regressionType == FromConfirmed {
retryCtx, cancel := context.WithTimeout(ctx, pTx.cfg.Timeout)
txm.done.Add(1)
go func() {
defer txm.done.Done()
txm.retryTx(retryCtx, cancel, pTx, pTx.tx, sig)
txm.lggr.Debugw("re-org retry completed", "id", pTx.id)
}()
pTx, err := txm.getPendingTx(txInfo.id)
if err != nil {
txm.lggr.Errorw("failed to get pending tx for rebroadcast", "id", txInfo.id, "error", err)
return
}

// Original block may be invalid. To be on the safe side, we'll use a new blockhash
blockhash, err := client.LatestBlockhash(ctx)
if err != nil {
txm.lggr.Errorw("failed to getLatestBlockhash for rebroadcast", "error", err)
return
}
if blockhash == nil || blockhash.Value == nil {
txm.lggr.Errorw("nil pointer returned from getLatestBlockhash for rebroadcast")
return
}

newSig, err := txm.rebroadcastWithNewBlockhash(ctx, pTx, blockhash.Value.Blockhash, blockhash.Value.LastValidBlockHeight)
if err != nil {
return // logging handled inside the func
}

txm.lggr.Debugw("confirmed re-orged tx was rebroadcasted successfully", "id", pTx.id, "newSig", newSig)
}
// For regressions from "Processed" do not restart the cycle immediately.
// The retry/bumping cycle for the original transaction is still active.
// If rebroadcasting becomes necessary later, it will be handled via the
// TxExpirationRebroadcast logic (if enabled) when the transaction expires.
// For regressions from "Processed" do nothing now. The retry/bumping cycle for the original tx is still active.
// If rebroadcasting with new blockhash becomes necessary later, it will be handled via TxExpirationRebroadcast when expired (if enabled)
}
}

Expand Down Expand Up @@ -645,33 +660,15 @@ func (txm *Txm) rebroadcastExpiredTxs(ctx context.Context, client client.ReaderW
return
}

// rebroadcast each expired tx after updating blockhash, lastValidBlockHeight and compute unit price (priority fee)
for _, tx := range expiredBroadcastedTxes {
txm.lggr.Debugw("transaction expired, rebroadcasting", "id", tx.id, "signature", tx.signatures, "lastValidBlockHeight", tx.lastValidBlockHeight, "currentBlockHeight", blockHeight)
// Removes all signatures associated to prior tx and cancels context.
_, err := txm.txs.Remove(tx.id)
// rebroadcast each expired tx
for _, expiredTx := range expiredBroadcastedTxes {
txm.lggr.Debugw("transaction expired, rebroadcasting", "id", expiredTx.id, "signature", expiredTx.signatures, "lastValidBlockHeight", expiredTx.lastValidBlockHeight, "currentBlockHeight", blockHeight)
newSig, err := txm.rebroadcastWithNewBlockhash(ctx, expiredTx, blockhash.Value.Blockhash, blockhash.Value.LastValidBlockHeight)
if err != nil {
txm.lggr.Errorw("failed to remove expired transaction", "id", tx.id, "error", err)
continue
continue // logging handled inside the func
}

tx.tx.Message.RecentBlockhash = blockhash.Value.Blockhash
tx.cfg.BaseComputeUnitPrice = txm.fee.BaseComputeUnitPrice()
rebroadcastTx := pendingTx{
tx: tx.tx,
cfg: tx.cfg,
id: tx.id, // using same id in case it was set by caller and we need to maintain it.
lastValidBlockHeight: blockhash.Value.LastValidBlockHeight,
}
// call sendWithRetry directly to avoid enqueuing
_, _, _, sendErr := txm.sendWithRetry(ctx, rebroadcastTx)
if sendErr != nil {
stateTransitionErr := txm.txs.OnPrebroadcastError(tx.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject)
txm.lggr.Errorw("failed to rebroadcast transaction", "id", tx.id, "error", errors.Join(sendErr, stateTransitionErr))
continue
}

txm.lggr.Debugw("rebroadcast transaction sent", "id", tx.id)
txm.lggr.Debugw("expired tx was rebroadcasted successfully", "id", expiredTx.id, "newSig", newSig)
}
}

Expand Down Expand Up @@ -995,6 +992,36 @@ func (txm *Txm) InflightTxs() int {
return len(txm.txs.ListAllSigs())
}

// rebroadcastWithNewBlockhash attempts to rebroadcast a pending tx with a new blockhash.
// It removes all signatures associated with the prior tx, cancels the context.
// It also updates the compute unit price and assigns a new blockhash for rebroadcasting.
// Calls sendWithRetry directly to avoid enqueuing the transaction.
// If the rebroadcast fails, it logs the error. If successful, it returns the new signature.
func (txm *Txm) rebroadcastWithNewBlockhash(ctx context.Context, pTx pendingTx, blockhash solana.Hash, lastValidBlockHeight uint64) (solana.Signature, error) {
// Removes all signatures associated to prior tx and cancels context.
_, err := txm.txs.Remove(pTx.id)
if err != nil {
txm.lggr.Errorw("failed to remove tx", "id", pTx.id, "error", err)
return solana.Signature{}, err
}

// Update the pendingTx
pTx.tx.Message.RecentBlockhash = blockhash
pTx.cfg.BaseComputeUnitPrice = txm.fee.BaseComputeUnitPrice()
pTx.lastValidBlockHeight = lastValidBlockHeight

// call sendWithRetry directly to avoid enqueuing
_, _, newSig, sendErr := txm.sendWithRetry(ctx, pTx)
if sendErr != nil {
stateTransitionErr := txm.txs.OnPrebroadcastError(pTx.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject)
combinedErr := errors.Join(sendErr, stateTransitionErr)
txm.lggr.Errorw("failed to rebroadcast tx with new blockhash", "id", pTx.id, "error", combinedErr)
return solana.Signature{}, combinedErr
}

return newSig, nil
}

// Close close service
func (txm *Txm) Close() error {
return txm.StopOnce("Txm", func() error {
Expand All @@ -1018,3 +1045,7 @@ func (txm *Txm) defaultTxConfig() TxConfig {
EstimateComputeUnitLimit: txm.cfg.EstimateComputeUnitLimit(),
}
}

func (txm *Txm) getPendingTx(txID string) (pendingTx, error) {
return txm.txs.GetPendingTx(txID)
}
Loading

0 comments on commit dbf4e41

Please sign in to comment.