From 1a4892816dedc63c7c07ffa1a7efeeccad3a108d Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 26 Jul 2024 16:53:17 -0400 Subject: [PATCH] chore: Unify connected and error response formats --- contract-tests/src/bin/sse-test-api/main.rs | 19 ++--- eventsource-client/examples/tail.rs | 7 +- eventsource-client/src/client.rs | 38 +++++---- eventsource-client/src/error.rs | 55 +------------ eventsource-client/src/event_parser.rs | 28 +++++-- eventsource-client/src/lib.rs | 2 + eventsource-client/src/response.rs | 85 +++++++++++++++++++++ 7 files changed, 137 insertions(+), 97 deletions(-) create mode 100644 eventsource-client/src/response.rs diff --git a/contract-tests/src/bin/sse-test-api/main.rs b/contract-tests/src/bin/sse-test-api/main.rs index 22f4e78..be9e967 100644 --- a/contract-tests/src/bin/sse-test-api/main.rs +++ b/contract-tests/src/bin/sse-test-api/main.rs @@ -50,25 +50,16 @@ struct Config { #[derive(Serialize, Debug)] #[serde(tag = "kind", rename_all = "camelCase")] enum EventType { - Connected { - status: u16, - headers: HashMap, - }, - Event { - event: Event, - }, - Comment { - comment: String, - }, - Error { - error: String, - }, + Connected {}, + Event { event: Event }, + Comment { comment: String }, + Error { error: String }, } impl From for EventType { fn from(event: es::SSE) -> Self { match event { - es::SSE::Connected((status, headers)) => Self::Connected { status, headers }, + es::SSE::Connected(_) => Self::Connected {}, es::SSE::Event(evt) => Self::Event { event: Event { event_type: evt.event_type, diff --git a/eventsource-client/examples/tail.rs b/eventsource-client/examples/tail.rs index e6465de..44fd6c3 100644 --- a/eventsource-client/examples/tail.rs +++ b/eventsource-client/examples/tail.rs @@ -40,8 +40,8 @@ fn tail_events(client: impl es::Client) -> impl Stream> { client .stream() .map_ok(|event| match event { - es::SSE::Connected((status, _)) => { - println!("got connected: \nstatus={}", status) + es::SSE::Connected(connection) => { + println!("got connected: \nstatus={}", connection.response().status()) } es::SSE::Event(ev) => { println!("got an event: {}\n{}", ev.event_type, ev.data) @@ -49,9 +49,6 @@ fn tail_events(client: impl es::Client) -> impl Stream> { es::SSE::Comment(comment) => { println!("got a comment: \n{}", comment) } - es::SSE::Connected(headers) => { - println!("got a connection start with headers: \n{:?}", headers) - } }) .map_err(|err| eprintln!("error streaming events: {:?}", err)) } diff --git a/eventsource-client/src/client.rs b/eventsource-client/src/client.rs index 1d4544b..da28f29 100644 --- a/eventsource-client/src/client.rs +++ b/eventsource-client/src/client.rs @@ -7,14 +7,13 @@ use hyper::{ }, header::{HeaderMap, HeaderName, HeaderValue}, service::Service, - Body, Request, StatusCode, Uri, + Body, Request, Uri, }; use log::{debug, info, trace, warn}; use pin_project::pin_project; use std::{ boxed, - collections::HashMap, - fmt::{self, Debug, Display, Formatter}, + fmt::{self, Debug, Formatter}, future::Future, io::ErrorKind, pin::Pin, @@ -28,8 +27,14 @@ use tokio::{ time::Sleep, }; -use crate::error::{Error, Result}; -use crate::{config::ReconnectOptions, ResponseWrapper}; +use crate::{ + config::ReconnectOptions, + response::{ErrorBody, Response}, +}; +use crate::{ + error::{Error, Result}, + event_parser::ConnectionDetails, +}; use hyper::client::HttpConnector; use hyper_timeout::TimeoutConnector; @@ -396,7 +401,7 @@ where return match event { SSE::Connected(_) => Poll::Ready(Some(Ok(event))), SSE::Event(ref evt) => { - *this.last_event_id = evt.id.clone(); + this.last_event_id.clone_from(&evt.id); if let Some(retry) = evt.retry { this.retry_strategy @@ -405,7 +410,6 @@ where Poll::Ready(Some(Ok(event))) } SSE::Comment(_) => Poll::Ready(Some(Ok(event))), - SSE::Connected(_) => Poll::Ready(Some(Ok(event))), }; } @@ -442,24 +446,17 @@ where self.as_mut().project().retry_strategy.reset(Instant::now()); self.as_mut().reset_redirects(); - let headers = resp.headers(); - let mut map = HashMap::new(); - for (key, value) in headers.iter() { - let key = key.to_string(); - let value = match value.to_str() { - Ok(value) => value.to_string(), - Err(_) => String::from(""), - }; - map.insert(key, value); - } - let status = resp.status().as_u16(); + let status = resp.status(); + let headers = resp.headers().clone(); self.as_mut() .project() .state .set(State::Connected(resp.into_body())); - return Poll::Ready(Some(Ok(SSE::Connected((status, map))))); + return Poll::Ready(Some(Ok(SSE::Connected(ConnectionDetails::new( + Response::new(status, headers), + ))))); } if resp.status() == 301 || resp.status() == 307 { @@ -486,7 +483,8 @@ where self.as_mut().project().state.set(State::New); return Poll::Ready(Some(Err(Error::UnexpectedResponse( - ResponseWrapper::new(resp), + Response::new(resp.status(), resp.headers().clone()), + ErrorBody::new(resp.into_body()), )))); } Err(e) => { diff --git a/eventsource-client/src/error.rs b/eventsource-client/src/error.rs index 00bb912..f84f891 100644 --- a/eventsource-client/src/error.rs +++ b/eventsource-client/src/error.rs @@ -1,53 +1,6 @@ -use std::collections::HashMap; +use crate::response::{ErrorBody, Response}; -use hyper::{body::Buf, Body, Response}; - -pub struct ResponseWrapper { - response: Response, -} - -impl ResponseWrapper { - pub fn new(response: Response) -> Self { - Self { response } - } - pub fn status(&self) -> u16 { - self.response.status().as_u16() - } - pub fn headers(&self) -> std::result::Result, HeaderError> { - let headers = self.response.headers(); - let mut map = HashMap::new(); - for (key, value) in headers.iter() { - let key = key.as_str(); - let value = match value.to_str() { - Ok(value) => value, - Err(err) => return Err(HeaderError::new(Box::new(err))), - }; - map.insert(key, value); - } - Ok(map) - } - - pub async fn body_bytes(self) -> Result> { - let body = self.response.into_body(); - - let buf = match hyper::body::aggregate(body).await { - Ok(buf) => buf, - Err(err) => return Err(Error::HttpStream(Box::new(err))), - }; - - Ok(buf.chunk().to_vec()) - } -} - -impl std::fmt::Debug for ResponseWrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ResponseWrapper") - .field("status", &self.status()) - .finish() - } -} - -/// Error type for invalid response headers encountered in ResponseWrapper. +/// Error type for invalid response headers encountered in ResponseDetails. #[derive(Debug)] pub struct HeaderError { /// Wrapped inner error providing details about the header issue. @@ -81,7 +34,7 @@ pub enum Error { /// An invalid request parameter InvalidParameter(Box), /// The HTTP response could not be handled. - UnexpectedResponse(ResponseWrapper), + UnexpectedResponse(Response, ErrorBody), /// An error reading from the HTTP response body. HttpStream(Box), /// The HTTP response stream ended @@ -105,7 +58,7 @@ impl std::fmt::Display for Error { TimedOut => write!(f, "timed out"), StreamClosed => write!(f, "stream closed"), InvalidParameter(err) => write!(f, "invalid parameter: {err}"), - UnexpectedResponse(r) => { + UnexpectedResponse(r, _) => { let status = r.status(); write!(f, "unexpected response: {status}") } diff --git a/eventsource-client/src/event_parser.rs b/eventsource-client/src/event_parser.rs index 0920be3..c854011 100644 --- a/eventsource-client/src/event_parser.rs +++ b/eventsource-client/src/event_parser.rs @@ -1,13 +1,11 @@ -use std::{ - collections::{HashMap, VecDeque}, - convert::TryFrom, - str::from_utf8, -}; +use std::{collections::VecDeque, convert::TryFrom, str::from_utf8}; use hyper::body::Bytes; use log::{debug, log_enabled, trace}; use pin_project::pin_project; +use crate::response::Response; + use super::error::{Error, Result}; #[derive(Default, PartialEq)] @@ -36,7 +34,7 @@ impl EventData { #[derive(Debug, Eq, PartialEq)] pub enum SSE { - Connected((u16, HashMap)), + Connected(ConnectionDetails), Event(Event), Comment(String), } @@ -75,6 +73,22 @@ impl TryFrom for Option { } } +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ConnectionDetails { + response: Response, +} + +impl ConnectionDetails { + pub(crate) fn new(response: Response) -> Self { + Self { response } + } + + /// Returns information describing the response at the time of connection. + pub fn response(&self) -> &Response { + &self.response + } +} + #[derive(Clone, Debug, Eq, PartialEq)] pub struct Event { pub event_type: String, @@ -235,7 +249,7 @@ impl EventParser { self.last_event_id = Some(value.to_string()); } - event_data.id = self.last_event_id.clone() + event_data.id.clone_from(&self.last_event_id) } else if key == "retry" { match value.parse::() { Ok(retry) => { diff --git a/eventsource-client/src/lib.rs b/eventsource-client/src/lib.rs index 7677f4f..52e9611 100644 --- a/eventsource-client/src/lib.rs +++ b/eventsource-client/src/lib.rs @@ -31,6 +31,7 @@ mod client; mod config; mod error; mod event_parser; +mod response; mod retry; pub use client::*; @@ -38,3 +39,4 @@ pub use config::*; pub use error::*; pub use event_parser::Event; pub use event_parser::SSE; +pub use response::Response; diff --git a/eventsource-client/src/response.rs b/eventsource-client/src/response.rs new file mode 100644 index 0000000..4e2eced --- /dev/null +++ b/eventsource-client/src/response.rs @@ -0,0 +1,85 @@ +use hyper::body::Buf; +use hyper::{header::HeaderValue, Body, HeaderMap, StatusCode}; + +use crate::{Error, HeaderError}; + +pub struct ErrorBody { + body: Body, +} + +impl ErrorBody { + pub fn new(body: Body) -> Self { + Self { body } + } + + /// Returns the body of the response as a vector of bytes. + /// + /// Caution: This method reads the entire body into memory. You should only use this method if + /// you know the response is of a reasonable size. + pub async fn body_bytes(self) -> Result, Error> { + let buf = match hyper::body::aggregate(self.body).await { + Ok(buf) => buf, + Err(err) => return Err(Error::HttpStream(Box::new(err))), + }; + + Ok(buf.chunk().to_vec()) + } +} + +impl std::fmt::Debug for ErrorBody { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ErrorBody").finish() + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Response { + status_code: StatusCode, + headers: HeaderMap, +} + +impl Response { + pub fn new(status_code: StatusCode, headers: HeaderMap) -> Self { + Self { + status_code, + headers, + } + } + + /// Returns the status code of this response. + pub fn status(&self) -> u16 { + self.status_code.as_u16() + } + + /// Returns the list of header keys present in this response. + pub fn get_header_keys(&self) -> Vec<&str> { + self.headers.keys().map(|key| key.as_str()).collect() + } + + /// Returns the value of a header. + /// + /// If the header contains more than one value, only the first value is returned. Refer to + /// [`get_header_values`] for a method that returns all values. + pub fn get_header_value(&self, key: &str) -> std::result::Result, HeaderError> { + if let Some(value) = self.headers.get(key) { + value + .to_str() + .map(Some) + .map_err(|e| HeaderError::new(Box::new(e))) + } else { + Ok(None) + } + } + + /// Returns all values for a header. + /// + /// If the header contains only one value, it will be returned as a single-element vector. + /// Refer to [`get_header_value`] for a method that returns only a single value. + pub fn get_header_values(&self, key: &str) -> std::result::Result, HeaderError> { + self.headers + .get_all(key) + .iter() + .map(|value| value.to_str().map_err(|e| HeaderError::new(Box::new(e)))) + .collect() + } +}