From 1069407d69589bfb875b28abea27f9e2e0bcf58f Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Thu, 3 Oct 2024 17:05:33 -0400 Subject: [PATCH] fix(wireguard): contended get_wg_conn --- boltconn/src/adapter/wireguard.rs | 122 ++++++++++++++++------------ boltconn/src/transport/wireguard.rs | 5 +- 2 files changed, 73 insertions(+), 54 deletions(-) diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 7df0988..872ee69 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -1,5 +1,4 @@ use crate::adapter::{AddrConnector, AddrConnectorWrapper, Connector, Outbound, OutboundType}; -use std::collections::HashMap; use crate::adapter; use crate::adapter::udp_over_tcp::UdpOverTcpAdapter; @@ -13,6 +12,8 @@ use crate::transport::wireguard::{WireguardConfig, WireguardTunnel}; use crate::transport::{AdapterOrSocket, InterfaceAddress, UdpSocketAdapter}; use async_trait::async_trait; use bytes::Bytes; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; use hickory_resolver::config::{ResolverConfig, ResolverOpts}; use hickory_resolver::name_server::GenericConnector; use hickory_resolver::proto::udp::DnsUdpSocket; @@ -274,8 +275,7 @@ impl Endpoint { pub struct WireguardManager { iface: String, - // We use an async wrapper to avoid deadlock in DashMap - active_conn: Mutex>>, + active_conn: DashMap>, endpoint_resolver: Arc, timeout: Duration, } @@ -297,56 +297,33 @@ impl WireguardManager { adapter: Option, ret_tx: tokio::sync::oneshot::Sender, ) -> Result, TransportError> { + // optimistic trial to avoid extra config.clone() + if let Some(ep) = self.active_conn.get(config) { + if ep.is_active.alive() { + let _ = ret_tx.send(false); + return Ok(ep.clone()); + } + } + // loop is only used for reconnecting a removed connection for _ in 0..10 { // get an existing conn, or create - let mut guard = self.active_conn.lock().await; - if let Some(endpoint) = guard.get(config) { - if endpoint.is_active.alive() { - let _ = ret_tx.send(false); - return Ok(endpoint.clone()); - } else { - guard.remove(config); - continue; - } - } else { - let _ = ret_tx.send(true); - let server_addr = - adapter::get_dst(&self.endpoint_resolver, &config.endpoint).await?; - let outbound = match adapter { - Some(a) => a, - None => { - if config.over_tcp { - let stream = Egress::new(&self.iface).tcp_stream(server_addr).await?; - AdapterOrSocket::Adapter(Arc::new(UdpOverTcpAdapter::new( - stream, - server_addr, - )?)) - } else { - AdapterOrSocket::Socket(match server_addr { - SocketAddr::V4(_) => { - let socket = Egress::new(&self.iface).udpv4_socket().await?; - socket.connect(server_addr).await?; - socket - } - SocketAddr::V6(_) => { - let socket = Egress::new(&self.iface).udpv6_socket().await?; - socket.connect(server_addr).await?; - socket - } - }) - } + // warning: if two keys fall into the same shard, the reconnecting may block this shard + match self.active_conn.entry(config.clone()) { + Entry::Occupied(entry) => { + if entry.get().is_active.alive() { + let _ = ret_tx.send(false); + return Ok(entry.get().clone()); + } else { + entry.remove(); + continue; } - }; - let ep = Endpoint::new( - name, - outbound, - config, - self.endpoint_resolver.clone(), - self.timeout, - ) - .await?; - guard.insert(config.clone(), ep.clone()); - return Ok(ep); + } + Entry::Vacant(e) => { + let _ = ret_tx.send(true); + let ep = self.create_endpoint(name, config, adapter).await?; + e.insert(ep.clone()); + return Ok(ep); + } } } Err(TransportError::WireGuard( @@ -354,11 +331,50 @@ impl WireguardManager { )) } + async fn create_endpoint( + &self, + name: &str, + config: &WireguardConfig, + adapter: Option, + ) -> Result, TransportError> { + let outbound = match adapter { + Some(a) => a, + None => { + let server_addr = + adapter::get_dst(&self.endpoint_resolver, &config.endpoint).await?; + if config.over_tcp { + let stream = Egress::new(&self.iface).tcp_stream(server_addr).await?; + AdapterOrSocket::Adapter(Arc::new(UdpOverTcpAdapter::new(stream, server_addr)?)) + } else { + AdapterOrSocket::Socket(match server_addr { + SocketAddr::V4(_) => { + let socket = Egress::new(&self.iface).udpv4_socket().await?; + socket.connect(server_addr).await?; + socket + } + SocketAddr::V6(_) => { + let socket = Egress::new(&self.iface).udpv6_socket().await?; + socket.connect(server_addr).await?; + socket + } + }) + } + } + }; + Endpoint::new( + name, + outbound, + config, + self.endpoint_resolver.clone(), + self.timeout, + ) + .await + } + pub async fn debug_internal_state(&self) -> Vec { - let conns = self.active_conn.lock().await; let mut ret = Vec::new(); - for (_, v) in conns.iter() { - let r = v.debug_internal_state().await; + for entry in self.active_conn.iter() { + let r = entry.debug_internal_state().await; ret.push(r); } ret diff --git a/boltconn/src/transport/wireguard.rs b/boltconn/src/transport/wireguard.rs index 4be8821..7f94371 100644 --- a/boltconn/src/transport/wireguard.rs +++ b/boltconn/src/transport/wireguard.rs @@ -41,6 +41,7 @@ pub struct WireguardConfig { impl Debug for WireguardConfig { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_tuple("") + .field(&self.name) .field(&self.ip_addr) .field(&self.ip_addr6) .field(&self.endpoint) @@ -51,7 +52,8 @@ impl Debug for WireguardConfig { impl PartialEq for WireguardConfig { fn eq(&self, other: &Self) -> bool { - self.public_key == other.public_key + self.name == other.name + && self.public_key == other.public_key && self.ip_addr == other.ip_addr && self.ip_addr6 == other.ip_addr6 && self.endpoint == other.endpoint @@ -62,6 +64,7 @@ impl Eq for WireguardConfig {} impl Hash for WireguardConfig { fn hash(&self, state: &mut H) { + self.name.hash(state); self.ip_addr.hash(state); self.ip_addr6.hash(state); self.public_key.hash(state);