diff --git a/Cargo.toml b/Cargo.toml index 05ce27f..d2103ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,13 @@ log = "0.4.14" rand_core = "0.6.3" x25519-dalek = { version = "2.0.0-pre.1", optional = true } curve25519-dalek = { version = "4.0.0-pre.2", optional = true } -tokio = { version = "1.36", features = ["net", "io-util"]} +tokio = { version = "1.36", features = ["net", "io-util"] } +tokio-util = { version = "0.7.10", features = ["codec"] } thiserror = "1" rand = "0.8.5" +futures = "0.3.30" +pin-project = "1" +hex = "0.4.3" [dev-dependencies] hex = "0.4.3" diff --git a/examples/echo_client.rs b/examples/echo_client.rs index 59d068e..47d6aa2 100644 --- a/examples/echo_client.rs +++ b/examples/echo_client.rs @@ -1,4 +1,5 @@ use adnl::{AdnlPeer, AdnlRawPublicKey}; +use futures::{SinkExt, StreamExt}; use std::{env, error::Error}; #[tokio::main] @@ -9,20 +10,19 @@ async fn main() -> Result<(), Box> { let public_key_hex = env::args() .nth(2) - .unwrap_or_else(|| "b7d8e88f4033eff806e2f5dff3c785be7dd038c923146e2d9fe80e4fe3cb8805".to_string()); + .unwrap_or_else(|| "691a14528fb2911839649c489cb4cbec1f4aa126c244c0ea2ac294eb568a7037".to_string()); let remote_public = AdnlRawPublicKey::try_from(&*hex::decode(public_key_hex)?)?; // act as a client: connect to ADNL server and perform handshake - let mut client = AdnlPeer::connect(&remote_public, addr).await?; + let mut client = AdnlPeer::connect(&remote_public, addr).await.expect("adnl connect"); // send over ADNL - client.send(&mut "hello".as_bytes().to_vec()).await?; + client.send("hello".as_bytes().into()).await.expect("send"); - // receive result into vector - let mut result = Vec::::new(); - client.receive(&mut result).await?; + // receive result + let result = client.next().await.expect("packet must be received")?; - println!("received: {}", String::from_utf8(result).unwrap()); + println!("received: {}", String::from_utf8(result.to_vec()).unwrap()); Ok(()) } diff --git a/examples/echo_server.rs b/examples/echo_server.rs index d2f95f5..55e0b93 100644 --- a/examples/echo_server.rs +++ b/examples/echo_server.rs @@ -3,6 +3,7 @@ use std::{env, error::Error}; use adnl::{AdnlPeer, AdnlPrivateKey, AdnlPublicKey}; +use futures::{SinkExt, StreamExt}; use tokio::net::TcpListener; use x25519_dalek::StaticSecret; @@ -16,7 +17,7 @@ async fn main() -> Result<(), Box> { .unwrap_or_else(|| "127.0.0.1:8080".to_string()); // ADNL: get private key from environment variable KEY or use default insecure one - let private_key_hex = env::var("KEY").unwrap_or_else(|_| "69734189c0348245a70eb5335e12bfd75dd4cffc42baf32773e8f994ff5cf7c2".to_string()); + let private_key_hex = env::var("KEY").unwrap_or_else(|_| "f0971651aec4bb0d65ec3861c597687fda9c1e7d2ee8a93acb9a131aa9f3aee7".to_string()); let private_key_bytes: [u8; 32] = hex::decode(private_key_hex)?.try_into().unwrap(); let private_key = StaticSecret::from(private_key_bytes); @@ -27,7 +28,7 @@ async fn main() -> Result<(), Box> { println!("Listening on: {}", addr); // ADNL: print public key and adnl address associated with given private key - println!("Public key is: {}", hex::encode(private_key.public().as_bytes())); + println!("Public key is: {}", hex::encode(private_key.public().edwards_repr())); println!("Address is: {}", hex::encode(private_key.public().address().as_bytes())); loop { @@ -47,18 +48,9 @@ async fn main() -> Result<(), Box> { // ADNL: handle handshake let mut adnl_server = AdnlPeer::handle_handshake(socket, &private_key).await.expect("handshake failed"); - let mut buf = vec![0; 1024]; - // In a loop, read data from the socket and write the data back. - loop { - let n = adnl_server.receive(&mut buf) - .await - .expect("failed to read data from socket"); - - adnl_server - .send(&mut buf[..n]) - .await - .expect("failed to write data to socket"); + while let Some(Ok(packet)) = adnl_server.next().await { + let _ = adnl_server.send(packet).await; } }); } diff --git a/examples/time.rs b/examples/time.rs index a09c146..ebe7644 100644 --- a/examples/time.rs +++ b/examples/time.rs @@ -1,4 +1,5 @@ use adnl::{AdnlPeer, AdnlRawPublicKey}; +use futures::{SinkExt, StreamExt}; use std::{error::Error, net::SocketAddrV4}; #[tokio::main] @@ -9,18 +10,16 @@ async fn main() -> Result<(), Box> { let ls_ip = "65.21.74.140"; let ls_port = 46427; // act as a client: connect to ADNL server and perform handshake - let mut client = - AdnlPeer::connect(&remote_public, SocketAddrV4::new(ls_ip.parse()?, ls_port)).await?; + let mut client = AdnlPeer::connect(&remote_public, SocketAddrV4::new(ls_ip.parse()?, ls_port)).await?; // already serialized TL with gettime query - let mut query = hex::decode("7af98bb435263e6c95d6fecb497dfd0aa5f031e7d412986b5ce720496db512052e8f2d100cdf068c7904345aad16000000000000")?; + let query = hex::decode("7af98bb435263e6c95d6fecb497dfd0aa5f031e7d412986b5ce720496db512052e8f2d100cdf068c7904345aad16000000000000")?; - // send over ADNL, use random nonce - client.send(&mut query).await?; + // send over ADNL + client.send(query.into()).await?; - // receive result into vector, use 8192 bytes buffer - let mut result = Vec::::new(); - client.receive(&mut result).await?; + // receive result + let result = client.next().await.ok_or_else(|| "no result")??; // get time from serialized TL answer println!( diff --git a/src/helper_types.rs b/src/helper_types.rs index f3debc5..dee234e 100644 --- a/src/helper_types.rs +++ b/src/helper_types.rs @@ -7,22 +7,24 @@ pub trait CryptoRandom: rand_core::RngCore + rand_core::CryptoRng {} impl CryptoRandom for T where T: rand_core::RngCore + rand_core::CryptoRng {} pub trait AdnlPublicKey { + /// Derives address from a public key fn address(&self) -> AdnlAddress { let mut hasher = Sha256::new(); hasher.update([0xc6, 0xb4, 0x13, 0x48]); // type id - always ed25519 - hasher.update(self.to_bytes()); + hasher.update(self.edwards_repr()); AdnlAddress(hasher.finalize().into()) } - fn to_bytes(&self) -> [u8; 32]; + /// Gets ed25519 representation of a public key + fn edwards_repr(&self) -> [u8; 32]; } -/// Public key can be provided using raw slice +/// Public key can be provided in a ed25519 form using raw slice #[derive(Clone)] pub struct AdnlRawPublicKey([u8; 32]); impl AdnlPublicKey for AdnlRawPublicKey { - fn to_bytes(&self) -> [u8; 32] { + fn edwards_repr(&self) -> [u8; 32] { self.0 } } @@ -104,6 +106,17 @@ impl From<[u8; 160]> for AdnlAesParams { } impl AdnlAesParams { + /// Swap receiver and transciever keys + pub fn swap(self) -> Self { + Self { + rx_key: self.tx_key, + tx_key: self.rx_key, + rx_nonce: self.tx_nonce, + tx_nonce: self.rx_nonce, + padding: self.padding, + } + } + pub fn rx_key(&self) -> &[u8; 32] { &self.rx_key } @@ -188,18 +201,16 @@ impl AdnlSecret { /// Common error type #[derive(Debug, Error)] pub enum AdnlError { - #[error("Read error")] - ReadError(Error), - #[error("Write error")] - WriteError(Error), - #[error("Consume error")] - ConsumeError(Error), + #[error("IO error")] + IoError(#[from] Error), #[error("Integrity error")] IntegrityError, - #[error("TooShortPacket error")] + #[error("Too short packet (32 bytes min)")] TooShortPacket, + #[error("Too long packet (4 MiB max)")] + TooLongPacket, #[error("Receiver ADNL address mismatch")] UnknownAddr(AdnlAddress), - #[error(transparent)] - OtherError(#[from] Error), + #[error("End of stream")] + EndOfStream, } diff --git a/src/integrations/dalek.rs b/src/integrations/dalek.rs index ba37030..8dc22a6 100644 --- a/src/integrations/dalek.rs +++ b/src/integrations/dalek.rs @@ -7,7 +7,7 @@ use curve25519_dalek::montgomery::MontgomeryPoint; use x25519_dalek::{PublicKey, StaticSecret}; impl AdnlPublicKey for PublicKey { - fn to_bytes(&self) -> [u8; 32] { + fn edwards_repr(&self) -> [u8; 32] { MontgomeryPoint(self.to_bytes()) .to_edwards(0) .unwrap() @@ -18,7 +18,7 @@ impl AdnlPublicKey for PublicKey { fn edwards_to_montgomery(public_key: &P) -> PublicKey { PublicKey::from( - CompressedEdwardsY::from_slice(&public_key.to_bytes()) + CompressedEdwardsY::from_slice(&public_key.edwards_repr()) .decompress() .unwrap() .to_montgomery() diff --git a/src/lib.rs b/src/lib.rs index 1babafb..16794e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,7 @@ pub use helper_types::{ AdnlAddress, AdnlAesParams, AdnlError, AdnlPrivateKey, AdnlPublicKey, AdnlSecret, AdnlRawPublicKey, }; pub use primitives::handshake::AdnlHandshake; -pub use primitives::receive::AdnlReceiver; -pub use primitives::send::AdnlSender; +pub use primitives::codec::AdnlCodec; pub use wrappers::builder::AdnlBuilder; pub use wrappers::peer::AdnlPeer; diff --git a/src/primitives/codec.rs b/src/primitives/codec.rs new file mode 100644 index 0000000..cc9e4bd --- /dev/null +++ b/src/primitives/codec.rs @@ -0,0 +1,102 @@ +use aes::cipher::{KeyIvInit, StreamCipher}; +use sha2::{Digest, Sha256}; +use tokio_util::{bytes::{Buf, Bytes, BytesMut}, codec::{Decoder, Encoder}}; + +use crate::{AdnlAesParams, AdnlError}; + +use super::AdnlAes; + +pub struct AdnlCodec { + aes_rx: AdnlAes, + aes_tx: AdnlAes, + last_readed_length: Option, +} + +impl AdnlCodec { + pub fn new(aes_params: &AdnlAesParams) -> Self { + Self { + aes_rx: AdnlAes::new(aes_params.rx_key().into(), aes_params.rx_nonce().into()), + aes_tx: AdnlAes::new(aes_params.tx_key().into(), aes_params.tx_nonce().into()), + last_readed_length: None, + } + } +} + +impl Decoder for AdnlCodec { + type Item = Bytes; + + type Error = AdnlError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let length = if let Some(length) = self.last_readed_length { + length + } else { + if src.len() < 4 { + return Ok(None) + } + self.aes_rx.apply_keystream(&mut src[..4]); + let mut length_bytes = [0u8; 4]; + length_bytes.copy_from_slice(&src[..4]); + let length = u32::from_le_bytes(length_bytes) as usize; + if length < 64 { + return Err(AdnlError::TooShortPacket); + } + if length > (1 << 24) { + return Err(AdnlError::TooLongPacket); + } + src.advance(4); + self.last_readed_length = Some(length); + length + }; + + // not enough bytes, need to wait for more data + if src.len() < length { + if src.capacity() < length { + src.reserve(length - src.capacity()); + } + return Ok(None) + } + + self.last_readed_length = None; + + // decode packet + self.aes_rx.apply_keystream(&mut src[..length]); + let given_hash = &src[length-32..length]; + + // integrity check + let mut hasher = Sha256::new(); + hasher.update(&src[..length-32]); + if given_hash != hasher.finalize().as_slice() { + return Err(AdnlError::IntegrityError) + } + + // copy and return buffer + let result = Bytes::copy_from_slice(&src[32..length-32]); + src.advance(length); + Ok(Some(result)) + } +} + +impl Encoder for AdnlCodec { + type Error = AdnlError; + + fn encode(&mut self, buffer: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> { + if buffer.len() > ((1 << 24) - 64) { + return Err(AdnlError::TooLongPacket); + } + let length = ((buffer.len() + 64) as u32).to_le_bytes(); + let nonce = rand::random::<[u8; 32]>(); + let mut hash = Sha256::new(); + hash.update(&nonce); + hash.update(&buffer); + let hash = hash.finalize(); + dst.reserve(buffer.len() + 68); + dst.extend_from_slice(&length); + dst.extend_from_slice(&nonce); + dst.extend_from_slice(&buffer); + dst.extend_from_slice(&hash); + let start_offset = dst.len() - buffer.len() - 68; + self.aes_tx.apply_keystream(&mut dst[start_offset..]); + Ok(()) + } +} \ No newline at end of file diff --git a/src/primitives/handshake.rs b/src/primitives/handshake.rs index d07d0eb..94c9477 100644 --- a/src/primitives/handshake.rs +++ b/src/primitives/handshake.rs @@ -6,6 +6,8 @@ use ctr::cipher::StreamCipher; use sha2::{Digest, Sha256}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use super::codec::AdnlCodec; + /// Handshake packet, must be sent from client to server prior to any datagrams pub struct AdnlHandshake { receiver: AdnlAddress, @@ -55,12 +57,16 @@ impl AdnlHandshake

