diff --git a/cmd/gobatis/generator/generator.go b/cmd/gobatis/generator/generator.go
index f7265ad..576ed64 100644
--- a/cmd/gobatis/generator/generator.go
+++ b/cmd/gobatis/generator/generator.go
@@ -242,6 +242,58 @@ func (cmd *Generator) generateInterface(out io.Writer, file *goparser2.File, itf
func (cmd *Generator) generateInterfaceInit(out io.Writer, file *goparser2.File, itf *goparser2.Interface) error {
io.WriteString(out, "\r\n\r\n"+`func init() {
gobatis.Init(func(ctx *gobatis.InitContext) error {`)
+ if len(itf.SqlFragments) > 0 {
+
+ var recordTypeName string
+ recordType := itf.DetectRecordType(nil, false)
+ if recordType != nil {
+ recordTypeName = recordType.ToLiteral()
+ }
+ if recordTypeName == "" {
+ recordTypeName = "xxxxxxxxxxxxxxxxxxxx"
+ }
+
+ io.WriteString(out, "\r\n"+` var sqlExpressions = ctx.SqlExpressions`)
+ io.WriteString(out, "\r\n"+` ctx.SqlExpressions = map[string]*gobatis.SqlExpression{}`)
+ io.WriteString(out, "\r\n"+` for id, expr := range sqlExpressions {`)
+ io.WriteString(out, "\r\n"+` ctx.SqlExpressions[id] = expr`)
+ io.WriteString(out, "\r\n"+` }`)
+ for id, fragmentDialects := range itf.SqlFragments {
+ io.WriteString(out, "\r\n { /// " + id)
+ if len(fragmentDialects) > 0 {
+ hasDefaultSql := false
+ for _, dialect := range fragmentDialects {
+ if dialect.Dialect != "default" {
+ continue
+ }
+ io.WriteString(out, preprocessingSQL("sqlStr", true, dialect.SQL, recordTypeName))
+ hasDefaultSql = true
+ }
+ if !hasDefaultSql {
+ io.WriteString(out, "\r\n sqlStr := \"\"")
+ }
+
+ io.WriteString(out, "\r\n switch ctx.Dialect {")
+ for _, dialect := range fragmentDialects {
+ if dialect.Dialect == "default" {
+ continue
+ }
+ io.WriteString(out, "\r\n case "+dialect.ToGoLiteral()+":\r\n")
+ io.WriteString(out, preprocessingSQL("sqlStr", false, dialect.SQL, recordTypeName))
+ }
+ io.WriteString(out, "\r\n}")
+ io.WriteString(out, "\r\n"+` expr, err := gobatis.NewSqlExpression(ctx, sqlstr)
+ if err != nil {
+ return err
+ }
+ ctx.SqlExpressions["`+id+`"] = expr`)
+ }
+ io.WriteString(out, "\r\n}")
+ }
+ io.WriteString(out, "\r\n"+`defer func() { `)
+ io.WriteString(out, "\r\n"+` ctx.SqlExpressions = sqlExpressions`)
+ io.WriteString(out, "\r\n"+`}()`)
+ }
for _, m := range itf.Methods {
if m.Config != nil && m.Config.Reference != nil {
continue
@@ -264,7 +316,6 @@ func (cmd *Generator) generateInterfaceInit(out io.Writer, file *goparser2.File,
// 这个函数没有什么用,只是为了分隔代码
err := func() error {
var recordTypeName string
-
if m.Config != nil && m.Config.RecordType != "" {
recordTypeName = m.Config.RecordType
} else {
diff --git a/cmd/gobatis/goparser2/interface.go b/cmd/gobatis/goparser2/interface.go
index 79228da..1eb33bd 100644
--- a/cmd/gobatis/goparser2/interface.go
+++ b/cmd/gobatis/goparser2/interface.go
@@ -19,6 +19,8 @@ type Interface struct {
EmbeddedInterfaces []string
Comments []string
Methods []*Method
+
+ SqlFragments map[string][]Dialect
}
func (itf *Interface) DetectRecordType(method *Method, debug bool) *Type {
diff --git a/core/context.go b/core/context.go
index 532960f..b1e352e 100644
--- a/core/context.go
+++ b/core/context.go
@@ -64,6 +64,7 @@ type InitContext struct {
Dialect Dialect
Mapper *Mapper
Statements map[string]*MappedStatement
+ SqlExpressions map[string]SqlExpression
}
var (
diff --git a/core/statement.go b/core/statement.go
index c992939..a15ce35 100644
--- a/core/statement.go
+++ b/core/statement.go
@@ -158,7 +158,7 @@ func (stmt *MappedStatement) GenerateSQLs(ctx *Context) ([]sqlAndParam, error) {
return sqlAndParams, nil
}
-func NewMapppedStatement(ctx *InitContext, id string, statementType StatementType, resultType ResultType, sqlStr string) (*MappedStatement, error) {
+func NewMapppedStatement(ctx *StmtContext, id string, statementType StatementType, resultType ResultType, sqlStr string) (*MappedStatement, error) {
stmt := &MappedStatement{
id: id,
sqlType: statementType,
@@ -183,7 +183,7 @@ func NewMapppedStatement(ctx *InitContext, id string, statementType StatementTyp
return stmt, nil
}
-func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (DynamicSQL, error) {
+func CreateSQL(ctx *StmtContext, id, sqlStr, fullText string, one bool) (DynamicSQL, error) {
if strings.Contains(sqlStr, "{{") {
funcMap := ctx.Config.TemplateFuncs
tpl, err := template.New(id).Funcs(funcMap).Parse(sqlStr)
@@ -200,7 +200,7 @@ func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (Dynamic
// http://www.mybatis.org/mybatis-3/dynamic-sql.html
if hasXMLTag(sqlStr) {
- dynamicSQL, err := loadDynamicSQLFromXML(sqlStr)
+ dynamicSQL, err := loadDynamicSQLFromXML(ctx, sqlStr)
if err != nil {
return nil, errors.New("sql is invalid dynamic sql of '" + id + "', " + err.Error() + "\r\n\t" + sqlStr)
}
@@ -226,6 +226,27 @@ func CreateSQL(ctx *InitContext, id, sqlStr, fullText string, one bool) (Dynamic
return allParamsSQL(sqlStr), nil
}
+func NewSqlExpression(ctx *InitContext, sqlstr string) (SqlExpression, error) {
+ stmtctx := &StmtContext{
+ InitContext: ctx,
+ }
+ stmtctx.FindSqlFragment = func(id string) (SqlExpression, error) {
+ if stmtctx.InitContext.SqlExpressions != nil {
+ sf := stmtctx.InitContext.SqlExpressions[id]
+ if sf != nil {
+ return sf, nil
+ }
+ }
+ return nil, errors.New("sql '"+id+"' missing")
+ }
+ segements, err := readSQLStatementForXML(stmtctx, sqlstr)
+ if err != nil {
+ return nil, err
+ }
+ return expressionArray(segements), nil
+}
+
+
func CompileNamedQuery(txt string) ([]string, Params, error) {
idx := strings.Index(txt, "#{")
if idx < 0 {
diff --git a/core/xml.go b/core/xml.go
index 9a06cb5..e72bbe6 100644
--- a/core/xml.go
+++ b/core/xml.go
@@ -17,12 +17,18 @@ type stmtXML struct {
SQL string `xml:",innerxml"` // nolint
}
+type sqlFragmentXML struct {
+ ID string `xml:"id,attr"`
+ SQL string `xml:",innerxml"` // nolint
+}
+
type xmlConfig struct {
XMLName xml.Name `xml:"gobatis"` // nolint
Selects []stmtXML `xml:"select"` // nolint
Deletes []stmtXML `xml:"delete"` // nolint
Updates []stmtXML `xml:"update"` // nolint
Inserts []stmtXML `xml:"insert"` // nolint
+ SqlFragments []sqlFragmentXML `xml:"sql"` // nolint
}
func readMappedStatementsFromXMLFile(ctx *InitContext, filename string) ([]*MappedStatement, error) {
@@ -40,8 +46,50 @@ func readMappedStatementsFromXMLFile(ctx *InitContext, filename string) ([]*Mapp
return nil, errors.New("Error decode file '" + filename + "': " + err.Error())
}
+ var sqlFragments = map[string]SqlExpression{}
+
+ stmtctx := &StmtContext{
+ InitContext: ctx,
+ }
+
+ stmtctx.FindSqlFragment = func(id string) (SqlExpression, error) {
+ if stmtctx.InitContext.SqlExpressions != nil {
+ sf := stmtctx.InitContext.SqlExpressions[id]
+ if sf != nil {
+ return sf, nil
+ }
+ }
+
+ sf := sqlFragments[id]
+ if sf != nil {
+ return sf, nil
+ }
+
+ var sqlStr string
+ for _, stmt := range xmlObj.SqlFragments {
+ if stmt.ID == id {
+ sqlStr = stmt.SQL
+ if sqlStr == "" {
+ return nil, errors.New("sql '"+ id +"' is empty in file '"+filename+"'")
+ }
+ break
+ }
+ }
+
+ if sqlStr == "" {
+ return nil, errors.New("sql '"+id+"' missing")
+ }
+ segements, err := readSQLStatementForXML(stmtctx, sqlStr)
+ if err != nil {
+ return nil, err
+ }
+ sf = expressionArray(segements)
+ sqlFragments[id] = sf
+ return sf, nil
+ }
+
for _, deleteStmt := range xmlObj.Deletes {
- stmt, err := newMapppedStatementFromXML(ctx, deleteStmt, StatementTypeDelete)
+ stmt, err := newMapppedStatementFromXML(stmtctx, deleteStmt, StatementTypeDelete)
if err != nil {
return nil, errors.New("Error parse file '" + filename + "' on '" + deleteStmt.ID + "': " + err.Error())
}
@@ -49,21 +97,21 @@ func readMappedStatementsFromXMLFile(ctx *InitContext, filename string) ([]*Mapp
}
for _, insertStmt := range xmlObj.Inserts {
- stmt, err := newMapppedStatementFromXML(ctx, insertStmt, StatementTypeInsert)
+ stmt, err := newMapppedStatementFromXML(stmtctx, insertStmt, StatementTypeInsert)
if err != nil {
return nil, errors.New("Error parse file '" + filename + "' on '" + insertStmt.ID + "': " + err.Error())
}
statements = append(statements, stmt)
}
for _, selectStmt := range xmlObj.Selects {
- stmt, err := newMapppedStatementFromXML(ctx, selectStmt, StatementTypeSelect)
+ stmt, err := newMapppedStatementFromXML(stmtctx, selectStmt, StatementTypeSelect)
if err != nil {
return nil, errors.New("Error parse file '" + filename + "' on '" + selectStmt.ID + "': " + err.Error())
}
statements = append(statements, stmt)
}
for _, updateStmt := range xmlObj.Updates {
- stmt, err := newMapppedStatementFromXML(ctx, updateStmt, StatementTypeUpdate)
+ stmt, err := newMapppedStatementFromXML(stmtctx, updateStmt, StatementTypeUpdate)
if err != nil {
return nil, errors.New("Error parse file '" + filename + "' on '" + updateStmt.ID + "': " + err.Error())
}
@@ -72,7 +120,13 @@ func readMappedStatementsFromXMLFile(ctx *InitContext, filename string) ([]*Mapp
return statements, nil
}
-func newMapppedStatementFromXML(ctx *InitContext, stmt stmtXML, sqlType StatementType) (*MappedStatement, error) {
+type StmtContext struct {
+ *InitContext
+
+ FindSqlFragment func(string) (SqlExpression, error)
+}
+
+func newMapppedStatementFromXML(ctx *StmtContext, stmt stmtXML, sqlType StatementType) (*MappedStatement, error) {
var resultType ResultType
switch strings.ToLower(stmt.Result) {
case "":
@@ -94,15 +148,15 @@ func newMapppedStatementFromXML(ctx *InitContext, stmt stmtXML, sqlType Statemen
return s, nil
}
-func loadDynamicSQLFromXML(sqlStr string) (DynamicSQL, error) {
- segements, err := readSQLStatementForXML(sqlStr)
+func loadDynamicSQLFromXML(ctx *StmtContext, sqlStr string) (DynamicSQL, error) {
+ segements, err := readSQLStatementForXML(ctx, sqlStr)
if err != nil {
return nil, err
}
return expressionArray(segements), nil
}
-func readSQLStatementForXML(sqlStr string) ([]sqlExpression, error) {
+func readSQLStatementForXML(ctx *StmtContext, sqlStr string) ([]SqlExpression, error) {
txtBegin := `
`
txtEnd := ``
@@ -120,7 +174,7 @@ func readSQLStatementForXML(sqlStr string) ([]sqlExpression, error) {
switch el := token.(type) {
case xml.StartElement:
if el.Name.Local == "statement" {
- return readElementForXML(decoder, "")
+ return readElementForXML(ctx, decoder, "")
}
case xml.Directive, xml.ProcInst, xml.Comment:
case xml.CharData:
@@ -133,9 +187,9 @@ func readSQLStatementForXML(sqlStr string) ([]sqlExpression, error) {
}
}
-func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error) {
+func readElementForXML(ctx *StmtContext, decoder *xml.Decoder, tag string) ([]SqlExpression, error) {
var sb strings.Builder
- var expressions []sqlExpression
+ var expressions []SqlExpression
var lastPrint *string
for {
@@ -167,7 +221,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error
switch el.Name.Local {
case "if":
- contents, err := readElementForXML(decoder, tag+"/if")
+ contents, err := readElementForXML(ctx, decoder, tag+"/if")
if err != nil {
return nil, err
}
@@ -201,7 +255,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error
expressions = append(expressions, elseExpression{test: test})
case "foreach":
- contents, err := readElementForXML(decoder, tag+"/foreach")
+ contents, err := readElementForXML(ctx, decoder, tag+"/foreach")
if err != nil {
return nil, err
}
@@ -221,7 +275,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error
expressions = append(expressions, foreach)
case "chose":
- choseEl, err := loadChoseElementForXML(decoder, tag+"/chose")
+ choseEl, err := loadChoseElementForXML(ctx, decoder, tag+"/chose")
if err != nil {
return nil, err
}
@@ -231,14 +285,14 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error
}
expressions = append(expressions, chose)
case "where":
- array, err := readElementForXML(decoder, tag+"/where")
+ array, err := readElementForXML(ctx, decoder, tag+"/where")
if err != nil {
return nil, err
}
expressions = append(expressions, &whereExpression{expressions: array})
case "set":
- array, err := readElementForXML(decoder, tag+"/set")
+ array, err := readElementForXML(ctx, decoder, tag+"/set")
if err != nil {
return nil, err
}
@@ -354,7 +408,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error
expressions = append(expressions, orderBy)
case "trim":
- array, err := readElementForXML(decoder, tag+"/trim")
+ array, err := readElementForXML(ctx, decoder, tag+"/trim")
if err != nil {
return nil, err
}
@@ -370,7 +424,7 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error
if prefix := readElementAttrForXML(el.Attr, "prefix"); prefix != "" {
expr, err := newRawExpression(prefix)
if err != nil {
- return nil, errors.New("element trim.prefix is invalid - '" + prefix + "'")
+ return nil, errors.New("element trim.prefix is invalid - '" + prefix + "', " + err.Error())
}
trimExpr.prefix = expr
}
@@ -383,6 +437,26 @@ func readElementForXML(decoder *xml.Decoder, tag string) ([]sqlExpression, error
trimExpr.suffix = expr
}
expressions = append(expressions, trimExpr)
+ case "include":
+ array, err := readElementForXML(ctx, decoder, tag+"/include")
+ if err != nil {
+ return nil, err
+ }
+ if len(array) != 0 {
+ return nil, errors.New("element include must is empty element")
+ }
+
+ refid := readElementAttrForXML(el.Attr, "refid")
+ if refid == "" {
+ return nil, errors.New("element include.refid is missing")
+ }
+
+ expr, err := ctx.FindSqlFragment(refid)
+ if err != nil {
+ return nil, errors.New("element include.refid '"+refid+"' invalid, " + err.Error())
+ }
+
+ expressions = append(expressions, expr)
case "value-range", "value_range", "valuerange":
content, err := readElementTextForXML(decoder, tag+"/"+el.Name.Local)
if err != nil {
@@ -485,7 +559,7 @@ func readElementAttrForXML(attrs []xml.Attr, name string) string {
return ""
}
-func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement, error) {
+func loadChoseElementForXML(ctx *StmtContext, decoder *xml.Decoder, tag string) (*xmlChoseElement, error) {
var segement xmlChoseElement
for {
token, err := decoder.Token()
@@ -499,7 +573,7 @@ func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement,
switch el := token.(type) {
case xml.StartElement:
if el.Name.Local == "when" {
- contents, err := readElementForXML(decoder, "when")
+ contents, err := readElementForXML(ctx, decoder, "when")
if err != nil {
return nil, err
}
@@ -512,7 +586,7 @@ func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement,
break
}
- var content sqlExpression
+ var content SqlExpression
if len(contents) == 1 {
content = contents[0]
} else if len(contents) > 1 {
@@ -527,7 +601,7 @@ func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement,
}
if el.Name.Local == "otherwise" {
- contents, err := readElementForXML(decoder, "otherwise")
+ contents, err := readElementForXML(ctx, decoder, "otherwise")
if err != nil {
return nil, err
}
@@ -559,7 +633,7 @@ func loadChoseElementForXML(decoder *xml.Decoder, tag string) (*xmlChoseElement,
type xmlWhenElement struct {
test string
- content sqlExpression
+ content SqlExpression
}
func (when *xmlWhenElement) String() string {
@@ -574,7 +648,7 @@ func (when *xmlWhenElement) String() string {
type xmlChoseElement struct {
when []xmlWhenElement
- otherwise sqlExpression
+ otherwise SqlExpression
}
func (chose *xmlChoseElement) String() string {
@@ -598,7 +672,7 @@ type xmlForEachElement struct {
index string
collection string
openTag, separatorTag, closeTag string
- contents []sqlExpression
+ contents []SqlExpression
}
func (foreach *xmlForEachElement) String() string {
@@ -643,6 +717,8 @@ func hasXMLTag(sqlStr string) bool {
"= 0 {
trueExprs = segements[:elseIndex]
falseExprs = segements[elseIndex+1:]
@@ -782,7 +782,7 @@ type choseExpression struct {
el xmlChoseElement
when []whenExpression
- otherwise sqlExpression
+ otherwise SqlExpression
}
func (chose *choseExpression) String() string {
@@ -819,7 +819,7 @@ func (chose *choseExpression) writeTo(printer *sqlPrinter) {
type whenExpression struct {
test *govaluate.EvaluableExpression
- expression sqlExpression
+ expression SqlExpression
}
func (ifExpr whenExpression) String() string {
@@ -844,7 +844,7 @@ func (ifExpr whenExpression) String() string {
// }
// }
-func newChoseExpression(el xmlChoseElement) (sqlExpression, error) {
+func newChoseExpression(el xmlChoseElement) (SqlExpression, error) {
var when []whenExpression
for idx := range el.when {
@@ -872,7 +872,7 @@ func newChoseExpression(el xmlChoseElement) (sqlExpression, error) {
type forEachExpression struct {
el xmlForEachElement
- segements []sqlExpression
+ segements []SqlExpression
}
func (foreach *forEachExpression) String() string {
@@ -1012,7 +1012,7 @@ func (foreach *forEachExpression) writeTo(printer *sqlPrinter) {
}
}
-func newForEachExpression(el xmlForEachElement) (sqlExpression, error) {
+func newForEachExpression(el xmlForEachElement) (SqlExpression, error) {
if len(el.contents) == 0 {
return nil, errors.New("contents of foreach is empty")
}
@@ -1150,7 +1150,7 @@ func (set *setExpression) writeTo(printer *sqlPrinter) {
}
}
-type expressionArray []sqlExpression
+type expressionArray []SqlExpression
func (expressions expressionArray) String() string {
var sb strings.Builder
@@ -1539,9 +1539,9 @@ func (expr orderByExpression) writeTo(printer *sqlPrinter) {
type trimExpression struct {
expressions expressionArray
prefixoverride []string
- prefix sqlExpression
+ prefix SqlExpression
suffixoverride []string
- suffix sqlExpression
+ suffix SqlExpression
}
func (expr trimExpression) String() string {
@@ -1688,8 +1688,8 @@ type valueRangeExpression struct {
field string
value string
- prefix sqlExpression
- suffix sqlExpression
+ prefix SqlExpression
+ suffix SqlExpression
}
func (expr valueRangeExpression) String() string {
diff --git a/core/xml_test.go b/core/xml_test.go
index 285332d..f2f2f21 100644
--- a/core/xml_test.go
+++ b/core/xml_test.go
@@ -71,11 +71,14 @@ func TestXmlOk(t *testing.T) {
}
var query *Query = nil
- initCtx := &core.InitContext{Config: cfg,
- // Tracer: cfg.Tracer,
- Dialect: dialects.Postgres,
- Mapper: core.CreateMapper("", nil, nil),
- Statements: make(map[string]*core.MappedStatement)}
+ initCtx := &core.StmtContext{
+ InitContext: &core.InitContext{Config: cfg,
+ // Tracer: cfg.Tracer,
+ Dialect: dialects.Postgres,
+ Mapper: core.CreateMapper("", nil, nil),
+ Statements: make(map[string]*core.MappedStatement),
+ },
+ }
for idx, test := range []xmlCase{
// {
@@ -1189,7 +1192,8 @@ type xmlErrCase struct {
}
func TestXmlFail(t *testing.T) {
- cfg := &core.Config{DriverName: "postgres",
+ cfg := &core.Config{
+ DriverName: "postgres",
DataSource: "aa",
XMLPaths: []string{"tests",
"../tests",
@@ -1200,11 +1204,14 @@ func TestXmlFail(t *testing.T) {
Tracer: core.StdLogger{Logger: log.New(os.Stdout, "[gobatis] ", log.Flags())},
}
- initCtx := &core.InitContext{Config: cfg,
- // Logger: cfg.Logger,
- Dialect: dialects.Postgres,
- Mapper: core.CreateMapper("", nil, nil),
- Statements: make(map[string]*core.MappedStatement)}
+ initCtx := &core.StmtContext{
+ InitContext: &core.InitContext{Config: cfg,
+ // Logger: cfg.Logger,
+ Dialect: dialects.Postgres,
+ Mapper: core.CreateMapper("", nil, nil),
+ Statements: make(map[string]*core.MappedStatement),
+ },
+ }
for idx, test := range []xmlErrCase{
{
@@ -1370,11 +1377,15 @@ func TestXmlExpressionOk(t *testing.T) {
Tracer: core.StdLogger{Logger: log.New(os.Stdout, "[gobatis] ", log.Flags())},
}
- initCtx := &core.InitContext{Config: cfg,
- // Logger: cfg.Logger,
- Dialect: dialects.Postgres,
- Mapper: core.CreateMapper("", nil, nil),
- Statements: make(map[string]*core.MappedStatement)}
+ initCtx := &core.StmtContext{
+ InitContext: &core.InitContext{
+ Config: cfg,
+ // Logger: cfg.Logger,
+ Dialect: dialects.Postgres,
+ Mapper: core.CreateMapper("", nil, nil),
+ Statements: make(map[string]*core.MappedStatement),
+ },
+ }
for idx, test := range []xmlCase{
{
@@ -1521,11 +1532,15 @@ func TestXmlExpressionFail(t *testing.T) {
Tracer: core.StdLogger{Logger: log.New(os.Stdout, "[gobatis] ", log.Flags())},
}
- initCtx := &core.InitContext{Config: cfg,
- // Logger: cfg.Logger,
- Dialect: dialects.Postgres,
- Mapper: core.CreateMapper("", nil, nil),
- Statements: make(map[string]*core.MappedStatement)}
+ initCtx := &core.StmtContext{
+ InitContext: &core.InitContext{
+ Config: cfg,
+ // Logger: cfg.Logger,
+ Dialect: dialects.Postgres,
+ Mapper: core.CreateMapper("", nil, nil),
+ Statements: make(map[string]*core.MappedStatement),
+ },
+ }
for idx, test := range []xmlErrCase{
{
diff --git a/example_xml/example_test.go b/example_xml/example_test.go
index 8397fc9..06ffa2b 100644
--- a/example_xml/example_test.go
+++ b/example_xml/example_test.go
@@ -83,6 +83,17 @@ func ExampleSimple() {
fmt.Println("fetch user from database!")
fmt.Println(u.Nickname)
+
+ list, err := userDao.SelectAll("", "all", nil)
+ if err != nil {
+ fmt.Println(err)
+ return
+ }
+ fmt.Println("fetch all user from database!")
+ for _, u := range list {
+ fmt.Println(u.Nickname)
+ }
+
_, err = userDao.DeleteByID(id)
if err != nil {
fmt.Println(err)
@@ -97,5 +108,7 @@ func ExampleSimple() {
// update user: 1
// fetch user from database!
// ABC
+ // fetch all user from database!
+ // ABC
// delete success!
}
diff --git a/example_xml/xmlfiles/dm/test.xml b/example_xml/xmlfiles/dm/test.xml
index 1edc4cc..73023c3 100644
--- a/example_xml/xmlfiles/dm/test.xml
+++ b/example_xml/xmlfiles/dm/test.xml
@@ -3,10 +3,14 @@
+
+ SELECT * FROM gobatis_users
+