Skip to content

Commit

Permalink
Client retry only for certain error codes. Fix initial buffer on retry.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomjakstas committed Sep 11, 2023
1 parent 76bb924 commit fd2030a
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 123 deletions.
1 change: 1 addition & 0 deletions binary/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ depends = "$auto"
section = "utility"
priority = "optional"
maintainer-scripts = "../builder/debian"
revision=""
assets = [
["target/release/tftp", "usr/bin/", "755"],
["../README.md", "usr/share/doc/tftp/README", "644"],
Expand Down
150 changes: 87 additions & 63 deletions binary/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,103 +324,122 @@ mod tests {
#[test]
fn test_client_full_encryption() {
// env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init();
let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
let key: [u8; 32] = bytes.try_into().unwrap();
let server_private_key: PrivateKey = key.into();
client_send(
EncryptionLevel::Protocol,
Some(server_private_key.clone()),
None,
None,
);
client_receive(
EncryptionLevel::Protocol,
Some(server_private_key),
None,
None,
);
for w in [1, 4] {
let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
let key: [u8; 32] = bytes.try_into().unwrap();
let server_private_key: PrivateKey = key.into();
client_send(
EncryptionLevel::Protocol,
w,
Some(server_private_key.clone()),
None,
None,
);
client_receive(
EncryptionLevel::Protocol,
w,
Some(server_private_key),
None,
None,
);
}
}

#[allow(unused_must_use)]
#[cfg(feature = "encryption")]
#[test]
fn test_client_full_encryption_only_authorized() {
// env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init();
let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
let key: [u8; 32] = bytes.try_into().unwrap();
let server_private_key: PrivateKey = key.into();
let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
let key: [u8; 32] = bytes.try_into().unwrap();
let client_private_key: PrivateKey = key.into();
let mut authorized_keys = AuthorizedKeys::new();

authorized_keys.push(PublicKey::from(&client_private_key));
client_send(
EncryptionLevel::Protocol,
Some(server_private_key.clone()),
#[cfg(feature = "encryption")]
Some(authorized_keys.clone()),
Some(client_private_key.clone()),
);
client_receive(
EncryptionLevel::Protocol,
Some(server_private_key),
#[cfg(feature = "encryption")]
Some(authorized_keys),
Some(client_private_key),
);
for w in [1, 4] {
let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
let key: [u8; 32] = bytes.try_into().unwrap();
let server_private_key: PrivateKey = key.into();
let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
let key: [u8; 32] = bytes.try_into().unwrap();
let client_private_key: PrivateKey = key.into();
let mut authorized_keys = AuthorizedKeys::new();

authorized_keys.push(PublicKey::from(&client_private_key));
client_send(
EncryptionLevel::Protocol,
w,
Some(server_private_key.clone()),
#[cfg(feature = "encryption")]
Some(authorized_keys.clone()),
Some(client_private_key.clone()),
);
client_receive(
EncryptionLevel::Protocol,
w,
Some(server_private_key),
#[cfg(feature = "encryption")]
Some(authorized_keys),
Some(client_private_key),
);
}
}

#[cfg(feature = "encryption")]
#[test]
fn test_client_protocol_encryption() {
client_send(EncryptionLevel::Protocol, None, None, None);
client_receive(EncryptionLevel::Protocol, None, None, None);
for w in [1, 4] {
client_send(EncryptionLevel::Protocol, w, None, None, None);
client_receive(EncryptionLevel::Protocol, w, None, None, None);
}
}

#[allow(unused_must_use)]
#[cfg(feature = "encryption")]
#[test]
fn test_client_protocol_encryption_authorized() {
// env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init();
let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
let key: [u8; 32] = bytes.try_into().unwrap();
let client_private_key: PrivateKey = key.into();
let mut authorized_keys = AuthorizedKeys::new();
authorized_keys.push(PublicKey::from(&client_private_key));
client_send(
EncryptionLevel::Protocol,
None,
Some(authorized_keys.clone()),
Some(client_private_key.clone()),
);
client_receive(
EncryptionLevel::Protocol,
None,
Some(authorized_keys),
Some(client_private_key),
);
for w in [1, 4] {
let bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
let key: [u8; 32] = bytes.try_into().unwrap();
let client_private_key: PrivateKey = key.into();
let mut authorized_keys = AuthorizedKeys::new();
authorized_keys.push(PublicKey::from(&client_private_key));
client_send(
EncryptionLevel::Protocol,
w,
None,
Some(authorized_keys.clone()),
Some(client_private_key.clone()),
);
client_receive(
EncryptionLevel::Protocol,
w,
None,
Some(authorized_keys),
Some(client_private_key),
);
}
}

#[cfg(feature = "encryption")]
#[test]
fn test_client_data_encryption() {
// env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init();
client_send(EncryptionLevel::Data, None, None, None);
client_receive(EncryptionLevel::Data, None, None, None);
for w in [1, 4] {
client_send(EncryptionLevel::Data, w, None, None, None);
client_receive(EncryptionLevel::Data, w, None, None, None);
}
}

#[test]
fn test_client_no_encryption() {
// env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug"))
// .format_timestamp_micros()
// .init();
client_send(EncryptionLevel::None, None, None, None);
client_receive(EncryptionLevel::None, None, None, None);
for w in [1, 4] {
client_send(EncryptionLevel::None, w, None, None, None);
client_receive(EncryptionLevel::None, w, None, None, None);
}
}

fn client_send(
encryption_level: EncryptionLevel,
window_size: u64,
server_private_key: Option<PrivateKey>,
authorized_keys: Option<AuthorizedKeys>,
_client_private_key: Option<PrivateKey>,
Expand Down Expand Up @@ -456,6 +475,7 @@ mod tests {
start_send_file(
server_port,
encryption_level,
window_size,
d,
#[cfg(feature = "encryption")]
server_public_key,
Expand All @@ -472,6 +492,7 @@ mod tests {

fn client_receive(
encryption_level: EncryptionLevel,
window_size: u64,
_server_private_key: Option<PrivateKey>,
_authorized_keys: Option<AuthorizedKeys>,
_client_private_key: Option<PrivateKey>,
Expand Down Expand Up @@ -508,6 +529,7 @@ mod tests {
start_receive_file(
server_port,
encryption_level,
window_size,
d,
#[cfg(feature = "encryption")]
server_public_key,
Expand All @@ -525,6 +547,7 @@ mod tests {
fn start_send_file(
server_port: u16,
_encryption_level: EncryptionLevel,
window_size: u64,
bytes: Vec<u8>,
#[cfg(feature = "encryption")] server_public_key: Option<ShortString>,
#[cfg(feature = "encryption")] private_key: Option<ShortString>,
Expand All @@ -544,7 +567,7 @@ mod tests {
encryption_level: _encryption_level.to_string().parse().unwrap(),
#[cfg(feature = "encryption")]
known_hosts: None,
window_size: 1,
window_size,
allow_server_port_change: false,
};

Expand All @@ -566,6 +589,7 @@ mod tests {
fn start_receive_file(
server_port: u16,
_encryption_level: EncryptionLevel,
window_size: u64,
bytes: Arc<Mutex<Cursor<Vec<u8>>>>,
#[cfg(feature = "encryption")] server_public_key: Option<ShortString>,
#[cfg(feature = "encryption")] private_key: Option<ShortString>,
Expand All @@ -585,7 +609,7 @@ mod tests {
encryption_level: _encryption_level.to_string().parse().unwrap(),
#[cfg(feature = "encryption")]
known_hosts: None,
window_size: 1,
window_size,
allow_server_port_change: false,
};

Expand Down
3 changes: 2 additions & 1 deletion builder/debian/postinst
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[ -z "$SERVER_HOME" ] && SERVER_HOME=/opt/tftp
[ -z "$SERVER_USER" ] && SERVER_USER=tftp

adduser --system $SERVER_USER --home $SERVER_HOME
useradd --system "$SERVER_USER" --no-create-home --shell /bin/false || true
install -o "$SERVER_USER" -g "$SERVER_USER" -m 0750 -d "$SERVER_HOME"
1 change: 1 addition & 0 deletions builder/debian/service
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Description=Tftp server

[Service]
User=tftp
Group=tftp
WorkingDirectory=/opt/tftp
ExecStart=/usr/bin/tftp server 0.0.0.0:69 . --allow-overwrite
AmbientCapabilities=CAP_NET_BIND_SERVICE
Expand Down
12 changes: 10 additions & 2 deletions src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use core::time::Duration;
use log::debug;
use log::error;


use crate::config::ConnectionOptions;
use crate::config::DEFAULT_DATA_BLOCK_SIZE;
use crate::encryption::EncryptionLevel;
Expand Down Expand Up @@ -46,6 +45,7 @@ pub fn query_server<'a>(
let mut initial = true;

let request_timeout = config.request_timeout;
let buffer_size = buffer.len();

loop {
let request_packet = RequestPacket {
Expand All @@ -56,6 +56,14 @@ pub fn query_server<'a>(
let packet = create_packet(request_packet);
let packet_type = packet.packet_type();

#[cfg(feature = "alloc")]
buffer.resize(buffer_size, 0);
// TODO heapless vector resizing is super slow
#[cfg(not(feature = "alloc"))]
unsafe {
buffer.set_len(buffer_size)
};

let (length, endpoint) = wait_for_initial_packet(
socket,
config.endpoint,
Expand Down Expand Up @@ -109,7 +117,7 @@ pub fn query_server<'a>(
}
(_, Ok(Packet::Error(p))) => {
// retry in case server does not support extensions
if initial && options.encryption_level == EncryptionLevel::None {
if matches!(p.code, ErrorCode::IllegalOperation | ErrorCode::Undefined) && initial && options.encryption_level == EncryptionLevel::None {
debug!("Received error {} retrying without extensions", p.message);
extensions = PacketExtensions::new();
used_extensions = Default::default();
Expand Down
19 changes: 8 additions & 11 deletions src/client/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ where

match Packet::from_bytes(&buffer) {
Ok(Packet::Data(p)) => {
let written = match write_block(
match write_block(
&socket,
endpoint,
&mut block_writer,
Expand All @@ -194,23 +194,20 @@ where
&options,
&mut last_block_ack,
) {
Ok(Some(n)) => {
if n > 0 {
timeout = instant();
total += n;
Ok(Some(written)) => {
timeout = instant();
total += written;

if written < options.block_size as usize {
info!("Client finished receiving with bytes {}", total);
return Ok((total, options.remote_public_key()));
}
n
}
Ok(None) => continue,
Err(e) => return Err(e),
};
// this would write more than expected but only by a block size maximum
handle_file_size(&socket, endpoint, total, config.max_file_size)?;

if written < options.block_size as usize {
info!("Client finished receiving with bytes {}", total);
return Ok((total, options.remote_public_key()));
}
}
Ok(Packet::Error(p)) => {
return Err(PacketError::RemoteError(p.message).into());
Expand Down
3 changes: 2 additions & 1 deletion src/client/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ where
let last_read_length = data_block.data.len();

debug!(
"Send data block {} data size {last_read_length} retry {} left to send {packets_to_send}",
"Send data block {} data size {last_read_length} retry {} remaining packets {packets_to_send}",
data_block.block, data_block.retry
);

Expand Down Expand Up @@ -244,6 +244,7 @@ where
last_received.elapsed().as_micros(),
wait_for.unwrap_or(Duration::ZERO).as_millis()
);

let length = match socket.recv_from(&mut buffer, wait_for) {
Ok((n, s)) => {
if s != endpoint {
Expand Down
14 changes: 14 additions & 0 deletions src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,20 @@ mod tests {
.as_slice(),
b"\x00\x05\x00\x01File not found\x00"
);

assert_eq!(
Packet::Error(ErrorPacket::new(
ErrorCode::DiskFull,
format_str!(
DefaultString,
"Unable to write file {}",
"some-file-to-test.bin"
),
))
.to_bytes()
.as_slice(),
b"\x00\x05\x00\x03Unable to write file some-file-to-test.bin\x00"
);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/readers/block_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::error::StorageError;
use crate::types::DataBlock;

pub trait BlockReader {
/// read next block
fn next(&mut self, retry: bool) -> Result<Option<Block>, StorageError>;

/// release block returning data size released
Expand All @@ -13,7 +14,6 @@ pub trait BlockReader {
#[derive(Debug)]
pub struct Block {
pub block: u16,
// TODO ref
pub data: DataBlock,
pub retry: bool,
}
Loading

0 comments on commit fd2030a

Please sign in to comment.