From e984ee904f1b2df904d0b2e6b2ca5064c87c0409 Mon Sep 17 00:00:00 2001 From: qiangxue Date: Fri, 15 Jun 2018 12:37:42 -0400 Subject: [PATCH] Added context support and bump up go version requirement --- .travis.yml | 1 - README.md | 20 +++++++++++++------- db.go | 42 ++++++++++++++++++++++++++++++++++++++++++ model_query.go | 22 ++++++++++++++++++---- query.go | 45 +++++++++++++++++++++++++++++++++++++++------ 5 files changed, 112 insertions(+), 18 deletions(-) diff --git a/.travis.yml b/.travis.yml index 40b99de..e17ee48 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: go go: - - 1.7 - 1.8 - 1.9 - tip diff --git a/README.md b/README.md index ae02844..fc2fbe1 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ as well as DB-agnostic query building capabilities. ozzo-dbx is not an ORM. It h ## Requirements -Go 1.7 or above. +Go 1.8 or above. ## Installation @@ -61,12 +61,6 @@ Run the following command to install the package: go get github.com/go-ozzo/ozzo-dbx ``` -You may also get specified release of the package by: - -``` -go get gopkg.in/go-ozzo/ozzo-dbx.v1 -``` - In addition, install the specific DB driver package for the kind of database to be used. Please refer to [SQL database drivers](https://github.com/golang/go/wiki/SQLDrivers) for a complete list. For example, if you are using MySQL, you may install the following package: @@ -332,6 +326,18 @@ q.One(&user) // ... ``` + +## Cancelable Queries + +Queries are cancelable when they are used with `context.Context`. In particular, by calling `Query.WithContext()` you +can associate a context with a query and use the context to cancel the query while it is running. For example, + +```go +q := db.NewQuery("SELECT id, name FROM users") +rows := q.WithContext(ctx).All() +``` + + ## Building Queries Instead of writing plain SQLs, `ozzo-dbx` allows you to build SQLs programmatically, which often leads to cleaner, diff --git a/db.go b/db.go index 546395c..752b6d7 100644 --- a/db.go +++ b/db.go @@ -7,6 +7,7 @@ package dbx import ( "bytes" + "context" "database/sql" "regexp" "strings" @@ -131,6 +132,15 @@ func (db *DB) Begin() (*Tx, error) { return &Tx{db.newBuilder(tx), tx}, nil } +// BeginTx starts a transaction with the given context and transaction options. +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.sqlDB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{db.newBuilder(tx), tx}, nil +} + // Wrap encapsulates an existing transaction. func (db *DB) Wrap(sqlTx *sql.Tx) *Tx { return &Tx{db.newBuilder(sqlTx), sqlTx} @@ -168,6 +178,38 @@ func (db *DB) Transactional(f func(*Tx) error) (err error) { return err } +// TransactionalContext starts a transaction and executes the given function with the given context and transaction options. +// If the function returns an error, the transaction will be rolled back. +// Otherwise, the transaction will be committed. +func (db *DB) TransactionalContext(ctx context.Context, opts *sql.TxOptions, f func(*Tx) error) (err error) { + tx, err := db.BeginTx(ctx, opts) + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } else if err != nil { + if err2 := tx.Rollback(); err2 != nil { + if err2 == sql.ErrTxDone { + return + } + err = Errors{err, err2} + } + } else { + if err = tx.Commit(); err == sql.ErrTxDone { + err = nil + } + } + }() + + err = f(tx) + + return err +} + // DriverName returns the name of the DB driver. func (db *DB) DriverName() string { return db.driverName diff --git a/model_query.go b/model_query.go index 69d3d02..3dad139 100644 --- a/model_query.go +++ b/model_query.go @@ -1,6 +1,7 @@ package dbx import ( + "context" "errors" "fmt" "reflect" @@ -15,6 +16,7 @@ type ( // ModelQuery represents a query associated with a struct model. ModelQuery struct { db *DB + ctx context.Context builder Builder model *structValue exclude []string @@ -39,6 +41,18 @@ func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder return q } + +// Context returns the context associated with the query. +func (q *ModelQuery) Context() context.Context { + return q.ctx +} + +// WithContext associates a context with the query. +func (q *ModelQuery) WithContext(ctx context.Context) *ModelQuery { + q.ctx = ctx + return q +} + // Exclude excludes the specified struct fields from being inserted/updated into the DB table. func (q *ModelQuery) Exclude(attrs ...string) *ModelQuery { q.exclude = attrs @@ -68,12 +82,12 @@ func (q *ModelQuery) Insert(attrs ...string) error { } if pkName == "" { - _, err := q.builder.Insert(q.model.tableName, Params(cols)).Execute() + _, err := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx).Execute() return err } // handle auto-incremental PK - query := q.builder.Insert(q.model.tableName, Params(cols)) + query := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx) pkValue, err := insertAndReturnPK(q.db, query, pkName) if err != nil { return err @@ -142,7 +156,7 @@ func (q *ModelQuery) Update(attrs ...string) error { for name := range pk { delete(cols, name) } - _, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).Execute() + _, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).WithContext(q.ctx).Execute() return err } @@ -155,6 +169,6 @@ func (q *ModelQuery) Delete() error { if len(pk) == 0 { return MissingPKError } - _, err := q.builder.Delete(q.model.tableName, HashExp(pk)).Execute() + _, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute() return err } diff --git a/query.go b/query.go index 6c33f05..df25166 100644 --- a/query.go +++ b/query.go @@ -5,6 +5,7 @@ package dbx import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -21,8 +22,12 @@ type Params map[string]interface{} type Executor interface { // Exec executes a SQL statement Exec(query string, args ...interface{}) (sql.Result, error) + // ExecContext executes a SQL statement with the given context + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) // Query queries a SQL statement Query(query string, args ...interface{}) (*sql.Rows, error) + // QueryContext queries a SQL statement with the given context + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) // Prepare creates a prepared statement Prepare(query string) (*sql.Stmt, error) } @@ -36,6 +41,7 @@ type Query struct { params Params stmt *sql.Stmt + ctx context.Context // FieldMapper maps struct field names to DB column names. FieldMapper FieldMapFunc @@ -70,6 +76,17 @@ func (q *Query) SQL() string { return q.sql } +// Context returns the context associated with the query. +func (q *Query) Context() context.Context { + return q.ctx +} + +// WithContext associates a context with the query. +func (q *Query) WithContext(ctx context.Context) *Query { + q.ctx = ctx + return q +} + // logSQL returns the SQL statement with parameters being replaced with the actual values. // The result is only for logging purpose and should not be used to execute. func (q *Query) logSQL() string { @@ -168,10 +185,18 @@ func (q *Query) Execute() (result sql.Result, err error) { defer q.log(time.Now(), true) - if q.stmt == nil { - result, err = q.executor.Exec(q.rawSQL, params...) + if q.ctx == nil { + if q.stmt == nil { + result, err = q.executor.Exec(q.rawSQL, params...) + } else { + result, err = q.stmt.Exec(params...) + } } else { - result, err = q.stmt.Exec(params...) + if q.stmt == nil { + result, err = q.executor.ExecContext(q.ctx, q.rawSQL, params...) + } else { + result, err = q.stmt.ExecContext(q.ctx, params...) + } } return } @@ -238,10 +263,18 @@ func (q *Query) Rows() (rows *Rows, err error) { defer q.log(time.Now(), false) var rr *sql.Rows - if q.stmt == nil { - rr, err = q.executor.Query(q.rawSQL, params...) + if q.ctx == nil { + if q.stmt == nil { + rr, err = q.executor.Query(q.rawSQL, params...) + } else { + rr, err = q.stmt.Query(params...) + } } else { - rr, err = q.stmt.Query(params...) + if q.stmt == nil { + rr, err = q.executor.QueryContext(q.ctx, q.rawSQL, params...) + } else { + rr, err = q.stmt.QueryContext(q.ctx, params...) + } } rows = &Rows{rr, q.FieldMapper} return