{ let mut packet = [0u8; 256]; packet[..32].copy_from_slice(self.receiver.as_bytes()); - packet[32..64].copy_from_slice(&self.sender.to_bytes()); + packet[32..64].copy_from_slice(&self.sender.edwards_repr()); packet[64..96].copy_from_slice(&hash); packet[96..256].copy_from_slice(&raw_params); packet } + pub fn make_codec(&self) -> AdnlCodec { + AdnlCodec::new(&self.aes_params) + } + /// Send handshake over the given transport, build [`AdnlClient`] on top of it pub async fn perform_handshake( &self, @@ -111,7 +117,7 @@ impl AdnlHandshake { Ok(Self { receiver, sender, - aes_params: AdnlAesParams::from(raw_params), + aes_params: AdnlAesParams::from(raw_params).swap(), secret, }) } diff --git a/src/primitives/mod.rs b/src/primitives/mod.rs index b32d077..1836048 100644 --- a/src/primitives/mod.rs +++ b/src/primitives/mod.rs @@ -3,6 +3,5 @@ use ctr::Ctr128BE; pub type AdnlAes = Ctr128BE; +pub mod codec; pub mod handshake; -pub mod receive; -pub mod send; diff --git a/src/primitives/receive.rs b/src/primitives/receive.rs deleted file mode 100644 index 253e124..0000000 --- a/src/primitives/receive.rs +++ /dev/null @@ -1,118 +0,0 @@ -use crate::primitives::AdnlAes; -use crate::{AdnlAesParams, AdnlError}; -use aes::cipher::KeyIvInit; -use ctr::cipher::StreamCipher; -use sha2::{Digest, Sha256}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -/// Low-level incoming datagram processor -pub struct AdnlReceiver { - aes: AdnlAes, -} - -impl AdnlReceiver { - /// Create receiver with given session parameters - pub fn new(aes_params: &AdnlAesParams) -> Self { - Self { - aes: AdnlAes::new(aes_params.rx_key().into(), aes_params.rx_nonce().into()), - } - } - - /// Receive datagram from `transport`. Received parts of the decrypted buffer - /// will be sent to `consumer`, which usually can be just `Vec`. Note that - /// data can be processed before this function will return, but in case of - /// [`AdnlError::IntegrityError`] you must assume that the data was tampered. - /// - /// Returns received packet length. - /// - /// You can adjust `BUFFER` according to your memory requirements. - /// Recommended size is 8192 bytes. - pub async fn receive( - &mut self, - transport: &mut R, - consumer: &mut C, - ) -> Result { - // read length - let mut length = [0u8; 4]; - log::debug!("reading length"); - transport - .read_exact(&mut length) - .await - .map_err(AdnlError::ReadError)?; - self.aes.apply_keystream(&mut length); - let length = u32::from_le_bytes(length); - log::debug!("length = {}", length); - if length < 64 { - return Err(AdnlError::TooShortPacket); - } - - let mut hasher = Sha256::new(); - - // read nonce - let mut nonce = [0u8; 32]; - log::debug!("reading nonce"); - transport - .read_exact(&mut nonce) - .await - .map_err(AdnlError::ReadError)?; - self.aes.apply_keystream(&mut nonce); - hasher.update(nonce); - - // read buffer chunks, decrypt and write to consumer - if BUFFER > 0 { - let mut buffer = [0u8; BUFFER]; - let mut bytes_to_read = length as usize - 64; - while bytes_to_read >= BUFFER { - log::debug!( - "chunked read (chunk len = {}), {} bytes remaining", - BUFFER, - bytes_to_read - ); - transport - .read_exact(&mut buffer) - .await - .map_err(AdnlError::ReadError)?; - self.aes.apply_keystream(&mut buffer); - hasher.update(buffer); - consumer - .write_all(&buffer) - .await - .map_err(AdnlError::WriteError)?; - bytes_to_read -= BUFFER; - } - - // read last chunk - if bytes_to_read > 0 { - log::debug!("last chunk, {} bytes remaining", bytes_to_read); - let buffer = &mut buffer[..bytes_to_read]; - transport - .read_exact(buffer) - .await - .map_err(AdnlError::ReadError)?; - self.aes.apply_keystream(buffer); - hasher.update(&buffer); - consumer - .write_all(buffer) - .await - .map_err(AdnlError::WriteError)?; - } - } - - let mut given_hash = [0u8; 32]; - log::debug!("reading hash"); - transport - .read_exact(&mut given_hash) - .await - .map_err(AdnlError::ReadError)?; - self.aes.apply_keystream(&mut given_hash); - - let real_hash = hasher.finalize(); - if real_hash.as_slice() != given_hash { - return Err(AdnlError::IntegrityError); - } - - log::debug!("receive finished successfully"); - - Ok(length as usize) - } -} diff --git a/src/primitives/send.rs b/src/primitives/send.rs deleted file mode 100644 index 41a7e93..0000000 --- a/src/primitives/send.rs +++ /dev/null @@ -1,71 +0,0 @@ -use aes::cipher::KeyIvInit; -use ctr::cipher::StreamCipher; -use sha2::{Digest, Sha256}; -use tokio::io::AsyncWriteExt; - -use crate::primitives::AdnlAes; -use crate::{AdnlAesParams, AdnlError}; - -/// Low-level outgoing datagram generator -pub struct AdnlSender { - aes: AdnlAes, -} - -impl AdnlSender { - /// Create sender with given session parameters - pub fn new(aes_params: &AdnlAesParams) -> Self { - Self { - aes: AdnlAes::new(aes_params.tx_key().into(), aes_params.tx_nonce().into()), - } - } - - /// Get estimated size of datagram for the given buffer - pub fn estimate_packet_length(buffer: &[u8]) -> u32 { - buffer.len() as u32 + 68 - } - - /// Send `buffer` over `transport` with `nonce`. Note that `nonce` must be random - /// in order to prevent bit-flipping attacks when an attacker knows whole plaintext in datagram. - pub async fn send( - &mut self, - transport: &mut W, - nonce: &mut [u8; 32], - buffer: &mut [u8], - ) -> Result<(), AdnlError> { - // remember not to send more than 4 GiB in a single packet - let mut length = ((buffer.len() + 64) as u32).to_le_bytes(); - - // calc hash - let mut hasher = Sha256::new(); - hasher.update(*nonce); - hasher.update(&*buffer); - let mut hash: [u8; 32] = hasher.finalize().into(); - - // encrypt packet - self.aes.apply_keystream(&mut length); - self.aes.apply_keystream(nonce); - self.aes.apply_keystream(buffer); - self.aes.apply_keystream(&mut hash); - - // write to transport - transport - .write_all(&length) - .await - .map_err(AdnlError::WriteError)?; - transport - .write_all(nonce) - .await - .map_err(AdnlError::WriteError)?; - transport - .write_all(buffer) - .await - .map_err(AdnlError::WriteError)?; - transport - .write_all(&hash) - .await - .map_err(AdnlError::WriteError)?; - transport.flush().await.map_err(AdnlError::WriteError)?; - - Ok(()) - } -} diff --git a/src/tests.rs b/src/tests.rs index 68c343c..5d11af5 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,7 +1,13 @@ extern crate alloc; +use std::error::Error; + use super::*; use alloc::vec::Vec; +use futures::{SinkExt, StreamExt}; +use tokio::net::TcpListener; +use tokio_util::{bytes::BytesMut, codec::{Decoder, Encoder}}; +use x25519_dalek::StaticSecret; #[test] fn test_handshake_1() { @@ -84,78 +90,115 @@ fn test_handshake( let handshake2 = AdnlHandshake::decrypt_from_raw(expected_handshake.as_slice().try_into().unwrap(), &key).expect("invalid handshake"); assert_eq!(handshake2.aes_params().to_bytes(), aes_params_raw, "aes_params mismatch"); assert_eq!(handshake2.receiver(), &remote_public.address(), "receiver mismatch"); - assert_eq!(handshake2.sender().to_bytes(), local_public.to_bytes(), "sender mismatch"); + assert_eq!(handshake2.sender().edwards_repr(), local_public.edwards_repr(), "sender mismatch"); assert_eq!(&handshake2.to_bytes(), expected_handshake.as_slice(), "reencryption failed"); } -#[tokio::test] -async fn test_send_1() { +#[test] +fn test_send_1() { let aes_params = hex::decode("b3d529e34b839a521518447b68343aebaae9314ac95aaacfdb687a2163d1a98638db306b63409ef7bc906b4c9dc115488cf90dfa964f520542c69e1a4a495edf9ae9ee72023203c8b266d552f251e8d724929733428c8e276ab3bd6291367336a6ab8dc3d36243419bd0b742f76691a5dec14edbd50f7c1b58ec961ae45be58cbf6623f3ec9705bd5d227761ec79cee377e2566ff668f863552bddfd6ff3a16b").unwrap(); - let nonce = - hex::decode("9a5ecd5d9afdfff2823e7520fa1c338f2baf1a21f51e6fdab0491d45a50066f7").unwrap(); + let _nonce = hex::decode("9a5ecd5d9afdfff2823e7520fa1c338f2baf1a21f51e6fdab0491d45a50066f7").unwrap(); let buffer = hex::decode("7af98bb471ff48e9b263959b17a04faae4a23501380d2aa932b09eac6f9846fcbae9bbcb0cdf068c7904345aad16000000000000").unwrap(); let expected_packet = hex::decode("250d70d08526791bc2b6278ded7bf2b051afb441b309dda06f76e4419d7c31d4d5baafc4ff71e0ebabe246d4ea19e3e579bd15739c8fc916feaf46ea7a6bc562ed1cf87c9bf4220eb037b9a0b58f663f0474b8a8b18fa24db515e41e4b02e509d8ef261a27ba894cbbecc92e59fc44bf5ff7c8281cb5e900").unwrap(); - test_send(aes_params, nonce, buffer, expected_packet).await; + test_send(aes_params, buffer, expected_packet); } -#[tokio::test] -async fn test_send_2() { +#[test] +fn test_send_2() { let aes_params = hex::decode("7e3c66de7c64d4bee4368e69560101991db4b084430a336cffe676c9ac0a795d8c98367309422a8e927e62ed657ba3eaeeb6acd3bbe5564057dfd1d60609a25a48963cbb7d14acf4fc83ec59254673bc85be22d04e80e7b83c641d37cae6e1d82a400bf159490bbc0048e69234ad89e999d792eefdaa56734202546d9188706e95e1272267206a8e7ee1f7c077f76bd26e494972e34d72e257bf20364dbf39b0").unwrap(); - let nonce = - hex::decode("d36d0683da23e62910fa0e8a9331dfc257db4cde0ba8d63893e88ac4de7d8d6c").unwrap(); + let _nonce = hex::decode("d36d0683da23e62910fa0e8a9331dfc257db4cde0ba8d63893e88ac4de7d8d6c").unwrap(); let buffer = hex::decode("7af98bb47bcae111ea0e56457826b1aec7f0f59b9b6579678b3db3839d17b63eb60174f20cdf068c7904345aad16000000000000").unwrap(); let expected_packet = hex::decode("24c709a0f676750ddaeafc8564d84546bfc831af27fb66716de382a347a1c32adef1a27e597c8a07605a09087fff32511d314970cad3983baefff01e7ee51bb672b17f7914a6d3f229a13acb14cdc14d98beae8a1e96510756726913541f558c2ffac63ed6cb076d0e888c3c0bb014d9f229c2a3f62e0847").unwrap(); - test_send(aes_params, nonce, buffer, expected_packet).await; + test_send(aes_params, buffer, expected_packet); } -async fn test_send(aes_params: Vec, nonce: Vec, buffer: Vec, expected_packet: Vec) { - let mut nonce = nonce.try_into().unwrap(); - let mut buffer = buffer; +fn test_send(aes_params: Vec, buffer: Vec, expected_packet: Vec) { let aes_params: [u8; 160] = aes_params.try_into().unwrap(); - let aes_params = AdnlAesParams::from(aes_params); - let mut protocol_client = AdnlSender::new(&aes_params); - let mut packet = Vec::::new(); - let _result = protocol_client - .send(&mut packet, &mut nonce, &mut buffer) - .await; - assert_eq!( - packet.as_slice(), - &expected_packet, - "outcoming packet is wrong" - ); + let mut codec = AdnlCodec::new(&aes_params.into()); + let mut packet = BytesMut::new(); + codec.encode(buffer.clone().into(), &mut packet).expect("packet must be encoded correctly"); + + // do not check nonce and hash as it's random + assert_eq!(&packet[..4], &expected_packet[..4], "outcoming packet length is wrong"); + assert_eq!(&packet[36..packet.len()-32], &expected_packet[36..expected_packet.len()-32], "outcoming packet length is wrong"); + + // check packet decoding to original buffer + // swap aes params + let mut new_aes_params = [0u8; 160]; + new_aes_params[..32].copy_from_slice(&aes_params[32..64]); + new_aes_params[32..64].copy_from_slice(&aes_params[..32]); + new_aes_params[64..80].copy_from_slice(&aes_params[80..96]); + new_aes_params[80..96].copy_from_slice(&aes_params[64..80]); + new_aes_params[96..160].copy_from_slice(&aes_params[96..160]); + let mut codec = AdnlCodec::new(&new_aes_params.into()); + test_recv(&mut codec, packet.into(), buffer); } -#[tokio::test] -async fn test_recv_1() { +#[test] +fn test_recv_1() { let encrypted_data = hex::decode("81e95e433c87c9ad2a716637b3a12644fbfb12dbd02996abc40ed2beb352483d6ecf9e2ad181a5abde4d4146ca3a8524739d3acebb2d7599cc6b81967692a62118997e16").unwrap(); let expected_data = Vec::new(); let aes_params = hex::decode("b3d529e34b839a521518447b68343aebaae9314ac95aaacfdb687a2163d1a98638db306b63409ef7bc906b4c9dc115488cf90dfa964f520542c69e1a4a495edf9ae9ee72023203c8b266d552f251e8d724929733428c8e276ab3bd6291367336a6ab8dc3d36243419bd0b742f76691a5dec14edbd50f7c1b58ec961ae45be58cbf6623f3ec9705bd5d227761ec79cee377e2566ff668f863552bddfd6ff3a16b").unwrap(); - let aes_params: [u8; 160] = aes_params.try_into().unwrap(); - let aes_params = AdnlAesParams::from(aes_params); - let mut protocol_client = AdnlReceiver::new(&aes_params); - test_recv(&mut protocol_client, encrypted_data, expected_data).await; + let aes_params: [u8; 160] = aes_params.as_slice().try_into().unwrap(); + let mut codec = AdnlCodec::new(&aes_params.into()); + test_recv(&mut codec, encrypted_data, expected_data); let encrypted_data = hex::decode("4b72a32bf31894cce9ceffd2dd97176e502946524e45e62689bd8c5d31ad53603c5fd3b402771f707cd2747747fad9df52e6c23ceec9fa2ee5b0f68b61c33c7790db03d1c593798a29d716505cea75acdf0e031c25447c55c4d29d32caab29bd5a0787644843bafc04160c92140aab0ecc990927").unwrap(); let expected_data = hex::decode("1684ac0f71ff48e9b263959b17a04faae4a23501380d2aa932b09eac6f9846fcbae9bbcb080d0053e9a3ac3062000000").unwrap(); - test_recv(&mut protocol_client, encrypted_data, expected_data).await; + test_recv(&mut codec, encrypted_data, expected_data); } -#[tokio::test] -async fn test_recv_2() { +#[test] +fn test_recv_2() { let encrypted_data = hex::decode("b75dcf27582beb4031d6d3700c9b7925bf84a78f2bd16b186484d36427a8824ac86e27cea81eb5bcbac447a37269845c65be51babd11c80627f81b4247f84df16d05c4f1").unwrap(); let expected_data = Vec::new(); let aes_params = hex::decode("7e3c66de7c64d4bee4368e69560101991db4b084430a336cffe676c9ac0a795d8c98367309422a8e927e62ed657ba3eaeeb6acd3bbe5564057dfd1d60609a25a48963cbb7d14acf4fc83ec59254673bc85be22d04e80e7b83c641d37cae6e1d82a400bf159490bbc0048e69234ad89e999d792eefdaa56734202546d9188706e95e1272267206a8e7ee1f7c077f76bd26e494972e34d72e257bf20364dbf39b0").unwrap(); - let aes_params: [u8; 160] = aes_params.try_into().unwrap(); - let aes_params = AdnlAesParams::from(aes_params); - let mut protocol_client = AdnlReceiver::new(&aes_params); - test_recv(&mut protocol_client, encrypted_data, expected_data).await; + let aes_params: [u8; 160] = aes_params.as_slice().try_into().unwrap(); + let mut codec = AdnlCodec::new(&aes_params.into()); + test_recv(&mut codec, encrypted_data, expected_data); let encrypted_data = hex::decode("77ebea5a6e6c8758e7703d889abad16e7e3c4e0c10c4e81ca10d0d9abddabb6f008905133a070ff825ad3f4b0ae969e04dbd8b280864d3d2175f3bc7cf3deb31de5497fa43997d8e2acafb9a31de2a22ecb279b5854c00791216e39c2e65863539d82716fc020e9647b2dd99d0f14e4f553b645f").unwrap(); let expected_data = hex::decode("1684ac0f7bcae111ea0e56457826b1aec7f0f59b9b6579678b3db3839d17b63eb60174f2080d0053e90bb03062000000").unwrap(); - test_recv(&mut protocol_client, encrypted_data, expected_data).await; + test_recv(&mut codec, encrypted_data, expected_data); } -async fn test_recv(client: &mut AdnlReceiver, encrypted_packet: Vec, expected_data: Vec) { - let mut data = Vec::::new(); - let mut binding = encrypted_packet.as_slice(); - let _r = client.receive::<_, _, 8192>(&mut binding, &mut data).await; +fn test_recv(codec: &mut AdnlCodec, encrypted_packet: Vec, expected_data: Vec) { + let data = codec.decode(&mut encrypted_packet.as_slice().into()).expect("decoding must be correct").expect("input must contain full packet"); assert_eq!(data, expected_data.as_slice(), "incoming packet is wrong"); } + +#[test] +fn test_public_key_consistency() { + let private_key: [u8; 32] = hex::decode("7d6336b1ca12d641bc26733ca1f866ccd8d89f3fa57b6168f484211db2e712cb").unwrap().try_into().unwrap(); + let private_key: StaticSecret = StaticSecret::from(private_key); + let public_key_raw = AdnlRawPublicKey::from(private_key.public().edwards_repr()); + assert_eq!(public_key_raw.edwards_repr(), private_key.public().edwards_repr()); + assert_eq!(public_key_raw.address(), private_key.public().address()); +} + +#[tokio::test] +async fn integrity_test() { + let server_private = StaticSecret::new(rand::thread_rng()); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let server_public = server_private.public(); + tokio::spawn(async move { + loop { + let (socket, _) = listener.accept().await.unwrap(); + let private_key = server_private.clone(); + tokio::spawn(async move { + let mut adnl_server = AdnlPeer::handle_handshake(socket, &private_key).await.expect("handshake failed"); + while let Some(Ok(packet)) = adnl_server.next().await { + let _ = adnl_server.send(packet).await; + } + }); + } + }); + + // act as a client: connect to ADNL server and perform handshake + let mut client = AdnlPeer::connect(&server_public, ("127.0.0.1", port)).await.expect("adnl connect"); + + // send over ADNL + client.send("hello".as_bytes().into()).await.expect("send"); + + // receive result + let result = client.next().await.expect("packet must be received").expect("packet must be decoded properly"); +} \ No newline at end of file diff --git a/src/wrappers/peer.rs b/src/wrappers/peer.rs index 59a3d25..3668361 100644 --- a/src/wrappers/peer.rs +++ b/src/wrappers/peer.rs @@ -1,13 +1,22 @@ -use crate::{AdnlBuilder, AdnlError, AdnlHandshake, AdnlPrivateKey, AdnlPublicKey, AdnlReceiver, AdnlSender}; -use tokio::io::{empty, AsyncReadExt, AsyncWriteExt}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::{AdnlBuilder, AdnlError, AdnlHandshake, AdnlPrivateKey, AdnlPublicKey}; +use pin_project::pin_project; +use tokio::io::{empty, AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio_util::bytes::Bytes; +use tokio_util::codec::{Decoder, Encoder, Framed}; use x25519_dalek::StaticSecret; +use futures::{Sink, SinkExt, Stream, StreamExt}; + +use crate::primitives::codec::AdnlCodec; /// Abstraction over [`AdnlSender`] and [`AdnlReceiver`] to keep things simple -pub struct AdnlPeer { - sender: AdnlSender, - receiver: AdnlReceiver, - transport: T, +#[pin_project] +pub struct AdnlPeer where T: AsyncRead + AsyncWrite { + #[pin] + stream: Framed, } impl AdnlPeer { @@ -36,27 +45,22 @@ impl AdnlPeer { impl AdnlPeer { /// Act as a client: send `handshake` over `transport` and check that handshake was successful /// Returns client part of ADNL connection - pub async fn perform_handshake( - mut transport: T, - handshake: &AdnlHandshake

