Skip to content

Commit

Permalink
[release-19.0] Column alias expanding on ORDER BY (#15302) (#15329)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
Co-authored-by: Andrés Taylor <[email protected]>
Co-authored-by: Harshit Gangal <[email protected]>
Co-authored-by: Manan Gupta <[email protected]>
  • Loading branch information
4 people authored Feb 22, 2024
1 parent e8283bf commit 979014c
Show file tree
Hide file tree
Showing 19 changed files with 783 additions and 161 deletions.
1 change: 1 addition & 0 deletions go/mysql/sqlerror/sql_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{
vterrors.WrongParametersToNativeFct: {num: ERWrongParametersToNativeFct, state: SSUnknownSQLState},
vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState},
vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation},
vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState},
}

func getStateToMySQLState(state vterrors.State) mysqlCode {
Expand Down
6 changes: 6 additions & 0 deletions go/test/endtoend/utils/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ func (mcmp *MySQLCompare) AssertMatches(query, expected string) {
}
}

// SkipIfBinaryIsBelowVersion should be used instead of using utils.SkipIfBinaryIsBelowVersion(t,
// This is because we might be inside a Run block that has a different `t` variable
func (mcmp *MySQLCompare) SkipIfBinaryIsBelowVersion(majorVersion int, binary string) {
SkipIfBinaryIsBelowVersion(mcmp.t, majorVersion, binary)
}

