Skip to content

Commit

Permalink
Add support for unix domain sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
argerus committed Nov 20, 2024
1 parent 9359eff commit 6ecc025
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 82 deletions.
1 change: 0 additions & 1 deletion databroker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ glob-match = "0.2.1"
jemallocator = { version = "0.5.0", optional = true }
lazy_static = "1.4.0"
thiserror = "1.0.47"

futures = { version = "0.3.28" }
async-trait = "0.1.82"

Expand Down
79 changes: 52 additions & 27 deletions databroker/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@

use std::{convert::TryFrom, future::Future, time::Duration};

use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::transport::Server;
use futures::Stream;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, UnixListener},
};
use tokio_stream::wrappers::{TcpListenerStream, UnixListenerStream};
#[cfg(feature = "tls")]
use tonic::transport::ServerTlsConfig;
use tonic::transport::{server::Connected, Server};
use tracing::{debug, info};

use databroker_proto::{kuksa, sdv};
Expand All @@ -34,7 +38,7 @@ pub enum ServerTLS {
Enabled { tls_config: ServerTlsConfig },
}

#[derive(PartialEq)]
#[derive(PartialEq, Clone)]
pub enum Api {
KuksaValV1,
KuksaValV2,
Expand Down Expand Up @@ -96,7 +100,7 @@ where
databroker.shutdown().await;
}

