Skip to content

Commit

Permalink
feat(rust): cache and reuse tcp connections
Browse files Browse the repository at this point in the history
  • Loading branch information
davide-baldo committed Dec 19, 2024
1 parent 2ff6e39 commit a187bdc
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl From<StaticHostnamePort> for HostnamePort {
}

/// Hostname and port
#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode, CborLen)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode, CborLen)]
#[rustfmt::skip]
pub struct HostnamePort {
#[n(0)] hostname: String,
Expand Down Expand Up @@ -158,6 +158,10 @@ impl HostnamePort {

Ok(HostnamePort::new(hostname, port))
}

pub fn is_localhost(&self) -> bool {
self.hostname == "localhost" || self.hostname == "127.0.0.1"
}
}

impl From<SocketAddr> for HostnamePort {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::transport::connect;
use crate::transport::{connect, CachedConnectionsQueue};
use crate::workers::{Addresses, TcpRecvProcessor, TcpSendWorker};
use crate::{TcpConnectionMode, TcpConnectionOptions, TcpTransport};
use core::fmt;
Expand All @@ -10,6 +10,8 @@ use ockam_core::{Address, Result};
use ockam_node::Context;
use ockam_transport_core::HostnamePort;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex as SyncMutex, Weak};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tracing::debug;

/// Result of [`TcpTransport::connect`] call.
Expand Down Expand Up @@ -84,6 +86,52 @@ impl TcpConnection {
}
}

pub(crate) struct ConnectionCollector {
hostname_port: HostnamePort,
connections_queue: Weak<CachedConnectionsQueue>,
read_half: SyncMutex<Option<OwnedReadHalf>>,
write_half: SyncMutex<Option<OwnedWriteHalf>>,
}

impl ConnectionCollector {
fn new(hostname_port: HostnamePort, connections_queue: &Arc<CachedConnectionsQueue>) -> Self {
Self {
hostname_port,
connections_queue: Arc::downgrade(connections_queue),
read_half: SyncMutex::new(None),
write_half: SyncMutex::new(None),
}
}

pub(crate) fn collect_read_half(&self, read_half: OwnedReadHalf) {
debug!("Collecting read half for {}", self.hostname_port);
self.read_half.lock().unwrap().replace(read_half);
self.check_and_push_connection();
}

pub(crate) fn collect_write_half(&self, write_half: OwnedWriteHalf) {
debug!("Collecting write half for {}", self.hostname_port);
self.write_half.lock().unwrap().replace(write_half);
self.check_and_push_connection();
}

fn check_and_push_connection(&self) {
let mut read_half = self.read_half.lock().unwrap();
let mut write_half = self.write_half.lock().unwrap();

if read_half.is_some() && write_half.is_some() {
if let Some(connections_queue) = self.connections_queue.upgrade() {
let read_half = read_half.take().unwrap();
let write_half = write_half.take().unwrap();

let mut guard = connections_queue.lock().unwrap();
let connections = guard.entry(self.hostname_port.clone()).or_default();
connections.push_back((read_half, write_half));
}
}
}
}

