From e0325be4f8ddbaad1cc1c9fa48b149f17ddd491f Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 11 Jun 2024 21:52:42 +0200 Subject: [PATCH] Ensure that we check the correct collation for foreign keys (#16109) Signed-off-by: Dirkjan Bussink --- go/vt/schemadiff/schema.go | 75 ++++++++++++++++++++++----------- go/vt/schemadiff/schema_test.go | 20 +++++++++ 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index 1738b9a4836..df1e23f0ccc 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -359,30 +359,6 @@ func (s *Schema) normalize(hints *DiffHints) error { return errors.Join(errs, err) } } - colTypeCompatibleForForeignKey := func(child, parent *sqlparser.ColumnType) bool { - if child.Type == parent.Type { - return true - } - if child.Type == "char" && parent.Type == "varchar" { - return true - } - if child.Type == "varchar" && parent.Type == "char" { - return true - } - return false - } - colTypeEqualForForeignKey := func(child, parent *sqlparser.ColumnType) bool { - if colTypeCompatibleForForeignKey(child, parent) && - child.Unsigned == parent.Unsigned && - child.Zerofill == parent.Zerofill && - sqlparser.Equals.ColumnCharset(child.Charset, parent.Charset) && - child.Options.Collate == parent.Options.Collate && - sqlparser.Equals.SliceOfString(child.EnumValues, parent.EnumValues) { - // Complete identify (other than precision which is ignored) - return true - } - return false - } // Now validate foreign key columns: // - referenced table columns must exist @@ -429,7 +405,7 @@ func (s *Schema) normalize(hints *DiffHints) error { if !ok { return errors.Join(errs, &InvalidReferencedColumnInForeignKeyConstraintError{Table: t.Name(), Constraint: cs.Name.String(), ReferencedTable: referencedTableName, ReferencedColumn: referencedColumnName}) } - if !colTypeEqualForForeignKey(coveredColumn.Type, referencedColumn.Type) { + if !colTypeEqualForForeignKey(s.env, t.TableSpec, referencedTable.CreateTable.TableSpec, coveredColumn.Type, referencedColumn.Type) { return errors.Join(errs, &ForeignKeyColumnTypeMismatchError{Table: t.Name(), Constraint: cs.Name.String(), Column: coveredColumn.Name.String(), ReferencedTable: referencedTableName, ReferencedColumn: referencedColumnName}) } } @@ -442,6 +418,55 @@ func (s *Schema) normalize(hints *DiffHints) error { return errs } +func colTypeCompatibleForForeignKey(child, parent *sqlparser.ColumnType) bool { + if child.Type == parent.Type { + return true + } + if child.Type == "char" && parent.Type == "varchar" { + return true + } + if child.Type == "varchar" && parent.Type == "char" { + return true + } + return false +} + +func colTypeEqualForForeignKey(env *Environment, ct, pt *sqlparser.TableSpec, child, parent *sqlparser.ColumnType) bool { + if colTypeCompatibleForForeignKey(child, parent) && + child.Unsigned == parent.Unsigned && + child.Zerofill == parent.Zerofill && + colCollationEqualForForeignKey(env, ct, pt, child, parent) && + sqlparser.Equals.SliceOfString(child.EnumValues, parent.EnumValues) { + // Complete identify (other than precision which is ignored) + return true + } + return false +} + +func colCollationEqualForForeignKey(env *Environment, ct, pt *sqlparser.TableSpec, child, parent *sqlparser.ColumnType) bool { + return *colCollation(env, ct, child) == *colCollation(env, pt, parent) +} + +func colCollation(env *Environment, t *sqlparser.TableSpec, col *sqlparser.ColumnType) *charsetCollate { + tc := getTableCharsetCollate(env, &t.Options) + cc := &charsetCollate{} + if col.Charset.Name != "" { + cc.charset = col.Charset.Name + } else if tc.charset != "" { + cc.charset = tc.charset + } else { + cc.charset = env.CollationEnv().LookupCharsetName(env.DefaultColl) + } + if col.Options != nil && col.Options.Collate != "" { + cc.collate = col.Options.Collate + } else if tc.collate != "" { + cc.collate = tc.collate + } else { + cc.collate = env.CollationEnv().LookupName(env.DefaultColl) + } + return cc +} + // Entities returns this schema's entities in good order (may be applied without error) func (s *Schema) Entities() []Entity { return s.sorted diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index 23782e676a1..c35cc224714 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -480,6 +480,26 @@ func TestInvalidSchema(t *testing.T) { schema: "create table t10(id VARCHAR(50) charset utf8mb4 collate utf8mb4_0900_ai_ci primary key); create table t11 (id int primary key, i VARCHAR(100) charset utf8mb4 collate utf8mb4_general_ci, key ix(i), constraint f10 foreign key (i) references t10(id) on delete restrict)", expectErr: &ForeignKeyColumnTypeMismatchError{Table: "t11", Constraint: "f10", Column: "i", ReferencedTable: "t10", ReferencedColumn: "id"}, }, + { + schema: "create table post (id varchar(191) not null, `title` text, primary key (`id`)); create table post_fks (id varchar(191) not null, `post_id` varchar(191) collate utf8mb4_0900_ai_ci not null, primary key (id), constraint post_fk foreign key (post_id) references post (id)) charset utf8mb4, collate utf8mb4_0900_as_ci;", + }, + { + schema: "create table post (id varchar(191) not null, `title` text, primary key (`id`)); create table post_fks (id varchar(191) not null, `post_id` varchar(191) collate utf8mb4_0900_ai_ci not null, primary key (id), constraint post_fk foreign key (post_id) references post (id)) collate utf8mb4_0900_as_ci;", + }, + { + schema: "create table post (id varchar(191) not null, `title` text, primary key (`id`)); create table post_fks (id varchar(191) not null, `post_id` varchar(191) collate utf8mb4_0900_ai_ci not null, primary key (id), constraint post_fk foreign key (post_id) references post (id)) charset utf8mb4;", + }, + { + schema: "create table post (id varchar(191) not null, `title` text, primary key (`id`)); create table post_fks (id varchar(191) not null, `post_id` varchar(191), primary key (id), constraint post_fk foreign key (post_id) references post (id)) charset utf8mb4, collate utf8mb4_0900_as_ci;", + expectErr: &ForeignKeyColumnTypeMismatchError{Table: "post_fks", Constraint: "post_fk", Column: "post_id", ReferencedTable: "post", ReferencedColumn: "id"}, + }, + { + schema: "create table post (id varchar(191) charset utf8mb4 not null, `title` text, primary key (`id`)); create table post_fks (id varchar(191) not null, `post_id` varchar(191), primary key (id), constraint post_fk foreign key (post_id) references post (id)) charset utf8mb4, collate utf8mb4_0900_as_ci;", + expectErr: &ForeignKeyColumnTypeMismatchError{Table: "post_fks", Constraint: "post_fk", Column: "post_id", ReferencedTable: "post", ReferencedColumn: "id"}, + }, + { + schema: "create table post (id varchar(191) charset utf8mb4 not null, `title` text, primary key (`id`)); create table post_fks (id varchar(191) not null, `post_id` varchar(191) collate utf8mb4_0900_ai_ci, primary key (id), constraint post_fk foreign key (post_id) references post (id)) charset utf8mb4, collate utf8mb4_0900_as_ci;", + }, } for _, ts := range tt { t.Run(ts.schema, func(t *testing.T) {