Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return VerifyError in a better way #25

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 91 additions & 32 deletions rustls-mbedpki-provider/src/client_cert_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
use alloc::vec;
use alloc::vec::Vec;
use chrono::NaiveDateTime;
use mbedtls::x509::VerifyError;
use rustls::pki_types::{CertificateDer, UnixTime};
use rustls::{
server::danger::{ClientCertVerified, ClientCertVerifier},
DistinguishedName,
};
use utils::error::mbedtls_err_into_rustls_err_with_error_msg;

use crate::{
mbedtls_err_into_rustls_err, rustls_cert_to_mbedtls_cert, verify_certificates_active, verify_tls_signature, CertActiveCheck,
mbedtls_err_into_rustls_err, merge_verify_result, rustls_cert_to_mbedtls_cert, verify_certificates_active,
verify_tls_signature, CertActiveCheck, VerifyErrorWrapper,
};

/// A [`rustls`] [`ClientCertVerifier`] implemented using the PKI functionality of
Expand All @@ -29,6 +30,7 @@
root_subjects: Vec<DistinguishedName>,
verify_callback: Option<Arc<dyn mbedtls::x509::VerifyCallback + 'static>>,
cert_active_check: CertActiveCheck,
mbedtls_verify_error_mapping: fn(VerifyError) -> rustls::Error,
}

impl core::fmt::Debug for MbedTlsClientCertVerifier {
Expand Down Expand Up @@ -71,9 +73,27 @@
root_subjects,
verify_callback: None,
cert_active_check: CertActiveCheck::default(),
mbedtls_verify_error_mapping: Self::default_mbedtls_verify_error_mapping,
})
}

/// The default mapping of [`VerifyError`] to [`rustls::Error`].
pub fn default_mbedtls_verify_error_mapping(verify_err: VerifyError) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::Other(rustls::OtherError(Arc::new(
VerifyErrorWrapper(verify_err),
))))
}

/// Set the mapping of [`VerifyError`] to [`rustls::Error`].
pub fn set_mbedtls_verify_error_mapping(&mut self, mapping: fn(VerifyError) -> rustls::Error) {
self.mbedtls_verify_error_mapping = mapping;
}

/// Get the current mapping of [`VerifyError`] to [`rustls::Error`].
pub fn mbedtls_verify_error_mapping(&self) -> fn(VerifyError) -> rustls::Error {
self.mbedtls_verify_error_mapping
}

/// The certificate authority certificates used to construct this object
pub fn trusted_cas(&self) -> &mbedtls::alloc::List<mbedtls::x509::Certificate> {
&self.trusted_cas
Expand Down Expand Up @@ -137,10 +157,10 @@
.collect();

let self_verify_callback = self.verify_callback.clone();
let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut mbedtls::x509::VerifyError| {
let callback = move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut VerifyError| {
// When the "time" feature is enabled for mbedtls, it checks cert expiration. We undo that here,
// since this check is done in `verify_certificates_active()` (subject to self.cert_active_check)
flags.remove(mbedtls::x509::VerifyError::CERT_EXPIRED | mbedtls::x509::VerifyError::CERT_FUTURE);
flags.remove(VerifyError::CERT_EXPIRED | VerifyError::CERT_FUTURE);
if let Some(cb) = self_verify_callback.as_ref() {
cb(cert, depth, flags)
} else {
Expand All @@ -149,10 +169,18 @@
};

let mut error_msg = String::default();
mbedtls::x509::Certificate::verify_with_callback(&chain, &self.trusted_cas, None, Some(&mut error_msg), callback)
.map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?;
let cert_verify_res = mbedtls::x509::Certificate::verify_with_callback_return_verify_err(
&chain,
&self.trusted_cas,
None,
Some(&mut error_msg),
callback,
)
.map_err(|e| e.1);

verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?;
let validity_verify_res = verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?;

merge_verify_result(&validity_verify_res, &cert_verify_res).map_err(self.mbedtls_verify_error_mapping)?;

Ok(ClientCertVerified::assertion())
}
Expand Down Expand Up @@ -188,10 +216,10 @@
mod tests {

use chrono::DateTime;
use mbedtls::x509::VerifyError;
use rustls::pki_types::{CertificateDer, UnixTime};
use rustls::{
server::danger::ClientCertVerifier, CertificateError, ClientConfig, ClientConnection, RootCertStore, ServerConfig,
ServerConnection,
server::danger::ClientCertVerifier, ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection,
};
use std::{sync::Arc, time::SystemTime};

Expand All @@ -218,6 +246,24 @@
);
}

#[test]
fn client_cert_verifier_setter_getter() {
let root_ca = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec());
let mut client_cert_verifier = MbedTlsClientCertVerifier::new([&root_ca]).unwrap();
assert!(!client_cert_verifier
.trusted_cas()
.is_empty());
const RETURN_ERR: rustls::Error = rustls::Error::BadMaxFragmentSize;
fn test_mbedtls_verify_error_mapping(_verify_err: VerifyError) -> rustls::Error {
RETURN_ERR
}
client_cert_verifier.set_mbedtls_verify_error_mapping(test_mbedtls_verify_error_mapping);
assert_eq!(
client_cert_verifier.mbedtls_verify_error_mapping()(VerifyError::empty()),
RETURN_ERR
);
}