, - ) -> Result { + pub async fn perform_handshake(mut transport: T, handshake: &AdnlHandshake

) -> Result { // send handshake transport .write_all(&handshake.to_bytes()) .await - .map_err(AdnlError::WriteError)?; + .map_err(AdnlError::IoError)?; + + let mut stream = handshake.make_codec().framed(transport); // receive empty message to ensure that server knows our AES keys - let mut client = Self { - sender: AdnlSender::new(handshake.aes_params()), - receiver: AdnlReceiver::new(handshake.aes_params()), - transport, - }; - let mut empty = empty(); - client - .receive_with_buffer::<_, 0>(&mut empty) - .await?; - Ok(client) + if let Some(x) = stream.next().await { + x?; + Ok(Self { stream }) + } else { + Err(AdnlError::EndOfStream) + } } /// Act as a server: receive handshake over transport. @@ -66,56 +70,46 @@ impl AdnlPeer { pub async fn handle_handshake(mut transport: T, private_key: &S) -> Result { // receive handshake let mut packet = [0u8; 256]; - transport.read_exact(&mut packet).await.map_err(AdnlError::ReadError)?; + transport.read_exact(&mut packet).await.map_err(AdnlError::IoError)?; let handshake = AdnlHandshake::decrypt_from_raw(&packet, private_key)?; let mut server = Self { - sender: AdnlSender::new(handshake.aes_params()), - receiver: AdnlReceiver::new(handshake.aes_params()), - transport, + stream: handshake.make_codec().framed(transport), }; // send empty packet to proof knowledge of AES keys - server.send(&mut []).await?; + server.send(Bytes::new()).await?; Ok(server) } +} - /// Send `data` to another peer with random nonce - pub async fn send(&mut self, data: &mut [u8]) -> Result<(), AdnlError> { - self.sender - .send(&mut self.transport, &mut rand::random(), data) - .await +impl Stream for AdnlPeer where T: AsyncRead + AsyncWrite +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) } +} - /// Send `data` to another peer. Random `nonce` must be provided to eliminate bit-flipping attacks. - pub async fn send_with_nonce( - &mut self, - data: &mut [u8], - nonce: &mut [u8; 32], - ) -> Result<(), AdnlError> { - self.sender.send(&mut self.transport, nonce, data).await +impl Sink for AdnlPeer where T: AsyncWrite + AsyncRead +{ + type Error = AdnlError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_ready(cx) } - /// Receive data from another peer into `consumer` which will process the data with - /// a `BUFFER` size of 8192 bytes. - pub async fn receive( - &mut self, - consumer: &mut C, - ) -> Result { - self.receiver - .receive::<_, _, 8192>(&mut self.transport, consumer) - .await + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.project().stream.start_send(item) } - /// Receive data from another peer into `consumer` which will process the data. Set `BUFFER` - /// according to your memory requirements, recommended size is 8192 bytes. - pub async fn receive_with_buffer( - &mut self, - consumer: &mut C, - ) -> Result { - self.receiver - .receive::<_, _, BUFFER>(&mut self.transport, consumer) - .await + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_close(cx) } } \ No newline at end of file