Skip to content

Commit

Permalink
Return VerifyError in a better way (#25)
Browse files Browse the repository at this point in the history
* special fix: return cert validity check result first

* return verify_error in a better way

* add set function

* add some unit tests

* add some more unit tests
  • Loading branch information
Taowyoo authored Dec 20, 2023
1 parent 98a438e commit 03cf570
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 77 deletions.
123 changes: 91 additions & 32 deletions rustls-mbedpki-provider/src/client_cert_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use chrono::NaiveDateTime;
use mbedtls::x509::VerifyError;
use rustls::pki_types::{CertificateDer, UnixTime};
use rustls::{
server::danger::{ClientCertVerified, ClientCertVerifier},
DistinguishedName,
};
use utils::error::mbedtls_err_into_rustls_err_with_error_msg;

use crate::{
mbedtls_err_into_rustls_err, rustls_cert_to_mbedtls_cert, verify_certificates_active, verify_tls_signature, CertActiveCheck,
mbedtls_err_into_rustls_err, merge_verify_result, rustls_cert_to_mbedtls_cert, verify_certificates_active,
verify_tls_signature, CertActiveCheck, VerifyErrorWrapper,
};

/// A [`rustls`] [`ClientCertVerifier`] implemented using the PKI functionality of
Expand All @@ -29,6 +30,7 @@ pub struct MbedTlsClientCertVerifier {
root_subjects: Vec<DistinguishedName>,
verify_callback: Option<Arc<dyn mbedtls::x509::VerifyCallback + 'static>>,
cert_active_check: CertActiveCheck,
mbedtls_verify_error_mapping: fn(VerifyError) -> rustls::Error,
}

impl core::fmt::Debug for MbedTlsClientCertVerifier {
Expand Down Expand Up @@ -71,9 +73,27 @@ impl MbedTlsClientCertVerifier {
root_subjects,
verify_callback: None,
cert_active_check: CertActiveCheck::default(),
mbedtls_verify_error_mapping: Self::default_mbedtls_verify_error_mapping,
})
}

/// The default mapping of [`VerifyError`] to [`rustls::Error`].
pub fn default_mbedtls_verify_error_mapping(verify_err: VerifyError) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::Other(rustls::OtherError(Arc::new(
VerifyErrorWrapper(verify_err),
))))
}

/// Set the mapping of [`VerifyError`] to [`rustls::Error`].
pub fn set_mbedtls_verify_error_mapping(&mut self, mapping: fn(VerifyError) -> rustls::Error) {
self.mbedtls_verify_error_mapping = mapping;
}

/// Get the current mapping of [`VerifyError`] to [`rustls::Error`].
pub fn mbedtls_verify_error_mapping(&self) -> fn(VerifyError) -> rustls::Error {
self.mbedtls_verify_error_mapping
}

