Skip to content
This repository has been archived by the owner on Oct 26, 2022. It is now read-only.

Commit

Permalink
Allow sending batch of messages through connection handle
Browse files Browse the repository at this point in the history
This is motivated by netfilter; changes to netfilter must be done
through a series of messages (started by NFNL_MSG_BATCH_BEGIN, ended by
NFNL_MSG_BATCH_END).  The full batch needs to be submitted to the kernel
in one write/sendto/..., otherwise the kernel will abort the batch.

(And sending a batch without an END message is interpreted as a query to
verify the batch without actually committing it.)
  • Loading branch information
stbuehler committed Oct 30, 2021
1 parent 64ede76 commit 7253106
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 19 deletions.
30 changes: 23 additions & 7 deletions netlink-proto/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::{
codecs::NetlinkCodec,
framed::NetlinkFramed,
sys::{Socket, SocketAddr},
BatchQueueElem,
Protocol,
Request,
Response,
Expand Down Expand Up @@ -92,14 +93,29 @@ where
return;
}

let (mut message, addr) = protocol.outgoing_messages.pop_front().unwrap();
message.finalize();
match protocol.outgoing_messages.pop_front().unwrap() {
BatchQueueElem::Single(mut message, addr) => {
message.finalize();

trace!("sending outgoing message");
if let Err(e) = Pin::as_mut(&mut socket).start_send((message, addr)) {
error!("failed to send message: {:?}", e);
self.socket_closed = true;
return;
trace!("sending outgoing message");
if let Err(e) = Pin::as_mut(&mut socket).start_send((message, addr)) {
error!("failed to send message: {:?}", e);
self.socket_closed = true;
return;
}
}
BatchQueueElem::Batch(mut messages, addr) => {
for message in &mut messages {
message.finalize();
}

trace!("sending outgoing message");
if let Err(e) = Pin::as_mut(&mut socket).start_send((messages, addr)) {
error!("failed to send message: {:?}", e);
self.socket_closed = true;
return;
}
}
}
}

Expand Down
69 changes: 69 additions & 0 deletions netlink-proto/src/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ where
Ok(rx)
}

/// Start a batch of messages
///
/// Collects multiple messages to be sent in one "request".
pub fn batch(&self, destination: SocketAddr) -> BatchHandle<T> {
BatchHandle {
handle: self.clone(),
destination,
messages: Vec::new(),
channels: Vec::new(),
}
}

pub fn notify(
&mut self,
message: NetlinkMessage<T>,
Expand All @@ -76,3 +88,60 @@ where
.map_err(|_| ErrorKind::ConnectionClosed.into())
}
}

/// A handle to create a batch request (multiple requests serialized in one buffer)
#[derive(Debug)]
#[must_use = "A batch of messages must be sent to actually do something"]
pub struct BatchHandle<T>
where
T: Debug + Clone + Eq + PartialEq,
{
handle: ConnectionHandle<T>,
destination: SocketAddr,
messages: Vec<NetlinkMessage<T>>,
channels: Vec<UnboundedSender<NetlinkMessage<T>>>,
}

impl<T> BatchHandle<T>
where
T: Debug + Clone + Eq + PartialEq,
{
/// Add a new request to the batch and get the response as a stream of messages.
///
/// Similar to [`ConnectionHandle::request`].
pub fn request(&mut self, message: NetlinkMessage<T>) -> impl Stream<Item = NetlinkMessage<T>> {
let (tx, rx) = unbounded::<NetlinkMessage<T>>();
self.messages.push(message);
self.channels.push(tx);
rx
}

/// Add a new request to the batch, but ignore response messages
///
/// Similar to [`ConnectionHandle::notify`].
pub fn notify(&mut self, message: NetlinkMessage<T>) {
let _ = self.request(message);
}

/// Send of batch request
pub fn send(self) -> Result<(), Error<T>> {
debug!("handle: forwarding new request to connection");
let request = Request::Batch {
metadata: self.channels,
messages: self.messages,
destination: self.destination,
};
UnboundedSender::unbounded_send(&self.handle.requests_tx, request).map_err(|e| {
// the channel is unbounded, so it can't be full. If this
// failed, it means the Connection shut down.
if e.is_full() {
panic!("internal error: unbounded channel full?!");
} else if e.is_disconnected() {
Error::from(ErrorKind::ConnectionClosed)
} else {
panic!("unknown error: {:?}", e);
}
})?;
Ok(())
}
}
2 changes: 1 addition & 1 deletion netlink-proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ mod framed;
pub use crate::framed::*;

