diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index f289438a1c9..e2205a6f6a8 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -34,6 +34,7 @@ type analyzer struct { binder *binder typer *typer rewriter *earlyRewriter + fk *fkManager sig QuerySignature si SchemaInformation currentDb string @@ -78,6 +79,12 @@ func (a *analyzer) lateInit() { reAnalyze: a.reAnalyze, tables: a.tables, } + a.fk = &fkManager{ + binder: a.binder, + tables: a.tables, + si: a.si, + getError: a.getError, + } } // Analyze analyzes the parsed query. @@ -142,7 +149,6 @@ func (a *analyzer) newSemTable( ColumnEqualities: map[columnName][]sqlparser.Expr{}, ExpandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, columns: map[*sqlparser.Union]sqlparser.SelectExprs{}, - comparator: nil, StatementIDs: a.scoper.statementIDs, QuerySignature: QuerySignature{}, childForeignKeysInvolved: map[TableSet][]vindexes.ChildFKInfo{}, @@ -157,7 +163,7 @@ func (a *analyzer) newSemTable( columns[union] = info.exprs } - childFks, parentFks, childFkToUpdExprs, err := a.getInvolvedForeignKeys(statement, fkChecksState) + childFks, parentFks, childFkToUpdExprs, err := a.fk.getInvolvedForeignKeys(statement, fkChecksState) if err != nil { return nil, err } @@ -448,149 +454,6 @@ func (a *analyzer) noteQuerySignature(node sqlparser.SQLNode) { } } -// getInvolvedForeignKeys gets the foreign keys that might require taking care off when executing the given statement. -func (a *analyzer) getInvolvedForeignKeys(statement sqlparser.Statement, fkChecksState *bool) (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, map[string]sqlparser.UpdateExprs, error) { - if fkChecksState != nil && !*fkChecksState { - return nil, nil, nil, nil - } - // There are only the DML statements that require any foreign keys handling. - switch stmt := statement.(type) { - case *sqlparser.Delete: - // For DELETE statements, none of the parent foreign keys require handling. - // So we collect all the child foreign keys. - allChildFks, _, err := a.getAllManagedForeignKeys() - return allChildFks, nil, nil, err - case *sqlparser.Insert: - // For INSERT statements, we have 3 different cases: - // 1. REPLACE statement: REPLACE statements are essentially DELETEs and INSERTs rolled into one. - // So we need to the parent foreign keys to ensure we are inserting the correct values, and the child foreign keys - // to ensure we don't change a row that breaks the constraint or cascade any operations on the child tables. - // 2. Normal INSERT statement: We don't need to check anything on the child foreign keys, so we just get all the parent foreign keys. - // 3. INSERT with ON DUPLICATE KEY UPDATE: This might trigger an update on the columns specified in the ON DUPLICATE KEY UPDATE clause. - allChildFks, allParentFKs, err := a.getAllManagedForeignKeys() - if err != nil { - return nil, nil, nil, err - } - if stmt.Action == sqlparser.ReplaceAct { - return allChildFks, allParentFKs, nil, nil - } - if len(stmt.OnDup) == 0 { - return nil, allParentFKs, nil, nil - } - // If only a certain set of columns are being updated, then there might be some child foreign keys that don't need any consideration since their columns aren't being updated. - // So, we filter these child foreign keys out. We can't filter any parent foreign keys because the statement will INSERT a row too, which requires validating all the parent foreign keys. - updatedChildFks, _, childFkToUpdExprs, err := a.filterForeignKeysUsingUpdateExpressions(allChildFks, nil, sqlparser.UpdateExprs(stmt.OnDup)) - return updatedChildFks, allParentFKs, childFkToUpdExprs, err - case *sqlparser.Update: - // For UPDATE queries we get all the parent and child foreign keys, but we can filter some of them out if the columns that they consist off aren't being updated or are set to NULLs. - allChildFks, allParentFks, err := a.getAllManagedForeignKeys() - if err != nil { - return nil, nil, nil, err - } - return a.filterForeignKeysUsingUpdateExpressions(allChildFks, allParentFks, stmt.Exprs) - default: - return nil, nil, nil, nil - } -} - -// filterForeignKeysUsingUpdateExpressions filters the child and parent foreign key constraints that don't require any validations/cascades given the updated expressions. -func (a *analyzer) filterForeignKeysUsingUpdateExpressions(allChildFks map[TableSet][]vindexes.ChildFKInfo, allParentFks map[TableSet][]vindexes.ParentFKInfo, updExprs sqlparser.UpdateExprs) (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, map[string]sqlparser.UpdateExprs, error) { - if len(allChildFks) == 0 && len(allParentFks) == 0 { - return nil, nil, nil, nil - } - - pFksRequired := make(map[TableSet][]bool, len(allParentFks)) - cFksRequired := make(map[TableSet][]bool, len(allChildFks)) - for ts, fks := range allParentFks { - pFksRequired[ts] = make([]bool, len(fks)) - } - for ts, fks := range allChildFks { - cFksRequired[ts] = make([]bool, len(fks)) - } - - // updExprToTableSet stores the tables that the updated expressions are from. - updExprToTableSet := make(map[*sqlparser.ColName]TableSet) - - // childFKToUpdExprs stores child foreign key to update expressions mapping. - childFKToUpdExprs := map[string]sqlparser.UpdateExprs{} - - // Go over all the update expressions - for _, updateExpr := range updExprs { - deps := a.binder.direct.dependencies(updateExpr.Name) - if deps.NumberOfTables() != 1 { - // If we don't get exactly one table for the given update expression, we would have definitely run into an error - // during the binder phase that we would have stored. We should return that error, since we can't safely proceed with - // foreign key related changes without having all the information. - return nil, nil, nil, a.getError() - } - updExprToTableSet[updateExpr.Name] = deps - // Get all the child and parent foreign keys for the given table that the update expression belongs to. - childFks := allChildFks[deps] - parentFKs := allParentFks[deps] - - // Any foreign key to a child table for a column that has been updated - // will require the cascade operations or restrict verification to happen, so we include all such foreign keys. - for idx, childFk := range childFks { - if childFk.ParentColumns.FindColumn(updateExpr.Name.Name) >= 0 { - cFksRequired[deps][idx] = true - tbl, _ := a.tables.tableInfoFor(deps) - ue := childFKToUpdExprs[childFk.String(tbl.GetVindexTable())] - ue = append(ue, updateExpr) - childFKToUpdExprs[childFk.String(tbl.GetVindexTable())] = ue - } - } - // If we are setting a column to NULL, then we don't need to verify the existence of an - // equivalent row in the parent table, even if this column was part of a foreign key to a parent table. - if sqlparser.IsNull(updateExpr.Expr) { - continue - } - // We add all the possible parent foreign key constraints that need verification that an equivalent row - // exists, given that this column has changed. - for idx, parentFk := range parentFKs { - if parentFk.ChildColumns.FindColumn(updateExpr.Name.Name) >= 0 { - pFksRequired[deps][idx] = true - } - } - } - // For the parent foreign keys, if any of the columns part of the fk is set to NULL, - // then, we don't care for the existence of an equivalent row in the parent table. - for _, updateExpr := range updExprs { - if !sqlparser.IsNull(updateExpr.Expr) { - continue - } - ts := updExprToTableSet[updateExpr.Name] - parentFKs := allParentFks[ts] - for idx, parentFk := range parentFKs { - if parentFk.ChildColumns.FindColumn(updateExpr.Name.Name) >= 0 { - pFksRequired[ts][idx] = false - } - } - } - - // Create new maps with only the required foreign keys. - pFksNeedsHandling := map[TableSet][]vindexes.ParentFKInfo{} - cFksNeedsHandling := map[TableSet][]vindexes.ChildFKInfo{} - for ts, parentFks := range allParentFks { - var pFKNeeded []vindexes.ParentFKInfo - for idx, fk := range parentFks { - if pFksRequired[ts][idx] { - pFKNeeded = append(pFKNeeded, fk) - } - } - pFksNeedsHandling[ts] = pFKNeeded - } - for ts, childFks := range allChildFks { - var cFKNeeded []vindexes.ChildFKInfo - for idx, fk := range childFks { - if cFksRequired[ts][idx] { - cFKNeeded = append(cFKNeeded, fk) - } - } - cFksNeedsHandling[ts] = cFKNeeded - } - return cFksNeedsHandling, pFksNeedsHandling, childFKToUpdExprs, nil -} - // getError gets the error stored in the analyzer during previous phases. func (a *analyzer) getError() error { if a.projErr != nil { @@ -602,40 +465,6 @@ func (a *analyzer) getError() error { return a.err } -// getAllManagedForeignKeys gets all the foreign keys for the query we are analyzing that Vitess is responsible for managing. -func (a *analyzer) getAllManagedForeignKeys() (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, error) { - allChildFKs := make(map[TableSet][]vindexes.ChildFKInfo) - allParentFKs := make(map[TableSet][]vindexes.ParentFKInfo) - - // Go over all the tables and collect the foreign keys. - for idx, table := range a.tables.Tables { - vi := table.GetVindexTable() - if vi == nil || vi.Keyspace == nil { - // If is not a real table, so should be skipped. - continue - } - // Check whether Vitess needs to manage the foreign keys in this keyspace or not. - fkMode, err := a.si.ForeignKeyMode(vi.Keyspace.Name) - if err != nil { - return nil, nil, err - } - if fkMode != vschemapb.Keyspace_managed { - continue - } - // Cyclic foreign key constraints error is stored in the keyspace. - ksErr := a.si.KeyspaceError(vi.Keyspace.Name) - if ksErr != nil { - return nil, nil, ksErr - } - - // Add all the child and parent foreign keys to our map. - ts := SingleTableSet(idx) - allChildFKs[ts] = vi.ChildForeignKeys - allParentFKs[ts] = vi.ParentForeignKeys - } - return allChildFKs, allParentFKs, nil -} - // ProjError is used to mark an error as something that should only be returned // if the planner fails to merge everything down to a single route type ProjError struct { diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index 9d91f6523cf..f93dd579898 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -57,73 +57,93 @@ func newBinder(scoper *scoper, org originable, tc *tableCollector, typer *typer) } func (b *binder) up(cursor *sqlparser.Cursor) error { - node := cursor.Node() - switch node := node.(type) { + switch node := cursor.Node().(type) { case *sqlparser.Subquery: - currScope := b.scoper.currentScope() - b.setSubQueryDependencies(node, currScope) + return b.setSubQueryDependencies(node) case *sqlparser.JoinCondition: - currScope := b.scoper.currentScope() - for _, ident := range node.Using { - name := sqlparser.NewColName(ident.String()) - deps, err := b.resolveColumn(name, currScope, true, true) - if err != nil { - return err - } - currScope.joinUsing[ident.Lowered()] = deps.direct - } + return b.bindJoinCondition(node) case *sqlparser.ColName: - currentScope := b.scoper.currentScope() - deps, err := b.resolveColumn(node, currentScope, false, true) - if err != nil { - if deps.direct.IsEmpty() || - !strings.HasSuffix(err.Error(), "is ambiguous") || - !b.canRewriteUsingJoin(deps, node) { - return err - } - - // if we got here it means we are dealing with a ColName that is involved in a JOIN USING. - // we do the rewriting of these ColName structs here because it would be difficult to copy all the - // needed state over to the earlyRewriter - deps, err = b.rewriteJoinUsingColName(deps, node, currentScope) - if err != nil { - return err - } - } - b.recursive[node] = deps.recursive - b.direct[node] = deps.direct - if deps.typ.Valid() { - b.typer.setTypeFor(node, deps.typ) - } + return b.bindColName(node) case *sqlparser.CountStar: - b.bindCountStar(node) + return b.bindCountStar(node) case *sqlparser.Union: - info := b.tc.unionInfo[node] - // TODO: this check can be removed and available type information should be used. - if !info.isAuthoritative { - return nil + return b.bindUnion(node) + case sqlparser.TableNames: + return b.bindTableNames(cursor, node) + default: + return nil + } +} + +func (b *binder) bindTableNames(cursor *sqlparser.Cursor, tables sqlparser.TableNames) error { + _, isDelete := cursor.Parent().(*sqlparser.Delete) + if !isDelete { + return nil + } + current := b.scoper.currentScope() + for _, target := range tables { + finalDep, err := b.findDependentTableSet(current, target) + if err != nil { + return err } + b.targets[target.Name] = finalDep.direct + } + return nil +} - for i, expr := range info.exprs { - ae := expr.(*sqlparser.AliasedExpr) - b.recursive[ae.Expr] = info.recursive[i] - if t := info.types[i]; t.Valid() { - b.typer.m[ae.Expr] = t - } +func (b *binder) bindUnion(union *sqlparser.Union) error { + info := b.tc.unionInfo[union] + // TODO: this check can be removed and available type information should be used. + if !info.isAuthoritative { + return nil + } + + for i, expr := range info.exprs { + ae := expr.(*sqlparser.AliasedExpr) + b.recursive[ae.Expr] = info.recursive[i] + if t := info.types[i]; t.Valid() { + b.typer.m[ae.Expr] = t } - case sqlparser.TableNames: - _, isDelete := cursor.Parent().(*sqlparser.Delete) - if !isDelete { - return nil + } + return nil +} + +func (b *binder) bindColName(col *sqlparser.ColName) error { + currentScope := b.scoper.currentScope() + deps, err := b.resolveColumn(col, currentScope, false, true) + if err != nil { + s := err.Error() + if deps.direct.IsEmpty() || + !strings.HasSuffix(s, "is ambiguous") || + !b.canRewriteUsingJoin(deps, col) { + return err } - current := b.scoper.currentScope() - for _, target := range node { - finalDep, err := b.findDependentTableSet(current, target) - if err != nil { - return err - } - b.targets[target.Name] = finalDep.direct + + // if we got here it means we are dealing with a ColName that is involved in a JOIN USING. + // we do the rewriting of these ColName structs here because it would be difficult to copy all the + // needed state over to the earlyRewriter + deps, err = b.rewriteJoinUsingColName(deps, col, currentScope) + if err != nil { + return err + } + } + b.recursive[col] = deps.recursive + b.direct[col] = deps.direct + if deps.typ.Valid() { + b.typer.setTypeFor(col, deps.typ) + } + return nil +} + +func (b *binder) bindJoinCondition(condition *sqlparser.JoinCondition) error { + currScope := b.scoper.currentScope() + for _, ident := range condition.Using { + name := sqlparser.NewColName(ident.String()) + deps, err := b.resolveColumn(name, currScope, true, true) + if err != nil { + return err } + currScope.joinUsing[ident.Lowered()] = deps.direct } return nil } @@ -142,7 +162,7 @@ func (b *binder) findDependentTableSet(current *scope, target sqlparser.TableNam c := createCertain(ts, ts, evalengine.Type{}) deps = deps.merge(c, false) } - finalDep, err := deps.get() + finalDep, err := deps.get(nil) if err != nil { return dependency{}, err } @@ -152,7 +172,7 @@ func (b *binder) findDependentTableSet(current *scope, target sqlparser.TableNam return finalDep, nil } -func (b *binder) bindCountStar(node *sqlparser.CountStar) { +func (b *binder) bindCountStar(node *sqlparser.CountStar) error { scope := b.scoper.currentScope() var ts TableSet for _, tbl := range scope.tables { @@ -169,6 +189,7 @@ func (b *binder) bindCountStar(node *sqlparser.CountStar) { } b.recursive[node] = ts b.direct[node] = ts + return nil } func (b *binder) rewriteJoinUsingColName(deps dependency, node *sqlparser.ColName, currentScope *scope) (dependency, error) { @@ -210,7 +231,8 @@ func (b *binder) canRewriteUsingJoin(deps dependency, node *sqlparser.ColName) b // the binder usually only sets the dependencies of ColNames, but we need to // handle the subquery dependencies differently, so they are set manually here // this method will only keep dependencies to tables outside the subquery -func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery, currScope *scope) { +func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery) error { + currScope := b.scoper.currentScope() subqRecursiveDeps := b.recursive.dependencies(subq) subqDirectDeps := b.direct.dependencies(subq) @@ -225,11 +247,12 @@ func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery, currScope *sc b.recursive[subq] = subqRecursiveDeps.KeepOnly(tablesToKeep) b.direct[subq] = subqDirectDeps.KeepOnly(tablesToKeep) + return nil } func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allowMulti, singleTableFallBack bool) (dependency, error) { if !current.stmtScope && current.inGroupBy { - return b.resolveColInGroupBy(colName, current, allowMulti, singleTableFallBack) + return b.resolveColInGroupBy(colName, current, allowMulti) } if !current.stmtScope && current.inHaving && !current.inHavingAggr { return b.resolveColumnInHaving(colName, current, allowMulti) @@ -243,11 +266,10 @@ func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allow var err error thisDeps, err = b.resolveColumnInScope(current, colName, allowMulti) if err != nil { - return dependency{}, makeAmbiguousError(colName, err) + return dependency{}, err } if !thisDeps.empty() { - deps, err := thisDeps.get() - return deps, makeAmbiguousError(colName, err) + return thisDeps.get(colName) } if current.parent == nil && len(current.tables) == 1 && @@ -294,16 +316,12 @@ func (b *binder) resolveColumnInHaving(colName *sqlparser.ColName, current *scop // Here we are searching among the SELECT expressions for a match thisDeps, err := b.resolveColumnInScope(current, colName, allowMulti) if err != nil { - return dependency{}, makeAmbiguousError(colName, err) + return dependency{}, err } if !thisDeps.empty() { // we found something! let's return it - deps, err := thisDeps.get() - if err != nil { - err = makeAmbiguousError(colName, err) - } - return deps, err + return thisDeps.get(colName) } notFoundErr := &ColumnNotFoundClauseError{Column: colName.Name.String(), Clause: "having clause"} @@ -376,7 +394,6 @@ func (b *binder) resolveColInGroupBy( colName *sqlparser.ColName, current *scope, allowMulti bool, - singleTableFallBack bool, ) (dependency, error) { if current.parent == nil { return dependency{}, vterrors.VT13001("did not expect this to be the last scope") @@ -408,7 +425,7 @@ func (b *binder) resolveColInGroupBy( } return deps, firstErr } - return dependencies.get() + return dependencies.get(colName) } func (b *binder) resolveColumnInScope(current *scope, expr *sqlparser.ColName, allowMulti bool) (dependencies, error) { @@ -425,18 +442,11 @@ func (b *binder) resolveColumnInScope(current *scope, expr *sqlparser.ColName, a } if deps, isUncertain := deps.(*uncertain); isUncertain && deps.fail { // if we have a failure from uncertain, we matched the column to multiple non-authoritative tables - return nil, ProjError{Inner: &AmbiguousColumnError{Column: sqlparser.String(expr)}} + return nil, ProjError{Inner: newAmbiguousColumnError(expr)} } return deps, nil } -func makeAmbiguousError(colName *sqlparser.ColName, err error) error { - if err == ambigousErr { - err = &AmbiguousColumnError{Column: sqlparser.String(colName)} - } - return err -} - // GetSubqueryAndOtherSide returns the subquery and other side of a comparison, iff one of the sides is a SubQuery func GetSubqueryAndOtherSide(node *sqlparser.ComparisonExpr) (*sqlparser.Subquery, sqlparser.Expr) { var subq *sqlparser.Subquery diff --git a/go/vt/vtgate/semantics/dependencies.go b/go/vt/vtgate/semantics/dependencies.go index 714fa97c2c4..70167ff02fc 100644 --- a/go/vt/vtgate/semantics/dependencies.go +++ b/go/vt/vtgate/semantics/dependencies.go @@ -18,8 +18,7 @@ package semantics import ( querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/evalengine" ) @@ -28,7 +27,7 @@ type ( // tables and figure out bindings and/or errors by merging dependencies together dependencies interface { empty() bool - get() (dependency, error) + get(col *sqlparser.ColName) (dependency, error) merge(other dependencies, allowMulti bool) dependencies } dependency struct { @@ -40,7 +39,7 @@ type ( nothing struct{} certain struct { dependency - err error + err bool } uncertain struct { dependency @@ -48,8 +47,6 @@ type ( } ) -var ambigousErr = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "ambiguous") - func createCertain(direct TableSet, recursive TableSet, qt evalengine.Type) *certain { c := &certain{ dependency: dependency{ @@ -82,9 +79,9 @@ func (u *uncertain) empty() bool { return false } -func (u *uncertain) get() (dependency, error) { +func (u *uncertain) get(col *sqlparser.ColName) (dependency, error) { if u.fail { - return dependency{}, ambigousErr + return dependency{}, newAmbiguousColumnError(col) } return u.dependency, nil } @@ -107,8 +104,11 @@ func (c *certain) empty() bool { return false } -func (c *certain) get() (dependency, error) { - return c.dependency, c.err +func (c *certain) get(col *sqlparser.ColName) (dependency, error) { + if c.err { + return c.dependency, newAmbiguousColumnError(col) + } + return c.dependency, nil } func (c *certain) merge(d dependencies, allowMulti bool) dependencies { @@ -120,7 +120,7 @@ func (c *certain) merge(d dependencies, allowMulti bool) dependencies { c.direct = c.direct.Merge(d.direct) c.recursive = c.recursive.Merge(d.recursive) if !allowMulti { - c.err = ambigousErr + c.err = true } return c @@ -133,7 +133,7 @@ func (n *nothing) empty() bool { return true } -func (n *nothing) get() (dependency, error) { +func (n *nothing) get(*sqlparser.ColName) (dependency, error) { return dependency{certain: true}, nil } diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 16ffc9ee019..a8e1442edb8 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -463,22 +463,12 @@ func (r *earlyRewriter) rewriteAliasesInGroupBy(node sqlparser.Expr, sel *sqlpar currentScope := r.scoper.currentScope() aliases := r.getAliasMap(sel) - insideAggr := false - downF := func(node, _ sqlparser.SQLNode) bool { - switch node.(type) { - case *sqlparser.Subquery: - return false - case sqlparser.AggrFunc: - insideAggr = true - } + aggrTrack := &aggrTracker{} - return true - } - - output := sqlparser.CopyOnRewrite(node, downF, func(cursor *sqlparser.CopyOnWriteCursor) { + output := sqlparser.CopyOnRewrite(node, aggrTrack.down, func(cursor *sqlparser.CopyOnWriteCursor) { switch col := cursor.Node().(type) { case sqlparser.AggrFunc: - insideAggr = false + aggrTrack.popAggr() case *sqlparser.ColName: if col.Qualifier.NonEmpty() { // we are only interested in columns not qualified by table names @@ -504,8 +494,8 @@ func (r *earlyRewriter) rewriteAliasesInGroupBy(node sqlparser.Expr, sel *sqlpar } if item.ambiguous { - err = &AmbiguousColumnError{Column: sqlparser.String(col)} - } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = newAmbiguousColumnError(col) + } else if aggrTrack.insideAggr && sqlparser.ContainsAggregation(item.expr) { err = &InvalidUseOfGroupFunction{} } if err != nil { @@ -529,23 +519,13 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars } aliases := r.getAliasMap(sel) - insideAggr := false - dontEnterSubquery := func(node, _ sqlparser.SQLNode) bool { - switch node.(type) { - case *sqlparser.Subquery: - return false - case sqlparser.AggrFunc: - insideAggr = true - } - - return true - } - output := sqlparser.CopyOnRewrite(node, dontEnterSubquery, func(cursor *sqlparser.CopyOnWriteCursor) { + aggrTrack := &aggrTracker{} + output := sqlparser.CopyOnRewrite(node, aggrTrack.down, func(cursor *sqlparser.CopyOnWriteCursor) { var col *sqlparser.ColName switch node := cursor.Node().(type) { case sqlparser.AggrFunc: - insideAggr = false + aggrTrack.popAggr() return case *sqlparser.ColName: col = node @@ -559,7 +539,7 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars } item, found := aliases[col.Name.Lowered()] - if insideAggr { + if aggrTrack.insideAggr { // inside aggregations, we want to first look for columns in the FROM clause isColumnOnTable, sure := r.isColumnOnTable(col, currentScope) if isColumnOnTable { @@ -576,8 +556,8 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars // If we get here, it means we have found an alias and want to use it if item.ambiguous { - err = &AmbiguousColumnError{Column: sqlparser.String(col)} - } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = newAmbiguousColumnError(col) + } else if aggrTrack.insideAggr && sqlparser.ContainsAggregation(item.expr) { err = &InvalidUseOfGroupFunction{} } if err != nil { @@ -594,6 +574,25 @@ func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlpars return } +type aggrTracker struct { + insideAggr bool +} + +func (at *aggrTracker) down(node, _ sqlparser.SQLNode) bool { + switch node.(type) { + case *sqlparser.Subquery: + return false + case sqlparser.AggrFunc: + at.insideAggr = true + } + + return true +} + +func (at *aggrTracker) popAggr() { + at.insideAggr = false +} + // rewriteAliasesInOrderBy rewrites columns in the ORDER BY to use aliases // from the SELECT expressions when applicable, following MySQL scoping rules: // - A column identifier without a table qualifier that matches an alias introduced @@ -608,23 +607,13 @@ func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlpar } aliases := r.getAliasMap(sel) - insideAggr := false - dontEnterSubquery := func(node, _ sqlparser.SQLNode) bool { - switch node.(type) { - case *sqlparser.Subquery: - return false - case sqlparser.AggrFunc: - insideAggr = true - } - - return true - } - output := sqlparser.CopyOnRewrite(node, dontEnterSubquery, func(cursor *sqlparser.CopyOnWriteCursor) { + aggrTrack := &aggrTracker{} + output := sqlparser.CopyOnRewrite(node, aggrTrack.down, func(cursor *sqlparser.CopyOnWriteCursor) { var col *sqlparser.ColName switch node := cursor.Node().(type) { case sqlparser.AggrFunc: - insideAggr = false + aggrTrack.popAggr() return case *sqlparser.ColName: col = node @@ -661,8 +650,8 @@ func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlpar } if item.ambiguous { - err = &AmbiguousColumnError{Column: sqlparser.String(col)} - } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = newAmbiguousColumnError(col) + } else if aggrTrack.insideAggr && sqlparser.ContainsAggregation(item.expr) { err = &InvalidUseOfGroupFunction{} } if err != nil { diff --git a/go/vt/vtgate/semantics/errors.go b/go/vt/vtgate/semantics/errors.go index 297f2b9613e..3a66a7adb24 100644 --- a/go/vt/vtgate/semantics/errors.go +++ b/go/vt/vtgate/semantics/errors.go @@ -88,6 +88,10 @@ func eprintf(e error, format string, args ...any) string { return fmt.Sprintf(format, args...) } +func newAmbiguousColumnError(name *sqlparser.ColName) error { + return &AmbiguousColumnError{Column: sqlparser.String(name)} +} + // Specific error implementations follow // UnionColumnsDoNotMatchError diff --git a/go/vt/vtgate/semantics/foreign_keys.go b/go/vt/vtgate/semantics/foreign_keys.go new file mode 100644 index 00000000000..4da2f5a232f --- /dev/null +++ b/go/vt/vtgate/semantics/foreign_keys.go @@ -0,0 +1,207 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package semantics + +import ( + vschemapb "vitess.io/vitess/go/vt/proto/vschema" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +type fkManager struct { + binder *binder + tables *tableCollector + si SchemaInformation + getError func() error +} + +// getInvolvedForeignKeys gets the foreign keys that might require taking care off when executing the given statement. +func (fk *fkManager) getInvolvedForeignKeys(statement sqlparser.Statement, fkChecksState *bool) (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, map[string]sqlparser.UpdateExprs, error) { + if fkChecksState != nil && !*fkChecksState { + return nil, nil, nil, nil + } + // There are only the DML statements that require any foreign keys handling. + switch stmt := statement.(type) { + case *sqlparser.Delete: + // For DELETE statements, none of the parent foreign keys require handling. + // So we collect all the child foreign keys. + allChildFks, _, err := fk.getAllManagedForeignKeys() + return allChildFks, nil, nil, err + case *sqlparser.Insert: + // For INSERT statements, we have 3 different cases: + // 1. REPLACE statement: REPLACE statements are essentially DELETEs and INSERTs rolled into one. + // So we need to the parent foreign keys to ensure we are inserting the correct values, and the child foreign keys + // to ensure we don't change a row that breaks the constraint or cascade any operations on the child tables. + // 2. Normal INSERT statement: We don't need to check anything on the child foreign keys, so we just get all the parent foreign keys. + // 3. INSERT with ON DUPLICATE KEY UPDATE: This might trigger an update on the columns specified in the ON DUPLICATE KEY UPDATE clause. + allChildFks, allParentFKs, err := fk.getAllManagedForeignKeys() + if err != nil { + return nil, nil, nil, err + } + if stmt.Action == sqlparser.ReplaceAct { + return allChildFks, allParentFKs, nil, nil + } + if len(stmt.OnDup) == 0 { + return nil, allParentFKs, nil, nil + } + // If only a certain set of columns are being updated, then there might be some child foreign keys that don't need any consideration since their columns aren't being updated. + // So, we filter these child foreign keys out. We can't filter any parent foreign keys because the statement will INSERT a row too, which requires validating all the parent foreign keys. + updatedChildFks, _, childFkToUpdExprs, err := fk.filterForeignKeysUsingUpdateExpressions(allChildFks, nil, sqlparser.UpdateExprs(stmt.OnDup)) + return updatedChildFks, allParentFKs, childFkToUpdExprs, err + case *sqlparser.Update: + // For UPDATE queries we get all the parent and child foreign keys, but we can filter some of them out if the columns that they consist off aren't being updated or are set to NULLs. + allChildFks, allParentFks, err := fk.getAllManagedForeignKeys() + if err != nil { + return nil, nil, nil, err + } + return fk.filterForeignKeysUsingUpdateExpressions(allChildFks, allParentFks, stmt.Exprs) + default: + return nil, nil, nil, nil + } +} + +// filterForeignKeysUsingUpdateExpressions filters the child and parent foreign key constraints that don't require any validations/cascades given the updated expressions. +func (fk *fkManager) filterForeignKeysUsingUpdateExpressions(allChildFks map[TableSet][]vindexes.ChildFKInfo, allParentFks map[TableSet][]vindexes.ParentFKInfo, updExprs sqlparser.UpdateExprs) (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, map[string]sqlparser.UpdateExprs, error) { + if len(allChildFks) == 0 && len(allParentFks) == 0 { + return nil, nil, nil, nil + } + + pFksRequired := make(map[TableSet][]bool, len(allParentFks)) + cFksRequired := make(map[TableSet][]bool, len(allChildFks)) + for ts, fks := range allParentFks { + pFksRequired[ts] = make([]bool, len(fks)) + } + for ts, fks := range allChildFks { + cFksRequired[ts] = make([]bool, len(fks)) + } + + // updExprToTableSet stores the tables that the updated expressions are from. + updExprToTableSet := make(map[*sqlparser.ColName]TableSet) + + // childFKToUpdExprs stores child foreign key to update expressions mapping. + childFKToUpdExprs := map[string]sqlparser.UpdateExprs{} + + // Go over all the update expressions + for _, updateExpr := range updExprs { + deps := fk.binder.direct.dependencies(updateExpr.Name) + if deps.NumberOfTables() != 1 { + // If we don't get exactly one table for the given update expression, we would have definitely run into an error + // during the binder phase that we would have stored. We should return that error, since we can't safely proceed with + // foreign key related changes without having all the information. + return nil, nil, nil, fk.getError() + } + updExprToTableSet[updateExpr.Name] = deps + // Get all the child and parent foreign keys for the given table that the update expression belongs to. + childFks := allChildFks[deps] + parentFKs := allParentFks[deps] + + // Any foreign key to a child table for a column that has been updated + // will require the cascade operations or restrict verification to happen, so we include all such foreign keys. + for idx, childFk := range childFks { + if childFk.ParentColumns.FindColumn(updateExpr.Name.Name) >= 0 { + cFksRequired[deps][idx] = true + tbl, _ := fk.tables.tableInfoFor(deps) + ue := childFKToUpdExprs[childFk.String(tbl.GetVindexTable())] + ue = append(ue, updateExpr) + childFKToUpdExprs[childFk.String(tbl.GetVindexTable())] = ue + } + } + // If we are setting a column to NULL, then we don't need to verify the existence of an + // equivalent row in the parent table, even if this column was part of a foreign key to a parent table. + if sqlparser.IsNull(updateExpr.Expr) { + continue + } + // We add all the possible parent foreign key constraints that need verification that an equivalent row + // exists, given that this column has changed. + for idx, parentFk := range parentFKs { + if parentFk.ChildColumns.FindColumn(updateExpr.Name.Name) >= 0 { + pFksRequired[deps][idx] = true + } + } + } + // For the parent foreign keys, if any of the columns part of the fk is set to NULL, + // then, we don't care for the existence of an equivalent row in the parent table. + for _, updateExpr := range updExprs { + if !sqlparser.IsNull(updateExpr.Expr) { + continue + } + ts := updExprToTableSet[updateExpr.Name] + parentFKs := allParentFks[ts] + for idx, parentFk := range parentFKs { + if parentFk.ChildColumns.FindColumn(updateExpr.Name.Name) >= 0 { + pFksRequired[ts][idx] = false + } + } + } + + // Create new maps with only the required foreign keys. + pFksNeedsHandling := map[TableSet][]vindexes.ParentFKInfo{} + cFksNeedsHandling := map[TableSet][]vindexes.ChildFKInfo{} + for ts, parentFks := range allParentFks { + var pFKNeeded []vindexes.ParentFKInfo + for idx, fk := range parentFks { + if pFksRequired[ts][idx] { + pFKNeeded = append(pFKNeeded, fk) + } + } + pFksNeedsHandling[ts] = pFKNeeded + } + for ts, childFks := range allChildFks { + var cFKNeeded []vindexes.ChildFKInfo + for idx, fk := range childFks { + if cFksRequired[ts][idx] { + cFKNeeded = append(cFKNeeded, fk) + } + } + cFksNeedsHandling[ts] = cFKNeeded + } + return cFksNeedsHandling, pFksNeedsHandling, childFKToUpdExprs, nil +} + +// getAllManagedForeignKeys gets all the foreign keys for the query we are analyzing that Vitess is responsible for managing. +func (fk *fkManager) getAllManagedForeignKeys() (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, error) { + allChildFKs := make(map[TableSet][]vindexes.ChildFKInfo) + allParentFKs := make(map[TableSet][]vindexes.ParentFKInfo) + + // Go over all the tables and collect the foreign keys. + for idx, table := range fk.tables.Tables { + vi := table.GetVindexTable() + if vi == nil || vi.Keyspace == nil { + // If is not a real table, so should be skipped. + continue + } + // Check whether Vitess needs to manage the foreign keys in this keyspace or not. + fkMode, err := fk.si.ForeignKeyMode(vi.Keyspace.Name) + if err != nil { + return nil, nil, err + } + if fkMode != vschemapb.Keyspace_managed { + continue + } + // Cyclic foreign key constraints error is stored in the keyspace. + ksErr := fk.si.KeyspaceError(vi.Keyspace.Name) + if ksErr != nil { + return nil, nil, ksErr + } + + // Add all the child and parent foreign keys to our map. + ts := SingleTableSet(idx) + allChildFKs[ts] = vi.ChildForeignKeys + allParentFKs[ts] = vi.ParentForeignKeys + } + return allChildFKs, allParentFKs, nil +} diff --git a/go/vt/vtgate/semantics/analyzer_fk_test.go b/go/vt/vtgate/semantics/foreign_keys_test.go similarity index 95% rename from go/vt/vtgate/semantics/analyzer_fk_test.go rename to go/vt/vtgate/semantics/foreign_keys_test.go index 05a5991b49f..e1c26ecf569 100644 --- a/go/vt/vtgate/semantics/analyzer_fk_test.go +++ b/go/vt/vtgate/semantics/foreign_keys_test.go @@ -133,14 +133,14 @@ var tbl = map[string]TableInfo{ func TestGetAllManagedForeignKeys(t *testing.T) { tests := []struct { name string - analyzer *analyzer + fkManager *fkManager childFkWanted map[TableSet][]vindexes.ChildFKInfo parentFkWanted map[TableSet][]vindexes.ParentFKInfo expectedErr string }{ { name: "Collect all foreign key constraints", - analyzer: &analyzer{ + fkManager: &fkManager{ tables: &tableCollector{ Tables: []TableInfo{ tbl["t0"], @@ -170,7 +170,7 @@ func TestGetAllManagedForeignKeys(t *testing.T) { }, { name: "keyspace not found in schema information", - analyzer: &analyzer{ + fkManager: &fkManager{ tables: &tableCollector{ Tables: []TableInfo{ tbl["t2"], @@ -187,7 +187,7 @@ func TestGetAllManagedForeignKeys(t *testing.T) { }, { name: "Cyclic fk constraints error", - analyzer: &analyzer{ + fkManager: &fkManager{ tables: &tableCollector{ Tables: []TableInfo{ tbl["t0"], tbl["t1"], @@ -209,7 +209,7 @@ func TestGetAllManagedForeignKeys(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - childFk, parentFk, err := tt.analyzer.getAllManagedForeignKeys() + childFk, parentFk, err := tt.fkManager.getAllManagedForeignKeys() if tt.expectedErr != "" { require.EqualError(t, err, tt.expectedErr) return @@ -226,7 +226,7 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { colb := sqlparser.NewColName("colb") colc := sqlparser.NewColName("colc") cold := sqlparser.NewColName("cold") - a := &analyzer{ + a := &fkManager{ binder: &binder{ direct: map[sqlparser.Expr]TableSet{ cola: SingleTableSet(0), @@ -235,7 +235,7 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { cold: SingleTableSet(1), }, }, - unshardedErr: fmt.Errorf("ambiguous test error"), + getError: func() error { return fmt.Errorf("ambiguous test error") }, tables: &tableCollector{ Tables: []TableInfo{ tbl["t4"], @@ -256,7 +256,7 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { } tests := []struct { name string - analyzer *analyzer + fkManager *fkManager allChildFks map[TableSet][]vindexes.ChildFKInfo allParentFks map[TableSet][]vindexes.ParentFKInfo updExprs sqlparser.UpdateExprs @@ -266,7 +266,7 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { }{ { name: "Child Foreign Keys Filtering", - analyzer: a, + fkManager: a, allParentFks: nil, allChildFks: map[TableSet][]vindexes.ChildFKInfo{ SingleTableSet(0): tbl["t4"].(*RealTable).Table.ChildForeignKeys, @@ -285,8 +285,8 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { }, parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{}, }, { - name: "Parent Foreign Keys Filtering", - analyzer: a, + name: "Parent Foreign Keys Filtering", + fkManager: a, allParentFks: map[TableSet][]vindexes.ParentFKInfo{ SingleTableSet(0): tbl["t4"].(*RealTable).Table.ParentForeignKeys, SingleTableSet(1): tbl["t5"].(*RealTable).Table.ParentForeignKeys, @@ -304,8 +304,8 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { }, }, }, { - name: "Unknown column", - analyzer: a, + name: "Unknown column", + fkManager: a, allParentFks: map[TableSet][]vindexes.ParentFKInfo{ SingleTableSet(0): tbl["t4"].(*RealTable).Table.ParentForeignKeys, SingleTableSet(1): tbl["t5"].(*RealTable).Table.ParentForeignKeys, @@ -319,7 +319,7 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - childFks, parentFks, _, err := tt.analyzer.filterForeignKeysUsingUpdateExpressions(tt.allChildFks, tt.allParentFks, tt.updExprs) + childFks, parentFks, _, err := tt.fkManager.filterForeignKeysUsingUpdateExpressions(tt.allChildFks, tt.allParentFks, tt.updExprs) require.EqualValues(t, tt.childFksWanted, childFks) require.EqualValues(t, tt.parentFksWanted, parentFks) if tt.errWanted != "" { @@ -340,7 +340,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { tests := []struct { name string stmt sqlparser.Statement - analyzer *analyzer + fkManager *fkManager childFksWanted map[TableSet][]vindexes.ChildFKInfo parentFksWanted map[TableSet][]vindexes.ParentFKInfo childFkUpdateExprsWanted map[string]sqlparser.UpdateExprs @@ -349,7 +349,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { { name: "Delete Query", stmt: &sqlparser.Delete{}, - analyzer: &analyzer{ + fkManager: &fkManager{ tables: &tableCollector{ Tables: []TableInfo{ tbl["t0"], @@ -380,7 +380,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { &sqlparser.UpdateExpr{Name: cold, Expr: &sqlparser.NullVal{}}, }, }, - analyzer: &analyzer{ + fkManager: &fkManager{ binder: &binder{ direct: map[sqlparser.Expr]TableSet{ cola: SingleTableSet(0), @@ -432,7 +432,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { stmt: &sqlparser.Insert{ Action: sqlparser.ReplaceAct, }, - analyzer: &analyzer{ + fkManager: &fkManager{ tables: &tableCollector{ Tables: []TableInfo{ tbl["t0"], @@ -464,7 +464,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { stmt: &sqlparser.Insert{ Action: sqlparser.InsertAct, }, - analyzer: &analyzer{ + fkManager: &fkManager{ tables: &tableCollector{ Tables: []TableInfo{ tbl["t0"], @@ -495,7 +495,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { &sqlparser.UpdateExpr{Name: colb, Expr: &sqlparser.NullVal{}}, }, }, - analyzer: &analyzer{ + fkManager: &fkManager{ binder: &binder{ direct: map[sqlparser.Expr]TableSet{ cola: SingleTableSet(0), @@ -535,7 +535,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { { name: "Insert error", stmt: &sqlparser.Insert{}, - analyzer: &analyzer{ + fkManager: &fkManager{ tables: &tableCollector{ Tables: []TableInfo{ tbl["t2"], @@ -553,7 +553,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { { name: "Update error", stmt: &sqlparser.Update{}, - analyzer: &analyzer{ + fkManager: &fkManager{ tables: &tableCollector{ Tables: []TableInfo{ tbl["t2"], @@ -572,7 +572,7 @@ func TestGetInvolvedForeignKeys(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fkState := true - childFks, parentFks, childFkUpdateExprs, err := tt.analyzer.getInvolvedForeignKeys(tt.stmt, &fkState) + childFks, parentFks, childFkUpdateExprs, err := tt.fkManager.getInvolvedForeignKeys(tt.stmt, &fkState) if tt.expectedErr != "" { require.EqualError(t, err, tt.expectedErr) return