Skip to content

Commit

Permalink
1/2: Rust oneway tls support (#112)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #112

Add support for rust PID binary oneway TLS

Reviewed By: danbunnell

Differential Revision: D43613367

fbshipit-source-id: cbeb8de110c2e4e9b3ec2c6ff786083cdb4ab300
  • Loading branch information
Justus Ge authored and facebook-github-bot committed Mar 6, 2023
1 parent e5b69d3 commit c2f55d5
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 53 deletions.
78 changes: 68 additions & 10 deletions protocol-rpc/src/connect/create_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// SPDX-License-Identifier: Apache-2.0

use std::env;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
Expand Down Expand Up @@ -37,13 +38,24 @@ pub fn create_client(
} else {
match (tls_dir, tls_key, tls_cert, tls_ca) {
(Some(d), None, None, None) => {
info!("using dir for tls files {}", d);
info!("using dir for TLS files {}", d);
Some(tls::TlsContext::from_dir(d, false))
}
// Two-way TLS support
(None, Some(key), Some(cert), Some(ca)) => {
debug!("using paths directly to read the files");
debug!("using paths directly to read TLS files");
Some(tls::TlsContext::from_paths(key, cert, ca))
}
// One-way TLS support
(None, None, None, Some(ca)) => {
let full_ca_path = if env::var("HOME").is_ok() {
env::var("HOME").unwrap() + "/" + ca
} else {
"/".to_owned() + ca
};
info!("full ca path: {}", full_ca_path);
Some(tls::TlsContext::from_path_client(full_ca_path.as_str()))
}
_ => {
let msg = "Supporting --tls-dir together with direct paths is not supported yet";
error!("{}", msg);
Expand All @@ -62,7 +74,7 @@ pub fn create_client(
.domain()
.unwrap_or_else(|| {
panic!(
"Cannot extract domain neither from host {}\
"Cannot extract domain neither from host {} \
nor --tls-domain arg was specified",
host
)
Expand All @@ -71,16 +83,26 @@ pub fn create_client(
};

info!(
"tls domain name: {} (--tls-domain can can override)",
"tls domain name: {} (--tls-domain can override)",
domain_name
);

Some(
ClientTlsConfig::new()
.domain_name(domain_name)
.identity(ctx.identity)
.ca_certificate(ctx.ca),
)
if ctx.identity.is_some() {
// Two-way TLS
Some(
ClientTlsConfig::new()
.domain_name(domain_name)
.identity(ctx.identity.unwrap())
.ca_certificate(ctx.ca.unwrap()),
)
} else {
// One-way TLS
Some(
ClientTlsConfig::new()
.domain_name(domain_name)
.ca_certificate(ctx.ca.unwrap()),
)
}
}
None => None,
};
Expand Down Expand Up @@ -181,4 +203,40 @@ mod test {
"private-id-multi-key".to_string(),
);
}

#[test]
#[should_panic(expected = "ca.pem not found")]
fn test_create_client_with_oneway_tls() {
use std::fs::File;
use std::io::Write;

use tempfile::tempdir;

// Create a directory inside of `std::env::temp_dir()`.
let dir = tempdir().unwrap();
use rcgen::*;
let ca_subject_alt_names: &[_] = &["ca.world.example".to_string(), "localhost".to_string()];

let ca_cert = generate_simple_self_signed(ca_subject_alt_names).unwrap();
let ca_pem = ca_cert.serialize_pem().unwrap();

let file_path_ca_pem = dir.path().join("ca.pem");
let mut file_ca_pem = File::create(file_path_ca_pem).unwrap();
file_ca_pem.write_all(ca_pem.as_bytes()).unwrap();

// create_client will use HOME env as the prefix of path, not temp dir, it will throw pem not found error
let _ = create_client(
false,
Some("localhost:10009"),
None,
None,
None,
Some("ca.pem"),
Some("localhost"),
"private-id-multi-key".to_string(),
);

drop(file_ca_pem);
dir.close().unwrap();
}
}
81 changes: 72 additions & 9 deletions protocol-rpc/src/connect/create_server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// SPDX-License-Identifier: Apache-2.0

use std::env;

use log::info;
use log::warn;
use tonic::transport::Server;
Expand All @@ -25,13 +27,33 @@ pub fn create_server(
} else {
match (tls_dir, tls_key, tls_cert, tls_ca) {
(Some(d), None, None, None) => {
info!("using dir for tls files {}", d);
info!("using dir for TLS files {}", d);
Some(tls::TlsContext::from_dir(d, true))
}
// Two-way TLS
(None, Some(key), Some(cert), Some(ca)) => {
debug!("using paths diretcly to read the files");
debug!("using paths directly to read the files");
Some(tls::TlsContext::from_paths(key, cert, ca))
}
// One-way TLS
(None, Some(key), Some(cert), None) => {
let full_key_path = if env::var("HOME").is_ok() {
env::var("HOME").unwrap() + "/" + key
} else {
"/".to_owned() + key
};
let full_cert_path = if env::var("HOME").is_ok() {
env::var("HOME").unwrap() + "/" + cert
} else {
"/".to_owned() + cert
};
info!("full key path: {}", full_key_path);
info!("full cert path: {}", full_cert_path);
Some(tls::TlsContext::from_paths_server(
full_key_path.as_str(),
full_cert_path.as_str(),
))
}
_ => {
let msg = "Supporting --tls-dir together with direct paths is not supported yet";
error!("{}", msg);
Expand All @@ -49,13 +71,21 @@ pub fn create_server(
server = match tls_context {
Some(ctx) => {
info!("Starting server with TLS support");
server
.tls_config(
ServerTlsConfig::new()
.identity(ctx.identity)
.client_ca_root(ctx.ca),
)
.unwrap()
if ctx.ca.is_some() {
// Two-way TLS
server
.tls_config(
ServerTlsConfig::new()
.identity(ctx.identity.unwrap())
.client_ca_root(ctx.ca.unwrap()),
)
.unwrap()
} else {
// One-way TLS
server
.tls_config(ServerTlsConfig::new().identity(ctx.identity.unwrap()))
.unwrap()
}
}
None => server,
};
Expand All @@ -76,4 +106,37 @@ mod test {
fn test_create_server_tls_panic() {
let _ = create_server(false, None, None, None, None);
}

#[test]
#[should_panic(expected = "private.key not found")]
fn test_create_server_with_oneway_tls() {
use std::fs::File;
use std::io::Write;

use tempfile::tempdir;

// Create a directory inside of `std::env::temp_dir()`.
let dir = tempdir().unwrap();
use rcgen::*;
let subject_alt_names: &[_] = &["hello.world.example".to_string(), "localhost".to_string()];

let server_cert = generate_simple_self_signed(subject_alt_names).unwrap();
let server_pem = server_cert.serialize_pem().unwrap();
let private_key = server_cert.serialize_private_key_pem();

let file_path_server_pem = dir.path().join("server.pem");
let mut file_server_pem = File::create(file_path_server_pem).unwrap();
file_server_pem.write_all(server_pem.as_bytes()).unwrap();

let file_path_private_key = dir.path().join("private.key");
let mut file_private_key = File::create(file_path_private_key).unwrap();
file_private_key.write_all(private_key.as_bytes()).unwrap();

// create_server will use HOME env as the prefix of path, not temp dir, it will throw key not found error
let _ = create_server(false, None, Some("private.key"), Some("server.pem"), None);

drop(file_server_pem);
drop(file_private_key);
dir.close().unwrap();
}
}
124 changes: 120 additions & 4 deletions protocol-rpc/src/connect/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use url::Url;

#[derive(Clone, Debug)]
pub struct TlsContext {
pub identity: tonic::transport::Identity,
pub ca: tonic::transport::Certificate,
pub identity: Option<tonic::transport::Identity>,
pub ca: Option<tonic::transport::Certificate>,
}

impl TlsContext {
Expand Down Expand Up @@ -47,8 +47,8 @@ impl TlsContext {
debug!("Successfully read key, cert and CA cert");

TlsContext {
identity: Identity::from_pem(cert, key),
ca: Certificate::from_pem(ca),
identity: Some(Identity::from_pem(cert, key)),
ca: Some(Certificate::from_pem(ca)),
}
}

Expand Down Expand Up @@ -92,6 +92,63 @@ impl TlsContext {

TlsContext::from_paths(key_path.as_path(), cert_path.as_path(), ca_path.as_path())
}

/// Construct TlsContext for Client from corresponding file in the system
/// panics if file not found
pub fn from_path_client<T>(ca_path: T) -> TlsContext
where
T: AsRef<Path> + Copy,
{
info!("Reading TLS file, ca: {}", ca_path.as_ref().display());
[ca_path].iter().for_each(|p| {
if !Path::new(p.as_ref()).exists() {
panic!("File {} not found", p.as_ref().display())
}
});

let z = async {
let ca = tokio::fs::read(ca_path).await.unwrap();
ca
};
let ca = block_on(z);
info!("Successfully read CA cert");

TlsContext {
identity: None,
ca: Some(Certificate::from_pem(ca)),
}
}

/// Construct TlsContext for Serever from corresponding files in the system
/// panics if file not found
pub fn from_paths_server<T>(key_path: T, cert_path: T) -> TlsContext
where
T: AsRef<Path> + Copy,
{
info!(
"Reading TLS files, key: {}, cert: {}",
key_path.as_ref().display(),
cert_path.as_ref().display(),
);
[key_path, cert_path].iter().for_each(|p| {
if !Path::new(p.as_ref()).exists() {
panic!("File {} not found", p.as_ref().display())
}
});

let z = async {
let key = tokio::fs::read(key_path).await.unwrap();
let cert = tokio::fs::read(cert_path).await.unwrap();
(key, cert)
};
let (key, cert) = block_on(z);
debug!("Successfully read key and cert");

TlsContext {
identity: Some(Identity::from_pem(cert, key)),
ca: None,
}
}
}

/// Converts host string into URI object
Expand Down Expand Up @@ -191,4 +248,63 @@ mod tests {
drop(file_client_key);
dir.close().unwrap();
}

#[tokio::test]
async fn test_from_path_client() {
use std::fs::File;
use std::io::Write;

use tempfile::tempdir;

// Create a directory inside of `std::env::temp_dir()`.
let dir = tempdir().unwrap();
use rcgen::*;
let ca_subject_alt_names: &[_] = &["ca.world.example".to_string(), "localhost".to_string()];

let ca_cert = generate_simple_self_signed(ca_subject_alt_names).unwrap();
let ca_pem = ca_cert.serialize_pem().unwrap();

let file_path_ca_pem = dir.path().join("ca.pem");
let mut file_ca_pem = File::create(file_path_ca_pem).unwrap();
file_ca_pem.write_all(ca_pem.as_bytes()).unwrap();

let _ = TlsContext::from_path_client(dir.path().join(Path::new("ca.pem")).as_path());

drop(file_ca_pem);
dir.close().unwrap();
}

#[tokio::test]
async fn test_from_path_server() {
use std::fs::File;
use std::io::Write;

use tempfile::tempdir;

// Create a directory inside of `std::env::temp_dir()`.
let dir = tempdir().unwrap();
use rcgen::*;
let subject_alt_names: &[_] = &["hello.world.example".to_string(), "localhost".to_string()];

let server_cert = generate_simple_self_signed(subject_alt_names).unwrap();
let server_pem = server_cert.serialize_pem().unwrap();
let private_key = server_cert.serialize_private_key_pem();

let file_path_server_pem = dir.path().join("server.pem");
let mut file_server_pem = File::create(file_path_server_pem).unwrap();
file_server_pem.write_all(server_pem.as_bytes()).unwrap();

let file_path_private_key = dir.path().join("private.key");
let mut file_private_key = File::create(file_path_private_key).unwrap();
file_private_key.write_all(private_key.as_bytes()).unwrap();

let _ = TlsContext::from_paths_server(
dir.path().join(Path::new("private.key")).as_path(),
dir.path().join(Path::new("server.pem")).as_path(),
);

drop(file_server_pem);
drop(file_private_key);
dir.close().unwrap();
}
}
4 changes: 1 addition & 3 deletions protocol-rpc/src/rpc/cross-psi-xor/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Arg::with_name("tls-ca")
.long("tls-ca")
.takes_value(true)
.requires("tls-key")
.requires("tls-cert")
.help("Path to root CA certificate issued cert and keys"),
Arg::with_name("tls-domain")
.long("tls-domain")
Expand All @@ -105,7 +103,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
])
.groups(&[
ArgGroup::with_name("tls")
.args(&["no-tls", "tls-dir", "tls-key"])
.args(&["no-tls", "tls-dir", "tls-ca"])
.required(true),
ArgGroup::with_name("out")
.args(&["output", "stdout"])
Expand Down
Loading

0 comments on commit c2f55d5

Please sign in to comment.