diff --git a/limitador-server/src/envoy_rls/server.rs b/limitador-server/src/envoy_rls/server.rs index 37476c6d..3975d08e 100644 --- a/limitador-server/src/envoy_rls/server.rs +++ b/limitador-server/src/envoy_rls/server.rs @@ -253,7 +253,7 @@ mod tests { namespace, 1, 60, - vec!["req_method == 'GET'"], + vec!["req_method == 'GET'".to_string()], vec!["app_id"], ) .expect("This must be a valid limit!"); @@ -395,10 +395,16 @@ mod tests { let namespace = "test_namespace"; vec![ - Limit::new(namespace, 10, 60, vec!["x == '1'"], vec!["z"]) - .expect("This must be a valid limit!"), - Limit::new(namespace, 0, 60, vec!["x == '1'", "y == '2'"], vec!["z"]) + Limit::new(namespace, 10, 60, vec!["x == '1'".to_string()], vec!["z"]) .expect("This must be a valid limit!"), + Limit::new( + namespace, + 0, + 60, + vec!["x == '1'".to_string(), "y == '2'".to_string()], + vec!["z"], + ) + .expect("This must be a valid limit!"), ] .into_iter() .for_each(|limit| { @@ -462,7 +468,7 @@ mod tests { #[tokio::test] async fn test_takes_into_account_the_hits_addend_param() { let namespace = "test_namespace"; - let limit = Limit::new(namespace, 10, 60, vec!["x == '1'"], vec!["y"]) + let limit = Limit::new(namespace, 10, 60, vec!["x == '1'".to_string()], vec!["y"]) .expect("This must be a valid limit!"); let limiter = RateLimiter::new(10_000); @@ -532,7 +538,7 @@ mod tests { // "hits_addend" is optional according to the spec, and should default // to 1, However, with the autogenerated structs it defaults to 0. let namespace = "test_namespace"; - let limit = Limit::new(namespace, 1, 60, vec!["x == '1'"], vec!["y"]) + let limit = Limit::new(namespace, 1, 60, vec!["x == '1'".to_string()], vec!["y"]) .expect("This must be a valid limit!"); let limiter = RateLimiter::new(10_000); diff --git a/limitador-server/src/http_api/request_types.rs b/limitador-server/src/http_api/request_types.rs index 1cae899a..af80e7b5 100644 --- a/limitador-server/src/http_api/request_types.rs +++ b/limitador-server/src/http_api/request_types.rs @@ -1,6 +1,5 @@ use limitador::counter::Counter as LimitadorCounter; -use limitador::errors::LimitadorError; -use limitador::limit::Limit as LimitadorLimit; +use limitador::limit::{Limit as LimitadorLimit, ParseError}; use paperclip::actix::Apiv2Schema; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; @@ -42,7 +41,7 @@ impl From<&LimitadorLimit> for Limit { } impl TryFrom for LimitadorLimit { - type Error = LimitadorError; + type Error = ParseError; fn try_from(limit: Limit) -> Result { let mut limitador_limit = if let Some(id) = limit.id { diff --git a/limitador/src/errors.rs b/limitador/src/errors.rs index d590aafa..82efa829 100644 --- a/limitador/src/errors.rs +++ b/limitador/src/errors.rs @@ -1,4 +1,5 @@ -use crate::limit::ConditionParsingError; +use crate::limit::cel::EvaluationError; +use crate::limit::ParseError; use crate::storage::StorageErr; use std::convert::Infallible; use std::error::Error; @@ -7,7 +8,7 @@ use std::fmt::{Display, Formatter}; #[derive(Debug)] pub enum LimitadorError { StorageError(StorageErr), - InterpreterError(ConditionParsingError), + InterpreterError(EvaluationError), } impl Display for LimitadorError { @@ -38,13 +39,13 @@ impl From for LimitadorError { } } -impl From for LimitadorError { - fn from(err: ConditionParsingError) -> Self { +impl From for LimitadorError { + fn from(err: EvaluationError) -> Self { LimitadorError::InterpreterError(err) } } -impl From for LimitadorError { +impl From for ParseError { fn from(value: Infallible) -> Self { unreachable!("unexpected infallible value: {:?}", value) } diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index b6729634..4ed66495 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -1,14 +1,9 @@ -use crate::limit::conditions::{ErrorType, Literal, SyntaxError, Token, TokenType}; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::{BTreeSet, HashMap, HashSet}; -use std::error::Error; -use std::fmt::{Debug, Display, Formatter}; +use std::fmt::Debug; use std::hash::{Hash, Hasher}; -use crate::errors::LimitadorError; -use crate::LimitadorResult; - #[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord, Serialize, Deserialize)] pub struct Namespace(String); @@ -47,211 +42,22 @@ pub struct Limit { variables: BTreeSet, } -#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)] -#[serde(try_from = "String", into = "String")] -pub struct Condition { - var_name: String, - predicate: Predicate, - operand: String, -} - -#[derive(Debug)] -pub struct ConditionParsingError { - error: SyntaxError, - pub tokens: Vec, - condition: String, -} - -impl Display for ConditionParsingError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{} of condition \"{}\"", self.error, self.condition) - } -} - -impl Error for ConditionParsingError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - Some(&self.error) - } -} - -impl TryFrom<&str> for Condition { - type Error = ConditionParsingError; - - fn try_from(value: &str) -> Result { - value.to_owned().try_into() - } -} - -impl TryFrom for Condition { - type Error = ConditionParsingError; - - fn try_from(value: String) -> Result { - match conditions::Scanner::scan(value.clone()) { - Ok(tokens) => match tokens.len().cmp(&(3_usize)) { - Ordering::Equal => { - match ( - &tokens[0].token_type, - &tokens[1].token_type, - &tokens[2].token_type, - ) { - ( - TokenType::Identifier, - TokenType::EqualEqual | TokenType::NotEqual, - TokenType::String, - ) => { - if let ( - Some(Literal::Identifier(var_name)), - Some(Literal::String(operand)), - ) = (&tokens[0].literal, &tokens[2].literal) - { - let predicate = match &tokens[1].token_type { - TokenType::EqualEqual => Predicate::Equal, - TokenType::NotEqual => Predicate::NotEqual, - _ => unreachable!(), - }; - Ok(Condition { - var_name: var_name.clone(), - predicate, - operand: operand.clone(), - }) - } else { - panic!( - "Unexpected state {tokens:?} returned from Scanner for: `{value}`" - ) - } - } - ( - TokenType::String, - TokenType::EqualEqual | TokenType::NotEqual, - TokenType::Identifier, - ) => { - if let ( - Some(Literal::String(operand)), - Some(Literal::Identifier(var_name)), - ) = (&tokens[0].literal, &tokens[2].literal) - { - let predicate = match &tokens[1].token_type { - TokenType::EqualEqual => Predicate::Equal, - TokenType::NotEqual => Predicate::NotEqual, - _ => unreachable!(), - }; - Ok(Condition { - var_name: var_name.clone(), - predicate, - operand: operand.clone(), - }) - } else { - panic!( - "Unexpected state {tokens:?} returned from Scanner for: `{value}`" - ) - } - } - (t1, t2, _) => { - let faulty = match (t1, t2) { - ( - TokenType::Identifier | TokenType::String, - TokenType::EqualEqual | TokenType::NotEqual, - ) => 2, - (TokenType::Identifier | TokenType::String, _) => 1, - (_, _) => 0, - }; - Err(ConditionParsingError { - error: SyntaxError { - pos: tokens[faulty].pos, - error: ErrorType::UnexpectedToken(tokens[faulty].clone()), - }, - tokens, - condition: value, - }) - } - } - } - Ordering::Less => Err(ConditionParsingError { - error: SyntaxError { - pos: value.len(), - error: ErrorType::MissingToken, - }, - tokens, - condition: value, - }), - Ordering::Greater => Err(ConditionParsingError { - error: SyntaxError { - pos: tokens[3].pos, - error: ErrorType::UnexpectedToken(tokens[3].clone()), - }, - tokens, - condition: value, - }), - }, - Err(err) => Err(ConditionParsingError { - error: err, - tokens: Vec::new(), - condition: value, - }), - } - } -} - -impl From for String { - fn from(condition: Condition) -> Self { - let p = &condition.predicate; - let predicate: String = p.clone().into(); - let quotes = if condition.operand.contains('"') { - '\'' - } else { - '"' - }; - format!( - "{} {} {}{}{}", - condition.var_name, predicate, quotes, condition.operand, quotes - ) - } -} - -#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Hash)] -pub enum Predicate { - Equal, - NotEqual, -} - -#[allow(dead_code)] -impl Predicate { - fn test(&self, lhs: &str, rhs: &str) -> bool { - match self { - Predicate::Equal => lhs == rhs, - Predicate::NotEqual => lhs != rhs, - } - } -} - -impl From for String { - fn from(op: Predicate) -> Self { - match op { - Predicate::Equal => "==".to_string(), - Predicate::NotEqual => "!=".to_string(), - } - } -} - impl Limit { - pub fn new, T: TryInto>( + pub fn new, T: TryInto>( namespace: N, max_value: u64, seconds: u64, conditions: impl IntoIterator, variables: impl IntoIterator>, - ) -> LimitadorResult + ) -> Result where >::Error: core::fmt::Debug, - >::Error: core::fmt::Debug, - LimitadorError: From<>::Error>, + >::Error: core::fmt::Debug, + ParseError: From<>::Error>, { // the above where-clause is needed in order to call unwrap(). - let conditions: Result, _> = conditions - .into_iter() - .map(|cond| cond.try_into()) - .map(|r| r.map(|c| cel::Predicate::parse::(c.into()).unwrap())) - .collect(); + let conditions: Result, _> = + conditions.into_iter().map(|cond| cond.try_into()).collect(); match conditions { Ok(conditions) => Ok(Self { id: None, @@ -266,23 +72,18 @@ impl Limit { } } - pub fn with_id, N: Into, T: TryInto>( + pub fn with_id, N: Into, T: TryInto>( id: S, namespace: N, max_value: u64, seconds: u64, conditions: impl IntoIterator, variables: impl IntoIterator>, - ) -> LimitadorResult + ) -> Result where - LimitadorError: From<>::Error>, + ParseError: From<>::Error>, { - match conditions - .into_iter() - .map(|cond| cond.try_into()) - .map(|r| r.map(|c| cel::Predicate::parse::(c.into()).unwrap())) - .collect() - { + match conditions.into_iter().map(|cond| cond.try_into()).collect() { Ok(conditions) => Ok(Self { id: Some(id.into()), namespace: namespace.into(), @@ -401,402 +202,7 @@ impl PartialEq for Limit { } } -mod conditions { - use std::error::Error; - use std::fmt::{Debug, Display, Formatter}; - use std::num::IntErrorKind; - - #[derive(Debug)] - pub struct SyntaxError { - pub pos: usize, - pub error: ErrorType, - } - - #[derive(Debug, Eq, PartialEq)] - pub enum ErrorType { - UnexpectedToken(Token), - MissingToken, - InvalidCharacter(char), - InvalidNumber, - UnclosedStringLiteral(char), - } - - impl Display for SyntaxError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match &self.error { - ErrorType::UnexpectedToken(token) => write!( - f, - "SyntaxError: Unexpected token `{}` at offset {}", - token, self.pos - ), - ErrorType::InvalidCharacter(char) => write!( - f, - "SyntaxError: Invalid character `{}` at offset {}", - char, self.pos - ), - ErrorType::InvalidNumber => { - write!(f, "SyntaxError: Invalid number at offset {}", self.pos) - } - ErrorType::MissingToken => { - write!(f, "SyntaxError: Expected token at offset {}", self.pos) - } - ErrorType::UnclosedStringLiteral(char) => { - write!(f, "SyntaxError: Missing closing `{}` for string literal starting at offset {}", char, self.pos) - } - } - } - } - - impl Error for SyntaxError {} - - #[derive(Clone, Eq, PartialEq, Debug)] - pub enum TokenType { - // Predicates - EqualEqual, - NotEqual, - - //Literals - Identifier, - String, - Number, - } - - #[derive(Clone, Eq, PartialEq, Debug)] - pub enum Literal { - Identifier(String), - String(String), - Number(i64), - } - - impl Display for Literal { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Literal::Identifier(id) => write!(f, "{id}"), - Literal::String(string) => write!(f, "'{string}'"), - Literal::Number(number) => write!(f, "{number}"), - } - } - } - - #[derive(Clone, Eq, PartialEq, Debug)] - pub struct Token { - pub token_type: TokenType, - pub literal: Option, - pub pos: usize, - } - - impl Display for Token { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.token_type { - TokenType::EqualEqual => write!(f, "Equality (==)"), - TokenType::NotEqual => write!(f, "Unequal (!=)"), - TokenType::Identifier => { - write!(f, "Identifier: {}", self.literal.as_ref().unwrap()) - } - TokenType::String => { - write!(f, "String literal: {}", self.literal.as_ref().unwrap()) - } - TokenType::Number => { - write!(f, "Number literal: {}", self.literal.as_ref().unwrap()) - } - } - } - } - - pub struct Scanner { - input: Vec, - pos: usize, - } - - impl Scanner { - pub fn scan(condition: String) -> Result, SyntaxError> { - let mut tokens: Vec = Vec::with_capacity(3); - let mut scanner = Scanner { - input: condition.chars().collect(), - pos: 0, - }; - while !scanner.done() { - match scanner.next_token() { - Ok(token) => { - if let Some(token) = token { - tokens.push(token) - } - } - Err(err) => { - return Err(err); - } - } - } - Ok(tokens) - } - - fn next_token(&mut self) -> Result, SyntaxError> { - let character = self.advance(); - match character { - '=' => { - if self.next_matches('=') { - Ok(Some(Token { - token_type: TokenType::EqualEqual, - literal: None, - pos: self.pos - 1, - })) - } else { - Err(SyntaxError { - pos: self.pos, - error: ErrorType::InvalidCharacter(self.input[self.pos - 1]), - }) - } - } - '!' => { - if self.next_matches('=') { - Ok(Some(Token { - token_type: TokenType::NotEqual, - literal: None, - pos: self.pos - 1, - })) - } else { - Err(SyntaxError { - pos: self.pos, - error: ErrorType::InvalidCharacter(self.input[self.pos - 1]), - }) - } - } - '"' | '\'' => self.scan_string(character).map(Some), - ' ' | '\n' | '\r' | '\t' => Ok(None), - _ => { - if character.is_alphabetic() { - self.scan_identifier().map(Some) - } else if character.is_numeric() { - self.scan_number().map(Some) - } else { - Err(SyntaxError { - pos: self.pos, - error: ErrorType::InvalidCharacter(character), - }) - } - } - } - } - - fn scan_identifier(&mut self) -> Result { - let start = self.pos; - while !self.done() && self.valid_id_char() { - self.advance(); - } - Ok(Token { - token_type: TokenType::Identifier, - literal: Some(Literal::Identifier( - self.input[start - 1..self.pos].iter().collect(), - )), - pos: start, - }) - } - - fn valid_id_char(&mut self) -> bool { - let char = self.input[self.pos]; - char.is_alphanumeric() || char == '.' || char == '_' - } - - fn scan_string(&mut self, until: char) -> Result { - let start = self.pos; - loop { - if self.done() { - return Err(SyntaxError { - pos: start, - error: ErrorType::UnclosedStringLiteral(until), - }); - } - if self.advance() == until { - return Ok(Token { - token_type: TokenType::String, - literal: Some(Literal::String( - self.input[start..self.pos - 1].iter().collect(), - )), - pos: start, - }); - } - } - } - - fn scan_number(&mut self) -> Result { - let start = self.pos; - while !self.done() && self.input[self.pos].is_numeric() { - self.advance(); - } - let number_str = self.input[start - 1..self.pos].iter().collect::(); - match number_str.parse::() { - Ok(number) => Ok(Token { - token_type: TokenType::Number, - literal: Some(Literal::Number(number)), - pos: start, - }), - Err(err) => { - let syntax_error = match err.kind() { - IntErrorKind::Empty => { - unreachable!("This means a bug in the scanner!") - } - IntErrorKind::Zero => { - unreachable!("We're parsing Numbers as i64, so 0 should always work!") - } - _ => SyntaxError { - pos: start, - error: ErrorType::InvalidNumber, - }, - }; - Err(syntax_error) - } - } - } - - fn advance(&mut self) -> char { - let char = self.input[self.pos]; - self.pos += 1; - char - } - - fn next_matches(&mut self, c: char) -> bool { - if self.done() || self.input[self.pos] != c { - return false; - } - - self.pos += 1; - true - } - - fn done(&self) -> bool { - self.pos >= self.input.len() - } - } - - #[cfg(test)] - mod tests { - use crate::limit::conditions::Literal::Identifier; - use crate::limit::conditions::{ErrorType, Literal, Scanner, Token, TokenType}; - - #[test] - fn test_scanner() { - let mut tokens = - Scanner::scan("foo=='bar '".to_owned()).expect("Should parse alright!"); - assert_eq!(tokens.len(), 3); - assert_eq!( - tokens[0], - Token { - token_type: TokenType::Identifier, - literal: Some(Identifier("foo".to_owned())), - pos: 1, - } - ); - assert_eq!( - tokens[1], - Token { - token_type: TokenType::EqualEqual, - literal: None, - pos: 4, - } - ); - assert_eq!( - tokens[2], - Token { - token_type: TokenType::String, - literal: Some(Literal::String("bar ".to_owned())), - pos: 6, - } - ); - - tokens[1].pos += 1; - tokens[2].pos += 2; - assert_eq!( - tokens, - Scanner::scan("foo == 'bar '".to_owned()).expect("Should parse alright!") - ); - - tokens[0].pos += 2; - tokens[1].pos += 2; - tokens[2].pos += 2; - assert_eq!( - tokens, - Scanner::scan(" foo == 'bar ' ".to_owned()).expect("Should parse alright!") - ); - - tokens[1].pos += 2; - tokens[2].pos += 4; - assert_eq!( - tokens, - Scanner::scan(" foo == 'bar ' ".to_owned()).expect("Should parse alright!") - ); - } - - #[test] - fn test_number_literal() { - let tokens = Scanner::scan("var == 42".to_owned()).expect("Should parse alright!"); - assert_eq!(tokens.len(), 3); - assert_eq!( - tokens[0], - Token { - token_type: TokenType::Identifier, - literal: Some(Identifier("var".to_owned())), - pos: 1, - } - ); - assert_eq!( - tokens[1], - Token { - token_type: TokenType::EqualEqual, - literal: None, - pos: 5, - } - ); - assert_eq!( - tokens[2], - Token { - token_type: TokenType::Number, - literal: Some(Literal::Number(42)), - pos: 8, - } - ); - } - - #[test] - fn test_charset() { - let tokens = - Scanner::scan(" 変数 == ' 💖 '".to_owned()).expect("Should parse alright!"); - assert_eq!(tokens.len(), 3); - assert_eq!( - tokens[0], - Token { - token_type: TokenType::Identifier, - literal: Some(Identifier("変数".to_owned())), - pos: 2, - } - ); - assert_eq!( - tokens[1], - Token { - token_type: TokenType::EqualEqual, - literal: None, - pos: 5, - } - ); - assert_eq!( - tokens[2], - Token { - token_type: TokenType::String, - literal: Some(Literal::String(" 💖 ".to_owned())), - pos: 8, - } - ); - } - - #[test] - fn unclosed_string_literal() { - let error = Scanner::scan("foo == 'ba".to_owned()).expect_err("Should fail!"); - assert_eq!(error.pos, 8); - assert_eq!(error.error, ErrorType::UnclosedStringLiteral('\'')); - } - } -} - -use crate::limit::cel::Context; +use crate::limit::cel::{Context, Predicate}; pub use cel::Expression as CelExpression; pub use cel::ParseError; pub use cel::Predicate as CelPredicate; @@ -906,68 +312,6 @@ mod tests { assert!(!limit.applies(&values)) } - #[test] - fn valid_condition_literal_parsing() { - let result: Condition = serde_json::from_str(r#""x == '5'""#).expect("Should deserialize"); - assert_eq!( - result, - Condition { - var_name: "x".to_string(), - predicate: Predicate::Equal, - operand: "5".to_string(), - } - ); - - let result: Condition = - serde_json::from_str(r#"" foobar=='ok' ""#).expect("Should deserialize"); - assert_eq!( - result, - Condition { - var_name: "foobar".to_string(), - predicate: Predicate::Equal, - operand: "ok".to_string(), - } - ); - - let result: Condition = - serde_json::from_str(r#"" foobar == 'ok' ""#).expect("Should deserialize"); - assert_eq!( - result, - Condition { - var_name: "foobar".to_string(), - predicate: Predicate::Equal, - operand: "ok".to_string(), - } - ); - } - - #[test] - fn invalid_deprecated_condition_parsing() { - serde_json::from_str::(r#""x == 5""#).expect_err("Should fail!"); - } - - #[test] - fn invalid_condition_parsing() { - let result = serde_json::from_str::(r#""x != 5 && x > 12""#) - .expect_err("should fail parsing"); - assert_eq!( - result.to_string(), - "SyntaxError: Invalid character `&` at offset 8 of condition \"x != 5 && x > 12\"" - .to_string() - ); - } - - #[test] - fn condition_serialization() { - let condition = Condition { - var_name: "foobar".to_string(), - predicate: Predicate::Equal, - operand: "ok".to_string(), - }; - let result = serde_json::to_string(&condition).expect("Should serialize"); - assert_eq!(result, r#""foobar == \"ok\"""#.to_string()); - } - #[test] fn limit_id() { let limit = Limit::with_id( diff --git a/limitador/src/limit/cel.rs b/limitador/src/limit/cel.rs index c35b1492..66533d22 100644 --- a/limitador/src/limit/cel.rs +++ b/limitador/src/limit/cel.rs @@ -1,4 +1,3 @@ -use crate::limit::cel::errors::EvaluationError; use crate::limit::Limit; use cel_interpreter::{ExecutionError, Value}; pub use errors::ParseError; @@ -73,6 +72,8 @@ pub(super) mod errors { } } +pub use errors::EvaluationError; + pub struct Context<'a> { variables: HashSet, ctx: cel_interpreter::Context<'a>, @@ -261,6 +262,14 @@ impl TryFrom for Predicate { } } +impl TryFrom<&str> for Predicate { + type Error = ParseError; + + fn try_from(value: &str) -> Result { + Self::parse(value) + } +} + impl From for String { fn from(value: Predicate) -> Self { value.expression.source diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 300c6dc4..e426f6b0 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -124,7 +124,7 @@ mod tests { ) .expect("This must be a valid limit!"); assert_eq!( - "namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req_method == \\\"GET\\\"\"],\"variables\":[\"app_id\"]}".as_bytes(), + "namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req_method == 'GET'\"],\"variables\":[\"app_id\"]}".as_bytes(), key_for_counters_of_limit(&limit)) } diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index f9ece6a7..edb562b0 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -1267,7 +1267,7 @@ mod test { namespace, max_hits, 60, - vec!["req_method == 'GET'"], + vec!["req_method == 'GET'".to_string()], vec!["app_id"], ) .expect("This must be a valid limit!");