From a187bdcff1c8d4bc32f8232e76d90085065ccffc Mon Sep 17 00:00:00 2001 From: Davide Baldo Date: Thu, 19 Dec 2024 11:39:27 +0100 Subject: [PATCH] feat(rust): cache and reuse tcp connections --- .../ockam_transport_core/src/hostname_port.rs | 6 +- .../src/transport/connection.rs | 83 ++++++++++++++++++- .../ockam_transport_tcp/src/transport/mod.rs | 12 ++- .../src/workers/listener.rs | 2 + .../src/workers/receiver.rs | 52 +++++++++--- .../ockam_transport_tcp/src/workers/sender.rs | 65 ++++++++++----- 6 files changed, 183 insertions(+), 37 deletions(-) diff --git a/implementations/rust/ockam/ockam_transport_core/src/hostname_port.rs b/implementations/rust/ockam/ockam_transport_core/src/hostname_port.rs index 4c09aa1ed3a..0b1f472120b 100644 --- a/implementations/rust/ockam/ockam_transport_core/src/hostname_port.rs +++ b/implementations/rust/ockam/ockam_transport_core/src/hostname_port.rs @@ -30,7 +30,7 @@ impl From for HostnamePort { } /// Hostname and port -#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode, CborLen)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode, CborLen)] #[rustfmt::skip] pub struct HostnamePort { #[n(0)] hostname: String, @@ -158,6 +158,10 @@ impl HostnamePort { Ok(HostnamePort::new(hostname, port)) } + + pub fn is_localhost(&self) -> bool { + self.hostname == "localhost" || self.hostname == "127.0.0.1" + } } impl From for HostnamePort { diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs b/implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs index 12ea8729c50..69fd7dbdb86 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs @@ -1,4 +1,4 @@ -use crate::transport::connect; +use crate::transport::{connect, CachedConnectionsQueue}; use crate::workers::{Addresses, TcpRecvProcessor, TcpSendWorker}; use crate::{TcpConnectionMode, TcpConnectionOptions, TcpTransport}; use core::fmt; @@ -10,6 +10,8 @@ use ockam_core::{Address, Result}; use ockam_node::Context; use ockam_transport_core::HostnamePort; use std::net::SocketAddr; +use std::sync::{Arc, Mutex as SyncMutex, Weak}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tracing::debug; /// Result of [`TcpTransport::connect`] call. @@ -84,6 +86,52 @@ impl TcpConnection { } } +pub(crate) struct ConnectionCollector { + hostname_port: HostnamePort, + connections_queue: Weak, + read_half: SyncMutex>, + write_half: SyncMutex>, +} + +impl ConnectionCollector { + fn new(hostname_port: HostnamePort, connections_queue: &Arc) -> Self { + Self { + hostname_port, + connections_queue: Arc::downgrade(connections_queue), + read_half: SyncMutex::new(None), + write_half: SyncMutex::new(None), + } + } + + pub(crate) fn collect_read_half(&self, read_half: OwnedReadHalf) { + debug!("Collecting read half for {}", self.hostname_port); + self.read_half.lock().unwrap().replace(read_half); + self.check_and_push_connection(); + } + + pub(crate) fn collect_write_half(&self, write_half: OwnedWriteHalf) { + debug!("Collecting write half for {}", self.hostname_port); + self.write_half.lock().unwrap().replace(write_half); + self.check_and_push_connection(); + } + + fn check_and_push_connection(&self) { + let mut read_half = self.read_half.lock().unwrap(); + let mut write_half = self.write_half.lock().unwrap(); + + if read_half.is_some() && write_half.is_some() { + if let Some(connections_queue) = self.connections_queue.upgrade() { + let read_half = read_half.take().unwrap(); + let write_half = write_half.take().unwrap(); + + let mut guard = connections_queue.lock().unwrap(); + let connections = guard.entry(self.hostname_port.clone()).or_default(); + connections.push_back((read_half, write_half)); + } + } + } +} + impl TcpTransport { /// Establish an outgoing TCP connection. /// @@ -103,9 +151,25 @@ impl TcpTransport { options: TcpConnectionOptions, ) -> Result { let peer = HostnamePort::from_str(&peer.into())?; - debug!("Connecting to {}", peer.clone()); - let (read_half, write_half) = connect(&peer).await?; + let (read_half, write_half) = { + let connection = { + let mut guard = self.connections.lock().unwrap(); + if let Some(connections) = guard.get_mut(&peer) { + debug!("Reusing existing connection to {}", peer.clone()); + connections.pop_front() + } else { + None + } + }; + + if let Some(read_write_half) = connection { + read_write_half + } else { + connect(&peer).await? + } + }; + let socket = read_half .peer_addr() .map_err(|e| ockam_core::Error::new(Origin::Transport, Kind::Internal, e))?; @@ -118,6 +182,17 @@ impl TcpTransport { let receiver_outgoing_access_control = options.create_receiver_outgoing_access_control(self.ctx.flow_controls()); + let connection_collector = { + if peer.is_localhost() { + None + } else { + Some(Arc::new(ConnectionCollector::new( + peer.clone(), + &self.connections, + ))) + } + }; + TcpSendWorker::start( &self.ctx, self.registry.clone(), @@ -126,6 +201,7 @@ impl TcpTransport { socket, mode, &flow_control_id, + connection_collector.clone(), ) .await?; @@ -138,6 +214,7 @@ impl TcpTransport { mode, &flow_control_id, receiver_outgoing_access_control, + connection_collector, ) .await?; diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/transport/mod.rs b/implementations/rust/ockam/ockam_transport_tcp/src/transport/mod.rs index 9db988c2393..fb3c6c0dda3 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/transport/mod.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/transport/mod.rs @@ -4,17 +4,23 @@ mod lifecycle; mod listener; mod portals; -pub(crate) use common::*; - pub use crate::portal::options::*; +pub(crate) use common::*; pub use connection::*; pub use listener::*; pub use portals::*; +use std::collections::{HashMap, VecDeque}; +use std::sync::Mutex as SyncMutex; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use crate::TcpRegistry; use ockam_core::compat::sync::Arc; use ockam_core::{async_trait, Result}; use ockam_node::{Context, HasContext}; +use ockam_transport_core::HostnamePort; + +type CachedConnectionsQueue = + SyncMutex>>; /// High level management interface for TCP transports /// @@ -58,6 +64,7 @@ use ockam_node::{Context, HasContext}; pub struct TcpTransport { ctx: Arc, registry: TcpRegistry, + connections: Arc, #[cfg(privileged_portals_support)] pub(crate) ebpf_support: Arc, @@ -71,6 +78,7 @@ impl TcpTransport { registry: TcpRegistry::default(), #[cfg(privileged_portals_support)] ebpf_support: Default::default(), + connections: Default::default(), } } } diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs b/implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs index 2b590bc79d2..3ee8cfc30e0 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs @@ -106,6 +106,7 @@ impl Processor for TcpListenProcessor { peer, mode, &receiver_flow_control_id, + None, ) .await?; @@ -119,6 +120,7 @@ impl Processor for TcpListenProcessor { mode, &receiver_flow_control_id, receiver_outgoing_access_control, + None, ) .await?; diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/workers/receiver.rs b/implementations/rust/ockam/ockam_transport_tcp/src/workers/receiver.rs index 13b8e969319..81af07c8f5f 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/workers/receiver.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/workers/receiver.rs @@ -1,3 +1,4 @@ +use crate::transport::ConnectionCollector; use crate::transport_message::TcpTransportMessage; use crate::workers::Addresses; use crate::{ @@ -16,7 +17,7 @@ use ockam_core::{Processor, Result}; use ockam_node::{Context, ProcessorBuilder}; use ockam_transport_core::TransportError; use tokio::{io::AsyncReadExt, net::tcp::OwnedReadHalf}; -use tracing::{info, instrument, trace}; +use tracing::{debug, info, instrument, trace}; /// A TCP receiving message processor /// @@ -29,11 +30,12 @@ use tracing::{info, instrument, trace}; pub(crate) struct TcpRecvProcessor { registry: TcpRegistry, incoming_buffer: Vec, - read_half: OwnedReadHalf, + read_half: Option, socket_address: SocketAddr, addresses: Addresses, mode: TcpConnectionMode, flow_control_id: FlowControlId, + connection_collector: Option>, } impl TcpRecvProcessor { @@ -45,15 +47,17 @@ impl TcpRecvProcessor { addresses: Addresses, mode: TcpConnectionMode, flow_control_id: FlowControlId, + connection_collector: Option>, ) -> Self { Self { registry, incoming_buffer: Vec::new(), - read_half, + read_half: Some(read_half), socket_address, addresses, mode, flow_control_id, + connection_collector, } } @@ -68,6 +72,7 @@ impl TcpRecvProcessor { mode: TcpConnectionMode, flow_control_id: &FlowControlId, receiver_outgoing_access_control: Arc, + connection_collector: Option>, ) -> Result<()> { let receiver = TcpRecvProcessor::new( registry, @@ -76,6 +81,7 @@ impl TcpRecvProcessor { addresses.clone(), mode, flow_control_id.clone(), + connection_collector, ); let mailbox = Mailbox::new( @@ -98,7 +104,12 @@ impl TcpRecvProcessor { Ok(()) } - async fn notify_sender_stream_dropped(&self, ctx: &Context, msg: impl Display) -> Result<()> { + async fn notify_sender_stream_dropped( + &mut self, + ctx: &Context, + msg: impl Display, + ) -> Result<()> { + self.read_half.take(); info!( "Connection to peer '{}' was closed; dropping stream. {}", self.socket_address, msg @@ -129,12 +140,16 @@ impl Processor for TcpRecvProcessor { self.flow_control_id.clone(), )); - let protocol_version = match self.read_half.read_u8().await { - Ok(p) => p, - Err(e) => { - self.notify_sender_stream_dropped(ctx, e).await?; - return Err(TransportError::GenericIo)?; + let protocol_version = if let Some(read_half) = self.read_half.as_mut() { + match read_half.read_u8().await { + Ok(p) => p, + Err(e) => { + self.notify_sender_stream_dropped(ctx, e).await?; + return Err(TransportError::GenericIo)?; + } } + } else { + return Err(TransportError::ConnectionDrop)?; }; let _protocol_version = match TcpProtocolVersion::try_from(protocol_version) { @@ -158,8 +173,15 @@ impl Processor for TcpRecvProcessor { #[instrument(skip_all, name = "TcpRecvProcessor::shutdown")] async fn shutdown(&mut self, ctx: &mut Self::Context) -> Result<()> { - self.registry.remove_receiver_processor(&ctx.address()); + if let Some(connection_collector) = self.connection_collector.as_ref() { + if let Some(read_half) = self.read_half.take() { + connection_collector.collect_read_half(read_half); + } else { + debug!("Connection closed, no read half to collect"); + } + } + self.registry.remove_receiver_processor(&ctx.address()); Ok(()) } @@ -177,7 +199,13 @@ impl Processor for TcpRecvProcessor { #[instrument(skip_all, name = "TcpRecvProcessor::process", fields(worker = %ctx.address()))] async fn process(&mut self, ctx: &mut Context) -> Result { // Read the message length - let len = match self.read_half.read_u32().await { + let read_half = if let Some(read_half) = self.read_half.as_mut() { + read_half + } else { + return Ok(false); + }; + + let len = match read_half.read_u32().await { Ok(l) => l, Err(e) => { self.notify_sender_stream_dropped(ctx, e).await?; @@ -217,7 +245,7 @@ impl Processor for TcpRecvProcessor { self.incoming_buffer.resize(len_usize, 0); // Then read into the buffer - match self.read_half.read_exact(&mut self.incoming_buffer).await { + match read_half.read_exact(&mut self.incoming_buffer).await { Ok(_) => {} Err(e) => { self.notify_sender_stream_dropped(ctx, e).await?; diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs b/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs index 9f5cbb0e9a1..b3c76828f58 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs @@ -9,12 +9,13 @@ use ockam_core::{ use ockam_core::{Any, Decodable, Mailbox, Mailboxes, Message, Result, Routed, Worker}; use ockam_node::{Context, WorkerBuilder}; +use crate::transport::ConnectionCollector; use crate::transport_message::TcpTransportMessage; use ockam_transport_core::TransportError; use serde::{Deserialize, Serialize}; use tokio::io::AsyncWriteExt; use tokio::net::tcp::OwnedWriteHalf; -use tracing::{info, instrument, trace, warn}; +use tracing::{debug, info, instrument, trace, warn}; #[derive(Serialize, Deserialize, Message, Clone)] pub(crate) enum TcpSendWorkerMsg { @@ -32,12 +33,13 @@ pub(crate) enum TcpSendWorkerMsg { pub(crate) struct TcpSendWorker { buffer: Vec, registry: TcpRegistry, - write_half: OwnedWriteHalf, + write_half: Option, socket_address: SocketAddr, addresses: Addresses, mode: TcpConnectionMode, receiver_flow_control_id: FlowControlId, rx_should_be_stopped: bool, + connection_collector: Option>, } impl TcpSendWorker { @@ -49,15 +51,17 @@ impl TcpSendWorker { addresses: Addresses, mode: TcpConnectionMode, receiver_flow_control_id: FlowControlId, + connection_collector: Option>, ) -> Self { Self { buffer: vec![], registry, - write_half, + write_half: Some(write_half), socket_address, addresses, receiver_flow_control_id, mode, + connection_collector, rx_should_be_stopped: true, } } @@ -76,6 +80,7 @@ impl TcpSendWorker { socket_address: SocketAddr, mode: TcpConnectionMode, receiver_flow_control_id: &FlowControlId, + connection_collector: Option>, ) -> Result<()> { trace!("Creating new TCP worker pair"); let sender_worker = Self::new( @@ -85,6 +90,7 @@ impl TcpSendWorker { addresses.clone(), mode, receiver_flow_control_id.clone(), + connection_collector, ); let main_mailbox = Mailbox::new( @@ -177,19 +183,24 @@ impl Worker for TcpSendWorker { )); // First thing send our protocol version - if self - .write_half - .write_u8(TcpProtocolVersion::V1.into()) - .await - .is_err() - { - warn!( - "Failed to send protocol version to peer {}", - self.socket_address - ); - self.stop(ctx).await?; + if let Some(write_half) = self.write_half.as_mut() { + if write_half + .write_u8(TcpProtocolVersion::V1.into()) + .await + .is_err() + { + warn!( + "Failed to send protocol version to peer {}", + self.socket_address + ); + self.write_half.take(); + self.stop(ctx).await?; - return Ok(()); + return Ok(()); + } + } else { + self.stop(ctx).await?; + return Err(TransportError::ConnectionDrop)?; } Ok(()) @@ -197,6 +208,14 @@ impl Worker for TcpSendWorker { #[instrument(skip_all, name = "TcpSendWorker::shutdown")] async fn shutdown(&mut self, ctx: &mut Self::Context) -> Result<()> { + if let Some(connection_collector) = self.connection_collector.as_ref() { + if let Some(write_half) = self.write_half.take() { + connection_collector.collect_write_half(write_half); + } else { + debug!("Connection closed, no read write to collect"); + } + } + self.registry .remove_sender_worker(self.addresses.sender_address()); @@ -230,6 +249,7 @@ impl Worker for TcpSendWorker { // No need to stop Receiver as it notified us about connection drop and will // stop itself self.rx_should_be_stopped = false; + self.write_half.take(); self.stop(ctx).await?; return Ok(()); @@ -243,16 +263,23 @@ impl Worker for TcpSendWorker { if let Err(err) = self.serialize_message(local_message) { // Close the stream + self.write_half.take(); self.stop(ctx).await?; return Err(err); }; - if self.write_half.write_all(&self.buffer).await.is_err() { - warn!("Failed to send message to peer {}", self.socket_address); - self.stop(ctx).await?; + if let Some(write_half) = self.write_half.as_mut() { + if write_half.write_all(&self.buffer).await.is_err() { + warn!("Failed to send message to peer {}", self.socket_address); + self.write_half.take(); + self.stop(ctx).await?; - return Ok(()); + return Ok(()); + } + } else { + self.stop(ctx).await?; + return Err(TransportError::ConnectionDrop)?; } }