Skip to content

Commit

Permalink
feat(tls): add tls stream compat wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Berrysoft committed Dec 20, 2024
1 parent 7ff681e commit 16a3fd4
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 7 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ nix = "0.29.0"
once_cell = "1.18.0"
os_pipe = "1.1.4"
paste = "1.0.14"
pin-project-lite = "0.2.14"
rand = "0.8.5"
rustls = { version = "0.23.1", default-features = false }
rustls-native-certs = "0.8.0"
Expand Down
2 changes: 1 addition & 1 deletion compio-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repository = { workspace = true }
compio-buf = { workspace = true, features = ["arrayvec"] }
futures-util = { workspace = true }
paste = { workspace = true }
pin-project-lite = { version = "0.2.14", optional = true }
pin-project-lite = { workspace = true, optional = true }

[dev-dependencies]
compio-runtime = { workspace = true }
Expand Down
5 changes: 5 additions & 0 deletions compio-tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ rustls = { workspace = true, default-features = false, optional = true, features
"tls12",
] }

futures-util = { workspace = true, optional = true }
pin-project-lite = { workspace = true, optional = true }

[dev-dependencies]
compio-net = { workspace = true }
compio-runtime = { workspace = true }
Expand All @@ -42,5 +45,7 @@ ring = ["rustls", "rustls/ring"]
aws-lc-rs = ["rustls", "rustls/aws-lc-rs"]
aws-lc-rs-fips = ["aws-lc-rs", "rustls/fips"]

io-compat = ["dep:futures-util", "dep:pin-project-lite"]

read_buf = ["compio-buf/read_buf", "compio-io/read_buf", "rustls?/read_buf"]
nightly = ["read_buf"]
187 changes: 187 additions & 0 deletions compio-tls/src/stream/compat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//! The code here should sync with `compio::io::compat`.
use std::{
future::Future,
io,
mem::MaybeUninit,
pin::Pin,
task::{Context, Poll},
};

use compio_io::{AsyncRead, AsyncWrite};
use pin_project_lite::pin_project;

use crate::TlsStream;

type PinBoxFuture<T> = Pin<Box<dyn Future<Output = T>>>;

pin_project! {
/// A [`TlsStream`] wrapper for [`futures_util::io`] traits.
pub struct TlsStreamCompat<S> {
#[pin]
inner: TlsStream<S>,
read_future: Option<PinBoxFuture<io::Result<usize>>>,
write_future: Option<PinBoxFuture<io::Result<usize>>>,
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
}
}

impl<S> TlsStreamCompat<S> {
/// Create [`TlsStreamCompat`] from [`TlsStream`].
pub fn new(stream: TlsStream<S>) -> Self {
Self {
inner: stream,
read_future: None,
write_future: None,
shutdown_future: None,
}
}

/// Get the reference of the inner stream.
pub fn get_ref(&self) -> &TlsStream<S> {
&self.inner
}
}

impl<S> From<TlsStream<S>> for TlsStreamCompat<S> {
fn from(value: TlsStream<S>) -> Self {
Self::new(value)
}
}

macro_rules! poll_future {
($f:expr, $cx:expr, $e:expr) => {{
let mut future = match $f.take() {
Some(f) => f,
None => Box::pin($e),
};
let f = future.as_mut();
match f.poll($cx) {
Poll::Pending => {
$f.replace(future);
return Poll::Pending;
}
Poll::Ready(res) => res,
}
}};
}

macro_rules! poll_future_would_block {
($f:expr, $cx:expr, $e:expr, $io:expr) => {{
if let Some(mut f) = $f.take() {
if f.as_mut().poll($cx).is_pending() {
$f.replace(f);
return Poll::Pending;
}
}

match $io {
Ok(len) => Poll::Ready(Ok(len)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
$f.replace(Box::pin($e));
$cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}};
}

impl<S: AsyncRead + 'static> futures_util::AsyncRead for TlsStreamCompat<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
let inner: &'static mut TlsStream<S> =
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };

poll_future_would_block!(
this.read_future,
cx,
inner.0.get_mut().fill_read_buf(),
io::Read::read(&mut inner.0, buf)
)
}
}

