diff --git a/Cargo.lock b/Cargo.lock index e99641d2..c1f440bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1227,9 +1227,9 @@ dependencies = [ [[package]] name = "quictransport-quinn" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d65e68ccd0d35f48bbe9723a2924e6829b819cd165b19d7b2914f85ede01398" +checksum = "04227967142d740ffc66367bad009a6315e626c4830fcdb9d55904bd3ae1e3f9" dependencies = [ "bytes", "quinn", @@ -2201,19 +2201,18 @@ checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" [[package]] name = "webtransport-generic" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc1fd0d5c7e24e485aa58040fba18d6a4204d4354eca19d34b14540ecd9147b8" +checksum = "3796cc7d83f889b8fd4c1a731b08d83618ea1a3a2e3fe09225562754acc9b814" dependencies = [ "bytes", - "tokio", ] [[package]] name = "webtransport-proto" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebeada5037d6302980ae2e0ab8d840e329c1697c612c6c077172de2b7631a276" +checksum = "7de84935ba0f2292c5f78f042758fc4a0ce506699e674d059c517f56b04091be" dependencies = [ "bytes", "http", @@ -2223,9 +2222,9 @@ dependencies = [ [[package]] name = "webtransport-quinn" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27b0ad39e557756d066277901c3024e586aa3e026e4ee9b377f2d45d782e39ff" +checksum = "18ac46cd68286fd4a70bc69dcab62bafbbe407927ae643f3bfce26f27090b19a" dependencies = [ "bytes", "futures", diff --git a/moq-api/src/client.rs b/moq-api/src/client.rs index 5f07d110..905d8ea8 100644 --- a/moq-api/src/client.rs +++ b/moq-api/src/client.rs @@ -27,16 +27,16 @@ impl Client { Ok(Some(origin)) } - pub async fn set_origin(&mut self, id: &str, origin: &Origin) -> Result<(), ApiError> { + pub async fn set_origin(&self, id: &str, origin: Origin) -> Result<(), ApiError> { let url = self.url.join("origin/")?.join(id)?; - let resp = self.client.post(url).json(origin).send().await?; + let resp = self.client.post(url).json(&origin).send().await?; resp.error_for_status()?; Ok(()) } - pub async fn delete_origin(&mut self, id: &str) -> Result<(), ApiError> { + pub async fn delete_origin(&self, id: &str) -> Result<(), ApiError> { let url = self.url.join("origin/")?.join(id)?; let resp = self.client.delete(url).send().await?; @@ -45,10 +45,10 @@ impl Client { Ok(()) } - pub async fn patch_origin(&mut self, id: &str, origin: &Origin) -> Result<(), ApiError> { + pub async fn patch_origin(&self, id: &str, origin: Origin) -> Result<(), ApiError> { let url = self.url.join("origin/")?.join(id)?; - let resp = self.client.patch(url).json(origin).send().await?; + let resp = self.client.patch(url).json(&origin).send().await?; resp.error_for_status()?; Ok(()) diff --git a/moq-api/src/server.rs b/moq-api/src/server.rs index 8a05c16a..15b86e72 100644 --- a/moq-api/src/server.rs +++ b/moq-api/src/server.rs @@ -90,6 +90,18 @@ async fn set_origin( // Convert the input back to JSON after validating it add adding any fields (TODO) let payload = serde_json::to_string(&origin)?; + // Attempt to get the current value for the key + let current: Option = redis::cmd("GET").arg(&key).query_async(&mut redis).await?; + + if let Some(current) = ¤t { + if current.eq(&payload) { + // The value is the same, so we're done. + return Ok(()); + } else { + return Err(AppError::Duplicate); + } + } + let res: Option = redis::cmd("SET") .arg(key) .arg(payload) diff --git a/moq-clock/Cargo.toml b/moq-clock/Cargo.toml index 47e6fd70..63de6554 100644 --- a/moq-clock/Cargo.toml +++ b/moq-clock/Cargo.toml @@ -18,9 +18,9 @@ moq-transport = { path = "../moq-transport", version = "0.3" } # QUIC quinn = "0.10" -webtransport-quinn = "0.8" -webtransport-generic = "0.8" -quictransport-quinn = "0.8" +webtransport-quinn = { version = "0.9" } +webtransport-generic = { version = "0.9" } +quictransport-quinn = { version = "0.9" } url = "2" # Crypto diff --git a/moq-clock/src/clock.rs b/moq-clock/src/clock.rs index e372e934..fec26c13 100644 --- a/moq-clock/src/clock.rs +++ b/moq-clock/src/clock.rs @@ -1,14 +1,17 @@ use anyhow::Context; -use moq_transport::serve; +use moq_transport::serve::{ + DatagramsReader, Group, GroupWriter, GroupsReader, GroupsWriter, ObjectsReader, StreamReader, TrackReader, + TrackReaderMode, +}; use chrono::prelude::*; pub struct Publisher { - track: serve::TrackPublisher, + track: GroupsWriter, } impl Publisher { - pub fn new(track: serve::TrackPublisher) -> Self { + pub fn new(track: GroupsWriter) -> Self { Self { track } } @@ -22,9 +25,9 @@ impl Publisher { loop { let segment = self .track - .create_group(serve::Group { - id: sequence as u64, - send_order: 0, + .create(Group { + group_id: sequence as u64, + priority: 0, }) .context("failed to create minute segment")?; @@ -46,19 +49,15 @@ impl Publisher { } } - async fn send_segment(mut segment: serve::GroupPublisher, mut now: DateTime) -> anyhow::Result<()> { + async fn send_segment(mut segment: GroupWriter, mut now: DateTime) -> anyhow::Result<()> { // Everything but the second. let base = now.format("%Y-%m-%d %H:%M:").to_string(); - segment - .write_object(base.clone().into()) - .context("failed to write base")?; + segment.write(base.clone().into()).context("failed to write base")?; loop { let delta = now.format("%S").to_string(); - segment - .write_object(delta.clone().into()) - .context("failed to write delta")?; + segment.write(delta.clone().into()).context("failed to write delta")?; println!("{}{}", base, delta); @@ -79,83 +78,69 @@ impl Publisher { } } pub struct Subscriber { - track: serve::TrackSubscriber, + track: TrackReader, } impl Subscriber { - pub fn new(track: serve::TrackSubscriber) -> Self { + pub fn new(track: TrackReader) -> Self { Self { track } } - pub async fn run(mut self) -> anyhow::Result<()> { - while let Some(stream) = self.track.next().await.context("failed to get stream")? { - match stream { - serve::TrackMode::Group(group) => tokio::spawn(async move { - if let Err(err) = Self::recv_group(group).await { - log::warn!("failed to receive group: {:?}", err); - } - }), - serve::TrackMode::Object(object) => tokio::spawn(async move { - if let Err(err) = Self::recv_object(object).await { - log::warn!("failed to receive group: {:?}", err); - } - }), - serve::TrackMode::Stream(stream) => tokio::spawn(async move { - if let Err(err) = Self::recv_track(stream).await { - log::warn!("failed to receive stream: {:?}", err); - } - }), - serve::TrackMode::Datagram(datagram) => tokio::spawn(async move { - if let Err(err) = Self::recv_datagram(datagram) { - log::warn!("failed to receive datagram: {:?}", err); - } - }), - }; + pub async fn run(self) -> anyhow::Result<()> { + match self.track.mode().await.context("failed to get mode")? { + TrackReaderMode::Stream(stream) => Self::recv_stream(stream).await, + TrackReaderMode::Groups(groups) => Self::recv_groups(groups).await, + TrackReaderMode::Objects(objects) => Self::recv_objects(objects).await, + TrackReaderMode::Datagrams(datagrams) => Self::recv_datagrams(datagrams).await, } - - Ok(()) } - async fn recv_track(mut track: serve::StreamSubscriber) -> anyhow::Result<()> { - while let Some(fragment) = track.next().await? { - let str = String::from_utf8_lossy(&fragment.payload); - println!("{}", str); + async fn recv_stream(mut track: StreamReader) -> anyhow::Result<()> { + while let Some(mut group) = track.next().await? { + while let Some(object) = group.read_next().await? { + let str = String::from_utf8_lossy(&object); + println!("{}", str); + } } Ok(()) } - async fn recv_group(mut segment: serve::GroupSubscriber) -> anyhow::Result<()> { - let mut first = segment - .next() - .await - .context("failed to get first fragment")? - .context("no fragments in segment")?; - - let base = first.read_all().await?; - let base = String::from_utf8_lossy(&base); + async fn recv_groups(mut groups: GroupsReader) -> anyhow::Result<()> { + while let Some(mut group) = groups.next().await? { + let base = group + .read_next() + .await + .context("failed to get first object")? + .context("empty group")?; - while let Some(mut fragment) = segment.next().await? { - let value = fragment.read_all().await.context("failed to read fragment")?; - let str = String::from_utf8_lossy(&value); + let base = String::from_utf8_lossy(&base); - println!("{}{}", base, str); + while let Some(object) = group.read_next().await? { + let str = String::from_utf8_lossy(&object); + println!("{}{}", base, str); + } } Ok(()) } - async fn recv_object(mut object: serve::ObjectSubscriber) -> anyhow::Result<()> { - let value = object.read_all().await.context("failed to read object")?; - let str = String::from_utf8_lossy(&value); + async fn recv_objects(mut objects: ObjectsReader) -> anyhow::Result<()> { + while let Some(mut object) = objects.next().await? { + let payload = object.read_all().await?; + let str = String::from_utf8_lossy(&payload); + println!("{}", str); + } - println!("{}", str); Ok(()) } - fn recv_datagram(datagram: serve::Datagram) -> anyhow::Result<()> { - let str = String::from_utf8_lossy(&datagram.payload); - println!("{}", str); + async fn recv_datagrams(mut datagrams: DatagramsReader) -> anyhow::Result<()> { + while let Some(datagram) = datagrams.read().await? { + let str = String::from_utf8_lossy(&datagram.payload); + println!("{}", str); + } + Ok(()) } } diff --git a/moq-clock/src/main.rs b/moq-clock/src/main.rs index 85e9a605..4cb1778b 100644 --- a/moq-clock/src/main.rs +++ b/moq-clock/src/main.rs @@ -92,7 +92,7 @@ async fn main() -> anyhow::Result<()> { async fn run(session: S, config: cli::Config) -> anyhow::Result<()> { if config.publish { - let (session, publisher) = moq_transport::Publisher::connect(session) + let (session, mut publisher) = moq_transport::Publisher::connect(session) .await .context("failed to create MoQ Transport session")?; @@ -102,7 +102,7 @@ async fn run(session: S, config: cli::Config) .produce(); let track = broadcast.create_track(&config.track)?; - let clock = clock::Publisher::new(track); + let clock = clock::Publisher::new(track.groups()?); tokio::select! { res = session.run() => res.context("session error")?, diff --git a/moq-pub/Cargo.toml b/moq-pub/Cargo.toml index bc37c131..b3560238 100644 --- a/moq-pub/Cargo.toml +++ b/moq-pub/Cargo.toml @@ -18,9 +18,9 @@ moq-transport = { path = "../moq-transport", version = "0.3" } # QUIC quinn = "0.10" -webtransport-quinn = "0.8" -quictransport-quinn = "0.8" -webtransport-generic = "0.8" +webtransport-quinn = { version = "0.9" } +quictransport-quinn = { version = "0.9" } +webtransport-generic = { version = "0.9" } url = "2" # Crypto diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index bb42bbf8..662e71c6 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -12,7 +12,7 @@ use tokio::io::AsyncRead; // TODO: clap complete -#[tokio::main] +#[tokio::main(flavor = "current_thread")] async fn main() -> anyhow::Result<()> { env_logger::init(); @@ -99,9 +99,9 @@ async fn main() -> anyhow::Result<()> { async fn run( session: T, mut media: Media, - broadcast: serve::BroadcastSubscriber, + broadcast: serve::BroadcastReader, ) -> anyhow::Result<()> { - let (session, publisher) = moq_transport::Publisher::connect(session) + let (session, mut publisher) = moq_transport::Publisher::connect(session) .await .context("failed to create MoQ Transport publisher")?; diff --git a/moq-pub/src/media.rs b/moq-pub/src/media.rs index eadf295d..14a3091d 100644 --- a/moq-pub/src/media.rs +++ b/moq-pub/src/media.rs @@ -1,5 +1,5 @@ use anyhow::{self, Context}; -use moq_transport::serve::{BroadcastPublisher, Group, GroupPublisher, TrackPublisher}; +use moq_transport::serve::{BroadcastWriter, GroupWriter, GroupsWriter, TrackWriter}; use mp4::{self, ReadBox}; use serde_json::json; use std::cmp::max; @@ -15,7 +15,7 @@ pub struct Media { } impl Media { - pub async fn new(mut input: I, mut broadcast: BroadcastPublisher) -> anyhow::Result { + pub async fn new(mut input: I, mut broadcast: BroadcastWriter) -> anyhow::Result { let ftyp = read_atom(&mut input).await?; anyhow::ensure!(&ftyp[4..8] == b"ftyp", "expected ftyp atom"); @@ -34,11 +34,8 @@ impl Media { let moov = mp4::MoovBox::read_box(&mut moov_reader, moov_header.size)?; // Create the catalog track with a single segment. - let mut init_track = broadcast.create_track("0.mp4")?; - - init_track - .create_group(Group { id: 0, send_order: 0 })? - .write_object(init.into())?; + let mut init_track = broadcast.create_track("0.mp4")?.groups()?; + init_track.next(0)?.write(init.into())?; let mut tracks = HashMap::new(); @@ -54,10 +51,10 @@ impl Media { tracks.insert(id, track); } - let mut catalog = broadcast.create_track(".catalog")?; + let catalog = broadcast.create_track(".catalog")?; // Create the catalog track - Self::serve_catalog(&mut catalog, &init_track.name, &moov)?; + Self::serve_catalog(catalog, &init_track.name, &moov)?; Ok(Media { tracks, input }) } @@ -105,12 +102,8 @@ impl Media { } } - fn serve_catalog( - track: &mut TrackPublisher, - init_track_name: &str, - moov: &mp4::MoovBox, - ) -> Result<(), anyhow::Error> { - let mut segment = track.create_group(Group { id: 0, send_order: 0 })?; + fn serve_catalog(track: TrackWriter, init_track_name: &str, moov: &mp4::MoovBox) -> Result<(), anyhow::Error> { + let mut segment = track.groups()?.next(0)?; let mut tracks = Vec::new(); @@ -192,7 +185,7 @@ impl Media { log::info!("catalog: {}", catalog_str); // Create a single fragment for the segment. - segment.write_object(catalog_str.into())?; + segment.write(catalog_str.into())?; Ok(()) } @@ -237,23 +230,19 @@ async fn read_atom(reader: &mut R) -> anyhow::Result, + current: Option, // The number of units per second. timescale: u64, - - // The number of segments produced. - sequence: u64, } impl Track { - fn new(track: TrackPublisher, timescale: u64) -> Self { + fn new(track: TrackWriter, timescale: u64) -> Self { Self { - track, - sequence: 0, + track: track.groups().unwrap(), current: None, timescale, } @@ -263,7 +252,7 @@ impl Track { if let Some(current) = self.current.as_mut() { if !fragment.keyframe { // Use the existing segment - current.write_object(raw.into())?; + current.write(raw.into())?; return Ok(()); } } @@ -278,18 +267,13 @@ impl Track { .try_into() .context("timestamp too large")?; - // Create a new segment. - let mut segment = self.track.create_group(Group { - id: self.sequence, - - // Newer segments are higher priority - send_order: u32::MAX.checked_sub(timestamp).context("priority too large")?.into(), - })?; + let priority = u32::MAX.checked_sub(timestamp).context("priority too large")?.into(); - self.sequence += 1; + // Create a new segment. + let mut segment = self.track.next(priority)?; // Write the fragment in it's own object. - segment.write_object(raw.into())?; + segment.write(raw.into())?; // Save for the next iteration self.current = Some(segment); @@ -299,7 +283,7 @@ impl Track { pub fn data(&mut self, raw: Vec) -> anyhow::Result<()> { let segment = self.current.as_mut().context("missing current fragment")?; - segment.write_object(raw.into())?; + segment.write(raw.into())?; Ok(()) } diff --git a/moq-relay/Cargo.toml b/moq-relay/Cargo.toml index 9d0138e0..169ebaef 100644 --- a/moq-relay/Cargo.toml +++ b/moq-relay/Cargo.toml @@ -17,9 +17,9 @@ moq-api = { path = "../moq-api", version = "0.0.1" } # QUIC quinn = "0.10" -quictransport-quinn = "0.8" -webtransport-quinn = "0.8" -webtransport-generic = "0.8" +quictransport-quinn = { version = "0.9" } +webtransport-quinn = { version = "0.9" } +webtransport-generic = { version = "0.9" } url = "2" # Crypto diff --git a/moq-relay/src/connection.rs b/moq-relay/src/connection.rs index f866a66a..0d46cbb0 100644 --- a/moq-relay/src/connection.rs +++ b/moq-relay/src/connection.rs @@ -1,18 +1,22 @@ use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; -use moq_transport::session::{Announced, Publisher, SessionError, Subscriber}; +use moq_transport::{ + serve::ServeError, + session::{Announced, Publisher, SessionError, Subscribed, Subscriber}, +}; -use crate::{Origin, OriginPublisher}; +use crate::{LocalsConsumer, LocalsProducer, RelayError, RemotesConsumer}; #[derive(Clone)] pub struct Connection { - origin: Origin, + locals: (LocalsProducer, LocalsConsumer), + remotes: Option, } impl Connection { - pub fn new(origin: Origin) -> Self { - Self { origin } + pub fn new(locals: (LocalsProducer, LocalsConsumer), remotes: Option) -> Self { + Self { locals, remotes } } pub async fn run(self, mut conn: quinn::Connecting) -> anyhow::Result<()> { @@ -70,16 +74,16 @@ impl Connection { let mut tasks = FuturesUnordered::new(); tasks.push(session.run().boxed()); - if let Some(publisher) = publisher { - tasks.push(Self::serve_publisher(publisher, self.origin.clone()).boxed()); + if let Some(remote) = publisher { + tasks.push(Self::serve_subscriber(self.clone(), remote).boxed()); } - if let Some(subscriber) = subscriber { - tasks.push(Self::serve_subscriber(subscriber, self.origin).boxed()); + if let Some(remote) = subscriber { + tasks.push(Self::serve_publisher(self.clone(), remote).boxed()); } // Return the first error - tasks.next().await.unwrap()?; + tasks.select_next_some().await?; Ok(()) } @@ -92,80 +96,156 @@ impl Connection { let mut tasks = FuturesUnordered::new(); tasks.push(session.run().boxed()); - if let Some(publisher) = publisher { - tasks.push(Self::serve_publisher(publisher, self.origin.clone()).boxed()); + if let Some(remote) = publisher { + tasks.push(Self::serve_subscriber(self.clone(), remote).boxed()); } - if let Some(subscriber) = subscriber { - tasks.push(Self::serve_subscriber(subscriber, self.origin).boxed()); + if let Some(remote) = subscriber { + tasks.push(Self::serve_publisher(self.clone(), remote).boxed()); } // Return the first error - tasks.next().await.unwrap()?; + tasks.select_next_some().await?; Ok(()) } - async fn serve_publisher( - mut publisher: Publisher, - origin: Origin, + async fn serve_subscriber( + self, + mut remote: Publisher, ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { tokio::select! { - res = tasks.next(), if !tasks.is_empty() => { - if let Err(err) = res.unwrap() { - log::info!("failed serving subscribe: err={}", err) - } + subscribe = remote.subscribed() => { + let conn = self.clone(); + + tasks.push(async move { + let info = subscribe.info.clone(); + log::info!("serving subscribe: {:?}", info); + + if let Err(err) = conn.serve_subscribe(subscribe).await { + log::warn!("failed serving subscribe: {:?}, error: {}", info, err) + } + }) }, - res = publisher.subscribed() => { - let subscribe = res?; - log::info!("serving subscribe: namespace={} name={}", subscribe.namespace(), subscribe.name()); + _= tasks.next(), if !tasks.is_empty() => {}, + }; + } + } + + async fn serve_subscribe( + self, + subscribe: Subscribed, + ) -> Result<(), RelayError> { + if let Some(local) = self.locals.1.route(&subscribe.namespace) { + log::debug!("using local announce: {:?}", local.info); + if let Some(track) = local.subscribe(&subscribe.name)? { + log::info!("serving from local: {:?}", track.info); + // NOTE: Depends on drop(track) being called afterwards + return Ok(subscribe.serve(track.reader).await?); + } + } - let track = origin.subscribe(subscribe.namespace(), subscribe.name())?; - tasks.push(subscribe.serve(track).boxed()); + if let Some(remotes) = &self.remotes { + if let Some(remote) = remotes.route(&subscribe.namespace).await? { + log::debug!("using remote announce: {:?}", remote.info); + if let Some(track) = remote.subscribe(&subscribe.namespace, &subscribe.name)? { + log::info!("serving from remote: {:?} {:?}", remote.info, track.info); + + // NOTE: Depends on drop(track) being called afterwards + return Ok(subscribe.serve(track.reader).await?); } - }; + } } + + Err(ServeError::NotFound.into()) } - async fn serve_subscriber( - mut subscriber: Subscriber, - origin: Origin, + async fn serve_publisher( + self, + mut remote: Subscriber, ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { tokio::select! { - res = tasks.next(), if !tasks.is_empty() => { - if let Err(err) = res.unwrap() { - log::info!("failed serving announce: err={}", err) - } + announce = remote.announced() => { + let remote = remote.clone(); + let conn = self.clone(); + + tasks.push(async move { + let info = announce.info.clone(); + log::info!("serving announce: {:?}", info); + + if let Err(err) = conn.serve_announce(remote, announce).await { + log::warn!("failed serving announce: {:?}, error: {}", info, err) + } + }); }, - res = subscriber.announced() => { - let announce = res?; - log::info!("serving announce: namespace={}", announce.namespace()); - - let publisher = origin.announce(announce.namespace())?; - tasks.push(Self::serve_announce(subscriber.clone(), publisher, announce)); - } + _ = tasks.next(), if !tasks.is_empty() => {}, }; } } async fn serve_announce( - mut subscriber: Subscriber, - mut publisher: OriginPublisher, + mut self, + remote: Subscriber, mut announce: Announced, - ) -> Result<(), SessionError> { - // Send ANNOUNCE_OK - // We sent ANNOUNCE_CANCEL when the scope drops - announce.accept()?; + ) -> Result<(), RelayError> { + let mut publisher = match self.locals.0.announce(&announce.namespace).await { + Ok(publisher) => { + announce.ok()?; + publisher + } + Err(err) => { + // TODO use better error codes + announce.close(err.clone().into())?; + return Err(err); + } + }; + + let mut tasks = FuturesUnordered::new(); + + let mut done = None; loop { - let track = publisher.requested().await?; - subscriber.subscribe(track)?; + tokio::select! { + // If the announce is closed, return the error + res = announce.closed(), if done.is_none() => done = Some(res), + + // Wait for the next subscriber and serve the track. + res = publisher.requested(), if done.is_none() => { + let track = match res? { + Some(track) => track, + None => { + done = Some(Ok(())); + continue + }, + }; + + let mut subscriber = remote.clone(); + + tasks.push(async move { + let info = track.info.clone(); + log::info!("relaying track: track={:?}", info); + + let res = match subscriber.subscribe(track) { + Ok(subscribe) => subscribe.closed().await, + Err(err) => Err(err), + }; + + if let Err(err) = res { + log::warn!("failed serving track: {:?}, error: {}", info, err) + } + }); + }, + _ = tasks.next(), if !tasks.is_empty() => {} + + // Done must be set and there are no tasks left + else => return Ok(done.unwrap()?), + } } } } diff --git a/moq-relay/src/error.rs b/moq-relay/src/error.rs index 9e1565f9..65258c6f 100644 --- a/moq-relay/src/error.rs +++ b/moq-relay/src/error.rs @@ -1,25 +1,37 @@ +use std::sync::Arc; + +use moq_transport::serve::ServeError; use thiserror::Error; -#[derive(Error, Debug)] +#[derive(Error, Debug, Clone)] pub enum RelayError { - #[error("transport error: {0}")] - Transport(#[from] moq_transport::session::SessionError), + #[error("session error: {0}")] + Transport(#[from] moq_transport::SessionError), #[error("serve error: {0}")] - Cache(#[from] moq_transport::serve::ServeError), + Serve(#[from] ServeError), #[error("api error: {0}")] - MoqApi(#[from] moq_api::ApiError), + Api(#[from] Arc), #[error("url error: {0}")] Url(#[from] url::ParseError), #[error("webtransport client error: {0}")] - WebTransportClient(#[from] webtransport_quinn::ClientError), + Client(#[from] webtransport_quinn::ClientError), #[error("webtransport server error: {0}")] - WebTransportServer(#[from] webtransport_quinn::ServerError), + Server(#[from] webtransport_quinn::ServerError), #[error("missing node")] MissingNode, } + +impl From for ServeError { + fn from(err: RelayError) -> Self { + match err { + RelayError::Serve(err) => err, + _ => ServeError::Internal(err.to_string()), + } + } +} diff --git a/moq-relay/src/local.rs b/moq-relay/src/local.rs new file mode 100644 index 00000000..d8c93390 --- /dev/null +++ b/moq-relay/src/local.rs @@ -0,0 +1,348 @@ +use std::collections::hash_map; +use std::collections::HashMap; + +use std::collections::VecDeque; +use std::fmt; +use std::ops; +use std::sync::Arc; +use std::sync::Weak; + +use moq_transport::serve::{self, ServeError, TrackReader, TrackWriter}; +use moq_transport::util::State; +use tokio::time; +use url::Url; + +use crate::RelayError; + +pub struct Locals { + pub api: Option, + pub node: Option, +} + +impl Locals { + pub fn produce(self) -> (LocalsProducer, LocalsConsumer) { + let (send, recv) = State::init(); + let info = Arc::new(self); + + let producer = LocalsProducer::new(info.clone(), send); + let consumer = LocalsConsumer::new(info, recv); + + (producer, consumer) + } +} + +#[derive(Default)] +struct LocalsState { + lookup: HashMap, +} + +#[derive(Clone)] +pub struct LocalsProducer { + info: Arc, + state: State, +} + +impl LocalsProducer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } + + pub async fn announce(&mut self, namespace: &str) -> Result { + let (mut writer, reader) = Local { + namespace: namespace.to_string(), + locals: self.info.clone(), + } + .produce(self.clone()); + + // Try to insert with the API. + writer.register().await?; + + let mut state = self.state.lock_mut().unwrap(); + match state.lookup.entry(namespace.to_string()) { + hash_map::Entry::Vacant(entry) => entry.insert(reader), + hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), + }; + + Ok(writer) + } + + async fn unannounce(&mut self, namespace: &str) -> Result<(), RelayError> { + if let Some(mut state) = self.state.lock_mut() { + state.lookup.remove(namespace).ok_or(ServeError::NotFound)?; + } + + if let Some(api) = self.api.as_ref() { + api.delete_origin(namespace).await.map_err(Arc::new)?; + } + + Ok(()) + } +} + +impl ops::Deref for LocalsProducer { + type Target = Locals; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +#[derive(Clone)] +pub struct LocalsConsumer { + pub info: Arc, + state: State, +} + +impl LocalsConsumer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } + + pub fn route(&self, namespace: &str) -> Option { + let state = self.state.lock(); + state.lookup.get(namespace).cloned() + } +} + +impl ops::Deref for LocalsConsumer { + type Target = Locals; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub struct Local { + pub namespace: String, + pub locals: Arc, +} + +impl Local { + /// Create a new broadcast. + fn produce(self, parent: LocalsProducer) -> (LocalProducer, LocalConsumer) { + let (send, recv) = State::init(); + let info = Arc::new(self); + + let writer = LocalProducer::new(info.clone(), send, parent); + let reader = LocalConsumer::new(info, recv); + + (writer, reader) + } +} + +impl fmt::Debug for Local { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Local").field("namespace", &self.namespace).finish() + } +} + +#[derive(Default)] +struct LocalState { + tracks: HashMap, + requested: VecDeque, +} + +impl Drop for LocalState { + fn drop(&mut self) { + for track in self.requested.drain(..) { + track.close(ServeError::NotFound).ok(); + } + } +} + +/// Publish new tracks for a broadcast by name. +pub struct LocalProducer { + pub info: Arc, + state: State, + + parent: LocalsProducer, + refresh: tokio::time::Interval, +} + +impl LocalProducer { + fn new(info: Arc, state: State, parent: LocalsProducer) -> Self { + let delay = time::Duration::from_secs(300); + let mut refresh = time::interval(delay); + refresh.reset_after(delay); // Skip the first tick + + Self { + info, + state, + refresh, + parent, + } + } + + /// Block until the next track requested by a reader. + pub async fn requested(&mut self) -> Result, RelayError> { + loop { + let notify = { + let state = self.state.lock(); + if !state.requested.is_empty() { + return Ok(state.into_mut().and_then(|mut state| state.requested.pop_front())); + } + + match state.modified() { + Some(notify) => notify, + None => return Ok(None), + } + }; + + tokio::select! { + // TODO make this fully async so we don't block requested() + _ = self.refresh.tick() => self.register().await?, + _ = notify => {}, + } + } + } + + pub async fn register(&mut self) -> Result<(), RelayError> { + if let (Some(api), Some(node)) = (self.info.locals.api.as_ref(), self.info.locals.node.as_ref()) { + // Refresh the origin in moq-api. + let origin = moq_api::Origin { url: node.clone() }; + log::debug!("registering origin: namespace={} url={}", self.namespace, node); + api.set_origin(&self.namespace, origin).await.map_err(Arc::new)?; + } + + Ok(()) + } +} + +impl ops::Deref for LocalProducer { + type Target = Local; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for LocalProducer { + fn drop(&mut self) { + // TODO this is super lazy, but doing async stuff in Drop is annoying. + let mut parent = self.parent.clone(); + let namespace = self.namespace.clone(); + tokio::spawn(async move { parent.unannounce(&namespace).await }); + } +} + +/// Subscribe to a broadcast by requesting tracks. +/// +/// This can be cloned to create handles. +#[derive(Clone)] +pub struct LocalConsumer { + pub info: Arc, + state: State, +} + +impl LocalConsumer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } + + pub fn subscribe(&self, name: &str) -> Result, RelayError> { + let state = self.state.lock(); + + // Try to reuse the track if there are still active readers + if let Some(track) = state.tracks.get(name) { + if let Some(track) = track.upgrade() { + return Ok(Some(track)); + } + } + + // Create a new track. + let (writer, reader) = serve::Track { + namespace: self.info.namespace.clone(), + name: name.to_string(), + } + .produce(); + + let reader = LocalTrackReader::new(reader, self.state.clone()); + + // Upgrade the lock to mutable. + let mut state = match state.into_mut() { + Some(state) => state, + None => return Ok(None), + }; + + // Insert the track into our Map so we deduplicate future requests. + state.tracks.insert(name.to_string(), reader.downgrade()); + + // Send the track to the writer to handle. + state.requested.push_back(writer); + + Ok(Some(reader)) + } +} + +impl ops::Deref for LocalConsumer { + type Target = Local; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +#[derive(Clone)] +pub struct LocalTrackReader { + pub reader: TrackReader, + drop: Arc, +} + +impl LocalTrackReader { + fn new(reader: TrackReader, parent: State) -> Self { + let drop = Arc::new(LocalTrackDrop { + parent, + name: reader.name.clone(), + }); + + Self { reader, drop } + } + + fn downgrade(&self) -> LocalTrackWeak { + LocalTrackWeak { + reader: self.reader.clone(), + drop: Arc::downgrade(&self.drop), + } + } +} + +impl ops::Deref for LocalTrackReader { + type Target = TrackReader; + + fn deref(&self) -> &Self::Target { + &self.reader + } +} + +impl ops::DerefMut for LocalTrackReader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.reader + } +} + +struct LocalTrackWeak { + reader: TrackReader, + drop: Weak, +} + +impl LocalTrackWeak { + fn upgrade(&self) -> Option { + Some(LocalTrackReader { + reader: self.reader.clone(), + drop: self.drop.upgrade()?, + }) + } +} + +struct LocalTrackDrop { + parent: State, + name: String, +} + +impl Drop for LocalTrackDrop { + fn drop(&mut self) { + if let Some(mut parent) = self.parent.lock_mut() { + parent.tracks.remove(&self.name); + } + } +} diff --git a/moq-relay/src/main.rs b/moq-relay/src/main.rs index 126c7d76..2cb98755 100644 --- a/moq-relay/src/main.rs +++ b/moq-relay/src/main.rs @@ -4,20 +4,22 @@ use clap::Parser; mod config; mod connection; mod error; -mod origin; -mod quic; +mod local; +mod relay; +mod remote; mod tls; mod web; pub use config::*; pub use connection::*; pub use error::*; -pub use origin::*; -pub use quic::*; +pub use local::*; +pub use relay::*; +pub use remote::*; pub use tls::*; pub use web::*; -#[tokio::main] +#[tokio::main(flavor = "current_thread")] async fn main() -> anyhow::Result<()> { env_logger::init(); @@ -31,7 +33,7 @@ async fn main() -> anyhow::Result<()> { let tls = Tls::load(&config)?; // Create a QUIC server for media. - let quic = Quic::new(config.clone(), tls.clone()) + let relay = Relay::new(config.clone(), tls.clone()) .await .context("failed to create server")?; @@ -42,10 +44,10 @@ async fn main() -> anyhow::Result<()> { // Unfortunately we can't use preconditions because Tokio still executes the branch; just ignore the result tokio::select! { - res = quic.serve() => res.context("failed to run quic server"), + res = relay.run() => res.context("failed to run quic server"), res = web.serve() => res.context("failed to run web server"), } } else { - quic.serve().await.context("failed to run quic server") + relay.run().await.context("failed to run quic server") } } diff --git a/moq-relay/src/origin.rs b/moq-relay/src/origin.rs deleted file mode 100644 index 37afab77..00000000 --- a/moq-relay/src/origin.rs +++ /dev/null @@ -1,390 +0,0 @@ -use std::collections::hash_map; -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; - -use std::collections::VecDeque; - -use moq_transport::serve::{self, ServeError, TrackPublisher, TrackSubscriber}; -use moq_transport::session::SessionError; -use moq_transport::util::Watch; -use url::Url; - -#[derive(Clone)] -pub struct Origin { - // An API client used to get/set broadcasts. - // If None then we never use a remote origin. - // TODO: Stub this out instead. - _api: Option, - - // The internal address of our node. - // If None then we can never advertise ourselves as an origin. - // TODO: Stub this out instead. - _node: Option, - - // A map of active broadcasts by namespace. - origins: Arc>>, - - // A QUIC endpoint we'll use to fetch from other origins. - _quic: quinn::Endpoint, -} - -impl Origin { - pub fn new(_api: Option, _node: Option, _quic: quinn::Endpoint) -> Self { - Self { - _api, - _node, - origins: Default::default(), - _quic, - } - } - - pub fn announce(&self, namespace: &str) -> Result { - let mut origins = self.origins.lock().unwrap(); - let entry = match origins.entry(namespace.to_string()) { - hash_map::Entry::Vacant(entry) => entry, - hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), - }; - - let (publisher, subscriber) = self.produce(namespace); - entry.insert(subscriber); - - Ok(publisher) - } - - /* - // Create a publisher that constantly updates itself as the origin in moq-api. - // It holds a reference to the subscriber to prevent dropping early. - let mut publisher = Publisher { - broadcast: publisher, - subscriber, - api: None, - }; - - // Insert the publisher into the database. - if let Some(api) = self.api.as_mut() { - // Make a URL for the broadcast. - let url = self.node.as_ref().ok_or(RelayError::MissingNode)?.clone().join(id)?; - let origin = moq_api::Origin { url }; - api.set_origin(id, &origin).await?; - - // Refresh every 5 minutes - publisher.api = Some((api.clone(), origin)); - } - - - Ok(()) - */ - - pub fn subscribe(&self, namespace: &str, name: &str) -> Result { - let mut origin = self - .origins - .lock() - .unwrap() - .get(namespace) - .cloned() - .ok_or(ServeError::NotFound)?; - - let track = origin.request_track(name)?; - Ok(track) - /* - let mut routes = self.local.lock().unwrap(); - - if let Some(broadcast) = routes.get(id) { - if let Some(broadcast) = broadcast.upgrade() { - return broadcast; - } - } - - let (publisher, subscriber) = broadcast::new(id); - let subscriber = Arc::new(Subscriber { - broadcast: subscriber, - origin: self.clone(), - }); - - cache.insert(id.to_string(), Arc::downgrade(&subscriber)); - - let mut this = self.clone(); - let id = id.to_string(); - - // Rather than fetching from the API and connecting via QUIC inline, we'll spawn a task to do it. - // This way we could stop polling this session and it won't impact other session. - // It also means we'll only connect the API and QUIC once if N subscribers suddenly show up. - // However, the downside is that we don't return an error immediately. - // If that's important, it can be done but it gets a bit racey. - tokio::spawn(async move { - if let Err(err) = this.serve(&id, publisher).await { - log::warn!("failed to serve remote broadcast: id={} err={}", id, err); - } - }); - - subscriber - */ - } - - /* - async fn serve(&mut self, id: &str, publisher: broadcast::Publisher) -> Result<(), RelayError> { - log::debug!("finding origin: id={}", id); - - // Fetch the origin from the API. - let origin = self - .api - .as_mut() - .ok_or(ServeError::NotFound)? - .get_origin(id) - .await? - .ok_or(ServeError::NotFound)?; - - log::debug!("fetching from origin: id={} url={}", id, origin.url); - - // Establish the webtransport session. - let session = webtransport_quinn::connect(&self.quic, &origin.url).await?; - let session = moq_transport::session::Client::subscriber(session, publisher).await?; - - session.run().await?; - - Ok(()) - } - */ - - /// Create a new broadcast. - fn produce(&self, namespace: &str) -> (OriginPublisher, OriginSubscriber) { - let state = Watch::new(State::new(namespace)); - - let publisher = OriginPublisher::new(state.clone()); - let subscriber = OriginSubscriber::new(state); - - (publisher, subscriber) - } -} - -#[derive(Debug)] -struct State { - namespace: String, - tracks: HashMap, - requested: VecDeque, - closed: Result<(), ServeError>, -} - -impl State { - pub fn new(namespace: &str) -> Self { - Self { - namespace: namespace.to_string(), - tracks: HashMap::new(), - requested: VecDeque::new(), - closed: Ok(()), - } - } - - pub fn get_track(&self, name: &str) -> Result, ServeError> { - // Insert the track into our Map so we deduplicate future requests. - if let Some(track) = self.tracks.get(name) { - return Ok(Some(track.clone())); - } - - self.closed.clone()?; - Ok(None) - } - - pub fn request_track(&mut self, name: &str) -> Result { - // Insert the track into our Map so we deduplicate future requests. - let entry = match self.tracks.entry(name.to_string()) { - hash_map::Entry::Vacant(entry) => entry, - hash_map::Entry::Occupied(entry) => return Ok(entry.get().clone()), - }; - - self.closed.clone()?; - - // Create a new track. - let (publisher, subscriber) = serve::Track { - namespace: self.namespace.clone(), - name: name.to_string(), - } - .produce(); - - // Deduplicate with others - // TODO This should be weak - entry.insert(subscriber.clone()); - - // Send the track to the Publisher to handle. - self.requested.push_back(publisher); - - Ok(subscriber) - } - - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.closed.clone()?; - self.closed = Err(err); - Ok(()) - } -} - -impl Drop for State { - fn drop(&mut self) { - for mut track in self.requested.drain(..) { - track.close(ServeError::NotFound).ok(); - } - } -} - -/// Publish new tracks for a broadcast by name. -pub struct OriginPublisher { - state: Watch, -} - -impl OriginPublisher { - fn new(state: Watch) -> Self { - Self { state } - } - - /// Block until the next track requested by a subscriber. - pub async fn requested(&mut self) -> Result { - loop { - let notify = { - let state = self.state.lock(); - if !state.requested.is_empty() { - return Ok(state.into_mut().requested.pop_front().unwrap()); - } - - state.closed.clone()?; - state.changed() - }; - - notify.await; - } - } -} - -impl Drop for OriginPublisher { - fn drop(&mut self) { - self.state.lock_mut().close(ServeError::Done).ok(); - } -} - -/// Subscribe to a broadcast by requesting tracks. -/// -/// This can be cloned to create handles. -#[derive(Clone)] -pub struct OriginSubscriber { - state: Watch, - _dropped: Arc, -} - -impl OriginSubscriber { - fn new(state: Watch) -> Self { - let _dropped = Arc::new(Dropped::new(state.clone())); - Self { state, _dropped } - } - - pub fn get_track(&self, name: &str) -> Result, ServeError> { - self.state.lock_mut().get_track(name) - } - - pub fn request_track(&mut self, name: &str) -> Result { - self.state.lock_mut().request_track(name) - } - - /// Wait until if the broadcast is closed, either because the publisher was dropped or called [Publisher::close]. - pub async fn closed(&self) -> ServeError { - loop { - let notify = { - let state = self.state.lock(); - if let Some(err) = state.closed.as_ref().err() { - return err.clone(); - } - - state.changed() - }; - - notify.await; - } - } -} - -struct Dropped { - state: Watch, -} - -impl Dropped { - fn new(state: Watch) -> Self { - Self { state } - } -} - -impl Drop for Dropped { - fn drop(&mut self) { - self.state.lock_mut().close(ServeError::Done).ok(); - } -} - -/* -pub struct Subscriber { - pub broadcast: broadcast::Subscriber, - - origin: Origin, -} - -impl Drop for Subscriber { - fn drop(&mut self) { - self.origin.cache.lock().unwrap().remove(&self.broadcast.id); - } -} - -impl Deref for Subscriber { - type Target = broadcast::Subscriber; - - fn deref(&self) -> &Self::Target { - &self.broadcast - } -} - -pub struct Publisher { - pub broadcast: broadcast::Publisher, - - api: Option<(moq_api::Client, moq_api::Origin)>, - - #[allow(dead_code)] - subscriber: Arc, -} - -impl Publisher { - pub async fn run(&mut self) -> Result<(), ApiError> { - // Every 5m tell the API we're still alive. - // TODO don't hard-code these values - let mut interval = time::interval(time::Duration::from_secs(60 * 5)); - - loop { - if let Some((api, origin)) = self.api.as_mut() { - api.patch_origin(&self.broadcast.id, origin).await?; - } - - // TODO move to start of loop; this is just for testing - interval.tick().await; - } - } - - pub async fn close(&mut self) -> Result<(), ApiError> { - if let Some((api, _)) = self.api.as_mut() { - api.delete_origin(&self.broadcast.id).await?; - } - - Ok(()) - } -} - -impl Deref for Publisher { - type Target = broadcast::Publisher; - - fn deref(&self) -> &Self::Target { - &self.broadcast - } -} - -impl DerefMut for Publisher { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.broadcast - } -} - -*/ diff --git a/moq-relay/src/quic.rs b/moq-relay/src/relay.rs similarity index 68% rename from moq-relay/src/quic.rs rename to moq-relay/src/relay.rs index 8727450e..945be6f5 100644 --- a/moq-relay/src/quic.rs +++ b/moq-relay/src/relay.rs @@ -4,19 +4,18 @@ use anyhow::Context; use tokio::task::JoinSet; -use crate::{Config, Connection, Origin, Tls}; +use crate::{ + Config, Connection, Locals, LocalsConsumer, LocalsProducer, Remotes, RemotesConsumer, RemotesProducer, Tls, +}; -pub struct Quic { +pub struct Relay { quic: quinn::Endpoint, - // The active connections. - conns: JoinSet>, - - // The map of active broadcasts by path. - origin: Origin, + locals: (LocalsProducer, LocalsConsumer), + remotes: Option<(RemotesProducer, RemotesConsumer)>, } -impl Quic { +impl Relay { // Create a QUIC endpoint that can be used for both clients and servers. pub async fn new(config: Config, tls: Tls) -> anyhow::Result { let mut client_config = tls.client.clone(); @@ -53,32 +52,47 @@ impl Quic { moq_api::Client::new(url) }); - if let Some(ref node) = config.api_node { + let node = config.api_node.map(|node| { log::info!("advertising origin: url={}", node); - } + node + }); - let origin = Origin::new(api, config.api_node, quic.clone()); - let conns = JoinSet::new(); + let remotes = api.clone().map(|api| { + Remotes { + api, + quic: quic.clone(), + } + .produce() + }); + let locals = Locals { api, node }.produce(); - Ok(Self { quic, origin, conns }) + Ok(Self { quic, locals, remotes }) } - pub async fn serve(mut self) -> anyhow::Result<()> { + pub async fn run(self) -> anyhow::Result<()> { log::info!("listening on {}", self.quic.local_addr()?); + let mut tasks = JoinSet::new(); + + let remotes = self.remotes.map(|(producer, consumer)| { + tasks.spawn(producer.run()); + consumer + }); + loop { tokio::select! { res = self.quic.accept() => { let conn = res.context("failed to accept QUIC connection")?; - let session = Connection::new(self.origin.clone()); - self.conns.spawn(session.run(conn)); - }, - res = self.conns.join_next(), if !self.conns.is_empty() => { - let res = res.expect("no tasks").expect("task aborted"); - if let Err(err) = res { - log::warn!("connection terminated: {:?}", err); - } + let session = Connection::new(self.locals.clone(), remotes.clone()); + + tasks.spawn(async move { + if let Err(err) = session.run(conn).await { + log::warn!("connection terminated: {:?}", err); + } + Ok(()) + }); }, + res = tasks.join_next(), if !tasks.is_empty() => res.expect("no tasks").expect("task aborted")?, } } } diff --git a/moq-relay/src/remote.rs b/moq-relay/src/remote.rs new file mode 100644 index 00000000..c89e30de --- /dev/null +++ b/moq-relay/src/remote.rs @@ -0,0 +1,420 @@ +use std::collections::HashMap; + +use std::collections::VecDeque; +use std::fmt; +use std::ops; +use std::sync::Arc; +use std::sync::Weak; + +use futures::stream::FuturesUnordered; +use futures::FutureExt; +use futures::StreamExt; +use moq_transport::serve::{Track, TrackReader, TrackWriter}; +use moq_transport::util::State; +use url::Url; + +use crate::RelayError; + +pub struct Remotes { + /// The client we use to fetch/store origin information. + pub api: moq_api::Client, + + // A QUIC endpoint we'll use to fetch from other origins. + pub quic: quinn::Endpoint, +} + +impl Remotes { + pub fn produce(self) -> (RemotesProducer, RemotesConsumer) { + let (send, recv) = State::init(); + let info = Arc::new(self); + + let producer = RemotesProducer::new(info.clone(), send); + let consumer = RemotesConsumer::new(info, recv); + + (producer, consumer) + } +} + +#[derive(Default)] +struct RemotesState { + lookup: HashMap, + requested: VecDeque, +} + +// Clone for convenience, but there should only be one instance of this +#[derive(Clone)] +pub struct RemotesProducer { + info: Arc, + state: State, +} + +impl RemotesProducer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } + + async fn next(&mut self) -> Result, RelayError> { + loop { + let notify = { + let state = self.state.lock(); + if !state.requested.is_empty() { + return Ok(state.into_mut().and_then(|mut state| state.requested.pop_front())); + } + + match state.modified() { + Some(notified) => notified, + None => return Ok(None), + } + }; + + notify.await + } + } + + pub async fn run(mut self) -> Result<(), RelayError> { + let mut tasks = FuturesUnordered::new(); + + loop { + tokio::select! { + remote = self.next() => { + let remote = match remote? { + Some(remote) => remote, + None => return Ok(()), + }; + + let url = remote.url.clone(); + + tasks.push(async move { + let info = remote.info.clone(); + + log::warn!("serving remote: {:?}", info); + if let Err(err) = remote.run().await { + log::warn!("failed serving remote: {:?}, error: {}", info, err); + } + + url + }); + } + res = tasks.next(), if !tasks.is_empty() => { + let url = res.unwrap(); + + if let Some(mut state) = self.state.lock_mut() { + state.lookup.remove(&url); + } + }, + } + } + } +} + +impl ops::Deref for RemotesProducer { + type Target = Remotes; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +#[derive(Clone)] +pub struct RemotesConsumer { + pub info: Arc, + state: State, +} + +impl RemotesConsumer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } + + pub async fn route(&self, namespace: &str) -> Result, RelayError> { + // Always fetch the origin instead of using the (potentially invalid) cache. + let origin = match self.api.get_origin(namespace).await.map_err(Arc::new)? { + None => return Ok(None), + Some(origin) => origin, + }; + + let state = self.state.lock(); + if let Some(remote) = state.lookup.get(&origin.url).cloned() { + return Ok(Some(remote)); + } + + let mut state = match state.into_mut() { + Some(state) => state, + None => return Ok(None), + }; + + let remote = Remote { + url: origin.url.clone(), + remotes: self.info.clone(), + }; + + let (writer, reader) = remote.produce(); + state.requested.push_back(writer); + + state.lookup.insert(origin.url, reader.clone()); + + Ok(Some(reader)) + } +} + +impl ops::Deref for RemotesConsumer { + type Target = Remotes; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub struct Remote { + pub remotes: Arc, + pub url: Url, +} + +impl fmt::Debug for Remote { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Remote").field("url", &self.url.to_string()).finish() + } +} + +impl ops::Deref for Remote { + type Target = Remotes; + + fn deref(&self) -> &Self::Target { + &self.remotes + } +} + +impl Remote { + /// Create a new broadcast. + pub fn produce(self) -> (RemoteProducer, RemoteConsumer) { + let (send, recv) = State::init(); + let info = Arc::new(self); + + let consumer = RemoteConsumer::new(info.clone(), recv); + let producer = RemoteProducer::new(info, send); + + (producer, consumer) + } +} + +struct RemoteState { + tracks: HashMap<(String, String), RemoteTrackWeak>, + requested: VecDeque, + closed: Result<(), RelayError>, +} + +impl Default for RemoteState { + fn default() -> Self { + Self { + tracks: HashMap::new(), + requested: VecDeque::new(), + closed: Ok(()), + } + } +} + +pub struct RemoteProducer { + pub info: Arc, + state: State, +} + +impl RemoteProducer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } + + pub async fn run(mut self) -> Result<(), RelayError> { + if let Err(err) = self.run_inner().await { + if let Some(mut state) = self.state.lock_mut() { + state.closed = Err(err.clone()); + } + + return Err(err); + } + + Ok(()) + } + + pub async fn run_inner(&mut self) -> Result<(), RelayError> { + // TODO reuse QUIC and MoQ sessions + let session = webtransport_quinn::connect(&self.quic, &self.url).await?; + let (session, mut subscriber) = moq_transport::Subscriber::connect(session).await?; + + // Run the session + let mut session = session.run().boxed(); + let mut tasks = FuturesUnordered::new(); + + let mut done = None; + + loop { + tokio::select! { + track = self.next(), if done.is_none() => { + let track = match track { + Ok(Some(track)) => track, + Ok(None) => { done = Some(Ok(())); continue }, + Err(err) => { done = Some(Err(err)); continue }, + }; + + let info = track.info.clone(); + + let subscribe = match subscriber.subscribe(track) { + Ok(subscribe) => subscribe, + Err(err) => { + log::warn!("failed subscribing: {:?}, error: {}", info, err); + continue + } + }; + + tasks.push(async move { + if let Err(err) = subscribe.closed().await { + log::warn!("failed serving track: {:?}, error: {}", info, err); + } + }); + } + _ = tasks.next(), if !tasks.is_empty() => {}, + + // Keep running the session + res = &mut session, if !tasks.is_empty() || done.is_none() => return Ok(res?), + + else => return done.unwrap(), + } + } + } + + /// Block until the next track requested by a consumer. + async fn next(&self) -> Result, RelayError> { + loop { + let notify = { + let state = self.state.lock(); + if !state.requested.is_empty() { + return Ok(state.into_mut().and_then(|mut state| state.requested.pop_front())); + } + + match state.modified() { + Some(notified) => notified, + None => return Ok(None), + } + }; + + notify.await + } + } +} + +impl ops::Deref for RemoteProducer { + type Target = Remote; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +#[derive(Clone)] +pub struct RemoteConsumer { + pub info: Arc, + state: State, +} + +impl RemoteConsumer { + fn new(info: Arc, state: State) -> Self { + Self { info, state } + } + + /// Request a track from the broadcast. + pub fn subscribe(&self, namespace: &str, name: &str) -> Result, RelayError> { + let key = (namespace.to_string(), name.to_string()); + let state = self.state.lock(); + if let Some(track) = state.tracks.get(&key) { + if let Some(track) = track.upgrade() { + return Ok(Some(track)); + } + } + + let mut state = match state.into_mut() { + Some(state) => state, + None => return Ok(None), + }; + + let (writer, reader) = Track::new(namespace, name).produce(); + let reader = RemoteTrackReader::new(reader, self.state.clone()); + + // Insert the track into our Map so we deduplicate future requests. + state.tracks.insert(key, reader.downgrade()); + state.requested.push_back(writer); + + Ok(Some(reader)) + } +} + +impl ops::Deref for RemoteConsumer { + type Target = Remote; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +#[derive(Clone)] +pub struct RemoteTrackReader { + pub reader: TrackReader, + drop: Arc, +} + +impl RemoteTrackReader { + fn new(reader: TrackReader, parent: State) -> Self { + let drop = Arc::new(RemoteTrackDrop { + parent, + key: (reader.namespace.clone(), reader.name.clone()), + }); + + Self { reader, drop } + } + + fn downgrade(&self) -> RemoteTrackWeak { + RemoteTrackWeak { + reader: self.reader.clone(), + drop: Arc::downgrade(&self.drop), + } + } +} + +impl ops::Deref for RemoteTrackReader { + type Target = TrackReader; + + fn deref(&self) -> &Self::Target { + &self.reader + } +} + +impl ops::DerefMut for RemoteTrackReader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.reader + } +} + +struct RemoteTrackWeak { + reader: TrackReader, + drop: Weak, +} + +impl RemoteTrackWeak { + fn upgrade(&self) -> Option { + Some(RemoteTrackReader { + reader: self.reader.clone(), + drop: self.drop.upgrade()?, + }) + } +} + +struct RemoteTrackDrop { + parent: State, + key: (String, String), +} + +impl Drop for RemoteTrackDrop { + fn drop(&mut self) { + if let Some(mut parent) = self.parent.lock_mut() { + parent.tracks.remove(&self.key); + } + } +} diff --git a/moq-transport/Cargo.toml b/moq-transport/Cargo.toml index 00343c00..f00c76d9 100644 --- a/moq-transport/Cargo.toml +++ b/moq-transport/Cargo.toml @@ -20,7 +20,7 @@ thiserror = "1" tokio = { version = "1", features = ["macros", "io-util", "sync"] } log = "0.4" -webtransport-generic = "0.8" +webtransport-generic = { version = "0.9" } paste = "1" futures = "0.3" diff --git a/moq-transport/src/coding/mod.rs b/moq-transport/src/coding/mod.rs index 8753e205..a3ff6f78 100644 --- a/moq-transport/src/coding/mod.rs +++ b/moq-transport/src/coding/mod.rs @@ -1,14 +1,10 @@ mod decode; mod encode; mod params; -mod reader; mod string; mod varint; -mod writer; pub use decode::*; pub use encode::*; pub use params::*; -pub use reader::*; pub use varint::*; -pub use writer::*; diff --git a/moq-transport/src/coding/reader.rs b/moq-transport/src/coding/reader.rs index df823a52..0338b98a 100644 --- a/moq-transport/src/coding/reader.rs +++ b/moq-transport/src/coding/reader.rs @@ -63,7 +63,11 @@ impl Reader { Ok(self.buffer.is_empty() && self.stream.read_buf(&mut self.buffer).await? == 0) } - pub fn into_inner(self) -> (bytes::BytesMut, S) { - (self.buffer, self.stream) + pub fn buffered(&self) -> &[u8] { + &self.buffer + } + + pub fn into_inner(self) -> S { + self.stream } } diff --git a/moq-transport/src/coding/writer.rs b/moq-transport/src/coding/writer.rs deleted file mode 100644 index 25bb2a0e..00000000 --- a/moq-transport/src/coding/writer.rs +++ /dev/null @@ -1,36 +0,0 @@ -use tokio::io::{AsyncWrite, AsyncWriteExt}; - -use crate::coding::Encode; - -use super::EncodeError; - -pub struct Writer { - stream: S, - buffer: bytes::BytesMut, -} - -impl Writer { - pub fn new(stream: S) -> Self { - Self { - stream, - buffer: Default::default(), - } - } - - pub async fn encode(&mut self, msg: &T) -> Result<(), EncodeError> { - self.buffer.clear(); - msg.encode(&mut self.buffer)?; - self.stream.write_all(&self.buffer).await?; - - Ok(()) - } - - pub async fn write(&mut self, buf: &[u8]) -> Result<(), EncodeError> { - self.stream.write_all(buf).await?; - Ok(()) - } - - pub fn into_inner(self) -> S { - self.stream - } -} diff --git a/moq-transport/src/lib.rs b/moq-transport/src/lib.rs index 2456c30f..3e58f272 100644 --- a/moq-transport/src/lib.rs +++ b/moq-transport/src/lib.rs @@ -13,4 +13,4 @@ pub mod session; pub mod setup; pub mod util; -pub use session::{Publisher, Session, Subscriber}; +pub use session::{Publisher, Session, SessionError, Subscriber}; diff --git a/moq-transport/src/serve/broadcast.rs b/moq-transport/src/serve/broadcast.rs index e85d4733..caaabedc 100644 --- a/moq-transport/src/serve/broadcast.rs +++ b/moq-transport/src/serve/broadcast.rs @@ -1,24 +1,23 @@ -//! A broadcast is a collection of tracks, split into two handles: [Publisher] and [Subscriber]. +//! A broadcast is a collection of tracks, split into two handles: [Writer] and [Reader]. //! -//! The [Publisher] can create tracks, either manually or on request. -//! It receives all requests by a [Subscriber] for a tracks that don't exist. +//! The [Writer] can create tracks, either manually or on request. +//! It receives all requests by a [Reader] for a tracks that don't exist. //! The simplest implementation is to close every unknown track with [ServeError::NotFound]. //! -//! A [Subscriber] can request tracks by name. +//! A [Reader] can request tracks by name. //! If the track already exists, it will be returned. //! If the track doesn't exist, it will be sent to [Unknown] to be handled. -//! A [Subscriber] can be cloned to create multiple subscriptions. +//! A [Reader] can be cloned to create multiple subscriptions. //! -//! The broadcast is automatically closed with [ServeError::Done] when [Publisher] is dropped, or all [Subscriber]s are dropped. +//! The broadcast is automatically closed with [ServeError::Done] when [Writer] is dropped, or all [Reader]s are dropped. use std::{ collections::{hash_map, HashMap}, - fmt, ops::Deref, sync::Arc, }; -use super::{ServeError, Track, TrackPublisher, TrackSubscriber}; -use crate::util::Watch; +use super::{ServeError, Track, TrackReader, TrackWriter}; +use crate::util::State; /// Static information about a broadcast. #[derive(Debug)] @@ -33,26 +32,25 @@ impl Broadcast { } } - pub fn produce(self) -> (BroadcastPublisher, BroadcastSubscriber) { - let state = Watch::new(State::default()); + pub fn produce(self) -> (BroadcastWriter, BroadcastReader) { + let (send, recv) = State::init(); let info = Arc::new(self); - let publisher = BroadcastPublisher::new(state.clone(), info.clone()); - let subscriber = BroadcastSubscriber::new(state, info); + let writer = BroadcastWriter::new(send, info.clone()); + let reader = BroadcastReader::new(recv, info); - (publisher, subscriber) + (writer, reader) } } /// Dynamic information about the broadcast. -#[derive(Debug)] -struct State { - tracks: HashMap, +struct BroadcastState { + tracks: HashMap, closed: Result<(), ServeError>, } -impl State { - pub fn get(&self, name: &str) -> Result, ServeError> { +impl BroadcastState { + pub fn get(&self, name: &str) -> Result, ServeError> { match self.tracks.get(name) { Some(track) => Ok(Some(track.clone())), // Return any error if we couldn't find a track. @@ -60,7 +58,7 @@ impl State { } } - pub fn insert(&mut self, track: TrackSubscriber) -> Result<(), ServeError> { + pub fn insert(&mut self, track: TrackReader) -> Result<(), ServeError> { match self.tracks.entry(track.name.clone()) { hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate), hash_map::Entry::Vacant(v) => v.insert(track), @@ -69,18 +67,12 @@ impl State { Ok(()) } - pub fn remove(&mut self, name: &str) -> Option { + pub fn remove(&mut self, name: &str) -> Option { self.tracks.remove(name) } - - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.closed.clone()?; - self.closed = Err(err); - Ok(()) - } } -impl Default for State { +impl Default for BroadcastState { fn default() -> Self { Self { tracks: HashMap::new(), @@ -90,40 +82,47 @@ impl Default for State { } /// Publish new tracks for a broadcast by name. -#[derive(Debug)] -pub struct BroadcastPublisher { - state: Watch, - info: Arc, +pub struct BroadcastWriter { + state: State, + pub info: Arc, } -impl BroadcastPublisher { - fn new(state: Watch, info: Arc) -> Self { - Self { state, info } +impl BroadcastWriter { + fn new(state: State, broadcast: Arc) -> Self { + Self { state, info: broadcast } } /// Create a new track with the given name, inserting it into the broadcast. - pub fn create_track(&mut self, track: &str) -> Result { - let (publisher, subscriber) = Track { + pub fn create_track(&mut self, track: &str) -> Result { + let (writer, reader) = Track { namespace: self.namespace.clone(), name: track.to_owned(), } .produce(); - self.state.lock_mut().insert(subscriber)?; - Ok(publisher) + self.state.lock_mut().ok_or(ServeError::Cancel)?.insert(reader)?; + + Ok(writer) } - pub fn remove_track(&mut self, track: &str) -> Option { - self.state.lock_mut().remove(track) + pub fn remove_track(&mut self, track: &str) -> Option { + self.state.lock_mut()?.remove(track) } /// Close the broadcast with an error. - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.state.lock_mut().close(err) + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + if let Some(mut state) = state.into_mut() { + state.closed = Err(err); + } + + Ok(()) } } -impl Deref for BroadcastPublisher { +impl Deref for BroadcastWriter { type Target = Broadcast; fn deref(&self) -> &Self::Target { @@ -134,67 +133,27 @@ impl Deref for BroadcastPublisher { /// Subscribe to a broadcast by requesting tracks. /// /// This can be cloned to create handles. -#[derive(Clone, Debug)] -pub struct BroadcastSubscriber { - state: Watch, - info: Arc, - _dropped: Arc, +#[derive(Clone)] +pub struct BroadcastReader { + state: State, + pub info: Arc, } -impl BroadcastSubscriber { - fn new(state: Watch, info: Arc) -> Self { - let _dropped = Arc::new(Dropped::new(state.clone())); - Self { state, info, _dropped } +impl BroadcastReader { + fn new(state: State, broadcast: Arc) -> Self { + Self { state, info: broadcast } } /// Get a track from the broadcast by name. - pub fn get_track(&self, name: &str) -> Result, ServeError> { + pub fn get_track(&self, name: &str) -> Result, ServeError> { self.state.lock().get(name) } - - /// Wait until if the broadcast is closed, either because the publisher was dropped or called [Publisher::close]. - pub async fn closed(&self) -> ServeError { - loop { - let notify = { - let state = self.state.lock(); - if let Some(err) = state.closed.as_ref().err() { - return err.clone(); - } - - state.changed() - }; - - notify.await; - } - } } -impl Deref for BroadcastSubscriber { +impl Deref for BroadcastReader { type Target = Broadcast; fn deref(&self) -> &Self::Target { &self.info } } - -struct Dropped { - state: Watch, -} - -impl Dropped { - fn new(state: Watch) -> Self { - Self { state } - } -} - -impl Drop for Dropped { - fn drop(&mut self) { - self.state.lock_mut().close(ServeError::Done).ok(); - } -} - -impl fmt::Debug for Dropped { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Dropped").finish() - } -} diff --git a/moq-transport/src/serve/datagram.rs b/moq-transport/src/serve/datagram.rs index 8ec67208..67e03a1c 100644 --- a/moq-transport/src/serve/datagram.rs +++ b/moq-transport/src/serve/datagram.rs @@ -1,11 +1,124 @@ -use std::fmt; +use std::{fmt, sync::Arc}; + +use crate::util::State; + +use super::{ServeError, Track}; + +pub struct Datagrams { + pub track: Arc, +} + +impl Datagrams { + pub fn produce(self) -> (DatagramsWriter, DatagramsReader) { + let (writer, reader) = State::init(); + + let writer = DatagramsWriter::new(writer, self.track.clone()); + let reader = DatagramsReader::new(reader, self.track); + + (writer, reader) + } +} + +struct DatagramsState { + // The latest datagram + latest: Option, + + // Increased each time datagram changes. + epoch: u64, + + // Set when the writer or all readers are dropped. + closed: Result<(), ServeError>, +} + +impl Default for DatagramsState { + fn default() -> Self { + Self { + latest: None, + epoch: 0, + closed: Ok(()), + } + } +} + +pub struct DatagramsWriter { + state: State, + pub track: Arc, +} + +impl DatagramsWriter { + fn new(state: State, track: Arc) -> Self { + Self { state, track } + } + + pub fn write(&mut self, datagram: Datagram) -> Result<(), ServeError> { + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + + state.latest = Some(datagram); + state.epoch += 1; + + Ok(()) + } + + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + + Ok(()) + } +} + +#[derive(Clone)] +pub struct DatagramsReader { + state: State, + pub track: Arc, + + epoch: u64, +} + +impl DatagramsReader { + fn new(state: State, track: Arc) -> Self { + Self { state, track, epoch: 0 } + } + + pub async fn read(&mut self) -> Result, ServeError> { + loop { + let notify = { + let state = self.state.lock(); + if self.epoch < state.epoch { + self.epoch = state.epoch; + return Ok(state.latest.clone()); + } + + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), // No more updates will come + } + }; + + notify.await; + } + } + + // Returns the largest group/sequence + pub fn latest(&self) -> Option<(u64, u64)> { + let state = self.state.lock(); + state + .latest + .as_ref() + .map(|datagram| (datagram.group_id, datagram.object_id)) + } +} /// Static information about the datagram. #[derive(Clone)] pub struct Datagram { - pub object_id: u64, pub group_id: u64, - pub send_order: u64, + pub object_id: u64, + pub priority: u64, pub payload: bytes::Bytes, } @@ -14,7 +127,7 @@ impl fmt::Debug for Datagram { f.debug_struct("Datagram") .field("object_id", &self.object_id) .field("group_id", &self.group_id) - .field("send_order", &self.send_order) + .field("priority", &self.priority) .field("payload", &self.payload.len()) .finish() } diff --git a/moq-transport/src/serve/error.rs b/moq-transport/src/serve/error.rs index 1637650b..eebc066d 100644 --- a/moq-transport/src/serve/error.rs +++ b/moq-transport/src/serve/error.rs @@ -1,8 +1,12 @@ #[derive(thiserror::Error, Debug, Clone, PartialEq)] pub enum ServeError { + // TODO stop using? #[error("done")] Done, + #[error("cancelled")] + Cancel, + #[error("closed, code={0}")] Closed(u64), @@ -16,18 +20,23 @@ pub enum ServeError { Mode, #[error("wrong size")] - WrongSize, + Size, + + #[error("internal error: {0}")] + Internal(String), } impl ServeError { pub fn code(&self) -> u64 { match self { Self::Done => 0, + Self::Cancel => 1, Self::Closed(code) => *code, Self::NotFound => 404, Self::Duplicate => 409, Self::Mode => 400, - Self::WrongSize => 413, + Self::Size => 413, + Self::Internal(_) => 500, } } } diff --git a/moq-transport/src/serve/group.rs b/moq-transport/src/serve/group.rs index 8e607e9f..b58b2234 100644 --- a/moq-transport/src/serve/group.rs +++ b/moq-transport/src/serve/group.rs @@ -1,59 +1,233 @@ -//! A stream is a stream of objects with a header, split into a [Publisher] and [Subscriber] handle. +//! A stream is a stream of objects with a header, split into a [Writer] and [Reader] handle. //! -//! A [Publisher] writes an ordered stream of objects. -//! Each object can have a sequence number, allowing the subscriber to detect gaps objects. +//! A [Writer] writes an ordered stream of objects. +//! Each object can have a sequence number, allowing the reader to detect gaps objects. //! -//! A [Subscriber] reads an ordered stream of objects. -//! The subscriber can be cloned, in which case each subscriber receives a copy of each object. (fanout) +//! A [Reader] reads an ordered stream of objects. +//! The reader can be cloned, in which case each reader receives a copy of each object. (fanout) //! -//! The stream is closed with [ServeError::Closed] when all publishers or subscribers are dropped. -use std::{fmt, ops::Deref, sync::Arc}; +//! The stream is closed with [ServeError::Closed] when all writers or readers are dropped. +use bytes::Bytes; +use std::{cmp, ops::Deref, sync::Arc}; -use crate::util::Watch; +use crate::util::State; -use super::{ObjectHeader, ObjectPublisher, ObjectSubscriber, ServeError}; +use super::{ServeError, Track}; -/// Static information about the stream. -#[derive(Debug)] +pub struct Groups { + pub track: Arc, +} + +impl Groups { + pub fn produce(self) -> (GroupsWriter, GroupsReader) { + let (writer, reader) = State::init(); + + let writer = GroupsWriter::new(writer, self.track.clone()); + let reader = GroupsReader::new(reader, self.track); + + (writer, reader) + } +} + +impl Deref for Groups { + type Target = Track; + + fn deref(&self) -> &Self::Target { + &self.track + } +} + +// State shared between the writer and reader. +struct GroupsState { + latest: Option, + epoch: u64, // Updated each time latest changes + closed: Result<(), ServeError>, +} + +impl Default for GroupsState { + fn default() -> Self { + Self { + latest: None, + epoch: 0, + closed: Ok(()), + } + } +} + +pub struct GroupsWriter { + pub info: Arc, + state: State, + next: u64, // Not in the state to avoid a lock +} + +impl GroupsWriter { + fn new(state: State, track: Arc) -> Self { + Self { + info: track, + state, + next: 0, + } + } + + // Helper to increment the group by one. + pub fn next(&mut self, priority: u64) -> Result { + self.create(Group { + group_id: self.next, + priority, + }) + } + + pub fn create(&mut self, group: Group) -> Result { + let group = GroupInfo { + track: self.info.clone(), + group_id: group.group_id, + priority: group.priority, + }; + let (writer, reader) = group.produce(); + + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + + if let Some(latest) = &state.latest { + match writer.group_id.cmp(&latest.group_id) { + cmp::Ordering::Less => return Ok(writer), // dropped immediately, lul + cmp::Ordering::Equal => return Err(ServeError::Duplicate), + cmp::Ordering::Greater => state.latest = Some(reader), + } + } else { + state.latest = Some(reader); + } + + self.next = state.latest.as_ref().unwrap().group_id + 1; + state.epoch += 1; + + Ok(writer) + } + + /// Close the segment with an error. + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + + Ok(()) + } +} + +impl Deref for GroupsWriter { + type Target = Track; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +#[derive(Clone)] +pub struct GroupsReader { + pub info: Arc, + state: State, + epoch: u64, +} + +impl GroupsReader { + fn new(state: State, track: Arc) -> Self { + Self { + info: track, + state, + epoch: 0, + } + } + + pub async fn next(&mut self) -> Result, ServeError> { + loop { + let notify = { + let state = self.state.lock(); + + if self.epoch != state.epoch { + self.epoch = state.epoch; + return Ok(state.latest.clone()); + } + + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), + } + }; + + notify.await; // Try again when the state changes + } + } + + // Returns the largest group/sequence + pub fn latest(&self) -> Option<(u64, u64)> { + let state = self.state.lock(); + state.latest.as_ref().map(|group| (group.group_id, group.latest())) + } +} + +impl Deref for GroupsReader { + type Target = Track; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +/// Parameters that can be specified by the user +#[derive(Debug, Clone, PartialEq)] pub struct Group { - // The sequence number of the stream within the track. + // The sequence number of the group within the track. // NOTE: These may be received out of order or with gaps. - pub id: u64, + pub group_id: u64, - // The priority of the stream within the BROADCAST. - pub send_order: u64, + // The priority of the group within the track. + pub priority: u64, } -impl Group { - pub fn produce(self) -> (GroupPublisher, GroupSubscriber) { - let state = Watch::new(State::default()); +/// Static information about the group +#[derive(Debug, Clone, PartialEq)] +pub struct GroupInfo { + pub track: Arc, + + // The sequence number of the group within the track. + // NOTE: These may be received out of order or with gaps. + pub group_id: u64, + + // The priority of the group within the track. + pub priority: u64, +} + +impl GroupInfo { + pub fn produce(self) -> (GroupWriter, GroupReader) { + let (writer, reader) = State::init(); let info = Arc::new(self); - let publisher = GroupPublisher::new(state.clone(), info.clone()); - let subscriber = GroupSubscriber::new(state, info); + let writer = GroupWriter::new(writer, info.clone()); + let reader = GroupReader::new(reader, info); - (publisher, subscriber) + (writer, reader) } } -#[derive(Debug)] -struct State { - // The data that has been received thus far. - objects: Vec, +impl Deref for GroupInfo { + type Target = Track; - // Set when the publisher is dropped. - closed: Result<(), ServeError>, + fn deref(&self) -> &Self::Target { + &self.track + } } -impl State { - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.closed.clone()?; - self.closed = Err(err); - Ok(()) - } +struct GroupState { + // The data that has been received thus far. + objects: Vec, + + // Set when the writer or all readers are dropped. + closed: Result<(), ServeError>, } -impl Default for State { +impl Default for GroupState { fn default() -> Self { Self { objects: Vec::new(), @@ -62,59 +236,66 @@ impl Default for State { } } -/// Used to write data to a stream and notify subscribers. -#[derive(Debug)] -pub struct GroupPublisher { +/// Used to write data to a stream and notify readers. +pub struct GroupWriter { // Mutable stream state. - state: Watch, + state: State, // Immutable stream state. - info: Arc, + pub info: Arc, // The next object sequence number to use. next: u64, } -impl GroupPublisher { - fn new(state: Watch, info: Arc) -> Self { - Self { state, info, next: 0 } +impl GroupWriter { + fn new(state: State, group: Arc) -> Self { + Self { + state, + info: group, + next: 0, + } } /// Create the next object ID with the given payload. - pub fn write_object(&mut self, payload: bytes::Bytes) -> Result<(), ServeError> { - let mut object = self.create_object(payload.len())?; + pub fn write(&mut self, payload: bytes::Bytes) -> Result<(), ServeError> { + let mut object = self.create(payload.len())?; object.write(payload)?; Ok(()) } /// Write an object over multiple writes. /// - /// BAD STUFF will happen if the size is wrong. - pub fn create_object(&mut self, size: usize) -> Result { - let (publisher, subscriber) = ObjectHeader { - group_id: self.id, + /// BAD STUFF will happen if the size is wrong; this is an advanced feature. + pub fn create(&mut self, size: usize) -> Result { + let (writer, reader) = GroupObject { + group: self.info.clone(), object_id: self.next, - send_order: self.send_order, size, } .produce(); self.next += 1; - let mut state = self.state.lock_mut(); - state.closed.clone()?; - state.objects.push(subscriber); - Ok(publisher) + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + state.objects.push(reader); + + Ok(writer) } /// Close the stream with an error. - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.state.lock_mut().close(err) + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + Ok(()) } } -impl Deref for GroupPublisher { - type Target = Group; +impl Deref for GroupWriter { + type Target = GroupInfo; fn deref(&self) -> &Self::Target { &self.info @@ -122,38 +303,42 @@ impl Deref for GroupPublisher { } /// Notified when a stream has new data available. -#[derive(Clone, Debug)] -pub struct GroupSubscriber { +#[derive(Clone)] +pub struct GroupReader { // Modify the stream state. - state: Watch, + state: State, // Immutable stream state. - info: Arc, + pub info: Arc, // The number of chunks that we've read. - // NOTE: Cloned subscribers inherit this index, but then run in parallel. + // NOTE: Cloned readers inherit this index, but then run in parallel. index: usize, - - _dropped: Arc, } -impl GroupSubscriber { - fn new(state: Watch, info: Arc) -> Self { - let _dropped = Arc::new(Dropped::new(state.clone())); +impl GroupReader { + fn new(state: State, group: Arc) -> Self { Self { state, - info, + info: group, index: 0, - _dropped, } } pub fn latest(&self) -> u64 { - self.state.lock().objects.len() as u64 + let state = self.state.lock(); + state.objects.last().map(|o| o.object_id).unwrap_or_default() } - /// Block until the next object is available. - pub async fn next(&mut self) -> Result, ServeError> { + pub async fn read_next(&mut self) -> Result, ServeError> { + let object = self.next().await?; + match object { + Some(mut object) => Ok(Some(object.read_all().await?)), + None => Ok(None), + } + } + + pub async fn next(&mut self) -> Result, ServeError> { loop { let notify = { let state = self.state.lock(); @@ -164,10 +349,10 @@ impl GroupSubscriber { return Ok(Some(object)); } - match &state.closed { - Ok(()) => state.changed(), - Err(ServeError::Done) => return Ok(None), - Err(err) => return Err(err.clone()), + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), } }; @@ -176,8 +361,8 @@ impl GroupSubscriber { } } -impl Deref for GroupSubscriber { - type Target = Group; +impl Deref for GroupReader { + type Target = GroupInfo; fn deref(&self) -> &Self::Target { &self.info @@ -185,33 +370,184 @@ impl Deref for GroupSubscriber { } /// A subset of Object, since we use the group's info. -#[derive(Debug)] +#[derive(Clone, PartialEq, Debug)] pub struct GroupObject { - // The sequence number of the object within the group. + pub group: Arc, + pub object_id: u64, // The size of the object. pub size: usize, } -struct Dropped { - state: Watch, +impl GroupObject { + pub fn produce(self) -> (GroupObjectWriter, GroupObjectReader) { + let (writer, reader) = State::init(); + let info = Arc::new(self); + + let writer = GroupObjectWriter::new(writer, info.clone()); + let reader = GroupObjectReader::new(reader, info); + + (writer, reader) + } +} + +impl Deref for GroupObject { + type Target = GroupInfo; + + fn deref(&self) -> &Self::Target { + &self.group + } +} + +struct GroupObjectState { + // The data that has been received thus far. + chunks: Vec, + + // Set when the writer is dropped. + closed: Result<(), ServeError>, } -impl Dropped { - fn new(state: Watch) -> Self { - Self { state } +impl Default for GroupObjectState { + fn default() -> Self { + Self { + chunks: Vec::new(), + closed: Ok(()), + } } } -impl Drop for Dropped { +/// Used to write data to a segment and notify readers. +pub struct GroupObjectWriter { + // Mutable segment state. + state: State, + + // Immutable segment state. + pub info: Arc, + + // The amount of promised data that has yet to be written. + remain: usize, +} + +impl GroupObjectWriter { + /// Create a new segment with the given info. + fn new(state: State, object: Arc) -> Self { + Self { + state, + remain: object.size, + info: object, + } + } + + /// Write a new chunk of bytes. + pub fn write(&mut self, chunk: Bytes) -> Result<(), ServeError> { + if chunk.len() > self.remain { + return Err(ServeError::Size); + } + self.remain -= chunk.len(); + + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + state.chunks.push(chunk); + + Ok(()) + } + + /// Close the segment with an error. + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + if self.remain != 0 { + return Err(ServeError::Size); + } + + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + + Ok(()) + } +} + +impl Drop for GroupObjectWriter { fn drop(&mut self) { - self.state.lock_mut().close(ServeError::Done).ok(); + if self.remain == 0 { + return; + } + + if let Some(mut state) = self.state.lock_mut() { + state.closed = Err(ServeError::Size); + } + } +} + +impl Deref for GroupObjectWriter { + type Target = GroupObject; + + fn deref(&self) -> &Self::Target { + &self.info } } -impl fmt::Debug for Dropped { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Dropped").finish() +/// Notified when a segment has new data available. +#[derive(Clone)] +pub struct GroupObjectReader { + // Modify the segment state. + state: State, + + // Immutable segment state. + pub info: Arc, + + // The number of chunks that we've read. + // NOTE: Cloned readers inherit this index, but then run in parallel. + index: usize, +} + +impl GroupObjectReader { + fn new(state: State, object: Arc) -> Self { + Self { + state, + info: object, + index: 0, + } + } + + /// Block until the next chunk of bytes is available. + pub async fn read(&mut self) -> Result, ServeError> { + loop { + let notify = { + let state = self.state.lock(); + + if self.index < state.chunks.len() { + let chunk = state.chunks[self.index].clone(); + self.index += 1; + return Ok(Some(chunk)); + } + + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), // No more changes will come + } + }; + + notify.await; // Try again when the state changes + } + } + + pub async fn read_all(&mut self) -> Result { + let mut chunks = Vec::new(); + while let Some(chunk) = self.read().await? { + chunks.push(chunk); + } + + Ok(Bytes::from(chunks.concat())) + } +} + +impl Deref for GroupObjectReader { + type Target = GroupObject; + + fn deref(&self) -> &Self::Target { + &self.info } } diff --git a/moq-transport/src/serve/object.rs b/moq-transport/src/serve/object.rs index bc380e33..733c71dc 100644 --- a/moq-transport/src/serve/object.rs +++ b/moq-transport/src/serve/object.rs @@ -1,22 +1,173 @@ -//! A fragment is a stream of bytes with a header, split into a [Publisher] and [Subscriber] handle. +//! A fragment is a stream of bytes with a header, split into a [Writer] and [Reader] handle. //! -//! A [Publisher] writes an ordered stream of bytes in chunks. +//! A [Writer] writes an ordered stream of bytes in chunks. //! There's no framing, so these chunks can be of any size or position, and won't be maintained over the network. //! -//! A [Subscriber] reads an ordered stream of bytes in chunks. +//! A [Reader] reads an ordered stream of bytes in chunks. //! These chunks are returned directly from the QUIC connection, so they may be of any size or position. -//! You can clone the [Subscriber] and each will read a copy of of all future chunks. (fanout) +//! You can clone the [Reader] and each will read a copy of of all future chunks. (fanout) //! -//! The fragment is closed with [ServeError::Closed] when all publishers or subscribers are dropped. -use std::{fmt, ops::Deref, sync::Arc}; +//! The fragment is closed with [ServeError::Closed] when all writers or readers are dropped. +use std::{cmp, ops::Deref, sync::Arc}; -use super::ServeError; -use crate::util::Watch; +use super::{ServeError, Track}; +use crate::util::State; use bytes::Bytes; +pub struct Objects { + pub track: Arc, +} + +impl Objects { + pub fn produce(self) -> (ObjectsWriter, ObjectsReader) { + let (writer, reader) = State::init(); + + let writer = ObjectsWriter { + state: writer, + track: self.track.clone(), + }; + let reader = ObjectsReader::new(reader, self.track); + + (writer, reader) + } +} + +struct ObjectsState { + // The latest group. + objects: Vec, + + // Increased each time objects changes. + epoch: usize, + + // Can be sent by the writer with an explicit error code. + closed: Result<(), ServeError>, +} + +impl Default for ObjectsState { + fn default() -> Self { + Self { + objects: Vec::new(), + epoch: 0, + closed: Ok(()), + } + } +} + +pub struct ObjectsWriter { + state: State, + pub track: Arc, +} + +impl ObjectsWriter { + pub fn write(&mut self, object: Object, payload: Bytes) -> Result<(), ServeError> { + let mut writer = self.create(object)?; + writer.write(payload)?; + Ok(()) + } + + pub fn create(&mut self, object: Object) -> Result { + let object = ObjectInfo { + track: self.track.clone(), + group_id: object.group_id, + object_id: object.object_id, + priority: object.priority, + }; + + let (writer, reader) = object.produce(); + + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + + if let Some(first) = state.objects.first() { + match writer.group_id.cmp(&first.group_id) { + // Drop this old group + cmp::Ordering::Less => return Ok(writer), + cmp::Ordering::Greater => state.objects.clear(), + cmp::Ordering::Equal => {} + } + } + + state.objects.push(reader); + state.epoch += 1; + + Ok(writer) + } + + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + + Ok(()) + } +} + +impl Deref for ObjectsWriter { + type Target = Track; + + fn deref(&self) -> &Self::Target { + &self.track + } +} + +#[derive(Clone)] +pub struct ObjectsReader { + state: State, + pub track: Arc, + epoch: usize, +} + +impl ObjectsReader { + fn new(state: State, track: Arc) -> Self { + Self { state, track, epoch: 0 } + } + + pub async fn next(&mut self) -> Result, ServeError> { + loop { + let notify = { + let state = self.state.lock(); + if self.epoch < state.epoch { + let index = state.objects.len().saturating_sub(state.epoch - self.epoch); + self.epoch = state.epoch - state.objects.len() + index + 1; + return Ok(Some(state.objects[index].clone())); + } + + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), // No more updates will come + } + }; + + notify.await; + } + } + + // Returns the largest group/sequence + pub fn latest(&self) -> Option<(u64, u64)> { + let state = self.state.lock(); + state + .objects + .iter() + .max_by_key(|a| (a.group_id, a.object_id)) + .map(|a| (a.group_id, a.object_id)) + } +} + +impl Deref for ObjectsReader { + type Target = Track; + + fn deref(&self) -> &Self::Target { + &self.track + } +} + /// Static information about the segment. -#[derive(Clone, Debug)] -pub struct ObjectHeader { +#[derive(Clone, PartialEq, Debug)] +pub struct ObjectInfo { + pub track: Arc, + // The sequence number of the group within the track. pub group_id: u64, @@ -24,26 +175,29 @@ pub struct ObjectHeader { pub object_id: u64, // The priority of the stream. - pub send_order: u64, + pub priority: u64, +} + +impl Deref for ObjectInfo { + type Target = Track; - // The size of the object - pub size: usize, + fn deref(&self) -> &Self::Target { + &self.track + } } -impl ObjectHeader { - pub fn produce(self) -> (ObjectPublisher, ObjectSubscriber) { - let state = Watch::new(State::default()); +impl ObjectInfo { + pub fn produce(self) -> (ObjectWriter, ObjectReader) { + let (writer, reader) = State::init(); let info = Arc::new(self); - let publisher = ObjectPublisher::new(state.clone(), info.clone()); - let subscriber = ObjectSubscriber::new(state, info); + let writer = ObjectWriter::new(writer, info.clone()); + let reader = ObjectReader::new(reader, info); - (publisher, subscriber) + (writer, reader) } } -/// Same as below but with a fully known payload. -#[derive(Clone)] pub struct Object { // The sequence number of the group within the track. pub group_id: u64, @@ -52,61 +206,18 @@ pub struct Object { pub object_id: u64, // The priority of the stream. - pub send_order: u64, - - // The payload. - pub payload: Bytes, -} - -impl From for ObjectHeader { - fn from(info: Object) -> Self { - Self { - group_id: info.group_id, - object_id: info.object_id, - send_order: info.send_order, - size: info.payload.len(), - } - } -} - -impl fmt::Debug for Object { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Object") - .field("group_id", &self.group_id) - .field("object_id", &self.object_id) - .field("send_order", &self.send_order) - .field("payload", &self.payload.len()) - .finish() - } + pub priority: u64, } -struct State { +struct ObjectState { // The data that has been received thus far. chunks: Vec, - // Set when the publisher is dropped. + // Set when the writer is dropped. closed: Result<(), ServeError>, } -impl State { - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.closed.clone()?; - self.closed = Err(err); - Ok(()) - } -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("State") - .field("chunks", &self.chunks.len()) - .field("size", &self.chunks.iter().map(|c| c.len()).sum::()) - .field("closed", &self.closed) - .finish() - } -} - -impl Default for State { +impl Default for ObjectState { fn default() -> Self { Self { chunks: Vec::new(), @@ -115,62 +226,43 @@ impl Default for State { } } -/// Used to write data to a segment and notify subscribers. -#[derive(Debug)] -pub struct ObjectPublisher { +/// Used to write data to a segment and notify readers. +pub struct ObjectWriter { // Mutable segment state. - state: Watch, + state: State, // Immutable segment state. - info: Arc, - - // The amount of promised data that has yet to be written. - remain: usize, + pub info: Arc, } -impl ObjectPublisher { +impl ObjectWriter { /// Create a new segment with the given info. - fn new(state: Watch, info: Arc) -> Self { - Self { - state, - remain: info.size, - info, - } + fn new(state: State, object: Arc) -> Self { + Self { state, info: object } } /// Write a new chunk of bytes. pub fn write(&mut self, chunk: Bytes) -> Result<(), ServeError> { - if chunk.len() > self.remain { - return Err(ServeError::WrongSize); - } - self.remain -= chunk.len(); - - let mut state = self.state.lock_mut(); - state.closed.clone()?; + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; state.chunks.push(chunk); Ok(()) } /// Close the segment with an error. - pub fn close(&mut self, mut err: ServeError) -> Result<(), ServeError> { - if err == ServeError::Done && self.remain != 0 { - err = ServeError::WrongSize; - } + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; - self.state.lock_mut().close(err)?; - Ok(()) - } -} + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); -impl Drop for ObjectPublisher { - fn drop(&mut self) { - self.close(ServeError::Done).ok(); + Ok(()) } } -impl Deref for ObjectPublisher { - type Target = ObjectHeader; +impl Deref for ObjectWriter { + type Target = ObjectInfo; fn deref(&self) -> &Self::Target { &self.info @@ -178,29 +270,25 @@ impl Deref for ObjectPublisher { } /// Notified when a segment has new data available. -#[derive(Clone, Debug)] -pub struct ObjectSubscriber { +#[derive(Clone)] +pub struct ObjectReader { // Modify the segment state. - state: Watch, + state: State, // Immutable segment state. - info: Arc, + pub info: Arc, // The number of chunks that we've read. - // NOTE: Cloned subscribers inherit this index, but then run in parallel. + // NOTE: Cloned readers inherit this index, but then run in parallel. index: usize, - - _dropped: Arc, } -impl ObjectSubscriber { - fn new(state: Watch, info: Arc) -> Self { - let _dropped = Arc::new(Dropped::new(state.clone())); +impl ObjectReader { + fn new(state: State, object: Arc) -> Self { Self { state, - info, + info: object, index: 0, - _dropped, } } @@ -216,10 +304,10 @@ impl ObjectSubscriber { return Ok(Some(chunk)); } - match &state.closed { - Err(ServeError::Done) => return Ok(None), - Err(err) => return Err(err.clone()), - Ok(()) => state.changed(), + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), // No more updates will come } }; @@ -237,32 +325,10 @@ impl ObjectSubscriber { } } -impl Deref for ObjectSubscriber { - type Target = ObjectHeader; +impl Deref for ObjectReader { + type Target = ObjectInfo; fn deref(&self) -> &Self::Target { &self.info } } - -struct Dropped { - state: Watch, -} - -impl Dropped { - fn new(state: Watch) -> Self { - Self { state } - } -} - -impl Drop for Dropped { - fn drop(&mut self) { - self.state.lock_mut().close(ServeError::Done).ok(); - } -} - -impl fmt::Debug for Dropped { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Dropped").finish() - } -} diff --git a/moq-transport/src/serve/stream.rs b/moq-transport/src/serve/stream.rs index 3c2fdaa5..1c28c707 100644 --- a/moq-transport/src/serve/stream.rs +++ b/moq-transport/src/serve/stream.rs @@ -1,88 +1,123 @@ -use std::{fmt, ops::Deref, sync::Arc}; +use bytes::Bytes; +use std::{ops::Deref, sync::Arc}; -use crate::util::Watch; +use crate::util::State; -use super::ServeError; +use super::{ServeError, Track}; -#[derive(Debug)] +#[derive(Debug, PartialEq, Clone)] pub struct Stream { - pub namespace: String, - pub name: String, - pub send_order: u64, + pub track: Arc, + pub priority: u64, } impl Stream { - pub fn produce(self) -> (StreamPublisher, StreamSubscriber) { - let state = Watch::new(State::default()); + pub fn produce(self) -> (StreamWriter, StreamReader) { + let (writer, reader) = State::init(); let info = Arc::new(self); - let publisher = StreamPublisher::new(state.clone(), info.clone()); - let subscriber = StreamSubscriber::new(state, info); + let writer = StreamWriter::new(writer, info.clone()); + let reader = StreamReader::new(reader, info); - (publisher, subscriber) + (writer, reader) } } -#[derive(Debug)] -struct State { - // The data that has been received thus far. - objects: Vec, +impl Deref for Stream { + type Target = Track; - // Set when the publisher is dropped. - closed: Result<(), ServeError>, + fn deref(&self) -> &Self::Target { + &self.track + } } -impl State { - pub fn insert_object(&mut self, object: StreamObject) -> Result<(), ServeError> { - self.closed.clone()?; - self.objects.push(object); - Ok(()) - } +struct StreamState { + // The latest group. + latest: Option, - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.closed.clone()?; - self.closed = Err(err); - Ok(()) - } + // Updated each time objects changes. + epoch: usize, + + // Set when the writer is dropped. + closed: Result<(), ServeError>, } -impl Default for State { +impl Default for StreamState { fn default() -> Self { Self { - objects: Vec::new(), + latest: None, + epoch: 0, closed: Ok(()), } } } -/// Used to write data to a stream and notify subscribers. -#[derive(Debug)] -pub struct StreamPublisher { +/// Used to write data to a stream and notify readers. +/// +/// This is Clone as a work-around, but be very careful because it's meant to be sequential. +#[derive(Clone)] +pub struct StreamWriter { // Mutable stream state. - state: Watch, + state: State, // Immutable stream state. - info: Arc, + pub info: Arc, } -impl StreamPublisher { - fn new(state: Watch, info: Arc) -> Self { +impl StreamWriter { + fn new(state: State, info: Arc) -> Self { Self { state, info } } - /// Create an object with the given info and payload. - pub fn write_object(&mut self, info: StreamObject) -> Result<(), ServeError> { - self.state.lock_mut().insert_object(info)?; - Ok(()) + pub fn create(&mut self, group_id: u64) -> Result { + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + + if let Some(latest) = &state.latest { + if latest.group_id > group_id { + return Err(ServeError::Duplicate); + } + } + + let group = Arc::new(StreamGroup { + stream: self.info.clone(), + group_id, + }); + + let (writer, reader) = State::init(); + + let reader = StreamGroupReader::new(reader, group.clone()); + let writer = StreamGroupWriter::new(writer, group); + + state.latest = Some(reader); + state.epoch += 1; + + Ok(writer) + } + + pub fn append(&mut self) -> Result { + let next = self + .state + .lock() + .latest + .as_ref() + .map(|g| g.group_id + 1) + .unwrap_or_default(); + self.create(next) } /// Close the stream with an error. - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.state.lock_mut().close(err) + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + + Ok(()) } } -impl Deref for StreamPublisher { +impl Deref for StreamWriter { type Target = Stream; fn deref(&self) -> &Self::Target { @@ -91,47 +126,39 @@ impl Deref for StreamPublisher { } /// Notified when a stream has new data available. -#[derive(Clone, Debug)] -pub struct StreamSubscriber { +#[derive(Clone)] +pub struct StreamReader { // Modify the stream state. - state: Watch, + state: State, // Immutable stream state. - info: Arc, + pub info: Arc, // The number of chunks that we've read. - // NOTE: Cloned subscribers inherit this index, but then run in parallel. - index: usize, - - _dropped: Arc, + // NOTE: Cloned readers inherit this index, but then run in parallel. + epoch: usize, } -impl StreamSubscriber { - fn new(state: Watch, info: Arc) -> Self { - let _dropped = Arc::new(Dropped::new(state.clone())); - Self { - state, - info, - index: 0, - _dropped, - } +impl StreamReader { + fn new(state: State, info: Arc) -> Self { + Self { state, info, epoch: 0 } } - /// Block until the next object is available. - pub async fn next(&mut self) -> Result, ServeError> { + /// Block until the next group is available. + pub async fn next(&mut self) -> Result, ServeError> { loop { let notify = { let state = self.state.lock(); - if self.index < state.objects.len() { - let object = state.objects[self.index].clone(); - self.index += 1; - return Ok(Some(object)); + if self.epoch != state.epoch { + self.epoch = state.epoch; + let latest = state.latest.clone().unwrap(); + return Ok(Some(latest)); } - match &state.closed { - Ok(()) => state.changed(), - Err(ServeError::Done) => return Ok(None), - Err(err) => return Err(err.clone()), + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), } }; @@ -139,17 +166,14 @@ impl StreamSubscriber { } } + // Returns the largest group/sequence pub fn latest(&self) -> Option<(u64, u64)> { - self.state - .lock() - .objects - .iter() - .max_by_key(|a| (a.group_id, a.object_id)) - .map(|a| (a.group_id, a.object_id)) + let state = self.state.lock(); + state.latest.as_ref().map(|group| (group.group_id, group.latest())) } } -impl Deref for StreamSubscriber { +impl Deref for StreamReader { type Target = Stream; fn deref(&self) -> &Self::Target { @@ -157,46 +181,323 @@ impl Deref for StreamSubscriber { } } -struct Dropped { - state: Watch, +#[derive(Clone, PartialEq, Debug)] +pub struct StreamGroup { + pub stream: Arc, + pub group_id: u64, } -impl Dropped { - fn new(state: Watch) -> Self { - Self { state } +impl Deref for StreamGroup { + type Target = Stream; + + fn deref(&self) -> &Self::Target { + &self.stream } } -impl Drop for Dropped { - fn drop(&mut self) { - self.state.lock_mut().close(ServeError::Done).ok(); +struct StreamGroupState { + // The objects that have been received thus far. + objects: Vec, + closed: Result<(), ServeError>, +} + +impl Default for StreamGroupState { + fn default() -> Self { + Self { + objects: Vec::new(), + closed: Ok(()), + } } } -impl fmt::Debug for Dropped { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Dropped").finish() +pub struct StreamGroupWriter { + state: State, + pub group: Arc, + next: u64, +} + +impl StreamGroupWriter { + fn new(state: State, group: Arc) -> Self { + Self { state, group, next: 0 } + } + + /// Add a new object to the group. + pub fn write(&mut self, payload: Bytes) -> Result<(), ServeError> { + let mut writer = self.create(payload.len())?; + writer.write(payload)?; + Ok(()) + } + + pub fn create(&mut self, size: usize) -> Result { + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + + let (writer, reader) = StreamObject { + group: self.group.clone(), + object_id: self.next, + size, + } + .produce(); + + state.objects.push(reader); + + Ok(writer) + } + + /// Close the stream with an error. + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + + Ok(()) + } +} + +impl Deref for StreamGroupWriter { + type Target = StreamGroup; + + fn deref(&self) -> &Self::Target { + &self.group } } #[derive(Clone)] +pub struct StreamGroupReader { + pub group: Arc, + state: State, + index: usize, +} + +impl StreamGroupReader { + fn new(state: State, group: Arc) -> Self { + Self { state, group, index: 0 } + } + + pub async fn read_next(&mut self) -> Result, ServeError> { + if let Some(mut reader) = self.next().await? { + Ok(Some(reader.read_all().await?)) + } else { + Ok(None) + } + } + + pub async fn next(&mut self) -> Result, ServeError> { + loop { + let notify = { + let state = self.state.lock(); + if self.index < state.objects.len() { + self.index += 1; + return Ok(Some(state.objects[self.index].clone())); + } + + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), + } + }; + + notify.await + } + } + + pub fn latest(&self) -> u64 { + let state = self.state.lock(); + state.objects.last().map(|o| o.object_id).unwrap_or_default() + } +} + +impl Deref for StreamGroupReader { + type Target = StreamGroup; + + fn deref(&self) -> &Self::Target { + &self.group + } +} + +/// A subset of Object, since we use the group's info. +#[derive(Clone, PartialEq, Debug)] pub struct StreamObject { - // The sequence number of the group within the track. - pub group_id: u64, + // The group this belongs to. + pub group: Arc, - // The sequence number of the object within the group. pub object_id: u64, - // The payload. - pub payload: bytes::Bytes, + // The size of the object. + pub size: usize, +} + +impl StreamObject { + pub fn produce(self) -> (StreamObjectWriter, StreamObjectReader) { + let (writer, reader) = State::init(); + let info = Arc::new(self); + + let writer = StreamObjectWriter::new(writer, info.clone()); + let reader = StreamObjectReader::new(reader, info); + + (writer, reader) + } +} + +impl Deref for StreamObject { + type Target = StreamGroup; + + fn deref(&self) -> &Self::Target { + &self.group + } +} + +struct StreamObjectState { + // The data that has been received thus far. + chunks: Vec, + + closed: Result<(), ServeError>, +} + +impl Default for StreamObjectState { + fn default() -> Self { + Self { + chunks: Vec::new(), + closed: Ok(()), + } + } +} + +/// Used to write data to a segment and notify readers. +pub struct StreamObjectWriter { + // Mutable segment state. + state: State, + + // Immutable segment state. + pub object: Arc, + + // The amount of promised data that has yet to be written. + remain: usize, +} + +impl StreamObjectWriter { + /// Create a new segment with the given info. + fn new(state: State, object: Arc) -> Self { + Self { + state, + remain: object.size, + object, + } + } + + /// Write a new chunk of bytes. + pub fn write(&mut self, chunk: Bytes) -> Result<(), ServeError> { + if chunk.len() > self.remain { + return Err(ServeError::Size); + } + self.remain -= chunk.len(); + + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + state.chunks.push(chunk); + + Ok(()) + } + + /// Close the stream with an error. + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + + Ok(()) + } } -impl fmt::Debug for StreamObject { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("StreamObject") - .field("group_id", &self.group_id) - .field("object_id", &self.object_id) - .field("payload", &self.payload.len()) - .finish() +impl Drop for StreamObjectWriter { + // Make sure we fully write the segment, otherwise close it with an error. + fn drop(&mut self) { + if self.remain == 0 { + return; + } + + let state = self.state.lock(); + if state.closed.is_err() { + return; + } + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Size); + } + } +} + +impl Deref for StreamObjectWriter { + type Target = StreamObject; + + fn deref(&self) -> &Self::Target { + &self.object + } +} + +/// Notified when a segment has new data available. +#[derive(Clone)] +pub struct StreamObjectReader { + // Modify the segment state. + state: State, + + // Immutable segment state. + pub object: Arc, + + // The number of chunks that we've read. + // NOTE: Cloned readers inherit this index, but then run in parallel. + index: usize, +} + +impl StreamObjectReader { + fn new(state: State, object: Arc) -> Self { + Self { + state, + object, + index: 0, + } + } + + /// Block until the next chunk of bytes is available. + pub async fn read(&mut self) -> Result, ServeError> { + loop { + let notify = { + let state = self.state.lock(); + + if self.index < state.chunks.len() { + let chunk = state.chunks[self.index].clone(); + self.index += 1; + return Ok(Some(chunk)); + } + + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Ok(None), + } + }; + + notify.await; // Try again when the state changes + } + } + + pub async fn read_all(&mut self) -> Result { + let mut chunks = Vec::new(); + while let Some(chunk) = self.read().await? { + chunks.push(chunk); + } + + Ok(Bytes::from(chunks.concat())) + } +} + +impl Deref for StreamObjectReader { + type Target = StreamObject; + + fn deref(&self) -> &Self::Target { + &self.object } } diff --git a/moq-transport/src/serve/track.rs b/moq-transport/src/serve/track.rs index a7425f76..12cf36b4 100644 --- a/moq-transport/src/serve/track.rs +++ b/moq-transport/src/serve/track.rs @@ -1,27 +1,28 @@ -//! A track is a collection of semi-reliable and semi-ordered streams, split into a [Publisher] and [Subscriber] handle. +//! A track is a collection of semi-reliable and semi-ordered streams, split into a [Writer] and [Reader] handle. //! -//! A [Publisher] creates streams with a sequence number and priority. +//! A [Writer] creates streams with a sequence number and priority. //! The sequest number is used to determine the order of streams, while the priority is used to determine which stream to transmit first. //! This may seem counter-intuitive, but is designed for live streaming where the newest streams may be higher priority. -//! A cloned [Publisher] can be used to create streams in parallel, but will error if a duplicate sequence number is used. +//! A cloned [Writer] can be used to create streams in parallel, but will error if a duplicate sequence number is used. //! -//! A [Subscriber] may not receive all streams in order or at all. +//! A [Reader] may not receive all streams in order or at all. //! These streams are meant to be transmitted over congested networks and the key to MoQ Tranport is to not block on them. //! streams will be cached for a potentially limited duration added to the unreliable nature. -//! A cloned [Subscriber] will receive a copy of all new stream going forward (fanout). +//! A cloned [Reader] will receive a copy of all new stream going forward (fanout). //! -//! The track is closed with [ServeError::Closed] when all publishers or subscribers are dropped. +//! The track is closed with [ServeError::Closed] when all writers or readers are dropped. -use crate::util::Watch; +use crate::util::State; use super::{ - Datagram, Group, GroupPublisher, GroupSubscriber, Object, ObjectHeader, ObjectPublisher, ObjectSubscriber, - ServeError, Stream, StreamPublisher, StreamSubscriber, + Datagrams, DatagramsReader, DatagramsWriter, Groups, GroupsReader, GroupsWriter, Objects, ObjectsReader, + ObjectsWriter, ServeError, Stream, StreamReader, StreamWriter, }; -use std::{cmp, fmt, ops::Deref, sync::Arc}; +use paste::paste; +use std::{ops::Deref, sync::Arc}; /// Static information about a track. -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub struct Track { pub namespace: String, pub name: String, @@ -35,202 +36,100 @@ impl Track { } } - pub fn produce(self) -> (TrackPublisher, TrackSubscriber) { - let state = Watch::new(State::default()); + pub fn produce(self) -> (TrackWriter, TrackReader) { + let (writer, reader) = State::init(); let info = Arc::new(self); - let publisher = TrackPublisher::new(state.clone(), info.clone()); - let subscriber = TrackSubscriber::new(state, info); + let writer = TrackWriter::new(writer, info.clone()); + let reader = TrackReader::new(reader, info); - (publisher, subscriber) + (writer, reader) } } -// The state of the cache, depending on the mode> -#[derive(Debug)] -enum Mode { - Init, - Stream(StreamSubscriber), - Group(GroupSubscriber), - Object(Vec), - Datagram(Datagram), -} - -#[derive(Debug)] -struct State { - mode: Mode, - epoch: usize, - - // Set when the publisher is closed/dropped, or all subscribers are dropped. +struct TrackState { + mode: Option, closed: Result<(), ServeError>, } -impl State { - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.closed.clone()?; - self.closed = Err(err); - Ok(()) - } - - pub fn insert_group(&mut self, group: GroupSubscriber) -> Result<(), ServeError> { - self.closed.clone()?; - - match &self.mode { - Mode::Init => {} - Mode::Group(old) => match group.id.cmp(&old.id) { - cmp::Ordering::Less => return Ok(()), - cmp::Ordering::Equal => return Err(ServeError::Duplicate), - cmp::Ordering::Greater => {} - }, - _ => return Err(ServeError::Mode), - }; - - self.mode = Mode::Group(group); - self.epoch += 1; - - Ok(()) - } - - pub fn insert_object(&mut self, object: ObjectSubscriber) -> Result<(), ServeError> { - self.closed.clone()?; - - match &mut self.mode { - Mode::Init => { - self.mode = Mode::Object(vec![object]); - } - Mode::Object(objects) => { - let first = objects.first().unwrap(); - - match object.group_id.cmp(&first.group_id) { - // Drop this old group - cmp::Ordering::Less => return Ok(()), - cmp::Ordering::Greater => objects.clear(), - cmp::Ordering::Equal => {} - } - - objects.push(object); - } - _ => return Err(ServeError::Mode), - }; - - self.epoch += 1; - - Ok(()) - } - - pub fn insert_datagram(&mut self, datagram: Datagram) -> Result<(), ServeError> { - self.closed.clone()?; - - match &self.mode { - Mode::Init | Mode::Datagram(_) => {} - _ => return Err(ServeError::Mode), - }; - - self.mode = Mode::Datagram(datagram); - self.epoch += 1; - - Ok(()) - } - - pub fn set_stream(&mut self, stream: StreamSubscriber) -> Result<(), ServeError> { - self.closed.clone()?; - - match &self.mode { - Mode::Init => {} - _ => return Err(ServeError::Mode), - }; - - self.mode = Mode::Stream(stream); - self.epoch += 1; - - Ok(()) - } -} - -impl Default for State { +impl Default for TrackState { fn default() -> Self { Self { - mode: Mode::Init, - epoch: 0, + mode: None, closed: Ok(()), } } } /// Creates new streams for a track. -#[derive(Debug)] -pub struct TrackPublisher { - state: Watch, - info: Arc, +pub struct TrackWriter { + state: State, + pub info: Arc, } -impl TrackPublisher { +impl TrackWriter { /// Create a track with the given name. - fn new(state: Watch, info: Arc) -> Self { + fn new(state: State, info: Arc) -> Self { Self { state, info } } - /// Create a group with the given info. - pub fn create_group(&mut self, group: Group) -> Result { - let (publisher, subscriber) = group.produce(); - self.state.lock_mut().insert_group(subscriber)?; - Ok(publisher) - } + pub fn stream(self, priority: u64) -> Result { + let (writer, reader) = Stream { + track: self.info.clone(), + priority, + } + .produce(); - /// Create an object with the given info and payload. - pub fn write_object(&mut self, object: Object) -> Result<(), ServeError> { - let payload = object.payload.clone(); - let header = ObjectHeader::from(object); - let (mut publisher, subscriber) = header.produce(); - publisher.write(payload)?; - self.state.lock_mut().insert_object(subscriber)?; - Ok(()) + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + state.mode = Some(reader.into()); + Ok(writer) } - /// Create an object with the given info and size, but no payload yet. - pub fn create_object(&mut self, object: ObjectHeader) -> Result { - let (publisher, subscriber) = object.produce(); - self.state.lock_mut().insert_object(subscriber)?; - Ok(publisher) + pub fn groups(self) -> Result { + let (writer, reader) = Groups { + track: self.info.clone(), + } + .produce(); + + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + state.mode = Some(reader.into()); + Ok(writer) } - /// Create a single stream for the entire track, served in strict order. - pub fn create_stream(&mut self, send_order: u64) -> Result { - let (publisher, subscriber) = Stream { - namespace: self.namespace.clone(), - name: self.name.clone(), - send_order, + pub fn objects(self) -> Result { + let (writer, reader) = Objects { + track: self.info.clone(), } .produce(); - self.state.lock_mut().set_stream(subscriber)?; - Ok(publisher) - } - /// Create a datagram that is not cached. - pub fn write_datagram(&mut self, info: Datagram) -> Result<(), ServeError> { - self.state.lock_mut().insert_datagram(info)?; - Ok(()) + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + state.mode = Some(reader.into()); + Ok(writer) } - /// Close the stream with an error. - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.state.lock_mut().close(err) + pub fn datagrams(self) -> Result { + let (writer, reader) = Datagrams { + track: self.info.clone(), + } + .produce(); + + let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; + state.mode = Some(reader.into()); + Ok(writer) } - pub async fn closed(&self) -> Result<(), ServeError> { - loop { - let notify = { - let state = self.state.lock(); - state.closed.clone()?; - state.changed() - }; + /// Close the track with an error. + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; - notify.await - } + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + Ok(()) } } -impl Deref for TrackPublisher { +impl Deref for TrackWriter { type Target = Track; fn deref(&self) -> &Self::Target { @@ -239,79 +138,41 @@ impl Deref for TrackPublisher { } /// Receives new streams for a track. -#[derive(Clone, Debug)] -pub struct TrackSubscriber { - state: Watch, - info: Arc, - epoch: usize, - _dropped: Arc, +#[derive(Clone)] +pub struct TrackReader { + state: State, + pub info: Arc, } -impl TrackSubscriber { - fn new(state: Watch, info: Arc) -> Self { - let _dropped = Arc::new(Dropped::new(state.clone())); - Self { - state, - info, - epoch: 0, - _dropped, - } +impl TrackReader { + fn new(state: State, info: Arc) -> Self { + Self { state, info } } - /// Block until the next stream arrives - pub async fn next(&mut self) -> Result, ServeError> { + pub async fn mode(self) -> Result { loop { let notify = { let state = self.state.lock(); - - if self.epoch != state.epoch { - match &state.mode { - Mode::Init => {} - Mode::Stream(stream) => { - self.epoch = state.epoch; - return Ok(Some(stream.clone().into())); - } - Mode::Group(group) => { - self.epoch = state.epoch; - return Ok(Some(group.clone().into())); - } - Mode::Object(objects) => { - let index = objects.len().saturating_sub(state.epoch - self.epoch); - self.epoch = state.epoch - objects.len() + index + 1; - return Ok(Some(objects[index].clone().into())); - } - Mode::Datagram(datagram) => { - self.epoch = state.epoch; - return Ok(Some(datagram.clone().into())); - } - } + if let Some(mode) = &state.mode { + return Ok(mode.clone()); } - // Otherwise check if we need to return an error. - match &state.closed { - Ok(()) => state.changed(), - Err(ServeError::Done) => return Ok(None), - Err(err) => return Err(err.clone()), + state.closed.clone()?; + match state.modified() { + Some(notify) => notify, + None => return Err(ServeError::Done), } }; - notify.await + notify.await; } } // Returns the largest group/sequence pub fn latest(&self) -> Option<(u64, u64)> { - let state = self.state.lock(); - match &state.mode { - Mode::Init => None, - Mode::Datagram(datagram) => Some((datagram.group_id, datagram.object_id)), - Mode::Group(group) => Some((group.id, group.latest())), - Mode::Object(objects) => objects - .iter() - .max_by_key(|a| (a.group_id, a.object_id)) - .map(|a| (a.group_id, a.object_id)), - Mode::Stream(stream) => stream.latest(), - } + // We don't even know the mode yet. + // TODO populate from SUBSCRIBE_OK + None } pub async fn closed(&self) -> Result<(), ServeError> { @@ -319,7 +180,11 @@ impl TrackSubscriber { let notify = { let state = self.state.lock(); state.closed.clone()?; - state.changed() + + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } }; notify.await @@ -327,7 +192,7 @@ impl TrackSubscriber { } } -impl Deref for TrackSubscriber { +impl Deref for TrackReader { type Target = Track; fn deref(&self) -> &Self::Target { @@ -335,56 +200,55 @@ impl Deref for TrackSubscriber { } } -#[derive(Debug)] -pub enum TrackMode { - Stream(StreamSubscriber), - Group(GroupSubscriber), - Object(ObjectSubscriber), - Datagram(Datagram), -} - -impl From for TrackMode { - fn from(subscriber: StreamSubscriber) -> Self { - Self::Stream(subscriber) - } -} +macro_rules! track_readers { + {$($name:ident,)*} => { + paste! { + #[derive(Clone)] + pub enum TrackReaderMode { + $($name([<$name Reader>])),* + } -impl From for TrackMode { - fn from(subscriber: GroupSubscriber) -> Self { - Self::Group(subscriber) - } -} + $(impl From<[<$name Reader>]> for TrackReaderMode { + fn from(reader: [<$name Reader >]) -> Self { + Self::$name(reader) + } + })* -impl From for TrackMode { - fn from(subscriber: ObjectSubscriber) -> Self { - Self::Object(subscriber) + impl TrackReaderMode { + pub fn latest(&self) -> Option<(u64, u64)> { + match self { + $(Self::$name(reader) => reader.latest(),)* + } + } + } + } } } -impl From for TrackMode { - fn from(info: Datagram) -> Self { - Self::Datagram(info) - } -} +track_readers!(Stream, Groups, Objects, Datagrams,); -struct Dropped { - state: Watch, -} +macro_rules! track_writers { + {$($name:ident,)*} => { + paste! { + pub enum TrackWriterMode { + $($name([<$name Writer>])),* + } -impl Dropped { - fn new(state: Watch) -> Self { - Self { state } - } -} + $(impl From<[<$name Writer>]> for TrackWriterMode { + fn from(writer: [<$name Writer>]) -> Self { + Self::$name(writer) + } + })* -impl Drop for Dropped { - fn drop(&mut self) { - self.state.lock_mut().close(ServeError::Done).ok(); + impl TrackWriterMode { + pub fn close(self, err: ServeError) -> Result<(), ServeError>{ + match self { + $(Self::$name(writer) => writer.close(err),)* + } + } + } + } } } -impl fmt::Debug for Dropped { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Dropped").finish() - } -} +track_writers!(Track, Stream, Groups, Objects, Datagrams,); diff --git a/moq-transport/src/session/announce.rs b/moq-transport/src/session/announce.rs index 93a9f07d..7d63b65f 100644 --- a/moq-transport/src/session/announce.rs +++ b/moq-transport/src/session/announce.rs @@ -1,80 +1,158 @@ -use crate::{message, serve::ServeError, util::Watch}; +use std::{collections::VecDeque, ops}; -use super::Publisher; +use crate::{message, serve::ServeError, Publisher}; -pub struct Announce { - session: Publisher, - msg: message::Announce, - state: Watch, +use super::Subscribed; + +use crate::util::State; + +#[derive(Debug, Clone)] +pub struct AnnounceInfo { + pub namespace: String, } -impl Announce { - pub(super) fn new(session: Publisher, msg: message::Announce) -> (Announce, AnnounceRecv) { - let state = Watch::default(); - let recv = AnnounceRecv { state: state.clone() }; +struct AnnounceState { + subscribers: VecDeque>, + ok: bool, + closed: Result<(), ServeError>, +} - let announce = Self { session, msg, state }; +impl Default for AnnounceState { + fn default() -> Self { + Self { + subscribers: Default::default(), + ok: false, + closed: Ok(()), + } + } +} - (announce, recv) +impl Drop for AnnounceState { + fn drop(&mut self) { + for subscriber in self.subscribers.drain(..) { + subscriber.close(ServeError::NotFound).ok(); + } } +} + +pub struct Announce { + publisher: Publisher, + state: State>, - pub fn namespace(&self) -> &str { - &self.msg.namespace + pub info: AnnounceInfo, +} + +impl Announce { + pub(super) fn new(mut publisher: Publisher, namespace: String) -> (Announce, AnnounceRecv) { + let info = AnnounceInfo { + namespace: namespace.clone(), + }; + + publisher.send_message(message::Announce { + namespace, + params: Default::default(), + }); + + let (send, recv) = State::init(); + + let send = Self { + publisher, + info, + state: send, + }; + let recv = AnnounceRecv { state: recv }; + + (send, recv) } - fn close(&mut self) -> Result<(), ServeError> { - let mut state = self.state.lock_mut(); - state.closed.clone()?; - state.closed = Err(ServeError::Done); + // Run until we get an error + pub async fn serve(self) -> Result<(), ServeError> { + loop { + let notify = { + let state = self.state.lock(); + state.closed.clone()?; - self.session - .send_message(message::Unannounce { - namespace: self.msg.namespace.clone(), - }) - .ok(); + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + }; - Ok(()) + notify.await + } } - pub async fn closed(&self) -> Result<(), ServeError> { + pub async fn subscribed(&mut self) -> Result>, ServeError> { loop { let notify = { let state = self.state.lock(); + if !state.subscribers.is_empty() { + return Ok(state.into_mut().and_then(|mut state| state.subscribers.pop_front())); + } + state.closed.clone()?; - state.changed() + match state.modified() { + Some(notified) => notified, + None => return Ok(None), + } }; - notify.await; + notify.await } } } impl Drop for Announce { fn drop(&mut self) { - self.close().ok(); - self.session.drop_announce(&self.msg.namespace); + if self.state.lock().closed.is_err() { + return; + } + + self.publisher.send_message(message::Unannounce { + namespace: self.info.namespace.to_string(), + }); } } -pub(super) struct AnnounceRecv { - state: Watch, +impl ops::Deref for Announce { + type Target = AnnounceInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub(super) struct AnnounceRecv { + state: State>, } -impl AnnounceRecv { - pub fn recv_error(&mut self, err: ServeError) -> Result<(), ServeError> { - let mut state = self.state.lock_mut(); +impl AnnounceRecv { + pub fn recv_ok(&mut self) -> Result<(), ServeError> { + if let Some(mut state) = self.state.lock_mut() { + if state.ok { + return Err(ServeError::Duplicate); + } + + state.ok = true; + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; state.closed = Err(err); + Ok(()) } -} -struct State { - closed: Result<(), ServeError>, -} + pub fn recv_subscribe(&mut self, subscriber: Subscribed) -> Result<(), ServeError> { + let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; + state.subscribers.push_back(subscriber); -impl Default for State { - fn default() -> Self { - Self { closed: Ok(()) } + Ok(()) } } diff --git a/moq-transport/src/session/announced.rs b/moq-transport/src/session/announced.rs index 40a38baa..228deb15 100644 --- a/moq-transport/src/session/announced.rs +++ b/moq-transport/src/session/announced.rs @@ -1,120 +1,103 @@ -use crate::{message, serve::ServeError, util::Watch}; +use std::ops; -use super::Subscriber; +use crate::{message, serve::ServeError, util::State}; + +use super::{AnnounceInfo, Subscriber}; + +// There's currently no feedback from the peer, so the shared state is empty. +// If Unannounce contained an error code then we'd be talking. +#[derive(Default)] +struct AnnouncedState {} pub struct Announced { session: Subscriber, - namespace: String, - state: Watch>, + state: State, + + pub info: AnnounceInfo, + + ok: bool, + error: Option, } impl Announced { - pub(super) fn new(session: Subscriber, namespace: String) -> (Announced, AnnouncedRecv) { - let state = Watch::new(State::new(session.clone(), namespace.clone())); - let recv = AnnouncedRecv { state: state.clone() }; + pub(super) fn new(session: Subscriber, namespace: String) -> (Announced, AnnouncedRecv) { + let info = AnnounceInfo { namespace }; - let announced = Self { + let (send, recv) = State::init(); + let send = Self { session, - namespace, - state, + info, + ok: false, + error: None, + state: send, }; + let recv = AnnouncedRecv { _state: recv }; - (announced, recv) - } - - pub fn namespace(&self) -> &str { - &self.namespace + (send, recv) } // Send an ANNOUNCE_OK - pub fn accept(&mut self) -> Result<(), ServeError> { - self.state.lock_mut().accept() - } + pub fn ok(&mut self) -> Result<(), ServeError> { + if self.ok { + return Err(ServeError::Duplicate); + } + + self.session.send_message(message::AnnounceOk { + namespace: self.namespace.clone(), + }); - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.state.lock_mut().close(err) + self.ok = true; + + Ok(()) } pub async fn closed(&self) -> Result<(), ServeError> { loop { - let notify = { - let state = self.state.lock(); - state.closed.clone()?; - state.changed() - }; - - notify.await; + // Wow this is dumb and yet pretty cool. + // Basically loop until the state changes and exit when Recv is dropped. + self.state.lock().modified().ok_or(ServeError::Cancel)?.await; } } -} -impl Drop for Announced { - fn drop(&mut self) { - self.close(ServeError::Done).ok(); - self.session.drop_announce(&self.namespace); + pub fn close(mut self, err: ServeError) -> Result<(), ServeError> { + self.error = Some(err); + Ok(()) } } -pub(super) struct AnnouncedRecv { - state: Watch>, -} +impl ops::Deref for Announced { + type Target = AnnounceInfo; -impl AnnouncedRecv { - pub fn recv_unannounce(&mut self) -> Result<(), ServeError> { - self.state.lock_mut().close(ServeError::Done) + fn deref(&self) -> &AnnounceInfo { + &self.info } } -struct State { - namespace: String, - session: Subscriber, - ok: bool, - closed: Result<(), ServeError>, -} - -impl State { - fn new(session: Subscriber, namespace: String) -> Self { - Self { - session, - namespace, - ok: false, - closed: Ok(()), - } - } - - pub fn accept(&mut self) -> Result<(), ServeError> { - self.closed.clone()?; - self.ok = true; - - self.session - .send_message(message::AnnounceOk { - namespace: self.namespace.clone(), - }) - .ok(); - - Ok(()) - } - - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.closed.clone()?; - self.closed = Err(err.clone()); +impl Drop for Announced { + fn drop(&mut self) { + let err = self.error.clone().unwrap_or(ServeError::Done); if self.ok { - self.session - .send_message(message::AnnounceCancel { - namespace: self.namespace.clone(), - }) - .ok(); + self.session.send_message(message::AnnounceCancel { + namespace: self.namespace.clone(), + }); } else { - self.session - .send_message(message::AnnounceError { - namespace: self.namespace.clone(), - code: err.code(), - reason: err.to_string(), - }) - .ok(); + self.session.send_message(message::AnnounceError { + namespace: self.namespace.clone(), + code: err.code(), + reason: err.to_string(), + }); } + } +} + +pub(super) struct AnnouncedRecv { + _state: State, +} +impl AnnouncedRecv { + pub fn recv_unannounce(self) -> Result<(), ServeError> { + // Will cause the state to be dropped Ok(()) } } diff --git a/moq-transport/src/session/error.rs b/moq-transport/src/session/error.rs index 10702f4d..780cede9 100644 --- a/moq-transport/src/session/error.rs +++ b/moq-transport/src/session/error.rs @@ -1,15 +1,19 @@ -use std::{io, sync}; +use std::sync::Arc; + +use webtransport_generic::ErrorCode; use crate::{coding, serve, setup}; #[derive(thiserror::Error, Debug, Clone)] pub enum SessionError { #[error("webtransport error: {0}")] - WebTransport(sync::Arc), + WebTransport(Arc), + + #[error("write error: {0}")] + Write(Arc), - // This needs an Arc because it's not Clone. - #[error("io error: {0}")] - Io(sync::Arc), + #[error("read error: {0}")] + Read(Arc), #[error("encode error: {0}")] Encode(#[from] coding::EncodeError), @@ -40,27 +44,35 @@ pub enum SessionError { #[error("internal error")] Internal, - #[error("cache error: {0}")] - Cache(#[from] serve::ServeError), + #[error("serve error: {0}")] + Serve(#[from] serve::ServeError), #[error("wrong size")] WrongSize, } +impl SessionError { + pub(super) fn from_webtransport(err: E) -> Self { + Self::WebTransport(Arc::new(err)) + } + + pub(super) fn from_read(err: E) -> Self { + Self::Read(Arc::new(err)) + } + + pub(super) fn from_write(err: E) -> Self { + Self::Write(Arc::new(err)) + } +} + /* impl From for SessionError { fn from(err: T) -> Self { - Self::WebTransport(sync::Arc::new(err)) + Self::WebTransport(Arc::new(err)) } } */ -impl From for SessionError { - fn from(err: io::Error) -> Self { - Self::Io(sync::Arc::new(err)) - } -} - impl SessionError { /// An integer code that is sent over the wire. pub fn code(&self) -> u64 { @@ -68,16 +80,25 @@ impl SessionError { Self::RoleIncompatible(..) => 406, Self::RoleViolation => 405, Self::WebTransport(_) => 503, + Self::Read(_) => 500, + Self::Write(_) => 500, Self::Version(..) => 406, Self::Decode(_) => 400, Self::Encode(_) => 500, - Self::Io(_) => 500, Self::BoundsExceeded(_) => 500, Self::Duplicate => 409, Self::Internal => 500, Self::WrongSize => 400, + Self::Serve(err) => err.code(), + } + } +} - Self::Cache(err) => err.code(), +impl From for serve::ServeError { + fn from(err: SessionError) -> Self { + match err { + SessionError::Serve(err) => err, + _ => serve::ServeError::Internal(err.to_string()), } } } diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index 55f00ab6..4eadbf7a 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -2,11 +2,11 @@ mod announce; mod announced; mod error; mod publisher; +mod reader; mod subscribe; mod subscribed; mod subscriber; - -use std::sync::Arc; +mod writer; pub use announce::*; pub use announced::*; @@ -16,21 +16,26 @@ pub use subscribe::*; pub use subscribed::*; pub use subscriber::*; +use reader::*; +use writer::*; + use futures::FutureExt; use futures::{stream::FuturesUnordered, StreamExt}; -use crate::coding::{Reader, Writer}; -use crate::{message, setup, util::Queue}; +use crate::message::Message; +use crate::util::Queue; +use crate::{message, setup}; pub struct Session { webtransport: S, sender: Writer, recver: Reader, - outgoing: Queue, publisher: Option>, subscriber: Option>, + + outgoing: Queue, } impl Session { @@ -41,19 +46,18 @@ impl Session { role: setup::Role, ) -> (Self, Option>, Option>) { let outgoing = Queue::default(); - let publisher = role .is_publisher() - .then(|| Publisher::new(webtransport.clone(), outgoing.clone())); + .then(|| Publisher::new(outgoing.clone(), webtransport.clone())); let subscriber = role.is_subscriber().then(|| Subscriber::new(outgoing.clone())); let session = Self { webtransport, sender, recver, - outgoing, publisher: publisher.clone(), subscriber: subscriber.clone(), + outgoing, }; (session, publisher, subscriber) @@ -69,10 +73,7 @@ impl Session { session: S, role: setup::Role, ) -> Result<(Session, Option>, Option>), SessionError> { - let control = session - .open_bi() - .await - .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + let control = session.open_bi().await.map_err(SessionError::from_webtransport)?; let mut sender = Writer::new(control.0); let mut recver = Reader::new(control.1); @@ -116,10 +117,7 @@ impl Session { session: S, role: setup::Role, ) -> Result<(Session, Option>, Option>), SessionError> { - let control = session - .accept_bi() - .await - .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + let control = session.accept_bi().await.map_err(SessionError::from_webtransport)?; let mut sender = Writer::new(control.0); let mut recver = Reader::new(control.1); @@ -162,24 +160,26 @@ impl Session { pub async fn run(self) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); - tasks.push(Self::run_send(self.outgoing, self.sender).boxed()); + tasks.push(Self::run_recv(self.recver, self.publisher, self.subscriber.clone()).boxed()); + tasks.push(Self::run_send(self.sender, self.outgoing).boxed()); if let Some(subscriber) = self.subscriber { tasks.push(Self::run_streams(self.webtransport.clone(), subscriber.clone()).boxed()); tasks.push(Self::run_datagrams(self.webtransport, subscriber).boxed()); } - let res = tasks.next().await.unwrap(); + let res = tasks.select_next_some().await; Err(res.expect_err("run terminated with OK")) } async fn run_send( - outgoing: Queue, mut sender: Writer, + outgoing: Queue, ) -> Result<(), SessionError> { loop { - let msg = outgoing.pop().await?; + let msg = outgoing.pop().await; + log::debug!("sending message: {:?}", msg); sender.encode(&msg).await?; } } @@ -191,6 +191,7 @@ impl Session { ) -> Result<(), SessionError> { loop { let msg: message::Message = recver.decode().await?; + log::debug!("received message: {:?}", msg); let msg = match TryInto::::try_into(msg) { Ok(msg) => { @@ -225,10 +226,16 @@ impl Session { loop { tokio::select! { res = webtransport.accept_uni() => { - let stream = res.map_err(|e| SessionError::WebTransport(Arc::new(e)))?; - tasks.push(Subscriber::recv_stream(subscriber.clone(), stream)); + let stream = res.map_err(SessionError::from_webtransport)?; + let subscriber = subscriber.clone(); + + tasks.push(async move { + if let Err(err) = Subscriber::recv_stream(subscriber, stream).await { + log::warn!("failed to serve stream: {}", err); + }; + }); }, - res = tasks.next(), if !tasks.is_empty() => res.unwrap()?, + _ = tasks.next(), if !tasks.is_empty() => {}, }; } } @@ -238,9 +245,9 @@ impl Session { let datagram = webtransport .recv_datagram() .await - .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + .map_err(SessionError::from_webtransport)?; - subscriber.recv_datagram(datagram).await?; + subscriber.recv_datagram(datagram)?; } } } diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 8df995b9..40ec83fb 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, Mutex}, }; -use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; +use futures::{stream::FuturesUnordered, StreamExt}; use crate::{ message::{self, Message}, @@ -19,20 +19,20 @@ use super::{Announce, AnnounceRecv, Session, SessionError, Subscribed, Subscribe pub struct Publisher { webtransport: S, - announces: Arc>>, - subscribed: Arc>>>, - subscribed_queue: Queue, SessionError>, + announces: Arc>>>, + subscribed: Arc>>, + unknown: Queue>, - outgoing: Queue, + outgoing: Queue, } impl Publisher { - pub(crate) fn new(webtransport: S, outgoing: Queue) -> Self { + pub(crate) fn new(outgoing: Queue, webtransport: S) -> Self { Self { webtransport, announces: Default::default(), subscribed: Default::default(), - subscribed_queue: Default::default(), + unknown: Default::default(), outgoing, } } @@ -50,144 +50,180 @@ impl Publisher { pub fn announce(&mut self, namespace: &str) -> Result, SessionError> { let mut announces = self.announces.lock().unwrap(); - // Insert the abort handle into the lookup table. let entry = match announces.entry(namespace.to_string()) { hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), hash_map::Entry::Vacant(entry) => entry, }; - let msg = message::Announce { - namespace: namespace.to_string(), - params: Default::default(), - }; - self.send_message(msg.clone())?; - - let (announce, recv) = Announce::new(self.clone(), msg); + let (send, recv) = Announce::new(self.clone(), namespace.to_string()); entry.insert(recv); - Ok(announce) + // Unannounce on close + Ok(send) } - pub async fn subscribed(&mut self) -> Result, SessionError> { - self.subscribed_queue.pop().await - } + // Helper function to announce and serve a list of tracks. + pub async fn serve(&mut self, broadcast: serve::BroadcastReader) -> Result<(), SessionError> { + let mut announce = self.announce(&broadcast.namespace)?; - // Helper to announce and serve any matching subscribers. - // TODO this currently takes over the connection; definitely remove Clone - pub async fn serve(mut self, broadcast: serve::BroadcastSubscriber) -> Result<(), SessionError> { - log::info!("serving broadcast: {}", broadcast.namespace); - - let announce = self.announce(&broadcast.namespace)?; let mut tasks = FuturesUnordered::new(); + let mut done = None; + loop { tokio::select! { - err = announce.closed() => err?, - res = tasks.next(), if !tasks.is_empty() => { - // TODO preseve the track name too - log::debug!("served track: namespace={} res={:?}", broadcast.namespace, res); - }, - sub = self.subscribed() => { - let mut subscribe = sub?; - match self.serve_track(&broadcast, &subscribe) { - Ok(track) => { - log::info!("serving track: namespace={} name={}", track.namespace, track.name); - tasks.push(subscribe.serve(track).boxed()); - }, - Err(err) => { - log::debug!("failed serving track: namespace={} name={} err={}", subscribe.namespace(), subscribe.name(), err); - subscribe.close(err).ok(); - } + subscribe = announce.subscribed(), if done.is_none() => { + let subscribe = match subscribe { + Ok(Some(subscribe)) => subscribe, + Ok(None) => { done = Some(Ok(())); continue }, + Err(err) => { done = Some(Err(err)); continue }, }; - } + + let broadcast = broadcast.clone(); + + tasks.push(async move { + let info = subscribe.info.clone(); + + match broadcast.get_track(&subscribe.name) { + Ok(track) => if let Err(err) = Self::serve_subscribe(subscribe, track).await { + log::warn!("failed serving subscribe: {:?}, error: {}", info, err) + }, + Err(err) => { + log::warn!("failed getting subscribe: {:?}, error: {}", info, err) + }, + } + }); + }, + _ = tasks.next(), if !tasks.is_empty() => {}, + else => return Ok(done.unwrap()?) } } } - fn serve_track( - &self, - broadcast: &serve::BroadcastSubscriber, - subscribe: &Subscribed, - ) -> Result { - if subscribe.namespace() != broadcast.namespace { - return Err(ServeError::NotFound); - } + pub async fn serve_subscribe( + subscribe: Subscribed, + track: Option, + ) -> Result<(), SessionError> { + match track { + Some(track) => subscribe.serve(track).await?, + None => subscribe.close(ServeError::NotFound)?, + }; + + Ok(()) + } - broadcast.get_track(subscribe.name())?.ok_or(ServeError::NotFound) + // Returns subscriptions that do not map to an active announce. + pub async fn subscribed(&mut self) -> Subscribed { + self.unknown.pop().await } pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { - log::debug!("received message: {:?}", msg); - - match msg { + let res = match msg { message::Subscriber::AnnounceOk(msg) => self.recv_announce_ok(msg), message::Subscriber::AnnounceError(msg) => self.recv_announce_error(msg), message::Subscriber::AnnounceCancel(msg) => self.recv_announce_cancel(msg), message::Subscriber::Subscribe(msg) => self.recv_subscribe(msg), message::Subscriber::Unsubscribe(msg) => self.recv_unsubscribe(msg), + }; + + if let Err(err) = res { + log::warn!("failed to process message: {}", err); } + + Ok(()) } - fn recv_announce_ok(&mut self, _msg: message::AnnounceOk) -> Result<(), SessionError> { - // Who cares - // TODO make AnnouncePending so we're forced to care + fn recv_announce_ok(&mut self, msg: message::AnnounceOk) -> Result<(), SessionError> { + if let Some(announce) = self.announces.lock().unwrap().get_mut(&msg.namespace) { + announce.recv_ok()?; + } + Ok(()) } fn recv_announce_error(&mut self, msg: message::AnnounceError) -> Result<(), SessionError> { - if let Some(announce) = self.announces.lock().unwrap().get_mut(&msg.namespace) { - announce.recv_error(ServeError::Closed(msg.code)).ok(); + if let Some(announce) = self.announces.lock().unwrap().remove(&msg.namespace) { + announce.recv_error(ServeError::Closed(msg.code))?; } Ok(()) } - fn recv_announce_cancel(&mut self, _msg: message::AnnounceCancel) -> Result<(), SessionError> { - unimplemented!("recv_announce_cancel") + fn recv_announce_cancel(&mut self, msg: message::AnnounceCancel) -> Result<(), SessionError> { + if let Some(announce) = self.announces.lock().unwrap().remove(&msg.namespace) { + announce.recv_error(ServeError::Cancel)?; + } + + Ok(()) } fn recv_subscribe(&mut self, msg: message::Subscribe) -> Result<(), SessionError> { - let mut subscribes = self.subscribed.lock().unwrap(); + let namespace = msg.track_namespace.clone(); - // Insert the abort handle into the lookup table. - let entry = match subscribes.entry(msg.id) { - hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), - hash_map::Entry::Vacant(entry) => entry, + let subscribe = { + let mut subscribes = self.subscribed.lock().unwrap(); + + // Insert the abort handle into the lookup table. + let entry = match subscribes.entry(msg.id) { + hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (send, recv) = Subscribed::new(self.clone(), msg); + entry.insert(recv); + + send }; - let (subscribe, recv) = Subscribed::new(self.clone(), msg); - entry.insert(recv); - self.subscribed_queue.push(subscribe) + // If we have an announce, route the subscribe to it. + // Otherwise, put it in the unknown queue. + // TODO Have some way to detect if the application is not reading from the unknown queue. + match self.announces.lock().unwrap().get_mut(&namespace) { + Some(announce) => announce.recv_subscribe(subscribe)?, + None => self.unknown.push(subscribe), + }; + + Ok(()) } fn recv_unsubscribe(&mut self, msg: message::Unsubscribe) -> Result<(), SessionError> { if let Some(subscribed) = self.subscribed.lock().unwrap().get_mut(&msg.id) { - subscribed.recv_unsubscribe().ok(); + subscribed.recv_unsubscribe()?; } Ok(()) } - pub fn send_message>(&self, msg: T) -> Result<(), SessionError> { + pub(super) fn send_message + Into>(&mut self, msg: T) { let msg = msg.into(); - log::debug!("sending message: {:?}", msg); + match &msg { + message::Publisher::SubscribeDone(msg) => self.drop_subscribe(msg.id), + message::Publisher::SubscribeError(msg) => self.drop_subscribe(msg.id), + message::Publisher::Unannounce(msg) => self.drop_announce(msg.namespace.as_str()), + _ => (), + }; + self.outgoing.push(msg.into()) } - pub(super) fn drop_subscribe(&mut self, id: u64) { + fn drop_subscribe(&mut self, id: u64) { self.subscribed.lock().unwrap().remove(&id); } - pub(super) fn drop_announce(&mut self, namespace: &str) { + fn drop_announce(&mut self, namespace: &str) { self.announces.lock().unwrap().remove(namespace); } - pub(super) fn webtransport(&mut self) -> &mut S { - &mut self.webtransport + pub(super) async fn open_uni(&self) -> Result { + self.webtransport + .open_uni() + .await + .map_err(SessionError::from_webtransport) } - pub fn close(self, err: SessionError) { - self.outgoing.close(err.clone()).ok(); - self.subscribed_queue.close(err).ok(); + pub(super) fn send_datagram(&self, data: bytes::Bytes) -> Result<(), SessionError> { + self.webtransport + .send_datagram(data) + .map_err(SessionError::from_webtransport) } } diff --git a/moq-transport/src/session/reader.rs b/moq-transport/src/session/reader.rs new file mode 100644 index 00000000..0c86b1a5 --- /dev/null +++ b/moq-transport/src/session/reader.rs @@ -0,0 +1,88 @@ +use std::{cmp, io}; + +use bytes::{Buf, Bytes, BytesMut}; + +use crate::coding::{Decode, DecodeError}; + +use super::SessionError; + +pub struct Reader { + stream: S, + buffer: BytesMut, +} + +impl Reader { + pub fn new(stream: S) -> Self { + Self { + stream, + buffer: Default::default(), + } + } + + pub async fn decode(&mut self) -> Result { + loop { + let mut cursor = io::Cursor::new(&self.buffer); + + // Try to decode with the current buffer. + let mut remain = match T::decode(&mut cursor) { + Ok(msg) => { + self.buffer.advance(cursor.position() as usize); + return Ok(msg); + } + Err(DecodeError::More(remain)) => remain, // Try again with more data + Err(err) => return Err(err.into()), + }; + + // Read in more data until we reach the requested amount. + // We always read at least once to avoid an infinite loop if some dingus puts remain=0 + loop { + let size = self + .stream + .read_buf(&mut self.buffer) + .await + .map_err(SessionError::from_read)?; + if size == 0 { + return Err(DecodeError::More(remain).into()); + } + + remain = remain.saturating_sub(size); + if remain == 0 { + break; + } + } + } + } + + pub async fn read_chunk(&mut self, max: usize) -> Result, SessionError> { + if !self.buffer.is_empty() { + let size = cmp::min(max, self.buffer.len()); + let data = self.buffer.split_to(size).freeze(); + return Ok(Some(data)); + } + + let chunk = match self.stream.read_chunk().await.map_err(SessionError::from_read)? { + Some(chunk) if chunk.len() <= max => Some(chunk), + Some(mut chunk) => { + // The chunk is too big; add the tail to the buffer for next read. + self.buffer.extend_from_slice(&chunk.split_off(max)); + Some(chunk) + } + None => None, + }; + + Ok(chunk) + } + + pub async fn done(&mut self) -> Result { + if !self.buffer.is_empty() { + return Ok(false); + } + + let size = self + .stream + .read_buf(&mut self.buffer) + .await + .map_err(SessionError::from_read)?; + Ok(size == 0) + } +} diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index 667974ed..1c193945 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -1,169 +1,209 @@ -use std::sync::{Arc, Mutex}; +use std::ops; use crate::{ - coding::Reader, data, message, - serve::{self, ServeError}, + serve::{self, ServeError, TrackWriter, TrackWriterMode}, + util::State, }; -use super::{SessionError, Subscriber}; +use super::Subscriber; -#[derive(Clone)] +#[derive(Debug, Clone)] +pub struct SubscribeInfo { + pub namespace: String, + pub name: String, +} + +struct SubscribeState { + ok: bool, + closed: Result<(), ServeError>, +} + +impl Default for SubscribeState { + fn default() -> Self { + Self { + ok: Default::default(), + closed: Ok(()), + } + } +} + +// Held by the application pub struct Subscribe { - session: Subscriber, + state: State, + subscriber: Subscriber, id: u64, - track: Arc>, + + pub info: SubscribeInfo, } impl Subscribe { - pub(super) fn new(session: Subscriber, id: u64, track: serve::TrackPublisher) -> Self { - Self { - session, + pub(super) fn new(mut subscriber: Subscriber, id: u64, track: TrackWriter) -> (Subscribe, SubscribeRecv) { + subscriber.send_message(message::Subscribe { id, - track: Arc::new(Mutex::new(track)), - } - } + track_alias: id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + // TODO add these to the publisher. + start: Default::default(), + end: Default::default(), + params: Default::default(), + }); + + let info = SubscribeInfo { + namespace: track.namespace.clone(), + name: track.name.clone(), + }; + + let (send, recv) = State::init(); + + let send = Subscribe { + state: send, + subscriber, + id, + info, + }; - pub fn recv_ok(&mut self, _msg: message::SubscribeOk) -> Result<(), ServeError> { - // TODO - Ok(()) - } + let recv = SubscribeRecv { + state: recv, + writer: Some(track.into()), + }; - pub fn recv_error(&mut self, code: u64) -> Result<(), ServeError> { - self.track.lock().unwrap().close(ServeError::Closed(code))?; - Ok(()) + (send, recv) } - pub fn recv_done(&mut self, code: u64) -> Result<(), ServeError> { - self.track.lock().unwrap().close(ServeError::Closed(code))?; - Ok(()) - } + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + let notify = { + let state = self.state.lock(); + state.closed.clone()?; - pub async fn recv_stream( - &mut self, - header: data::Header, - reader: Reader, - ) -> Result<(), SessionError> { - match header { - data::Header::Track(track) => self.recv_track(track, reader).await, - data::Header::Group(group) => self.recv_group(group, reader).await, - data::Header::Object(object) => self.recv_object(object, reader).await, + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } + }; + + notify.await } } +} - async fn recv_track( - &mut self, - header: data::TrackHeader, - mut reader: Reader, - ) -> Result<(), SessionError> { - log::trace!("received track: {:?}", header); +impl Drop for Subscribe { + fn drop(&mut self) { + self.subscriber.send_message(message::Unsubscribe { id: self.id }); + } +} - let mut track = self.track.lock().unwrap().create_stream(header.send_order)?; +impl ops::Deref for Subscribe { + type Target = SubscribeInfo; - while !reader.done().await? { - let chunk: data::TrackObject = reader.decode().await?; + fn deref(&self) -> &SubscribeInfo { + &self.info + } +} - let mut remain = chunk.size; +pub(super) struct SubscribeRecv { + state: State, + writer: Option, +} - let mut chunks = vec![]; - while remain > 0 { - let chunk = reader.read(remain).await?.ok_or(SessionError::WrongSize)?; - log::trace!("received track payload: {:?}", chunk.len()); - remain -= chunk.len(); - chunks.push(chunk); - } +impl SubscribeRecv { + pub fn ok(&mut self) -> Result<(), ServeError> { + let state = self.state.lock(); + if state.ok { + return Err(ServeError::Duplicate); + } - let object = serve::StreamObject { - object_id: chunk.object_id, - group_id: chunk.group_id, - payload: bytes::Bytes::from(chunks.concat()), - }; - log::trace!("received track object: {:?}", track); + if let Some(mut state) = state.into_mut() { + state.ok = true; + } - track.write_object(object)?; + Ok(()) + } + + pub fn error(mut self, err: ServeError) -> Result<(), ServeError> { + if let Some(writer) = self.writer.take() { + writer.close(err.clone())?; } + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Cancel)?; + state.closed = Err(err); + Ok(()) } - async fn recv_group( - &mut self, - header: data::GroupHeader, - mut reader: Reader, - ) -> Result<(), SessionError> { - log::trace!("received group: {:?}", header); + pub fn track(&mut self, header: data::TrackHeader) -> Result { + let writer = self.writer.take().ok_or(ServeError::Done)?; - let mut group = self.track.lock().unwrap().create_group(serve::Group { - id: header.group_id, - send_order: header.send_order, - })?; + let stream = match writer { + TrackWriterMode::Track(init) => init.stream(header.send_order)?, + _ => return Err(ServeError::Mode), + }; - while !reader.done().await? { - let object: data::GroupObject = reader.decode().await?; + self.writer = Some(stream.clone().into()); - log::trace!("received group object: {:?}", object); - let mut remain = object.size; - let mut object = group.create_object(object.size)?; + Ok(stream) + } - while remain > 0 { - let data = reader.read(remain).await?.ok_or(SessionError::WrongSize)?; - log::trace!("received group payload: {:?}", data.len()); - remain -= data.len(); - object.write(data)?; - } - } + pub fn group(&mut self, header: data::GroupHeader) -> Result { + let writer = self.writer.take().ok_or(ServeError::Done)?; - Ok(()) + let mut groups = match writer { + TrackWriterMode::Track(init) => init.groups()?, + TrackWriterMode::Groups(groups) => groups, + _ => return Err(ServeError::Mode), + }; + + let writer = groups.create(serve::Group { + group_id: header.group_id, + priority: header.send_order, + })?; + + self.writer = Some(groups.into()); + + Ok(writer) } - async fn recv_object( - &mut self, - header: data::ObjectHeader, - mut reader: Reader, - ) -> Result<(), SessionError> { - log::trace!("received object: {:?}", header); - - // TODO avoid buffering the entire object to learn the size. - let mut chunks = vec![]; - while let Some(data) = reader.read(usize::MAX).await? { - log::trace!("received object payload: {:?}", data.len()); - chunks.push(data); - } + pub fn object(&mut self, header: data::ObjectHeader) -> Result { + let writer = self.writer.take().ok_or(ServeError::Done)?; - let mut object = self.track.lock().unwrap().create_object(serve::ObjectHeader { + let mut objects = match writer { + TrackWriterMode::Track(init) => init.objects()?, + TrackWriterMode::Objects(objects) => objects, + _ => return Err(ServeError::Mode), + }; + + let writer = objects.create(serve::Object { group_id: header.group_id, object_id: header.object_id, - send_order: header.send_order, - size: chunks.iter().map(|c| c.len()).sum(), + priority: header.send_order, })?; - log::trace!("received object: {:?}", object); - - for chunk in chunks { - object.write(chunk)?; - } + self.writer = Some(objects.into()); - Ok(()) + Ok(writer) } - pub fn recv_datagram(&self, datagram: data::Datagram) -> Result<(), SessionError> { - log::trace!("received datagram: {:?}", datagram); + pub fn datagram(&mut self, datagram: data::Datagram) -> Result<(), ServeError> { + let writer = self.writer.take().ok_or(ServeError::Done)?; + + let mut datagrams = match writer { + TrackWriterMode::Track(init) => init.datagrams()?, + TrackWriterMode::Datagrams(datagrams) => datagrams, + _ => return Err(ServeError::Mode), + }; - self.track.lock().unwrap().write_datagram(serve::Datagram { + datagrams.write(serve::Datagram { group_id: datagram.group_id, object_id: datagram.object_id, + priority: datagram.send_order, payload: datagram.payload, - send_order: datagram.send_order, })?; Ok(()) } } - -impl Drop for Subscribe { - fn drop(&mut self) { - let msg = message::Unsubscribe { id: self.id }; - self.session.send_message(msg).ok(); - self.session.drop_subscribe(self.id); - } -} diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index cc4aaf18..8a988a89 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -1,82 +1,178 @@ -use std::sync::Arc; +use std::ops; use futures::stream::FuturesUnordered; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; -use crate::coding::{Encode, Writer}; -use crate::serve::ServeError; -use crate::util::{Watch, WatchWeak}; -use crate::{data, message, serve}; +use webtransport_generic::SendStream; -use super::{Publisher, SessionError}; +use crate::coding::Encode; +use crate::serve::{ServeError, TrackReaderMode}; +use crate::util::State; +use crate::{data, message, serve, Publisher}; + +use super::{SessionError, SubscribeInfo, Writer}; + +#[derive(Debug)] +struct SubscribedState { + max: Option<(u64, u64)>, + closed: Result<(), ServeError>, +} + +impl SubscribedState { + fn update_max(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { + if let Some((max_group, max_object)) = self.max { + if group_id >= max_group && object_id >= max_object { + self.max = Some((group_id, object_id)); + } + } + + Ok(()) + } +} + +impl Default for SubscribedState { + fn default() -> Self { + Self { + max: None, + closed: Ok(()), + } + } +} -#[derive(Clone)] pub struct Subscribed { - session: Publisher, - state: Watch>, + publisher: Publisher, + state: State, msg: message::Subscribe, + ok: bool, + + pub info: SubscribeInfo, } impl Subscribed { - pub(super) fn new(session: Publisher, msg: message::Subscribe) -> (Subscribed, SubscribedRecv) { - let state = Watch::new(State::new(session.clone(), msg.id)); - let recv = SubscribedRecv { - state: state.downgrade(), + pub(super) fn new(publisher: Publisher, msg: message::Subscribe) -> (Self, SubscribedRecv) { + let (send, recv) = State::init(); + let info = SubscribeInfo { + namespace: msg.track_namespace.clone(), + name: msg.track_name.clone(), }; - let subscribed = Self { session, state, msg }; - (subscribed, recv) + let send = Self { + publisher, + state: send, + msg, + info, + ok: false, + }; + + // Prevents updates after being closed + let recv = SubscribedRecv { state: recv }; + + (send, recv) } - pub fn namespace(&self) -> &str { - self.msg.track_namespace.as_str() + pub async fn serve(mut self, track: serve::TrackReader) -> Result<(), SessionError> { + let res = self.serve_inner(track).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + + res } - pub fn name(&self) -> &str { - self.msg.track_name.as_str() + async fn serve_inner(&mut self, track: serve::TrackReader) -> Result<(), SessionError> { + let latest = track.latest(); + self.state.lock_mut().ok_or(ServeError::Cancel)?.max = latest; + + self.publisher.send_message(message::SubscribeOk { + id: self.msg.id, + expires: None, + latest, + }); + + self.ok = true; // So we sent SubscribeDone on drop + + match track.mode().await? { + // TODO cancel track/datagrams on closed + TrackReaderMode::Stream(stream) => self.serve_track(stream).await, + TrackReaderMode::Groups(groups) => self.serve_groups(groups).await, + TrackReaderMode::Objects(objects) => self.serve_objects(objects).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } } - pub async fn serve(mut self, mut track: serve::TrackSubscriber) -> Result<(), SessionError> { - let mut tasks = FuturesUnordered::new(); + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; - self.state.lock_mut().ok(track.latest())?; - let mut done = false; + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + Ok(()) + } + + pub async fn closed(&self) -> Result<(), ServeError> { loop { - tokio::select! { - next = track.next(), if !done => { - let next = match next? { - Some(next) => next, - None => { done = true; continue }, - }; - - match next { - serve::TrackMode::Stream(stream) => return self.serve_track(stream).await, - serve::TrackMode::Group(group) => tasks.push(Self::serve_group(self.clone(), group).boxed()), - serve::TrackMode::Object(object) => tasks.push(Self::serve_object(self.clone(), object).boxed()), - serve::TrackMode::Datagram(datagram) => self.serve_datagram(datagram).await?, - } - }, - task = tasks.next(), if !tasks.is_empty() => task.unwrap()?, - else => return Ok(()), + let notify = { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } }; + + notify.await; } } +} + +impl ops::Deref for Subscribed { + type Target = SubscribeInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} - async fn serve_track(mut self, mut track: serve::StreamSubscriber) -> Result<(), SessionError> { - let stream = self - .session - .webtransport() - .open_uni() - .await - .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; +impl Drop for Subscribed { + fn drop(&mut self) { + let state = self.state.lock(); + let err = state.closed.as_ref().err().cloned().unwrap_or(ServeError::Done); + let max = state.max; + drop(state); // Important to avoid a deadlock + + if self.ok { + self.publisher.send_message(message::SubscribeDone { + id: self.msg.id, + last: max, + code: err.code(), + reason: err.to_string(), + }); + } else { + self.publisher.send_message(message::SubscribeError { + id: self.msg.id, + alias: 0, + code: err.code(), + reason: err.to_string(), + }); + }; + } +} + +impl Subscribed { + async fn serve_track(&mut self, mut track: serve::StreamReader) -> Result<(), SessionError> { + let mut stream = self.publisher.open_uni().await?; + + // TODO figure out u32 vs u64 priority + stream.priority(track.priority as i32); let mut writer = Writer::new(stream); let header: data::Header = data::TrackHeader { subscribe_id: self.msg.id, track_alias: self.msg.track_alias, - send_order: track.send_order, + send_order: track.priority, } .into(); @@ -84,47 +180,84 @@ impl Subscribed { log::trace!("sent track header: {:?}", header); - while let Some(object) = track.next().await? { - // TODO support streaming chunks - // TODO check if closed + while let Some(mut group) = track.next().await? { + while let Some(mut object) = group.next().await? { + let header = data::TrackObject { + group_id: object.group_id, + object_id: object.object_id, + size: object.size, + }; - let header = data::TrackObject { - group_id: object.group_id, - object_id: object.object_id, - size: object.payload.len(), - }; + self.state + .lock_mut() + .ok_or(ServeError::Done)? + .update_max(object.group_id, object.object_id)?; - writer.encode(&header).await?; + writer.encode(&header).await?; - log::trace!("sent track object: {:?}", header); + log::trace!("sent track object: {:?}", header); - self.state.lock_mut().update_max(object.group_id, object.object_id)?; - writer.write(&object.payload).await?; + while let Some(chunk) = object.read().await? { + writer.write(&chunk).await?; + log::trace!("sent track payload: {:?}", chunk.len()); + } - log::trace!("sent track payload: {:?}", object.payload.len()); - log::trace!("sent track done"); + log::trace!("sent track done"); + } } Ok(()) } - pub async fn serve_group(mut self, mut group: serve::GroupSubscriber) -> Result<(), SessionError> { - let stream = self - .session - .webtransport() - .open_uni() - .await - .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; - let mut writer = Writer::new(stream); + async fn serve_groups(&mut self, mut groups: serve::GroupsReader) -> Result<(), SessionError> { + let mut tasks = FuturesUnordered::new(); + let mut done: Option> = None; - let header: data::Header = data::GroupHeader { - subscribe_id: self.msg.id, - track_alias: self.msg.track_alias, - group_id: group.id, - send_order: group.send_order, + loop { + tokio::select! { + res = groups.next(), if done.is_none() => match res { + Ok(Some(group)) => { + let header = data::GroupHeader { + subscribe_id: self.msg.id, + track_alias: self.msg.track_alias, + group_id: group.group_id, + send_order: group.priority, + }; + + let publisher = self.publisher.clone(); + let state = self.state.clone(); + let info = group.info.clone(); + + tasks.push(async move { + if let Err(err) = Self::serve_group(header, group, publisher, state).await { + log::warn!("failed to serve group: {:?}, error: {}", info, err); + } + }); + }, + Ok(None) => done = Some(Ok(())), + Err(err) => done = Some(Err(err)), + }, + res = self.closed(), if done.is_none() => done = Some(res), + _ = tasks.next(), if !tasks.is_empty() => {}, + else => return Ok(done.unwrap()?), + } } - .into(); + } + + async fn serve_group( + header: data::GroupHeader, + mut group: serve::GroupReader, + publisher: Publisher, + state: State, + ) -> Result<(), SessionError> { + let mut stream = publisher.open_uni().await?; + // TODO figure out u32 vs u64 priority + stream.priority(group.priority as i32); + + let mut writer = Writer::new(stream); + + let header: data::Header = header.into(); writer.encode(&header).await?; log::trace!("sent group: {:?}", header); @@ -137,7 +270,10 @@ impl Subscribed { writer.encode(&header).await?; - self.state.lock_mut().update_max(group.id, object.object_id)?; + state + .lock_mut() + .ok_or(ServeError::Done)? + .update_max(group.group_id, object.object_id)?; log::trace!("sent group object: {:?}", header); @@ -152,30 +288,65 @@ impl Subscribed { Ok(()) } - pub async fn serve_object(mut self, mut object: serve::ObjectSubscriber) -> Result<(), SessionError> { - let stream = self - .session - .webtransport() - .open_uni() - .await - .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; - let mut writer = Writer::new(stream); + pub async fn serve_objects(&mut self, mut objects: serve::ObjectsReader) -> Result<(), SessionError> { + let mut tasks = FuturesUnordered::new(); + let mut done = None; - let header: data::Header = data::ObjectHeader { - subscribe_id: self.msg.id, - track_alias: self.msg.track_alias, - group_id: object.group_id, - object_id: object.object_id, - send_order: object.send_order, + loop { + tokio::select! { + res = objects.next(), if done.is_none() => match res { + Ok(Some(object)) => { + let header = data::ObjectHeader { + subscribe_id: self.msg.id, + track_alias: self.msg.track_alias, + group_id: object.group_id, + object_id: object.object_id, + send_order: object.priority, + }; + + let publisher = self.publisher.clone(); + let state = self.state.clone(); + let info = object.info.clone(); + + tasks.push(async move { + if let Err(err) = Self::serve_object(header, object, publisher, state).await { + log::warn!("failed to serve object: {:?}, error: {}", info, err); + }; + }); + }, + Ok(None) => done = Some(Ok(())), + Err(err) => done = Some(Err(err)), + }, + _ = tasks.next(), if !tasks.is_empty() => {}, + res = self.closed(), if done.is_none() => done = Some(res), + else => return Ok(done.unwrap()?), + } } - .into(); + } + + async fn serve_object( + header: data::ObjectHeader, + mut object: serve::ObjectReader, + publisher: Publisher, + state: State, + ) -> Result<(), SessionError> { + state + .lock_mut() + .ok_or(ServeError::Done)? + .update_max(object.group_id, object.object_id)?; + + let mut stream = publisher.open_uni().await?; + + // TODO figure out u32 vs u64 priority + stream.priority(object.priority as i32); + let mut writer = Writer::new(stream); + + let header: data::Header = header.into(); writer.encode(&header).await?; log::trace!("sent object: {:?}", header); - self.state.lock_mut().update_max(object.group_id, object.object_id)?; - while let Some(chunk) = object.read().await? { writer.write(&chunk).await?; log::trace!("sent object payload: {:?}", chunk.len()); @@ -186,141 +357,46 @@ impl Subscribed { Ok(()) } - pub async fn serve_datagram(&mut self, datagram: serve::Datagram) -> Result<(), SessionError> { - let datagram = data::Datagram { - subscribe_id: self.msg.id, - track_alias: self.msg.track_alias, - group_id: datagram.group_id, - object_id: datagram.object_id, - send_order: datagram.send_order, - payload: datagram.payload, - }; - - let mut buffer = Vec::with_capacity(datagram.payload.len() + 100); - datagram.encode(&mut buffer)?; - - log::trace!("sent datagram: {:?}", datagram); - - // TODO send the datagram - //self.session.webtransport().send_datagram(&buffer)?; - - self.state - .lock_mut() - .update_max(datagram.group_id, datagram.object_id)?; - - Ok(()) - } - - pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.state.lock_mut().close(err) - } - - pub async fn closed(&self) -> Result<(), ServeError> { - loop { - let notify = { - let state = self.state.lock(); - state.closed.clone()?; - state.changed() + async fn serve_datagrams(&mut self, mut datagrams: serve::DatagramsReader) -> Result<(), SessionError> { + while let Some(datagram) = datagrams.read().await? { + let datagram = data::Datagram { + subscribe_id: self.msg.id, + track_alias: self.msg.track_alias, + group_id: datagram.group_id, + object_id: datagram.object_id, + send_order: datagram.priority, + payload: datagram.payload, }; - notify.await - } - } -} + let mut buffer = bytes::BytesMut::with_capacity(datagram.payload.len() + 100); + datagram.encode(&mut buffer)?; -pub(super) struct SubscribedRecv { - state: WatchWeak>, -} + self.publisher.send_datagram(buffer.into())?; + log::trace!("sent datagram: {:?}", datagram); -impl SubscribedRecv { - pub fn recv_unsubscribe(&mut self) -> Result<(), ServeError> { - if let Some(state) = self.state.upgrade() { - state.lock_mut().close(ServeError::Done)?; + self.state + .lock_mut() + .ok_or(ServeError::Done)? + .update_max(datagram.group_id, datagram.object_id)?; } + Ok(()) } } -struct State { - session: Publisher, - id: u64, - - ok: bool, - max: Option<(u64, u64)>, - closed: Result<(), ServeError>, -} - -impl State { - fn new(session: Publisher, id: u64) -> Self { - Self { - session, - id, - ok: false, - max: None, - closed: Ok(()), - } - } +pub(super) struct SubscribedRecv { + state: State, } -impl State { - fn ok(&mut self, latest: Option<(u64, u64)>) -> Result<(), ServeError> { - self.ok = true; - self.max = latest; - - self.session - .send_message(message::SubscribeOk { - id: self.id, - expires: None, - latest, - }) - .ok(); - - Ok(()) - } - - fn close(&mut self, err: ServeError) -> Result<(), ServeError> { - self.closed.clone()?; - self.closed = Err(err.clone()); - - if self.ok { - self.session - .send_message(message::SubscribeDone { - id: self.id, - last: self.max, - code: err.code(), - reason: err.to_string(), - }) - .ok(); - } else { - self.session - .send_message(message::SubscribeError { - id: self.id, - alias: 0, - code: err.code(), - reason: err.to_string(), - }) - .ok(); - } - - Ok(()) - } - - fn update_max(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { - self.closed.clone()?; +impl SubscribedRecv { + pub fn recv_unsubscribe(&mut self) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; - if let Some((max_group, max_object)) = self.max { - if group_id >= max_group && object_id >= max_object { - self.max = Some((group_id, object_id)); - } + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Cancel); } Ok(()) } } - -impl Drop for State { - fn drop(&mut self) { - self.close(ServeError::Done).ok(); - self.session.drop_subscribe(self.id); - } -} diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index d150c33c..e5a9ce26 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -5,27 +5,30 @@ use std::{ }; use crate::{ - coding::{Decode, Reader}, - data, message, serve, setup, + coding::Decode, + data, + message::{self, Message}, + serve::{self, ServeError}, + setup, util::Queue, }; -use super::{Announced, AnnouncedRecv, Session, SessionError, Subscribe}; +use super::{Announced, AnnouncedRecv, Reader, Session, SessionError, Subscribe, SubscribeRecv}; // TODO remove Clone. #[derive(Clone)] pub struct Subscriber { - announced: Arc>>>, - announced_queue: Queue, SessionError>, + announced: Arc>>, + announced_queue: Queue>, - subscribes: Arc>>>, + subscribes: Arc>>, subscribe_next: Arc, - outgoing: Queue, + outgoing: Queue, } impl Subscriber { - pub(super) fn new(outgoing: Queue) -> Self { + pub(super) fn new(outgoing: Queue) -> Self { Self { announced: Default::default(), announced_queue: Default::default(), @@ -45,51 +48,50 @@ impl Subscriber { Ok((session, subscriber.unwrap())) } - pub async fn announced(&mut self) -> Result, SessionError> { + pub async fn announced(&mut self) -> Announced { self.announced_queue.pop().await } - pub fn subscribe(&mut self, track: serve::TrackPublisher) -> Result<(), SessionError> { + pub fn subscribe(&mut self, track: serve::TrackWriter) -> Result, ServeError> { let id = self.subscribe_next.fetch_add(1, atomic::Ordering::Relaxed); - let msg = message::Subscribe { - id, - track_alias: id, - track_namespace: track.namespace.to_string(), - track_name: track.name.to_string(), - // TODO add these to the publisher. - start: Default::default(), - end: Default::default(), - params: Default::default(), - }; - - self.send_message(msg.clone())?; + let (send, recv) = Subscribe::new(self.clone(), id, track); + self.subscribes.lock().unwrap().insert(id, recv); - let publisher = Subscribe::new(self.clone(), msg.id, track); - self.subscribes.lock().unwrap().insert(id, publisher); - - Ok(()) + Ok(send) } - pub(super) fn send_message>(&mut self, msg: M) -> Result<(), SessionError> { + pub(super) fn send_message>(&mut self, msg: M) { let msg = msg.into(); - log::debug!("sending message: {:?}", msg); - self.outgoing.push(msg.into()) + + // Remove our entry on terminal state. + match &msg { + message::Subscriber::AnnounceCancel(msg) => self.drop_announce(&msg.namespace), + message::Subscriber::AnnounceError(msg) => self.drop_announce(&msg.namespace), + _ => {} + } + + self.outgoing.push(msg.into()); } pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { - log::debug!("received message: {:?}", msg); - - match msg { + let res = match &msg { message::Publisher::Announce(msg) => self.recv_announce(msg), message::Publisher::Unannounce(msg) => self.recv_unannounce(msg), message::Publisher::SubscribeOk(msg) => self.recv_subscribe_ok(msg), message::Publisher::SubscribeError(msg) => self.recv_subscribe_error(msg), message::Publisher::SubscribeDone(msg) => self.recv_subscribe_done(msg), + }; + + if let Err(SessionError::Serve(err)) = res { + log::debug!("failed to process message: {:?} {}", msg, err); + return Ok(()); } + + res } - fn recv_announce(&mut self, msg: message::Announce) -> Result<(), SessionError> { + fn recv_announce(&mut self, msg: &message::Announce) -> Result<(), SessionError> { let mut announces = self.announced.lock().unwrap(); let entry = match announces.entry(msg.namespace.clone()) { @@ -97,83 +99,174 @@ impl Subscriber { hash_map::Entry::Vacant(entry) => entry, }; - let (announced, recv) = Announced::new(self.clone(), msg.namespace); - self.announced_queue.push(announced)?; + let (announced, recv) = Announced::new(self.clone(), msg.namespace.to_string()); + self.announced_queue.push(announced); entry.insert(recv); Ok(()) } - fn recv_unannounce(&mut self, msg: message::Unannounce) -> Result<(), SessionError> { - if let Some(announce) = self.announced.lock().unwrap().get_mut(&msg.namespace) { - announce.recv_unannounce().ok(); + fn recv_unannounce(&mut self, msg: &message::Unannounce) -> Result<(), SessionError> { + if let Some(announce) = self.announced.lock().unwrap().remove(&msg.namespace) { + announce.recv_unannounce()?; } Ok(()) } - fn recv_subscribe_ok(&mut self, msg: message::SubscribeOk) -> Result<(), SessionError> { - if let Some(sub) = self.subscribes.lock().unwrap().get_mut(&msg.id) { - sub.recv_ok(msg).ok(); + fn recv_subscribe_ok(&mut self, msg: &message::SubscribeOk) -> Result<(), SessionError> { + if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&msg.id) { + subscribe.ok()?; } Ok(()) } - fn recv_subscribe_error(&mut self, msg: message::SubscribeError) -> Result<(), SessionError> { - if let Some(subscriber) = self.subscribes.lock().unwrap().get_mut(&msg.id) { - subscriber.recv_error(msg.code).ok(); + fn recv_subscribe_error(&mut self, msg: &message::SubscribeError) -> Result<(), SessionError> { + if let Some(subscribe) = self.subscribes.lock().unwrap().remove(&msg.id) { + subscribe.error(ServeError::Closed(msg.code))?; } Ok(()) } - fn recv_subscribe_done(&mut self, msg: message::SubscribeDone) -> Result<(), SessionError> { - if let Some(subscriber) = self.subscribes.lock().unwrap().get_mut(&msg.id) { - subscriber.recv_done(msg.code).ok(); + fn recv_subscribe_done(&mut self, msg: &message::SubscribeDone) -> Result<(), SessionError> { + if let Some(subscribe) = self.subscribes.lock().unwrap().remove(&msg.id) { + subscribe.error(ServeError::Closed(msg.code))?; } Ok(()) } - pub(super) fn drop_subscribe(&mut self, id: u64) { - self.subscribes.lock().unwrap().remove(&id); - } - - pub(super) fn drop_announce(&mut self, namespace: &str) { + fn drop_announce(&mut self, namespace: &str) { self.announced.lock().unwrap().remove(namespace); } - pub(super) async fn recv_stream(self, stream: S::RecvStream) -> Result<(), SessionError> { + pub(super) async fn recv_stream(mut self, stream: S::RecvStream) -> Result<(), SessionError> { let mut reader = Reader::new(stream); let header: data::Header = reader.decode().await?; let id = header.subscribe_id(); - let subscribe = self.subscribes.lock().unwrap().get(&id).cloned(); - if let Some(mut subscribe) = subscribe { - subscribe.recv_stream(header, reader).await? + let res = self.recv_stream_inner(reader, header).await; + if let Err(SessionError::Serve(err)) = &res { + // The writer is closed, so we should teriminate. + // TODO it would be nice to do this immediately when the Writer is closed. + if let Some(subscribe) = self.subscribes.lock().unwrap().remove(&id) { + subscribe.error(err.clone())?; + } } + res + } + + async fn recv_stream_inner( + &mut self, + reader: Reader, + header: data::Header, + ) -> Result<(), SessionError> { + let id = header.subscribe_id(); + + // This is super silly, but I couldn't figure out a way to avoid the mutex guard across awaits. + enum Writer { + Track(serve::StreamWriter), + Group(serve::GroupWriter), + Object(serve::ObjectWriter), + } + + let writer = { + let mut subscribes = self.subscribes.lock().unwrap(); + let subscribe = subscribes.get_mut(&id).ok_or(ServeError::NotFound)?; + + match header { + data::Header::Track(track) => Writer::Track(subscribe.track(track)?), + data::Header::Group(group) => Writer::Group(subscribe.group(group)?), + data::Header::Object(object) => Writer::Object(subscribe.object(object)?), + } + }; + + match writer { + Writer::Track(track) => Self::recv_track(track, reader).await?, + Writer::Group(group) => Self::recv_group(group, reader).await?, + Writer::Object(object) => Self::recv_object(object, reader).await?, + }; + Ok(()) } - // TODO should not be async - pub async fn recv_datagram(&mut self, datagram: bytes::Bytes) -> Result<(), SessionError> { - let mut cursor = io::Cursor::new(datagram); - let datagram = data::Datagram::decode(&mut cursor)?; + async fn recv_track(mut track: serve::StreamWriter, mut reader: Reader) -> Result<(), SessionError> { + log::trace!("received track: {:?}", track.info); - let subscribe = self.subscribes.lock().unwrap().get(&datagram.subscribe_id).cloned(); + let mut prev: Option = None; - if let Some(subscribe) = subscribe { - subscribe.recv_datagram(datagram)?; + while !reader.done().await? { + let chunk: data::TrackObject = reader.decode().await?; + + let mut group = match prev { + Some(group) if group.group_id == chunk.group_id => group, + _ => track.create(chunk.group_id)?, + }; + + let mut object = group.create(chunk.size)?; + + let mut remain = chunk.size; + while remain > 0 { + let chunk = reader.read_chunk(remain).await?.ok_or(SessionError::WrongSize)?; + + log::trace!("received track payload: {:?}", chunk.len()); + remain -= chunk.len(); + object.write(chunk)?; + } + + prev = Some(group); + } + + Ok(()) + } + + async fn recv_group(mut group: serve::GroupWriter, mut reader: Reader) -> Result<(), SessionError> { + log::trace!("received group: {:?}", group.info); + + while !reader.done().await? { + let object: data::GroupObject = reader.decode().await?; + + log::trace!("received group object: {:?}", object); + let mut remain = object.size; + let mut object = group.create(object.size)?; + + while remain > 0 { + let data = reader.read_chunk(remain).await?.ok_or(SessionError::WrongSize)?; + log::trace!("received group payload: {:?}", data.len()); + remain -= data.len(); + object.write(data)?; + } } Ok(()) } - pub fn close(self, err: SessionError) { - self.outgoing.close(err.clone()).ok(); - self.announced_queue.close(err).ok(); + async fn recv_object( + mut object: serve::ObjectWriter, + mut reader: Reader, + ) -> Result<(), SessionError> { + log::trace!("received object: {:?}", object.info); + + while let Some(data) = reader.read_chunk(usize::MAX).await? { + log::trace!("received object payload: {:?}", data.len()); + object.write(data)?; + } + + Ok(()) + } + + pub fn recv_datagram(&mut self, datagram: bytes::Bytes) -> Result<(), SessionError> { + let mut cursor = io::Cursor::new(datagram); + let datagram = data::Datagram::decode(&mut cursor)?; + + if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&datagram.subscribe_id) { + subscribe.datagram(datagram)?; + } + + Ok(()) } } diff --git a/moq-transport/src/session/writer.rs b/moq-transport/src/session/writer.rs new file mode 100644 index 00000000..131de74c --- /dev/null +++ b/moq-transport/src/session/writer.rs @@ -0,0 +1,51 @@ +use std::io; + +use crate::coding::{Encode, EncodeError}; + +use super::SessionError; +use bytes::Buf; + +pub struct Writer { + stream: S, + buffer: bytes::BytesMut, +} + +impl Writer { + pub fn new(stream: S) -> Self { + Self { + stream, + buffer: Default::default(), + } + } + + pub async fn encode(&mut self, msg: &T) -> Result<(), SessionError> { + self.buffer.clear(); + msg.encode(&mut self.buffer)?; + + while !self.buffer.is_empty() { + self.stream + .write_buf(&mut self.buffer) + .await + .map_err(SessionError::from_write)?; + } + + Ok(()) + } + + pub async fn write(&mut self, buf: &[u8]) -> Result<(), SessionError> { + let mut cursor = io::Cursor::new(buf); + + while cursor.has_remaining() { + let size = self + .stream + .write_buf(&mut cursor) + .await + .map_err(SessionError::from_write)?; + if size == 0 { + return Err(EncodeError::More(cursor.remaining()).into()); + } + } + + Ok(()) + } +} diff --git a/moq-transport/src/util/mod.rs b/moq-transport/src/util/mod.rs index 93a76691..0263f841 100644 --- a/moq-transport/src/util/mod.rs +++ b/moq-transport/src/util/mod.rs @@ -1,5 +1,7 @@ mod queue; +mod state; mod watch; pub use queue::*; +pub use state::*; pub use watch::*; diff --git a/moq-transport/src/util/queue.rs b/moq-transport/src/util/queue.rs index 6e0e890e..26290192 100644 --- a/moq-transport/src/util/queue.rs +++ b/moq-transport/src/util/queue.rs @@ -2,11 +2,12 @@ use std::collections::VecDeque; use super::Watch; -pub struct Queue { - state: Watch>, +// TODO replace with mpsc or similar +pub struct Queue { + state: Watch>, } -impl Clone for Queue { +impl Clone for Queue { fn clone(&self) -> Self { Self { state: self.state.clone(), @@ -14,7 +15,7 @@ impl Clone for Queue { } } -impl Default for Queue { +impl Default for Queue { fn default() -> Self { Self { state: Default::default(), @@ -22,48 +23,22 @@ impl Default for Queue { } } -struct State { - queue: VecDeque, - closed: Result<(), E>, -} - -impl Default for State { - fn default() -> Self { - Self { - queue: Default::default(), - closed: Ok(()), - } +impl Queue { + pub fn push(&self, item: T) { + self.state.lock_mut().push_back(item); } -} -impl Queue { - pub fn push(&self, item: T) -> Result<(), E> { - let mut state = self.state.lock_mut(); - state.closed.clone()?; - state.queue.push_back(item); - Ok(()) - } - - pub async fn pop(&self) -> Result { + pub async fn pop(&self) -> T { loop { let notify = { - let state = self.state.lock(); - state.closed.clone()?; - - if !state.queue.is_empty() { - return Ok(state.into_mut().queue.pop_front().unwrap()); + let queue = self.state.lock(); + if !queue.is_empty() { + return queue.into_mut().pop_front().unwrap(); } - state.changed() + queue.changed() }; notify.await } } - - pub fn close(&self, err: E) -> Result<(), E> { - let mut state = self.state.lock_mut(); - state.closed.clone()?; - state.closed = Err(err); - Ok(()) - } } diff --git a/moq-transport/src/util/state.rs b/moq-transport/src/util/state.rs new file mode 100644 index 00000000..905a378f --- /dev/null +++ b/moq-transport/src/util/state.rs @@ -0,0 +1,249 @@ +use std::{ + fmt, + future::Future, + ops::{Deref, DerefMut}, + pin::Pin, + sync::{Arc, Mutex, MutexGuard, Weak}, + task, +}; + +struct StateInner { + value: T, + wakers: Vec, + epoch: usize, + dropped: Option<()>, +} + +impl StateInner { + pub fn new(value: T) -> Self { + Self { + value, + wakers: Vec::new(), + epoch: 0, + dropped: Some(()), + } + } + + pub fn register(&mut self, waker: &task::Waker) { + self.wakers.retain(|existing| !existing.will_wake(waker)); + self.wakers.push(waker.clone()); + } + + pub fn notify(&mut self) { + self.epoch += 1; + for waker in self.wakers.drain(..) { + waker.wake(); + } + } +} + +impl Default for StateInner { + fn default() -> Self { + Self::new(T::default()) + } +} + +impl fmt::Debug for StateInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value.fmt(f) + } +} + +pub struct State { + state: Arc>>, + drop: Arc>, +} + +impl State { + pub fn new(initial: T) -> (Self, Self) { + let state = Arc::new(Mutex::new(StateInner::new(initial))); + + ( + Self { + state: state.clone(), + drop: Arc::new(StateDrop { state: state.clone() }), + }, + Self { + state: state.clone(), + drop: Arc::new(StateDrop { state }), + }, + ) + } + + pub fn lock(&self) -> StateRef { + StateRef { + state: self.state.clone(), + drop: self.drop.clone(), + lock: self.state.lock().unwrap(), + } + } + + pub fn lock_mut(&self) -> Option> { + let lock = self.state.lock().unwrap(); + lock.dropped?; + Some(StateMut { + lock, + _drop: self.drop.clone(), + }) + } + + pub fn downgrade(&self) -> StateWeak { + StateWeak { + state: Arc::downgrade(&self.state), + drop: Arc::downgrade(&self.drop), + } + } +} + +impl Clone for State { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + drop: self.drop.clone(), + } + } +} + +impl State { + pub fn init() -> (Self, Self) { + Self::new(T::default()) + } +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.state.try_lock() { + Ok(lock) => lock.value.fmt(f), + Err(_) => write!(f, ""), + } + } +} + +pub struct StateRef<'a, T> { + state: Arc>>, + lock: MutexGuard<'a, StateInner>, + drop: Arc>, +} + +impl<'a, T> StateRef<'a, T> { + // Release the lock and wait for a notification when next updated. + pub fn modified(self) -> Option> { + self.lock.dropped?; + + Some(StateChanged { + state: self.state, + epoch: self.lock.epoch, + }) + } + + // Upgrade to a mutable references that automatically calls notify on drop. + pub fn into_mut(self) -> Option> { + self.lock.dropped?; + Some(StateMut { + lock: self.lock, + _drop: self.drop, + }) + } +} + +impl<'a, T> Deref for StateRef<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.lock.value + } +} + +impl<'a, T: fmt::Debug> fmt::Debug for StateRef<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.lock.fmt(f) + } +} + +pub struct StateMut<'a, T> { + lock: MutexGuard<'a, StateInner>, + _drop: Arc>, +} + +impl<'a, T> Deref for StateMut<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.lock.value + } +} + +impl<'a, T> DerefMut for StateMut<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.lock.value + } +} + +impl<'a, T> Drop for StateMut<'a, T> { + fn drop(&mut self) { + self.lock.notify(); + } +} + +impl<'a, T: fmt::Debug> fmt::Debug for StateMut<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.lock.fmt(f) + } +} + +pub struct StateChanged { + state: Arc>>, + epoch: usize, +} + +impl Future for StateChanged { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + // TODO is there an API we can make that doesn't drop this lock? + let mut state = self.state.lock().unwrap(); + + if state.epoch > self.epoch { + task::Poll::Ready(()) + } else { + state.register(cx.waker()); + task::Poll::Pending + } + } +} + +pub struct StateWeak { + state: Weak>>, + drop: Weak>, +} + +impl StateWeak { + pub fn upgrade(&self) -> Option> { + if let (Some(state), Some(drop)) = (self.state.upgrade(), self.drop.upgrade()) { + Some(State { state, drop }) + } else { + None + } + } +} + +impl Clone for StateWeak { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + drop: self.drop.clone(), + } + } +} + +struct StateDrop { + state: Arc>>, +} + +impl Drop for StateDrop { + fn drop(&mut self) { + let mut state = self.state.lock().unwrap(); + state.dropped = None; + state.notify(); + } +}