diff --git a/src/router/asynchronous.rs b/src/router/asynchronous.rs index 550312c0..39823a7c 100644 --- a/src/router/asynchronous.rs +++ b/src/router/asynchronous.rs @@ -236,11 +236,13 @@ impl NlRouter { .nl_seq(self.next_seq().await) .nl_payload(nl_payload) .build()?; - let flags = *msg.nl_flags(); let seq = *msg.nl_seq(); - self.socket.send(&msg).await?; let (sender, receiver) = channel(1024); self.senders.lock().await.insert(seq, sender); + let flags = *msg.nl_flags(); + + self.socket.send(&msg).await?; + Ok(NlRouterReceiverHandle::new( receiver, Arc::clone(&self.senders), diff --git a/src/router/synchronous.rs b/src/router/synchronous.rs index ab65db06..25b34fb5 100644 --- a/src/router/synchronous.rs +++ b/src/router/synchronous.rs @@ -92,6 +92,16 @@ fn spawn_processing_thread(socket: Arc, senders: Senders) -> Pro } } } + } else { + for (seq, sender) in lock.iter() { + if sender + .send(Err(RouterError::BadSeqOrPid(m.clone()))) + .is_err() + { + error!("{}", RouterError::::ClosedChannel); + seqs_to_remove.insert(*seq); + } + } } } Err(e) => { @@ -221,11 +231,14 @@ impl NlRouter { .nl_seq(self.next_seq()) .nl_payload(nl_payload) .build()?; - let flags = *msg.nl_flags(); - let seq = *msg.nl_seq(); - self.socket.send(&msg)?; + let (sender, receiver) = channel(); + let seq = *msg.nl_seq(); self.senders.lock().insert(seq, sender); + let flags = *msg.nl_flags(); + + self.socket.send(&msg)?; + Ok(NlRouterReceiverHandle::new( receiver, Arc::clone(&self.senders), @@ -474,7 +487,6 @@ where self.next_is_ack = true; } else { self.next_is_none = true; - return None; } } else if self.next_is_ack { self.next_is_none = true; diff --git a/src/rtnl.rs b/src/rtnl.rs index 71381dca..06402941 100644 --- a/src/rtnl.rs +++ b/src/rtnl.rs @@ -488,13 +488,15 @@ mod test { .unwrap(); for msg in recv { let msg = msg.unwrap(); - let handle = msg.get_payload().unwrap().rtattrs.get_attr_handle(); - handle - .get_attr_payload_as_with_len::(Ifla::Ifname) - .unwrap(); - // Assert length of ethernet address - if let Ok(attr) = handle.get_attr_payload_as_with_len::>(Ifla::Address) { - assert_eq!(attr.len(), 6); + if let Some(payload) = msg.get_payload() { + let handle = payload.rtattrs.get_attr_handle(); + handle + .get_attr_payload_as_with_len::(Ifla::Ifname) + .unwrap(); + // Assert length of ethernet address + if let Ok(attr) = handle.get_attr_payload_as_with_len::>(Ifla::Address) { + assert_eq!(attr.len(), 6); + } } } } @@ -523,8 +525,11 @@ mod test { .unwrap(); for msg in recv { let msg = msg.unwrap(); - assert!(matches!(msg.get_payload().unwrap(), Tcmsg { .. })); - assert_eq!(msg.nl_type(), &Rtm::Newqdisc); + assert!(matches!(msg.get_payload(), Some(Tcmsg { .. }) | None)); + assert!(matches!( + msg.nl_type(), + Rtm::Newqdisc | Rtm::UnrecognizedConst(3) + )); } } }