Skip to content

Commit

Permalink
feat(rust): cache and reuse tcp connections
Browse files Browse the repository at this point in the history
  • Loading branch information
davide-baldo committed Dec 19, 2024
1 parent 2ff6e39 commit 5694915
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl From<StaticHostnamePort> for HostnamePort {
}

/// Hostname and port
#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode, CborLen)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode, CborLen)]
#[rustfmt::skip]
pub struct HostnamePort {
#[n(0)] hostname: String,
Expand Down Expand Up @@ -158,6 +158,10 @@ impl HostnamePort {

Ok(HostnamePort::new(hostname, port))
}

pub fn is_localhost(&self) -> bool {
self.hostname == "localhost" || self.hostname == "127.0.0.1"
}
}

impl From<SocketAddr> for HostnamePort {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::transport::connect;
use crate::transport::{connect, CachedConnectionsQueue};
use crate::workers::{Addresses, TcpRecvProcessor, TcpSendWorker};
use crate::{TcpConnectionMode, TcpConnectionOptions, TcpTransport};
use core::fmt;
Expand All @@ -10,6 +10,9 @@ use ockam_core::{Address, Result};
use ockam_node::Context;
use ockam_transport_core::HostnamePort;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex as SyncMutex, Weak};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::time::Instant;
use tracing::debug;

/// Result of [`TcpTransport::connect`] call.
Expand Down Expand Up @@ -84,6 +87,55 @@ impl TcpConnection {
}
}

pub(crate) struct ConnectionCollector {
hostname_port: HostnamePort,
connections_queue: Weak<CachedConnectionsQueue>,
read_half: SyncMutex<Option<(Instant, OwnedReadHalf)>>,
write_half: SyncMutex<Option<OwnedWriteHalf>>,
}

impl ConnectionCollector {
fn new(hostname_port: HostnamePort, connections_queue: &Arc<CachedConnectionsQueue>) -> Self {
Self {
hostname_port,
connections_queue: Arc::downgrade(connections_queue),
read_half: SyncMutex::new(None),
write_half: SyncMutex::new(None),
}
}

pub(crate) fn collect_read_half(&self, last_known_reply: Instant, read_half: OwnedReadHalf) {
debug!("Collecting read half for {}", self.hostname_port);
self.read_half
.lock()
.unwrap()
.replace((last_known_reply, read_half));
self.check_and_push_connection();
}

pub(crate) fn collect_write_half(&self, write_half: OwnedWriteHalf) {
debug!("Collecting write half for {}", self.hostname_port);
self.write_half.lock().unwrap().replace(write_half);
self.check_and_push_connection();
}

fn check_and_push_connection(&self) {
let mut read_half = self.read_half.lock().unwrap();
let mut write_half = self.write_half.lock().unwrap();

if read_half.is_some() && write_half.is_some() {
if let Some(connections_queue) = self.connections_queue.upgrade() {
let (last_known_reply, read_half) = read_half.take().unwrap();
let write_half = write_half.take().unwrap();

let mut guard = connections_queue.lock().unwrap();
let connections = guard.entry(self.hostname_port.clone()).or_default();
connections.push_back((last_known_reply, read_half, write_half));
}
}
}
}

impl TcpTransport {
/// Establish an outgoing TCP connection.
///
Expand All @@ -103,9 +155,41 @@ impl TcpTransport {
options: TcpConnectionOptions,
) -> Result<TcpConnection> {
let peer = HostnamePort::from_str(&peer.into())?;
debug!("Connecting to {}", peer.clone());

let (read_half, write_half) = connect(&peer).await?;
let (last_known_reply, skip_initialization, read_half, write_half) = {
let connection = {
let mut guard = self.connections.lock().unwrap();
if let Some(connections) = guard.get_mut(&peer) {
loop {
if let Some((last_known_reply, read_half, write_half)) =
connections.pop_front()
{
let elapsed = last_known_reply.elapsed();
if elapsed.as_secs() < 2 {
debug!(
"Reusing existing connection to {}, {}ms old",
peer.clone(),
elapsed.as_millis()
);
break Some((last_known_reply, true, read_half, write_half));
}
} else {
break None;
}
}
} else {
None
}
};

if let Some(read_write_half) = connection {
read_write_half
} else {
let (read_half, write_half) = connect(&peer).await?;
(Instant::now(), false, read_half, write_half)
}
};

let socket = read_half
.peer_addr()
.map_err(|e| ockam_core::Error::new(Origin::Transport, Kind::Internal, e))?;
Expand All @@ -118,26 +202,39 @@ impl TcpTransport {
let receiver_outgoing_access_control =
options.create_receiver_outgoing_access_control(self.ctx.flow_controls());

let connection_collector = {
if peer.is_localhost() {
None
} else {
Some(Arc::new(ConnectionCollector::new(peer, &self.connections)))
}
};

TcpSendWorker::start(
&self.ctx,
self.registry.clone(),
write_half,
skip_initialization,
&addresses,
socket,
mode,
&flow_control_id,
connection_collector.clone(),
)
.await?;

TcpRecvProcessor::start(
&self.ctx,
self.registry.clone(),
read_half,
skip_initialization,
last_known_reply,
&addresses,
socket,
mode,
&flow_control_id,
receiver_outgoing_access_control,
connection_collector,
)
.await?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@ mod lifecycle;
mod listener;
mod portals;

pub(crate) use common::*;

pub use crate::portal::options::*;
use crate::TcpRegistry;
pub(crate) use common::*;
pub use connection::*;
pub use listener::*;
pub use portals::*;

use crate::TcpRegistry;
use ockam_core::compat::sync::Arc;
use ockam_core::{async_trait, Result};
use ockam_node::{Context, HasContext};
use ockam_transport_core::HostnamePort;
pub use portals::*;
use std::collections::{HashMap, VecDeque};
use std::sync::Mutex as SyncMutex;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::time::Instant;

type CachedConnectionsQueue =
SyncMutex<HashMap<HostnamePort, VecDeque<(Instant, OwnedReadHalf, OwnedWriteHalf)>>>;

/// High level management interface for TCP transports
///
Expand Down Expand Up @@ -58,6 +64,7 @@ use ockam_node::{Context, HasContext};
pub struct TcpTransport {
ctx: Arc<Context>,
registry: TcpRegistry,
connections: Arc<CachedConnectionsQueue>,

#[cfg(privileged_portals_support)]
pub(crate) ebpf_support: Arc<crate::privileged_portal::TcpTransportEbpfSupport>,
Expand All @@ -71,6 +78,7 @@ impl TcpTransport {
registry: TcpRegistry::default(),
#[cfg(privileged_portals_support)]
ebpf_support: Default::default(),
connections: Default::default(),
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use ockam_core::{Address, Processor, Result};
use ockam_node::Context;
use ockam_transport_core::TransportError;
use tokio::net::TcpListener;
use tokio::time::Instant;
use tracing::{debug, instrument};

/// A TCP Listen processor
Expand Down Expand Up @@ -102,10 +103,12 @@ impl Processor for TcpListenProcessor {
ctx,
self.registry.clone(),
write_half,
false,
&addresses,
peer,
mode,
&receiver_flow_control_id,
None,
)
.await?;

Expand All @@ -114,11 +117,14 @@ impl Processor for TcpListenProcessor {
ctx,
self.registry.clone(),
read_half,
false,
Instant::now(),
&addresses,
peer,
mode,
&receiver_flow_control_id,
receiver_outgoing_access_control,
None,
)
.await?;

Expand Down
Loading

0 comments on commit 5694915

Please sign in to comment.