Skip to content

Commit

Permalink
added support to call tx.Rollback and Commit within DB.Transactional
Browse files Browse the repository at this point in the history
  • Loading branch information
qiangxue committed Mar 1, 2018
1 parent 21ba863 commit db79922
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
29 changes: 22 additions & 7 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,33 @@ func (db *DB) Wrap(sqlTx *sql.Tx) *Tx {
// Transactional starts a transaction and executes the given function.
// If the function returns an error, the transaction will be rolled back.
// Otherwise, the transaction will be committed.
func (db *DB) Transactional(f func(*Tx) error) error {
func (db *DB) Transactional(f func(*Tx) error) (err error) {
tx, err := db.Begin()
if err != nil {
return err
}
if err := f(tx); err != nil {
if e := tx.Rollback(); e != nil {
return Errors{err, e}

defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
if err2 := tx.Rollback(); err2 != nil {
if err2 == sql.ErrTxDone {
return
}
err = Errors{err, err2}
}
} else {
if err = tx.Commit(); err == sql.ErrTxDone {
err = nil
}
}
return err
}
return tx.Commit()
}()

err = f(tx)

return err
}

// DriverName returns the name of the DB driver.
Expand Down
36 changes: 36 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,42 @@ func TestDB_Transactional(t *testing.T) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}

// Rollback called within Transactional and return error
err = db.Transactional(func(tx *Tx) error {
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
if err != nil {
return err
}
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
if err != nil {
tx.Rollback()
return err
}
return nil
})
if assert.NotNil(t, err) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}

// Rollback called within Transactional without returning error
err = db.Transactional(func(tx *Tx) error {
_, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute()
if err != nil {
return err
}
_, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute()
if err != nil {
tx.Rollback()
return nil
}
return nil
})
if assert.Nil(t, err) {
db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name)
assert.Equal(t, "Go in Action", name)
}
}

func TestErrors_Error(t *testing.T) {
Expand Down

0 comments on commit db79922

Please sign in to comment.