diff --git a/go/mysql/collations/charset/eightbit/8bit.go b/go/mysql/collations/charset/eightbit/8bit.go index 5bd930c61cb..12630749d5d 100644 --- a/go/mysql/collations/charset/eightbit/8bit.go +++ b/go/mysql/collations/charset/eightbit/8bit.go @@ -81,3 +81,17 @@ func (Charset_8bit) Length(src []byte) int { func (Charset_8bit) MaxWidth() int { return 1 } + +func (Charset_8bit) Slice(src []byte, from, to int) []byte { + if from >= len(src) { + return nil + } + if to > len(src) { + to = len(src) + } + return src[from:to] +} + +func (Charset_8bit) Validate(src []byte) bool { + return true +} diff --git a/go/mysql/collations/charset/eightbit/binary.go b/go/mysql/collations/charset/eightbit/binary.go index 44824bbc342..fa36fcf66a5 100644 --- a/go/mysql/collations/charset/eightbit/binary.go +++ b/go/mysql/collations/charset/eightbit/binary.go @@ -62,3 +62,17 @@ func (Charset_binary) Length(src []byte) int { func (Charset_binary) MaxWidth() int { return 1 } + +func (Charset_binary) Slice(src []byte, from, to int) []byte { + if from >= len(src) { + return nil + } + if to > len(src) { + to = len(src) + } + return src[from:to] +} + +func (Charset_binary) Validate(src []byte) bool { + return true +} diff --git a/go/mysql/collations/charset/eightbit/latin1.go b/go/mysql/collations/charset/eightbit/latin1.go index 67fa07c62c2..f32b4523a18 100644 --- a/go/mysql/collations/charset/eightbit/latin1.go +++ b/go/mysql/collations/charset/eightbit/latin1.go @@ -230,3 +230,17 @@ func (Charset_latin1) Length(src []byte) int { func (Charset_latin1) MaxWidth() int { return 1 } + +func (Charset_latin1) Slice(src []byte, from, to int) []byte { + if from >= len(src) { + return nil + } + if to > len(src) { + to = len(src) + } + return src[from:to] +} + +func (Charset_latin1) Validate(src []byte) bool { + return true +} diff --git a/go/mysql/collations/colldata/collation.go b/go/mysql/collations/colldata/collation.go index 7697c08cbed..a041006ddc7 100644 --- a/go/mysql/collations/colldata/collation.go +++ b/go/mysql/collations/colldata/collation.go @@ -17,6 +17,7 @@ limitations under the License. package colldata import ( + "bytes" "fmt" "math" @@ -380,3 +381,46 @@ coerceToRight: return charset.Convert(dst, rightCS, in, leftCS) }, nil, nil } + +func Index(col Collation, str, sub []byte, offset int) int { + cs := col.Charset() + if offset > 0 { + l := charset.Length(cs, str) + if offset > l { + return -1 + } + str = charset.Slice(cs, str, offset, len(str)) + } + + pos := instr(col, str, sub) + if pos < 0 { + return -1 + } + return offset + pos +} + +func instr(col Collation, str, sub []byte) int { + if len(sub) == 0 { + return 0 + } + + if len(str) == 0 { + return -1 + } + + if col.IsBinary() && col.Charset().MaxWidth() == 1 { + return bytes.Index(str, sub) + } + + var pos int + cs := col.Charset() + for len(str) > 0 { + if col.Collate(str, sub, true) == 0 { + return pos + } + _, size := cs.DecodeRune(str) + str = str[size:] + pos++ + } + return -1 +} diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 7525bfdaec4..c80fabb5dca 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1147,6 +1147,18 @@ func (cached *builtinLn) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinLocate) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinLog) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index e017a949a07..5097d54dbd6 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2986,6 +2986,40 @@ func (asm *assembler) Like_collate(expr *LikeExpr, collation colldata.Collation) }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) } +func (asm *assembler) Locate3(collation colldata.Collation) { + asm.adjustStack(-2) + + asm.emit(func(env *ExpressionEnv) int { + substr := env.vm.stack[env.vm.sp-3].(*evalBytes) + str := env.vm.stack[env.vm.sp-2].(*evalBytes) + pos := env.vm.stack[env.vm.sp-1].(*evalInt64) + env.vm.sp -= 2 + + if pos.i < 1 || pos.i > math.MaxInt { + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(0) + return 1 + } + + found := colldata.Index(collation, str.bytes, substr.bytes, int(pos.i)-1) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(found) + 1) + return 1 + }, "LOCATE VARCHAR(SP-3), VARCHAR(SP-2) INT64(SP-1) COLLATE '%s'", collation.Name()) +} + +func (asm *assembler) Locate2(collation colldata.Collation) { + asm.adjustStack(-1) + + asm.emit(func(env *ExpressionEnv) int { + substr := env.vm.stack[env.vm.sp-2].(*evalBytes) + str := env.vm.stack[env.vm.sp-1].(*evalBytes) + env.vm.sp-- + + found := colldata.Index(collation, str.bytes, substr.bytes, 0) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(found) + 1) + return 1 + }, "LOCATE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) +} + func (asm *assembler) Strcmp(collation collations.TypedCollation) { asm.adjustStack(-1) @@ -3833,11 +3867,6 @@ func (asm *assembler) Fn_LAST_DAY() { return 1 } arg := env.vm.stack[env.vm.sp-1].(*evalTemporal) - if arg.dt.IsZero() { - env.vm.stack[env.vm.sp-1] = nil - return 1 - } - d := lastDay(env.currentTimezone(), arg.dt) env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDate(d) return 1 @@ -3850,12 +3879,8 @@ func (asm *assembler) Fn_TO_DAYS() { return 1 } arg := env.vm.stack[env.vm.sp-1].(*evalTemporal) - if arg.dt.Date.IsZero() { - env.vm.stack[env.vm.sp-1] = nil - } else { - numDays := datetime.MysqlDayNumber(arg.dt.Date.Year(), arg.dt.Date.Month(), arg.dt.Date.Day()) - env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(numDays)) - } + numDays := datetime.MysqlDayNumber(arg.dt.Date.Year(), arg.dt.Date.Month(), arg.dt.Date.Day()) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(int64(numDays)) return 1 }, "FN TO_DAYS DATE(SP-1)") } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 7b2c92783ee..09e08ad0d48 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -615,6 +615,18 @@ func TestCompilerSingle(t *testing.T) { expression: `time('1111:66:56')`, result: `NULL`, }, + { + expression: `locate('Å', 'a')`, + result: `INT64(1)`, + }, + { + expression: `locate('a', 'Å')`, + result: `INT64(1)`, + }, + { + expression: `locate("", "😊😂🤢", 3)`, + result: `INT64(3)`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 23ff1cbdca3..e65800c9824 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -110,6 +110,11 @@ type ( CallExpr collate collations.ID } + + builtinLocate struct { + CallExpr + collate collations.ID + } ) var _ IR = (*builtinInsert)(nil) @@ -1265,6 +1270,111 @@ func (call *builtinSubstring) compile(c *compiler) (ctype, error) { return ctype{Type: tt, Col: col, Flag: flagNullable}, nil } +func (call *builtinLocate) eval(env *ExpressionEnv) (eval, error) { + substr, err := call.Arguments[0].eval(env) + if err != nil || substr == nil { + return nil, err + } + + str, err := call.Arguments[1].eval(env) + if err != nil || str == nil { + return nil, err + } + + if _, ok := str.(*evalBytes); !ok { + str, err = evalToVarchar(str, call.collate, true) + if err != nil { + return nil, err + } + } + + col := str.(*evalBytes).col.Collation + substr, err = evalToVarchar(substr, col, true) + if err != nil { + return nil, err + } + + pos := int64(1) + if len(call.Arguments) > 2 { + p, err := call.Arguments[2].eval(env) + if err != nil || p == nil { + return nil, err + } + pos = evalToInt64(p).i + if pos < 1 || pos > math.MaxInt { + return newEvalInt64(0), nil + } + } + + var coll colldata.Collation + if typeIsTextual(substr.SQLType()) && typeIsTextual(str.SQLType()) { + coll = colldata.Lookup(col) + } else { + coll = colldata.Lookup(collations.CollationBinaryID) + } + found := colldata.Index(coll, str.ToRawBytes(), substr.ToRawBytes(), int(pos)-1) + return newEvalInt64(int64(found) + 1), nil +} + +func (call *builtinLocate) compile(c *compiler) (ctype, error) { + substr, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + + str, err := call.Arguments[1].compile(c) + if err != nil { + return ctype{}, err + } + + skip1 := c.compileNullCheck2(substr, str) + var skip2 *jump + if len(call.Arguments) > 2 { + l, err := call.Arguments[2].compile(c) + if err != nil { + return ctype{}, err + } + skip2 = c.compileNullCheck2(str, l) + _ = c.compileToInt64(l, 1) + } + + if !str.isTextual() { + c.asm.Convert_xce(len(call.Arguments)-1, sqltypes.VarChar, c.collation) + str.Col = collations.TypedCollation{ + Collation: c.collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + fromCharset := colldata.Lookup(substr.Col.Collation).Charset() + toCharset := colldata.Lookup(str.Col.Collation).Charset() + if !substr.isTextual() || (fromCharset != toCharset && !toCharset.IsSuperset(fromCharset)) { + c.asm.Convert_xce(len(call.Arguments), sqltypes.VarChar, str.Col.Collation) + substr.Col = collations.TypedCollation{ + Collation: str.Col.Collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireASCII, + } + } + + var coll colldata.Collation + if typeIsTextual(substr.Type) && typeIsTextual(str.Type) { + coll = colldata.Lookup(str.Col.Collation) + } else { + coll = colldata.Lookup(collations.CollationBinaryID) + } + + if len(call.Arguments) > 2 { + c.asm.Locate3(coll) + } else { + c.asm.Locate2(coll) + } + + c.asm.jumpDestination(skip1, skip2) + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil +} + type builtinConcat struct { CallExpr collate collations.ID diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 319f8d1b328..430e4e123ac 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -282,8 +282,8 @@ func (b *builtinDateFormat) eval(env *ExpressionEnv) (eval, error) { case *evalTemporal: t = e.toDateTime(datetime.DefaultPrecision, env.now) default: - t = evalToDateTime(date, datetime.DefaultPrecision, env.now, env.sqlmode.AllowZeroDate()) - if t == nil || t.isZero() { + t = evalToDateTime(date, datetime.DefaultPrecision, env.now, false) + if t == nil { return nil, nil } } @@ -379,8 +379,8 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - dt := evalToDateTime(n, -1, env.now, env.sqlmode.AllowZeroDate()) - if dt == nil || dt.isZero() { + dt := evalToDateTime(n, -1, env.now, false) + if dt == nil { return nil, nil } @@ -388,7 +388,7 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { if !ok { return nil, nil } - return newEvalDateTime(out, int(dt.prec), env.sqlmode.AllowZeroDate()), nil + return newEvalDateTime(out, int(dt.prec), false), nil } func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { @@ -504,8 +504,8 @@ func (b *builtinDayOfWeek) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } return newEvalInt64(int64(d.dt.Date.Weekday() + 1)), nil @@ -537,8 +537,8 @@ func (b *builtinDayOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } return newEvalInt64(int64(d.dt.Date.ToStdTime(env.currentTimezone()).YearDay())), nil @@ -815,7 +815,7 @@ func (b *builtinMakedate) eval(env *ExpressionEnv) (eval, error) { if t.IsZero() { return nil, nil } - return newEvalDate(datetime.NewDateTimeFromStd(t).Date, env.sqlmode.AllowZeroDate()), nil + return newEvalDate(datetime.NewDateTimeFromStd(t).Date, false), nil } func (call *builtinMakedate) compile(c *compiler) (ctype, error) { @@ -1189,7 +1189,7 @@ func (b *builtinMonthName) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) + d := evalToDate(date, env.now, false) if d == nil { return nil, nil } @@ -1212,7 +1212,7 @@ func (call *builtinMonthName) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1, c.sqlmode.AllowZeroDate()) + c.asm.Convert_xD(1, false) } col := typedCoercionCollation(sqltypes.VarChar, call.collate) c.asm.Fn_MONTHNAME(col) @@ -1272,8 +1272,8 @@ func (b *builtinToDays) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - dt := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if dt == nil || dt.isZero() { + dt := evalToDate(date, env.now, false) + if dt == nil { return nil, nil } @@ -1292,7 +1292,7 @@ func (call *builtinToDays) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1, true) + c.asm.Convert_xD(1, false) } c.asm.Fn_TO_DAYS() c.asm.jumpDestination(skip) @@ -1477,8 +1477,8 @@ func dateTimeUnixTimestamp(env *ExpressionEnv, date eval) evalNumeric { case *evalTemporal: dt = e.toDateTime(int(e.prec), env.now) default: - dt = evalToDateTime(date, -1, env.now, env.sqlmode.AllowZeroDate()) - if dt == nil || dt.isZero() { + dt = evalToDateTime(date, -1, env.now, false) + if dt == nil { var prec int32 switch d := date.(type) { case *evalInt64, *evalUint64: @@ -1584,8 +1584,8 @@ func (b *builtinWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } @@ -1644,8 +1644,8 @@ func (b *builtinWeekDay) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } return newEvalInt64(int64(d.dt.Date.Weekday()+6) % 7), nil @@ -1678,8 +1678,8 @@ func (b *builtinWeekOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } @@ -1750,8 +1750,8 @@ func (b *builtinYearWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) - if d == nil || d.isZero() { + d := evalToDate(date, env.now, false) + if d == nil { return nil, nil } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 9d9cdfa248e..9dbb7276e12 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -82,6 +82,7 @@ var Cases = []TestCase{ {Run: FnRTrim}, {Run: FnTrim}, {Run: FnSubstr}, + {Run: FnLocate}, {Run: FnConcat}, {Run: FnConcatWs}, {Run: FnHex}, @@ -1527,6 +1528,34 @@ func FnSubstr(yield Query) { } } +func FnLocate(yield Query) { + mysqlDocSamples := []string{ + `LOCATE('bar', 'foobarbar')`, + `LOCATE('xbar', 'foobar')`, + `LOCATE('bar', 'foobarbar', 5)`, + `INSTR('foobarbar', 'bar')`, + `INSTR('xbar', 'foobar')`, + `POSITION('bar' IN 'foobarbar')`, + `POSITION('xbar' IN 'foobar')`, + } + + for _, q := range mysqlDocSamples { + yield(q, nil) + } + + for _, substr := range locateStrings { + for _, str := range locateStrings { + yield(fmt.Sprintf("LOCATE(%s, %s)", substr, str), nil) + yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil) + yield(fmt.Sprintf("POSITION(%s IN %s)", str, substr), nil) + + for _, i := range radianInputs { + yield(fmt.Sprintf("LOCATE(%s, %s, %s)", substr, str, i), nil) + } + } + } +} + func FnConcat(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("CONCAT(%s)", str), nil) diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index c453f904c96..c4ab2fdb92d 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -234,6 +234,42 @@ var insertStrings = []string{ // "_ucs2 'AabcÅå'", } +var locateStrings = []string{ + "NULL", + "\"\"", + "\"a\"", + "\"abc\"", + "1", + "-1", + "0123", + "0xAACC", + "3.1415926", + // MySQL has broken behavior for these inputs, + // see https://bugs.mysql.com/bug.php?id=113933 + // "\"Å å\"", + // "\"中文测试\"", + // "\"日本語テスト\"", + // "\"한국어 시험\"", + // "\"😊😂🤢\"", + // "_utf8mb4 'abcABCÅå'", + "DATE '2022-10-11'", + "TIME '11:02:23'", + "'123'", + "9223372036854775807", + "-9223372036854775808", + "999999999999999999999999", + "-999999999999999999999999", + "_binary 'Müller' ", + "_utf8mb4 'abcABCÅå'", + "_latin1 0xFF", + // TODO: support other multibyte encodings + // "_dec8 'ÒòÅå'", + // "_utf8mb3 'abcABCÅå'", + // "_utf16 'AabcÅå'", + // "_utf32 'AabcÅå'", + // "_ucs2 'AabcÅå'", +} + var inputConversionTypes = []string{ "BINARY", "BINARY(1)", "BINARY(0)", "BINARY(16)", "BINARY(-1)", "CHAR", "CHAR(1)", "CHAR(0)", "CHAR(16)", "CHAR(-1)", diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 11618bb1d1a..73beb7fd59e 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -604,6 +604,12 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { return nil, argError(method) } return &builtinStrcmp{CallExpr: call, collate: ast.cfg.Collation}, nil + case "instr": + if len(args) != 2 { + return nil, argError(method) + } + call = CallExpr{Arguments: []IR{call.Arguments[1], call.Arguments[0]}, Method: method} + return &builtinLocate{CallExpr: call, collate: ast.cfg.Collation}, nil default: return nil, translateExprNotSupported(fn) } @@ -729,7 +735,7 @@ func (ast *astCompiler) translateCallable(call sqlparser.Callable) (IR, error) { case *sqlparser.CurTimeFuncExpr: if call.Fsp > 6 { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision 12 specified for '%s'. Maximum is 6.", call.Name.String()) + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for '%s'. Maximum is 6.", call.Fsp, call.Name.String()) } var cexpr = CallExpr{Arguments: nil, Method: call.Name.String()} @@ -802,6 +808,31 @@ func (ast *astCompiler) translateCallable(call sqlparser.Callable) (IR, error) { CallExpr: cexpr, collate: ast.cfg.Collation, }, nil + case *sqlparser.LocateExpr: + var args []IR + substr, err := ast.translateExpr(call.SubStr) + if err != nil { + return nil, err + } + args = append(args, substr) + str, err := ast.translateExpr(call.Str) + if err != nil { + return nil, err + } + args = append(args, str) + + if call.Pos != nil { + to, err := ast.translateExpr(call.Pos) + if err != nil { + return nil, err + } + args = append(args, to) + } + var cexpr = CallExpr{Arguments: args, Method: "LOCATE"} + return &builtinLocate{ + CallExpr: cexpr, + collate: ast.cfg.Collation, + }, nil case *sqlparser.IntervalDateExpr: var err error args := make([]IR, 2)