#[test]
fn connection_client_cert_verifier() {
let client_config = ClientConfig::builder();
Expand Down Expand Up @@ -329,13 +375,16 @@
now.duration_since(SystemTime::UNIX_EPOCH)
.unwrap(),
);

assert_eq!(
verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.unwrap_err(),
rustls::Error::InvalidCertificate(CertificateError::Expired)
);
let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now);
if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res {
let verify_err = other_err
.0
.downcast_ref::<crate::VerifyErrorWrapper>()
.unwrap();
assert_eq!(verify_err.0, VerifyError::CERT_EXPIRED);
} else {
panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`")

Check warning on line 386 in rustls-mbedpki-provider/src/client_cert_verifier.rs

View check run for this annotation

Codecov / codecov/patch

rustls-mbedpki-provider/src/client_cert_verifier.rs#L386

Added line #L386 was not covered by tests
}
}

#[test]
Expand All @@ -353,12 +402,16 @@
.unwrap(),
);

assert_eq!(
verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.unwrap_err(),
rustls::Error::InvalidCertificate(CertificateError::Expired)
);
let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now);
if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res {
let verify_err = other_err
.0
.downcast_ref::<crate::VerifyErrorWrapper>()
.unwrap();
assert_eq!(verify_err.0, VerifyError::CERT_EXPIRED);
} else {
panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`")

Check warning on line 413 in rustls-mbedpki-provider/src/client_cert_verifier.rs

View check run for this annotation

Codecov / codecov/patch

rustls-mbedpki-provider/src/client_cert_verifier.rs#L413

Added line #L413 was not covered by tests
}

// Test that we accept expired certs when `ignore_expired` is true
verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: false });
Expand All @@ -373,16 +426,22 @@
now.duration_since(SystemTime::UNIX_EPOCH)
.unwrap(),
);
assert_eq!(
verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.unwrap_err(),
rustls::Error::InvalidCertificate(CertificateError::NotValidYet)
);

let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now);
if let Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Other(other_err))) = verify_res {
let verify_err = other_err
.0
.downcast_ref::<crate::VerifyErrorWrapper>()
.unwrap();
assert_eq!(verify_err.0, VerifyError::CERT_FUTURE);
} else {
panic!("should get an error with type: `rustls::Error::InvalidCertificate(rustls::CertificateError::Other(..))`")

Check warning on line 437 in rustls-mbedpki-provider/src/client_cert_verifier.rs

View check run for this annotation

Codecov / codecov/patch

rustls-mbedpki-provider/src/client_cert_verifier.rs#L437

Added line #L437 was not covered by tests
}
// Test that we accept certs that are not valid yet when `ignore_not_active_yet` is true
verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: false, ignore_not_active_yet: true });

assert!(verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.is_ok());
verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: true });
assert!(verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.is_ok());
Expand All @@ -406,8 +465,8 @@
assert!(matches!(verify_res, Err(rustls::Error::InvalidCertificate(_))));

verifier.set_verify_callback(Some(Arc::new(
move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut mbedtls::x509::VerifyError| {
flags.remove(mbedtls::x509::VerifyError::CERT_NOT_TRUSTED);
move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut VerifyError| {
flags.remove(VerifyError::CERT_NOT_TRUSTED);
Ok(())
},
)));
Expand Down
106 changes: 85 additions & 21 deletions rustls-mbedpki-provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
#[cfg(not(test))]
extern crate std;

use core::fmt::Display;

use chrono::NaiveDateTime;
use mbedtls::x509::VerifyError;
use rustls::pki_types::CertificateDer;
use rustls::SignatureScheme;

Expand Down Expand Up @@ -85,19 +88,20 @@
rustls::SignatureScheme::RSA_PSS_SHA256,
];

/// Helper function to convert a [`mbedtls::x509::InvalidTimeError`] to a [`rustls::Error`]
fn time_err_to_err(_time_err: mbedtls::x509::InvalidTimeError) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding)
}

/// Verifies that certificates are active, i.e., `now` is between not_before and not_after for
/// each certificate
fn verify_certificates_active<'a>(
chain: impl IntoIterator<Item = &'a mbedtls::x509::Certificate>,
now: NaiveDateTime,
active_check: &CertActiveCheck,
) -> Result<(), rustls::Error> {
) -> Result<Result<(), VerifyError>, rustls::Error> {
if active_check.ignore_expired && active_check.ignore_not_active_yet {
return Ok(());
}

fn time_err_to_err(_time_err: mbedtls::x509::InvalidTimeError) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding)
return Ok(Ok(()));
}