pub async fn serve<F>(
pub async fn serve_tcp<F>(
addr: impl Into<std::net::SocketAddr>,
broker: broker::DataBroker,
#[cfg(feature = "tls")] server_tls: ServerTLS,
Expand All @@ -110,25 +114,14 @@ where
let socket_addr = addr.into();
let listener = TcpListener::bind(socket_addr).await?;

/* On Linux systems try to notify daemon readiness to systemd.
* This function determines whether the a system is using systemd
* or not, so it is safe to use on non-systemd systems as well.
*/
#[cfg(target_os = "linux")]
{
match sd_notify::booted() {
Ok(true) => {
info!("Notifying systemd that the service is ready");
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
}
_ => {
debug!("System is not using systemd, will not try to notify");
}
}
if let Ok(addr) = listener.local_addr() {
info!("Listening on {}", addr);
}

let incoming = TcpListenerStream::new(listener);

serve_with_incoming_shutdown(
listener,
incoming,
broker,
#[cfg(feature = "tls")]
server_tls,
Expand All @@ -139,23 +132,55 @@ where
.await
}

pub async fn serve_with_incoming_shutdown<F>(
listener: TcpListener,
pub async fn serve_uds<F>(
path: impl AsRef<std::path::Path>,
broker: broker::DataBroker,
#[cfg(feature = "tls")] server_tls: ServerTLS,
apis: &[Api],
authorization: Authorization,
signal: F,
) -> Result<(), Box<dyn std::error::Error>>
where
F: Future<Output = ()>,
{
broker.start_housekeeping_task();
let listener = UnixListener::bind(path)?;

if let Ok(addr) = listener.local_addr() {
info!("Listening on {}", addr);
match addr.as_pathname() {
Some(pathname) => info!("Listening on unix socket at {}", pathname.display()),
None => info!("Listening on unix socket (unknown path)"),
}
}

let incoming = TcpListenerStream::new(listener);
let incoming = UnixListenerStream::new(listener);

serve_with_incoming_shutdown(
incoming,
broker,
ServerTLS::Disabled,
apis,
authorization,
signal,
)
.await
}

pub async fn serve_with_incoming_shutdown<F, I, IO, IE>(
incoming: I,
broker: broker::DataBroker,
#[cfg(feature = "tls")] server_tls: ServerTLS,
apis: &[Api],
authorization: Authorization,
signal: F,
) -> Result<(), Box<dyn std::error::Error>>
where
F: Future<Output = ()>,
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IO::ConnectInfo: Clone + Send + Sync + 'static,
IE: Into<Box<dyn std::error::Error + Send + Sync>>,
{
broker.start_housekeeping_task();

let mut server = Server::builder()
.http2_keepalive_interval(Some(Duration::from_secs(10)))
.http2_keepalive_timeout(Some(Duration::from_secs(20)));
Expand Down
90 changes: 86 additions & 4 deletions databroker/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;

static DEFAULT_UNIX_SOCKET_PATH: &str = "/run/kuksa/databroker.sock";

use std::io;
use std::os::unix::fs::FileTypeExt;
use std::path::Path;

use databroker::authorization::Authorization;
use databroker::broker::RegistrationError;

Expand Down Expand Up @@ -179,6 +185,15 @@ async fn read_metadata_file<'a, 'b>(
Ok(())
}

fn unlink_unix_domain_socket(path: impl AsRef<Path>) -> Result<(), io::Error> {
if let Ok(metadata) = std::fs::metadata(&path) {
if metadata.file_type().is_socket() {
std::fs::remove_file(&path)?;
}
};
Ok(())
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
let version = option_env!("CARGO_PKG_VERSION").unwrap_or_default();
let commit_sha = option_env!("VERGEN_GIT_SHA").unwrap_or_default();
Expand Down Expand Up @@ -228,8 +243,26 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.default_value("55555"),
)
.arg(
Arg::new("vss-file")
Arg::new("enable-unix-socket")
.display_order(3)
.long("enable-unix-socket")
.help("Listen on unix socket, default /run/kuksa/databroker.sock")
.action(ArgAction::SetTrue)
.env("KUKSA_DATABROKER_ENABLE_UNIX_SOCKET")
)
.arg(
Arg::new("unix-socket")
.display_order(4)
.long("unix-socket")
.help("Listen on unix socket, e.g. /tmp/kuksa/databroker.sock")
.action(ArgAction::Set)
.value_name("PATH")
.required(false)
.env("KUKSA_DATABROKER_UNIX_SOCKET"),
)
.arg(
Arg::new("vss-file")
.display_order(5)
.alias("metadata")
.long("vss")
.help("Populate data broker with VSS metadata from (comma-separated) list of files")
Expand All @@ -242,7 +275,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.arg(
Arg::new("jwt-public-key")
.display_order(5)
.display_order(6)
.long("jwt-public-key")
.help("Public key used to verify JWT access tokens")
.action(ArgAction::Set)
Expand All @@ -251,7 +284,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.arg(
Arg::new("disable-authorization")
.display_order(6)
.display_order(7)
.long("disable-authorization")
.help("Disable authorization")
.action(ArgAction::SetTrue),
Expand Down Expand Up @@ -489,7 +522,56 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
if args.get_flag("enable-databroker-v1") {
apis.push(grpc::server::Api::SdvDatabrokerV1);
}
grpc::server::serve(

let unix_socket_path = args.get_one::<String>("unix-socket").cloned().or_else(|| {
// If the --unix-socket PATH is not explicitly set, check whether it
// should be enabled using the default path
if args.get_flag("enable-unix-socket") {
Some(DEFAULT_UNIX_SOCKET_PATH.into())
} else {
None
}
});

if let Some(path) = unix_socket_path {
// We cannot assume that the socket was closed down properly
// so unlink before we recreate it.
unlink_unix_domain_socket(&path)?;
std::fs::create_dir_all(Path::new(&path).parent().unwrap())?;
let broker = broker.clone();
let authorization = authorization.clone();
let apis = apis.clone();
tokio::spawn(async move {
if let Err(err) =
grpc::server::serve_uds(&path, broker, &apis, authorization, shutdown_handler())
.await
{
error!("{err}");
}

info!("Unlinking unix domain socket at {}", path);
unlink_unix_domain_socket(path)
.unwrap_or_else(|_| error!("Failed to unlink unix domain socket"));
});
}

// On Linux systems try to notify daemon readiness to systemd.
// This function determines whether the a system is using systemd
// or not, so it is safe to use on non-systemd systems as well.
#[cfg(target_os = "linux")]
{
match sd_notify::booted() {
Ok(true) => {
info!("Notifying systemd that the service is ready");
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
}
_ => {
debug!("System is not using systemd, will not try to notify");
}
}
}

grpc::server::serve_tcp(
addr,
broker,
#[cfg(feature = "tls")]
Expand Down
4 changes: 3 additions & 1 deletion databroker/tests/world/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use databroker::{
};

use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;
use tracing::debug;

use lazy_static::lazy_static;
Expand Down Expand Up @@ -188,6 +189,7 @@ impl DataBrokerWorld {
let addr = listener
.local_addr()
.expect("failed to determine listener's port");
let incoming = TcpListenerStream::new(listener);

tokio::spawn(async move {
let commit_sha = option_env!("VERGEN_GIT_SHA").unwrap_or("unknown");
Expand Down Expand Up @@ -230,7 +232,7 @@ impl DataBrokerWorld {
}

grpc::server::serve_with_incoming_shutdown(
listener,
incoming,
data_broker,
#[cfg(feature = "tls")]
CERTS.server_tls_config(),
Expand Down
Loading

0 comments on commit 6ecc025

Please sign in to comment.