Skip to content

Commit

Permalink
feat(rust): added global buffer allocator to reduce memory fragmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
davide-baldo committed Nov 14, 2024
1 parent 6d3e50b commit d1698c0
Show file tree
Hide file tree
Showing 15 changed files with 167 additions and 26 deletions.
70 changes: 70 additions & 0 deletions implementations/rust/ockam/ockam_core/src/buffer_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use crate::compat::sync::Mutex;
use crate::compat::vec::Vec;
use core::fmt::{Debug, Formatter};
use once_cell::sync::Lazy;

const MIN_BUFFER_SIZE: usize = 96 * 1024;
const MAX_BUFFER_SIZE: usize = 192 * 1024;
const MAX_BUFFERS: usize = 32;

/// The global instance of [`BufferPool`].
pub static GLOBAL_BUFFER_POOL: Lazy<BufferPool> = Lazy::new(BufferPool::new);

/// A buffer pool for reusing buffers at least big as [`MIN_BUFFER_SIZE`].
pub struct BufferPool {
buffers: Mutex<Inner>,
}

impl Debug for BufferPool {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
let len = self.buffers.lock().unwrap().buffers.len();
f.debug_struct("BufferPool").field("buffers", &len).finish()
}
}

struct Inner {
buffers: Vec<Vec<u8>>,
}

impl BufferPool {
fn new() -> Self {
Self {
buffers: Mutex::new(Inner {
buffers: Vec::new(),
}),
}
}

/// When the size is big enough, it'll return a buffer from the pool,
/// otherwise it'll return a new buffer.
pub fn try_size(&self, size: usize) -> Vec<u8> {
if (MIN_BUFFER_SIZE..=MAX_BUFFER_SIZE).contains(&size) {
self.take()
} else {
Vec::with_capacity(size)
}
}

/// Take a buffer from the pool.
pub fn take(&self) -> Vec<u8> {
let mut buffers = self.buffers.lock().unwrap();
if let Some(mut buffer) = buffers.buffers.pop() {
buffer.clear();
buffer
} else {
Vec::with_capacity(MIN_BUFFER_SIZE)
}
}

/// Release a buffer back to the pool, the buffer will only be reused if
/// it's capacity is within [`MIN_BUFFER_SIZE`] and [`MAX_BUFFER_SIZE`].
pub fn release(&self, buffer: Vec<u8>) {
if buffer.capacity() >= MIN_BUFFER_SIZE && buffer.capacity() <= MAX_BUFFER_SIZE {
let mut buffers = self.buffers.lock().unwrap();
if buffers.buffers.len() < MAX_BUFFERS {
buffers.buffers.push(buffer);
buffers.buffers.sort_by_key(|b| -(b.capacity() as i64));
}
}
}
}
4 changes: 2 additions & 2 deletions implementations/rust/ockam/ockam_core/src/cbor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod cow_str;
pub(crate) mod schema;

use crate::compat::vec::Vec;
use crate::Result;
use crate::{Result, GLOBAL_BUFFER_POOL};
use minicbor::{CborLen, Encode};

/// Encode a type implementing [`Encode`] and return the encoded byte vector.
Expand All @@ -18,7 +18,7 @@ where
T: Encode<()> + CborLen<()>,
{
let expected_len = minicbor::len(&x);
let mut output = Vec::with_capacity(expected_len);
let mut output = GLOBAL_BUFFER_POOL.try_size(expected_len);
minicbor::encode(x, &mut output)?;
Ok(output)
}
2 changes: 2 additions & 0 deletions implementations/rust/ockam/ockam_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub mod hex_encoding;
pub mod env;

pub mod bare;
mod buffer_pool;
mod cbor;
mod error;
mod identity;
Expand All @@ -88,6 +89,7 @@ mod uint;
mod worker;

