diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 4e427307c8..75c2443b8e 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -13,6 +13,7 @@ use std::io::ErrorKind; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use tokio::sync::Notify; use tokio_util::io::ReaderStream; use url::Url; @@ -60,6 +61,7 @@ pub async fn run( namespaces: NamespaceStore, connector: C, disable_metrics: bool, + shutdown: Arc, ) -> anyhow::Result<()> where A: crate::net::Accept, @@ -124,6 +126,7 @@ where hyper::server::Server::builder(acceptor) .serve(router.into_make_service()) + .with_graceful_shutdown(shutdown.notified()) .await .context("Could not bind admin HTTP API server")?; Ok(()) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index c6f1390340..bc4edc77dd 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -21,7 +21,7 @@ use hyper::{header, Body, Request, Response, StatusCode}; use serde::de::DeserializeOwned; use serde::Serialize; use serde_json::Number; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, Notify}; use tokio::task::JoinSet; use tonic::transport::Server; use tower_http::trace::DefaultOnResponse; @@ -237,6 +237,7 @@ pub struct UserApi { pub enable_console: bool, pub self_url: Option, pub path: Arc, + pub shutdown: Arc, } impl UserApi @@ -441,6 +442,7 @@ where join_set.spawn(async move { hyper::server::Server::builder(acceptor) .serve(h2c) + .with_graceful_shutdown(self.shutdown.notified()) .await .context("http server")?; Ok(()) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index a66b95fecd..0f5ffddd11 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -131,6 +131,7 @@ struct Services { db_config: DbConfig, auth: Arc, path: Arc, + shutdown: Arc, } impl Services @@ -156,6 +157,7 @@ where enable_console: self.user_api_config.enable_http_console, self_url: self.user_api_config.self_url, path: self.path.clone(), + shutdown: self.shutdown.clone(), }; let user_http_service = user_http.configure(join_set); @@ -166,12 +168,14 @@ where disable_metrics, }) = self.admin_api_config { + let shutdown = self.shutdown.clone(); join_set.spawn(http::admin::run( acceptor, user_http_service, self.namespaces, connector, disable_metrics, + shutdown, )); } } @@ -398,6 +402,7 @@ where db_config: self.db_config, auth, path: self.path.clone(), + shutdown: self.shutdown.clone(), }; services.configure(&mut join_set); @@ -433,6 +438,7 @@ where db_config: self.db_config, auth, path: self.path.clone(), + shutdown: self.shutdown.clone(), }; services.configure(&mut join_set);