Skip to content

Commit

Permalink
feat: enhance verifiers and expose util functions (#10)
Browse files Browse the repository at this point in the history
* feat: enhance verifiers and expose util functions

- Change verifiers to be able to ignore expired cert.
- Change verifiers to be able to have a custom verify callback.
- Make some util functions public.

* test: add tests

* refactor

* refactor

* code improve

* fmt

* fix rustdoc
  • Loading branch information
Taowyoo authored Nov 9, 2023
1 parent 3225e03 commit 9f0c916
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 62 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ jobs:
uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- run: cargo clippy --locked --package rustls-mbedcrypto-provider --all-features --all-targets -- --deny warnings
- run: cargo clippy --locked --package rustls-mbedcrypto-provider --no-default-features --all-targets -- --deny warnings
- run: cargo clippy --locked --all-features --all-targets -- --deny warnings
- run: cargo clippy --locked --no-default-features --all-targets -- --deny warnings

clippy-nightly:
name: Clippy (Nightly)
Expand All @@ -212,5 +212,5 @@ jobs:
uses: dtolnay/rust-toolchain@nightly
with:
components: clippy
- run: cargo clippy --locked --package rustls-mbedcrypto-provider --all-features --all-targets
- run: cargo clippy --locked --package rustls-mbedcrypto-provider --no-default-features --all-targets
- run: cargo clippy --locked --all-features --all-targets
- run: cargo clippy --locked --no-default-features --all-targets
135 changes: 125 additions & 10 deletions rustls-mbedpki-provider/src/client_cert_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/

use std::sync::Arc;

use chrono::NaiveDateTime;
use pki_types::{CertificateDer, UnixTime};
use rustls::{
Expand All @@ -14,19 +16,21 @@ use rustls::{

use crate::{
mbedtls_err_into_rustls_err, mbedtls_err_into_rustls_err_with_error_msg, rustls_cert_to_mbedtls_cert,
verify_certificates_active, verify_tls_signature,
verify_certificates_active, verify_tls_signature, CertActiveCheck,
};

/// A `rustls` `ClientCertVerifier` implemented using the PKI functionality of
/// A [`rustls`] [`ClientCertVerifier`] implemented using the PKI functionality of
/// `mbedtls`
#[derive(Clone)]
pub struct MbedTlsClientCertVerifier {
trusted_cas: mbedtls::alloc::List<mbedtls::x509::Certificate>,
root_subjects: Vec<rustls::DistinguishedName>,
root_subjects: Vec<DistinguishedName>,
verify_callback: Option<Arc<dyn mbedtls::x509::VerifyCallback + 'static>>,
cert_active_check: CertActiveCheck,
}

impl MbedTlsClientCertVerifier {
/// Constructs a new `MbedTlsClientCertVerifier` object given the provided trusted certificate authority
/// Constructs a new [`MbedTlsClientCertVerifier`] object given the provided trusted certificate authority
/// certificates.
///
/// Returns an error if any of the certificates are invalid.
Expand All @@ -40,7 +44,7 @@ impl MbedTlsClientCertVerifier {
Self::new_from_mbedtls_trusted_cas(trusted_cas)
}

/// Constructs a new `MbedTlsClientCertVerifier` object given the provided trusted certificate authority
/// Constructs a new [`MbedTlsClientCertVerifier`] object given the provided trusted certificate authority
/// certificates.
pub fn new_from_mbedtls_trusted_cas(
trusted_cas: mbedtls::alloc::List<mbedtls::x509::Certificate>,
Expand All @@ -49,7 +53,12 @@ impl MbedTlsClientCertVerifier {
for ca in trusted_cas.iter() {
root_subjects.push(DistinguishedName::from(ca.subject_raw()?));
}
Ok(Self { trusted_cas, root_subjects })
Ok(Self {
trusted_cas,
root_subjects,
verify_callback: None,
cert_active_check: CertActiveCheck::default(),
})
}

/// The certificate authority certificates used to construct this object
Expand All @@ -62,6 +71,28 @@ impl MbedTlsClientCertVerifier {
pub fn root_subjects(&self) -> &[DistinguishedName] {
self.root_subjects.as_ref()
}

/// Retrieves the verification callback function set for the certificate verification process.
pub fn verify_callback(&self) -> Option<Arc<dyn mbedtls::x509::VerifyCallback + 'static>> {
self.verify_callback.clone()
}

/// Sets the verification callback for mbedtls certificate verification process,
///
/// This callback function allows you to add logic at end of mbedtls verification before returning.
pub fn set_verify_callback(&mut self, callback: Option<Arc<dyn mbedtls::x509::VerifyCallback + 'static>>) {
self.verify_callback = callback;
}

/// Getter for [`CertActiveCheck`]
pub fn cert_active_check(&self) -> &CertActiveCheck {
&self.cert_active_check
}

/// Setter for [`CertActiveCheck`]
pub fn set_cert_active_check(&mut self, check: CertActiveCheck) {
self.cert_active_check = check;
}
}

impl ClientCertVerifier for MbedTlsClientCertVerifier {
Expand All @@ -74,7 +105,7 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier {
end_entity: &CertificateDer,
intermediates: &[CertificateDer],
now: UnixTime,
) -> Result<rustls::server::danger::ClientCertVerified, rustls::Error> {
) -> Result<ClientCertVerified, rustls::Error> {
let now = NaiveDateTime::from_timestamp_opt(
now.as_secs()
.try_into()
Expand All @@ -92,11 +123,25 @@ impl ClientCertVerifier for MbedTlsClientCertVerifier {
.into_iter()
.collect();

verify_certificates_active(chain.iter().map(|c| &**c), now)?;
verify_certificates_active(chain.iter().map(|c| &**c), now, &self.cert_active_check)?;

let mut error_msg = String::default();
mbedtls::x509::Certificate::verify(&chain, &self.trusted_cas, None, Some(&mut error_msg))
.map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?;
match &self.verify_callback {
Some(callback) => {
let callback = Arc::clone(callback);
mbedtls::x509::Certificate::verify_with_callback(
&chain,
&self.trusted_cas,
None,
Some(&mut error_msg),
move |cert: &mbedtls::x509::Certificate, depth: i32, flags: &mut mbedtls::x509::VerifyError| {
callback(cert, depth, flags)
},
)
}
None => mbedtls::x509::Certificate::verify(&chain, &self.trusted_cas, None, Some(&mut error_msg)),
}
.map_err(|e| mbedtls_err_into_rustls_err_with_error_msg(e, &error_msg))?;

Ok(ClientCertVerified::assertion())
}
Expand Down Expand Up @@ -272,4 +317,74 @@ mod tests {
rustls::Error::InvalidCertificate(CertificateError::Expired)
);
}

#[test]
fn client_cert_verifier_active_check() {
let cert_chain = get_chain(include_bytes!("../test-data/rsa/client.fullchain"));
let trusted_cas = [CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec())];

let mut verifier = MbedTlsClientCertVerifier::new(trusted_cas.iter()).unwrap();
let now = SystemTime::from(DateTime::parse_from_rfc3339("2052-11-26T12:00:00+00:00").unwrap());
let now = UnixTime::since_unix_epoch(
now.duration_since(SystemTime::UNIX_EPOCH)
.unwrap(),
);

assert_eq!(
verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.unwrap_err(),
rustls::Error::InvalidCertificate(CertificateError::Expired)
);
verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: true, ignore_not_active_yet: false });

assert!(verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.is_ok());

let now = SystemTime::from(DateTime::parse_from_rfc3339("2002-11-26T12:00:00+00:00").unwrap());
let now = UnixTime::since_unix_epoch(
now.duration_since(SystemTime::UNIX_EPOCH)
.unwrap(),
);
assert_eq!(
verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.unwrap_err(),
rustls::Error::InvalidCertificate(CertificateError::NotValidYet)
);
verifier.set_cert_active_check(crate::CertActiveCheck { ignore_expired: false, ignore_not_active_yet: true });

assert!(verifier
.verify_client_cert(&cert_chain[0], &cert_chain[1..], now)
.is_ok());
}

#[test]
fn client_cert_verifier_callback() {
let mut cert_chain = get_chain(include_bytes!("../test-data/rsa/client.fullchain"));
cert_chain.remove(1);
let trusted_cas = [CertificateDer::from(include_bytes!("../test-data/rsa/ca.der").to_vec())];

let mut verifier = MbedTlsClientCertVerifier::new(trusted_cas.iter()).unwrap();
assert!(verifier.verify_callback().is_none());
let now = SystemTime::from(DateTime::parse_from_rfc3339("2023-11-26T12:00:00+00:00").unwrap());
let now = UnixTime::since_unix_epoch(
now.duration_since(SystemTime::UNIX_EPOCH)
.unwrap(),
);

let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now);
assert!(matches!(verify_res, Err(rustls::Error::InvalidCertificate(_))));

verifier.set_verify_callback(Some(Arc::new(
move |_cert: &mbedtls::x509::Certificate, _depth: i32, flags: &mut mbedtls::x509::VerifyError| {
flags.remove(mbedtls::x509::VerifyError::CERT_NOT_TRUSTED);
Ok(())
},
)));
assert!(verifier.verify_callback().is_some());
let verify_res = verifier.verify_client_cert(&cert_chain[0], &cert_chain[1..], now);
assert!(verify_res.is_ok(), "{:?}", verify_res);
}
}
95 changes: 76 additions & 19 deletions rustls-mbedpki-provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,33 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/

//! rustls-mbedpki-provider
//!
//! rustls-mbedpki-provider is a pki provider for rustls based on [mbedtls].
//!
//! [mbedtls]: https://github.com/fortanix/rust-mbedtls
// Require docs for public APIs, deny unsafe code, etc.
#![forbid(unsafe_code, unused_must_use)]
#![cfg_attr(not(bench), forbid(unstable_features))]
#![deny(
clippy::alloc_instead_of_core,
clippy::clone_on_ref_ptr,
clippy::std_instead_of_core,
clippy::use_self,
clippy::upper_case_acronyms,
trivial_casts,
trivial_numeric_casts,
missing_docs,
unreachable_pub,
unused_import_braces,
unused_extern_crates,
unused_qualifications
)]
// Enable documentation for all features on docs.rs
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
#![cfg_attr(bench, feature(test))]

