diff --git a/ast/node.go b/ast/node.go index dfdf715f..6d98090b 100644 --- a/ast/node.go +++ b/ast/node.go @@ -183,13 +183,13 @@ type BuiltinNode struct { Map Node // Used by optimizer to fold filter() and map() builtins. } -// ClosureNode represents a predicate. +// PredicateNode represents a predicate. // Example: // // filter(foo, .bar == 1) // // The predicate is ".bar == 1". -type ClosureNode struct { +type PredicateNode struct { base Node Node // Node of the predicate body. } diff --git a/ast/print.go b/ast/print.go index 6a7d698a..f5937715 100644 --- a/ast/print.go +++ b/ast/print.go @@ -162,7 +162,7 @@ func (n *BuiltinNode) String() string { return fmt.Sprintf("%s(%s)", n.Name, strings.Join(arguments, ", ")) } -func (n *ClosureNode) String() string { +func (n *PredicateNode) String() string { return n.Node.String() } diff --git a/ast/visitor.go b/ast/visitor.go index 90bc9f1d..03d72d11 100644 --- a/ast/visitor.go +++ b/ast/visitor.go @@ -45,7 +45,7 @@ func Walk(node *Node, v Visitor) { for i := range n.Arguments { Walk(&n.Arguments[i], v) } - case *ClosureNode: + case *PredicateNode: Walk(&n.Node, v) case *PointerNode: case *VariableDeclaratorNode: diff --git a/checker/checker.go b/checker/checker.go index fae8f5a1..0ad22f4e 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -152,8 +152,8 @@ func (v *checker) visit(node ast.Node) Nature { nt = v.CallNode(n) case *ast.BuiltinNode: nt = v.BuiltinNode(n) - case *ast.ClosureNode: - nt = v.ClosureNode(n) + case *ast.PredicateNode: + nt = v.PredicateNode(n) case *ast.PointerNode: nt = v.PointerNode(n) case *ast.VariableDeclaratorNode: @@ -194,19 +194,14 @@ func (v *checker) IdentifierNode(node *ast.IdentifierNode) Nature { if node.Value == "$env" { return unknown } - return v.ident(node, node.Value, true, true) + + return v.ident(node, node.Value, v.config.Env.Strict, true) } // ident method returns type of environment variable, builtin or function. func (v *checker) ident(node ast.Node, name string, strict, builtins bool) Nature { - if t, ok := v.config.Types[name]; ok { - if t.Ambiguous { - return v.error(node, "ambiguous identifier %v", name) - } - if t.Type == nil { - return nilNature - } - return Nature{Type: t.Type, Method: t.Method} + if nt, ok := v.config.Env.Get(name); ok { + return nt } if builtins { if fn, ok := v.config.Functions[name]; ok { @@ -219,9 +214,6 @@ func (v *checker) ident(node ast.Node, name string, strict, builtins bool) Natur if v.config.Strict && strict { return v.error(node, "unknown name %v", name) } - if v.config.DefaultType != nil { - return Nature{Type: v.config.DefaultType} - } return unknown } @@ -419,16 +411,10 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { case "..": if isInteger(l) && isInteger(r) { - return Nature{ - Type: arrayType, - SubType: Array{Of: integerNature}, - } + return arrayOf(integerNature) } if or(l, r, isInteger) { - return Nature{ - Type: arrayType, - SubType: Array{Of: integerNature}, - } + return arrayOf(integerNature) } case "??": @@ -501,6 +487,13 @@ func (v *checker) MemberNode(node *ast.MemberNode) Nature { if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { return v.error(node.Property, "cannot use %v to get an element from %v", prop, base) } + if prop, ok := node.Property.(*ast.StringNode); ok { + if field, ok := base.Fields[prop.Value]; ok { + return field + } else if base.Strict { + return v.error(node.Property, "unknown field %v", prop.Value) + } + } return base.Elem() case reflect.Array, reflect.Slice: @@ -512,7 +505,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) Nature { case reflect.Struct: if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value - if field, ok := fetchField(base, propertyName); ok { + if field, ok := base.FieldByName(propertyName); ok { return Nature{Type: field.Type} } if node.Method { @@ -625,15 +618,15 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } return boolNature } @@ -646,23 +639,20 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } if isUnknown(collection) { return arrayNature } - return Nature{ - Type: arrayType, - SubType: Array{Of: collection.Elem()}, - } + return arrayOf(collection.Elem()) } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -673,17 +663,14 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection, scopeVar{"index", integerNature}) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - return Nature{ - Type: arrayType, - SubType: Array{Of: closure.Out(0)}, - } + return arrayOf(*predicate.PredicateOut) } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -698,14 +685,14 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { - if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } return integerNature @@ -720,13 +707,13 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { if len(node.Arguments) == 2 { v.begin(collection) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { - return closure.Out(0) + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + return predicate.Out(0) } } else { if isUnknown(collection) { @@ -742,15 +729,15 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } if isUnknown(collection) { return unknown @@ -766,15 +753,15 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } return integerNature } @@ -787,14 +774,15 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - return Nature{Type: reflect.TypeOf(map[any][]any{})} + groups := arrayOf(collection.Elem()) + return Nature{Type: reflect.TypeOf(map[any][]any{}), ArrayOf: &groups} } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -805,18 +793,18 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() if len(node.Arguments) == 3 { _ = v.visit(node.Arguments[2]) } - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isUnknown(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - return Nature{Type: reflect.TypeOf([]any{})} + return collection } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -827,15 +815,15 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { } v.begin(collection, scopeVar{"index", integerNature}, scopeVar{"acc", unknown}) - closure := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() if len(node.Arguments) == 3 { _ = v.visit(node.Arguments[2]) } - if isFunc(closure) && closure.NumOut() == 1 { - return closure.Out(0) + if isFunc(predicate) && predicate.NumOut() == 1 { + return *predicate.PredicateOut } return v.error(node.Arguments[1], "predicate should has two input and one output param") @@ -879,7 +867,9 @@ func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" { if s, ok := node.Arguments[1].(*ast.StringNode); ok { - return Nature{Type: v.config.Types[s.Value].Type} + if nt, ok := v.config.Env.Get(s.Value); ok { + return nt + } } return unknown } @@ -1106,24 +1096,23 @@ func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newNature Na } } -func (v *checker) ClosureNode(node *ast.ClosureNode) Nature { +func (v *checker) PredicateNode(node *ast.PredicateNode) Nature { nt := v.visit(node.Node) - var out reflect.Type + var out []reflect.Type if isUnknown(nt) { - out = anyType - } else { - out = nt.Type + out = append(out, anyType) + } else if !isNil(nt) { + out = append(out, nt.Type) + } + return Nature{ + Type: reflect.FuncOf([]reflect.Type{anyType}, out, false), + PredicateOut: &nt, } - return Nature{Type: reflect.FuncOf( - []reflect.Type{anyType}, - []reflect.Type{out}, - false, - )} } func (v *checker) PointerNode(node *ast.PointerNode) Nature { if len(v.predicateScopes) == 0 { - return v.error(node, "cannot use pointer accessor outside closure") + return v.error(node, "cannot use pointer accessor outside predicate") } scope := v.predicateScopes[len(v.predicateScopes)-1] if node.Name == "" { @@ -1145,7 +1134,7 @@ func (v *checker) PointerNode(node *ast.PointerNode) Nature { } func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { - if _, ok := v.config.Types[node.Name]; ok { + if _, ok := v.config.Env.Get(node.Name); ok { return v.error(node, "cannot redeclare %v", node.Name) } if _, ok := v.config.Functions[node.Name]; ok { @@ -1210,10 +1199,7 @@ func (v *checker) ArrayNode(node *ast.ArrayNode) Nature { prev = curr } if allElementsAreSameType { - return Nature{ - Type: arrayNature.Type, - SubType: Array{Of: prev}, - } + return arrayOf(prev) } return arrayNature } diff --git a/checker/checker_test.go b/checker/checker_test.go index ae42392c..0639f1f4 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -9,6 +9,7 @@ import ( "github.com/expr-lang/expr/internal/testify/assert" "github.com/expr-lang/expr/internal/testify/require" + "github.com/expr-lang/expr/types" "github.com/expr-lang/expr" "github.com/expr-lang/expr/ast" @@ -758,26 +759,6 @@ func TestCheck_TaggedFieldName(t *testing.T) { assert.NoError(t, err) } -func TestCheck_Ambiguous(t *testing.T) { - type A struct { - Ambiguous bool - } - type B struct { - Ambiguous int - } - type Env struct { - A - B - } - - tree, err := parser.Parse(`Ambiguous == 1`) - require.NoError(t, err) - - _, err = checker.Check(tree, conf.New(Env{})) - assert.Error(t, err) - assert.Contains(t, err.Error(), "ambiguous identifier Ambiguous") -} - func TestCheck_NoConfig(t *testing.T) { tree, err := parser.Parse(`any`) require.NoError(t, err) @@ -831,7 +812,7 @@ func TestCheck_AllowUndefinedVariables_OptionalChaining(t *testing.T) { func TestCheck_PointerNode(t *testing.T) { _, err := checker.Check(&parser.Tree{Node: &ast.PointerNode{}}, nil) assert.Error(t, err) - assert.Contains(t, err.Error(), "cannot use pointer accessor outside closure") + assert.Contains(t, err.Error(), "cannot use pointer accessor outside predicate") } func TestCheck_TypeWeights(t *testing.T) { @@ -1088,3 +1069,41 @@ func TestCheck_builtin_without_call(t *testing.T) { }) } } + +func TestCheck_types(t *testing.T) { + env := types.Map{ + "foo": types.StrictMap{ + "bar": types.Map{ + "baz": "", + }, + }, + } + + noerr := "no error" + tests := []struct { + code string + err string + }{ + {`unknown`, noerr}, + {`foo.bar.baz > 0`, `invalid operation: > (mismatched types string and int)`}, + {`foo.unknown.baz`, `unknown field unknown (1:5)`}, + {`foo.bar.unknown`, noerr}, + {`[foo] | map(.unknown)`, `unknown field unknown`}, + {`[foo] | map(.bar) | filter(.baz)`, `predicate should return boolean (got string)`}, + } + + for _, test := range tests { + t.Run(test.code, func(t *testing.T) { + tree, err := parser.Parse(test.code) + require.NoError(t, err) + + _, err = checker.Check(tree, conf.New(env)) + if test.err == noerr { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), test.err) + } + }) + } +} diff --git a/checker/info.go b/checker/info.go index 3c9396fd..f1cc92eb 100644 --- a/checker/info.go +++ b/checker/info.go @@ -4,15 +4,17 @@ import ( "reflect" "github.com/expr-lang/expr/ast" - "github.com/expr-lang/expr/conf" + . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/vm" ) -func FieldIndex(types conf.TypesTable, node ast.Node) (bool, []int, string) { +func FieldIndex(env Nature, node ast.Node) (bool, []int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if t, ok := types[n.Value]; ok && len(t.FieldIndex) > 0 { - return true, t.FieldIndex, n.Value + if env.Kind() == reflect.Struct { + if field, ok := env.Get(n.Value); ok && len(field.FieldIndex) > 0 { + return true, field.FieldIndex, n.Value + } } case *ast.MemberNode: base := n.Node.Nature() @@ -20,8 +22,8 @@ func FieldIndex(types conf.TypesTable, node ast.Node) (bool, []int, string) { if base.Kind() == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { name := prop.Value - if field, ok := fetchField(base, name); ok { - return true, field.Index, name + if field, ok := base.FieldByName(name); ok { + return true, field.FieldIndex, name } } } @@ -29,11 +31,13 @@ func FieldIndex(types conf.TypesTable, node ast.Node) (bool, []int, string) { return false, nil, "" } -func MethodIndex(types conf.TypesTable, node ast.Node) (bool, int, string) { +func MethodIndex(env Nature, node ast.Node) (bool, int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if t, ok := types[n.Value]; ok { - return t.Method, t.MethodIndex, n.Value + if env.Kind() == reflect.Struct { + if m, ok := env.Get(n.Value); ok { + return m.Method, m.MethodIndex, n.Value + } } case *ast.MemberNode: if name, ok := n.Property.(*ast.StringNode); ok { diff --git a/checker/nature/nature.go b/checker/nature/nature.go index a7365998..a385521c 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -12,16 +12,19 @@ var ( ) type Nature struct { - Type reflect.Type - SubType SubType - Func *builtin.Function - Method bool + Type reflect.Type // Type of the value. If nil, then value is unknown. + Func *builtin.Function // Used to pass function type from callee to CallNode. + ArrayOf *Nature // Elem nature of array type (usually Type is []any, but ArrayOf can be any nature). + PredicateOut *Nature // Out nature of predicate. + Fields map[string]Nature // Fields of map type. + Strict bool // If map is types.StrictMap. + Nil bool // If value is nil. + Method bool // If value retrieved from method. Usually used to determine amount of in arguments. + MethodIndex int // Index of method in type. + FieldIndex []int // Index of field in type. } func (n Nature) String() string { - if n.SubType != nil { - return n.SubType.String() - } if n.Type != nil { return n.Type.String() } @@ -54,8 +57,8 @@ func (n Nature) Elem() Nature { case reflect.Map, reflect.Ptr: return Nature{Type: n.Type.Elem()} case reflect.Array, reflect.Slice: - if array, ok := n.SubType.(Array); ok { - return array.Of + if n.ArrayOf != nil { + return *n.ArrayOf } return Nature{Type: n.Type.Elem()} } @@ -63,6 +66,12 @@ func (n Nature) Elem() Nature { } func (n Nature) AssignableTo(nt Nature) bool { + if n.Nil { + // Untyped nil is assignable to any interface, but implements only the empty interface. + if nt.Type != nil && nt.Type.Kind() == reflect.Interface { + return true + } + } if n.Type == nil || nt.Type == nil { return false } @@ -88,24 +97,14 @@ func (n Nature) MethodByName(name string) (Nature, bool) { // the same interface. return Nature{Type: method.Type}, true } else { - return Nature{Type: method.Type, Method: true}, true + return Nature{ + Type: method.Type, + Method: true, + MethodIndex: method.Index, + }, true } } -func (n Nature) NumField() int { - if n.Type == nil { - return 0 - } - return n.Type.NumField() -} - -func (n Nature) Field(i int) reflect.StructField { - if n.Type == nil { - return reflect.StructField{} - } - return n.Type.Field(i) -} - func (n Nature) NumIn() int { if n.Type == nil { return 0 @@ -140,3 +139,89 @@ func (n Nature) IsVariadic() bool { } return n.Type.IsVariadic() } + +func (n Nature) FieldByName(name string) (Nature, bool) { + if n.Type == nil { + return unknown, false + } + field, ok := fetchField(n.Type, name) + return Nature{Type: field.Type, FieldIndex: field.Index}, ok +} + +func (n Nature) IsFastMap() bool { + if n.Type == nil { + return false + } + if n.Type.Kind() == reflect.Map && + n.Type.Key().Kind() == reflect.String && + n.Type.Elem().Kind() == reflect.Interface { + return true + } + return false +} + +func (n Nature) Get(name string) (Nature, bool) { + if n.Type == nil { + return unknown, false + } + + if m, ok := n.MethodByName(name); ok { + return m, true + } + + t := deref.Type(n.Type) + + switch t.Kind() { + case reflect.Struct: + if f, ok := fetchField(t, name); ok { + return Nature{ + Type: f.Type, + FieldIndex: f.Index, + }, true + } + case reflect.Map: + if f, ok := n.Fields[name]; ok { + return f, true + } + } + return unknown, false +} + +func (n Nature) All() map[string]Nature { + table := make(map[string]Nature) + + if n.Type == nil { + return table + } + + for i := 0; i < n.Type.NumMethod(); i++ { + method := n.Type.Method(i) + table[method.Name] = Nature{ + Type: method.Type, + Method: true, + MethodIndex: method.Index, + } + } + + t := deref.Type(n.Type) + + switch t.Kind() { + case reflect.Struct: + for name, nt := range StructFields(t) { + if _, ok := table[name]; ok { + continue + } + table[name] = nt + } + + case reflect.Map: + for key, nt := range n.Fields { + if _, ok := table[key]; ok { + continue + } + table[key] = nt + } + } + + return table +} diff --git a/checker/nature/of.go b/checker/nature/of.go new file mode 100644 index 00000000..93235f9c --- /dev/null +++ b/checker/nature/of.go @@ -0,0 +1,47 @@ +package nature + +import ( + "fmt" + "reflect" + + "github.com/expr-lang/expr/types" +) + +func Of(value any) Nature { + if value == nil { + return Nature{Nil: true} + } + + v := reflect.ValueOf(value) + + switch v.Kind() { + case reflect.Map: + _, strict := value.(types.StrictMap) + fields := make(map[string]Nature, v.Len()) + for _, key := range v.MapKeys() { + elem := v.MapIndex(key) + if !elem.IsValid() || !elem.CanInterface() { + panic(fmt.Sprintf("invalid map value: %s", key)) + } + face := elem.Interface() + switch face.(type) { + case types.Map, types.StrictMap: + fields[key.String()] = Of(face) + default: + if face == nil { + fields[key.String()] = Nature{Nil: true} + continue + } + fields[key.String()] = Nature{Type: reflect.TypeOf(face)} + + } + } + return Nature{ + Type: v.Type(), + Fields: fields, + Strict: strict, + } + } + + return Nature{Type: v.Type()} +} diff --git a/checker/nature/types.go b/checker/nature/types.go deleted file mode 100644 index 1f9955e9..00000000 --- a/checker/nature/types.go +++ /dev/null @@ -1,13 +0,0 @@ -package nature - -type SubType interface { - String() string -} - -type Array struct { - Of Nature -} - -func (a Array) String() string { - return "[]" + a.Of.String() -} diff --git a/checker/nature/utils.go b/checker/nature/utils.go new file mode 100644 index 00000000..c242f91a --- /dev/null +++ b/checker/nature/utils.go @@ -0,0 +1,76 @@ +package nature + +import ( + "reflect" + + "github.com/expr-lang/expr/internal/deref" +) + +func fieldName(field reflect.StructField) string { + if taggedName := field.Tag.Get("expr"); taggedName != "" { + return taggedName + } + return field.Name +} + +func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { + // First check all structs fields. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + // Search all fields, even embedded structs. + if fieldName(field) == name { + return field, true + } + } + + // Second check fields of embedded structs. + for i := 0; i < t.NumField(); i++ { + anon := t.Field(i) + if anon.Anonymous { + anonType := anon.Type + if anonType.Kind() == reflect.Pointer { + anonType = anonType.Elem() + } + if field, ok := fetchField(anonType, name); ok { + field.Index = append(anon.Index, field.Index...) + return field, true + } + } + } + + return reflect.StructField{}, false +} + +func StructFields(t reflect.Type) map[string]Nature { + table := make(map[string]Nature) + + t = deref.Type(t) + if t == nil { + return table + } + + switch t.Kind() { + case reflect.Struct: + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + + if f.Anonymous { + for name, typ := range StructFields(f.Type) { + if _, ok := table[name]; ok { + continue + } + typ.FieldIndex = append(f.Index, typ.FieldIndex...) + table[name] = typ + } + } + + table[fieldName(f)] = Nature{ + Type: f.Type, + FieldIndex: f.Index, + } + + } + } + + return table +} diff --git a/checker/types.go b/checker/types.go index 2eb5392e..ef93cf03 100644 --- a/checker/types.go +++ b/checker/types.go @@ -5,12 +5,11 @@ import ( "time" . "github.com/expr-lang/expr/checker/nature" - "github.com/expr-lang/expr/conf" ) var ( unknown = Nature{} - nilNature = Nature{Type: reflect.TypeOf(Nil{})} + nilNature = Nature{Nil: true} boolNature = Nature{Type: reflect.TypeOf(true)} integerNature = Nature{Type: reflect.TypeOf(0)} floatNature = Nature{Type: reflect.TypeOf(float64(0))} @@ -28,14 +27,15 @@ var ( arrayType = reflect.TypeOf([]any{}) ) -// Nil is a special type to represent nil. -type Nil struct{} +func arrayOf(nt Nature) Nature { + return Nature{ + Type: arrayType, + ArrayOf: &nt, + } +} func isNil(nt Nature) bool { - if nt.Type == nil { - return false - } - return nt.Type == nilNature.Type + return nt.Nil } func combined(l, r Nature) Nature { @@ -72,7 +72,7 @@ func or(l, r Nature, fns ...func(Nature) bool) bool { func isUnknown(nt Nature) bool { switch { - case nt.Type == nil: + case nt.Type == nil && !nt.Nil: return true case nt.Kind() == reflect.Interface: return true @@ -166,34 +166,6 @@ func isFunc(nt Nature) bool { return false } -func fetchField(nt Nature, name string) (reflect.StructField, bool) { - // First check all structs fields. - for i := 0; i < nt.NumField(); i++ { - field := nt.Field(i) - // Search all fields, even embedded structs. - if conf.FieldName(field) == name { - return field, true - } - } - - // Second check fields of embedded structs. - for i := 0; i < nt.NumField(); i++ { - anon := nt.Field(i) - if anon.Anonymous { - anonType := anon.Type - if anonType.Kind() == reflect.Pointer { - anonType = anonType.Elem() - } - if field, ok := fetchField(Nature{Type: anonType}, name); ok { - field.Index = append(anon.Index, field.Index...) - return field, true - } - } - } - - return reflect.StructField{}, false -} - func kind(t reflect.Type) reflect.Kind { if t == nil { return reflect.Invalid diff --git a/compiler/compiler.go b/compiler/compiler.go index 5e993540..68eae5ea 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -9,6 +9,7 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/checker" + . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/parser" @@ -259,8 +260,8 @@ func (c *compiler) compile(node ast.Node) { c.CallNode(n) case *ast.BuiltinNode: c.BuiltinNode(n) - case *ast.ClosureNode: - c.ClosureNode(n) + case *ast.PredicateNode: + c.PredicateNode(n) case *ast.PointerNode: c.PointerNode(n) case *ast.VariableDeclaratorNode: @@ -292,21 +293,19 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) { return } - var mapEnv bool - var types conf.TypesTable + var env Nature if c.config != nil { - mapEnv = c.config.MapEnv - types = c.config.Types + env = c.config.Env } - if mapEnv { + if env.IsFastMap() { c.emit(OpLoadFast, c.addConstant(node.Value)) - } else if ok, index, name := checker.FieldIndex(types, node); ok { + } else if ok, index, name := checker.FieldIndex(env, node); ok { c.emit(OpLoadField, c.addConstant(&runtime.Field{ Index: index, Path: []string{name}, })) - } else if ok, index, name := checker.MethodIndex(types, node); ok { + } else if ok, index, name := checker.MethodIndex(env, node); ok { c.emit(OpLoadMethod, c.addConstant(&runtime.Method{ Name: name, Index: index, @@ -647,12 +646,12 @@ func (c *compiler) ChainNode(node *ast.ChainNode) { } func (c *compiler) MemberNode(node *ast.MemberNode) { - var types conf.TypesTable + var env Nature if c.config != nil { - types = c.config.Types + env = c.config.Env } - if ok, index, name := checker.MethodIndex(types, node); ok { + if ok, index, name := checker.MethodIndex(env, node); ok { c.compile(node.Node) c.emit(OpMethod, c.addConstant(&runtime.Method{ Name: name, @@ -663,14 +662,14 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { op := OpFetch base := node.Node - ok, index, nodeName := checker.FieldIndex(types, node) + ok, index, nodeName := checker.FieldIndex(env, node) path := []string{nodeName} if ok { op = OpFetchField for !node.Optional { if ident, isIdent := base.(*ast.IdentifierNode); isIdent { - if ok, identIndex, name := checker.FieldIndex(types, ident); ok { + if ok, identIndex, name := checker.FieldIndex(env, ident); ok { index = append(identIndex, index...) path = append([]string{name}, path...) c.emitLocation(ident.Location(), OpLoadField, c.addConstant( @@ -681,7 +680,7 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { } if member, isMember := base.(*ast.MemberNode); isMember { - if ok, memberIndex, name := checker.FieldIndex(types, member); ok { + if ok, memberIndex, name := checker.FieldIndex(env, member); ok { index = append(memberIndex, index...) path = append([]string{name}, path...) node = member @@ -730,7 +729,7 @@ func (c *compiler) SliceNode(node *ast.SliceNode) { func (c *compiler) CallNode(node *ast.CallNode) { fn := node.Callee.Type() - if kind(fn) == reflect.Func { + if fn.Kind() == reflect.Func { fnInOffset := 0 fnNumIn := fn.NumIn() switch callee := node.Callee.(type) { @@ -742,7 +741,7 @@ func (c *compiler) CallNode(node *ast.CallNode) { } } case *ast.IdentifierNode: - if t, ok := c.config.Types[callee.Value]; ok && t.Method { + if t, ok := c.config.Env.MethodByName(callee.Value); ok && t.Method { fnInOffset = 1 fnNumIn-- } @@ -777,7 +776,7 @@ func (c *compiler) CallNode(node *ast.CallNode) { } c.compile(node.Callee) - isMethod, _, _ := checker.MethodIndex(c.config.Types, node.Callee) + isMethod, _, _ := checker.MethodIndex(c.config.Env, node.Callee) if index, ok := checker.TypedFuncIndex(node.Callee.Type(), isMethod); ok { c.emit(OpCallTyped, index) return @@ -1117,7 +1116,7 @@ func (c *compiler) emitLoopBackwards(body func()) { c.patchJump(end) } -func (c *compiler) ClosureNode(node *ast.ClosureNode) { +func (c *compiler) PredicateNode(node *ast.PredicateNode) { c.compile(node.Node) } @@ -1199,6 +1198,9 @@ func (c *compiler) PairNode(node *ast.PairNode) { } func (c *compiler) derefInNeeded(node ast.Node) { + if node.Nature().Nil { + return + } switch node.Type().Kind() { case reflect.Ptr, reflect.Interface: c.emit(OpDeref) diff --git a/conf/config.go b/conf/config.go index 01a407a1..77bb2a67 100644 --- a/conf/config.go +++ b/conf/config.go @@ -6,33 +6,32 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/checker/nature" + "github.com/expr-lang/expr/types" "github.com/expr-lang/expr/vm/runtime" ) type FunctionsTable map[string]*builtin.Function type Config struct { - Env any - Types TypesTable - MapEnv bool - DefaultType reflect.Type - Expect reflect.Kind - ExpectAny bool - Optimize bool - Strict bool - Profile bool - ConstFns map[string]reflect.Value - Visitors []ast.Visitor - Functions FunctionsTable - Builtins FunctionsTable - Disabled map[string]bool // disabled builtins + EnvObject any + Env nature.Nature + Expect reflect.Kind + ExpectAny bool + Optimize bool + Strict bool + Profile bool + ConstFns map[string]reflect.Value + Visitors []ast.Visitor + Functions FunctionsTable + Builtins FunctionsTable + Disabled map[string]bool // disabled builtins } // CreateNew creates new config with default values. func CreateNew() *Config { c := &Config{ Optimize: true, - Types: make(TypesTable), ConstFns: make(map[string]reflect.Value), Functions: make(map[string]*builtin.Function), Builtins: make(map[string]*builtin.Function), @@ -52,31 +51,20 @@ func New(env any) *Config { } func (c *Config) WithEnv(env any) { - var mapEnv bool - var mapValueType reflect.Type - if _, ok := env.(map[string]any); ok { - mapEnv = true - } else { - if reflect.ValueOf(env).Kind() == reflect.Map { - mapValueType = reflect.TypeOf(env).Elem() - } - } - - c.Env = env - types := CreateTypesTable(env) - for name, t := range types { - c.Types[name] = t - } - c.MapEnv = mapEnv - c.DefaultType = mapValueType c.Strict = true + c.EnvObject = env + c.Env = nature.Of(env) + c.Env.Strict = true // To keep backward compatibility with expr.AllowUndefinedVariables() + if _, ok := env.(types.Map); ok { + c.Env.Strict = false + } } func (c *Config) ConstExpr(name string) { - if c.Env == nil { + if c.EnvObject == nil { panic("no environment is specified for ConstExpr()") } - fn := reflect.ValueOf(runtime.Fetch(c.Env, name)) + fn := reflect.ValueOf(runtime.Fetch(c.EnvObject, name)) if fn.Kind() != reflect.Func { panic(fmt.Errorf("const expression %q must be a function", name)) } @@ -99,7 +87,7 @@ func (c *Config) IsOverridden(name string) bool { if _, ok := c.Functions[name]; ok { return true } - if _, ok := c.Types[name]; ok { + if _, ok := c.Env.Get(name); ok { return true } return false diff --git a/conf/types_table.go b/conf/types_table.go deleted file mode 100644 index a42a4287..00000000 --- a/conf/types_table.go +++ /dev/null @@ -1,121 +0,0 @@ -package conf - -import ( - "reflect" - - "github.com/expr-lang/expr/internal/deref" -) - -type TypesTable map[string]Tag - -type Tag struct { - Type reflect.Type - Ambiguous bool - FieldIndex []int - Method bool - MethodIndex int -} - -// CreateTypesTable creates types table for type checks during parsing. -// If struct is passed, all fields will be treated as variables, -// as well as all fields of embedded structs and struct itself. -// -// If map is passed, all items will be treated as variables -// (key as name, value as type). -func CreateTypesTable(i any) TypesTable { - if i == nil { - return nil - } - - types := make(TypesTable) - v := reflect.ValueOf(i) - t := reflect.TypeOf(i) - - d := t - if t.Kind() == reflect.Ptr { - d = t.Elem() - } - - switch d.Kind() { - case reflect.Struct: - types = FieldsFromStruct(d) - - // Methods of struct should be gathered from original struct with pointer, - // as methods maybe declared on pointer receiver. Also this method retrieves - // all embedded structs methods as well, no need to recursion. - for i := 0; i < t.NumMethod(); i++ { - m := t.Method(i) - types[m.Name] = Tag{ - Type: m.Type, - Method: true, - MethodIndex: i, - } - } - - case reflect.Map: - for _, key := range v.MapKeys() { - value := v.MapIndex(key) - if key.Kind() == reflect.String && value.IsValid() && value.CanInterface() { - if key.String() == "$env" { // Could check for all keywords here - panic("attempt to misuse env keyword as env map key") - } - types[key.String()] = Tag{Type: reflect.TypeOf(value.Interface())} - } - } - - // A map may have method too. - for i := 0; i < t.NumMethod(); i++ { - m := t.Method(i) - types[m.Name] = Tag{ - Type: m.Type, - Method: true, - MethodIndex: i, - } - } - } - - return types -} - -func FieldsFromStruct(t reflect.Type) TypesTable { - types := make(TypesTable) - t = deref.Type(t) - if t == nil { - return types - } - - switch t.Kind() { - case reflect.Struct: - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - - if f.Anonymous { - for name, typ := range FieldsFromStruct(f.Type) { - if _, ok := types[name]; ok { - types[name] = Tag{Ambiguous: true} - } else { - typ.FieldIndex = append(f.Index, typ.FieldIndex...) - types[name] = typ - } - } - } - if fn := FieldName(f); fn == "$env" { // Could check for all keywords here - panic("attempt to misuse env keyword as env struct field tag") - } else { - types[FieldName(f)] = Tag{ - Type: f.Type, - FieldIndex: f.Index, - } - } - } - } - - return types -} - -func FieldName(field reflect.StructField) string { - if taggedName := field.Tag.Get("expr"); taggedName != "" { - return taggedName - } - return field.Name -} diff --git a/docgen/docgen.go b/docgen/docgen.go index aed0f48f..d9abcf0e 100644 --- a/docgen/docgen.go +++ b/docgen/docgen.go @@ -5,7 +5,7 @@ import ( "regexp" "strings" - "github.com/expr-lang/expr/conf" + "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/internal/deref" ) @@ -84,8 +84,8 @@ func CreateDoc(i any) *Context { PkgPath: deref.Type(reflect.TypeOf(i)).PkgPath(), } - for name, t := range conf.CreateTypesTable(i) { - if t.Ambiguous { + for name, t := range nature.Of(i).All() { + if _, ok := c.Variables[Identifier(name)]; ok { continue } c.Variables[Identifier(name)] = c.use(t.Type, fromMethod(t.Method)) @@ -220,8 +220,11 @@ appendix: c.Types[name] = a } - for name, field := range conf.FieldsFromStruct(t) { - if isPrivate(name) || isProtobuf(name) || field.Ambiguous { + for name, field := range nature.StructFields(t) { + if isPrivate(name) || isProtobuf(name) { + continue + } + if _, ok := a.Fields[Identifier(name)]; ok { continue } a.Fields[Identifier(name)] = c.use(field.Type) diff --git a/docgen/docgen_test.go b/docgen/docgen_test.go index 26cf4f7a..2532a0c7 100644 --- a/docgen/docgen_test.go +++ b/docgen/docgen_test.go @@ -131,7 +131,7 @@ func TestCreateDoc(t *testing.T) { PkgPath: "github.com/expr-lang/expr/docgen_test", } - assert.EqualValues(t, expected, doc) + assert.Equal(t, expected.Markdown(), doc.Markdown()) } type A struct { @@ -160,6 +160,9 @@ func TestCreateDoc_Ambiguous(t *testing.T) { Kind: "struct", Name: "A", }, + "AmbiguousField": { + Kind: "int", + }, "B": { Kind: "struct", Name: "B", @@ -189,16 +192,17 @@ func TestCreateDoc_Ambiguous(t *testing.T) { "C": { Kind: "struct", Fields: map[Identifier]*Type{ - "A": {Kind: "struct", Name: "A"}, - "B": {Kind: "struct", Name: "B"}, - "OkField": {Kind: "int"}, + "A": {Kind: "struct", Name: "A"}, + "AmbiguousField": {Kind: "int"}, + "B": {Kind: "struct", Name: "B"}, + "OkField": {Kind: "int"}, }, }, }, PkgPath: "github.com/expr-lang/expr/docgen_test", } - assert.EqualValues(t, expected, doc) + assert.Equal(t, expected.Markdown(), doc.Markdown()) } func TestCreateDoc_FromMap(t *testing.T) { @@ -247,7 +251,7 @@ func TestCreateDoc_FromMap(t *testing.T) { }, } - require.EqualValues(t, expected, doc) + require.EqualValues(t, expected.Markdown(), doc.Markdown()) } func TestContext_Markdown(t *testing.T) { diff --git a/expr.go b/expr.go index 8c619e1c..5e60791e 100644 --- a/expr.go +++ b/expr.go @@ -45,7 +45,7 @@ func Operator(operator string, fn ...string) Option { p := &patcher.OperatorOverloading{ Operator: operator, Overloads: fn, - Types: c.Types, + Env: &c.Env, Functions: c.Functions, } c.Visitors = append(c.Visitors, p) diff --git a/expr_test.go b/expr_test.go index 3724467f..e71393f3 100644 --- a/expr_test.go +++ b/expr_test.go @@ -12,6 +12,7 @@ import ( "github.com/expr-lang/expr/internal/testify/assert" "github.com/expr-lang/expr/internal/testify/require" + "github.com/expr-lang/expr/types" "github.com/expr-lang/expr" "github.com/expr-lang/expr/ast" @@ -1673,10 +1674,6 @@ func TestIssue105(t *testing.T) { _, err := expr.Compile(code, expr.Env(Env{})) require.NoError(t, err) - - _, err = expr.Compile(`Field == ''`, expr.Env(Env{})) - require.Error(t, err) - require.Contains(t, err.Error(), "ambiguous identifier Field") } func TestIssue_nested_closures(t *testing.T) { @@ -2704,3 +2701,38 @@ func TestExpr_nil_op_str(t *testing.T) { }) } } + +func TestExpr_env_types_map(t *testing.T) { + envTypes := types.Map{ + "foo": types.StrictMap{ + "bar": "value", + }, + } + + program, err := expr.Compile(`foo.bar`, expr.Env(envTypes)) + require.NoError(t, err) + + env := map[string]any{ + "foo": map[string]any{ + "bar": "value", + }, + } + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, "value", output) +} + +func TestExpr_env_types_map_error(t *testing.T) { + envTypes := types.Map{ + "foo": types.StrictMap{ + "bar": "value", + }, + } + + program, err := expr.Compile(`foo.bar`, expr.Env(envTypes)) + require.NoError(t, err) + + _, err = expr.Run(program, envTypes) + require.Error(t, err) +} diff --git a/optimizer/filter_map.go b/optimizer/filter_map.go index e916dd75..c6de6c73 100644 --- a/optimizer/filter_map.go +++ b/optimizer/filter_map.go @@ -10,14 +10,14 @@ func (*filterMap) Visit(node *Node) { if mapBuiltin, ok := (*node).(*BuiltinNode); ok && mapBuiltin.Name == "map" && len(mapBuiltin.Arguments) == 2 { - if closure, ok := mapBuiltin.Arguments[1].(*ClosureNode); ok { + if predicate, ok := mapBuiltin.Arguments[1].(*PredicateNode); ok { if filter, ok := mapBuiltin.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && filter.Map == nil /* not already optimized */ { patchCopyType(node, &BuiltinNode{ Name: "filter", Arguments: filter.Arguments, - Map: closure.Node, + Map: predicate.Node, }) } } diff --git a/optimizer/fold.go b/optimizer/fold.go index 2f4562c2..bb40eab9 100644 --- a/optimizer/fold.go +++ b/optimizer/fold.go @@ -298,8 +298,8 @@ func (fold *fold) Visit(node *Node) { base.Arguments[0], &BinaryNode{ Operator: "&&", - Left: base.Arguments[1].(*ClosureNode).Node, - Right: n.Arguments[1].(*ClosureNode).Node, + Left: base.Arguments[1].(*PredicateNode).Node, + Right: n.Arguments[1].(*PredicateNode).Node, }, }, }) diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index 56a89049..7830c0e7 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -184,7 +184,7 @@ func TestOptimize_filter_len(t *testing.T) { Name: "count", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: "==", Left: &ast.MemberNode{ @@ -211,7 +211,7 @@ func TestOptimize_filter_0(t *testing.T) { Name: "find", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: "==", Left: &ast.MemberNode{ @@ -239,7 +239,7 @@ func TestOptimize_filter_first(t *testing.T) { Name: "find", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: "==", Left: &ast.MemberNode{ @@ -267,7 +267,7 @@ func TestOptimize_filter_minus_1(t *testing.T) { Name: "findLast", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: "==", Left: &ast.MemberNode{ @@ -295,7 +295,7 @@ func TestOptimize_filter_last(t *testing.T) { Name: "findLast", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: "==", Left: &ast.MemberNode{ @@ -323,7 +323,7 @@ func TestOptimize_filter_map(t *testing.T) { Name: "filter", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: "==", Left: &ast.MemberNode{ @@ -354,7 +354,7 @@ func TestOptimize_filter_map_first(t *testing.T) { Name: "find", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: "==", Left: &ast.MemberNode{ @@ -402,7 +402,7 @@ func TestOptimize_predicate_combination(t *testing.T) { Name: tt.fn, Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: tt.wantOp, Left: &ast.BinaryNode{ @@ -452,7 +452,7 @@ func TestOptimize_predicate_combination_nested(t *testing.T) { Name: "all", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BuiltinNode{ Name: "all", Arguments: []ast.Node{ @@ -460,7 +460,7 @@ func TestOptimize_predicate_combination_nested(t *testing.T) { Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Friends"}, }, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.BinaryNode{ Operator: "&&", Left: &ast.BinaryNode{ diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go index 62e296d1..65f88e34 100644 --- a/optimizer/predicate_combination.go +++ b/optimizer/predicate_combination.go @@ -21,19 +21,19 @@ func (v *predicateCombination) Visit(node *Node) { if combinedOp, ok := combinedOperator(left.Name, op.Operator); ok { if right, ok := op.Right.(*BuiltinNode); ok && right.Name == left.Name { if left.Arguments[0].Type() == right.Arguments[0].Type() && left.Arguments[0].String() == right.Arguments[0].String() { - closure := &ClosureNode{ + predicate := &PredicateNode{ Node: &BinaryNode{ Operator: combinedOp, - Left: left.Arguments[1].(*ClosureNode).Node, - Right: right.Arguments[1].(*ClosureNode).Node, + Left: left.Arguments[1].(*PredicateNode).Node, + Right: right.Arguments[1].(*PredicateNode).Node, }, } - v.Visit(&closure.Node) + v.Visit(&predicate.Node) patchCopyType(node, &BuiltinNode{ Name: left.Name, Arguments: []Node{ left.Arguments[0], - closure, + predicate, }, }) } diff --git a/optimizer/sum_map_test.go b/optimizer/sum_map_test.go index 96bdcfd3..2d0ffeb7 100644 --- a/optimizer/sum_map_test.go +++ b/optimizer/sum_map_test.go @@ -22,7 +22,7 @@ func TestOptimize_sum_map(t *testing.T) { Name: "sum", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ + &ast.PredicateNode{ Node: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Age"}, diff --git a/parser/parser.go b/parser/parser.go index 77b2a700..0817f6e4 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -19,7 +19,7 @@ type arg byte const ( expr arg = 1 << iota - closure + predicate ) const optional arg = 1 << 7 @@ -27,21 +27,21 @@ const optional arg = 1 << 7 var predicates = map[string]struct { args []arg }{ - "all": {[]arg{expr, closure}}, - "none": {[]arg{expr, closure}}, - "any": {[]arg{expr, closure}}, - "one": {[]arg{expr, closure}}, - "filter": {[]arg{expr, closure}}, - "map": {[]arg{expr, closure}}, - "count": {[]arg{expr, closure | optional}}, - "sum": {[]arg{expr, closure | optional}}, - "find": {[]arg{expr, closure}}, - "findIndex": {[]arg{expr, closure}}, - "findLast": {[]arg{expr, closure}}, - "findLastIndex": {[]arg{expr, closure}}, - "groupBy": {[]arg{expr, closure}}, - "sortBy": {[]arg{expr, closure, expr | optional}}, - "reduce": {[]arg{expr, closure, expr | optional}}, + "all": {[]arg{expr, predicate}}, + "none": {[]arg{expr, predicate}}, + "any": {[]arg{expr, predicate}}, + "one": {[]arg{expr, predicate}}, + "filter": {[]arg{expr, predicate}}, + "map": {[]arg{expr, predicate}}, + "count": {[]arg{expr, predicate | optional}}, + "sum": {[]arg{expr, predicate | optional}}, + "find": {[]arg{expr, predicate}}, + "findIndex": {[]arg{expr, predicate}}, + "findLast": {[]arg{expr, predicate}}, + "findLastIndex": {[]arg{expr, predicate}}, + "groupBy": {[]arg{expr, predicate}}, + "sortBy": {[]arg{expr, predicate, expr | optional}}, + "reduce": {[]arg{expr, predicate, expr | optional}}, } type parser struct { @@ -49,7 +49,7 @@ type parser struct { current Token pos int err *file.Error - depth int // closure call depth + depth int // predicate call depth config *conf.Config } @@ -298,7 +298,7 @@ func (p *parser) parsePrimary() Node { } } else { if token.Is(Operator, "#") || token.Is(Operator, ".") { - p.error("cannot use pointer accessor outside closure") + p.error("cannot use pointer accessor outside predicate") } } @@ -448,8 +448,8 @@ func (p *parser) parseCall(token Token, arguments []Node, checkOverrides bool) N switch { case arg&expr == expr: node = p.parseExpression(0) - case arg&closure == closure: - node = p.parseClosure() + case arg&predicate == predicate: + node = p.parsePredicate() } arguments = append(arguments, node) } @@ -497,7 +497,7 @@ func (p *parser) parseArguments(arguments []Node) []Node { return arguments } -func (p *parser) parseClosure() Node { +func (p *parser) parsePredicate() Node { startToken := p.current expectClosingBracket := false if p.current.Is(Bracket, "{") { @@ -512,11 +512,11 @@ func (p *parser) parseClosure() Node { if expectClosingBracket { p.expect(Bracket, "}") } - closure := &ClosureNode{ + predicateNode := &PredicateNode{ Node: node, } - closure.SetLocation(startToken.Location) - return closure + predicateNode.SetLocation(startToken.Location) + return predicateNode } func (p *parser) parseArrayExpression(token Token) Node { diff --git a/parser/parser_test.go b/parser/parser_test.go index 3c6ee5b2..a280bf39 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -427,7 +427,7 @@ world`}, Name: "all", Arguments: []Node{ &IdentifierNode{Value: "Tickets"}, - &ClosureNode{ + &PredicateNode{ Node: &PointerNode{}, }}}, }, @@ -437,7 +437,7 @@ world`}, Name: "all", Arguments: []Node{ &IdentifierNode{Value: "Tickets"}, - &ClosureNode{ + &PredicateNode{ Node: &BinaryNode{ Operator: ">", Left: &MemberNode{Node: &PointerNode{}, @@ -450,7 +450,7 @@ world`}, Name: "one", Arguments: []Node{ &IdentifierNode{Value: "Tickets"}, - &ClosureNode{ + &PredicateNode{ Node: &BinaryNode{ Operator: ">", Left: &MemberNode{ @@ -463,7 +463,7 @@ world`}, "filter(Prices, {# > 100})", &BuiltinNode{Name: "filter", Arguments: []Node{&IdentifierNode{Value: "Prices"}, - &ClosureNode{Node: &BinaryNode{Operator: ">", + &PredicateNode{Node: &BinaryNode{Operator: ">", Left: &PointerNode{}, Right: &IntegerNode{Value: 100}}}}}, }, @@ -550,7 +550,7 @@ world`}, Name: "map", Arguments: []Node{ &ArrayNode{}, - &ClosureNode{ + &PredicateNode{ Node: &PointerNode{Name: "index"}, }, }, @@ -694,7 +694,7 @@ a map key must be a quoted string, a number, a identifier, or an expression encl | .....^ .foo -cannot use pointer accessor outside closure (1:1) +cannot use pointer accessor outside predicate (1:1) | .foo | ^ @@ -904,7 +904,7 @@ func TestParse_pipe_operator(t *testing.T) { Name: "map", Arguments: []Node{ &IdentifierNode{Value: "arr"}, - &ClosureNode{ + &PredicateNode{ Node: &MemberNode{ Node: &PointerNode{}, Property: &StringNode{Value: "foo"}, diff --git a/patcher/operator_override.go b/patcher/operator_override.go index 551fe09b..b2d20ec4 100644 --- a/patcher/operator_override.go +++ b/patcher/operator_override.go @@ -6,13 +6,14 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" ) type OperatorOverloading struct { Operator string // Operator token to overload. Overloads []string // List of function names to replace operator with. - Types conf.TypesTable // Env types. + Env *nature.Nature // Env type. Functions conf.FunctionsTable // Env functions. applied bool // Flag to indicate if any changes were made to the tree. } @@ -56,7 +57,7 @@ func (p *OperatorOverloading) FindSuitableOperatorOverload(l, r reflect.Type) (r func (p *OperatorOverloading) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) { for _, fn := range p.Overloads { - fnType, ok := p.Types[fn] + fnType, ok := p.Env.Get(fn) if !ok { continue } @@ -103,7 +104,7 @@ func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex func (p *OperatorOverloading) Check() { for _, fn := range p.Overloads { - fnType, foundType := p.Types[fn] + fnType, foundType := p.Env.Get(fn) fnFunc, foundFunc := p.Functions[fn] if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) { panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator)) @@ -119,7 +120,7 @@ func (p *OperatorOverloading) Check() { } } -func checkType(fnType conf.Tag, fn string, operator string) { +func checkType(fnType nature.Nature, fn string, operator string) { requiredNumIn := 2 if fnType.Method { requiredNumIn = 3 // As first argument of method is receiver. diff --git a/types/types.go b/types/types.go new file mode 100644 index 00000000..9057d14a --- /dev/null +++ b/types/types.go @@ -0,0 +1,5 @@ +package types + +type Map map[string]any + +type StrictMap map[string]any