Skip to content

Commit

Permalink
feat(rust): added global buffer allocator to reduce memory fragmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
davide-baldo committed Nov 14, 2024
1 parent 05fbf4b commit bee4614
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 29 deletions.
6 changes: 3 additions & 3 deletions implementations/rust/ockam/ockam_api/tests/portals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ fn portal_low_bandwidth_connection_keep_working_for_60s() {
// ┌────────┐ ┌───────────┐ ┌────────┐
// │ Node └─────► TCP └────────► Node │
// │ 1 ◄─────┐Passthrough◄────────┐ 2 │
// └────┬───┘ │ 64KB/s │ └────▲───┘
// └────┬───┘ │ 128KB/s │ └────▲───┘
// │ └───────────┘ │
// │ ┌───────────┐ │
// │ Portal │ TCP │ Outlet │
Expand Down Expand Up @@ -270,8 +270,8 @@ fn portal_low_bandwidth_connection_keep_working_for_60s() {

let passthrough_server_handle = start_passthrough_server(
&second_node_listen_address.to_string(),
Disruption::LimitBandwidth(64 * 1024),
Disruption::LimitBandwidth(64 * 1024),
Disruption::LimitBandwidth(128 * 1024),
Disruption::LimitBandwidth(128 * 1024),
)
.await;

Expand Down
100 changes: 100 additions & 0 deletions implementations/rust/ockam/ockam_core/src/buffer_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#[cfg(feature = "std")]
use once_cell::sync::Lazy;

/// The global instance of [`BufferPool`].
/// The goal of this pool is to reduce memory fragmentation by keep reusing the same buffers.
#[cfg(feature = "std")]
pub static GLOBAL_BUFFER_POOL: Lazy<std::BufferPool> = Lazy::new(std::BufferPool::new);

#[cfg(not(feature = "std"))]
pub static GLOBAL_BUFFER_POOL: BufferPool = BufferPool {};

#[cfg(feature = "std")]
mod std {
use crate::compat::sync::Mutex;
use crate::compat::vec::Vec;

const MIN_BUFFER_SIZE: usize = 96 * 1024;
const MAX_BUFFER_SIZE: usize = 192 * 1024;
const MAX_BUFFERS: usize = 32;

/// A buffer pool for reusing buffers at least big as [`MIN_BUFFER_SIZE`].
pub struct BufferPool {
buffers: Mutex<Inner>,
}

struct Inner {
buffers: Vec<Vec<u8>>,
}

impl BufferPool {
pub(super) fn new() -> Self {
Self {
buffers: Mutex::new(Inner {
buffers: Vec::new(),
}),
}
}

/// When the size is big enough, it'll return a buffer from the pool,
/// otherwise it'll return a new buffer.
pub fn try_size(&self, size: usize) -> Vec<u8> {
if (MIN_BUFFER_SIZE..=MAX_BUFFER_SIZE).contains(&size) {
self.take()
} else {
Vec::with_capacity(size)
}
}

/// Take a buffer from the pool.
pub fn take(&self) -> Vec<u8> {
let mut buffers = self.buffers.lock().unwrap();
if let Some(mut buffer) = buffers.buffers.pop() {
buffer.clear();
buffer
} else {
Vec::with_capacity(MIN_BUFFER_SIZE)
}
}

/// Release a buffer back to the pool, the buffer will only be reused if
/// it's capacity is within [`MIN_BUFFER_SIZE`] and [`MAX_BUFFER_SIZE`].
pub fn release(&self, buffer: Vec<u8>) {
if buffer.capacity() >= MIN_BUFFER_SIZE && buffer.capacity() <= MAX_BUFFER_SIZE {
let mut buffers = self.buffers.lock().unwrap();
if buffers.buffers.len() < MAX_BUFFERS {
buffers.buffers.push(buffer);
// we can assume the smaller allocations are the newer ones,
// by returning the smaller ones first, we should be able
// to avoid "refreshing" the buffer pool too often
buffers.buffers.sort_by_key(|b| -(b.capacity() as i64));
}
}
}
}
}

#[cfg(not(feature = "std"))]
mod no_std {
use crate::compat::vec::Vec;

/// A buffer pool for reusing buffers at least big as [`MIN_BUFFER_SIZE`].
pub struct BufferPool;

impl BufferPool {
/// When the size is big enough, it'll return a buffer from the pool,
/// otherwise it'll return a new buffer.
pub fn try_size(&self, size: usize) -> Vec<u8> {
Vec::with_capacity(size)
}

/// Take a buffer from the pool.
pub fn take(&self) -> Vec<u8> {
Vec::new()
}

/// Release a buffer back to the pool, the buffer will only be reused if
/// it's capacity is within [`MIN_BUFFER_SIZE`] and [`MAX_BUFFER_SIZE`].
pub fn release(&self, buffer: Vec<u8>) {}
}
}
4 changes: 2 additions & 2 deletions implementations/rust/ockam/ockam_core/src/cbor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod cow_str;
pub(crate) mod schema;

use crate::compat::vec::Vec;
use crate::Result;
use crate::{Result, GLOBAL_BUFFER_POOL};
use minicbor::{CborLen, Encode};

/// Encode a type implementing [`Encode`] and return the encoded byte vector.
Expand All @@ -18,7 +18,7 @@ where
T: Encode<()> + CborLen<()>,
{
let expected_len = minicbor::len(&x);
let mut output = Vec::with_capacity(expected_len);
let mut output = GLOBAL_BUFFER_POOL.try_size(expected_len);
minicbor::encode(x, &mut output)?;
Ok(output)
}
2 changes: 2 additions & 0 deletions implementations/rust/ockam/ockam_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub mod hex_encoding;
pub mod env;

pub mod bare;
mod buffer_pool;
mod cbor;
mod error;
mod identity;
Expand All @@ -88,6 +89,7 @@ mod uint;
mod worker;

pub use access_control::*;
pub use buffer_pool::*;
pub use cbor::*;
pub use error::*;
pub use identity::*;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use core::sync::atomic::Ordering;
use ockam_core::compat::sync::Arc;
use ockam_core::compat::vec::Vec;
use ockam_core::{route, Any, Result, Route, Routed, SecureChannelLocalInfo};
use ockam_core::{route, Any, Result, Route, Routed, SecureChannelLocalInfo, GLOBAL_BUFFER_POOL};
use ockam_core::{Decodable, LocalMessage};
use ockam_node::Context;

Expand Down Expand Up @@ -219,6 +219,8 @@ impl DecryptorHandler {
SecureChannelMessage::Close => self.handle_close(ctx).await?,
};

GLOBAL_BUFFER_POOL.release(decrypted_payload);

Ok(())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use ockam_core::compat::vec::Vec;
use ockam_core::errcode::{Kind, Origin};
use ockam_core::{
async_trait, route, CowBytes, Decodable, Error, LocalMessage, NeutralMessage, Route,
GLOBAL_BUFFER_POOL,
};
use ockam_core::{Any, Result, Routed, Worker};
use ockam_node::Context;
Expand Down Expand Up @@ -99,9 +100,9 @@ impl EncryptorWorker {
msg: SecureChannelPaddedMessage<'_>,
) -> Result<Vec<u8>> {
let payload = ockam_core::cbor_encode_preallocate(&msg)?;
let mut destination = Vec::with_capacity(SIZE_OF_ENCRYPT_OVERHEAD + payload.len());
let mut destination = GLOBAL_BUFFER_POOL.try_size(SIZE_OF_ENCRYPT_OVERHEAD + payload.len());

match self.encryptor.encrypt(&mut destination, &payload).await {
let result = match self.encryptor.encrypt(&mut destination, &payload).await {
Ok(()) => Ok(destination),
// If encryption failed, that means we have some internal error,
// and we may be in an invalid state, it's better to stop the Worker
Expand All @@ -111,7 +112,16 @@ impl EncryptorWorker {
ctx.stop_worker(address).await?;
Err(err)
}
};

GLOBAL_BUFFER_POOL.release(payload);
if let SecureChannelMessage::Payload(plaintext) = msg.message {
if !plaintext.payload.is_borrowed() {
GLOBAL_BUFFER_POOL.release(plaintext.payload.into_owned());
}
}

result
}

#[instrument(skip_all)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ockam_core::bare::{read_slice, write_slice};
use ockam_core::errcode::{Kind, Origin};
use ockam_core::{Encodable, Encoded, Message, NeutralMessage};
use ockam_core::{Encodable, Encoded, Message, NeutralMessage, GLOBAL_BUFFER_POOL};
use serde::{Deserialize, Serialize};

/// A command message type for a Portal
Expand Down Expand Up @@ -84,7 +84,7 @@ impl PortalMessage<'_> {
let capacity = 1 + payload.len() + if counter.is_some() { 3 } else { 1 } + {
ockam_core::bare::size_of_variable_length(payload.len() as u64)
};
let mut vec = Vec::with_capacity(capacity);
let mut vec = GLOBAL_BUFFER_POOL.try_size(capacity);
vec.push(3);
write_slice(&mut vec, payload);
// TODO: re-enable once orchestrator accepts packet counter
Expand All @@ -108,7 +108,7 @@ pub enum PortalInternalMessage {
}

/// Maximum allowed size for a payload
pub const MAX_PAYLOAD_SIZE: usize = 48 * 1024;
pub const MAX_PAYLOAD_SIZE: usize = 96 * 1024;

#[cfg(test)]
mod test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tracing::{error, instrument, warn};
/// [`TcpPortalWorker::start_receiver`](crate::TcpPortalWorker::start_receiver)
pub(crate) struct TcpPortalRecvProcessor<R> {
registry: TcpRegistry,
buf: Vec<u8>,
incoming_buffer: Vec<u8>,
read_half: R,
addresses: Addresses,
onward_route: Route,
Expand All @@ -37,7 +37,7 @@ impl<R: AsyncRead + Unpin + Send + Sync + 'static> TcpPortalRecvProcessor<R> {
) -> Self {
Self {
registry,
buf: Vec::with_capacity(MAX_PAYLOAD_SIZE),
incoming_buffer: Vec::with_capacity(MAX_PAYLOAD_SIZE),
read_half,
addresses,
onward_route,
Expand Down Expand Up @@ -67,9 +67,9 @@ impl<R: AsyncRead + Unpin + Send + Sync + 'static> Processor for TcpPortalRecvPr

#[instrument(skip_all, name = "TcpPortalRecvProcessor::process")]
async fn process(&mut self, ctx: &mut Context) -> Result<bool> {
self.buf.clear();
self.incoming_buffer.clear();

let _len = match self.read_half.read_buf(&mut self.buf).await {
let _len = match self.read_half.read_buf(&mut self.incoming_buffer).await {
Ok(len) => len,
Err(err) => {
error!("Tcp Portal connection read failed with error: {}", err);
Expand All @@ -82,7 +82,7 @@ impl<R: AsyncRead + Unpin + Send + Sync + 'static> Processor for TcpPortalRecvPr
OpenTelemetryContext::inject(&cx)
});

if self.buf.is_empty() {
if self.incoming_buffer.is_empty() {
// Notify Sender that connection was closed
ctx.set_tracing_context(tracing_context.clone());
if let Err(err) = ctx
Expand Down Expand Up @@ -113,7 +113,7 @@ impl<R: AsyncRead + Unpin + Send + Sync + 'static> Processor for TcpPortalRecvPr
}

// Loop just in case buf was extended (should not happen though)
for chunk in self.buf.chunks(MAX_PAYLOAD_SIZE) {
for chunk in self.incoming_buffer.chunks(MAX_PAYLOAD_SIZE) {
let msg = LocalMessage::new()
.with_tracing_context(tracing_context.clone())
.with_onward_route(self.onward_route.clone())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use ockam_core::compat::{boxed::Box, sync::Arc};
use ockam_core::{
async_trait, AllowOnwardAddress, AllowSourceAddress, Decodable, DenyAll, IncomingAccessControl,
LocalInfoIdentifier, Mailbox, Mailboxes, OutgoingAccessControl, SecureChannelLocalInfo,
GLOBAL_BUFFER_POOL,
};
use ockam_core::{Any, Result, Route, Routed, Worker};
use ockam_node::{Context, ProcessorBuilder, WorkerBuilder};
Expand Down Expand Up @@ -509,11 +510,11 @@ impl Worker for TcpPortalWorker {
// Send to Tcp stream
match msg {
PortalMessage::Payload(payload, packet_counter) => {
self.handle_payload(ctx, payload, packet_counter).await
self.handle_payload(ctx, payload, packet_counter).await?;
}
PortalMessage::Disconnect => {
self.start_disconnection(ctx, DisconnectionReason::Remote)
.await
.await?;
}
PortalMessage::Ping | PortalMessage::Pong => {
return Err(TransportError::Protocol)?;
Expand All @@ -524,8 +525,11 @@ impl Worker for TcpPortalWorker {
if msg != PortalInternalMessage::Disconnect {
return Err(TransportError::Protocol)?;
};
self.handle_disconnect(ctx).await
self.handle_disconnect(ctx).await?;
}

GLOBAL_BUFFER_POOL.release(payload);
Ok(())
}
State::SendPing { .. } | State::SendPong { .. } => {
return Err(TransportError::PortalInvalidState)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ impl Worker for RemoteWorker {
let their_identifier = SecureChannelLocalInfo::find_info(msg.local_message())
.map(|l| l.their_identifier())
.ok();
let return_route = msg.return_route();
let payload = msg.into_payload();
let msg = msg.into_local_message();
let return_route = msg.return_route;
let payload = msg.payload;

// TODO: Add borrowing
let msg: OckamPortalPacket = minicbor::decode(&payload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use minicbor::{CborLen, Decode, Encode};
use ockam_core::compat::string::String;
#[cfg(feature = "std")]
use ockam_core::OpenTelemetryContext;
use ockam_core::{CowBytes, LocalMessage, Route};
use ockam_core::{CowBytes, LocalMessage, Route, GLOBAL_BUFFER_POOL};

/// TCP transport message type.
#[derive(Debug, Clone, Eq, PartialEq, Encode, Decode, CborLen)]
Expand Down Expand Up @@ -57,10 +57,19 @@ impl From<TcpTransportMessage<'_>> for LocalMessage {
#[cfg(feature = "std")]
let local_message = local_message.with_tracing_context(value.tracing_context());

let payload = if !value.payload.is_borrowed() {
value.payload.into_owned()
} else {
let mut payload = GLOBAL_BUFFER_POOL.try_size(value.payload.len());
payload.resize(value.payload.len(), 0);
payload.copy_from_slice(&value.payload);
payload
};

local_message
.with_onward_route(value.onward_route)
.with_return_route(value.return_route)
.with_payload(value.payload.into_owned())
.with_payload(payload)
}
}

Expand Down
Loading

0 comments on commit bee4614

Please sign in to comment.