From 6ed02934e39ddc913f22ad1d6197153c8e55fb4a Mon Sep 17 00:00:00 2001 From: h7x4 Date: Wed, 23 Aug 2023 15:51:48 +0200 Subject: [PATCH] Add support for modifying devicelist during execution --- Cargo.toml | 2 +- examples/cdc_acm_serial.rs | 18 +- examples/hid_keyboard.rs | 28 +-- examples/host.rs | 3 +- src/lib.rs | 375 ++++++++++++++++++++++++++++++------- src/util.rs | 19 +- 6 files changed, 353 insertions(+), 92 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f53fa56..de692d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ description = "A library to run USB/IP server" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tokio = { version = "1.22.0", features = ["rt", "net", "io-util"] } +tokio = { version = "1.22.0", features = ["rt", "net", "io-util", "sync"] } log = "0.4.17" num-traits = "0.2.15" num-derive = "0.3.3" diff --git a/examples/cdc_acm_serial.rs b/examples/cdc_acm_serial.rs index 880bb6c..06a9458 100644 --- a/examples/cdc_acm_serial.rs +++ b/examples/cdc_acm_serial.rs @@ -11,14 +11,16 @@ async fn main() { let handler = Arc::new(Mutex::new(Box::new(usbip::cdc::UsbCdcAcmHandler::new()) as Box)); - let server = usbip::UsbIpServer::new_simulated(vec![usbip::UsbDevice::new(0).with_interface( - usbip::ClassCode::CDC as u8, - usbip::cdc::CDC_ACM_SUBCLASS, - 0x00, - "Test CDC ACM", - usbip::cdc::UsbCdcAcmHandler::endpoints(), - handler.clone(), - )]); + let server = Arc::new(usbip::UsbIpServer::new_simulated(vec![ + usbip::UsbDevice::new(0).with_interface( + usbip::ClassCode::CDC as u8, + usbip::cdc::CDC_ACM_SUBCLASS, + 0x00, + "Test CDC ACM", + usbip::cdc::UsbCdcAcmHandler::endpoints(), + handler.clone(), + ), + ])); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240); tokio::spawn(usbip::server(addr, server)); diff --git a/examples/hid_keyboard.rs b/examples/hid_keyboard.rs index 923dd10..3b5c6df 100644 --- a/examples/hid_keyboard.rs +++ b/examples/hid_keyboard.rs @@ -12,19 +12,21 @@ async fn main() { Box::new(usbip::hid::UsbHidKeyboardHandler::new_keyboard()) as Box, )); - let server = usbip::UsbIpServer::new_simulated(vec![usbip::UsbDevice::new(0).with_interface( - usbip::ClassCode::HID as u8, - 0x00, - 0x00, - "Test HID", - vec![usbip::UsbEndpoint { - address: 0x81, // IN - attributes: 0x03, // Interrupt - max_packet_size: 0x08, // 8 bytes - interval: 10, - }], - handler.clone(), - )]); + let server = Arc::new(usbip::UsbIpServer::new_simulated(vec![ + usbip::UsbDevice::new(0).with_interface( + usbip::ClassCode::HID as u8, + 0x00, + 0x00, + "Test HID", + vec![usbip::UsbEndpoint { + address: 0x81, // IN + attributes: 0x03, // Interrupt + max_packet_size: 0x08, // 8 bytes + interval: 10, + }], + handler.clone(), + ), + ])); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240); tokio::spawn(usbip::server(addr, server)); diff --git a/examples/host.rs b/examples/host.rs index df0e07e..235947b 100644 --- a/examples/host.rs +++ b/examples/host.rs @@ -1,12 +1,13 @@ use env_logger; use std::net::*; +use std::sync::Arc; use std::time::Duration; use usbip; #[tokio::main] async fn main() { env_logger::init(); - let server = usbip::UsbIpServer::new_from_host(); + let server = Arc::new(usbip::UsbIpServer::new_from_host()); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240); tokio::spawn(usbip::server(addr, server)); diff --git a/src/lib.rs b/src/lib.rs index d98fdf9..afcf740 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ use std::sync::{Arc, Mutex}; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; +use tokio::sync::RwLock; pub mod cdc; mod consts; @@ -31,14 +32,19 @@ pub use setup::*; pub use util::*; /// Main struct of a USB/IP server +#[derive(Default)] pub struct UsbIpServer { - devices: Vec, + available_devices: RwLock>, + used_devices: RwLock>, } impl UsbIpServer { /// Create a [UsbIpServer] with simulated devices pub fn new_simulated(devices: Vec) -> Self { - Self { devices } + Self { + available_devices: RwLock::new(devices), + used_devices: RwLock::new(HashMap::new()), + } } fn with_devices(device_list: Vec>) -> Vec { @@ -48,7 +54,7 @@ impl UsbIpServer { let open_device = match dev.open() { Ok(dev) => dev, Err(err) => { - println!("Impossible to share {:?}: {}", dev, err); + warn!("Impossible to share {:?}: {}", dev, err); continue; } }; @@ -180,10 +186,11 @@ impl UsbIpServer { devs.push(d) } Self { - devices: Self::with_devices(devs), + available_devices: RwLock::new(Self::with_devices(devs)), + ..Default::default() } } - Err(_) => Self { devices: vec![] }, + Err(_) => Default::default(), } } @@ -198,22 +205,61 @@ impl UsbIpServer { devs.push(d) } Self { - devices: Self::with_devices(devs), + available_devices: RwLock::new(Self::with_devices(devs)), + ..Default::default() } } - Err(_) => Self { devices: vec![] }, + Err(_) => Default::default(), + } + } + + pub async fn add_device(&self, device: UsbDevice) { + self.available_devices.write().await.push(device); + } + + pub async fn remove_device(&self, bus_id: &str) -> Result<()> { + let mut available_devices = self.available_devices.write().await; + + if let Some(device) = available_devices.iter().position(|d| d.bus_id == bus_id) { + available_devices.remove(device); + Ok(()) + } else if let Some(device) = self + .used_devices + .read() + .await + .values() + .find(|d| d.bus_id == bus_id) + { + Err(std::io::Error::new( + ErrorKind::Other, + format!("Device {} is in use", device.bus_id), + )) + } else { + Err(std::io::Error::new( + ErrorKind::NotFound, + format!("Device {} not found", bus_id), + )) } } } -async fn handler( +pub async fn handler( mut socket: &mut T, server: Arc, ) -> Result<()> { - let mut current_import_device = None; + let mut current_import_device_id: Option = None; loop { let mut command = [0u8; 4]; if let Err(err) = socket.read_exact(&mut command).await { + if let Some(dev_id) = current_import_device_id { + let mut used_devices = server.used_devices.write().await; + let mut available_devices = server.available_devices.write().await; + match used_devices.remove(&dev_id) { + Some(dev) => available_devices.push(dev), + None => unreachable!(), + } + } + if err.kind() == ErrorKind::UnexpectedEof { info!("Remote closed the connection"); return Ok(()); @@ -221,16 +267,24 @@ async fn handler( return Err(err); } } + + let used_devices = server.used_devices.read().await; + let mut current_import_device = current_import_device_id + .clone() + .and_then(|ref id| used_devices.get(id)); + match command { [0x01, 0x11, 0x80, 0x05] => { trace!("Got OP_REQ_DEVLIST"); let _status = socket.read_u32().await?; + let devices = server.available_devices.read().await; + // OP_REP_DEVLIST socket.write_u32(0x01110005).await?; socket.write_u32(0).await?; - socket.write_u32(server.devices.len() as u32).await?; - for dev in &server.devices { + socket.write_u32(devices.len() as u32).await?; + for dev in devices.iter() { dev.write_dev_with_interfaces(&mut socket).await?; } trace!("Sent OP_REP_DEVLIST"); @@ -240,13 +294,22 @@ async fn handler( let _status = socket.read_u32().await?; let mut bus_id = [0u8; 32]; socket.read_exact(&mut bus_id).await?; + + current_import_device_id = None; current_import_device = None; - for device in &server.devices { - let mut expected = device.bus_id.as_bytes().to_vec(); + std::mem::drop(used_devices); + + let mut used_devices = server.used_devices.write().await; + let mut available_devices = server.available_devices.write().await; + for (i, dev) in available_devices.iter().enumerate() { + let mut expected = dev.bus_id.as_bytes().to_vec(); expected.resize(32, 0); if expected == bus_id { - current_import_device = Some(device); - info!("Found device {:?}", device.path); + let dev = available_devices.remove(i); + let dev_id = dev.bus_id.clone(); + used_devices.insert(dev.bus_id.clone(), dev); + current_import_device_id = dev_id.clone().into(); + current_import_device = Some(used_devices.get(&dev_id).unwrap()); break; } } @@ -347,6 +410,22 @@ async fn handler( let mut padding = [0u8; 6 * 4]; socket.read_exact(&mut padding).await?; + std::mem::drop(used_devices); + + let mut used_devices = server.used_devices.write().await; + let mut available_devices = server.available_devices.write().await; + + let dev = match current_import_device_id + .clone() + .and_then(|ref k| used_devices.remove(k)) + { + Some(dev) => dev, + None => unreachable!(), + }; + + available_devices.push(dev); + current_import_device_id = None; + // USBIP_RET_UNLINK // command socket.write_u32(0x4).await?; @@ -356,7 +435,7 @@ async fn handler( socket.write_u32(0).await?; // status socket.write_u32(0).await?; - socket.write_all(&mut padding).await?; + socket.write_all(&padding).await?; } _ => warn!("Got unknown command {:?}", command), } @@ -364,16 +443,15 @@ async fn handler( } /// Spawn a USB/IP server at `addr` using [TcpListener] -pub async fn server(addr: SocketAddr, server: UsbIpServer) { +pub async fn server(addr: SocketAddr, server: Arc) { let listener = TcpListener::bind(addr).await.expect("bind to addr"); let server = async move { - let usbip_server = Arc::new(server); loop { match listener.accept().await { Ok((mut socket, _addr)) => { info!("Got connection from {:?}", socket.peer_addr()); - let new_server = usbip_server.clone(); + let new_server = server.clone(); tokio::spawn(async move { let res = handler(&mut socket, new_server).await; info!("Handler ended with {:?}", res); @@ -391,12 +469,46 @@ pub async fn server(addr: SocketAddr, server: UsbIpServer) { #[cfg(test)] mod test { + use tokio::{net::TcpStream, task::JoinSet}; + use super::*; use crate::util::tests::*; + fn new_server_with_single_device() -> UsbIpServer { + UsbIpServer::new_simulated(vec![UsbDevice::new(0).with_interface( + ClassCode::CDC as u8, + cdc::CDC_ACM_SUBCLASS, + 0x00, + "Test CDC ACM", + cdc::UsbCdcAcmHandler::endpoints(), + Arc::new(Mutex::new( + Box::new(cdc::UsbCdcAcmHandler::new()) as Box + )), + )]) + } + + fn op_req_import(bus_id: u32) -> Vec { + let mut req = vec![0x01, 0x11, 0x80, 0x03, 0x00, 0x00, 0x00, 0x00]; + let mut path = bus_id.to_string().as_bytes().to_vec(); + path.resize(32, 0); + req.extend(path); + req + } + + async fn attach_device(connection: &mut TcpStream, bus_id: u32) -> u32 { + let req = op_req_import(bus_id); + connection.write_all(req.as_slice()).await.unwrap(); + connection.read_u32().await.unwrap(); + let result = connection.read_u32().await.unwrap(); + if result == 0 { + connection.read_exact(&mut vec![0; 0x138]).await.unwrap(); + } + return result; + } + #[tokio::test] async fn req_empty_devlist() { - let server = UsbIpServer { devices: vec![] }; + let server = UsbIpServer::new_simulated(vec![]); // OP_REQ_DEVLIST let mut mock_socket = MockSocket::new(vec![0x01, 0x11, 0x80, 0x05, 0x00, 0x00, 0x00, 0x00]); @@ -410,20 +522,7 @@ mod test { #[tokio::test] async fn req_sample_devlist() { - let intf_handler = Arc::new(Mutex::new( - Box::new(cdc::UsbCdcAcmHandler::new()) as Box - )); - let server = UsbIpServer { - devices: vec![UsbDevice::new(0).with_interface( - ClassCode::CDC as u8, - cdc::CDC_ACM_SUBCLASS, - 0x00, - "Test CDC ACM", - cdc::UsbCdcAcmHandler::endpoints(), - intf_handler.clone(), - )], - }; - + let server = new_server_with_single_device(); // OP_REQ_DEVLIST let mut mock_socket = MockSocket::new(vec![0x01, 0x11, 0x80, 0x05, 0x00, 0x00, 0x00, 0x00]); handler(&mut mock_socket, Arc::new(server)).await.ok(); @@ -436,52 +535,192 @@ mod test { #[tokio::test] async fn req_import() { - let intf_handler = Arc::new(Mutex::new( - Box::new(cdc::UsbCdcAcmHandler::new()) as Box - )); - let server = UsbIpServer { - devices: vec![UsbDevice::new(0).with_interface( - ClassCode::CDC as u8, - cdc::CDC_ACM_SUBCLASS, - 0x00, - "Test CDC ACM", - cdc::UsbCdcAcmHandler::endpoints(), - intf_handler.clone(), - )], - }; + let server = new_server_with_single_device(); // OP_REQ_IMPORT - let mut req = vec![0x01, 0x11, 0x80, 0x03, 0x00, 0x00, 0x00, 0x00]; - let mut path = "0".as_bytes().to_vec(); - path.resize(32, 0); - req.extend(path); + let req = op_req_import(0); let mut mock_socket = MockSocket::new(req); handler(&mut mock_socket, Arc::new(server)).await.ok(); // OP_REQ_IMPORT assert_eq!(mock_socket.output.len(), 0x140); } + #[tokio::test] + async fn add_and_remove_10_devices() { + let server_ = Arc::new(UsbIpServer::new_simulated(vec![])); + let addr = get_free_address().await; + tokio::spawn(server(addr, server_.clone())); + + let mut join_set = JoinSet::new(); + let devices = (0..10).map(UsbDevice::new).collect::>(); + + for device in devices.iter() { + let new_server = server_.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.add_device(new_device).await; + }); + } + + for device in devices.iter() { + let new_server = server_.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.remove_device(&new_device.bus_id).await.unwrap(); + }); + } + + while join_set.join_next().await.is_some() {} + + let device_len = server_.clone().available_devices.read().await.len(); + + assert_eq!(device_len, 0); + } + + #[tokio::test] + async fn send_usb_traffic_while_adding_and_removing_devices() { + let server_ = Arc::new(new_server_with_single_device()); + + let addr = get_free_address().await; + tokio::spawn(server(addr, server_.clone())); + + let cmd_loop_handle = tokio::spawn(async move { + let mut connection = poll_connect(addr).await; + let result = attach_device(&mut connection, 0).await; + assert_eq!(result, 0); + + let cdc_loopback_bulk_cmd = vec![ + 0x00, 0x00, 0x00, 0x01, // command + 0x00, 0x00, 0x00, 0x01, // seq num + 0x00, 0x00, 0x00, 0x00, // dev id + 0x00, 0x00, 0x00, 0x00, // OUT + 0x00, 0x00, 0x00, 0x02, // ep 2 + 0x00, 0x00, 0x00, 0x00, // transfer flags + 0x00, 0x00, 0x00, 0x08, // transfer buffer length 8 + 0x00, 0x00, 0x00, 0x00, // start frame + 0x00, 0x00, 0x00, 0x00, // number of packets + 0x00, 0x00, 0x00, 0x00, // interval + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Empty setup packet + 0x01, 0x02, 0x03, 0x04, // data + 0x05, 0x06, 0x07, 0x08, // data + ]; + loop { + connection + .write_all(cdc_loopback_bulk_cmd.as_slice()) + .await + .unwrap(); + let mut result = vec![0; 4 * 12]; + connection.read_exact(&mut result).await.unwrap(); + } + }); + + let add_and_remove_device_handle = tokio::spawn(async move { + let mut join_set = JoinSet::new(); + let devices = (1..4).map(UsbDevice::new).collect::>(); + + loop { + for device in devices.iter() { + let new_server = server_.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.add_device(new_device).await; + }); + } + + for device in devices.iter() { + let new_server = server_.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.remove_device(&new_device.bus_id).await.unwrap(); + }); + } + while join_set.join_next().await.is_some() {} + tokio::time::sleep(tokio::time::Duration::from_millis(20)).await; + } + }); + + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + cmd_loop_handle.abort(); + add_and_remove_device_handle.abort(); + } + + #[tokio::test] + async fn only_single_connection_allowed_to_device() { + let server_ = Arc::new(new_server_with_single_device()); + + let addr = get_free_address().await; + tokio::spawn(server(addr, server_.clone())); + + let mut first_connection = poll_connect(addr).await; + let mut second_connection = TcpStream::connect(addr).await.unwrap(); + + let result = attach_device(&mut first_connection, 0).await; + assert_eq!(result, 0); + + let result = attach_device(&mut second_connection, 0).await; + assert_eq!(result, 1); + } + + #[tokio::test] + async fn device_gets_released_on_cmd_unlink() { + let server_ = Arc::new(new_server_with_single_device()); + + let addr = get_free_address().await; + tokio::spawn(server(addr, server_.clone())); + + let mut connection = poll_connect(addr).await; + + let result = attach_device(&mut connection, 0).await; + assert_eq!(result, 0); + + let unlink_req = vec![ + 0x00, 0x00, 0x00, 0x02, // cmd + 0x00, 0x00, 0x00, 0x01, // seq_num + 0x00, 0x00, 0x00, 0x00, // dev_id + 0x00, 0x00, 0x00, 0x00, // direction + 0x00, 0x00, 0x00, 0x00, // ep + 0x00, 0x00, 0x00, 0x00, // seq_num_submit + 0x00, 0x00, 0x00, 0x00, // padding + 0x00, 0x00, 0x00, 0x00, // padding + 0x00, 0x00, 0x00, 0x00, // padding + 0x00, 0x00, 0x00, 0x00, // padding + 0x00, 0x00, 0x00, 0x00, // padding + 0x00, 0x00, 0x00, 0x00, // padding + ]; + connection.write_all(unlink_req.as_slice()).await.unwrap(); + connection.read_exact(&mut vec![0; 4 * 5]).await.unwrap(); + let result = connection.read_u32().await.unwrap(); + connection.read_exact(&mut vec![0; 4 * 6]).await.unwrap(); + assert_eq!(result, 0); + + let result = attach_device(&mut connection, 0).await; + assert_eq!(result, 0); + } + + #[tokio::test] + async fn device_gets_released_on_closed_socket() { + let server_ = Arc::new(new_server_with_single_device()); + + let addr = get_free_address().await; + tokio::spawn(server(addr, server_.clone())); + + let mut connection = poll_connect(addr).await; + let result = attach_device(&mut connection, 0).await; + assert_eq!(result, 0); + + std::mem::drop(connection); + + let mut connection = TcpStream::connect(addr).await.unwrap(); + let result = attach_device(&mut connection, 0).await; + assert_eq!(result, 0); + } + #[tokio::test] async fn req_import_get_device_desc() { - let intf_handler = Arc::new(Mutex::new( - Box::new(cdc::UsbCdcAcmHandler::new()) as Box - )); - let server = UsbIpServer { - devices: vec![UsbDevice::new(0).with_interface( - ClassCode::CDC as u8, - cdc::CDC_ACM_SUBCLASS, - 0x00, - "Test CDC ACM", - cdc::UsbCdcAcmHandler::endpoints(), - intf_handler.clone(), - )], - }; + let server = new_server_with_single_device(); // OP_REQ_IMPORT - let mut req = vec![0x01, 0x11, 0x80, 0x03, 0x00, 0x00, 0x00, 0x00]; - let mut path = "0".as_bytes().to_vec(); - path.resize(32, 0); - req.extend(path); + let mut req = op_req_import(0); // USBIP_CMD_SUBMIT req.extend(vec![ 0x00, 0x00, 0x00, 0x01, // command diff --git a/src/util.rs b/src/util.rs index d0d579b..508260a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -24,10 +24,14 @@ pub fn verify_descriptor(desc: &[u8]) { pub(crate) mod tests { use std::{ io::*, + net::SocketAddr, pin::Pin, task::{Context, Poll}, }; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::{TcpListener, TcpStream}, + }; pub(crate) struct MockSocket { pub input: Cursor>, @@ -73,4 +77,17 @@ pub(crate) mod tests { Poll::Ready(Ok(())) } } + + pub(crate) async fn get_free_address() -> SocketAddr { + let stream = TcpListener::bind("127.0.0.1:0").await.unwrap(); + stream.local_addr().unwrap() + } + + pub(crate) async fn poll_connect(addr: SocketAddr) -> TcpStream { + loop { + if let Ok(stream) = TcpStream::connect(addr).await { + return stream; + } + } + } }