Skip to content

Commit

Permalink
chore: Add status information and return connection tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
spolu authored and keelerm84 committed Jul 30, 2024
1 parent cc4ae52 commit a772843
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 18 deletions.
19 changes: 14 additions & 5 deletions contract-tests/src/bin/sse-test-api/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,25 @@ struct Config {
#[derive(Serialize, Debug)]
#[serde(tag = "kind", rename_all = "camelCase")]
enum EventType {
Connected {},
Event { event: Event },
Comment { comment: String },
Error { error: String },
Connected {
status: u16,
headers: HashMap<String, String>,
},
Event {
event: Event,
},
Comment {
comment: String,
},
Error {
error: String,
},
}

impl From<es::SSE> for EventType {
fn from(event: es::SSE) -> Self {
match event {
es::SSE::Connected((status, headers)) => Self::Connected { status, headers },
es::SSE::Event(evt) => Self::Event {
event: Event {
event_type: evt.event_type,
Expand All @@ -67,7 +77,6 @@ impl From<es::SSE> for EventType {
},
},
es::SSE::Comment(comment) => Self::Comment { comment },
es::SSE::Connected(_) => Self::Connected {},
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion contract-tests/src/bin/sse-test-api/stream_entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl Inner {
match stream.try_next().await {
Ok(Some(event)) => {
let event_type: EventType = event.into();
if let EventType::Connected {} = event_type {
if matches!(event_type, EventType::Connected { .. }) {
continue;
}

Expand Down
3 changes: 3 additions & 0 deletions eventsource-client/examples/tail.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ fn tail_events(client: impl es::Client) -> impl Stream<Item = Result<(), ()>> {
client
.stream()
.map_ok(|event| match event {
es::SSE::Connected((status, _)) => {
println!("got connected: \nstatus={}", status)
}
es::SSE::Event(ev) => {
println!("got an event: {}\n{}", ev.event_type, ev.data)
}
Expand Down
27 changes: 22 additions & 5 deletions eventsource-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use log::{debug, info, trace, warn};
use pin_project::pin_project;
use std::{
boxed,
collections::HashMap,
fmt::{self, Debug, Display, Formatter},
future::Future,
io::ErrorKind,
Expand All @@ -27,8 +28,8 @@ use tokio::{
time::Sleep,
};

use crate::config::ReconnectOptions;
use crate::error::{Error, Result};
use crate::{config::ReconnectOptions, ResponseWrapper};

use hyper::client::HttpConnector;
use hyper_timeout::TimeoutConnector;
Expand Down Expand Up @@ -393,6 +394,7 @@ where
let this = self.as_mut().project();
if let Some(event) = this.event_parser.get_event() {
return match event {
SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
SSE::Event(ref evt) => {
*this.last_event_id = evt.id.clone();

Expand Down Expand Up @@ -437,15 +439,27 @@ where
debug!("HTTP response: {:#?}", resp);

if resp.status().is_success() {
let reply =
Poll::Ready(Some(Ok(SSE::Connected(resp.headers().to_owned()))));
self.as_mut().project().retry_strategy.reset(Instant::now());
self.as_mut().reset_redirects();

let headers = resp.headers();
let mut map = HashMap::new();
for (key, value) in headers.iter() {
let key = key.to_string();
let value = match value.to_str() {
Ok(value) => value.to_string(),
Err(_) => String::from(""),
};
map.insert(key, value);
}
let status = resp.status().as_u16();

self.as_mut()
.project()
.state
.set(State::Connected(resp.into_body()));
return reply;

return Poll::Ready(Some(Ok(SSE::Connected((status, map)))));
}

if resp.status() == 301 || resp.status() == 307 {
Expand All @@ -470,7 +484,10 @@ where

self.as_mut().reset_redirects();
self.as_mut().project().state.set(State::New);
return Poll::Ready(Some(Err(Error::UnexpectedResponse(resp.status()))));

return Poll::Ready(Some(Err(Error::UnexpectedResponse(
ResponseWrapper::new(resp),
))));
}
Err(e) => {
// This seems basically impossible. AFAIK we can only get this way if we
Expand Down
82 changes: 79 additions & 3 deletions eventsource-client/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,77 @@
use hyper::StatusCode;
use std::collections::HashMap;

use hyper::{body::Buf, Body, Response};

pub struct ResponseWrapper {
response: Response<Body>,
}

impl ResponseWrapper {
pub fn new(response: Response<Body>) -> Self {
Self { response }
}
pub fn status(&self) -> u16 {
self.response.status().as_u16()
}
pub fn headers(&self) -> std::result::Result<HashMap<&str, &str>, HeaderError> {
let headers = self.response.headers();
let mut map = HashMap::new();
for (key, value) in headers.iter() {
let key = key.as_str();
let value = match value.to_str() {
Ok(value) => value,
Err(err) => return Err(HeaderError::new(Box::new(err))),
};
map.insert(key, value);
}
Ok(map)
}

pub async fn body_bytes(self) -> Result<Vec<u8>> {
let body = self.response.into_body();

let buf = match hyper::body::aggregate(body).await {
Ok(buf) => buf,
Err(err) => return Err(Error::HttpStream(Box::new(err))),
};

Ok(buf.chunk().to_vec())
}
}

impl std::fmt::Debug for ResponseWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseWrapper")
.field("status", &self.status())
.finish()
}
}

/// Error type for invalid response headers encountered in ResponseWrapper.
#[derive(Debug)]
pub struct HeaderError {
/// Wrapped inner error providing details about the header issue.
inner_error: Box<dyn std::error::Error + Send + Sync + 'static>,
}

impl HeaderError {
/// Constructs a new `HeaderError` wrapping an existing error.
pub fn new(err: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self {
HeaderError { inner_error: err }
}
}

impl std::fmt::Display for HeaderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid response header: {}", self.inner_error)
}
}

impl std::error::Error for HeaderError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.inner_error.as_ref())
}
}

/// Error type returned from this library's functions.
#[derive(Debug)]
Expand All @@ -8,7 +81,7 @@ pub enum Error {
/// An invalid request parameter
InvalidParameter(Box<dyn std::error::Error + Send + Sync + 'static>),
/// The HTTP response could not be handled.
UnexpectedResponse(StatusCode),
UnexpectedResponse(ResponseWrapper),
/// An error reading from the HTTP response body.
HttpStream(Box<dyn std::error::Error + Send + Sync + 'static>),
/// The HTTP response stream ended
Expand All @@ -32,7 +105,10 @@ impl std::fmt::Display for Error {
TimedOut => write!(f, "timed out"),
StreamClosed => write!(f, "stream closed"),
InvalidParameter(err) => write!(f, "invalid parameter: {err}"),
UnexpectedResponse(status_code) => write!(f, "unexpected response: {status_code}"),
UnexpectedResponse(r) => {
let status = r.status();
write!(f, "unexpected response: {status}")
}
HttpStream(err) => write!(f, "http error: {err}"),
Eof => write!(f, "eof"),
UnexpectedEof => write!(f, "unexpected eof"),
Expand Down
10 changes: 7 additions & 3 deletions eventsource-client/src/event_parser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::{collections::VecDeque, convert::TryFrom, str::from_utf8};
use std::{
collections::{HashMap, VecDeque},
convert::TryFrom,
str::from_utf8,
};

use hyper::{body::Bytes, http::HeaderValue, HeaderMap};
use hyper::body::Bytes;
use log::{debug, log_enabled, trace};
use pin_project::pin_project;

Expand Down Expand Up @@ -32,7 +36,7 @@ impl EventData {

#[derive(Debug, Eq, PartialEq)]
pub enum SSE {
Connected(HeaderMap<HeaderValue>),
Connected((u16, HashMap<String, String>)),
Event(Event),
Comment(String),
}
Expand Down
2 changes: 1 addition & 1 deletion eventsource-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
//! .map_ok(|event| match event {
//! SSE::Comment(comment) => println!("got a comment event: {:?}", comment),
//! SSE::Event(evt) => println!("got an event: {}", evt.event_type),
//! SSE::Connected(headers) => println!("got connection start with headers: {:?}", headers)
//! SSE::Connected(_) => println!("got connected")
//! })
//! .map_err(|e| println!("error streaming events: {:?}", e));
//! # while let Ok(Some(_)) = stream.try_next().await {}
Expand Down

0 comments on commit a772843

Please sign in to comment.