// AssertMatchesAny ensures the given query produces any one of the expected results.
func (mcmp *MySQLCompare) AssertMatchesAny(query string, expected ...string) {
mcmp.t.Helper()
Expand Down
54 changes: 27 additions & 27 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func TestAggregateTypes(t *testing.T) {
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`)
t.Run("Average for sharded keyspaces", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average for sharded keyspaces", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
})
}
Expand All @@ -101,7 +101,7 @@ func TestEqualFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having 1 = 1", `[[INT64(5)]]`)
Expand Down Expand Up @@ -177,8 +177,8 @@ func TestAggrOnJoin(t *testing.T) {
mcmp.AssertMatches("select a.val1 from aggr_test a join t3 t on a.val2 = t.id7 group by a.val1 having count(*) = 4",
`[[VARCHAR("a")]]`)

t.Run("Average in join for sharded", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in join for sharded", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches(`select avg(a1.val2), avg(a2.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7`,
"[[DECIMAL(1.5000) DECIMAL(1.0000)]]")

Expand All @@ -196,7 +196,7 @@ func TestNotEqualFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having a != 5", `[]`)
Expand All @@ -220,7 +220,7 @@ func TestLessFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))
mcmp.AssertMatches("select count(*) as a from aggr_test having a < 10", `[[INT64(5)]]`)
mcmp.AssertMatches("select count(*) as a from aggr_test having 1 < a", `[[INT64(5)]]`)
Expand All @@ -243,7 +243,7 @@ func TestLessEqualFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having a <= 10", `[[INT64(5)]]`)
Expand All @@ -267,7 +267,7 @@ func TestGreaterFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having a > 1", `[[INT64(5)]]`)
Expand All @@ -291,7 +291,7 @@ func TestGreaterEqualFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having a >= 1", `[[INT64(5)]]`)
Expand Down Expand Up @@ -326,7 +326,7 @@ func TestAggOnTopOfLimit(t *testing.T) {
mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',6), (2,'a',1), (3,'b',1), (4,'c',3), (5,'c',4), (6,'b',null), (7,null,2), (8,null,null)")

for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))
mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]")
Expand All @@ -335,8 +335,8 @@ func TestAggOnTopOfLimit(t *testing.T) {
mcmp.AssertMatches("select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]")
mcmp.AssertMatches("select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`)
mcmp.AssertMatchesNoOrder("select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("select avg(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[NULL]]")
mcmp.AssertMatchesNoOrder("select val1, avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL DECIMAL(2.0000)] [VARCHAR("a") DECIMAL(3.5000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.5000)]]`)
})
Expand All @@ -347,8 +347,8 @@ func TestAggOnTopOfLimit(t *testing.T) {
mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
mcmp.AssertMatches("select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]")
mcmp.AssertMatches("select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`)
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("select count(*), sum(val1), avg(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0) FLOAT64(0)]]")
mcmp.AssertMatches("select count(val1), sum(id), avg(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7) DECIMAL(3.5000)]]")
mcmp.AssertMatchesNoOrder("select val1, count(val2), sum(val2), avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1",
Expand All @@ -363,14 +363,14 @@ func TestEmptyTableAggr(t *testing.T) {
defer closer()

for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
})
Expand All @@ -380,13 +380,13 @@ func TestEmptyTableAggr(t *testing.T) {
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'b1','bar',200)")

for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
Expand Down Expand Up @@ -434,8 +434,8 @@ func TestAggregateLeftJoin(t *testing.T) {
mcmp.AssertMatches("SELECT sum(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
mcmp.AssertMatches("SELECT count(*) FROM t2 LEFT JOIN t1 ON t1.t1_id = t2.id WHERE IFNULL(t1.name, 'NOTSET') = 'r'", `[[INT64(1)]]`)

t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("SELECT avg(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(0.5000)]]`)
mcmp.AssertMatches("SELECT avg(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1.0000)]]`)
aggregations := []string{
Expand Down Expand Up @@ -491,8 +491,8 @@ func TestScalarAggregate(t *testing.T) {

mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
mcmp.AssertMatches("select count(distinct val1) from aggr_test", `[[INT64(3)]]`)
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
})
}
Expand Down Expand Up @@ -551,8 +551,8 @@ func TestComplexAggregation(t *testing.T) {
mcmp.Exec(`SELECT shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT name+COUNT(t1_id)+1 FROM t1 GROUP BY name`)
mcmp.Exec(`SELECT COUNT(*)+shardkey+MIN(t1_id)+1+MAX(t1_id)*SUM(t1_id)+1+name FROM t1 GROUP BY shardkey, name`)
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey)+AVG(t1_id) FROM t1`)
})
}
Expand Down
12 changes: 6 additions & 6 deletions go/test/endtoend/vtgate/queries/dml/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestSimpleInsertSelect(t *testing.T) {
mcmp.Exec("insert into u_tbl(id, num) values (1,2),(3,4)")

for i, mode := range []string{"oltp", "olap"} {
t.Run(mode, func(t *testing.T) {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

qr := mcmp.Exec(fmt.Sprintf("insert into s_tbl(id, num) select id*%d, num*%d from s_tbl where id < 10", 10+i, 20+i))
Expand All @@ -65,7 +65,7 @@ func TestFailureInsertSelect(t *testing.T) {
mcmp.Exec("insert into u_tbl(id, num) values (1,2),(3,4)")

for _, mode := range []string{"oltp", "olap"} {
t.Run(mode, func(t *testing.T) {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

// primary key same
Expand Down Expand Up @@ -127,7 +127,7 @@ func TestAutoIncInsertSelect(t *testing.T) {
}}

for _, tcase := range tcases {
t.Run(tcase.query, func(t *testing.T) {
mcmp.Run(tcase.query, func(mcmp *utils.MySQLCompare) {
qr := utils.Exec(t, mcmp.VtConn, tcase.query)
assert.EqualValues(t, tcase.expRowsAffected, qr.RowsAffected)
assert.EqualValues(t, tcase.expInsertID, qr.InsertID)
Expand Down Expand Up @@ -178,7 +178,7 @@ func TestAutoIncInsertSelectOlapMode(t *testing.T) {
}}

for _, tcase := range tcases {
t.Run(tcase.query, func(t *testing.T) {
mcmp.Run(tcase.query, func(mcmp *utils.MySQLCompare) {
qr := utils.Exec(t, mcmp.VtConn, tcase.query)
assert.EqualValues(t, tcase.expRowsAffected, qr.RowsAffected)
assert.EqualValues(t, tcase.expInsertID, qr.InsertID)
Expand Down Expand Up @@ -386,7 +386,7 @@ func TestInsertSelectUnshardedUsingSharded(t *testing.T) {
mcmp.Exec("insert into s_tbl(id, num) values (1,2),(3,4)")

for _, mode := range []string{"oltp", "olap"} {
t.Run(mode, func(t *testing.T) {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))
qr := mcmp.Exec("insert into u_tbl(id, num) select id, num from s_tbl where s_tbl.id in (1,3)")
assert.EqualValues(t, 2, qr.RowsAffected)
Expand Down Expand Up @@ -453,7 +453,7 @@ func TestMixedCases(t *testing.T) {
}}

for _, tc := range tcases {
t.Run(tc.insQuery, func(t *testing.T) {
mcmp.Run(tc.insQuery, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, tc.insQuery)
utils.AssertMatches(t, mcmp.VtConn, tc.selQuery, tc.exp)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestLookupQueries(t *testing.T) {
(3, 'monkey', 'monkey')`)

for _, workload := range []string{"olap", "oltp"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, "set workload = "+workload)

mcmp.AssertMatches("select id from user where lookup = 'apa'", "[[INT64(1)] [INT64(2)]]")
Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func TestAnalyze(t *testing.T) {
defer closer()

for _, workload := range []string{"olap", "oltp"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
utils.Exec(t, mcmp.VtConn, "analyze table t1")
utils.Exec(t, mcmp.VtConn, "analyze table uks.unsharded")
Expand Down Expand Up @@ -309,7 +309,7 @@ func TestTransactionModeVar(t *testing.T) {
}}

for _, tcase := range tcases {
t.Run(tcase.setStmt, func(t *testing.T) {
mcmp.Run(tcase.setStmt, func(mcmp *utils.MySQLCompare) {
if tcase.setStmt != "" {
utils.Exec(t, mcmp.VtConn, tcase.setStmt)
}
Expand Down
Loading

0 comments on commit 979014c

Please sign in to comment.