Skip to content

Commit

Permalink
fix foreach
Browse files Browse the repository at this point in the history
  • Loading branch information
hengwei-test committed May 30, 2024
1 parent 8bb88cc commit 63b5eed
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 23 deletions.
89 changes: 68 additions & 21 deletions core/parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ func (sf structFinder) RValue(dialect Dialect, param *Param) (interface{}, error
}

type kvFinder struct {
Parameters Parameters
mapper *Mapper
paramNames []string
paramValues []interface{}
Expand Down Expand Up @@ -186,6 +187,24 @@ func (kvf *kvFinder) RValue(dialect Dialect, param *Param) (interface{}, error)
return kvf.get(param.Name, rvalueGetter{dialect: dialect, param: param})
}



func (kvf *kvFinder) getByIndex(foundIdx int) (reflect.Value, bool) {
if kvf.cachedValues == nil {
kvf.cachedValues = make([]reflect.Value, len(kvf.paramValues))
}
rValue := kvf.cachedValues[foundIdx]
if !rValue.IsValid() {
value := kvf.paramValues[foundIdx]
if value == nil {
return reflect.Value{}, false
}
rValue = reflect.ValueOf(value)
kvf.cachedValues[foundIdx] = rValue
}
return rValue, true
}

func (kvf *kvFinder) get(name string, getter valueGetter) (interface{}, error) {
foundIdx := -1
for idx := range kvf.paramNames {
Expand All @@ -199,6 +218,19 @@ func (kvf *kvFinder) get(name string, getter valueGetter) (interface{}, error) {
return getter.value(kvf.paramValues[foundIdx])
}

if kvf.Parameters != nil {
value, err := kvf.Parameters.Get(name)
if err == nil {
return value, nil
}

if err != ErrNotFound {
return nil, err
}

return getter.value(value)
}

dotIndex := strings.IndexByte(name, '.')
if dotIndex < 0 {
// 这里的是为下面情况的特殊处理
Expand All @@ -211,32 +243,47 @@ func (kvf *kvFinder) get(name string, getter valueGetter) (interface{}, error) {
}
dotIndex = -1
foundIdx = 0
} else {
fieldName := name[:dotIndex]
for idx := range kvf.paramNames {
if kvf.paramNames[idx] == fieldName {
foundIdx = idx
break
}
}
if foundIdx < 0 {
return nil, ErrNotFound


rValue, ok := kvf.getByIndex(foundIdx)
if !ok {
return nil, ErrNotFound // errors.New("canot read param '" + name[dotIndex+1:] + "', param '" + name[:dotIndex+1] + "' is nil")
}

return kvf.getFieldValue(name[dotIndex+1:], rValue, getter)
}

if kvf.cachedValues == nil {
kvf.cachedValues = make([]reflect.Value, len(kvf.paramValues))
fieldName := name[:dotIndex]
for idx := range kvf.paramNames {
if kvf.paramNames[idx] == fieldName {
foundIdx = idx
break
}
}
rValue := kvf.cachedValues[foundIdx]
if !rValue.IsValid() {
value := kvf.paramValues[foundIdx]
if value == nil {
if foundIdx >= 0 {
rValue, ok := kvf.getByIndex(foundIdx)
if !ok {
return nil, ErrNotFound // errors.New("canot read param '" + name[dotIndex+1:] + "', param '" + name[:dotIndex+1] + "' is nil")
}
rValue = reflect.ValueOf(value)
kvf.cachedValues[foundIdx] = rValue
return kvf.getFieldValue(name[dotIndex+1:], rValue, getter)
}

if kvf.Parameters == nil {
return nil, ErrNotFound
}

value, err := kvf.Parameters.Get(fieldName)
if err != nil {
return nil, err
}
if value == nil {
return nil, ErrNotFound
}

return kvf.getFieldValue(name[dotIndex+1:], reflect.ValueOf(value), getter)
}

func (kvf *kvFinder) getFieldValue(fieldName string,rValue reflect.Value, getter valueGetter) (interface{}, error) {
kind := rValue.Kind()
if kind == reflect.Ptr {
kind = rValue.Type().Elem().Kind()
Expand All @@ -251,7 +298,7 @@ func (kvf *kvFinder) get(name string, getter valueGetter) (interface{}, error) {
return nil, ErrNotFound // errors.New("canot read param '" + name[:dotIndex+1] + "', param '" + name[:dotIndex+1] + "' is nil")
}

value := rValue.MapIndex(reflect.ValueOf(name[dotIndex+1:]))
value := rValue.MapIndex(reflect.ValueOf(fieldName))
if !value.IsValid() {
return getter.value(nil)
// return nil, ErrNotFound //errors.New("canot read param '" + name[:dotIndex+1] + "', param '" + name[:dotIndex+1] + "' is nil")
Expand All @@ -266,9 +313,9 @@ func (kvf *kvFinder) get(name string, getter valueGetter) (interface{}, error) {

tm := kvf.mapper.TypeMap(rValue.Type())

fi, ok := tm.FieldNames[name[dotIndex+1:]]
fi, ok := tm.FieldNames[fieldName]
if !ok {
fi, ok = tm.Names[name[dotIndex+1:]]
fi, ok = tm.Names[fieldName]
if !ok {
return nil, ErrNotFound
}
Expand Down
3 changes: 1 addition & 2 deletions core/xml_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ func (foreach *forEachExpression) execOne(printer *sqlPrinter, key, value interf
// newPrinter.ctx = &ctx

printer.ctx.finder = &kvFinder{
Parameters: oldFinder,
mapper: printer.ctx.Mapper,
paramNames: []string{foreach.el.item, foreach.el.index},
paramValues: []interface{}{value, key},
Expand Down Expand Up @@ -1859,7 +1860,6 @@ type nestParameters struct {
values map[string]func(name string) (interface{}, error)
}


func (s nestParameters) RValue(dialect Dialect, param *Param) (interface{}, error) {
get, ok := s.values[param.Name]
if ok {
Expand All @@ -1872,7 +1872,6 @@ func (s nestParameters) RValue(dialect Dialect, param *Param) (interface{}, erro
return s.Parameters.RValue(dialect, param)
}


func (s nestParameters) Get(name string) (interface{}, error) {
get, ok := s.values[name]
if ok {
Expand Down
22 changes: 22 additions & 0 deletions core/xml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,17 @@ func TestXmlOk(t *testing.T) {
exceptedSQL: "aa ($1,$2,$3,$4)",
execeptedParams: []interface{}{0, 1, 2, 3},
},

{
name: "foreach array parent scope value",
sql: `aa <foreach collection="aa" index="index" item="item" open="(" separator="," close=")">#{item},#{parentValue}</foreach>`,
paramNames: []string{"aa", "parentValue"},
paramValues: []interface{}{[]interface{}{"a"}, "apr"},
exceptedSQL: "aa ($1,$2)",
execeptedParams: []interface{}{"a", "apr"},
isUnsortable: true,
},

{
name: "foreach map index",
sql: `aa <foreach collection="aa" index="index" item="item" open="(" separator="," close=")">#{index}</foreach>`,
Expand Down Expand Up @@ -564,6 +575,17 @@ func TestXmlOk(t *testing.T) {
execeptedParams: []interface{}{"a", "b", "c", "d"},
isUnsortable: true,
},


{
name: "foreach map parent scope value",
sql: `aa <foreach collection="aa" index="index" item="item" open="(" separator="," close=")">#{item},#{parentValue}</foreach>`,
paramNames: []string{"aa", "parentValue"},
paramValues: []interface{}{map[interface{}]interface{}{"1": "a"}, "apr"},
exceptedSQL: "aa ($1,$2)",
execeptedParams: []interface{}{"a", "apr"},
isUnsortable: true,
},
{
name: "choose ok",
sql: `aa <choose><when test="a==1">#{i1}</when></choose>`,
Expand Down

0 comments on commit 63b5eed

Please sign in to comment.