use chrono::NaiveDateTime;
use mbedtls::hash::Type;
use pki_types::CertificateDer;
Expand All @@ -14,12 +41,28 @@ use std::sync::Arc;
#[cfg(test)]
mod tests_common;

/// module for implementation of [`ClientCertVerifier`]
///
/// [`ClientCertVerifier`]: rustls::server::danger::ClientCertVerifier
pub mod client_cert_verifier;
/// module for implementation of [`ServerCertVerifier`]
///
/// [`ServerCertVerifier`]: rustls::client::danger::ServerCertVerifier
pub mod server_cert_verifier;

pub use client_cert_verifier::MbedTlsClientCertVerifier;
pub use server_cert_verifier::MbedTlsServerCertVerifier;

/// A config about whether to check certificate validity period
#[derive(Debug, PartialEq, Eq, Clone, Default)]
pub struct CertActiveCheck {
/// Accept expired certificates
pub ignore_expired: bool,
/// Accept certificates that are not yet active
pub ignore_not_active_yet: bool,
}

/// Helper function to convert a [`CertificateDer`] to [`mbedtls::x509::Certificate`]
pub fn rustls_cert_to_mbedtls_cert(cert: &CertificateDer) -> mbedtls::Result<mbedtls::alloc::Box<mbedtls::x509::Certificate>> {
let cert = mbedtls::x509::Certificate::from_der(cert)?;
Ok(cert)
Expand All @@ -30,6 +73,7 @@ pub fn mbedtls_err_into_rustls_err(err: mbedtls::Error) -> rustls::Error {
mbedtls_err_into_rustls_err_with_error_msg(err, "")
}

/// All supported signature schemas
pub const SUPPORTED_SIGNATURE_SCHEMA: [SignatureScheme; 9] = [
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA384,
Expand Down Expand Up @@ -75,7 +119,8 @@ pub fn mbedtls_err_into_rustls_err_with_error_msg(err: mbedtls::Error, msg: &str
}
}

fn rustls_signature_scheme_to_mbedtls_hash_type(signature_scheme: SignatureScheme) -> mbedtls::hash::Type {
/// Helper function to convert rustls [`SignatureScheme`] to mbedtls [`Type`]
pub fn rustls_signature_scheme_to_mbedtls_hash_type(signature_scheme: SignatureScheme) -> Type {
match signature_scheme {
SignatureScheme::RSA_PKCS1_SHA1 => Type::Sha1,
SignatureScheme::ECDSA_SHA1_Legacy => Type::Sha1,
Expand All @@ -95,7 +140,8 @@ fn rustls_signature_scheme_to_mbedtls_hash_type(signature_scheme: SignatureSchem
}
}

fn rustls_signature_scheme_to_mbedtls_pk_options(signature_scheme: SignatureScheme) -> Option<mbedtls::pk::Options> {
/// Helper function to convert rustls [`SignatureScheme`] to mbedtls [`mbedtls::pk::Options`]
pub fn rustls_signature_scheme_to_mbedtls_pk_options(signature_scheme: SignatureScheme) -> Option<mbedtls::pk::Options> {
use mbedtls::pk::Options;
use mbedtls::pk::RsaPadding;
// reference: https://www.rfc-editor.org/rfc/rfc8446.html#section-4.2.3
Expand All @@ -118,7 +164,8 @@ fn rustls_signature_scheme_to_mbedtls_pk_options(signature_scheme: SignatureSche
}
}

fn rustls_signature_scheme_to_mbedtls_curve_id(signature_scheme: SignatureScheme) -> mbedtls::pk::EcGroupId {
/// Helper function to convert rustls [`SignatureScheme`] to mbedtls [`mbedtls::pk::EcGroupId`]
pub fn rustls_signature_scheme_to_mbedtls_curve_id(signature_scheme: SignatureScheme) -> mbedtls::pk::EcGroupId {
// reference: https://www.rfc-editor.org/rfc/rfc8446.html#section-4.2.3
use mbedtls::pk::EcGroupId;
match signature_scheme {
Expand All @@ -141,7 +188,7 @@ fn rustls_signature_scheme_to_mbedtls_curve_id(signature_scheme: SignatureScheme
}

/// Returns the size of the message digest given the hash type.
fn hash_size_bytes(hash_type: mbedtls::hash::Type) -> Option<usize> {
fn hash_size_bytes(hash_type: Type) -> Option<usize> {
match hash_type {
mbedtls::hash::Type::None => None,
mbedtls::hash::Type::Md2 => Some(16),
Expand All @@ -156,7 +203,8 @@ fn hash_size_bytes(hash_type: mbedtls::hash::Type) -> Option<usize> {
}
}

fn buffer_for_hash_type(hash_type: mbedtls::hash::Type) -> Option<Vec<u8>> {
/// Returns the a ready to use empty [`Vec<u8>`] for the message digest with given hash type.
pub fn buffer_for_hash_type(hash_type: Type) -> Option<Vec<u8>> {
let size = hash_size_bytes(hash_type)?;
Some(vec![0; size])
}
Expand All @@ -166,27 +214,36 @@ fn buffer_for_hash_type(hash_type: mbedtls::hash::Type) -> Option<Vec<u8>> {
fn verify_certificates_active<'a>(
chain: impl IntoIterator<Item = &'a mbedtls::x509::Certificate>,
now: NaiveDateTime,
active_check: &CertActiveCheck,
) -> Result<(), rustls::Error> {
if active_check.ignore_expired && active_check.ignore_not_active_yet {
return Ok(());
}

fn time_err_to_err(_time_err: mbedtls::x509::InvalidTimeError) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding)
}

for cert in chain.into_iter() {
let not_after = cert
.not_after()
.map_err(mbedtls_err_into_rustls_err)?
.try_into()
.map_err(time_err_to_err)?;
if now > not_after {
return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Expired));
if !active_check.ignore_expired {
let not_after = cert
.not_after()
.map_err(mbedtls_err_into_rustls_err)?
.try_into()
.map_err(time_err_to_err)?;
if now > not_after {
return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::Expired));
}
}
let not_before = cert
.not_before()
.map_err(mbedtls_err_into_rustls_err)?
.try_into()
.map_err(time_err_to_err)?;
if now < not_before {
return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidYet));
if !active_check.ignore_not_active_yet {
let not_before = cert
.not_before()
.map_err(mbedtls_err_into_rustls_err)?
.try_into()
.map_err(time_err_to_err)?;
if now < not_before {
return Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidYet));
}
}
}
Ok(())
Expand Down
Loading

0 comments on commit 9f0c916

Please sign in to comment.