Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

iroh-net transport #105

Merged
merged 11 commits into from
Nov 11, 2024
91 changes: 79 additions & 12 deletions src/transport/iroh_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ use crate::{
};

use std::{
collections::BTreeSet,
fmt,
future::Future,
io,
iter::once,
marker::PhantomData,
net::SocketAddr,
pin::pin,
pin::Pin,
pin::{pin, Pin},
sync::Arc,
task::{Context, Poll},
};

use futures_lite::{Stream, StreamExt};
use futures_sink::Sink;
use futures_util::FutureExt;
use iroh_net::NodeAddr;
use iroh_net::{NodeAddr, NodeId};
use pin_project::pin_project;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::oneshot;
Expand Down Expand Up @@ -70,6 +70,16 @@ impl Drop for ServerEndpointInner {
}
}

/// Access control for the server, either unrestricted or limited to a list of nodes that can
/// connect to the server endpoint
#[derive(Debug, Clone)]
pub enum AccessControl {
/// Unrestricted access, anyone can connect
Unrestricted,
/// Restricted access, only nodes in the list can connect, all other nodes will be rejected
Allowed(Vec<NodeId>),
}

/// A server endpoint using a quinn connection
#[derive(Debug)]
pub struct IrohNetServerEndpoint<In: RpcMessage, Out: RpcMessage> {
Expand Down Expand Up @@ -103,40 +113,97 @@ impl<In: RpcMessage, Out: RpcMessage> IrohNetServerEndpoint<In, Out> {
}
}

async fn endpoint_handler(endpoint: iroh_net::Endpoint, sender: flume::Sender<SocketInner>) {
async fn endpoint_handler(
endpoint: iroh_net::Endpoint,
sender: flume::Sender<SocketInner>,
allowed_node_ids: BTreeSet<NodeId>,
) {
loop {
tracing::debug!("Waiting for incoming connection...");
let connecting = match endpoint.accept().await {
Some(connecting) => connecting,
None => break,
};

tracing::debug!("Awaiting connection from connect...");
let conection = match connecting.await {
Ok(conection) => conection,
let connection = match connecting.await {
Ok(connection) => connection,
Err(e) => {
tracing::warn!("Error accepting connection: {}", e);
continue;
}
};

// When the `allowed_node_ids` is empty, it's empty forever, so the CPU's branch
// prediction should always optimize this block away from this loop.
// The same applies when it isn't empty, ignoring the check for emptiness and always
// extracting the node id and checking if it's in the set.
if !allowed_node_ids.is_empty() {
let Ok(client_node_id) = iroh_net::endpoint::get_remote_node_id(&connection)
.map_err(|e| {
tracing::error!(
?e,
"Failed to extract iroh-net node id from incoming connection from {:?}",
connection.remote_address()
)
})
else {
connection.close(0u32.into(), b"failed to extract iroh-net node id");
continue;
};

if !allowed_node_ids.contains(&client_node_id) {
connection.close(0u32.into(), b"forbidden node id");
continue;
}
}

tracing::debug!(
"Connection established from {:?}",
conection.remote_address()
connection.remote_address()
);

tracing::debug!("Spawning connection handler...");
tokio::spawn(Self::connection_handler(conection, sender.clone()));
tokio::spawn(Self::connection_handler(connection, sender.clone()));
}
}

/// Create a new server channel, given a quinn endpoint.
///
/// The endpoint must be a server endpoint.
/// Create a new server channel, given a quinn endpoint, with unrestricted access by node id
///
/// The server channel will take care of listening on the endpoint and spawning
/// handlers for new connections.
pub fn new(endpoint: iroh_net::Endpoint) -> io::Result<Self> {
Self::new_with_access_control(endpoint, AccessControl::Unrestricted)
}

/// Create a new server endpoint, with specified access control
///
/// The server channel will take care of listening on the endpoint and spawning
/// handlers for new connections.
pub fn new_with_access_control(
endpoint: iroh_net::Endpoint,
access_control: AccessControl,
) -> io::Result<Self> {
let allowed_node_ids = match access_control {
AccessControl::Unrestricted => BTreeSet::new(),
AccessControl::Allowed(list) if list.is_empty() => {
fogodev marked this conversation as resolved.
Show resolved Hide resolved
tracing::warn!(
fogodev marked this conversation as resolved.
Show resolved Hide resolved
"Allowed list of `NodeId`s is empty, iroh-net \
quic-rpc endpoint will have unrestricted access!"
);
BTreeSet::new()
}
AccessControl::Allowed(list) => BTreeSet::from_iter(list),
};

let (ipv4_socket_addr, maybe_ipv6_socket_addr) = endpoint.bound_sockets();
let (sender, receiver) = flume::bounded(16);
let task = tokio::spawn(Self::endpoint_handler(endpoint.clone(), sender));
let task = tokio::spawn(Self::endpoint_handler(
endpoint.clone(),
sender,
allowed_node_ids,
));

Ok(Self {
inner: Arc::new(ServerEndpointInner {
endpoint: Some(endpoint),
Expand Down
Loading