Skip to content

Commit

Permalink
Added context support and bump up go version requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
qiangxue committed Jun 15, 2018
1 parent db79922 commit e984ee9
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 18 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
language: go

go:
- 1.7
- 1.8
- 1.9
- tip
Expand Down
20 changes: 13 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ as well as DB-agnostic query building capabilities. ozzo-dbx is not an ORM. It h

## Requirements

Go 1.7 or above.
Go 1.8 or above.

## Installation

Expand All @@ -61,12 +61,6 @@ Run the following command to install the package:
go get github.com/go-ozzo/ozzo-dbx
```

You may also get specified release of the package by:

```
go get gopkg.in/go-ozzo/ozzo-dbx.v1
```

In addition, install the specific DB driver package for the kind of database to be used. Please refer to
[SQL database drivers](https://github.com/golang/go/wiki/SQLDrivers) for a complete list. For example, if you are
using MySQL, you may install the following package:
Expand Down Expand Up @@ -332,6 +326,18 @@ q.One(&user)
// ...
```


## Cancelable Queries

Queries are cancelable when they are used with `context.Context`. In particular, by calling `Query.WithContext()` you
can associate a context with a query and use the context to cancel the query while it is running. For example,

```go
q := db.NewQuery("SELECT id, name FROM users")
rows := q.WithContext(ctx).All()
```


## Building Queries

Instead of writing plain SQLs, `ozzo-dbx` allows you to build SQLs programmatically, which often leads to cleaner,
Expand Down
42 changes: 42 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package dbx

import (
"bytes"
"context"
"database/sql"
"regexp"
"strings"
Expand Down Expand Up @@ -131,6 +132,15 @@ func (db *DB) Begin() (*Tx, error) {
return &Tx{db.newBuilder(tx), tx}, nil
}

// BeginTx starts a transaction with the given context and transaction options.
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.sqlDB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{db.newBuilder(tx), tx}, nil
}

// Wrap encapsulates an existing transaction.
func (db *DB) Wrap(sqlTx *sql.Tx) *Tx {
return &Tx{db.newBuilder(sqlTx), sqlTx}
Expand Down Expand Up @@ -168,6 +178,38 @@ func (db *DB) Transactional(f func(*Tx) error) (err error) {
return err
}

// TransactionalContext starts a transaction and executes the given function with the given context and transaction options.
// If the function returns an error, the transaction will be rolled back.
// Otherwise, the transaction will be committed.
func (db *DB) TransactionalContext(ctx context.Context, opts *sql.TxOptions, f func(*Tx) error) (err error) {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return err
}

defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
if err2 := tx.Rollback(); err2 != nil {
if err2 == sql.ErrTxDone {
return
}
err = Errors{err, err2}
}
} else {
if err = tx.Commit(); err == sql.ErrTxDone {
err = nil
}
}
}()

err = f(tx)

return err
}

// DriverName returns the name of the DB driver.
func (db *DB) DriverName() string {
return db.driverName
Expand Down
22 changes: 18 additions & 4 deletions model_query.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dbx

import (
"context"
"errors"
"fmt"
"reflect"
Expand All @@ -15,6 +16,7 @@ type (
// ModelQuery represents a query associated with a struct model.
ModelQuery struct {
db *DB
ctx context.Context
builder Builder
model *structValue
exclude []string
Expand All @@ -39,6 +41,18 @@ func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder
return q
}


// Context returns the context associated with the query.
func (q *ModelQuery) Context() context.Context {
return q.ctx
}

// WithContext associates a context with the query.
func (q *ModelQuery) WithContext(ctx context.Context) *ModelQuery {
q.ctx = ctx
return q
}

// Exclude excludes the specified struct fields from being inserted/updated into the DB table.
func (q *ModelQuery) Exclude(attrs ...string) *ModelQuery {
q.exclude = attrs
Expand Down Expand Up @@ -68,12 +82,12 @@ func (q *ModelQuery) Insert(attrs ...string) error {
}

if pkName == "" {
_, err := q.builder.Insert(q.model.tableName, Params(cols)).Execute()
_, err := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx).Execute()
return err
}

// handle auto-incremental PK
query := q.builder.Insert(q.model.tableName, Params(cols))
query := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx)
pkValue, err := insertAndReturnPK(q.db, query, pkName)
if err != nil {
return err
Expand Down Expand Up @@ -142,7 +156,7 @@ func (q *ModelQuery) Update(attrs ...string) error {
for name := range pk {
delete(cols, name)
}
_, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).Execute()
_, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).WithContext(q.ctx).Execute()
return err
}

Expand All @@ -155,6 +169,6 @@ func (q *ModelQuery) Delete() error {
if len(pk) == 0 {
return MissingPKError
}
_, err := q.builder.Delete(q.model.tableName, HashExp(pk)).Execute()
_, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute()
return err
}
45 changes: 39 additions & 6 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package dbx

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
Expand All @@ -21,8 +22,12 @@ type Params map[string]interface{}
type Executor interface {
// Exec executes a SQL statement
Exec(query string, args ...interface{}) (sql.Result, error)
// ExecContext executes a SQL statement with the given context
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
// Query queries a SQL statement
Query(query string, args ...interface{}) (*sql.Rows, error)
// QueryContext queries a SQL statement with the given context
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
// Prepare creates a prepared statement
Prepare(query string) (*sql.Stmt, error)
}
Expand All @@ -36,6 +41,7 @@ type Query struct {
params Params

stmt *sql.Stmt
ctx context.Context

// FieldMapper maps struct field names to DB column names.
FieldMapper FieldMapFunc
Expand Down Expand Up @@ -70,6 +76,17 @@ func (q *Query) SQL() string {
return q.sql
}

// Context returns the context associated with the query.
func (q *Query) Context() context.Context {
return q.ctx
}

// WithContext associates a context with the query.
func (q *Query) WithContext(ctx context.Context) *Query {
q.ctx = ctx
return q
}

// logSQL returns the SQL statement with parameters being replaced with the actual values.
// The result is only for logging purpose and should not be used to execute.
func (q *Query) logSQL() string {
Expand Down Expand Up @@ -168,10 +185,18 @@ func (q *Query) Execute() (result sql.Result, err error) {

defer q.log(time.Now(), true)

if q.stmt == nil {
result, err = q.executor.Exec(q.rawSQL, params...)
if q.ctx == nil {
if q.stmt == nil {
result, err = q.executor.Exec(q.rawSQL, params...)
} else {
result, err = q.stmt.Exec(params...)
}
} else {
result, err = q.stmt.Exec(params...)
if q.stmt == nil {
result, err = q.executor.ExecContext(q.ctx, q.rawSQL, params...)
} else {
result, err = q.stmt.ExecContext(q.ctx, params...)
}
}
return
}
Expand Down Expand Up @@ -238,10 +263,18 @@ func (q *Query) Rows() (rows *Rows, err error) {
defer q.log(time.Now(), false)

var rr *sql.Rows
if q.stmt == nil {
rr, err = q.executor.Query(q.rawSQL, params...)
if q.ctx == nil {
if q.stmt == nil {
rr, err = q.executor.Query(q.rawSQL, params...)
} else {
rr, err = q.stmt.Query(params...)
}
} else {
rr, err = q.stmt.Query(params...)
if q.stmt == nil {
rr, err = q.executor.QueryContext(q.ctx, q.rawSQL, params...)
} else {
rr, err = q.stmt.QueryContext(q.ctx, params...)
}
}
rows = &Rows{rr, q.FieldMapper}
return
Expand Down

0 comments on commit e984ee9

Please sign in to comment.