for cert in chain.into_iter() {
Expand All @@ -108,7 +112,7 @@
.try_into()
.map_err(time_err_to_err)?;
if now > not_after {
return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Expired));
return Ok(Err(VerifyError::CERT_EXPIRED));
}
}
if !active_check.ignore_not_active_yet {
Expand All @@ -118,11 +122,11 @@
.try_into()
.map_err(time_err_to_err)?;
if now < not_before {
return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidYet));
return Ok(Err(VerifyError::CERT_FUTURE));
}
}
}
Ok(())
Ok(Ok(()))
}

/// Verifies the tls signature, matches verify functions in rustls `ClientCertVerifier` and
Expand All @@ -140,18 +144,10 @@
// for tls 1.3, we need to verify the advertised curve in signature scheme matches the public key
if is_tls13 {
let signature_curve = utils::pk::rustls_signature_scheme_to_mbedtls_curve_id(dss.scheme);
match signature_curve {
mbedtls::pk::EcGroupId::None => (),
_ => {
let curves_match = pk
.curve()
.is_ok_and(|pk_curve| pk_curve == signature_curve);
if !curves_match {
return Err(rustls::Error::PeerMisbehaved(
rustls::PeerMisbehaved::SignedHandshakeWithUnadvertisedSigScheme,
));
}
}
if !check_ec_signature_curve_match(signature_curve, pk) {
return Err(rustls::Error::PeerMisbehaved(
rustls::PeerMisbehaved::SignedHandshakeWithUnadvertisedSigScheme,
));

Check warning on line 150 in rustls-mbedpki-provider/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

rustls-mbedpki-provider/src/lib.rs#L148-L150

Added lines #L148 - L150 were not covered by tests
}
}

Expand All @@ -168,8 +164,76 @@
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn check_ec_signature_curve_match(signature_curve: mbedtls::pk::EcGroupId, pk: &mbedtls::pk::Pk) -> bool {
match signature_curve {
mbedtls::pk::EcGroupId::None => true,
_ => pk
.curve()
.is_ok_and(|pk_curve| pk_curve == signature_curve),
}
}

/// Helper function to convert a [`CertificateDer`] to [`mbedtls::x509::Certificate`]
pub fn rustls_cert_to_mbedtls_cert(cert: &CertificateDer) -> mbedtls::Result<mbedtls::alloc::Box<mbedtls::x509::Certificate>> {
let cert = mbedtls::x509::Certificate::from_der(cert)?;
Ok(cert)
}

pub(crate) fn merge_verify_result(
first: &Result<(), VerifyError>,
second: &Result<(), VerifyError>,
) -> Result<(), VerifyError> {
match (first, second) {
(Ok(()), Ok(())) => Ok(()),
(Ok(()), Err(second_err)) => Err(*second_err),
(Err(first_err), Ok(())) => Err(*first_err),
(Err(first_err), Err(second_err)) => Err(*first_err | *second_err),

Check warning on line 190 in rustls-mbedpki-provider/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

rustls-mbedpki-provider/src/lib.rs#L190

Added line #L190 was not covered by tests
}
}

/// A wrapper on [`mbedtls::x509::VerifyError`] to impl [`std::error::Error`] for it.
#[derive(Debug)]
pub struct VerifyErrorWrapper(pub VerifyError);

impl Display for VerifyErrorWrapper {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use core::fmt::Debug;
self.0.fmt(f)
}
}

impl std::error::Error for VerifyErrorWrapper {}

#[cfg(test)]
mod tests {
use super::*;
use mbedtls::pk::EcGroupId;
use mbedtls::x509::InvalidTimeError;
use rustls::{CertificateError, Error};

#[test]
fn test_time_err_to_err() {
// Create a sample InvalidTimeError
let time_err = InvalidTimeError;

// Call the function and check the result
let result = time_err_to_err(time_err);
assert_eq!(result, Error::InvalidCertificate(CertificateError::BadEncoding));
}

#[test]
fn verify_error_wrapper_display() {
let verify_error = VerifyError::CERT_EXPIRED; // Replace with actual instantiation of VerifyError
let wrapper = VerifyErrorWrapper(verify_error);
assert_eq!(format!("{}", wrapper), format!("{:?}", verify_error));
}

#[test]
fn test_check_ec_signature_curve_match() {
let cert = CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec());
let cert = rustls_cert_to_mbedtls_cert(&cert).unwrap();
let pk = cert.public_key();
assert!(check_ec_signature_curve_match(EcGroupId::None, pk));
assert!(!check_ec_signature_curve_match(EcGroupId::SecP256R1, pk));
}
}
Loading