Skip to content
This repository has been archived by the owner on Oct 26, 2022. It is now read-only.

Commit

Permalink
Implement Sink for batched messages on NetlinkFramed
Browse files Browse the repository at this point in the history
Move some data and poll functions to hidden inner struct (they don't
need to be generic over codec and message type).
  • Loading branch information
stbuehler committed Nov 25, 2021
1 parent 64b3168 commit a937896
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 75 deletions.
60 changes: 60 additions & 0 deletions netlink-proto/src/codecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,64 @@ mod tests {
16 /* header */ + 4, /* padded data */
);
}

fn encode_batch_with_prefix<D>(prefix_len: usize, data: Vec<NetlinkMessage<D>>) -> BytesMut
where
D: NetlinkSerializable + std::fmt::Debug,
{
let mut buf = BytesMut::new();
// for now encoder doesn't require buffer to be "pre-aligned"; allow this
// to be tested by different (unaligned) "prefixes"
buf.resize(prefix_len, 0x7f);
for frame in data {
NetlinkCodec::encode(frame, &mut buf).unwrap();
}
buf
}

fn test_batch_encode(
prefix_len: usize,
data: Vec<NetlinkMessage<MsgNever>>,
expected_msg_len: usize,
) {
let result = encode_batch_with_prefix(prefix_len, data.clone());
assert_eq!(result.len(), prefix_len + expected_msg_len,);
test_decode(&result[prefix_len..], &data);
}

#[test]
fn test_batch_encoding_unaligned1() {
test_batch_encode(
1,
vec![overrun_msg_with_len(1), overrun_msg_with_len(4)],
2 * 16 + 2 * 4,
);
}

#[test]
fn test_batch_encoding_unaligned2() {
test_batch_encode(
1,
vec![overrun_msg_with_len(2), overrun_msg_with_len(4)],
2 * 16 + 2 * 4,
);
}

#[test]
fn test_batch_encoding_unaligned3() {
test_batch_encode(
1,
vec![overrun_msg_with_len(3), overrun_msg_with_len(4)],
2 * 16 + 2 * 4,
);
}

#[test]
fn test_batch_encoding_unaligned4() {
test_batch_encode(
1,
vec![overrun_msg_with_len(4), overrun_msg_with_len(4)],
2 * 16 + 2 * 4,
);
}
}
12 changes: 10 additions & 2 deletions netlink-proto/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ where

while !protocol.outgoing_messages.is_empty() {
trace!("found outgoing message to send checking if socket is ready");
if let Poll::Ready(Err(e)) = Pin::as_mut(&mut socket).poll_ready(cx) {
if let Poll::Ready(Err(e)) = <NetlinkFramed<T, S, C> as Sink<(
NetlinkMessage<T>,
SocketAddr,
)>>::poll_ready(Pin::as_mut(&mut socket), cx)
{
// Sink errors are usually not recoverable. The socket
// probably shut down.
warn!("netlink socket shut down: {:?}", e);
Expand All @@ -115,7 +119,11 @@ where

pub fn poll_flush(&mut self, cx: &mut Context) {
trace!("poll_flush called");
if let Poll::Ready(Err(e)) = Pin::new(&mut self.socket).poll_flush(cx) {
if let Poll::Ready(Err(e)) = <NetlinkFramed<T, S, C> as Sink<(
NetlinkMessage<T>,
SocketAddr,
)>>::poll_flush(Pin::new(&mut self.socket), cx)
{
warn!("error flushing netlink socket: {:?}", e);
self.socket_closed = true;
}
Expand Down
204 changes: 131 additions & 73 deletions netlink-proto/src/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,87 @@ use crate::{
};
use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable};

pub struct NetlinkFramed<T, S, C> {
// some functions don't need the message type and codec type
struct FramedIO<S> {
socket: S,
// see https://doc.rust-lang.org/nomicon/phantom-data.html
// "invariant" seems like the safe choice; using `fn(T) -> T`
// should make it invariant but still Send+Sync.
msg_type: PhantomData<fn(T) -> T>, // invariant
codec: PhantomData<fn(C) -> C>, // invariant
reader: BytesMut,
writer: BytesMut,
in_addr: SocketAddr,
out_addr: SocketAddr,
flushed: bool,
}

impl<S> FramedIO<S>
where
S: AsyncSocket,
{
fn poll_next_datagram(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.reader.clear();
self.reader.reserve(INITIAL_READER_CAPACITY);

self.in_addr = match ready!(self.socket.poll_recv_from(cx, &mut self.reader)) {
Ok(addr) => addr,
Err(e) => {
error!("failed to read from netlink socket: {:?}", e);
return Poll::Ready(Err(()));
}
};
Poll::Ready(Ok(()))
}

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if !self.flushed {
match self.poll_flush(cx)? {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}

Poll::Ready(Ok(()))
}

fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.flushed {
return Poll::Ready(Ok(()));
}

trace!("flushing frame; length={}", self.writer.len());
let Self {
ref mut socket,
ref mut out_addr,
ref mut writer,
..
} = *self;

let n = ready!(socket.poll_send_to(cx, writer, out_addr))?;
trace!("written {}", n);

let wrote_all = n == self.writer.len();
self.writer.clear();
self.flushed = true;

let res = if wrote_all {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"failed to write entire datagram to socket",
))
};

Poll::Ready(res)
}
}

pub struct NetlinkFramed<T, S, C> {
io: FramedIO<S>,
// see https://doc.rust-lang.org/nomicon/phantom-data.html
// "invariant" seems like the safe choice; using `fn(T) -> T`
// should make it invariant but still Send+Sync.
msg_type: PhantomData<fn(T) -> T>, // invariant
codec: PhantomData<fn(C) -> C>, // invariant
}

impl<T, S, C> Stream for NetlinkFramed<T, S, C>
where
T: NetlinkDeserializable + Debug,
Expand All @@ -41,33 +108,21 @@ where
type Item = (NetlinkMessage<T>, SocketAddr);

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Self {
ref mut socket,
ref mut in_addr,
ref mut reader,
..
} = Pin::get_mut(self);
let Self { ref mut io, .. } = Pin::get_mut(self);

loop {
match C::decode::<T>(reader) {
Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))),
match C::decode::<T>(&mut io.reader) {
Ok(Some(item)) => return Poll::Ready(Some((item, io.in_addr))),
Ok(None) => {}
Err(e) => {
error!("unrecoverable error in decoder: {:?}", e);
return Poll::Ready(None);
}
}

reader.clear();
reader.reserve(INITIAL_READER_CAPACITY);

*in_addr = match ready!(socket.poll_recv_from(cx, reader)) {
Ok(addr) => addr,
Err(e) => {
error!("failed to read from netlink socket: {:?}", e);
return Poll::Ready(None);
}
};
if let Err(()) = ready!(io.poll_next_datagram(cx)) {
return Poll::Ready(None);
}
}
}
}
Expand All @@ -80,15 +135,8 @@ where
{
type Error = io::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if !self.flushed {
match self.poll_flush(cx)? {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}

Poll::Ready(Ok(()))
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.io.poll_ready(cx)
}

fn start_send(
Expand All @@ -98,48 +146,56 @@ where
trace!("sending frame");
let (frame, out_addr) = item;
let pin = self.get_mut();
C::encode(frame, &mut pin.writer)?;
pin.out_addr = out_addr;
pin.flushed = false;
trace!("frame encoded; length={}", pin.writer.len());
C::encode(frame, &mut pin.io.writer)?;
pin.io.out_addr = out_addr;
pin.io.flushed = false;
trace!("frame encoded; length={}", pin.io.writer.len());
Ok(())
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.flushed {
return Poll::Ready(Ok(()));
}
self.io.poll_flush(cx)
}

trace!("flushing frame; length={}", self.writer.len());
let Self {
ref mut socket,
ref mut out_addr,
ref mut writer,
..
} = *self;
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.io.poll_flush(cx)
}
}

