Skip to content

Commit

Permalink
vtexplain: Fix setting up the column information (#15275)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Feb 19, 2024
1 parent f38dab3 commit 9a78e7d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
18 changes: 6 additions & 12 deletions go/vt/mysqlctl/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ func (mysqld *Mysqld) executeSchemaCommands(ctx context.Context, sql string) err
return mysqld.executeMysqlScript(ctx, params, sql)
}

func encodeEntityName(name string) string {
var buf strings.Builder
sqltypes.NewVarChar(name).EncodeSQL(&buf)
return buf.String()
}

// tableListSQL returns an IN clause "('t1', 't2'...) for a list of tables."
func tableListSQL(tables []string) (string, error) {
if len(tables) == 0 {
Expand All @@ -80,7 +74,7 @@ func tableListSQL(tables []string) (string, error) {

encodedTables := make([]string, len(tables))
for i, tableName := range tables {
encodedTables[i] = encodeEntityName(tableName)
encodedTables[i] = sqltypes.EncodeStringSQL(tableName)
}

return "(" + strings.Join(encodedTables, ", ") + ")", nil
Expand Down Expand Up @@ -307,13 +301,13 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql
if dbName == "" {
dbName2 = "database()"
} else {
dbName2 = encodeEntityName(dbName)
dbName2 = sqltypes.EncodeStringSQL(dbName)
}
sanitizedTableName, err := sqlescape.UnescapeID(tableName)
if err != nil {
return "", err
}
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sanitizedTableName))
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, sqltypes.EncodeStringSQL(sanitizedTableName))
qr, err := exec(query, -1, true)
if err != nil {
return "", err
Expand Down Expand Up @@ -407,7 +401,7 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t
FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary'
ORDER BY table_name, SEQ_IN_INDEX`
sql = fmt.Sprintf(sql, encodeEntityName(dbName), tableList)
sql = fmt.Sprintf(sql, sqltypes.EncodeStringSQL(dbName), tableList)
qr, err := conn.Conn.ExecuteFetch(sql, len(tables)*100, true)
if err != nil {
return nil, err
Expand Down Expand Up @@ -631,8 +625,8 @@ func GetPrimaryKeyEquivalentColumns(ctx context.Context, exec func(string, int,
) AS pke ON index_cols.INDEX_NAME = pke.INDEX_NAME
WHERE index_cols.TABLE_SCHEMA = %s AND index_cols.TABLE_NAME = %s AND NON_UNIQUE = 0 AND NULLABLE != 'YES'
ORDER BY SEQ_IN_INDEX ASC`
encodedDbName := encodeEntityName(dbName)
encodedTable := encodeEntityName(table)
encodedDbName := sqltypes.EncodeStringSQL(dbName)
encodedTable := sqltypes.EncodeStringSQL(table)
sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable)
qr, err := exec(sql, 1000, true)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtexplain/vtexplain_vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio
}
tEnv.addResult(query, tEnv.getResult(likeQuery))

likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedLikeTable)
query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedTable)
likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sanitizedLikeTable))
query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sanitizedTable))
if tEnv.getResult(likeQuery) == nil {
return nil, fmt.Errorf("check your schema, table[%s] doesn't exist", likeTable)
}
Expand Down Expand Up @@ -516,7 +516,7 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio
tEnv.addResult("SELECT * FROM "+backtickedTable+" WHERE 1 != 1", &sqltypes.Result{
Fields: rowTypes,
})
query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedTable)
query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sanitizedTable))
tEnv.addResult(query, &sqltypes.Result{
Fields: colTypes,
Rows: colValues,
Expand Down Expand Up @@ -618,7 +618,7 @@ func (t *explainTablet) handleSelect(query string) (*sqltypes.Result, error) {

// Gen4 supports more complex queries so we now need to
// handle multiple FROM clauses
tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From))
tables := make([]*sqlparser.AliasedTableExpr, 0, len(selStmt.From))
for _, from := range selStmt.From {
tables = append(tables, getTables(from)...)
}
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtexplain/vtexplain_vttablet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ create table t2 (
require.NoError(t, err)
defer vte.Stop()

// Check if the correct schema query is registered.
_, found := vte.globalTabletEnv.schemaQueries["SELECT COLUMN_NAME as column_name\n\t\tFROM INFORMATION_SCHEMA.COLUMNS\n\t\tWHERE TABLE_SCHEMA = database() AND TABLE_NAME = 't1'\n\t\tORDER BY ORDINAL_POSITION"]
assert.True(t, found)

sql := "SELECT * FROM t1 INNER JOIN t2 ON t1.id = t2.id"

_, err = vte.Run(sql)
Expand Down

0 comments on commit 9a78e7d

Please sign in to comment.