Skip to content

Commit

Permalink
Introduce MapWatcher and make Endpoint::node_addr return `impl Wa…
Browse files Browse the repository at this point in the history
…tcher`
  • Loading branch information
matheus23 committed Dec 13, 2024
1 parent 5cf71b5 commit 2529cde
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 57 deletions.
4 changes: 2 additions & 2 deletions iroh/examples/connect-unreliable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
use std::net::SocketAddr;

use clap::Parser;
use iroh::{Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey};
use iroh::{watchable::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey};
use tracing::info;

// An example ALPN that we are using to communicate over the `Endpoint`
Expand Down Expand Up @@ -50,7 +50,7 @@ async fn main() -> anyhow::Result<()> {
.bind()
.await?;

let node_addr = endpoint.node_addr().await?;
let node_addr = endpoint.node_addr()?.initialized().await?;
let me = node_addr.node_id;
println!("node id: {me}");
println!("node listening addresses:");
Expand Down
4 changes: 2 additions & 2 deletions iroh/examples/listen-unreliable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! This example uses the default relay servers to attempt to holepunch, and will use that relay server to relay packets if the two devices cannot establish a direct UDP connection.
//! run this example from the project root:
//! $ cargo run --example listen-unreliable
use iroh::{Endpoint, RelayMode, SecretKey};
use iroh::{watchable::Watcher as _, Endpoint, RelayMode, SecretKey};
use tracing::{info, warn};

// An example ALPN that we are using to communicate over the `Endpoint`
Expand Down Expand Up @@ -35,7 +35,7 @@ async fn main() -> anyhow::Result<()> {
println!("node id: {me}");
println!("node listening addresses:");

let node_addr = endpoint.node_addr().await?;
let node_addr = endpoint.node_addr()?.initialized().await?;
let local_addrs = node_addr
.direct_addresses
.into_iter()
Expand Down
4 changes: 2 additions & 2 deletions iroh/examples/listen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! $ cargo run --example listen
use std::time::Duration;

use iroh::{endpoint::ConnectionError, Endpoint, RelayMode, SecretKey};
use iroh::{endpoint::ConnectionError, watchable::Watcher as _, Endpoint, RelayMode, SecretKey};
use tracing::{debug, info, warn};

// An example ALPN that we are using to communicate over the `Endpoint`
Expand Down Expand Up @@ -37,7 +37,7 @@ async fn main() -> anyhow::Result<()> {
println!("node id: {me}");
println!("node listening addresses:");

let node_addr = endpoint.node_addr().await?;
let node_addr = endpoint.node_addr()?.initialized().await?;
let local_addrs = node_addr
.direct_addresses
.into_iter()
Expand Down
12 changes: 6 additions & 6 deletions iroh/src/discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ mod tests {
use tokio_util::task::AbortOnDropHandle;

use super::*;
use crate::RelayMode;
use crate::{watchable::Watcher as _, RelayMode};

type InfoStore = HashMap<NodeId, (Option<RelayUrl>, BTreeSet<SocketAddr>, u64)>;

Expand Down Expand Up @@ -580,7 +580,7 @@ mod tests {
};
let ep1_addr = NodeAddr::new(ep1.node_id());
// wait for our address to be updated and thus published at least once
ep1.node_addr().await?;
ep1.node_addr()?.initialized().await?;
let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?;
Ok(())
}
Expand All @@ -606,7 +606,7 @@ mod tests {
};
let ep1_addr = NodeAddr::new(ep1.node_id());
// wait for out address to be updated and thus published at least once
ep1.node_addr().await?;
ep1.node_addr()?.initialized().await?;
let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?;
Ok(())
}
Expand Down Expand Up @@ -636,7 +636,7 @@ mod tests {
};
let ep1_addr = NodeAddr::new(ep1.node_id());
// wait for out address to be updated and thus published at least once
ep1.node_addr().await?;
ep1.node_addr()?.initialized().await?;
let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?;
Ok(())
}
Expand All @@ -659,7 +659,7 @@ mod tests {
};
let ep1_addr = NodeAddr::new(ep1.node_id());
// wait for out address to be updated and thus published at least once
ep1.node_addr().await?;
ep1.node_addr()?.initialized().await?;
let res = ep2.connect(ep1_addr, TEST_ALPN).await;
assert!(res.is_err());
Ok(())
Expand All @@ -682,7 +682,7 @@ mod tests {
new_endpoint(secret, disco).await
};
// wait for out address to be updated and thus published at least once
ep1.node_addr().await?;
ep1.node_addr()?.initialized().await?;
let ep1_wrong_addr = NodeAddr {
node_id: ep1.node_id(),
relay_url: None,
Expand Down
54 changes: 26 additions & 28 deletions iroh/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::{
dns::{default_resolver, DnsResolver},
magicsock::{self, Handle, QuicMappedAddr},
tls,
watchable::{DirectWatcher, Watcher as _},
watchable::{DirectWatcher, Watcher},
};

mod rtt_actor;
Expand Down Expand Up @@ -759,29 +759,27 @@ impl Endpoint {
/// The returned [`NodeAddr`] will have the current [`RelayUrl`] and direct addresses
/// as they would be returned by [`Endpoint::home_relay`] and
/// [`Endpoint::direct_addresses`].
pub async fn node_addr(&self) -> Result<NodeAddr> {
let mut watch_addrs = self.direct_addresses();
let mut watch_relay = self.home_relay();
tokio::select! {
addrs = watch_addrs.initialized() => {
let addrs = addrs?;
let relay = self.home_relay().get()?;
Ok(NodeAddr::from_parts(
self.node_id(),
relay,
addrs.into_iter().map(|x| x.addr),
))
},
relay = watch_relay.initialized() => {
let relay = relay?;
let addrs = self.direct_addresses().get()?.unwrap_or_default();
Ok(NodeAddr::from_parts(
self.node_id(),
Some(relay),
addrs.into_iter().map(|x| x.addr),
))
},
}
pub fn node_addr(&self) -> Result<impl Watcher<Value = Option<NodeAddr>>> {
let watch_addrs = self.direct_addresses();
let watch_relay = self.home_relay();
let node_id = self.node_id();
let watcher =
watch_addrs
.or(watch_relay)
.map(move |(addrs, relay)| match (addrs, relay) {
(Some(addrs), relay) => Some(NodeAddr::from_parts(
node_id,
relay,
addrs.into_iter().map(|x| x.addr),
)),
(None, Some(relay)) => Some(NodeAddr::from_parts(
node_id,
Some(relay),
std::iter::empty(),
)),
(None, None) => None,
})?;
Ok(watcher)
}

/// Returns a [`Watcher`] for the [`RelayUrl`] of the Relay server used as home relay.
Expand Down Expand Up @@ -1447,7 +1445,7 @@ mod tests {
.bind()
.await
.unwrap();
let my_addr = ep.node_addr().await.unwrap();
let my_addr = ep.node_addr().unwrap().initialized().await.unwrap();
let res = ep.connect(my_addr.clone(), TEST_ALPN).await;
assert!(res.is_err());
let err = res.err().unwrap();
Expand Down Expand Up @@ -1729,8 +1727,8 @@ mod tests {
.bind()
.await
.unwrap();
let ep1_nodeaddr = ep1.node_addr().await.unwrap();
let ep2_nodeaddr = ep2.node_addr().await.unwrap();
let ep1_nodeaddr = ep1.node_addr().unwrap().initialized().await.unwrap();
let ep2_nodeaddr = ep2.node_addr().unwrap().initialized().await.unwrap();
ep1.add_node_addr(ep2_nodeaddr.clone()).unwrap();
ep2.add_node_addr(ep1_nodeaddr.clone()).unwrap();
let ep1_nodeid = ep1.node_id();
Expand Down Expand Up @@ -1853,7 +1851,7 @@ mod tests {
let ep1_nodeid = ep1.node_id();
let ep2_nodeid = ep2.node_id();

let ep1_nodeaddr = ep1.node_addr().await.unwrap();
let ep1_nodeaddr = ep1.node_addr().unwrap().initialized().await.unwrap();
tracing::info!(
"node id 1 {ep1_nodeid}, relay URL {:?}",
ep1_nodeaddr.relay_url()
Expand Down
2 changes: 1 addition & 1 deletion iroh/src/magicsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3192,7 +3192,7 @@ mod tests {
println!("first conn!");
let conn = m1
.endpoint
.connect(m2.endpoint.node_addr().await?, ALPN)
.connect(m2.endpoint.node_addr()?.initialized().await?, ALPN)
.await?;
println!("Closing first conn");
conn.close(0u32.into(), b"bye lolz");
Expand Down
85 changes: 69 additions & 16 deletions iroh/src/watchable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,6 @@ impl<T: Clone + Eq> Watchable<T> {
}
}

/// An observer for a value.
///
/// The [`Watcher`] can get the current value, and will be notified when the value changes.
/// Only the most recent value is accessible, and if the thread with the [`Watchable`]
/// changes the value faster than the thread with the [`Watcher`] can keep up with, then
/// it'll miss in-between values.
/// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always
/// end up reporting the most recent state eventually.
#[derive(Debug, Clone)]
pub struct DirectWatcher<T> {
epoch: u64,
shared: Weak<Shared<T>>,
}

/// A handle to a value that's represented by one or more underlying [`Watchable`]s.
///
/// This handle allows one to observe the latest state and be notified
Expand Down Expand Up @@ -217,6 +203,38 @@ pub trait Watcher: Clone {
watcher: self,
}
}

/// Maps this watcher with a function that transforms the observed values.
fn map<T: Clone + Eq, F: Clone + Fn(Self::Value) -> T>(
self,
map: F,
) -> Result<MapWatcher<Self, T, F>, Disconnected> {
Ok(MapWatcher {
current: (map)(self.get()?),
map,
watcher: self,
})
}

/// Returns a watcher that updates every time this or the other watcher
/// updates, and yields both watcher's items together when that happens.
fn or<W: Watcher>(self, other: W) -> OrWatcher<Self, W> {
OrWatcher(self, other)
}
}

/// An observer for a value.
///
/// The [`Watcher`] can get the current value, and will be notified when the value changes.
/// Only the most recent value is accessible, and if the thread with the [`Watchable`]
/// changes the value faster than the thread with the [`Watcher`] can keep up with, then
/// it'll miss in-between values.
/// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always
/// end up reporting the most recent state eventually.
#[derive(Debug, Clone)]
pub struct DirectWatcher<T> {
epoch: u64,
shared: Weak<Shared<T>>,
}

impl<T: Clone + Eq> Watcher for DirectWatcher<T> {
Expand Down Expand Up @@ -245,10 +263,13 @@ impl<T: Clone + Eq> Watcher for DirectWatcher<T> {
}

/// Combines two [`Watcher`]s into a single watcher.
///
/// This watcher updates when one of the inner watchers
/// is updated at least once.
#[derive(Clone, Debug)]
pub struct Or<S: Watcher, T: Watcher>(S, T);
pub struct OrWatcher<S: Watcher, T: Watcher>(S, T);

impl<S: Watcher, T: Watcher> Watcher for Or<S, T> {
impl<S: Watcher, T: Watcher> Watcher for OrWatcher<S, T> {
type Value = (S::Value, T::Value);

fn get(&self) -> Result<Self::Value, Disconnected> {
Expand All @@ -270,6 +291,38 @@ impl<S: Watcher, T: Watcher> Watcher for Or<S, T> {
}
}

/// Maps a [`Watcher`] and allows filtering updates.
#[derive(Clone, Debug)]
pub struct MapWatcher<W: Watcher, T: Clone + Eq, F: Clone + Fn(W::Value) -> T> {
map: F,
watcher: W,
current: T,
}

impl<W: Watcher, T: Clone + Eq, F: Clone + Fn(W::Value) -> T> Watcher for MapWatcher<W, T, F> {
type Value = T;

fn get(&self) -> Result<Self::Value, Disconnected> {
Ok((self.map)(self.watcher.get()?))
}

fn poll_updated(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Self::Value, Disconnected>> {
loop {
let value = futures_lite::ready!(self.watcher.poll_updated(cx)?);
let mapped = (self.map)(value);
if mapped != self.current {
self.current = mapped.clone();
return Poll::Ready(Ok(mapped));
} else {
self.current = mapped;
}
}
}
}

/// Future returning the next item after the current one in a [`Watcher`].
///
/// See [`Watcher::updated`].
Expand Down

0 comments on commit 2529cde

Please sign in to comment.