impl TcpTransport {
/// Establish an outgoing TCP connection.
///
Expand All @@ -103,9 +151,25 @@ impl TcpTransport {
options: TcpConnectionOptions,
) -> Result<TcpConnection> {
let peer = HostnamePort::from_str(&peer.into())?;
debug!("Connecting to {}", peer.clone());

let (read_half, write_half) = connect(&peer).await?;
let (read_half, write_half) = {
let connection = {
let mut guard = self.connections.lock().unwrap();
if let Some(connections) = guard.get_mut(&peer) {
debug!("Reusing existing connection to {}", peer.clone());
connections.pop_front()
} else {
None
}
};

if let Some(read_write_half) = connection {
read_write_half
} else {
connect(&peer).await?
}
};

let socket = read_half
.peer_addr()
.map_err(|e| ockam_core::Error::new(Origin::Transport, Kind::Internal, e))?;
Expand All @@ -118,6 +182,17 @@ impl TcpTransport {
let receiver_outgoing_access_control =
options.create_receiver_outgoing_access_control(self.ctx.flow_controls());

let connection_collector = {
if peer.is_localhost() {
None
} else {
Some(Arc::new(ConnectionCollector::new(
peer.clone(),
&self.connections,
)))
}
};

TcpSendWorker::start(
&self.ctx,
self.registry.clone(),
Expand All @@ -126,6 +201,7 @@ impl TcpTransport {
socket,
mode,
&flow_control_id,
connection_collector.clone(),
)
.await?;

Expand All @@ -138,6 +214,7 @@ impl TcpTransport {
mode,
&flow_control_id,
receiver_outgoing_access_control,
connection_collector,
)
.await?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@ mod lifecycle;
mod listener;
mod portals;

pub(crate) use common::*;

pub use crate::portal::options::*;
pub(crate) use common::*;
pub use connection::*;
pub use listener::*;
pub use portals::*;
use std::collections::{HashMap, VecDeque};
use std::sync::Mutex as SyncMutex;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};

use crate::TcpRegistry;
use ockam_core::compat::sync::Arc;
use ockam_core::{async_trait, Result};
use ockam_node::{Context, HasContext};
use ockam_transport_core::HostnamePort;

type CachedConnectionsQueue =
SyncMutex<HashMap<HostnamePort, VecDeque<(OwnedReadHalf, OwnedWriteHalf)>>>;

/// High level management interface for TCP transports
///
Expand Down Expand Up @@ -58,6 +64,7 @@ use ockam_node::{Context, HasContext};
pub struct TcpTransport {
ctx: Arc<Context>,
registry: TcpRegistry,
connections: Arc<CachedConnectionsQueue>,

#[cfg(privileged_portals_support)]
pub(crate) ebpf_support: Arc<crate::privileged_portal::TcpTransportEbpfSupport>,
Expand All @@ -71,6 +78,7 @@ impl TcpTransport {
registry: TcpRegistry::default(),
#[cfg(privileged_portals_support)]
ebpf_support: Default::default(),
connections: Default::default(),
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl Processor for TcpListenProcessor {
peer,
mode,
&receiver_flow_control_id,
None,
)
.await?;

Expand All @@ -119,6 +120,7 @@ impl Processor for TcpListenProcessor {
mode,
&receiver_flow_control_id,
receiver_outgoing_access_control,
None,
)
.await?;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::transport::ConnectionCollector;
use crate::transport_message::TcpTransportMessage;
use crate::workers::Addresses;
use crate::{
Expand All @@ -16,7 +17,7 @@ use ockam_core::{Processor, Result};
use ockam_node::{Context, ProcessorBuilder};
use ockam_transport_core::TransportError;
use tokio::{io::AsyncReadExt, net::tcp::OwnedReadHalf};
use tracing::{info, instrument, trace};
use tracing::{debug, info, instrument, trace};

/// A TCP receiving message processor
///
Expand All @@ -29,11 +30,12 @@ use tracing::{info, instrument, trace};
pub(crate) struct TcpRecvProcessor {
registry: TcpRegistry,
incoming_buffer: Vec<u8>,
read_half: OwnedReadHalf,
read_half: Option<OwnedReadHalf>,
socket_address: SocketAddr,
addresses: Addresses,
mode: TcpConnectionMode,
flow_control_id: FlowControlId,
connection_collector: Option<Arc<ConnectionCollector>>,
}

impl TcpRecvProcessor {
Expand All @@ -45,15 +47,17 @@ impl TcpRecvProcessor {
addresses: Addresses,
mode: TcpConnectionMode,
flow_control_id: FlowControlId,
connection_collector: Option<Arc<ConnectionCollector>>,
) -> Self {
Self {
registry,
incoming_buffer: Vec::new(),
read_half,
read_half: Some(read_half),
socket_address,
addresses,
mode,
flow_control_id,
connection_collector,
}
}

Expand All @@ -68,6 +72,7 @@ impl TcpRecvProcessor {
mode: TcpConnectionMode,
flow_control_id: &FlowControlId,
receiver_outgoing_access_control: Arc<dyn OutgoingAccessControl>,
connection_collector: Option<Arc<ConnectionCollector>>,
) -> Result<()> {
let receiver = TcpRecvProcessor::new(
registry,
Expand All @@ -76,6 +81,7 @@ impl TcpRecvProcessor {
addresses.clone(),
mode,
flow_control_id.clone(),
connection_collector,
);

let mailbox = Mailbox::new(
Expand All @@ -98,7 +104,12 @@ impl TcpRecvProcessor {
Ok(())
}

async fn notify_sender_stream_dropped(&self, ctx: &Context, msg: impl Display) -> Result<()> {
async fn notify_sender_stream_dropped(
&mut self,
ctx: &Context,
msg: impl Display,
) -> Result<()> {
self.read_half.take();
info!(
"Connection to peer '{}' was closed; dropping stream. {}",
self.socket_address, msg
Expand Down Expand Up @@ -129,12 +140,16 @@ impl Processor for TcpRecvProcessor {
self.flow_control_id.clone(),
));

let protocol_version = match self.read_half.read_u8().await {
Ok(p) => p,
Err(e) => {
self.notify_sender_stream_dropped(ctx, e).await?;
return Err(TransportError::GenericIo)?;
let protocol_version = if let Some(read_half) = self.read_half.as_mut() {
match read_half.read_u8().await {
Ok(p) => p,
Err(e) => {
self.notify_sender_stream_dropped(ctx, e).await?;
return Err(TransportError::GenericIo)?;
}
}
} else {
return Err(TransportError::ConnectionDrop)?;
};

let _protocol_version = match TcpProtocolVersion::try_from(protocol_version) {
Expand All @@ -158,8 +173,15 @@ impl Processor for TcpRecvProcessor {

#[instrument(skip_all, name = "TcpRecvProcessor::shutdown")]
async fn shutdown(&mut self, ctx: &mut Self::Context) -> Result<()> {
self.registry.remove_receiver_processor(&ctx.address());
if let Some(connection_collector) = self.connection_collector.as_ref() {
if let Some(read_half) = self.read_half.take() {
connection_collector.collect_read_half(read_half);
} else {
debug!("Connection closed, no read half to collect");
}
}

self.registry.remove_receiver_processor(&ctx.address());
Ok(())
}

Expand All @@ -177,7 +199,13 @@ impl Processor for TcpRecvProcessor {
#[instrument(skip_all, name = "TcpRecvProcessor::process", fields(worker = %ctx.address()))]
async fn process(&mut self, ctx: &mut Context) -> Result<bool> {
// Read the message length
let len = match self.read_half.read_u32().await {
let read_half = if let Some(read_half) = self.read_half.as_mut() {
read_half
} else {
return Ok(false);
};

let len = match read_half.read_u32().await {
Ok(l) => l,
Err(e) => {
self.notify_sender_stream_dropped(ctx, e).await?;
Expand Down Expand Up @@ -217,7 +245,7 @@ impl Processor for TcpRecvProcessor {
self.incoming_buffer.resize(len_usize, 0);

// Then read into the buffer
match self.read_half.read_exact(&mut self.incoming_buffer).await {
match read_half.read_exact(&mut self.incoming_buffer).await {
Ok(_) => {}
Err(e) => {
self.notify_sender_stream_dropped(ctx, e).await?;
Expand Down
Loading

0 comments on commit a187bdc

Please sign in to comment.