diff --git a/iroh/examples/connect-unreliable.rs b/iroh/examples/connect-unreliable.rs index ce1825d935..b908438360 100644 --- a/iroh/examples/connect-unreliable.rs +++ b/iroh/examples/connect-unreliable.rs @@ -8,7 +8,7 @@ use std::net::SocketAddr; use clap::Parser; -use iroh::{Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; +use iroh::{watchable::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; use tracing::info; // An example ALPN that we are using to communicate over the `Endpoint` @@ -50,7 +50,7 @@ async fn main() -> anyhow::Result<()> { .bind() .await?; - let node_addr = endpoint.node_addr().await?; + let node_addr = endpoint.node_addr()?.initialized().await?; let me = node_addr.node_id; println!("node id: {me}"); println!("node listening addresses:"); diff --git a/iroh/examples/listen-unreliable.rs b/iroh/examples/listen-unreliable.rs index 552c24daf7..84ab987e27 100644 --- a/iroh/examples/listen-unreliable.rs +++ b/iroh/examples/listen-unreliable.rs @@ -3,7 +3,7 @@ //! This example uses the default relay servers to attempt to holepunch, and will use that relay server to relay packets if the two devices cannot establish a direct UDP connection. //! run this example from the project root: //! $ cargo run --example listen-unreliable -use iroh::{Endpoint, RelayMode, SecretKey}; +use iroh::{watchable::Watcher as _, Endpoint, RelayMode, SecretKey}; use tracing::{info, warn}; // An example ALPN that we are using to communicate over the `Endpoint` @@ -35,7 +35,7 @@ async fn main() -> anyhow::Result<()> { println!("node id: {me}"); println!("node listening addresses:"); - let node_addr = endpoint.node_addr().await?; + let node_addr = endpoint.node_addr()?.initialized().await?; let local_addrs = node_addr .direct_addresses .into_iter() diff --git a/iroh/examples/listen.rs b/iroh/examples/listen.rs index 23199246ae..e6f257f29d 100644 --- a/iroh/examples/listen.rs +++ b/iroh/examples/listen.rs @@ -5,7 +5,7 @@ //! $ cargo run --example listen use std::time::Duration; -use iroh::{endpoint::ConnectionError, Endpoint, RelayMode, SecretKey}; +use iroh::{endpoint::ConnectionError, watchable::Watcher as _, Endpoint, RelayMode, SecretKey}; use tracing::{debug, info, warn}; // An example ALPN that we are using to communicate over the `Endpoint` @@ -37,7 +37,7 @@ async fn main() -> anyhow::Result<()> { println!("node id: {me}"); println!("node listening addresses:"); - let node_addr = endpoint.node_addr().await?; + let node_addr = endpoint.node_addr()?.initialized().await?; let local_addrs = node_addr .direct_addresses .into_iter() diff --git a/iroh/src/discovery.rs b/iroh/src/discovery.rs index d860e37532..a46359fd7a 100644 --- a/iroh/src/discovery.rs +++ b/iroh/src/discovery.rs @@ -450,7 +450,7 @@ mod tests { use tokio_util::task::AbortOnDropHandle; use super::*; - use crate::RelayMode; + use crate::{watchable::Watcher as _, RelayMode}; type InfoStore = HashMap, BTreeSet, u64)>; @@ -580,7 +580,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for our address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr()?.initialized().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -606,7 +606,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr()?.initialized().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -636,7 +636,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr()?.initialized().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -659,7 +659,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr()?.initialized().await?; let res = ep2.connect(ep1_addr, TEST_ALPN).await; assert!(res.is_err()); Ok(()) @@ -682,7 +682,7 @@ mod tests { new_endpoint(secret, disco).await }; // wait for out address to be updated and thus published at least once - ep1.node_addr().await?; + ep1.node_addr()?.initialized().await?; let ep1_wrong_addr = NodeAddr { node_id: ep1.node_id(), relay_url: None, diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index d014648cd6..6fd9673168 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -38,7 +38,7 @@ use crate::{ dns::{default_resolver, DnsResolver}, magicsock::{self, Handle, QuicMappedAddr}, tls, - watchable::{DirectWatcher, Watcher as _}, + watchable::{DirectWatcher, Watcher}, }; mod rtt_actor; @@ -759,29 +759,27 @@ impl Endpoint { /// The returned [`NodeAddr`] will have the current [`RelayUrl`] and direct addresses /// as they would be returned by [`Endpoint::home_relay`] and /// [`Endpoint::direct_addresses`]. - pub async fn node_addr(&self) -> Result { - let mut watch_addrs = self.direct_addresses(); - let mut watch_relay = self.home_relay(); - tokio::select! { - addrs = watch_addrs.initialized() => { - let addrs = addrs?; - let relay = self.home_relay().get()?; - Ok(NodeAddr::from_parts( - self.node_id(), - relay, - addrs.into_iter().map(|x| x.addr), - )) - }, - relay = watch_relay.initialized() => { - let relay = relay?; - let addrs = self.direct_addresses().get()?.unwrap_or_default(); - Ok(NodeAddr::from_parts( - self.node_id(), - Some(relay), - addrs.into_iter().map(|x| x.addr), - )) - }, - } + pub fn node_addr(&self) -> Result>> { + let watch_addrs = self.direct_addresses(); + let watch_relay = self.home_relay(); + let node_id = self.node_id(); + let watcher = + watch_addrs + .or(watch_relay) + .map(move |(addrs, relay)| match (addrs, relay) { + (Some(addrs), relay) => Some(NodeAddr::from_parts( + node_id, + relay, + addrs.into_iter().map(|x| x.addr), + )), + (None, Some(relay)) => Some(NodeAddr::from_parts( + node_id, + Some(relay), + std::iter::empty(), + )), + (None, None) => None, + })?; + Ok(watcher) } /// Returns a [`Watcher`] for the [`RelayUrl`] of the Relay server used as home relay. @@ -1447,7 +1445,7 @@ mod tests { .bind() .await .unwrap(); - let my_addr = ep.node_addr().await.unwrap(); + let my_addr = ep.node_addr().unwrap().initialized().await.unwrap(); let res = ep.connect(my_addr.clone(), TEST_ALPN).await; assert!(res.is_err()); let err = res.err().unwrap(); @@ -1729,8 +1727,8 @@ mod tests { .bind() .await .unwrap(); - let ep1_nodeaddr = ep1.node_addr().await.unwrap(); - let ep2_nodeaddr = ep2.node_addr().await.unwrap(); + let ep1_nodeaddr = ep1.node_addr().unwrap().initialized().await.unwrap(); + let ep2_nodeaddr = ep2.node_addr().unwrap().initialized().await.unwrap(); ep1.add_node_addr(ep2_nodeaddr.clone()).unwrap(); ep2.add_node_addr(ep1_nodeaddr.clone()).unwrap(); let ep1_nodeid = ep1.node_id(); @@ -1853,7 +1851,7 @@ mod tests { let ep1_nodeid = ep1.node_id(); let ep2_nodeid = ep2.node_id(); - let ep1_nodeaddr = ep1.node_addr().await.unwrap(); + let ep1_nodeaddr = ep1.node_addr().unwrap().initialized().await.unwrap(); tracing::info!( "node id 1 {ep1_nodeid}, relay URL {:?}", ep1_nodeaddr.relay_url() diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 5a1f41c651..bd0fc05bf3 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -3192,7 +3192,7 @@ mod tests { println!("first conn!"); let conn = m1 .endpoint - .connect(m2.endpoint.node_addr().await?, ALPN) + .connect(m2.endpoint.node_addr()?.initialized().await?, ALPN) .await?; println!("Closing first conn"); conn.close(0u32.into(), b"bye lolz"); diff --git a/iroh/src/watchable.rs b/iroh/src/watchable.rs index d1894e63d2..bd9fffe565 100644 --- a/iroh/src/watchable.rs +++ b/iroh/src/watchable.rs @@ -101,20 +101,6 @@ impl Watchable { } } -/// An observer for a value. -/// -/// The [`Watcher`] can get the current value, and will be notified when the value changes. -/// Only the most recent value is accessible, and if the thread with the [`Watchable`] -/// changes the value faster than the thread with the [`Watcher`] can keep up with, then -/// it'll miss in-between values. -/// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always -/// end up reporting the most recent state eventually. -#[derive(Debug, Clone)] -pub struct DirectWatcher { - epoch: u64, - shared: Weak>, -} - /// A handle to a value that's represented by one or more underlying [`Watchable`]s. /// /// This handle allows one to observe the latest state and be notified @@ -217,6 +203,38 @@ pub trait Watcher: Clone { watcher: self, } } + + /// Maps this watcher with a function that transforms the observed values. + fn map T>( + self, + map: F, + ) -> Result, Disconnected> { + Ok(MapWatcher { + current: (map)(self.get()?), + map, + watcher: self, + }) + } + + /// Returns a watcher that updates every time this or the other watcher + /// updates, and yields both watcher's items together when that happens. + fn or(self, other: W) -> OrWatcher { + OrWatcher(self, other) + } +} + +/// An observer for a value. +/// +/// The [`Watcher`] can get the current value, and will be notified when the value changes. +/// Only the most recent value is accessible, and if the thread with the [`Watchable`] +/// changes the value faster than the thread with the [`Watcher`] can keep up with, then +/// it'll miss in-between values. +/// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always +/// end up reporting the most recent state eventually. +#[derive(Debug, Clone)] +pub struct DirectWatcher { + epoch: u64, + shared: Weak>, } impl Watcher for DirectWatcher { @@ -245,10 +263,13 @@ impl Watcher for DirectWatcher { } /// Combines two [`Watcher`]s into a single watcher. +/// +/// This watcher updates when one of the inner watchers +/// is updated at least once. #[derive(Clone, Debug)] -pub struct Or(S, T); +pub struct OrWatcher(S, T); -impl Watcher for Or { +impl Watcher for OrWatcher { type Value = (S::Value, T::Value); fn get(&self) -> Result { @@ -270,6 +291,38 @@ impl Watcher for Or { } } +/// Maps a [`Watcher`] and allows filtering updates. +#[derive(Clone, Debug)] +pub struct MapWatcher T> { + map: F, + watcher: W, + current: T, +} + +impl T> Watcher for MapWatcher { + type Value = T; + + fn get(&self) -> Result { + Ok((self.map)(self.watcher.get()?)) + } + + fn poll_updated( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + loop { + let value = futures_lite::ready!(self.watcher.poll_updated(cx)?); + let mapped = (self.map)(value); + if mapped != self.current { + self.current = mapped.clone(); + return Poll::Ready(Ok(mapped)); + } else { + self.current = mapped; + } + } + } +} + /// Future returning the next item after the current one in a [`Watcher`]. /// /// See [`Watcher::updated`].