From 18c21b2f74b1f25c993c07259817896237da0915 Mon Sep 17 00:00:00 2001 From: Jinpeng Zhang Date: Mon, 18 Dec 2023 10:00:21 -0800 Subject: [PATCH] sqlmodel(*): optimize O(N*N) generated column check to O(N) (#10314) close pingcap/tiflow#10313 --- pkg/sqlmodel/multirow.go | 13 ++++++++++--- pkg/sqlmodel/row_change.go | 4 +++- pkg/sqlmodel/utils.go | 10 ++++++---- pkg/sqlmodel/utils_test.go | 30 ++++++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/pkg/sqlmodel/multirow.go b/pkg/sqlmodel/multirow.go index 22bb4c81fd8..0cf5c5add4f 100644 --- a/pkg/sqlmodel/multirow.go +++ b/pkg/sqlmodel/multirow.go @@ -129,12 +129,16 @@ func GenUpdateSQL(changes ...*RowChange) (string, []any) { whenCaseStmts[i] = whereBuf.String() } + // Build gegerated columns lower name set to accelerate the following check + targetGeneratedColSet := generatedColumnsNameSet(first.targetTableInfo.Columns) + // Generate `ColumnName`=CASE WHEN .. THEN .. END // Use this value in order to identify which is the first CaseWhenThen line, // because generated column can happen any where and it will be skipped. isFirstCaseWhenThenLine := true for _, column := range first.targetTableInfo.Columns { - if isGenerated(first.targetTableInfo.Columns, column.Name) { + // skip generated columns + if _, ok := targetGeneratedColSet[column.Name.L]; ok { continue } if !isFirstCaseWhenThenLine { @@ -166,7 +170,7 @@ func GenUpdateSQL(changes ...*RowChange) (string, []any) { var assignValueColumnCount int var skipColIdx []int for i, col := range first.sourceTableInfo.Columns { - if isGenerated(first.targetTableInfo.Columns, col.Name) { + if _, ok := targetGeneratedColSet[col.Name.L]; ok { skipColIdx = append(skipColIdx, i) continue } @@ -235,8 +239,11 @@ func GenInsertSQL(tp DMLType, changes ...*RowChange) (string, []interface{}) { buf.WriteString(" (") columnNum := 0 var skipColIdx []int + + // build gegerated columns lower name set to accelerate the following check + generatedColumns := generatedColumnsNameSet(first.targetTableInfo.Columns) for i, col := range first.sourceTableInfo.Columns { - if isGenerated(first.targetTableInfo.Columns, col.Name) { + if _, ok := generatedColumns[col.Name.L]; ok { skipColIdx = append(skipColIdx, i) continue } diff --git a/pkg/sqlmodel/row_change.go b/pkg/sqlmodel/row_change.go index a4ae24444bb..a8fa76835b8 100644 --- a/pkg/sqlmodel/row_change.go +++ b/pkg/sqlmodel/row_change.go @@ -307,10 +307,12 @@ func (r *RowChange) genUpdateSQL() (string, []interface{}) { buf.WriteString(r.targetTable.QuoteString()) buf.WriteString(" SET ") + // Build target generated columns lower names set to accelerate following check + generatedColumns := generatedColumnsNameSet(r.targetTableInfo.Columns) args := make([]interface{}, 0, len(r.preValues)+len(r.postValues)) writtenFirstCol := false for i, col := range r.sourceTableInfo.Columns { - if isGenerated(r.targetTableInfo.Columns, col.Name) { + if _, ok := generatedColumns[col.Name.L]; ok { continue } diff --git a/pkg/sqlmodel/utils.go b/pkg/sqlmodel/utils.go index 1d70c35d25c..e7db35ae55d 100644 --- a/pkg/sqlmodel/utils.go +++ b/pkg/sqlmodel/utils.go @@ -50,13 +50,15 @@ func valuesHolder(n int) string { return builder.String() } -func isGenerated(columns []*timodel.ColumnInfo, name timodel.CIStr) bool { +// generatedColumnsNameSet returns a set of generated columns' name. +func generatedColumnsNameSet(columns []*timodel.ColumnInfo) map[string]struct{} { + m := make(map[string]struct{}) for _, col := range columns { - if col.Name.L == name.L { - return col.IsGenerated() + if col.IsGenerated() { + m[col.Name.L] = struct{}{} } } - return false + return m } // ColValAsStr convert column value as string diff --git a/pkg/sqlmodel/utils_test.go b/pkg/sqlmodel/utils_test.go index dcda5ec7819..c7c5fa27a81 100644 --- a/pkg/sqlmodel/utils_test.go +++ b/pkg/sqlmodel/utils_test.go @@ -16,6 +16,7 @@ package sqlmodel import ( "testing" + timodel "github.com/pingcap/tidb/pkg/parser/model" "github.com/shopspring/decimal" "github.com/stretchr/testify/require" ) @@ -40,3 +41,32 @@ func TestValidatorGenColData(t *testing.T) { res = ColValAsStr(decimal.NewFromInt(222123123)) require.Equal(t, "222123123", res) } + +func TestGeneratedColumnsNameSet(t *testing.T) { + t.Parallel() + + cols := []*timodel.ColumnInfo{ + { + Name: timodel.CIStr{O: "A", L: "a"}, + GeneratedExprString: "generated_expr", + }, + { + Name: timodel.CIStr{O: "B", L: "b"}, + }, + { + Name: timodel.CIStr{O: "C", L: "c"}, + GeneratedExprString: "generated_expr", + }, + { + Name: timodel.CIStr{O: "D", L: "d"}, + GeneratedExprString: "generated_expr", + }, + } + + m := generatedColumnsNameSet(cols) + require.Equal(t, map[string]struct{}{ + "a": {}, + "c": {}, + "d": {}, + }, m) +}