/// The certificate authority certificates used to construct this object
pub fn trusted_cas(&self) -> &mbedtls::alloc::List<mbedtls::x509::Certificate> {
&self.trusted_cas
Expand Down Expand Up @@ -137,10 +157,10 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier {
.collect();

let self_verify_callback = self.verify_callback.clone();
let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut mbedtls::x509::VerifyError| {
let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut VerifyError| {
// When the "time" feature is enabled for mbedtls, it checks cert expiration. We undo that here,
// since this check is done in `verify_certificates_active()` (subject to self.cert_active_check)
flags.remove(mbedtls::x509::VerifyError::CERT_EXPIRED | mbedtls::x509::VerifyError::CERT_FUTURE);
flags.remove(VerifyError::CERT_EXPIRED | VerifyError::CERT_FUTURE);
if let Some(cb) = self_verify_callback.as_ref() {
cb(cert, depth, flags)
} else {
Expand All @@ -149,10 +169,18 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier {
};

let mut error_msg = String::default();
mbedtls::x509::Certificate::verify_with_callback(&chain, &self.trusted_cas, None, Some(&mut error_msg), callback)
.map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?;
let cert_verify_res = mbedtls::x509::Certificate::verify_with_callback_return_verify_err(
&chain,
&self.trusted_cas,
None,
Some(&mut error_msg),
callback,
)
.map_err(|e| e.1);

verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?;
let validity_verify_res = verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?;

merge_verify_result(&validity_verify_res, &cert_verify_res).map_err(self.mbedtls_verify_error_mapping)?;

Ok(ClientCertVerified::assertion())
}
Expand Down Expand Up @@ -188,10 +216,10 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier {
mod tests {

use chrono::DateTime;
use mbedtls::x509::VerifyError;
use rustls::pki_types::{CertificateDer, UnixTime};
use rustls::{
server::danger::ClientCertVerifier, CertificateError, ClientConfig, ClientConnection, RootCertStore, ServerConfig,
ServerConnection,
server::danger::ClientCertVerifier, ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection,
};
use std::{sync::Arc, time::SystemTime};

Expand All @@ -218,6 +246,24 @@ mod tests {
);
}

#[test]
fn client_cert_verifier_setter_getter() {
let root_ca = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec());
let mut client_cert_verifier = MbedTlsClientCertVerifier::new([&root_ca]).unwrap();
assert!(!client_cert_verifier
.trusted_cas()
.is_empty());
const RETURN_ERR: rustls::Error = rustls::Error::BadMaxFragmentSize;
fn test_mbedtls_verify_error_mapping(_verify_err: VerifyError) -> rustls::Error {
RETURN_ERR
}
client_cert_verifier.set_mbedtls_verify_error_mapping(test_mbedtls_verify_error_mapping);
assert_eq!(
client_cert_verifier.mbedtls_verify_error_mapping()(VerifyError::empty()),
RETURN_ERR
);
}

#[test]
fn connection_client_cert_verifier() {
let client_config = ClientConfig::builder();
Expand Down Expand Up @@ -329,13 +375,16 @@ mod tests {
now.duration_since(SystemTime::UNIX_EPOCH)
.unwrap(),
);

assert_eq!(
verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.unwrap_err(),
rustls::Error::InvalidCertificate(CertificateError::Expired)
);
let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now);
if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res {
let verify_err = other_err
.0
.downcast_ref::<crate::VerifyErrorWrapper>()
.unwrap();
assert_eq!(verify_err.0, VerifyError::CERT_EXPIRED);
} else {
panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`")
}
}

#[test]
Expand All @@ -353,12 +402,16 @@ mod tests {
.unwrap(),
);

assert_eq!(
verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.unwrap_err(),
rustls::Error::InvalidCertificate(CertificateError::Expired)
);
let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now);
if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res {
let verify_err = other_err
.0
.downcast_ref::<crate::VerifyErrorWrapper>()
.unwrap();
assert_eq!(verify_err.0, VerifyError::CERT_EXPIRED);
} else {
panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`")
}

// Test that we accept expired certs when `ignore_expired` is true
verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: false });
Expand All @@ -373,16 +426,22 @@ mod tests {
now.duration_since(SystemTime::UNIX_EPOCH)
.unwrap(),
);
assert_eq!(
verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.unwrap_err(),
rustls::Error::InvalidCertificate(CertificateError::NotValidYet)
);

let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now);
if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res {
let verify_err = other_err
.0
.downcast_ref::<crate::VerifyErrorWrapper>()
.unwrap();
assert_eq!(verify_err.0, VerifyError::CERT_FUTURE);
} else {
panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`")
}
// Test that we accept certs that are not valid yet when `ignore_not_active_yet` is true
verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: false, ignore_not_active_yet: true });

assert!(verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.is_ok());
verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: true });
assert!(verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.is_ok());
Expand All @@ -406,8 +465,8 @@ mod tests {
assert!(matches!(verify_res, Err(rustls::Error::InvalidCertificate(_))));

verifier.set_verify_callback(Some(Arc::new(
move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut mbedtls::x509::VerifyError| {
flags.remove(mbedtls::x509::VerifyError::CERT_NOT_TRUSTED);
move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut VerifyError| {
flags.remove(VerifyError::CERT_NOT_TRUSTED);
Ok(())
},
)));
Expand Down
106 changes: 85 additions & 21 deletions rustls-mbedpki-provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ extern crate alloc;
#[cfg(not(test))]
extern crate std;

use core::fmt::Display;

use chrono::NaiveDateTime;
use mbedtls::x509::VerifyError;
use rustls::pki_types::CertificateDer;
use rustls::SignatureScheme;

Expand Down Expand Up @@ -85,19 +88,20 @@ pub const SUPPORTED_SIGNATURE_SCHEMA: [SignatureScheme; 9] = [
rustls::SignatureScheme::RSA_PSS_SHA256,
];

/// Helper function to convert a [`mbedtls::x509::InvalidTimeError`] to a [`rustls::Error`]
fn time_err_to_err(_time_err: mbedtls::x509::InvalidTimeError) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding)
}

/// Verifies that certificates are active, i.e., `now` is between not_before and not_after for
/// each certificate
fn verify_certificates_active<'a>(
chain: impl IntoIterator<Item = &'a mbedtls::x509::Certificate>,
now: NaiveDateTime,
active_check: &CertActiveCheck,
) -> Result<(), rustls::Error> {
) -> Result<Result<(), VerifyError>, rustls::Error> {
if active_check.ignore_expired && active_check.ignore_not_active_yet {
return Ok(());
}

fn time_err_to_err(_time_err: mbedtls::x509::InvalidTimeError) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding)
return Ok(Ok(()));
}

for cert in chain.into_iter() {
Expand All @@ -108,7 +112,7 @@ fn verify_certificates_active<'a>(
.try_into()
.map_err(time_err_to_err)?;
if now > not_after {
return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Expired));
return Ok(Err(VerifyError::CERT_EXPIRED));
}
}
if !active_check.ignore_not_active_yet {
Expand All @@ -118,11 +122,11 @@ fn verify_certificates_active<'a>(
.try_into()
.map_err(time_err_to_err)?;
if now < not_before {
return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidYet));
return Ok(Err(VerifyError::CERT_FUTURE));
}
}
}
Ok(())
Ok(Ok(()))
}

/// Verifies the tls signature, matches verify functions in rustls `ClientCertVerifier` and
Expand All @@ -140,18 +144,10 @@ fn verify_tls_signature(
// for tls 1.3, we need to verify the advertised curve in signature scheme matches the public key
if is_tls13 {
let signature_curve = utils::pk::rustls_signature_scheme_to_mbedtls_curve_id(dss.scheme);
match signature_curve {
mbedtls::pk::EcGroupId::None => (),
_ => {
let curves_match = pk
.curve()
.is_ok_and(|pk_curve| pk_curve == signature_curve);
if !curves_match {
return Err(rustls::Error::PeerMisbehaved(
rustls::PeerMisbehaved::SignedHandshakeWithUnadvertisedSigScheme,
));
}
}
if !check_ec_signature_curve_match(signature_curve, pk) {
return Err(rustls::Error::PeerMisbehaved(
rustls::PeerMisbehaved::SignedHandshakeWithUnadvertisedSigScheme,
));
}
}

Expand All @@ -168,8 +164,76 @@ fn verify_tls_signature(
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn check_ec_signature_curve_match(signature_curve: mbedtls::pk::EcGroupId, pk: &mbedtls::pk::Pk) -> bool {
match signature_curve {
mbedtls::pk::EcGroupId::None => true,
_ => pk
.curve()
.is_ok_and(|pk_curve| pk_curve == signature_curve),
}
}

/// Helper function to convert a [`CertificateDer`] to [`mbedtls::x509::Certificate`]
pub fn rustls_cert_to_mbedtls_cert(cert: &CertificateDer) -> mbedtls::Result<mbedtls::alloc::Box<mbedtls::x509::Certificate>> {
let cert = mbedtls::x509::Certificate::from_der(cert)?;
Ok(cert)
}

pub(crate) fn merge_verify_result(
first: &Result<(), VerifyError>,
second: &Result<(), VerifyError>,
) -> Result<(), VerifyError> {
match (first, second) {
(Ok(()), Ok(())) => Ok(()),
(Ok(()), Err(second_err)) => Err(*second_err),
(Err(first_err), Ok(())) => Err(*first_err),
(Err(first_err), Err(second_err)) => Err(*first_err | *second_err),
}
}

/// A wrapper on [`mbedtls::x509::VerifyError`] to impl [`std::error::Error`] for it.
#[derive(Debug)]
pub struct VerifyErrorWrapper(pub VerifyError);

impl Display for VerifyErrorWrapper {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use core::fmt::Debug;
self.0.fmt(f)
}
}

impl std::error::Error for VerifyErrorWrapper {}

#[cfg(test)]
mod tests {
use super::*;
use mbedtls::pk::EcGroupId;
use mbedtls::x509::InvalidTimeError;
use rustls::{CertificateError, Error};

#[test]
fn test_time_err_to_err() {
// Create a sample InvalidTimeError
let time_err = InvalidTimeError;

// Call the function and check the result
let result = time_err_to_err(time_err);
assert_eq!(result, Error::InvalidCertificate(CertificateError::BadEncoding));
}

#[test]
fn verify_error_wrapper_display() {
let verify_error = VerifyError::CERT_EXPIRED; // Replace with actual instantiation of VerifyError
let wrapper = VerifyErrorWrapper(verify_error);
assert_eq!(format!("{}", wrapper), format!("{:?}", verify_error));
}

#[test]
fn test_check_ec_signature_curve_match() {
let cert = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec());
let cert = rustls_cert_to_mbedtls_cert(&cert).unwrap();
let pk = cert.public_key();
assert!(check_ec_signature_curve_match(EcGroupId::None, pk));
assert!(!check_ec_signature_curve_match(EcGroupId::SecP256R1, pk));
}
}
Loading

0 comments on commit 03cf570

Please sign in to comment.