Skip to content

Commit

Permalink
fix(wireguard): contended get_wg_conn
Browse files Browse the repository at this point in the history
  • Loading branch information
XOR-op committed Oct 3, 2024
1 parent 9fcacc4 commit 1069407
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 54 deletions.
122 changes: 69 additions & 53 deletions boltconn/src/adapter/wireguard.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::adapter::{AddrConnector, AddrConnectorWrapper, Connector, Outbound, OutboundType};
use std::collections::HashMap;

use crate::adapter;
use crate::adapter::udp_over_tcp::UdpOverTcpAdapter;
Expand All @@ -13,6 +12,8 @@ use crate::transport::wireguard::{WireguardConfig, WireguardTunnel};
use crate::transport::{AdapterOrSocket, InterfaceAddress, UdpSocketAdapter};
use async_trait::async_trait;
use bytes::Bytes;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use hickory_resolver::config::{ResolverConfig, ResolverOpts};
use hickory_resolver::name_server::GenericConnector;
use hickory_resolver::proto::udp::DnsUdpSocket;
Expand Down Expand Up @@ -274,8 +275,7 @@ impl Endpoint {

pub struct WireguardManager {
iface: String,
// We use an async wrapper to avoid deadlock in DashMap
active_conn: Mutex<HashMap<WireguardConfig, Arc<Endpoint>>>,
active_conn: DashMap<WireguardConfig, Arc<Endpoint>>,
endpoint_resolver: Arc<Dns>,
timeout: Duration,
}
Expand All @@ -297,68 +297,84 @@ impl WireguardManager {
adapter: Option<AdapterOrSocket>,
ret_tx: tokio::sync::oneshot::Sender<bool>,
) -> Result<Arc<Endpoint>, TransportError> {
// optimistic trial to avoid extra config.clone()
if let Some(ep) = self.active_conn.get(config) {
if ep.is_active.alive() {
let _ = ret_tx.send(false);
return Ok(ep.clone());
}
}
// loop is only used for reconnecting a removed connection
for _ in 0..10 {
// get an existing conn, or create
let mut guard = self.active_conn.lock().await;
if let Some(endpoint) = guard.get(config) {
if endpoint.is_active.alive() {
let _ = ret_tx.send(false);
return Ok(endpoint.clone());
} else {
guard.remove(config);
continue;
}
} else {
let _ = ret_tx.send(true);
let server_addr =
adapter::get_dst(&self.endpoint_resolver, &config.endpoint).await?;
let outbound = match adapter {
Some(a) => a,
None => {
if config.over_tcp {
let stream = Egress::new(&self.iface).tcp_stream(server_addr).await?;
AdapterOrSocket::Adapter(Arc::new(UdpOverTcpAdapter::new(
stream,
server_addr,
)?))
} else {
AdapterOrSocket::Socket(match server_addr {
SocketAddr::V4(_) => {
let socket = Egress::new(&self.iface).udpv4_socket().await?;
socket.connect(server_addr).await?;
socket
}
SocketAddr::V6(_) => {
let socket = Egress::new(&self.iface).udpv6_socket().await?;
socket.connect(server_addr).await?;
socket
}
})
}
// warning: if two keys fall into the same shard, the reconnecting may block this shard
match self.active_conn.entry(config.clone()) {
Entry::Occupied(entry) => {
if entry.get().is_active.alive() {
let _ = ret_tx.send(false);
return Ok(entry.get().clone());
} else {
entry.remove();
continue;
}
};
let ep = Endpoint::new(
name,
outbound,
config,
self.endpoint_resolver.clone(),
self.timeout,
)
.await?;
guard.insert(config.clone(), ep.clone());
return Ok(ep);
}
Entry::Vacant(e) => {
let _ = ret_tx.send(true);
let ep = self.create_endpoint(name, config, adapter).await?;
e.insert(ep.clone());
return Ok(ep);
}
}
}
Err(TransportError::WireGuard(
"get_wg_conn: unexpected loop time",
))
}

async fn create_endpoint(
&self,
name: &str,
config: &WireguardConfig,
adapter: Option<AdapterOrSocket>,
) -> Result<Arc<Endpoint>, TransportError> {
let outbound = match adapter {
Some(a) => a,
None => {
let server_addr =
adapter::get_dst(&self.endpoint_resolver, &config.endpoint).await?;
if config.over_tcp {
let stream = Egress::new(&self.iface).tcp_stream(server_addr).await?;
AdapterOrSocket::Adapter(Arc::new(UdpOverTcpAdapter::new(stream, server_addr)?))
} else {
AdapterOrSocket::Socket(match server_addr {
SocketAddr::V4(_) => {
let socket = Egress::new(&self.iface).udpv4_socket().await?;
socket.connect(server_addr).await?;
socket
}
SocketAddr::V6(_) => {
let socket = Egress::new(&self.iface).udpv6_socket().await?;
socket.connect(server_addr).await?;
socket
}
})
}
}
};
Endpoint::new(
name,
outbound,
config,
self.endpoint_resolver.clone(),
self.timeout,
)
.await
}

pub async fn debug_internal_state(&self) -> Vec<boltapi::MasterConnectionStatus> {
let conns = self.active_conn.lock().await;
let mut ret = Vec::new();
for (_, v) in conns.iter() {
let r = v.debug_internal_state().await;
for entry in self.active_conn.iter() {
let r = entry.debug_internal_state().await;
ret.push(r);
}
ret
Expand Down
5 changes: 4 additions & 1 deletion boltconn/src/transport/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub struct WireguardConfig {
impl Debug for WireguardConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("")
.field(&self.name)
.field(&self.ip_addr)
.field(&self.ip_addr6)
.field(&self.endpoint)
Expand All @@ -51,7 +52,8 @@ impl Debug for WireguardConfig {

impl PartialEq for WireguardConfig {
fn eq(&self, other: &Self) -> bool {
self.public_key == other.public_key
self.name == other.name
&& self.public_key == other.public_key
&& self.ip_addr == other.ip_addr
&& self.ip_addr6 == other.ip_addr6
&& self.endpoint == other.endpoint
Expand All @@ -62,6 +64,7 @@ impl Eq for WireguardConfig {}

impl Hash for WireguardConfig {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.ip_addr.hash(state);
self.ip_addr6.hash(state);
self.public_key.hash(state);
Expand Down

0 comments on commit 1069407

Please sign in to comment.