diff --git a/execution/execute.go b/execution/execute.go index 8681cb1..9685660 100644 --- a/execution/execute.go +++ b/execution/execute.go @@ -59,6 +59,8 @@ func exprExecutor(expr ast.Expr) (expressionExecutor, error) { return indexExprExecutor(e) case ast.PropertyExpr: return propertyExprExecutor(e) + case ast.VariableExpr: + return variableExprExecutor(e) case ast.NumberIntExpr: return numberIntExprExecutor(e) case ast.NumberFloatExpr: @@ -78,7 +80,41 @@ func exprExecutor(expr ast.Expr) (expressionExecutor, error) { func binaryExprExecutor(e ast.BinaryExpr) (expressionExecutor, error) { return func(data *model.Value) (*model.Value, error) { - panic("not implemented") + left, err := ExecuteAST(e.Left, data) + if err != nil { + return nil, fmt.Errorf("error evaluating left expression: %w", err) + } + right, err := ExecuteAST(e.Right, data) + if err != nil { + return nil, fmt.Errorf("error evaluating right expression: %w", err) + } + + switch e.Operator.Kind { + case lexer.Plus: + return left.Add(right) + case lexer.Dash: + return left.Subtract(right) + case lexer.Star: + return left.Multiply(right) + case lexer.Slash: + return left.Divide(right) + case lexer.Percent: + return left.Modulo(right) + case lexer.GreaterThan: + return left.GreaterThan(right) + case lexer.GreaterThanOrEqual: + return left.GreaterThanOrEqual(right) + case lexer.LessThan: + return left.LessThan(right) + case lexer.LessThanOrEqual: + return left.LessThanOrEqual(right) + case lexer.Equal: + return left.Equal(right) + case lexer.NotEqual: + return left.NotEqual(right) + default: + return nil, fmt.Errorf("unhandled operator: %s", e.Operator.Value) + } }, nil } @@ -155,3 +191,13 @@ func propertyExprExecutor(e ast.PropertyExpr) (expressionExecutor, error) { return data.GetMapKey(keyStr) }, nil } + +func variableExprExecutor(e ast.VariableExpr) (expressionExecutor, error) { + return func(data *model.Value) (*model.Value, error) { + varName := e.Name + if varName == "this" { + return data, nil + } + return nil, fmt.Errorf("variable %s not found", varName) + }, nil +} diff --git a/execution/execute_test.go b/execution/execute_test.go index 75cc22f..2e79c75 100644 --- a/execution/execute_test.go +++ b/execution/execute_test.go @@ -54,6 +54,75 @@ func TestExecuteSelector_HappyPath(t *testing.T) { } } + t.Run("binary expressions", func(t *testing.T) { + t.Run("math", func(t *testing.T) { + t.Run("literals", func(t *testing.T) { + t.Run("addition", runTest(testCase{ + in: model.NewValue(nil), + s: `1 + 2`, + out: model.NewIntValue(3), + })) + t.Run("subtraction", runTest(testCase{ + in: model.NewValue(nil), + s: `5 - 2`, + out: model.NewIntValue(3), + })) + t.Run("multiplication", runTest(testCase{ + in: model.NewValue(nil), + s: `5 * 2`, + out: model.NewIntValue(10), + })) + t.Run("division", runTest(testCase{ + in: model.NewValue(nil), + s: `10 / 2`, + out: model.NewIntValue(5), + })) + t.Run("modulus", runTest(testCase{ + in: model.NewValue(nil), + s: `10 % 3`, + out: model.NewIntValue(1), + })) + t.Run("ordering", runTest(testCase{ + in: model.NewValue(nil), + s: `45.2 + 5 * 4 - 2 / 2`, // 45.2 + (5 * 4) - (2 / 2) = 45.2 + 20 - 1 + out: model.NewFloatValue(64.2), + })) + }) + }) + t.Run("comparison", func(t *testing.T) { + t.Run("equal", runTest(testCase{ + in: model.NewValue(nil), + s: `1 == 1`, + out: model.NewBoolValue(true), + })) + t.Run("not equal", runTest(testCase{ + in: model.NewValue(nil), + s: `1 != 1`, + out: model.NewBoolValue(false), + })) + t.Run("greater than", runTest(testCase{ + in: model.NewValue(nil), + s: `2 > 1`, + out: model.NewBoolValue(true), + })) + t.Run("greater than or equal", runTest(testCase{ + in: model.NewValue(nil), + s: `2 >= 2`, + out: model.NewBoolValue(true), + })) + t.Run("less than", runTest(testCase{ + in: model.NewValue(nil), + s: `1 < 2`, + out: model.NewBoolValue(true), + })) + t.Run("less than or equal", runTest(testCase{ + in: model.NewValue(nil), + s: `2 <= 2`, + out: model.NewBoolValue(true), + })) + }) + }) + t.Run("literal", func(t *testing.T) { t.Run("string", runTest(testCase{ in: model.NewValue(nil), diff --git a/model/value_comparison.go b/model/value_comparison.go new file mode 100644 index 0000000..c0cc597 --- /dev/null +++ b/model/value_comparison.go @@ -0,0 +1,153 @@ +package model + +func (v *Value) Equal(other *Value) (*Value, error) { + if v.IsInt() && other.IsInt() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a == b), nil + } + if v.IsFloat() && other.IsFloat() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(a == b), nil + } + if v.IsInt() && other.IsFloat() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(float64(a) == b), nil + } + if v.IsFloat() && other.IsInt() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a == float64(b)), nil + } + return nil, &ErrIncompatibleTypes{A: v, B: other} +} + +func (v *Value) NotEqual(other *Value) (*Value, error) { + equals, err := v.Equal(other) + if err != nil { + return nil, err + } + boolValue, err := equals.BoolValue() + if err != nil { + return nil, err + } + return NewValue(!boolValue), nil +} + +func (v *Value) LessThan(other *Value) (*Value, error) { + if v.IsInt() && other.IsInt() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a < b), nil + } + if v.IsFloat() && other.IsFloat() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(a < b), nil + } + if v.IsInt() && other.IsFloat() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(float64(a) < b), nil + } + if v.IsFloat() && other.IsInt() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a < float64(b)), nil + } + return nil, &ErrIncompatibleTypes{A: v, B: other} +} + +func (v *Value) LessThanOrEqual(other *Value) (*Value, error) { + lessThan, err := v.LessThan(other) + if err != nil { + return nil, err + } + boolValue, err := lessThan.BoolValue() + if err != nil { + return nil, err + } + equals, err := v.Equal(other) + if err != nil { + return nil, err + } + boolEquals, err := equals.BoolValue() + if err != nil { + return nil, err + } + return NewValue(boolValue || boolEquals), nil +} + +func (v *Value) GreaterThan(other *Value) (*Value, error) { + lessThanOrEqual, err := v.LessThanOrEqual(other) + if err != nil { + return nil, err + } + boolValue, err := lessThanOrEqual.BoolValue() + if err != nil { + return nil, err + } + return NewValue(!boolValue), nil +} + +func (v *Value) GreaterThanOrEqual(other *Value) (*Value, error) { + lessThan, err := v.LessThan(other) + if err != nil { + return nil, err + } + boolValue, err := lessThan.BoolValue() + if err != nil { + return nil, err + } + return NewValue(!boolValue), nil +} diff --git a/model/value_math.go b/model/value_math.go new file mode 100644 index 0000000..94699dc --- /dev/null +++ b/model/value_math.go @@ -0,0 +1,266 @@ +package model + +import ( + "fmt" + "math" +) + +type ErrIncompatibleTypes struct { + A *Value + B *Value +} + +func (e *ErrIncompatibleTypes) Error() string { + return fmt.Sprintf("incompatible types: %s and %s", e.A.Type(), e.B.Type()) +} + +func (v *Value) Add(other *Value) (*Value, error) { + if v.IsInt() && other.IsInt() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a + b), nil + } + if v.IsFloat() && other.IsFloat() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(a + b), nil + } + if v.IsInt() && other.IsFloat() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(float64(a) + b), nil + } + if v.IsFloat() && other.IsInt() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a + float64(b)), nil + } + if v.IsString() && other.IsString() { + a, err := v.StringValue() + if err != nil { + return nil, err + } + b, err := other.StringValue() + if err != nil { + return nil, err + } + return NewValue(a + b), nil + } + return nil, &ErrIncompatibleTypes{A: v, B: other} +} + +func (v *Value) Subtract(other *Value) (*Value, error) { + if v.IsInt() && other.IsInt() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a - b), nil + } + if v.IsFloat() && other.IsFloat() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(a - b), nil + } + if v.IsInt() && other.IsFloat() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(float64(a) - b), nil + } + if v.IsFloat() && other.IsInt() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a - float64(b)), nil + } + return nil, &ErrIncompatibleTypes{A: v, B: other} +} + +func (v *Value) Multiply(other *Value) (*Value, error) { + if v.IsInt() && other.IsInt() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a * b), nil + } + if v.IsFloat() && other.IsFloat() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(a * b), nil + } + if v.IsInt() && other.IsFloat() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(float64(a) * b), nil + } + if v.IsFloat() && other.IsInt() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a * float64(b)), nil + } + return nil, &ErrIncompatibleTypes{A: v, B: other} +} + +func (v *Value) Divide(other *Value) (*Value, error) { + if v.IsInt() && other.IsInt() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a / b), nil + } + if v.IsFloat() && other.IsFloat() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(a / b), nil + } + if v.IsInt() && other.IsFloat() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(float64(a) / b), nil + } + if v.IsFloat() && other.IsInt() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a / float64(b)), nil + } + return nil, &ErrIncompatibleTypes{A: v, B: other} +} + +func (v *Value) Modulo(other *Value) (*Value, error) { + if v.IsInt() && other.IsInt() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(a % b), nil + } + if v.IsFloat() && other.IsFloat() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(math.Mod(a, b)), nil + } + if v.IsInt() && other.IsFloat() { + a, err := v.IntValue() + if err != nil { + return nil, err + } + b, err := other.FloatValue() + if err != nil { + return nil, err + } + return NewValue(math.Mod(float64(a), b)), nil + } + if v.IsFloat() && other.IsInt() { + a, err := v.FloatValue() + if err != nil { + return nil, err + } + b, err := other.IntValue() + if err != nil { + return nil, err + } + return NewValue(math.Mod(a, float64(b))), nil + } + return nil, &ErrIncompatibleTypes{A: v, B: other} +} diff --git a/selector/ast/expression_complex.go b/selector/ast/expression_complex.go index c616253..6fd2304 100644 --- a/selector/ast/expression_complex.go +++ b/selector/ast/expression_complex.go @@ -10,6 +10,13 @@ type BinaryExpr struct { func (BinaryExpr) expr() {} +type UnaryExpr struct { + Operator lexer.Token + Right Expr +} + +func (UnaryExpr) expr() {} + type CallExpr struct { Function string Args Expressions @@ -76,3 +83,9 @@ type MapExpr struct { } func (MapExpr) expr() {} + +type VariableExpr struct { + Name string +} + +func (VariableExpr) expr() {} diff --git a/selector/lexer/token.go b/selector/lexer/token.go index 4d83298..2c9f360 100644 --- a/selector/lexer/token.go +++ b/selector/lexer/token.go @@ -18,10 +18,9 @@ const ( CloseCurly OpenParen CloseParen - Equal - Equals - NotEqual - Not + Equal // == + Equals // = + NotEqual // != And Or Like @@ -29,7 +28,7 @@ const ( String Number Bool - Add + Plus Increment IncrementBy Dash @@ -40,6 +39,13 @@ const ( Percent Dot Spread + Dollar + Variable + GreaterThan + GreaterThanOrEqual + LessThan + LessThanOrEqual + Exclamation ) type Tokens []Token diff --git a/selector/lexer/tokenize.go b/selector/lexer/tokenize.go index bd2e73c..23d200a 100644 --- a/selector/lexer/tokenize.go +++ b/selector/lexer/tokenize.go @@ -41,6 +41,13 @@ func (p *Tokenizer) peekRuneEqual(i int, to rune) bool { return rune(p.src[i]) == to } +func (p *Tokenizer) peekRuneMatches(i int, fn func(rune) bool) bool { + if i >= p.srcLen { + return false + } + return fn(rune(p.src[i])) +} + func (p *Tokenizer) parseCurRune() (Token, error) { switch p.src[p.i] { case '.': @@ -70,6 +77,15 @@ func (p *Tokenizer) parseCurRune() (Token, error) { return NewToken(Slash, "/", p.i, 1), nil case '%': return NewToken(Percent, "%", p.i, 1), nil + case '$': + if p.peekRuneMatches(p.i+1, unicode.IsLetter) { + pos := p.i + 1 + for pos < p.srcLen && (unicode.IsLetter(rune(p.src[pos])) || unicode.IsDigit(rune(p.src[pos]))) { + pos++ + } + return NewToken(Variable, p.src[p.i+1:pos], p.i, pos-p.i), nil + } + return NewToken(Dollar, "$", p.i, 1), nil case '=': if p.peekRuneEqual(p.i+1, '=') { return NewToken(Equal, "==", p.i, 2), nil @@ -85,7 +101,7 @@ func (p *Tokenizer) parseCurRune() (Token, error) { if p.peekRuneEqual(p.i+1, '+') { return NewToken(Increment, "++", p.i, 2), nil } - return NewToken(Add, "+", p.i, 1), nil + return NewToken(Plus, "+", p.i, 1), nil case '-': if p.peekRuneEqual(p.i+1, '=') { return NewToken(DecrementBy, "-=", p.i, 2), nil @@ -94,6 +110,16 @@ func (p *Tokenizer) parseCurRune() (Token, error) { return NewToken(Decrement, "--", p.i, 2), nil } return NewToken(Dash, "-", p.i, 1), nil + case '>': + if p.peekRuneEqual(p.i+1, '=') { + return NewToken(GreaterThanOrEqual, ">=", p.i, 2), nil + } + return NewToken(GreaterThan, ">", p.i, 1), nil + case '<': + if p.peekRuneEqual(p.i+1, '=') { + return NewToken(LessThanOrEqual, "<>>=", p.i, 2), nil + } + return NewToken(LessThan, "<", p.i, 1), nil case '!': if p.peekRuneEqual(p.i+1, '=') { return NewToken(NotEqual, "!=", p.i, 2), nil @@ -101,7 +127,7 @@ func (p *Tokenizer) parseCurRune() (Token, error) { if p.peekRuneEqual(p.i+1, '~') { return NewToken(NotLike, "!~", p.i, 2), nil } - return NewToken(Not, "!", p.i, 1), nil + return NewToken(Exclamation, "!", p.i, 1), nil case '&': if p.peekRuneEqual(p.i+1, '&') { return NewToken(And, "&&", p.i, 2), nil diff --git a/selector/lexer/tokenize_test.go b/selector/lexer/tokenize_test.go index ab9685f..3d0401d 100644 --- a/selector/lexer/tokenize_test.go +++ b/selector/lexer/tokenize_test.go @@ -3,7 +3,56 @@ package lexer import "testing" func TestTokenizer_Parse(t *testing.T) { - tok := NewTokenizer("foo.bar.baz[1] != 42.123 || foo.bar.baz['hello'] == 42 && x == 'a\\'b' + false . .... asd...") + type testCase struct { + in string + out []TokenKind + } + + runTest := func(tc testCase) func(t *testing.T) { + return func(t *testing.T) { + tok := NewTokenizer(tc.in) + tokens, err := tok.Tokenize() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tokens) != len(tc.out) { + t.Fatalf("unexpected number of tokens: %d", len(tokens)) + } + for i := range tokens { + if tokens[i].Kind != tc.out[i] { + t.Errorf("unexpected token kind at position %d: exp %v, got %v", i, tc.out[i], tokens[i].Kind) + return + } + } + } + } + + t.Run("variables", runTest(testCase{ + in: "$foo $bar123 $baz $", + out: []TokenKind{ + Variable, + Variable, + Variable, + Dollar, + }, + })) + + t.Run("everything", runTest(testCase{ + in: "foo.bar.baz[1] != 42.123 || foo.bar.baz['hello'] == 42 && x == 'a\\'b' + false . .... asd... $name", + out: []TokenKind{ + Symbol, Dot, Symbol, Dot, Symbol, OpenBracket, Number, CloseBracket, NotEqual, Number, + Or, + Symbol, Dot, Symbol, Dot, Symbol, OpenBracket, String, CloseBracket, Equal, Number, + And, + Symbol, Equal, String, + Plus, Bool, + Dot, Spread, Dot, + Symbol, Spread, + Variable, + }, + })) + + tok := NewTokenizer("foo.bar.baz[1] != 42.123 || foo.bar.baz['hello'] == 42 && x == 'a\\'b' + false . .... asd... $name") tokens, err := tok.Tokenize() if err != nil { t.Fatalf("unexpected error: %v", err) @@ -14,9 +63,10 @@ func TestTokenizer_Parse(t *testing.T) { Symbol, Dot, Symbol, Dot, Symbol, OpenBracket, String, CloseBracket, Equal, Number, And, Symbol, Equal, String, - Add, Bool, + Plus, Bool, Dot, Spread, Dot, Symbol, Spread, + Variable, } if len(tokens) != len(exp) { t.Fatalf("unexpected number of tokens: %d", len(tokens)) diff --git a/selector/parser/denotations.go b/selector/parser/denotations.go new file mode 100644 index 0000000..40589a0 --- /dev/null +++ b/selector/parser/denotations.go @@ -0,0 +1,26 @@ +package parser + +import "github.com/tomwright/dasel/v3/selector/lexer" + +// null denotation tokens are tokens that expect no token to the left of them. +var nullDenotationTokens = []lexer.TokenKind{} + +// left denotation tokens are tokens that expect a token to the left of them. +var leftDenotationTokens = []lexer.TokenKind{ + lexer.Plus, + lexer.Dash, + lexer.Slash, + lexer.Star, + lexer.Percent, + lexer.Equal, + lexer.NotEqual, + lexer.GreaterThan, + lexer.GreaterThanOrEqual, + lexer.LessThan, + lexer.LessThanOrEqual, +} + +// right denotation tokens are tokens that expect a token to the right of them. +var rightDenotationTokens = []lexer.TokenKind{ + lexer.Exclamation, // Not operator +} diff --git a/selector/parser/parse_array.go b/selector/parser/parse_array.go index c376861..f6f3f33 100644 --- a/selector/parser/parse_array.go +++ b/selector/parser/parse_array.go @@ -6,7 +6,7 @@ import ( ) func parseArray(p *Parser) (ast.Expr, error) { - if err := p.expect(lexer.Symbol); err != nil { + if err := p.expect(lexer.Symbol, lexer.Variable); err != nil { return nil, err } if err := p.expectN(1, lexer.OpenBracket); err != nil { @@ -19,10 +19,24 @@ func parseArray(p *Parser) (ast.Expr, error) { if err != nil { return nil, err } - return ast.ChainExprs( - ast.PropertyExpr{ + + var e ast.Expr + + switch { + case token.IsKind(lexer.Variable): + e = ast.VariableExpr{ + Name: token.Value, + } + case token.IsKind(lexer.Symbol): + e = ast.PropertyExpr{ Property: ast.StringExpr{Value: token.Value}, - }, + } + default: + panic("unexpected token kind") + } + + return ast.ChainExprs( + e, idx, ), nil } @@ -61,7 +75,7 @@ func parseSquareBrackets(p *Parser) (ast.Expr, error) { if p.current().IsKind(lexer.Colon) { p.advance() // We have no start index - end, err = p.parseExpression() + end, _, err = p.parseExpression(nil) if err != nil { return nil, err } @@ -71,7 +85,7 @@ func parseSquareBrackets(p *Parser) (ast.Expr, error) { }, nil } - start, err = p.parseExpression() + start, _, err = p.parseExpression(nil) if err != nil { return nil, err } @@ -98,7 +112,7 @@ func parseSquareBrackets(p *Parser) (ast.Expr, error) { }, nil } - end, err = p.parseExpression() + end, _, err = p.parseExpression(nil) if err != nil { return nil, err } diff --git a/selector/parser/parse_func.go b/selector/parser/parse_func.go index 8a8475b..22bf2dd 100644 --- a/selector/parser/parse_func.go +++ b/selector/parser/parse_func.go @@ -37,7 +37,7 @@ func parseArgs(p *Parser) ([]ast.Expr, error) { break } - arg, err := p.parseExpression() + arg, _, err := p.parseExpression(nil) if err != nil { return nil, err } diff --git a/selector/parser/parse_map.go b/selector/parser/parse_map.go index 7796a35..c73469d 100644 --- a/selector/parser/parse_map.go +++ b/selector/parser/parse_map.go @@ -26,6 +26,9 @@ func parseMap(p *Parser) (ast.Expr, error) { expressions := make([]ast.Expr, 0) + var expr ast.Expr + var err error + var replaceLast bool for { if p.current().IsKind(lexer.CloseParen) { if len(expressions) == 0 { @@ -40,10 +43,14 @@ func parseMap(p *Parser) (ast.Expr, error) { continue } - expr, err := p.parseExpression() + expr, replaceLast, err = p.parseExpression(expr) if err != nil { return nil, err } + if replaceLast { + expressions[len(expressions)-1] = expr + continue + } expressions = append(expressions, expr) } diff --git a/selector/parser/parse_object.go b/selector/parser/parse_object.go index b8ecb20..1afd203 100644 --- a/selector/parser/parse_object.go +++ b/selector/parser/parse_object.go @@ -51,7 +51,7 @@ func parseObject(p *Parser) (ast.Expr, error) { continue } - key, err := p.parseExpression() + key, _, err := p.parseExpression(nil) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func parseObject(p *Parser) (ast.Expr, error) { } p.advance() - val, err := p.parseExpression() + val, _, err := p.parseExpression(nil) if err != nil { return nil, err } diff --git a/selector/parser/parse_variable.go b/selector/parser/parse_variable.go new file mode 100644 index 0000000..de70bee --- /dev/null +++ b/selector/parser/parse_variable.go @@ -0,0 +1,31 @@ +package parser + +import ( + "github.com/tomwright/dasel/v3/selector/ast" + "github.com/tomwright/dasel/v3/selector/lexer" +) + +func parseVariable(p *Parser) (ast.Expr, error) { + token := p.current() + + next := p.peek() + + if next.IsKind(lexer.OpenBracket) { + return parseArray(p) + } + + prop := ast.VariableExpr{ + Name: token.Value, + } + + if next.IsKind(lexer.Spread) { + p.advanceN(2) + return ast.ChainExprs( + prop, + ast.SpreadExpr{}, + ), nil + } + + p.advance() + return prop, nil +} diff --git a/selector/parser/parser.go b/selector/parser/parser.go index 9025c6c..b3d4dc4 100644 --- a/selector/parser/parser.go +++ b/selector/parser/parser.go @@ -2,6 +2,7 @@ package parser import ( "fmt" + "slices" "github.com/tomwright/dasel/v3/selector/ast" "github.com/tomwright/dasel/v3/selector/lexer" @@ -41,7 +42,7 @@ func (p *Parser) currentScope() scope { func (p *Parser) endOfExpressionTokens() []lexer.TokenKind { switch p.currentScope() { case scopeRoot: - return []lexer.TokenKind{lexer.EOF, lexer.Dot} + return append([]lexer.TokenKind{lexer.EOF, lexer.Dot}, leftDenotationTokens...) case scopeFuncArgs: return []lexer.TokenKind{lexer.Comma, lexer.CloseParen} case scopeMap: @@ -71,6 +72,9 @@ func NewParser(tokens lexer.Tokens) *Parser { func (p *Parser) Parse() (ast.Expr, error) { var expressions ast.Expressions + var expr ast.Expr + var err error + var replaceLast bool for p.hasToken() { if p.current().IsKind(lexer.EOF) { break @@ -79,10 +83,14 @@ func (p *Parser) Parse() (ast.Expr, error) { p.advance() continue } - expr, err := p.parseExpression() + expr, replaceLast, err = p.parseExpression(expr) if err != nil { return nil, err } + if replaceLast { + expressions[len(expressions)-1] = expr + continue + } expressions = append(expressions, expr) } switch len(expressions) { @@ -95,32 +103,41 @@ func (p *Parser) Parse() (ast.Expr, error) { } } -func (p *Parser) parseExpression() (res ast.Expr, err error) { +func (p *Parser) parseExpression(last ast.Expr) (res ast.Expr, replaceLast bool, err error) { defer func() { if err == nil { err = p.expectEndOfExpression() } }() + + if last != nil && slices.Contains(leftDenotationTokens, p.current().Kind) { + res, replaceLast, err = parseBinary(p, last) + return + } + switch p.current().Kind { case lexer.String: - return parseStringLiteral(p) + res, err = parseStringLiteral(p) case lexer.Number: - return parseNumberLiteral(p) + res, err = parseNumberLiteral(p) case lexer.Symbol: - return parseSymbol(p) + res, err = parseSymbol(p) case lexer.OpenBracket: - return parseSquareBrackets(p) + res, err = parseSquareBrackets(p) case lexer.OpenCurly: - return parseObject(p) + res, err = parseObject(p) case lexer.Bool: - return parseBoolLiteral(p) + res, err = parseBoolLiteral(p) case lexer.Spread: - return parseSpread(p) + res, err = parseSpread(p) + case lexer.Variable: + res, err = parseVariable(p) default: - return nil, &UnexpectedTokenError{ + return nil, false, &UnexpectedTokenError{ Token: p.current(), } } + return } func (p *Parser) hasToken() bool { @@ -138,6 +155,14 @@ func (p *Parser) current() lexer.Token { return lexer.Token{Kind: lexer.EOF} } +func (p *Parser) previous() lexer.Token { + i := p.i - 1 + if i > 0 && i < len(p.tokens) { + return p.tokens[i] + } + return lexer.Token{Kind: lexer.EOF} +} + func (p *Parser) advance() lexer.Token { p.i++ return p.current() diff --git a/selector/parser/parser_binary.go b/selector/parser/parser_binary.go new file mode 100644 index 0000000..dde40bf --- /dev/null +++ b/selector/parser/parser_binary.go @@ -0,0 +1,30 @@ +package parser + +import "github.com/tomwright/dasel/v3/selector/ast" + +func parseBinary(p *Parser, left ast.Expr) (ast.Expr, bool, error) { + if err := p.expect(leftDenotationTokens...); err != nil { + return nil, false, err + } + for { + if !p.current().IsKind(leftDenotationTokens...) { + break + } + + token := p.current() + p.advance() + + right, _, err := p.parseExpression(left) + if err != nil { + return nil, false, err + } + + left = ast.BinaryExpr{ + Left: left, + Operator: token, + Right: right, + } + } + + return left, true, nil +} diff --git a/selector/parser/parser_test.go b/selector/parser/parser_test.go index 494130e..e456ea8 100644 --- a/selector/parser/parser_test.go +++ b/selector/parser/parser_test.go @@ -291,4 +291,15 @@ func TestParser_Parse_HappyPath(t *testing.T) { }}, })) }) + + t.Run("variables", func(t *testing.T) { + t.Run("single variable", run(t, testCase{ + input: `$foo`, + expected: ast.VariableExpr{Name: "foo"}, + })) + t.Run("variable passed to func", run(t, testCase{ + input: `len($foo)`, + expected: ast.CallExpr{Function: "len", Args: ast.Expressions{ast.VariableExpr{Name: "foo"}}}, + })) + }) }