Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type coercion between the sides of an UNION #15340

Merged
merged 6 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion go/test/endtoend/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ func AssertMatchesAny(t testing.TB, conn *mysql.Conn, query string, expected ...
return
}
}
t.Errorf("Query: %s (-want +got):\n%v\nGot:%s", query, expected, got)

var err strings.Builder
_, _ = fmt.Fprintf(&err, "Query did not match:\n%s\n", query)
for i, e := range expected {
_, _ = fmt.Fprintf(&err, "Expected query %d does not match.\nwant: %v\ngot: %v\n\n", i, e, got)
}
t.Error(err.String())
}

// AssertMatchesCompareMySQL executes the given query on both Vitess and MySQL and make sure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,26 @@ func TestInfrSchemaAndUnionAll(t *testing.T) {
}
}

func TestInfoschemaTypes(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")

require.NoError(t,
utils.WaitForAuthoritative(t, "ks", "t1", clusterInstance.VtgateProcess.ReadVSchema))

mcmp, closer := start(t)
defer closer()

mcmp.Exec(`
SELECT ORDINAL_POSITION
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = 'ks' AND TABLE_NAME = 't1'
UNION
SELECT ORDINAL_POSITION
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = 'ks' AND TABLE_NAME = 't2';
`)
}

func TestTypeORMQuery(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
// This test checks that we can run queries similar to the ones that the TypeORM framework uses
Expand Down
1 change: 1 addition & 0 deletions go/test/endtoend/vtgate/queries/orderby/orderby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ func TestOrderByComplex(t *testing.T) {
"select email, max(col) as max_col from (select email, col from user where col > 20) as filtered group by email order by max_col",
"select a.email, a.max_col from (select email, max(col) as max_col from user group by email) as a order by a.max_col desc",
"select email, max(col) as max_col from user where email like 'a%' group by email order by max_col, email",
`select email, max(col) as max_col from user group by email union select email, avg(col) as avg_col from user group by email order by email desc`,
}

for _, query := range queries {
Expand Down
117 changes: 55 additions & 62 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@
return nil, err
}

fields, err := c.getFields(res)
fields, fieldTypes, err := c.getFieldTypes(vcursor, res)
if err != nil {
return nil, err
}

var rows [][]sqltypes.Value
err = c.coerceAndVisitResults(res, fields, func(result *sqltypes.Result) error {
err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error {
rows = append(rows, result.Rows...)
return nil
}, evalengine.ParseSQLMode(vcursor.SQLMode()))
Expand All @@ -116,17 +116,17 @@
}, nil
}

