From c27ce4c3b85d0d2b83eeb585eebce757f1459a19 Mon Sep 17 00:00:00 2001 From: Tomas Jakstas Date: Sun, 23 Apr 2023 00:28:07 +0300 Subject: [PATCH] Server load balance, multithread support. Initial flow control. --- CHANGELOG.md | 13 +- Cargo.lock | 15 +- Cargo.toml | 7 +- README.md | 113 +++-- binary/Cargo.toml | 6 +- binary/src/cli.rs | 3 + binary/src/io.rs | 4 +- binary/src/main.rs | 66 ++- binary/src/socket.rs | 247 ++++++++-- build | 4 +- src/client/client.rs | 148 +++--- src/client/extensions.rs | 16 +- src/config.rs | 15 +- src/encryption.rs | 4 +- src/error.rs | 10 +- src/flow_control.rs | 385 +++++++++++---- src/key_management.rs | 2 +- src/server/connection.rs | 127 +++-- src/server/extensions.rs | 4 +- src/server/server.rs | 996 ++++++++++++++++++++++++++------------- src/server/validation.rs | 2 +- src/socket.rs | 62 ++- src/storage.rs | 208 ++++---- test | 4 +- 24 files changed, 1659 insertions(+), 802 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0161f4e..1bbd7da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,11 +2,20 @@ All visible changes will be documented here. This project adheres to Semantic Versioning. -## [Unreleased] +## [0.2.0] - 2023-04-22 -- Default retry timeout in milli seconds +### Added + +- Server load balance connections +- Server use multi threads by default - Secure server directory so that files do not escape provided directory - Retry window size from last acknoledged +- Initial flow control + +### Changed + +- Default retry timeout from seconds to milliseconds +- Default retry timeout from 1000ms to 80ms ## [0.1.0] - 2023-02-15 diff --git a/Cargo.lock b/Cargo.lock index 27cf932..522a04f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -864,6 +864,16 @@ dependencies = [ "serde", ] +[[package]] +name = "socket2" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc8d618c6641ae355025c449427f9e96b98abf99a772be3cef6708d15c77147a" +dependencies = [ + "libc", + "windows-sys 0.45.0", +] + [[package]] name = "spin" version = "0.9.4" @@ -950,7 +960,7 @@ dependencies = [ [[package]] name = "tftp" -version = "0.1.0" +version = "0.2.0" dependencies = [ "arrayvec", "base64", @@ -965,7 +975,7 @@ dependencies = [ [[package]] name = "tftp-binary" -version = "0.1.0" +version = "0.2.0" dependencies = [ "cargo-deb", "clap", @@ -973,6 +983,7 @@ dependencies = [ "log", "polling", "rand", + "socket2", "tftp", ] diff --git a/Cargo.toml b/Cargo.toml index d236eac..4125ef1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tftp" -version = "0.1.0" +version = "0.2.0" edition = "2021" [profile.release] @@ -28,9 +28,10 @@ rand = "0.8" vfs = "0.9" [features] -default = ["std", "alloc", "encryption"] +default = ["std", "alloc", "encryption", "multi_thread"] std = ["x25519-dalek?/std"] alloc = ["chacha20poly1305?/alloc"] encryption = ["x25519-dalek", "chacha20poly1305", "base64"] seek = [] -stack_large_window = [] \ No newline at end of file +stack_large_window = [] +multi_thread = [] \ No newline at end of file diff --git a/README.md b/README.md index 79cb3be..3cf30b7 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,8 @@ encrypt traffic or data. Install deb ``` -wget https://github.com/songokas/tftp/releases/download/v0.1.0/tftp_0.1.0_amd64.deb \ - && sudo apt install ./tftp_0.1.0_amd64.deb +wget https://github.com/songokas/tftp/releases/download/v0.2.0/tftp_0.2.0_amd64.deb \ + && sudo apt install ./tftp_0.2.0_amd64.deb ``` Download binary @@ -118,64 +118,63 @@ echo "hello" | tftp send 127.0.0.1:9000 /dev/stdin --allow-server-port-change ### Stats ``` - Send 100Mb - +-----------------------------------------------------------------------------------------------------------+ - 18 |-+ + + + + + + +-| - | | - | x | - | | - 16 |-+ +-| - | | - | | - 14 |-+ +-| - | | - | | - 12 |-+ +-| -Time | | - | | - 10 |-+x +-| - | | - | | - 8 |-+ +-| - | x | - | | - 6 |-+ x +-| - | x x | - | + x x x + + x + + + x | - 4 +-----------------------------------------------------------------------------------------------------------+ - 0 10 20 30 40 50 60 - WindowSize - + Send 100Mb + +----------------------------------------------------------------------------------+ + 18 |-+ + + + + + + +-| + | | + | x | + | | + 16 |-+ +-| + | | + | | + 14 |-+ +-| + | | + | | + 12 |-+ +-| +Time | | + | | + 10 |-+x +-| + | | + | | + 8 |-+ +-| + | x | + | | + 6 |-+ x +-| + | x x | + | + x x x + + x + + + x | + 4 +----------------------------------------------------------------------------------+ + 0 10 20 30 40 50 60 + WindowSize - Receive 100Mb - +-----------------------------------------------------------------------------------------------------------+ - 18 |-+ + + + + + + +-| - | | - | x | - 16 |-+ +-| - | | - 14 |-+ +-| - | | - | | - 12 |-+ +-| - | | - | x | -Time 10 |-+ +-| - | | - | | - 8 |-+ +-| - | | - 6 |-+ x +-| - | | - | x x | - 4 |-+ x x x x +-| - | x x | - | + + + + + + | - 2 +-----------------------------------------------------------------------------------------------------------+ - 0 10 20 30 40 50 60 - WindowSize + Receive 100Mb + +----------------------------------------------------------------------------------+ + 18 |-+ + + + + + + +-| + | | + | x | + 16 |-+ +-| + | | + 14 |-+ +-| + | | + | | + 12 |-+ +-| + | | + | x | +Time 10 |-+ +-| + | | + | | + 8 |-+ +-| + | | + 6 |-+ x +-| + | | + | x x | + 4 |-+ x x x x +-| + | x x | + | + + + + + + | + 2 +----------------------------------------------------------------------------------+ + 0 10 20 30 40 50 60 + WinddowSize ``` diff --git a/binary/Cargo.toml b/binary/Cargo.toml index 0633311..98e9c68 100644 --- a/binary/Cargo.toml +++ b/binary/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tftp-binary" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Tomas Jakstas "] @@ -21,18 +21,20 @@ tftp = { path = "..", default-features = false} rand = { version = "0.8", optional = false, default-features = false, features = ["getrandom"]} # std compile time requirement polling = "2" +socket2 = { version = "~0.5.1", features = ["all"] } [dev-dependencies] rand = "0.8" env_logger = "0.10" [features] -default = ["std", "alloc", "encryption"] +default = ["std", "alloc", "encryption", "multi_thread"] std = ["tftp/std", "polling/std"] alloc = ["tftp/alloc"] encryption = ["tftp/encryption"] seek = ["tftp/seek"] stack_large_window = ["tftp/stack_large_window"] +multi_thread = ["tftp/multi_thread"] [[bin]] name = "tftp" diff --git a/binary/src/cli.rs b/binary/src/cli.rs index 1af0b22..8a90d84 100644 --- a/binary/src/cli.rs +++ b/binary/src/cli.rs @@ -155,6 +155,9 @@ pub enum Commands { #[arg(long, default_value_t = 1 as u64, value_parser = clap::value_parser!(u64).range(1..=(MAX_BLOCKS_READER as u64)))] max_blocks_in_queue: u64, + + #[arg(long, default_value_t = true)] + ignore_rate_control: bool, }, Receive { diff --git a/binary/src/io.rs b/binary/src/io.rs index c1da83b..6f730d3 100644 --- a/binary/src/io.rs +++ b/binary/src/io.rs @@ -21,7 +21,7 @@ pub fn create_reader(path: &FilePath) -> BoxedResult<(Option, StdCompatFile let file_size = file.metadata().map_err(from_io_err)?.len(); #[cfg(not(feature = "std"))] let file = StdCompatFile(file); - Ok(((file_size > 0).then(|| file_size), file)) + Ok(((file_size > 0).then_some(file_size), file)) } pub fn create_server_reader( @@ -230,6 +230,6 @@ impl tftp::std_compat::io::BufRead for StdBufReader { pub fn std_into_path(path: PathBuf) -> FilePath { let mut f = FilePath::new(); // TODO alloc in stack - let _result = f.push_str(&path.to_string_lossy().to_string()); + f.push_str(&path.to_string_lossy()); f } diff --git a/binary/src/main.rs b/binary/src/main.rs index c39d91b..a732c93 100644 --- a/binary/src/main.rs +++ b/binary/src/main.rs @@ -23,8 +23,10 @@ use tftp::{ error::{BoxedResult, EncryptionError}, key_management::{append_to_known_hosts, get_from_known_hosts}, server::server, + socket::Socket, std_compat::{ io::{Read, Seek, Write}, + net::SocketAddr, time::Instant, }, types::FilePath, @@ -51,12 +53,14 @@ fn main() -> BinResult<()> { remote_path, max_blocks_in_queue, config, + ignore_rate_control, } => start_send( local_path, remote_path, config, max_blocks_in_queue as u16, create_reader, + ignore_rate_control, ) .map(|_| ()), @@ -75,11 +79,13 @@ fn main() -> BinResult<()> { .map(|_| ()), Commands::Server(config) => { let config = config.try_into()?; + // init_logger(config.listen); server( config, create_server_reader, create_server_writer, create_socket, + create_bound_socket, instant_callback, OsRng, ) @@ -94,13 +100,15 @@ fn start_send( config: ClientCliConfig, max_blocks_in_queue: u16, create_reader: CreateReader, + ignore_rate_control: bool, ) -> BinResult where R: Read + Seek, CreateReader: Fn(&FilePath) -> BoxedResult<(Option, R)>, { - let socket = - create_socket(config.listen.as_str(), 1).map_err(|e| BinError::from(e.to_string()))?; + let socket = create_socket(config.listen.as_str(), 1, false) + .map_err(|e| BinError::from(e.to_string()))?; + // init_logger(socket.local_addr().expect("local address")); let options = ConnectionOptions { block_size: config.block_size as u16, @@ -126,7 +134,7 @@ where Some(p) => p, None => Path::new(local_path.as_str()) .file_name() - .ok_or_else(|| "Invalid local filename")? + .ok_or("Invalid local filename")? .to_string_lossy() .parse() .expect("Invalid local file name"), @@ -140,10 +148,11 @@ where socket, instant_callback, OsRng, + ignore_rate_control, ) .map(|(total, _remote_key)| { debug!("Client total sent {}", total); - let file = known_hosts_file.as_ref().map(|s| s.as_str()); + let file = known_hosts_file.as_deref(); handle_hosts_file(file, _remote_key, &endpoint); total }) @@ -161,7 +170,10 @@ where W: Write + Seek, CreateWriter: Fn(&FilePath) -> BoxedResult, { - let socket = create_socket(&config.listen, 1).map_err(|e| BinError::from(e.to_string()))?; + let socket = + create_socket(&config.listen, 1, false).map_err(|e| BinError::from(e.to_string()))?; + // init_logger(socket.local_addr().expect("local address")); + let options = ConnectionOptions { block_size: config.block_size as u16, retry_packet_after_timeout: Duration::from_millis(config.retry_timeout), @@ -187,7 +199,7 @@ where Some(p) => p, None => Path::new(remote_path.as_str()) .file_name() - .ok_or_else(|| "Invalid remote file name")? + .ok_or("Invalid remote file name")? .to_string_lossy() .parse() .expect("Invalid remote file name"), @@ -204,7 +216,7 @@ where ) .map(|(total, _remote_key)| { debug!("Client total received {}", total); - let file = known_hosts_file.as_ref().map(|s| s.as_str()); + let file = known_hosts_file.as_deref(); handle_hosts_file(file, _remote_key, &endpoint); total }) @@ -252,6 +264,24 @@ fn handle_hosts_file( }; } +#[allow(dead_code)] +fn init_logger(local_addr: SocketAddr) { + use std::io::Write; + // builder using box + Builder::from_env(Env::default().default_filter_or("debug")) + .format(move |buf, record| { + writeln!( + buf, + "[{local_addr} {} {}]: {}", + record.level(), + buf.timestamp_micros(), + record.args() + ) + }) + .try_init() + .unwrap_or_default(); +} + #[cfg(test)] mod tests { use std::{ @@ -279,6 +309,7 @@ mod tests { #[cfg(feature = "encryption")] #[test] fn test_client_full_encryption() { + // env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init(); let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); let key: [u8; 32] = bytes.try_into().unwrap(); let server_private_key: PrivateKey = key.into(); @@ -367,7 +398,9 @@ mod tests { #[test] fn test_client_no_encryption() { - // env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); + // env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")) + // .format_timestamp_micros() + // .init(); client_send(EncryptionLevel::None, None, None, None); client_receive(EncryptionLevel::None, None, None, None); } @@ -401,7 +434,7 @@ mod tests { #[cfg(feature = "encryption")] let client_private_key = _client_private_key .as_ref() - .map(|k| encode_private_key(&k).unwrap()); + .map(|k| encode_private_key(k).unwrap()); let client = { let d = bytes.clone(); spawn(move || { @@ -453,7 +486,7 @@ mod tests { #[cfg(feature = "encryption")] let client_private_key = _client_private_key .as_ref() - .map(|k| encode_private_key(&k).unwrap()); + .map(|k| encode_private_key(k).unwrap()); let client = { let d = expected_data.clone(); spawn(move || { @@ -470,7 +503,7 @@ mod tests { }) }; let result = client.join().unwrap(); - assert!(result.is_ok(), "{:?}", result); + assert!(result.is_ok(), "{result:?}"); assert_eq!(result.unwrap(), expected_size); assert_eq!(&bytes, expected_data.lock().unwrap().get_ref()); } @@ -483,7 +516,7 @@ mod tests { #[cfg(feature = "encryption")] private_key: Option, ) -> BinResult { let cli_config = ClientCliConfig { - endpoint: format!("127.0.0.1:{}", server_port).parse().unwrap(), + endpoint: format!("127.0.0.1:{server_port}").parse().unwrap(), listen: "127.0.0.1:0".parse().unwrap(), request_timeout: 1000, block_size: 100, @@ -506,7 +539,7 @@ mod tests { let create_reader = |_path: &FilePath| Ok((Some(bytes.len() as u64), CursorReader::new(bytes.clone()))); - start_send(local_file, remote_file, cli_config, 4, create_reader) + start_send(local_file, remote_file, cli_config, 4, create_reader, false) } fn start_receive_file( @@ -517,7 +550,7 @@ mod tests { #[cfg(feature = "encryption")] private_key: Option, ) -> BinResult { let cli_config = ClientCliConfig { - endpoint: format!("127.0.0.1:{}", server_port).parse().unwrap(), + endpoint: format!("127.0.0.1:{server_port}").parse().unwrap(), listen: "127.0.0.1:0".parse().unwrap(), request_timeout: 1000, block_size: 100, @@ -549,7 +582,7 @@ mod tests { private_key: Option, authorized_keys: Option, ) -> DefaultBoxedResult { - let listen: std::net::SocketAddr = format!("127.0.0.1:{}", server_port).parse().unwrap(); + let listen: std::net::SocketAddr = format!("127.0.0.1:{server_port}").parse().unwrap(); #[cfg(not(feature = "std"))] let listen = std_to_socket_addr(listen); let config = ServerConfig { @@ -558,7 +591,7 @@ mod tests { allow_overwrite: false, max_queued_blocks_reader: 4, max_queued_blocks_writer: 4, - request_timeout: Duration::from_millis(100), + request_timeout: Duration::from_millis(1000), max_connections: 10, max_file_size: 2000, max_block_size: MAX_DATA_BLOCK_SIZE, @@ -583,6 +616,7 @@ mod tests { create_reader, create_writer, create_socket, + create_bound_socket, instant_callback, OsRng, ) diff --git a/binary/src/socket.rs b/binary/src/socket.rs index c560291..d0ff208 100644 --- a/binary/src/socket.rs +++ b/binary/src/socket.rs @@ -1,7 +1,14 @@ use core::time::Duration; +use std::net::UdpSocket; -use log::info; -use polling::{Event, Poller}; +#[cfg(not(target_family = "windows"))] +use std::os::fd::{AsFd, AsRawFd, RawFd}; +#[cfg(target_family = "windows")] +use std::os::windows::io::{AsRawSocket, RawSocket}; + +use log::{info, trace}; +use polling::{Event, Poller, Source}; +use socket2::{Domain, Protocol, SockAddr, Type}; use tftp::{ config::ConnectionOptions, encryption::EncryptionLevel, @@ -16,80 +23,207 @@ use crate::{ io::from_io_err, }; -pub fn create_socket(listen: &str, socket_id: usize) -> BoxedResult { - let socket = std::net::UdpSocket::bind(listen).map_err(from_io_err)?; +pub fn create_socket(listen: &str, socket_id: usize, reuse: bool) -> BoxedResult { + let address: std::net::SocketAddr = listen + .parse() + .map_err(|_| io::Error::from(io::ErrorKind::AddrNotAvailable))?; + let socket = socket2::Socket::new( + Domain::for_address(address), + Type::DGRAM, + Protocol::UDP.into(), + ) + .map_err(from_io_err)?; + + socket.set_reuse_address(reuse).map_err(from_io_err)?; + #[cfg(not(target_family = "windows"))] + socket.set_reuse_port(reuse).map_err(from_io_err)?; + socket.bind(&address.into()).map_err(from_io_err)?; + + let socket: UdpSocket = socket.into(); socket.set_nonblocking(true).map_err(from_io_err)?; - let local_addr = socket.local_addr().map_err(from_io_err)?; - let poller = if socket_id > 0 { - let poller = Poller::new().map_err(from_io_err)?; - poller - .add(&socket, Event::readable(socket_id)) - .map_err(from_io_err)?; - poller.into() - } else { - None - }; + + let poller = Poller::new().map_err(from_io_err)?; + poller + .add(&socket, Event::readable(socket_id)) + .map_err(from_io_err)?; let socket = StdSocket { socket, poller, socket_id, + events: Vec::new(), + }; + Ok(socket) +} + +pub fn create_bound_socket( + listen: &str, + socket_id: usize, + endpoint: SocketAddr, +) -> BoxedResult { + let endpoint = socket_addr_to_std(endpoint); + let socket = socket2::Socket::new( + Domain::for_address(endpoint), + Type::DGRAM, + Protocol::UDP.into(), + ) + .map_err(from_io_err)?; + + socket.set_reuse_address(true).map_err(from_io_err)?; + #[cfg(not(target_family = "windows"))] + socket.set_reuse_port(true).map_err(from_io_err)?; + + let address: std::net::SocketAddr = listen + .parse() + .map_err(|_| io::Error::from(io::ErrorKind::AddrNotAvailable))?; + socket.bind(&address.into()).unwrap(); + + let socket: UdpSocket = socket.into(); + socket.set_nonblocking(true).map_err(from_io_err)?; + socket.connect(endpoint).map_err(from_io_err)?; + let poller = Poller::new().map_err(from_io_err)?; + poller + .add(&socket, Event::readable(socket_id)) + .map_err(from_io_err)?; + let socket = StdBoundSocket { + socket, + poller, + socket_id, }; Ok(socket) } pub struct StdSocket { - socket: std::net::UdpSocket, - poller: Option, + socket: UdpSocket, + poller: Poller, socket_id: usize, + // TODO alloc in stack + events: Vec, } impl Socket for StdSocket { fn recv_from( - &self, - buf: &mut DataBuffer, + &mut self, + buff: &mut DataBuffer, wait_for: Option, ) -> io::Result<(usize, SocketAddr)> { - if let (Some(d), Some(poller)) = (wait_for, &self.poller) { - poller - .modify(&self.socket, Event::readable(self.socket_id)) - .map_err(from_io_err)?; - // TODO alloc in stack - let mut events = Vec::new(); - poller.wait(&mut events, d.into()).map_err(from_io_err)?; - } + self.modify_interest(self.socket_id(), self.as_raw_fd())?; + self.poller + .wait(&mut self.events, wait_for.or_else(|| Duration::ZERO.into())) + .map_err(from_io_err)?; #[cfg(feature = "std")] - let result = self.socket.recv_from(buf); + let result = self.socket.recv_from(buff); #[cfg(not(feature = "std"))] let result = self .socket - .recv_from(buf) + .recv_from(buff) .map(|(b, s)| (b, std_to_socket_addr(s))) .map_err(from_io_err); + if let Ok((size, client)) = result.as_ref() { + trace!("Received from {client} {size} {:x?}", buff); + } result } - fn send_to(&self, buff: &mut DataBuffer, addr: SocketAddr) -> io::Result { + fn send_to(&self, buff: &mut DataBuffer, client: SocketAddr) -> io::Result { #[cfg(feature = "std")] - return self.socket.send_to(&buff, addr); + let result = self.socket.send_to(buff, client); + #[cfg(not(feature = "std"))] + let result = self + .socket + .send_to(&buff, socket_addr_to_std(client)) + .map_err(from_io_err); + trace!("Send to {client} {} {:x?}", buff.len(), buff); + result + } + + fn local_addr(&self) -> io::Result { + #[cfg(feature = "std")] + return self.socket.local_addr(); #[cfg(not(feature = "std"))] self.socket - .send_to(&buff, socket_addr_to_std(addr)) + .local_addr() + .map(|s| std_to_socket_addr(s)) + .map_err(from_io_err) + } + + fn notified(&self, socket: &impl ToSocketId) -> bool { + self.events.iter().any(|e| e.key == socket.socket_id()) + } + + fn add_interest(&self, socket: &impl ToSocketId) -> io::Result<()> { + self.poller + .add( + RawCInt(socket.as_raw_fd()), + Event::readable(socket.socket_id()), + ) + .map_err(from_io_err) + } + + fn modify_interest(&mut self, socket_id: usize, raw_fd: SocketRawFd) -> io::Result<()> { + self.events.retain(|e| e.key != socket_id); + self.poller + .modify(RawCInt(raw_fd), Event::readable(socket_id)) .map_err(from_io_err) } +} + +#[cfg(target_family = "windows")] +struct RawCInt(u64); + +#[cfg(target_family = "windows")] +impl Source for RawCInt { + fn raw(&self) -> RawSocket { + self.0 as RawSocket + } +} + +#[cfg(not(target_family = "windows"))] +struct RawCInt(i32); + +#[cfg(not(target_family = "windows"))] +impl Source for RawCInt { + fn raw(&self) -> RawFd { + self.0 as RawFd + } +} - fn try_clone(&self) -> io::Result - where - Self: Sized, - { - Ok(Self { - #[cfg(feature = "std")] - socket: self.socket.try_clone()?, - #[cfg(not(feature = "std"))] - socket: self.socket.try_clone().map_err(from_io_err)?, - poller: None, - socket_id: 0, - }) +impl ToSocketId for StdSocket { + fn as_raw_fd(&self) -> SocketRawFd { + #[cfg(target_family = "windows")] + return self.socket.as_raw_socket(); + #[cfg(not(target_family = "windows"))] + self.socket.as_fd().as_raw_fd() + } + + fn socket_id(&self) -> usize { + self.socket_id + } +} + +pub struct StdBoundSocket { + socket: UdpSocket, + poller: Poller, + socket_id: usize, +} + +impl BoundSocket for StdBoundSocket { + fn recv(&self, buff: &mut DataBuffer, wait_for: Option) -> io::Result { + if let Some(d) = wait_for { + self.poller + .modify(&self.socket, Event::readable(self.socket_id)) + .map_err(from_io_err)?; + // TODO alloc in stack + let mut events = Vec::new(); + self.poller + .wait(&mut events, d.into()) + .map_err(from_io_err)?; + } + self.socket.recv(buff).map_err(from_io_err) + } + + fn send(&self, buff: &mut DataBuffer) -> io::Result { + self.socket.send(buff).map_err(from_io_err) } fn local_addr(&self) -> io::Result { @@ -103,6 +237,19 @@ impl Socket for StdSocket { } } +impl ToSocketId for StdBoundSocket { + fn as_raw_fd(&self) -> SocketRawFd { + #[cfg(target_family = "windows")] + return self.socket.as_raw_socket(); + #[cfg(not(target_family = "windows"))] + self.socket.as_fd().as_raw_fd() + } + + fn socket_id(&self) -> usize { + self.socket_id + } +} + #[cfg(not(feature = "std"))] pub fn std_to_socket_addr(addr: std::net::SocketAddr) -> SocketAddr { match addr { @@ -117,8 +264,10 @@ pub fn std_to_socket_addr(addr: std::net::SocketAddr) -> SocketAddr { } } -#[cfg(not(feature = "std"))] pub fn socket_addr_to_std(addr: SocketAddr) -> std::net::SocketAddr { + #[cfg(feature = "std")] + return addr; + #[cfg(not(feature = "std"))] match addr.ip { tftp::std_compat::net::IpVersion::Ipv4(b) => std::net::SocketAddr::V4( std::net::SocketAddrV4::new(std::net::Ipv4Addr::from(b), addr.port), @@ -137,8 +286,8 @@ mod tests { #[test] fn test_receive_wait_for() { - let socket_r = create_socket("127.0.0.1:9000", 1).unwrap(); - let socket_s = create_socket("127.0.0.1:0", 0).unwrap(); + let mut socket_r = create_socket("127.0.0.1:9000", 1, false).unwrap(); + let socket_s = create_socket("127.0.0.1:0", 0, false).unwrap(); let mut buf = DataBuffer::new(); #[allow(unused_must_use)] { @@ -146,19 +295,19 @@ mod tests { } let now = Instant::now(); - let wait_for = Duration::from_millis(15); + let wait_for = Duration::from_millis(30); let result = socket_r.recv_from(&mut buf, wait_for.into()); assert_eq!(result.unwrap_err().kind(), io::ErrorKind::WouldBlock); assert!(now.elapsed() >= wait_for); let now = Instant::now(); - let wait_for = Duration::from_micros(15); + let wait_for = Duration::from_micros(30); let result = socket_r.recv_from(&mut buf, wait_for.into()); assert_eq!(result.unwrap_err().kind(), io::ErrorKind::WouldBlock); assert!(now.elapsed() >= wait_for); let now = Instant::now(); - let wait_for = Duration::from_micros(15); + let wait_for = Duration::from_micros(30); let result = socket_r.recv_from(&mut buf, None); assert_eq!(result.unwrap_err().kind(), io::ErrorKind::WouldBlock); assert!(now.elapsed() < wait_for); diff --git a/build b/build index d9294bd..09197ab 100755 --- a/build +++ b/build @@ -27,8 +27,8 @@ build() { cargo build -p tftp-binary --target "$target" "--$RELEASE" --features seek done - cargo build -p tftp-binary --target "x86_64-unknown-linux-musl" "--$RELEASE" - cp target/x86_64-unknown-linux-musl/release/tftp dist/tftp-static + # cargo build -p tftp-binary --target "x86_64-unknown-linux-musl" "--$RELEASE" + # cp target/x86_64-unknown-linux-musl/release/tftp dist/tftp-static build_deb } diff --git a/src/client/client.rs b/src/client/client.rs index 8784cac..6f789e8 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -46,15 +46,17 @@ pub struct ClientConfig { pub allow_server_port_change: bool, } +#[allow(clippy::too_many_arguments)] pub fn send_file( config: ClientConfig, local_file_path: FilePath, remote_file_path: FilePath, mut options: ConnectionOptions, create_reader: CreateReader, - socket: Sock, + mut socket: Sock, instant: InstantCallback, rng: Rng, + ignore_rate_control: bool, ) -> BoxedResult<(usize, Option)> where R: Read + Seek, @@ -71,7 +73,7 @@ where ); #[cfg(feature = "encryption")] - let (socket, initial_keys) = create_initial_socket(socket, &config, &mut options, rng)?; + let (mut socket, initial_keys) = create_initial_socket(socket, &config, &mut options, rng)?; let mut max_buffer_size = max( options.block_size + DATA_PACKET_HEADER_SIZE as u16, @@ -90,20 +92,30 @@ where d }; - print_options("Client initial", &options); + // print_options("Client initial", &options); + + let mut rate_control = RateControl::new(instant); + + rate_control.start_rtt(1); let (_, acknowledge, mut options, endpoint) = query_server( - &socket, + &mut socket, &mut buffer, - PacketType::Write, + Packet::Write, remote_file_path, options, instant, &config, )?; + let initial_rtt = rate_control + .end_rtt(1) + .unwrap_or_else(|| Duration::from_millis(1)); + + debug!("Initial exchange took {}", initial_rtt.as_secs_f32()); + #[cfg(feature = "encryption")] - let (socket, options) = configure_socket(socket, initial_keys, options); + let (mut socket, options) = configure_socket(socket, initial_keys, options); print_options("Client using", &options); @@ -116,7 +128,6 @@ where reader, config.max_blocks_in_memory, options.block_size, - options.retry_packet_after_timeout, instant, options.window_size, ); @@ -128,45 +139,54 @@ where let mut total_unconfirmed = 0; let mut total_confirmed = 0; - let mut rate_control = RateControl::new(instant); - let mut stats_print = instant(); let mut stats_calculate = instant(); - + let mut no_work: u8 = 0; let mut packets_to_send = u32::MAX; - let packet_send_window: u32 = 200; + let flow_control_period = Duration::from_millis(200); + rate_control.acknoledged_data(options.block_size as usize, 1); + rate_control.calculate_transmit_rate( + options.block_size, + options.window_size, + options.retry_packet_after_timeout, + initial_rtt, + ); loop { - if stats_calculate.elapsed().as_millis() > packet_send_window as u128 { + if stats_calculate.elapsed() > flow_control_period { rate_control.calculate_transmit_rate( options.block_size, options.window_size, - options.retry_packet_after_timeout.as_secs_f64(), + options.retry_packet_after_timeout, + stats_calculate.elapsed(), ); stats_calculate = instant(); - packets_to_send = u32::MAX; - // rate_control.packets_to_send(packet_send_window, options.block_size as u32); - } - if stats_print.elapsed().as_secs() > 2 { - rate_control.print_info(); - stats_print = instant(); + + // packets_to_send = u32::MAX; + packets_to_send = if ignore_rate_control { + u32::MAX + } else { + rate_control.packets_to_send(flow_control_period, options.block_size) + }; } if packets_to_send > 0 { - if let Some(data_block) = block_reader.next()? { + let timeout_interval = rate_control + .timeout_interval(options.retry_packet_after_timeout, options.block_size); + if let Some(data_block) = block_reader.next(timeout_interval)? { let data_length = data_block.data.len(); debug!( - "Send data block {} data size {} ack {}", - data_block.block, data_length, data_block.expect_ack + "Send data block {} data size {data_length} ack {} to send {packets_to_send}", + data_block.block, data_block.expect_ack ); if data_block.expect_ack { + if data_block.retry > 0 { + rate_control.increment_errors(); + } rate_control.start_rtt(data_block.block); } - if data_block.retry > 0 { - rate_control.increment_errors(); - } let data_packet = Packet::Data(DataPacket { block: data_block.block, @@ -192,6 +212,7 @@ where } } else { no_work = no_work.wrapping_add(1); + rate_control.mark_as_data_limited(); } #[cfg(feature = "alloc")] @@ -208,7 +229,7 @@ where None }; - debug!( + trace!( "Last sent {}us Last received {}us waiting {}ms", last_sent.elapsed().as_micros(), last_received.elapsed().as_micros(), @@ -255,9 +276,15 @@ where timeout = instant(); let data_length = block_reader.free_block(p.block); - rate_control.data_received(data_length); + rate_control.acknoledged_data( + data_length, + (data_length / options.block_size as usize) as u32, + ); if let Some(rtt) = rate_control.end_rtt(p.block) { - debug!("Rtt for block {} elapsed {}us", p.block, rtt.as_micros()); + trace!("Rtt for block {} elapsed {}us", p.block, rtt.as_micros()); + // can not measure rtt for random blocks + } else if config.max_blocks_in_memory == 1 && options.window_size >= 1 { + rate_control.increment_errors(); } total_confirmed += data_length; @@ -278,6 +305,7 @@ where } } +#[allow(clippy::too_many_arguments)] pub fn receive_file( config: ClientConfig, local_file_path: FilePath, @@ -307,7 +335,7 @@ where ); #[cfg(feature = "encryption")] - let (socket, initial_keys) = create_initial_socket(socket, &config, &mut options, rng)?; + let (mut socket, initial_keys) = create_initial_socket(socket, &config, &mut options, rng)?; #[allow(unused_must_use)] let mut buffer = { @@ -316,20 +344,24 @@ where d }; - print_options("Client initial", &options); - + let initial_rtt = instant(); let (mut received_length, acknowledge, mut options, endpoint) = query_server( - &socket, + &mut socket, &mut buffer, - PacketType::Read, + Packet::Read, remote_file_path, options, instant, &config, )?; + debug!( + "Initial exchange took {}", + initial_rtt.elapsed().as_secs_f32() + ); + #[cfg(feature = "encryption")] - let (socket, options) = configure_socket(socket, initial_keys, options); + let (mut socket, options) = configure_socket(socket, initial_keys, options); let writer = create_writer(&local_file_path)?; let mut block_writer = FileWriter::from_writer( @@ -372,7 +404,6 @@ where let mut timeout = instant(); let mut total = 0; - let mut no_work: u8 = 0; loop { #[cfg(feature = "alloc")] @@ -383,18 +414,12 @@ where buffer.set_len(max_buffer_size as usize) }; - let wait_for = if no_work > 1 { - Duration::from_millis(no_work as u64).into() - } else { - None - }; - let length = match socket.recv_from(&mut buffer, wait_for) { + let length = match socket.recv_from(&mut buffer, Duration::from_secs(1).into()) { Ok((n, s)) => { if s != endpoint { continue; } - debug!("Received packet size {}", n); - no_work = 1; + trace!("Received packet size {}", n); n } Err(ref e) if e.kind() == ErrorKind::WouldBlock => { @@ -405,7 +430,6 @@ where } return Err(PacketError::Timeout(elapsed).into()); } - no_work = no_work.wrapping_add(1); continue; } Err(e) => { @@ -435,11 +459,8 @@ where Ok(n) => { if n > 0 { timeout = instant(); - no_work = 1; - } else { - no_work = no_work.wrapping_add(1); + total += n; } - total += n; } Err(e) => return Err(e), } @@ -561,10 +582,10 @@ fn configure_socket( (socket, options) } -fn query_server( - socket: &impl Socket, +fn query_server<'a>( + socket: &mut impl Socket, buffer: &mut DataBuffer, - packet_type: PacketType, + create_packet: impl Fn(RequestPacket) -> Packet<'a>, file_path: FilePath, options: ConnectionOptions, instant: InstantCallback, @@ -583,11 +604,8 @@ fn query_server( mode: Mode::Octet, extensions, }; - let packet = match packet_type { - PacketType::Read => Packet::Read(request_packet), - PacketType::Write => Packet::Write(request_packet), - _ => panic!("Invalid packet type provided"), - }; + let packet = create_packet(request_packet); + let packet_type = packet.packet_type(); let (length, endpoint) = wait_for_initial_packet( socket, @@ -596,10 +614,11 @@ fn query_server( buffer, request_timeout, instant, + options.retry_packet_after_timeout, )?; if config.endpoint != endpoint { if !config.allow_server_port_change { - error!("Server is using new port, however configuration does not allow it"); + error!("Server is using a new port, however configuration does not allow it"); return Err(PacketError::Invalid.into()); } else { debug!("Using new endpoint {}", endpoint); @@ -733,19 +752,24 @@ fn write_block( } fn wait_for_initial_packet( - socket: &impl Socket, + socket: &mut impl Socket, endpoint: SocketAddr, packet: Packet, buffer: &mut DataBuffer, request_timeout: Duration, instant: InstantCallback, + mut retry_timeout: Duration, ) -> BoxedResult<(usize, SocketAddr)> { let timeout = instant(); loop { socket.send_to(&mut packet.clone().to_bytes(), endpoint)?; - debug!("Initial packet elapsed {}", timeout.elapsed().as_secs_f32()); + debug!( + "Initial packet elapsed {} wait {}", + timeout.elapsed().as_secs_f32(), + retry_timeout.as_secs_f32() + ); - match socket.recv_from(buffer, Duration::from_millis(200).into()) { + match socket.recv_from(buffer, retry_timeout.into()) { Ok((n, s)) => { if s.ip() == endpoint.ip() { return Ok((n, s)); @@ -757,6 +781,10 @@ fn wait_for_initial_packet( if elapsed > request_timeout { return Err(PacketError::Timeout(elapsed).into()); } + retry_timeout = min( + request_timeout.saturating_sub(elapsed), + retry_timeout.saturating_mul(2), + ); continue; } Err(e) => { diff --git a/src/client/extensions.rs b/src/client/extensions.rs index 1fe07d5..25a9863 100644 --- a/src/client/extensions.rs +++ b/src/client/extensions.rs @@ -80,7 +80,7 @@ pub fn create_extensions(options: &ConnectionOptions) -> PacketExtensions { if options.encryption_level != EncryptionLevel::None { match (&options.encryption_keys, options.encryption_level) { (Some(EncryptionKeys::ClientKey(s)), l) => { - let value = encode_public_key(&s).expect("invalid key"); + let value = encode_public_key(s).expect("invalid key"); extensions.insert(Extension::PublicKey, value); extensions.insert(Extension::Nonce, "0".parse().expect("convert to string")); extensions.insert( @@ -97,7 +97,7 @@ pub fn create_extensions(options: &ConnectionOptions) -> PacketExtensions { | EncryptionLevel::OptionalProtocol ) => { - let value = encode_public_key(&local).expect("invalid key"); + let value = encode_public_key(local).expect("invalid key"); extensions.insert(Extension::PublicKey, value); extensions.insert(Extension::Nonce, "0".parse().expect("convert to string")); extensions.insert( @@ -195,15 +195,13 @@ pub fn parse_extensions( if matches!( _expected_encryption_level, EncryptionLevel::Data | EncryptionLevel::Protocol + ) && !matches!( + options.encryption_level, + EncryptionLevel::Data | EncryptionLevel::Protocol ) { - if !matches!( + return Err(ExtensionError::ClientRequiredEncryption( options.encryption_level, - EncryptionLevel::Data | EncryptionLevel::Protocol - ) { - return Err(ExtensionError::ClientRequiredEncryption( - options.encryption_level, - )); - } + )); } } Ok(options) diff --git a/src/config.rs b/src/config.rs index 158e7d7..c0a75bf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -91,10 +91,17 @@ impl ConnectionOptions { pub fn remote_public_key(&self) -> Option { match self.encryption_keys { - Some(EncryptionKeys::LocalToRemote(_, p)) => p.clone().into(), + Some(EncryptionKeys::LocalToRemote(_, p)) => p.into(), _ => None, } } + + pub fn is_encrypting(&self) -> bool { + matches!( + self.encryption_keys, + Some(crate::encryption::EncryptionKeys::LocalToRemote(..)) + ) + } } pub fn print_options(context: &str, options: &ConnectionOptions) { @@ -106,9 +113,7 @@ pub fn print_options(context: &str, options: &ConnectionOptions) { options.file_size.unwrap_or(0), options.retry_packet_after_timeout.as_millis(), options.encryption_level, - matches!( - options.encryption_keys, - Some(crate::encryption::EncryptionKeys::LocalToRemote(..)) - ), + options.is_encrypting(), + ); } diff --git a/src/encryption.rs b/src/encryption.rs index 07b1aad..0bc3325 100644 --- a/src/encryption.rs +++ b/src/encryption.rs @@ -275,7 +275,7 @@ mod tests { ]; for (expected, public) in data { let encoded = decode_public_key(public.as_bytes()); - assert_eq!(expected, encoded.is_ok(), "{}", public); + assert_eq!(expected, encoded.is_ok(), "{public}"); } } @@ -289,7 +289,7 @@ mod tests { ]; for (expected, nonce) in data { let encoded = decode_nonce(nonce.as_bytes()); - assert_eq!(expected, encoded.is_ok(), "{}", nonce); + assert_eq!(expected, encoded.is_ok(), "{nonce}"); } } diff --git a/src/error.rs b/src/error.rs index cd4df77..8087b9b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -164,7 +164,7 @@ impl Display for StorageError { StorageError::AlreadyWriten => write!(f, "Block has been already written"), StorageError::FileTooBig => write!(f, "File is too big"), StorageError::ExpectedBlock((expected, current)) => { - write!(f, "Expecting block {} current block {}", expected, current) + write!(f, "Expecting block {expected} current block {current}") } } } @@ -210,10 +210,10 @@ impl Display for ExtensionError { ExtensionError::InvalidPublicKey => write!(f, "Invalid public key received",), ExtensionError::InvalidNonce => write!(f, "Invalid nonce received",), ExtensionError::EncryptionError(s) => { - write!(f, "Invalid extension parsing error {}", s) + write!(f, "Invalid extension parsing error {s}") } ExtensionError::InvalidExtension(s) => { - write!(f, "Invalid extension {}", s) + write!(f, "Invalid extension {s}") } } } @@ -257,8 +257,8 @@ impl Display for EncryptionError { match self { EncryptionError::Encrypt => write!(f, "Failed to encrypt"), EncryptionError::Decrypt => write!(f, "Failed to decrypt"), - EncryptionError::Encode(t) => write!(f, "Failed to encode {}", t), - EncryptionError::Decode(t) => write!(f, "Failed to decode {}", t), + EncryptionError::Encode(t) => write!(f, "Failed to encode {t}"), + EncryptionError::Decode(t) => write!(f, "Failed to decode {t}"), } } } diff --git a/src/flow_control.rs b/src/flow_control.rs index ce72bf3..24e2a36 100644 --- a/src/flow_control.rs +++ b/src/flow_control.rs @@ -1,56 +1,82 @@ -use core::{ops::Div, time::Duration}; +use core::{ + cmp::{max, min}, + ops::Div, + time::Duration, +}; use log::{debug, trace}; -use crate::{std_compat::time::Instant, time::InstantCallback}; +use crate::{config::MAX_DATA_BLOCK_SIZE, std_compat::time::Instant, time::InstantCallback}; + +// const MSS: u16 = MAX_DATA_BLOCK_SIZE; +const INITIAL_TCP_WINDOW: u16 = 4380; pub struct RateControl { // bytes/s - avg_transmit_rate: u32, + allowed_transmit_rate: u32, // for measuring rtt current_rtt: Instant, rtt_for_packet: u16, - rtt_estimate: f64, + rtt_estimate: f32, // // for packet loss rate - total_packets: u32, + total_acknoledged_packets: u32, error_packets: u32, // - // for speed estimations - start: Instant, // bytes - total_send: u64, + total_data_sent: u32, // bytes - total_received: u64, + total_acknoledged_data: u32, // instant: InstantCallback, + + // received bytes per second + receive_set: arrayvec::ArrayVec, + + new_loss: bool, + no_feedback_timer: Duration, + feedback_timer_expired: bool, + data_limited: bool, } impl RateControl { pub fn new(instant: InstantCallback) -> Self { + let mut receive_set = arrayvec::ArrayVec::new(); + receive_set.push(u32::MAX); Self { rtt_for_packet: 0, rtt_estimate: 0.0, - total_send: 0, - total_received: 0, + total_data_sent: 0, + total_acknoledged_data: 0, instant, - start: instant(), current_rtt: instant(), error_packets: 0, - avg_transmit_rate: 0, - total_packets: 0, + // allowed_transmit_rate: initial_rate(MSS) as u32, + allowed_transmit_rate: u32::MAX, + total_acknoledged_packets: 0, + receive_set, + new_loss: false, + feedback_timer_expired: false, + data_limited: false, + no_feedback_timer: Duration::from_secs(2), } } pub fn increment_errors(&mut self) { self.error_packets += 1; + self.new_loss = true; + } + + pub fn mark_as_data_limited(&mut self) { + self.data_limited = true; } pub fn start_rtt(&mut self, block: u16) { - if self.current_rtt.elapsed().as_secs() > 1 { + if self.current_rtt.elapsed() >= self.no_feedback_timer { self.rtt_for_packet = 0; + self.feedback_timer_expired = true; } if self.rtt_for_packet != 0 { return; @@ -64,9 +90,9 @@ impl RateControl { if self.rtt_for_packet == block { let elapsed = self.current_rtt.elapsed(); if self.rtt_estimate == 0.0 { - self.rtt_estimate = elapsed.as_secs_f64(); + self.rtt_estimate = elapsed.as_secs_f32(); } else { - self.rtt_estimate = smooth_rtt_estimate(self.rtt_estimate, elapsed.as_secs_f64()); + self.rtt_estimate = smooth_rtt_estimate(self.rtt_estimate, elapsed.as_secs_f32()); } elapsed_duration = elapsed.into() } @@ -75,98 +101,146 @@ impl RateControl { } pub fn data_sent(&mut self, size: usize) { - self.total_send += size as u64; - self.total_packets += 1; + self.total_data_sent += size as u32; + } + + pub fn acknoledged_data(&mut self, size: usize, packets: u32) { + self.total_acknoledged_data += size as u32; + self.total_acknoledged_packets += packets; + } + + pub fn timeout_interval(&self, min_retry_timeout: Duration, block_size: u16) -> Duration { + let max_duration = Duration::from_secs(1); + if min_retry_timeout >= max_duration || self.rtt_estimate == 0.0 { + return min_retry_timeout; + } + let timeout = (4.0 * self.rtt_estimate) + .max(2.0 * block_size as f32 / self.allowed_transmit_rate as f32); + min(Duration::from_secs_f32(timeout), max_duration) } - pub fn data_received(&mut self, size: usize) { - self.total_received += size as u64; + fn maximize_set(&mut self, received: u32) { + let mut max_value = self.receive_set.iter().max().copied().unwrap_or_default(); + if received > max_value || max_value == u32::MAX { + max_value = received; + } + self.receive_set.clear(); + self.receive_set.push(max_value); + } + + fn update_set(&mut self, received: u32) { + if self.receive_set.is_full() { + self.receive_set.pop(); + } + self.receive_set.insert(0, received); } pub fn calculate_transmit_rate( &mut self, block_size: u16, window_size: u16, - retransmission_timeout: f64, - ) { - let loss_event_rate = self.error_packets as f64 / self.total_packets as f64; - self.avg_transmit_rate = average_transmit_rate( - self.rtt_estimate, - block_size as f64, - window_size as f64, - loss_event_rate, - retransmission_timeout, - ); - } + min_retry_timeout: Duration, + received_in: Duration, + ) -> u32 { + let mut received = if self.feedback_timer_expired { + initial_rate(block_size) as u32 + } else { + if self.total_acknoledged_data == 0 { + return self.allowed_transmit_rate; + } + (self.total_acknoledged_data as f32 / received_in.as_secs_f32()) as u32 + }; - pub fn print_info(&self) { - debug!( - "Expected rate: {} bytes/s Current rrt: {} Average send speed: {} bytes/s Average receive speed: {} bytes/s Total packets: {} Errors: {}", - self.avg_transmit_rate, - self.rtt_estimate, - self.average_send_speed(), - self.average_receive_speed(), - self.total_packets, - self.error_packets - ) - } - - pub fn average_send_speed(&self) -> u64 { - let passed = self.start.elapsed(); - if passed.as_secs() > 0 { - self.total_send / passed.as_secs() + let recv_limit = if self.data_limited { + if self.new_loss || self.feedback_timer_expired { + for v in self.receive_set.iter_mut().filter(|v| **v > 0) { + *v /= 2; + } + received = (0.85 * (received as f32)) as u32; + self.maximize_set(received); + self.receive_set.iter().max().copied().unwrap_or_default() + } else { + self.maximize_set(received); + 2_u32.saturating_mul(self.receive_set.iter().max().copied().unwrap_or_default()) + } } else { - 0 - } - } + self.update_set(received); + 2_u32.saturating_mul(self.receive_set.iter().max().copied().unwrap_or_default()) + }; + let loss_event_rate = + 1_f32.min(self.error_packets as f32 / max(self.total_acknoledged_packets, 1) as f32); + let timeout_interval = self.timeout_interval(min_retry_timeout, block_size); - pub fn average_receive_speed(&self) -> u64 { - let passed = self.start.elapsed(); - if passed.as_secs() > 0 { - self.total_received / passed.as_secs() + if loss_event_rate > 0.0 { + let avg_transmit_rate = average_transmit_rate( + self.rtt_estimate, + block_size as f32, + window_size as f32, + loss_event_rate, + timeout_interval.as_secs_f32(), + ); + let minimum_rate = block_size as u32 / 64; + self.allowed_transmit_rate = max(min(avg_transmit_rate, recv_limit), minimum_rate); } else { - 0 + self.allowed_transmit_rate = max( + min(2_u32.saturating_mul(self.allowed_transmit_rate), recv_limit), + initial_rate(block_size) as u32, + ); } + + debug!( + "Allowed rate {} bytes/s Rtt {}s Loss rate {loss_event_rate} Received {received} bytes Receive limit {recv_limit} bytes Data {} bytes Packets {} Errors {}", + self.allowed_transmit_rate, timeout_interval.as_secs_f32(), self.total_acknoledged_data, self.total_acknoledged_packets, self.error_packets + ); + + self.no_feedback_timer = timeout_interval; + self.new_loss = false; + self.feedback_timer_expired = false; + self.total_acknoledged_data = 0; + self.total_acknoledged_packets = 0; + self.error_packets = 0; + self.data_limited = false; + + self.allowed_transmit_rate } - /// send_window is in milliseconds and must be less than a second - #[allow(dead_code)] - pub fn packets_to_send(&self, send_window: u32, block_size: u32) -> u32 { - packets_to_send(self.avg_transmit_rate, send_window, block_size) + pub fn packets_to_send(&self, time_window: Duration, block_size: u16) -> u32 { + packets_to_send(self.allowed_transmit_rate, time_window, block_size) } } -fn packets_to_send(avg_rate: u32, send_window: u32, block_size: u32) -> u32 { - if avg_rate > 0 { - let mut packets = (avg_rate / block_size / (1000 / send_window)) * 2; - if packets == 0 { - packets = 1; - } - packets +fn packets_to_send(allowed_rate: u32, time_window: Duration, block_size: u16) -> u32 { + if allowed_rate > 0 { + let packets = (allowed_rate as f32 / block_size as f32 * time_window.as_secs_f32()) as u32; + max(packets, 1) } else { u32::MAX } } -fn smooth_rtt_estimate(rtt_estimate: f64, current_rrt: f64) -> f64 { - 0.9 * rtt_estimate + (1_f64 - 0.9) * current_rrt +fn smooth_rtt_estimate(rtt_estimate: f32, current_rrt: f32) -> f32 { + 0.9 * rtt_estimate + (1_f32 - 0.9) * current_rrt } // loss_event_rate = error_packets / total_packets; fn average_transmit_rate( - round_trip_time: f64, - block_size: f64, - window_size: f64, - loss_event_rate: f64, - retransmission_timeout: f64, + round_trip_time: f32, + block_size: f32, + window_size: f32, + loss_event_rate: f32, + retransmission_timeout: f32, ) -> u32 { (block_size - / (round_trip_time * f64::sqrt(2_f64 * window_size * loss_event_rate / 3_f64) + / (round_trip_time * f32::sqrt(2.0 * window_size * loss_event_rate / 3.0) + (retransmission_timeout - * (3_f64 - * f64::sqrt(3_f64 * window_size * loss_event_rate / 8_f64) + * (3.0 + * f32::sqrt(3.0 * window_size * loss_event_rate / 8.0) * loss_event_rate - * (1_f64 + 32_f64 * loss_event_rate.powf(2.0)))))) as u32 + * (1.0 + 32.0 * loss_event_rate * loss_event_rate))))) as u32 +} + +fn initial_rate(block_size: u16) -> u16 { + min(4 * block_size, max(2 * block_size, INITIAL_TCP_WINDOW)) } #[cfg(test)] @@ -196,35 +270,166 @@ mod tests { assert!(rate.rtt_estimate != current); } + #[cfg(feature = "std")] + #[test] + fn test_timeout_interval() { + use std::thread::sleep; + + let mut rate = RateControl::new(std::time::Instant::now); + + let result = rate.timeout_interval(Duration::from_millis(80), 512); + assert_eq!(result.as_millis(), 80); + + rate.start_rtt(1); + sleep(result); + rate.end_rtt(1); + + let transmit_rate = + rate.calculate_transmit_rate(512, 8, result, Duration::from_millis(1000)); + assert_eq!(transmit_rate, 4294967295); + + let transmit_rate = + rate.calculate_transmit_rate(512, 8, result, Duration::from_millis(1000)); + assert_eq!(transmit_rate, 4294967295); + + let result = rate.timeout_interval(Duration::from_millis(80), 512); + assert_eq!(result.as_millis(), 320); + + rate.feedback_timer_expired = true; + rate.data_limited = true; + + let transmit_rate = + rate.calculate_transmit_rate(512, 8, result, Duration::from_millis(1000)); + assert_eq!(transmit_rate, 2147483647); + + let result = rate.timeout_interval(Duration::from_millis(80), 512); + assert_eq!(result.as_millis(), 320); + + // timeout depends on data or rtt + rate.start_rtt(1); + sleep(Duration::from_millis(1)); + rate.end_rtt(1); + + rate.acknoledged_data(8000, 8); + let transmit_rate = rate.calculate_transmit_rate( + 512, + 8, + Duration::from_millis(80), + Duration::from_millis(200), + ); + + let result = rate.timeout_interval(Duration::from_millis(80), 512); + assert_eq!(result.as_millis(), 288); + } + + #[cfg(feature = "std")] + #[test] + fn test_transmit_rate() { + use std::thread::sleep; + + let mut rate = RateControl::new(std::time::Instant::now); + + rate.start_rtt(1); + sleep(Duration::from_millis(1)); + rate.end_rtt(1); + + rate.acknoledged_data(8000, 8); + let transmit_rate = rate.calculate_transmit_rate( + 512, + 8, + Duration::from_millis(80), + Duration::from_millis(200), + ); + assert_eq!(transmit_rate, 4294967295); + + rate.acknoledged_data(8000, 8); + let transmit_rate = rate.calculate_transmit_rate( + 512, + 8, + Duration::from_millis(80), + Duration::from_millis(200), + ); + assert_eq!(transmit_rate, 4294967295); + + rate.acknoledged_data(8000, 8); + let transmit_rate = rate.calculate_transmit_rate( + 512, + 8, + Duration::from_millis(80), + Duration::from_millis(200), + ); + assert_eq!(transmit_rate, 80000); + + rate.acknoledged_data(8000, 8); + let transmit_rate = rate.calculate_transmit_rate( + 512, + 8, + Duration::from_millis(80), + Duration::from_millis(200), + ); + assert_eq!(transmit_rate, 80000); + } + #[test] fn test_average_transmit() { - let result = average_transmit_rate(0.01, 512_f64, 1_f64, 0_f64, 1_f64); + let result = average_transmit_rate(0.01, 512_f32, 1_f32, 0_f32, 1_f32); assert_eq!(result, 4294967295); - let result = average_transmit_rate(0.01, 512_f64, 1_f64, 0.001_f64, 1_f64); + let result = average_transmit_rate(0.01, 512_f32, 1_f32, 0_f32, 0.001); + assert_eq!(result, 4294967295); + + let result = average_transmit_rate(0.01, 512_f32, 1_f32, 0.001, 1_f32); assert_eq!(result, 1618739); - let result = average_transmit_rate(0.01, 512_f64, 8_f64, 0.001_f64, 1_f64); + let result = average_transmit_rate(0.01, 512_f32, 8_f32, 0.001, 1_f32); assert_eq!(result, 572310); + + let result = average_transmit_rate(0.080_099_8, 512_f32, 8_f32, 0.001, 0.080); + assert_eq!(result, 87330); + + let result = average_transmit_rate(0.080_099_8, 512_f32, 8_f32, 0.0, 0.080); + assert_eq!(result, 4294967295); + + let result = average_transmit_rate(0.380_099_8, 1400.0, 8.0, 25.0 / 158.0, 0.080); + assert_eq!(result, 3532); + + let result = average_transmit_rate(0.011, 1400.0, 8.0, 1.0, 0.080); + assert_eq!(result, 101); + + let result = average_transmit_rate(0.127, 1400.0, 8.0, 1.0, 0.080); + assert_eq!(result, 99); } #[test] fn test_packets_to_send() { - assert_eq!(446, packets_to_send(572310, 200, 512)); - assert_eq!(1264, packets_to_send(1618739, 200, 512)); - // send 1677721 packets * 512 / 1024 / 1024 = 819Mb in 200ms - assert_eq!(1677721 * 2, packets_to_send(4294967295, 200, 512)); - assert_eq!(4294967295, packets_to_send(0, 200, 512)); - assert_eq!(1, packets_to_send(1, 200, 512)); - assert_eq!(1, packets_to_send(128, 200, 512)); - assert_eq!(6, packets_to_send(10000, 200, 512)); + assert_eq!( + 223, + packets_to_send(572310, Duration::from_millis(200), 512) + ); + assert_eq!( + 632, + packets_to_send(1618739, Duration::from_millis(200), 512) + ); + assert_eq!( + 1677721, + packets_to_send(4294967295, Duration::from_millis(200), 512) + ); + assert_eq!( + 4294967295, + packets_to_send(0, Duration::from_millis(200), 512) + ); + assert_eq!(1, packets_to_send(1, Duration::from_millis(200), 512)); + assert_eq!(1, packets_to_send(128, Duration::from_millis(200), 512)); + assert_eq!(3, packets_to_send(10000, Duration::from_millis(200), 512)); + + assert_eq!(19, packets_to_send(10000, Duration::from_millis(1000), 512)); } #[test] fn test_smooth_rrt_estimate() { assert_eq!(0.275, smooth_rtt_estimate(0.25, 0.5)); assert_eq!(0.55, smooth_rtt_estimate(0.5, 1.0)); - assert_eq!(0.46, smooth_rtt_estimate(0.5, 0.1)); + assert_eq!(0.45999998, smooth_rtt_estimate(0.5, 0.1)); assert_eq!(0.45, smooth_rtt_estimate(0.5, 0.0)); } } diff --git a/src/key_management.rs b/src/key_management.rs index 86e0502..2c237cb 100644 --- a/src/key_management.rs +++ b/src/key_management.rs @@ -91,7 +91,7 @@ pub fn append_to_known_hosts( file.write_fmt(format_args!( "\n{} {}\n", endpoint, - encode_public_key(&public_key)? + encode_public_key(public_key)? ))?; Ok(()) } diff --git a/src/server/connection.rs b/src/server/connection.rs index a2e758f..35e56e9 100644 --- a/src/server/connection.rs +++ b/src/server/connection.rs @@ -1,6 +1,6 @@ use core::{cmp::min, time::Duration}; -use log::{debug, error}; +use log::{debug, error, warn}; use rand::{CryptoRng, RngCore}; use super::{extensions::create_options, validation::validate_request_options}; @@ -17,7 +17,7 @@ use crate::{ RequestPacket, }, server::ServerConfig, - socket::Socket, + socket::{BoundSocket, Socket, ToSocketId}, std_compat::{ io::{self, Read, Seek, Write}, net::SocketAddr, @@ -33,32 +33,34 @@ pub enum ClientType { Writer(FileWriter), } -pub struct Connection { - pub socket: S, +pub struct Connection { + pub socket: B, pub options: ConnectionOptions, pub encryptor: Option, pub last_updated: Instant, pub transfer: usize, pub client_type: ClientType, pub endpoint: SocketAddr, + pub invalid: bool, } -impl Connection { - pub fn recv_from( - &self, - buffer: &mut DataBuffer, - wait_for: Option, - ) -> io::Result<(usize, SocketAddr)> { - self.socket.recv_from(buffer, wait_for) +impl Connection { + pub fn recv(&self, buffer: &mut DataBuffer, wait_for: Option) -> io::Result { + self.socket.recv(buffer, wait_for) } - pub fn receive_packet(&self, _buffer: &mut DataBuffer) -> bool { + pub fn decrypt_packet(&self, _buffer: &mut DataBuffer) -> bool { #[cfg(feature = "encryption")] if let (EncryptionLevel::Protocol | EncryptionLevel::Full, Some(encryptor)) = (self.options.encryption_level, &self.encryptor) { if encryptor.decrypt(_buffer).is_err() { - error!("Failed to decrypt packet {:x?}", &_buffer); + debug!( + "Failed to decrypt packet from {} {} {:x?}", + self.endpoint, + &_buffer.len(), + &_buffer + ); return false; } } @@ -69,21 +71,26 @@ impl Connection { self.options.encryption_level, &self.encryptor, ) { - if let Err(_) = overwrite_data_packet(_buffer, |buf| encryptor.decrypt(buf)) { - error!("Failed to decrypt data {:x?}", &_buffer); + if overwrite_data_packet(_buffer, |buf| encryptor.decrypt(buf)).is_err() { + debug!( + "Failed to decrypt data from {} {} {:x?}", + self.endpoint, + &_buffer.len(), + &_buffer + ); return false; } } - return true; + true } pub fn send_packet(&self, packet: Packet) -> bool { let packet_name = packet.packet_type(); match &packet { - Packet::Data(d) => debug!("Send {} {} {}", packet_name, d.block, self.endpoint), - Packet::Ack(d) => debug!("Send {} {} {}", packet_name, d.block, self.endpoint), - _ => debug!("Send {} {}", packet_name, self.endpoint), + Packet::Data(d) => debug!("Send {} {} to {}", packet_name, d.block, self.endpoint), + Packet::Ack(d) => debug!("Send {} {} to {}", packet_name, d.block, self.endpoint), + _ => debug!("Send {} to {}", packet_name, self.endpoint), }; let mut data = packet.to_bytes(); @@ -106,8 +113,8 @@ impl Connection { return false; } } - if let Err(e) = self.socket.send_to(&mut data, self.endpoint) { - error!("Failed to send {} for {} {}", packet_name, self.endpoint, e); + if let Err(e) = self.socket.send(&mut data) { + error!("Failed to send {} to {} {}", packet_name, self.endpoint, e); return false; } true @@ -206,21 +213,23 @@ impl<'a> ConnectionBuilder<'a> { Ok(()) } - pub(crate) fn build_writer( + pub fn build_writer( mut self, socket: &S, client: SocketAddr, create_writer: &CreateWriter, - create_socket: &CreateSocket, + create_bound_socket: &CreateBoundSocket, instant: fn() -> Instant, - ) -> BoxedResult<(Connection, PacketExtensions, Option)> + socket_id: usize, + ) -> ConnectionResult where S: Socket, + B: BoundSocket + ToSocketId, W: Write + Seek, - CreateSocket: Fn(&str, usize) -> BoxedResult, + CreateBoundSocket: Fn(&str, usize, SocketAddr) -> BoxedResult, CreateWriter: Fn(&FilePath, &ServerConfig) -> BoxedResult, { - let file_name = self.file_name.ok_or_else(|| FileError::InvalidFileName)?; + let file_name = self.file_name.ok_or(FileError::InvalidFileName)?; let (encryptor, finalized_keys) = if self.options.encryption_level != EncryptionLevel::Full { (None, self.finalized_keys.take()) @@ -249,12 +258,27 @@ impl<'a> ConnectionBuilder<'a> { return Err(e); } }; - let new_socket = if self.config.require_server_port_change { - let listen = format_str!(DefaultString, "{}:{}", self.config.listen.ip(), 0); - create_socket(&listen, 0)? + let listen = if self.config.require_server_port_change { + format_str!(DefaultString, "{}:{}", self.config.listen.ip(), 0) } else { - socket.try_clone()? + format_str!( + DefaultString, + "{}:{}", + self.config.listen.ip(), + self.config.listen.port() + ) }; + let new_socket = match create_bound_socket(&listen, socket_id, client) { + Ok(s) => s, + Err(e) => { + error!("Failed to create socket {}", e); + return Err(e); + } + }; + #[cfg(not(feature = "multi_thread"))] + if let Err(_) = socket.add_interest(&new_socket) { + warn!("Unable to add socket {} to epoll", new_socket.socket_id()); + } Ok(( Connection { socket: new_socket, @@ -269,27 +293,30 @@ impl<'a> ConnectionBuilder<'a> { options: self.options, endpoint: client, encryptor, + invalid: false, }, self.used_extensions, finalized_keys, )) } - pub fn build_reader( + pub fn build_reader( mut self, socket: &S, client: SocketAddr, create_reader: &CreateReader, - create_socket: &CreateSocket, + create_bound_socket: &CreateBoundSocket, instant: fn() -> Instant, - ) -> BoxedResult<(Connection, PacketExtensions, Option)> + socket_id: usize, + ) -> ConnectionResult where S: Socket, - CreateSocket: Fn(&str, usize) -> BoxedResult, + B: BoundSocket + ToSocketId, + CreateBoundSocket: Fn(&str, usize, SocketAddr) -> BoxedResult, R: Read + Seek, CreateReader: Fn(&FilePath, &ServerConfig) -> BoxedResult<(Option, R)>, { - let file_name = self.file_name.ok_or_else(|| FileError::InvalidFileName)?; + let file_name = self.file_name.ok_or(FileError::InvalidFileName)?; let (encryptor, finalized_keys) = if self.options.encryption_level != EncryptionLevel::Full { (None, self.finalized_keys.take()) @@ -304,7 +331,7 @@ impl<'a> ConnectionBuilder<'a> { &self.used_extensions, self.config, )?; - let (transfer_size, reader) = match create_reader(&file_path, &self.config) { + let (transfer_size, reader) = match create_reader(&file_path, self.config) { Ok(f) => f, Err(e) => { error!("Failed to open file {} {}", file_path, e); @@ -336,16 +363,30 @@ impl<'a> ConnectionBuilder<'a> { reader, self.config.max_queued_blocks_reader, self.options.block_size, - self.options.retry_packet_after_timeout, instant, self.options.window_size, ); - let new_socket = if self.config.require_server_port_change { - let listen = format_str!(DefaultString, "{}:{}", self.config.listen.ip(), 0); - create_socket(&listen, 0)? + let listen = if self.config.require_server_port_change { + format_str!(DefaultString, "{}:{}", self.config.listen.ip(), 0) } else { - socket.try_clone()? + format_str!( + DefaultString, + "{}:{}", + self.config.listen.ip(), + self.config.listen.port() + ) }; + let new_socket = match create_bound_socket(&listen, socket_id, client) { + Ok(s) => s, + Err(e) => { + error!("Failed to create socket {}", e); + return Err(e); + } + }; + #[cfg(not(feature = "multi_thread"))] + if let Err(_) = socket.add_interest(&new_socket) { + warn!("Unable to add socket {} to epoll", new_socket.socket_id()); + } Ok(( Connection { socket: new_socket, @@ -355,6 +396,7 @@ impl<'a> ConnectionBuilder<'a> { endpoint: client, encryptor, last_updated: instant(), + invalid: false, }, self.used_extensions, finalized_keys, @@ -406,3 +448,6 @@ fn handle_encrypted( remote_key, ))) } + +type ConnectionResult = + BoxedResult<(Connection, PacketExtensions, Option)>; diff --git a/src/server/extensions.rs b/src/server/extensions.rs index 0ad3d3e..d69e98b 100644 --- a/src/server/extensions.rs +++ b/src/server/extensions.rs @@ -136,7 +136,7 @@ pub fn create_options( ); used_extensions.insert( Extension::Nonce, - encode_nonce(&final_keys.nonce()).expect("nonce encoder"), + encode_nonce(final_keys.nonce()).expect("nonce encoder"), ); options.encryption_keys = Some(EncryptionKeys::LocalToRemote( final_keys.public, @@ -199,7 +199,7 @@ mod tests { extensions.insert(Extension::WindowSize, "8".parse().unwrap()); let (extensions, options, _) = create_options(extensions, options, &create_config(), None, 8, OsRng).unwrap(); - assert_eq!(extensions.len(), 4, "{:?}", extensions); + assert_eq!(extensions.len(), 4, "{extensions:?}"); assert_eq!(options.window_size, 8); assert_eq!(options.block_size, 101); assert_eq!(options.file_size, Some(6)); diff --git a/src/server/server.rs b/src/server/server.rs index 5e618c5..84e40cd 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -1,4 +1,4 @@ -use core::{cmp::max, fmt::write, mem::size_of_val, time::Duration}; +use core::{cmp::max, fmt::write, mem::size_of_val, num::NonZeroU32, time::Duration}; use log::{debug, error, info, trace, warn}; use rand::{CryptoRng, RngCore}; @@ -19,10 +19,11 @@ use crate::{ connection::{ClientType, Connection, ConnectionBuilder}, validation::handle_file_size, }, - socket::Socket, + socket::{BoundSocket, Socket, ToSocketId}, std_compat::{ io::{ErrorKind, Read, Seek, Write}, net::SocketAddr, + time::Instant, }, storage::{BlockReader, BlockWriter}, string::format_str, @@ -30,11 +31,17 @@ use crate::{ types::{DataBuffer, DefaultString, FilePath}, }; -#[cfg(feature = "alloc")] -type Clients = Map>; -#[cfg(not(feature = "alloc"))] -type Clients = - Map, { crate::config::MAX_CLIENTS as usize }>; +#[cfg(all(feature = "alloc", not(feature = "multi_thread")))] +type Clients = Map>; +#[cfg(all(not(feature = "alloc"), not(feature = "multi_thread")))] +type Clients = + Map, { crate::config::MAX_CLIENTS as usize }>; + +#[cfg(all(feature = "std", feature = "alloc", feature = "multi_thread"))] +type Handles = Map>; +#[cfg(all(feature = "std", not(feature = "alloc"), feature = "multi_thread"))] +type Handles = + Map, { crate::config::MAX_CLIENTS as usize }>; #[cfg(all(feature = "alloc", feature = "encryption"))] pub type AuthorizedKeys = alloc::vec::Vec; @@ -60,41 +67,31 @@ pub struct ServerConfig { pub require_server_port_change: bool, } -pub fn server( +#[allow(clippy::too_many_arguments)] +pub fn server( config: ServerConfig, create_reader: CreateReader, create_writer: CreateWriter, create_socket: CreateSocket, + create_bound_socket: CreateBoundSocket, instant: InstantCallback, mut rng: Rng, ) -> DefaultBoxedResult where - S: Socket, + S: Socket + ToSocketId, + B: BoundSocket + ToSocketId + Send + 'static, Rng: CryptoRng + RngCore + Copy, - R: Read + Seek, - CreateSocket: Fn(&str, usize) -> BoxedResult, + R: Read + Seek + Send + 'static, + CreateSocket: Fn(&str, usize, bool) -> BoxedResult, + CreateBoundSocket: Fn(&str, usize, SocketAddr) -> BoxedResult, CreateReader: Fn(&FilePath, &ServerConfig) -> BoxedResult<(Option, R)>, - W: Write + Seek, + W: Write + Seek + Send + 'static, CreateWriter: Fn(&FilePath, &ServerConfig) -> BoxedResult, { info!("Starting server on {}", config.listen); - let max_buffer_size = max( - config.max_block_size + DATA_PACKET_HEADER_SIZE as u16, - MIN_BUFFER_SIZE, - ); - assert!(max_buffer_size <= MAX_BUFFER_SIZE); - #[allow(unused_must_use)] - let mut buffer = { - let mut d = DataBuffer::new(); - d.resize(max_buffer_size as usize, 0); - d - }; - let mut clients: Clients<_, _, _> = Clients::new(); - debug!( - "Size of all clients in memory {} bytes", - size_of_val(&clients) - ); + let mut buffer = create_max_buffer(config.max_block_size); + let max_buffer_size = buffer.len(); #[cfg(feature = "encryption")] if let Some(private_key) = config.private_key.as_ref() { @@ -110,280 +107,579 @@ where &config.listen.ip(), &config.listen.port() ); - let socket = create_socket(&listen, 1)?; - let mut timeout_duration = instant(); - let mut last_socket_addr: Option = None; - let mut no_work: u8 = 0; - let mut last_received = instant(); - loop { - let send_duration = instant(); - if timeout_duration.elapsed() > Duration::from_secs(2) { - timeout_clients(&mut clients, config.request_timeout); - timeout_duration = instant(); - } - let sent = send_data_blocks(&mut clients); - if sent > 0 { - no_work = 1; - } else { - no_work = no_work.wrapping_add(1); - } + let mut socket_id = NonZeroU32::new(1).expect("Socket id must be more than zero"); + let mut socket = create_socket(&listen, 0, true)?; - #[cfg(feature = "alloc")] - buffer.resize(max_buffer_size as usize, 0); - // TODO heapless vector resizing is super slow - #[cfg(not(feature = "alloc"))] - unsafe { - buffer.set_len(max_buffer_size as usize) - }; - let wait_for = if clients.is_empty() { - Duration::from_millis(500).into() - } else if no_work > 2 { - Duration::from_millis(no_work as u64).into() - } else { - None - }; - - trace!( - "Total clients {} sent {} packets in {}us waiting {}ms last received {}us", - clients.len(), - sent, - send_duration.elapsed().as_micros(), - wait_for.unwrap_or(Duration::ZERO).as_millis(), - last_received.elapsed().as_micros(), + #[cfg(not(feature = "multi_thread"))] + { + let mut timeout_duration = instant(); + let mut next_client_to_send = 0; + let mut next_client_to_receive = 0; + let mut wait_control = WaitControl::new(); + let execute_timeout_client = Duration::from_secs(2); + let mut clients: Clients<_, _, _> = Clients::new(); + debug!( + "Size of all clients in memory {} bytes", + size_of_val(&clients) ); - let (mut received_length, from_client) = match socket.recv_from(&mut buffer, wait_for) { - Ok(n) => { - no_work = 1; - last_received = instant(); - n - } - Err(ref e) if e.kind() == ErrorKind::WouldBlock => { - if config.require_server_port_change { - let mut recv = None; - for (s, c) in clients.iter() { - if last_socket_addr == Some(*s) { - continue; - } - match c.recv_from(&mut buffer, None) { - Ok(n) => { - recv = Some(n); - no_work = 1; - last_received = instant(); - break; - } - _ => continue, - } - } - if let Some(p) = recv { - last_socket_addr = Some(p.1); - p - } else { - last_socket_addr = None; - continue; - } - } else { - no_work = no_work.wrapping_add(1); - continue; - } + loop { + let send_duration = instant(); + if timeout_duration.elapsed() > execute_timeout_client { + clients.retain(|client, connection: &mut Connection<_, _, B>| { + !timeout_client(connection, config.request_timeout) + }); + timeout_duration = instant(); } - Err(e) => return Err(e.into()), - }; - buffer.truncate(received_length); - let client_length = clients.len(); - match clients.entry(from_client) { - Entry::Occupied(mut entry) => { - let mut connection = entry.get_mut(); - if !connection.receive_packet(&mut buffer) { - continue; - } + let sent_in = instant(); - let packet_type = PacketType::from_bytes(&buffer); - if !matches!( - packet_type, - Ok(PacketType::Data | PacketType::Ack | PacketType::Error) - ) { - debug!("Incorrect packet type received {:x?}", buffer); - continue; - } + let (sent, recv_next_client_to_send) = + send_data_blocks(&mut clients, next_client_to_send); + next_client_to_send = recv_next_client_to_send; + wait_control.sending(sent); - match Packet::from_bytes(&buffer) { - Ok(Packet::Data(p)) => { - let data_length = p.data.len(); + debug!( + "Sent {sent} next {recv_next_client_to_send} in {}", + sent_in.elapsed().as_secs_f32() + ); - debug!( - "Packet received block {} size {} total {} from {}", - p.block, data_length, connection.transfer, from_client - ); + #[cfg(feature = "alloc")] + buffer.resize(max_buffer_size as usize, 0); + // TODO heapless vector resizing is super slow + #[cfg(not(feature = "alloc"))] + unsafe { + buffer.set_len(max_buffer_size as usize) + }; - let mut write_elapsed = instant(); - match write_block(&mut connection, p.block, p.data) { - Ok(n) if n > 0 => { - connection.last_updated = instant(); - connection.transfer += n; - trace!( - "Block {} written in {}us", - p.block, - write_elapsed.elapsed().as_micros() - ); - } - Ok(_) => continue, - Err(e) => { - connection.send_packet(Packet::Error(e)); - entry.remove(); - continue; + let client_received_in = instant(); + + let client_received = clients.iter().skip(next_client_to_receive).find_map( + |(client_socket_addr, connection)| { + next_client_to_receive += 1; + if socket.notified(&connection.socket) { + match connection.recv(&mut buffer, None) { + Ok(b) => { + if let Err(_) = socket.modify_interest( + connection.socket.socket_id(), + connection.socket.as_raw_fd(), + ) { + warn!("Unable to modify epoll"); + } + Some((b, *client_socket_addr)) } + _ => None, } - - // this would write more than expected but only by a block size maximum - if let Err(e) = - handle_file_size(connection.transfer as u64, config.max_file_size) - { - connection.send_packet(Packet::Error(e)); - entry.remove(); - continue; - } + } else { + None } - Ok(Packet::Ack(p)) => { - let ClientType::Reader(ref mut block_reader): ClientType = connection.client_type else { - continue; - }; + }, + ); - debug!("Ack received {} {}", p.block, from_client); + debug!( + "Received from client {:?} next {next_client_to_receive} in {}", + client_received, + client_received_in.elapsed().as_secs_f32() + ); - if block_reader.free_block(p.block) > 0 { - connection.last_updated = instant(); - } - if block_reader.is_finished() { - info!("Client read {} finished", from_client); - entry.remove(); + let received_in = instant(); + + let (received_length, from_client) = match client_received { + Some(r) => r, + None => match socket.recv_from(&mut buffer, wait_control.wait_for(clients.len())) { + Ok((received, from_client)) => { + // ignore existing connection attemps on the main socket + if clients.contains_key(&from_client) { + next_client_to_receive = 0; continue; } + trace!("New connection from {from_client} next {next_client_to_receive}"); + next_client_to_receive = 0; + (received, from_client) } - Ok(Packet::Error(p)) => { - error!("Error received {:?} {}", p.code, p.message); - entry.remove(); + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + next_client_to_receive = 0; + wait_control.receiver_idle(); continue; } - _ => { - debug!("Incorrect packet received {:x?}", buffer); + Err(e) => return Err(e.into()), + }, + }; + + debug!( + "Received connection from {from_client} in {}", + received_in.elapsed().as_secs_f32(), + ); + + wait_control.receiving(); + buffer.truncate(received_length); + + let clients_len = clients.len(); + let processed_in = instant(); + + match clients.entry(from_client) { + Entry::Occupied(mut entry) => { + let mut connection = entry.get_mut(); + match &connection.client_type { + ClientType::Reader(r) => handle_read(&mut connection, &mut buffer, instant), + ClientType::Writer(r) => handle_write( + &mut connection, + &mut buffer, + instant, + config.max_file_size, + ), + }; + } + Entry::Vacant(entry) => { + if clients_len >= config.max_connections as usize { + info!( + "Max connections {} reached. Ignoring connection from {}", + config.max_connections, from_client + ); continue; } + let Some(connection) = create_new_connection(&config, &mut socket_id, &mut buffer, &socket, from_client, &create_reader, &create_writer, &create_bound_socket, instant, rng) else { + continue; }; + entry.insert(connection); + } } - Entry::Vacant(entry) => { - if client_length >= config.max_connections as usize { - error!( - "Max connections {} reached. Ignoring connection from {}", - config.max_connections, from_client - ); - continue; + + debug!( + "Processed connection from {from_client} in {}", + processed_in.elapsed().as_secs_f32(), + ); + } + } + + #[cfg(all(feature = "std", feature = "multi_thread"))] + { + let mut handles = Handles::new(); + + loop { + #[cfg(feature = "alloc")] + buffer.resize(max_buffer_size, 0); + // TODO heapless vector resizing is super slow + #[cfg(not(feature = "alloc"))] + unsafe { + buffer.set_len(max_buffer_size) + }; + + let received_in = instant(); + let (received_length, from_client) = + match socket.recv_from(&mut buffer, Duration::from_secs(1).into()) { + Ok(connection_received) => connection_received, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + continue; + } + Err(e) => return Err(e.into()), + }; + + debug!( + "Received connection from {from_client} in {} exists {}", + received_in.elapsed().as_secs_f32(), + handles.contains_key(&from_client), + ); + + if handles.contains_key(&from_client) { + continue; + } + + buffer.truncate(received_length); + + if handles.len() >= config.max_connections as usize { + info!( + "Max connections {} reached. Ignoring connection from {}", + config.max_connections, from_client + ); + continue; + } + + let Some(connection) = create_new_connection(&config, &mut socket_id, &mut buffer, &socket, from_client, &create_reader, &create_writer, &create_bound_socket, instant, rng) else { + continue; + }; + let handle = match &connection.client_type { + ClientType::Reader(r) => spawn_reader(connection, instant, config.request_timeout), + ClientType::Writer(r) => spawn_writer( + connection, + instant, + config.request_timeout, + config.max_file_size, + ), + }; + handles.insert(from_client, handle); + handles.retain(|_, t| !t.is_finished()); + } + } +} + +#[allow(clippy::too_many_arguments)] +fn create_new_connection( + config: &ServerConfig, + socket_id: &mut NonZeroU32, + buffer: &mut DataBuffer, + socket: &S, + from_client: SocketAddr, + create_reader: &CreateReader, + create_writer: &CreateWriter, + create_bound_socket: &CreateBoundSocket, + instant: InstantCallback, + mut rng: Rng, +) -> Option> +where + S: Socket, + B: BoundSocket + ToSocketId, + Rng: CryptoRng + RngCore + Copy, + R: Read + Seek, + CreateBoundSocket: Fn(&str, usize, SocketAddr) -> BoxedResult, + CreateReader: Fn(&FilePath, &ServerConfig) -> BoxedResult<(Option, R)>, + W: Write + Seek, + CreateWriter: Fn(&FilePath, &ServerConfig) -> BoxedResult, +{ + let mut builder = match ConnectionBuilder::from_new_connection(config, buffer, rng) { + Ok(b) => b, + Err(e) => { + debug!("New connection error {}", e); + return None; + } + }; + + match Packet::from_bytes(&buffer) { + Ok(Packet::Write(p)) => { + debug!( + "New client {from_client} writing to file {} in directory {}", + p.file_name, config.directory + ); + + let Ok(()) = builder.with_request(p, config.max_window_size, rng) else { + return None; + }; + + *socket_id = socket_id + .checked_add(1) + .or_else(|| NonZeroU32::new(1)) + .expect("Socket id expected"); + + let Ok((mut connection, used_extensions, encrypt_new_connection)): Result<(Connection<_, W, _>, _, _), _> = + builder.build_writer(socket, from_client, &create_writer, &create_bound_socket, instant, socket_id.get() as usize) + else { + return None; + }; + + if !used_extensions.is_empty() { + if !connection.send_packet(Packet::OptionalAck(OptionalAck { + extensions: used_extensions, + })) { + return None; } + } else if !connection.send_packet(Packet::Ack(AckPacket { block: 0 })) { + return None; + } + // new encryption starts only here + if let Some(keys) = encrypt_new_connection { + connection.encryptor = keys.encryptor.into(); + } - let mut builder = - match ConnectionBuilder::from_new_connection(&config, &mut buffer, rng) { - Ok(b) => b, - Err(e) => { - debug!("New connection error {}", e); - continue; - } - }; + print_options("Server writing using", &connection.options); - if !matches!( - PacketType::from_bytes(&buffer), - Ok(PacketType::Write | PacketType::Read) - ) { - debug!("Incorrect packet type received {:x?}", buffer); - continue; + connection.into() + } + Ok(Packet::Read(p)) => { + debug!( + "New client {from_client} reading file {} in directory {}", + p.file_name, config.directory + ); + + let Ok(()) = builder.with_request(p, config.max_window_size, rng) else { + return None; + }; + *socket_id = socket_id + .checked_add(1) + .or_else(|| NonZeroU32::new(1)) + .expect("Socket id expected"); + + let Ok((mut connection, used_extensions, encrypt_new_connection)): Result<(Connection, _, _), _> = + builder.build_reader(socket, from_client, &create_reader, &create_bound_socket, instant, socket_id.get() as usize) else { + return None; + }; + + if !used_extensions.is_empty() { + if !connection.send_packet(Packet::OptionalAck(OptionalAck { + extensions: used_extensions, + })) { + return None; } + } - debug!("Received from new client {}", from_client); + // new encryption starts only here + if let Some(keys) = encrypt_new_connection { + connection.encryptor = keys.encryptor.into(); + } - match Packet::from_bytes(&buffer) { - Ok(Packet::Write(p)) => { - debug!( - "New client writing to file {} in directory {}", - p.file_name, config.directory - ); + print_options("Server reading using", &connection.options); - let Ok(()) = builder.with_request(p, config.max_window_size, rng) else { - continue; - }; + connection.into() + } + _ => { + debug!("Incorrect packet received {:x?}", buffer); + None + } + } +} - let Ok((mut connection, used_extensions, encrypt_new_connection)): Result<(Connection<_, W, _>, _, _), _> = - builder.build_writer(&socket, from_client, &create_writer, &create_socket, instant) - else { - continue; - }; +#[cfg(all(feature = "std", feature = "multi_thread"))] +fn spawn_reader< + R: Read + Seek + Send + 'static, + W: Write + Seek + Send + 'static, + B: BoundSocket + Send + 'static, +>( + mut connection: Connection, + instant: InstantCallback, + request_timeout: Duration, +) -> std::thread::JoinHandle<()> { + use crate::config::ENCRYPTION_TAG_SIZE; + + std::thread::spawn(move || { + let mut buffer = create_max_buffer( + connection.options.block_size + + if connection.options.is_encrypting() { + ENCRYPTION_TAG_SIZE as u16 + } else { + 0 + }, + ); + let mut wait_control = WaitControl::new(); + let max_buffer_size = buffer.len(); + loop { + if timeout_client(&mut connection, request_timeout) { + return; + } + let sent = send_data_block(&mut connection); + wait_control.sending(sent); + + #[cfg(feature = "alloc")] + buffer.resize(max_buffer_size, 0); + // TODO heapless vector resizing is super slow + #[cfg(not(feature = "alloc"))] + unsafe { + buffer.set_len(max_buffer_size) + }; - if !used_extensions.is_empty() { - if !connection.send_packet(Packet::OptionalAck(OptionalAck { - extensions: used_extensions, - })) { - continue; - } - } else if !connection.send_packet(Packet::Ack(AckPacket { block: 0 })) { - continue; - } - // // new encryption starts only here - if let Some(keys) = encrypt_new_connection { - connection.encryptor = keys.encryptor.into(); - } + let received_length = match connection.recv(&mut buffer, wait_control.wait_for(1)) { + Ok(connection_received) => connection_received, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + wait_control.receiver_idle(); + continue; + } + Err(e) => return, + }; + wait_control.receiving(); + buffer.truncate(received_length); + handle_read(&mut connection, &mut buffer, instant); + } + }) +} - print_options("Server writing using", &connection.options); +#[cfg(all(feature = "std", feature = "multi_thread"))] +fn spawn_writer< + R: Send + 'static, + W: Write + Seek + Send + 'static, + B: BoundSocket + Send + 'static, +>( + mut connection: Connection, + instant: InstantCallback, + request_timeout: Duration, + max_file_size: u64, +) -> std::thread::JoinHandle<()> { + use crate::config::ENCRYPTION_TAG_SIZE; + + std::thread::spawn(move || { + let mut buffer = create_max_buffer( + connection.options.block_size + + if connection.options.is_encrypting() { + ENCRYPTION_TAG_SIZE as u16 + } else { + 0 + }, + ); + let max_buffer_size = buffer.len(); + loop { + if timeout_client(&mut connection, request_timeout) { + return; + } + #[cfg(feature = "alloc")] + buffer.resize(max_buffer_size, 0); + // TODO heapless vector resizing is super slow + #[cfg(not(feature = "alloc"))] + unsafe { + buffer.set_len(max_buffer_size) + }; - entry.insert(connection); - } - Ok(Packet::Read(p)) => { - debug!( - "New client reading file {} in directory {}", - p.file_name, config.directory - ); + let received_length = match connection.recv(&mut buffer, Duration::from_secs(1).into()) + { + Ok(connection_received) => connection_received, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + continue; + } + Err(e) => return, + }; + buffer.truncate(received_length); + handle_write(&mut connection, &mut buffer, instant, max_file_size); + } + }) +} - let Ok(()) = builder.with_request(p, config.max_window_size, rng) else { - continue; - }; +fn create_max_buffer(max_block_size: u16) -> DataBuffer { + let max_buffer_size = max( + max_block_size + DATA_PACKET_HEADER_SIZE as u16, + MIN_BUFFER_SIZE, + ); + assert!(max_buffer_size <= MAX_BUFFER_SIZE); + #[allow(unused_must_use)] + let mut buffer = { + let mut d = DataBuffer::new(); + d.resize(max_buffer_size as usize, 0); + d + }; + buffer +} - let Ok((mut connection, used_extensions, encrypt_new_connection)): Result<(Connection, _, _), _> = - builder.build_reader(&socket, from_client, &create_reader, &create_socket, instant) else { - continue; - }; +fn handle_write( + connection: &mut Connection, + mut buffer: &mut DataBuffer, + instant: InstantCallback, + max_file_size: u64, +) -> Option<()> { + if !connection.decrypt_packet(buffer) { + return None; + } - if !used_extensions.is_empty() { - if !connection.send_packet(Packet::OptionalAck(OptionalAck { - extensions: used_extensions, - })) { - continue; - } - } + let packet_type = PacketType::from_bytes(buffer); + if !matches!( + packet_type, + Ok(PacketType::Data | PacketType::Ack | PacketType::Error) + ) { + debug!( + "Incorrect packet type received from {} {} {:x?}", + connection.endpoint, + buffer.len(), + buffer, + ); + return None; + } - // new encryption starts only here - if let Some(keys) = encrypt_new_connection { - connection.encryptor = keys.encryptor.into(); - } + match Packet::from_bytes(buffer) { + Ok(Packet::Data(p)) => { + let data_length = p.data.len(); - print_options("Server reading using", &connection.options); + debug!( + "Packet received block {} size {} total {} from {}", + p.block, data_length, connection.transfer, connection.endpoint + ); - entry.insert(connection); - } - _ => { - debug!("Incorrect packet received {:x?}", buffer); - continue; - } - }; + let mut write_elapsed = instant(); + match write_block(connection, p.block, p.data) { + Ok(n) if n > 0 => { + connection.last_updated = instant(); + connection.transfer += n; + trace!( + "Block {} written in {}us", + p.block, + write_elapsed.elapsed().as_micros() + ); + } + Ok(_) => return None, + Err(e) => { + connection.send_packet(Packet::Error(e)); + connection.invalid = true; + return None; + } } + + // this would write more than expected but only by a block size maximum + if let Err(e) = handle_file_size(connection.transfer as u64, max_file_size) { + connection.send_packet(Packet::Error(e)); + connection.invalid = true; + return None; + } + } + Ok(Packet::Ack(_)) => { + return None; + } + Ok(Packet::Error(p)) => { + error!("Error received {:?} {}", p.code, p.message); + connection.invalid = true; + return None; } + _ => { + debug!( + "Incorrect packet received from {} {:x?}", + connection.endpoint, buffer + ); + return None; + } + }; + Some(()) +} + +fn handle_read( + connection: &mut Connection, + mut buffer: &mut DataBuffer, + instant: InstantCallback, +) -> Option<()> { + if !connection.decrypt_packet(buffer) { + return None; } + + let packet_type = PacketType::from_bytes(buffer); + if !matches!(packet_type, Ok(PacketType::Ack | PacketType::Error)) { + debug!( + "Incorrect packet type received from {} {:x?}", + connection.endpoint, buffer, + ); + return None; + } + + match Packet::from_bytes(buffer) { + Ok(Packet::Ack(p)) => { + let ClientType::Reader(ref mut block_reader): ClientType = connection.client_type else { + return None; + }; + + debug!("Ack received {} {}", p.block, connection.endpoint); + + if block_reader.free_block(p.block) > 0 { + connection.last_updated = instant(); + } + if block_reader.is_finished() { + info!("Client read {} finished", connection.endpoint); + connection.invalid = true; + return None; + } + } + Ok(Packet::Error(p)) => { + error!("Error received {:?} {}", p.code, p.message); + connection.invalid = true; + return None; + } + _ => { + debug!( + "Incorrect packet received from {} {:x?}", + connection.endpoint, buffer + ); + return None; + } + }; + Some(()) } -fn write_block( - connection: &mut Connection, +fn write_block( + connection: &mut Connection, mut block: u16, data: &[u8], ) -> Result where - S: Socket, + B: BoundSocket, { let ClientType::Writer(ref mut block_writer): ClientType<_, W> = connection.client_type else { return Ok(0); @@ -432,75 +728,145 @@ where Ok(length) } -fn timeout_clients( - clients: &mut Clients, +fn timeout_client( + connection: &mut Connection, request_timeout: Duration, -) { - clients.retain(|client, connection| { - let client_type = match connection.client_type { - ClientType::Writer(ref w) => { - if w.is_finished_below(connection.options.block_size) { - info!("Client write finished {}", client); - return false; - } - "write" +) -> bool { + let client_type = match connection.client_type { + ClientType::Writer(ref w) => { + if w.is_finished_below(connection.options.block_size) { + info!("Client write finished {}", connection.endpoint); + return true; } - ClientType::Reader(_) => "read", - }; - if connection.last_updated.elapsed() <= request_timeout { - return true; + "write" } + ClientType::Reader(_) => "read", + }; + if connection.invalid { + return true; + } + if connection.last_updated.elapsed() <= request_timeout { + return false; + } - warn!( - "Client {} timeout {} {}", - client_type, - client, - connection.last_updated.elapsed().as_secs_f32() - ); + warn!( + "Client {} timeout {} {}", + client_type, + connection.endpoint, + connection.last_updated.elapsed().as_secs_f32() + ); - let message = format_str!( - DefaultString, - "Client timeout {}", - connection.last_updated.elapsed().as_secs_f32() - ); - connection.send_packet(Packet::Error(ErrorPacket::new( - ErrorCode::AccessVioliation, - message, - ))); - false - }); + let message = format_str!( + DefaultString, + "Client timeout {}", + connection.last_updated.elapsed().as_secs_f32() + ); + connection.send_packet(Packet::Error(ErrorPacket::new( + ErrorCode::AccessVioliation, + message, + ))); + true } -fn send_data_blocks(clients: &mut Clients) -> usize { - let mut sent = 0; - clients.retain(|_, connection| { - for _ in 0..1 { - let block_reader = match &mut connection.client_type { - ClientType::Reader(r) => r, - ClientType::Writer(_) => return true, - }; +fn send_data_block( + connection: &mut Connection, +) -> bool { + let block_reader = match &mut connection.client_type { + ClientType::Reader(r) => r, + ClientType::Writer(_) => return false, + }; - let packet_block = match block_reader.next() { - Ok(Some(b)) => b, - Ok(None) => return true, - Err(e) => { - error!("Failed to read {}", e); - connection.send_packet(Packet::Error(ErrorPacket::new( - ErrorCode::AccessVioliation, - format_str!(DefaultString, "{}", e), - ))); - return false; - } - }; - let packet_sent = connection.send_packet(Packet::Data(DataPacket { - block: packet_block.block, - data: &packet_block.data, - })); - if packet_sent { - sent += 1; - } + let packet_block = match block_reader.next(connection.options.retry_packet_after_timeout) { + Ok(Some(b)) => b, + Ok(None) => return false, + Err(e) => { + error!("Failed to read {}", e); + connection.send_packet(Packet::Error(ErrorPacket::new( + ErrorCode::AccessVioliation, + format_str!(DefaultString, "{}", e), + ))); + connection.invalid = true; + return false; } - true - }); - sent + }; + connection.send_packet(Packet::Data(DataPacket { + block: packet_block.block, + data: &packet_block.data, + })) +} + +#[cfg(not(feature = "multi_thread"))] +fn send_data_blocks( + clients: &mut Clients, + next_client: usize, +) -> (bool, usize) { + let mut current_client: Option = clients + .iter_mut() + .filter(|(_, connection)| { + matches!(connection.client_type, ClientType::Reader(_)) && !connection.invalid + }) + .enumerate() + .skip(next_client) + .find_map(|(index, (s, c))| send_data_block(c).then(|| index + 1)); + if current_client.is_none() { + current_client = clients + .iter_mut() + .filter(|(_, connection)| { + matches!(connection.client_type, ClientType::Reader(_)) && !connection.invalid + }) + .take(next_client) + .enumerate() + .find_map(|(index, (s, c))| send_data_block(c).then(|| index + 1)); + } + + ( + current_client.is_some(), + current_client.unwrap_or(next_client), + ) +} + +struct WaitControl { + idle: u8, + sending: bool, + receiving: bool, +} + +impl WaitControl { + fn new() -> Self { + Self { + idle: 0, + sending: false, + receiving: false, + } + } + + fn sending(&mut self, sent: bool) { + if sent { + self.idle = 0; + self.sending = true; + } else { + self.idle = self.idle.wrapping_add(1); + self.sending = false; + } + } + + fn receiver_idle(&mut self) { + self.idle = self.idle.wrapping_add(1); + self.receiving = false; + } + + fn receiving(&mut self) { + self.idle = 0; + self.receiving = true; + } + + fn wait_for(&self, client_size: usize) -> Option { + if client_size == 0 { + Duration::from_millis(500).into() + } else if !self.sending && !self.receiving { + Duration::from_millis(self.idle as u64).into() + } else { + None + } + } } diff --git a/src/server/validation.rs b/src/server/validation.rs index 1cea93f..7ee0645 100644 --- a/src/server/validation.rs +++ b/src/server/validation.rs @@ -57,7 +57,7 @@ pub fn validate_request_options( Err(e) => { let packet = Packet::Error(e); socket.send_to(&mut packet.to_bytes(), client)?; - return Err(FileError::InvalidFileName.into()); + Err(FileError::InvalidFileName.into()) } } } diff --git a/src/socket.rs b/src/socket.rs index 30d9cc9..2a635e3 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -10,19 +10,36 @@ use crate::{ types::DataBuffer, }; -pub trait Socket { +pub trait Socket: ToSocketId { fn recv_from( - &self, + &mut self, buf: &mut DataBuffer, wait_for: Option, ) -> Result<(usize, SocketAddr)>; fn send_to(&self, buf: &mut DataBuffer, addr: SocketAddr) -> Result; - fn try_clone(&self) -> Result - where - Self: Sized; + fn local_addr(&self) -> Result; + + fn notified(&self, to_socket_id: &impl ToSocketId) -> bool; + fn add_interest(&self, to_socket_id: &impl ToSocketId) -> Result<()>; + fn modify_interest(&mut self, socket_id: usize, raw_fd: SocketRawFd) -> Result<()>; +} + +pub trait BoundSocket: ToSocketId { + fn recv(&self, buff: &mut DataBuffer, wait_for: Option) -> Result; + fn send(&self, buff: &mut DataBuffer) -> Result; fn local_addr(&self) -> Result; } +pub trait ToSocketId { + fn as_raw_fd(&self) -> SocketRawFd; + fn socket_id(&self) -> usize; +} + +#[cfg(target_family = "windows")] +pub type SocketRawFd = u64; +#[cfg(not(target_family = "windows"))] +pub type SocketRawFd = i32; + #[cfg(feature = "encryption")] pub struct EncryptionBoundSocket { pub socket: S, @@ -63,13 +80,12 @@ where S: Socket, { fn recv_from( - &self, + &mut self, buff: &mut DataBuffer, wait_for: Option, ) -> Result<(usize, SocketAddr)> { let (received_length, s) = self.socket.recv_from(buff, wait_for)?; buff.truncate(received_length); - log::trace!("Received data {:x?}", buff); match (self.encryption_level, &self.encryptor) { (EncryptionLevel::Protocol | EncryptionLevel::Full, Some(encryptor)) => { encryptor @@ -87,7 +103,6 @@ where fn send_to(&self, buff: &mut DataBuffer, endpoint: SocketAddr) -> Result { use crate::packet::PacketType; - log::trace!("Send data {:x?}", buff); match (self.encryption_level, &self.encryptor) { (EncryptionLevel::Protocol | EncryptionLevel::Full, Some(encryptor)) => { let packet_type = PacketType::from_bytes(buff); @@ -114,14 +129,33 @@ where self.socket.send_to(buff, endpoint) } - fn try_clone(&self) -> Result - where - Self: Sized, - { + fn local_addr(&self) -> Result { + self.socket.local_addr() + } + + fn add_interest(&self, to_socket_id: &impl ToSocketId) -> Result<()> { unimplemented!() } - fn local_addr(&self) -> Result { - self.socket.local_addr() + fn modify_interest(&mut self, socket_id: usize, raw_fd: SocketRawFd) -> Result<()> { + unimplemented!() + } + + fn notified(&self, to_socket_id: &impl ToSocketId) -> bool { + unimplemented!() + } +} + +#[cfg(feature = "encryption")] +impl ToSocketId for EncryptionBoundSocket +where + S: ToSocketId, +{ + fn as_raw_fd(&self) -> SocketRawFd { + unimplemented!() + } + + fn socket_id(&self) -> usize { + unimplemented!() } } diff --git a/src/storage.rs b/src/storage.rs index be5f6e1..cc11127 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -97,7 +97,10 @@ where } let index = self.block_mapper.index(block); if index <= self.current_block_written { - return Err(StorageError::AlreadyWriten); + return Err(StorageError::ExpectedBlock(( + self.block_mapper.block(self.current_block_written + 1), + self.block_mapper.block(self.current_block_written), + ))); } if self.window_size > 1 && index > self.current_block_written + 1 { @@ -197,7 +200,7 @@ where /// Since not all packet are received/acknowledged this serves as a repeater pub trait BlockReader { - fn next(&mut self) -> Result, StorageError>; + fn next(&mut self, retry_timeout: Duration) -> Result, StorageError>; /// release block returning data size released fn free_block(&mut self, block: u16) -> usize; fn is_finished(&self) -> bool; @@ -211,7 +214,6 @@ pub struct FileReader { file_reading_finished: bool, current_block_read: u64, block_size: u16, - retry_timeout: Duration, instant: InstantCallback, block_mapper: BlockMapper, start_window: u64, @@ -228,7 +230,6 @@ where reader: T, max_blocks_in_memory: u16, block_size: u16, - retry_timeout: Duration, instant: InstantCallback, window_size: u16, ) -> Self { @@ -251,7 +252,6 @@ where file_reading_finished: false, current_block_read: 0, block_size, - retry_timeout, instant, block_mapper: BlockMapper::new(), start_window: 1, @@ -262,7 +262,7 @@ where } fn read_block(&mut self, block: u64) -> Result, StorageError> { - let position = (block as u64 - 1) * self.block_size as u64; + let position = (block - 1) * self.block_size as u64; #[cfg(feature = "seek")] if matches!(self.reader.seek(SeekFrom::Current(0)), Ok(p) if p != position) { self.reader.seek(SeekFrom::Start(position))?; @@ -283,13 +283,13 @@ impl BlockReader for FileReader where T: Read + Seek, { - fn next(&mut self) -> Result, StorageError> { + fn next(&mut self, retry_timeout: Duration) -> Result, StorageError> { if self.is_finished() { return Ok(None); } if self.window_size > 1 { - if self.window_last_read.elapsed() > self.retry_timeout { + if self.window_last_read.elapsed() > retry_timeout { self.window_last_read = (self.instant)(); self.current_block_read = self.start_window - 1; } @@ -301,7 +301,7 @@ where let next = self .blocks .iter_mut() - .find(|(_, t)| t.last_read.elapsed() >= self.retry_timeout) + .find(|(_, t)| t.last_read.elapsed() >= retry_timeout) .map(|(b, t)| (b, t)); if next.is_some() @@ -338,7 +338,6 @@ where return Ok(None); }; } - let earliest = self .blocks .iter() @@ -416,17 +415,15 @@ where if self.window_size <= 1 { size += self.blocks.remove(&index).map(|t| t.size).unwrap_or(0); - } else { - if index <= self.end_window && index >= self.start_window { - for b in self.start_window..=index { - size += self.blocks.remove(&b).map(|t| t.size).unwrap_or(0); - } - self.start_window = index + 1; - self.current_block_read = index; - if !self.file_reading_finished { - self.end_window = self.current_block_read + self.window_size as u64; - } - }; + } else if index <= self.end_window && index >= self.start_window { + for b in self.start_window..=index { + size += self.blocks.remove(&b).map(|t| t.size).unwrap_or(0); + } + self.start_window = index + 1; + self.current_block_read = index; + if !self.file_reading_finished { + self.end_window = self.current_block_read + self.window_size as u64; + } } size } @@ -458,7 +455,7 @@ impl BlockMapper { if block < 10000 { let next_block = self.next_block_set - 1; return (next_block * u16::MAX as u64) + block as u64 + next_block; - } else if block >= 10000 && block < 20000 { + } else if (10000..20000).contains(&block) { self.current_block_set += 1; } } @@ -541,9 +538,8 @@ mod tests { block_writer.write_block(1, &random_bytes[..20]).unwrap(); let result = block_writer.write_block(1, &random_bytes[..20]); assert!( - matches!(result, Err(StorageError::AlreadyWriten)), - "{:?}", - result + matches!(result, Err(StorageError::ExpectedBlock((2, 1)))), + "{result:?}", ); assert!(!block_writer.is_finished_below(20)); block_writer.write_block(3, &random_bytes[40..60]).unwrap(); @@ -552,8 +548,7 @@ mod tests { let result = block_writer.write_block(6, &random_bytes[100..102]); assert!( matches!(result, Err(StorageError::CapacityReached)), - "{:?}", - result + "{result:?}", ); assert!(!block_writer.is_finished_below(20)); block_writer.write_block(2, &random_bytes[20..40]).unwrap(); @@ -579,8 +574,7 @@ mod tests { let result = block_writer.write_block(5, &random_bytes[..20]); assert!( matches!(result, Err(StorageError::CapacityReached)), - "{:?}", - result + "{result:?}", ); block_writer.write_block(1, &random_bytes[..20]).unwrap(); block_writer.write_block(4, &random_bytes[..20]).unwrap(); @@ -588,29 +582,25 @@ mod tests { let result = block_writer.write_block(5, &random_bytes[..20]); assert!( matches!(result, Err(StorageError::AlreadyWriten)), - "{:?}", - result + "{result:?}", ); let result = block_writer.write_block(6, &random_bytes[..20]); assert!( matches!(result, Err(StorageError::CapacityReached)), - "{:?}", - result + "{result:?}", ); let result = block_writer.write_block(9, &random_bytes[..20]); assert!( matches!(result, Err(StorageError::CapacityReached)), - "{:?}", - result + "{result:?}", ); block_writer.write_block(2, &random_bytes[..20]).unwrap(); block_writer.write_block(6, &random_bytes[..2]).unwrap(); assert!(!block_writer.is_finished_below(20)); let result = block_writer.write_block(2, &random_bytes[..20]); assert!( - matches!(result, Err(StorageError::AlreadyWriten)), - "{:?}", - result + matches!(result, Err(StorageError::ExpectedBlock(_))), + "{result:?}", ); block_writer.write_block(3, &random_bytes[..20]).unwrap(); assert!(block_writer.is_finished_below(20)); @@ -629,22 +619,19 @@ mod tests { assert!(!last_in_windown); let result = block_writer.write_block(1, &random_bytes[..20]); assert!( - matches!(result, Err(StorageError::AlreadyWriten)), - "{:?}", - result + matches!(result, Err(StorageError::ExpectedBlock(_))), + "{result:?}", ); assert!(!block_writer.is_finished_below(20)); let result = block_writer.write_block(3, &random_bytes[40..60]); assert!( matches!(result, Err(StorageError::ExpectedBlock(_))), - "{:?}", - result + "{result:?}", ); let result = block_writer.write_block(6, &random_bytes[100..102]); assert!( matches!(result, Err(StorageError::CapacityReached)), - "{:?}", - result + "{result:?}", ); assert!(!block_writer.is_finished_below(20)); @@ -673,9 +660,7 @@ mod tests { fn test_block_write_window_size_1_packet() { let random_bytes: Vec = (0..102).map(|_| rand::random::()).collect(); let cursor = Arc::new(Mutex::new(Cursor::new(vec![]))); - let writer = CursorWriter { - cursor: cursor.clone(), - }; + let writer = CursorWriter { cursor }; let mut block_writer = FileWriter::from_writer(writer, 20, 0, 3); let (s, last_in_windown) = block_writer.write_block(1, &random_bytes[..16]).unwrap(); assert_eq!(s, 16); @@ -691,42 +676,36 @@ mod tests { let inner_reader = CursorReader { cursor: inner_reader, }; - let mut block_reader = FileReader::from_reader( - inner_reader, - 2, - 20, - Duration::from_millis(100), - instant_callback, - 1, - ); + let retry_timeout = Duration::from_millis(100); + let mut block_reader = FileReader::from_reader(inner_reader, 2, 20, instant_callback, 1); //can read upto maximum blocks - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 1); assert_eq!(&block.data, &random_bytes[0..20]); sleep(Duration::from_millis(20)); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert_eq!(result.unwrap().block, 2); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert!(result.is_none()); // retry reading last blocks sleep(Duration::from_millis(101)); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 1); assert_eq!(&block.data, &random_bytes[0..20]); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 2); assert_eq!(&block.data, &random_bytes[20..40]); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert!(result.is_none()); // can read more blocks after free let size = block_reader.free_block(1); assert_eq!(size, 20); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert_eq!(result.unwrap().block, 3); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert!(result.is_none()); let size = block_reader.free_block(2); @@ -735,50 +714,45 @@ mod tests { assert_eq!(size, 20); let size = block_reader.free_block(4); assert_eq!(size, 0); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert_eq!(result.unwrap().block, 4); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert_eq!(result.unwrap().block, 5); let size = block_reader.free_block(5); assert_eq!(size, 20); // last block is empty - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 6); assert_eq!(block.data, []); - let block = block_reader.next().unwrap(); + let block = block_reader.next(retry_timeout).unwrap(); assert!(block.is_none()); let size = block_reader.free_block(6); assert_eq!(size, 0); // its not finished until all blocks are freed - assert!(!block_reader.is_finished(), "{:?}", block_reader); - let result = block_reader.next().unwrap(); + assert!(!block_reader.is_finished(), "{block_reader:?}"); + let result = block_reader.next(retry_timeout).unwrap(); assert!(result.is_none()); let size = block_reader.free_block(4); assert_eq!(size, 20); - assert!(block_reader.is_finished(), "{:?}", block_reader); - let result = block_reader.next().unwrap(); + assert!(block_reader.is_finished(), "{block_reader:?}"); + let result = block_reader.next(retry_timeout).unwrap(); assert!(result.is_none()); } #[test] fn test_block_read_window_size() { let random_bytes: Vec = (0..100).map(|_| rand::random::()).collect(); - let inner_reader = Cursor::new(random_bytes.clone()); + let inner_reader = Cursor::new(random_bytes); #[cfg(not(feature = "std"))] let inner_reader = CursorReader { cursor: inner_reader, }; - let mut block_reader = FileReader::from_reader( - inner_reader, - 2, - 20, - Duration::from_millis(100), - instant_callback, - 4, - ); + let retry_timeout = Duration::from_millis(100); + + let mut block_reader = FileReader::from_reader(inner_reader, 2, 20, instant_callback, 4); let size = block_reader.free_block(1); assert_eq!(size, 0); @@ -786,7 +760,7 @@ mod tests { assert_eq!(size, 0); // can free first block - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 1); let size = block_reader.free_block(0); assert_eq!(size, 0); @@ -794,101 +768,95 @@ mod tests { assert_eq!(size, 20); // can free multiple blocks which are less than a windown size - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 2); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 3); let size = block_reader.free_block(3); assert_eq!(size, 40); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 4); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 5); let size = block_reader.free_block(6); assert_eq!(size, 0); let size = block_reader.free_block(5); assert_eq!(size, 40); - assert!(!block_reader.is_finished(), "{:?}", block_reader); + assert!(!block_reader.is_finished(), "{block_reader:?}"); // free last empty block - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 6); assert_eq!(block.data, []); - assert!(!block_reader.is_finished(), "{:?}", block_reader); - let block = block_reader.next().unwrap(); + assert!(!block_reader.is_finished(), "{block_reader:?}"); + let block = block_reader.next(retry_timeout).unwrap(); assert!(block.is_none()); let size = block_reader.free_block(6); assert_eq!(size, 0); - assert!(block_reader.is_finished(), "{:?}", block_reader); - let block = block_reader.next().unwrap(); + assert!(block_reader.is_finished(), "{block_reader:?}"); + let block = block_reader.next(retry_timeout).unwrap(); assert!(block.is_none()); } #[test] fn test_block_read_window_size_full() { let random_bytes: Vec = (0..100).map(|_| rand::random::()).collect(); - let inner_reader = Cursor::new(random_bytes.clone()); + let inner_reader = Cursor::new(random_bytes); #[cfg(not(feature = "std"))] let inner_reader = CursorReader { cursor: inner_reader, }; - let mut block_reader = FileReader::from_reader( - inner_reader, - 2, - 20, - Duration::from_millis(100), - instant_callback, - 3, - ); + let retry_timeout = Duration::from_millis(100); + let mut block_reader = FileReader::from_reader(inner_reader, 2, 20, instant_callback, 3); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 1); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 2); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 3); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert!(result.is_none()); let size = block_reader.free_block(3); assert_eq!(size, 60); - assert!(!block_reader.is_finished(), "{:?}", block_reader); + assert!(!block_reader.is_finished(), "{block_reader:?}"); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 4); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 5); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 6); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert!(result.is_none()); // read window again after timeout sleep(Duration::from_millis(101)); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 4); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 5); - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 6); - let result = block_reader.next().unwrap(); + let result = block_reader.next(retry_timeout).unwrap(); assert!(result.is_none()); - assert!(!block_reader.is_finished(), "{:?}", block_reader); + assert!(!block_reader.is_finished(), "{block_reader:?}"); let size = block_reader.free_block(3); - assert_eq!(size, 0, "{:?}", block_reader); + assert_eq!(size, 0, "{block_reader:?}"); let size = block_reader.free_block(4); - assert_eq!(size, 20, "{:?}", block_reader); + assert_eq!(size, 20, "{block_reader:?}"); // block starts from next - let block = block_reader.next().unwrap().unwrap(); + let block = block_reader.next(retry_timeout).unwrap().unwrap(); assert_eq!(block.block, 5); let size = block_reader.free_block(6); - assert_eq!(size, 20, "{:?}", block_reader); - assert!(block_reader.is_finished(), "{:?}", block_reader); + assert_eq!(size, 20, "{block_reader:?}"); + assert!(block_reader.is_finished(), "{block_reader:?}"); } #[test] diff --git a/test b/test index ce72562..bebdb24 100755 --- a/test +++ b/test @@ -88,10 +88,10 @@ test_high_load() { head -c 1M /dev/urandom > $TEST_DIR/server/samples/$i done; for ((i=1;i<=size;i++)); do - $client send 127.0.0.1:9000 $TEST_DIR/client/samples/$i --window-size $window_size & pids+=($!) + $client send 127.0.0.1:9000 $TEST_DIR/client/samples/$i --listen 127.0.0.1:0 --window-size $window_size & pids+=($!) done for ((i=size+1;i<=max_size;i++)); do - $client receive 127.0.0.1:9000 $i --window-size $window_size --local-path $TEST_DIR/delete/$i & pids+=($!) + $client receive 127.0.0.1:9000 $i --window-size $window_size --listen 127.0.0.1:0 --local-path $TEST_DIR/delete/$i & pids+=($!) done time wait "${pids[@]}"