From 38d84a156abb542ad9c46f9d708e1cc432fd5568 Mon Sep 17 00:00:00 2001 From: Derek Perkins Date: Wed, 5 Jan 2022 04:12:10 -0700 Subject: [PATCH] vitessdriver: protect against silent failures without this check, database calls could appear to succeed with no errors, but the transaction ends up abandoned, potentially leaving users in an inconsistent state Signed-off-by: Derek Perkins --- go/vt/vitessdriver/driver.go | 64 +++++++++++++++++++++++++++++-- go/vt/vitessdriver/driver_test.go | 21 ++++++++-- 2 files changed, 78 insertions(+), 7 deletions(-) diff --git a/go/vt/vitessdriver/driver.go b/go/vt/vitessdriver/driver.go index 2cd375daaab..7f8a50a4956 100644 --- a/go/vt/vitessdriver/driver.go +++ b/go/vt/vitessdriver/driver.go @@ -23,6 +23,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "fmt" "google.golang.org/grpc" "google.golang.org/protobuf/proto" @@ -251,18 +252,62 @@ func (c *conn) Close() error { // DistributedTxFromSessionToken allows users to send serialized sessions over the wire and // reconnect to an existing transaction. Setting the sessionToken and address on the // supplied configuration is the minimum required -func DistributedTxFromSessionToken(ctx context.Context, c Configuration) (*sql.Tx, error) { +// WARNING: the original Tx must already have already done work on all shards to be affected, +// otherwise the ShardSessions will not be sent through in the session token, and thus will +// never be committed in the source. The returned validation function checks to make sure that +// the new transaction work has not added any new ShardSessions. +func DistributedTxFromSessionToken(ctx context.Context, c Configuration) (*sql.Tx, func() error, error) { if c.SessionToken == "" { - return nil, errors.New("c.SessionToken is required") + return nil, nil, errors.New("c.SessionToken is required") + } + + session, err := sessionTokenToSession(c.SessionToken) + if err != nil { + return nil, nil, err + } + + // if there isn't 1 or more shards already referenced, no work in this Tx can be committed + originalShardSessionCount := len(session.ShardSessions) + if originalShardSessionCount == 0 { + return nil, nil, errors.New("there must be at least 1 ShardSession") } db, err := OpenWithConfiguration(c) if err != nil { - return nil, err + return nil, nil, err } // this should return the only connection associated with the db - return db.BeginTx(ctx, nil) + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, nil, err + } + + // this is designed to be run after all new work has been done in the tx, similar to + // where you would traditionally run a tx.Commit, to help prevent you from silently + // losing transactional data. + validationFunc := func() error { + var sessionToken string + sessionToken, err = SessionTokenFromTx(ctx, tx) + if err != nil { + return err + } + + session, err = sessionTokenToSession(sessionToken) + if err != nil { + return err + } + + if len(session.ShardSessions) > originalShardSessionCount { + return fmt.Errorf("mismatched ShardSession count: originally %d, now %d", + originalShardSessionCount, len(session.ShardSessions), + ) + } + + return nil + } + + return tx, validationFunc, nil } // SessionTokenFromTx serializes the sessionFromToken on the tx, which can be reconstituted @@ -275,6 +320,17 @@ func SessionTokenFromTx(ctx context.Context, tx *sql.Tx) (string, error) { return "", err } + session, err := sessionTokenToSession(sessionToken) + if err != nil { + return "", err + } + + // if there isn't 1 or more shards already referenced, no work in this Tx can be committed + originalShardSessionCount := len(session.ShardSessions) + if originalShardSessionCount == 0 { + return "", errors.New("there must be at least 1 ShardSession") + } + return sessionToken, nil } diff --git a/go/vt/vitessdriver/driver_test.go b/go/vt/vitessdriver/driver_test.go index cd22a28fcde..6af7a534de8 100644 --- a/go/vt/vitessdriver/driver_test.go +++ b/go/vt/vitessdriver/driver_test.go @@ -667,7 +667,7 @@ func TestSessionToken(t *testing.T) { SessionToken: sessionToken, } - sameTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) + sameTx, sameValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) if err != nil { t.Fatal(err) } @@ -682,8 +682,18 @@ func TestSessionToken(t *testing.T) { t.Fatal(err) } + err = sameValidationFunc() + if err != nil { + t.Fatal(err) + } + // enforce that Rollback can't be called on the distributed tx - noRollbackTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) + noRollbackTx, noRollbackValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) + if err != nil { + t.Fatal(err) + } + + err = noRollbackValidationFunc() if err != nil { t.Fatal(err) } @@ -694,7 +704,12 @@ func TestSessionToken(t *testing.T) { } // enforce that Commit can't be called on the distributed tx - noCommitTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) + noCommitTx, noCommitValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) + if err != nil { + t.Fatal(err) + } + + err = noCommitValidationFunc() if err != nil { t.Fatal(err) }