let n = ready!(socket.poll_send_to(cx, writer, out_addr))?;
trace!("written {}", n);
impl<T, S, C> Sink<(Vec<NetlinkMessage<T>>, SocketAddr)> for NetlinkFramed<T, S, C>
where
T: NetlinkSerializable + Debug,
S: AsyncSocket,
C: NetlinkMessageCodec,
{
type Error = io::Error;

let wrote_all = n == self.writer.len();
self.writer.clear();
self.flushed = true;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.io.poll_ready(cx)
}

let res = if wrote_all {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"failed to write entire datagram to socket",
))
};
fn start_send(
self: Pin<&mut Self>,
item: (Vec<NetlinkMessage<T>>, SocketAddr),
) -> Result<(), Self::Error> {
trace!("sending frame");
let (frames, out_addr) = item;
let pin = self.get_mut();
for frame in frames {
C::encode(frame, &mut pin.io.writer)?;
}
pin.io.out_addr = out_addr;
pin.io.flushed = false;
trace!("frame encoded; length={}", pin.io.writer.len());
Ok(())
}

Poll::Ready(res)
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.io.poll_flush(cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.poll_flush(cx))?;
Poll::Ready(Ok(()))
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.io.poll_flush(cx)
}
}

Expand All @@ -155,14 +211,16 @@ impl<T, S, C> NetlinkFramed<T, S, C> {
/// See struct level documentation for more details.
pub fn new(socket: S) -> Self {
Self {
socket,
io: FramedIO {
socket,
out_addr: SocketAddr::new(0, 0),
in_addr: SocketAddr::new(0, 0),
reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY),
writer: BytesMut::with_capacity(INITIAL_WRITER_CAPACITY),
flushed: true,
},
msg_type: PhantomData,
codec: PhantomData,
out_addr: SocketAddr::new(0, 0),
in_addr: SocketAddr::new(0, 0),
reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY),
writer: BytesMut::with_capacity(INITIAL_WRITER_CAPACITY),
flushed: true,
}
}

Expand All @@ -174,7 +232,7 @@ impl<T, S, C> NetlinkFramed<T, S, C> {
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_ref(&self) -> &S {
&self.socket
&self.io.socket
}

/// Returns a mutable reference to the underlying I/O stream wrapped by
Expand All @@ -186,11 +244,11 @@ impl<T, S, C> NetlinkFramed<T, S, C> {
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_mut(&mut self) -> &mut S {
&mut self.socket
&mut self.io.socket
}

/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> S {
self.socket
self.io.socket
}
}

0 comments on commit a937896

Please sign in to comment.