From 2ce33b7c26dfdc6eec76979bce428f87b213b634 Mon Sep 17 00:00:00 2001 From: Richard Zak Date: Fri, 9 Dec 2022 12:23:05 -0500 Subject: [PATCH] chore: refactor unit tests Signed-off-by: Richard Zak --- Cargo.lock | 44 +- Cargo.toml | 22 +- crates/attestation/src/lib.rs | 3 +- crates/server/Cargo.toml | 34 ++ {src => crates/server/src}/kvm.rs | 0 crates/server/src/lib.rs | 844 ++++++++++++++++++++++++++++ src/main.rs | 883 +----------------------------- 7 files changed, 932 insertions(+), 898 deletions(-) create mode 100644 crates/server/Cargo.toml rename {src => crates/server/src}/kvm.rs (100%) create mode 100644 crates/server/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 259121bb..b9c6034b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1029,6 +1029,35 @@ dependencies = [ "serde", ] +[[package]] +name = "server" +version = "0.2.0" +dependencies = [ + "anyhow", + "attestation", + "axum", + "const-oid", + "der", + "http", + "hyper", + "memoffset", + "rstest", + "rustls-pemfile", + "sec1", + "serde", + "sgx", + "testaso", + "tokio", + "toml", + "tower", + "tower-http", + "tracing", + "tracing-subscriber", + "uuid", + "x509-cert", + "zeroize", +] + [[package]] name = "sgx" version = "0.6.0" @@ -1120,31 +1149,18 @@ dependencies = [ "anyhow", "attestation", "axum", - "base64", "clap", "confargs", - "const-oid", "der", - "flagset", "http", - "hyper", - "memoffset", - "mime", - "rstest", - "rustls-pemfile", - "sec1", - "serde", + "server", "sgx", - "testaso", "tokio", "toml", "tower", "tower-http", "tracing", - "tracing-subscriber", - "uuid", "x509-cert", - "zeroize", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c6ca521a..d86019ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,39 +11,26 @@ license = "AGPL-3.0" [dependencies] attestation = { path = "crates/attestation", features = ["sgx", "snp"] } +server = { path = "crates/server" } anyhow = { version = "^1.0.66", default-features = false } axum = { version = "^0.5.17", features = ["headers"], default-features = false } -base64 = { version = "^0.13.1", default-features = false } clap = { version = "^4.0.29", features = ["help", "usage", "error-context", "std", "derive", "env"], default-features = false } confargs = { version = "^0.1.3", default-features = false } -const-oid = { version = "0.9.1", features = ["db"], default-features = false } -der = { version = "0.6", features = ["std"], default-features = false } -flagset = { version = "0.4.3", default-features = false} -hyper = { git = "https://github.com/rjzak/hyper", branch = "wasi_wip", features = ["http1", "server"], default-features = false } -mime = { version = "^0.3.16", default-features = false } -rustls-pemfile = {version = "1.0.1", default-features = false } -sec1 = { version = "0.3", features = ["std", "pkcs8"], default-features = false } -serde = { version = "1.0", features = ["derive"], default-features = false } tokio = { version = "^1.23.0", features = ["rt", "macros"], default-features = false } -toml = { version = "0.5", default-features = false } tower-http = { version = "^0.3.5", features = ["trace"], default-features = false } -tracing-subscriber = { version="^0.3.15", features = ["env-filter", "json", "fmt"], default-features = false } tracing = { version = "^0.1.29", default-features = false } -uuid = { version = "^1.2.2", features = ["v4"], default-features = false } -x509 = { version = "0.1", features = ["std"], package = "x509-cert", default-features = false } -zeroize = { version = "^1.5.2", features = ["alloc"], default-features = false } [target.'cfg(not(target_os = "wasi"))'.dependencies] tokio = { version = "^1.23.0", features = ["rt-multi-thread", "macros"], default-features = false } [dev-dependencies] axum = { version = "^0.5.17", default-features = false } +der = { version = "0.6", features = ["std"], default-features = false } http = { version = "^0.2.6", default-features = false } -memoffset = { version = "0.7.1", default-features = false } -rstest = { version = "0.16", default-features = false } sgx = { version = "0.6.0", default-features = false } +toml = { version = "0.5", default-features = false } tower = { version = "^0.4.11", features = ["util"], default-features = false } -testaso = { version = "0.1", default-features = false } +x509 = { version = "0.1", features = ["std"], package = "x509-cert", default-features = false } [profile.release] incremental = false @@ -55,4 +42,5 @@ strip = true resolver = '2' members = [ 'crates/attestation', + 'crates/server', ] diff --git a/crates/attestation/src/lib.rs b/crates/attestation/src/lib.rs index fb58aff3..d20fede1 100644 --- a/crates/attestation/src/lib.rs +++ b/crates/attestation/src/lib.rs @@ -7,12 +7,13 @@ pub mod sgx; #[cfg(feature = "snp")] pub mod snp; -use serde::{Deserialize, Deserializer}; use std::borrow::Borrow; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::ops::Deref; +use serde::{Deserialize, Deserializer}; + /// Digest generic in hash size `N` #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct Digest(pub [u8; N]); diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml new file mode 100644 index 00000000..4e005afe --- /dev/null +++ b/crates/server/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "server" +version = "0.2.0" +edition = "2021" +license = "AGPL-3.0" +description = "Server library for Steward" + +[dependencies] +attestation = { path = "../../crates/attestation", features = ["sgx", "snp"] } +anyhow = { version = "^1.0.66", default-features = false } +axum = { version = "^0.5.17", features = ["headers"], default-features = false } +const-oid = { version = "0.9.1", features = ["db"], default-features = false } +der = { version = "0.6", features = ["std"], default-features = false } +hyper = { git = "https://github.com/rjzak/hyper", branch = "wasi_wip", features = ["http1", "server"], default-features = false } +rustls-pemfile = {version = "1.0.1", default-features = false } +sec1 = { version = "0.3", features = ["std", "pkcs8"], default-features = false } +serde = { version = "1.0", features = ["derive"], default-features = false } +tokio = { version = "^1.23.0", features = ["rt", "macros"], default-features = false } +toml = { version = "0.5", default-features = false } +tower-http = { version = "^0.3.5", features = ["trace"], default-features = false } +tracing = { version = "^0.1.29", default-features = false } +tracing-subscriber = { version="^0.3.15", features = ["env-filter", "json", "fmt"], default-features = false } +uuid = { version = "^1.2.2", features = ["v4"], default-features = false } +x509 = { version = "0.1", features = ["std"], package = "x509-cert", default-features = false } +zeroize = { version = "^1.5.2", features = ["alloc"], default-features = false } + +[dev-dependencies] +axum = { version = "^0.5.17", default-features = false } +http = { version = "^0.2.6", default-features = false } +memoffset = { version = "0.7.1", default-features = false } +rstest = { version = "0.16", default-features = false } +sgx = { version = "0.6.0", default-features = false } +tower = { version = "^0.4.11", features = ["util"], default-features = false } +testaso = { version = "0.1", default-features = false } diff --git a/src/kvm.rs b/crates/server/src/kvm.rs similarity index 100% rename from src/kvm.rs rename to crates/server/src/kvm.rs diff --git a/crates/server/src/lib.rs b/crates/server/src/lib.rs new file mode 100644 index 00000000..67535621 --- /dev/null +++ b/crates/server/src/lib.rs @@ -0,0 +1,844 @@ +// SPDX-FileCopyrightText: 2022 Profian Inc. +// SPDX-License-Identifier: AGPL-3.0-only + +#![warn(rust_2018_idioms, unused_lifetimes, unused_qualifications, clippy::all)] + +mod kvm; + +use attestation::crypto::{CertReqExt, PrivateKeyInfoExt, TbsCertificateExt}; +use attestation::sgx::Sgx; +use attestation::snp::Snp; +use kvm::Kvm; + +use std::io::BufRead; +use std::path::Path; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; + +use anyhow::{anyhow, Context}; +use axum::body::Bytes; +use axum::extract::{Extension, TypedHeader}; +use axum::headers::ContentType; +use axum::response::IntoResponse; +use axum::routing::{get, post}; +use axum::Router; +use const_oid::db::rfc5280::{ + ID_CE_BASIC_CONSTRAINTS, ID_CE_EXT_KEY_USAGE, ID_CE_KEY_USAGE, ID_CE_SUBJECT_ALT_NAME, + ID_KP_CLIENT_AUTH, ID_KP_SERVER_AUTH, +}; +use const_oid::db::rfc5912::ID_EXTENSION_REQ; +use der::asn1::{GeneralizedTime, Ia5StringRef, UIntRef}; +use der::{Decode, Encode, Sequence}; +use hyper::StatusCode; +use sec1::pkcs8::PrivateKeyInfo; +use serde::Deserialize; +use tower_http::trace::{ + DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, + TraceLayer, +}; +use tower_http::LatencyUnit; +use tracing::{debug, Level}; +use x509::attr::Attribute; +use x509::ext::pkix::name::GeneralName; +use x509::ext::pkix::{BasicConstraints, ExtendedKeyUsage, KeyUsage, KeyUsages, SubjectAltName}; +use x509::name::RdnSequence; +use x509::request::{CertReq, ExtensionReq}; +use x509::time::{Time, Validity}; +use x509::{Certificate, TbsCertificate}; +use zeroize::Zeroizing; + +pub const PKCS10: &str = "application/pkcs10"; +pub const BUNDLE: &str = "application/vnd.steward.pkcs10-bundle.v1"; + +#[derive(Clone, Deserialize, Debug, Default, Eq, PartialEq)] +pub struct Config { + pub sgx: Option, + pub snp: Option, +} + +#[derive(Clone, Debug)] +pub struct State { + key: Zeroizing>, + pub crt: Vec, + san: Option, + config: Config, +} + +/// ASN.1 +/// Output ::= SEQUENCE { +/// chain SEQUENCE OF Certificate, +/// issued SEQUENCE OF Certificate, +/// } +#[derive(Clone, Debug, Default, Sequence)] +pub struct Output<'a> { + /// The signing certificate chain back to the root. + pub chain: Vec>, + + /// All issued certificates. + pub issued: Vec>, +} + +impl State { + pub fn load( + san: Option, + key: impl AsRef, + crt: impl AsRef, + config: Option, + ) -> anyhow::Result { + // Load the key file. + let key = std::io::BufReader::new(std::fs::File::open(key)?); + + // Load the crt file. + let crt = std::io::BufReader::new(std::fs::File::open(crt)?); + + Self::read(san, key, crt, config) + } + + pub fn read( + san: Option, + mut key: impl BufRead, + mut crt: impl BufRead, + config: Option, + ) -> anyhow::Result { + let key = match rustls_pemfile::read_one(&mut key)? { + Some(rustls_pemfile::Item::PKCS8Key(buf)) => Zeroizing::new(buf), + _ => return Err(anyhow!("invalid key file")), + }; + + let crt = match rustls_pemfile::read_one(&mut crt)? { + Some(rustls_pemfile::Item::X509Certificate(buf)) => buf, + _ => return Err(anyhow!("invalid key file")), + }; + + // Validate the syntax of the files. + PrivateKeyInfo::from_der(key.as_ref())?; + Certificate::from_der(crt.as_ref())?; + + let config = if let Some(path) = config { + let config = std::fs::read_to_string(path).context("failed to read config file")?; + toml::from_str(&config).context("failed to parse config")? + } else { + Config::default() + }; + + Ok(State { + crt, + san, + key, + config, + }) + } + + pub fn generate(san: Option, hostname: &str) -> anyhow::Result { + use const_oid::db::rfc5912::SECP_256_R_1 as P256; + + // Generate the private key. + let key = PrivateKeyInfo::generate(P256)?; + let pki = PrivateKeyInfo::from_der(key.as_ref())?; + + // Create a relative distinguished name. + let rdns = RdnSequence::encode_from_string(&format!("CN={hostname}"))?; + let rdns = RdnSequence::from_der(&rdns)?; + + // Create the extensions. + let ku = KeyUsage(KeyUsages::KeyCertSign.into()).to_vec()?; + let bc = BasicConstraints { + ca: true, + path_len_constraint: Some(0), + } + .to_vec()?; + + // Create the certificate duration. + let now = SystemTime::now(); + let dur = Duration::from_secs(60 * 60 * 24 * 365); + let validity = Validity { + not_before: Time::GeneralTime(GeneralizedTime::from_system_time(now)?), + not_after: Time::GeneralTime(GeneralizedTime::from_system_time(now + dur)?), + }; + + // Create the certificate body. + let tbs = TbsCertificate { + version: x509::Version::V3, + serial_number: UIntRef::new(&[0u8])?, + signature: pki.signs_with()?, + issuer: rdns.clone(), + validity, + subject: rdns, + subject_public_key_info: pki.public_key()?, + issuer_unique_id: None, + subject_unique_id: None, + extensions: Some(vec![ + x509::ext::Extension { + extn_id: ID_CE_KEY_USAGE, + critical: true, + extn_value: &ku, + }, + x509::ext::Extension { + extn_id: ID_CE_BASIC_CONSTRAINTS, + critical: true, + extn_value: &bc, + }, + ]), + }; + + // Self-sign the certificate. + let crt = tbs.sign(&pki)?; + Ok(Self { + key, + crt, + san, + config: Default::default(), + }) + } +} + +#[derive(Debug, Clone, Default)] +struct SpanMaker; + +impl tower_http::trace::MakeSpan for SpanMaker { + fn make_span(&mut self, request: &axum::http::request::Request) -> tracing::span::Span { + let reqid = uuid::Uuid::new_v4(); + tracing::span!( + Level::INFO, + "request", + method = %request.method(), + uri = %request.uri(), + version = ?request.version(), + headers = ?request.headers(), + request_id = %reqid, + ) + } +} + +pub fn app(state: State) -> Router { + Router::new() + .route("/", post(attest)) + .route("/", get(health)) + .layer(Extension(Arc::new(state))) + .layer( + TraceLayer::new_for_http() + .make_span_with(SpanMaker::default()) + .on_request(DefaultOnRequest::new().level(Level::INFO)) + .on_response( + DefaultOnResponse::new() + .level(Level::INFO) + .latency_unit(LatencyUnit::Micros), + ) + .on_body_chunk(DefaultOnBodyChunk::new()) + .on_eos( + DefaultOnEos::new() + .level(Level::INFO) + .latency_unit(LatencyUnit::Micros), + ) + .on_failure( + DefaultOnFailure::new() + .level(Level::INFO) + .latency_unit(LatencyUnit::Micros), + ), + ) +} + +async fn health() -> StatusCode { + StatusCode::OK +} + +fn attest_request( + issuer: &Certificate<'_>, + pki: &PrivateKeyInfo<'_>, + sans: SubjectAltName<'_>, + cr: CertReq<'_>, + validity: &Validity, + state: &State, +) -> Result, StatusCode> { + let info = cr.verify().map_err(|e| { + debug!("failed to verify certificate info: {e}"); + StatusCode::BAD_REQUEST + })?; + + let mut extensions = Vec::new(); + let mut attested = false; + for Attribute { oid, values } in info.attributes.iter() { + if *oid != ID_EXTENSION_REQ { + debug!("invalid extension {oid}"); + return Err(StatusCode::BAD_REQUEST); + } + for any in values.iter() { + let ereq: ExtensionReq<'_> = any.decode_into().map_err(|e| { + debug!("failed to decode extension request: {e}"); + StatusCode::BAD_REQUEST + })?; + for ext in Vec::from(ereq) { + // If the issuer is self-signed, we are in debug mode. + let iss = &issuer.tbs_certificate; + let dbg = iss.issuer_unique_id == iss.subject_unique_id; + let dbg = dbg && iss.issuer == iss.subject; + + // Validate the extension. + let (copy, att) = match ext.extn_id { + Kvm::OID => (Kvm::default().verify(&info, &ext, dbg), Kvm::ATT), + Sgx::OID => ( + Sgx::default().verify(&info, &ext, state.config.sgx.as_ref(), dbg), + Sgx::ATT, + ), + Snp::OID => ( + Snp::default().verify(&info, &ext, state.config.snp.as_ref(), dbg), + Snp::ATT, + ), + oid => { + debug!("extension `{oid}` is unsupported"); + return Err(StatusCode::BAD_REQUEST); + } + }; + let copy = copy.map_err(|e| { + debug!("extension validation failed: {e}"); + StatusCode::BAD_REQUEST + })?; + + // Save results. + attested |= att; + if copy { + extensions.push(ext); + } + } + } + } + if !attested { + debug!("attestation failed"); + return Err(StatusCode::UNAUTHORIZED); + } + + // Add Subject Alternative Name + let sans: Vec = sans.to_vec().or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; + extensions.push(x509::ext::Extension { + extn_id: ID_CE_SUBJECT_ALT_NAME, + critical: false, + extn_value: &sans, + }); + + // Add extended key usage. + let eku = ExtendedKeyUsage(vec![ID_KP_SERVER_AUTH, ID_KP_CLIENT_AUTH]) + .to_vec() + .or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; + extensions.push(x509::ext::Extension { + extn_id: ID_CE_EXT_KEY_USAGE, + critical: false, + extn_value: &eku, + }); + + // Generate the instance id. + let uuid = uuid::Uuid::new_v4(); + let serial_number = UIntRef::new(uuid.as_bytes()).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; + + let signature = pki + .signs_with() + .or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; + + // Create and sign the new certificate. + TbsCertificate { + version: x509::Version::V3, + serial_number, + signature, + issuer: issuer.tbs_certificate.subject.clone(), + validity: *validity, + subject: RdnSequence(Vec::new()), + subject_public_key_info: info.public_key, + issuer_unique_id: issuer.tbs_certificate.subject_unique_id, + subject_unique_id: None, + extensions: Some(extensions), + } + .sign(pki) + .or(Err(StatusCode::INTERNAL_SERVER_ERROR)) +} + +/// Receives: +/// ASN.1 SEQUENCE OF CertRequest. +/// Returns: +/// ASN.1 SEQUENCE OF Output. +pub async fn attest( + TypedHeader(ct): TypedHeader, + body: Bytes, + Extension(state): Extension>, +) -> Result, impl IntoResponse> { + // Decode the signing certificate and key. + let issuer = Certificate::from_der(&state.crt).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; + let isskey = PrivateKeyInfo::from_der(&state.key).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; + + const TTL: Duration = Duration::from_secs(60 * 60 * 24 * 28); + let now = SystemTime::now(); + let end = now + TTL; + let validity = Validity { + not_before: Time::try_from(now).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?, + not_after: Time::try_from(end).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?, + }; + + // Check for correct mime type. + let reqs = match ct.to_string().as_ref() { + PKCS10 => vec![CertReq::from_der(body.as_ref()).or(Err(StatusCode::BAD_REQUEST))?], + BUNDLE => Vec::from_der(body.as_ref()).or(Err(StatusCode::BAD_REQUEST))?, + _ => return Err(StatusCode::BAD_REQUEST), + }; + + // Decode and verify the certification requests. + reqs.into_iter() + .map(|cr| { + // Create the basic subject alt name. + let name = Ia5StringRef::new("foo.bar.hub.profian.com") + .or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; + let mut sans = vec![GeneralName::DnsName(name)]; + + // Optionally, add the configured subject alt name. + if let Some(name) = &state.san { + let name = Ia5StringRef::new(name).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; + sans.push(GeneralName::DnsName(name)); + } + attest_request( + &issuer, + &isskey, + SubjectAltName(sans), + cr, + &validity, + &state, + ) + }) + .collect::, _>>() + .and_then(|issued| { + let issued: Vec> = issued + .iter() + .map(|c| Certificate::from_der(c).or(Err(StatusCode::INTERNAL_SERVER_ERROR))) + .collect::>()?; + + match ct.to_string().as_ref() { + PKCS10 => vec![issuer, issued[0].clone()].to_vec(), + BUNDLE => Output { + chain: vec![issuer], + issued, + } + .to_vec(), + _ => return Err(StatusCode::BAD_REQUEST), + } + .or(Err(StatusCode::INTERNAL_SERVER_ERROR)) + }) +} + +pub fn init_tracing() { + if std::env::var("RUST_LOG_JSON").is_ok() { + tracing_subscriber::fmt::fmt() + .json() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + } else { + tracing_subscriber::fmt::init(); + } +} + +#[cfg(test)] +mod tests { + use super::init_tracing; + use std::sync::Once; + + static TRACING: Once = Once::new(); + + mod attest { + use super::{init_tracing, TRACING}; + use crate::*; + use attestation::crypto::CertReqInfoExt; + use attestation::sgx::Sgx; + use attestation::snp::{Evidence, Snp}; + use const_oid::db::rfc5912::{ID_EXTENSION_REQ, SECP_256_R_1, SECP_384_R_1}; + use const_oid::ObjectIdentifier; + use der::{AnyRef, Encode}; + use kvm::Kvm; + use x509::attr::Attribute; + use x509::request::{CertReq, CertReqInfo, ExtensionReq}; + use x509::PkiPath; + use x509::{ext::Extension, name::RdnSequence}; + + use axum::response::Response; + use http::header::CONTENT_TYPE; + use http::Request; + use hyper::Body; + use rstest::rstest; + use tower::ServiceExt; // for `app.oneshot()` + + fn certificates_state() -> State { + #[cfg(not(target_os = "wasi"))] + return State::load(None, "../../testdata/ca.key", "../../testdata/ca.crt", None) + .expect("failed to load state"); + #[cfg(target_os = "wasi")] + { + let crt = + std::io::BufReader::new(include_bytes!("../../testdata/ca.crt").as_slice()); + let key = + std::io::BufReader::new(include_bytes!("../../testdata/ca.key").as_slice()); + + State::read(None, key, crt, None).expect("failed to load state") + } + } + + fn hostname_state() -> State { + State::generate(None, "localhost").unwrap() + } + + fn cr(curve: ObjectIdentifier, exts: Vec>, multi: bool) -> Vec { + let pki = PrivateKeyInfo::generate(curve).unwrap(); + let pki = PrivateKeyInfo::from_der(pki.as_ref()).unwrap(); + let spki = pki.public_key().unwrap(); + + let req = ExtensionReq::from(exts).to_vec().unwrap(); + let any = AnyRef::from_der(&req).unwrap(); + let att = Attribute { + oid: ID_EXTENSION_REQ, + values: vec![any].try_into().unwrap(), + }; + + // Create a certification request information structure. + let cri = CertReqInfo { + version: x509::request::Version::V1, + attributes: vec![att].try_into().unwrap(), + subject: RdnSequence::default(), + public_key: spki, + }; + + // Sign the request. + let signed = cri.sign(&pki).unwrap(); + if multi { + vec![CertReq::from_der(&signed).unwrap()].to_vec().unwrap() + } else { + signed + } + } + + async fn attest_response(state: State, response: Response, multi: bool) { + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + + let path = if multi { + let response = Output::from_der(body.as_ref()).unwrap(); + let mut path = response.chain; + path.push(response.issued[0].clone()); + path + } else { + PkiPath::from_der(&body).unwrap() + }; + + let issr = Certificate::from_der(&state.crt).unwrap(); + assert_eq!(2, path.len()); + assert_eq!(issr, path[0]); + issr.tbs_certificate.verify_crt(&path[1]).unwrap(); + } + + #[test] + fn reencode_multi() { + let encoded = cr(SECP_256_R_1, vec![], true); + let crs = Vec::>::from_der(&encoded).unwrap(); + assert_eq!(crs.len(), 1); + + let encoded: Vec = crs[0].to_vec().unwrap(); + let decoded = CertReq::from_der(&encoded).unwrap(); + let reencoded: Vec = decoded.to_vec().unwrap(); + assert_eq!(encoded, reencoded); + } + + #[test] + fn reencode_single() { + let encoded = cr(SECP_256_R_1, vec![], false); + let decoded = CertReq::from_der(&encoded).unwrap(); + let reencoded = decoded.to_vec().unwrap(); + assert_eq!(encoded, reencoded); + } + + #[rstest] + #[case(PKCS10, false)] + #[case(BUNDLE, true)] + #[tokio::test] + async fn kvm_certs(#[case] header: &str, #[case] multi: bool) { + let ext = Extension { + extn_id: Kvm::OID, + critical: false, + extn_value: &[], + }; + + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, header) + .body(Body::from(cr(SECP_256_R_1, vec![ext], multi))) + .unwrap(); + + let response = app(certificates_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + attest_response(certificates_state(), response, multi).await; + } + + #[rstest] + #[case(PKCS10, false)] + #[case(BUNDLE, true)] + #[tokio::test] + async fn kvm_hostname(#[case] header: &str, #[case] multi: bool) { + let ext = Extension { + extn_id: Kvm::OID, + critical: false, + extn_value: &[], + }; + + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, header) + .body(Body::from(cr(SECP_256_R_1, vec![ext], multi))) + .unwrap(); + + let state = hostname_state(); + let response = app(state.clone()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + attest_response(state, response, multi).await; + } + + // Though similar to the above test, this is the only test which + // actually sends many CSRs, versus an array of just one CSR. + #[tokio::test] + async fn kvm_hostname_many_certs() { + let ext = Extension { + extn_id: Kvm::OID, + critical: false, + extn_value: &[], + }; + + let one_cr_bytes = cr(SECP_256_R_1, vec![ext], true); + let crs = Vec::>::from_der(&one_cr_bytes).unwrap(); + assert_eq!(crs.len(), 1); + + let five_crs = vec![ + crs[0].clone(), + crs[0].clone(), + crs[0].clone(), + crs[0].clone(), + crs[0].clone(), + ]; + + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, BUNDLE) + .body(Body::from(five_crs.to_vec().unwrap())) + .unwrap(); + + let state = hostname_state(); + let response = app(state.clone()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let output = Output::from_der(body.as_ref()).unwrap(); + + assert_eq!(output.issued.len(), five_crs.len()); + } + + #[rstest] + #[case(PKCS10, false)] + #[case(BUNDLE, true)] + #[tokio::test] + async fn sgx_certs(#[case] header: &str, #[case] multi: bool) { + TRACING.call_once(init_tracing); + for quote in [ + include_bytes!("../../attestation/src/sgx/quote.unknown").as_slice(), + include_bytes!("../../attestation/src/sgx/quote.icelake").as_slice(), + ] { + let ext = Extension { + extn_id: Sgx::OID, + critical: false, + extn_value: quote, + }; + + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, header) + .body(Body::from(cr(SECP_256_R_1, vec![ext], multi))) + .unwrap(); + + let response = app(certificates_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + attest_response(certificates_state(), response, multi).await; + } + } + + #[rstest] + #[case(PKCS10, false)] + #[case(BUNDLE, true)] + #[tokio::test] + async fn sgx_hostname(#[case] header: &str, #[case] multi: bool) { + TRACING.call_once(init_tracing); + for quote in [ + include_bytes!("../../attestation/src/sgx/quote.unknown").as_slice(), + include_bytes!("../../attestation/src/sgx/quote.icelake").as_slice(), + ] { + let ext = Extension { + extn_id: Sgx::OID, + critical: false, + extn_value: quote, + }; + + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, header) + .body(Body::from(cr(SECP_256_R_1, vec![ext], multi))) + .unwrap(); + + let state = hostname_state(); + let response = app(state.clone()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + attest_response(state, response, multi).await; + } + } + + #[rstest] + #[case(PKCS10, false)] + #[case(BUNDLE, true)] + #[tokio::test] + async fn snp_certs(#[case] header: &str, #[case] multi: bool) { + TRACING.call_once(init_tracing); + let evidence = Evidence { + vcek: Certificate::from_der(include_bytes!("../../attestation/src/snp/milan.vcek")) + .unwrap(), + report: include_bytes!("../../attestation/src/snp/milan.rprt"), + } + .to_vec() + .unwrap(); + + let ext = Extension { + extn_id: Snp::OID, + critical: false, + extn_value: &evidence, + }; + + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, header) + .body(Body::from(cr(SECP_384_R_1, vec![ext], multi))) + .unwrap(); + + let response = app(certificates_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + attest_response(certificates_state(), response, multi).await; + } + + #[rstest] + #[case(PKCS10, false)] + #[case(BUNDLE, true)] + #[tokio::test] + async fn snp_hostname(#[case] header: &str, #[case] multi: bool) { + TRACING.call_once(init_tracing); + let evidence = Evidence { + vcek: Certificate::from_der(include_bytes!("../../attestation/src/snp/milan.vcek")) + .unwrap(), + report: include_bytes!("../../attestation/src/snp/milan.rprt"), + } + .to_vec() + .unwrap(); + + let ext = Extension { + extn_id: Snp::OID, + critical: false, + extn_value: &evidence, + }; + + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, header) + .body(Body::from(cr(SECP_384_R_1, vec![ext], multi))) + .unwrap(); + + let state = hostname_state(); + let response = app(state.clone()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + attest_response(state, response, multi).await; + } + + #[rstest] + #[case(PKCS10, false)] + #[case(BUNDLE, true)] + #[tokio::test] + async fn err_no_attestation_certs(#[case] header: &str, #[case] multi: bool) { + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, header) + .body(Body::from(cr(SECP_256_R_1, vec![], multi))) + .unwrap(); + + let response = app(certificates_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn err_no_attestation_hostname() { + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, BUNDLE) + .body(Body::from(cr(SECP_256_R_1, vec![], true))) + .unwrap(); + + let response = app(hostname_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + #[rstest] + #[case(false)] + #[case(true)] + #[tokio::test] + async fn err_no_content_type(#[case] multi: bool) { + let request = Request::builder() + .method("POST") + .uri("/") + .body(Body::from(cr(SECP_256_R_1, vec![], multi))) + .unwrap(); + + let response = app(certificates_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn err_empty_body() { + let request = Request::builder() + .method("POST") + .uri("/") + .body(Body::empty()) + .unwrap(); + + let response = app(certificates_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn err_bad_body() { + let request = Request::builder() + .method("POST") + .uri("/") + .body(Body::from(vec![0x01, 0x02, 0x03, 0x04])) + .unwrap(); + + let response = app(certificates_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn err_bad_csr_sig() { + let mut cr = cr(SECP_256_R_1, vec![], true); + let last = cr.last_mut().unwrap(); + *last = last.wrapping_add(1); // Modify the signature... + + let request = Request::builder() + .method("POST") + .uri("/") + .header(CONTENT_TYPE, BUNDLE) + .body(Body::from(cr)) + .unwrap(); + + let response = app(certificates_state()).oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + } +} diff --git a/src/main.rs b/src/main.rs index effa8ad5..4656c8ce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,57 +3,14 @@ #![warn(rust_2018_idioms, unused_lifetimes, unused_qualifications, clippy::all)] -#[macro_use] -extern crate anyhow; -mod kvm; +use server::{app, init_tracing, State}; -use attestation::crypto::{CertReqExt, PrivateKeyInfoExt, TbsCertificateExt}; -use attestation::sgx::Sgx; -use attestation::snp::Snp; -use kvm::Kvm; - -use std::io::BufRead; use std::net::IpAddr; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::path::PathBuf; -use anyhow::Context; -use axum::body::Bytes; -use axum::extract::{Extension, TypedHeader}; -use axum::headers::ContentType; -use axum::response::IntoResponse; -use axum::routing::{get, post}; -use axum::Router; +use anyhow::{anyhow, Context}; use clap::Parser; use confargs::{prefix_char_filter, Toml}; -use const_oid::db::rfc5280::{ - ID_CE_BASIC_CONSTRAINTS, ID_CE_EXT_KEY_USAGE, ID_CE_KEY_USAGE, ID_CE_SUBJECT_ALT_NAME, - ID_KP_CLIENT_AUTH, ID_KP_SERVER_AUTH, -}; -use const_oid::db::rfc5912::ID_EXTENSION_REQ; -use der::asn1::{GeneralizedTime, Ia5StringRef, UIntRef}; -use der::{Decode, Encode, Sequence}; -use hyper::StatusCode; -use sec1::pkcs8::PrivateKeyInfo; -use serde::Deserialize; -use tower_http::trace::{ - DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, - TraceLayer, -}; -use tower_http::LatencyUnit; -use tracing::{debug, Level}; -use x509::attr::Attribute; -use x509::ext::pkix::name::GeneralName; -use x509::ext::pkix::{BasicConstraints, ExtendedKeyUsage, KeyUsage, KeyUsages, SubjectAltName}; -use x509::name::RdnSequence; -use x509::request::{CertReq, ExtensionReq}; -use x509::time::{Time, Validity}; -use x509::{Certificate, TbsCertificate}; -use zeroize::Zeroizing; - -const PKCS10: &str = "application/pkcs10"; -const BUNDLE: &str = "application/vnd.steward.pkcs10-bundle.v1"; /// Attestation server for use with Enarx. /// @@ -87,160 +44,10 @@ struct Args { config: Option, } -#[derive(Clone, Deserialize, Debug, Default, Eq, PartialEq)] -struct Config { - sgx: Option, - snp: Option, -} - -#[derive(Debug)] -#[cfg_attr(test, derive(Clone))] -struct State { - key: Zeroizing>, - crt: Vec, - san: Option, - config: Config, -} - -/// ASN.1 -/// Output ::= SEQUENCE { -/// chain SEQUENCE OF Certificate, -/// issued SEQUENCE OF Certificate, -/// } -#[derive(Clone, Debug, Default, Sequence)] -struct Output<'a> { - /// The signing certificate chain back to the root. - pub chain: Vec>, - - /// All issued certificates. - pub issued: Vec>, -} - -impl State { - pub fn load( - san: Option, - key: impl AsRef, - crt: impl AsRef, - config: Option, - ) -> anyhow::Result { - // Load the key file. - let key = std::io::BufReader::new(std::fs::File::open(key)?); - - // Load the crt file. - let crt = std::io::BufReader::new(std::fs::File::open(crt)?); - - Self::read(san, key, crt, config) - } - - pub fn read( - san: Option, - mut key: impl BufRead, - mut crt: impl BufRead, - config: Option, - ) -> anyhow::Result { - let key = match rustls_pemfile::read_one(&mut key)? { - Some(rustls_pemfile::Item::PKCS8Key(buf)) => Zeroizing::new(buf), - _ => return Err(anyhow!("invalid key file")), - }; - - let crt = match rustls_pemfile::read_one(&mut crt)? { - Some(rustls_pemfile::Item::X509Certificate(buf)) => buf, - _ => return Err(anyhow!("invalid key file")), - }; - - // Validate the syntax of the files. - PrivateKeyInfo::from_der(key.as_ref())?; - Certificate::from_der(crt.as_ref())?; - - let config = if let Some(path) = config { - let config = std::fs::read_to_string(path).context("failed to read config file")?; - toml::from_str(&config).context("failed to parse config")? - } else { - Config::default() - }; - - Ok(State { - crt, - san, - key, - config, - }) - } - - pub fn generate(san: Option, hostname: &str) -> anyhow::Result { - use const_oid::db::rfc5912::SECP_256_R_1 as P256; - - // Generate the private key. - let key = PrivateKeyInfo::generate(P256)?; - let pki = PrivateKeyInfo::from_der(key.as_ref())?; - - // Create a relative distinguished name. - let rdns = RdnSequence::encode_from_string(&format!("CN={hostname}"))?; - let rdns = RdnSequence::from_der(&rdns)?; - - // Create the extensions. - let ku = KeyUsage(KeyUsages::KeyCertSign.into()).to_vec()?; - let bc = BasicConstraints { - ca: true, - path_len_constraint: Some(0), - } - .to_vec()?; - - // Create the certificate duration. - let now = SystemTime::now(); - let dur = Duration::from_secs(60 * 60 * 24 * 365); - let validity = Validity { - not_before: Time::GeneralTime(GeneralizedTime::from_system_time(now)?), - not_after: Time::GeneralTime(GeneralizedTime::from_system_time(now + dur)?), - }; - - // Create the certificate body. - let tbs = TbsCertificate { - version: x509::Version::V3, - serial_number: UIntRef::new(&[0u8])?, - signature: pki.signs_with()?, - issuer: rdns.clone(), - validity, - subject: rdns, - subject_public_key_info: pki.public_key()?, - issuer_unique_id: None, - subject_unique_id: None, - extensions: Some(vec![ - x509::ext::Extension { - extn_id: ID_CE_KEY_USAGE, - critical: true, - extn_value: &ku, - }, - x509::ext::Extension { - extn_id: ID_CE_BASIC_CONSTRAINTS, - critical: true, - extn_value: &bc, - }, - ]), - }; - - // Self-sign the certificate. - let crt = tbs.sign(&pki)?; - Ok(Self { - key, - crt, - san, - config: Default::default(), - }) - } -} - #[cfg_attr(not(target_os = "wasi"), tokio::main)] #[cfg_attr(target_os = "wasi", tokio::main(flavor = "current_thread"))] async fn main() -> anyhow::Result<()> { - if std::env::var("RUST_LOG_JSON").is_ok() { - tracing_subscriber::fmt::fmt() - .json() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .init(); - } else { - tracing_subscriber::fmt::init(); - } + init_tracing(); let args = confargs::args::(prefix_char_filter::<'@'>) .context("Failed to parse config") @@ -280,678 +87,21 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -#[derive(Debug, Clone, Default)] -struct SpanMaker; - -impl tower_http::trace::MakeSpan for SpanMaker { - fn make_span(&mut self, request: &axum::http::request::Request) -> tracing::span::Span { - let reqid = uuid::Uuid::new_v4(); - tracing::span!( - Level::INFO, - "request", - method = %request.method(), - uri = %request.uri(), - version = ?request.version(), - headers = ?request.headers(), - request_id = %reqid, - ) - } -} - -fn app(state: State) -> Router { - Router::new() - .route("/", post(attest)) - .route("/", get(health)) - .layer(Extension(Arc::new(state))) - .layer( - TraceLayer::new_for_http() - .make_span_with(SpanMaker::default()) - .on_request(DefaultOnRequest::new().level(Level::INFO)) - .on_response( - DefaultOnResponse::new() - .level(Level::INFO) - .latency_unit(LatencyUnit::Micros), - ) - .on_body_chunk(DefaultOnBodyChunk::new()) - .on_eos( - DefaultOnEos::new() - .level(Level::INFO) - .latency_unit(LatencyUnit::Micros), - ) - .on_failure( - DefaultOnFailure::new() - .level(Level::INFO) - .latency_unit(LatencyUnit::Micros), - ), - ) -} - -async fn health() -> StatusCode { - StatusCode::OK -} - -fn attest_request( - issuer: &Certificate<'_>, - pki: &PrivateKeyInfo<'_>, - sans: SubjectAltName<'_>, - cr: CertReq<'_>, - validity: &Validity, - state: &State, -) -> Result, StatusCode> { - let info = cr.verify().map_err(|e| { - debug!("failed to verify certificate info: {e}"); - StatusCode::BAD_REQUEST - })?; - - let mut extensions = Vec::new(); - let mut attested = false; - for Attribute { oid, values } in info.attributes.iter() { - if *oid != ID_EXTENSION_REQ { - debug!("invalid extension {oid}"); - return Err(StatusCode::BAD_REQUEST); - } - for any in values.iter() { - let ereq: ExtensionReq<'_> = any.decode_into().map_err(|e| { - debug!("failed to decode extension request: {e}"); - StatusCode::BAD_REQUEST - })?; - for ext in Vec::from(ereq) { - // If the issuer is self-signed, we are in debug mode. - let iss = &issuer.tbs_certificate; - let dbg = iss.issuer_unique_id == iss.subject_unique_id; - let dbg = dbg && iss.issuer == iss.subject; - - // Validate the extension. - let (copy, att) = match ext.extn_id { - Kvm::OID => (Kvm::default().verify(&info, &ext, dbg), Kvm::ATT), - Sgx::OID => ( - Sgx::default().verify(&info, &ext, state.config.sgx.as_ref(), dbg), - Sgx::ATT, - ), - Snp::OID => ( - Snp::default().verify(&info, &ext, state.config.snp.as_ref(), dbg), - Snp::ATT, - ), - oid => { - debug!("extension `{oid}` is unsupported"); - return Err(StatusCode::BAD_REQUEST); - } - }; - let copy = copy.map_err(|e| { - debug!("extension validation failed: {e}"); - StatusCode::BAD_REQUEST - })?; - - // Save results. - attested |= att; - if copy { - extensions.push(ext); - } - } - } - } - if !attested { - debug!("attestation failed"); - return Err(StatusCode::UNAUTHORIZED); - } - - // Add Subject Alternative Name - let sans: Vec = sans.to_vec().or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; - extensions.push(x509::ext::Extension { - extn_id: ID_CE_SUBJECT_ALT_NAME, - critical: false, - extn_value: &sans, - }); - - // Add extended key usage. - let eku = ExtendedKeyUsage(vec![ID_KP_SERVER_AUTH, ID_KP_CLIENT_AUTH]) - .to_vec() - .or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; - extensions.push(x509::ext::Extension { - extn_id: ID_CE_EXT_KEY_USAGE, - critical: false, - extn_value: &eku, - }); - - // Generate the instance id. - let uuid = uuid::Uuid::new_v4(); - let serial_number = UIntRef::new(uuid.as_bytes()).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; - - let signature = pki - .signs_with() - .or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; - - // Create and sign the new certificate. - TbsCertificate { - version: x509::Version::V3, - serial_number, - signature, - issuer: issuer.tbs_certificate.subject.clone(), - validity: *validity, - subject: RdnSequence(Vec::new()), - subject_public_key_info: info.public_key, - issuer_unique_id: issuer.tbs_certificate.subject_unique_id, - subject_unique_id: None, - extensions: Some(extensions), - } - .sign(pki) - .or(Err(StatusCode::INTERNAL_SERVER_ERROR)) -} - -/// Receives: -/// ASN.1 SEQUENCE OF CertRequest. -/// Returns: -/// ASN.1 SEQUENCE OF Output. -async fn attest( - TypedHeader(ct): TypedHeader, - body: Bytes, - Extension(state): Extension>, -) -> Result, impl IntoResponse> { - // Decode the signing certificate and key. - let issuer = Certificate::from_der(&state.crt).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; - let isskey = PrivateKeyInfo::from_der(&state.key).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; - - const TTL: Duration = Duration::from_secs(60 * 60 * 24 * 28); - let now = SystemTime::now(); - let end = now + TTL; - let validity = Validity { - not_before: Time::try_from(now).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?, - not_after: Time::try_from(end).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?, - }; - - // Check for correct mime type. - let reqs = match ct.to_string().as_ref() { - PKCS10 => vec![CertReq::from_der(body.as_ref()).or(Err(StatusCode::BAD_REQUEST))?], - BUNDLE => Vec::from_der(body.as_ref()).or(Err(StatusCode::BAD_REQUEST))?, - _ => return Err(StatusCode::BAD_REQUEST), - }; - - // Decode and verify the certification requests. - reqs.into_iter() - .map(|cr| { - // Create the basic subject alt name. - let name = Ia5StringRef::new("foo.bar.hub.profian.com") - .or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; - let mut sans = vec![GeneralName::DnsName(name)]; - - // Optionally, add the configured subject alt name. - if let Some(name) = &state.san { - let name = Ia5StringRef::new(name).or(Err(StatusCode::INTERNAL_SERVER_ERROR))?; - sans.push(GeneralName::DnsName(name)); - } - attest_request( - &issuer, - &isskey, - SubjectAltName(sans), - cr, - &validity, - &state, - ) - }) - .collect::, _>>() - .and_then(|issued| { - let issued: Vec> = issued - .iter() - .map(|c| Certificate::from_der(c).or(Err(StatusCode::INTERNAL_SERVER_ERROR))) - .collect::>()?; - - match ct.to_string().as_ref() { - PKCS10 => vec![issuer, issued[0].clone()].to_vec(), - BUNDLE => Output { - chain: vec![issuer], - issued, - } - .to_vec(), - _ => return Err(StatusCode::BAD_REQUEST), - } - .or(Err(StatusCode::INTERNAL_SERVER_ERROR)) - }) -} - #[cfg(test)] mod tests { - use std::sync::Once; - static TRACING: Once = Once::new(); - - pub fn init_tracing() { - TRACING.call_once(|| { - if std::env::var("RUST_LOG_JSON").is_ok() { - tracing_subscriber::fmt::fmt() - .json() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .init(); - } else { - tracing_subscriber::fmt::init(); - } - }); - } - - mod attest { - use super::init_tracing; - use crate::*; - use attestation::crypto::CertReqInfoExt; - use attestation::sgx::Sgx; - use attestation::snp::{Evidence, Snp}; - use const_oid::db::rfc5912::{ID_EXTENSION_REQ, SECP_256_R_1, SECP_384_R_1}; - use const_oid::ObjectIdentifier; - use der::{AnyRef, Encode}; - use kvm::Kvm; - use x509::attr::Attribute; - use x509::request::{CertReq, CertReqInfo, ExtensionReq}; - use x509::PkiPath; - use x509::{ext::Extension, name::RdnSequence}; - - use axum::response::Response; - use http::header::CONTENT_TYPE; - use http::Request; - use hyper::Body; - use rstest::rstest; - use tower::ServiceExt; // for `app.oneshot()` - - fn certificates_state() -> State { - #[cfg(not(target_os = "wasi"))] - return State::load(None, "testdata/ca.key", "testdata/ca.crt", None) - .expect("failed to load state"); - #[cfg(target_os = "wasi")] - { - let crt = std::io::BufReader::new(include_bytes!("../testdata/ca.crt").as_slice()); - let key = std::io::BufReader::new(include_bytes!("../testdata/ca.key").as_slice()); - - State::read(None, key, crt, None).expect("failed to load state") - } - } - - fn hostname_state() -> State { - State::generate(None, "localhost").unwrap() - } - - fn cr(curve: ObjectIdentifier, exts: Vec>, multi: bool) -> Vec { - let pki = PrivateKeyInfo::generate(curve).unwrap(); - let pki = PrivateKeyInfo::from_der(pki.as_ref()).unwrap(); - let spki = pki.public_key().unwrap(); - - let req = ExtensionReq::from(exts).to_vec().unwrap(); - let any = AnyRef::from_der(&req).unwrap(); - let att = Attribute { - oid: ID_EXTENSION_REQ, - values: vec![any].try_into().unwrap(), - }; - - // Create a certification request information structure. - let cri = CertReqInfo { - version: x509::request::Version::V1, - attributes: vec![att].try_into().unwrap(), - subject: RdnSequence::default(), - public_key: spki, - }; - - // Sign the request. - let signed = cri.sign(&pki).unwrap(); - if multi { - vec![CertReq::from_der(&signed).unwrap()].to_vec().unwrap() - } else { - signed - } - } - - async fn attest_response(state: State, response: Response, multi: bool) { - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - - let path = if multi { - let response = Output::from_der(body.as_ref()).unwrap(); - let mut path = response.chain; - path.push(response.issued[0].clone()); - path - } else { - PkiPath::from_der(&body).unwrap() - }; - - let issr = Certificate::from_der(&state.crt).unwrap(); - assert_eq!(2, path.len()); - assert_eq!(issr, path[0]); - issr.tbs_certificate.verify_crt(&path[1]).unwrap(); - } - - #[test] - fn reencode_multi() { - let encoded = cr(SECP_256_R_1, vec![], true); - let crs = Vec::>::from_der(&encoded).unwrap(); - assert_eq!(crs.len(), 1); - - let encoded: Vec = crs[0].to_vec().unwrap(); - let decoded = CertReq::from_der(&encoded).unwrap(); - let reencoded: Vec = decoded.to_vec().unwrap(); - assert_eq!(encoded, reencoded); - } - - #[test] - fn reencode_single() { - let encoded = cr(SECP_256_R_1, vec![], false); - let decoded = CertReq::from_der(&encoded).unwrap(); - let reencoded = decoded.to_vec().unwrap(); - assert_eq!(encoded, reencoded); - } - - #[rstest] - #[case(PKCS10, false)] - #[case(BUNDLE, true)] - #[tokio::test] - async fn kvm_certs(#[case] header: &str, #[case] multi: bool) { - init_tracing(); - let ext = Extension { - extn_id: Kvm::OID, - critical: false, - extn_value: &[], - }; - - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, header) - .body(Body::from(cr(SECP_256_R_1, vec![ext], multi))) - .unwrap(); - - let response = app(certificates_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - attest_response(certificates_state(), response, multi).await; - } - - #[rstest] - #[case(PKCS10, false)] - #[case(BUNDLE, true)] - #[tokio::test] - async fn kvm_hostname(#[case] header: &str, #[case] multi: bool) { - init_tracing(); - let ext = Extension { - extn_id: Kvm::OID, - critical: false, - extn_value: &[], - }; - - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, header) - .body(Body::from(cr(SECP_256_R_1, vec![ext], multi))) - .unwrap(); - - let state = hostname_state(); - let response = app(state.clone()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - attest_response(state, response, multi).await; - } - - // Though similar to the above test, this is the only test which - // actually sends many CSRs, versus an array of just one CSR. - #[tokio::test] - async fn kvm_hostname_many_certs() { - init_tracing(); - let ext = Extension { - extn_id: Kvm::OID, - critical: false, - extn_value: &[], - }; - - let one_cr_bytes = cr(SECP_256_R_1, vec![ext], true); - let crs = Vec::>::from_der(&one_cr_bytes).unwrap(); - assert_eq!(crs.len(), 1); - - let five_crs = vec![ - crs[0].clone(), - crs[0].clone(), - crs[0].clone(), - crs[0].clone(), - crs[0].clone(), - ]; - - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, BUNDLE) - .body(Body::from(five_crs.to_vec().unwrap())) - .unwrap(); - - let state = hostname_state(); - let response = app(state.clone()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let output = Output::from_der(body.as_ref()).unwrap(); - - assert_eq!(output.issued.len(), five_crs.len()); - } - - #[rstest] - #[case(PKCS10, false)] - #[case(BUNDLE, true)] - #[tokio::test] - async fn sgx_certs(#[case] header: &str, #[case] multi: bool) { - init_tracing(); - for quote in [ - include_bytes!("../crates/attestation/src/sgx/quote.unknown").as_slice(), - include_bytes!("../crates/attestation/src/sgx/quote.icelake").as_slice(), - ] { - let ext = Extension { - extn_id: Sgx::OID, - critical: false, - extn_value: quote, - }; - - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, header) - .body(Body::from(cr(SECP_256_R_1, vec![ext], multi))) - .unwrap(); - - let response = app(certificates_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - attest_response(certificates_state(), response, multi).await; - } - } - - #[rstest] - #[case(PKCS10, false)] - #[case(BUNDLE, true)] - #[tokio::test] - async fn sgx_hostname(#[case] header: &str, #[case] multi: bool) { - init_tracing(); - for quote in [ - include_bytes!("../crates/attestation/src/sgx/quote.unknown").as_slice(), - include_bytes!("../crates/attestation/src/sgx/quote.icelake").as_slice(), - ] { - let ext = Extension { - extn_id: Sgx::OID, - critical: false, - extn_value: quote, - }; - - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, header) - .body(Body::from(cr(SECP_256_R_1, vec![ext], multi))) - .unwrap(); - - let state = hostname_state(); - let response = app(state.clone()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - attest_response(state, response, multi).await; - } - } - - #[rstest] - #[case(PKCS10, false)] - #[case(BUNDLE, true)] - #[tokio::test] - async fn snp_certs(#[case] header: &str, #[case] multi: bool) { - init_tracing(); - let evidence = Evidence { - vcek: Certificate::from_der(include_bytes!( - "../crates/attestation/src/snp/milan.vcek" - )) - .unwrap(), - report: include_bytes!("../crates/attestation/src/snp/milan.rprt"), - } - .to_vec() - .unwrap(); - - let ext = Extension { - extn_id: Snp::OID, - critical: false, - extn_value: &evidence, - }; - - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, header) - .body(Body::from(cr(SECP_384_R_1, vec![ext], multi))) - .unwrap(); - - let response = app(certificates_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - attest_response(certificates_state(), response, multi).await; - } - - #[rstest] - #[case(PKCS10, false)] - #[case(BUNDLE, true)] - #[tokio::test] - async fn snp_hostname(#[case] header: &str, #[case] multi: bool) { - init_tracing(); - let evidence = Evidence { - vcek: Certificate::from_der(include_bytes!( - "../crates/attestation/src/snp/milan.vcek" - )) - .unwrap(), - report: include_bytes!("../crates/attestation/src/snp/milan.rprt"), - } - .to_vec() - .unwrap(); - - let ext = Extension { - extn_id: Snp::OID, - critical: false, - extn_value: &evidence, - }; - - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, header) - .body(Body::from(cr(SECP_384_R_1, vec![ext], multi))) - .unwrap(); - - let state = hostname_state(); - let response = app(state.clone()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - attest_response(state, response, multi).await; - } - - #[rstest] - #[case(PKCS10, false)] - #[case(BUNDLE, true)] - #[tokio::test] - async fn err_no_attestation_certs(#[case] header: &str, #[case] multi: bool) { - init_tracing(); - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, header) - .body(Body::from(cr(SECP_256_R_1, vec![], multi))) - .unwrap(); - - let response = app(certificates_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - } - - #[tokio::test] - async fn err_no_attestation_hostname() { - init_tracing(); - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, BUNDLE) - .body(Body::from(cr(SECP_256_R_1, vec![], true))) - .unwrap(); - - let response = app(hostname_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - } - - #[rstest] - #[case(false)] - #[case(true)] - #[tokio::test] - async fn err_no_content_type(#[case] multi: bool) { - init_tracing(); - let request = Request::builder() - .method("POST") - .uri("/") - .body(Body::from(cr(SECP_256_R_1, vec![], multi))) - .unwrap(); - - let response = app(certificates_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - } - - #[tokio::test] - async fn err_empty_body() { - init_tracing(); - let request = Request::builder() - .method("POST") - .uri("/") - .body(Body::empty()) - .unwrap(); - - let response = app(certificates_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - } - - #[tokio::test] - async fn err_bad_body() { - init_tracing(); - let request = Request::builder() - .method("POST") - .uri("/") - .body(Body::from(vec![0x01, 0x02, 0x03, 0x04])) - .unwrap(); - - let response = app(certificates_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - } - - #[tokio::test] - async fn err_bad_csr_sig() { - init_tracing(); - let mut cr = cr(SECP_256_R_1, vec![], true); - let last = cr.last_mut().unwrap(); - *last = last.wrapping_add(1); // Modify the signature... - - let request = Request::builder() - .method("POST") - .uri("/") - .header(CONTENT_TYPE, BUNDLE) - .body(Body::from(cr)) - .unwrap(); - - let response = app(certificates_state()).oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - } - } - // Unit tests for configuration mod config { - use super::init_tracing; - use crate::Config; use attestation::sgx::quote::traits::ParseBytes; use attestation::sgx::quote::Quote; use attestation::snp::{Evidence, PolicyFlags, Report, Snp}; use attestation::{Digest, Measurements}; + use server::{init_tracing, Config}; + + use std::collections::HashSet; + use std::sync::Once; + use der::Decode; use sgx::parameters::MiscSelect; - use std::collections::HashSet; use x509::attr::Attribute; use x509::request::{CertReq, ExtensionReq}; use x509::Certificate; @@ -960,12 +110,12 @@ mod tests { const ICELAKE_CSR: &[u8] = include_bytes!("../crates/attestation/src/sgx/icelake.signed.csr"); const MILAN_CSR: &[u8] = include_bytes!("../crates/attestation/src/snp/milan.signed.csr"); + static TRACING: Once = Once::new(); fn assert_sgx_config( csr: &CertReq<'_>, conf: &attestation::sgx::config::Config, ) -> anyhow::Result<()> { - init_tracing(); let sgx = attestation::sgx::Sgx::default(); #[allow(unused_variables)] @@ -994,7 +144,6 @@ mod tests { csr: &CertReq<'_>, conf: &attestation::snp::config::Config, ) -> anyhow::Result<()> { - init_tracing(); let snp = Snp::default(); #[allow(unused_variables)] for Attribute { oid, values } in csr.info.attributes.iter() { @@ -1021,6 +170,7 @@ mod tests { #[test] fn test_config() { + TRACING.call_once(init_tracing); let config: Config = toml::from_str( r#" [snp] @@ -1032,7 +182,7 @@ mod tests { signer = ["2eba0f494f428e799c22d6f12778aebea4dc8d991f9e63fd3cddd57ac6eb5dd9"] "#, ) - .expect("Couldn't deserialize"); + .expect("Couldn't deserialize"); let snp = attestation::snp::config::Config { measurements: Measurements { @@ -1084,6 +234,7 @@ mod tests { #[test] fn test_sgx_signed_canned_csr() { + TRACING.call_once(init_tracing); let csr = CertReq::from_der(ICELAKE_CSR).unwrap(); let config: Config = toml::from_str(DEFAULT_CONFIG).expect("Couldn't deserialize"); assert_sgx_config(&csr, &config.sgx.unwrap()).unwrap(); @@ -1091,6 +242,7 @@ mod tests { #[test] fn test_sgx_signed_csr_bad_config_signer() { + TRACING.call_once(init_tracing); let csr = CertReq::from_der(ICELAKE_CSR).unwrap(); let config: Config = toml::from_str( r#" @@ -1120,6 +272,7 @@ mod tests { #[test] fn test_snp_signed_canned_csr() { + TRACING.call_once(init_tracing); let csr = CertReq::from_der(MILAN_CSR).unwrap(); let config: Config = toml::from_str(DEFAULT_CONFIG).expect("Couldn't deserialize"); assert!(assert_snp_config(&csr, &config.snp.unwrap()).is_ok()); @@ -1135,8 +288,7 @@ mod tests { policy_flags = ["SingleSocket", "Debug"] signer = ["e368c18e60842db9325778532dd81594d732078bf01aa91686be40333da639e08733b910bd057bdda50d715968b075ce"] "#, - ) - .expect("Couldn't deserialize"); + ).expect("Couldn't deserialize"); assert!(assert_snp_config(&csr, &config.snp.unwrap()).is_err()); } @@ -1151,8 +303,7 @@ mod tests { signer = ["5b2181f5e2294fa0709d22b3f85d9d88b287b897c6b7289004802b53bbf09bc50f5469f98a6d6718d5f9c918d3d3c16f"] abi = ">254.0" "#, - ) - .expect("Couldn't deserialize"); + ).expect("Couldn't deserialize"); assert!(assert_snp_config(&csr, &config.snp.unwrap()).is_err()); }