Skip to content

Commit

Permalink
chore: Unify connected and error response formats
Browse files Browse the repository at this point in the history
  • Loading branch information
keelerm84 committed Jul 30, 2024
1 parent 90bbea9 commit 1a48928
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 97 deletions.
19 changes: 5 additions & 14 deletions contract-tests/src/bin/sse-test-api/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,16 @@ struct Config {
#[derive(Serialize, Debug)]
#[serde(tag = "kind", rename_all = "camelCase")]
enum EventType {
Connected {
status: u16,
headers: HashMap<String, String>,
},
Event {
event: Event,
},
Comment {
comment: String,
},
Error {
error: String,
},
Connected {},
Event { event: Event },
Comment { comment: String },
Error { error: String },
}

impl From<es::SSE> for EventType {
fn from(event: es::SSE) -> Self {
match event {
es::SSE::Connected((status, headers)) => Self::Connected { status, headers },
es::SSE::Connected(_) => Self::Connected {},
es::SSE::Event(evt) => Self::Event {
event: Event {
event_type: evt.event_type,
Expand Down
7 changes: 2 additions & 5 deletions eventsource-client/examples/tail.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,15 @@ fn tail_events(client: impl es::Client) -> impl Stream<Item = Result<(), ()>> {
client
.stream()
.map_ok(|event| match event {
es::SSE::Connected((status, _)) => {
println!("got connected: \nstatus={}", status)
es::SSE::Connected(connection) => {
println!("got connected: \nstatus={}", connection.response().status())
}
es::SSE::Event(ev) => {
println!("got an event: {}\n{}", ev.event_type, ev.data)
}
es::SSE::Comment(comment) => {
println!("got a comment: \n{}", comment)
}
es::SSE::Connected(headers) => {
println!("got a connection start with headers: \n{:?}", headers)
}
})
.map_err(|err| eprintln!("error streaming events: {:?}", err))
}
38 changes: 18 additions & 20 deletions eventsource-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ use hyper::{
},
header::{HeaderMap, HeaderName, HeaderValue},
service::Service,
Body, Request, StatusCode, Uri,
Body, Request, Uri,
};
use log::{debug, info, trace, warn};
use pin_project::pin_project;
use std::{
boxed,
collections::HashMap,
fmt::{self, Debug, Display, Formatter},
fmt::{self, Debug, Formatter},
future::Future,
io::ErrorKind,
pin::Pin,
Expand All @@ -28,8 +27,14 @@ use tokio::{
time::Sleep,
};

use crate::error::{Error, Result};
use crate::{config::ReconnectOptions, ResponseWrapper};
use crate::{
config::ReconnectOptions,
response::{ErrorBody, Response},
};
use crate::{
error::{Error, Result},
event_parser::ConnectionDetails,
};

use hyper::client::HttpConnector;
use hyper_timeout::TimeoutConnector;
Expand Down Expand Up @@ -396,7 +401,7 @@ where
return match event {
SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
SSE::Event(ref evt) => {
*this.last_event_id = evt.id.clone();
this.last_event_id.clone_from(&evt.id);

if let Some(retry) = evt.retry {
this.retry_strategy
Expand All @@ -405,7 +410,6 @@ where
Poll::Ready(Some(Ok(event)))
}
SSE::Comment(_) => Poll::Ready(Some(Ok(event))),
SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
};
}

Expand Down Expand Up @@ -442,24 +446,17 @@ where
self.as_mut().project().retry_strategy.reset(Instant::now());
self.as_mut().reset_redirects();

let headers = resp.headers();
let mut map = HashMap::new();
for (key, value) in headers.iter() {
let key = key.to_string();
let value = match value.to_str() {
Ok(value) => value.to_string(),
Err(_) => String::from(""),
};
map.insert(key, value);
}
let status = resp.status().as_u16();
let status = resp.status();
let headers = resp.headers().clone();

self.as_mut()
.project()
.state
.set(State::Connected(resp.into_body()));

return Poll::Ready(Some(Ok(SSE::Connected((status, map)))));
return Poll::Ready(Some(Ok(SSE::Connected(ConnectionDetails::new(
Response::new(status, headers),
)))));
}

if resp.status() == 301 || resp.status() == 307 {
Expand All @@ -486,7 +483,8 @@ where
self.as_mut().project().state.set(State::New);

return Poll::Ready(Some(Err(Error::UnexpectedResponse(
ResponseWrapper::new(resp),
Response::new(resp.status(), resp.headers().clone()),
ErrorBody::new(resp.into_body()),
))));
}
Err(e) => {
Expand Down
55 changes: 4 additions & 51 deletions eventsource-client/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,6 @@
use std::collections::HashMap;
use crate::response::{ErrorBody, Response};

use hyper::{body::Buf, Body, Response};

pub struct ResponseWrapper {
response: Response<Body>,
}

impl ResponseWrapper {
pub fn new(response: Response<Body>) -> Self {
Self { response }
}
pub fn status(&self) -> u16 {
self.response.status().as_u16()
}
pub fn headers(&self) -> std::result::Result<HashMap<&str, &str>, HeaderError> {
let headers = self.response.headers();
let mut map = HashMap::new();
for (key, value) in headers.iter() {
let key = key.as_str();
let value = match value.to_str() {
Ok(value) => value,
Err(err) => return Err(HeaderError::new(Box::new(err))),
};
map.insert(key, value);
}
Ok(map)
}

pub async fn body_bytes(self) -> Result<Vec<u8>> {
let body = self.response.into_body();

let buf = match hyper::body::aggregate(body).await {
Ok(buf) => buf,
Err(err) => return Err(Error::HttpStream(Box::new(err))),
};

Ok(buf.chunk().to_vec())
}
}

impl std::fmt::Debug for ResponseWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseWrapper")
.field("status", &self.status())
.finish()
}
}

/// Error type for invalid response headers encountered in ResponseWrapper.
/// Error type for invalid response headers encountered in ResponseDetails.
#[derive(Debug)]
pub struct HeaderError {
/// Wrapped inner error providing details about the header issue.
Expand Down Expand Up @@ -81,7 +34,7 @@ pub enum Error {
/// An invalid request parameter
InvalidParameter(Box<dyn std::error::Error + Send + Sync + 'static>),
/// The HTTP response could not be handled.
UnexpectedResponse(ResponseWrapper),
UnexpectedResponse(Response, ErrorBody),
/// An error reading from the HTTP response body.
HttpStream(Box<dyn std::error::Error + Send + Sync + 'static>),
/// The HTTP response stream ended
Expand All @@ -105,7 +58,7 @@ impl std::fmt::Display for Error {
TimedOut => write!(f, "timed out"),
StreamClosed => write!(f, "stream closed"),
InvalidParameter(err) => write!(f, "invalid parameter: {err}"),
UnexpectedResponse(r) => {
UnexpectedResponse(r, _) => {
let status = r.status();
write!(f, "unexpected response: {status}")
}
Expand Down
28 changes: 21 additions & 7 deletions eventsource-client/src/event_parser.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::{
collections::{HashMap, VecDeque},
convert::TryFrom,
str::from_utf8,
};
use std::{collections::VecDeque, convert::TryFrom, str::from_utf8};

use hyper::body::Bytes;
use log::{debug, log_enabled, trace};
use pin_project::pin_project;

use crate::response::Response;

use super::error::{Error, Result};

#[derive(Default, PartialEq)]
Expand Down Expand Up @@ -36,7 +34,7 @@ impl EventData {

#[derive(Debug, Eq, PartialEq)]
pub enum SSE {
Connected((u16, HashMap<String, String>)),
Connected(ConnectionDetails),
Event(Event),
Comment(String),
}
Expand Down Expand Up @@ -75,6 +73,22 @@ impl TryFrom<EventData> for Option<SSE> {
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ConnectionDetails {
response: Response,
}

impl ConnectionDetails {
pub(crate) fn new(response: Response) -> Self {
Self { response }
}

/// Returns information describing the response at the time of connection.
pub fn response(&self) -> &Response {
&self.response
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Event {
pub event_type: String,
Expand Down Expand Up @@ -235,7 +249,7 @@ impl EventParser {
self.last_event_id = Some(value.to_string());
}

event_data.id = self.last_event_id.clone()
event_data.id.clone_from(&self.last_event_id)
} else if key == "retry" {
match value.parse::<u64>() {
Ok(retry) => {
Expand Down
2 changes: 2 additions & 0 deletions eventsource-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ mod client;
mod config;
mod error;
mod event_parser;
mod response;
mod retry;

pub use client::*;
pub use config::*;
pub use error::*;
pub use event_parser::Event;
pub use event_parser::SSE;
pub use response::Response;
85 changes: 85 additions & 0 deletions eventsource-client/src/response.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use hyper::body::Buf;
use hyper::{header::HeaderValue, Body, HeaderMap, StatusCode};

use crate::{Error, HeaderError};

pub struct ErrorBody {
body: Body,
}

impl ErrorBody {
pub fn new(body: Body) -> Self {
Self { body }
}

/// Returns the body of the response as a vector of bytes.
///
/// Caution: This method reads the entire body into memory. You should only use this method if
/// you know the response is of a reasonable size.
pub async fn body_bytes(self) -> Result<Vec<u8>, Error> {
let buf = match hyper::body::aggregate(self.body).await {
Ok(buf) => buf,
Err(err) => return Err(Error::HttpStream(Box::new(err))),
};

Ok(buf.chunk().to_vec())
}
}

impl std::fmt::Debug for ErrorBody {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ErrorBody").finish()
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Response {
status_code: StatusCode,
headers: HeaderMap<HeaderValue>,
}

impl Response {
pub fn new(status_code: StatusCode, headers: HeaderMap<HeaderValue>) -> Self {
Self {
status_code,
headers,
}
}

/// Returns the status code of this response.
pub fn status(&self) -> u16 {
self.status_code.as_u16()
}

/// Returns the list of header keys present in this response.
pub fn get_header_keys(&self) -> Vec<&str> {
self.headers.keys().map(|key| key.as_str()).collect()
}

/// Returns the value of a header.
///
/// If the header contains more than one value, only the first value is returned. Refer to
/// [`get_header_values`] for a method that returns all values.
pub fn get_header_value(&self, key: &str) -> std::result::Result<Option<&str>, HeaderError> {
if let Some(value) = self.headers.get(key) {
value
.to_str()
.map(Some)
.map_err(|e| HeaderError::new(Box::new(e)))
} else {
Ok(None)
}
}

/// Returns all values for a header.
///
/// If the header contains only one value, it will be returned as a single-element vector.
/// Refer to [`get_header_value`] for a method that returns only a single value.
pub fn get_header_values(&self, key: &str) -> std::result::Result<Vec<&str>, HeaderError> {
self.headers
.get_all(key)
.iter()
.map(|value| value.to_str().map_err(|e| HeaderError::new(Box::new(e))))
.collect()
}
}

0 comments on commit 1a48928

Please sign in to comment.