diff --git a/go/mysql/datetime/datetime.go b/go/mysql/datetime/datetime.go index cc1fc92e091..bf73ac85c27 100644 --- a/go/mysql/datetime/datetime.go +++ b/go/mysql/datetime/datetime.go @@ -94,8 +94,8 @@ func (t Time) FormatDecimal() decimal.Decimal { return dec } -func (t Time) ToDateTime() (out DateTime) { - return NewDateTimeFromStd(t.ToStdTime(time.Local)) +func (t Time) ToDateTime(now time.Time) (out DateTime) { + return NewDateTimeFromStd(t.ToStdTime(now)) } func (t Time) IsZero() bool { @@ -421,9 +421,9 @@ func (t Time) toStdTime(year int, month time.Month, day int, loc *time.Location) return time.Date(year, month, day, hours, minutes, secs, nsecs, loc) } -func (t Time) ToStdTime(loc *time.Location) (out time.Time) { - year, month, day := time.Now().Date() - return t.toStdTime(year, month, day, loc) +func (t Time) ToStdTime(now time.Time) (out time.Time) { + year, month, day := now.Date() + return t.toStdTime(year, month, day, now.Location()) } func (t Time) AddInterval(itv *Interval, stradd bool) (Time, uint8, bool) { @@ -444,7 +444,7 @@ func (d Date) ToStdTime(loc *time.Location) (out time.Time) { return time.Date(d.Year(), time.Month(d.Month()), d.Day(), 0, 0, 0, 0, loc) } -func (dt DateTime) ToStdTime(loc *time.Location) time.Time { +func (dt DateTime) ToStdTime(now time.Time) time.Time { zerodate := dt.Date.IsZero() zerotime := dt.Time.IsZero() @@ -452,12 +452,12 @@ func (dt DateTime) ToStdTime(loc *time.Location) time.Time { case zerodate && zerotime: return time.Time{} case zerodate: - return dt.Time.ToStdTime(loc) + return dt.Time.ToStdTime(now) case zerotime: - return dt.Date.ToStdTime(loc) + return dt.Date.ToStdTime(now.Location()) default: year, month, day := dt.Date.Year(), time.Month(dt.Date.Month()), dt.Date.Day() - return dt.Time.toStdTime(year, month, day, loc) + return dt.Time.toStdTime(year, month, day, now.Location()) } } @@ -527,7 +527,10 @@ func (dt DateTime) Compare(dt2 DateTime) int { // if we're comparing a time to a datetime, we need to normalize them // both into datetimes; this normalization is not trivial because negative // times result in a date change, so let the standard library handle this - return dt.ToStdTime(time.Local).Compare(dt2.ToStdTime(time.Local)) + + // Using the current time is OK here since the comparison is relative + now := time.Now() + return dt.ToStdTime(now).Compare(dt2.ToStdTime(now)) } if cmp := dt.Date.Compare(dt2.Date); cmp != 0 { return cmp @@ -559,9 +562,10 @@ func (dt DateTime) Round(p int) (r DateTime) { r = dt if n == 1e9 { r.Time.nanosecond = 0 - return NewDateTimeFromStd(r.ToStdTime(time.Local).Add(time.Second)) + r.addInterval(&Interval{timeparts: timeparts{sec: 1}, unit: IntervalSecond}) + } else { + r.Time.nanosecond = uint32(n) } - r.Time.nanosecond = uint32(n) return r } diff --git a/go/mysql/json/parser.go b/go/mysql/json/parser.go index 322c623058e..35278263877 100644 --- a/go/mysql/json/parser.go +++ b/go/mysql/json/parser.go @@ -678,7 +678,7 @@ func (v *Value) MarshalDate() string { func (v *Value) MarshalDateTime() string { if dt, ok := v.DateTime(); ok { - return dt.ToStdTime(time.Local).Format("2006-01-02 15:04:05.000000") + return dt.ToStdTime(time.Now()).Format("2006-01-02 15:04:05.000000") } return "" } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 105ee3f4721..47a7e30d17b 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll end := env.vm.sp - elseOffset for sp := env.vm.sp - stackDepth; sp < end; sp += 2 { if env.vm.stack[sp].(*evalInt64).i != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now) goto done } } if elseOffset != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now) } else { env.vm.stack[env.vm.sp-stackDepth] = nil } @@ -1110,7 +1110,7 @@ func (asm *assembler) Convert_xD(offset int) { // Need to explicitly check here or we otherwise // store a nil wrapper in an interface vs. a direct // nil. - d := evalToDate(env.vm.stack[env.vm.sp-offset]) + d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now) if d == nil { env.vm.stack[env.vm.sp-offset] = nil } else { @@ -1125,7 +1125,7 @@ func (asm *assembler) Convert_xD_nz(offset int) { // Need to explicitly check here or we otherwise // store a nil wrapper in an interface vs. a direct // nil. - d := evalToDate(env.vm.stack[env.vm.sp-offset]) + d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now) if d == nil || d.isZero() { env.vm.stack[env.vm.sp-offset] = nil } else { @@ -1140,7 +1140,7 @@ func (asm *assembler) Convert_xDT(offset, prec int) { // Need to explicitly check here or we otherwise // store a nil wrapper in an interface vs. a direct // nil. - dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec) + dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now) if dt == nil { env.vm.stack[env.vm.sp-offset] = nil } else { @@ -1155,7 +1155,7 @@ func (asm *assembler) Convert_xDT_nz(offset, prec int) { // Need to explicitly check here or we otherwise // store a nil wrapper in an interface vs. a direct // nil. - dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec) + dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now) if dt == nil || dt.isZero() { env.vm.stack[env.vm.sp-offset] = nil } else { @@ -4252,7 +4252,7 @@ func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) { } tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal) - env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}) + env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}, env.now) env.vm.sp-- return 1 }, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)") @@ -4274,7 +4274,7 @@ func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col col goto baddate } - env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col) + env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col, env.now) env.vm.sp-- return 1 diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index efcf0036acb..efbe3d0ed0c 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -457,6 +457,8 @@ func TestCompilerSingle(t *testing.T) { }, } + tz, _ := time.LoadLocation("Europe/Madrid") + for _, tc := range testCases { t.Run(tc.expression, func(t *testing.T) { expr, err := sqlparser.ParseExpr(tc.expression) @@ -478,6 +480,7 @@ func TestCompilerSingle(t *testing.T) { } env := evalengine.EmptyExpressionEnv() + env.SetTime(time.Date(2023, 10, 24, 12, 0, 0, 0, tz)) env.Row = tc.values expected, err := env.Evaluate(evalengine.Deoptimize(converted)) diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index fbc3cbca57d..97109775952 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -18,6 +18,7 @@ package evalengine import ( "strconv" + "time" "unicode/utf8" "vitess.io/vitess/go/hack" @@ -167,7 +168,7 @@ func evalIsTruthy(e eval) boolean { } } -func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) { +func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (eval, error) { if e == nil { return nil, nil } @@ -199,9 +200,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64: return evalToInt64(e).toUint64(), nil case sqltypes.Date: - return evalToDate(e), nil + return evalToDate(e, now), nil case sqltypes.Datetime, sqltypes.Timestamp: - return evalToDateTime(e, -1), nil + return evalToDateTime(e, -1, now), nil case sqltypes.Time: return evalToTime(e, -1), nil default: @@ -329,7 +330,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, err } // Separate return here to avoid nil wrapped in interface type - d := evalToDate(e) + d := evalToDate(e, time.Now()) if d == nil { return nil, nil } @@ -340,7 +341,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, err } // Separate return here to avoid nil wrapped in interface type - dt := evalToDateTime(e, -1) + dt := evalToDateTime(e, -1, time.Now()) if dt == nil { return nil, nil } diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index 13acc5bd290..34d1f17d7f8 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -1,6 +1,8 @@ package evalengine import ( + "time" + "vitess.io/vitess/go/hack" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/datetime" @@ -92,12 +94,12 @@ func (e *evalTemporal) toJSON() *evalJSON { } } -func (e *evalTemporal) toDateTime(l int) *evalTemporal { +func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { switch e.SQLType() { case sqltypes.Datetime, sqltypes.Date: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)} case sqltypes.Time: - return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(), prec: uint8(l)} + return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} default: panic("unreachable") } @@ -118,7 +120,7 @@ func (e *evalTemporal) toTime(l int) *evalTemporal { } } -func (e *evalTemporal) toDate() *evalTemporal { +func (e *evalTemporal) toDate(now time.Time) *evalTemporal { switch e.SQLType() { case sqltypes.Datetime: dt := datetime.DateTime{Date: e.dt.Date} @@ -126,7 +128,7 @@ func (e *evalTemporal) toDate() *evalTemporal { case sqltypes.Date: return e case sqltypes.Time: - dt := e.dt.Time.ToDateTime() + dt := e.dt.Time.ToDateTime(now) dt.Time = datetime.Time{} return &evalTemporal{t: sqltypes.Date, dt: dt} default: @@ -138,7 +140,7 @@ func (e *evalTemporal) isZero() bool { return e.dt.IsZero() } -func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation) eval { +func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation, now time.Time) eval { var tmp *evalTemporal var ok bool @@ -150,7 +152,7 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collatio tmp = &evalTemporal{t: e.t} tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid()) case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()): - tmp = e.toDateTime(int(e.prec)) + tmp = e.toDateTime(int(e.prec), now) tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid()) } if !ok { @@ -324,10 +326,10 @@ func evalToTime(e eval, l int) *evalTemporal { return nil } -func evalToDateTime(e eval, l int) *evalTemporal { +func evalToDateTime(e eval, l int, now time.Time) *evalTemporal { switch e := e.(type) { case *evalTemporal: - return e.toDateTime(precision(l, int(e.prec))) + return e.toDateTime(precision(l, int(e.prec)), now) case *evalBytes: if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() { return newEvalDateTime(t, l) @@ -371,10 +373,10 @@ func evalToDateTime(e eval, l int) *evalTemporal { return nil } -func evalToDate(e eval) *evalTemporal { +func evalToDate(e eval, now time.Time) *evalTemporal { switch e := e.(type) { case *evalTemporal: - return e.toDate() + return e.toDate(now) case *evalBytes: if t, _ := datetime.ParseDate(e.string()); !t.IsZero() { return newEvalDate(t) diff --git a/go/vt/vtgate/evalengine/expr_convert.go b/go/vt/vtgate/evalengine/expr_convert.go index 6531cdd6fae..4e60a8e3a8c 100644 --- a/go/vt/vtgate/evalengine/expr_convert.go +++ b/go/vt/vtgate/evalengine/expr_convert.go @@ -125,12 +125,12 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) { case p > 6: return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p) } - if dt := evalToDateTime(e, c.Length); dt != nil { + if dt := evalToDateTime(e, c.Length, env.now); dt != nil { return dt, nil } return nil, nil case "DATE": - if d := evalToDate(e); d != nil { + if d := evalToDate(e, env.now); d != nil { return d, nil } return nil, nil diff --git a/go/vt/vtgate/evalengine/expr_env.go b/go/vt/vtgate/evalengine/expr_env.go index e67e25e70a6..a92176dc8f1 100644 --- a/go/vt/vtgate/evalengine/expr_env.go +++ b/go/vt/vtgate/evalengine/expr_env.go @@ -99,6 +99,15 @@ func (env *ExpressionEnv) TypeOf(expr Expr, fields []*querypb.Field) (sqltypes.T return ty, f, nil } +func (env *ExpressionEnv) SetTime(now time.Time) { + // This function is called only once by NewExpressionEnv to ensure that all expressions in the same + // ExpressionEnv evaluate NOW() and similar SQL functions to the same value. + env.now = now + if tz := env.currentTimezone(); tz != nil { + env.now = env.now.In(tz) + } +} + // EmptyExpressionEnv returns a new ExpressionEnv with no bind vars or row func EmptyExpressionEnv() *ExpressionEnv { return NewExpressionEnv(context.Background(), nil, nil) @@ -108,14 +117,6 @@ func EmptyExpressionEnv() *ExpressionEnv { func NewExpressionEnv(ctx context.Context, bindVars map[string]*querypb.BindVariable, vc VCursor) *ExpressionEnv { env := &ExpressionEnv{BindVars: bindVars, vc: vc} env.user = callerid.ImmediateCallerIDFromContext(ctx) - - // The current time for this ExpressionEnv is set only once, during creation. - // This is to ensure that all expressions in the same ExpressionEnv evaluate NOW() - // and similar SQL functions to the same value. - env.now = time.Now() - - if tz := env.currentTimezone(); tz != nil { - env.now = env.now.In(tz) - } + env.SetTime(time.Now()) return env } diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index 189b68e4136..27765c695bf 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -520,7 +520,7 @@ func (c *CaseExpr) eval(env *ExpressionEnv) (eval, error) { return nil, nil } t, _ := c.typeof(env, nil) - return evalCoerce(result, t, ca.result().Collation) + return evalCoerce(result, t, ca.result().Collation, env.now) } func (c *CaseExpr) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) { diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 99e0f27f755..18586227bdb 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -282,9 +282,9 @@ func (b *builtinDateFormat) eval(env *ExpressionEnv) (eval, error) { var t *evalTemporal switch e := date.(type) { case *evalTemporal: - t = e.toDateTime(datetime.DefaultPrecision) + t = e.toDateTime(datetime.DefaultPrecision, env.now) default: - t = evalToDateTime(date, datetime.DefaultPrecision) + t = evalToDateTime(date, datetime.DefaultPrecision, env.now) if t == nil || t.isZero() { return nil, nil } @@ -381,7 +381,7 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - dt := evalToDateTime(n, -1) + dt := evalToDateTime(n, -1, env.now) if dt == nil || dt.isZero() { return nil, nil } @@ -445,7 +445,7 @@ func (b *builtinDate) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil { return nil, nil } @@ -482,7 +482,7 @@ func (b *builtinDayOfMonth) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil { return nil, nil } @@ -519,7 +519,7 @@ func (b *builtinDayOfWeek) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil || d.isZero() { return nil, nil } @@ -556,7 +556,7 @@ func (b *builtinDayOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil || d.isZero() { return nil, nil } @@ -1178,7 +1178,7 @@ func (b *builtinMonth) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil { return nil, nil } @@ -1215,7 +1215,7 @@ func (b *builtinMonthName) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil { return nil, nil } @@ -1258,7 +1258,7 @@ func (b *builtinQuarter) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil { return nil, nil } @@ -1364,9 +1364,9 @@ func dateTimeUnixTimestamp(env *ExpressionEnv, date eval) evalNumeric { var dt *evalTemporal switch e := date.(type) { case *evalTemporal: - dt = e.toDateTime(int(e.prec)) + dt = e.toDateTime(int(e.prec), env.now) default: - dt = evalToDateTime(date, -1) + dt = evalToDateTime(date, -1, env.now) if dt == nil || dt.isZero() { var prec int32 switch d := date.(type) { @@ -1386,15 +1386,11 @@ func dateTimeUnixTimestamp(env *ExpressionEnv, date eval) evalNumeric { } } - tz := env.currentTimezone() - if tz == nil { - tz = time.Local - } - - ts := dt.dt.ToStdTime(tz) + ts := dt.dt.ToStdTime(env.now) if dt.prec == 0 { return newEvalInt64(ts.Unix()) } + dec := decimal.New(ts.Unix(), 0) dec = dec.Add(decimal.New(int64(dt.dt.Time.Nanosecond()), -9)) return newEvalDecimalWithPrec(dec, int32(dt.prec)) @@ -1458,7 +1454,7 @@ func (b *builtinWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil || d.isZero() { return nil, nil } @@ -1522,7 +1518,7 @@ func (b *builtinWeekDay) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil || d.isZero() { return nil, nil } @@ -1560,7 +1556,7 @@ func (b *builtinWeekOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil || d.isZero() { return nil, nil } @@ -1600,7 +1596,7 @@ func (b *builtinYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil { return nil, nil } @@ -1640,7 +1636,7 @@ func (b *builtinYearWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date) + d := evalToDate(date, env.now) if d == nil || d.isZero() { return nil, nil } @@ -1726,11 +1722,11 @@ func (call *builtinDateMath) eval(env *ExpressionEnv) (eval, error) { } if tmp, ok := date.(*evalTemporal); ok { - return tmp.addInterval(interval, collations.TypedCollation{}), nil + return tmp.addInterval(interval, collations.TypedCollation{}, env.now), nil } if tmp := evalToTemporal(date); tmp != nil { - return tmp.addInterval(interval, defaultCoercionCollation(call.collate)), nil + return tmp.addInterval(interval, defaultCoercionCollation(call.collate), env.now), nil } return nil, nil diff --git a/go/vt/vtgate/evalengine/integration/fuzz_test.go b/go/vt/vtgate/evalengine/integration/fuzz_test.go index ebfaa486b19..8360e9e5baf 100644 --- a/go/vt/vtgate/evalengine/integration/fuzz_test.go +++ b/go/vt/vtgate/evalengine/integration/fuzz_test.go @@ -352,7 +352,7 @@ func compareResult(local, remote Result, cmp *testcases.Comparison) error { remoteCollationName = env.LookupName(coll) } - equals, err := cmp.Equals(local.Value, remote.Value) + equals, err := cmp.Equals(local.Value, remote.Value, time.Now()) if err != nil { return err } diff --git a/go/vt/vtgate/evalengine/testcases/helpers.go b/go/vt/vtgate/evalengine/testcases/helpers.go index f7cf5b22dd8..245d59992aa 100644 --- a/go/vt/vtgate/evalengine/testcases/helpers.go +++ b/go/vt/vtgate/evalengine/testcases/helpers.go @@ -164,7 +164,7 @@ func (cmp *Comparison) closeFloat(a, b float64) bool { return math.Abs((a-b)/b) < tolerance } -func (cmp *Comparison) Equals(local, remote sqltypes.Value) (bool, error) { +func (cmp *Comparison) Equals(local, remote sqltypes.Value, now time.Time) (bool, error) { switch { case local.IsFloat() && remote.IsFloat(): localFloat, err := local.ToFloat64() @@ -185,7 +185,7 @@ func (cmp *Comparison) Equals(local, remote sqltypes.Value) (bool, error) { if !ok { return false, fmt.Errorf("error converting remote value '%s' to datetime", remote) } - return cmp.closeDatetime(localDatetime.ToStdTime(time.Local), remoteDatetime.ToStdTime(time.Local), 1*time.Second), nil + return cmp.closeDatetime(localDatetime.ToStdTime(now), remoteDatetime.ToStdTime(now), 1*time.Second), nil case cmp.LooseTime && local.IsTime() && remote.IsTime(): localTime, _, ok := datetime.ParseTime(local.ToString(), -1) if !ok { @@ -195,7 +195,7 @@ func (cmp *Comparison) Equals(local, remote sqltypes.Value) (bool, error) { if !ok { return false, fmt.Errorf("error converting remote value '%s' to time", remote) } - return cmp.closeDatetime(localTime.ToStdTime(time.Local), remoteTime.ToStdTime(time.Local), 1*time.Second), nil + return cmp.closeDatetime(localTime.ToStdTime(now), remoteTime.ToStdTime(now), 1*time.Second), nil default: return local.String() == remote.String(), nil }