From fb14df681bec1435c5e170a95f779fb53aa68ce0 Mon Sep 17 00:00:00 2001 From: ramiroaisen <52116153+ramiroaisen@users.noreply.github.com> Date: Sat, 23 Dec 2023 20:13:25 -0300 Subject: [PATCH] feat: ws stats --- Cargo.lock | 178 +++++++++--- defs/db/WsStatsConnection.ts | 17 ++ .../api/ws/stats/connection/WS/ClientEvent.ts | 3 + .../api/ws/stats/connection/WS/Query.ts | 8 + .../api/ws/stats/connection/WS/ServerEvent.ts | 6 + rs/bin/openstream/src/main.rs | 15 ++ rs/packages/api/src/lib.rs | 1 + rs/packages/api/src/ws_stats/mod.rs | 110 ++++++++ .../api/src/ws_stats/routes/connection.rs | 253 ++++++++++++++++++ rs/packages/api/src/ws_stats/routes/mod.rs | 17 ++ rs/packages/config/src/lib.rs | 14 + rs/packages/db/src/models/mod.rs | 2 + .../db/src/models/ws_stats_connection/mod.rs | 92 +++++++ rs/packages/prex/Cargo.toml | 5 + rs/packages/prex/src/lib.rs | 1 + rs/packages/prex/src/ws.rs | 168 ++++++++++++ 16 files changed, 859 insertions(+), 31 deletions(-) create mode 100644 defs/db/WsStatsConnection.ts create mode 100644 defs/ws-stats/api/ws/stats/connection/WS/ClientEvent.ts create mode 100644 defs/ws-stats/api/ws/stats/connection/WS/Query.ts create mode 100644 defs/ws-stats/api/ws/stats/connection/WS/ServerEvent.ts create mode 100644 rs/packages/api/src/ws_stats/mod.rs create mode 100644 rs/packages/api/src/ws_stats/routes/connection.rs create mode 100644 rs/packages/api/src/ws_stats/routes/mod.rs create mode 100644 rs/packages/db/src/models/ws_stats_connection/mod.rs create mode 100644 rs/packages/prex/src/ws.rs diff --git a/Cargo.lock b/Cargo.lock index 7e14611c..7f25cb41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -155,7 +155,7 @@ dependencies = [ "hex", "http 0.1.0", "http-range", - "hyper", + "hyper 0.14.27", "lang", "lazy-regex", "log", @@ -202,7 +202,7 @@ dependencies = [ "derive_more", "geoip", "http 0.2.9", - "hyper", + "hyper 0.14.27", "macros", "mongodb", "serde", @@ -276,7 +276,7 @@ dependencies = [ "base64-compat", "futures-util", "http 0.1.0", - "hyper", + "hyper 0.14.27", "log", "mime_guess", "owo-colors 3.5.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -394,8 +394,8 @@ dependencies = [ "bytes", "futures-util", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.27", "itoa 1.0.5", "matchit", "memchr", @@ -420,7 +420,7 @@ dependencies = [ "bytes", "futures-util", "http 0.2.9", - "http-body", + "http-body 0.4.5", "mime", "rustversion", "tower-layer", @@ -779,7 +779,7 @@ dependencies = [ "dotenv", "ffmpeg", "futures 0.3.25", - "hyper", + "hyper 0.14.27", "jemallocator", "lazy_static", "rand 0.8.5", @@ -1337,7 +1337,7 @@ dependencies = [ "futures-util", "geoip", "human_bytes", - "hyper", + "hyper 0.14.27", "image", "indexmap 1.9.2", "lang", @@ -2332,7 +2332,7 @@ dependencies = [ name = "hls-logger" version = "0.1.0" dependencies = [ - "hyper", + "hyper 0.14.27", "log", "logger", "prex", @@ -2405,7 +2405,7 @@ name = "http" version = "0.1.0" dependencies = [ "db", - "hyper", + "hyper 0.14.27", "mongodb", "pin-project", "prex", @@ -2427,6 +2427,17 @@ dependencies = [ "itoa 1.0.5", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa 1.0.5", +] + [[package]] name = "http-auth-basic" version = "0.3.3" @@ -2456,6 +2467,29 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.0.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "pin-project-lite", +] + [[package]] name = "http-range" version = "0.1.5" @@ -2507,7 +2541,7 @@ dependencies = [ "futures-util", "h2", "http 0.2.9", - "http-body", + "http-body 0.4.5", "httparse", "httpdate", "itoa 1.0.5", @@ -2519,6 +2553,19 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75" +dependencies = [ + "bytes", + "http 1.0.0", + "http-body 1.0.0", + "pin-project-lite", + "tokio", +] + [[package]] name = "hyper-rustls" version = "0.24.0" @@ -2526,7 +2573,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7" dependencies = [ "http 0.2.9", - "hyper", + "hyper 0.14.27", "rustls", "tokio", "tokio-rustls", @@ -2538,7 +2585,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper", + "hyper 0.14.27", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -2551,12 +2598,45 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.27", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-tungstenite" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad" +dependencies = [ + "http-body-util", + "hyper 1.1.0", + "hyper-util", + "pin-project-lite", + "tokio", + "tokio-tungstenite", + "tungstenite", +] + +[[package]] +name = "hyper-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdea9aac0dbe5a9240d68cfd9501e2db94222c6dc06843e06640b9e07f0fdc67" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "hyper 1.1.0", + "pin-project-lite", + "socket2 0.5.3", + "tokio", + "tracing", +] + [[package]] name = "iana-time-zone" version = "0.1.53" @@ -3158,7 +3238,7 @@ checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" name = "macros" version = "0.1.0" dependencies = [ - "hyper", + "hyper 0.14.27", "macros-build", "once_cell", "parking_lot 0.12.1", @@ -3187,7 +3267,7 @@ dependencies = [ "async-trait", "css-inline", "html2text", - "hyper", + "hyper 0.14.27", "lettre", "nanohtml2text", "prex", @@ -3279,7 +3359,7 @@ dependencies = [ "drop-tracer", "ffmpeg", "futures-util", - "hyper", + "hyper 0.14.27", "log", "mongodb", "parking_lot 0.12.1", @@ -3439,7 +3519,7 @@ version = "0.1.0" dependencies = [ "bytes", "ffmpeg", - "hyper", + "hyper 0.14.27", "prex", "tokio", ] @@ -3756,7 +3836,7 @@ dependencies = [ "dotenv", "drop-tracer", "futures 0.3.25", - "hyper", + "hyper 0.14.27", "jemallocator", "local-ip-address", "log", @@ -4069,9 +4149,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -4141,9 +4221,12 @@ dependencies = [ "constants", "futures 0.3.25", "http-auth-basic", - "hyper", + "hyper 0.14.27", + "hyper-tungstenite", + "hyper-util", "ip_rfc", "log", + "pin-project-lite", "regex", "serde", "serde_json", @@ -4151,7 +4234,9 @@ dependencies = [ "test-util", "thiserror", "tokio", + "tokio-tungstenite", "tower", + "tungstenite", ] [[package]] @@ -4211,7 +4296,7 @@ dependencies = [ "dotenv", "ffmpeg", "futures 0.1.31", - "hyper", + "hyper 0.14.27", "multiqueue", "reqwest", "serde_json", @@ -4611,8 +4696,8 @@ dependencies = [ "futures-util", "h2", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.27", "hyper-rustls", "hyper-tls", "ipnet", @@ -4724,7 +4809,7 @@ dependencies = [ "db", "futures 0.3.25", "http 0.1.0", - "hyper", + "hyper 0.14.27", "log", "mongodb", "owo-colors 3.5.0", @@ -5099,7 +5184,7 @@ dependencies = [ "base64 0.13.1", "bytes", "chrono", - "hyper", + "hyper 0.14.27", "iso8601-timestamp", "log", "mongodb", @@ -5447,7 +5532,7 @@ dependencies = [ "drop-tracer", "geoip", "http-basic-auth", - "hyper", + "hyper 0.14.27", "lazy-regex", "log", "media", @@ -5564,7 +5649,7 @@ dependencies = [ "drop-tracer", "futures 0.3.25", "http 0.1.0", - "hyper", + "hyper 0.14.27", "ip-counter", "ip_rfc", "log", @@ -6124,6 +6209,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.4" @@ -6197,8 +6294,8 @@ dependencies = [ "futures-util", "h2", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.27", "hyper-timeout", "percent-encoding", "pin-project", @@ -6383,6 +6480,25 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd" +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.0.0", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typed-arena" version = "2.0.1" diff --git a/defs/db/WsStatsConnection.ts b/defs/db/WsStatsConnection.ts new file mode 100644 index 00000000..1e55a7fb --- /dev/null +++ b/defs/db/WsStatsConnection.ts @@ -0,0 +1,17 @@ +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { CountryCode } from "../CountryCode"; +import type { DateTime } from "../DateTime"; + +export type WsStatsConnection = { + _id: string; + st: string; + dp: string; + du: number | null; + op: boolean; + cc: CountryCode | null; + ip: string; + ap: string | null; + av: number | null; + ca: DateTime; + cl: DateTime | null; +}; diff --git a/defs/ws-stats/api/ws/stats/connection/WS/ClientEvent.ts b/defs/ws-stats/api/ws/stats/connection/WS/ClientEvent.ts new file mode 100644 index 00000000..d0bc256b --- /dev/null +++ b/defs/ws-stats/api/ws/stats/connection/WS/ClientEvent.ts @@ -0,0 +1,3 @@ +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. + +export type ClientEvent = { kind: "ping" } | { kind: "pong" }; diff --git a/defs/ws-stats/api/ws/stats/connection/WS/Query.ts b/defs/ws-stats/api/ws/stats/connection/WS/Query.ts new file mode 100644 index 00000000..dc620975 --- /dev/null +++ b/defs/ws-stats/api/ws/stats/connection/WS/Query.ts @@ -0,0 +1,8 @@ +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. + +export type Query = { + connection_id?: string; + station_id: string; + app_kind?: string; + app_version?: number; +}; diff --git a/defs/ws-stats/api/ws/stats/connection/WS/ServerEvent.ts b/defs/ws-stats/api/ws/stats/connection/WS/ServerEvent.ts new file mode 100644 index 00000000..96644f44 --- /dev/null +++ b/defs/ws-stats/api/ws/stats/connection/WS/ServerEvent.ts @@ -0,0 +1,6 @@ +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. + +export type ServerEvent = + | { kind: "ping" } + | { kind: "pong" } + | ({ kind: "start" } & { connection_id: string }); diff --git a/rs/bin/openstream/src/main.rs b/rs/bin/openstream/src/main.rs index 690cc6d7..5e38f0f2 100644 --- a/rs/bin/openstream/src/main.rs +++ b/rs/bin/openstream/src/main.rs @@ -3,6 +3,7 @@ use std::process::ExitStatus; use std::sync::Arc; use api::storage::StorageServer; +use api::ws_stats::WsStatsServer; use clap::{Parser, Subcommand}; use config::Config; use db::access_token::{AccessToken, GeneratedBy}; @@ -438,6 +439,7 @@ async fn start_async(Start { config }: Start) -> Result<(), anyhow::Error> { ref assets, ref smtp, ref payments, + ref ws_stats, } = config.as_ref(); db::access_token::AccessToken::start_autoremove_job(); @@ -535,6 +537,19 @@ async fn start_async(Start { config }: Start) -> Result<(), anyhow::Error> { }.boxed()); } + if let Some(ws_stats_config) = ws_stats { + let ws_stats = WsStatsServer::new( + deployment.id.clone(), + ws_stats_config.addrs.clone(), + shutdown.clone(), + ); + let fut = ws_stats.start()?; + futs.push(async move { + fut.await.map_err(crate::error::ServerStartError::from)?; + Ok(()) + }.boxed()); + } + if let Some(static_config) = assets { let assets = StaticServer::new( static_config.addrs.clone(), diff --git a/rs/packages/api/src/lib.rs b/rs/packages/api/src/lib.rs index ddde9c0f..d2fa9563 100644 --- a/rs/packages/api/src/lib.rs +++ b/rs/packages/api/src/lib.rs @@ -6,6 +6,7 @@ pub mod qs; pub mod request_ext; pub mod routes; pub mod storage; +pub mod ws_stats; use payments::client::PaymentsClient; diff --git a/rs/packages/api/src/ws_stats/mod.rs b/rs/packages/api/src/ws_stats/mod.rs new file mode 100644 index 00000000..e5857707 --- /dev/null +++ b/rs/packages/api/src/ws_stats/mod.rs @@ -0,0 +1,110 @@ +pub mod routes; + +use futures::stream::FuturesUnordered; +use futures::TryStreamExt; +use hyper::Server; +use log::*; +use serde::{Deserialize, Serialize}; +use shutdown::Shutdown; +use socket2::{Domain, Protocol, Socket, Type}; +use std::future::Future; +use std::net::SocketAddr; + +#[derive(Debug)] +pub struct WsStatsServer { + deployment_id: String, + addrs: Vec, + shutdown: Shutdown, +} + +#[derive(Debug, thiserror::Error)] +pub enum WsStatsServerError { + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("hyper error: {0}")] + Hyper(#[from] hyper::Error), +} + +#[derive(Serialize, Deserialize, Debug, Clone, Copy)] +pub struct Status { + status: usize, +} + +impl WsStatsServer { + pub fn new(deployment_id: String, addrs: Vec, shutdown: Shutdown) -> Self { + Self { + deployment_id, + addrs, + shutdown, + } + } + + pub fn start( + self, + ) -> Result> + 'static, WsStatsServerError> { + let mut app = prex::prex(); + + app.with(http::middleware::server); + app.get("/status", http::middleware::status); + + app.at("/").nest(routes::router( + self.deployment_id.clone(), + self.shutdown.clone(), + )); + + let app = app.build().expect("ws stats server prex build"); + + let futs = FuturesUnordered::new(); + + for addr in self.addrs.iter().copied() { + let domain = match addr { + SocketAddr::V4(_) => Domain::IPV4, + SocketAddr::V6(_) => Domain::IPV6, + }; + + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + + if addr.is_ipv6() { + socket.set_only_v6(true)?; + } + + socket.set_nonblocking(true)?; + socket.set_reuse_address(true)?; + // socket.set_reuse_port(true)?; + + socket.bind(&addr.into())?; + socket.listen(1024)?; + + let tcp = socket.into(); + + let server = Server::from_tcp(tcp)? + .http1_only(true) + .http1_title_case_headers(false) + .http1_preserve_header_case(false) + .http1_keepalive(false); + + { + use owo_colors::*; + info!(target: "ws-stats", "ws-stats server bound to {}", addr.yellow()); + } + + let fut = server + .serve(app.clone()) + .with_graceful_shutdown(self.shutdown.signal()); + + futs.push(fut); + } + + Ok(async move { + futs.try_collect().await?; + drop(self); + Ok(()) + }) + } +} + +impl Drop for WsStatsServer { + fn drop(&mut self) { + info!(target: "ws-stats", "ws-stats server dropped"); + } +} diff --git a/rs/packages/api/src/ws_stats/routes/connection.rs b/rs/packages/api/src/ws_stats/routes/connection.rs new file mode 100644 index 00000000..f9ee34c1 --- /dev/null +++ b/rs/packages/api/src/ws_stats/routes/connection.rs @@ -0,0 +1,253 @@ +use db::{ws_stats_connection::WsStatsConnection, Model}; +use futures_util::{sink::SinkExt, stream::StreamExt}; +use hyper::{Body, StatusCode}; +use mongodb::bson::doc; +use prex::{ + handler::Handler, + ws::tungstenite::{error::ProtocolError, Message}, + Next, Request, Response, +}; +use serde::{Deserialize, Serialize}; +use serde_util::DateTime; +use shutdown::Shutdown; +use ts_rs::TS; + +#[derive(Debug, Clone)] +pub struct WsConnectionHandler { + pub deployment_id: String, + pub shutdown: Shutdown, +} + +#[derive(Debug, thiserror::Error)] +pub enum WsConnectionHandlerError { + #[error("expecting websocket request")] + NotWs, + + #[error("query string is invalid: {0}")] + InvalidQs(#[from] serde_qs::Error), + + #[error("websocket protocol error: {0}")] + ProtocolError(#[from] ProtocolError), +} + +impl From for Response { + fn from(err: WsConnectionHandlerError) -> Self { + let body = Body::from(format!("{}", err)); + + let status = match err { + WsConnectionHandlerError::NotWs => StatusCode::BAD_REQUEST, + WsConnectionHandlerError::InvalidQs(_) => StatusCode::BAD_REQUEST, + WsConnectionHandlerError::ProtocolError(_) => StatusCode::BAD_REQUEST, + }; + + let mut res = Response::new(status); + *res.body_mut() = body; + + res + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, TS)] +#[ts( + export, + export_to = "../../../defs/ws-stats/api/ws/stats/connection/WS/" +)] +pub struct Query { + #[ts(optional)] + #[serde(skip_serializing_if = "Option::is_none")] + connection_id: Option, + + station_id: String, + + #[ts(optional)] + #[serde(skip_serializing_if = "Option::is_none")] + app_kind: Option, + + #[ts(optional)] + #[serde(skip_serializing_if = "Option::is_none")] + app_version: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, TS)] +#[ts( + export, + export_to = "../../../defs/ws-stats/api/ws/stats/connection/WS/" +)] +#[serde(tag = "kind")] +pub enum ServerEvent { + #[serde(rename = "ping")] + Ping, + #[serde(rename = "pong")] + Pong, + #[serde(rename = "start")] + Start { connection_id: String }, +} +#[derive(Debug, Clone, Serialize, Deserialize, TS)] +#[ts( + export, + export_to = "../../../defs/ws-stats/api/ws/stats/connection/WS/" +)] +#[serde(tag = "kind")] +pub enum ClientEvent { + #[serde(rename = "ping")] + Ping, + #[serde(rename = "pong")] + Pong, +} + +impl WsConnectionHandler { + async fn handle(&self, mut req: Request) -> Result { + if !prex::ws::is_upgrade_request(&req) { + return Err(WsConnectionHandlerError::NotWs); + } + + let shutdown = self.shutdown.clone(); + let deployment_id = self.deployment_id.clone(); + let qs = req.qs::()?; + let ip = req.isomorphic_ip(); + let country_code = geoip::ip_to_country_code(&ip); + + let (res, stream_future) = prex::ws::upgrade(&mut req, None)?; + + tokio::spawn(async move { + let mut stream = match stream_future.await { + Ok(stream) => stream, + Err(_) => { + // TODO: log + return; + } + }; + + let Query { + connection_id: prev_id, + station_id, + app_kind, + app_version, + } = qs; + + let connection_id: String; + let created_at: DateTime; + + macro_rules! create { + () => {{ + connection_id = WsStatsConnection::uid(); + created_at = DateTime::now(); + + let connection = WsStatsConnection { + id: connection_id.clone(), + station_id: station_id.clone(), + deployment_id, + duration_ms: None, + is_open: true, + country_code, + ip, + app_kind: app_kind.clone(), + app_version, + created_at, + closed_at: None, + }; + + match WsStatsConnection::insert(&connection).await { + Ok(_) => {} + Err(_) => return, + }; + }}; + } + + match prev_id { + None => create!(), + + Some(prev_id) => match WsStatsConnection::get_by_id(&prev_id).await { + Err(_) => return, + + Ok(None) => create!(), + + Ok(Some(connection)) => { + connection_id = connection.id; + created_at = connection.created_at; + } + }, + } + + 'start: { + let start_message = serde_json::to_string(&ServerEvent::Start { + connection_id: connection_id.clone(), + }) + .unwrap(); + + let r = tokio::select! { + _ = shutdown.signal() => break 'start, + r = stream.send(Message::text(start_message)) => r + }; + + match r { + Ok(_) => {} + _ => break 'start, + } + + 'messages: loop { + let msg = tokio::select! { + _ = shutdown.signal() => { + break 'messages; + } + + msg = stream.next() => msg, + }; + + let msg = match msg { + Some(Ok(msg)) => msg, + _ => break 'messages, + }; + + match msg { + Message::Text(text) => { + let event = match serde_json::from_str::(&text) { + Ok(event) => event, + Err(_) => continue 'messages, + }; + + match event { + ClientEvent::Pong => {} + ClientEvent::Ping => { + let pong = serde_json::to_string(&ServerEvent::Pong).unwrap(); + let r = tokio::select! { + _ = shutdown.signal() => break 'messages, + r = stream.send(Message::text(pong)) => r + }; + + match r { + Ok(_) => {} + _ => break 'messages, + } + } + } + } + + _ => continue 'messages, + } + } + } + + let duration_ms = ((*DateTime::now() - *created_at).as_seconds_f64() * 1000.0).round(); + + let update = doc! { + "$set": { + WsStatsConnection::KEY_IS_OPEN: false, + WsStatsConnection::KEY_CLOSED_AT: DateTime::now(), + WsStatsConnection::KEY_DURATION_MS: duration_ms, + } + }; + + let _ = WsStatsConnection::update_by_id(&connection_id, update).await; + }); + + Ok(res) + } +} + +#[async_trait::async_trait] +impl Handler for WsConnectionHandler { + async fn call(&self, req: Request, _: Next) -> Response { + self.handle(req).await.into() + } +} diff --git a/rs/packages/api/src/ws_stats/routes/mod.rs b/rs/packages/api/src/ws_stats/routes/mod.rs new file mode 100644 index 00000000..5c69fba3 --- /dev/null +++ b/rs/packages/api/src/ws_stats/routes/mod.rs @@ -0,0 +1,17 @@ +pub mod connection; +use prex::router::builder::Builder; +use shutdown::Shutdown; + +pub fn router(deployment_id: String, shutdown: Shutdown) -> Builder { + let mut router = prex::prex(); + + router.get( + "/ws/stats/connection", + connection::WsConnectionHandler { + deployment_id: deployment_id.clone(), + shutdown: shutdown.clone(), + }, + ); + + router +} diff --git a/rs/packages/config/src/lib.rs b/rs/packages/config/src/lib.rs index cb4e61c0..04acaa0b 100644 --- a/rs/packages/config/src/lib.rs +++ b/rs/packages/config/src/lib.rs @@ -51,6 +51,10 @@ pub struct Config { #[garde(dive)] pub storage: Option, + #[config(nested)] + #[garde(dive)] + pub ws_stats: Option, + #[serde(rename = "static")] #[config(nested, rename = "static")] #[garde(dive)] @@ -111,6 +115,16 @@ pub struct Stream { pub addrs: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, MetreConfig, garde::Validate)] +#[serde(deny_unknown_fields)] +#[config(rename_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub struct WsStats { + #[config(parse_env = parse_addrs)] + #[garde(length(min = 1))] + pub addrs: Vec, +} + #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, MetreConfig, garde::Validate)] #[serde(deny_unknown_fields)] #[config(rename_all = "snake_case")] diff --git a/rs/packages/db/src/models/mod.rs b/rs/packages/db/src/models/mod.rs index 7f60a2b3..4e968a13 100644 --- a/rs/packages/db/src/models/mod.rs +++ b/rs/packages/db/src/models/mod.rs @@ -32,3 +32,5 @@ pub mod payment_method; pub mod account_invitations; pub mod probe; + +pub mod ws_stats_connection; diff --git a/rs/packages/db/src/models/ws_stats_connection/mod.rs b/rs/packages/db/src/models/ws_stats_connection/mod.rs new file mode 100644 index 00000000..9f5a0162 --- /dev/null +++ b/rs/packages/db/src/models/ws_stats_connection/mod.rs @@ -0,0 +1,92 @@ +use crate::Model; +use geoip::CountryCode; +use mongodb::bson::doc; +use mongodb::IndexModel; +use serde::{Deserialize, Serialize}; +use serde_util::DateTime; +use std::net::IpAddr; +use ts_rs::TS; + +crate::register!(WsStatsConnection); + +#[derive(Debug, Clone, Serialize, Deserialize, TS)] +#[ts(export, export_to = "../../../defs/db/")] +#[serde(rename_all = "snake_case")] +#[macros::keys] +pub struct WsStatsConnection { + #[serde(rename = "_id")] + pub id: String, + + #[serde(rename = "st")] + pub station_id: String, + + #[serde(rename = "dp")] + pub deployment_id: String, + // #[serde(with = "serde_util::as_f64::option")] + // pub transfer_bytes: Option, + #[serde(rename = "du")] + #[serde(with = "serde_util::as_f64::option")] + pub duration_ms: Option, + + #[serde(rename = "op")] + pub is_open: bool, + + #[serde(rename = "cc")] + pub country_code: Option, + + #[serde(rename = "ip")] + #[serde(with = "serde_util::ip")] + pub ip: IpAddr, + + #[serde(rename = "ap")] + pub app_kind: Option, + + #[serde(rename = "av")] + pub app_version: Option, + + #[serde(rename = "ca")] + pub created_at: DateTime, + // pub request: Request, + // pub last_transfer_at: DateTime, + #[serde(rename = "cl")] + pub closed_at: Option, +} + +impl WsStatsConnection { + pub const KEY_MANNUALLY_CLOSED: &'static str = "_m"; +} + +impl Model for WsStatsConnection { + const CL_NAME: &'static str = "ws_stats_connection"; + const UID_LEN: usize = 12; + + fn indexes() -> Vec { + let station_id = IndexModel::builder() + .keys(doc! { Self::KEY_STATION_ID: 1 }) + .build(); + + let created_at = IndexModel::builder() + .keys(doc! { Self::KEY_CREATED_AT: 1 }) + .build(); + + let created_at_station_id = IndexModel::builder() + .keys(doc! { Self::KEY_CREATED_AT: 1, Self::KEY_STATION_ID: 1 }) + .build(); + + let is_open = IndexModel::builder() + .keys(doc! { Self::KEY_IS_OPEN: 1 }) + .build(); + + vec![station_id, created_at, created_at_station_id, is_open] + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn keys_match() { + assert_eq!(crate::KEY_ID, WsStatsConnection::KEY_ID); + } +} diff --git a/rs/packages/prex/Cargo.toml b/rs/packages/prex/Cargo.toml index 598f7934..cdef0bab 100644 --- a/rs/packages/prex/Cargo.toml +++ b/rs/packages/prex/Cargo.toml @@ -20,6 +20,11 @@ bytes = "1.4.0" log = "0.4.17" serde_qs = "0.12.0" constants = { version = "0.1.0", path = "../../config/constants" } +hyper-tungstenite = "0.13.0" +tokio-tungstenite = "0.21.0" +tungstenite = "0.21.0" +pin-project-lite = "0.2.13" +hyper-util = "0.1.2" [dev-dependencies] test-util = { version = "0.1.0", path = "../test-util" } diff --git a/rs/packages/prex/src/lib.rs b/rs/packages/prex/src/lib.rs index 6ff3da24..c1a83cae 100644 --- a/rs/packages/prex/src/lib.rs +++ b/rs/packages/prex/src/lib.rs @@ -11,6 +11,7 @@ pub mod path; pub mod request; pub mod response; pub mod router; +pub mod ws; pub use app::prex; pub use next::Next; diff --git a/rs/packages/prex/src/ws.rs b/rs/packages/prex/src/ws.rs new file mode 100644 index 00000000..b7c9cad3 --- /dev/null +++ b/rs/packages/prex/src/ws.rs @@ -0,0 +1,168 @@ +use hyper::header::HeaderValue; +use hyper::header::CONNECTION; +use hyper::upgrade::OnUpgrade; +use hyper::Body; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tungstenite::handshake::derive_accept_key; +use tungstenite::protocol::{Role, WebSocketConfig}; +use tungstenite::{error::ProtocolError, Error}; + +pub use hyper; +pub use tungstenite; + +pub use tokio_tungstenite::WebSocketStream; + +/// A [`WebSocketStream`] that wraps an upgraded HTTP connection from hyper. +pub type HyperWebsocketStream = WebSocketStream; + +use crate::Request; +use crate::Response; + +pin_project! { + /// A future that resolves to a websocket stream when the associated HTTP upgrade completes. + #[derive(Debug)] + pub struct HyperWebsocket { + #[pin] + inner: hyper::upgrade::OnUpgrade, + config: Option, + } +} + +/// Try to upgrade a received `hyper::Request` to a websocket connection. +/// +/// The function returns a HTTP response and a future that resolves to the websocket stream. +/// The response body *MUST* be sent to the client before the future can be resolved. +/// +/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers. +/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers. +/// You can inspect the headers manually before calling this function, +/// and modify the response headers appropriately. +/// +/// This function also does not look at the `Connection` or `Upgrade` headers. +/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`]. +/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually. +/// +pub fn upgrade( + request: &mut Request, + config: Option, +) -> Result<(Response, HyperWebsocket), ProtocolError> { + let key = request + .headers() + .get("sec-websocket-key") + .ok_or(ProtocolError::MissingSecWebSocketKey)?; + if request + .headers() + .get("sec-websocket-version") + .map(|v| v.as_bytes()) + != Some(b"13") + { + return Err(ProtocolError::MissingSecWebSocketVersionHeader); + } + + let mut response = Response::new(hyper::StatusCode::SWITCHING_PROTOCOLS); + + response + .headers_mut() + .append(CONNECTION, HeaderValue::from_static("upgrade")); + + response.headers_mut().append( + hyper::header::UPGRADE, + HeaderValue::from_static("websocket"), + ); + + response.headers_mut().append( + "sec-websocket-accept", + HeaderValue::from_str(&derive_accept_key(key.as_bytes())).unwrap(), + ); + + *response.body_mut() = Body::from("switching to websocket protocol"); + + let on_upgrade = match request.extensions_mut().remove::() { + Some(x) => x, + None => return Err(ProtocolError::MissingConnectionUpgradeHeader), + }; + + let stream = HyperWebsocket { + inner: on_upgrade, + config, + }; + + Ok((response, stream)) +} + +/// Check if a request is a websocket upgrade request. +/// +/// If the `Upgrade` header lists multiple protocols, +/// this function returns true if of them are `"websocket"`, +/// If the server supports multiple upgrade protocols, +/// it would be more appropriate to try each listed protocol in order. +pub fn is_upgrade_request(request: &Request) -> bool { + header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade") + && header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket") +} + +/// Check if there is a header of the given name containing the wanted value. +fn header_contains_value( + headers: &hyper::HeaderMap, + header: impl hyper::header::AsHeaderName, + value: impl AsRef<[u8]>, +) -> bool { + let value = value.as_ref(); + for header in headers.get_all(header) { + if header + .as_bytes() + .split(|&c| c == b',') + .any(|x| trim(x).eq_ignore_ascii_case(value)) + { + return true; + } + } + false +} + +fn trim(data: &[u8]) -> &[u8] { + trim_end(trim_start(data)) +} + +fn trim_start(data: &[u8]) -> &[u8] { + if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) { + &data[start..] + } else { + b"" + } +} + +fn trim_end(data: &[u8]) -> &[u8] { + if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) { + &data[..last + 1] + } else { + b"" + } +} + +impl std::future::Future for HyperWebsocket { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + let upgraded = match this.inner.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => x, + }; + + let upgraded = upgraded.map_err(|_| Error::Protocol(ProtocolError::HandshakeIncomplete))?; + + let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, this.config.take()); + tokio::pin!(stream); + + // The future returned by `from_raw_socket` is always ready. + // Not sure why it is a future in the first place. + match stream.as_mut().poll(cx) { + Poll::Pending => unreachable!("from_raw_socket should always be created ready"), + Poll::Ready(x) => Poll::Ready(Ok(x)), + } + } +}