diff --git a/cmd/gobatis/generator/generator.go b/cmd/gobatis/generator/generator.go index f7265ad..576ed64 100644 --- a/cmd/gobatis/generator/generator.go +++ b/cmd/gobatis/generator/generator.go @@ -242,6 +242,58 @@ func (cmd *Generator) generateInterface(out io.Writer, file *goparser2.File, itf func (cmd *Generator) generateInterfaceInit(out io.Writer, file *goparser2.File, itf *goparser2.Interface) error { io.WriteString(out, "\r\n\r\n"+`func init() { gobatis.Init(func(ctx *gobatis.InitContext) error {`) + if len(itf.SqlFragments) > 0 { + + var recordTypeName string + recordType := itf.DetectRecordType(nil, false) + if recordType != nil { + recordTypeName = recordType.ToLiteral() + } + if recordTypeName == "" { + recordTypeName = "xxxxxxxxxxxxxxxxxxxx" + } + + io.WriteString(out, "\r\n"+` var sqlExpressions = ctx.SqlExpressions`) + io.WriteString(out, "\r\n"+` ctx.SqlExpressions = map[string]*gobatis.SqlExpression{}`) + io.WriteString(out, "\r\n"+` for id, expr := range sqlExpressions {`) + io.WriteString(out, "\r\n"+` ctx.SqlExpressions[id] = expr`) + io.WriteString(out, "\r\n"+` }`) + for id, fragmentDialects := range itf.SqlFragments { + io.WriteString(out, "\r\n { /// " + id) + if len(fragmentDialects) > 0 { + hasDefaultSql := false + for _, dialect := range fragmentDialects { + if dialect.Dialect != "default" { + continue + } + io.WriteString(out, preprocessingSQL("sqlStr", true, dialect.SQL, recordTypeName)) + hasDefaultSql = true + } + if !hasDefaultSql { + io.WriteString(out, "\r\n sqlStr := \"\"") + } + + io.WriteString(out, "\r\n switch ctx.Dialect {") + for _, dialect := range fragmentDialects { + if dialect.Dialect == "default" { + continue + } + io.WriteString(out, "\r\n case "+dialect.ToGoLiteral()+":\r\n") + io.WriteString(out, preprocessingSQL("sqlStr", false, dialect.SQL, recordTypeName)) + } + io.WriteString(out, "\r\n}") + io.WriteString(out, "\r\n"+` expr, err := gobatis.NewSqlExpression(ctx, sqlstr) + if err != nil { + return err + } + ctx.SqlExpressions["`+id+`"] = expr`) + } + io.WriteString(out, "\r\n}") + } + io.WriteString(out, "\r\n"+`defer func() { `) + io.WriteString(out, "\r\n"+` ctx.SqlExpressions = sqlExpressions`) + io.WriteString(out, "\r\n"+`}()`) + } for _, m := range itf.Methods { if m.Config != nil && m.Config.Reference != nil { continue @@ -264,7 +316,6 @@ func (cmd *Generator) generateInterfaceInit(out io.Writer, file *goparser2.File, // 这个函数没有什么用,只是为了分隔代码 err := func() error { var recordTypeName string - if m.Config != nil && m.Config.RecordType != "" { recordTypeName = m.Config.RecordType } else { diff --git a/cmd/gobatis/goparser2/interface.go b/cmd/gobatis/goparser2/interface.go index 79228da..1eb33bd 100644 --- a/cmd/gobatis/goparser2/interface.go +++ b/cmd/gobatis/goparser2/interface.go @@ -19,6 +19,8 @@ type Interface struct { EmbeddedInterfaces []string Comments []string Methods []*Method + + SqlFragments map[string][]Dialect } func (itf *Interface) DetectRecordType(method *Method, debug bool) *Type { diff --git a/core/context.go b/core/context.go index 532960f..b1e352e 100644 --- a/core/context.go +++ b/core/context.go @@ -64,6 +64,7 @@ type InitContext struct { Dialect Dialect Mapper *Mapper Statements map[string]*MappedStatement + SqlExpressions map[string]SqlExpression } var ( diff --git a/core/statement.go b/core/statement.go index c992939..a15ce35 100644 --- a/core/statement.go +++ b/core/statement.go @@ -158,7 +158,7 @@ func (stmt *MappedStatement) GenerateSQLs(ctx *Context) ([]sqlAndParam, error) { return sqlAndParams, nil } -func NewMapppedStatement(ctx *InitContext, id string, statementType StatementType, resultType ResultType, sqlStr string) (*MappedStatement, error) { +func NewMapppedStatement(ctx *StmtContext, id string, statementType StatementType, resultType ResultType, sqlStr string) (*MappedStatement, error) { stmt := &MappedStatement{ id: id, sqlType: statementType, @@ -183,7 +183,7 @@ func NewMapppedStatement(ctx *InitContext, id string, statementType StatementTyp return stmt, nil } -func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (DynamicSQL, error) { +func CreateSQL(ctx *StmtContext, id, sqlStr, fullText string, one bool) (DynamicSQL, error) { if strings.Contains(sqlStr, "{{") { funcMap := ctx.Config.TemplateFuncs tpl, err := template.New(id).Funcs(funcMap).Parse(sqlStr) @@ -200,7 +200,7 @@ func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (Dynamic // http://www.mybatis.org/mybatis-3/dynamic-sql.html if hasXMLTag(sqlStr) { - dynamicSQL, err := loadDynamicSQLFromXML(sqlStr) + dynamicSQL, err := loadDynamicSQLFromXML(ctx, sqlStr) if err != nil { return nil, errors.New("sql is invalid dynamic sql of '" + id + "', " + err.Error() + "\r\n\t" + sqlStr) } @@ -226,6 +226,27 @@ func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (Dynamic return allParamsSQL(sqlStr), nil } +func NewSqlExpression(ctx *InitContext, sqlstr string) (SqlExpression, error) { + stmtctx := &StmtContext{ + InitContext: ctx, + } + stmtctx.FindSqlFragment = func(id string) (SqlExpression, error) { + if stmtctx.InitContext.SqlExpressions != nil { + sf := stmtctx.InitContext.SqlExpressions[id] + if sf != nil { + return sf, nil + } + } + return nil, errors.New("sql '"+id+"' missing") + } + segements, err := readSQLStatementForXML(stmtctx, sqlstr) + if err != nil { + return nil, err + } + return expressionArray(segements), nil +} + + func CompileNamedQuery(txt string) ([]string, Params, error) { idx := strings.Index(txt, "#{") if idx < 0 { diff --git a/core/xml.go b/core/xml.go index 9a06cb5..e72bbe6 100644 --- a/core/xml.go +++ b/core/xml.go @@ -17,12 +17,18 @@ type stmtXML struct { SQL string `xml:",innerxml"` // nolint } +type sqlFragmentXML struct { + ID string `xml:"id,attr"` + SQL string `xml:",innerxml"` // nolint +} + type xmlConfig struct { XMLName xml.Name `xml:"gobatis"` // nolint Selects []stmtXML `xml:"select"` // nolint Deletes []stmtXML `xml:"delete"` // nolint Updates []stmtXML `xml:"update"` // nolint Inserts []stmtXML `xml:"insert"` // nolint + SqlFragments []sqlFragmentXML `xml:"sql"` // nolint } func readMappedStatementsFromXMLFile(ctx *InitContext, filename string) ([]*MappedStatement, error) { @@ -40,8 +46,50 @@ func readMappedStatementsFromXMLFile(ctx *InitContext, filename string) ([]*Mapp return nil, errors.New("Error decode file '" + filename + "': " + err.Error()) } + var sqlFragments = map[string]SqlExpression{} + + stmtctx := &StmtContext{ + InitContext: ctx, + } + + stmtctx.FindSqlFragment = func(id string) (SqlExpression, error) { + if stmtctx.InitContext.SqlExpressions != nil { + sf := stmtctx.InitContext.SqlExpressions[id] + if sf != nil { + return sf, nil + } + } + + sf := sqlFragments[id] + if sf != nil { + return sf, nil + } + + var sqlStr string + for _, stmt := range xmlObj.SqlFragments { + if stmt.ID == id { + sqlStr = stmt.SQL + if sqlStr == "" { + return nil, errors.New("sql '"+ id +"' is empty in file '"+filename+"'") + } + break + } + } + + if sqlStr == "" { + return nil, errors.New("sql '"+id+"' missing") + } + segements, err := readSQLStatementForXML(stmtctx, sqlStr) + if err != nil { + return nil, err + } + sf = expressionArray(segements) + sqlFragments[id] = sf + return sf, nil + } + for _, deleteStmt := range xmlObj.Deletes { - stmt, err := newMapppedStatementFromXML(ctx, deleteStmt, StatementTypeDelete) + stmt, err := newMapppedStatementFromXML(stmtctx, deleteStmt, StatementTypeDelete) if err != nil { return nil, errors.New("Error parse file '" + filename + "' on '" + deleteStmt.ID + "': " + err.Error()) } @@ -49,21 +97,21 @@ func readMappedStatementsFromXMLFile(ctx *InitContext, filename string) ([]*Mapp } for _, insertStmt := range xmlObj.Inserts { - stmt, err := newMapppedStatementFromXML(ctx, insertStmt, StatementTypeInsert) + stmt, err := newMapppedStatementFromXML(stmtctx, insertStmt, StatementTypeInsert) if err != nil { return nil, errors.New("Error parse file '" + filename + "' on '" + insertStmt.ID + "': " + err.Error()) } statements = append(statements, stmt) } for _, selectStmt := range xmlObj.Selects { - stmt, err := newMapppedStatementFromXML(ctx, selectStmt, StatementTypeSelect) + stmt, err := newMapppedStatementFromXML(stmtctx, selectStmt, StatementTypeSelect) if err != nil { return nil, errors.New("Error parse file '" + filename + "' on '" + selectStmt.ID + "': " + err.Error()) } statements = append(statements, stmt) } for _, updateStmt := range xmlObj.Updates { - stmt, err := newMapppedStatementFromXML(ctx, updateStmt, StatementTypeUpdate) + stmt, err := newMapppedStatementFromXML(stmtctx, updateStmt, StatementTypeUpdate) if err != nil { return nil, errors.New("Error parse file '" + filename + "' on '" + updateStmt.ID + "': " + err.Error()) } @@ -72,7 +120,13 @@ func readMappedStatementsFromXMLFile(ctx *InitContext, filename string) ([]*Mapp return statements, nil } -func newMapppedStatementFromXML(ctx *InitContext, stmt stmtXML, sqlType StatementType) (*MappedStatement, error) { +type StmtContext struct { + *InitContext + + FindSqlFragment func(string) (SqlExpression, error) +} + +func newMapppedStatementFromXML(ctx *StmtContext, stmt stmtXML, sqlType StatementType) (*MappedStatement, error) { var resultType ResultType switch strings.ToLower(stmt.Result) { case "": @@ -94,15 +148,15 @@ func newMapppedStatementFromXML(ctx *InitContext, stmt stmtXML, sqlType Statemen return s, nil } -func loadDynamicSQLFromXML(sqlStr string) (DynamicSQL, error) { - segements, err := readSQLStatementForXML(sqlStr) +func loadDynamicSQLFromXML(ctx *StmtContext, sqlStr string) (DynamicSQL, error) { + segements, err := readSQLStatementForXML(ctx, sqlStr) if err != nil { return nil, err } return expressionArray(segements), nil } -func readSQLStatementForXML(sqlStr string) ([]sqlExpression, error) { +func readSQLStatementForXML(ctx *StmtContext, sqlStr string) ([]SqlExpression, error) { txtBegin := ` ` txtEnd := `` @@ -120,7 +174,7 @@ func readSQLStatementForXML(sqlStr string) ([]sqlExpression, error) { switch el := token.(type) { case xml.StartElement: if el.Name.Local == "statement" { - return readElementForXML(decoder, "") + return readElementForXML(ctx, decoder, "") } case xml.Directive, xml.ProcInst, xml.Comment: case xml.CharData: @@ -133,9 +187,9 @@ func readSQLStatementForXML(sqlStr string) ([]sqlExpression, error) { } } -func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error) { +func readElementForXML(ctx *StmtContext, decoder *xml.Decoder, tag string) ([]SqlExpression, error) { var sb strings.Builder - var expressions []sqlExpression + var expressions []SqlExpression var lastPrint *string for { @@ -167,7 +221,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error switch el.Name.Local { case "if": - contents, err := readElementForXML(decoder, tag+"/if") + contents, err := readElementForXML(ctx, decoder, tag+"/if") if err != nil { return nil, err } @@ -201,7 +255,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error expressions = append(expressions, elseExpression{test: test}) case "foreach": - contents, err := readElementForXML(decoder, tag+"/foreach") + contents, err := readElementForXML(ctx, decoder, tag+"/foreach") if err != nil { return nil, err } @@ -221,7 +275,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error expressions = append(expressions, foreach) case "chose": - choseEl, err := loadChoseElementForXML(decoder, tag+"/chose") + choseEl, err := loadChoseElementForXML(ctx, decoder, tag+"/chose") if err != nil { return nil, err } @@ -231,14 +285,14 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error } expressions = append(expressions, chose) case "where": - array, err := readElementForXML(decoder, tag+"/where") + array, err := readElementForXML(ctx, decoder, tag+"/where") if err != nil { return nil, err } expressions = append(expressions, &whereExpression{expressions: array}) case "set": - array, err := readElementForXML(decoder, tag+"/set") + array, err := readElementForXML(ctx, decoder, tag+"/set") if err != nil { return nil, err } @@ -354,7 +408,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error expressions = append(expressions, orderBy) case "trim": - array, err := readElementForXML(decoder, tag+"/trim") + array, err := readElementForXML(ctx, decoder, tag+"/trim") if err != nil { return nil, err } @@ -370,7 +424,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error if prefix := readElementAttrForXML(el.Attr, "prefix"); prefix != "" { expr, err := newRawExpression(prefix) if err != nil { - return nil, errors.New("element trim.prefix is invalid - '" + prefix + "'") + return nil, errors.New("element trim.prefix is invalid - '" + prefix + "', " + err.Error()) } trimExpr.prefix = expr } @@ -383,6 +437,26 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error trimExpr.suffix = expr } expressions = append(expressions, trimExpr) + case "include": + array, err := readElementForXML(ctx, decoder, tag+"/include") + if err != nil { + return nil, err + } + if len(array) != 0 { + return nil, errors.New("element include must is empty element") + } + + refid := readElementAttrForXML(el.Attr, "refid") + if refid == "" { + return nil, errors.New("element include.refid is missing") + } + + expr, err := ctx.FindSqlFragment(refid) + if err != nil { + return nil, errors.New("element include.refid '"+refid+"' invalid, " + err.Error()) + } + + expressions = append(expressions, expr) case "value-range", "value_range", "valuerange": content, err := readElementTextForXML(decoder, tag+"/"+el.Name.Local) if err != nil { @@ -485,7 +559,7 @@ func readElementAttrForXML(attrs []xml.Attr, name string) string { return "" } -func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement, error) { +func loadChoseElementForXML(ctx *StmtContext, decoder *xml.Decoder, tag string) (*xmlChoseElement, error) { var segement xmlChoseElement for { token, err := decoder.Token() @@ -499,7 +573,7 @@ func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement, switch el := token.(type) { case xml.StartElement: if el.Name.Local == "when" { - contents, err := readElementForXML(decoder, "when") + contents, err := readElementForXML(ctx, decoder, "when") if err != nil { return nil, err } @@ -512,7 +586,7 @@ func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement, break } - var content sqlExpression + var content SqlExpression if len(contents) == 1 { content = contents[0] } else if len(contents) > 1 { @@ -527,7 +601,7 @@ func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement, } if el.Name.Local == "otherwise" { - contents, err := readElementForXML(decoder, "otherwise") + contents, err := readElementForXML(ctx, decoder, "otherwise") if err != nil { return nil, err } @@ -559,7 +633,7 @@ func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement, type xmlWhenElement struct { test string - content sqlExpression + content SqlExpression } func (when *xmlWhenElement) String() string { @@ -574,7 +648,7 @@ func (when *xmlWhenElement) String() string { type xmlChoseElement struct { when []xmlWhenElement - otherwise sqlExpression + otherwise SqlExpression } func (chose *xmlChoseElement) String() string { @@ -598,7 +672,7 @@ type xmlForEachElement struct { index string collection string openTag, separatorTag, closeTag string - contents []sqlExpression + contents []SqlExpression } func (foreach *xmlForEachElement) String() string { @@ -643,6 +717,8 @@ func hasXMLTag(sqlStr string) bool { "= 0 { trueExprs = segements[:elseIndex] falseExprs = segements[elseIndex+1:] @@ -782,7 +782,7 @@ type choseExpression struct { el xmlChoseElement when []whenExpression - otherwise sqlExpression + otherwise SqlExpression } func (chose *choseExpression) String() string { @@ -819,7 +819,7 @@ func (chose *choseExpression) writeTo(printer *sqlPrinter) { type whenExpression struct { test *govaluate.EvaluableExpression - expression sqlExpression + expression SqlExpression } func (ifExpr whenExpression) String() string { @@ -844,7 +844,7 @@ func (ifExpr whenExpression) String() string { // } // } -func newChoseExpression(el xmlChoseElement) (sqlExpression, error) { +func newChoseExpression(el xmlChoseElement) (SqlExpression, error) { var when []whenExpression for idx := range el.when { @@ -872,7 +872,7 @@ func newChoseExpression(el xmlChoseElement) (sqlExpression, error) { type forEachExpression struct { el xmlForEachElement - segements []sqlExpression + segements []SqlExpression } func (foreach *forEachExpression) String() string { @@ -1012,7 +1012,7 @@ func (foreach *forEachExpression) writeTo(printer *sqlPrinter) { } } -func newForEachExpression(el xmlForEachElement) (sqlExpression, error) { +func newForEachExpression(el xmlForEachElement) (SqlExpression, error) { if len(el.contents) == 0 { return nil, errors.New("contents of foreach is empty") } @@ -1150,7 +1150,7 @@ func (set *setExpression) writeTo(printer *sqlPrinter) { } } -type expressionArray []sqlExpression +type expressionArray []SqlExpression func (expressions expressionArray) String() string { var sb strings.Builder @@ -1539,9 +1539,9 @@ func (expr orderByExpression) writeTo(printer *sqlPrinter) { type trimExpression struct { expressions expressionArray prefixoverride []string - prefix sqlExpression + prefix SqlExpression suffixoverride []string - suffix sqlExpression + suffix SqlExpression } func (expr trimExpression) String() string { @@ -1688,8 +1688,8 @@ type valueRangeExpression struct { field string value string - prefix sqlExpression - suffix sqlExpression + prefix SqlExpression + suffix SqlExpression } func (expr valueRangeExpression) String() string { diff --git a/core/xml_test.go b/core/xml_test.go index 285332d..f2f2f21 100644 --- a/core/xml_test.go +++ b/core/xml_test.go @@ -71,11 +71,14 @@ func TestXmlOk(t *testing.T) { } var query *Query = nil - initCtx := &core.InitContext{Config: cfg, - // Tracer: cfg.Tracer, - Dialect: dialects.Postgres, - Mapper: core.CreateMapper("", nil, nil), - Statements: make(map[string]*core.MappedStatement)} + initCtx := &core.StmtContext{ + InitContext: &core.InitContext{Config: cfg, + // Tracer: cfg.Tracer, + Dialect: dialects.Postgres, + Mapper: core.CreateMapper("", nil, nil), + Statements: make(map[string]*core.MappedStatement), + }, + } for idx, test := range []xmlCase{ // { @@ -1189,7 +1192,8 @@ type xmlErrCase struct { } func TestXmlFail(t *testing.T) { - cfg := &core.Config{DriverName: "postgres", + cfg := &core.Config{ + DriverName: "postgres", DataSource: "aa", XMLPaths: []string{"tests", "../tests", @@ -1200,11 +1204,14 @@ func TestXmlFail(t *testing.T) { Tracer: core.StdLogger{Logger: log.New(os.Stdout, "[gobatis] ", log.Flags())}, } - initCtx := &core.InitContext{Config: cfg, - // Logger: cfg.Logger, - Dialect: dialects.Postgres, - Mapper: core.CreateMapper("", nil, nil), - Statements: make(map[string]*core.MappedStatement)} + initCtx := &core.StmtContext{ + InitContext: &core.InitContext{Config: cfg, + // Logger: cfg.Logger, + Dialect: dialects.Postgres, + Mapper: core.CreateMapper("", nil, nil), + Statements: make(map[string]*core.MappedStatement), + }, + } for idx, test := range []xmlErrCase{ { @@ -1370,11 +1377,15 @@ func TestXmlExpressionOk(t *testing.T) { Tracer: core.StdLogger{Logger: log.New(os.Stdout, "[gobatis] ", log.Flags())}, } - initCtx := &core.InitContext{Config: cfg, - // Logger: cfg.Logger, - Dialect: dialects.Postgres, - Mapper: core.CreateMapper("", nil, nil), - Statements: make(map[string]*core.MappedStatement)} + initCtx := &core.StmtContext{ + InitContext: &core.InitContext{ + Config: cfg, + // Logger: cfg.Logger, + Dialect: dialects.Postgres, + Mapper: core.CreateMapper("", nil, nil), + Statements: make(map[string]*core.MappedStatement), + }, + } for idx, test := range []xmlCase{ { @@ -1521,11 +1532,15 @@ func TestXmlExpressionFail(t *testing.T) { Tracer: core.StdLogger{Logger: log.New(os.Stdout, "[gobatis] ", log.Flags())}, } - initCtx := &core.InitContext{Config: cfg, - // Logger: cfg.Logger, - Dialect: dialects.Postgres, - Mapper: core.CreateMapper("", nil, nil), - Statements: make(map[string]*core.MappedStatement)} + initCtx := &core.StmtContext{ + InitContext: &core.InitContext{ + Config: cfg, + // Logger: cfg.Logger, + Dialect: dialects.Postgres, + Mapper: core.CreateMapper("", nil, nil), + Statements: make(map[string]*core.MappedStatement), + }, + } for idx, test := range []xmlErrCase{ { diff --git a/example_xml/example_test.go b/example_xml/example_test.go index 8397fc9..06ffa2b 100644 --- a/example_xml/example_test.go +++ b/example_xml/example_test.go @@ -83,6 +83,17 @@ func ExampleSimple() { fmt.Println("fetch user from database!") fmt.Println(u.Nickname) + + list, err := userDao.SelectAll("", "all", nil) + if err != nil { + fmt.Println(err) + return + } + fmt.Println("fetch all user from database!") + for _, u := range list { + fmt.Println(u.Nickname) + } + _, err = userDao.DeleteByID(id) if err != nil { fmt.Println(err) @@ -97,5 +108,7 @@ func ExampleSimple() { // update user: 1 // fetch user from database! // ABC + // fetch all user from database! + // ABC // delete success! } diff --git a/example_xml/xmlfiles/dm/test.xml b/example_xml/xmlfiles/dm/test.xml index 1edc4cc..73023c3 100644 --- a/example_xml/xmlfiles/dm/test.xml +++ b/example_xml/xmlfiles/dm/test.xml @@ -3,10 +3,14 @@ + + SELECT * FROM gobatis_users + SELECT * FROM gobatis_users where id=#{id} order by id + + SELECT * FROM gobatis_users + SELECT * FROM gobatis_users where id=#{id} order by id + + SELECT * FROM gobatis_users + SELECT * FROM gobatis_users where id=#{id} order by id + + SELECT * FROM gobatis_users +