pub use access_control::*;
pub use buffer_pool::*;
pub use cbor::*;
pub use error::*;
pub use identity::*;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use core::sync::atomic::Ordering;
use ockam_core::compat::sync::Arc;
use ockam_core::compat::vec::Vec;
use ockam_core::{route, Any, Result, Route, Routed, SecureChannelLocalInfo};
use ockam_core::{route, Any, Result, Route, Routed, SecureChannelLocalInfo, GLOBAL_BUFFER_POOL};
use ockam_core::{Decodable, LocalMessage};
use ockam_node::Context;

Expand Down Expand Up @@ -219,6 +219,8 @@ impl DecryptorHandler {
SecureChannelMessage::Close => self.handle_close(ctx).await?,
};

GLOBAL_BUFFER_POOL.release(decrypted_payload);

Ok(())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use ockam_core::compat::vec::Vec;
use ockam_core::errcode::{Kind, Origin};
use ockam_core::{
async_trait, route, CowBytes, Decodable, Error, LocalMessage, NeutralMessage, Route,
GLOBAL_BUFFER_POOL,
};
use ockam_core::{Any, Result, Routed, Worker};
use ockam_node::Context;
Expand Down Expand Up @@ -99,9 +100,9 @@ impl EncryptorWorker {
msg: SecureChannelPaddedMessage<'_>,
) -> Result<Vec<u8>> {
let payload = ockam_core::cbor_encode_preallocate(&msg)?;
let mut destination = Vec::with_capacity(SIZE_OF_ENCRYPT_OVERHEAD + payload.len());
let mut destination = GLOBAL_BUFFER_POOL.try_size(SIZE_OF_ENCRYPT_OVERHEAD + payload.len());

match self.encryptor.encrypt(&mut destination, &payload).await {
let result = match self.encryptor.encrypt(&mut destination, &payload).await {
Ok(()) => Ok(destination),
// If encryption failed, that means we have some internal error,
// and we may be in an invalid state, it's better to stop the Worker
Expand All @@ -111,7 +112,16 @@ impl EncryptorWorker {
ctx.stop_worker(address).await?;
Err(err)
}
};

GLOBAL_BUFFER_POOL.release(payload);
if let SecureChannelMessage::Payload(plaintext) = msg.message {
if !plaintext.payload.is_borrowed() {
GLOBAL_BUFFER_POOL.release(plaintext.payload.into_owned());
}
}

result
}

#[instrument(skip_all)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ockam_core::bare::{read_slice, write_slice};
use ockam_core::errcode::{Kind, Origin};
use ockam_core::{Encodable, Encoded, Message, NeutralMessage};
use ockam_core::{Encodable, Encoded, Message, NeutralMessage, GLOBAL_BUFFER_POOL};
use serde::{Deserialize, Serialize};

/// A command message type for a Portal
Expand Down Expand Up @@ -84,7 +84,7 @@ impl PortalMessage<'_> {
let capacity = 1 + payload.len() + if counter.is_some() { 3 } else { 1 } + {
ockam_core::bare::size_of_variable_length(payload.len() as u64)
};
let mut vec = Vec::with_capacity(capacity);
let mut vec = GLOBAL_BUFFER_POOL.try_size(capacity);
vec.push(3);
write_slice(&mut vec, payload);
// TODO: re-enable once orchestrator accepts packet counter
Expand All @@ -108,7 +108,7 @@ pub enum PortalInternalMessage {
}

/// Maximum allowed size for a payload
pub const MAX_PAYLOAD_SIZE: usize = 48 * 1024;
pub const MAX_PAYLOAD_SIZE: usize = 96 * 1024;

#[cfg(test)]
mod test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tracing::{error, instrument, warn};
/// [`TcpPortalWorker::start_receiver`](crate::TcpPortalWorker::start_receiver)
pub(crate) struct TcpPortalRecvProcessor<R> {
registry: TcpRegistry,
buf: Vec<u8>,
incoming_buffer: Vec<u8>,
read_half: R,
addresses: Addresses,
onward_route: Route,
Expand All @@ -37,7 +37,7 @@ impl<R: AsyncRead + Unpin + Send + Sync + 'static> TcpPortalRecvProcessor<R> {
) -> Self {
Self {
registry,
buf: Vec::with_capacity(MAX_PAYLOAD_SIZE),
incoming_buffer: Vec::with_capacity(MAX_PAYLOAD_SIZE),
read_half,
addresses,
onward_route,
Expand Down Expand Up @@ -67,9 +67,9 @@ impl<R: AsyncRead + Unpin + Send + Sync + 'static> Processor for TcpPortalRecvPr

#[instrument(skip_all, name = "TcpPortalRecvProcessor::process")]
async fn process(&mut self, ctx: &mut Context) -> Result<bool> {
self.buf.clear();
self.incoming_buffer.clear();

let _len = match self.read_half.read_buf(&mut self.buf).await {
let _len = match self.read_half.read_buf(&mut self.incoming_buffer).await {
Ok(len) => len,
Err(err) => {
error!("Tcp Portal connection read failed with error: {}", err);
Expand All @@ -82,7 +82,7 @@ impl<R: AsyncRead + Unpin + Send + Sync + 'static> Processor for TcpPortalRecvPr
OpenTelemetryContext::inject(&cx)
});

if self.buf.is_empty() {
if self.incoming_buffer.is_empty() {
// Notify Sender that connection was closed
ctx.set_tracing_context(tracing_context.clone());
if let Err(err) = ctx
Expand Down Expand Up @@ -113,7 +113,7 @@ impl<R: AsyncRead + Unpin + Send + Sync + 'static> Processor for TcpPortalRecvPr
}

// Loop just in case buf was extended (should not happen though)
for chunk in self.buf.chunks(MAX_PAYLOAD_SIZE) {
for chunk in self.incoming_buffer.chunks(MAX_PAYLOAD_SIZE) {
let msg = LocalMessage::new()
.with_tracing_context(tracing_context.clone())
.with_onward_route(self.onward_route.clone())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use ockam_core::compat::{boxed::Box, sync::Arc};
use ockam_core::{
async_trait, AllowOnwardAddress, AllowSourceAddress, Decodable, DenyAll, IncomingAccessControl,
LocalInfoIdentifier, Mailbox, Mailboxes, OutgoingAccessControl, SecureChannelLocalInfo,
GLOBAL_BUFFER_POOL,
};
use ockam_core::{Any, Result, Route, Routed, Worker};
use ockam_node::{Context, ProcessorBuilder, WorkerBuilder};
Expand Down Expand Up @@ -509,11 +510,11 @@ impl Worker for TcpPortalWorker {
// Send to Tcp stream
match msg {
PortalMessage::Payload(payload, packet_counter) => {
self.handle_payload(ctx, payload, packet_counter).await
self.handle_payload(ctx, payload, packet_counter).await?;
}
PortalMessage::Disconnect => {
self.start_disconnection(ctx, DisconnectionReason::Remote)
.await
.await?;
}
PortalMessage::Ping | PortalMessage::Pong => {
return Err(TransportError::Protocol)?;
Expand All @@ -524,8 +525,11 @@ impl Worker for TcpPortalWorker {
if msg != PortalInternalMessage::Disconnect {
return Err(TransportError::Protocol)?;
};
self.handle_disconnect(ctx).await
self.handle_disconnect(ctx).await?;
}

GLOBAL_BUFFER_POOL.release(payload);
Ok(())
}
State::SendPing { .. } | State::SendPong { .. } => {
return Err(TransportError::PortalInvalidState)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ impl Worker for RemoteWorker {
let their_identifier = SecureChannelLocalInfo::find_info(msg.local_message())
.map(|l| l.their_identifier())
.ok();
let return_route = msg.return_route();
let payload = msg.into_payload();
let msg = msg.into_local_message();
let return_route = msg.return_route;
let payload = msg.payload;

// TODO: Add borrowing
let msg: OckamPortalPacket = minicbor::decode(&payload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use minicbor::{CborLen, Decode, Encode};
use ockam_core::compat::string::String;
#[cfg(feature = "std")]
use ockam_core::OpenTelemetryContext;
use ockam_core::{CowBytes, LocalMessage, Route};
use ockam_core::{CowBytes, LocalMessage, Route, GLOBAL_BUFFER_POOL};

/// TCP transport message type.
#[derive(Debug, Clone, Eq, PartialEq, Encode, Decode, CborLen)]
Expand Down Expand Up @@ -57,10 +57,19 @@ impl From<TcpTransportMessage<'_>> for LocalMessage {
#[cfg(feature = "std")]
let local_message = local_message.with_tracing_context(value.tracing_context());

let payload = if !value.payload.is_borrowed() {
value.payload.into_owned()
} else {
let mut payload = GLOBAL_BUFFER_POOL.try_size(value.payload.len());
payload.resize(value.payload.len(), 0);
payload.copy_from_slice(&value.payload);
payload
};

local_message
.with_onward_route(value.onward_route)
.with_return_route(value.return_route)
.with_payload(value.payload.into_owned())
.with_payload(payload)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use tracing::{info, instrument, trace};
/// the node message system.
pub(crate) struct TcpRecvProcessor {
registry: TcpRegistry,
incoming_buffer: Vec<u8>,
read_half: OwnedReadHalf,
socket_address: SocketAddr,
addresses: Addresses,
Expand All @@ -47,6 +48,7 @@ impl TcpRecvProcessor {
) -> Self {
Self {
registry,
incoming_buffer: Vec::new(),
read_half,
socket_address,
addresses,
Expand Down Expand Up @@ -210,10 +212,12 @@ impl Processor for TcpRecvProcessor {
trace!("Received message header for {} bytes", len);

// Allocate a buffer of that size
let mut buf = vec![0; len_usize];
self.incoming_buffer.clear();
self.incoming_buffer.reserve(len_usize);
self.incoming_buffer.resize(len_usize, 0);

// Then read into the buffer
match self.read_half.read_exact(&mut buf).await {
match self.read_half.read_exact(&mut self.incoming_buffer).await {
Ok(_) => {}
Err(e) => {
self.notify_sender_stream_dropped(ctx, e).await?;
Expand All @@ -222,7 +226,7 @@ impl Processor for TcpRecvProcessor {
}

// Deserialize the message now
let transport_message: TcpTransportMessage = match minicbor::decode(&buf) {
let transport_message: TcpTransportMessage = match minicbor::decode(&self.incoming_buffer) {
Ok(msg) => msg,
Err(e) => {
self.notify_sender_stream_dropped(ctx, e).await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ockam_core::flow_control::FlowControlId;
use ockam_core::{
async_trait,
compat::{net::SocketAddr, sync::Arc},
AllowAll, AllowSourceAddress, DenyAll, LocalMessage,
AllowAll, AllowSourceAddress, DenyAll, LocalMessage, GLOBAL_BUFFER_POOL,
};
use ockam_core::{Any, Decodable, Mailbox, Mailboxes, Message, Result, Routed, Worker};
use ockam_node::{Context, WorkerBuilder};
Expand Down Expand Up @@ -138,6 +138,10 @@ impl TcpSendWorker {
minicbor::encode(&transport_message, &mut self.buffer)
.map_err(|_| TransportError::Encoding)?;

if !transport_message.payload.is_borrowed() {
GLOBAL_BUFFER_POOL.release(transport_message.payload.into_owned());
}

// Should not ever happen...
if self.buffer.len() < LENGTH_VALUE_SIZE {
return Err(TransportError::Encoding)?;
Expand Down Expand Up @@ -244,7 +248,6 @@ impl Worker for TcpSendWorker {
if let Err(err) = self.serialize_message(local_message) {
// Close the stream
self.stop(ctx).await?;

return Err(err);
};

Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions tools/profile/portal.memory_profile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fi

"${OCKAM}" node delete portal -y >/dev/null 2>&1 || true
export OCKAM_LOG_LEVEL=info
export OCKAM_OPENTELEMETRY_EXPORT=0

if [ "$(uname)" == "Darwin" ]; then
rm -rf /tmp/ockam.trace/
Expand Down
Loading

0 comments on commit d1698c0

Please sign in to comment.