Skip to content

Commit

Permalink
feat: add <include /> support
Browse files Browse the repository at this point in the history
  • Loading branch information
hengwei-test committed May 16, 2024
1 parent 562466c commit 8b27452
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 72 deletions.
53 changes: 52 additions & 1 deletion cmd/gobatis/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,58 @@ func (cmd *Generator) generateInterface(out io.Writer, file *goparser2.File, itf
func (cmd *Generator) generateInterfaceInit(out io.Writer, file *goparser2.File, itf *goparser2.Interface) error {
io.WriteString(out, "\r\n\r\n"+`func init() {
gobatis.Init(func(ctx *gobatis.InitContext) error {`)
if len(itf.SqlFragments) > 0 {

var recordTypeName string
recordType := itf.DetectRecordType(nil, false)
if recordType != nil {
recordTypeName = recordType.ToLiteral()
}
if recordTypeName == "" {
recordTypeName = "xxxxxxxxxxxxxxxxxxxx"
}

io.WriteString(out, "\r\n"+` var sqlExpressions = ctx.SqlExpressions`)
io.WriteString(out, "\r\n"+` ctx.SqlExpressions = map[string]*gobatis.SqlExpression{}`)
io.WriteString(out, "\r\n"+` for id, expr := range sqlExpressions {`)
io.WriteString(out, "\r\n"+` ctx.SqlExpressions[id] = expr`)
io.WriteString(out, "\r\n"+` }`)
for id, fragmentDialects := range itf.SqlFragments {
io.WriteString(out, "\r\n { /// " + id)
if len(fragmentDialects) > 0 {
hasDefaultSql := false
for _, dialect := range fragmentDialects {
if dialect.Dialect != "default" {
continue
}
io.WriteString(out, preprocessingSQL("sqlStr", true, dialect.SQL, recordTypeName))
hasDefaultSql = true
}
if !hasDefaultSql {
io.WriteString(out, "\r\n sqlStr := \"\"")
}

io.WriteString(out, "\r\n switch ctx.Dialect {")
for _, dialect := range fragmentDialects {
if dialect.Dialect == "default" {
continue
}
io.WriteString(out, "\r\n case "+dialect.ToGoLiteral()+":\r\n")
io.WriteString(out, preprocessingSQL("sqlStr", false, dialect.SQL, recordTypeName))
}
io.WriteString(out, "\r\n}")
io.WriteString(out, "\r\n"+` expr, err := gobatis.NewSqlExpression(ctx, sqlstr)
if err != nil {
return err
}
ctx.SqlExpressions["`+id+`"] = expr`)
}
io.WriteString(out, "\r\n}")
}
io.WriteString(out, "\r\n"+`defer func() { `)
io.WriteString(out, "\r\n"+` ctx.SqlExpressions = sqlExpressions`)
io.WriteString(out, "\r\n"+`}()`)
}
for _, m := range itf.Methods {
if m.Config != nil && m.Config.Reference != nil {
continue
Expand All @@ -264,7 +316,6 @@ func (cmd *Generator) generateInterfaceInit(out io.Writer, file *goparser2.File,
// 这个函数没有什么用,只是为了分隔代码
err := func() error {
var recordTypeName string

if m.Config != nil && m.Config.RecordType != "" {
recordTypeName = m.Config.RecordType
} else {
Expand Down
2 changes: 2 additions & 0 deletions cmd/gobatis/goparser2/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ type Interface struct {
EmbeddedInterfaces []string
Comments []string
Methods []*Method

SqlFragments map[string][]Dialect
}

func (itf *Interface) DetectRecordType(method *Method, debug bool) *Type {
Expand Down
1 change: 1 addition & 0 deletions core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type InitContext struct {
Dialect Dialect
Mapper *Mapper
Statements map[string]*MappedStatement
SqlExpressions map[string]SqlExpression
}

var (
Expand Down
27 changes: 24 additions & 3 deletions core/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (stmt *MappedStatement) GenerateSQLs(ctx *Context) ([]sqlAndParam, error) {
return sqlAndParams, nil
}

func NewMapppedStatement(ctx *InitContext, id string, statementType StatementType, resultType ResultType, sqlStr string) (*MappedStatement, error) {
func NewMapppedStatement(ctx *StmtContext, id string, statementType StatementType, resultType ResultType, sqlStr string) (*MappedStatement, error) {
stmt := &MappedStatement{
id: id,
sqlType: statementType,
Expand All @@ -183,7 +183,7 @@ func NewMapppedStatement(ctx *InitContext, id string, statementType StatementTyp
return stmt, nil
}

func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (DynamicSQL, error) {
func CreateSQL(ctx *StmtContext, id, sqlStr, fullText string, one bool) (DynamicSQL, error) {
if strings.Contains(sqlStr, "{{") {
funcMap := ctx.Config.TemplateFuncs
tpl, err := template.New(id).Funcs(funcMap).Parse(sqlStr)
Expand All @@ -200,7 +200,7 @@ func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (Dynamic

// http://www.mybatis.org/mybatis-3/dynamic-sql.html
if hasXMLTag(sqlStr) {
dynamicSQL, err := loadDynamicSQLFromXML(sqlStr)
dynamicSQL, err := loadDynamicSQLFromXML(ctx, sqlStr)
if err != nil {
return nil, errors.New("sql is invalid dynamic sql of '" + id + "', " + err.Error() + "\r\n\t" + sqlStr)
}
Expand All @@ -226,6 +226,27 @@ func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (Dynamic
return allParamsSQL(sqlStr), nil
}

func NewSqlExpression(ctx *InitContext, sqlstr string) (SqlExpression, error) {
stmtctx := &StmtContext{
InitContext: ctx,
}
stmtctx.FindSqlFragment = func(id string) (SqlExpression, error) {
if stmtctx.InitContext.SqlExpressions != nil {
sf := stmtctx.InitContext.SqlExpressions[id]
if sf != nil {
return sf, nil
}
}
return nil, errors.New("sql '"+id+"' missing")
}
segements, err := readSQLStatementForXML(stmtctx, sqlstr)
if err != nil {
return nil, err
}
return expressionArray(segements), nil
}


func CompileNamedQuery(txt string) ([]string, Params, error) {
idx := strings.Index(txt, "#{")
if idx < 0 {
Expand Down
Loading

0 comments on commit 8b27452

Please sign in to comment.