func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field, sqlmode evalengine.SQLMode) error {
if len(row) != len(fields) {
func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fieldTypes []evalengine.Type, sqlmode evalengine.SQLMode) error {
if len(row) != len(fieldTypes) {
return errWrongNumberOfColumnsInSelect
}

for i, value := range row {
if _, found := c.NoNeedToTypeCheck[i]; found {
continue
}
if fields[i].Type != value.Type() {
newValue, err := evalengine.CoerceTo(value, fields[i].Type, sqlmode)
if fieldTypes[i].Type() != value.Type() {
newValue, err := evalengine.CoerceTo(value, fieldTypes[i], sqlmode)
if err != nil {
return err
}
Expand All @@ -136,44 +136,44 @@
return nil
}

func (c *Concatenate) getFields(res []*sqltypes.Result) (resultFields []*querypb.Field, err error) {
func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([]*querypb.Field, []evalengine.Type, error) {
if len(res) == 0 {
return nil, nil
return nil, nil, nil

Check warning on line 141 in go/vt/vtgate/engine/concatenate.go

View check run for this annotation

Codecov / codecov/patch

go/vt/vtgate/engine/concatenate.go#L141

Added line #L141 was not covered by tests
}

resultFields = res[0].Fields
columns := make([][]sqltypes.Type, len(resultFields))

addFields := func(fields []*querypb.Field) error {
if len(fields) != len(columns) {
return errWrongNumberOfColumnsInSelect
}
for idx, field := range fields {
columns[idx] = append(columns[idx], field.Type)
}
return nil
}
typers := make([]evalengine.TypeAggregator, len(res[0].Fields))
collations := vcursor.Environment().CollationEnv()

for _, r := range res {
if r == nil || r.Fields == nil {
continue
}
err := addFields(r.Fields)
if err != nil {
return nil, err
if len(r.Fields) != len(typers) {
return nil, nil, errWrongNumberOfColumnsInSelect
}
for idx, field := range r.Fields {
if err := typers[idx].AddField(field, collations); err != nil {
return nil, nil, err

Check warning on line 156 in go/vt/vtgate/engine/concatenate.go

View check run for this annotation

Codecov / codecov/patch

go/vt/vtgate/engine/concatenate.go#L156

Added line #L156 was not covered by tests
}
}
}

// The resulting column types need to be the coercion of all the input columns
for colIdx, t := range columns {
fields := make([]*querypb.Field, 0, len(typers))
types := make([]evalengine.Type, 0, len(typers))
for colIdx, typer := range typers {
f := res[0].Fields[colIdx]

if _, found := c.NoNeedToTypeCheck[colIdx]; found {
fields = append(fields, f)
types = append(types, evalengine.NewTypeFromField(f))
continue
}

resultFields[colIdx].Type = evalengine.AggregateTypes(t)
t := typer.Type()
fields = append(fields, t.ToField(f.Name))
types = append(types, t)
}

return resultFields, nil
return fields, types, nil
}

func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) {
Expand Down Expand Up @@ -250,7 +250,7 @@
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fields []*querypb.Field // Cached final field types
fieldTypes []evalengine.Type // Cached final field types
)

// Process each result chunk, considering type coercion.
Expand All @@ -263,7 +263,7 @@
needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
_, skip := c.NoNeedToTypeCheck[idx]
if !skip && fields[idx].Type != field.Type {
if !skip && fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
Expand All @@ -272,7 +272,7 @@
// Apply type coercion if needed.
if needsCoercion {
for _, row := range res.Rows {
if err := c.coerceValuesTo(row, fields, sqlmode); err != nil {
if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil {
return err
}
}
Expand All @@ -299,11 +299,10 @@

// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest)
if err != nil {
return err
}
resultChunk.Fields = fields

defer condFields.Broadcast()
return callback(resultChunk, currIndex)
Expand Down Expand Up @@ -370,12 +369,12 @@
firsts[i] = result[0]
}

fields, err := c.getFields(firsts)
_, fieldTypes, err := c.getFieldTypes(vcursor, firsts)
if err != nil {
return err
}
for _, res := range results {
if err = c.coerceAndVisitResults(res, fields, callback, sqlmode); err != nil {
if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil {
return err
}
}
Expand All @@ -385,26 +384,26 @@

func (c *Concatenate) coerceAndVisitResults(
res []*sqltypes.Result,
fields []*querypb.Field,
fieldTypes []evalengine.Type,
callback func(*sqltypes.Result) error,
sqlmode evalengine.SQLMode,
) error {
for _, r := range res {
if len(r.Rows) > 0 &&
len(fields) != len(r.Rows[0]) {
len(fieldTypes) != len(r.Rows[0]) {
return errWrongNumberOfColumnsInSelect
}

needsCoercion := false
for idx, field := range r.Fields {
if fields[idx].Type != field.Type {
if fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
}
if needsCoercion {
for _, row := range r.Rows {
err := c.coerceValuesTo(row, fields, sqlmode)
err := c.coerceValuesTo(row, fieldTypes, sqlmode)
if err != nil {
return err
}
Expand All @@ -420,35 +419,29 @@

// GetFields fetches the field info.
func (c *Concatenate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
res, err := c.Sources[0].GetFields(ctx, vcursor, bindVars)
if err != nil {
return nil, err
}

columns := make([][]sqltypes.Type, len(res.Fields))

addFields := func(fields []*querypb.Field) {
for idx, field := range fields {
columns[idx] = append(columns[idx], field.Type)
}
}

addFields(res.Fields)

for i := 1; i < len(c.Sources); i++ {
result, err := c.Sources[i].GetFields(ctx, vcursor, bindVars)
sourceFields := make([][]*querypb.Field, 0, len(c.Sources))
for _, src := range c.Sources {
f, err := src.GetFields(ctx, vcursor, bindVars)
if err != nil {
return nil, err
}
addFields(result.Fields)
sourceFields = append(sourceFields, f.Fields)
}

// The resulting column types need to be the coercion of all the input columns
for colIdx, t := range columns {
res.Fields[colIdx].Type = evalengine.AggregateTypes(t)
}
fields := make([]*querypb.Field, 0, len(sourceFields[0]))
collations := vcursor.Environment().CollationEnv()

return res, nil
for colIdx := 0; colIdx < len(sourceFields[0]); colIdx++ {
var typer evalengine.TypeAggregator
for _, src := range sourceFields {
if err := typer.AddField(src[colIdx], collations); err != nil {
return nil, err

Check warning on line 438 in go/vt/vtgate/engine/concatenate.go

View check run for this annotation

Codecov / codecov/patch

go/vt/vtgate/engine/concatenate.go#L438

Added line #L438 was not covered by tests
}
}
name := sourceFields[0][colIdx].Name
fields = append(fields, typer.Field(name))
}
return &sqltypes.Result{Fields: fields}, nil
}

// NeedsTransaction returns whether a transaction is needed for this primitive
Expand Down
28 changes: 19 additions & 9 deletions go/vt/vtgate/engine/concatenate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/test/utils"

"github.com/stretchr/testify/assert"
Expand All @@ -32,7 +33,17 @@ import (
)

func r(names, types string, rows ...string) *sqltypes.Result {
return sqltypes.MakeTestResult(sqltypes.MakeTestFields(names, types), rows...)
fields := sqltypes.MakeTestFields(names, types)
for _, f := range fields {
if sqltypes.IsText(f.Type) {
f.Charset = collations.CollationUtf8mb4ID
} else {
f.Charset = collations.CollationBinaryID
}
_, flags := sqltypes.TypeToMySQL(f.Type)
f.Flags = uint32(flags)
}
return sqltypes.MakeTestResult(fields, rows...)
}

func TestConcatenate_NoErrors(t *testing.T) {
Expand Down Expand Up @@ -173,12 +184,12 @@ func TestConcatenateTypes(t *testing.T) {
tests := []struct {
t1, t2, expected string
}{
{t1: "int32", t2: "int64", expected: "int64"},
{t1: "int32", t2: "int32", expected: "int32"},
{t1: "int32", t2: "varchar", expected: "varchar"},
{t1: "int32", t2: "decimal", expected: "decimal"},
{t1: "hexval", t2: "uint64", expected: "varchar"},
{t1: "varchar", t2: "varbinary", expected: "varbinary"},
{t1: "int32", t2: "int64", expected: `[name:"id" type:int64 charset:63]`},
{t1: "int32", t2: "int32", expected: `[name:"id" type:int32 charset:63]`},
{t1: "int32", t2: "varchar", expected: `[name:"id" type:varchar charset:255]`},
{t1: "int32", t2: "decimal", expected: `[name:"id" type:decimal charset:63]`},
{t1: "hexval", t2: "uint64", expected: `[name:"id" type:varchar charset:255]`},
{t1: "varchar", t2: "varbinary", expected: `[name:"id" type:varbinary charset:63 flags:128]`},
}

for _, test := range tests {
Expand All @@ -196,8 +207,7 @@ func TestConcatenateTypes(t *testing.T) {
res, err := concatenate.GetFields(context.Background(), &noopVCursor{}, nil)
require.NoError(t, err)

expected := fmt.Sprintf(`[name:"id" type:%s]`, test.expected)
assert.Equal(t, expected, strings.ToLower(fmt.Sprintf("%v", res.Fields)))
assert.Equal(t, test.expected, strings.ToLower(fmt.Sprintf("%v", res.Fields)))
})
}
}
Loading
Loading