Skip to content

Commit

Permalink
feat(rust): encoding and allocation optimizations for privileged portals
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjoDeundiak committed Nov 21, 2024
1 parent f57af1b commit e358181
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 70 deletions.
5 changes: 5 additions & 0 deletions implementations/rust/ockam/ockam_core/src/cbor/cow_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ impl CowBytes<'_> {
pub fn into_owned(self) -> Vec<u8> {
self.0.into_owned()
}

/// Return underlying slice
pub fn as_slice(&self) -> &[u8] {
self.0.as_ref()
}
}

impl<'a> From<&'a [u8]> for CowBytes<'a> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl SecureChannelLocalInfo {
pub fn to_local_info(&self) -> Result<LocalInfo> {
Ok(LocalInfo::new(
SECURE_CHANNEL_IDENTIFIER.into(),
minicbor::to_vec(&self.their_identifier)?,
crate::cbor_encode_preallocate(&self.their_identifier)?,
))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct TcpTransportEbpfSupport {

links: Arc<Mutex<HashMap<Iface, IfaceLink>>>,

tcp_packet_writer: Arc<AsyncMutex<Option<Arc<dyn TcpPacketWriter>>>>,
tcp_packet_writer: Arc<AsyncMutex<Option<Box<dyn TcpPacketWriter>>>>,
raw_socket_processor_address: Address,

bpf: Arc<Mutex<Option<OckamBpf>>>,
Expand Down Expand Up @@ -77,12 +77,12 @@ impl TcpTransportEbpfSupport {
pub(crate) async fn start_raw_socket_processor_if_needed(
&self,
ctx: &Context,
) -> Result<Arc<dyn TcpPacketWriter>> {
) -> Result<Box<dyn TcpPacketWriter>> {
debug!("Starting RawSocket");

let mut tcp_packet_writer_lock = self.tcp_packet_writer.lock().await;
if let Some(tcp_packet_writer_lock) = tcp_packet_writer_lock.as_ref() {
return Ok(tcp_packet_writer_lock.clone());
return Ok(tcp_packet_writer_lock.create_new_box());
}

let (processor, tcp_packet_writer) = RawSocketProcessor::create(
Expand All @@ -92,7 +92,7 @@ impl TcpTransportEbpfSupport {
)
.await?;

*tcp_packet_writer_lock = Some(tcp_packet_writer.clone());
*tcp_packet_writer_lock = Some(tcp_packet_writer.create_new_box());

ctx.start_processor(self.raw_socket_processor_address.clone(), processor)
.await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl TcpPacketReader for AsyncFdPacketReader {
} = Self::parse_headers(&mut self.buffer[..len])?;

let header_and_payload =
match TcpStrippedHeaderAndPayload::new(self.buffer[offset..len].to_vec()) {
match TcpStrippedHeaderAndPayload::new(self.buffer[offset..len].to_vec().into()) {
Some(header_and_payload) => header_and_payload,
None => {
return Err(TransportError::ParsingHeaders(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,49 @@ use tokio::io::Interest;

/// RawSocket packet writer implemented via tokio's AsyncFd
pub struct AsyncFdPacketWriter {
// The idea is that each writer has its own buffer, so that we avoid either
// 1. Waiting for the lock on a shared buffer
// 2. Allocating new buffer on every write operations
buffer: Vec<u8>,
fd: Arc<AsyncFd<OwnedFd>>,
}

impl AsyncFdPacketWriter {
/// Constructor
pub fn new(fd: Arc<AsyncFd<OwnedFd>>) -> Self {
Self { fd }
Self { buffer: vec![], fd }
}
}

#[async_trait]
impl TcpPacketWriter for AsyncFdPacketWriter {
async fn write_packet(
&self,
&mut self,
src_port: Port,
destination_ip: Ipv4Addr,
dst_ip: Ipv4Addr,
dst_port: Port,
header_and_payload: TcpStrippedHeaderAndPayload,
header_and_payload: TcpStrippedHeaderAndPayload<'_>,
) -> Result<()> {
// We need to prepend ports to the beginning of the header, instead of cloning, let's
// add this data to the end and reverse the whole binary few types for the same result,
// but should be more efficient
let mut packet = header_and_payload.take();
packet.reserve(4);
packet.reverse();
self.buffer.clear();
self.buffer.reserve(header_and_payload.len() + 4);

let mut ports = [0u8; 4];
let mut ports_view = tcp_header_ports::View::new(&mut ports);
ports_view.source_mut().write(src_port);
ports_view.dest_mut().write(dst_port);
ports.reverse();

packet.extend_from_slice(&ports[..]);
packet.reverse();
self.buffer.extend_from_slice(ports.as_slice());
self.buffer.extend_from_slice(header_and_payload.as_slice());

tcp_set_checksum(Ipv4Addr::UNSPECIFIED, destination_ip, &mut packet);
tcp_set_checksum(Ipv4Addr::UNSPECIFIED, dst_ip, &mut self.buffer);

let destination_addr = SockaddrIn::from(SocketAddrV4::new(destination_ip, 0));
let destination_addr = SockaddrIn::from(SocketAddrV4::new(dst_ip, 0));

// We don't pick source IP, kernel does it for us by performing Routing Table lookup.
// The problem is that if for some reason tcp packets from one connection
// use different src_ip, the connection would be disrupted.
// As an alternative, we could build IPv4 header ourselves and control it by setting
// IP_HDRINCL socket option, but that brings a lot of challenges.

enum WriteResult {
Ok { len: usize },
Expand All @@ -63,7 +68,7 @@ impl TcpPacketWriter for AsyncFdPacketWriter {
.async_io(Interest::WRITABLE, |fd| {
let res = nix::sys::socket::sendto(
fd.as_raw_fd(),
packet.as_slice(),
self.buffer.as_slice(),
&destination_addr,
MsgFlags::empty(),
);
Expand All @@ -90,14 +95,19 @@ impl TcpPacketWriter for AsyncFdPacketWriter {
WriteResult::Err(err) => return Err(err)?,
};

if len != packet.len() {
if len != self.buffer.len() {
return Err(TransportError::RawSocketWrite(format!(
"Could not write the whole packet. Packet len: {}. Actually written: {}",
packet.len(),
self.buffer.len(),
len
)))?;
}

Ok(())
}

fn create_new_box(&self) -> Box<dyn TcpPacketWriter> {
// fd is shared. buffer is allocated each time we clone the writer
Box::new(AsyncFdPacketWriter::new(self.fd.clone()))
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::privileged_portal::packet::TcpStrippedHeaderAndPayload;
use minicbor::{Decode, Encode};
use minicbor::{CborLen, Decode, Encode};
use ockam_core::CowBytes;
use rand::distributions::{Distribution, Standard};
use rand::Rng;

Expand All @@ -14,7 +15,7 @@ pub type Port = u16;
pub type Proto = u8;

/// Unique random connection identifier
#[derive(Clone, Debug, Eq, PartialEq, Hash, Encode, Decode)]
#[derive(Clone, Debug, Eq, PartialEq, Hash, Encode, Decode, CborLen)]
#[cbor(transparent)]
#[rustfmt::skip]
pub struct ConnectionIdentifier(#[n(0)] u64);
Expand All @@ -26,20 +27,20 @@ impl Distribution<ConnectionIdentifier> for Standard {
}

/// Packet exchanged between the Inlet and the Outlet
#[derive(Encode, Decode)]
#[derive(Encode, Decode, CborLen)]
#[rustfmt::skip]
pub struct OckamPortalPacket {
pub struct OckamPortalPacket<'a> {
/// Unique TCP connection identifier
#[n(0)] pub connection_identifier: ConnectionIdentifier,
/// Monotonic increasing route numeration
#[n(1)] pub route_index: u32,
/// Stripped (without ports) TCP header and payload
#[n(2)] pub header_and_payload: Vec<u8>,
#[b(2)] pub header_and_payload: CowBytes<'a>,
}

impl OckamPortalPacket {
impl<'a> OckamPortalPacket<'a> {
/// Dissolve into parts consuming the original value to avoid clones
pub fn dissolve(self) -> Option<(ConnectionIdentifier, u32, TcpStrippedHeaderAndPayload)> {
pub fn dissolve(self) -> Option<(ConnectionIdentifier, u32, TcpStrippedHeaderAndPayload<'a>)> {
let header_and_payload = TcpStrippedHeaderAndPayload::new(self.header_and_payload)?;

Some((
Expand All @@ -53,7 +54,7 @@ impl OckamPortalPacket {
pub fn from_tcp_packet(
connection_identifier: ConnectionIdentifier,
route_index: u32,
header_and_payload: TcpStrippedHeaderAndPayload,
header_and_payload: TcpStrippedHeaderAndPayload<'a>,
) -> Self {
Self {
connection_identifier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ use tokio::io::unix::AsyncFd;
/// AsyncFd
pub fn create_async_fd_raw_socket(
proto: Proto,
) -> Result<(Arc<dyn TcpPacketWriter>, Box<dyn TcpPacketReader>)> {
) -> Result<(Box<dyn TcpPacketWriter>, Box<dyn TcpPacketReader>)> {
let fd = create_raw_socket_fd(proto)?;
let fd = Arc::new(fd);

let writer = AsyncFdPacketWriter::new(fd.clone());
let writer = Arc::new(writer);
let writer = Box::new(writer);

let reader = AsyncFdPacketReader::new(fd);
let reader = Box::new(reader);
Expand All @@ -30,20 +30,22 @@ pub fn create_async_fd_raw_socket(
fn create_raw_socket_fd(proto: Proto) -> Result<AsyncFd<OwnedFd>> {
// Unfortunately, SockProtocol enum doesn't support arbitrary values
let proto: SockProtocol = unsafe { mem::transmute(proto as i32) };
let socket = nix::sys::socket::socket(
let res = nix::sys::socket::socket(
AddressFamily::Inet,
SockType::Raw,
SockFlag::SOCK_NONBLOCK,
Some(proto),
);

let socket = match socket {
let socket = match res {
Ok(socket) => socket,
Err(err) => {
return Err(TransportError::RawSocketCreation(err.to_string()))?;
}
};

// TODO: It's possible to bind that socket to an IP if needed

let res = unsafe {
// We don't want to construct IPv4 header ourselves, for receiving it will be included
// nevertheless
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::privileged_portal::packet_binary::{ipv4_header, stripped_tcp_header, tcp_header};
use crate::privileged_portal::Port;
use ockam_core::CowBytes;
use std::net::Ipv4Addr;

/// Result of reading packet from RawSocket
Expand All @@ -9,15 +10,15 @@ pub struct RawSocketReadResult {
/// Info from TCP header
pub tcp_info: TcpInfo,
/// Part of the TCP header (without ports) and TCP payload
pub header_and_payload: TcpStrippedHeaderAndPayload,
pub header_and_payload: TcpStrippedHeaderAndPayload<'static>,
}

/// TCP Header excluding first 4 bytes (src and dst ports) + payload
pub struct TcpStrippedHeaderAndPayload(Vec<u8>);
pub struct TcpStrippedHeaderAndPayload<'a>(CowBytes<'a>);

impl TcpStrippedHeaderAndPayload {
impl<'a> TcpStrippedHeaderAndPayload<'a> {
/// Constructor
pub fn new(bytes: Vec<u8>) -> Option<Self> {
pub fn new(bytes: CowBytes<'a>) -> Option<Self> {
if bytes.len() < 16 {
return None;
}
Expand All @@ -26,10 +27,15 @@ impl TcpStrippedHeaderAndPayload {
}

/// Consume and return the data
pub fn take(self) -> Vec<u8> {
pub fn take(self) -> CowBytes<'a> {
self.0
}

/// Return underlying slice
pub fn as_slice(&self) -> &[u8] {
self.0.as_slice()
}

/// Length
pub fn len(&self) -> usize {
self.0.len()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ use std::net::Ipv4Addr;
pub trait TcpPacketWriter: Send + Sync + 'static {
/// Write packet to the RawSocket
async fn write_packet(
&self,
&mut self,
src_port: Port,
destination_ip: Ipv4Addr,
dst_ip: Ipv4Addr,
dst_port: Port,
header_and_payload: TcpStrippedHeaderAndPayload,
header_and_payload: TcpStrippedHeaderAndPayload<'_>,
) -> Result<()>;

/// Clone current implementation and wrap in a Box
fn create_new_box(&self) -> Box<dyn TcpPacketWriter>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ use aya::programs::tc::{qdisc_detach_program, TcAttachType};
use log::{error, info, warn};
use ockam_core::Result;
use ockam_transport_core::TransportError;
use std::sync::Arc;

impl TcpTransport {
/// Start [`RawSocketProcessor`]. Should be done once.
pub(crate) async fn start_raw_socket_processor_if_needed(
&self,
) -> Result<Arc<dyn TcpPacketWriter>> {
) -> Result<Box<dyn TcpPacketWriter>> {
self.ebpf_support
.start_raw_socket_processor_if_needed(self.ctx())
.await
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::privileged_portal::packet::RawSocketReadResult;
use crate::privileged_portal::{Inlet, InletConnection, OckamPortalPacket, Outlet, PortalMode};
use log::{debug, trace, warn};
use ockam_core::{async_trait, route, LocalInfoIdentifier, LocalMessage, Processor, Result};
use ockam_core::{
async_trait, cbor_encode_preallocate, route, LocalInfoIdentifier, LocalMessage, Processor,
Result,
};
use ockam_node::Context;
use ockam_transport_core::TransportError;
use rand::random;
Expand Down Expand Up @@ -129,7 +132,7 @@ impl Processor for InternalProcessor {
LocalMessage::new()
.with_onward_route(inlet_shared_state.route().clone())
.with_return_route(route![inlet.remote_worker_address.clone()])
.with_payload(minicbor::to_vec(portal_packet)?),
.with_payload(cbor_encode_preallocate(&portal_packet)?),
ctx.address(),
)
.await?;
Expand Down Expand Up @@ -170,7 +173,7 @@ impl Processor for InternalProcessor {
LocalMessage::new()
.with_onward_route(return_route)
.with_return_route(route![outlet.remote_worker_address.clone()])
.with_payload(minicbor::to_vec(portal_packet)?),
.with_payload(cbor_encode_preallocate(&portal_packet)?),
ctx.address(),
)
.await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use log::trace;
use ockam_core::{async_trait, Processor, Result};
use ockam_node::Context;
use ockam_transport_core::TransportError;
use std::sync::Arc;

/// Processor responsible for receiving all data with OCKAM_TCP_PORTAL_PROTOCOL on the machine
/// and redirect it to individual portal workers.
Expand All @@ -23,7 +22,7 @@ impl RawSocketProcessor {
ip_proto: u8,
inlet_registry: InletRegistry,
outlet_registry: OutletRegistry,
) -> Result<(Self, Arc<dyn TcpPacketWriter>)> {
) -> Result<(Self, Box<dyn TcpPacketWriter>)> {
let (tcp_packet_writer, tcp_packet_reader) = create_async_fd_raw_socket(ip_proto)?;

let s = Self {
Expand Down
Loading

0 comments on commit e358181

Please sign in to comment.