Skip to content

Commit

Permalink
vitessdriver: protect against silent failures
Browse files Browse the repository at this point in the history
without this check, database calls could appear to succeed with no errors, but the transaction ends up abandoned, potentially leaving users in an inconsistent state

Signed-off-by: Derek Perkins <[email protected]>
  • Loading branch information
derekperkins committed Jan 5, 2022
1 parent 3207038 commit 38d84a1
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 7 deletions.
64 changes: 60 additions & 4 deletions go/vt/vitessdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"

"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -251,18 +252,62 @@ func (c *conn) Close() error {
// DistributedTxFromSessionToken allows users to send serialized sessions over the wire and
// reconnect to an existing transaction. Setting the sessionToken and address on the
// supplied configuration is the minimum required
func DistributedTxFromSessionToken(ctx context.Context, c Configuration) (*sql.Tx, error) {
// WARNING: the original Tx must already have already done work on all shards to be affected,
// otherwise the ShardSessions will not be sent through in the session token, and thus will
// never be committed in the source. The returned validation function checks to make sure that
// the new transaction work has not added any new ShardSessions.
func DistributedTxFromSessionToken(ctx context.Context, c Configuration) (*sql.Tx, func() error, error) {
if c.SessionToken == "" {
return nil, errors.New("c.SessionToken is required")
return nil, nil, errors.New("c.SessionToken is required")
}

session, err := sessionTokenToSession(c.SessionToken)
if err != nil {
return nil, nil, err
}

// if there isn't 1 or more shards already referenced, no work in this Tx can be committed
originalShardSessionCount := len(session.ShardSessions)
if originalShardSessionCount == 0 {
return nil, nil, errors.New("there must be at least 1 ShardSession")
}

db, err := OpenWithConfiguration(c)
if err != nil {
return nil, err
return nil, nil, err
}

// this should return the only connection associated with the db
return db.BeginTx(ctx, nil)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, nil, err
}

// this is designed to be run after all new work has been done in the tx, similar to
// where you would traditionally run a tx.Commit, to help prevent you from silently
// losing transactional data.
validationFunc := func() error {
var sessionToken string
sessionToken, err = SessionTokenFromTx(ctx, tx)
if err != nil {
return err
}

session, err = sessionTokenToSession(sessionToken)
if err != nil {
return err
}

if len(session.ShardSessions) > originalShardSessionCount {
return fmt.Errorf("mismatched ShardSession count: originally %d, now %d",
originalShardSessionCount, len(session.ShardSessions),
)
}

return nil
}

return tx, validationFunc, nil
}

// SessionTokenFromTx serializes the sessionFromToken on the tx, which can be reconstituted
Expand All @@ -275,6 +320,17 @@ func SessionTokenFromTx(ctx context.Context, tx *sql.Tx) (string, error) {
return "", err
}

session, err := sessionTokenToSession(sessionToken)
if err != nil {
return "", err
}

// if there isn't 1 or more shards already referenced, no work in this Tx can be committed
originalShardSessionCount := len(session.ShardSessions)
if originalShardSessionCount == 0 {
return "", errors.New("there must be at least 1 ShardSession")
}

return sessionToken, nil
}

Expand Down
21 changes: 18 additions & 3 deletions go/vt/vitessdriver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ func TestSessionToken(t *testing.T) {
SessionToken: sessionToken,
}

sameTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
sameTx, sameValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
if err != nil {
t.Fatal(err)
}
Expand All @@ -682,8 +682,18 @@ func TestSessionToken(t *testing.T) {
t.Fatal(err)
}

err = sameValidationFunc()
if err != nil {
t.Fatal(err)
}

// enforce that Rollback can't be called on the distributed tx
noRollbackTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
noRollbackTx, noRollbackValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
if err != nil {
t.Fatal(err)
}

err = noRollbackValidationFunc()
if err != nil {
t.Fatal(err)
}
Expand All @@ -694,7 +704,12 @@ func TestSessionToken(t *testing.T) {
}

// enforce that Commit can't be called on the distributed tx
noCommitTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
noCommitTx, noCommitValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
if err != nil {
t.Fatal(err)
}

err = noCommitValidationFunc()
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 38d84a1

Please sign in to comment.