From b7add2e1fdabd9d095d2b2b0bffb5d3d6f5834f7 Mon Sep 17 00:00:00 2001 From: qiangxue Date: Tue, 7 Jan 2020 10:38:52 -0500 Subject: [PATCH] Added DB.WithContext --- db.go | 22 +++++++++++++++++++++- db_test.go | 14 ++++++++++++++ model_query.go | 1 + query.go | 1 + select.go | 1 + 5 files changed, 38 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index bced576..b2dc969 100644 --- a/db.go +++ b/db.go @@ -60,6 +60,7 @@ type ( sqlDB *sql.DB driverName string + ctx context.Context } // Errors represents a list of errors. @@ -129,6 +130,19 @@ func (db *DB) Clone() *DB { return db2 } +// WithContext returns a new instance of DB associated with the given context. +func (db *DB) WithContext(ctx context.Context) *DB { + db2 := db.Clone() + db2.ctx = ctx + return db2 +} + +// Context returns the context associated with the DB instance. +// It returns nil if no context is associated. +func (db *DB) Context() context.Context { + return db.ctx +} + // DB returns the sql.DB instance encapsulated by dbx.DB. func (db *DB) DB() *sql.DB { return db.sqlDB @@ -143,7 +157,13 @@ func (db *DB) Close() error { // Begin starts a transaction. func (db *DB) Begin() (*Tx, error) { - tx, err := db.sqlDB.Begin() + var tx *sql.Tx + var err error + if db.ctx != nil { + tx, err = db.sqlDB.BeginTx(db.ctx, nil) + } else { + tx, err = db.sqlDB.Begin() + } if err != nil { return nil, err } diff --git a/db_test.go b/db_test.go index 59436e5..99b6817 100644 --- a/db_test.go +++ b/db_test.go @@ -5,6 +5,7 @@ package dbx import ( + "context" "database/sql" "errors" "io/ioutil" @@ -36,7 +37,13 @@ func TestDB_Open(t *testing.T) { assert.NotNil(t, db.sqlDB) assert.NotNil(t, db.FieldMapper) db2 := db.Clone() + assert.NotEqual(t, db, db2) assert.Equal(t, db.driverName, db2.driverName) + ctx := context.Background() + db3 := db.WithContext(ctx) + assert.Equal(t, ctx, db3.ctx) + assert.Equal(t, ctx, db3.Context()) + assert.NotEqual(t, db, db3) } _, err = Open("xyz", TestDSN) @@ -179,6 +186,13 @@ func TestDB_Begin(t *testing.T) { }, desc: "Wrap", }, + { + makeTx: func(db *DB) *Tx { + tx, _ := db.BeginTx(context.Background(), nil) + return tx + }, + desc: "BeginTx", + }, } db := getPreparedDB() diff --git a/model_query.go b/model_query.go index f7c565b..7289aa8 100644 --- a/model_query.go +++ b/model_query.go @@ -32,6 +32,7 @@ var ( func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder Builder) *ModelQuery { q := &ModelQuery{ db: db, + ctx: db.ctx, builder: builder, model: newStructValue(model, fieldMapFunc), } diff --git a/query.go b/query.go index 565e562..c69602b 100644 --- a/query.go +++ b/query.go @@ -68,6 +68,7 @@ func NewQuery(db *DB, executor Executor, sql string) *Query { rawSQL: rawSQL, placeholders: placeholders, params: Params{}, + ctx: db.ctx, FieldMapper: db.FieldMapper, LogFunc: db.LogFunc, PerfFunc: db.PerfFunc, diff --git a/select.go b/select.go index 77d0c8c..8a9e5a6 100644 --- a/select.go +++ b/select.go @@ -59,6 +59,7 @@ func NewSelectQuery(builder Builder, db *DB) *SelectQuery { union: []UnionInfo{}, limit: -1, params: Params{}, + ctx: db.ctx, FieldMapper: db.FieldMapper, } }