diff --git a/Cargo.toml b/Cargo.toml index ad64614d..b3cb2ed1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,10 @@ memoffset = { version = "0.7.1", default-features = false } rstest = { version = "0.15", default-features = false } testaso = { version = "0.1", default-features = false } +[features] +default = [] +insecure = [] + [profile.release] incremental = false codegen-units = 1 diff --git a/src/ext/kvm.rs b/src/ext/kvm.rs index fa9a15dd..2be06ac4 100644 --- a/src/ext/kvm.rs +++ b/src/ext/kvm.rs @@ -19,7 +19,7 @@ impl ExtVerifier for Kvm { const OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.3.6.1.4.1.58270.1.1"); const ATT: bool = true; - fn verify(&self, _cri: &CertReqInfo<'_>, ext: &Extension<'_>, dbg: bool) -> Result { + fn verify(&self, _cri: &CertReqInfo<'_>, ext: &Extension<'_>) -> Result { if ext.critical { return Err(anyhow!("kvm extension cannot be critical")); } @@ -28,10 +28,10 @@ impl ExtVerifier for Kvm { return Err(anyhow!("invalid kvm extension")); } - if !dbg { - return Err(anyhow!("steward not in debug mode")); - } + #[cfg(not(feature = "insecure"))] + return Err(anyhow!("steward not in debug mode")); + #[cfg(feature = "insecure")] Ok(true) } } diff --git a/src/ext/mod.rs b/src/ext/mod.rs index df54f5c9..96a0dbc6 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -28,5 +28,5 @@ pub trait ExtVerifier { /// certificate. Returning `Ok(false)` will allow the certification request /// to continue, but this particular extension will not be included /// in the resulting certificate. - fn verify(&self, cri: &CertReqInfo<'_>, ext: &Extension<'_>, dbg: bool) -> Result; + fn verify(&self, cri: &CertReqInfo<'_>, ext: &Extension<'_>) -> Result; } diff --git a/src/ext/sgx/mod.rs b/src/ext/sgx/mod.rs index d2828dbd..d674ace2 100644 --- a/src/ext/sgx/mod.rs +++ b/src/ext/sgx/mod.rs @@ -11,8 +11,12 @@ use std::fmt::Debug; use anyhow::{anyhow, Result}; use const_oid::ObjectIdentifier; -use der::{Decode, Encode}; +use der::Decode; +#[cfg(not(feature = "insecure"))] +use der::Encode; +#[cfg(not(feature = "insecure"))] use sgx::parameters::{Attributes, MiscSelect}; +#[cfg(not(feature = "insecure"))] use sha2::{Digest, Sha256}; use x509::{ext::Extension, request::CertReqInfo, Certificate, TbsCertificate}; @@ -42,7 +46,7 @@ impl ExtVerifier for Sgx { const OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.3.6.1.4.1.58270.1.2"); const ATT: bool = true; - fn verify(&self, cri: &CertReqInfo<'_>, ext: &Extension<'_>, dbg: bool) -> Result { + fn verify(&self, cri: &CertReqInfo<'_>, ext: &Extension<'_>) -> Result { if ext.critical { return Err(anyhow!("sgx extension cannot be critical")); } @@ -62,7 +66,8 @@ impl ExtVerifier for Sgx { // Validate the report. let pck = self.trusted(&chain)?; - let rpt = quote.verify(pck)?; + #[cfg(feature = "insecure")] + quote.verify(pck)?; // Force certs to have the same key type as the PCK. // @@ -82,7 +87,10 @@ impl ExtVerifier for Sgx { return Err(anyhow!("sgx pck algorithm mismatch")); } - if !dbg { + #[cfg(not(feature = "insecure"))] + { + let rpt = quote.verify(pck)?; + // TODO: Validate that the certification request came from an SGX enclave. let hash = Sha256::digest(&cri.public_key.to_vec()?); if hash.as_slice() != &rpt.reportdata[..hash.as_slice().len()] { diff --git a/src/ext/snp/mod.rs b/src/ext/snp/mod.rs index 92db26c0..bb1a66da 100644 --- a/src/ext/snp/mod.rs +++ b/src/ext/snp/mod.rs @@ -13,6 +13,7 @@ use der::asn1::UIntRef; use der::{Decode, Encode, Sequence}; use flagset::{flags, FlagSet}; use sec1::pkcs8::AlgorithmIdentifier; +#[cfg(not(feature = "insecure"))] use sha2::Digest; use x509::ext::Extension; use x509::{request::CertReqInfo, Certificate}; @@ -241,7 +242,7 @@ impl ExtVerifier for Snp { const OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.3.6.1.4.1.58270.1.3"); const ATT: bool = true; - fn verify(&self, cri: &CertReqInfo<'_>, ext: &Extension<'_>, dbg: bool) -> Result { + fn verify(&self, cri: &CertReqInfo<'_>, ext: &Extension<'_>) -> Result { if ext.critical { return Err(anyhow!("snp extension cannot be critical")); } @@ -372,7 +373,8 @@ impl ExtVerifier for Snp { } } - if !dbg { + #[cfg(not(feature = "insecure"))] + { // Validate that the certification request came from an SNP VM. let hash = sha2::Sha384::digest(&cri.public_key.to_vec()?); if hash.as_slice() != &report.body.report_data[..hash.as_slice().len()] { diff --git a/src/main.rs b/src/main.rs index 02a6d89f..73c8d8e6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,16 +26,17 @@ use axum::routing::{get, post}; use axum::Router; use clap::Parser; use confargs::{prefix_char_filter, Toml}; -#[cfg(debug_assertions)] +#[cfg(feature = "insecure")] use const_oid::db::rfc5280::{ID_CE_BASIC_CONSTRAINTS, ID_CE_KEY_USAGE}; use const_oid::db::rfc5280::{ ID_CE_EXT_KEY_USAGE, ID_CE_SUBJECT_ALT_NAME, ID_KP_CLIENT_AUTH, ID_KP_SERVER_AUTH, }; use const_oid::db::rfc5912::ID_EXTENSION_REQ; -#[cfg(debug_assertions)] +#[cfg(feature = "insecure")] use der::asn1::GeneralizedTime; use der::asn1::{Ia5StringRef, UIntRef}; use der::{Decode, Encode, Sequence}; +#[cfg(feature = "insecure")] use ext::kvm::Kvm; use ext::sgx::Sgx; use ext::snp::Snp; @@ -51,7 +52,7 @@ use tower_http::LatencyUnit; use tracing::{debug, Level}; use x509::attr::Attribute; use x509::ext::pkix::name::GeneralName; -#[cfg(debug_assertions)] +#[cfg(feature = "insecure")] use x509::ext::pkix::{BasicConstraints, KeyUsage, KeyUsages}; use x509::ext::pkix::{ExtendedKeyUsage, SubjectAltName}; use x509::name::RdnSequence; @@ -71,7 +72,8 @@ const BUNDLE: &str = "application/vnd.steward.pkcs10-bundle.v1"; /// The configuration file must contain valid TOML table mapping argument /// names to their values. #[derive(Clone, Debug, Parser)] -#[command(author, version, about)] +#[clap(author, version, about)] +#[cfg_attr(feature = "insecure", clap(about = "Insecure Mode", long_about = None))] struct Args { #[arg(short, long, env = "STEWARD_KEY")] key: Option, @@ -85,7 +87,7 @@ struct Args { #[arg(short, long, env = "ROCKET_ADDRESS", default_value = "::")] addr: IpAddr, - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[clap(short, long, env = "RENDER_EXTERNAL_HOSTNAME")] host: Option, @@ -147,9 +149,9 @@ impl State { // Validate the syntax of the files. PrivateKeyInfo::from_der(key.as_ref())?; - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] Certificate::from_der(crt.as_ref())?; - #[cfg(not(debug_assertions))] + #[cfg(not(feature = "insecure"))] { let cert = Certificate::from_der(crt.as_ref())?; let iss = &cert.tbs_certificate; @@ -162,7 +164,7 @@ impl State { Ok(State { crt, san, key }) } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] pub fn generate(san: Option, hostname: &str) -> anyhow::Result { use const_oid::db::rfc5912::SECP_256_R_1 as P256; @@ -233,11 +235,14 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); } + #[cfg(feature = "insecure")] + println!("Running in insecure mode."); + let args = confargs::args::(prefix_char_filter::<'@'>) .context("Failed to parse config") .map(Args::parse_from)?; - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] let state = match (args.key, args.crt, args.host) { (None, None, Some(host)) => State::generate(args.san, &host)?, (Some(key), Some(crt), _) => State::load(args.san, key, crt)?, @@ -247,7 +252,7 @@ async fn main() -> anyhow::Result<()> { } }; - #[cfg(not(debug_assertions))] + #[cfg(not(feature = "insecure"))] let state = match (args.key, args.crt) { (Some(key), Some(crt)) => State::load(args.san, key, crt)?, _ => { @@ -344,14 +349,6 @@ fn attest_request( StatusCode::BAD_REQUEST })?; - let dbg = if cfg!(debug_assertions) { - // If the issuer is self-signed, we are in debug mode. - let iss = &issuer.tbs_certificate; - iss.issuer_unique_id == iss.subject_unique_id && iss.issuer == iss.subject - } else { - false - }; - let mut extensions = Vec::new(); let mut attested = false; for Attribute { oid, values } in info.attributes.iter() { @@ -367,9 +364,10 @@ fn attest_request( for ext in Vec::from(ereq) { // Validate the extension. let (copy, att) = match ext.extn_id { - Kvm::OID => (Kvm::default().verify(&info, &ext, dbg), Kvm::ATT), - Sgx::OID => (Sgx::default().verify(&info, &ext, dbg), Sgx::ATT), - Snp::OID => (Snp::default().verify(&info, &ext, dbg), Snp::ATT), + #[cfg(feature = "insecure")] + Kvm::OID => (Kvm::default().verify(&info, &ext), Kvm::ATT), + Sgx::OID => (Sgx::default().verify(&info, &ext), Sgx::ATT), + Snp::OID => (Snp::default().verify(&info, &ext), Snp::ATT), oid => { debug!("extension `{oid}` is unsupported"); return Err(StatusCode::BAD_REQUEST); @@ -503,21 +501,21 @@ async fn attest( mod tests { mod attest { use crate::ext::{kvm::Kvm, ExtVerifier}; - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] use crate::ext::{sgx::Sgx, snp::Snp}; use crate::*; - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] use const_oid::db::rfc5912::SECP_384_R_1; use const_oid::db::rfc5912::{ID_EXTENSION_REQ, SECP_256_R_1}; use const_oid::ObjectIdentifier; use der::{AnyRef, Encode}; use x509::attr::Attribute; use x509::request::{CertReq, CertReqInfo, ExtensionReq}; - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] use x509::PkiPath; use x509::{ext::Extension, name::RdnSequence}; - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] use axum::response::Response; use http::header::CONTENT_TYPE; use http::Request; @@ -528,17 +526,17 @@ mod tests { fn certificates_state() -> State { #[cfg(not(target_os = "wasi"))] { - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] return State::load(None, "testdata/ca.key", "testdata/ca.crt") .expect("failed to load state"); - #[cfg(not(debug_assertions))] + #[cfg(not(feature = "insecure"))] return State::load(None, "testdata/test.key", "testdata/test.crt") .expect("failed to load state"); } #[cfg(target_os = "wasi")] { - let (crt, key) = if cfg!(debug_assertions) { + let (crt, key) = if cfg!(feature = "insecure") { ( std::io::BufReader::new(include_bytes!("../testdata/ca.crt").as_slice()), std::io::BufReader::new(include_bytes!("../testdata/ca.key").as_slice()), @@ -554,7 +552,7 @@ mod tests { } } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] fn hostname_state() -> State { State::generate(None, "localhost").unwrap() } @@ -588,7 +586,7 @@ mod tests { } } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] async fn attest_response(state: State, response: Response, multi: bool) { let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); @@ -627,6 +625,7 @@ mod tests { assert_eq!(encoded, reencoded); } + #[cfg(feature = "insecure")] #[rstest] #[case(PKCS10, false)] #[case(BUNDLE, true)] @@ -646,18 +645,11 @@ mod tests { .unwrap(); let response = app(certificates_state()).oneshot(request).await.unwrap(); - #[cfg(debug_assertions)] - { - assert_eq!(response.status(), StatusCode::OK); - attest_response(certificates_state(), response, multi).await; - } - #[cfg(not(debug_assertions))] - { - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - } + assert_eq!(response.status(), StatusCode::OK); + attest_response(certificates_state(), response, multi).await; } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[rstest] #[case(PKCS10, false)] #[case(BUNDLE, true)] @@ -684,7 +676,7 @@ mod tests { // Though similar to the above test, this is the only test which // actually sends many CSRs, versus an array of just one CSR. - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[tokio::test] async fn kvm_hostname_many_certs() { let ext = Extension { @@ -722,7 +714,7 @@ mod tests { assert_eq!(output.issued.len(), five_crs.len()); } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[rstest] #[case(PKCS10, false)] #[case(BUNDLE, true)] @@ -751,7 +743,7 @@ mod tests { } } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[rstest] #[case(PKCS10, false)] #[case(BUNDLE, true)] @@ -781,7 +773,7 @@ mod tests { } } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[rstest] #[case(PKCS10, false)] #[case(BUNDLE, true)] @@ -812,7 +804,7 @@ mod tests { attest_response(certificates_state(), response, multi).await; } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[rstest] #[case(PKCS10, false)] #[case(BUNDLE, true)] @@ -844,7 +836,7 @@ mod tests { attest_response(state, response, multi).await; } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[rstest] #[case(PKCS10, false)] #[case(BUNDLE, true)] @@ -861,7 +853,7 @@ mod tests { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[tokio::test] async fn err_no_attestation_hostname() { let request = Request::builder() @@ -875,7 +867,7 @@ mod tests { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[rstest] #[case(false)] #[case(true)] @@ -891,7 +883,7 @@ mod tests { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[tokio::test] async fn err_empty_body() { let request = Request::builder() @@ -904,7 +896,7 @@ mod tests { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[tokio::test] async fn err_bad_body() { let request = Request::builder() @@ -917,7 +909,7 @@ mod tests { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } - #[cfg(debug_assertions)] + #[cfg(feature = "insecure")] #[tokio::test] async fn err_bad_csr_sig() { let mut cr = cr(SECP_256_R_1, vec![], true);