Skip to content

Commit

Permalink
Added DB.WithContext
Browse files Browse the repository at this point in the history
  • Loading branch information
qiangxue committed Jan 7, 2020
1 parent 10d1ad7 commit b7add2e
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 1 deletion.
22 changes: 21 additions & 1 deletion db.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type (

sqlDB *sql.DB
driverName string
ctx context.Context
}

// Errors represents a list of errors.
Expand Down Expand Up @@ -129,6 +130,19 @@ func (db *DB) Clone() *DB {
return db2
}

// WithContext returns a new instance of DB associated with the given context.
func (db *DB) WithContext(ctx context.Context) *DB {
db2 := db.Clone()
db2.ctx = ctx
return db2
}

// Context returns the context associated with the DB instance.
// It returns nil if no context is associated.
func (db *DB) Context() context.Context {
return db.ctx
}

// DB returns the sql.DB instance encapsulated by dbx.DB.
func (db *DB) DB() *sql.DB {
return db.sqlDB
Expand All @@ -143,7 +157,13 @@ func (db *DB) Close() error {

// Begin starts a transaction.
func (db *DB) Begin() (*Tx, error) {
tx, err := db.sqlDB.Begin()
var tx *sql.Tx
var err error
if db.ctx != nil {
tx, err = db.sqlDB.BeginTx(db.ctx, nil)
} else {
tx, err = db.sqlDB.Begin()
}
if err != nil {
return nil, err
}
Expand Down
14 changes: 14 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package dbx

import (
"context"
"database/sql"
"errors"
"io/ioutil"
Expand Down Expand Up @@ -36,7 +37,13 @@ func TestDB_Open(t *testing.T) {
assert.NotNil(t, db.sqlDB)
assert.NotNil(t, db.FieldMapper)
db2 := db.Clone()
assert.NotEqual(t, db, db2)
assert.Equal(t, db.driverName, db2.driverName)
ctx := context.Background()
db3 := db.WithContext(ctx)
assert.Equal(t, ctx, db3.ctx)
assert.Equal(t, ctx, db3.Context())
assert.NotEqual(t, db, db3)
}

_, err = Open("xyz", TestDSN)
Expand Down Expand Up @@ -179,6 +186,13 @@ func TestDB_Begin(t *testing.T) {
},
desc: "Wrap",
},
{
makeTx: func(db *DB) *Tx {
tx, _ := db.BeginTx(context.Background(), nil)
return tx
},
desc: "BeginTx",
},
}

db := getPreparedDB()
Expand Down
1 change: 1 addition & 0 deletions model_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder Builder) *ModelQuery {
q := &ModelQuery{
db: db,
ctx: db.ctx,
builder: builder,
model: newStructValue(model, fieldMapFunc),
}
Expand Down
1 change: 1 addition & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func NewQuery(db *DB, executor Executor, sql string) *Query {
rawSQL: rawSQL,
placeholders: placeholders,
params: Params{},
ctx: db.ctx,
FieldMapper: db.FieldMapper,
LogFunc: db.LogFunc,
PerfFunc: db.PerfFunc,
Expand Down
1 change: 1 addition & 0 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func NewSelectQuery(builder Builder, db *DB) *SelectQuery {
union: []UnionInfo{},
limit: -1,
params: Params{},
ctx: db.ctx,
FieldMapper: db.FieldMapper,
}
}
Expand Down

0 comments on commit b7add2e

Please sign in to comment.