diff --git a/src/server.rs b/src/server.rs index ca01750..1cc68ec 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,13 +4,16 @@ use crate::types::edge::EdgeDB; use crate::types::{Address, Edge, U256}; use json::JsonValue; use std::error::Error; +use std::fmt::{Debug, Display, Formatter}; use std::io::Read; use std::io::{BufRead, BufReader, Write}; use std::net::{TcpListener, TcpStream}; use std::ops::Deref; +use std::str::FromStr; use std::sync::mpsc::TrySendError; use std::sync::{mpsc, Arc, Mutex, RwLock}; use std::thread; +use num_bigint::BigUint; struct JsonRpcRequest { id: JsonValue, @@ -18,6 +21,20 @@ struct JsonRpcRequest { params: JsonValue, } +struct InputValidationError(String); +impl Error for InputValidationError {} + +impl Debug for InputValidationError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Error: {}", self.0) + } +} +impl Display for InputValidationError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Error: {}", self.0) + } +} + pub fn start_server(listen_at: &str, queue_size: usize, threads: u64) { let edges: Arc>> = Arc::new(RwLock::new(Arc::new(EdgeDB::default()))); @@ -137,6 +154,27 @@ fn compute_transfer( mut socket: TcpStream, ) -> Result<(), Box> { socket.write_all(chunked_header().as_bytes())?; + + let parsed_value_param = match request.params["value"].as_str() { + Some(value_str) => match BigUint::from_str(value_str) { + Ok(parsed_value) => parsed_value, + Err(e) => { + return Err(Box::new(InputValidationError(format!( + "Invalid value: {}. Couldn't parse value: {}", + value_str, e + )))); + } + }, + None => U256::MAX.into(), + }; + + if parsed_value_param > U256::MAX.into() { + return Err(Box::new(InputValidationError(format!( + "Value {} is too large. Maximum value is {}.", + parsed_value_param, U256::MAX + )))); + } + let max_distances = if request.params["iterative"].as_bool().unwrap_or_default() { vec![Some(1), Some(2), None] } else { @@ -148,11 +186,7 @@ fn compute_transfer( &Address::from(request.params["from"].to_string().as_str()), &Address::from(request.params["to"].to_string().as_str()), edges, - if request.params.has_key("value") { - U256::from(request.params["value"].to_string().as_str()) - } else { - U256::MAX - }, + U256::from_bigint_truncating(parsed_value_param.clone()), max_distance, max_transfers, );