diff --git a/db.go b/db.go index c283f56b..abeb7aa9 100644 --- a/db.go +++ b/db.go @@ -9,6 +9,7 @@ import ( "reflect" "strings" "sync/atomic" + "time" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" @@ -32,32 +33,55 @@ func WithDiscardUnknownColumns() DBOption { } } -type DB struct { - *sql.DB +func WithReadOnlyReplica(replica *sql.DB) DBOption { + return func(db *DB) { + db.replicas = append(db.replicas, replica) + } +} - dialect schema.Dialect +type DB struct { + // Must be a pointer so we copy the state, not the state fields. + *noCopyState queryHooks []QueryHook fmter schema.Formatter - flags internal.Flag - stats DBStats } +// noCopyState contains DB fields that must not be copied on clone(), +// for example, it is forbidden to copy atomic.Pointer. +type noCopyState struct { + *sql.DB + dialect schema.Dialect + + replicas []*sql.DB + healthyReplicas atomic.Pointer[[]*sql.DB] + nextReplica atomic.Int64 + + flags internal.Flag + closed atomic.Bool +} + func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { dialect.Init(sqldb) db := &DB{ - DB: sqldb, - dialect: dialect, - fmter: schema.NewFormatter(dialect), + noCopyState: &noCopyState{ + DB: sqldb, + dialect: dialect, + }, + fmter: schema.NewFormatter(dialect), } for _, opt := range opts { opt(db) } + if len(db.replicas) > 0 { + go db.monitorReplicas() + } + return db } @@ -69,6 +93,11 @@ func (db *DB) String() string { return b.String() } +func (db *DB) Close() error { + db.closed.Store(true) + return db.DB.Close() +} + func (db *DB) DBStats() DBStats { return DBStats{ Queries: atomic.LoadUint32(&db.stats.Queries), @@ -232,6 +261,44 @@ func (db *DB) HasFeature(feat feature.Feature) bool { return db.dialect.Features().Has(feat) } +// healthyReplica returns a random healthy replica. +func (db *DB) healthyReplica() *sql.DB { + replicas := db.loadHealthyReplicas() + if len(replicas) == 0 { + return db.DB + } + if len(replicas) == 1 { + return replicas[0] + } + i := db.nextReplica.Add(1) + return replicas[int(i)%len(replicas)] +} + +func (db *DB) loadHealthyReplicas() []*sql.DB { + if ptr := db.healthyReplicas.Load(); ptr != nil { + return *ptr + } + return nil +} + +func (db *DB) monitorReplicas() { + for !db.closed.Load() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + healthy := make([]*sql.DB, 0, len(db.replicas)) + + for _, replica := range db.replicas { + if err := replica.PingContext(ctx); err == nil { + healthy = append(healthy, replica) + } + } + + db.healthyReplicas.Store(&healthy) + time.Sleep(5 * time.Second) + } +} + //------------------------------------------------------------------------------ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { diff --git a/internal/dbtest/docker-compose.yaml b/internal/dbtest/docker-compose.yaml index d47ab5b3..0223bf70 100755 --- a/internal/dbtest/docker-compose.yaml +++ b/internal/dbtest/docker-compose.yaml @@ -1,5 +1,3 @@ -version: '3.9' - services: mysql8: image: mysql:8.0 diff --git a/query_base.go b/query_base.go index 08ff8e5d..27f55742 100644 --- a/query_base.go +++ b/query_base.go @@ -24,7 +24,7 @@ const ( type withQuery struct { name string - query schema.QueryAppender + query Query recursive bool } @@ -114,8 +114,27 @@ func (q *baseQuery) DB() *DB { return q.db } -func (q *baseQuery) GetConn() IConn { - return q.conn +func (q *baseQuery) resolveConn(query Query) IConn { + if q.conn != nil { + return q.conn + } + if len(q.db.replicas) == 0 || !isReadOnlyQuery(query) { + return q.db.DB + } + return q.db.healthyReplica() +} + +func isReadOnlyQuery(query Query) bool { + sel, ok := query.(*SelectQuery) + if !ok { + return false + } + for _, el := range sel.with { + if !isReadOnlyQuery(el.query) { + return false + } + } + return true } func (q *baseQuery) GetModel() Model { @@ -128,10 +147,8 @@ func (q *baseQuery) GetTableName() string { } for _, wq := range q.with { - if v, ok := wq.query.(Query); ok { - if model := v.GetModel(); model != nil { - return v.GetTableName() - } + if model := wq.query.GetModel(); model != nil { + return wq.query.GetTableName() } } @@ -249,7 +266,7 @@ func (q *baseQuery) isSoftDelete() bool { //------------------------------------------------------------------------------ -func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) { +func (q *baseQuery) addWith(name string, query Query, recursive bool) { q.with = append(q.with, withQuery{ name: name, query: query, @@ -565,28 +582,33 @@ func (q *baseQuery) scan( hasDest bool, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) + res, err := q._scan(ctx, iquery, query, model, hasDest) + q.db.afterQuery(ctx, event, res, err) + return res, err +} - rows, err := q.conn.QueryContext(ctx, query) +func (q *baseQuery) _scan( + ctx context.Context, + iquery Query, + query string, + model Model, + hasDest bool, +) (sql.Result, error) { + rows, err := q.resolveConn(iquery).QueryContext(ctx, query) if err != nil { - q.db.afterQuery(ctx, event, nil, err) return nil, err } defer rows.Close() numRow, err := model.ScanRows(ctx, rows) if err != nil { - q.db.afterQuery(ctx, event, nil, err) return nil, err } if numRow == 0 && hasDest && isSingleRowModel(model) { - err = sql.ErrNoRows + return nil, sql.ErrNoRows } - - res := driver.RowsAffected(numRow) - q.db.afterQuery(ctx, event, res, err) - - return res, err + return driver.RowsAffected(numRow), nil } func (q *baseQuery) exec( @@ -595,7 +617,7 @@ func (q *baseQuery) exec( query string, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) - res, err := q.conn.ExecContext(ctx, query) + res, err := q.resolveConn(iquery).ExecContext(ctx, query) q.db.afterQuery(ctx, event, res, err) return res, err } diff --git a/query_column_add.go b/query_column_add.go index 50576873..de4ff15f 100644 --- a/query_column_add.go +++ b/query_column_add.go @@ -20,8 +20,7 @@ var _ Query = (*AddColumnQuery)(nil) func NewAddColumnQuery(db *DB) *AddColumnQuery { q := &AddColumnQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_column_drop.go b/query_column_drop.go index 24fc93cf..a67084a6 100644 --- a/query_column_drop.go +++ b/query_column_drop.go @@ -18,8 +18,7 @@ var _ Query = (*DropColumnQuery)(nil) func NewDropColumnQuery(db *DB) *DropColumnQuery { q := &DropColumnQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_delete.go b/query_delete.go index 1235ba71..3467fdb8 100644 --- a/query_delete.go +++ b/query_delete.go @@ -23,8 +23,7 @@ func NewDeleteQuery(db *DB) *DeleteQuery { q := &DeleteQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -56,12 +55,12 @@ func (q *DeleteQuery) Apply(fns ...func(*DeleteQuery) *DeleteQuery) *DeleteQuery return q } -func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery { +func (q *DeleteQuery) With(name string, query Query) *DeleteQuery { q.addWith(name, query, false) return q } -func (q *DeleteQuery) WithRecursive(name string, query schema.QueryAppender) *DeleteQuery { +func (q *DeleteQuery) WithRecursive(name string, query Query) *DeleteQuery { q.addWith(name, query, true) return q } diff --git a/query_index_create.go b/query_index_create.go index 11824cfa..f229bb5c 100644 --- a/query_index_create.go +++ b/query_index_create.go @@ -28,8 +28,7 @@ func NewCreateIndexQuery(db *DB) *CreateIndexQuery { q := &CreateIndexQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } diff --git a/query_index_drop.go b/query_index_drop.go index ae28e795..6300bb67 100644 --- a/query_index_drop.go +++ b/query_index_drop.go @@ -23,8 +23,7 @@ var _ Query = (*DropIndexQuery)(nil) func NewDropIndexQuery(db *DB) *DropIndexQuery { q := &DropIndexQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_insert.go b/query_insert.go index 8bec4ce2..3013a51d 100644 --- a/query_insert.go +++ b/query_insert.go @@ -30,8 +30,7 @@ func NewInsertQuery(db *DB) *InsertQuery { q := &InsertQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -63,12 +62,12 @@ func (q *InsertQuery) Apply(fns ...func(*InsertQuery) *InsertQuery) *InsertQuery return q } -func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery { +func (q *InsertQuery) With(name string, query Query) *InsertQuery { q.addWith(name, query, false) return q } -func (q *InsertQuery) WithRecursive(name string, query schema.QueryAppender) *InsertQuery { +func (q *InsertQuery) WithRecursive(name string, query Query) *InsertQuery { q.addWith(name, query, true) return q } diff --git a/query_merge.go b/query_merge.go index 3c3f4f7f..aa30456a 100644 --- a/query_merge.go +++ b/query_merge.go @@ -25,8 +25,7 @@ var _ Query = (*MergeQuery)(nil) func NewMergeQuery(db *DB) *MergeQuery { q := &MergeQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } if q.db.dialect.Name() != dialect.MSSQL && q.db.dialect.Name() != dialect.PG { @@ -60,12 +59,12 @@ func (q *MergeQuery) Apply(fns ...func(*MergeQuery) *MergeQuery) *MergeQuery { return q } -func (q *MergeQuery) With(name string, query schema.QueryAppender) *MergeQuery { +func (q *MergeQuery) With(name string, query Query) *MergeQuery { q.addWith(name, query, false) return q } -func (q *MergeQuery) WithRecursive(name string, query schema.QueryAppender) *MergeQuery { +func (q *MergeQuery) WithRecursive(name string, query Query) *MergeQuery { q.addWith(name, query, true) return q } diff --git a/query_raw.go b/query_raw.go index 1634d0e5..b1f43af9 100644 --- a/query_raw.go +++ b/query_raw.go @@ -14,23 +14,10 @@ type RawQuery struct { args []interface{} } -// Deprecated: Use NewRaw instead. When add it to IDB, it conflicts with the sql.Conn#Raw -func (db *DB) Raw(query string, args ...interface{}) *RawQuery { - return &RawQuery{ - baseQuery: baseQuery{ - db: db, - conn: db.DB, - }, - query: query, - args: args, - } -} - func NewRawQuery(db *DB, query string, args ...interface{}) *RawQuery { return &RawQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, query: query, args: args, diff --git a/query_select.go b/query_select.go index 2b0872ae..95e04f45 100644 --- a/query_select.go +++ b/query_select.go @@ -40,8 +40,7 @@ func NewSelectQuery(db *DB) *SelectQuery { return &SelectQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -72,12 +71,12 @@ func (q *SelectQuery) Apply(fns ...func(*SelectQuery) *SelectQuery) *SelectQuery return q } -func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) With(name string, query Query) *SelectQuery { q.addWith(name, query, false) return q } -func (q *SelectQuery) WithRecursive(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) WithRecursive(name string, query Query) *SelectQuery { q.addWith(name, query, true) return q } @@ -749,7 +748,7 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { query := internal.String(queryBytes) ctx, event := q.db.beforeQuery(ctx, q, query, nil, query, q.model) - rows, err := q.conn.QueryContext(ctx, query) + rows, err := q.resolveConn(q).QueryContext(ctx, query) q.db.afterQuery(ctx, event, nil, err) return rows, err } @@ -877,7 +876,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var num int - err = q.conn.QueryRowContext(ctx, query).Scan(&num) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&num) q.db.afterQuery(ctx, event, nil, err) @@ -895,13 +894,15 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in return int(n), nil } } - if _, ok := q.conn.(*DB); ok { - return q.scanAndCountConc(ctx, dest...) + if q.conn == nil { + return q.scanAndCountConcurrently(ctx, dest...) } return q.scanAndCountSeq(ctx, dest...) } -func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) { +func (q *SelectQuery) scanAndCountConcurrently( + ctx context.Context, dest ...interface{}, +) (int, error) { var count int var wg sync.WaitGroup var mu sync.Mutex @@ -979,7 +980,7 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var exists bool - err = q.conn.QueryRowContext(ctx, query).Scan(&exists) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&exists) q.db.afterQuery(ctx, event, nil, err) diff --git a/query_table_create.go b/query_table_create.go index aeb79cd3..ce14deb4 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -39,8 +39,7 @@ var _ Query = (*CreateTableQuery)(nil) func NewCreateTableQuery(db *DB) *CreateTableQuery { q := &CreateTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, varchar: db.Dialect().DefaultVarcharLen(), } diff --git a/query_table_drop.go b/query_table_drop.go index a9201451..e937723a 100644 --- a/query_table_drop.go +++ b/query_table_drop.go @@ -20,8 +20,7 @@ var _ Query = (*DropTableQuery)(nil) func NewDropTableQuery(db *DB) *DropTableQuery { q := &DropTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_table_truncate.go b/query_table_truncate.go index 1db81fb5..9805630e 100644 --- a/query_table_truncate.go +++ b/query_table_truncate.go @@ -21,8 +21,7 @@ var _ Query = (*TruncateTableQuery)(nil) func NewTruncateTableQuery(db *DB) *TruncateTableQuery { q := &TruncateTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_update.go b/query_update.go index bb926408..1d6fd3ba 100644 --- a/query_update.go +++ b/query_update.go @@ -31,8 +31,7 @@ func NewUpdateQuery(db *DB) *UpdateQuery { q := &UpdateQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -64,12 +63,12 @@ func (q *UpdateQuery) Apply(fns ...func(*UpdateQuery) *UpdateQuery) *UpdateQuery return q } -func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery { +func (q *UpdateQuery) With(name string, query Query) *UpdateQuery { q.addWith(name, query, false) return q } -func (q *UpdateQuery) WithRecursive(name string, query schema.QueryAppender) *UpdateQuery { +func (q *UpdateQuery) WithRecursive(name string, query Query) *UpdateQuery { q.addWith(name, query, true) return q } diff --git a/query_values.go b/query_values.go index 34deb1ee..24b85aee 100644 --- a/query_values.go +++ b/query_values.go @@ -24,8 +24,7 @@ var ( func NewValuesQuery(db *DB, model interface{}) *ValuesQuery { q := &ValuesQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } q.setModel(model)