From c2f55d5f2b5167edf405637b2e2049abe71e5c91 Mon Sep 17 00:00:00 2001 From: Justus Ge Date: Mon, 6 Mar 2023 15:44:59 -0800 Subject: [PATCH] 1/2: Rust oneway tls support (#112) Summary: Pull Request resolved: https://github.com/facebookresearch/Private-ID/pull/112 Add support for rust PID binary oneway TLS Reviewed By: danbunnell Differential Revision: D43613367 fbshipit-source-id: cbeb8de110c2e4e9b3ec2c6ff786083cdb4ab300 --- protocol-rpc/src/connect/create_client.rs | 78 +++++++++-- protocol-rpc/src/connect/create_server.rs | 81 ++++++++++-- protocol-rpc/src/connect/tls.rs | 124 +++++++++++++++++- protocol-rpc/src/rpc/cross-psi-xor/client.rs | 4 +- protocol-rpc/src/rpc/cross-psi-xor/server.rs | 2 - protocol-rpc/src/rpc/cross-psi/client.rs | 4 +- protocol-rpc/src/rpc/cross-psi/server.rs | 2 - protocol-rpc/src/rpc/pjc/client.rs | 4 +- protocol-rpc/src/rpc/pjc/server.rs | 2 - .../src/rpc/private-id-multi-key/client.rs | 4 +- .../src/rpc/private-id-multi-key/server.rs | 2 - protocol-rpc/src/rpc/private-id/client.rs | 4 +- protocol-rpc/src/rpc/private-id/server.rs | 2 - protocol-rpc/src/rpc/suid-create/client.rs | 4 +- protocol-rpc/src/rpc/suid-create/server.rs | 2 - 15 files changed, 266 insertions(+), 53 deletions(-) diff --git a/protocol-rpc/src/connect/create_client.rs b/protocol-rpc/src/connect/create_client.rs index 5876e49..6d2a23b 100644 --- a/protocol-rpc/src/connect/create_client.rs +++ b/protocol-rpc/src/connect/create_client.rs @@ -1,6 +1,7 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 +use std::env; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; @@ -37,13 +38,24 @@ pub fn create_client( } else { match (tls_dir, tls_key, tls_cert, tls_ca) { (Some(d), None, None, None) => { - info!("using dir for tls files {}", d); + info!("using dir for TLS files {}", d); Some(tls::TlsContext::from_dir(d, false)) } + // Two-way TLS support (None, Some(key), Some(cert), Some(ca)) => { - debug!("using paths directly to read the files"); + debug!("using paths directly to read TLS files"); Some(tls::TlsContext::from_paths(key, cert, ca)) } + // One-way TLS support + (None, None, None, Some(ca)) => { + let full_ca_path = if env::var("HOME").is_ok() { + env::var("HOME").unwrap() + "/" + ca + } else { + "/".to_owned() + ca + }; + info!("full ca path: {}", full_ca_path); + Some(tls::TlsContext::from_path_client(full_ca_path.as_str())) + } _ => { let msg = "Supporting --tls-dir together with direct paths is not supported yet"; error!("{}", msg); @@ -62,7 +74,7 @@ pub fn create_client( .domain() .unwrap_or_else(|| { panic!( - "Cannot extract domain neither from host {}\ + "Cannot extract domain neither from host {} \ nor --tls-domain arg was specified", host ) @@ -71,16 +83,26 @@ pub fn create_client( }; info!( - "tls domain name: {} (--tls-domain can can override)", + "tls domain name: {} (--tls-domain can override)", domain_name ); - Some( - ClientTlsConfig::new() - .domain_name(domain_name) - .identity(ctx.identity) - .ca_certificate(ctx.ca), - ) + if ctx.identity.is_some() { + // Two-way TLS + Some( + ClientTlsConfig::new() + .domain_name(domain_name) + .identity(ctx.identity.unwrap()) + .ca_certificate(ctx.ca.unwrap()), + ) + } else { + // One-way TLS + Some( + ClientTlsConfig::new() + .domain_name(domain_name) + .ca_certificate(ctx.ca.unwrap()), + ) + } } None => None, }; @@ -181,4 +203,40 @@ mod test { "private-id-multi-key".to_string(), ); } + + #[test] + #[should_panic(expected = "ca.pem not found")] + fn test_create_client_with_oneway_tls() { + use std::fs::File; + use std::io::Write; + + use tempfile::tempdir; + + // Create a directory inside of `std::env::temp_dir()`. + let dir = tempdir().unwrap(); + use rcgen::*; + let ca_subject_alt_names: &[_] = &["ca.world.example".to_string(), "localhost".to_string()]; + + let ca_cert = generate_simple_self_signed(ca_subject_alt_names).unwrap(); + let ca_pem = ca_cert.serialize_pem().unwrap(); + + let file_path_ca_pem = dir.path().join("ca.pem"); + let mut file_ca_pem = File::create(file_path_ca_pem).unwrap(); + file_ca_pem.write_all(ca_pem.as_bytes()).unwrap(); + + // create_client will use HOME env as the prefix of path, not temp dir, it will throw pem not found error + let _ = create_client( + false, + Some("localhost:10009"), + None, + None, + None, + Some("ca.pem"), + Some("localhost"), + "private-id-multi-key".to_string(), + ); + + drop(file_ca_pem); + dir.close().unwrap(); + } } diff --git a/protocol-rpc/src/connect/create_server.rs b/protocol-rpc/src/connect/create_server.rs index c1a02a7..2be9fd5 100644 --- a/protocol-rpc/src/connect/create_server.rs +++ b/protocol-rpc/src/connect/create_server.rs @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 +use std::env; + use log::info; use log::warn; use tonic::transport::Server; @@ -25,13 +27,33 @@ pub fn create_server( } else { match (tls_dir, tls_key, tls_cert, tls_ca) { (Some(d), None, None, None) => { - info!("using dir for tls files {}", d); + info!("using dir for TLS files {}", d); Some(tls::TlsContext::from_dir(d, true)) } + // Two-way TLS (None, Some(key), Some(cert), Some(ca)) => { - debug!("using paths diretcly to read the files"); + debug!("using paths directly to read the files"); Some(tls::TlsContext::from_paths(key, cert, ca)) } + // One-way TLS + (None, Some(key), Some(cert), None) => { + let full_key_path = if env::var("HOME").is_ok() { + env::var("HOME").unwrap() + "/" + key + } else { + "/".to_owned() + key + }; + let full_cert_path = if env::var("HOME").is_ok() { + env::var("HOME").unwrap() + "/" + cert + } else { + "/".to_owned() + cert + }; + info!("full key path: {}", full_key_path); + info!("full cert path: {}", full_cert_path); + Some(tls::TlsContext::from_paths_server( + full_key_path.as_str(), + full_cert_path.as_str(), + )) + } _ => { let msg = "Supporting --tls-dir together with direct paths is not supported yet"; error!("{}", msg); @@ -49,13 +71,21 @@ pub fn create_server( server = match tls_context { Some(ctx) => { info!("Starting server with TLS support"); - server - .tls_config( - ServerTlsConfig::new() - .identity(ctx.identity) - .client_ca_root(ctx.ca), - ) - .unwrap() + if ctx.ca.is_some() { + // Two-way TLS + server + .tls_config( + ServerTlsConfig::new() + .identity(ctx.identity.unwrap()) + .client_ca_root(ctx.ca.unwrap()), + ) + .unwrap() + } else { + // One-way TLS + server + .tls_config(ServerTlsConfig::new().identity(ctx.identity.unwrap())) + .unwrap() + } } None => server, }; @@ -76,4 +106,37 @@ mod test { fn test_create_server_tls_panic() { let _ = create_server(false, None, None, None, None); } + + #[test] + #[should_panic(expected = "private.key not found")] + fn test_create_server_with_oneway_tls() { + use std::fs::File; + use std::io::Write; + + use tempfile::tempdir; + + // Create a directory inside of `std::env::temp_dir()`. + let dir = tempdir().unwrap(); + use rcgen::*; + let subject_alt_names: &[_] = &["hello.world.example".to_string(), "localhost".to_string()]; + + let server_cert = generate_simple_self_signed(subject_alt_names).unwrap(); + let server_pem = server_cert.serialize_pem().unwrap(); + let private_key = server_cert.serialize_private_key_pem(); + + let file_path_server_pem = dir.path().join("server.pem"); + let mut file_server_pem = File::create(file_path_server_pem).unwrap(); + file_server_pem.write_all(server_pem.as_bytes()).unwrap(); + + let file_path_private_key = dir.path().join("private.key"); + let mut file_private_key = File::create(file_path_private_key).unwrap(); + file_private_key.write_all(private_key.as_bytes()).unwrap(); + + // create_server will use HOME env as the prefix of path, not temp dir, it will throw key not found error + let _ = create_server(false, None, Some("private.key"), Some("server.pem"), None); + + drop(file_server_pem); + drop(file_private_key); + dir.close().unwrap(); + } } diff --git a/protocol-rpc/src/connect/tls.rs b/protocol-rpc/src/connect/tls.rs index 313cee9..cc1b9e9 100644 --- a/protocol-rpc/src/connect/tls.rs +++ b/protocol-rpc/src/connect/tls.rs @@ -14,8 +14,8 @@ use url::Url; #[derive(Clone, Debug)] pub struct TlsContext { - pub identity: tonic::transport::Identity, - pub ca: tonic::transport::Certificate, + pub identity: Option, + pub ca: Option, } impl TlsContext { @@ -47,8 +47,8 @@ impl TlsContext { debug!("Successfully read key, cert and CA cert"); TlsContext { - identity: Identity::from_pem(cert, key), - ca: Certificate::from_pem(ca), + identity: Some(Identity::from_pem(cert, key)), + ca: Some(Certificate::from_pem(ca)), } } @@ -92,6 +92,63 @@ impl TlsContext { TlsContext::from_paths(key_path.as_path(), cert_path.as_path(), ca_path.as_path()) } + + /// Construct TlsContext for Client from corresponding file in the system + /// panics if file not found + pub fn from_path_client(ca_path: T) -> TlsContext + where + T: AsRef + Copy, + { + info!("Reading TLS file, ca: {}", ca_path.as_ref().display()); + [ca_path].iter().for_each(|p| { + if !Path::new(p.as_ref()).exists() { + panic!("File {} not found", p.as_ref().display()) + } + }); + + let z = async { + let ca = tokio::fs::read(ca_path).await.unwrap(); + ca + }; + let ca = block_on(z); + info!("Successfully read CA cert"); + + TlsContext { + identity: None, + ca: Some(Certificate::from_pem(ca)), + } + } + + /// Construct TlsContext for Serever from corresponding files in the system + /// panics if file not found + pub fn from_paths_server(key_path: T, cert_path: T) -> TlsContext + where + T: AsRef + Copy, + { + info!( + "Reading TLS files, key: {}, cert: {}", + key_path.as_ref().display(), + cert_path.as_ref().display(), + ); + [key_path, cert_path].iter().for_each(|p| { + if !Path::new(p.as_ref()).exists() { + panic!("File {} not found", p.as_ref().display()) + } + }); + + let z = async { + let key = tokio::fs::read(key_path).await.unwrap(); + let cert = tokio::fs::read(cert_path).await.unwrap(); + (key, cert) + }; + let (key, cert) = block_on(z); + debug!("Successfully read key and cert"); + + TlsContext { + identity: Some(Identity::from_pem(cert, key)), + ca: None, + } + } } /// Converts host string into URI object @@ -191,4 +248,63 @@ mod tests { drop(file_client_key); dir.close().unwrap(); } + + #[tokio::test] + async fn test_from_path_client() { + use std::fs::File; + use std::io::Write; + + use tempfile::tempdir; + + // Create a directory inside of `std::env::temp_dir()`. + let dir = tempdir().unwrap(); + use rcgen::*; + let ca_subject_alt_names: &[_] = &["ca.world.example".to_string(), "localhost".to_string()]; + + let ca_cert = generate_simple_self_signed(ca_subject_alt_names).unwrap(); + let ca_pem = ca_cert.serialize_pem().unwrap(); + + let file_path_ca_pem = dir.path().join("ca.pem"); + let mut file_ca_pem = File::create(file_path_ca_pem).unwrap(); + file_ca_pem.write_all(ca_pem.as_bytes()).unwrap(); + + let _ = TlsContext::from_path_client(dir.path().join(Path::new("ca.pem")).as_path()); + + drop(file_ca_pem); + dir.close().unwrap(); + } + + #[tokio::test] + async fn test_from_path_server() { + use std::fs::File; + use std::io::Write; + + use tempfile::tempdir; + + // Create a directory inside of `std::env::temp_dir()`. + let dir = tempdir().unwrap(); + use rcgen::*; + let subject_alt_names: &[_] = &["hello.world.example".to_string(), "localhost".to_string()]; + + let server_cert = generate_simple_self_signed(subject_alt_names).unwrap(); + let server_pem = server_cert.serialize_pem().unwrap(); + let private_key = server_cert.serialize_private_key_pem(); + + let file_path_server_pem = dir.path().join("server.pem"); + let mut file_server_pem = File::create(file_path_server_pem).unwrap(); + file_server_pem.write_all(server_pem.as_bytes()).unwrap(); + + let file_path_private_key = dir.path().join("private.key"); + let mut file_private_key = File::create(file_path_private_key).unwrap(); + file_private_key.write_all(private_key.as_bytes()).unwrap(); + + let _ = TlsContext::from_paths_server( + dir.path().join(Path::new("private.key")).as_path(), + dir.path().join(Path::new("server.pem")).as_path(), + ); + + drop(file_server_pem); + drop(file_private_key); + dir.close().unwrap(); + } } diff --git a/protocol-rpc/src/rpc/cross-psi-xor/client.rs b/protocol-rpc/src/rpc/cross-psi-xor/client.rs index 46076d2..5f5a61b 100644 --- a/protocol-rpc/src/rpc/cross-psi-xor/client.rs +++ b/protocol-rpc/src/rpc/cross-psi-xor/client.rs @@ -95,8 +95,6 @@ async fn main() -> Result<(), Box> { Arg::with_name("tls-ca") .long("tls-ca") .takes_value(true) - .requires("tls-key") - .requires("tls-cert") .help("Path to root CA certificate issued cert and keys"), Arg::with_name("tls-domain") .long("tls-domain") @@ -105,7 +103,7 @@ async fn main() -> Result<(), Box> { ]) .groups(&[ ArgGroup::with_name("tls") - .args(&["no-tls", "tls-dir", "tls-key"]) + .args(&["no-tls", "tls-dir", "tls-ca"]) .required(true), ArgGroup::with_name("out") .args(&["output", "stdout"]) diff --git a/protocol-rpc/src/rpc/cross-psi-xor/server.rs b/protocol-rpc/src/rpc/cross-psi-xor/server.rs index fc4624f..17a8806 100644 --- a/protocol-rpc/src/rpc/cross-psi-xor/server.rs +++ b/protocol-rpc/src/rpc/cross-psi-xor/server.rs @@ -72,13 +72,11 @@ async fn main() -> Result<(), Box> { .long("tls-key") .takes_value(true) .requires("tls-cert") - .requires("tls-ca") .help("Path to tls key (non-encrypted)"), Arg::with_name("tls-cert") .long("tls-cert") .takes_value(true) .requires("tls-key") - .requires("tls-ca") .help( "Path to tls certificate (pem format), SINGLE cert, \ NO CHAINING, required by client as well", diff --git a/protocol-rpc/src/rpc/cross-psi/client.rs b/protocol-rpc/src/rpc/cross-psi/client.rs index 96e3751..708fcc9 100644 --- a/protocol-rpc/src/rpc/cross-psi/client.rs +++ b/protocol-rpc/src/rpc/cross-psi/client.rs @@ -85,8 +85,6 @@ async fn main() -> Result<(), Box> { Arg::with_name("tls-ca") .long("tls-ca") .takes_value(true) - .requires("tls-key") - .requires("tls-cert") .help("Path to root CA certificate issued cert and keys"), Arg::with_name("tls-domain") .long("tls-domain") @@ -95,7 +93,7 @@ async fn main() -> Result<(), Box> { ]) .groups(&[ ArgGroup::with_name("tls") - .args(&["no-tls", "tls-dir", "tls-key"]) + .args(&["no-tls", "tls-dir", "tls-ca"]) .required(true), ArgGroup::with_name("out") .args(&["output", "stdout"]) diff --git a/protocol-rpc/src/rpc/cross-psi/server.rs b/protocol-rpc/src/rpc/cross-psi/server.rs index 08282e6..e4b89af 100644 --- a/protocol-rpc/src/rpc/cross-psi/server.rs +++ b/protocol-rpc/src/rpc/cross-psi/server.rs @@ -69,13 +69,11 @@ async fn main() -> Result<(), Box> { .long("tls-key") .takes_value(true) .requires("tls-cert") - .requires("tls-ca") .help("Path to tls key (non-encrypted)"), Arg::with_name("tls-cert") .long("tls-cert") .takes_value(true) .requires("tls-key") - .requires("tls-ca") .help( "Path to tls certificate (pem format), SINGLE cert, \ NO CHAINING, required by client as well", diff --git a/protocol-rpc/src/rpc/pjc/client.rs b/protocol-rpc/src/rpc/pjc/client.rs index 8b99415..f7ce54d 100644 --- a/protocol-rpc/src/rpc/pjc/client.rs +++ b/protocol-rpc/src/rpc/pjc/client.rs @@ -81,8 +81,6 @@ async fn main() -> Result<(), Box> { Arg::with_name("tls-ca") .long("tls-ca") .takes_value(true) - .requires("tls-key") - .requires("tls-cert") .help("Path to root CA certificate issued cert and keys"), Arg::with_name("tls-domain") .long("tls-domain") @@ -91,7 +89,7 @@ async fn main() -> Result<(), Box> { ]) .groups(&[ ArgGroup::with_name("tls") - .args(&["no-tls", "tls-dir", "tls-key"]) + .args(&["no-tls", "tls-dir", "tls-ca"]) .required(true), ArgGroup::with_name("out") .args(&["output", "stdout"]) diff --git a/protocol-rpc/src/rpc/pjc/server.rs b/protocol-rpc/src/rpc/pjc/server.rs index e57ea34..fbe1448 100644 --- a/protocol-rpc/src/rpc/pjc/server.rs +++ b/protocol-rpc/src/rpc/pjc/server.rs @@ -65,13 +65,11 @@ async fn main() -> Result<(), Box> { .long("tls-key") .takes_value(true) .requires("tls-cert") - .requires("tls-ca") .help("Path to tls key (non-encrypted)"), Arg::with_name("tls-cert") .long("tls-cert") .takes_value(true) .requires("tls-key") - .requires("tls-ca") .help( "Path to tls certificate (pem format), SINGLE cert, \ NO CHAINING, required by client as well", diff --git a/protocol-rpc/src/rpc/private-id-multi-key/client.rs b/protocol-rpc/src/rpc/private-id-multi-key/client.rs index f52e8d4..b1078e5 100644 --- a/protocol-rpc/src/rpc/private-id-multi-key/client.rs +++ b/protocol-rpc/src/rpc/private-id-multi-key/client.rs @@ -93,8 +93,6 @@ async fn main() -> Result<(), Box> { Arg::with_name("tls-ca") .long("tls-ca") .takes_value(true) - .requires("tls-key") - .requires("tls-cert") .help("Path to root CA certificate issued cert and keys"), Arg::with_name("tls-domain") .long("tls-domain") @@ -108,7 +106,7 @@ async fn main() -> Result<(), Box> { ]) .groups(&[ ArgGroup::with_name("tls") - .args(&["no-tls", "tls-dir", "tls-key"]) + .args(&["no-tls", "tls-dir", "tls-ca"]) .required(true), ArgGroup::with_name("out") .args(&["output", "stdout"]) diff --git a/protocol-rpc/src/rpc/private-id-multi-key/server.rs b/protocol-rpc/src/rpc/private-id-multi-key/server.rs index 74af762..5312095 100644 --- a/protocol-rpc/src/rpc/private-id-multi-key/server.rs +++ b/protocol-rpc/src/rpc/private-id-multi-key/server.rs @@ -77,13 +77,11 @@ async fn main() -> Result<(), Box> { .long("tls-key") .takes_value(true) .requires("tls-cert") - .requires("tls-ca") .help("Path to tls key (non-encrypted)"), Arg::with_name("tls-cert") .long("tls-cert") .takes_value(true) .requires("tls-key") - .requires("tls-ca") .help( "Path to tls certificate (pem format), SINGLE cert, \ NO CHAINING, required by client as well", diff --git a/protocol-rpc/src/rpc/private-id/client.rs b/protocol-rpc/src/rpc/private-id/client.rs index 411947c..2122d22 100644 --- a/protocol-rpc/src/rpc/private-id/client.rs +++ b/protocol-rpc/src/rpc/private-id/client.rs @@ -93,8 +93,6 @@ async fn main() -> Result<(), Box> { Arg::with_name("tls-ca") .long("tls-ca") .takes_value(true) - .requires("tls-key") - .requires("tls-cert") .help("Path to root CA certificate issued cert and keys"), Arg::with_name("tls-domain") .long("tls-domain") @@ -116,7 +114,7 @@ async fn main() -> Result<(), Box> { ]) .groups(&[ ArgGroup::with_name("tls") - .args(&["no-tls", "tls-dir", "tls-key"]) + .args(&["no-tls", "tls-dir", "tls-ca"]) .required(true), ArgGroup::with_name("out") .args(&["output", "stdout"]) diff --git a/protocol-rpc/src/rpc/private-id/server.rs b/protocol-rpc/src/rpc/private-id/server.rs index 05f6287..d7f200b 100644 --- a/protocol-rpc/src/rpc/private-id/server.rs +++ b/protocol-rpc/src/rpc/private-id/server.rs @@ -77,13 +77,11 @@ async fn main() -> Result<(), Box> { .long("tls-key") .takes_value(true) .requires("tls-cert") - .requires("tls-ca") .help("Path to tls key (non-encrypted)"), Arg::with_name("tls-cert") .long("tls-cert") .takes_value(true) .requires("tls-key") - .requires("tls-ca") .help( "Path to tls certificate (pem format), SINGLE cert, \ NO CHAINING, required by client as well", diff --git a/protocol-rpc/src/rpc/suid-create/client.rs b/protocol-rpc/src/rpc/suid-create/client.rs index f7e8a41..7c35894 100644 --- a/protocol-rpc/src/rpc/suid-create/client.rs +++ b/protocol-rpc/src/rpc/suid-create/client.rs @@ -86,8 +86,6 @@ async fn main() -> Result<(), Box> { Arg::with_name("tls-ca") .long("tls-ca") .takes_value(true) - .requires("tls-key") - .requires("tls-cert") .help("Path to root CA certificate issued cert and keys"), Arg::with_name("tls-domain") .long("tls-domain") @@ -96,7 +94,7 @@ async fn main() -> Result<(), Box> { ]) .groups(&[ ArgGroup::with_name("tls") - .args(&["no-tls", "tls-dir", "tls-key"]) + .args(&["no-tls", "tls-dir", "tls-ca"]) .required(true), ArgGroup::with_name("out") .args(&["output", "stdout"]) diff --git a/protocol-rpc/src/rpc/suid-create/server.rs b/protocol-rpc/src/rpc/suid-create/server.rs index b3551d6..97e32f1 100644 --- a/protocol-rpc/src/rpc/suid-create/server.rs +++ b/protocol-rpc/src/rpc/suid-create/server.rs @@ -71,13 +71,11 @@ async fn main() -> Result<(), Box> { .long("tls-key") .takes_value(true) .requires("tls-cert") - .requires("tls-ca") .help("Path to tls key (non-encrypted)"), Arg::with_name("tls-cert") .long("tls-cert") .takes_value(true) .requires("tls-key") - .requires("tls-ca") .help( "Path to tls certificate (pem format), SINGLE cert, \ NO CHAINING, required by client as well",