Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick #1788

Merged
merged 10 commits into from
Sep 6, 2023
7 changes: 6 additions & 1 deletion sqle/driver/mysql/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

rulepkg "github.com/actiontech/sqle/sqle/driver/mysql/rule"
"github.com/actiontech/sqle/sqle/driver/mysql/session"
"github.com/actiontech/sqle/sqle/driver/mysql/util"
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
"github.com/actiontech/sqle/sqle/utils"
Expand Down Expand Up @@ -34,6 +35,7 @@ const (
)

const CheckInvalidErrorFormat = "预检查失败: %v"
const CheckInvalidError = "预检查失败"

func (i *MysqlDriverImpl) CheckInvalid(node ast.Node) error {
var err error
Expand Down Expand Up @@ -65,7 +67,10 @@ func (i *MysqlDriverImpl) CheckInvalid(node ast.Node) error {
case *ast.UnparsedStmt:
err = i.checkUnparsedStmt(stmt)
}
if err != nil {

if err != nil && session.IsParseShowCreateTableContentErr(err) {
return err // todo #1630 直接返回原始错误类型,方便跳过
} else if err != nil {
return fmt.Errorf(CheckInvalidErrorFormat, err)
}
return nil
Expand Down
14 changes: 12 additions & 2 deletions sqle/driver/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@ func (i *MysqlDriverImpl) audit(ctx context.Context, sql string) (*driverV2.Audi
} else {
err = i.CheckInvalid(nodes[0])
}
if err != nil {
if err != nil && session.IsParseShowCreateTableContentErr(err) {
i.Logger().Errorf("check invalid failed: %v", err)
i.result.Add(driverV2.RuleLevelWarn, CheckInvalidError, fmt.Sprintf(CheckInvalidErrorFormat, "解析建表语句失败,部分在线审核规则可能失效,请人工确认"))
} else if err != nil {
return nil, err
}

Expand Down Expand Up @@ -354,6 +357,11 @@ func (i *MysqlDriverImpl) audit(ctx context.Context, sql string) (*driverV2.Audi
}

if err := handler.Func(input); err != nil {
// todo #1630 临时跳过解析建表语句失败导致的规则
if session.IsParseShowCreateTableContentErr(err) {
i.Logger().Errorf("skip rule, rule_desc_name=%v rule_desc=%v err:%v", rule.Name, rule.Desc, err.Error())
continue
}
return nil, err
}
}
Expand Down Expand Up @@ -396,7 +404,9 @@ func (i *MysqlDriverImpl) audit(ctx context.Context, sql string) (*driverV2.Audi

// print osc
oscCommandLine, err := i.generateOSCCommandLine(nodes[0])
if err != nil {
if err != nil && session.IsParseShowCreateTableContentErr(err) {
i.Logger().Errorf("generate osc command failed: %v", err.Error()) // todo #1630 临时跳过创表语句解析错误
} else if err != nil {
return nil, err
}
if oscCommandLine != "" {
Expand Down
39 changes: 31 additions & 8 deletions sqle/driver/mysql/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type TableInfo struct {
// OriginalTable save parser object from db by query "show create table ...";
// using in inspect and generate rollback sql
OriginalTable *ast.CreateTableStmt
// OriginalTableError save the error about getting original table
OriginalTableError error // todo #1630 临时缓存错误,方便跳过解析建表语句的错误

//
MergedTable *ast.CreateTableStmt
Expand Down Expand Up @@ -370,6 +372,9 @@ func (c *Context) UpdateContext(node ast.Node) {
}
info.MergedTable, _ = util.MergeAlterToTable(oldTable, s)
info.AlterTables = append(info.AlterTables, s)
if info.MergedTable == nil || info.MergedTable.Table == nil {
return
}
// rename table
if s.Table.Name.String() != info.MergedTable.Table.Name.String() {
schemaName := c.GetSchemaName(s.Table)
Expand Down Expand Up @@ -503,6 +508,19 @@ func (c *Context) AddSystemVariable(name, value string) {
c.sysVars[name] = value
}

type ParseShowCreateTableContentError struct { // todo #1630 临时返回一个指定的错误类型,方便跳过解析建表语句的错误
Msg string
}

func (p *ParseShowCreateTableContentError) Error() string {
return fmt.Sprintf("parse show create table content failed: %v", p.Msg)
}

func IsParseShowCreateTableContentErr(err error) bool {
var target *ParseShowCreateTableContentError
return errors.As(err, &target)
}

// GetCreateTableStmt get create table stmtNode for db by query; if table not exist, return null.
func (c *Context) GetCreateTableStmt(stmt *ast.TableName) (*ast.CreateTableStmt, bool, error) {
exist, err := c.IsTableExist(stmt)
Expand All @@ -525,17 +543,21 @@ func (c *Context) GetCreateTableStmt(stmt *ast.TableName) (*ast.CreateTableStmt,
return nil, false, nil
}

if info.OriginalTableError != nil && IsParseShowCreateTableContentErr(info.OriginalTableError) { // todo #1630 临时减少解析失败时的调用次数
return nil, false, info.OriginalTableError
}
createTableSql, err := c.e.ShowCreateTable(utils.SupplementalQuotationMarks(stmt.Schema.String()), utils.SupplementalQuotationMarks(stmt.Name.String()))
if err != nil {
return nil, exist, err
}
createStmt, err := util.ParseCreateTableStmt(createTableSql)
if err != nil {
createStmt, errByMysqlParser := util.ParseCreateTableStmt(createTableSql)
if errByMysqlParser != nil {
//todo to be compatible with OceanBase-MySQL-Mode
log.Logger().Warnf("parse create table stmt failed. try to parse it as OB-MySQL-Mode. err:%v", err)
createStmt, err = c.parseObMysqlCreateTableSql(createTableSql)
log.Logger().Warnf("parse create table stmt failed. try to parse it with compatible method. err:%v", errByMysqlParser)
createStmt, err = c.parseCreateTableSqlCompatibly(createTableSql)
if err != nil {
return nil, exist, err
info.OriginalTableError = &ParseShowCreateTableContentError{Msg: errByMysqlParser.Error()}
return nil, exist, info.OriginalTableError
}
}
info.OriginalTable = createStmt
Expand Down Expand Up @@ -586,7 +608,7 @@ partition p15)
建表语句后半段是options,oceanbase mysql模式下的show create table结果返回的options中包含mysql不支持的options, 为了能解析, 方法将会倒着遍历建表语句, 每次找到右括号时截断后面的部分, 然后尝试解析一次, 直到解析成功, 此时剩余的建表语句将不在包含OB特有options

*/
func (c *Context) parseObMysqlCreateTableSql(createTableSql string) (*ast.CreateTableStmt, error) {
func (c *Context) parseCreateTableSqlCompatibly(createTableSql string) (*ast.CreateTableStmt, error) {
for i := len(createTableSql) - 1; i >= 0; i-- {
if createTableSql[i] == ')' {
stmt, err := util.ParseCreateTableStmt(createTableSql[0 : i+1])
Expand All @@ -595,8 +617,9 @@ func (c *Context) parseObMysqlCreateTableSql(createTableSql string) (*ast.Create
}
}
}

return nil, fmt.Errorf("convert OB MySQL create table sql failed")
errMsg := "parse create table sql with compatible method failed"
log.Logger().Errorf(errMsg)
return nil, errors.New(errMsg)
}

// GetCollationDatabase get collation database.
Expand Down
16 changes: 10 additions & 6 deletions sqle/driver/mysql/util/parser_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,17 @@ func GetLimitCount(limit *ast.Limit, _default int64) (int64, error) {
}

func MergeAlterToTable(oldTable *ast.CreateTableStmt, alterTable *ast.AlterTableStmt) (*ast.CreateTableStmt, error) {
newTable := &ast.CreateTableStmt{
Table: oldTable.Table,
Cols: oldTable.Cols,
Constraints: oldTable.Constraints,
Options: oldTable.Options,
Partition: oldTable.Partition,
newTable := &ast.CreateTableStmt{}
if oldTable != nil {
newTable = &ast.CreateTableStmt{
Table: oldTable.Table,
Cols: oldTable.Cols,
Constraints: oldTable.Constraints,
Options: oldTable.Options,
Partition: oldTable.Partition,
}
}

for _, spec := range GetAlterTableSpecByTp(alterTable.Specs, ast.AlterTableRenameTable) {
newTable.Table = spec.NewTable
}
Expand Down
6 changes: 5 additions & 1 deletion sqle/server/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/actiontech/sqle/sqle/driver"
"github.com/actiontech/sqle/sqle/driver/mysql/session"
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
"github.com/actiontech/sqle/sqle/model"
"github.com/actiontech/sqle/sqle/utils"
Expand Down Expand Up @@ -278,7 +279,10 @@ func genRollbackSQL(l *logrus.Entry, task *model.Task, p driver.Plugin) ([]*mode
rollbackSQLs := make([]*model.RollbackSQL, 0, len(task.ExecuteSQLs))
for _, executeSQL := range task.ExecuteSQLs {
rollbackSQL, reason, err := p.GenRollbackSQL(context.TODO(), executeSQL.Content)
if err != nil {
if err != nil && session.IsParseShowCreateTableContentErr(err) {
l.Errorf("gen rollback sql error, %v", err) // todo #1630 临时跳过创表语句解析错误
return nil, nil
} else if err != nil {
l.Errorf("gen rollback sql error, %v", err)
return nil, err
}
Expand Down
Loading