diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index ab75086884a7..1b17495c5da3 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -48,6 +48,7 @@ use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; struct QueryData { query: String, #[serde(deserialize_with = "bytes_to_pg_text")] + #[serde(default)] params: Vec>, #[serde(default)] array_mode: Option, @@ -1105,3 +1106,63 @@ impl Discard<'_> { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_payload() { + let payload = "{\"query\":\"SELECT * FROM users WHERE name = ?\",\"params\":[\"test\"],\"arrayMode\":true}"; + let deserialized_payload: Payload = serde_json::from_str(payload).unwrap(); + + match deserialized_payload { + Payload::Single(QueryData { + query, + params, + array_mode, + }) => { + assert_eq!(query, "SELECT * FROM users WHERE name = ?"); + assert_eq!(params, vec![Some(String::from("test"))]); + assert!(array_mode.unwrap()); + } + Payload::Batch(_) => { + panic!("deserialization failed: case with single query, one param, and array mode") + } + } + + let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}"; + let deserialized_payload: Payload = serde_json::from_str(payload).unwrap(); + + match deserialized_payload { + Payload::Batch(BatchQueryData { queries }) => { + assert_eq!(queries.len(), 2); + for (i, query) in queries.into_iter().enumerate() { + assert_eq!( + query.query, + format!("SELECT * FROM users{i} WHERE name = ?") + ); + assert_eq!(query.params, vec![Some(format!("test{i}"))]); + assert_eq!(query.array_mode.unwrap(), i > 0); + } + } + Payload::Single(_) => panic!("deserialization failed: case with multiple queries"), + } + + let payload = "{\"query\":\"SELECT 1\"}"; + let deserialized_payload: Payload = serde_json::from_str(payload).unwrap(); + + match deserialized_payload { + Payload::Single(QueryData { + query, + params, + array_mode, + }) => { + assert_eq!(query, "SELECT 1"); + assert_eq!(params, vec![]); + assert!(array_mode.is_none()); + } + Payload::Batch(_) => panic!("deserialization failed: case with only one query"), + } + } +}