impl<S: AsyncRead + 'static> TlsStreamCompat<S> {
/// Attempt to read from the `AsyncRead` into `buf`.
///
/// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
pub fn poll_read_uninit(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [MaybeUninit<u8>],
) -> Poll<io::Result<usize>> {
let this = self.project();

let inner: &'static mut TlsStream<S> =
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
poll_future_would_block!(
this.read_future,
cx,
inner.0.get_mut().fill_read_buf(),
super::read_buf(&mut inner.0, buf)
)
}
}

impl<S: AsyncWrite + 'static> futures_util::AsyncWrite for TlsStreamCompat<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();

if this.shutdown_future.is_some() {
debug_assert!(this.write_future.is_none());
return Poll::Pending;
}

let inner: &'static mut TlsStream<S> =
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
poll_future_would_block!(
this.write_future,
cx,
inner.0.get_mut().flush_write_buf(),
io::Write::write(&mut inner.0, buf)
)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();

if this.shutdown_future.is_some() {
debug_assert!(this.write_future.is_none());
return Poll::Pending;
}

let inner: &'static mut TlsStream<S> =
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
let res = poll_future!(this.write_future, cx, inner.0.get_mut().flush_write_buf());
Poll::Ready(res.map(|_| ()))
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();

// Avoid shutdown on flush because the inner buffer might be passed to the
// driver.
if this.write_future.is_some() {
debug_assert!(this.shutdown_future.is_none());
return Poll::Pending;
}

let inner: &'static mut TlsStream<S> =
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
let res = poll_future!(
this.shutdown_future,
cx,
inner.0.get_mut().get_mut().shutdown()
);
Poll::Ready(res)
}
}
14 changes: 9 additions & 5 deletions compio-tls/src/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
#[cfg(feature = "rustls")]
mod rtls;

#[cfg(feature = "io-compat")]
mod compat;

#[cfg(feature = "io-compat")]
pub use compat::*;

#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum TlsStreamInner<S> {
Expand Down Expand Up @@ -113,17 +119,15 @@ impl<S> From<native_tls::TlsStream<SyncStream<S>>> for TlsStream<S> {

#[cfg(not(feature = "read_buf"))]
#[inline]
fn read_buf<B: IoBufMut>(reader: &mut impl io::Read, buf: &mut B) -> io::Result<usize> {
let slice: &mut [MaybeUninit<u8>] = buf.as_mut_slice();
fn read_buf(reader: &mut impl io::Read, slice: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
slice.fill(MaybeUninit::new(0));
let slice = unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) };
reader.read(slice)
}

#[cfg(feature = "read_buf")]
#[inline]
fn read_buf<B: IoBufMut>(reader: &mut impl io::Read, buf: &mut B) -> io::Result<usize> {
let slice: &mut [MaybeUninit<u8>] = buf.as_mut_slice();
fn read_buf(reader: &mut impl io::Read, slice: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let mut borrowed_buf = io::BorrowedBuf::from(slice);
let mut cursor = borrowed_buf.unfilled();
reader.read_buf(cursor.reborrow())?;
Expand All @@ -133,7 +137,7 @@ fn read_buf<B: IoBufMut>(reader: &mut impl io::Read, buf: &mut B) -> io::Result<
impl<S: AsyncRead> AsyncRead for TlsStream<S> {
async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
loop {
let res = read_buf(&mut self.0, &mut buf);
let res = read_buf(&mut self.0, buf.as_mut_slice());
match res {
Ok(res) => {
unsafe { buf.set_buf_init(res) };
Expand Down
7 changes: 6 additions & 1 deletion compio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ polling = [
"compio-dispatcher?/polling",
]
io = ["dep:compio-io"]
io-compat = ["io", "compio-io/compat", "compio-quic?/io-compat"]
io-compat = [
"io",
"compio-io/compat",
"compio-quic?/io-compat",
"compio-tls?/io-compat",
]
runtime = ["dep:compio-runtime", "dep:compio-fs", "dep:compio-net", "io"]
macros = ["dep:compio-macros", "runtime"]
event = ["compio-runtime/event", "runtime"]
Expand Down

0 comments on commit 16a3fd4

Please sign in to comment.