diff --git a/Cargo.lock b/Cargo.lock index cb3b5c8..7235ebb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -660,9 +660,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.0-alpha.4" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c23376606de66c7b9249d091b59ee55b52df72063e1cae7bb44e0691c9e5150" +checksum = "5bc238b76c51bbc449c55ffbc39d03772a057cc8cf783c49d4af4c2537b74a8b" dependencies = [ "log", "ring", @@ -682,7 +682,7 @@ dependencies = [ "mbedtls", "rustls", "rustls-mbedtls-provider-utils", - "rustls-pemfile 2.0.0-alpha.2", + "rustls-pemfile 2.0.0", "rustls-pki-types", "rustls-webpki", "webpki-roots", @@ -697,7 +697,7 @@ dependencies = [ "mbedtls", "rustls", "rustls-mbedtls-provider-utils", - "rustls-pemfile 1.0.3", + "rustls-pemfile 2.0.0", "rustls-pki-types", "x509-parser", ] @@ -711,6 +711,8 @@ dependencies = [ "rustls-mbedcrypto-provider", "rustls-mbedpki-provider", "rustls-native-certs", + "rustls-pemfile 2.0.0", + "rustls-pki-types", ] [[package]] @@ -745,9 +747,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.0.0-alpha.2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e9975e1f0807681e097d288d545dc40c98a4d3a6ef95a40b18d00e5e4daa9a4" +checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4" dependencies = [ "base64", "rustls-pki-types", @@ -755,15 +757,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "0.2.2" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdf0cbc2bc68777eb846b2b7fedf03807bb763adc585bf006ac2fa2884daa9d1" +checksum = "e7673e0aa20ee4937c6aacfc12bb8341cfbf054cdd21df6bec5fd0629fe9339b" [[package]] name = "rustls-webpki" -version = "0.102.0-alpha.6" +version = "0.102.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34d9ed3a8267782ba32d257ff5b197b63eef19a467dbd1be011caaae35ee416e" +checksum = "de2635c8bc2b88d367767c5de8ea1d8db9af3f6219eba28442242d9ab81d1b89" dependencies = [ "ring", "rustls-pki-types", @@ -1012,9 +1014,9 @@ checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "webpki-roots" -version = "0.26.0-alpha.2" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87e3d99d80231fabcc72d887ed09f843b7f3942c75907285e51112a46c8f6f81" +checksum = "0de2cfda980f21be5a7ed2eadb3e6fe074d56022bea2cdeb1a62eb220fc04188" dependencies = [ "rustls-pki-types", ] diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 5e54756..3cad206 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -7,8 +7,10 @@ description = "rustls-mbedtls-provider example code." publish = false [dependencies] -rustls-mbedcrypto-provider = { path = "../rustls-mbedcrypto-provider" } +rustls-mbedcrypto-provider = { path = "../rustls-mbedcrypto-provider", features = ["tls12"] } rustls-mbedpki-provider = { path = "../rustls-mbedpki-provider" } env_logger = "0.10" -rustls = { version = "0.22.0-alpha.4", default-features = false } +rustls = { version = "0.22.0", default-features = false } rustls-native-certs = "0.6.3" +rustls-pki-types = "1" +rustls-pemfile = "2" diff --git a/examples/src/main.rs b/examples/src/main.rs index 6c9c878..a502904 100644 --- a/examples/src/main.rs +++ b/examples/src/main.rs @@ -9,7 +9,7 @@ use std::io::{stderr, stdout, Read, Write}; use std::net::TcpStream; use std::sync::Arc; -use rustls_mbedcrypto_provider::MBEDTLS; +use rustls_mbedcrypto_provider::mbedtls_crypto_provider; use rustls_mbedpki_provider::MbedTlsServerCertVerifier; fn main() { @@ -21,9 +21,7 @@ fn main() { .map(|cert| cert.0.into()) .collect(); let server_cert_verifier = MbedTlsServerCertVerifier::new(&root_certs).unwrap(); - let config = rustls::ClientConfig::builder_with_provider(MBEDTLS) - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() + let config = rustls::ClientConfig::builder_with_provider(mbedtls_crypto_provider().into()) .with_safe_default_protocol_versions() .unwrap() .dangerous() diff --git a/rustls-mbedcrypto-provider/Cargo.toml b/rustls-mbedcrypto-provider/Cargo.toml index 34f4757..a015620 100644 --- a/rustls-mbedcrypto-provider/Cargo.toml +++ b/rustls-mbedcrypto-provider/Cargo.toml @@ -12,15 +12,15 @@ categories = ["network-programming", "cryptography"] resolver = "2" [dependencies] -rustls = { version = "0.22.0-alpha.4", default-features = false } +rustls = { version = "0.22.0", default-features = false } mbedtls = { version = "0.12.0-alpha.2", default-features = false, features = [ "std", ] } log = { version = "0.4.4", optional = true } -pki-types = { package = "rustls-pki-types", version = "0.2.1", features = [ +pki-types = { package = "rustls-pki-types", version = "1", features = [ "std", ] } -webpki = { package = "rustls-webpki", version = "0.102.0-alpha.6", features = [ +webpki = { package = "rustls-webpki", version = "0.102.0", features = [ "alloc", "std", ], default-features = false } @@ -36,16 +36,11 @@ mbedtls = { version = "0.12.0-alpha.2", default-features = false, features = [ ] } [dev-dependencies] -rustls = { version = "0.22.0-alpha.4", default-features = false, features = [ +rustls = { version = "0.22.0", default-features = false, features = [ "ring", ] } -webpki = { package = "rustls-webpki", version = "0.102.0-alpha.1", default-features = false, features = [ - "alloc", - "std", -] } -pki-types = { package = "rustls-pki-types", version = "0.2.0" } -webpki-roots = "0.26.0-alpha.2" -rustls-pemfile = "=2.0.0-alpha.2" +webpki-roots = "0.26.0" +rustls-pemfile = "2" env_logger = "0.10" log = { version = "0.4.4" } diff --git a/rustls-mbedcrypto-provider/examples/client.rs b/rustls-mbedcrypto-provider/examples/client.rs index e708701..3610572 100644 --- a/rustls-mbedcrypto-provider/examples/client.rs +++ b/rustls-mbedcrypto-provider/examples/client.rs @@ -9,7 +9,7 @@ use std::io::{stdout, Read, Write}; use std::net::TcpStream; use std::sync::Arc; -use rustls_mbedcrypto_provider::MBEDTLS; +use rustls_mbedcrypto_provider::mbedtls_crypto_provider; fn main() { env_logger::init(); @@ -21,8 +21,9 @@ fn main() { .cloned(), ); - let config = rustls::ClientConfig::builder_with_provider(MBEDTLS) - .with_safe_defaults() + let config = rustls::ClientConfig::builder_with_provider(mbedtls_crypto_provider().into()) + .with_safe_default_protocol_versions() + .unwrap() .with_root_certificates(root_store) .with_no_client_auth(); diff --git a/rustls-mbedcrypto-provider/examples/internal/bench.rs b/rustls-mbedcrypto-provider/examples/internal/bench.rs index 02b4164..d774c7f 100644 --- a/rustls-mbedcrypto-provider/examples/internal/bench.rs +++ b/rustls-mbedcrypto-provider/examples/internal/bench.rs @@ -22,6 +22,7 @@ use pki_types::{CertificateDer, PrivateKeyDer}; use rustls::client::Resumption; use rustls::crypto::ring::{cipher_suite, Ticketer}; +use rustls::crypto::CryptoProvider; use rustls::server::{NoServerSessionStorage, ServerSessionMemoryCache, WebPkiClientVerifier}; use rustls::RootCertStore; use rustls::{ClientConfig, ClientConnection}; @@ -298,9 +299,7 @@ fn make_server_config( ClientAuth::No => WebPkiClientVerifier::no_client_auth(), }; - let mut cfg = ServerConfig::builder_with_provider(rustls_mbedcrypto_provider::MBEDTLS) - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() + let mut cfg = ServerConfig::builder_with_provider(rustls_mbedcrypto_provider::mbedtls_crypto_provider().into()) .with_protocol_versions(&[params.version]) .unwrap() .with_client_cert_verifier(client_auth) @@ -324,12 +323,16 @@ fn make_client_config(params: &BenchmarkParam, clientauth: ClientAuth, resume: R let mut rootbuf = io::BufReader::new(fs::File::open(params.key_type.path_for("ca.cert")).unwrap()); root_store.add_parsable_certificates(rustls_pemfile::certs(&mut rootbuf).map(|result| result.unwrap())); - let cfg = ClientConfig::builder_with_provider(rustls_mbedcrypto_provider::MBEDTLS) - .with_cipher_suites(&[params.ciphersuite]) - .with_safe_default_kx_groups() - .with_protocol_versions(&[params.version]) - .unwrap() - .with_root_certificates(root_store); + let cfg = ClientConfig::builder_with_provider( + CryptoProvider { + cipher_suites: vec![params.ciphersuite], + ..rustls_mbedcrypto_provider::mbedtls_crypto_provider() + } + .into(), + ) + .with_protocol_versions(&[params.version]) + .unwrap() + .with_root_certificates(root_store); let mut cfg = if clientauth == ClientAuth::Yes { cfg.with_client_auth_cert(params.key_type.get_client_chain(), params.key_type.get_client_key()) diff --git a/rustls-mbedcrypto-provider/src/lib.rs b/rustls-mbedcrypto-provider/src/lib.rs index b491c9f..9a75c2d 100644 --- a/rustls-mbedcrypto-provider/src/lib.rs +++ b/rustls-mbedcrypto-provider/src/lib.rs @@ -91,7 +91,10 @@ pub(crate) mod tls12; pub(crate) mod tls13; use mbedtls::rng::Random; -use rustls::{SignatureScheme, SupportedCipherSuite, WebPkiSupportedAlgorithms}; +use rustls::{ + crypto::{CryptoProvider, KeyProvider, SecureRandom, WebPkiSupportedAlgorithms}, + SignatureScheme, SupportedCipherSuite, +}; /// RNG supported by *mbedtls* pub mod rng { @@ -115,43 +118,43 @@ pub mod rng { } } -/// A `CryptoProvider` backed by the [*mbedtls*] crate. +/// returns a `CryptoProvider` backed by the [*mbedtls*] crate. /// /// [*mbedtls*]: https://github.com/fortanix/rust-mbedtls -pub static MBEDTLS: &'static dyn rustls::crypto::CryptoProvider = &Mbedtls; +pub fn mbedtls_crypto_provider() -> CryptoProvider { + CryptoProvider { + cipher_suites: ALL_CIPHER_SUITES.to_vec(), + kx_groups: ALL_KX_GROUPS.to_vec(), + signature_verification_algorithms: SUPPORTED_SIG_ALGS, + secure_random: &MbedtlsSecureRandom, + key_provider: &MbedtlsKeyProvider, + } +} -/// Crypto provider based on the [*mbedtls*] crate. -/// -/// [*mbedtls*]: https://github.com/fortanix/rust-mbedtls #[derive(Debug)] -struct Mbedtls; +/// Implements `SecureRandom` using `mbedtls` +pub struct MbedtlsSecureRandom; -impl rustls::crypto::CryptoProvider for Mbedtls { - fn fill_random(&self, bytes: &mut [u8]) -> Result<(), rustls::crypto::GetRandomFailed> { +impl SecureRandom for MbedtlsSecureRandom { + fn fill(&self, buf: &mut [u8]) -> Result<(), rustls::crypto::GetRandomFailed> { rng::rng_new() .ok_or(rustls::crypto::GetRandomFailed)? - .random(bytes) + .random(buf) .map_err(|_| rustls::crypto::GetRandomFailed) } +} - fn default_cipher_suites(&self) -> &'static [SupportedCipherSuite] { - ALL_CIPHER_SUITES - } - - fn default_kx_groups(&self) -> &'static [&'static dyn rustls::crypto::SupportedKxGroup] { - ALL_KX_GROUPS - } +#[derive(Debug)] +/// Implements `KeyProvider` using `mbedtls` +pub struct MbedtlsKeyProvider; +impl KeyProvider for MbedtlsKeyProvider { fn load_private_key( &self, - key_der: pki_types::PrivateKeyDer<'static>, + key_der: webpki::types::PrivateKeyDer<'static>, ) -> Result, rustls::Error> { Ok(alloc::sync::Arc::new(sign::MbedTlsPkSigningKey::new(&key_der)?)) } - - fn signature_verification_algorithms(&self) -> WebPkiSupportedAlgorithms { - SUPPORTED_SIG_ALGS - } } /// The cipher suite configuration that an application should use by default. diff --git a/rustls-mbedcrypto-provider/src/tls12.rs b/rustls-mbedcrypto-provider/src/tls12.rs index 781c699..00b2349 100644 --- a/rustls-mbedcrypto-provider/src/tls12.rs +++ b/rustls-mbedcrypto-provider/src/tls12.rs @@ -15,13 +15,11 @@ use rustls::crypto::cipher::{ PlainMessage, Tls12AeadAlgorithm, UnsupportedOperationError, NONCE_LEN, }; use rustls::crypto::tls12::PrfUsingHmac; -use rustls::crypto::KeyExchangeAlgorithm; +use rustls::crypto::{CipherSuiteCommon, KeyExchangeAlgorithm}; use super::aead::{self, Algorithm, AES128_GCM, AES256_GCM}; use alloc::string::String; -use rustls::{ - CipherSuite, CipherSuiteCommon, ConnectionTrafficSecrets, Error, SignatureScheme, SupportedCipherSuite, Tls12CipherSuite, -}; +use rustls::{CipherSuite, ConnectionTrafficSecrets, Error, SignatureScheme, SupportedCipherSuite, Tls12CipherSuite}; pub(crate) const GCM_FIXED_IV_LEN: usize = 4; pub(crate) const GCM_EXPLICIT_NONCE_LEN: usize = 8; @@ -34,6 +32,8 @@ pub static TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: SupportedCipherSuite = common: CipherSuiteCommon { suite: CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, hash_provider: &super::hash::SHA256, + confidentiality_limit: u64::MAX, + integrity_limit: 1 << 36, }, kx: KeyExchangeAlgorithm::ECDHE, sign: TLS12_ECDSA_SCHEMES, @@ -46,6 +46,8 @@ pub static TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: SupportedCipherSuite = S common: CipherSuiteCommon { suite: CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, hash_provider: &super::hash::SHA256, + confidentiality_limit: u64::MAX, + integrity_limit: 1 << 36, }, kx: KeyExchangeAlgorithm::ECDHE, sign: TLS12_RSA_SCHEMES, @@ -58,6 +60,8 @@ pub static TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: SupportedCipherSuite = Support common: CipherSuiteCommon { suite: CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, hash_provider: &super::hash::SHA256, + confidentiality_limit: 1 << 23, + integrity_limit: 1 << 52, }, kx: KeyExchangeAlgorithm::ECDHE, sign: TLS12_RSA_SCHEMES, @@ -70,6 +74,8 @@ pub static TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: SupportedCipherSuite = Support common: CipherSuiteCommon { suite: CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, hash_provider: &super::hash::SHA384, + confidentiality_limit: 1 << 23, + integrity_limit: 1 << 52, }, kx: KeyExchangeAlgorithm::ECDHE, sign: TLS12_RSA_SCHEMES, @@ -82,6 +88,8 @@ pub static TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: SupportedCipherSuite = Suppo common: CipherSuiteCommon { suite: CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, hash_provider: &super::hash::SHA256, + confidentiality_limit: 1 << 23, + integrity_limit: 1 << 52, }, kx: KeyExchangeAlgorithm::ECDHE, sign: TLS12_ECDSA_SCHEMES, @@ -94,6 +102,8 @@ pub static TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: SupportedCipherSuite = Suppo common: CipherSuiteCommon { suite: CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, hash_provider: &super::hash::SHA384, + confidentiality_limit: 1 << 23, + integrity_limit: 1 << 52, }, kx: KeyExchangeAlgorithm::ECDHE, sign: TLS12_ECDSA_SCHEMES, @@ -192,7 +202,7 @@ struct GcmMessageDecrypter { } impl MessageDecrypter for GcmMessageDecrypter { - fn decrypt(&self, msg: OpaqueMessage, seq: u64) -> Result { + fn decrypt(&mut self, msg: OpaqueMessage, seq: u64) -> Result { let payload = msg.payload(); if payload.len() < GCM_OVERHEAD { return Err(Error::DecryptError); @@ -242,7 +252,7 @@ impl MessageDecrypter for GcmMessageDecrypter { } impl MessageEncrypter for GcmMessageEncrypter { - fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result { + fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result { let nonce = Nonce::new(&self.iv, seq).0; let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len()); let mut tag = [0u8; aead::TAG_LEN]; @@ -271,6 +281,10 @@ impl MessageEncrypter for GcmMessageEncrypter { Ok(OpaqueMessage::new(msg.typ, msg.version, payload)) } + + fn encrypted_payload_len(&self, payload_len: usize) -> usize { + payload_len + GCM_EXPLICIT_NONCE_LEN + aead::TAG_LEN + } } /// The RFC7905/RFC7539 ChaCha20Poly1305 construction. @@ -292,7 +306,7 @@ struct ChaCha20Poly1305MessageDecrypter { const CHACHAPOLY1305_OVERHEAD: usize = 16; impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter { - fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result { + fn decrypt(&mut self, mut msg: OpaqueMessage, seq: u64) -> Result { let payload = msg.payload(); if payload.len() < CHACHAPOLY1305_OVERHEAD { @@ -343,7 +357,7 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter { } impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter { - fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result { + fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result { let nonce = Nonce::new(&self.enc_offset, seq).0; let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len()); let mut tag = [0u8; aead::TAG_LEN]; @@ -376,6 +390,10 @@ impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter { Ok(OpaqueMessage::new(msg.typ, msg.version, payload)) } + + fn encrypted_payload_len(&self, payload_len: usize) -> usize { + payload_len + aead::TAG_LEN + } } /// Generate GCM IV based on given IV and explicit nonce. diff --git a/rustls-mbedcrypto-provider/src/tls13.rs b/rustls-mbedcrypto-provider/src/tls13.rs index a008f85..793c8a7 100644 --- a/rustls-mbedcrypto-provider/src/tls13.rs +++ b/rustls-mbedcrypto-provider/src/tls13.rs @@ -17,10 +17,10 @@ use rustls::crypto::cipher::{ Tls13AeadAlgorithm, UnsupportedOperationError, }; use rustls::crypto::tls13::HkdfUsingHmac; +use rustls::crypto::CipherSuiteCommon; use rustls::internal::msgs::codec::Codec; use rustls::{ - CipherSuite, CipherSuiteCommon, ConnectionTrafficSecrets, ContentType, Error, ProtocolVersion, SupportedCipherSuite, - Tls13CipherSuite, + CipherSuite, ConnectionTrafficSecrets, ContentType, Error, ProtocolVersion, SupportedCipherSuite, Tls13CipherSuite, }; /// The TLS1.3 ciphersuite TLS_CHACHA20_POLY1305_SHA256 @@ -31,9 +31,12 @@ pub(crate) static TLS13_CHACHA20_POLY1305_SHA256_INTERNAL: &Tls13CipherSuite = & common: CipherSuiteCommon { suite: CipherSuite::TLS13_CHACHA20_POLY1305_SHA256, hash_provider: &super::hash::SHA256, + confidentiality_limit: u64::MAX, + integrity_limit: 1 << 36, }, hkdf_provider: &HkdfUsingHmac(&super::hmac::HMAC_SHA256), aead_alg: &AeadAlgorithm(&aead::CHACHA20_POLY1305), + quic: None, }; /// The TLS1.3 ciphersuite TLS_AES_256_GCM_SHA384 @@ -41,9 +44,12 @@ pub static TLS13_AES_256_GCM_SHA384: SupportedCipherSuite = SupportedCipherSuite common: CipherSuiteCommon { suite: CipherSuite::TLS13_AES_256_GCM_SHA384, hash_provider: &super::hash::SHA384, + confidentiality_limit: 1 << 23, + integrity_limit: 1 << 52, }, hkdf_provider: &HkdfUsingHmac(&super::hmac::HMAC_SHA384), aead_alg: &AeadAlgorithm(&aead::AES256_GCM), + quic: None, }); /// The TLS1.3 ciphersuite TLS_AES_128_GCM_SHA256 @@ -51,9 +57,12 @@ pub static TLS13_AES_128_GCM_SHA256: SupportedCipherSuite = SupportedCipherSuite common: CipherSuiteCommon { suite: CipherSuite::TLS13_AES_128_GCM_SHA256, hash_provider: &super::hash::SHA256, + confidentiality_limit: 1 << 23, + integrity_limit: 1 << 52, }, hkdf_provider: &HkdfUsingHmac(&super::hmac::HMAC_SHA256), aead_alg: &AeadAlgorithm(&aead::AES128_GCM), + quic: None, }); // common encrypter/decrypter/key_len items for above Tls13AeadAlgorithm impls @@ -95,7 +104,7 @@ struct Tls13MessageDecrypter { } impl MessageEncrypter for Tls13MessageEncrypter { - fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result { + fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result { let total_len = msg.payload.len() + 1 + aead::TAG_LEN; let mut payload = Vec::with_capacity(total_len); payload.extend_from_slice(msg.payload); @@ -134,10 +143,14 @@ impl MessageEncrypter for Tls13MessageEncrypter { payload, )) } + + fn encrypted_payload_len(&self, payload_len: usize) -> usize { + payload_len + 1 + aead::TAG_LEN + } } impl MessageDecrypter for Tls13MessageDecrypter { - fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result { + fn decrypt(&mut self, mut msg: OpaqueMessage, seq: u64) -> Result { let payload = msg.payload_mut(); if payload.len() < aead::TAG_LEN { return Err(Error::DecryptError); diff --git a/rustls-mbedcrypto-provider/tests/api.rs b/rustls-mbedcrypto-provider/tests/api.rs index 84074ee..ed304a8 100644 --- a/rustls-mbedcrypto-provider/tests/api.rs +++ b/rustls-mbedcrypto-provider/tests/api.rs @@ -18,10 +18,11 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::sync::Mutex; -use pki_types::{CertificateDer, PrivateKeyDer, UnixTime}; -use primary_provider::cipher_suite; +use pki_types::{CertificateDer, ServerName, UnixTime}; use primary_provider::sign::MbedTlsPkSigningKey as RsaSigningKey; +use primary_provider::{mbedtls_crypto_provider, MbedtlsSecureRandom}; use rustls::client::{verify_server_cert_signed_by_trust_anchor, ResolvesClientCert, Resumption}; +use rustls::crypto::CryptoProvider; use rustls::internal::msgs::base::Payload; use rustls::internal::msgs::codec::Codec; use rustls::internal::msgs::enums::AlertLevel; @@ -216,80 +217,6 @@ fn check_read_buf_err(reader: &mut dyn io::Read, err_kind: io::ErrorKind) { assert!(matches!(err, err if err.kind() == err_kind)) } -#[test] -fn config_builder_for_client_rejects_empty_kx_groups() { - assert_eq!( - client_config_builder() - .with_safe_default_cipher_suites() - .with_kx_groups(&[]) - .with_safe_default_protocol_versions() - .err(), - Some(Error::General("no kx groups configured".into())) - ); -} - -#[test] -fn config_builder_for_client_rejects_empty_cipher_suites() { - assert_eq!( - client_config_builder() - .with_cipher_suites(&[]) - .with_safe_default_kx_groups() - .with_safe_default_protocol_versions() - .err(), - Some(Error::General("no usable cipher suites configured".into())) - ); -} - -#[cfg(feature = "tls12")] -#[test] -fn config_builder_for_client_rejects_incompatible_cipher_suites() { - assert_eq!( - client_config_builder() - .with_cipher_suites(&[cipher_suite::TLS13_AES_256_GCM_SHA384]) - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS12]) - .err(), - Some(Error::General("no usable cipher suites configured".into())) - ); -} - -#[test] -fn config_builder_for_server_rejects_empty_kx_groups() { - assert_eq!( - server_config_builder() - .with_safe_default_cipher_suites() - .with_kx_groups(&[]) - .with_safe_default_protocol_versions() - .err(), - Some(Error::General("no kx groups configured".into())) - ); -} - -#[test] -fn config_builder_for_server_rejects_empty_cipher_suites() { - assert_eq!( - server_config_builder() - .with_cipher_suites(&[]) - .with_safe_default_kx_groups() - .with_safe_default_protocol_versions() - .err(), - Some(Error::General("no usable cipher suites configured".into())) - ); -} - -#[cfg(feature = "tls12")] -#[test] -fn config_builder_for_server_rejects_incompatible_cipher_suites() { - assert_eq!( - server_config_builder() - .with_cipher_suites(&[cipher_suite::TLS13_AES_256_GCM_SHA384]) - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS12]) - .err(), - Some(Error::General("no usable cipher suites configured".into())) - ); -} - #[test] fn buffered_client_data_sent() { let server_config = Arc::new(make_server_config(KeyType::Rsa)); @@ -432,71 +359,6 @@ fn server_can_get_client_cert_after_resumption() { } } -#[test] -#[cfg(feature = "ring")] -fn test_config_builders_debug() { - let b = server_config_builder(); - assert_eq!( - "ConfigBuilder { state: WantsCipherSuites(Ring) }", - format!("{:?}", b) - ); - let b = b.with_cipher_suites(&[cipher_suite::TLS13_CHACHA20_POLY1305_SHA256]); - assert_eq!("ConfigBuilder { state: WantsKxGroups { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], provider: Ring } }", format!("{:?}", b)); - let b = b.with_kx_groups(&[primary_provider::kx_group::X25519]); - assert_eq!("ConfigBuilder { state: WantsVersions { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], kx_groups: [X25519], provider: Ring } }", format!("{:?}", b)); - let b = b - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap(); - let b = b.with_no_client_auth(); - assert_eq!("ConfigBuilder { state: WantsServerCert { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], kx_groups: [X25519], provider: Ring, versions: [TLSv1_3], verifier: NoClientAuth } }", format!("{:?}", b)); - - let b = client_config_builder(); - assert_eq!( - "ConfigBuilder { state: WantsCipherSuites(Ring) }", - format!("{:?}", b) - ); - let b = b.with_cipher_suites(&[cipher_suite::TLS13_CHACHA20_POLY1305_SHA256]); - assert_eq!("ConfigBuilder { state: WantsKxGroups { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], provider: Ring } }", format!("{:?}", b)); - let b = b.with_kx_groups(&[primary_provider::kx_group::X25519]); - assert_eq!("ConfigBuilder { state: WantsVersions { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], kx_groups: [X25519], provider: Ring } }", format!("{:?}", b)); - let b = b - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap(); - assert_eq!("ConfigBuilder { state: WantsVerifier { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], kx_groups: [X25519], provider: Ring, versions: [TLSv1_3] } }", format!("{:?}", b)); -} - -#[test] -fn test_config_builders_debug_mbedtls() { - let b = server_config_builder(); - assert_eq!( - "ConfigBuilder { state: WantsCipherSuites(Mbedtls) }", - format!("{:?}", b) - ); - let b = b.with_cipher_suites(&[cipher_suite::TLS13_CHACHA20_POLY1305_SHA256]); - assert_eq!("ConfigBuilder { state: WantsKxGroups { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], provider: Mbedtls } }", format!("{:?}", b)); - let b = b.with_kx_groups(&[primary_provider::kx_group::X25519]); - assert_eq!("ConfigBuilder { state: WantsVersions { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], kx_groups: [X25519], provider: Mbedtls } }", format!("{:?}", b)); - let b = b - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap(); - let b = b.with_no_client_auth(); - assert_eq!("ConfigBuilder { state: WantsServerCert { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], kx_groups: [X25519], provider: Mbedtls, versions: [TLSv1_3], verifier: NoClientAuth } }", format!("{:?}", b)); - - let b = client_config_builder(); - assert_eq!( - "ConfigBuilder { state: WantsCipherSuites(Mbedtls) }", - format!("{:?}", b) - ); - let b = b.with_cipher_suites(&[cipher_suite::TLS13_CHACHA20_POLY1305_SHA256]); - assert_eq!("ConfigBuilder { state: WantsKxGroups { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], provider: Mbedtls } }", format!("{:?}", b)); - let b = b.with_kx_groups(&[primary_provider::kx_group::X25519]); - assert_eq!("ConfigBuilder { state: WantsVersions { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], kx_groups: [X25519], provider: Mbedtls } }", format!("{:?}", b)); - let b = b - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap(); - assert_eq!("ConfigBuilder { state: WantsVerifier { cipher_suites: [TLS13_CHACHA20_POLY1305_SHA256], kx_groups: [X25519], provider: Mbedtls, versions: [TLSv1_3] } }", format!("{:?}", b)); -} - /// Test that the server handles combination of `offer_client_auth()` returning true /// and `client_auth_mandatory` returning `Some(false)`. This exercises both the /// client's and server's ability to "recover" from the server asking for a client @@ -512,7 +374,8 @@ fn server_allow_any_anonymous_or_authenticated_client() { .unwrap(); let server_config = server_config_builder() - .with_safe_defaults() + .with_safe_default_protocol_versions() + .unwrap() .with_client_cert_verifier(client_auth) .with_single_cert(kt.get_chain(), kt.get_key()) .unwrap(); @@ -890,11 +753,11 @@ fn client_trims_terminating_dot() { fn check_sigalgs_reduced_by_ciphersuite(kt: KeyType, suite: CipherSuite, expected_sigalgs: Vec) { let client_config = finish_client_config( kt, - client_config_builder() - .with_cipher_suites(&[find_suite(suite)]) - .with_safe_default_kx_groups() - .with_safe_default_protocol_versions() - .unwrap(), + ClientConfig::builder_with_provider( + CryptoProvider { cipher_suites: vec![find_suite(suite)], ..mbedtls_crypto_provider() }.into(), + ) + .with_safe_default_protocol_versions() + .unwrap(), ); let mut server_config = make_server_config(kt); @@ -1721,12 +1584,6 @@ where os } - fn new_fails(sess: &'a mut C) -> OtherSession<'a, C, S> { - let mut os = OtherSession::new(sess); - os.fail_ok = true; - os - } - fn flush_vectored(&mut self, b: &[io::IoSlice<'_>]) -> io::Result { let mut total = 0; let mut lengths = vec![]; @@ -2214,159 +2071,6 @@ fn stream_write_swallows_underlying_io_error_after_plaintext_processed() { assert_eq!(format!("{:?}", rc), "Ok(5)"); } -fn make_disjoint_suite_configs() -> (ClientConfig, ServerConfig) { - let kt = KeyType::Rsa; - let server_config = finish_server_config( - kt, - server_config_builder() - .with_cipher_suites(&[cipher_suite::TLS13_CHACHA20_POLY1305_SHA256]) - .with_safe_default_kx_groups() - .with_safe_default_protocol_versions() - .unwrap(), - ); - - let client_config = finish_client_config( - kt, - client_config_builder() - .with_cipher_suites(&[cipher_suite::TLS13_AES_256_GCM_SHA384]) - .with_safe_default_kx_groups() - .with_safe_default_protocol_versions() - .unwrap(), - ); - - (client_config, server_config) -} - -#[test] -fn client_stream_handshake_error() { - let (client_config, server_config) = make_disjoint_suite_configs(); - let (mut client, mut server) = make_pair_for_configs(client_config, server_config); - - { - let mut pipe = OtherSession::new_fails(&mut server); - let mut client_stream = Stream::new(&mut client, &mut pipe); - let rc = client_stream.write(b"hello"); - assert!(rc.is_err()); - assert_eq!( - format!("{:?}", rc), - "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })" - ); - let rc = client_stream.write(b"hello"); - assert!(rc.is_err()); - assert_eq!( - format!("{:?}", rc), - "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })" - ); - } -} - -#[test] -fn client_streamowned_handshake_error() { - let (client_config, server_config) = make_disjoint_suite_configs(); - let (client, mut server) = make_pair_for_configs(client_config, server_config); - - let pipe = OtherSession::new_fails(&mut server); - let mut client_stream = StreamOwned::new(client, pipe); - let rc = client_stream.write(b"hello"); - assert!(rc.is_err()); - assert_eq!( - format!("{:?}", rc), - "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })" - ); - let rc = client_stream.write(b"hello"); - assert!(rc.is_err()); - assert_eq!( - format!("{:?}", rc), - "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })" - ); - - let (_, _) = client_stream.into_parts(); -} - -#[test] -fn server_stream_handshake_error() { - let (client_config, server_config) = make_disjoint_suite_configs(); - let (mut client, mut server) = make_pair_for_configs(client_config, server_config); - - client - .writer() - .write_all(b"world") - .unwrap(); - - { - let mut pipe = OtherSession::new_fails(&mut client); - let mut server_stream = Stream::new(&mut server, &mut pipe); - let mut bytes = [0u8; 5]; - let rc = server_stream.read(&mut bytes); - assert!(rc.is_err()); - assert_eq!( - format!("{:?}", rc), - "Err(Custom { kind: InvalidData, error: PeerIncompatible(NoCipherSuitesInCommon) })" - ); - } -} - -#[test] -fn server_streamowned_handshake_error() { - let (client_config, server_config) = make_disjoint_suite_configs(); - let (mut client, server) = make_pair_for_configs(client_config, server_config); - - client - .writer() - .write_all(b"world") - .unwrap(); - - let pipe = OtherSession::new_fails(&mut client); - let mut server_stream = StreamOwned::new(server, pipe); - let mut bytes = [0u8; 5]; - let rc = server_stream.read(&mut bytes); - assert!(rc.is_err()); - assert_eq!( - format!("{:?}", rc), - "Err(Custom { kind: InvalidData, error: PeerIncompatible(NoCipherSuitesInCommon) })" - ); -} - -#[test] -fn server_config_is_clone() { - let _ = make_server_config(KeyType::Rsa); -} - -#[test] -fn client_config_is_clone() { - let _ = make_client_config(KeyType::Rsa); -} - -#[test] -fn client_connection_is_debug() { - let (client, _) = make_pair(KeyType::Rsa); - println!("{:?}", client); -} - -#[test] -fn server_connection_is_debug() { - let (_, server) = make_pair(KeyType::Rsa); - println!("{:?}", server); -} - -#[test] -fn server_complete_io_for_handshake_ending_with_alert() { - let (client_config, server_config) = make_disjoint_suite_configs(); - let (mut client, mut server) = make_pair_for_configs(client_config, server_config); - - assert!(server.is_handshaking()); - - let mut pipe = OtherSession::new_fails(&mut client); - let rc = server.complete_io(&mut pipe); - assert!(rc.is_err(), "server io failed due to handshake failure"); - assert!(!server.wants_write(), "but server did send its alert"); - assert_eq!( - format!("{:?}", pipe.last_error), - "Some(AlertReceived(HandshakeFailure))", - "which was received by client" - ); -} - #[test] fn server_exposes_offered_sni() { let kt = KeyType::Rsa; @@ -2757,11 +2461,11 @@ fn negotiated_ciphersuite_client() { let scs = find_suite(suite); let client_config = finish_client_config( kt, - client_config_builder() - .with_cipher_suites(&[scs]) - .with_safe_default_kx_groups() - .with_protocol_versions(&[version]) - .unwrap(), + ClientConfig::builder_with_provider( + CryptoProvider { cipher_suites: vec![scs], ..mbedtls_crypto_provider() }.into(), + ) + .with_protocol_versions(&[version]) + .unwrap(), ); do_suite_test(client_config, make_server_config(kt), scs, version.version); @@ -2775,11 +2479,11 @@ fn negotiated_ciphersuite_server() { let scs = find_suite(suite); let server_config = finish_server_config( kt, - server_config_builder() - .with_cipher_suites(&[scs]) - .with_safe_default_kx_groups() - .with_protocol_versions(&[version]) - .unwrap(), + ServerConfig::builder_with_provider( + CryptoProvider { cipher_suites: vec![scs], ..mbedtls_crypto_provider() }.into(), + ) + .with_protocol_versions(&[version]) + .unwrap(), ); do_suite_test(make_client_config(kt), server_config, scs, version.version); @@ -3190,13 +2894,13 @@ impl rustls::server::StoresServerSessions for ServerStorage { #[derive(Debug, Clone)] enum ClientStorageOp { - SetKxHint(rustls::ServerName, rustls::NamedGroup), - GetKxHint(rustls::ServerName, Option), - SetTls12Session(rustls::ServerName), - GetTls12Session(rustls::ServerName, bool), - RemoveTls12Session(rustls::ServerName), - InsertTls13Ticket(rustls::ServerName), - TakeTls13Ticket(rustls::ServerName, bool), + SetKxHint(ServerName<'static>, rustls::NamedGroup), + GetKxHint(ServerName<'static>, Option), + SetTls12Session(ServerName<'static>), + GetTls12Session(ServerName<'static>, bool), + RemoveTls12Session(ServerName<'static>), + InsertTls13Ticket(ServerName<'static>), + TakeTls13Ticket(ServerName<'static>, bool), } struct ClientStorage { @@ -3230,7 +2934,7 @@ impl fmt::Debug for ClientStorage { } impl rustls::client::ClientSessionStore for ClientStorage { - fn set_kx_hint(&self, server_name: &rustls::ServerName, group: rustls::NamedGroup) { + fn set_kx_hint(&self, server_name: ServerName<'static>, group: rustls::NamedGroup) { self.ops .lock() .unwrap() @@ -3239,16 +2943,16 @@ impl rustls::client::ClientSessionStore for ClientStorage { .set_kx_hint(server_name, group) } - fn kx_hint(&self, server_name: &rustls::ServerName) -> Option { + fn kx_hint(&self, server_name: &ServerName) -> Option { let rc = self.storage.kx_hint(server_name); self.ops .lock() .unwrap() - .push(ClientStorageOp::GetKxHint(server_name.clone(), rc)); + .push(ClientStorageOp::GetKxHint(server_name.to_owned(), rc)); rc } - fn set_tls12_session(&self, server_name: &rustls::ServerName, value: rustls::client::Tls12ClientSessionValue) { + fn set_tls12_session(&self, server_name: ServerName<'static>, value: rustls::client::Tls12ClientSessionValue) { self.ops .lock() .unwrap() @@ -3257,16 +2961,16 @@ impl rustls::client::ClientSessionStore for ClientStorage { .set_tls12_session(server_name, value) } - fn tls12_session(&self, server_name: &rustls::ServerName) -> Option { + fn tls12_session(&self, server_name: &ServerName) -> Option { let rc = self.storage.tls12_session(server_name); self.ops .lock() .unwrap() - .push(ClientStorageOp::GetTls12Session(server_name.clone(), rc.is_some())); + .push(ClientStorageOp::GetTls12Session(server_name.to_owned(), rc.is_some())); rc } - fn remove_tls12_session(&self, server_name: &rustls::ServerName) { + fn remove_tls12_session(&self, server_name: &ServerName<'static>) { self.ops .lock() .unwrap() @@ -3275,7 +2979,7 @@ impl rustls::client::ClientSessionStore for ClientStorage { .remove_tls12_session(server_name); } - fn insert_tls13_ticket(&self, server_name: &rustls::ServerName, value: rustls::client::Tls13ClientSessionValue) { + fn insert_tls13_ticket(&self, server_name: ServerName<'static>, value: rustls::client::Tls13ClientSessionValue) { self.ops .lock() .unwrap() @@ -3284,7 +2988,7 @@ impl rustls::client::ClientSessionStore for ClientStorage { .insert_tls13_ticket(server_name, value); } - fn take_tls13_ticket(&self, server_name: &rustls::ServerName) -> Option { + fn take_tls13_ticket(&self, server_name: &ServerName<'static>) -> Option { let rc = self .storage .take_tls13_ticket(server_name); @@ -3639,7 +3343,7 @@ fn test_client_sends_helloretryrequest() { #[test] fn test_client_rejects_hrr_with_varied_session_id() { use rustls::internal::msgs::handshake::SessionId; - let different_session_id = SessionId::random(PROVIDER).unwrap(); + let different_session_id = SessionId::random(&MbedtlsSecureRandom).unwrap(); let assert_client_sends_hello_with_secp384 = |msg: &mut Message| -> Altered { if let MessagePayload::Handshake { parsed, encoded } = &mut msg.payload { @@ -4057,8 +3761,6 @@ fn test_client_tls12_no_resume_after_server_downgrade() { let server_config_1 = Arc::new(common::finish_server_config( KeyType::Rsa, server_config_builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() .with_protocol_versions(&[&rustls::version::TLS13]) .unwrap(), )); @@ -4066,8 +3768,6 @@ fn test_client_tls12_no_resume_after_server_downgrade() { let mut server_config_2 = common::finish_server_config( KeyType::Rsa, server_config_builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() .with_protocol_versions(&[&rustls::version::TLS12]) .unwrap(), ); @@ -4259,152 +3959,12 @@ fn test_no_warning_logging_during_successful_sessions() { } } -/// Test that secrets can be extracted and used for encryption/decryption. -#[cfg(feature = "tls12")] -#[test] -fn test_secret_extraction_enabled() { - use rustls::ConnectionTrafficSecrets; - // Normally, secret extraction would be used to configure kTLS (TLS offload - // to the kernel). We want this test to run on any platform, though, so - // instead we just compare secrets for equality. - - // TLS 1.2 and 1.3 have different mechanisms for key exchange and handshake, - // and secrets are stored/extracted differently, so we want to test them both. - // We support 3 different AEAD algorithms (AES-128-GCM mode, AES-256-GCM, and - // Chacha20Poly1305), so that's 2*3 = 6 combinations to test. - let kt = KeyType::Rsa; - for suite in [ - cipher_suite::TLS13_AES_128_GCM_SHA256, - cipher_suite::TLS13_AES_256_GCM_SHA384, - cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, - cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - ] { - let version = suite.version(); - println!("Testing suite {:?}", suite.suite().as_str()); - - // Only offer the cipher suite (and protocol version) that we're testing - let mut server_config = server_config_builder() - .with_cipher_suites(&[suite]) - .with_safe_default_kx_groups() - .with_protocol_versions(&[version]) - .unwrap() - .with_no_client_auth() - .with_single_cert(kt.get_chain(), kt.get_key()) - .unwrap(); - // Opt into secret extraction from both sides - server_config.enable_secret_extraction = true; - let server_config = Arc::new(server_config); - - let mut client_config = make_client_config(kt); - client_config.enable_secret_extraction = true; - - let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config), &server_config); - - do_handshake(&mut client, &mut server); - - // The handshake is finished, we're now able to extract traffic secrets - let client_secrets = client - .dangerous_extract_secrets() - .unwrap(); - let server_secrets = server - .dangerous_extract_secrets() - .unwrap(); - - // Comparing secrets for equality is something you should never have to - // do in production code, so ConnectionTrafficSecrets doesn't implement - // PartialEq/Eq on purpose. Instead, we have to get creative. - fn explode_secrets(s: &ConnectionTrafficSecrets) -> (&[u8], &[u8]) { - match s { - ConnectionTrafficSecrets::Aes128Gcm { key, iv } => (key.as_ref(), iv.as_ref()), - ConnectionTrafficSecrets::Aes256Gcm { key, iv } => (key.as_ref(), iv.as_ref()), - ConnectionTrafficSecrets::Chacha20Poly1305 { key, iv } => (key.as_ref(), iv.as_ref()), - _ => panic!("unexpected secret type"), - } - } - - fn assert_secrets_equal( - (l_seq, l_sec): (u64, ConnectionTrafficSecrets), - (r_seq, r_sec): (u64, ConnectionTrafficSecrets), - ) { - assert_eq!(l_seq, r_seq); - assert_eq!(explode_secrets(&l_sec), explode_secrets(&r_sec)); - } - - assert_secrets_equal(client_secrets.tx, server_secrets.rx); - assert_secrets_equal(client_secrets.rx, server_secrets.tx); - } -} - -/// Test that secrets cannot be extracted unless explicitly enabled, and until -/// the handshake is done. -#[cfg(feature = "tls12")] -#[test] -fn test_secret_extraction_disabled_or_too_early() { - let suite = cipher_suite::TLS13_AES_128_GCM_SHA256; - let kt = KeyType::Rsa; - - for (server_enable, client_enable) in [(true, false), (false, true)] { - let mut server_config = server_config_builder() - .with_cipher_suites(&[suite]) - .with_safe_default_kx_groups() - .with_safe_default_protocol_versions() - .unwrap() - .with_no_client_auth() - .with_single_cert(kt.get_chain(), kt.get_key()) - .unwrap(); - server_config.enable_secret_extraction = server_enable; - let server_config = Arc::new(server_config); - - let mut client_config = make_client_config(kt); - client_config.enable_secret_extraction = client_enable; - - let client_config = Arc::new(client_config); - - let (client, server) = make_pair_for_arc_configs(&client_config, &server_config); - - assert!( - client - .dangerous_extract_secrets() - .is_err(), - "extraction should fail until handshake completes" - ); - assert!( - server - .dangerous_extract_secrets() - .is_err(), - "extraction should fail until handshake completes" - ); - - let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); - - do_handshake(&mut client, &mut server); - - assert_eq!( - server_enable, - server - .dangerous_extract_secrets() - .is_ok() - ); - assert_eq!( - client_enable, - client - .dangerous_extract_secrets() - .is_ok() - ); - } -} - #[test] fn test_received_plaintext_backpressure() { - let suite = cipher_suite::TLS13_AES_128_GCM_SHA256; let kt = KeyType::Rsa; let server_config = Arc::new( server_config_builder() - .with_cipher_suites(&[suite]) - .with_safe_default_kx_groups() .with_safe_default_protocol_versions() .unwrap() .with_no_client_auth() @@ -4468,125 +4028,21 @@ fn test_received_plaintext_backpressure() { ); } -#[test] -fn test_debug_server_name_from_ip() { - assert_eq!( - format!("{:?}", rustls::ServerName::IpAddress("127.0.0.1".parse().unwrap())), - "IpAddress(127.0.0.1)" - ) -} - -#[test] -fn test_debug_server_name_from_string() { - assert_eq!( - format!("{:?}", rustls::ServerName::try_from("a.com").unwrap()), - "DnsName(\"a.com\")" - ) -} - #[test] fn test_explicit_provider_selection() { let client_config = finish_client_config( KeyType::Rsa, - rustls::ClientConfig::builder_with_provider(rustls::crypto::ring::RING).with_safe_defaults(), + rustls::ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider())) + .with_safe_default_protocol_versions() + .unwrap(), ); let server_config = finish_server_config( KeyType::Rsa, - rustls::ServerConfig::builder_with_provider(rustls_mbedcrypto_provider::MBEDTLS).with_safe_defaults(), + rustls::ServerConfig::builder_with_provider(Arc::new(rustls_mbedcrypto_provider::mbedtls_crypto_provider())) + .with_safe_default_protocol_versions() + .unwrap(), ); let (mut client, mut server) = make_pair_for_configs(client_config, server_config); do_handshake(&mut client, &mut server); } - -#[derive(Debug)] -struct FaultyRandomProvider { - parent: &'static dyn rustls::crypto::CryptoProvider, - - // when empty, `fill_random` requests return `GetRandomFailed` - rand_queue: Mutex<&'static [u8]>, -} - -impl rustls::crypto::CryptoProvider for FaultyRandomProvider { - fn fill_random(&self, output: &mut [u8]) -> Result<(), rustls::crypto::GetRandomFailed> { - let mut queue = self.rand_queue.lock().unwrap(); - - println!("fill_random request for {} bytes (got {})", output.len(), queue.len()); - - if queue.len() < output.len() { - return Err(rustls::crypto::GetRandomFailed); - } - - let fixed_output = &queue[..output.len()]; - output.copy_from_slice(fixed_output); - *queue = &queue[output.len()..]; - Ok(()) - } - - fn default_cipher_suites(&self) -> &'static [SupportedCipherSuite] { - self.parent.default_cipher_suites() - } - - fn default_kx_groups(&self) -> &'static [&'static (dyn rustls::crypto::SupportedKxGroup)] { - self.parent.default_kx_groups() - } - - fn load_private_key(&self, key_der: PrivateKeyDer<'static>) -> Result, Error> { - self.parent.load_private_key(key_der) - } - - fn signature_verification_algorithms(&self) -> rustls::WebPkiSupportedAlgorithms { - self.parent - .signature_verification_algorithms() - } -} - -#[test] -fn test_client_construction_fails_if_random_source_fails_in_first_request() { - static TEST_PROVIDER: FaultyRandomProvider = FaultyRandomProvider { parent: PROVIDER, rand_queue: Mutex::new(b"") }; - - let client_config = finish_client_config( - KeyType::Rsa, - rustls::ClientConfig::builder_with_provider(&TEST_PROVIDER).with_safe_defaults(), - ); - - assert_eq!( - ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap_err(), - Error::FailedToGetRandomBytes - ); -} - -#[test] -fn test_client_construction_fails_if_random_source_fails_in_second_request() { - static TEST_PROVIDER: FaultyRandomProvider = - FaultyRandomProvider { parent: PROVIDER, rand_queue: Mutex::new(b"nice random number generator huh") }; - - let client_config = finish_client_config( - KeyType::Rsa, - rustls::ClientConfig::builder_with_provider(&TEST_PROVIDER).with_safe_defaults(), - ); - - assert_eq!( - ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap_err(), - Error::FailedToGetRandomBytes - ); -} - -#[test] -fn test_client_construction_requires_64_bytes_of_random_material() { - static TEST_PROVIDER: FaultyRandomProvider = FaultyRandomProvider { - parent: PROVIDER, - rand_queue: Mutex::new( - b"nice random number generator !!!\ - it's really not very good is it?", - ), - }; - - let client_config = finish_client_config( - KeyType::Rsa, - rustls::ClientConfig::builder_with_provider(&TEST_PROVIDER).with_safe_defaults(), - ); - - ClientConnection::new(Arc::new(client_config), server_name("localhost")) - .expect("check how much random material ClientConnection::new consumes"); -} diff --git a/rustls-mbedcrypto-provider/tests/common/mod.rs b/rustls-mbedcrypto-provider/tests/common/mod.rs index ef85ae7..be659b8 100644 --- a/rustls-mbedcrypto-provider/tests/common/mod.rs +++ b/rustls-mbedcrypto-provider/tests/common/mod.rs @@ -11,10 +11,11 @@ use std::io; use std::ops::{Deref, DerefMut}; use std::sync::Arc; -use pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer}; -use webpki::extract_trust_anchor; +use pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName}; +use primary_provider::mbedtls_crypto_provider; use rustls::client::{ServerCertVerifierBuilder, WebPkiServerVerifier}; +use rustls::crypto::CryptoProvider; use rustls::internal::msgs::codec::Reader; use rustls::internal::msgs::message::{Message, OpaqueMessage, PlainMessage}; use rustls::server::{ClientCertVerifierBuilder, WebPkiClientVerifier}; @@ -25,7 +26,6 @@ use rustls::{ClientConfig, ClientConnection}; use rustls::{ConnectionCommon, ServerConfig, ServerConnection, SideData}; pub use rustls_mbedcrypto_provider as primary_provider; -pub use rustls_mbedcrypto_provider::MBEDTLS as PROVIDER; macro_rules! embed_files { ( @@ -271,7 +271,7 @@ impl KeyType { } } -pub fn server_config_builder() -> rustls::ConfigBuilder { +pub fn server_config_builder() -> rustls::ConfigBuilder { // ensure `ServerConfig::builder()` is covered, even though it is // equivalent to `builder_with_provider(PROVIDER)`. #[cfg(feature = "ring")] @@ -280,11 +280,11 @@ pub fn server_config_builder() -> rustls::ConfigBuilder rustls::ConfigBuilder { +pub fn client_config_builder() -> rustls::ConfigBuilder { // ensure `ClientConfig::builder()` is covered, even though it is // equivalent to `builder_with_provider(PROVIDER)`. #[cfg(feature = "ring")] @@ -294,7 +294,7 @@ pub fn client_config_builder() -> rustls::ConfigBuilder ServerConfig { - finish_server_config(kt, server_config_builder().with_safe_defaults()) + finish_server_config( + kt, + server_config_builder() + .with_safe_default_protocol_versions() + .unwrap(), + ) } pub fn make_server_config_with_versions(kt: KeyType, versions: &[&'static rustls::SupportedProtocolVersion]) -> ServerConfig { finish_server_config( kt, server_config_builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() .with_protocol_versions(versions) .unwrap(), ) @@ -325,11 +328,11 @@ pub fn make_server_config_with_kx_groups( ) -> ServerConfig { finish_server_config( kt, - server_config_builder() - .with_safe_default_cipher_suites() - .with_kx_groups(kx_groups) - .with_safe_default_protocol_versions() - .unwrap(), + ServerConfig::builder_with_provider( + CryptoProvider { kx_groups: kx_groups.to_vec(), ..mbedtls_crypto_provider() }.into(), + ) + .with_safe_default_protocol_versions() + .unwrap(), ) } @@ -339,7 +342,7 @@ pub fn get_client_root_store(kt: KeyType) -> Arc { let chain = kt.get_chain(); let trust_anchor = chain.last().unwrap(); RootCertStore { - roots: vec![extract_trust_anchor(trust_anchor) + roots: vec![webpki::anchor_from_trusted_cert(trust_anchor) .unwrap() .to_owned()], } @@ -372,7 +375,8 @@ pub fn make_server_config_with_optional_client_auth( pub fn make_server_config_with_client_verifier(kt: KeyType, verifier_builder: ClientCertVerifierBuilder) -> ServerConfig { server_config_builder() - .with_safe_defaults() + .with_safe_default_protocol_versions() + .unwrap() .with_client_cert_verifier(verifier_builder.build().unwrap()) .with_single_cert(kt.get_chain(), kt.get_key()) .unwrap() @@ -404,32 +408,40 @@ pub fn finish_client_config_with_creds( } pub fn make_client_config(kt: KeyType) -> ClientConfig { - finish_client_config(kt, client_config_builder().with_safe_defaults()) + finish_client_config( + kt, + client_config_builder() + .with_safe_default_protocol_versions() + .unwrap(), + ) } pub fn make_client_config_with_kx_groups( kt: KeyType, kx_groups: &[&'static dyn rustls::crypto::SupportedKxGroup], ) -> ClientConfig { - let builder = client_config_builder() - .with_safe_default_cipher_suites() - .with_kx_groups(kx_groups) - .with_safe_default_protocol_versions() - .unwrap(); + let builder = ClientConfig::builder_with_provider( + CryptoProvider { kx_groups: kx_groups.to_vec(), ..mbedtls_crypto_provider() }.into(), + ) + .with_safe_default_protocol_versions() + .unwrap(); finish_client_config(kt, builder) } pub fn make_client_config_with_versions(kt: KeyType, versions: &[&'static rustls::SupportedProtocolVersion]) -> ClientConfig { let builder = client_config_builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() .with_protocol_versions(versions) .unwrap(); finish_client_config(kt, builder) } pub fn make_client_config_with_auth(kt: KeyType) -> ClientConfig { - finish_client_config_with_creds(kt, client_config_builder().with_safe_defaults()) + finish_client_config_with_creds( + kt, + client_config_builder() + .with_safe_default_protocol_versions() + .unwrap(), + ) } pub fn make_client_config_with_versions_with_auth( @@ -437,8 +449,6 @@ pub fn make_client_config_with_versions_with_auth( versions: &[&'static rustls::SupportedProtocolVersion], ) -> ClientConfig { let builder = client_config_builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() .with_protocol_versions(versions) .unwrap(); finish_client_config_with_creds(kt, builder) @@ -449,8 +459,6 @@ pub fn make_client_config_with_verifier( verifier_builder: ServerCertVerifierBuilder, ) -> ClientConfig { client_config_builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() .with_protocol_versions(versions) .unwrap() .dangerous() @@ -459,6 +467,7 @@ pub fn make_client_config_with_verifier( } pub fn webpki_client_verifier_builder(roots: Arc) -> ClientCertVerifierBuilder { + // TODO we don't have a ring feature! #[cfg(feature = "ring")] { WebPkiClientVerifier::builder(roots) @@ -466,7 +475,7 @@ pub fn webpki_client_verifier_builder(roots: Arc) -> ClientCertVe #[cfg(not(feature = "ring"))] { - WebPkiClientVerifier::builder_with_provider(roots, PROVIDER) + WebPkiClientVerifier::builder_with_provider(roots, mbedtls_crypto_provider().into()) } } @@ -478,7 +487,7 @@ pub fn webpki_server_verifier_builder(roots: Arc) -> ServerCertVe #[cfg(not(feature = "ring"))] { - WebPkiServerVerifier::builder_with_provider(roots, PROVIDER) + WebPkiServerVerifier::builder_with_provider(roots, mbedtls_crypto_provider().into()) } } @@ -566,7 +575,7 @@ pub fn do_handshake_until_both_error( } } -pub fn server_name(name: &'static str) -> rustls::ServerName { +pub fn server_name(name: &'static str) -> ServerName { name.try_into().unwrap() } diff --git a/rustls-mbedpki-provider/Cargo.toml b/rustls-mbedpki-provider/Cargo.toml index aa7f6a1..cb37586 100644 --- a/rustls-mbedpki-provider/Cargo.toml +++ b/rustls-mbedpki-provider/Cargo.toml @@ -11,7 +11,7 @@ categories = ["network-programming", "cryptography"] resolver = "2" [dependencies] -rustls = { version = "0.22.0-alpha.4", default_features = false } +rustls = { version = "0.22.0", default_features = false } mbedtls = { version = "0.12.0-alpha.2", features = [ "x509", "chrono", @@ -20,7 +20,7 @@ mbedtls = { version = "0.12.0-alpha.2", features = [ x509-parser = "0.15" chrono = "0.4" -pki-types = { package = "rustls-pki-types", version = "0.2.1", features = [ +pki-types = { package = "rustls-pki-types", version = "1", features = [ "std", ] } utils = { package = "rustls-mbedtls-provider-utils", path = "../rustls-mbedtls-provider-utils", version = "0.1.0-alpha.1" } @@ -35,5 +35,6 @@ mbedtls = { version = "0.12.0-alpha.2", default-features = false, features = [ ] } [dev-dependencies] -rustls-pemfile = "1.0" -rustls = { version = "0.22.0-alpha.4" } +rustls-pemfile = "2" +rustls = { version = "0.22.0" } + diff --git a/rustls-mbedpki-provider/src/client_cert_verifier.rs b/rustls-mbedpki-provider/src/client_cert_verifier.rs index a3f7f2d..aabc13d 100644 --- a/rustls-mbedpki-provider/src/client_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/client_cert_verifier.rs @@ -202,7 +202,6 @@ mod tests { fn server_config_with_verifier(client_cert_verifier: MbedTlsClientCertVerifier) -> ServerConfig { ServerConfig::builder() - .with_safe_defaults() .with_client_cert_verifier(Arc::new(client_cert_verifier)) .with_single_cert( get_chain(include_bytes!("../test-data/rsa/end.fullchain")), @@ -223,7 +222,7 @@ mod tests { #[test] fn connection_client_cert_verifier() { - let client_config = ClientConfig::builder().with_safe_defaults(); + let client_config = ClientConfig::builder(); let root_ca = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec()); let mut root_store = RootCertStore::empty(); root_store.add(root_ca.clone()).unwrap(); @@ -245,7 +244,7 @@ mod tests { } fn test_connection_client_cert_verifier_with_invalid_certs(invalid_cert_chain: Vec>) { - let client_config = ClientConfig::builder().with_safe_defaults(); + let client_config = ClientConfig::builder(); let root_ca = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec()); let mut root_store = RootCertStore::empty(); root_store.add(root_ca.clone()).unwrap(); diff --git a/rustls-mbedpki-provider/src/server_cert_verifier.rs b/rustls-mbedpki-provider/src/server_cert_verifier.rs index ac4ee17..745bdd1 100644 --- a/rustls-mbedpki-provider/src/server_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/server_cert_verifier.rs @@ -11,11 +11,9 @@ use alloc::sync::Arc; use alloc::vec; use alloc::vec::Vec; use chrono::NaiveDateTime; +use pki_types::ServerName; use pki_types::{CertificateDer, UnixTime}; -use rustls::{ - client::danger::{ServerCertVerified, ServerCertVerifier}, - ServerName, -}; +use rustls::client::danger::{ServerCertVerified, ServerCertVerifier}; use utils::error::mbedtls_err_into_rustls_err_with_error_msg; use crate::{ @@ -99,10 +97,10 @@ impl MbedTlsServerCertVerifier { } } -fn server_name_to_str(server_name: &ServerName) -> String { +fn server_name_to_str(server_name: &ServerName) -> Option { match server_name { - ServerName::DnsName(name) => name.as_ref().to_string(), - ServerName::IpAddress(addr) => addr.to_string(), + ServerName::DnsName(name) => Some(name.as_ref().to_string()), + ServerName::IpAddress(_) => None, // We have this case because rustls::ServerName is marked as non-exhaustive. _ => { panic!("unknown server name: {server_name:?}") @@ -152,7 +150,7 @@ impl ServerCertVerifier for MbedTlsServerCertVerifier { move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut mbedtls::x509::VerifyError| { callback(cert, depth, flags) }, - Some(&server_name_str), + server_name_str.as_deref(), ) .map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?; } @@ -161,7 +159,7 @@ impl ServerCertVerifier for MbedTlsServerCertVerifier { &self.trusted_cas, None, Some(&mut error_msg), - Some(&server_name_str), + server_name_str.as_deref(), ) .map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?, }; @@ -214,10 +212,6 @@ mod tests { fn client_config_with_verifier(server_cert_verifier: V) -> ClientConfig { ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_safe_default_protocol_versions() - .unwrap() .dangerous() .with_custom_certificate_verifier(Arc::new(server_cert_verifier)) .with_no_client_auth() @@ -242,7 +236,6 @@ mod tests { let client_config = client_config_with_verifier(MbedTlsServerCertVerifier::new(&[root_ca]).unwrap()); let server_config = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(invalid_cert_chain, get_key(include_bytes!("../test-data/rsa/end.key"))) .unwrap(); @@ -269,11 +262,7 @@ mod tests { supported_verify_schemes, }; let client_config = client_config_with_verifier(verifier); - let server_config = ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(protocol_versions) - .unwrap() + let server_config = ServerConfig::builder_with_protocol_versions(protocol_versions) .with_no_client_auth() .with_single_cert(cert_chain, get_key(include_bytes!("../test-data/rsa/end.key"))) .unwrap(); diff --git a/rustls-mbedpki-provider/src/tests_common.rs b/rustls-mbedpki-provider/src/tests_common.rs index 950f934..be36fc8 100644 --- a/rustls-mbedpki-provider/src/tests_common.rs +++ b/rustls-mbedpki-provider/src/tests_common.rs @@ -12,14 +12,13 @@ use core::{ ops::{Deref, DerefMut}, }; -use pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, UnixTime}; +use pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; use rustls::{client::danger::ServerCertVerifier, ClientConnection, ConnectionCommon, ServerConnection, SideData}; /// Get a certificate chain from the contents of a pem file pub(crate) fn get_chain(bytes: &[u8]) -> Vec { rustls_pemfile::certs(&mut io::BufReader::new(bytes)) - .unwrap() - .into_iter() + .map(Result::unwrap) .map(CertificateDer::from) .collect() } @@ -27,11 +26,10 @@ pub(crate) fn get_chain(bytes: &[u8]) -> Vec { /// Get a private key from the contents of a pem file pub(crate) fn get_key(bytes: &[u8]) -> PrivateKeyDer { let value = rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(bytes)) - .unwrap() - .into_iter() .next() + .unwrap() .unwrap(); - PrivateKeyDer::from(PrivatePkcs8KeyDer::from(value)) + PrivateKeyDer::from(value) } // Copied from rustls repo @@ -98,7 +96,7 @@ impl ServerCertVerifier for VerifierWithSupportedVerifySc &self, end_entity: &CertificateDer, intermediates: &[CertificateDer], - server_name: &rustls::ServerName, + server_name: &ServerName, ocsp_response: &[u8], now: UnixTime, ) -> Result { diff --git a/rustls-mbedtls-provider-utils/Cargo.toml b/rustls-mbedtls-provider-utils/Cargo.toml index 5d1033f..03b789f 100644 --- a/rustls-mbedtls-provider-utils/Cargo.toml +++ b/rustls-mbedtls-provider-utils/Cargo.toml @@ -12,11 +12,11 @@ categories = ["network-programming", "cryptography"] resolver = "2" [dependencies] -rustls = { version = "0.22.0-alpha.4", default-features = false } +rustls = { version = "0.22.0", default-features = false } mbedtls = { version = "0.12.0-alpha.2", default-features = false, features = [ "std", ] } -pki-types = { package = "rustls-pki-types", version = "0.2.1", features = [ +pki-types = { package = "rustls-pki-types", version = "1", features = [ "std", ] } diff --git a/rustls-mbedtls-provider-utils/src/error.rs b/rustls-mbedtls-provider-utils/src/error.rs index 06b1861..efa4c1d 100644 --- a/rustls-mbedtls-provider-utils/src/error.rs +++ b/rustls-mbedtls-provider-utils/src/error.rs @@ -1,4 +1,5 @@ use alloc::{format, sync::Arc}; +use rustls::OtherError; /// Converts an `mbedtls::Error` into a `rustls::Error` pub fn mbedtls_err_into_rustls_err(err: mbedtls::Error) -> rustls::Error { @@ -30,7 +31,7 @@ pub fn mbedtls_err_into_rustls_err_with_error_msg(err: mbedtls::Error, msg: &str mbedtls::Error::X509SigMismatch | mbedtls::Error::X509UnknownOid | mbedtls::Error::X509UnknownSigAlg | - mbedtls::Error::X509UnknownVersion => rustls::Error::InvalidCertificate(rustls::CertificateError::Other(Arc::new(err))), + mbedtls::Error::X509UnknownVersion => rustls::Error::InvalidCertificate(rustls::CertificateError::Other(OtherError(Arc::new(err)))), mbedtls::Error::X509InvalidName => rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName), @@ -92,7 +93,9 @@ mod tests { ), format!( "{:?}", - rustls::Error::InvalidCertificate(CertificateError::Other(Arc::new(mbedtls::Error::X509UnknownVersion))) + rustls::Error::InvalidCertificate(CertificateError::Other(OtherError(Arc::new( + mbedtls::Error::X509UnknownVersion + )))) ) ); assert_eq!( @@ -102,7 +105,9 @@ mod tests { ), format!( "{:?}", - rustls::Error::InvalidCertificate(CertificateError::Other(Arc::new(mbedtls::Error::X509InvalidSerial))) + rustls::Error::InvalidCertificate(CertificateError::Other(OtherError(Arc::new( + mbedtls::Error::X509InvalidSerial + )))) ) ); }