From bea2b59ef421aebf76547fb193888fe80a93d3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Thu, 22 Feb 2024 08:56:16 +0100 Subject: [PATCH 1/3] Column alias expanding on ORDER BY (#15302) Co-authored-by: Harshit Gangal Co-authored-by: Manan Gupta Signed-off-by: Andres Taylor --- go/mysql/sqlerror/sql_error.go | 1 + .../queries/aggregation/aggregation_test.go | 36 +- .../vtgate/queries/dml/insert_test.go | 12 +- .../queries/lookup_queries/main_test.go | 2 +- .../endtoend/vtgate/queries/misc/misc_test.go | 4 +- .../vtgate/queries/orderby/orderby_test.go | 70 ++++ .../vtgate/queries/orderby/schema.sql | 9 + .../vtgate/queries/orderby/vschema.json | 8 + .../vtgate/queries/union/union_test.go | 9 +- go/vt/schemadiff/schema.go | 2 +- go/vt/vterrors/state.go | 1 + .../testdata/postprocess_cases.json | 134 +++++++ .../testdata/unsupported_cases.json | 5 - go/vt/vtgate/semantics/analyzer.go | 19 +- go/vt/vtgate/semantics/binder.go | 15 +- go/vt/vtgate/semantics/early_rewriter.go | 362 +++++++++++++++--- go/vt/vtgate/semantics/early_rewriter_test.go | 207 ++++++++-- go/vt/vtgate/semantics/errors.go | 24 +- 18 files changed, 768 insertions(+), 152 deletions(-) diff --git a/go/mysql/sqlerror/sql_error.go b/go/mysql/sqlerror/sql_error.go index 2796189dde2..fc201be82ef 100644 --- a/go/mysql/sqlerror/sql_error.go +++ b/go/mysql/sqlerror/sql_error.go @@ -242,6 +242,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{ vterrors.WrongParametersToNativeFct: {num: ERWrongParametersToNativeFct, state: SSUnknownSQLState}, vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState}, vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation}, + vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState}, } func getStateToMySQLState(state vterrors.State) mysqlCode { diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index 89b3d0c8c85..6f4dd01d4e2 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -73,7 +73,7 @@ func TestAggregateTypes(t *testing.T) { mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`) mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`) mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`) - t.Run("Average for sharded keyspaces", func(t *testing.T) { + mcmp.Run("Average for sharded keyspaces", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`) }) @@ -101,7 +101,7 @@ func TestEqualFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having 1 = 1", `[[INT64(5)]]`) @@ -177,7 +177,7 @@ func TestAggrOnJoin(t *testing.T) { mcmp.AssertMatches("select a.val1 from aggr_test a join t3 t on a.val2 = t.id7 group by a.val1 having count(*) = 4", `[[VARCHAR("a")]]`) - t.Run("Average in join for sharded", func(t *testing.T) { + mcmp.Run("Average in join for sharded", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.AssertMatches(`select avg(a1.val2), avg(a2.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7`, "[[DECIMAL(1.5000) DECIMAL(1.0000)]]") @@ -196,7 +196,7 @@ func TestNotEqualFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a != 5", `[]`) @@ -220,7 +220,7 @@ func TestLessFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a < 10", `[[INT64(5)]]`) mcmp.AssertMatches("select count(*) as a from aggr_test having 1 < a", `[[INT64(5)]]`) @@ -243,7 +243,7 @@ func TestLessEqualFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a <= 10", `[[INT64(5)]]`) @@ -267,7 +267,7 @@ func TestGreaterFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a > 1", `[[INT64(5)]]`) @@ -291,7 +291,7 @@ func TestGreaterEqualFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a >= 1", `[[INT64(5)]]`) @@ -326,7 +326,7 @@ func TestAggOnTopOfLimit(t *testing.T) { mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',6), (2,'a',1), (3,'b',1), (4,'c',3), (5,'c',4), (6,'b',null), (7,null,2), (8,null,null)") for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]") mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]") @@ -335,7 +335,7 @@ func TestAggOnTopOfLimit(t *testing.T) { mcmp.AssertMatches("select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]") mcmp.AssertMatches("select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`) mcmp.AssertMatchesNoOrder("select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`) - t.Run("Average in sharded query", func(t *testing.T) { + mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.AssertMatches("select avg(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[NULL]]") mcmp.AssertMatchesNoOrder("select val1, avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL DECIMAL(2.0000)] [VARCHAR("a") DECIMAL(3.5000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.5000)]]`) @@ -347,7 +347,7 @@ func TestAggOnTopOfLimit(t *testing.T) { mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]") mcmp.AssertMatches("select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]") mcmp.AssertMatches("select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`) - t.Run("Average in sharded query", func(t *testing.T) { + mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.AssertMatches("select count(*), sum(val1), avg(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0) FLOAT64(0)]]") mcmp.AssertMatches("select count(val1), sum(id), avg(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7) DECIMAL(3.5000)]]") @@ -363,13 +363,13 @@ func TestEmptyTableAggr(t *testing.T) { defer closer() for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload)) mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") - t.Run("Average in sharded query", func(t *testing.T) { + mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]") @@ -380,12 +380,12 @@ func TestEmptyTableAggr(t *testing.T) { mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'b1','bar',200)") for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload)) mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") - t.Run("Average in sharded query", func(t *testing.T) { + mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]") @@ -434,7 +434,7 @@ func TestAggregateLeftJoin(t *testing.T) { mcmp.AssertMatches("SELECT sum(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`) mcmp.AssertMatches("SELECT count(*) FROM t2 LEFT JOIN t1 ON t1.t1_id = t2.id WHERE IFNULL(t1.name, 'NOTSET') = 'r'", `[[INT64(1)]]`) - t.Run("Average in sharded query", func(t *testing.T) { + mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.AssertMatches("SELECT avg(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(0.5000)]]`) mcmp.AssertMatches("SELECT avg(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1.0000)]]`) @@ -491,7 +491,7 @@ func TestScalarAggregate(t *testing.T) { mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)") mcmp.AssertMatches("select count(distinct val1) from aggr_test", `[[INT64(3)]]`) - t.Run("Average in sharded query", func(t *testing.T) { + mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`) }) @@ -551,7 +551,7 @@ func TestComplexAggregation(t *testing.T) { mcmp.Exec(`SELECT shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`) mcmp.Exec(`SELECT name+COUNT(t1_id)+1 FROM t1 GROUP BY name`) mcmp.Exec(`SELECT COUNT(*)+shardkey+MIN(t1_id)+1+MAX(t1_id)*SUM(t1_id)+1+name FROM t1 GROUP BY shardkey, name`) - t.Run("Average in sharded query", func(t *testing.T) { + mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey)+AVG(t1_id) FROM t1`) }) diff --git a/go/test/endtoend/vtgate/queries/dml/insert_test.go b/go/test/endtoend/vtgate/queries/dml/insert_test.go index 80d0602b898..ce052b7b2ba 100644 --- a/go/test/endtoend/vtgate/queries/dml/insert_test.go +++ b/go/test/endtoend/vtgate/queries/dml/insert_test.go @@ -38,7 +38,7 @@ func TestSimpleInsertSelect(t *testing.T) { mcmp.Exec("insert into u_tbl(id, num) values (1,2),(3,4)") for i, mode := range []string{"oltp", "olap"} { - t.Run(mode, func(t *testing.T) { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) qr := mcmp.Exec(fmt.Sprintf("insert into s_tbl(id, num) select id*%d, num*%d from s_tbl where id < 10", 10+i, 20+i)) @@ -65,7 +65,7 @@ func TestFailureInsertSelect(t *testing.T) { mcmp.Exec("insert into u_tbl(id, num) values (1,2),(3,4)") for _, mode := range []string{"oltp", "olap"} { - t.Run(mode, func(t *testing.T) { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) // primary key same @@ -127,7 +127,7 @@ func TestAutoIncInsertSelect(t *testing.T) { }} for _, tcase := range tcases { - t.Run(tcase.query, func(t *testing.T) { + mcmp.Run(tcase.query, func(mcmp *utils.MySQLCompare) { qr := utils.Exec(t, mcmp.VtConn, tcase.query) assert.EqualValues(t, tcase.expRowsAffected, qr.RowsAffected) assert.EqualValues(t, tcase.expInsertID, qr.InsertID) @@ -178,7 +178,7 @@ func TestAutoIncInsertSelectOlapMode(t *testing.T) { }} for _, tcase := range tcases { - t.Run(tcase.query, func(t *testing.T) { + mcmp.Run(tcase.query, func(mcmp *utils.MySQLCompare) { qr := utils.Exec(t, mcmp.VtConn, tcase.query) assert.EqualValues(t, tcase.expRowsAffected, qr.RowsAffected) assert.EqualValues(t, tcase.expInsertID, qr.InsertID) @@ -386,7 +386,7 @@ func TestInsertSelectUnshardedUsingSharded(t *testing.T) { mcmp.Exec("insert into s_tbl(id, num) values (1,2),(3,4)") for _, mode := range []string{"oltp", "olap"} { - t.Run(mode, func(t *testing.T) { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) qr := mcmp.Exec("insert into u_tbl(id, num) select id, num from s_tbl where s_tbl.id in (1,3)") assert.EqualValues(t, 2, qr.RowsAffected) @@ -453,7 +453,7 @@ func TestMixedCases(t *testing.T) { }} for _, tc := range tcases { - t.Run(tc.insQuery, func(t *testing.T) { + mcmp.Run(tc.insQuery, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, tc.insQuery) utils.AssertMatches(t, mcmp.VtConn, tc.selQuery, tc.exp) }) diff --git a/go/test/endtoend/vtgate/queries/lookup_queries/main_test.go b/go/test/endtoend/vtgate/queries/lookup_queries/main_test.go index c385941502a..25bf78437da 100644 --- a/go/test/endtoend/vtgate/queries/lookup_queries/main_test.go +++ b/go/test/endtoend/vtgate/queries/lookup_queries/main_test.go @@ -134,7 +134,7 @@ func TestLookupQueries(t *testing.T) { (3, 'monkey', 'monkey')`) for _, workload := range []string{"olap", "oltp"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, "set workload = "+workload) mcmp.AssertMatches("select id from user where lookup = 'apa'", "[[INT64(1)] [INT64(2)]]") diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index eca46cfcc29..7446238d764 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -276,7 +276,7 @@ func TestAnalyze(t *testing.T) { defer closer() for _, workload := range []string{"olap", "oltp"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload)) utils.Exec(t, mcmp.VtConn, "analyze table t1") utils.Exec(t, mcmp.VtConn, "analyze table uks.unsharded") @@ -309,7 +309,7 @@ func TestTransactionModeVar(t *testing.T) { }} for _, tcase := range tcases { - t.Run(tcase.setStmt, func(t *testing.T) { + mcmp.Run(tcase.setStmt, func(mcmp *utils.MySQLCompare) { if tcase.setStmt != "" { utils.Exec(t, mcmp.VtConn, tcase.setStmt) } diff --git a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go index 43f800ee24c..993f7834301 100644 --- a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go +++ b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go @@ -83,3 +83,73 @@ func TestOrderBy(t *testing.T) { mcmp.AssertMatches("select id1, id2 from t4 order by reverse(id2) desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(2) VARCHAR("Abc")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(1) VARCHAR("a")]]`) } } + +func TestOrderByComplex(t *testing.T) { + // tests written to try to trick the ORDER BY engine and planner + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into user(id, col, email) values(1,1,'a'), (2,2,'Abc'), (3,3,'b'), (4,4,'c'), (5,2,'test'), (6,1,'test'), (7,2,'a'), (8,3,'b'), (9,4,'c3'), (10,2,'d')") + + queries := []string{ + "select email, max(col) from user group by email order by col", + "select email, max(col) from user group by email order by col + 1", + "select email, max(col) from user group by email order by max(col)", + "select email, max(col) from user group by email order by max(col) + 1", + "select email, max(col) from user group by email order by min(col)", + "select email, max(col) as col from user group by email order by col", + "select email, max(col) as col from user group by email order by max(col)", + "select email, max(col) as col from user group by email order by col + 1", + "select email, max(col) as col from user group by email order by email + col", + "select email, max(col) as col from user group by email order by email + max(col)", + "select email, max(col) as col from user group by email order by email, col", + "select email, max(col) as xyz from user group by email order by email, xyz", + "select email, max(col) as xyz from user group by email order by email, max(xyz)", + "select email, max(col) as xyz from user group by email order by email, abs(xyz)", + "select email, max(col) as xyz from user group by email order by email, max(col)", + "select email, max(col) as xyz from user group by email order by email, abs(col)", + "select email, max(col) as xyz from user group by email order by xyz + email", + "select email, max(col) as xyz from user group by email order by abs(xyz) + email", + "select email, max(col) as xyz from user group by email order by abs(xyz)", + "select email, max(col) as xyz from user group by email order by abs(col)", + "select email, max(col) as max_col from user group by email order by max_col desc, length(email)", + "select email, max(col) as max_col, min(col) as min_col from user group by email order by max_col - min_col", + "select email, max(col) as col1, count(*) as col2 from user group by email order by col2 * col1", + "select email, sum(col) as sum_col from user group by email having sum_col > 10 order by sum_col / count(email)", + "select email, max(col) as max_col, char_length(email) as len_email from user group by email order by len_email, max_col desc", + "select email, max(col) as col_alias from user group by email order by case when col_alias > 100 then 0 else 1 end, col_alias", + "select email, count(*) as cnt, max(col) as max_col from user group by email order by cnt desc, max_col + cnt", + "select email, max(col) as max_col from user group by email order by if(max_col > 50, max_col, -max_col) desc", + "select email, max(col) as col, sum(col) as sum_col from user group by email order by col * sum_col desc", + "select email, max(col) as col, (select min(col) from user as u2 where u2.email = user.email) as min_col from user group by email order by col - min_col", + "select email, max(col) as max_col, (max(col) % 10) as mod_col from user group by email order by mod_col, max_col", + "select email, max(col) as 'value', count(email) as 'number' from user group by email order by 'number', 'value'", + "select email, max(col) as col, concat('email: ', email, ' col: ', max(col)) as complex_alias from user group by email order by complex_alias desc", + "select email, max(col) as max_col from user group by email union select email, min(col) as min_col from user group by email order by email", + "select email, max(col) as col from user where col > 50 group by email order by col desc", + "select email, max(col) as col from user group by email order by length(email), col", + "select email, max(col) as max_col, substring(email, 1, 3) as sub_email from user group by email order by sub_email, max_col desc", + "select email, max(col) as max_col from user group by email order by reverse(email), max_col", + "select email, max(col) as max_col from user group by email having max_col > avg(max_col) order by max_col desc", + "select email, count(*) as count, max(col) as max_col from user group by email order by count * max_col desc", + "select email, max(col) as col_alias from user group by email order by col_alias limit 10", + "select email, max(col) as col from user group by email order by col desc, email", + "select concat(email, ' ', max(col)) as combined from user group by email order by combined desc", + "select email, max(col) as max_col from user group by email order by ascii(email), max_col", + "select email, char_length(email) as email_length, max(col) as max_col from user group by email order by email_length desc, max_col", + "select email, max(col) as col from user group by email having col between 10 and 100 order by col", + "select email, max(col) as max_col, min(col) as min_col from user group by email order by max_col + min_col desc", + "select email, max(col) as 'max', count(*) as 'count' from user group by email order by 'max' desc, 'count'", + "select email, max(col) as max_col from (select email, col from user where col > 20) as filtered group by email order by max_col", + "select a.email, a.max_col from (select email, max(col) as max_col from user group by email) as a order by a.max_col desc", + "select email, max(col) as max_col from user where email like 'a%' group by email order by max_col, email", + } + + for _, query := range queries { + mcmp.Run(query, func(mcmp *utils.MySQLCompare) { + _, _ = mcmp.ExecAllowAndCompareError(query) + }) + } +} diff --git a/go/test/endtoend/vtgate/queries/orderby/schema.sql b/go/test/endtoend/vtgate/queries/orderby/schema.sql index 8f0131db357..efaedc14754 100644 --- a/go/test/endtoend/vtgate/queries/orderby/schema.sql +++ b/go/test/endtoend/vtgate/queries/orderby/schema.sql @@ -27,3 +27,12 @@ create table t4_id2_idx ) Engine = InnoDB DEFAULT charset = utf8mb4 COLLATE = utf8mb4_general_ci; + +create table user +( + id bigint primary key, + col bigint, + email varchar(20) +) Engine = InnoDB + DEFAULT charset = utf8mb4 + COLLATE = utf8mb4_general_ci; \ No newline at end of file diff --git a/go/test/endtoend/vtgate/queries/orderby/vschema.json b/go/test/endtoend/vtgate/queries/orderby/vschema.json index 14418850a35..771676de4b9 100644 --- a/go/test/endtoend/vtgate/queries/orderby/vschema.json +++ b/go/test/endtoend/vtgate/queries/orderby/vschema.json @@ -66,6 +66,14 @@ "name": "unicode_loose_md5" } ] + }, + "user": { + "column_vindexes": [ + { + "column": "id", + "name": "hash" + } + ] } } } \ No newline at end of file diff --git a/go/test/endtoend/vtgate/queries/union/union_test.go b/go/test/endtoend/vtgate/queries/union/union_test.go index 898f8b1d659..d91ea3c4073 100644 --- a/go/test/endtoend/vtgate/queries/union/union_test.go +++ b/go/test/endtoend/vtgate/queries/union/union_test.go @@ -57,7 +57,7 @@ func TestUnionDistinct(t *testing.T) { mcmp.Exec("insert into t2(id3, id4) values (2, 3), (3, 4), (4,4), (5,5)") for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, "set workload = "+workload) mcmp.AssertMatches("select 1 union select null", "[[INT64(1)] [NULL]]") mcmp.AssertMatches("select null union select null", "[[NULL]]") @@ -69,10 +69,7 @@ func TestUnionDistinct(t *testing.T) { mcmp.AssertMatchesNoOrder("select id1 from t1 where id1 = 1 union select 452 union select id1 from t1 where id1 = 4", "[[INT64(1)] [INT64(452)] [INT64(4)]]") mcmp.AssertMatchesNoOrder("select id1, id2 from t1 union select 827, 452 union select id3,id4 from t2", "[[INT64(4) INT64(4)] [INT64(1) INT64(1)] [INT64(2) INT64(2)] [INT64(3) INT64(3)] [INT64(827) INT64(452)] [INT64(2) INT64(3)] [INT64(3) INT64(4)] [INT64(5) INT64(5)]]") - t.Run("skipped for now", func(t *testing.T) { - t.Skip() - mcmp.AssertMatches("select 1 from dual where 1 IN (select 1 as col union select 2)", "[[INT64(1)]]") - }) + mcmp.AssertMatches("select 1 from dual where 1 IN (select 1 as col union select 2)", "[[INT64(1)]]") if utils.BinaryIsAtLeastAtVersion(19, "vtgate") { mcmp.AssertMatches(`SELECT 1 from t1 UNION SELECT 2 from t1`, `[[INT64(1)] [INT64(2)]]`) mcmp.AssertMatches(`SELECT 5 from t1 UNION SELECT 6 from t1`, `[[INT64(5)] [INT64(6)]]`) @@ -97,7 +94,7 @@ func TestUnionAll(t *testing.T) { mcmp.Exec("insert into t2(id3, id4) values(3, 3), (4, 4)") for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, "set workload = "+workload) // union all between two selectuniqueequal mcmp.AssertMatches("select id1 from t1 where id1 = 1 union all select id1 from t1 where id1 = 4", "[[INT64(1)]]") diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index 084b703b14f..e3782fdbf0b 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -1058,7 +1058,7 @@ func (s *Schema) ValidateViewReferences() error { Column: e.Column, Ambiguous: true, } - case *semantics.ColumnNotFoundError: + case semantics.ColumnNotFoundError: return &InvalidColumnReferencedInViewError{ View: view.Name(), Column: e.Column.Name.String(), diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 5d286b0c991..c34bfb4f2b1 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -48,6 +48,7 @@ const ( WrongValue WrongArguments BadNullError + InvalidGroupFuncUse // failed precondition NoDB diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index 6697da92c23..cef721e1c6a 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -2205,5 +2205,139 @@ "user.user" ] } + }, + { + "comment": "ORDER BY literal works fine even when the columns have the same name", + "query": "select a.id, b.id from user as a, user_extra as b union all select 1, 2 order by 1", + "plan": { + "QueryType": "SELECT", + "Original": "select a.id, b.id from user as a, user_extra as b union all select 1, 2 order by 1", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(0|2) ASC", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,L:1", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select a.id, weight_string(a.id) from `user` as a where 1 != 1", + "Query": "select a.id, weight_string(a.id) from `user` as a", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select b.id from user_extra as b where 1 != 1", + "Query": "select b.id from user_extra as b", + "Table": "user_extra" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1, 2, weight_string(1) from dual where 1 != 1", + "Query": "select 1, 2, weight_string(1) from dual", + "Table": "dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "ORDER BY literal works fine even when the columns have the same name", + "query": "select a.id, b.id from user as a, user_extra as b union all select 1, 2 order by 2", + "plan": { + "QueryType": "SELECT", + "Original": "select a.id, b.id from user as a, user_extra as b union all select 1, 2 order by 2", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(1|2) ASC", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,R:1", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select a.id from `user` as a where 1 != 1", + "Query": "select a.id from `user` as a", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select b.id, weight_string(b.id) from user_extra as b where 1 != 1", + "Query": "select b.id, weight_string(b.id) from user_extra as b", + "Table": "user_extra" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1, 2, weight_string(2) from dual where 1 != 1", + "Query": "select 1, 2, weight_string(2) from dual", + "Table": "dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user", + "user.user_extra" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 979b50c3d3d..10cf6b84791 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -219,11 +219,6 @@ "query": "select id2 from user uu where id in (select id from user where id = uu.id and user.col in (select col from (select col, id, user_id from user_extra where user_id = 5) uu where uu.user_id = uu.id))", "plan": "VT12001: unsupported: correlated subquery is only supported for EXISTS" }, - { - "comment": "rewrite of 'order by 2' that becomes 'order by id', leading to ambiguous binding.", - "query": "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 2", - "plan": "Column 'id' in field list is ambiguous" - }, { "comment": "unsupported with clause in delete statement", "query": "with x as (select * from user) delete from x", diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index a9c06c8fe4d..ecba7032b96 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -69,10 +69,12 @@ func (a *analyzer) lateInit() { a.binder = newBinder(a.scoper, a, a.tables, a.typer) a.scoper.binder = a.binder a.rewriter = &earlyRewriter{ - env: a.si.Environment(), - scoper: a.scoper, binder: a.binder, + scoper: a.scoper, expandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, + env: a.si.Environment(), + aliasMapCache: map[*sqlparser.Select]map[string]exprContainer{}, + reAnalyze: a.lateAnalyze, } } @@ -232,10 +234,6 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } - if err := a.scoper.up(cursor); err != nil { - a.setError(err) - return false - } if err := a.tables.up(cursor); err != nil { a.setError(err) return false @@ -256,6 +254,11 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return true } + if err := a.scoper.up(cursor); err != nil { + a.setError(err) + return false + } + a.leaveProjection(cursor) return a.shouldContinue() } @@ -348,6 +351,10 @@ func (a *analyzer) analyze(statement sqlparser.Statement) error { a.lateInit() + return a.lateAnalyze(statement) +} + +func (a *analyzer) lateAnalyze(statement sqlparser.SQLNode) error { _ = sqlparser.Rewrite(statement, a.analyzeDown, a.analyzeUp) return a.err } diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index f5e7d3c6297..b010649e067 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -57,7 +57,8 @@ func newBinder(scoper *scoper, org originable, tc *tableCollector, typer *typer) } func (b *binder) up(cursor *sqlparser.Cursor) error { - switch node := cursor.Node().(type) { + node := cursor.Node() + switch node := node.(type) { case *sqlparser.Subquery: currScope := b.scoper.currentScope() b.setSubQueryDependencies(node, currScope) @@ -65,7 +66,7 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { currScope := b.scoper.currentScope() for _, ident := range node.Using { name := sqlparser.NewColName(ident.String()) - deps, err := b.resolveColumn(name, currScope, true) + deps, err := b.resolveColumn(name, currScope, true, true) if err != nil { return err } @@ -73,7 +74,7 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { } case *sqlparser.ColName: currentScope := b.scoper.currentScope() - deps, err := b.resolveColumn(node, currentScope, false) + deps, err := b.resolveColumn(node, currentScope, false, true) if err != nil { if deps.direct.IsEmpty() || !strings.HasSuffix(err.Error(), "is ambiguous") || @@ -185,7 +186,7 @@ func (b *binder) rewriteJoinUsingColName(deps dependency, node *sqlparser.ColNam return dependency{}, err } node.Qualifier = name - deps, err = b.resolveColumn(node, currentScope, false) + deps, err = b.resolveColumn(node, currentScope, false, true) if err != nil { return dependency{}, err } @@ -226,7 +227,7 @@ func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery, currScope *sc b.direct[subq] = subqDirectDeps.KeepOnly(tablesToKeep) } -func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allowMulti bool) (dependency, error) { +func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allowMulti, singleTableFallBack bool) (dependency, error) { var thisDeps dependencies first := true var tableName *sqlparser.TableName @@ -248,7 +249,7 @@ func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allow } else if err != nil { return dependency{}, err } - if current.parent == nil && len(current.tables) == 1 && first && colName.Qualifier.IsEmpty() { + if current.parent == nil && len(current.tables) == 1 && first && colName.Qualifier.IsEmpty() && singleTableFallBack { // if this is the top scope, and we still haven't been able to find a match, we know we are about to fail // we can check this last scope and see if there is a single table. if there is just one table in the scope // we assume that the column is meant to come from this table. @@ -263,7 +264,7 @@ func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allow first = false current = current.parent } - return dependency{}, ShardedError{&ColumnNotFoundError{Column: colName, Table: tableName}} + return dependency{}, ShardedError{ColumnNotFoundError{Column: colName, Table: tableName}} } func (b *binder) resolveColumnInScope(current *scope, expr *sqlparser.ColName, allowMulti bool) (dependencies, error) { diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 08d432ae9d0..58bf2e2a5c4 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -34,6 +34,12 @@ type earlyRewriter struct { warning string expandedColumns map[sqlparser.TableName][]*sqlparser.ColName env *vtenv.Environment + aliasMapCache map[*sqlparser.Select]map[string]exprContainer + + // reAnalyze is used when we are running in the late stage, after the other parts of semantic analysis + // have happened, and we are introducing or changing the AST. We invoke it so all parts of the query have been + // typed, scoped and bound correctly + reAnalyze func(n sqlparser.SQLNode) error } func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { @@ -43,15 +49,7 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { case sqlparser.SelectExprs: return r.handleSelectExprs(cursor, node) case *sqlparser.JoinTableExpr: - r.handleJoinTableExpr(node) - case sqlparser.OrderBy: - r.clause = "order clause" - iter := &orderByIterator{ - node: node, - idx: -1, - } - - return r.handleOrderByAndGroupBy(cursor.Parent(), iter) + r.handleJoinTableExprDown(node) case *sqlparser.OrExpr: rewriteOrExpr(r.env, cursor, node) case *sqlparser.AndExpr: @@ -64,7 +62,7 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { node: node, idx: -1, } - return r.handleOrderByAndGroupBy(cursor.Parent(), iter) + return r.handleGroupBy(cursor.Parent(), iter) case *sqlparser.ComparisonExpr: return handleComparisonExpr(cursor, node) case *sqlparser.With: @@ -72,24 +70,45 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { case *sqlparser.AliasedTableExpr: return r.handleAliasedTable(node) case *sqlparser.Delete: - // When we do not have any target, it is a single table delete. - // In a single table delete, the table references is always a single aliased table expression. - if len(node.Targets) != 0 { - return nil - } - tblExpr, ok := node.TableExprs[0].(*sqlparser.AliasedTableExpr) - if !ok { - return nil - } - tblName, err := tblExpr.TableName() - if err != nil { - return err + return handleDelete(node) + } + return nil +} + +func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { + switch node := cursor.Node().(type) { + case *sqlparser.JoinTableExpr: + return r.handleJoinTableExprUp(node) + case sqlparser.OrderBy: + r.clause = "order clause" + iter := &orderByIterator{ + node: node, + idx: -1, + r: r, } - node.Targets = append(node.Targets, tblName) + return r.handleOrderBy(cursor.Parent(), iter) } return nil } +func handleDelete(del *sqlparser.Delete) error { + // When we do not have any target, it is a single table delete. + // In a single table delete, the table references is always a single aliased table expression. + if len(del.Targets) != 0 { + return nil + } + tblExpr, ok := del.TableExprs[0].(*sqlparser.AliasedTableExpr) + if !ok { + return nil + } + tblName, err := tblExpr.TableName() + if err != nil { + return err + } + del.Targets = append(del.Targets, tblName) + return nil +} + func (r *earlyRewriter) handleAliasedTable(node *sqlparser.AliasedTableExpr) error { tbl, ok := node.Expr.(sqlparser.TableName) if !ok || tbl.Qualifier.NotEmpty() { @@ -139,31 +158,21 @@ func rewriteNotExpr(cursor *sqlparser.Cursor, node *sqlparser.NotExpr) { cursor.Replace(cmp) } -func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { +func (r *earlyRewriter) handleJoinTableExprUp(join *sqlparser.JoinTableExpr) error { // this rewriting is done in the `up` phase, because we need the scope to have been // filled in with the available tables - node, ok := cursor.Node().(*sqlparser.JoinTableExpr) - if !ok || len(node.Condition.Using) == 0 { + if len(join.Condition.Using) == 0 { return nil } - err := rewriteJoinUsing(r.binder, node) + err := rewriteJoinUsing(r.binder, join) if err != nil { return err } // since the binder has already been over the join, we need to invoke it again, so it // can bind columns to the right tables - sqlparser.Rewrite(node.Condition.On, nil, func(cursor *sqlparser.Cursor) bool { - innerErr := r.binder.up(cursor) - if innerErr == nil { - return true - } - - err = innerErr - return false - }) - return err + return r.reAnalyze(join.Condition.On) } // handleWhereClause processes WHERE clauses, specifically the HAVING clause. @@ -175,7 +184,7 @@ func (r *earlyRewriter) handleWhereClause(node *sqlparser.Where, parent sqlparse if node.Type != sqlparser.HavingClause { return nil } - expr, err := r.rewriteAliasesInOrderByHavingAndGroupBy(node.Expr, sel) + expr, err := r.rewriteAliasesInHavingAndGroupBy(node.Expr, sel) if err != nil { return err } @@ -193,8 +202,8 @@ func (r *earlyRewriter) handleSelectExprs(cursor *sqlparser.Cursor, node sqlpars return r.expandStar(cursor, node) } -// handleJoinTableExpr processes JOIN table expressions and handles the Straight Join type. -func (r *earlyRewriter) handleJoinTableExpr(node *sqlparser.JoinTableExpr) { +// handleJoinTableExprDown processes JOIN table expressions and handles the Straight Join type. +func (r *earlyRewriter) handleJoinTableExprDown(node *sqlparser.JoinTableExpr) { if node.Join != sqlparser.StraightJoinType { return } @@ -205,6 +214,7 @@ func (r *earlyRewriter) handleJoinTableExpr(node *sqlparser.JoinTableExpr) { type orderByIterator struct { node sqlparser.OrderBy idx int + r *earlyRewriter } func (it *orderByIterator) next() sqlparser.Expr { @@ -217,7 +227,7 @@ func (it *orderByIterator) next() sqlparser.Expr { return it.node[it.idx].Expr } -func (it *orderByIterator) replace(e sqlparser.Expr) error { +func (it *orderByIterator) replace(e sqlparser.Expr) (err error) { if it.idx >= len(it.node) { return vterrors.VT13001("went past the last item") } @@ -253,13 +263,57 @@ type iterator interface { replace(e sqlparser.Expr) error } -func (r *earlyRewriter) replaceLiteralsInOrderByGroupBy(e sqlparser.Expr, iter iterator) (bool, error) { +func (r *earlyRewriter) replaceLiteralsInOrderBy(e sqlparser.Expr, iter iterator) (bool, error) { lit := getIntLiteral(e) if lit == nil { return false, nil } - newExpr, err := r.rewriteOrderByExpr(lit) + newExpr, recheck, err := r.rewriteOrderByExpr(lit) + if err != nil { + return false, err + } + + if getIntLiteral(newExpr) == nil { + coll, ok := e.(*sqlparser.CollateExpr) + if ok { + coll.Expr = newExpr + newExpr = coll + } + } else { + // the expression is still a literal int. that means that we don't really need to sort by it. + // we'll just replace the number with a string instead, just like mysql would do in this situation + // mysql> explain select 1 as foo from user group by 1; + // + // mysql> show warnings; + // +-------+------+-----------------------------------------------------------------+ + // | Level | Code | Message | + // +-------+------+-----------------------------------------------------------------+ + // | Note | 1003 | /* select#1 */ select 1 AS `foo` from `test`.`user` group by '' | + // +-------+------+-----------------------------------------------------------------+ + newExpr = sqlparser.NewStrLiteral("") + } + + err = iter.replace(newExpr) + if err != nil { + return false, err + } + if recheck { + err = r.reAnalyze(newExpr) + } + if err != nil { + return false, err + } + return true, nil +} + +func (r *earlyRewriter) replaceLiteralsInGroupBy(e sqlparser.Expr, iter iterator) (bool, error) { + lit := getIntLiteral(e) + if lit == nil { + return false, nil + } + + newExpr, err := r.rewriteGroupByExpr(lit) if err != nil { return false, err } @@ -309,7 +363,41 @@ func getIntLiteral(e sqlparser.Expr) *sqlparser.Literal { } // handleOrderBy processes the ORDER BY clause. -func (r *earlyRewriter) handleOrderByAndGroupBy(parent sqlparser.SQLNode, iter iterator) error { +func (r *earlyRewriter) handleOrderBy(parent sqlparser.SQLNode, iter iterator) error { + stmt, ok := parent.(sqlparser.SelectStatement) + if !ok { + return nil + } + + sel := sqlparser.GetFirstSelect(stmt) + for e := iter.next(); e != nil; e = iter.next() { + lit, err := r.replaceLiteralsInOrderBy(e, iter) + if err != nil { + return err + } + if lit { + continue + } + + expr, err := r.rewriteAliasesInOrderBy(e, sel) + if err != nil { + return err + } + + if err = iter.replace(expr); err != nil { + return err + } + + if err = r.reAnalyze(expr); err != nil { + return err + } + } + + return nil +} + +// handleGroupBy processes the GROUP BY clause. +func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) error { stmt, ok := parent.(sqlparser.SelectStatement) if !ok { return nil @@ -317,14 +405,14 @@ func (r *earlyRewriter) handleOrderByAndGroupBy(parent sqlparser.SQLNode, iter i sel := sqlparser.GetFirstSelect(stmt) for e := iter.next(); e != nil; e = iter.next() { - lit, err := r.replaceLiteralsInOrderByGroupBy(e, iter) + lit, err := r.replaceLiteralsInGroupBy(e, iter) if err != nil { return err } if lit { continue } - expr, err := r.rewriteAliasesInOrderByHavingAndGroupBy(e, sel) + expr, err := r.rewriteAliasesInHavingAndGroupBy(e, sel) if err != nil { return err } @@ -343,7 +431,7 @@ func (r *earlyRewriter) handleOrderByAndGroupBy(parent sqlparser.SQLNode, iter i // in SELECT points to that expression, not any table column. // - However, if the aliased expression is an aggregation and the column identifier in // the HAVING/ORDER BY clause is inside an aggregation function, the rule does not apply. -func (r *earlyRewriter) rewriteAliasesInOrderByHavingAndGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { +func (r *earlyRewriter) rewriteAliasesInHavingAndGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { type ExprContainer struct { expr sqlparser.Expr ambiguous bool @@ -435,7 +523,176 @@ func (r *earlyRewriter) rewriteAliasesInOrderByHavingAndGroupBy(node sqlparser.E return } -func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) { +// rewriteAliasesInOrderBy rewrites columns in the ORDER BY and HAVING clauses to use aliases +// from the SELECT expressions when applicable, following MySQL scoping rules: +// - A column identifier without a table qualifier that matches an alias introduced +// in SELECT points to that expression, not any table column. +// - However, if the aliased expression is an aggregation and the column identifier in +// the HAVING/ORDER BY clause is inside an aggregation function, the rule does not apply. +func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { + currentScope := r.scoper.currentScope() + if currentScope.isUnion { + // It is not safe to rewrite order by clauses in unions. + return node, nil + } + + aliases := r.getAliasMap(sel) + insideAggr := false + dontEnterSubquery := func(node, _ sqlparser.SQLNode) bool { + switch node.(type) { + case *sqlparser.Subquery: + return false + case sqlparser.AggrFunc: + insideAggr = true + } + + _, isSubq := node.(*sqlparser.Subquery) + return !isSubq + } + output := sqlparser.CopyOnRewrite(node, dontEnterSubquery, func(cursor *sqlparser.CopyOnWriteCursor) { + var col *sqlparser.ColName + + switch node := cursor.Node().(type) { + case sqlparser.AggrFunc: + insideAggr = false + return + case *sqlparser.ColName: + col = node + default: + return + } + + if !col.Qualifier.IsEmpty() { + // we are only interested in columns not qualified by table names + return + } + + item, found := aliases[col.Name.Lowered()] + if !found { + // if there is no matching alias, there is no rewriting needed + return + } + + topLevel := col == node + if !topLevel && r.isColumnOnTable(col, currentScope) { + // we only want to replace columns that are not coming from the table + return + } + + if item.ambiguous { + err = &AmbiguousColumnError{Column: sqlparser.String(col)} + } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = &InvalidUserOfGroupFunction{} + } + if err != nil { + cursor.StopTreeWalk() + return + } + + cursor.Replace(sqlparser.CloneExpr(item.expr)) + }, nil) + + expr = output.(sqlparser.Expr) + return +} + +func (r *earlyRewriter) isColumnOnTable(col *sqlparser.ColName, currentScope *scope) bool { + if !currentScope.stmtScope && currentScope.parent != nil { + currentScope = currentScope.parent + } + _, err := r.binder.resolveColumn(col, currentScope, false, false) + return err == nil +} + +func (r *earlyRewriter) getAliasMap(sel *sqlparser.Select) (aliases map[string]exprContainer) { + var found bool + aliases, found = r.aliasMapCache[sel] + if found { + return + } + aliases = map[string]exprContainer{} + for _, e := range sel.SelectExprs { + ae, ok := e.(*sqlparser.AliasedExpr) + if !ok { + continue + } + + var alias string + + item := exprContainer{expr: ae.Expr} + if ae.As.NotEmpty() { + alias = ae.As.Lowered() + } else if col, ok := ae.Expr.(*sqlparser.ColName); ok { + alias = col.Name.Lowered() + } + + if old, alreadyExists := aliases[alias]; alreadyExists && !sqlparser.Equals.Expr(old.expr, item.expr) { + item.ambiguous = true + } + + aliases[alias] = item + } + return aliases +} + +type exprContainer struct { + expr sqlparser.Expr + ambiguous bool +} + +func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (expr sqlparser.Expr, needReAnalysis bool, err error) { + scope, found := r.scoper.specialExprScopes[node] + if !found { + return node, false, nil + } + num, err := strconv.Atoi(node.Val) + if err != nil { + return nil, false, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error parsing column number: %s", node.Val) + } + + stmt, isSel := scope.stmt.(*sqlparser.Select) + if !isSel { + return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error invalid statement type, expect Select, got: %T", scope.stmt) + } + + if num < 1 || num > len(stmt.SelectExprs) { + return nil, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) + } + + // We loop like this instead of directly accessing the offset, to make sure there are no unexpanded `*` before + for i := 0; i < num; i++ { + if _, ok := stmt.SelectExprs[i].(*sqlparser.AliasedExpr); !ok { + return nil, false, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot use column offsets in %s when using `%s`", r.clause, sqlparser.String(stmt.SelectExprs[i])) + } + } + + colOffset := num - 1 + aliasedExpr, ok := stmt.SelectExprs[colOffset].(*sqlparser.AliasedExpr) + if !ok { + return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to handle %s", sqlparser.String(node)) + } + + if scope.isUnion { + colName := sqlparser.NewColName(aliasedExpr.ColumnName()) + vtabl, ok := scope.tables[0].(*vTableInfo) + if !ok { + panic("BUG: not expected") + } + + // since column names can be ambiguous here, we want to do the binding by offset and not by column name + allColExprs := vtabl.cols[colOffset] + direct, recursive, typ := r.binder.org.depsForExpr(allColExprs) + r.binder.direct[colName] = direct + r.binder.recursive[colName] = recursive + r.binder.typer.m[colName] = typ + + return colName, false, nil + } + + return realCloneOfColNames(aliasedExpr.Expr, false), true, nil +} + +func (r *earlyRewriter) rewriteGroupByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) { scope, found := r.scoper.specialExprScopes[node] if !found { return node, nil @@ -467,13 +724,8 @@ func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.E } if scope.isUnion { - col, isCol := aliasedExpr.Expr.(*sqlparser.ColName) - - if aliasedExpr.As.IsEmpty() && isCol { - return sqlparser.NewColName(col.Name.String()), nil - } - - return sqlparser.NewColName(aliasedExpr.ColumnName()), nil + colName := sqlparser.NewColName(aliasedExpr.ColumnName()) + return colName, nil } return realCloneOfColNames(aliasedExpr.Expr, false), nil diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index e681f722b1d..cf93a52447c 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -304,42 +304,84 @@ func TestRewriteJoinUsingColumns(t *testing.T) { } -func TestOrderByGroupByLiteral(t *testing.T) { +func TestGroupByLiteral(t *testing.T) { schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{}, } cDB := "db" tcases := []struct { - sql string - expSQL string - expErr string + sql string + expSQL string + expDeps TableSet + expErr string }{{ - sql: "select 1 as id from t1 order by 1", - expSQL: "select 1 as id from t1 order by '' asc", + sql: "select t1.col from t1 group by 1", + expSQL: "select t1.col from t1 group by t1.col", + expDeps: TS0, }, { - sql: "select t1.col from t1 order by 1", - expSQL: "select t1.col from t1 order by t1.col asc", + sql: "select t1.col as xyz from t1 group by 1", + expSQL: "select t1.col as xyz from t1 group by t1.col", + expDeps: TS0, }, { - sql: "select t1.col from t1 order by 1.0", - expSQL: "select t1.col from t1 order by 1.0 asc", + sql: "select id from t1 group by 2", + expErr: "Unknown column '2' in 'group clause'", }, { - sql: "select t1.col from t1 order by 'fubick'", - expSQL: "select t1.col from t1 order by 'fubick' asc", + sql: "select *, id from t1 group by 2", + expErr: "cannot use column offsets in group clause when using `*`", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.NewTestParser().Parse(tcase.sql) + require.NoError(t, err) + selectStatement := ast.(*sqlparser.Select) + st, err := Analyze(selectStatement, cDB, schemaInfo) + if tcase.expErr == "" { + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + gb := selectStatement.GroupBy + deps := st.RecursiveDeps(gb[0]) + assert.Equal(t, tcase.expDeps, deps) + } else { + require.EqualError(t, err, tcase.expErr) + } + }) + } +} + +func TestOrderByLiteral(t *testing.T) { + schemaInfo := &FakeSI{ + Tables: map[string]*vindexes.Table{}, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + expDeps TableSet + expErr string + }{{ + sql: "select 1 as id from t1 order by 1", + expSQL: "select 1 as id from t1 order by '' asc", + expDeps: NoTables, }, { - sql: "select t1.col as foo from t1 order by 1", - expSQL: "select t1.col as foo from t1 order by t1.col asc", + sql: "select t1.col from t1 order by 1", + expSQL: "select t1.col from t1 order by t1.col asc", + expDeps: TS0, }, { - sql: "select t1.col from t1 group by 1", - expSQL: "select t1.col from t1 group by t1.col", + sql: "select t1.col from t1 order by 1.0", + expSQL: "select t1.col from t1 order by 1.0 asc", + expDeps: NoTables, }, { - sql: "select t1.col as xyz from t1 group by 1", - expSQL: "select t1.col as xyz from t1 group by t1.col", + sql: "select t1.col from t1 order by 'fubick'", + expSQL: "select t1.col from t1 order by 'fubick' asc", + expDeps: NoTables, }, { - sql: "select t1.col as xyz, count(*) from t1 group by 1 order by 2", - expSQL: "select t1.col as xyz, count(*) from t1 group by t1.col order by count(*) asc", + sql: "select t1.col as foo from t1 order by 1", + expSQL: "select t1.col as foo from t1 order by t1.col asc", + expDeps: TS0, }, { - sql: "select id from t1 group by 2", - expErr: "Unknown column '2' in 'group clause'", + sql: "select t1.col as xyz, count(*) from t1 group by 1 order by 2", + expSQL: "select t1.col as xyz, count(*) from t1 group by t1.col order by count(*) asc", + expDeps: TS0, }, { sql: "select id from t1 order by 2", expErr: "Unknown column '2' in 'order clause'", @@ -347,27 +389,41 @@ func TestOrderByGroupByLiteral(t *testing.T) { sql: "select *, id from t1 order by 2", expErr: "cannot use column offsets in order clause when using `*`", }, { - sql: "select *, id from t1 group by 2", - expErr: "cannot use column offsets in group clause when using `*`", + sql: "select id from t1 order by 1 collate utf8_general_ci", + expSQL: "select id from t1 order by id collate utf8_general_ci asc", + expDeps: TS0, + }, { + sql: "select id from `user` union select 1 from dual order by 1", + expSQL: "select id from `user` union select 1 from dual order by id asc", + expDeps: TS0, }, { - sql: "select id from t1 order by 1 collate utf8_general_ci", - expSQL: "select id from t1 order by id collate utf8_general_ci asc", + sql: "select id from t1 order by 2", + expErr: "Unknown column '2' in 'order clause'", }, { - sql: "select a.id from `user` union select 1 from dual order by 1", - expSQL: "select a.id from `user` union select 1 from dual order by id asc", + sql: "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 1", + expSQL: "select a.id, b.id from `user` as a, user_extra as b union select 1, 2 from dual order by id asc", + expDeps: TS0, }, { - sql: "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 1", - expErr: "Column 'id' in field list is ambiguous", + sql: "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 2", + expSQL: "select a.id, b.id from `user` as a, user_extra as b union select 1, 2 from dual order by id asc", + expDeps: TS1, + }, { + sql: "select user.id as foo from user union select col from user_extra order by 1", + expSQL: "select `user`.id as foo from `user` union select col from user_extra order by foo asc", + expDeps: MergeTableSets(TS0, TS1), }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.NewTestParser().Parse(tcase.sql) require.NoError(t, err) selectStatement := ast.(sqlparser.SelectStatement) - _, err = Analyze(selectStatement, cDB, schemaInfo) + st, err := Analyze(selectStatement, cDB, schemaInfo) if tcase.expErr == "" { require.NoError(t, err) assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + ordering := selectStatement.GetOrderBy() + deps := st.RecursiveDeps(ordering[0].Expr) + assert.Equal(t, tcase.expDeps, deps) } else { require.EqualError(t, err, tcase.expErr) } @@ -375,7 +431,7 @@ func TestOrderByGroupByLiteral(t *testing.T) { } } -func TestHavingAndOrderByColumnName(t *testing.T) { +func TestHavingColumnName(t *testing.T) { schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{}, } @@ -388,28 +444,97 @@ func TestHavingAndOrderByColumnName(t *testing.T) { sql: "select id, sum(foo) as sumOfFoo from t1 having sumOfFoo > 1", expSQL: "select id, sum(foo) as sumOfFoo from t1 having sum(foo) > 1", }, { + sql: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + }, { + sql: "select foo + 2 as foo from t1 having foo = 42", + expSQL: "select foo + 2 as foo from t1 having foo + 2 = 42", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.NewTestParser().Parse(tcase.sql) + require.NoError(t, err) + selectStatement := ast.(sqlparser.SelectStatement) + _, err = Analyze(selectStatement, cDB, schemaInfo) + if tcase.expErr == "" { + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + } else { + require.EqualError(t, err, tcase.expErr) + } + }) + } +} + +func TestOrderByColumnName(t *testing.T) { + schemaInfo := &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t1": { + Keyspace: &vindexes.Keyspace{Name: "ks", Sharded: true}, + Name: sqlparser.NewIdentifierCS("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("id"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewIdentifierCI("foo"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewIdentifierCI("bar"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }, + }, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + expErr string + }{{ sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo", expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) asc", }, { - sql: "select id, sum(foo) as foo from t1 having sum(foo) > 1", - expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo + 1", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) + 1 asc", + }, { + sql: "select id, sum(foo) as sumOfFoo from t1 order by abs(sumOfFoo)", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by abs(sum(foo)) asc", + }, { + sql: "select id, sum(foo) as sumOfFoo from t1 order by max(sumOfFoo)", + expErr: "Invalid use of group function", + }, { + sql: "select id, sum(foo) as foo from t1 order by foo + 1", + expSQL: "select id, sum(foo) as foo from t1 order by foo + 1 asc", + }, { + sql: "select id, sum(foo) as foo from t1 order by foo", + expSQL: "select id, sum(foo) as foo from t1 order by sum(foo) asc", }, { sql: "select id, lower(min(foo)) as foo from t1 order by min(foo)", expSQL: "select id, lower(min(foo)) as foo from t1 order by min(foo) asc", }, { - // invalid according to group by rules, but still accepted by mysql - sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", - expSQL: "select id, t1.bar as foo from t1 group by id order by min(t1.bar) asc", + sql: "select id, lower(min(foo)) as foo from t1 order by foo", + expSQL: "select id, lower(min(foo)) as foo from t1 order by lower(min(foo)) asc", }, { - sql: "select foo + 2 as foo from t1 having foo = 42", - expSQL: "select foo + 2 as foo from t1 having foo + 2 = 42", + sql: "select id, lower(min(foo)) as foo from t1 order by abs(foo)", + expSQL: "select id, lower(min(foo)) as foo from t1 order by abs(foo) asc", }, { - sql: "select id, b as id, count(*) from t1 order by id", + sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", + expSQL: "select id, t1.bar as foo from t1 group by id order by min(foo) asc", + }, { + sql: "select id, bar as id, count(*) from t1 order by id", expErr: "Column 'id' in field list is ambiguous", }, { sql: "select id, id, count(*) from t1 order by id", expSQL: "select id, id, count(*) from t1 order by id asc", - }} + }, { + sql: "select id, count(distinct foo) k from t1 group by id order by k", + expSQL: "select id, count(distinct foo) as k from t1 group by id order by count(distinct foo) asc", + }, { + sql: "select user.id as foo from user union select col from user_extra order by foo", + expSQL: "select `user`.id as foo from `user` union select col from user_extra order by foo asc", + }, + } for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.NewTestParser().Parse(tcase.sql) diff --git a/go/vt/vtgate/semantics/errors.go b/go/vt/vtgate/semantics/errors.go index 8d0b23d7f82..069c3476db0 100644 --- a/go/vt/vtgate/semantics/errors.go +++ b/go/vt/vtgate/semantics/errors.go @@ -51,6 +51,7 @@ type ( AmbiguousColumnError struct{ Column string } SubqueryColumnCountError struct{ Expected int } ColumnsMissingInSchemaError struct{} + InvalidUserOfGroupFunction struct{} UnsupportedMultiTablesInUpdateError struct { ExprCount int @@ -207,18 +208,18 @@ func (e *BuggyError) Error() string { func (e *BuggyError) bug() {} // ColumnNotFoundError -func (e *ColumnNotFoundError) Error() string { +func (e ColumnNotFoundError) Error() string { if e.Table == nil { return eprintf(e, "column '%s' not found", sqlparser.String(e.Column)) } return eprintf(e, "column '%s' not found in table '%s'", sqlparser.String(e.Column), sqlparser.String(e.Table)) } -func (e *ColumnNotFoundError) ErrorCode() vtrpcpb.Code { +func (e ColumnNotFoundError) ErrorCode() vtrpcpb.Code { return vtrpcpb.Code_INVALID_ARGUMENT } -func (e *ColumnNotFoundError) ErrorState() vterrors.State { +func (e ColumnNotFoundError) ErrorState() vterrors.State { return vterrors.BadFieldError } @@ -235,6 +236,7 @@ func (e *AmbiguousColumnError) ErrorCode() vtrpcpb.Code { return vtrpcpb.Code_INVALID_ARGUMENT } +// UnsupportedConstruct func (e *UnsupportedConstruct) unsupported() {} func (e *UnsupportedConstruct) ErrorCode() vtrpcpb.Code { @@ -245,6 +247,7 @@ func (e *UnsupportedConstruct) Error() string { return eprintf(e, e.errString) } +// SubqueryColumnCountError func (e *SubqueryColumnCountError) ErrorCode() vtrpcpb.Code { return vtrpcpb.Code_INVALID_ARGUMENT } @@ -253,7 +256,7 @@ func (e *SubqueryColumnCountError) Error() string { return fmt.Sprintf("Operand should contain %d column(s)", e.Expected) } -// MissingInVSchemaError +// ColumnsMissingInSchemaError func (e *ColumnsMissingInSchemaError) Error() string { return "VT09015: schema tracking required" } @@ -261,3 +264,16 @@ func (e *ColumnsMissingInSchemaError) Error() string { func (e *ColumnsMissingInSchemaError) ErrorCode() vtrpcpb.Code { return vtrpcpb.Code_INVALID_ARGUMENT } + +// InvalidUserOfGroupFunction +func (*InvalidUserOfGroupFunction) Error() string { + return "Invalid use of group function" +} + +func (*InvalidUserOfGroupFunction) ErrorCode() vtrpcpb.Code { + return vtrpcpb.Code_INVALID_ARGUMENT +} + +func (*InvalidUserOfGroupFunction) ErrorState() vterrors.State { + return vterrors.InvalidGroupFuncUse +} From 2b720fdff2a1e5cfb4511235611d88e861af3a54 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 22 Feb 2024 09:24:06 +0100 Subject: [PATCH 2/3] empty commit to trigger ci Signed-off-by: Andres Taylor From fd4e2ee14f47d1de88aac2e0529b14afb0f5557d Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 22 Feb 2024 12:24:41 +0100 Subject: [PATCH 3/3] test: fix so we skip correctly Signed-off-by: Andres Taylor --- go/test/endtoend/utils/cmp.go | 6 ++++++ .../queries/aggregation/aggregation_test.go | 18 +++++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/go/test/endtoend/utils/cmp.go b/go/test/endtoend/utils/cmp.go index 38726d6c3aa..8d0b56ac6b3 100644 --- a/go/test/endtoend/utils/cmp.go +++ b/go/test/endtoend/utils/cmp.go @@ -70,6 +70,12 @@ func (mcmp *MySQLCompare) AssertMatches(query, expected string) { } } +// SkipIfBinaryIsBelowVersion should be used instead of using utils.SkipIfBinaryIsBelowVersion(t, +// This is because we might be inside a Run block that has a different `t` variable +func (mcmp *MySQLCompare) SkipIfBinaryIsBelowVersion(majorVersion int, binary string) { + SkipIfBinaryIsBelowVersion(mcmp.t, majorVersion, binary) +} + // AssertMatchesAny ensures the given query produces any one of the expected results. func (mcmp *MySQLCompare) AssertMatchesAny(query string, expected ...string) { mcmp.t.Helper() diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index 6f4dd01d4e2..da2e14218fe 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -74,7 +74,7 @@ func TestAggregateTypes(t *testing.T) { mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`) mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`) mcmp.Run("Average for sharded keyspaces", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`) }) } @@ -178,7 +178,7 @@ func TestAggrOnJoin(t *testing.T) { `[[VARCHAR("a")]]`) mcmp.Run("Average in join for sharded", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches(`select avg(a1.val2), avg(a2.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7`, "[[DECIMAL(1.5000) DECIMAL(1.0000)]]") @@ -336,7 +336,7 @@ func TestAggOnTopOfLimit(t *testing.T) { mcmp.AssertMatches("select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`) mcmp.AssertMatchesNoOrder("select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`) mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches("select avg(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[NULL]]") mcmp.AssertMatchesNoOrder("select val1, avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL DECIMAL(2.0000)] [VARCHAR("a") DECIMAL(3.5000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.5000)]]`) }) @@ -348,7 +348,7 @@ func TestAggOnTopOfLimit(t *testing.T) { mcmp.AssertMatches("select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]") mcmp.AssertMatches("select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`) mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches("select count(*), sum(val1), avg(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0) FLOAT64(0)]]") mcmp.AssertMatches("select count(val1), sum(id), avg(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7) DECIMAL(3.5000)]]") mcmp.AssertMatchesNoOrder("select val1, count(val2), sum(val2), avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", @@ -370,7 +370,7 @@ func TestEmptyTableAggr(t *testing.T) { mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]") }) @@ -386,7 +386,7 @@ func TestEmptyTableAggr(t *testing.T) { mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]") mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]") @@ -435,7 +435,7 @@ func TestAggregateLeftJoin(t *testing.T) { mcmp.AssertMatches("SELECT count(*) FROM t2 LEFT JOIN t1 ON t1.t1_id = t2.id WHERE IFNULL(t1.name, 'NOTSET') = 'r'", `[[INT64(1)]]`) mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches("SELECT avg(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(0.5000)]]`) mcmp.AssertMatches("SELECT avg(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1.0000)]]`) aggregations := []string{ @@ -492,7 +492,7 @@ func TestScalarAggregate(t *testing.T) { mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)") mcmp.AssertMatches("select count(distinct val1) from aggr_test", `[[INT64(3)]]`) mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`) }) } @@ -552,7 +552,7 @@ func TestComplexAggregation(t *testing.T) { mcmp.Exec(`SELECT name+COUNT(t1_id)+1 FROM t1 GROUP BY name`) mcmp.Exec(`SELECT COUNT(*)+shardkey+MIN(t1_id)+1+MAX(t1_id)*SUM(t1_id)+1+name FROM t1 GROUP BY shardkey, name`) mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) { - utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey)+AVG(t1_id) FROM t1`) }) }