Skip to content

Commit

Permalink
feat(io): reimplement vectored extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
Berrysoft committed Dec 20, 2024
1 parent 7ff681e commit 634988d
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 94 deletions.
1 change: 1 addition & 0 deletions compio-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ pub mod compat;
mod read;
mod split;
pub mod util;
mod vectored;
mod write;

pub(crate) type IoResult<T> = std::io::Result<T>;
Expand Down
93 changes: 46 additions & 47 deletions compio-io/src/read/ext.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#[cfg(feature = "allocator_api")]
use std::alloc::Allocator;

use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, t_alloc};
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, SetBufInit, t_alloc};

use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take};
use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take, vectored::VectoredWrap};

/// Shared code for read a scalar value from the underlying reader.
macro_rules! read_scalar {
Expand Down Expand Up @@ -36,63 +36,35 @@ macro_rules! read_scalar {

/// Shared code for loop reading until reaching a certain length.
macro_rules! loop_read_exact {
($buf:ident, $len:expr, $tracker:ident,loop $read_expr:expr) => {
let mut $tracker = 0;
($buf:ident, $len:expr, $tracker:ident, $read_expr:expr, $update_expr:expr, $buf_expr:expr) => {
let mut $tracker = 0usize;
let len = $len;

while $tracker < len {
match $read_expr.await.into_inner() {
BufResult(Ok(0), buf) => {
let BufResult(res, buf) = $read_expr;
$buf = buf;
match res {
Ok(0) => {
return BufResult(
Err(::std::io::Error::new(
::std::io::ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
)),
buf,
$buf_expr,
);
}
BufResult(Ok(n), buf) => {
$tracker += n;
$buf = buf;
}
BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
$buf = buf;
Ok(n) => {
$tracker += n as usize;
$update_expr;
}
BufResult(Err(e), buf) => return BufResult(Err(e), buf),
Err(ref e) if e.kind() == ::std::io::ErrorKind::Interrupted => {}
Err(e) => return BufResult(Err(e), $buf_expr),
}
}
return BufResult(Ok(()), $buf)
return BufResult(Ok(()), $buf_expr)
};
}

macro_rules! loop_read_vectored {
($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
use ::compio_buf::OwnedIterator;

let mut $iter = match $buf.owned_iter() {
Ok(buf) => buf,
Err(buf) => return BufResult(Ok(()), buf),
};
let mut $tracker: $tracker_ty = 0;

loop {
let len = $iter.buf_capacity();
if len > 0 {
match $read_expr.await {
BufResult(Ok(()), ret) => {
$iter = ret;
$tracker += len as $tracker_ty;
}
BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
};
}

match $iter.next() {
Ok(next) => $iter = next,
Err(buf) => return BufResult(Ok(()), buf),
}
}
}};
($buf:ident, $iter:ident, $read_expr:expr) => {{
use ::compio_buf::OwnedIterator;

Expand Down Expand Up @@ -158,7 +130,14 @@ pub trait AsyncReadExt: AsyncRead {

/// Read the exact number of bytes required to fill the buf.
async fn read_exact<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<(), T> {
loop_read_exact!(buf, buf.buf_capacity(), read, loop self.read(buf.slice(read..)));
loop_read_exact!(
buf,
buf.buf_capacity(),
read,
self.read(buf.slice(read..)).await.into_inner(),
{},
buf
);
}

/// Read all bytes until underlying reader reaches `EOF`.
Expand All @@ -171,7 +150,15 @@ pub trait AsyncReadExt: AsyncRead {

/// Read the exact number of bytes required to fill the vectored buf.
async fn read_vectored_exact<T: IoVectoredBufMut>(&mut self, buf: T) -> BufResult<(), T> {
loop_read_vectored!(buf, _total: usize, iter, loop self.read_exact(iter))
let mut buf = VectoredWrap::new(buf);
loop_read_exact!(
buf,
buf.capacity(),
read,
self.read_vectored(buf).await,
unsafe { buf.set_buf_init(read) },
buf.into_inner()
);
}

/// Creates an adaptor which reads at most `limit` bytes from it.
Expand Down Expand Up @@ -234,7 +221,11 @@ pub trait AsyncReadAtExt: AsyncReadAt {
buf,
buf.buf_capacity(),
read,
loop self.read_at(buf.slice(read..), pos + read as u64)
self.read_at(buf.slice(read..), pos + read as u64)
.await
.into_inner(),
{},
buf
);
}

Expand Down Expand Up @@ -262,7 +253,15 @@ pub trait AsyncReadAtExt: AsyncReadAt {
buf: T,
pos: u64,
) -> BufResult<(), T> {
loop_read_vectored!(buf, total: u64, iter, loop self.read_exact_at(iter, pos + total))
let mut buf = VectoredWrap::new(buf);
loop_read_exact!(
buf,
buf.capacity(),
read,
self.read_vectored_at(buf, pos + read as u64).await,
unsafe { buf.set_buf_init(read) },
buf.into_inner()
);
}
}

Expand Down
153 changes: 153 additions & 0 deletions compio-io/src/vectored.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use std::pin::Pin;

use compio_buf::{
Indexable, IndexableMut, IndexedIter, IntoInner, IoBuf, IoBufMut, IoVectoredBuf,
IoVectoredBufMut, MaybeOwned, MaybeOwnedMut, SetBufInit,
};

pub struct VectoredWrap<T> {
buffers: Pin<Box<T>>,
wraps: Vec<BufWrap>,
vec_off: usize,
}

impl<T: IoVectoredBuf> VectoredWrap<T> {
pub fn new(buffers: T) -> Self {
let buffers = Box::pin(buffers);
let wraps = buffers.iter_buf().map(|buf| BufWrap::new(&*buf)).collect();
Self {
buffers,
wraps,
vec_off: 0,
}
}

pub fn len(&self) -> usize {
self.wraps.iter().map(|buf| buf.len).sum()
}

pub fn capacity(&self) -> usize {
self.wraps.iter().map(|buf| buf.capacity).sum()
}
}

impl<T: IoVectoredBuf + 'static> IoVectoredBuf for VectoredWrap<T> {
type Buf = BufWrap;
type OwnedIter = IndexedIter<Self>;

fn iter_buf(&self) -> impl Iterator<Item = MaybeOwned<'_, Self::Buf>> {
self.wraps
.iter()
.skip(self.vec_off)
.map(MaybeOwned::Borrowed)
}

fn owned_iter(self) -> Result<Self::OwnedIter, Self>
where
Self: Sized,
{
IndexedIter::new(self)
}
}

impl<T: IoVectoredBufMut + 'static> IoVectoredBufMut for VectoredWrap<T> {
fn iter_buf_mut(&mut self) -> impl Iterator<Item = MaybeOwnedMut<'_, Self::Buf>> {
self.wraps
.iter_mut()
.skip(self.vec_off)
.map(MaybeOwnedMut::Borrowed)
}
}

impl<T: SetBufInit> SetBufInit for VectoredWrap<T> {
unsafe fn set_buf_init(&mut self, mut len: usize) {
self.buffers.as_mut().get_unchecked_mut().set_buf_init(len);
self.vec_off = 0;
for buf in self.wraps.iter_mut().skip(self.vec_off) {
let capacity = (*buf).buf_capacity();
let buf_new_len = len.min(capacity);
buf.set_buf_init(buf_new_len);
*buf = buf.offset(buf_new_len);
if len >= capacity {
len -= capacity;
} else {
break;
}
self.vec_off += 1;
}
}
}

impl<T> Indexable for VectoredWrap<T> {
type Output = BufWrap;

fn index(&self, n: usize) -> Option<&Self::Output> {
self.wraps.get(n + self.vec_off)
}
}

impl<T> IndexableMut for VectoredWrap<T> {
fn index_mut(&mut self, n: usize) -> Option<&mut Self::Output> {
self.wraps.get_mut(n + self.vec_off)
}
}

impl<T> IntoInner for VectoredWrap<T> {
type Inner = T;

fn into_inner(self) -> Self::Inner {
// Safety: no pointers still maintaining
*unsafe { Pin::into_inner_unchecked(self.buffers) }
}
}

pub struct BufWrap {
ptr: *mut u8,
len: usize,
capacity: usize,
}

impl BufWrap {
fn new<T: IoBuf>(buf: &T) -> Self {
Self {
ptr: buf.as_buf_ptr().cast_mut(),
len: buf.buf_len(),
capacity: buf.buf_capacity(),
}
}

fn offset(&self, off: usize) -> Self {
Self {
ptr: unsafe { self.ptr.add(off) },
len: self.len.saturating_sub(off),
capacity: self.capacity.saturating_sub(off),
}
}
}

unsafe impl IoBuf for BufWrap {
fn as_buf_ptr(&self) -> *const u8 {
self.ptr.cast_const()
}

fn buf_len(&self) -> usize {
self.len
}

fn buf_capacity(&self) -> usize {
self.capacity
}
}

unsafe impl IoBufMut for BufWrap {
fn as_buf_mut_ptr(&mut self) -> *mut u8 {
self.ptr
}
}

impl SetBufInit for BufWrap {
unsafe fn set_buf_init(&mut self, len: usize) {
debug_assert!(len <= self.capacity, "{} > {}", len, self.capacity);
self.len = self.len.max(len);
}
}
Loading

0 comments on commit 634988d

Please sign in to comment.