From 6dbbdda6d5832a12e046055fe9f0d9cd2f9653d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Sun, 1 Dec 2024 00:46:46 +0900 Subject: [PATCH] fix(driver): return buffer back if op cancelled --- compio-driver/src/buffer_pool/fallback.rs | 94 +++++++++++++++++++---- compio-driver/src/op.rs | 8 +- 2 files changed, 83 insertions(+), 19 deletions(-) diff --git a/compio-driver/src/buffer_pool/fallback.rs b/compio-driver/src/buffer_pool/fallback.rs index 260bf871..b5da4ecc 100644 --- a/compio-driver/src/buffer_pool/fallback.rs +++ b/compio-driver/src/buffer_pool/fallback.rs @@ -6,16 +6,28 @@ use std::{ io, mem::ManuallyDrop, ops::{Deref, DerefMut}, + rc::Rc, }; -use compio_buf::{IntoInner, IoBuf, SetBufInit, Slice}; +use compio_buf::{IntoInner, IoBuf, IoBufMut, SetBufInit, Slice}; + +struct BufferPoolInner { + buffers: RefCell>>, +} + +impl BufferPoolInner { + pub(crate) fn add_buffer(&self, mut buffer: Vec) { + buffer.clear(); + self.buffers.borrow_mut().push_back(buffer) + } +} /// Buffer pool /// /// A buffer pool to allow user no need to specify a specific buffer to do the /// IO operation pub struct BufferPool { - buffers: RefCell>>, + inner: Rc, } impl Debug for BufferPool { @@ -31,12 +43,14 @@ impl BufferPool { .collect(); Self { - buffers: RefCell::new(buffers), + inner: Rc::new(BufferPoolInner { + buffers: RefCell::new(buffers), + }), } } - pub(crate) fn get_buffer(&self, len: usize) -> io::Result>> { - let buffer = self.buffers.borrow_mut().pop_front().ok_or_else(|| { + pub(crate) fn get_buffer(&self, len: usize) -> io::Result { + let buffer = self.inner.buffers.borrow_mut().pop_front().ok_or_else(|| { io::Error::new(io::ErrorKind::Other, "buffer ring has no available buffer") })?; let len = if len == 0 { @@ -44,24 +58,74 @@ impl BufferPool { } else { buffer.capacity().min(len) }; - Ok(buffer.slice(..len)) + Ok(OwnedBuffer::new(buffer.slice(..len), self.inner.clone())) } - pub(crate) fn add_buffer(&self, mut buffer: Vec) { - buffer.clear(); - self.buffers.borrow_mut().push_back(buffer) + pub(crate) fn add_buffer(&self, buffer: Vec) { + self.inner.add_buffer(buffer); } /// Safety: `len` should be valid - pub(crate) unsafe fn create_proxy( - &self, - mut slice: Slice>, - len: usize, - ) -> BorrowedBuffer { + pub(crate) unsafe fn create_proxy(&self, mut slice: OwnedBuffer, len: usize) -> BorrowedBuffer { unsafe { slice.set_buf_init(len); } - BorrowedBuffer::new(slice, self) + BorrowedBuffer::new(slice.into_inner(), self) + } +} + +pub(crate) struct OwnedBuffer { + buffer: ManuallyDrop>>, + pool: Rc, +} + +impl OwnedBuffer { + fn new(buffer: Slice>, pool: Rc) -> Self { + Self { + buffer: ManuallyDrop::new(buffer), + pool, + } + } +} + +unsafe impl IoBuf for OwnedBuffer { + fn as_buf_ptr(&self) -> *const u8 { + self.buffer.as_buf_ptr() + } + + fn buf_len(&self) -> usize { + self.buffer.buf_len() + } + + fn buf_capacity(&self) -> usize { + self.buffer.buf_capacity() + } +} + +unsafe impl IoBufMut for OwnedBuffer { + fn as_buf_mut_ptr(&mut self) -> *mut u8 { + self.buffer.as_buf_mut_ptr() + } +} + +impl SetBufInit for OwnedBuffer { + unsafe fn set_buf_init(&mut self, len: usize) { + self.buffer.set_buf_init(len); + } +} + +impl Drop for OwnedBuffer { + fn drop(&mut self) { + self.pool + .add_buffer(unsafe { ManuallyDrop::take(&mut self.buffer) }.into_inner()); + } +} + +impl IntoInner for OwnedBuffer { + type Inner = Slice>; + + fn into_inner(mut self) -> Self::Inner { + unsafe { ManuallyDrop::take(&mut self.buffer) } } } diff --git a/compio-driver/src/op.rs b/compio-driver/src/op.rs index 45949dbb..83de834f 100644 --- a/compio-driver/src/op.rs +++ b/compio-driver/src/op.rs @@ -272,14 +272,14 @@ impl Connect { pub(crate) mod managed { use std::io; - use compio_buf::{IntoInner, Slice}; + use compio_buf::IntoInner; use super::{ReadAt, Recv}; - use crate::{BorrowedBuffer, BufferPool, SharedFd, TakeBuffer}; + use crate::{BorrowedBuffer, BufferPool, OwnedBuffer, SharedFd, TakeBuffer}; /// Read a file at specified position into managed buffer. pub struct ReadManagedAt { - pub(crate) op: ReadAt>, S>, + pub(crate) op: ReadAt, } impl ReadManagedAt { @@ -322,7 +322,7 @@ pub(crate) mod managed { /// Receive data from remote into managed buffer. pub struct RecvManaged { - pub(crate) op: Recv>, S>, + pub(crate) op: Recv, } impl RecvManaged {