Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release-19.0] Column alias expanding on ORDER BY (#15302) #15329

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go/mysql/sqlerror/sql_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{
vterrors.WrongParametersToNativeFct: {num: ERWrongParametersToNativeFct, state: SSUnknownSQLState},
vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState},
vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation},
vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState},
}

func getStateToMySQLState(state vterrors.State) mysqlCode {
Expand Down
6 changes: 6 additions & 0 deletions go/test/endtoend/utils/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ func (mcmp *MySQLCompare) AssertMatches(query, expected string) {
}
}

// SkipIfBinaryIsBelowVersion should be used instead of using utils.SkipIfBinaryIsBelowVersion(t,
// This is because we might be inside a Run block that has a different `t` variable
func (mcmp *MySQLCompare) SkipIfBinaryIsBelowVersion(majorVersion int, binary string) {
SkipIfBinaryIsBelowVersion(mcmp.t, majorVersion, binary)
}

// AssertMatchesAny ensures the given query produces any one of the expected results.
func (mcmp *MySQLCompare) AssertMatchesAny(query string, expected ...string) {
mcmp.t.Helper()
Expand Down
54 changes: 27 additions & 27 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func TestAggregateTypes(t *testing.T) {
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`)
t.Run("Average for sharded keyspaces", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average for sharded keyspaces", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
})
}
Expand All @@ -101,7 +101,7 @@ func TestEqualFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having 1 = 1", `[[INT64(5)]]`)
Expand Down Expand Up @@ -177,8 +177,8 @@ func TestAggrOnJoin(t *testing.T) {
mcmp.AssertMatches("select a.val1 from aggr_test a join t3 t on a.val2 = t.id7 group by a.val1 having count(*) = 4",
`[[VARCHAR("a")]]`)

t.Run("Average in join for sharded", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in join for sharded", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches(`select avg(a1.val2), avg(a2.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7`,
"[[DECIMAL(1.5000) DECIMAL(1.0000)]]")

Expand All @@ -196,7 +196,7 @@ func TestNotEqualFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having a != 5", `[]`)
Expand All @@ -220,7 +220,7 @@ func TestLessFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))
mcmp.AssertMatches("select count(*) as a from aggr_test having a < 10", `[[INT64(5)]]`)
mcmp.AssertMatches("select count(*) as a from aggr_test having 1 < a", `[[INT64(5)]]`)
Expand All @@ -243,7 +243,7 @@ func TestLessEqualFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having a <= 10", `[[INT64(5)]]`)
Expand All @@ -267,7 +267,7 @@ func TestGreaterFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having a > 1", `[[INT64(5)]]`)
Expand All @@ -291,7 +291,7 @@ func TestGreaterEqualFilterOnScatter(t *testing.T) {

workloads := []string{"oltp", "olap"}
for _, workload := range workloads {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))

mcmp.AssertMatches("select count(*) as a from aggr_test having a >= 1", `[[INT64(5)]]`)
Expand Down Expand Up @@ -326,7 +326,7 @@ func TestAggOnTopOfLimit(t *testing.T) {
mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',6), (2,'a',1), (3,'b',1), (4,'c',3), (5,'c',4), (6,'b',null), (7,null,2), (8,null,null)")

for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))
mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]")
mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]")
Expand All @@ -335,8 +335,8 @@ func TestAggOnTopOfLimit(t *testing.T) {
mcmp.AssertMatches("select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]")
mcmp.AssertMatches("select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`)
mcmp.AssertMatchesNoOrder("select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("select avg(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[NULL]]")
mcmp.AssertMatchesNoOrder("select val1, avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL DECIMAL(2.0000)] [VARCHAR("a") DECIMAL(3.5000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.5000)]]`)
})
Expand All @@ -347,8 +347,8 @@ func TestAggOnTopOfLimit(t *testing.T) {
mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
mcmp.AssertMatches("select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]")
mcmp.AssertMatches("select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`)
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("select count(*), sum(val1), avg(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0) FLOAT64(0)]]")
mcmp.AssertMatches("select count(val1), sum(id), avg(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7) DECIMAL(3.5000)]]")
mcmp.AssertMatchesNoOrder("select val1, count(val2), sum(val2), avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1",
Expand All @@ -363,14 +363,14 @@ func TestEmptyTableAggr(t *testing.T) {
defer closer()

for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
})
Expand All @@ -380,13 +380,13 @@ func TestEmptyTableAggr(t *testing.T) {
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'b1','bar',200)")

for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
Expand Down Expand Up @@ -434,8 +434,8 @@ func TestAggregateLeftJoin(t *testing.T) {
mcmp.AssertMatches("SELECT sum(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
mcmp.AssertMatches("SELECT count(*) FROM t2 LEFT JOIN t1 ON t1.t1_id = t2.id WHERE IFNULL(t1.name, 'NOTSET') = 'r'", `[[INT64(1)]]`)

t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("SELECT avg(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(0.5000)]]`)
mcmp.AssertMatches("SELECT avg(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1.0000)]]`)
aggregations := []string{
Expand Down Expand Up @@ -491,8 +491,8 @@ func TestScalarAggregate(t *testing.T) {

mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
mcmp.AssertMatches("select count(distinct val1) from aggr_test", `[[INT64(3)]]`)
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
})
}
Expand Down Expand Up @@ -551,8 +551,8 @@ func TestComplexAggregation(t *testing.T) {
mcmp.Exec(`SELECT shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT name+COUNT(t1_id)+1 FROM t1 GROUP BY name`)
mcmp.Exec(`SELECT COUNT(*)+shardkey+MIN(t1_id)+1+MAX(t1_id)*SUM(t1_id)+1+name FROM t1 GROUP BY shardkey, name`)
t.Run("Average in sharded query", func(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp.Run("Average in sharded query", func(mcmp *utils.MySQLCompare) {
mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate")
mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey)+AVG(t1_id) FROM t1`)
})
}
Expand Down
12 changes: 6 additions & 6 deletions go/test/endtoend/vtgate/queries/dml/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestSimpleInsertSelect(t *testing.T) {
mcmp.Exec("insert into u_tbl(id, num) values (1,2),(3,4)")

for i, mode := range []string{"oltp", "olap"} {
t.Run(mode, func(t *testing.T) {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

qr := mcmp.Exec(fmt.Sprintf("insert into s_tbl(id, num) select id*%d, num*%d from s_tbl where id < 10", 10+i, 20+i))
Expand All @@ -65,7 +65,7 @@ func TestFailureInsertSelect(t *testing.T) {
mcmp.Exec("insert into u_tbl(id, num) values (1,2),(3,4)")

for _, mode := range []string{"oltp", "olap"} {
t.Run(mode, func(t *testing.T) {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

// primary key same
Expand Down Expand Up @@ -127,7 +127,7 @@ func TestAutoIncInsertSelect(t *testing.T) {
}}

for _, tcase := range tcases {
t.Run(tcase.query, func(t *testing.T) {
mcmp.Run(tcase.query, func(mcmp *utils.MySQLCompare) {
qr := utils.Exec(t, mcmp.VtConn, tcase.query)
assert.EqualValues(t, tcase.expRowsAffected, qr.RowsAffected)
assert.EqualValues(t, tcase.expInsertID, qr.InsertID)
Expand Down Expand Up @@ -178,7 +178,7 @@ func TestAutoIncInsertSelectOlapMode(t *testing.T) {
}}

for _, tcase := range tcases {
t.Run(tcase.query, func(t *testing.T) {
mcmp.Run(tcase.query, func(mcmp *utils.MySQLCompare) {
qr := utils.Exec(t, mcmp.VtConn, tcase.query)
assert.EqualValues(t, tcase.expRowsAffected, qr.RowsAffected)
assert.EqualValues(t, tcase.expInsertID, qr.InsertID)
Expand Down Expand Up @@ -386,7 +386,7 @@ func TestInsertSelectUnshardedUsingSharded(t *testing.T) {
mcmp.Exec("insert into s_tbl(id, num) values (1,2),(3,4)")

for _, mode := range []string{"oltp", "olap"} {
t.Run(mode, func(t *testing.T) {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))
qr := mcmp.Exec("insert into u_tbl(id, num) select id, num from s_tbl where s_tbl.id in (1,3)")
assert.EqualValues(t, 2, qr.RowsAffected)
Expand Down Expand Up @@ -453,7 +453,7 @@ func TestMixedCases(t *testing.T) {
}}

for _, tc := range tcases {
t.Run(tc.insQuery, func(t *testing.T) {
mcmp.Run(tc.insQuery, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, tc.insQuery)
utils.AssertMatches(t, mcmp.VtConn, tc.selQuery, tc.exp)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestLookupQueries(t *testing.T) {
(3, 'monkey', 'monkey')`)

for _, workload := range []string{"olap", "oltp"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, "set workload = "+workload)

mcmp.AssertMatches("select id from user where lookup = 'apa'", "[[INT64(1)] [INT64(2)]]")
Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func TestAnalyze(t *testing.T) {
defer closer()

for _, workload := range []string{"olap", "oltp"} {
t.Run(workload, func(t *testing.T) {
mcmp.Run(workload, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
utils.Exec(t, mcmp.VtConn, "analyze table t1")
utils.Exec(t, mcmp.VtConn, "analyze table uks.unsharded")
Expand Down Expand Up @@ -309,7 +309,7 @@ func TestTransactionModeVar(t *testing.T) {
}}

for _, tcase := range tcases {
t.Run(tcase.setStmt, func(t *testing.T) {
mcmp.Run(tcase.setStmt, func(mcmp *utils.MySQLCompare) {
if tcase.setStmt != "" {
utils.Exec(t, mcmp.VtConn, tcase.setStmt)
}
Expand Down
Loading
Loading