From a220a69e65f1b2a769b288ad4a7dabec0158804d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Fri, 26 Nov 2021 00:00:40 +0100 Subject: [PATCH] Allow sending batch of messages through connection handle This is motivated by netfilter; changes to netfilter must be done through a series of messages (started by NFNL_MSG_BATCH_BEGIN, ended by NFNL_MSG_BATCH_END). The full batch needs to be submitted to the kernel in one write/sendto/..., otherwise the kernel will abort the batch. (And sending a batch without an END message is interpreted as a query to verify the batch without actually committing it.) --- netlink-proto/src/connection.rs | 12 +++++ netlink-proto/src/handle.rs | 74 ++++++++++++++++++++++++++ netlink-proto/src/protocol/protocol.rs | 13 +++++ netlink-proto/src/protocol/request.rs | 5 ++ 4 files changed, 104 insertions(+) diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index f9b0e1f8..50eeabe5 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -117,6 +117,18 @@ where return; } } + OutgoingMessage::Batch(mut messages, addr) => { + for message in &mut messages { + message.finalize(); + } + + trace!("sending outgoing message"); + if let Err(e) = Pin::as_mut(&mut socket).start_send((messages, addr)) { + error!("failed to send message: {:?}", e); + self.socket_closed = true; + return; + } + } } } diff --git a/netlink-proto/src/handle.rs b/netlink-proto/src/handle.rs index 2c484867..458db0af 100644 --- a/netlink-proto/src/handle.rs +++ b/netlink-proto/src/handle.rs @@ -58,6 +58,18 @@ where Ok(rx) } + /// Start a batch of messages + /// + /// Collects multiple messages to be sent in one "request". + pub fn batch(&self, destination: SocketAddr) -> BatchHandle { + BatchHandle { + handle: self.clone(), + destination, + messages: Vec::new(), + channels: Vec::new(), + } + } + pub fn notify( &mut self, message: NetlinkMessage, @@ -83,3 +95,65 @@ impl Clone for ConnectionHandle { } } } + +/// A handle to create a batch request (multiple requests serialized in one buffer) +/// +/// The request needs to be [`sent`](`BatchHandle::send`) to actually do something. +#[derive(Debug)] +#[must_use = "A batch of messages must be sent to actually do something"] +pub struct BatchHandle +where + T: Debug, +{ + handle: ConnectionHandle, + destination: SocketAddr, + messages: Vec>, + channels: Vec>>, +} + +impl BatchHandle +where + T: Debug, +{ + /// Add a new request to the batch and get the response as a stream of messages. + /// + /// Similar to [`ConnectionHandle::request`]. + /// + /// Response stream will block until batch request is sent, and will be empty + /// if the batch request is dropped. + pub fn request(&mut self, message: NetlinkMessage) -> impl Stream> { + let (tx, rx) = unbounded::>(); + self.messages.push(message); + self.channels.push(tx); + rx + } + + /// Add a new request to the batch, but ignore response messages + /// + /// Similar to [`ConnectionHandle::notify`]. + pub fn notify(&mut self, message: NetlinkMessage) { + let _ = self.request(message); + } + + /// Send batch request + pub fn send(self) -> Result<(), Error> { + debug!("handle: forwarding new request to connection"); + let request = Request::Batch { + metadata: self.channels, + messages: self.messages, + destination: self.destination, + }; + UnboundedSender::unbounded_send(&self.handle.requests_tx, request).map_err(|e| { + // the channel is unbounded, so it can't be full. If this + // failed, it means the Connection shut down. + if e.is_full() { + panic!("internal error: unbounded channel full?!"); + } else if e.is_disconnected() { + Error::ConnectionClosed + } else { + panic!("unknown error: {:?}", e); + } + })?; + Ok(()) + } +} diff --git a/netlink-proto/src/protocol/protocol.rs b/netlink-proto/src/protocol/protocol.rs index c7d0203d..b502cd78 100644 --- a/netlink-proto/src/protocol/protocol.rs +++ b/netlink-proto/src/protocol/protocol.rs @@ -47,6 +47,7 @@ struct PendingRequest { #[derive(Debug)] pub(crate) enum OutgoingMessage { Single(NetlinkMessage, SocketAddr), + Batch(Vec>, SocketAddr), } #[derive(Debug, Default)] @@ -178,6 +179,18 @@ where self.outgoing_messages .push_back(OutgoingMessage::Single(message, destination)); } + Request::Batch { + mut messages, + metadata, + destination, + } => { + assert_eq!(messages.len(), metadata.len()); + for (msg, md) in messages.iter_mut().zip(metadata.into_iter()) { + self.request_single(msg, md, &destination); + } + self.outgoing_messages + .push_back(OutgoingMessage::Batch(messages, destination)); + } } } diff --git a/netlink-proto/src/protocol/request.rs b/netlink-proto/src/protocol/request.rs index 6fab4f5a..832befff 100644 --- a/netlink-proto/src/protocol/request.rs +++ b/netlink-proto/src/protocol/request.rs @@ -13,4 +13,9 @@ pub(crate) enum Request { message: NetlinkMessage, destination: SocketAddr, }, + Batch { + metadata: Vec, + messages: Vec>, + destination: SocketAddr, + }, }