Skip to content

Commit

Permalink
Merge pull request #233 from EspressoSystems/jb/sink-lifetime
Browse files Browse the repository at this point in the history
Allow socket messages to/from clients to be handled in separate tasks
  • Loading branch information
jbearer authored Jul 2, 2024
2 parents c99d61d + 88fc1c8 commit 1253cd8
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 95 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tide-disco"
version = "0.8.0"
version = "0.9.0"
edition = "2021"
authors = ["Espresso Systems <[email protected]>"]
description = "Discoverability for Tide"
Expand Down Expand Up @@ -42,6 +42,7 @@ libc = "0.2"
markdown = "0.3"
maud = { version = "0.26", features = ["tide"] }
parking_lot = "0.12"
pin-project = "1.0"
prometheus = "0.13"
reqwest = { version = "0.12", features = ["json"] }
routefinder = "0.5"
Expand Down
46 changes: 6 additions & 40 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

86 changes: 84 additions & 2 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,89 @@ where
/// }.boxed());
/// # }
/// ```
//
///
/// In some cases, it may be desirable to handle messages to and from the client in separate
/// tasks. There are two ways of doing this:
///
/// ## Split the connection into separate stream and sink
///
/// ```
/// use async_std::task::spawn;
/// use futures::{future::{join, FutureExt}, sink::SinkExt, stream::StreamExt};
/// use tide_disco::{error::ServerError, socket::Connection, Api};
/// # use vbs::version::StaticVersion;
///
/// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) {
/// api.socket("endpoint", |_req, mut conn: Connection<i32, i32, ServerError, StaticVersion<0, 1>>, _state| async move {
/// let (mut sink, mut stream) = conn.split();
/// let recv = spawn(async move {
/// while let Some(Ok(msg)) = stream.next().await {
/// // Handle message from client.
/// }
/// });
/// let send = spawn(async move {
/// loop {
/// let msg = // get message to send to client
/// # 0;
/// sink.send(msg).await;
/// }
/// });
///
/// join(send, recv).await;
/// Ok(())
/// }.boxed());
/// # }
/// ```
///
/// This approach requires messages to be sent to the client by value, consuming the message.
/// This is because, if we were to use the `Sync<&ToClient>` implementation for `Connection`,
/// the lifetime for `&ToClient` would be fixed after `split` is called, since the lifetime
/// appears in the return type, `SplitSink<Connection<...>, &ToClient>`. Thus, this lifetime
/// outlives any scoped local variables created after the `split` call, such as `msg` in the
/// `loop`.
///
/// If we want to use the message after sending it to the client, we would have to clone it,
/// which may be inefficient or impossible. Thus, there is another approach:
///
/// ## Clone the connection
///
/// ```
/// use async_std::task::spawn;
/// use futures::{future::{join, FutureExt}, sink::SinkExt, stream::StreamExt};
/// use tide_disco::{error::ServerError, socket::Connection, Api};
/// # use vbs::version::StaticVersion;
///
/// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) {
/// api.socket("endpoint", |_req, mut conn: Connection<i32, i32, ServerError, StaticVersion<0, 1>>, _state| async move {
/// let recv = {
/// let mut conn = conn.clone();
/// spawn(async move {
/// while let Some(Ok(msg)) = conn.next().await {
/// // Handle message from client.
/// }
/// })
/// };
/// let send = spawn(async move {
/// loop {
/// let msg = // get message to send to client
/// # 0;
/// conn.send(&msg).await;
/// // msg is still live at this point.
/// drop(msg);
/// }
/// });
///
/// join(send, recv).await;
/// Ok(())
/// }.boxed());
/// # }
/// ```
///
/// Depending on the exact situation, this method may end up being more verbose than the
/// previous example. But it allows us to retain the higher-ranked trait bound `conn: for<'a>
/// Sink<&'a ToClient>` instead of fixing the lifetime, which can prevent an unnecessary clone
/// in certain situations.
///
/// # Errors
///
/// If the route `name` does not exist in the API specification, or if the route already has a
Expand Down Expand Up @@ -1531,7 +1613,7 @@ mod test {
.unwrap()
.socket(
"once",
|_req, mut conn: Connection<_, (), _, StaticVer01>, _state| {
|_req, mut conn: Connection<str, (), _, StaticVer01>, _state| {
async move {
conn.send("msg").boxed().await?;
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ mod test {
.unwrap()
.socket(
"socket_test",
|_req, mut conn: Connection<_, (), _, StaticVer01>, _state| {
|_req, mut conn: Connection<str, (), _, StaticVer01>, _state| {
async move {
conn.send("SOCKET").await.unwrap();
Ok(())
Expand Down
94 changes: 44 additions & 50 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use futures::{
task::{Context, Poll},
FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt,
};
use pin_project::pin_project;
use serde::{de::DeserializeOwned, Serialize};
use std::borrow::Cow;
use std::fmt::{self, Display, Formatter};
Expand Down Expand Up @@ -133,7 +134,9 @@ enum MessageType {
///
/// [Connection] implements [Stream], which can be used to receive `FromClient` messages from the
/// client, and [Sink] which can be used to send `ToClient` messages to the client.
#[pin_project]
pub struct Connection<ToClient: ?Sized, FromClient, Error, VER: StaticVersionType> {
#[pin]
conn: WebSocketConnection,
// [Sink] wrapper around `conn`
sink: Pin<Box<dyn Send + Sink<Message, Error = SocketError<Error>>>>,
Expand All @@ -150,7 +153,7 @@ impl<ToClient: ?Sized, FromClient: DeserializeOwned, E, VER: StaticVersionType>
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// Get a `Pin<&mut WebSocketConnection>` for the underlying connection, so we can use the
// `Stream` implementation of that field.
match self.pinned_inner().poll_next(cx) {
match self.project().conn.poll_next(cx) {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(match msg {
Expand Down Expand Up @@ -194,29 +197,25 @@ impl<ToClient: Serialize + ?Sized, FromClient, E, VER: StaticVersionType> Sink<&
}
}

impl<ToClient: ?Sized, FromClient, Error, VER: StaticVersionType> Drop
for Connection<ToClient, FromClient, Error, VER>
impl<ToClient: Serialize, FromClient, E, VER: StaticVersionType> Sink<ToClient>
for Connection<ToClient, FromClient, E, VER>
{
fn drop(&mut self) {
// This is the idiomatic way to implement [drop] for a type that uses pinning. Since [drop]
// is implicitly called with `&mut self` even on types that were pinned, we place any
// implementation inside [inner_drop], which takes `Pin<&mut Self>`, when the commpiler will
// be able to check that we do not do anything that we couldn't have done on a
// `Pin<&mut Self>`.
//
// The [drop] implementation for this type is trivial, and it would be safe to use the
// automatically generated [drop] implementation, but we nonetheless implement [drop]
// explicitly in the idiomatic fashion so that it is impossible to accidentally implement an
// unsafe version of [drop] for this type in the future.

// `new_unchecked` is okay because we know this value is never used again after being
// dropped.
inner_drop(unsafe { Pin::new_unchecked(self) });
fn inner_drop<ToClient: ?Sized, FromClient, Error, VER: StaticVersionType>(
_this: Pin<&mut Connection<ToClient, FromClient, Error, VER>>,
) {
// Any logic goes here.
}
type Error = SocketError<E>;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Sink::<&ToClient>::poll_ready(self, cx)
}

fn start_send(self: Pin<&mut Self>, item: ToClient) -> Result<(), Self::Error> {
self.start_send(&item)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Sink::<&ToClient>::poll_flush(self, cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Sink::<&ToClient>::poll_close(self, cx)
}
}

Expand All @@ -233,39 +232,34 @@ impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType>
unreachable!()
};
Ok(Self {
sink: Box::pin(sink::unfold(
(conn.clone(), ty),
|(conn, accept), msg| async move {
conn.send(msg).await?;
Ok((conn, accept))
},
)),
sink: Self::sink(conn.clone()),
conn,
accept: ty,
_phantom: Default::default(),
})
}

/// Project a `Pin<&mut Self>` to a pinned reference to the underlying connection.
fn pinned_inner(self: Pin<&mut Self>) -> Pin<&mut WebSocketConnection> {
// # Soundness
//
// This implements _structural pinning_ for [Connection]. This comes with some requirements
// to maintain safety, as described at
// https://doc.rust-lang.org/std/pin/index.html#pinning-is-structural-for-field:
//
// 1. The struct must only be [Unpin] if all the structural fields are [Unpin]. This is the
// default, and we don't explicitly implement [Unpin] for [Connection].
// 2. The destructor of the struct must not move structural fields out of its argument. This
// is enforced by the compiler in our [Drop] implementation, which follows the idiom for
// safe [Drop] implementations for pinned structs.
// 3. You must make sure that you uphold the [Drop] guarantee: once your struct is pinned,
// the memory that contains the content is not overwritten or deallocated without calling
// the content’s destructors. This is also enforced by our [Drop] implementation.
// 4. You must not offer any other operations that could lead to data being moved out of the
// structural fields when your type is pinned. There are no operations on this type that
// move out of `conn`.
unsafe { self.map_unchecked_mut(|s| &mut s.conn) }
/// Wrap a `WebSocketConnection` in a type that implements `Sink<Message>`.
fn sink(
conn: WebSocketConnection,
) -> Pin<Box<dyn Send + Sink<Message, Error = SocketError<E>>>> {
Box::pin(sink::unfold(conn, |conn, msg| async move {
conn.send(msg).await?;
Ok(conn)
}))
}
}

impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType> Clone
for Connection<ToClient, FromClient, E, VER>
{
fn clone(&self) -> Self {
Self {
sink: Self::sink(self.conn.clone()),
conn: self.conn.clone(),
accept: self.accept,
_phantom: Default::default(),
}
}
}

Expand Down

0 comments on commit 1253cd8

Please sign in to comment.