diff --git a/db.go b/db.go index 2f52a2248..c283f56bd 100644 --- a/db.go +++ b/db.go @@ -35,8 +35,7 @@ func WithDiscardUnknownColumns() DBOption { type DB struct { *sql.DB - dialect schema.Dialect - features feature.Feature + dialect schema.Dialect queryHooks []QueryHook @@ -50,10 +49,9 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { dialect.Init(sqldb) db := &DB{ - DB: sqldb, - dialect: dialect, - features: dialect.Features(), - fmter: schema.NewFormatter(dialect), + DB: sqldb, + dialect: dialect, + fmter: schema.NewFormatter(dialect), } for _, opt := range opts { @@ -231,7 +229,7 @@ func (db *DB) UpdateFQN(alias, column string) Ident { // HasFeature uses feature package to report whether the underlying DBMS supports this feature. func (db *DB) HasFeature(feat feature.Feature) bool { - return db.fmter.HasFeature(feat) + return db.dialect.Features().Has(feat) } //------------------------------------------------------------------------------ @@ -513,7 +511,7 @@ func (tx Tx) commitTX() error { } func (tx Tx) commitSP() error { - if tx.Dialect().Features().Has(feature.MSSavepoint) { + if tx.db.HasFeature(feature.MSSavepoint) { return nil } query := "RELEASE SAVEPOINT " + tx.name @@ -537,7 +535,7 @@ func (tx Tx) rollbackTX() error { func (tx Tx) rollbackSP() error { query := "ROLLBACK TO SAVEPOINT " + tx.name - if tx.Dialect().Features().Has(feature.MSSavepoint) { + if tx.db.HasFeature(feature.MSSavepoint) { query = "ROLLBACK TRANSACTION " + tx.name } _, err := tx.ExecContext(tx.ctx, query) @@ -601,7 +599,7 @@ func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) { qName := "SP_" + hex.EncodeToString(sp) query := "SAVEPOINT " + qName - if tx.Dialect().Features().Has(feature.MSSavepoint) { + if tx.db.HasFeature(feature.MSSavepoint) { query = "SAVE TRANSACTION " + qName } _, err = tx.ExecContext(ctx, query) diff --git a/dialect/mssqldialect/dialect.go b/dialect/mssqldialect/dialect.go index bde140963..651790e9f 100755 --- a/dialect/mssqldialect/dialect.go +++ b/dialect/mssqldialect/dialect.go @@ -39,7 +39,7 @@ type Dialect struct { features feature.Feature } -func New() *Dialect { +func New(opts ...DialectOption) *Dialect { d := new(Dialect) d.tables = schema.NewTables(d) d.features = feature.CTE | @@ -49,9 +49,21 @@ func New() *Dialect { feature.OffsetFetch | feature.UpdateFromTable | feature.MSSavepoint + + for _, opt := range opts { + opt(d) + } return d } +type DialectOption func(d *Dialect) + +func WithoutFeature(other feature.Feature) DialectOption { + return func(d *Dialect) { + d.features = d.features.Remove(other) + } +} + func (d *Dialect) Init(db *sql.DB) { var version string if err := db.QueryRow("SELECT @@VERSION").Scan(&version); err != nil { diff --git a/dialect/mysqldialect/dialect.go b/dialect/mysqldialect/dialect.go index 90a2b9cd9..83644c987 100644 --- a/dialect/mysqldialect/dialect.go +++ b/dialect/mysqldialect/dialect.go @@ -27,8 +27,6 @@ func init() { } } -type DialectOption func(d *Dialect) - type Dialect struct { schema.BaseDialect @@ -60,6 +58,8 @@ func New(opts ...DialectOption) *Dialect { return d } +type DialectOption func(d *Dialect) + func WithTimeLocation(loc string) DialectOption { return func(d *Dialect) { location, err := time.LoadLocation(loc) @@ -70,6 +70,12 @@ func WithTimeLocation(loc string) DialectOption { } } +func WithoutFeature(other feature.Feature) DialectOption { + return func(d *Dialect) { + d.features = d.features.Remove(other) + } +} + func (d *Dialect) Init(db *sql.DB) { var version string if err := db.QueryRow("SELECT version()").Scan(&version); err != nil { diff --git a/dialect/oracledialect/dialect.go b/dialect/oracledialect/dialect.go index c9d3d3dda..71f87d198 100644 --- a/dialect/oracledialect/dialect.go +++ b/dialect/oracledialect/dialect.go @@ -27,7 +27,7 @@ type Dialect struct { features feature.Feature } -func New() *Dialect { +func New(opts ...DialectOption) *Dialect { d := new(Dialect) d.tables = schema.NewTables(d) d.features = feature.CTE | @@ -42,9 +42,22 @@ func New() *Dialect { feature.AutoIncrement | feature.CompositeIn | feature.DeleteReturning + + for _, opt := range opts { + opt(d) + } + return d } +type DialectOption func(d *Dialect) + +func WithoutFeature(other feature.Feature) DialectOption { + return func(d *Dialect) { + d.features = d.features.Remove(other) + } +} + func (d *Dialect) Init(*sql.DB) {} func (d *Dialect) Name() dialect.Name { diff --git a/dialect/pgdialect/dialect.go b/dialect/pgdialect/dialect.go index 040163f98..cdca0444e 100644 --- a/dialect/pgdialect/dialect.go +++ b/dialect/pgdialect/dialect.go @@ -34,7 +34,7 @@ var _ schema.Dialect = (*Dialect)(nil) var _ sqlschema.InspectorDialect = (*Dialect)(nil) var _ sqlschema.MigratorDialect = (*Dialect)(nil) -func New() *Dialect { +func New(opts ...DialectOption) *Dialect { d := new(Dialect) d.tables = schema.NewTables(d) d.features = feature.CTE | @@ -55,9 +55,22 @@ func New() *Dialect { feature.GeneratedIdentity | feature.CompositeIn | feature.DeleteReturning + + for _, opt := range opts { + opt(d) + } + return d } +type DialectOption func(d *Dialect) + +func WithoutFeature(other feature.Feature) DialectOption { + return func(d *Dialect) { + d.features = d.features.Remove(other) + } +} + func (d *Dialect) Init(*sql.DB) {} func (d *Dialect) Name() dialect.Name { diff --git a/dialect/sqlitedialect/dialect.go b/dialect/sqlitedialect/dialect.go index 92959482e..1280d0d69 100644 --- a/dialect/sqlitedialect/dialect.go +++ b/dialect/sqlitedialect/dialect.go @@ -26,7 +26,7 @@ type Dialect struct { features feature.Feature } -func New() *Dialect { +func New(opts ...DialectOption) *Dialect { d := new(Dialect) d.tables = schema.NewTables(d) d.features = feature.CTE | @@ -42,9 +42,22 @@ func New() *Dialect { feature.AutoIncrement | feature.CompositeIn | feature.DeleteReturning + + for _, opt := range opts { + opt(d) + } + return d } +type DialectOption func(d *Dialect) + +func WithoutFeature(other feature.Feature) DialectOption { + return func(d *Dialect) { + d.features = d.features.Remove(other) + } +} + func (d *Dialect) Init(*sql.DB) {} func (d *Dialect) Name() dialect.Name { diff --git a/model_map_slice.go b/model_map_slice.go index 1e96c898c..8e4a22f6b 100644 --- a/model_map_slice.go +++ b/model_map_slice.go @@ -99,7 +99,7 @@ func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte slice := *m.dest b = append(b, "VALUES "...) - if m.db.features.Has(feature.ValuesRow) { + if m.db.HasFeature(feature.ValuesRow) { b = append(b, "ROW("...) } else { b = append(b, '(') @@ -118,7 +118,7 @@ func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte for i, el := range slice { if i > 0 { b = append(b, "), "...) - if m.db.features.Has(feature.ValuesRow) { + if m.db.HasFeature(feature.ValuesRow) { b = append(b, "ROW("...) } else { b = append(b, '(') diff --git a/query_base.go b/query_base.go index 52b0c1e22..08ff8e5d9 100644 --- a/query_base.go +++ b/query_base.go @@ -201,7 +201,7 @@ func (q *baseQuery) beforeAppendModel(ctx context.Context, query Query) error { } func (q *baseQuery) hasFeature(feature feature.Feature) bool { - return q.db.features.Has(feature) + return q.db.HasFeature(feature) } //------------------------------------------------------------------------------ diff --git a/query_delete.go b/query_delete.go index ccfeb1997..1235ba718 100644 --- a/query_delete.go +++ b/query_delete.go @@ -201,7 +201,7 @@ func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e return upd.AppendQuery(fmter, b) } - withAlias := q.db.features.Has(feature.DeleteTableAlias) + withAlias := q.db.HasFeature(feature.DeleteTableAlias) b, err = q.appendWith(fmter, b) if err != nil { diff --git a/query_insert.go b/query_insert.go index b6747cd65..8bec4ce26 100644 --- a/query_insert.go +++ b/query_insert.go @@ -190,7 +190,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e } b = append(b, "INTO "...) - if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() { + if q.db.HasFeature(feature.InsertTableAlias) && !q.on.IsZero() { b, err = q.appendFirstTableWithAlias(fmter, b) } else { b, err = q.appendFirstTable(fmter, b) @@ -385,9 +385,9 @@ func (q *InsertQuery) appendSliceValues( } func (q *InsertQuery) getFields() ([]*schema.Field, error) { - hasIdentity := q.db.features.Has(feature.Identity) + hasIdentity := q.db.HasFeature(feature.Identity) - if len(q.columns) > 0 || q.db.features.Has(feature.DefaultPlaceholder) && !hasIdentity { + if len(q.columns) > 0 || q.db.HasFeature(feature.DefaultPlaceholder) && !hasIdentity { return q.baseQuery.getFields() } @@ -640,8 +640,8 @@ func (q *InsertQuery) afterInsertHook(ctx context.Context) error { } func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error { - if q.db.features.Has(feature.Returning) || - q.db.features.Has(feature.Output) || + if q.db.HasFeature(feature.Returning) || + q.db.HasFeature(feature.Output) || q.table == nil || len(q.table.PKs) != 1 || !q.table.PKs[0].AutoIncrement { diff --git a/query_table_truncate.go b/query_table_truncate.go index 9ac5599d9..1db81fb53 100644 --- a/query_table_truncate.go +++ b/query_table_truncate.go @@ -110,7 +110,7 @@ func (q *TruncateTableQuery) AppendQuery( return nil, err } - if q.db.features.Has(feature.TableIdentity) { + if q.db.HasFeature(feature.TableIdentity) { if q.continueIdentity { b = append(b, " CONTINUE IDENTITY"...) } else { diff --git a/query_values.go b/query_values.go index 5c2abef60..34deb1ee4 100644 --- a/query_values.go +++ b/query_values.go @@ -145,7 +145,7 @@ func (q *ValuesQuery) appendQuery( fields []*schema.Field, ) (_ []byte, err error) { b = append(b, "VALUES "...) - if q.db.features.Has(feature.ValuesRow) { + if q.db.HasFeature(feature.ValuesRow) { b = append(b, "ROW("...) } else { b = append(b, '(') @@ -168,7 +168,7 @@ func (q *ValuesQuery) appendQuery( for i := 0; i < sliceLen; i++ { if i > 0 { b = append(b, "), "...) - if q.db.features.Has(feature.ValuesRow) { + if q.db.HasFeature(feature.ValuesRow) { b = append(b, "ROW("...) } else { b = append(b, '(') diff --git a/relation_join.go b/relation_join.go index 487f776ed..19dede4f9 100644 --- a/relation_join.go +++ b/relation_join.go @@ -63,7 +63,7 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { var where []byte - if q.db.dialect.Features().Has(feature.CompositeIn) { + if q.db.HasFeature(feature.CompositeIn) { return j.manyQueryCompositeIn(where, q) } return j.manyQueryMulti(where, q)