Skip to content

Commit

Permalink
Add RunInTx method for DB
Browse files Browse the repository at this point in the history
  • Loading branch information
dmakushin authored and Dmitrii Makushin committed Nov 29, 2024
1 parent 992acfb commit ae120dd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU=
github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97 h1:XItoZNmhOih06TC02jK7l3wlpZ0XT/sPQYutDcGOQjg=
github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97/go.mod h1:bM3Vmw1IakoaXocHmMIGgJFYob0vuK+CFWiJHQvz0jQ=
github.com/stephenafamo/scan v0.6.0 h1:N0joyP/wriC9VvP6w9SDxHIuQGatW4c2YW7Z5L4m45s=
github.com/stephenafamo/scan v0.6.0/go.mod h1:FhIUJ8pLNyex36xGFiazDJJ5Xry0UkAi+RkWRrEcRMg=
github.com/stephenafamo/scan v0.6.1 h1:nXokGCQwYazMuyvdNAoK0T8Z76FWcpMvDdtengpz6PU=
github.com/stephenafamo/scan v0.6.1/go.mod h1:FhIUJ8pLNyex36xGFiazDJJ5Xry0UkAi+RkWRrEcRMg=
github.com/stephenafamo/sqlparser v0.0.0-20241111104950-b04fa8a26c9c h1:JFga++XBnZG2xlnvQyHJkeBWZ9G9mGdtgvLeSRbp/BA=
Expand Down
28 changes: 28 additions & 0 deletions stdlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"

"github.com/stephenafamo/scan"
"github.com/stephenafamo/scan/stdscan"
Expand Down Expand Up @@ -96,6 +98,32 @@ func (d DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
return NewTx(tx), nil
}

// RunInTx runs the provided function in a transaction.
// If the function returns an error, the transaction is rolled back.
// Otherwise, the transaction is committed.
func (d DB) RunInTx(ctx context.Context, txOptions *sql.TxOptions, fn func(context.Context, Tx) error) error {
tx, err := d.BeginTx(ctx, txOptions)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}

if err := fn(ctx, tx); err != nil {
err = fmt.Errorf("call method in transaction: %w", err)

if rollbackErr := tx.Rollback(); rollbackErr != nil {
return errors.Join(err, rollbackErr)
}

return err
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}

return nil
}

// NewTx wraps an [*sql.Tx] and returns a type that implements [Queryer] but still
// retains the expected methods used by *sql.Tx
// This is useful when an existing *sql.Tx is used in other places in the codebase
Expand Down

0 comments on commit ae120dd

Please sign in to comment.