Skip to content

Commit

Permalink
fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
hengwei-test committed May 27, 2024
1 parent 4b6cd9d commit 96bec49
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 14 deletions.
4 changes: 2 additions & 2 deletions cmd/gobatis/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (cmd *Generator) generateInterfaceInit(out io.Writer, file *goparser2.File,
}

io.WriteString(out, "\r\n"+` var sqlExpressions = ctx.SqlExpressions`)
io.WriteString(out, "\r\n"+` ctx.SqlExpressions = map[string]*gobatis.SqlExpression{}`)
io.WriteString(out, "\r\n"+` ctx.SqlExpressions = map[string]gobatis.SqlExpression{}`)
io.WriteString(out, "\r\n"+` for id, expr := range sqlExpressions {`)
io.WriteString(out, "\r\n"+` ctx.SqlExpressions[id] = expr`)
io.WriteString(out, "\r\n"+` }`)
Expand Down Expand Up @@ -282,7 +282,7 @@ func (cmd *Generator) generateInterfaceInit(out io.Writer, file *goparser2.File,
io.WriteString(out, preprocessingSQL("sqlStr", false, dialect.SQL, recordTypeName))
}
io.WriteString(out, "\r\n}")
io.WriteString(out, "\r\n"+` expr, err := gobatis.NewSqlExpression(ctx, sqlstr)
io.WriteString(out, "\r\n"+` expr, err := gobatis.NewSqlExpression(ctx, sqlStr)
if err != nil {
return err
}
Expand Down
59 changes: 59 additions & 0 deletions core/expr_gval.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,63 @@ var expFunctions = []gval.Language{
gval.Function("isNull", isNull),
gval.Function("isnotnull", isNotNull),
gval.Function("isNotNull", isNotNull),
gval.Constant("nil", nil),
gval.Constant("null", nil),


gval.InfixOperator("==", func(a, b interface{}) (interface{}, error) {
if a == nil {
if b == nil {
return true, nil
}
return isNilValue(b), nil
}

if b == nil {
return isNilValue(a), nil
}

return reflect.DeepEqual(a, b), nil
}),
gval.InfixOperator("!=", func(a, b interface{}) (interface{}, error) {
if a == nil {
if b == nil {
return false, nil
}
return isNotNilValue(b), nil
}
if b == nil {
return isNotNilValue(a), nil
}

return !reflect.DeepEqual(a, b), nil
}),
}

func isNilValue(v interface{}) bool {
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Interface {
rv = rv.Elem()
}
if rv.Kind() != reflect.Ptr &&
rv.Kind() != reflect.Map &&
rv.Kind() != reflect.Slice {
return false
}
return rv.IsNil()
}

func isNotNilValue(v interface{}) bool {
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Interface {
rv = rv.Elem()
}
if rv.Kind() != reflect.Ptr &&
rv.Kind() != reflect.Map &&
rv.Kind() != reflect.Slice {
return true
}
return !rv.IsNil()
}

func RegisterExprFunction(name string, fn func(args ...interface{}) (interface{}, error)) {
Expand All @@ -294,6 +351,8 @@ type gvalSelector struct {
}

func (gs gvalSelector) SelectGVal(c context.Context, key string) (interface{}, error) {
// value, err := gs.get.Get(key)
// fmt.Println(key, value, err)
return gs.get.Get(key)
}

Expand Down
73 changes: 73 additions & 0 deletions core/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,79 @@ func TestReadOnly(t *testing.T) {
})
}

func TestQueryWithUserQuery(t *testing.T) {
tests.Run(t, func(_ testing.TB, factory *core.Session) {
mac, _ := net.ParseMAC("01:02:03:04:A5:A6")
ip := net.ParseIP("192.168.1.1")
name := "张三"
insertUser := tests.User{
Name: name,
Nickname: "haha",
Password: "password",
Description: "地球人",
Address: "沪南路1155号",
HostIP: ip,
HostMAC: mac,
HostIPPtr: &ip,
HostMACPtr: &mac,
Sex: "女",
ContactInfo: map[string]interface{}{"QQ": "8888888"},
Birth: time.Now(),
CreateTime: time.Now(),
}
users := tests.NewTestUsers(factory.SessionReference())


_, err := users.Insert(&insertUser)
if err != nil {
t.Error(err)
return
}

assertList := func(t testing.TB, list []tests.User) {
if len(list) != 1 {
t.Error("want 1 got", len(list))
return
}

if list[0].Name != name {
t.Error("want '"+name+"' got", list[0].Name)
return
}
}

list, err := users.QueryWithUserQuery1(tests.UserQuery{UseUsername: true, Username: name})
if err != nil {
t.Error(err)
return
}
assertList(t, list)

list, err = users.QueryWithUserQuery2(tests.UserQuery{UseUsername: true, Username: name})
if err != nil {
t.Error(err)
return
}
assertList(t, list)

list, err = users.QueryWithUserQuery3(tests.UserQuery{UseUsername: true, Username: name})
if err != nil {
t.Error(err)
return
}
assertList(t, list)


list, err = users.QueryWithUserQuery4(tests.UserQuery{UseUsername: true, Username: name})
if err != nil {
t.Error(err)
return
}
assertList(t, list)

})
}

func TestHandleError(t *testing.T) {
tests.Run(t, func(_ testing.TB, factory *core.Session) {
if factory.Dialect() != dialects.Postgres {
Expand Down
13 changes: 13 additions & 0 deletions core/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ func (stmt *MappedStatement) GenerateSQLs(ctx *Context) ([]sqlAndParam, error) {
}

func NewMapppedStatement(ctx *StmtContext, id string, statementType StatementType, resultType ResultType, sqlStr string) (*MappedStatement, error) {
if ctx.FindSqlFragment == nil {
sqlExpressions := ctx.InitContext.SqlExpressions
ctx.FindSqlFragment = func(id string) (SqlExpression, error) {
if sqlExpressions != nil {
sf := sqlExpressions[id]
if sf != nil {
return sf, nil
}
}
return nil, errors.New("sql '" + id + "' missing")
}
}

stmt := &MappedStatement{
id: id,
sqlType: statementType,
Expand Down
18 changes: 16 additions & 2 deletions core/xml_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ func (ifExpr ifExpression) String() string {
func (ifExpr ifExpression) writeTo(printer *sqlPrinter) {
bResult, err := ifExpr.test.Test(evalParameters{ctx: printer.ctx})
if err != nil {
printer.err = err
printer.err = errors.New("execute `"+ifExpr.test.String()+"` fail, " + err.Error())
return
}

Expand Down Expand Up @@ -543,7 +543,7 @@ func (chose *choseExpression) writeTo(printer *sqlPrinter) {
for idx := range chose.when {
bResult, err := chose.when[idx].test.Test(evalParameters{ctx: printer.ctx})
if err != nil {
printer.err = err
printer.err = errors.New("execute `"+chose.when[idx].test.String()+"` fail, " + err.Error())
return
}

Expand Down Expand Up @@ -1854,6 +1854,20 @@ type nestParameters struct {
values map[string]func(name string) (interface{}, error)
}


func (s nestParameters) RValue(dialect Dialect, param *Param) (interface{}, error) {
get, ok := s.values[param.Name]
if ok {
value, err := get(param.Name)
if err != nil {
return nil, err
}
return toSQLType(dialect, param, value)
}
return s.Parameters.RValue(dialect, param)
}


func (s nestParameters) Get(name string) (interface{}, error) {
get, ok := s.values[name]
if ok {
Expand Down
Loading

0 comments on commit 96bec49

Please sign in to comment.