diff --git a/Cargo.toml b/Cargo.toml index b9c39e28..b39b7714 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ nix = "0.29.0" once_cell = "1.18.0" os_pipe = "1.1.4" paste = "1.0.14" +pin-project-lite = "0.2.14" rand = "0.8.5" rustls = { version = "0.23.1", default-features = false } rustls-native-certs = "0.8.0" diff --git a/compio-io/Cargo.toml b/compio-io/Cargo.toml index 262db9c9..5278cb39 100644 --- a/compio-io/Cargo.toml +++ b/compio-io/Cargo.toml @@ -14,7 +14,7 @@ repository = { workspace = true } compio-buf = { workspace = true, features = ["arrayvec"] } futures-util = { workspace = true } paste = { workspace = true } -pin-project-lite = { version = "0.2.14", optional = true } +pin-project-lite = { workspace = true, optional = true } [dev-dependencies] compio-runtime = { workspace = true } diff --git a/compio-tls/Cargo.toml b/compio-tls/Cargo.toml index d26e7248..31c778c1 100644 --- a/compio-tls/Cargo.toml +++ b/compio-tls/Cargo.toml @@ -25,6 +25,9 @@ rustls = { workspace = true, default-features = false, optional = true, features "tls12", ] } +futures-util = { workspace = true, optional = true } +pin-project-lite = { workspace = true, optional = true } + [dev-dependencies] compio-net = { workspace = true } compio-runtime = { workspace = true } @@ -42,5 +45,7 @@ ring = ["rustls", "rustls/ring"] aws-lc-rs = ["rustls", "rustls/aws-lc-rs"] aws-lc-rs-fips = ["aws-lc-rs", "rustls/fips"] +io-compat = ["dep:futures-util", "dep:pin-project-lite"] + read_buf = ["compio-buf/read_buf", "compio-io/read_buf", "rustls?/read_buf"] nightly = ["read_buf"] diff --git a/compio-tls/src/stream/compat.rs b/compio-tls/src/stream/compat.rs new file mode 100644 index 00000000..0388ed92 --- /dev/null +++ b/compio-tls/src/stream/compat.rs @@ -0,0 +1,187 @@ +//! The code here should sync with `compio::io::compat`. + +use std::{ + future::Future, + io, + mem::MaybeUninit, + pin::Pin, + task::{Context, Poll}, +}; + +use compio_io::{AsyncRead, AsyncWrite}; +use pin_project_lite::pin_project; + +use crate::TlsStream; + +type PinBoxFuture = Pin>>; + +pin_project! { + /// A [`TlsStream`] wrapper for [`futures_util::io`] traits. + pub struct TlsStreamCompat { + #[pin] + inner: TlsStream, + read_future: Option>>, + write_future: Option>>, + shutdown_future: Option>>, + } +} + +impl TlsStreamCompat { + /// Create [`TlsStreamCompat`] from [`TlsStream`]. + pub fn new(stream: TlsStream) -> Self { + Self { + inner: stream, + read_future: None, + write_future: None, + shutdown_future: None, + } + } + + /// Get the reference of the inner stream. + pub fn get_ref(&self) -> &TlsStream { + &self.inner + } +} + +impl From> for TlsStreamCompat { + fn from(value: TlsStream) -> Self { + Self::new(value) + } +} + +macro_rules! poll_future { + ($f:expr, $cx:expr, $e:expr) => {{ + let mut future = match $f.take() { + Some(f) => f, + None => Box::pin($e), + }; + let f = future.as_mut(); + match f.poll($cx) { + Poll::Pending => { + $f.replace(future); + return Poll::Pending; + } + Poll::Ready(res) => res, + } + }}; +} + +macro_rules! poll_future_would_block { + ($f:expr, $cx:expr, $e:expr, $io:expr) => {{ + if let Some(mut f) = $f.take() { + if f.as_mut().poll($cx).is_pending() { + $f.replace(f); + return Poll::Pending; + } + } + + match $io { + Ok(len) => Poll::Ready(Ok(len)), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + $f.replace(Box::pin($e)); + $cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + }}; +} + +impl futures_util::AsyncRead for TlsStreamCompat { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let this = self.project(); + let inner: &'static mut TlsStream = + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + + poll_future_would_block!( + this.read_future, + cx, + inner.0.get_mut().fill_read_buf(), + io::Read::read(&mut inner.0, buf) + ) + } +} + +impl TlsStreamCompat { + /// Attempt to read from the `AsyncRead` into `buf`. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. + pub fn poll_read_uninit( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [MaybeUninit], + ) -> Poll> { + let this = self.project(); + + let inner: &'static mut TlsStream = + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + poll_future_would_block!( + this.read_future, + cx, + inner.0.get_mut().fill_read_buf(), + super::read_buf(&mut inner.0, buf) + ) + } +} + +impl futures_util::AsyncWrite for TlsStreamCompat { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + + if this.shutdown_future.is_some() { + debug_assert!(this.write_future.is_none()); + return Poll::Pending; + } + + let inner: &'static mut TlsStream = + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + poll_future_would_block!( + this.write_future, + cx, + inner.0.get_mut().flush_write_buf(), + io::Write::write(&mut inner.0, buf) + ) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + if this.shutdown_future.is_some() { + debug_assert!(this.write_future.is_none()); + return Poll::Pending; + } + + let inner: &'static mut TlsStream = + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + let res = poll_future!(this.write_future, cx, inner.0.get_mut().flush_write_buf()); + Poll::Ready(res.map(|_| ())) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + // Avoid shutdown on flush because the inner buffer might be passed to the + // driver. + if this.write_future.is_some() { + debug_assert!(this.shutdown_future.is_none()); + return Poll::Pending; + } + + let inner: &'static mut TlsStream = + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + let res = poll_future!( + this.shutdown_future, + cx, + inner.0.get_mut().get_mut().shutdown() + ); + Poll::Ready(res) + } +} diff --git a/compio-tls/src/stream/mod.rs b/compio-tls/src/stream/mod.rs index 7d5e7ea2..41d91da5 100644 --- a/compio-tls/src/stream/mod.rs +++ b/compio-tls/src/stream/mod.rs @@ -6,6 +6,12 @@ use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream}; #[cfg(feature = "rustls")] mod rtls; +#[cfg(feature = "io-compat")] +mod compat; + +#[cfg(feature = "io-compat")] +pub use compat::*; + #[derive(Debug)] #[allow(clippy::large_enum_variant)] enum TlsStreamInner { @@ -113,8 +119,7 @@ impl From>> for TlsStream { #[cfg(not(feature = "read_buf"))] #[inline] -fn read_buf(reader: &mut impl io::Read, buf: &mut B) -> io::Result { - let slice: &mut [MaybeUninit] = buf.as_mut_slice(); +fn read_buf(reader: &mut impl io::Read, slice: &mut [MaybeUninit]) -> io::Result { slice.fill(MaybeUninit::new(0)); let slice = unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) }; reader.read(slice) @@ -122,8 +127,7 @@ fn read_buf(reader: &mut impl io::Read, buf: &mut B) -> io::Result< #[cfg(feature = "read_buf")] #[inline] -fn read_buf(reader: &mut impl io::Read, buf: &mut B) -> io::Result { - let slice: &mut [MaybeUninit] = buf.as_mut_slice(); +fn read_buf(reader: &mut impl io::Read, slice: &mut [MaybeUninit]) -> io::Result { let mut borrowed_buf = io::BorrowedBuf::from(slice); let mut cursor = borrowed_buf.unfilled(); reader.read_buf(cursor.reborrow())?; @@ -133,7 +137,7 @@ fn read_buf(reader: &mut impl io::Read, buf: &mut B) -> io::Result< impl AsyncRead for TlsStream { async fn read(&mut self, mut buf: B) -> BufResult { loop { - let res = read_buf(&mut self.0, &mut buf); + let res = read_buf(&mut self.0, buf.as_mut_slice()); match res { Ok(res) => { unsafe { buf.set_buf_init(res) }; diff --git a/compio/Cargo.toml b/compio/Cargo.toml index 8ff6326d..12f4ad2a 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -98,7 +98,12 @@ polling = [ "compio-dispatcher?/polling", ] io = ["dep:compio-io"] -io-compat = ["io", "compio-io/compat", "compio-quic?/io-compat"] +io-compat = [ + "io", + "compio-io/compat", + "compio-quic?/io-compat", + "compio-tls?/io-compat", +] runtime = ["dep:compio-runtime", "dep:compio-fs", "dep:compio-net", "io"] macros = ["dep:compio-macros", "runtime"] event = ["compio-runtime/event", "runtime"]