diff --git a/Cargo.lock b/Cargo.lock index cc7a07a..c5c2a8f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -421,6 +421,12 @@ version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e825f6987101665dea6ec934c09ec6d721de7bc1bf92248e1d5810c8cd636b77" +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "futures" version = "0.3.28" @@ -553,6 +559,36 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "h2" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2mux" +version = "0.1.0" +dependencies = [ + "bytes", + "h2", + "http", + "tokio", + "tracing", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -592,6 +628,17 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ebdb29d2ea9ed0083cd8cece49bbd968021bd99b0849edb4a9a7ee0fdf6a4e0" +[[package]] +name = "http" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "httparse" version = "1.8.0" @@ -1164,18 +1211,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.166" +version = "1.0.167" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d01b7404f9d441d3ad40e6a636a7782c377d2abdbe4fa2440e2edcc2f4f10db8" +checksum = "7daf513456463b42aa1d94cff7e0c24d682b429f020b9afa4f5ba5c40a22b237" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.166" +version = "1.0.167" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dd83d6dde2b6b2d466e14d9d1acce8816dedee94f735eac6395808b3483c6d6" +checksum = "b69b106b68bc8054f0e974e70d19984040f8a5cf9215ca82626ea4853f82c4b9" dependencies = [ "proc-macro2", "quote", @@ -1427,18 +1474,18 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.41" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c16a64ba9387ef3fdae4f9c1a7f07a0997fce91985c0336f1ddc1822b3b37802" +checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.41" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d14928354b01c4d6a4f0e549069adef399a284e7995c7ccca94e8a07a5346c59" +checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" dependencies = [ "proc-macro2", "quote", @@ -1486,6 +1533,20 @@ dependencies = [ "syn 2.0.23", ] +[[package]] +name = "tokio-util" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + [[package]] name = "tracing" version = "0.1.37" diff --git a/Cargo.toml b/Cargo.toml index ed1f320..44804ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,8 @@ members = [ "tunnel", - "shuttle" + "shuttle", + "h2mux" ] [patch.crates-io] diff --git a/h2mux/Cargo.toml b/h2mux/Cargo.toml index 0fb760c..a80d4df 100644 --- a/h2mux/Cargo.toml +++ b/h2mux/Cargo.toml @@ -10,3 +10,4 @@ bytes = "1.4.0" h2 = "0.3" http = "0.2.9" tokio = "1.29" +tracing = "0.1.37" diff --git a/h2mux/src/lib.rs b/h2mux/src/lib.rs new file mode 100644 index 0000000..48636ce --- /dev/null +++ b/h2mux/src/lib.rs @@ -0,0 +1,455 @@ +use bytes::{Buf, Bytes}; +use h2::{Reason, RecvStream, SendStream}; +use http::header::{HeaderName, CONNECTION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE}; +use http::HeaderMap; +// use pin_project_lite::pin_project; +use std::error::Error as StdError; +use std::io::{self, Cursor, IoSlice}; +use std::mem; +use std::task::{Context, ready, Poll, self}; +use std::pin::Pin; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +// use tracing::{debug, trace, warn}; + +mod utils; +mod ping; + +// use crate::body::Body; +// use crate::common::{task, Future, Pin, Poll}; +use crate::ping::Recorder; + +// pub(crate) mod ping; + +use crate::utils::H2MapIoErr; + + + +// cfg_client! { +// pub(crate) mod client; +// pub(crate) use self::client::ClientTask; +// } + +// cfg_server! { +// pub(crate) mod server; +// pub(crate) use self::server::Server; +// } + +// /// Default initial stream window size defined in HTTP2 spec. +// pub(crate) const SPEC_WINDOW_SIZE: u32 = 65_535; + +// // List of connection headers from: +// // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection +// // +// // TE headers are allowed in HTTP/2 requests as long as the value is "trailers", so they're +// // tested separately. +// const CONNECTION_HEADERS: [HeaderName; 5] = [ +// HeaderName::from_static("keep-alive"), +// HeaderName::from_static("proxy-connection"), +// TRAILER, +// TRANSFER_ENCODING, +// UPGRADE, +// ]; + +// fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { +// for header in &CONNECTION_HEADERS { +// if headers.remove(header).is_some() { +// warn!("Connection header illegal in HTTP/2: {}", header.as_str()); +// } +// } + +// if is_request { +// if headers +// .get(TE) +// .map(|te_header| te_header != "trailers") +// .unwrap_or(false) +// { +// warn!("TE headers not set to \"trailers\" are illegal in HTTP/2 requests"); +// headers.remove(TE); +// } +// } else if headers.remove(TE).is_some() { +// warn!("TE headers illegal in HTTP/2 responses"); +// } + +// if let Some(header) = headers.remove(CONNECTION) { +// warn!( +// "Connection header illegal in HTTP/2: {}", +// CONNECTION.as_str() +// ); +// let header_contents = header.to_str().unwrap(); + +// // A `Connection` header may have a comma-separated list of names of other headers that +// // are meant for only this specific connection. +// // +// // Iterate these names and remove them as headers. Connection-specific headers are +// // forbidden in HTTP2, as that information has been moved into frame types of the h2 +// // protocol. +// for name in header_contents.split(',') { +// let name = name.trim(); +// headers.remove(name); +// } +// } +// } + +// // body adapters used by both Client and Server + +// pin_project! { +// pub(crate) struct PipeToSendStream +// where +// S: Body, +// { +// body_tx: SendStream>, +// data_done: bool, +// #[pin] +// stream: S, +// } +// } + +// impl PipeToSendStream +// where +// S: Body, +// { +// fn new(stream: S, tx: SendStream>) -> PipeToSendStream { +// PipeToSendStream { +// body_tx: tx, +// data_done: false, +// stream, +// } +// } +// } + +// impl Future for PipeToSendStream +// where +// S: Body, +// S::Error: Into>, +// { +// type Output = crate::Result<()>; + +// fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { +// let mut me = self.project(); +// loop { +// // we don't have the next chunk of data yet, so just reserve 1 byte to make +// // sure there's some capacity available. h2 will handle the capacity management +// // for the actual body chunk. +// me.body_tx.reserve_capacity(1); + +// if me.body_tx.capacity() == 0 { +// loop { +// match ready!(me.body_tx.poll_capacity(cx)) { +// Some(Ok(0)) => {} +// Some(Ok(_)) => break, +// Some(Err(e)) => return Poll::Ready(Err(crate::Error::new_body_write(e))), +// None => { +// // None means the stream is no longer in a +// // streaming state, we either finished it +// // somehow, or the remote reset us. +// return Poll::Ready(Err(crate::Error::new_body_write( +// "send stream capacity unexpectedly closed", +// ))); +// } +// } +// } +// } else if let Poll::Ready(reason) = me +// .body_tx +// .poll_reset(cx) +// .map_err(crate::Error::new_body_write)? +// { +// debug!("stream received RST_STREAM: {:?}", reason); +// return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(reason)))); +// } + +// match ready!(me.stream.as_mut().poll_frame(cx)) { +// Some(Ok(frame)) => { +// if frame.is_data() { +// let chunk = frame.into_data().unwrap_or_else(|_| unreachable!()); +// let is_eos = me.stream.is_end_stream(); +// trace!( +// "send body chunk: {} bytes, eos={}", +// chunk.remaining(), +// is_eos, +// ); + +// let buf = SendBuf::Buf(chunk); +// me.body_tx +// .send_data(buf, is_eos) +// .map_err(crate::Error::new_body_write)?; + +// if is_eos { +// return Poll::Ready(Ok(())); +// } +// } else if frame.is_trailers() { +// // no more DATA, so give any capacity back +// me.body_tx.reserve_capacity(0); +// me.body_tx +// .send_trailers(frame.into_trailers().unwrap_or_else(|_| unreachable!())) +// .map_err(crate::Error::new_body_write)?; +// return Poll::Ready(Ok(())); +// } else { +// trace!("discarding unknown frame"); +// // loop again +// } +// } +// Some(Err(e)) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), +// None => { +// // no more frames means we're done here +// // but at this point, we haven't sent an EOS DATA, or +// // any trailers, so send an empty EOS DATA. +// return Poll::Ready(me.body_tx.send_eos_frame()); +// } +// } +// } +// } +// } + +// trait SendStreamExt { +// fn on_user_err(&mut self, err: E) -> crate::Error +// where +// E: Into>; +// fn send_eos_frame(&mut self) -> crate::Result<()>; +// } + +// impl SendStreamExt for SendStream> { +// fn on_user_err(&mut self, err: E) -> crate::Error +// where +// E: Into>, +// { +// let err = crate::Error::new_user_body(err); +// debug!("send body user stream error: {}", err); +// self.send_reset(err.h2_reason()); +// err +// } + +// fn send_eos_frame(&mut self) -> crate::Result<()> { +// trace!("send body eos"); +// self.send_data(SendBuf::None, true) +// .map_err(crate::Error::new_body_write) +// } +// } + +#[repr(usize)] +enum SendBuf { + Buf(B), + Cursor(Cursor>), + None, +} + +impl Buf for SendBuf { + #[inline] + fn remaining(&self) -> usize { + match *self { + Self::Buf(ref b) => b.remaining(), + Self::Cursor(ref c) => Buf::remaining(c), + Self::None => 0, + } + } + + #[inline] + fn chunk(&self) -> &[u8] { + match *self { + Self::Buf(ref b) => b.chunk(), + Self::Cursor(ref c) => c.chunk(), + Self::None => &[], + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + match *self { + Self::Buf(ref mut b) => b.advance(cnt), + Self::Cursor(ref mut c) => c.advance(cnt), + Self::None => {} + } + } + + fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { + match *self { + Self::Buf(ref b) => b.chunks_vectored(dst), + Self::Cursor(ref c) => c.chunks_vectored(dst), + Self::None => 0, + } + } +} + +struct H2Upgraded +where + B: Buf, +{ + ping: Recorder, + send_stream: UpgradedSendStream, + recv_stream: RecvStream, + buf: Bytes, +} + +impl AsyncRead for H2Upgraded +where + B: Buf, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.buf.is_empty() { + self.buf = loop { + match ready!(self.recv_stream.poll_data(cx)) { + None => return Poll::Ready(Ok(())), + Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => { + continue + } + Some(Ok(buf)) => { + self.ping.record_data(buf.len()); + break buf; + } + Some(Err(e)) => { + return Poll::Ready(match e.reason() { + Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), + Some(Reason::STREAM_CLOSED) => { + Err(io::Error::new(io::ErrorKind::BrokenPipe, e)) + } + _ => Err(h2_to_io_error(e)), + }) + } + } + }; + } + let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); + read_buf.put_slice(&self.buf[..cnt]); + self.buf.advance(cnt); + let _ = self.recv_stream.flow_control().release_capacity(cnt); + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for H2Upgraded +where + B: Buf, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.send_stream.reserve_capacity(buf.len()); + + // We ignore all errors returned by `poll_capacity` and `write`, as we + // will get the correct from `poll_reset` anyway. + let cnt = match ready!(self.send_stream.poll_capacity(cx)) { + None => Some(0), + Some(Ok(cnt)) => self + .send_stream + .write(&buf[..cnt], false) + .ok() + .map(|()| cnt), + Some(Err(_)) => None, + }; + + if let Some(cnt) = cnt { + return Poll::Ready(Ok(cnt)); + } + + Poll::Ready(Err(h2_to_io_error( + match ready!(self.send_stream.poll_reset(cx)) { + Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } + Ok(reason) => reason.into(), + Err(e) => e, + }, + ))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if self.send_stream.write(&[], true).is_ok() { + return Poll::Ready(Ok(())); + } + + Poll::Ready(Err(h2_to_io_error( + match ready!(self.send_stream.poll_reset(cx)) { + Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())), + Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } + Ok(reason) => reason.into(), + Err(e) => e, + }, + ))) + } +} + +fn h2_to_io_error(e: h2::Error) -> io::Error { + if e.is_io() { + e.into_io().unwrap() + } else { + io::Error::new(io::ErrorKind::Other, e) + } +} + +struct UpgradedSendStream(SendStream>>); + +impl UpgradedSendStream +where + B: Buf, +{ + unsafe fn new(inner: SendStream>) -> Self { + assert_eq!(mem::size_of::(), mem::size_of::>()); + Self(mem::transmute(inner)) + } + + fn reserve_capacity(&mut self, cnt: usize) { + unsafe { self.as_inner_unchecked().reserve_capacity(cnt) } + } + + fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll>> { + unsafe { self.as_inner_unchecked().poll_capacity(cx) } + } + + fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll> { + unsafe { self.as_inner_unchecked().poll_reset(cx) } + } + + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); + unsafe { + self.as_inner_unchecked() + .send_data(send_buf, end_of_stream) + .map_err(h2_to_io_error) + } + } + + unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream> { + &mut *(&mut self.0 as *mut _ as *mut _) + } +} + +#[repr(transparent)] +struct Neutered { + _inner: B, + impossible: Impossible, +} + +enum Impossible {} + +unsafe impl Send for Neutered {} + +impl Buf for Neutered { + fn remaining(&self) -> usize { + match self.impossible {} + } + + fn chunk(&self) -> &[u8] { + match self.impossible {} + } + + fn advance(&mut self, _cnt: usize) { + match self.impossible {} + } +} diff --git a/h2mux/src/ping.rs b/h2mux/src/ping.rs new file mode 100644 index 0000000..297e0c7 --- /dev/null +++ b/h2mux/src/ping.rs @@ -0,0 +1,506 @@ +/// HTTP2 Ping usage +/// +/// hyper uses HTTP2 pings for two purposes: +/// +/// 1. Adaptive flow control using BDP +/// 2. Connection keep-alive +/// +/// Both cases are optional. +/// +/// # BDP Algorithm +/// +/// 1. When receiving a DATA frame, if a BDP ping isn't outstanding: +/// 1a. Record current time. +/// 1b. Send a BDP ping. +/// 2. Increment the number of received bytes. +/// 3. When the BDP ping ack is received: +/// 3a. Record duration from sent time. +/// 3b. Merge RTT with a running average. +/// 3c. Calculate bdp as bytes/rtt. +/// 3d. If bdp is over 2/3 max, set new max to bdp and update windows. +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{self, Poll}; +use std::time::{Duration, Instant}; + +use h2::{Ping, PingPong}; +use tracing::{debug, trace}; + +use crate::common::time::Time; +use crate::rt::Sleep; + +type WindowSize = u32; + +pub(super) fn disabled() -> Recorder { + Recorder { shared: None } +} + +pub(super) fn channel(ping_pong: PingPong, config: Config, __timer: Time) -> (Recorder, Ponger) { + debug_assert!( + config.is_enabled(), + "ping channel requires bdp or keep-alive config", + ); + + let bdp = config.bdp_initial_window.map(|wnd| Bdp { + bdp: wnd, + max_bandwidth: 0.0, + rtt: 0.0, + ping_delay: Duration::from_millis(100), + stable_count: 0, + }); + + let (bytes, next_bdp_at) = if bdp.is_some() { + (Some(0), Some(Instant::now())) + } else { + (None, None) + }; + + let keep_alive = config.keep_alive_interval.map(|interval| KeepAlive { + interval, + timeout: config.keep_alive_timeout, + while_idle: config.keep_alive_while_idle, + sleep: __timer.sleep(interval), + state: KeepAliveState::Init, + timer: __timer, + }); + + let last_read_at = keep_alive.as_ref().map(|_| Instant::now()); + + let shared = Arc::new(Mutex::new(Shared { + bytes, + last_read_at, + is_keep_alive_timed_out: false, + ping_pong, + ping_sent_at: None, + next_bdp_at, + })); + + ( + Recorder { + shared: Some(shared.clone()), + }, + Ponger { + bdp, + keep_alive, + shared, + }, + ) +} + +#[derive(Clone)] +pub(super) struct Config { + pub(super) bdp_initial_window: Option, + /// If no frames are received in this amount of time, a PING frame is sent. + pub(super) keep_alive_interval: Option, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + pub(super) keep_alive_timeout: Duration, + /// If true, sends pings even when there are no active streams. + pub(super) keep_alive_while_idle: bool, +} + +#[derive(Clone)] +pub(crate) struct Recorder { + shared: Option>>, +} + +pub(super) struct Ponger { + bdp: Option, + keep_alive: Option, + shared: Arc>, +} + +struct Shared { + ping_pong: PingPong, + ping_sent_at: Option, + + // bdp + /// If `Some`, bdp is enabled, and this tracks how many bytes have been + /// read during the current sample. + bytes: Option, + /// We delay a variable amount of time between BDP pings. This allows us + /// to send less pings as the bandwidth stabilizes. + next_bdp_at: Option, + + // keep-alive + /// If `Some`, keep-alive is enabled, and the Instant is how long ago + /// the connection read the last frame. + last_read_at: Option, + + is_keep_alive_timed_out: bool, +} + +struct Bdp { + /// Current BDP in bytes + bdp: u32, + /// Largest bandwidth we've seen so far. + max_bandwidth: f64, + /// Round trip time in seconds + rtt: f64, + /// Delay the next ping by this amount. + /// + /// This will change depending on how stable the current bandwidth is. + ping_delay: Duration, + /// The count of ping round trips where BDP has stayed the same. + stable_count: u32, +} + +struct KeepAlive { + /// If no frames are received in this amount of time, a PING frame is sent. + interval: Duration, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + timeout: Duration, + /// If true, sends pings even when there are no active streams. + while_idle: bool, + state: KeepAliveState, + sleep: Pin>, + timer: Time, +} + +enum KeepAliveState { + Init, + Scheduled(Instant), + PingSent, +} + +pub(super) enum Ponged { + SizeUpdate(WindowSize), + KeepAliveTimedOut, +} + +#[derive(Debug)] +pub(super) struct KeepAliveTimedOut; + +// ===== impl Config ===== + +impl Config { + pub(super) fn is_enabled(&self) -> bool { + self.bdp_initial_window.is_some() || self.keep_alive_interval.is_some() + } +} + +// ===== impl Recorder ===== + +impl Recorder { + pub(crate) fn record_data(&self, len: usize) { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + locked.update_last_read_at(); + + // are we ready to send another bdp ping? + // if not, we don't need to record bytes either + + if let Some(ref next_bdp_at) = locked.next_bdp_at { + if Instant::now() < *next_bdp_at { + return; + } else { + locked.next_bdp_at = None; + } + } + + if let Some(ref mut bytes) = locked.bytes { + *bytes += len; + } else { + // no need to send bdp ping if bdp is disabled + return; + } + + if !locked.is_ping_sent() { + locked.send_ping(); + } + } + + pub(crate) fn record_non_data(&self) { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + locked.update_last_read_at(); + } + + /// If the incoming stream is already closed, convert self into + /// a disabled reporter. + #[cfg(feature = "client")] + pub(super) fn for_stream(self, stream: &h2::RecvStream) -> Self { + if stream.is_end_stream() { + disabled() + } else { + self + } + } + + pub(super) fn ensure_not_timed_out(&self) -> crate::Result<()> { + if let Some(ref shared) = self.shared { + let locked = shared.lock().unwrap(); + if locked.is_keep_alive_timed_out { + return Err(KeepAliveTimedOut.crate_error()); + } + } + + // else + Ok(()) + } +} + +// ===== impl Ponger ===== + +impl Ponger { + pub(super) fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll { + let now = Instant::now(); + let mut locked = self.shared.lock().unwrap(); + let is_idle = self.is_idle(); + + if let Some(ref mut ka) = self.keep_alive { + ka.maybe_schedule(is_idle, &locked); + ka.maybe_ping(cx, &mut locked); + } + + if !locked.is_ping_sent() { + // XXX: this doesn't register a waker...? + return Poll::Pending; + } + + match locked.ping_pong.poll_pong(cx) { + Poll::Ready(Ok(_pong)) => { + let start = locked + .ping_sent_at + .expect("pong received implies ping_sent_at"); + locked.ping_sent_at = None; + let rtt = now - start; + trace!("recv pong"); + + if let Some(ref mut ka) = self.keep_alive { + locked.update_last_read_at(); + ka.maybe_schedule(is_idle, &locked); + ka.maybe_ping(cx, &mut locked); + } + + if let Some(ref mut bdp) = self.bdp { + let bytes = locked.bytes.expect("bdp enabled implies bytes"); + locked.bytes = Some(0); // reset + trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt); + + let update = bdp.calculate(bytes, rtt); + locked.next_bdp_at = Some(now + bdp.ping_delay); + if let Some(update) = update { + return Poll::Ready(Ponged::SizeUpdate(update)); + } + } + } + Poll::Ready(Err(e)) => { + debug!("pong error: {}", e); + } + Poll::Pending => { + if let Some(ref mut ka) = self.keep_alive { + if let Err(KeepAliveTimedOut) = ka.maybe_timeout(cx) { + self.keep_alive = None; + locked.is_keep_alive_timed_out = true; + return Poll::Ready(Ponged::KeepAliveTimedOut); + } + } + } + } + + // XXX: this doesn't register a waker...? + Poll::Pending + } + + fn is_idle(&self) -> bool { + Arc::strong_count(&self.shared) <= 2 + } +} + +// ===== impl Shared ===== + +impl Shared { + fn send_ping(&mut self) { + match self.ping_pong.send_ping(Ping::opaque()) { + Ok(()) => { + self.ping_sent_at = Some(Instant::now()); + trace!("sent ping"); + } + Err(err) => { + debug!("error sending ping: {}", err); + } + } + } + + fn is_ping_sent(&self) -> bool { + self.ping_sent_at.is_some() + } + + fn update_last_read_at(&mut self) { + if self.last_read_at.is_some() { + self.last_read_at = Some(Instant::now()); + } + } + + fn last_read_at(&self) -> Instant { + self.last_read_at.expect("keep_alive expects last_read_at") + } +} + +// ===== impl Bdp ===== + +/// Any higher than this likely will be hitting the TCP flow control. +const BDP_LIMIT: usize = 1024 * 1024 * 16; + +impl Bdp { + fn calculate(&mut self, bytes: usize, rtt: Duration) -> Option { + // No need to do any math if we're at the limit. + if self.bdp as usize == BDP_LIMIT { + self.stabilize_delay(); + return None; + } + + // average the rtt + let rtt = seconds(rtt); + if self.rtt == 0.0 { + // First sample means rtt is first rtt. + self.rtt = rtt; + } else { + // Weigh this rtt as 1/8 for a moving average. + self.rtt += (rtt - self.rtt) * 0.125; + } + + // calculate the current bandwidth + let bw = (bytes as f64) / (self.rtt * 1.5); + trace!("current bandwidth = {:.1}B/s", bw); + + if bw < self.max_bandwidth { + // not a faster bandwidth, so don't update + self.stabilize_delay(); + return None; + } else { + self.max_bandwidth = bw; + } + + // if the current `bytes` sample is at least 2/3 the previous + // bdp, increase to double the current sample. + if bytes >= self.bdp as usize * 2 / 3 { + self.bdp = (bytes * 2).min(BDP_LIMIT) as WindowSize; + trace!("BDP increased to {}", self.bdp); + + self.stable_count = 0; + self.ping_delay /= 2; + Some(self.bdp) + } else { + self.stabilize_delay(); + None + } + } + + fn stabilize_delay(&mut self) { + if self.ping_delay < Duration::from_secs(10) { + self.stable_count += 1; + + if self.stable_count >= 2 { + self.ping_delay *= 4; + self.stable_count = 0; + } + } + } +} + +fn seconds(dur: Duration) -> f64 { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + let secs = dur.as_secs() as f64; + secs + (dur.subsec_nanos() as f64) / NANOS_PER_SEC +} + +// ===== impl KeepAlive ===== + +impl KeepAlive { + fn maybe_schedule(&mut self, is_idle: bool, shared: &Shared) { + match self.state { + KeepAliveState::Init => { + if !self.while_idle && is_idle { + return; + } + + self.schedule(shared); + } + KeepAliveState::PingSent => { + if shared.is_ping_sent() { + return; + } + self.schedule(shared); + } + KeepAliveState::Scheduled(..) => (), + } + } + + fn schedule(&mut self, shared: &Shared) { + let interval = shared.last_read_at() + self.interval; + self.state = KeepAliveState::Scheduled(interval); + self.timer.reset(&mut self.sleep, interval); + } + + fn maybe_ping(&mut self, cx: &mut task::Context<'_>, shared: &mut Shared) { + match self.state { + KeepAliveState::Scheduled(at) => { + if Pin::new(&mut self.sleep).poll(cx).is_pending() { + return; + } + // check if we've received a frame while we were scheduled + if shared.last_read_at() + self.interval > at { + self.state = KeepAliveState::Init; + cx.waker().wake_by_ref(); // schedule us again + return; + } + trace!("keep-alive interval ({:?}) reached", self.interval); + shared.send_ping(); + self.state = KeepAliveState::PingSent; + let timeout = Instant::now() + self.timeout; + self.timer.reset(&mut self.sleep, timeout); + } + KeepAliveState::Init | KeepAliveState::PingSent => (), + } + } + + fn maybe_timeout(&mut self, cx: &mut task::Context<'_>) -> Result<(), KeepAliveTimedOut> { + match self.state { + KeepAliveState::PingSent => { + if Pin::new(&mut self.sleep).poll(cx).is_pending() { + return Ok(()); + } + trace!("keep-alive timeout ({:?}) reached", self.timeout); + Err(KeepAliveTimedOut) + } + KeepAliveState::Init | KeepAliveState::Scheduled(..) => Ok(()), + } + } +} + +// ===== impl KeepAliveTimedOut ===== + +impl KeepAliveTimedOut { + pub(super) fn crate_error(self) -> crate::Error { + crate::Error::new(crate::error::Kind::Http2).with(self) + } +} + +impl fmt::Display for KeepAliveTimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("keep-alive timed out") + } +} + +impl std::error::Error for KeepAliveTimedOut { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&crate::error::TimedOut) + } +} diff --git a/h2mux/src/utils.rs b/h2mux/src/utils.rs new file mode 100644 index 0000000..b0f6d59 --- /dev/null +++ b/h2mux/src/utils.rs @@ -0,0 +1,22 @@ +use std::io; + +use h2; + +pub trait H2MapIoErr { + fn map_io_err(self) -> Result; +} + +impl H2MapIoErr for Result { + fn map_io_err(self) -> Result { + match self { + Ok(ok) => Ok(ok), + Err(err) => { + if err.is_io() { + Err(err.into_io().unwrap()) + } else { + Err(io::Error::new(io::ErrorKind::InvalidData, err.to_string())) + } + } + } + } +}