mod protocol;
pub(crate) use self::protocol::{Protocol, Response};
pub(crate) use self::protocol::{BatchQueueElem, Protocol, Response};
pub(crate) type Request<T> =
self::protocol::Request<T, UnboundedSender<crate::packet::NetlinkMessage<T>>>;

Expand Down
2 changes: 1 addition & 1 deletion netlink-proto/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
mod protocol;
mod request;

pub(crate) use protocol::{Protocol, Response};
pub(crate) use protocol::{BatchQueueElem, Protocol, Response};
pub(crate) use request::Request;
53 changes: 43 additions & 10 deletions netlink-proto/src/protocol/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ struct PendingRequest<M> {
metadata: M,
}

#[derive(Debug)]
pub enum BatchQueueElem<T>
where
T: Debug + Clone + PartialEq + Eq + NetlinkSerializable<T> + NetlinkDeserializable<T>,
{
Single(NetlinkMessage<T>, SocketAddr),
Batch(Vec<NetlinkMessage<T>>, SocketAddr),
}

#[derive(Debug, Default)]
pub(crate) struct Protocol<T, M>
where
Expand All @@ -66,7 +75,7 @@ where
pub incoming_requests: VecDeque<(NetlinkMessage<T>, SocketAddr)>,

/// The messages to be sent out
pub outgoing_messages: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
pub outgoing_messages: VecDeque<BatchQueueElem<T>>,
}

impl<T, M> Protocol<T, M>
Expand Down Expand Up @@ -136,17 +145,15 @@ where
debug!("done handling response to request {:?}", request_id);
}

pub fn request(&mut self, request: Request<T, M>) {
let Request::Single {
mut message,
metadata,
destination,
} = request;

self.set_sequence_id(&mut message);
fn request_single(
&mut self,
message: &mut NetlinkMessage<T>,
metadata: M,
destination: &SocketAddr,
) {
self.set_sequence_id(message);
let request_id = RequestId::new(self.sequence_id, destination.port_number());
let flags = message.header.flags;
self.outgoing_messages.push_back((message, destination));

// If we expect a response, we store the request id so that we
// can map the response to this specific request.
Expand All @@ -170,6 +177,32 @@ where
}
}

pub fn request(&mut self, request: Request<T, M>) {
match request {
Request::Single {
mut message,
metadata,
destination,
} => {
self.request_single(&mut message, metadata, &destination);
self.outgoing_messages
.push_back(BatchQueueElem::Single(message, destination));
}
Request::Batch {
mut messages,
metadata,
destination,
} => {
assert_eq!(messages.len(), metadata.len());
for (msg, md) in messages.iter_mut().zip(metadata.into_iter()) {
self.request_single(msg, md, &destination);
}
self.outgoing_messages
.push_back(BatchQueueElem::Batch(messages, destination));
}
}
}

fn set_sequence_id(&mut self, message: &mut NetlinkMessage<T>) {
self.sequence_id += 1;
message.header.sequence_number = self.sequence_id;
Expand Down
5 changes: 5 additions & 0 deletions netlink-proto/src/protocol/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@ where
message: NetlinkMessage<T>,
destination: SocketAddr,
},
Batch {
metadata: Vec<M>,
messages: Vec<NetlinkMessage<T>>,
destination: SocketAddr,
},
}

0 comments on commit 7253106

Please sign in to comment.