From ce80e5912e9fda65a3553132581a9ca7afd93b25 Mon Sep 17 00:00:00 2001 From: Yuxiang Cao Date: Mon, 18 Dec 2023 11:29:08 -0800 Subject: [PATCH 1/5] special fix: return cert validity check result first --- rustls-mbedpki-provider/src/client_cert_verifier.rs | 4 ++-- rustls-mbedpki-provider/src/server_cert_verifier.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rustls-mbedpki-provider/src/client_cert_verifier.rs b/rustls-mbedpki-provider/src/client_cert_verifier.rs index 43870d8..024bad2 100644 --- a/rustls-mbedpki-provider/src/client_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/client_cert_verifier.rs @@ -136,6 +136,8 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier { .into_iter() .collect(); + verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?; + let self_verify_callback = self.verify_callback.clone(); let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut mbedtls::x509::VerifyError| { // When the "time" feature is enabled for mbedtls, it checks cert expiration. We undo that here, @@ -152,8 +154,6 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier { mbedtls::x509::Certificate::verify_with_callback(&chain, &self.trusted_cas, None, Some(&mut error_msg), callback) .map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?; - verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?; - Ok(ClientCertVerified::assertion()) } diff --git a/rustls-mbedpki-provider/src/server_cert_verifier.rs b/rustls-mbedpki-provider/src/server_cert_verifier.rs index 3ab4dd5..511cdbf 100644 --- a/rustls-mbedpki-provider/src/server_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/server_cert_verifier.rs @@ -137,6 +137,8 @@ impl ServerCertVerifier for MbedTlsServerCertVerifier { let server_name_str = server_name_to_str(server_name); + verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?; + let self_verify_callback = self.verify_callback.clone(); let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut mbedtls::x509::VerifyError| { // When the "time" feature is enabled for mbedtls, it checks cert expiration. We undo that here, @@ -160,8 +162,6 @@ impl ServerCertVerifier for MbedTlsServerCertVerifier { ) .map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?; - verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?; - Ok(ServerCertVerified::assertion()) } From 864c4f19c6fa97496faf06d781a8d86733b971a6 Mon Sep 17 00:00:00 2001 From: Yuxiang Cao Date: Mon, 18 Dec 2023 17:34:42 -0800 Subject: [PATCH 2/5] return verify_error in a better way --- .../src/client_cert_verifier.rs | 92 ++++++++++++------- rustls-mbedpki-provider/src/lib.rs | 38 +++++++- .../src/server_cert_verifier.rs | 76 ++++++++++----- 3 files changed, 145 insertions(+), 61 deletions(-) diff --git a/rustls-mbedpki-provider/src/client_cert_verifier.rs b/rustls-mbedpki-provider/src/client_cert_verifier.rs index 024bad2..a5ab2a9 100644 --- a/rustls-mbedpki-provider/src/client_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/client_cert_verifier.rs @@ -10,15 +10,16 @@ use alloc::sync::Arc; use alloc::vec; use alloc::vec::Vec; use chrono::NaiveDateTime; +use mbedtls::x509::VerifyError; use rustls::pki_types::{CertificateDer, UnixTime}; use rustls::{ server::danger::{ClientCertVerified, ClientCertVerifier}, DistinguishedName, }; -use utils::error::mbedtls_err_into_rustls_err_with_error_msg; use crate::{ - mbedtls_err_into_rustls_err, rustls_cert_to_mbedtls_cert, verify_certificates_active, verify_tls_signature, CertActiveCheck, + mbedtls_err_into_rustls_err, merge_verify_result, rustls_cert_to_mbedtls_cert, verify_certificates_active, + verify_tls_signature, CertActiveCheck, VerifyErrorWrapper, }; /// A [`rustls`] [`ClientCertVerifier`] implemented using the PKI functionality of @@ -29,6 +30,7 @@ pub struct MbedTlsClientCertVerifier { root_subjects: Vec, verify_callback: Option>, cert_active_check: CertActiveCheck, + mbedtls_verify_error_mapping: fn(VerifyError) -> rustls::Error, } impl core::fmt::Debug for MbedTlsClientCertVerifier { @@ -71,9 +73,17 @@ impl MbedTlsClientCertVerifier { root_subjects, verify_callback: None, cert_active_check: CertActiveCheck::default(), + mbedtls_verify_error_mapping: Self::default_mbedtls_verify_error_mapping, }) } + /// The default mapping of [`VerifyError`] to [`rustls::Error`]. + pub fn default_mbedtls_verify_error_mapping(verify_err: VerifyError) -> rustls::Error { + rustls::Error::InvalidCertificate(rustls::CertificateError::Other(rustls::OtherError(Arc::new( + VerifyErrorWrapper(verify_err), + )))) + } + /// The certificate authority certificates used to construct this object pub fn trusted_cas(&self) -> &mbedtls::alloc::List { &self.trusted_cas @@ -136,13 +146,11 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier { .into_iter() .collect(); - verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?; - let self_verify_callback = self.verify_callback.clone(); - let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut mbedtls::x509::VerifyError| { + let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut VerifyError| { // When the "time" feature is enabled for mbedtls, it checks cert expiration. We undo that here, // since this check is done in `verify_certificates_active()` (subject to self.cert_active_check) - flags.remove(mbedtls::x509::VerifyError::CERT_EXPIRED | mbedtls::x509::VerifyError::CERT_FUTURE); + flags.remove(VerifyError::CERT_EXPIRED | VerifyError::CERT_FUTURE); if let Some(cb) = self_verify_callback.as_ref() { cb(cert, depth, flags) } else { @@ -151,8 +159,18 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier { }; let mut error_msg = String::default(); - mbedtls::x509::Certificate::verify_with_callback(&chain, &self.trusted_cas, None, Some(&mut error_msg), callback) - .map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?; + let cert_verify_res = mbedtls::x509::Certificate::verify_with_callback_return_verify_err( + &chain, + &self.trusted_cas, + None, + Some(&mut error_msg), + callback, + ) + .map_err(|e| e.1); + + let validity_verify_res = verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?; + + merge_verify_result(&validity_verify_res, &cert_verify_res).map_err(self.mbedtls_verify_error_mapping)?; Ok(ClientCertVerified::assertion()) } @@ -188,10 +206,10 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier { mod tests { use chrono::DateTime; + use mbedtls::x509::VerifyError; use rustls::pki_types::{CertificateDer, UnixTime}; use rustls::{ - server::danger::ClientCertVerifier, CertificateError, ClientConfig, ClientConnection, RootCertStore, ServerConfig, - ServerConnection, + server::danger::ClientCertVerifier, ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection, }; use std::{sync::Arc, time::SystemTime}; @@ -329,13 +347,16 @@ mod tests { now.duration_since(SystemTime::UNIX_EPOCH) .unwrap(), ); - - assert_eq!( - verifier - .verify_client_cert(&cert_chain[0], &cert_chain[1..], now) - .unwrap_err(), - rustls::Error::InvalidCertificate(CertificateError::Expired) - ); + let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now); + if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res { + let verify_err = other_err + .0 + .downcast_ref::() + .unwrap(); + assert_eq!(verify_err.0, VerifyError::CERT_EXPIRED); + } else { + panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`") + } } #[test] @@ -353,12 +374,16 @@ mod tests { .unwrap(), ); - assert_eq!( - verifier - .verify_client_cert(&cert_chain[0], &cert_chain[1..], now) - .unwrap_err(), - rustls::Error::InvalidCertificate(CertificateError::Expired) - ); + let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now); + if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res { + let verify_err = other_err + .0 + .downcast_ref::() + .unwrap(); + assert_eq!(verify_err.0, VerifyError::CERT_EXPIRED); + } else { + panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`") + } // Test that we accept expired certs when `ignore_expired` is true verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: false }); @@ -373,13 +398,16 @@ mod tests { now.duration_since(SystemTime::UNIX_EPOCH) .unwrap(), ); - assert_eq!( - verifier - .verify_client_cert(&cert_chain[0], &cert_chain[1..], now) - .unwrap_err(), - rustls::Error::InvalidCertificate(CertificateError::NotValidYet) - ); - + let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now); + if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res { + let verify_err = other_err + .0 + .downcast_ref::() + .unwrap(); + assert_eq!(verify_err.0, VerifyError::CERT_FUTURE); + } else { + panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`") + } // Test that we accept certs that are not valid yet when `ignore_not_active_yet` is true verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: false, ignore_not_active_yet: true }); @@ -406,8 +434,8 @@ mod tests { assert!(matches!(verify_res, Err(rustls::Error::InvalidCertificate(_)))); verifier.set_verify_callback(Some(Arc::new( - move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut mbedtls::x509::VerifyError| { - flags.remove(mbedtls::x509::VerifyError::CERT_NOT_TRUSTED); + move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut VerifyError| { + flags.remove(VerifyError::CERT_NOT_TRUSTED); Ok(()) }, ))); diff --git a/rustls-mbedpki-provider/src/lib.rs b/rustls-mbedpki-provider/src/lib.rs index 16ef59b..8343dae 100644 --- a/rustls-mbedpki-provider/src/lib.rs +++ b/rustls-mbedpki-provider/src/lib.rs @@ -43,7 +43,10 @@ extern crate alloc; #[cfg(not(test))] extern crate std; +use core::fmt::Display; + use chrono::NaiveDateTime; +use mbedtls::x509::VerifyError; use rustls::pki_types::CertificateDer; use rustls::SignatureScheme; @@ -91,9 +94,9 @@ fn verify_certificates_active<'a>( chain: impl IntoIterator, now: NaiveDateTime, active_check: &CertActiveCheck, -) -> Result<(), rustls::Error> { +) -> Result, rustls::Error> { if active_check.ignore_expired && active_check.ignore_not_active_yet { - return Ok(()); + return Ok(Ok(())); } fn time_err_to_err(_time_err: mbedtls::x509::InvalidTimeError) -> rustls::Error { @@ -108,7 +111,7 @@ fn verify_certificates_active<'a>( .try_into() .map_err(time_err_to_err)?; if now > not_after { - return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Expired)); + return Ok(Err(VerifyError::CERT_EXPIRED)); } } if !active_check.ignore_not_active_yet { @@ -118,11 +121,11 @@ fn verify_certificates_active<'a>( .try_into() .map_err(time_err_to_err)?; if now < not_before { - return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidYet)); + return Ok(Err(VerifyError::CERT_FUTURE)); } } } - Ok(()) + Ok(Ok(())) } /// Verifies the tls signature, matches verify functions in rustls `ClientCertVerifier` and @@ -173,3 +176,28 @@ pub fn rustls_cert_to_mbedtls_cert(cert: &CertificateDer) -> mbedtls::Result, + second: &Result<(), VerifyError>, +) -> Result<(), VerifyError> { + match (first, second) { + (Ok(()), Ok(())) => Ok(()), + (Ok(()), Err(second_err)) => Err(*second_err), + (Err(first_err), Ok(())) => Err(*first_err), + (Err(first_err), Err(second_err)) => Err(*first_err | *second_err), + } +} + +/// A wrapper on [`mbedtls::x509::VerifyError`] to impl [`std::error::Error`] for it. +#[derive(Debug)] +pub struct VerifyErrorWrapper(pub VerifyError); + +impl Display for VerifyErrorWrapper { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use core::fmt::Debug; + self.0.fmt(f) + } +} + +impl std::error::Error for VerifyErrorWrapper {} diff --git a/rustls-mbedpki-provider/src/server_cert_verifier.rs b/rustls-mbedpki-provider/src/server_cert_verifier.rs index 511cdbf..4da0a25 100644 --- a/rustls-mbedpki-provider/src/server_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/server_cert_verifier.rs @@ -11,11 +11,13 @@ use alloc::sync::Arc; use alloc::vec; use alloc::vec::Vec; use chrono::NaiveDateTime; +use mbedtls::x509::VerifyError; use rustls::client::danger::{ServerCertVerified, ServerCertVerifier}; use rustls::pki_types::ServerName; use rustls::pki_types::{CertificateDer, UnixTime}; -use utils::error::mbedtls_err_into_rustls_err_with_error_msg; +use crate::merge_verify_result; +use crate::VerifyErrorWrapper; use crate::{ mbedtls_err_into_rustls_err, rustls_cert_to_mbedtls_cert, verify_certificates_active, verify_tls_signature, CertActiveCheck, }; @@ -26,6 +28,7 @@ pub struct MbedTlsServerCertVerifier { trusted_cas: mbedtls::alloc::List, verify_callback: Option>, cert_active_check: CertActiveCheck, + mbedtls_verify_error_mapping: fn(VerifyError) -> rustls::Error, } impl core::fmt::Debug for MbedTlsServerCertVerifier { @@ -66,9 +69,17 @@ impl MbedTlsServerCertVerifier { trusted_cas, verify_callback: None, cert_active_check: CertActiveCheck::default(), + mbedtls_verify_error_mapping: Self::default_mbedtls_verify_error_mapping, }) } + /// The default mapping of [`VerifyError`] to [`rustls::Error`]. + pub fn default_mbedtls_verify_error_mapping(verify_err: VerifyError) -> rustls::Error { + rustls::Error::InvalidCertificate(rustls::CertificateError::Other(rustls::OtherError(Arc::new( + VerifyErrorWrapper(verify_err), + )))) + } + /// The certificate authority certificates used to construct this object pub fn trusted_cas(&self) -> &mbedtls::alloc::List { &self.trusted_cas @@ -137,30 +148,30 @@ impl ServerCertVerifier for MbedTlsServerCertVerifier { let server_name_str = server_name_to_str(server_name); - verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?; - let self_verify_callback = self.verify_callback.clone(); - let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut mbedtls::x509::VerifyError| { + let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut VerifyError| { // When the "time" feature is enabled for mbedtls, it checks cert expiration. We undo that here, // since this check is done in `verify_certificates_active()` (subject to self.cert_active_check) - flags.remove(mbedtls::x509::VerifyError::CERT_EXPIRED | mbedtls::x509::VerifyError::CERT_FUTURE); + flags.remove(VerifyError::CERT_EXPIRED | VerifyError::CERT_FUTURE); if let Some(cb) = self_verify_callback.as_ref() { cb(cert, depth, flags) } else { Ok(()) } }; - - let mut error_msg = String::default(); - mbedtls::x509::Certificate::verify_with_callback_expected_common_name( + let cert_verify_res = mbedtls::x509::Certificate::verify_with_callback_expected_common_name_return_verify_err( &chain, &self.trusted_cas, None, - Some(&mut error_msg), + None, callback, server_name_str.as_deref(), ) - .map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?; + .map_err(|e| e.1); + + let validity_verify_res = verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?; + + merge_verify_result(&validity_verify_res, &cert_verify_res).map_err(self.mbedtls_verify_error_mapping)?; Ok(ServerCertVerified::assertion()) } @@ -196,6 +207,7 @@ impl ServerCertVerifier for MbedTlsServerCertVerifier { mod tests { use std::{sync::Arc, time::SystemTime}; + use mbedtls::x509::VerifyError; use rustls::pki_types::{CertificateDer, UnixTime}; use rustls::{ client::danger::ServerCertVerifier, @@ -340,10 +352,15 @@ mod tests { .unwrap(), ); let verify_res = verifier.verify_server_cert(&cert_chain[0], &cert_chain[1..], &server_name, &[], now); - assert_eq!( - verify_res.unwrap_err(), - rustls::Error::InvalidCertificate(rustls::CertificateError::Expired) - ); + if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res { + let verify_err = other_err + .0 + .downcast_ref::() + .unwrap(); + assert_eq!(verify_err.0, VerifyError::CERT_EXPIRED); + } else { + panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`") + } } #[test] @@ -362,10 +379,15 @@ mod tests { // Test that we reject expired certs let verify_res = verifier.verify_server_cert(&cert_chain[0], &cert_chain[1..], &server_name, &[], now); - assert_eq!( - verify_res.unwrap_err(), - rustls::Error::InvalidCertificate(rustls::CertificateError::Expired) - ); + if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res { + let verify_err = other_err + .0 + .downcast_ref::() + .unwrap(); + assert_eq!(verify_err.0, VerifyError::CERT_EXPIRED); + } else { + panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`") + } // Test that we accept expired certs when `ignore_expired` is true verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: false }); @@ -380,10 +402,16 @@ mod tests { // Test that we reject certs that are not valid yet let verify_res = verifier.verify_server_cert(&cert_chain[0], &cert_chain[1..], &server_name, &[], now); - assert_eq!( - verify_res.unwrap_err(), - rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidYet) - ); + + if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res { + let verify_err = other_err + .0 + .downcast_ref::() + .unwrap(); + assert_eq!(verify_err.0, VerifyError::CERT_FUTURE); + } else { + panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`") + } // Test that we accept certs that are not valid yet when `ignore_not_active_yet` is true verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: false, ignore_not_active_yet: true }); @@ -427,8 +455,8 @@ mod tests { assert!(matches!(verify_res, Err(rustls::Error::InvalidCertificate(_)))); verifier.set_verify_callback(Some(Arc::new( - move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut mbedtls::x509::VerifyError| { - flags.remove(mbedtls::x509::VerifyError::CERT_CN_MISMATCH); + move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut VerifyError| { + flags.remove(VerifyError::CERT_CN_MISMATCH); Ok(()) }, ))); From c6182c8dbd8ee85cef4dac3e4d7a585f4d4dae67 Mon Sep 17 00:00:00 2001 From: Yuxiang Cao Date: Mon, 18 Dec 2023 17:38:25 -0800 Subject: [PATCH 3/5] add set function --- rustls-mbedpki-provider/src/client_cert_verifier.rs | 5 +++++ rustls-mbedpki-provider/src/server_cert_verifier.rs | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/rustls-mbedpki-provider/src/client_cert_verifier.rs b/rustls-mbedpki-provider/src/client_cert_verifier.rs index a5ab2a9..42fc42b 100644 --- a/rustls-mbedpki-provider/src/client_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/client_cert_verifier.rs @@ -84,6 +84,11 @@ impl MbedTlsClientCertVerifier { )))) } + /// Set the mapping of [`VerifyError`] to [`rustls::Error`]. + pub fn set_mbedtls_verify_error_mapping(&mut self, mapping: fn(VerifyError) -> rustls::Error) { + self.mbedtls_verify_error_mapping = mapping; + } + /// The certificate authority certificates used to construct this object pub fn trusted_cas(&self) -> &mbedtls::alloc::List { &self.trusted_cas diff --git a/rustls-mbedpki-provider/src/server_cert_verifier.rs b/rustls-mbedpki-provider/src/server_cert_verifier.rs index 4da0a25..8285a6d 100644 --- a/rustls-mbedpki-provider/src/server_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/server_cert_verifier.rs @@ -80,6 +80,11 @@ impl MbedTlsServerCertVerifier { )))) } + /// Set the mapping of [`VerifyError`] to [`rustls::Error`]. + pub fn set_mbedtls_verify_error_mapping(&mut self, mapping: fn(VerifyError) -> rustls::Error) { + self.mbedtls_verify_error_mapping = mapping; + } + /// The certificate authority certificates used to construct this object pub fn trusted_cas(&self) -> &mbedtls::alloc::List { &self.trusted_cas From f875ccd0dd97b33e079504ed478c76eafa7ea2ef Mon Sep 17 00:00:00 2001 From: Yuxiang Cao Date: Tue, 19 Dec 2023 11:49:36 -0800 Subject: [PATCH 4/5] add some unit tests --- .../src/client_cert_verifier.rs | 28 +++++++++++++- .../src/server_cert_verifier.rs | 37 ++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/rustls-mbedpki-provider/src/client_cert_verifier.rs b/rustls-mbedpki-provider/src/client_cert_verifier.rs index 42fc42b..6f512c7 100644 --- a/rustls-mbedpki-provider/src/client_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/client_cert_verifier.rs @@ -89,6 +89,11 @@ impl MbedTlsClientCertVerifier { self.mbedtls_verify_error_mapping = mapping; } + /// Get the current mapping of [`VerifyError`] to [`rustls::Error`]. + pub fn mbedtls_verify_error_mapping(&self) -> fn(VerifyError) -> rustls::Error { + self.mbedtls_verify_error_mapping + } + /// The certificate authority certificates used to construct this object pub fn trusted_cas(&self) -> &mbedtls::alloc::List { &self.trusted_cas @@ -241,6 +246,24 @@ mod tests { ); } + #[test] + fn client_cert_verifier_setter_getter() { + let root_ca = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec()); + let mut client_cert_verifier = MbedTlsClientCertVerifier::new([&root_ca]).unwrap(); + assert!(!client_cert_verifier + .trusted_cas() + .is_empty()); + const RETURN_ERR: rustls::Error = rustls::Error::BadMaxFragmentSize; + fn test_mbedtls_verify_error_mapping(_verify_err: VerifyError) -> rustls::Error { + RETURN_ERR + } + client_cert_verifier.set_mbedtls_verify_error_mapping(test_mbedtls_verify_error_mapping); + assert_eq!( + client_cert_verifier.mbedtls_verify_error_mapping()(VerifyError::empty()), + RETURN_ERR + ); + } + #[test] fn connection_client_cert_verifier() { let client_config = ClientConfig::builder(); @@ -415,7 +438,10 @@ mod tests { } // Test that we accept certs that are not valid yet when `ignore_not_active_yet` is true verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: false, ignore_not_active_yet: true }); - + assert!(verifier + .verify_client_cert(&cert_chain[0], &cert_chain[1..], now) + .is_ok()); + verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: true }); assert!(verifier .verify_client_cert(&cert_chain[0], &cert_chain[1..], now) .is_ok()); diff --git a/rustls-mbedpki-provider/src/server_cert_verifier.rs b/rustls-mbedpki-provider/src/server_cert_verifier.rs index 8285a6d..ee086f9 100644 --- a/rustls-mbedpki-provider/src/server_cert_verifier.rs +++ b/rustls-mbedpki-provider/src/server_cert_verifier.rs @@ -85,6 +85,11 @@ impl MbedTlsServerCertVerifier { self.mbedtls_verify_error_mapping = mapping; } + /// Get the current mapping of [`VerifyError`] to [`rustls::Error`]. + pub fn mbedtls_verify_error_mapping(&self) -> fn(VerifyError) -> rustls::Error { + self.mbedtls_verify_error_mapping + } + /// The certificate authority certificates used to construct this object pub fn trusted_cas(&self) -> &mbedtls::alloc::List { &self.trusted_cas @@ -213,7 +218,7 @@ mod tests { use std::{sync::Arc, time::SystemTime}; use mbedtls::x509::VerifyError; - use rustls::pki_types::{CertificateDer, UnixTime}; + use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::{ client::danger::ServerCertVerifier, version::{TLS12, TLS13}, @@ -221,6 +226,7 @@ mod tests { SupportedProtocolVersion, }; + use crate::server_cert_verifier::server_name_to_str; use crate::tests_common::{do_handshake_until_error, get_chain, get_key, VerifierWithSupportedVerifySchemes}; use super::MbedTlsServerCertVerifier; @@ -242,6 +248,24 @@ mod tests { ); } + #[test] + fn server_cert_verifier_setter_getter() { + let root_ca = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec()); + let mut server_cert_verifier = MbedTlsServerCertVerifier::new([&root_ca]).unwrap(); + assert!(!server_cert_verifier + .trusted_cas() + .is_empty()); + const RETURN_ERR: rustls::Error = rustls::Error::BadMaxFragmentSize; + fn test_mbedtls_verify_error_mapping(_verify_err: VerifyError) -> rustls::Error { + RETURN_ERR + } + server_cert_verifier.set_mbedtls_verify_error_mapping(test_mbedtls_verify_error_mapping); + assert_eq!( + server_cert_verifier.mbedtls_verify_error_mapping()(VerifyError::empty()), + RETURN_ERR + ); + } + fn test_connection_server_cert_verifier_with_invalid_certs( invalid_cert_chain: Vec>, ) -> rustls::Error { @@ -422,6 +446,9 @@ mod tests { verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: false, ignore_not_active_yet: true }); let verify_res = verifier.verify_server_cert(&cert_chain[0], &cert_chain[1..], &server_name, &[], now); assert!(verify_res.is_ok()); + verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: true }); + let verify_res = verifier.verify_server_cert(&cert_chain[0], &cert_chain[1..], &server_name, &[], now); + assert!(verify_res.is_ok()); } #[test] @@ -503,4 +530,12 @@ mod tests { test_server_cert_verifier_invalid_chain(&broken_chain); } } + + #[test] + fn test_server_name_to_str() { + let server_name = ServerName::DnsName("example.com".try_into().unwrap()); + assert_eq!(server_name_to_str(&server_name), Some("example.com".to_string())); + let server_name = ServerName::IpAddress("127.0.0.1".try_into().unwrap()); + assert_eq!(server_name_to_str(&server_name), None); + } } From 5c50a6ecb72dda2504d6c1765fec2421e58737fc Mon Sep 17 00:00:00 2001 From: Yuxiang Cao Date: Wed, 20 Dec 2023 08:50:00 -0800 Subject: [PATCH 5/5] add some more unit tests --- rustls-mbedpki-provider/src/lib.rs | 68 +++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/rustls-mbedpki-provider/src/lib.rs b/rustls-mbedpki-provider/src/lib.rs index 8343dae..7e4ac38 100644 --- a/rustls-mbedpki-provider/src/lib.rs +++ b/rustls-mbedpki-provider/src/lib.rs @@ -88,6 +88,11 @@ pub const SUPPORTED_SIGNATURE_SCHEMA: [SignatureScheme; 9] = [ rustls::SignatureScheme::RSA_PSS_SHA256, ]; +/// Helper function to convert a [`mbedtls::x509::InvalidTimeError`] to a [`rustls::Error`] +fn time_err_to_err(_time_err: mbedtls::x509::InvalidTimeError) -> rustls::Error { + rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) +} + /// Verifies that certificates are active, i.e., `now` is between not_before and not_after for /// each certificate fn verify_certificates_active<'a>( @@ -99,10 +104,6 @@ fn verify_certificates_active<'a>( return Ok(Ok(())); } - fn time_err_to_err(_time_err: mbedtls::x509::InvalidTimeError) -> rustls::Error { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) - } - for cert in chain.into_iter() { if !active_check.ignore_expired { let not_after = cert @@ -143,18 +144,10 @@ fn verify_tls_signature( // for tls 1.3, we need to verify the advertised curve in signature scheme matches the public key if is_tls13 { let signature_curve = utils::pk::rustls_signature_scheme_to_mbedtls_curve_id(dss.scheme); - match signature_curve { - mbedtls::pk::EcGroupId::None => (), - _ => { - let curves_match = pk - .curve() - .is_ok_and(|pk_curve| pk_curve == signature_curve); - if !curves_match { - return Err(rustls::Error::PeerMisbehaved( - rustls::PeerMisbehaved::SignedHandshakeWithUnadvertisedSigScheme, - )); - } - } + if !check_ec_signature_curve_match(signature_curve, pk) { + return Err(rustls::Error::PeerMisbehaved( + rustls::PeerMisbehaved::SignedHandshakeWithUnadvertisedSigScheme, + )); } } @@ -171,6 +164,15 @@ fn verify_tls_signature( Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } +fn check_ec_signature_curve_match(signature_curve: mbedtls::pk::EcGroupId, pk: &mbedtls::pk::Pk) -> bool { + match signature_curve { + mbedtls::pk::EcGroupId::None => true, + _ => pk + .curve() + .is_ok_and(|pk_curve| pk_curve == signature_curve), + } +} + /// Helper function to convert a [`CertificateDer`] to [`mbedtls::x509::Certificate`] pub fn rustls_cert_to_mbedtls_cert(cert: &CertificateDer) -> mbedtls::Result> { let cert = mbedtls::x509::Certificate::from_der(cert)?; @@ -201,3 +203,37 @@ impl Display for VerifyErrorWrapper { } impl std::error::Error for VerifyErrorWrapper {} + +#[cfg(test)] +mod tests { + use super::*; + use mbedtls::pk::EcGroupId; + use mbedtls::x509::InvalidTimeError; + use rustls::{CertificateError, Error}; + + #[test] + fn test_time_err_to_err() { + // Create a sample InvalidTimeError + let time_err = InvalidTimeError; + + // Call the function and check the result + let result = time_err_to_err(time_err); + assert_eq!(result, Error::InvalidCertificate(CertificateError::BadEncoding)); + } + + #[test] + fn verify_error_wrapper_display() { + let verify_error = VerifyError::CERT_EXPIRED; // Replace with actual instantiation of VerifyError + let wrapper = VerifyErrorWrapper(verify_error); + assert_eq!(format!("{}", wrapper), format!("{:?}", verify_error)); + } + + #[test] + fn test_check_ec_signature_curve_match() { + let cert = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec()); + let cert = rustls_cert_to_mbedtls_cert(&cert).unwrap(); + let pk = cert.public_key(); + assert!(check_ec_signature_curve_match(EcGroupId::None, pk)); + assert!(!check_ec_signature_curve_match(EcGroupId::SecP256R1, pk)); + } +}