Skip to content

Commit

Permalink
more expressive ValueCmp filter
Browse files Browse the repository at this point in the history
  • Loading branch information
0xripleys committed Jan 12, 2024
1 parent fec7d0f commit cfb9731
Showing 1 changed file with 192 additions and 72 deletions.
264 changes: 192 additions & 72 deletions rpc-client-api/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,7 @@ impl RpcFilterType {
}
}
}
RpcFilterType::ValueCmp(compare) => {
if compare.num_bytes > 8 {
Err(RpcFilterError::DataTooLarge)
} else {
Ok(())
}
}
RpcFilterType::ValueCmp(_) => Ok(()),
RpcFilterType::TokenAccountState => Ok(()),
}
}
Expand All @@ -91,7 +85,9 @@ impl RpcFilterType {
match self {
RpcFilterType::DataSize(size) => account.data().len() as u64 == *size,
RpcFilterType::Memcmp(compare) => compare.bytes_match(account.data()),
RpcFilterType::ValueCmp(compare) => compare.values_match(account.data()),
RpcFilterType::ValueCmp(compare) => {
compare.values_match(account.data()).unwrap_or(false)
}
RpcFilterType::TokenAccountState => Account::valid_account_data(account.data()),
}
}
Expand All @@ -117,6 +113,8 @@ pub enum RpcFilterError {
Base58DecodeError(#[from] bs58::decode::Error),
#[error("base64 decode error")]
Base64DecodeError(#[from] base64::DecodeError),
#[error("invalid filter")]
InvalidFilter,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down Expand Up @@ -231,54 +229,154 @@ impl Memcmp {

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ValueCmp {
pub offset: usize,
pub num_bytes: u8,
pub value: u64,
cmp_type: ValueCmpType,
endian: EndianType,
pub left: Operand,
comparator: Comparator,
pub right: Operand,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Operand {
Mem {
offset: usize,
value_type: ValueType,
},
Constant(String),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ValueType {
U8,
U16,
U32,
U64,
U128,
}

enum WrappedValueType {
U8(u8),
U16(u16),
U32(u32),
U64(u64),
U128(u128),
}

impl ValueCmp {
pub fn values_match(&self, data: &[u8]) -> bool {
if self.offset > data.len() {
return false;
}
if data[self.offset..].len() < self.num_bytes as usize {
return false;
fn parse_mem_into_value_type(
o: &Operand,
data: &[u8],
) -> Result<WrappedValueType, RpcFilterError> {
match o {
Operand::Mem { offset, value_type } => match value_type {
ValueType::U8 => {
if *offset >= data.len() {
return Err(RpcFilterError::InvalidFilter);
}

Ok(WrappedValueType::U8(data[*offset]))
}
ValueType::U16 => {
if *offset + 1 >= data.len() {
return Err(RpcFilterError::InvalidFilter);
}
Ok(WrappedValueType::U16(u16::from_le_bytes(
data[*offset..*offset + 2].try_into().unwrap(),
)))
}
ValueType::U32 => {
if *offset + 3 >= data.len() {
return Err(RpcFilterError::InvalidFilter);
}
Ok(WrappedValueType::U32(u32::from_le_bytes(
data[*offset..*offset + 4].try_into().unwrap(),
)))
}
ValueType::U64 => {
if *offset + 7 >= data.len() {
return Err(RpcFilterError::InvalidFilter);
}
Ok(WrappedValueType::U64(u64::from_le_bytes(
data[*offset..*offset + 8].try_into().unwrap(),
)))
}
ValueType::U128 => {
if *offset + 15 >= data.len() {
return Err(RpcFilterError::InvalidFilter);
}
Ok(WrappedValueType::U128(u128::from_le_bytes(
data[*offset..*offset + 16].try_into().unwrap(),
)))
}
},
_ => Err(RpcFilterError::InvalidFilter),
}
let bytes = &data[self.offset..self.offset + self.num_bytes as usize];
}

pub fn values_match(&self, data: &[u8]) -> Result<bool, RpcFilterError> {
match (&self.left, &self.right) {
(left @ Operand::Mem { .. }, right @ Operand::Mem { .. }) => {
let left = Self::parse_mem_into_value_type(left, data)?;
let right = Self::parse_mem_into_value_type(right, data)?;

let mut padded_bytes = [0u8; 8];
let value = match self.endian {
EndianType::Big => {
padded_bytes[8 - self.num_bytes as usize..].copy_from_slice(bytes);
u64::from_be_bytes(padded_bytes)
match (left, right) {
(WrappedValueType::U8(left), WrappedValueType::U8(right)) => {
Ok(self.comparator.compare(left, right))
}
(WrappedValueType::U16(left), WrappedValueType::U16(right)) => {
Ok(self.comparator.compare(left, right))
}
(WrappedValueType::U32(left), WrappedValueType::U32(right)) => {
Ok(self.comparator.compare(left, right))
}
(WrappedValueType::U64(left), WrappedValueType::U64(right)) => {
Ok(self.comparator.compare(left, right))
}
(WrappedValueType::U128(left), WrappedValueType::U128(right)) => {
Ok(self.comparator.compare(left, right))
}
_ => Err(RpcFilterError::InvalidFilter),
}
}
EndianType::Little => {
padded_bytes[..self.num_bytes as usize].copy_from_slice(bytes);
u64::from_le_bytes(padded_bytes)
(left @ Operand::Mem { .. }, Operand::Constant(constant)) => {
match Self::parse_mem_into_value_type(left, data)? {
WrappedValueType::U8(left) => {
let right = constant
.parse::<u8>()
.map_err(|_| RpcFilterError::InvalidFilter)?;
Ok(self.comparator.compare(left, right))
}
WrappedValueType::U16(left) => {
let right = constant
.parse::<u16>()
.map_err(|_| RpcFilterError::InvalidFilter)?;
Ok(self.comparator.compare(left, right))
}
WrappedValueType::U32(left) => {
let right = constant
.parse::<u32>()
.map_err(|_| RpcFilterError::InvalidFilter)?;
Ok(self.comparator.compare(left, right))
}
WrappedValueType::U64(left) => {
let right = constant
.parse::<u64>()
.map_err(|_| RpcFilterError::InvalidFilter)?;
Ok(self.comparator.compare(left, right))
}
WrappedValueType::U128(left) => {
let right = constant
.parse::<u128>()
.map_err(|_| RpcFilterError::InvalidFilter)?;
Ok(self.comparator.compare(left, right))
}
}
}
};

match self.cmp_type {
ValueCmpType::Eq => value == self.value,
ValueCmpType::Ne => value != self.value,
ValueCmpType::Gt => value > self.value,
ValueCmpType::Ge => value >= self.value,
ValueCmpType::Lt => value < self.value,
ValueCmpType::Le => value <= self.value,
_ => Err(RpcFilterError::InvalidFilter),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EndianType {
Big = 0,
Little,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ValueCmpType {
pub enum Comparator {
Eq = 0,
Ne,
Gt,
Expand All @@ -287,6 +385,20 @@ pub enum ValueCmpType {
Le,
}

impl Comparator {
// write a generic function to compare two values
pub fn compare<T: PartialOrd>(&self, left: T, right: T) -> bool {
match self {
Comparator::Eq => left == right,
Comparator::Ne => left != right,
Comparator::Gt => left > right,
Comparator::Ge => left >= right,
Comparator::Lt => left < right,
Comparator::Le => left <= right,
}
}
}

// Internal struct to hold Memcmp filter data as either encoded String or raw Bytes
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
Expand Down Expand Up @@ -467,44 +579,52 @@ mod tests {
fn test_values_match() {
// test all the ValueCmp cases
let data = vec![1, 2, 3, 4, 5];

let filter = ValueCmp {
left: Operand::Mem {
offset: 1,
value_type: ValueType::U8,
},
comparator: Comparator::Eq,
right: Operand::Constant("2".to_string()),
};

assert!(ValueCmp {
offset: 0,
num_bytes: 4,
cmp_type: ValueCmpType::Gt,
value: 4,
endian: EndianType::Little,
left: Operand::Mem {
offset: 1,
value_type: ValueType::U8
},
comparator: Comparator::Eq,
right: Operand::Constant("2".to_string())
}
.values_match(data[0..4].as_ref()));
.values_match(&data)
.unwrap());

assert!(ValueCmp {
offset: 0,
num_bytes: 4,
cmp_type: ValueCmpType::Eq,
value: 67305985,
endian: EndianType::Little,
left: Operand::Mem {
offset: 1,
value_type: ValueType::U8
},
comparator: Comparator::Lt,
right: Operand::Constant("3".to_string())
}
.values_match(data[0..4].as_ref()));
.values_match(&data)
.unwrap());

assert!(ValueCmp {
offset: 0,
num_bytes: 2,
cmp_type: ValueCmpType::Eq,
value: 515,
endian: EndianType::Big,
left: Operand::Mem {
offset: 0,
value_type: ValueType::U32
},
comparator: Comparator::Eq,
right: Operand::Constant("67305985".to_string())
}
.values_match(data[1..3].as_ref()));

let filter = ValueCmp {
offset: 0,
num_bytes: 2,
cmp_type: ValueCmpType::Eq,
value: 515,
endian: EndianType::Big,
};
.values_match(&data)
.unwrap());

// serialize
// let s = serde_json::to_string(&filter).unwrap();
// println!("{}", s);
let s = serde_json::to_string(&filter).unwrap();
println!("{}", s);
}

#[test]
Expand Down

0 comments on commit cfb9731

Please sign in to comment.