From fd2030af3420ae9144f1ee1188572e9de1290804 Mon Sep 17 00:00:00 2001 From: Tomas Jakstas Date: Mon, 11 Sep 2023 20:07:21 +0300 Subject: [PATCH] Client retry only for certain error codes. Fix initial buffer on retry. --- binary/Cargo.toml | 1 + binary/src/main.rs | 150 ++++++++++++++++++------------- builder/debian/postinst | 3 +- builder/debian/service | 1 + src/client/connection.rs | 12 ++- src/client/receiver.rs | 19 ++-- src/client/sender.rs | 3 +- src/packet.rs | 14 +++ src/readers/block_reader.rs | 2 +- src/server/connection_builder.rs | 8 -- src/server/helpers/connection.rs | 12 +-- src/server/helpers/write.rs | 2 +- src/server/multi_thread.rs | 4 +- src/server/single_thread.rs | 44 ++++----- src/server/validation.rs | 5 +- 15 files changed, 157 insertions(+), 123 deletions(-) diff --git a/binary/Cargo.toml b/binary/Cargo.toml index 578feb2..415f5f5 100644 --- a/binary/Cargo.toml +++ b/binary/Cargo.toml @@ -52,6 +52,7 @@ depends = "$auto" section = "utility" priority = "optional" maintainer-scripts = "../builder/debian" +revision="" assets = [ ["target/release/tftp", "usr/bin/", "755"], ["../README.md", "usr/share/doc/tftp/README", "644"], diff --git a/binary/src/main.rs b/binary/src/main.rs index f2ec695..0493869 100644 --- a/binary/src/main.rs +++ b/binary/src/main.rs @@ -324,21 +324,25 @@ mod tests { #[test] fn test_client_full_encryption() { // env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init(); - let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); - let key: [u8; 32] = bytes.try_into().unwrap(); - let server_private_key: PrivateKey = key.into(); - client_send( - EncryptionLevel::Protocol, - Some(server_private_key.clone()), - None, - None, - ); - client_receive( - EncryptionLevel::Protocol, - Some(server_private_key), - None, - None, - ); + for w in [1, 4] { + let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); + let key: [u8; 32] = bytes.try_into().unwrap(); + let server_private_key: PrivateKey = key.into(); + client_send( + EncryptionLevel::Protocol, + w, + Some(server_private_key.clone()), + None, + None, + ); + client_receive( + EncryptionLevel::Protocol, + w, + Some(server_private_key), + None, + None, + ); + } } #[allow(unused_must_use)] @@ -346,36 +350,42 @@ mod tests { #[test] fn test_client_full_encryption_only_authorized() { // env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init(); - let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); - let key: [u8; 32] = bytes.try_into().unwrap(); - let server_private_key: PrivateKey = key.into(); - let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); - let key: [u8; 32] = bytes.try_into().unwrap(); - let client_private_key: PrivateKey = key.into(); - let mut authorized_keys = AuthorizedKeys::new(); - - authorized_keys.push(PublicKey::from(&client_private_key)); - client_send( - EncryptionLevel::Protocol, - Some(server_private_key.clone()), - #[cfg(feature = "encryption")] - Some(authorized_keys.clone()), - Some(client_private_key.clone()), - ); - client_receive( - EncryptionLevel::Protocol, - Some(server_private_key), - #[cfg(feature = "encryption")] - Some(authorized_keys), - Some(client_private_key), - ); + for w in [1, 4] { + let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); + let key: [u8; 32] = bytes.try_into().unwrap(); + let server_private_key: PrivateKey = key.into(); + let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); + let key: [u8; 32] = bytes.try_into().unwrap(); + let client_private_key: PrivateKey = key.into(); + let mut authorized_keys = AuthorizedKeys::new(); + + authorized_keys.push(PublicKey::from(&client_private_key)); + client_send( + EncryptionLevel::Protocol, + w, + Some(server_private_key.clone()), + #[cfg(feature = "encryption")] + Some(authorized_keys.clone()), + Some(client_private_key.clone()), + ); + client_receive( + EncryptionLevel::Protocol, + w, + Some(server_private_key), + #[cfg(feature = "encryption")] + Some(authorized_keys), + Some(client_private_key), + ); + } } #[cfg(feature = "encryption")] #[test] fn test_client_protocol_encryption() { - client_send(EncryptionLevel::Protocol, None, None, None); - client_receive(EncryptionLevel::Protocol, None, None, None); + for w in [1, 4] { + client_send(EncryptionLevel::Protocol, w, None, None, None); + client_receive(EncryptionLevel::Protocol, w, None, None, None); + } } #[allow(unused_must_use)] @@ -383,31 +393,37 @@ mod tests { #[test] fn test_client_protocol_encryption_authorized() { // env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init(); - let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); - let key: [u8; 32] = bytes.try_into().unwrap(); - let client_private_key: PrivateKey = key.into(); - let mut authorized_keys = AuthorizedKeys::new(); - authorized_keys.push(PublicKey::from(&client_private_key)); - client_send( - EncryptionLevel::Protocol, - None, - Some(authorized_keys.clone()), - Some(client_private_key.clone()), - ); - client_receive( - EncryptionLevel::Protocol, - None, - Some(authorized_keys), - Some(client_private_key), - ); + for w in [1, 4] { + let bytes: Vec = (0..32).map(|_| rand::random::()).collect(); + let key: [u8; 32] = bytes.try_into().unwrap(); + let client_private_key: PrivateKey = key.into(); + let mut authorized_keys = AuthorizedKeys::new(); + authorized_keys.push(PublicKey::from(&client_private_key)); + client_send( + EncryptionLevel::Protocol, + w, + None, + Some(authorized_keys.clone()), + Some(client_private_key.clone()), + ); + client_receive( + EncryptionLevel::Protocol, + w, + None, + Some(authorized_keys), + Some(client_private_key), + ); + } } #[cfg(feature = "encryption")] #[test] fn test_client_data_encryption() { // env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init(); - client_send(EncryptionLevel::Data, None, None, None); - client_receive(EncryptionLevel::Data, None, None, None); + for w in [1, 4] { + client_send(EncryptionLevel::Data, w, None, None, None); + client_receive(EncryptionLevel::Data, w, None, None, None); + } } #[test] @@ -415,12 +431,15 @@ mod tests { // env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")) // .format_timestamp_micros() // .init(); - client_send(EncryptionLevel::None, None, None, None); - client_receive(EncryptionLevel::None, None, None, None); + for w in [1, 4] { + client_send(EncryptionLevel::None, w, None, None, None); + client_receive(EncryptionLevel::None, w, None, None, None); + } } fn client_send( encryption_level: EncryptionLevel, + window_size: u64, server_private_key: Option, authorized_keys: Option, _client_private_key: Option, @@ -456,6 +475,7 @@ mod tests { start_send_file( server_port, encryption_level, + window_size, d, #[cfg(feature = "encryption")] server_public_key, @@ -472,6 +492,7 @@ mod tests { fn client_receive( encryption_level: EncryptionLevel, + window_size: u64, _server_private_key: Option, _authorized_keys: Option, _client_private_key: Option, @@ -508,6 +529,7 @@ mod tests { start_receive_file( server_port, encryption_level, + window_size, d, #[cfg(feature = "encryption")] server_public_key, @@ -525,6 +547,7 @@ mod tests { fn start_send_file( server_port: u16, _encryption_level: EncryptionLevel, + window_size: u64, bytes: Vec, #[cfg(feature = "encryption")] server_public_key: Option, #[cfg(feature = "encryption")] private_key: Option, @@ -544,7 +567,7 @@ mod tests { encryption_level: _encryption_level.to_string().parse().unwrap(), #[cfg(feature = "encryption")] known_hosts: None, - window_size: 1, + window_size, allow_server_port_change: false, }; @@ -566,6 +589,7 @@ mod tests { fn start_receive_file( server_port: u16, _encryption_level: EncryptionLevel, + window_size: u64, bytes: Arc>>>, #[cfg(feature = "encryption")] server_public_key: Option, #[cfg(feature = "encryption")] private_key: Option, @@ -585,7 +609,7 @@ mod tests { encryption_level: _encryption_level.to_string().parse().unwrap(), #[cfg(feature = "encryption")] known_hosts: None, - window_size: 1, + window_size, allow_server_port_change: false, }; diff --git a/builder/debian/postinst b/builder/debian/postinst index c09cb34..5e870b3 100644 --- a/builder/debian/postinst +++ b/builder/debian/postinst @@ -1,4 +1,5 @@ [ -z "$SERVER_HOME" ] && SERVER_HOME=/opt/tftp [ -z "$SERVER_USER" ] && SERVER_USER=tftp -adduser --system $SERVER_USER --home $SERVER_HOME +useradd --system "$SERVER_USER" --no-create-home --shell /bin/false || true +install -o "$SERVER_USER" -g "$SERVER_USER" -m 0750 -d "$SERVER_HOME" diff --git a/builder/debian/service b/builder/debian/service index b28d906..b5670da 100644 --- a/builder/debian/service +++ b/builder/debian/service @@ -3,6 +3,7 @@ Description=Tftp server [Service] User=tftp +Group=tftp WorkingDirectory=/opt/tftp ExecStart=/usr/bin/tftp server 0.0.0.0:69 . --allow-overwrite AmbientCapabilities=CAP_NET_BIND_SERVICE diff --git a/src/client/connection.rs b/src/client/connection.rs index d83cd35..5f33089 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -4,7 +4,6 @@ use core::time::Duration; use log::debug; use log::error; - use crate::config::ConnectionOptions; use crate::config::DEFAULT_DATA_BLOCK_SIZE; use crate::encryption::EncryptionLevel; @@ -46,6 +45,7 @@ pub fn query_server<'a>( let mut initial = true; let request_timeout = config.request_timeout; + let buffer_size = buffer.len(); loop { let request_packet = RequestPacket { @@ -56,6 +56,14 @@ pub fn query_server<'a>( let packet = create_packet(request_packet); let packet_type = packet.packet_type(); + #[cfg(feature = "alloc")] + buffer.resize(buffer_size, 0); + // TODO heapless vector resizing is super slow + #[cfg(not(feature = "alloc"))] + unsafe { + buffer.set_len(buffer_size) + }; + let (length, endpoint) = wait_for_initial_packet( socket, config.endpoint, @@ -109,7 +117,7 @@ pub fn query_server<'a>( } (_, Ok(Packet::Error(p))) => { // retry in case server does not support extensions - if initial && options.encryption_level == EncryptionLevel::None { + if matches!(p.code, ErrorCode::IllegalOperation | ErrorCode::Undefined) && initial && options.encryption_level == EncryptionLevel::None { debug!("Received error {} retrying without extensions", p.message); extensions = PacketExtensions::new(); used_extensions = Default::default(); diff --git a/src/client/receiver.rs b/src/client/receiver.rs index 0bc89ba..5eac0ce 100644 --- a/src/client/receiver.rs +++ b/src/client/receiver.rs @@ -185,7 +185,7 @@ where match Packet::from_bytes(&buffer) { Ok(Packet::Data(p)) => { - let written = match write_block( + match write_block( &socket, endpoint, &mut block_writer, @@ -194,23 +194,20 @@ where &options, &mut last_block_ack, ) { - Ok(Some(n)) => { - if n > 0 { - timeout = instant(); - total += n; + Ok(Some(written)) => { + timeout = instant(); + total += written; + + if written < options.block_size as usize { + info!("Client finished receiving with bytes {}", total); + return Ok((total, options.remote_public_key())); } - n } Ok(None) => continue, Err(e) => return Err(e), }; // this would write more than expected but only by a block size maximum handle_file_size(&socket, endpoint, total, config.max_file_size)?; - - if written < options.block_size as usize { - info!("Client finished receiving with bytes {}", total); - return Ok((total, options.remote_public_key())); - } } Ok(Packet::Error(p)) => { return Err(PacketError::RemoteError(p.message).into()); diff --git a/src/client/sender.rs b/src/client/sender.rs index e193096..39eb7ca 100644 --- a/src/client/sender.rs +++ b/src/client/sender.rs @@ -185,7 +185,7 @@ where let last_read_length = data_block.data.len(); debug!( - "Send data block {} data size {last_read_length} retry {} left to send {packets_to_send}", + "Send data block {} data size {last_read_length} retry {} remaining packets {packets_to_send}", data_block.block, data_block.retry ); @@ -244,6 +244,7 @@ where last_received.elapsed().as_micros(), wait_for.unwrap_or(Duration::ZERO).as_millis() ); + let length = match socket.recv_from(&mut buffer, wait_for) { Ok((n, s)) => { if s != endpoint { diff --git a/src/packet.rs b/src/packet.rs index 01d649f..97da082 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -647,6 +647,20 @@ mod tests { .as_slice(), b"\x00\x05\x00\x01File not found\x00" ); + + assert_eq!( + Packet::Error(ErrorPacket::new( + ErrorCode::DiskFull, + format_str!( + DefaultString, + "Unable to write file {}", + "some-file-to-test.bin" + ), + )) + .to_bytes() + .as_slice(), + b"\x00\x05\x00\x03Unable to write file some-file-to-test.bin\x00" + ); } #[test] diff --git a/src/readers/block_reader.rs b/src/readers/block_reader.rs index 49a011d..8a3c1c2 100644 --- a/src/readers/block_reader.rs +++ b/src/readers/block_reader.rs @@ -2,6 +2,7 @@ use crate::error::StorageError; use crate::types::DataBlock; pub trait BlockReader { + /// read next block fn next(&mut self, retry: bool) -> Result, StorageError>; /// release block returning data size released @@ -13,7 +14,6 @@ pub trait BlockReader { #[derive(Debug)] pub struct Block { pub block: u16, - // TODO ref pub data: DataBlock, pub retry: bool, } diff --git a/src/server/connection_builder.rs b/src/server/connection_builder.rs index 4165b10..4db089e 100644 --- a/src/server/connection_builder.rs +++ b/src/server/connection_builder.rs @@ -401,7 +401,6 @@ fn handle_encrypted( let remote_key = remote_public_key.into(); let finalized_keys = create_finalized_keys(private_key, &remote_key, Some(remote_nonce.into()), rng); - // finalized_keys.encryptor.nonce = remote_nonce.into(); Ok(Some(( size_of::() + size_of::() + size_of::(), finalized_keys, @@ -453,13 +452,6 @@ fn block_writer(writer: W) -> Writers { Writers::Single(SingleBlockWriter::new(writer)) } -// type ConnectionResult = BoxedResult<( -// Connection, -// ClientType, -// PacketExtensions, -// Option, -// )>; - #[cfg(test)] mod tests { #[allow(unused_imports)] diff --git a/src/server/helpers/connection.rs b/src/server/helpers/connection.rs index e158be3..ae26562 100644 --- a/src/server/helpers/connection.rs +++ b/src/server/helpers/connection.rs @@ -81,10 +81,10 @@ pub fn accept_connection( used_extensions: PacketExtensions, encrypt_new_connection: Option, ) -> Option<()> { - match connection_type { - ConnectionType::Read => { - debug!("Server extensions {:?}", used_extensions); + debug!("Server extensions {:?}", used_extensions); + match connection_type { + ConnectionType::Write => { if !used_extensions.is_empty() { if !connection.send_packet(Packet::OptionalAck(OptionalAck { extensions: used_extensions, @@ -102,12 +102,8 @@ pub fn accept_connection( print_options("Server writing using", &connection.options); Some(()) - - // Some((connection, ClientType::Writer(client_type))) } - ConnectionType::Write => { - debug!("Server extensions {:?}", used_extensions); - + ConnectionType::Read => { if !used_extensions.is_empty() && !connection.send_packet(Packet::OptionalAck(OptionalAck { extensions: used_extensions, diff --git a/src/server/helpers/write.rs b/src/server/helpers/write.rs index 5a6ff1c..e4dd954 100644 --- a/src/server/helpers/write.rs +++ b/src/server/helpers/write.rs @@ -55,7 +55,7 @@ pub fn handle_write( #[allow(unused_mut)] let mut write_elapsed = instant(); match write_block(connection, block_writer, p.block, p.data) { - Ok(Some(n)) if n > 0 => { + Ok(Some(n)) => { connection.last_updated = instant(); connection.transfer += n; trace!( diff --git a/src/server/multi_thread.rs b/src/server/multi_thread.rs index 9ebe193..4defde1 100644 --- a/src/server/multi_thread.rs +++ b/src/server/multi_thread.rs @@ -4,8 +4,8 @@ use std::thread::JoinHandle; use core::num::NonZeroU32; use core::time::Duration; -use log::debug; use log::info; +use log::trace; use rand::CryptoRng; use rand::RngCore; @@ -128,7 +128,7 @@ where Err(e) => return Err(e.into()), }; - debug!( + trace!( "Received connection from {from_client} in {} exists {}", received_in.elapsed().as_secs_f32(), handles.contains_key(&from_client), diff --git a/src/server/single_thread.rs b/src/server/single_thread.rs index 0f0a9d3..bf520a0 100644 --- a/src/server/single_thread.rs +++ b/src/server/single_thread.rs @@ -115,7 +115,7 @@ where let mut clients: Clients<_, _, _> = Clients::new(); - debug!( + trace!( "Size of all clients in memory {} bytes", size_of_val(&clients) ); @@ -134,19 +134,19 @@ where buffer.set_len(max_buffer_size as usize) }; - // let sent_in = instant(); + let sent_in = instant(); let (sent, recv_next_client_to_send) = send_data_blocks(&mut clients, next_client_to_send); next_client_to_send = recv_next_client_to_send; wait_control.sending(sent); - // debug!( - // "Sent {sent} next {recv_next_client_to_send} in {}", - // sent_in.elapsed().as_secs_f32() - // ); + trace!( + "Sent {sent} next {recv_next_client_to_send} in {}", + sent_in.elapsed().as_secs_f32() + ); - // let client_received_in = instant(); + let client_received_in = instant(); let client_received = clients.iter().skip(next_client_to_receive).find_map( |(client_socket_addr, (connection, _))| { @@ -170,13 +170,13 @@ where }, ); - // debug!( - // "Received from client {:?} next {next_client_to_receive} in {}", - // client_received, - // client_received_in.elapsed().as_secs_f32() - // ); + trace!( + "Received from client {:?} next {next_client_to_receive} in {}", + client_received, + client_received_in.elapsed().as_secs_f32() + ); - // let received_in = instant(); + let received_in = instant(); let (received_length, from_client) = match client_received { Some(r) => r, @@ -200,16 +200,16 @@ where }, }; - // debug!( - // "Received connection from {from_client} in {}", - // received_in.elapsed().as_secs_f32(), - // ); + trace!( + "Received connection from {from_client} in {}", + received_in.elapsed().as_secs_f32(), + ); wait_control.receiving(); buffer.truncate(received_length); let clients_len = clients.len(); - // let processed_in = instant(); + let processed_in = instant(); match clients.entry(from_client) { Entry::Occupied(mut entry) => { @@ -354,10 +354,10 @@ where } } - // debug!( - // "Processed connection from {from_client} in {}", - // processed_in.elapsed().as_secs_f32(), - // ); + trace!( + "Processed connection from {from_client} in {}", + processed_in.elapsed().as_secs_f32(), + ); } } diff --git a/src/server/validation.rs b/src/server/validation.rs index 300a9ba..c389e3a 100644 --- a/src/server/validation.rs +++ b/src/server/validation.rs @@ -35,7 +35,7 @@ pub fn validate_request_options( }; let packet = Packet::Error(ErrorPacket::new( - ErrorCode::IllegalOperation, + ErrorCode::AccessVioliation, format_str!( DefaultString, "Missing extension {} while {} provided", @@ -75,9 +75,8 @@ pub fn handle_file_size(received_size: u64, max_file_size: u64) -> Result<(), Er ); let message = format_str!( DefaultString, - "Invalid file size received {} expected {}", + "Invalid file size received {}", received_size, - max_file_size ); return Err(ErrorPacket::new(ErrorCode::DiskFull, message)); }