From e928c0d9f571c16ce4b355d2eab77303e15603c2 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:47:58 -0400 Subject: [PATCH 01/14] change(adapter): outbounds now store proxy name --- boltconn/src/adapter/chain.rs | 8 +++++++- boltconn/src/adapter/direct.rs | 4 ++++ boltconn/src/adapter/http.rs | 14 +++++++++++++- boltconn/src/adapter/mod.rs | 4 ++++ boltconn/src/adapter/shadowsocks.rs | 7 +++++++ boltconn/src/adapter/socks5.rs | 14 +++++++++++++- boltconn/src/adapter/ssh.rs | 7 +++++++ boltconn/src/adapter/trojan.rs | 14 +++++++++++++- boltconn/src/adapter/wireguard.rs | 7 +++++++ boltconn/src/proxy/dispatcher.rs | 26 +++++++++++++++++++------- boltconn/src/proxy/mod.rs | 2 ++ 11 files changed, 96 insertions(+), 11 deletions(-) diff --git a/boltconn/src/adapter/chain.rs b/boltconn/src/adapter/chain.rs index c15e019..900e4a3 100644 --- a/boltconn/src/adapter/chain.rs +++ b/boltconn/src/adapter/chain.rs @@ -14,12 +14,14 @@ use tokio::task::JoinHandle; #[derive(Clone)] pub struct ChainOutbound { + name: String, chains: Vec>, } impl ChainOutbound { - pub fn new(chains: Vec>) -> Self { + pub fn new(name: &str, chains: Vec>) -> Self { Self { + name: name.to_string(), chains: chains.into_iter().map(Arc::from).collect(), } } @@ -122,6 +124,10 @@ impl ChainOutbound { #[async_trait] impl Outbound for ChainOutbound { + fn id(&self) -> String { + self.name.clone() + } + fn outbound_type(&self) -> OutboundType { OutboundType::Chain } diff --git a/boltconn/src/adapter/direct.rs b/boltconn/src/adapter/direct.rs index 9d0e0b3..572d692 100644 --- a/boltconn/src/adapter/direct.rs +++ b/boltconn/src/adapter/direct.rs @@ -68,6 +68,10 @@ impl DirectOutbound { #[async_trait] impl Outbound for DirectOutbound { + fn id(&self) -> String { + "DIRECT".to_string() + } + fn outbound_type(&self) -> OutboundType { OutboundType::Direct } diff --git a/boltconn/src/adapter/http.rs b/boltconn/src/adapter/http.rs index 52d4dfd..42e00d1 100644 --- a/boltconn/src/adapter/http.rs +++ b/boltconn/src/adapter/http.rs @@ -24,6 +24,7 @@ pub struct HttpConfig { #[derive(Clone)] pub struct HttpOutbound { + name: String, iface_name: String, dst: NetworkAddr, dns: Arc, @@ -31,8 +32,15 @@ pub struct HttpOutbound { } impl HttpOutbound { - pub fn new(iface_name: &str, dst: NetworkAddr, dns: Arc, config: HttpConfig) -> Self { + pub fn new( + name: &str, + iface_name: &str, + dst: NetworkAddr, + dns: Arc, + config: HttpConfig, + ) -> Self { Self { + name: name.to_string(), iface_name: iface_name.to_string(), dst, dns, @@ -96,6 +104,10 @@ impl HttpOutbound { #[async_trait] impl Outbound for HttpOutbound { + fn id(&self) -> String { + self.name.clone() + } + fn outbound_type(&self) -> OutboundType { OutboundType::Http } diff --git a/boltconn/src/adapter/mod.rs b/boltconn/src/adapter/mod.rs index 8b454a9..e92eeac 100644 --- a/boltconn/src/adapter/mod.rs +++ b/boltconn/src/adapter/mod.rs @@ -167,6 +167,10 @@ impl OutboundType { #[async_trait] pub trait Outbound: Send + Sync { + /// Get the globally unique id of the outbound to distinguish it + /// even from others with the same type. + fn id(&self) -> String; + fn outbound_type(&self) -> OutboundType; /// Run with tokio::spawn. diff --git a/boltconn/src/adapter/shadowsocks.rs b/boltconn/src/adapter/shadowsocks.rs index de9489c..20d03dd 100644 --- a/boltconn/src/adapter/shadowsocks.rs +++ b/boltconn/src/adapter/shadowsocks.rs @@ -38,6 +38,7 @@ impl From for ServerConfig { #[derive(Clone)] pub struct SSOutbound { + name: String, iface_name: String, dst: NetworkAddr, dns: Arc, @@ -46,12 +47,14 @@ pub struct SSOutbound { impl SSOutbound { pub fn new( + name: &str, iface_name: &str, dst: NetworkAddr, dns: Arc, config: ShadowSocksConfig, ) -> Self { Self { + name: name.to_string(), iface_name: iface_name.to_string(), dst, dns, @@ -132,6 +135,10 @@ impl SSOutbound { #[async_trait] impl Outbound for SSOutbound { + fn id(&self) -> String { + self.name.clone() + } + fn outbound_type(&self) -> OutboundType { OutboundType::Shadowsocks } diff --git a/boltconn/src/adapter/socks5.rs b/boltconn/src/adapter/socks5.rs index 3583e4b..70d97ba 100644 --- a/boltconn/src/adapter/socks5.rs +++ b/boltconn/src/adapter/socks5.rs @@ -41,6 +41,7 @@ impl Socks5Config { #[derive(Clone)] pub struct Socks5Outbound { + name: String, iface_name: String, dst: NetworkAddr, dns: Arc, @@ -48,8 +49,15 @@ pub struct Socks5Outbound { } impl Socks5Outbound { - pub fn new(iface_name: &str, dst: NetworkAddr, dns: Arc, config: Socks5Config) -> Self { + pub fn new( + name: &str, + iface_name: &str, + dst: NetworkAddr, + dns: Arc, + config: Socks5Config, + ) -> Self { Self { + name: name.to_string(), iface_name: iface_name.to_string(), dst, dns, @@ -130,6 +138,10 @@ impl Socks5Outbound { #[async_trait] impl Outbound for Socks5Outbound { + fn id(&self) -> String { + self.name.clone() + } + fn outbound_type(&self) -> OutboundType { OutboundType::Socks5 } diff --git a/boltconn/src/adapter/ssh.rs b/boltconn/src/adapter/ssh.rs index 0d480df..cef89da 100644 --- a/boltconn/src/adapter/ssh.rs +++ b/boltconn/src/adapter/ssh.rs @@ -19,6 +19,7 @@ use tokio::task::JoinHandle; #[derive(Clone)] pub struct SshOutboundHandle { + name: String, iface_name: String, dst: NetworkAddr, dns: Arc, @@ -28,6 +29,7 @@ pub struct SshOutboundHandle { impl SshOutboundHandle { pub fn new( + name: &str, iface_name: &str, dst: NetworkAddr, dns: Arc, @@ -35,6 +37,7 @@ impl SshOutboundHandle { manager: Arc, ) -> Self { Self { + name: name.to_string(), iface_name: iface_name.to_string(), dst, dns, @@ -82,6 +85,10 @@ impl SshOutboundHandle { #[async_trait] impl Outbound for SshOutboundHandle { + fn id(&self) -> String { + self.name.clone() + } + fn outbound_type(&self) -> OutboundType { OutboundType::Ssh } diff --git a/boltconn/src/adapter/trojan.rs b/boltconn/src/adapter/trojan.rs index e266c76..ab2169a 100644 --- a/boltconn/src/adapter/trojan.rs +++ b/boltconn/src/adapter/trojan.rs @@ -27,6 +27,7 @@ use tokio_tungstenite::client_async; #[derive(Clone)] pub struct TrojanOutbound { + name: String, iface_name: String, dst: NetworkAddr, dns: Arc, @@ -34,8 +35,15 @@ pub struct TrojanOutbound { } impl TrojanOutbound { - pub fn new(iface_name: &str, dst: NetworkAddr, dns: Arc, config: TrojanConfig) -> Self { + pub fn new( + name: &str, + iface_name: &str, + dst: NetworkAddr, + dns: Arc, + config: TrojanConfig, + ) -> Self { Self { + name: name.to_string(), iface_name: iface_name.to_string(), dst, dns, @@ -173,6 +181,10 @@ impl TrojanOutbound { #[async_trait] impl Outbound for TrojanOutbound { + fn id(&self) -> String { + self.name.clone() + } + fn outbound_type(&self) -> OutboundType { OutboundType::Trojan } diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 9c9f1aa..9c1a41b 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -328,6 +328,7 @@ impl WireguardManager { #[derive(Clone)] pub struct WireguardHandle { + name: String, src: SocketAddr, dst: NetworkAddr, config: Arc, @@ -337,6 +338,7 @@ pub struct WireguardHandle { impl WireguardHandle { pub fn new( + name: &str, src: SocketAddr, dst: NetworkAddr, config: WireguardConfig, @@ -344,6 +346,7 @@ impl WireguardHandle { dns_config: Arc, ) -> Self { Self { + name: name.to_string(), src, dst, config: Arc::new(config), @@ -403,6 +406,10 @@ impl WireguardHandle { #[async_trait] impl Outbound for WireguardHandle { + fn id(&self) -> String { + self.name.clone() + } + fn outbound_type(&self) -> OutboundType { OutboundType::Wireguard } diff --git a/boltconn/src/proxy/dispatcher.rs b/boltconn/src/proxy/dispatcher.rs index 085e31a..38a3921 100644 --- a/boltconn/src/proxy/dispatcher.rs +++ b/boltconn/src/proxy/dispatcher.rs @@ -86,6 +86,7 @@ impl Dispatcher { pub(super) fn build_normal_outbound( &self, + proxy_name: &str, iface_name: &str, proxy_config: &ProxyImpl, src_addr: SocketAddr, @@ -104,6 +105,7 @@ impl Dispatcher { ), ProxyImpl::Http(cfg) => ( Box::new(HttpOutbound::new( + proxy_name, iface_name, dst_addr.clone(), self.dns.clone(), @@ -113,6 +115,7 @@ impl Dispatcher { ), ProxyImpl::Socks5(cfg) => ( Box::new(Socks5Outbound::new( + proxy_name, iface_name, dst_addr.clone(), self.dns.clone(), @@ -122,6 +125,7 @@ impl Dispatcher { ), ProxyImpl::Shadowsocks(cfg) => ( Box::new(SSOutbound::new( + proxy_name, iface_name, dst_addr.clone(), self.dns.clone(), @@ -131,6 +135,7 @@ impl Dispatcher { ), ProxyImpl::Trojan(cfg) => ( Box::new(TrojanOutbound::new( + proxy_name, iface_name, dst_addr.clone(), self.dns.clone(), @@ -140,6 +145,7 @@ impl Dispatcher { ), ProxyImpl::Wireguard(cfg) => ( Box::new(WireguardHandle::new( + proxy_name, src_addr, dst_addr.clone(), cfg.clone(), @@ -150,6 +156,7 @@ impl Dispatcher { ), ProxyImpl::Ssh(cfg) => ( Box::new(SshOutboundHandle::new( + proxy_name, iface_name, dst_addr.clone(), self.dns.clone(), @@ -168,6 +175,7 @@ impl Dispatcher { pub(super) fn create_chain( &self, + chain_name: &str, vec: &[GeneralProxy], src_addr: SocketAddr, dst_addr: &NetworkAddr, @@ -176,8 +184,8 @@ impl Dispatcher { let impls: Vec<_> = vec .iter() .map(|n| match n { - GeneralProxy::Single(p) => p.get_impl(), - GeneralProxy::Group(g) => g.get_proxy().get_impl(), + GeneralProxy::Single(p) => (p.get_name(), p.get_impl()), + GeneralProxy::Group(g) => (g.get_name(), g.get_proxy().get_impl()), }) .collect(); let mut res = vec![]; @@ -187,7 +195,7 @@ impl Dispatcher { // if A->B->C, then vec is [C, B, A] dst_addrs.push(dst_addr.clone()); for idx in 1..vec.len() { - let proxy_impl = impls.get(idx - 1).unwrap().as_ref(); + let proxy_impl = impls.get(idx - 1).unwrap().1.as_ref(); if let Some(dst) = proxy_impl.server_addr() { dst_addrs.push(dst); } else { @@ -197,16 +205,18 @@ impl Dispatcher { } for idx in 0..vec.len() { + let proxy = impls.get(idx).unwrap(); let (outbounding, _) = self.build_normal_outbound( + &proxy.0, iface_name, - impls.get(idx).unwrap().as_ref(), + &proxy.1, src_addr, dst_addrs.get(idx).unwrap(), None, )?; res.push(outbounding); } - Ok(ChainOutbound::new(res)) + Ok(ChainOutbound::new(chain_name, res)) } pub async fn submit_tcp( @@ -238,7 +248,7 @@ impl Dispatcher { match proxy_config.as_ref() { ProxyImpl::Chain(vec) => ( Box::new( - self.create_chain(vec, src_addr, &dst_addr, iface_name) + self.create_chain(&proxy_name, vec, src_addr, &dst_addr, iface_name) .map_err(|_| DispatchError::BadChain)?, ), OutboundType::Chain, @@ -252,6 +262,7 @@ impl Dispatcher { } _ => self .build_normal_outbound( + &proxy_name, iface_name, proxy_config.as_ref(), src_addr, @@ -405,7 +416,7 @@ impl Dispatcher { match proxy_config.as_ref() { ProxyImpl::Chain(vec) => ( Box::new( - self.create_chain(vec, src_addr, &dst_addr, iface_name) + self.create_chain(&proxy_name, vec, src_addr, &dst_addr, iface_name) .map_err(|_| DispatchError::Reject)?, ), OutboundType::Chain, @@ -413,6 +424,7 @@ impl Dispatcher { ProxyImpl::BlackHole => return Err(DispatchError::BlackHole), _ => self .build_normal_outbound( + &proxy_name, iface_name, proxy_config.as_ref(), src_addr, diff --git a/boltconn/src/proxy/mod.rs b/boltconn/src/proxy/mod.rs index 4726dae..105e723 100644 --- a/boltconn/src/proxy/mod.rs +++ b/boltconn/src/proxy/mod.rs @@ -101,6 +101,7 @@ pub async fn latency_test( let creator: Box = match proxy.get_impl().as_ref() { ProxyImpl::Chain(vec) => { match dispatcher.create_chain( + &proxy.get_name(), vec, get_random_local_addr(&dst_addr, rng.gen_range(32768..65535)), &dst_addr, @@ -115,6 +116,7 @@ pub async fn latency_test( } proxy_config => { let creator = match dispatcher.build_normal_outbound( + &proxy.get_name(), iface.as_str(), proxy_config, get_random_local_addr(&dst_addr, rng.gen_range(32768..65535)), From 0faedfb59a46eddb02e4011b1cfb85856a21a0ba Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:56:39 -0400 Subject: [PATCH 02/14] chore(adapter): show proxy name on data path error --- boltconn/src/adapter/direct.rs | 3 ++- boltconn/src/adapter/http.rs | 2 +- boltconn/src/adapter/mod.rs | 21 ++++++++++++++------- boltconn/src/adapter/shadowsocks.rs | 3 ++- boltconn/src/adapter/socks5.rs | 3 ++- boltconn/src/adapter/ssh.rs | 2 +- boltconn/src/adapter/trojan.rs | 6 ++++-- 7 files changed, 26 insertions(+), 14 deletions(-) diff --git a/boltconn/src/adapter/direct.rs b/boltconn/src/adapter/direct.rs index 572d692..e50f426 100644 --- a/boltconn/src/adapter/direct.rs +++ b/boltconn/src/adapter/direct.rs @@ -45,7 +45,7 @@ impl DirectOutbound { }; let outbound = Egress::new(&self.iface_name).tcp_stream(dst_addr).await?; - established_tcp(inbound, outbound, abort_handle).await; + established_tcp(self.id(), inbound, outbound, abort_handle).await; Ok(()) } @@ -56,6 +56,7 @@ impl DirectOutbound { ) -> io::Result<()> { let outbound = Arc::new(Egress::new(&self.iface_name).udpv4_socket().await?); established_udp( + self.id(), inbound, DirectUdpAdapter(outbound, self.dns.clone()), None, diff --git a/boltconn/src/adapter/http.rs b/boltconn/src/adapter/http.rs index 42e00d1..e3a0914 100644 --- a/boltconn/src/adapter/http.rs +++ b/boltconn/src/adapter/http.rs @@ -92,7 +92,7 @@ impl HttpOutbound { .map_err(|_| io_err("Parse response failed"))?; if let Some(200) = resp_struct.code { let tcp_stream = buf_reader.into_inner(); - established_tcp(inbound, tcp_stream, abort_handle).await; + established_tcp(self.name, inbound, tcp_stream, abort_handle).await; Ok(()) } else { Err(io_err( diff --git a/boltconn/src/adapter/mod.rs b/boltconn/src/adapter/mod.rs index e92eeac..d6d3466 100644 --- a/boltconn/src/adapter/mod.rs +++ b/boltconn/src/adapter/mod.rs @@ -213,20 +213,25 @@ fn empty_handle() -> JoinHandle> { } #[tracing::instrument(skip_all)] -async fn established_tcp(inbound: Connector, outbound: T, abort_handle: ConnAbortHandle) -where +async fn established_tcp( + name: String, + inbound: Connector, + outbound: T, + abort_handle: ConnAbortHandle, +) where T: AsyncWrite + AsyncRead + Unpin + Send + 'static, { let (mut out_read, mut out_write) = tokio::io::split(outbound); let Connector { tx, mut rx } = inbound; // recv from inbound and send to outbound let abort_handle2 = abort_handle.clone(); + let name2 = name.clone(); let _guard = DuplexCloseGuard::new( tokio::spawn(async move { while let Some(buf) = rx.recv().await { let res = out_write.write_all(buf.as_ref()).await; if let Err(err) = res { - tracing::debug!("write to outbound failed: {}", err); + tracing::debug!("[{}] write to outbound failed: {}", name2, err); abort_handle2.cancel(); break; } @@ -252,7 +257,7 @@ where } } Err(err) => { - tracing::debug!("outbound read error: {}", err); + tracing::debug!("[{}] outbound read error: {}", name, err); abort_handle.cancel(); break; } @@ -263,6 +268,7 @@ where #[tracing::instrument(skip_all)] async fn established_udp( + name: String, inbound: AddrConnector, outbound: S, tunnel_addr: Option, @@ -274,6 +280,7 @@ async fn established_udp( let tunnel_addr2 = tunnel_addr.clone(); let AddrConnector { tx, mut rx } = inbound; let abort_handle2 = abort_handle.clone(); + let name2 = name.clone(); let _guard = UdpDropGuard(tokio::spawn(async move { // recv from outbound and send to inbound loop { @@ -292,12 +299,12 @@ async fn established_udp( } } if tx.send((buf.freeze(), addr)).await.is_err() { - tracing::debug!("write to inbound failed"); + tracing::debug!("[{}] write to inbound failed", name); break; } } Err(err) => { - tracing::debug!("outbound read error: {}", err); + tracing::debug!("[{}] outbound read error: {}", name, err); break; } } @@ -309,7 +316,7 @@ async fn established_udp( let addr = tunnel_addr2.clone().unwrap_or(addr); let res = outbound2.send_to(buf.as_ref(), addr).await; if let Err(err) = res { - tracing::debug!("write to outbound failed: {}", err); + tracing::debug!("[{}] write to outbound failed: {}", name2, err); break; } } diff --git a/boltconn/src/adapter/shadowsocks.rs b/boltconn/src/adapter/shadowsocks.rs index 20d03dd..6a91601 100644 --- a/boltconn/src/adapter/shadowsocks.rs +++ b/boltconn/src/adapter/shadowsocks.rs @@ -108,7 +108,7 @@ impl SSOutbound { let (target_addr, context, resolved_config) = self.create_internal(server_addr).await; let ss_stream = ProxyClientStream::from_stream(context, outbound, &resolved_config, target_addr); - established_tcp(inbound, ss_stream, abort_handle).await; + established_tcp(self.name, inbound, ss_stream, abort_handle).await; Ok(()) } @@ -123,6 +123,7 @@ impl SSOutbound { let (_, context, resolved_config) = self.create_internal(server_addr).await; let proxy_socket = ShadowsocksUdpAdapter::new(context, &resolved_config, adapter_or_socket); established_udp( + self.name, inbound, proxy_socket, if tunnel_only { Some(self.dst) } else { None }, diff --git a/boltconn/src/adapter/socks5.rs b/boltconn/src/adapter/socks5.rs index 70d97ba..3bdd1ae 100644 --- a/boltconn/src/adapter/socks5.rs +++ b/boltconn/src/adapter/socks5.rs @@ -91,7 +91,7 @@ impl Socks5Outbound { .request(Socks5Command::TCPConnect, target) .await .map_err(as_io_err)?; - established_tcp(inbound, socks_stream, abort_handle).await; + established_tcp(self.name, inbound, socks_stream, abort_handle).await; Ok(()) } @@ -126,6 +126,7 @@ impl Socks5Outbound { .unwrap(); out_sock.connect(bound_addr).await?; established_udp( + self.name, inbound, Socks5UdpAdapter(out_sock), if tunnel_only { Some(self.dst) } else { None }, diff --git a/boltconn/src/adapter/ssh.rs b/boltconn/src/adapter/ssh.rs index cef89da..6fc7429 100644 --- a/boltconn/src/adapter/ssh.rs +++ b/boltconn/src/adapter/ssh.rs @@ -78,7 +78,7 @@ impl SshOutboundHandle { } }; let channel = master_conn.new_mapped_connection(self.dst.clone()).await?; - established_tcp(inbound, channel, abort_handle).await; + established_tcp(self.name, inbound, channel, abort_handle).await; Ok(()) } } diff --git a/boltconn/src/adapter/trojan.rs b/boltconn/src/adapter/trojan.rs index ab2169a..dbae0c1 100644 --- a/boltconn/src/adapter/trojan.rs +++ b/boltconn/src/adapter/trojan.rs @@ -69,11 +69,11 @@ impl TrojanOutbound { .map_err(|e| io_err(e.to_string().as_str()))?; self.first_packet(first_packet, TrojanCmd::Connect, &mut stream) .await?; - established_tcp(inbound, stream, abort_handle).await; + established_tcp(self.name, inbound, stream, abort_handle).await; } else { self.first_packet(first_packet, TrojanCmd::Connect, &mut stream) .await?; - established_tcp(inbound, stream, abort_handle).await; + established_tcp(self.name, inbound, stream, abort_handle).await; } Ok(()) } @@ -103,6 +103,7 @@ impl TrojanOutbound { socket: Arc::new(udp_socket), }; established_udp( + self.name, inbound, adapter, if tunnel_only { Some(self.dst) } else { None }, @@ -117,6 +118,7 @@ impl TrojanOutbound { socket: Arc::new(udp_socket), }; established_udp( + self.name, inbound, adapter, if tunnel_only { Some(self.dst) } else { None }, From f20d15f5ce56360d0081d0b18b5be48b4147d114 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Wed, 25 Sep 2024 23:38:47 -0400 Subject: [PATCH 03/14] perf(wireguard): reduce preallocated buffer pool size --- boltconn/src/adapter/wireguard.rs | 13 +++++++++++-- boltconn/src/proxy/dispatcher.rs | 5 ++++- boltconn/src/transport/smol.rs | 3 +++ boltconn/src/transport/wireguard.rs | 3 ++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 9c1a41b..ff8dcb3 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -40,6 +40,7 @@ pub struct Endpoint { impl Endpoint { pub async fn new( + name: &str, outbound: AdapterOrSocket, config: &WireguardConfig, endpoint_resolver: Arc, @@ -80,7 +81,13 @@ impl Endpoint { resolver, config.dns_preference, )); - Mutex::new(SmolStack::new(iface, device, dns, Duration::from_secs(120))) + Mutex::new(SmolStack::new( + name, + iface, + device, + dns, + Duration::from_secs(120), + )) }) }; @@ -265,6 +272,7 @@ impl WireguardManager { pub async fn get_wg_conn( &self, + name: &str, config: &WireguardConfig, adapter: Option, ret_tx: tokio::sync::oneshot::Sender, @@ -310,6 +318,7 @@ impl WireguardManager { } }; let ep = Endpoint::new( + name, outbound, config, self.endpoint_resolver.clone(), @@ -385,7 +394,7 @@ impl WireguardHandle { ret_tx: tokio::sync::oneshot::Sender, ) -> io::Result> { self.manager - .get_wg_conn(&self.config, adapter, ret_tx) + .get_wg_conn(&self.name, &self.config, adapter, ret_tx) .await .map_err(|e| io_err(format!("{}", e).as_str())) } diff --git a/boltconn/src/proxy/dispatcher.rs b/boltconn/src/proxy/dispatcher.rs index 38a3921..1c005b2 100644 --- a/boltconn/src/proxy/dispatcher.rs +++ b/boltconn/src/proxy/dispatcher.rs @@ -185,7 +185,10 @@ impl Dispatcher { .iter() .map(|n| match n { GeneralProxy::Single(p) => (p.get_name(), p.get_impl()), - GeneralProxy::Group(g) => (g.get_name(), g.get_proxy().get_impl()), + GeneralProxy::Group(g) => { + let proxy = g.get_proxy(); + (proxy.get_name(), proxy.get_impl()) + } }) .collect(); let mut res = vec![]; diff --git a/boltconn/src/transport/smol.rs b/boltconn/src/transport/smol.rs index f1e9cdd..31d3964 100644 --- a/boltconn/src/transport/smol.rs +++ b/boltconn/src/transport/smol.rs @@ -384,6 +384,7 @@ impl Drop for UdpConnTask { // Program -- TCP/UDP -> SmolStack -> IP -- Internet // \ TCP/UDP <- SmolStack <- IP / pub struct SmolStack { + name: String, tcp_conn: DashMap, udp_conn: DashMap, ip_addr: InterfaceAddress, @@ -396,6 +397,7 @@ pub struct SmolStack { impl SmolStack { pub fn new( + name: &str, iface_ip: InterfaceAddress, mut ip_device: VirtualIpDevice, dns: Arc>, @@ -416,6 +418,7 @@ impl SmolStack { } }); Self { + name: name.to_string(), tcp_conn: Default::default(), udp_conn: Default::default(), ip_addr: iface_ip, diff --git a/boltconn/src/transport/wireguard.rs b/boltconn/src/transport/wireguard.rs index 316204a..a666a4b 100644 --- a/boltconn/src/transport/wireguard.rs +++ b/boltconn/src/transport/wireguard.rs @@ -125,7 +125,8 @@ impl WireguardTunnel { let buf_pool = Arc::new({ let pool = Pool::>::new(); // allocate memory in advance - const INIT_POOL_SIZE: usize = 4096; + // TODO: profile to determine removal + const INIT_POOL_SIZE: usize = 128; let mut index_arr = [0; INIT_POOL_SIZE]; for entry in index_arr.iter_mut() { *entry = pool From 6354b3576dd319ba5d090880f95c11a6c81bc92c Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:55:46 -0400 Subject: [PATCH 04/14] chore(dns): add name for debug information --- boltconn/src/adapter/wireguard.rs | 1 + boltconn/src/app.rs | 1 + boltconn/src/network/dns/dns.rs | 51 +++++++++++++++++++++---------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index ff8dcb3..be617be 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -78,6 +78,7 @@ impl Endpoint { ) }; let dns = Arc::new(GenericDns::new_with_resolver( + name, resolver, config.dns_preference, )); diff --git a/boltconn/src/app.rs b/boltconn/src/app.rs index d523585..b83bafd 100644 --- a/boltconn/src/app.rs +++ b/boltconn/src/app.rs @@ -447,6 +447,7 @@ async fn initialize_dns( .await .map_err(|e| anyhow!("Parse nameserver policy failed: {e}"))?; Arc::new(Dns::with_config( + "default", outbound_iface, config.preference, &config.hosts, diff --git a/boltconn/src/network/dns/dns.rs b/boltconn/src/network/dns/dns.rs index 0cc9511..704ddc5 100644 --- a/boltconn/src/network/dns/dns.rs +++ b/boltconn/src/network/dns/dns.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use std::time::Duration; pub struct GenericDns { + name: String, table: DnsTable, preference: DnsPreference, host_resolver: ArcSwap, @@ -28,6 +29,7 @@ pub type Dns = GenericDns; impl Dns { pub fn with_config( + name: &str, iface_name: &str, preference: DnsPreference, hosts: &HashMap, @@ -45,6 +47,7 @@ impl Dns { } let host_resolver = HostsResolver::new(hosts); Dns { + name: name.to_string(), table: DnsTable::new(), preference, host_resolver: ArcSwap::new(Arc::new(host_resolver)), @@ -78,10 +81,12 @@ impl Dns { impl GenericDns

{ pub fn new_with_resolver( + name: &str, resolver: AsyncResolver>, preference: DnsPreference, ) -> Self { Self { + name: name.to_string(), table: DnsTable::new(), preference, host_resolver: ArcSwap::new(Arc::new(HostsResolver::empty())), @@ -109,7 +114,7 @@ impl GenericDns

{ async fn genuine_lookup_v4(&self, domain_name: &str) -> Option { for r in self.resolvers.load().iter() { - if let Some(ip) = Self::genuine_lookup_one_v4(domain_name, r).await { + if let Some(ip) = Self::genuine_lookup_one_v4(&self.name, domain_name, r).await { return Some(ip); } } @@ -117,6 +122,7 @@ impl GenericDns

{ } async fn genuine_lookup_one_v4( + name: &str, domain_name: &str, resolver: &AsyncResolver>, ) -> Option { @@ -129,14 +135,14 @@ impl GenericDns

{ } } } else { - tracing::debug!("DNS v4 lookup for {domain_name} timeout: 5s"); + tracing::debug!("DNS {name} v4 lookup for {domain_name} timeout: 5s"); } None } async fn genuine_lookup_v6(&self, domain_name: &str) -> Option { for r in self.resolvers.load().iter() { - if let Some(ip) = Self::genuine_lookup_one_v6(domain_name, r).await { + if let Some(ip) = Self::genuine_lookup_one_v6(&self.name, domain_name, r).await { return Some(ip); } } @@ -144,6 +150,7 @@ impl GenericDns

{ } async fn genuine_lookup_one_v6( + name: &str, domain_name: &str, resolver: &AsyncResolver>, ) -> Option { @@ -156,29 +163,37 @@ impl GenericDns

{ } } } else { - tracing::debug!("DNS v6 lookup for {domain_name} timeout: 5s"); + tracing::debug!("DNS {name} v6 lookup for {domain_name} timeout: 5s"); } None } - async fn one_v4_wrapper(domain_name: &str, resolver: &DispatchedDnsResolver) -> Option { + async fn one_v4_wrapper( + name: &str, + domain_name: &str, + resolver: &DispatchedDnsResolver, + ) -> Option { match resolver { DispatchedDnsResolver::Iface(resolver) => { - Self::genuine_lookup_one_v4(domain_name, resolver).await + Self::genuine_lookup_one_v4(name, domain_name, resolver).await } DispatchedDnsResolver::Plain(resolver) => { - Self::genuine_lookup_one_v4(domain_name, resolver).await + Self::genuine_lookup_one_v4(name, domain_name, resolver).await } } } - async fn one_v6_wrapper(domain_name: &str, resolver: &DispatchedDnsResolver) -> Option { + async fn one_v6_wrapper( + name: &str, + domain_name: &str, + resolver: &DispatchedDnsResolver, + ) -> Option { match resolver { DispatchedDnsResolver::Iface(resolver) => { - Self::genuine_lookup_one_v6(domain_name, resolver).await + Self::genuine_lookup_one_v6(name, domain_name, resolver).await } DispatchedDnsResolver::Plain(resolver) => { - Self::genuine_lookup_one_v6(domain_name, resolver).await + Self::genuine_lookup_one_v6(name, domain_name, resolver).await } } } @@ -201,20 +216,24 @@ impl GenericDns

{ } if let Some(resolver) = self.ns_policy.load().resolve(domain_name) { return match pref { - DnsPreference::Ipv4Only => Self::one_v4_wrapper(domain_name, resolver).await, - DnsPreference::Ipv6Only => Self::one_v6_wrapper(domain_name, resolver).await, + DnsPreference::Ipv4Only => { + Self::one_v4_wrapper(&self.name, domain_name, resolver).await + } + DnsPreference::Ipv6Only => { + Self::one_v6_wrapper(&self.name, domain_name, resolver).await + } DnsPreference::PreferIpv4 => { - if let Some(a) = Self::one_v4_wrapper(domain_name, resolver).await { + if let Some(a) = Self::one_v4_wrapper(&self.name, domain_name, resolver).await { Some(a) } else { - Self::one_v6_wrapper(domain_name, resolver).await + Self::one_v6_wrapper(&self.name, domain_name, resolver).await } } DnsPreference::PreferIpv6 => { - if let Some(a) = Self::one_v6_wrapper(domain_name, resolver).await { + if let Some(a) = Self::one_v6_wrapper(&self.name, domain_name, resolver).await { Some(a) } else { - Self::one_v4_wrapper(domain_name, resolver).await + Self::one_v4_wrapper(&self.name, domain_name, resolver).await } } }; From 7e70f280f3e3646ac102c92380a77314332887ce Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Sat, 28 Sep 2024 21:58:06 -0400 Subject: [PATCH 05/14] chore(wireguard): add name for debug information --- boltconn/src/adapter/wireguard.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index be617be..aded020 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -31,6 +31,7 @@ use tokio::task::JoinHandle; // Shared Wireguard Tunnel between multiple client connections pub struct Endpoint { + name: String, wg: Arc, stack: Arc>, stop_sender: broadcast::Sender<()>, @@ -101,15 +102,15 @@ impl Endpoint { let tunnel = tunnel.clone(); let stop_send = stop_send.clone(); let timer = last_active.clone(); + let name = name.to_string(); tokio::spawn(async move { let mut buf = [0u8; MAX_PKT_SIZE]; loop { - if tunnel - .send_outgoing_packet(&mut smol_wg_rx, &mut buf) - .await - .is_err() - { + if let Err(e) = tunnel.send_outgoing_packet(&mut smol_wg_rx, &mut buf).await { let _ = stop_send.send(()); + tracing::trace!( + "[WireGuard] Close connection #{name} for send_outgoing_packet for {e}", + ); return; } *timer.lock().await = Instant::now(); @@ -121,6 +122,7 @@ impl Endpoint { let tunnel = tunnel.clone(); let stop_send = stop_send.clone(); let timer = last_active.clone(); + let name = name.to_string(); tokio::spawn(async move { let mut buf = [0u8; MAX_PKT_SIZE]; let mut wg_buf = [0u8; MAX_PKT_SIZE]; @@ -131,8 +133,9 @@ impl Endpoint { { Ok(true) => *timer.lock().await = Instant::now(), Ok(false) => {} - Err(_) => { + Err(e) => { let _ = stop_send.send(()); + tracing::trace!("[WireGuard] Close connection #{} for {}", name, e); return; } } @@ -229,6 +232,7 @@ impl Endpoint { }); } + let name_clone = name.to_string(); tokio::spawn(async move { // kill all coroutine let _ = stop_recv.recv().await; @@ -237,9 +241,13 @@ impl Endpoint { wg_in.abort(); wg_tick.abort(); smol_drive.abort(); + tracing::trace!("[WireGuard] connection #{} killed", name_clone); }); + tracing::info!("[WireGuard] Established master connection #{}", name); + Ok(Arc::new(Self { + name: name.to_string(), wg: tunnel, stack: smol_stack, stop_sender: stop_send, From 77da01eb2998c48c1e2577c39203a920b8a2d02f Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:51:34 -0400 Subject: [PATCH 06/14] chore(debug): show internal wg stats (internal-test feature required) --- boltapi/src/rpc.rs | 6 ++- boltapi/src/schema.rs | 10 +++++ boltconn/src/adapter/wireguard.rs | 49 ++++++++++++++++++------- boltconn/src/cli/mod.rs | 26 +++++++++++-- boltconn/src/cli/request.rs | 21 +++++++++++ boltconn/src/cli/request_uds.rs | 6 ++- boltconn/src/external/controller.rs | 7 +++- boltconn/src/external/uds_controller.rs | 6 ++- boltconn/src/main.rs | 9 ----- boltconn/src/proxy/dispatcher.rs | 4 ++ boltconn/src/transport/wireguard.rs | 5 +++ 11 files changed, 118 insertions(+), 31 deletions(-) diff --git a/boltapi/src/rpc.rs b/boltapi/src/rpc.rs index c7f3b2c..a7dfa80 100644 --- a/boltapi/src/rpc.rs +++ b/boltapi/src/rpc.rs @@ -1,6 +1,6 @@ use crate::{ - ConnectionSchema, GetGroupRespSchema, GetInterceptDataResp, HttpInterceptSchema, TrafficResp, - TunStatusSchema, + ConnectionSchema, GetGroupRespSchema, GetInterceptDataResp, HttpInterceptSchema, + MasterConnectionStatus, TrafficResp, TunStatusSchema, }; pub const MAX_CODEC_FRAME_LENGTH: usize = 512 * 1024 * 1024; @@ -55,6 +55,8 @@ pub trait ControlService { async fn get_conn_log_limit() -> u32; + async fn get_master_conn_stat() -> Vec; + async fn reload(); // Streaming diff --git a/boltapi/src/schema.rs b/boltapi/src/schema.rs index c3aed87..3742ebb 100644 --- a/boltapi/src/schema.rs +++ b/boltapi/src/schema.rs @@ -118,3 +118,13 @@ pub struct TrafficResp { pub struct TunStatusSchema { pub enabled: bool, } + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(deny_unknown_fields)] +pub struct MasterConnectionStatus { + pub name: String, + pub alive: bool, + pub last_active: u64, + pub last_handshake: u64, + pub hand_shake_is_expired: bool, +} diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index aded020..7d6eb10 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -21,7 +21,6 @@ use hickory_resolver::AsyncResolver; use std::io; use std::io::ErrorKind; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::{ready, Context, Poll}; use std::time::{Duration, Instant}; @@ -36,7 +35,8 @@ pub struct Endpoint { stack: Arc>, stop_sender: broadcast::Sender<()>, notify: Arc, - is_active: Arc, + is_active: AbortCanary, + last_active: Arc>, } impl Endpoint { @@ -94,8 +94,7 @@ impl Endpoint { }; let last_active = Arc::new(Mutex::new(Instant::now())); - let indicator = Arc::new(AtomicBool::new(true)); - let indi_write = indicator.clone(); + let (indicator, indi_write) = AbortCanary::pair(); // drive wg tunnel let wg_out = { @@ -175,15 +174,15 @@ impl Endpoint { }; // drive smol - let smol_drive = { + { let smol_stack = smol_stack.clone(); let notifier = notify.clone(); - let (abort_canary, canary_clone) = AbortCanary::pair(); + let smol_canary = indicator.clone(); local_async_run(async move { let mut immediate_next_loop = false; notifier.notified().await; - while abort_canary.alive() { + while smol_canary.alive() { let mut stack_handle = smol_stack.lock().await; stack_handle.drive_iface(); immediate_next_loop |= stack_handle.poll_all_tcp().await; @@ -206,19 +205,20 @@ impl Endpoint { } immediate_next_loop = false; } + smol_canary.abort(); }); - canary_clone - }; + } // timeout inactive tunnel { let stop_send = stop_send.clone(); - let indi_write = indi_write.clone(); + let indi_write = indicator.clone(); let name = config.name.clone(); + let last_active = last_active.clone(); tokio::spawn(async move { loop { if last_active.lock().await.elapsed() > timeout { - indi_write.store(false, Ordering::Relaxed); + indi_write.abort(); let _ = stop_send.send(()); tracing::debug!( "[WireGuard] Stop inactive tunnel #{} after for {}s.", @@ -236,11 +236,10 @@ impl Endpoint { tokio::spawn(async move { // kill all coroutine let _ = stop_recv.recv().await; - indi_write.store(false, Ordering::Relaxed); + indi_write.abort(); wg_out.abort(); wg_in.abort(); wg_tick.abort(); - smol_drive.abort(); tracing::trace!("[WireGuard] connection #{} killed", name_clone); }); @@ -253,12 +252,24 @@ impl Endpoint { stop_sender: stop_send, notify, is_active: indicator, + last_active, })) } pub fn clone_notify(&self) -> Arc { self.notify.clone() } + + pub async fn debug_internal_state(&self) -> boltapi::MasterConnectionStatus { + let tunn_state = self.wg.stats().await; + boltapi::MasterConnectionStatus { + name: self.name.clone(), + alive: self.is_active.alive(), + last_active: self.last_active.lock().await.elapsed().as_secs(), + last_handshake: tunn_state.1.map(|x| x.as_secs()).unwrap_or(114514), + hand_shake_is_expired: tunn_state.0, + } + } } pub struct WireguardManager { @@ -290,7 +301,7 @@ impl WireguardManager { // get an existing conn, or create let mut guard = self.active_conn.lock().await; if let Some(endpoint) = guard.get(config) { - if endpoint.is_active.load(Ordering::Relaxed) { + if endpoint.is_active.alive() { let _ = ret_tx.send(false); return Ok(endpoint.clone()); } else { @@ -342,6 +353,16 @@ impl WireguardManager { "get_wg_conn: unexpected loop time", )) } + + pub async fn debug_internal_state(&self) -> Vec { + let conns = self.active_conn.lock().await; + let mut ret = Vec::new(); + for (_, v) in conns.iter() { + let r = v.debug_internal_state().await; + ret.push(r); + } + ret + } } #[derive(Clone)] diff --git a/boltconn/src/cli/mod.rs b/boltconn/src/cli/mod.rs index 7bc0a8c..4aed2ae 100644 --- a/boltconn/src/cli/mod.rs +++ b/boltconn/src/cli/mod.rs @@ -174,6 +174,12 @@ pub(crate) enum LogsLimitOptions { Get, } +#[derive(Debug, Clone, Copy, Subcommand)] +pub(crate) enum MasterConnOptions { + /// Show the WireGuard master connections + Wg, +} + #[derive(Debug, Subcommand)] pub(crate) enum SubCommand { /// Start the main program @@ -209,6 +215,10 @@ pub(crate) enum SubCommand { Generate(GenerateOptions), #[cfg(feature = "internal-test")] #[clap(hide = true)] + #[command(subcommand)] + MasterConn(MasterConnOptions), + #[cfg(feature = "internal-test")] + #[clap(hide = true)] Internal, } @@ -375,6 +385,9 @@ pub(crate) async fn controller_main(args: ProgramArgs) -> ! { DnsOptions::Lookup { domain_name } => requester.real_lookup(domain_name).await, DnsOptions::Mapping { fake_ip } => requester.fake_ip_to_real(fake_ip).await, }, + SubCommand::MasterConn(opt) => match opt { + MasterConnOptions::Wg => requester.master_conn_stats().await, + }, SubCommand::Start(_) | SubCommand::Generate(_) | SubCommand::Clean @@ -383,9 +396,7 @@ pub(crate) async fn controller_main(args: ProgramArgs) -> ! { unreachable!() } #[cfg(feature = "internal-test")] - SubCommand::Internal => { - unreachable!() - } + SubCommand::Internal => internal_code(requester).await, }; match result { Ok(_) => exit(0), @@ -395,3 +406,12 @@ pub(crate) async fn controller_main(args: ProgramArgs) -> ! { } } } + +#[cfg(feature = "internal-test")] +use crate::cli::request::Requester; +/// This function is a shortcut for testing things conveniently. Only for development use. +#[cfg(feature = "internal-test")] +async fn internal_code(_requester: Requester) -> anyhow::Result<()> { + println!("This option is not for end-user."); + Ok(()) +} diff --git a/boltconn/src/cli/request.rs b/boltconn/src/cli/request.rs index eb16458..50532e4 100644 --- a/boltconn/src/cli/request.rs +++ b/boltconn/src/cli/request.rs @@ -306,6 +306,27 @@ impl Requester { Inner::Uds(c) => c.reload_config().await, } } + + pub async fn master_conn_stats(&self) -> Result<()> { + match &self.inner { + Inner::Web(_) => Err(anyhow::anyhow!("conn-stats: Not supported by RESTful API")), + Inner::Uds(c) => { + let list = c.get_master_conn_stat().await?; + for entry in list { + let alive_str = |alive: bool| if alive { "alive" } else { "dead" }; + println!( + "{}:\t smol[{}, last active in {}s], wg=[{}, last handshake in {}s]", + entry.name, + alive_str(entry.alive), + entry.last_active, + alive_str(!entry.hand_shake_is_expired), + entry.last_handshake + ); + } + Ok(()) + } + } + } } fn pretty_size(data: u64) -> String { diff --git a/boltconn/src/cli/request_uds.rs b/boltconn/src/cli/request_uds.rs index 0d7c35e..fb06bd2 100644 --- a/boltconn/src/cli/request_uds.rs +++ b/boltconn/src/cli/request_uds.rs @@ -3,7 +3,7 @@ use boltapi::multiplex::rpc_multiplex_twoway; use boltapi::rpc::{ClientStreamServiceRequest, ClientStreamServiceResponse, ControlServiceClient}; use boltapi::{ ConnectionSchema, GetGroupRespSchema, GetInterceptDataResp, HttpInterceptSchema, - TunStatusSchema, + MasterConnectionStatus, TunStatusSchema, }; use tarpc::context::Context; use tarpc::tokio_util::codec::LengthDelimitedCodec; @@ -171,4 +171,8 @@ impl UdsConnector { pub async fn reload_config(&self) -> Result<()> { Ok(self.client.reload(Context::current()).await?) } + + pub async fn get_master_conn_stat(&self) -> Result> { + Ok(self.client.get_master_conn_stat(Context::current()).await?) + } } diff --git a/boltconn/src/external/controller.rs b/boltconn/src/external/controller.rs index 34084a4..b1fb2e0 100644 --- a/boltconn/src/external/controller.rs +++ b/boltconn/src/external/controller.rs @@ -9,7 +9,8 @@ use crate::proxy::{ }; use boltapi::{ ConnectionSchema, GetGroupRespSchema, GetInterceptDataResp, GetInterceptRangeReq, - HttpInterceptSchema, ProcessSchema, ProxyData, SessionSchema, TrafficResp, TunStatusSchema, + HttpInterceptSchema, MasterConnectionStatus, ProcessSchema, ProxyData, SessionSchema, + TrafficResp, TunStatusSchema, }; use std::collections::HashSet; use std::io::Write; @@ -389,6 +390,10 @@ impl Controller { self.stat_center.get_conn_log_limit() } + pub async fn get_master_conn_stat(&self) -> Vec { + self.dispatcher.get_wg_mgr().debug_internal_state().await + } + pub async fn real_lookup(&self, domain_name: String) -> Option { self.dns .genuine_lookup(domain_name.as_str()) diff --git a/boltconn/src/external/uds_controller.rs b/boltconn/src/external/uds_controller.rs index d06991d..5ac0603 100644 --- a/boltconn/src/external/uds_controller.rs +++ b/boltconn/src/external/uds_controller.rs @@ -5,7 +5,7 @@ use boltapi::multiplex::rpc_multiplex_twoway; use boltapi::rpc::{ClientStreamServiceClient, ControlService}; use boltapi::{ ConnectionSchema, GetGroupRespSchema, GetInterceptDataResp, GetInterceptRangeReq, - HttpInterceptSchema, TrafficResp, TunStatusSchema, + HttpInterceptSchema, MasterConnectionStatus, TrafficResp, TunStatusSchema, }; use std::io; use std::sync::Arc; @@ -343,6 +343,10 @@ impl ControlService for UdsRpcServer { self.controller.get_conn_log_limit() } + async fn get_master_conn_stat(self, _ctx: Context) -> Vec { + self.controller.get_master_conn_stat().await + } + async fn reload(self, _ctx: Context) { self.controller.reload().await } diff --git a/boltconn/src/main.rs b/boltconn/src/main.rs index 11b563b..0eafe55 100644 --- a/boltconn/src/main.rs +++ b/boltconn/src/main.rs @@ -53,8 +53,6 @@ fn main() -> ExitCode { let args: ProgramArgs = ProgramArgs::parse(); let cmds = match args.cmd { SubCommand::Start(sub) => sub, - #[cfg(feature = "internal-test")] - SubCommand::Internal => return internal_code(), _ => rt.block_on(cli::controller_main(args)), }; if !is_root() { @@ -137,10 +135,3 @@ pub(crate) fn process_path(cmds: &StartOptions) -> Result<(PathBuf, PathBuf, Pat } Ok((config_path, data_path, cert_path)) } - -/// This function is a shortcut for testing things conveniently. Only for development use. -#[cfg(feature = "internal-test")] -fn internal_code() -> ExitCode { - println!("This option is not for end-user."); - ExitCode::SUCCESS -} diff --git a/boltconn/src/proxy/dispatcher.rs b/boltconn/src/proxy/dispatcher.rs index 1c005b2..b43c22c 100644 --- a/boltconn/src/proxy/dispatcher.rs +++ b/boltconn/src/proxy/dispatcher.rs @@ -80,6 +80,10 @@ impl Dispatcher { self.modifier.store(Arc::new(closure)); } + pub fn get_wg_mgr(&self) -> Arc { + self.wireguard_mgr.clone() + } + pub(super) fn get_iface_name(&self) -> String { self.iface_name.clone() } diff --git a/boltconn/src/transport/wireguard.rs b/boltconn/src/transport/wireguard.rs index a666a4b..e9a73e3 100644 --- a/boltconn/src/transport/wireguard.rs +++ b/boltconn/src/transport/wireguard.rs @@ -316,6 +316,11 @@ impl WireguardTunnel { } } } + + pub async fn stats(&self) -> (bool, Option) { + let tun = self.tunnel.lock().await; + (tun.is_expired(), tun.stats().0) + } } impl WireguardTunnelInner { From 4d8070f5b39b06683016359c48058ee01f1deb30 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Tue, 1 Oct 2024 19:57:38 -0400 Subject: [PATCH 07/14] fix(wireguard): correctly close the connection when error occurs --- boltconn/src/proxy/tun_udp_inbound.rs | 4 +++- boltconn/src/transport/wireguard.rs | 27 +++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/boltconn/src/proxy/tun_udp_inbound.rs b/boltconn/src/proxy/tun_udp_inbound.rs index b946cde..d185496 100644 --- a/boltconn/src/proxy/tun_udp_inbound.rs +++ b/boltconn/src/proxy/tun_udp_inbound.rs @@ -109,7 +109,9 @@ impl TunUdpInbound { // hijack dns if let Ok(answer) = self.dns.respond_to_query(payload.as_ref()) { let raw_data = create_raw_udp_pkt(answer.as_ref(), dst, src); - let _ = self.tun_tx.send(raw_data.freeze()); + if let Err(e) = self.tun_tx.send_async(raw_data.freeze()).await { + tracing::error!("TUN back tx closed"); + } } } else { // retry once diff --git a/boltconn/src/transport/wireguard.rs b/boltconn/src/transport/wireguard.rs index e9a73e3..1a01f0a 100644 --- a/boltconn/src/transport/wireguard.rs +++ b/boltconn/src/transport/wireguard.rs @@ -151,6 +151,7 @@ impl WireguardTunnel { let socket = Arc::new(s); let socket_clone = socket.clone(); let pool = buf_pool.clone(); + let name = config.name.clone(); local_async_run(async move { // dedicated to poll UDP from small kernel buffer loop { @@ -166,8 +167,16 @@ impl WireguardTunnel { } None => { let mut buf = vec![0; MAX_UDP_PKT_SIZE]; - let Ok(len) = socket.recv(&mut buf).await else { - break; + let len = match socket.recv(&mut buf).await { + Ok(len) => len, + Err(e) => { + tracing::warn!( + "WireGuard #{} failed to receive: {}", + name, + e + ); + break; + } }; buf.resize(len, 0); BufferIndex::Raw(buf) @@ -181,9 +190,17 @@ impl WireguardTunnel { } } }); + let name = config.name.clone(); tokio::spawn(async move { while let Ok(data) = out_rx.recv_async().await { - socket_clone.send(&data).await?; + if let Err(e) = socket_clone.send(&data).await { + tracing::warn!( + "WireGuard #{} outbound send failed: {}", + name, + e + ); + return Err(e); + } } Ok::<(), io::Error>(()) }); @@ -340,7 +357,9 @@ impl WireguardTunnelInner { AdapterOrChannel::Channel(c, _) => { let data = Bytes::copy_from_slice(data); let len = data.len(); - let _ = c.send(data); + c.send_async(data) + .await + .map_err(|_| io_err("WireGuard outbound channel closed"))?; Ok(len) } } From 1004793945bb31ef8cd3530be2b942a42d62237a Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Tue, 1 Oct 2024 20:08:11 -0400 Subject: [PATCH 08/14] fix: compliation error about feature gates --- boltconn/src/cli/mod.rs | 1 + boltconn/src/proxy/tun_udp_inbound.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/boltconn/src/cli/mod.rs b/boltconn/src/cli/mod.rs index 4aed2ae..afcf6b0 100644 --- a/boltconn/src/cli/mod.rs +++ b/boltconn/src/cli/mod.rs @@ -385,6 +385,7 @@ pub(crate) async fn controller_main(args: ProgramArgs) -> ! { DnsOptions::Lookup { domain_name } => requester.real_lookup(domain_name).await, DnsOptions::Mapping { fake_ip } => requester.fake_ip_to_real(fake_ip).await, }, + #[cfg(feature = "internal-test")] SubCommand::MasterConn(opt) => match opt { MasterConnOptions::Wg => requester.master_conn_stats().await, }, diff --git a/boltconn/src/proxy/tun_udp_inbound.rs b/boltconn/src/proxy/tun_udp_inbound.rs index d185496..adc03c4 100644 --- a/boltconn/src/proxy/tun_udp_inbound.rs +++ b/boltconn/src/proxy/tun_udp_inbound.rs @@ -109,7 +109,7 @@ impl TunUdpInbound { // hijack dns if let Ok(answer) = self.dns.respond_to_query(payload.as_ref()) { let raw_data = create_raw_udp_pkt(answer.as_ref(), dst, src); - if let Err(e) = self.tun_tx.send_async(raw_data.freeze()).await { + if self.tun_tx.send_async(raw_data.freeze()).await.is_err() { tracing::error!("TUN back tx closed"); } } From 4e80001671a8584f7ca6593cf659f6e26a5759ab Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Wed, 2 Oct 2024 18:06:30 -0400 Subject: [PATCH 09/14] fix(smol): handle dropped errors --- boltconn/src/transport/smol.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/boltconn/src/transport/smol.rs b/boltconn/src/transport/smol.rs index 31d3964..cd3a8e6 100644 --- a/boltconn/src/transport/smol.rs +++ b/boltconn/src/transport/smol.rs @@ -143,7 +143,9 @@ impl TcpConnTask { // notify smol when new message comes tokio::spawn(async move { while let Some(buf) = back_rx.recv().await { - let _ = tx.send_async(buf).await; + if tx.send_async(buf).await.is_err() { + return; + } notify.notify_one(); } }); @@ -307,7 +309,9 @@ impl UdpConnTask { .await .map(|ip| SocketAddr::new(ip, port)), } { - let _ = tx.send_async((buf, dst)).await; + if tx.send_async((buf, dst)).await.is_err() { + return; + } notify.notify_one(); } } From ce278f7678eab35763b90677de21ae3b70237a7d Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:07:04 -0400 Subject: [PATCH 10/14] fix(dns): handle IO errors --- boltconn/src/adapter/direct.rs | 2 +- boltconn/src/adapter/mod.rs | 15 ++-- boltconn/src/adapter/wireguard.rs | 8 +- boltconn/src/dispatch/action.rs | 2 +- boltconn/src/external/controller.rs | 8 +- boltconn/src/network/dns/dns.rs | 131 +++++++++++++++------------- boltconn/src/transport/smol.rs | 8 +- boltconn/src/transport/wireguard.rs | 2 +- 8 files changed, 95 insertions(+), 81 deletions(-) diff --git a/boltconn/src/adapter/direct.rs b/boltconn/src/adapter/direct.rs index e50f426..6f018a0 100644 --- a/boltconn/src/adapter/direct.rs +++ b/boltconn/src/adapter/direct.rs @@ -127,7 +127,7 @@ impl UdpSocketAdapter for DirectUdpAdapter { let addr = match addr { NetworkAddr::Raw(s) => s, NetworkAddr::DomainName { domain_name, port } => { - let Some(ip) = self.1.genuine_lookup(domain_name.as_str()).await else { + let Ok(Some(ip)) = self.1.genuine_lookup(domain_name.as_str()).await else { // drop return Ok(()); }; diff --git a/boltconn/src/adapter/mod.rs b/boltconn/src/adapter/mod.rs index d6d3466..9bd2680 100644 --- a/boltconn/src/adapter/mod.rs +++ b/boltconn/src/adapter/mod.rs @@ -426,10 +426,10 @@ async fn lookup(dns: &Dns, addr: &NetworkAddr) -> io::Result { ref domain_name, port, } => { - let resp = dns - .genuine_lookup(domain_name.as_str()) - .await - .ok_or_else(|| io_err("dns not found"))?; + let resp = match dns.genuine_lookup(domain_name.as_str()).await { + Ok(Some(resp)) => resp, + _ => return Err(io_err("dns not found")), + }; SocketAddr::new(resp, *port) } }) @@ -440,9 +440,10 @@ pub(super) async fn get_dst(dns: &Dns, dst: &NetworkAddr) -> io::Result { // translate fake ip SocketAddr::new( - dns.genuine_lookup(domain_name.as_str()) - .await - .ok_or_else(|| io_err("DNS failed"))?, + match dns.genuine_lookup(domain_name.as_str()).await { + Ok(Some(resp)) => resp, + _ => return Err(io_err("dns not found")), + }, *port, ) } diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 7d6eb10..4436912 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -407,10 +407,10 @@ impl WireguardHandle { let dst = match self.dst { NetworkAddr::Raw(s) => s, NetworkAddr::DomainName { domain_name, port } => SocketAddr::new( - smol_dns - .genuine_lookup(domain_name.as_str()) - .await - .ok_or::(ErrorKind::AddrNotAvailable.into())?, + match smol_dns.genuine_lookup(domain_name.as_str()).await { + Ok(Some(addr)) => addr, + _ => return Err(ErrorKind::AddrNotAvailable.into()), + }, port, ), }; diff --git a/boltconn/src/dispatch/action.rs b/boltconn/src/dispatch/action.rs index 0f91e93..c4d048e 100644 --- a/boltconn/src/dispatch/action.rs +++ b/boltconn/src/dispatch/action.rs @@ -26,7 +26,7 @@ impl LocalResolve { pub async fn resolve_to(&self, info: &mut ConnInfo) { if info.resolved_dst.is_none() { if let NetworkAddr::DomainName { domain_name, port } = &info.dst { - if let Some(addr) = self.dns.genuine_lookup(domain_name).await { + if let Ok(Some(addr)) = self.dns.genuine_lookup(domain_name).await { info.resolved_dst = Some(SocketAddr::new(addr, *port)); } } diff --git a/boltconn/src/external/controller.rs b/boltconn/src/external/controller.rs index b1fb2e0..2acf19b 100644 --- a/boltconn/src/external/controller.rs +++ b/boltconn/src/external/controller.rs @@ -395,10 +395,10 @@ impl Controller { } pub async fn real_lookup(&self, domain_name: String) -> Option { - self.dns - .genuine_lookup(domain_name.as_str()) - .await - .map(|ip| ip.to_string()) + match self.dns.genuine_lookup(domain_name.as_str()).await { + Ok(Some(ip)) => Some(ip.to_string()), + _ => None, + } } pub fn fake_ip_to_real(&self, fake_ip: String) -> Option { diff --git a/boltconn/src/network/dns/dns.rs b/boltconn/src/network/dns/dns.rs index 704ddc5..004768f 100644 --- a/boltconn/src/network/dns/dns.rs +++ b/boltconn/src/network/dns/dns.rs @@ -3,19 +3,55 @@ use crate::network::dns::dns_table::DnsTable; use crate::network::dns::hosts::HostsResolver; use crate::network::dns::ns_policy::{DispatchedDnsResolver, NameserverPolicies}; use crate::network::dns::provider::IfaceProvider; +use crate::proxy::error::TransportError; use arc_swap::ArcSwap; use hickory_proto::op::{Message, MessageType, ResponseCode}; use hickory_proto::rr::{DNSClass, RData, Record, RecordType}; use hickory_resolver::config::*; +use hickory_resolver::error::ResolveErrorKind; use hickory_resolver::name_server::{GenericConnector, RuntimeProvider}; use hickory_resolver::AsyncResolver; use std::collections::HashMap; use std::io; -use std::io::Result; use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; +macro_rules! impl_genuine_lookup { + ($func_name:ident, $lookup_type:ident) => { + async fn $func_name( + name: &str, + domain_name: &str, + resolver: &AsyncResolver>, + ) -> Result, TransportError> { + const TIMEOUT_SEC: u64 = 5; + if let Ok(r) = tokio::time::timeout( + Duration::from_secs(TIMEOUT_SEC), + resolver.$lookup_type(domain_name), + ) + .await + { + return match r { + Ok(result) => Ok(result.iter().next().map(|i| i.0.into())), + Err(e) => match e.kind().clone() { + ResolveErrorKind::Io(err) => Err(TransportError::Io(err)), + _ => Ok(None), + }, + }; + } else { + tracing::debug!( + "DNS {} {} lookup for {} timeout: {}s", + name, + stringify!($lookup_type), + domain_name, + TIMEOUT_SEC + ); + } + Ok(None) + } + }; +} + pub struct GenericDns { name: String, table: DnsTable, @@ -112,67 +148,31 @@ impl GenericDns

{ }) } - async fn genuine_lookup_v4(&self, domain_name: &str) -> Option { + async fn genuine_lookup_v4(&self, domain_name: &str) -> Result, TransportError> { for r in self.resolvers.load().iter() { - if let Some(ip) = Self::genuine_lookup_one_v4(&self.name, domain_name, r).await { - return Some(ip); + if let Some(ip) = Self::genuine_lookup_one_v4(&self.name, domain_name, r).await? { + return Ok(Some(ip)); } } - None + Ok(None) } - - async fn genuine_lookup_one_v4( - name: &str, - domain_name: &str, - resolver: &AsyncResolver>, - ) -> Option { - if let Ok(r) = - tokio::time::timeout(Duration::from_secs(5), resolver.ipv4_lookup(domain_name)).await - { - if let Ok(result) = r { - if let Some(i) = result.iter().next() { - return Some(i.0.into()); - } - } - } else { - tracing::debug!("DNS {name} v4 lookup for {domain_name} timeout: 5s"); - } - None - } - - async fn genuine_lookup_v6(&self, domain_name: &str) -> Option { + async fn genuine_lookup_v6(&self, domain_name: &str) -> Result, TransportError> { for r in self.resolvers.load().iter() { - if let Some(ip) = Self::genuine_lookup_one_v6(&self.name, domain_name, r).await { - return Some(ip); + if let Some(ip) = Self::genuine_lookup_one_v6(&self.name, domain_name, r).await? { + return Ok(Some(ip)); } } - None + Ok(None) } - async fn genuine_lookup_one_v6( - name: &str, - domain_name: &str, - resolver: &AsyncResolver>, - ) -> Option { - if let Ok(r) = - tokio::time::timeout(Duration::from_secs(5), resolver.ipv6_lookup(domain_name)).await - { - if let Ok(result) = r { - if let Some(i) = result.iter().next() { - return Some(i.0.into()); - } - } - } else { - tracing::debug!("DNS {name} v6 lookup for {domain_name} timeout: 5s"); - } - None - } + impl_genuine_lookup!(genuine_lookup_one_v4, ipv4_lookup); + impl_genuine_lookup!(genuine_lookup_one_v6, ipv6_lookup); async fn one_v4_wrapper( name: &str, domain_name: &str, resolver: &DispatchedDnsResolver, - ) -> Option { + ) -> Result, TransportError> { match resolver { DispatchedDnsResolver::Iface(resolver) => { Self::genuine_lookup_one_v4(name, domain_name, resolver).await @@ -187,7 +187,7 @@ impl GenericDns

{ name: &str, domain_name: &str, resolver: &DispatchedDnsResolver, - ) -> Option { + ) -> Result, TransportError> { match resolver { DispatchedDnsResolver::Iface(resolver) => { Self::genuine_lookup_one_v6(name, domain_name, resolver).await @@ -198,7 +198,10 @@ impl GenericDns

{ } } - pub async fn genuine_lookup(&self, domain_name: &str) -> Option { + pub async fn genuine_lookup( + &self, + domain_name: &str, + ) -> Result, TransportError> { self.genuine_lookup_with(domain_name, self.preference).await } @@ -206,12 +209,12 @@ impl GenericDns

{ &self, domain_name: &str, pref: DnsPreference, - ) -> Option { + ) -> Result, TransportError> { if let Some(ip) = self.host_resolver.load().resolve(domain_name) { if (matches!(pref, DnsPreference::Ipv6Only) && ip.is_ipv6()) || (matches!(pref, DnsPreference::Ipv4Only) && ip.is_ipv4()) { - return Some(ip); + return Ok(Some(ip)); } } if let Some(resolver) = self.ns_policy.load().resolve(domain_name) { @@ -223,15 +226,19 @@ impl GenericDns

{ Self::one_v6_wrapper(&self.name, domain_name, resolver).await } DnsPreference::PreferIpv4 => { - if let Some(a) = Self::one_v4_wrapper(&self.name, domain_name, resolver).await { - Some(a) + if let Ok(Some(a)) = + Self::one_v4_wrapper(&self.name, domain_name, resolver).await + { + Ok(Some(a)) } else { Self::one_v6_wrapper(&self.name, domain_name, resolver).await } } DnsPreference::PreferIpv6 => { - if let Some(a) = Self::one_v6_wrapper(&self.name, domain_name, resolver).await { - Some(a) + if let Ok(Some(a)) = + Self::one_v6_wrapper(&self.name, domain_name, resolver).await + { + Ok(Some(a)) } else { Self::one_v4_wrapper(&self.name, domain_name, resolver).await } @@ -242,15 +249,15 @@ impl GenericDns

{ DnsPreference::Ipv4Only => self.genuine_lookup_v4(domain_name).await, DnsPreference::Ipv6Only => self.genuine_lookup_v6(domain_name).await, DnsPreference::PreferIpv4 => { - if let Some(a) = self.genuine_lookup_v4(domain_name).await { - Some(a) + if let Ok(Some(a)) = self.genuine_lookup_v4(domain_name).await { + Ok(Some(a)) } else { self.genuine_lookup_v6(domain_name).await } } DnsPreference::PreferIpv6 => { - if let Some(a) = self.genuine_lookup_v6(domain_name).await { - Some(a) + if let Ok(Some(a)) = self.genuine_lookup_v6(domain_name).await { + Ok(Some(a)) } else { self.genuine_lookup_v4(domain_name).await } @@ -263,6 +270,8 @@ impl GenericDns

{ if let Some(record) = self.table.query_by_ip(fake_ip) { self.genuine_lookup(&record.domain_name) .await + .ok() + .flatten() .unwrap_or(fake_ip) } else { tracing::debug!("Failed to extract fake_ip: {}", fake_ip); @@ -270,7 +279,7 @@ impl GenericDns

{ } } - pub fn respond_to_query(&self, pkt: &[u8]) -> Result> { + pub fn respond_to_query(&self, pkt: &[u8]) -> io::Result> { // https://stackoverflow.com/questions/55092830/how-to-perform-dns-lookup-with-multiple-questions // There should be no >1 questions in on query let err = Err(io::Error::new(io::ErrorKind::InvalidData, "fail to answer")); diff --git a/boltconn/src/transport/smol.rs b/boltconn/src/transport/smol.rs index cd3a8e6..434e5a5 100644 --- a/boltconn/src/transport/smol.rs +++ b/boltconn/src/transport/smol.rs @@ -298,7 +298,7 @@ impl UdpConnTask { None } } - NetworkAddr::DomainName { domain_name, port } => dns + NetworkAddr::DomainName { domain_name, port } => match dns .genuine_lookup_with( domain_name.as_str(), match socket_version { @@ -307,7 +307,11 @@ impl UdpConnTask { }, ) .await - .map(|ip| SocketAddr::new(ip, port)), + { + Ok(Some(ip)) => Some(SocketAddr::new(ip, port)), + Ok(None) => None, + Err(_) => return, + }, } { if tx.send_async((buf, dst)).await.is_err() { return; diff --git a/boltconn/src/transport/wireguard.rs b/boltconn/src/transport/wireguard.rs index 1a01f0a..4be8821 100644 --- a/boltconn/src/transport/wireguard.rs +++ b/boltconn/src/transport/wireguard.rs @@ -109,7 +109,7 @@ impl WireguardTunnel { } => { let resp = dns .genuine_lookup(domain_name) - .await + .await? .ok_or_else(|| io_err("dns not found"))?; SocketAddr::new(resp, port) } From 9fcacc4a59fbd873330ab4aad2c6fb123dad9d3a Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:03:44 -0400 Subject: [PATCH 11/14] chore(adapter): change trait error type --- boltconn/src/adapter/chain.rs | 16 +++--- boltconn/src/adapter/direct.rs | 21 ++++--- boltconn/src/adapter/http.rs | 31 +++++------ boltconn/src/adapter/mod.rs | 19 +++---- boltconn/src/adapter/shadowsocks.rs | 22 ++++---- boltconn/src/adapter/socks5.rs | 20 +++---- boltconn/src/adapter/ssh.rs | 18 +++--- boltconn/src/adapter/trojan.rs | 26 ++++----- boltconn/src/adapter/wireguard.rs | 33 ++++++----- boltconn/src/proxy/error.rs | 4 ++ boltconn/src/transport/smol.rs | 85 +++++++++++------------------ 11 files changed, 134 insertions(+), 161 deletions(-) diff --git a/boltconn/src/adapter/chain.rs b/boltconn/src/adapter/chain.rs index 900e4a3..cd7b82b 100644 --- a/boltconn/src/adapter/chain.rs +++ b/boltconn/src/adapter/chain.rs @@ -3,11 +3,11 @@ use crate::adapter::{ UdpTransferType, }; use async_trait::async_trait; -use std::io; use std::sync::Arc; use crate::common::duplex_chan::DuplexChan; use crate::common::StreamOutboundTrait; +use crate::proxy::error::TransportError; use crate::proxy::ConnAbortHandle; use crate::transport::UdpSocketAdapter; use tokio::task::JoinHandle; @@ -32,7 +32,7 @@ impl ChainOutbound { mut inbound_tcp_container: Option, mut inbound_udp_container: Option, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { tokio::spawn(async move { let mut not_first_jump = false; let mut need_next_jump = true; @@ -136,7 +136,7 @@ impl Outbound for ChainOutbound { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { self.clone().spawn(true, Some(inbound), None, abort_handle) } @@ -146,9 +146,9 @@ impl Outbound for ChainOutbound { _tcp_outbound: Option>, _udp_outbound: Option>, _abort_handle: ConnAbortHandle, - ) -> io::Result { + ) -> Result { tracing::error!("spawn_tcp_with_outbound() should not be called with ChainOutbound"); - return Err(io::ErrorKind::InvalidData.into()); + Err(TransportError::Internal("Invalid outbound")) } fn spawn_udp( @@ -156,7 +156,7 @@ impl Outbound for ChainOutbound { inbound: AddrConnector, abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> JoinHandle> { + ) -> JoinHandle> { self.clone().spawn(false, None, Some(inbound), abort_handle) } @@ -167,8 +167,8 @@ impl Outbound for ChainOutbound { _udp_outbound: Option>, _abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> io::Result { + ) -> Result { tracing::error!("spawn_udp_with_outbound() should not be called with ChainUdpOutbound"); - return Err(io::ErrorKind::InvalidData.into()); + Err(TransportError::Internal("Invalod outbound")) } } diff --git a/boltconn/src/adapter/direct.rs b/boltconn/src/adapter/direct.rs index 6f018a0..f39d1cf 100644 --- a/boltconn/src/adapter/direct.rs +++ b/boltconn/src/adapter/direct.rs @@ -8,7 +8,6 @@ use crate::proxy::error::TransportError; use crate::proxy::{ConnAbortHandle, NetworkAddr}; use crate::transport::UdpSocketAdapter; use async_trait::async_trait; -use std::io; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::UdpSocket; @@ -37,7 +36,11 @@ impl DirectOutbound { } } - async fn run_tcp(self, inbound: Connector, abort_handle: ConnAbortHandle) -> io::Result<()> { + async fn run_tcp( + self, + inbound: Connector, + abort_handle: ConnAbortHandle, + ) -> Result<(), TransportError> { let dst_addr = if let Some(dst) = self.resolved_dst { dst } else { @@ -53,7 +56,7 @@ impl DirectOutbound { self, inbound: AddrConnector, abort_handle: ConnAbortHandle, - ) -> io::Result<()> { + ) -> Result<(), TransportError> { let outbound = Arc::new(Egress::new(&self.iface_name).udpv4_socket().await?); established_udp( self.id(), @@ -81,7 +84,7 @@ impl Outbound for DirectOutbound { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { tokio::spawn(self.clone().run_tcp(inbound, abort_handle)) } @@ -91,9 +94,9 @@ impl Outbound for DirectOutbound { _tcp_outbound: Option>, _udp_outbound: Option>, _abort_handle: ConnAbortHandle, - ) -> io::Result { + ) -> Result { tracing::error!("spawn_tcp_with_outbound() should not be called with DirectOutbound"); - return Err(io::ErrorKind::InvalidData.into()); + Err(TransportError::Internal("Invalid outbound")) } fn spawn_udp( @@ -101,7 +104,7 @@ impl Outbound for DirectOutbound { inbound: AddrConnector, abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> JoinHandle> { + ) -> JoinHandle> { tokio::spawn(self.clone().run_udp(inbound, abort_handle)) } @@ -112,9 +115,9 @@ impl Outbound for DirectOutbound { _udp_outbound: Option>, _abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> io::Result { + ) -> Result { tracing::error!("spawn_udp_with_outbound() should not be called with DirectOutbound"); - return Err(io::ErrorKind::InvalidData.into()); + Err(TransportError::Internal("Invalid outbound")) } } diff --git a/boltconn/src/adapter/http.rs b/boltconn/src/adapter/http.rs index e3a0914..899dd75 100644 --- a/boltconn/src/adapter/http.rs +++ b/boltconn/src/adapter/http.rs @@ -6,12 +6,12 @@ use crate::common::{io_err, StreamOutboundTrait}; use crate::config::AuthData; use crate::network::dns::Dns; use crate::network::egress::Egress; +use crate::proxy::error::TransportError; use crate::proxy::{ConnAbortHandle, NetworkAddr}; use crate::transport::UdpSocketAdapter; use async_trait::async_trait; use base64::Engine; use httparse::Response; -use std::io; use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::task::JoinHandle; @@ -53,7 +53,7 @@ impl HttpOutbound { inbound: Connector, mut outbound: S, abort_handle: ConnAbortHandle, - ) -> io::Result<()> + ) -> Result<(), TransportError> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -79,25 +79,25 @@ impl HttpOutbound { let mut resp = String::new(); while !resp.ends_with("\r\n\r\n") { if buf_reader.read_line(&mut resp).await? == 0 { - return Err(io_err("EOF")); + return Err(TransportError::Http("EOF")); } if resp.len() > 4096 { - return Err(io_err("Too long resp")); + return Err(TransportError::Http("Response too long")); } } let mut buf = [httparse::EMPTY_HEADER; 16]; let mut resp_struct = Response::new(buf.as_mut()); resp_struct .parse(resp.as_bytes()) - .map_err(|_| io_err("Parse response failed"))?; + .map_err(|_| TransportError::Http("Parsing failed"))?; if let Some(200) = resp_struct.code { let tcp_stream = buf_reader.into_inner(); established_tcp(self.name, inbound, tcp_stream, abort_handle).await; Ok(()) } else { - Err(io_err( + Err(TransportError::Io(io_err( format!("Http Connect Failed: {:?}", resp_struct.code).as_str(), - )) + ))) } } } @@ -116,7 +116,7 @@ impl Outbound for HttpOutbound { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { let self_clone = self.clone(); tokio::spawn(async move { let server_addr = @@ -124,10 +124,7 @@ impl Outbound for HttpOutbound { let tcp_stream = Egress::new(&self_clone.iface_name) .tcp_stream(server_addr) .await?; - self_clone - .run_tcp(inbound, tcp_stream, abort_handle) - .await - .map_err(|e| io_err(e.to_string().as_str())) + self_clone.run_tcp(inbound, tcp_stream, abort_handle).await }) } @@ -137,10 +134,10 @@ impl Outbound for HttpOutbound { tcp_outbound: Option>, udp_outbound: Option>, abort_handle: ConnAbortHandle, - ) -> io::Result { + ) -> Result { if tcp_outbound.is_none() || udp_outbound.is_some() { tracing::error!("Invalid HTTP proxy tcp spawn"); - return Err(io::ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let self_clone = self.clone(); tokio::spawn(async move { @@ -157,7 +154,7 @@ impl Outbound for HttpOutbound { _inbound: AddrConnector, _abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> JoinHandle> { + ) -> JoinHandle> { tracing::error!("spawn_udp() should not be called with HttpOutbound"); empty_handle() } @@ -169,8 +166,8 @@ impl Outbound for HttpOutbound { _udp_outbound: Option>, _abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> io::Result { + ) -> Result { tracing::error!("spawn_udp_with_outbound() should not be called with HttpOutbound"); - return Err(io::ErrorKind::InvalidData.into()); + Err(TransportError::Internal("Invalid outbound")) } } diff --git a/boltconn/src/adapter/mod.rs b/boltconn/src/adapter/mod.rs index 9bd2680..aeec73f 100644 --- a/boltconn/src/adapter/mod.rs +++ b/boltconn/src/adapter/mod.rs @@ -36,7 +36,6 @@ pub use direct::*; pub use socks5::*; pub use ssh::*; use std::future::Future; -use std::io::ErrorKind; pub use tcp_adapter::*; pub use trojan::*; pub use udp_adapter::*; @@ -178,7 +177,7 @@ pub trait Outbound: Send + Sync { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle>; + ) -> JoinHandle>; /// Return whether outbound is used async fn spawn_tcp_with_outbound( @@ -187,7 +186,7 @@ pub trait Outbound: Send + Sync { tcp_outbound: Option>, udp_outbound: Option>, abort_handle: ConnAbortHandle, - ) -> io::Result; + ) -> Result; /// Run with tokio::spawn. fn spawn_udp( @@ -195,7 +194,7 @@ pub trait Outbound: Send + Sync { inbound: AddrConnector, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> JoinHandle>; + ) -> JoinHandle>; /// Return whether outbound is used async fn spawn_udp_with_outbound( @@ -205,11 +204,11 @@ pub trait Outbound: Send + Sync { udp_outbound: Option>, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> io::Result; + ) -> Result; } -fn empty_handle() -> JoinHandle> { - tokio::spawn(async move { Err(io_err("Invalid spawn")) }) +fn empty_handle() -> JoinHandle> { + tokio::spawn(async move { Err(TransportError::Internal("Invalid spawn")) }) } #[tracing::instrument(skip_all)] @@ -451,14 +450,14 @@ pub(super) async fn get_dst(dns: &Dns, dst: &NetworkAddr) -> io::Result>>( +pub(super) async fn connect_timeout>>( future: F, component_str: &str, -) -> io::Result<()> { +) -> Result<(), TransportError> { tokio::time::timeout(Duration::from_secs(10), future) .await .unwrap_or_else(|_| { tracing::debug!("{} timeout after 10s", component_str); - Err(ErrorKind::TimedOut.into()) + Err(TransportError::Timeout("connect")) }) } diff --git a/boltconn/src/adapter/shadowsocks.rs b/boltconn/src/adapter/shadowsocks.rs index 6a91601..c711033 100644 --- a/boltconn/src/adapter/shadowsocks.rs +++ b/boltconn/src/adapter/shadowsocks.rs @@ -2,7 +2,7 @@ use crate::adapter::{ established_tcp, established_udp, lookup, AddrConnector, Connector, Outbound, OutboundType, }; -use crate::common::{io_err, StreamOutboundTrait}; +use crate::common::StreamOutboundTrait; use crate::network::dns::Dns; use crate::network::egress::Egress; use crate::proxy::error::TransportError; @@ -101,7 +101,7 @@ impl SSOutbound { outbound: S, server_addr: SocketAddr, abort_handle: ConnAbortHandle, - ) -> io::Result<()> + ) -> Result<(), TransportError> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -119,7 +119,7 @@ impl SSOutbound { server_addr: SocketAddr, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> io::Result<()> { + ) -> Result<(), TransportError> { let (_, context, resolved_config) = self.create_internal(server_addr).await; let proxy_socket = ShadowsocksUdpAdapter::new(context, &resolved_config, adapter_or_socket); established_udp( @@ -148,7 +148,7 @@ impl Outbound for SSOutbound { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { let self_clone = self.clone(); tokio::spawn(async move { let server_addr = self_clone.get_server_addr().await?; @@ -167,10 +167,10 @@ impl Outbound for SSOutbound { tcp_outbound: Option>, udp_outbound: Option>, abort_handle: ConnAbortHandle, - ) -> io::Result { + ) -> Result { if tcp_outbound.is_none() || udp_outbound.is_some() { tracing::error!("Invalid Shadowsocks tcp spawn"); - return Err(io::ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let self_clone = self.clone(); tokio::spawn(async move { @@ -187,14 +187,16 @@ impl Outbound for SSOutbound { inbound: AddrConnector, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> JoinHandle> { + ) -> JoinHandle> { let self_clone = self.clone(); tokio::spawn(async move { let server_addr = self_clone.get_server_addr().await?; let out_sock = { let socket = match server_addr { SocketAddr::V4(_) => Egress::new(&self_clone.iface_name).udpv4_socket().await?, - SocketAddr::V6(_) => return Err(io_err("ss ipv6 udp not supported now")), + SocketAddr::V6(_) => { + return Err(TransportError::Internal("IPv6 not supported")) + } }; socket.connect(server_addr).await?; socket @@ -218,10 +220,10 @@ impl Outbound for SSOutbound { udp_outbound: Option>, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> io::Result { + ) -> Result { if tcp_outbound.is_some() || udp_outbound.is_none() { tracing::error!("Invalid Shadowsocks UDP outbound ancestor"); - return Err(io::ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let udp_outbound = udp_outbound.unwrap(); let self_clone = self.clone(); diff --git a/boltconn/src/adapter/socks5.rs b/boltconn/src/adapter/socks5.rs index 3bdd1ae..216e1cc 100644 --- a/boltconn/src/adapter/socks5.rs +++ b/boltconn/src/adapter/socks5.rs @@ -2,7 +2,7 @@ use crate::adapter::{ established_tcp, established_udp, lookup, AddrConnector, Connector, Outbound, OutboundType, }; -use crate::common::{as_io_err, io_err, StreamOutboundTrait}; +use crate::common::{as_io_err, StreamOutboundTrait}; use crate::config::AuthData; use crate::network::dns::Dns; use crate::network::egress::Egress; @@ -81,7 +81,7 @@ impl Socks5Outbound { inbound: Connector, outbound: S, abort_handle: ConnAbortHandle, - ) -> io::Result<()> + ) -> Result<(), TransportError> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -101,7 +101,7 @@ impl Socks5Outbound { outbound: S, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> io::Result<()> + ) -> Result<(), TransportError> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -115,7 +115,7 @@ impl Socks5Outbound { let mut socks_stream = self.connect_proxy(outbound).await?; let out_sock = Arc::new(match server_addr { SocketAddr::V4(_) => Egress::new(&self.iface_name).udpv4_socket().await?, - SocketAddr::V6(_) => return Err(io_err("udp v6 not supported")), + SocketAddr::V6(_) => return Err(TransportError::Socks5Extra("IPv6 not supported")), }); let bound_addr = socks_stream .request(Socks5Command::UDPAssociate, target.clone()) @@ -151,7 +151,7 @@ impl Outbound for Socks5Outbound { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { let self_clone = self.clone(); tokio::spawn(async move { let server_addr = @@ -169,10 +169,10 @@ impl Outbound for Socks5Outbound { tcp_outbound: Option>, udp_outbound: Option>, abort_handle: ConnAbortHandle, - ) -> io::Result { + ) -> Result { if tcp_outbound.is_none() || udp_outbound.is_some() { tracing::error!("Invalid Socks5 tcp spawn"); - return Err(io::ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let self_clone = self.clone(); tokio::spawn(async move { @@ -188,7 +188,7 @@ impl Outbound for Socks5Outbound { inbound: AddrConnector, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> JoinHandle> { + ) -> JoinHandle> { let self_clone = self.clone(); tokio::spawn(async move { let server_addr = @@ -209,9 +209,9 @@ impl Outbound for Socks5Outbound { _udp_outbound: Option>, _abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> io::Result { + ) -> Result { tracing::error!("Socks5 does not support UDP chain"); - return Err(io::ErrorKind::InvalidData.into()); + Err(TransportError::Internal("Invalid outbound")) } } diff --git a/boltconn/src/adapter/ssh.rs b/boltconn/src/adapter/ssh.rs index 6fc7429..6cb37f0 100644 --- a/boltconn/src/adapter/ssh.rs +++ b/boltconn/src/adapter/ssh.rs @@ -11,8 +11,6 @@ use crate::transport::ssh::{SshConfig, SshTunnel}; use crate::transport::UdpSocketAdapter; use async_trait::async_trait; use std::collections::HashMap; -use std::io; -use std::io::ErrorKind; use std::sync::Arc; use std::time::Duration; use tokio::task::JoinHandle; @@ -97,7 +95,7 @@ impl Outbound for SshOutboundHandle { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { let (tx, _) = tokio::sync::oneshot::channel(); let self_clone = self.clone(); tokio::spawn(async move { @@ -105,7 +103,7 @@ impl Outbound for SshOutboundHandle { let r = self_clone.attach_tcp(inbound, None, abort_handle, tx).await; if let Err(e) = r { abort_handle2.cancel(); - return Err(io_err(format!("SSH TCP spawn error: {:?}", e).as_str())); + return Err(e); } Ok(()) }) @@ -117,10 +115,10 @@ impl Outbound for SshOutboundHandle { tcp_outbound: Option>, udp_outbound: Option>, abort_handle: ConnAbortHandle, - ) -> std::io::Result { + ) -> Result { if tcp_outbound.is_none() || udp_outbound.is_some() { tracing::error!("Invalid SSH proxy tcp spawn"); - return Err(io::ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let (comp_tx, comp_rx) = tokio::sync::oneshot::channel(); let self_clone = self.clone(); @@ -137,7 +135,7 @@ impl Outbound for SshOutboundHandle { }); comp_rx .await - .map_err(|_| ErrorKind::ConnectionAborted.into()) + .map_err(|_| TransportError::ShadowSocks("Aborted")) } fn spawn_udp( @@ -145,7 +143,7 @@ impl Outbound for SshOutboundHandle { _inbound: AddrConnector, _abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> JoinHandle> { + ) -> JoinHandle> { tracing::error!("spawn_udp() should not be called with SshOutbound"); empty_handle() } @@ -157,9 +155,9 @@ impl Outbound for SshOutboundHandle { _udp_outbound: Option>, _abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> std::io::Result { + ) -> Result { tracing::error!("spawn_udp() should not be called with SshOutbound"); - Err(io::ErrorKind::InvalidData.into()) + Err(TransportError::Internal("Invalid outbound")) } } diff --git a/boltconn/src/adapter/trojan.rs b/boltconn/src/adapter/trojan.rs index dbae0c1..ed241f7 100644 --- a/boltconn/src/adapter/trojan.rs +++ b/boltconn/src/adapter/trojan.rs @@ -56,17 +56,14 @@ impl TrojanOutbound { mut inbound: Connector, outbound: S, abort_handle: ConnAbortHandle, - ) -> io::Result<()> + ) -> Result<(), TransportError> where S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, { let mut stream = self.connect_proxy(outbound).await?; let first_packet = inbound.rx.recv().await.ok_or_else(|| io_err("No resp"))?; if let Some(ref uri) = self.config.websocket_path { - let mut stream = self - .with_websocket(stream, uri.as_str()) - .await - .map_err(|e| io_err(e.to_string().as_str()))?; + let mut stream = self.with_websocket(stream, uri.as_str()).await?; self.first_packet(first_packet, TrojanCmd::Connect, &mut stream) .await?; established_tcp(self.name, inbound, stream, abort_handle).await; @@ -84,7 +81,7 @@ impl TrojanOutbound { outbound: S, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> io::Result<()> + ) -> Result<(), TransportError> where S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, { @@ -92,10 +89,7 @@ impl TrojanOutbound { let (data, dst) = inbound.rx.recv().await.ok_or_else(|| io_err("No resp"))?; let first_packet = Bytes::from(encapsule_udp_packet(data.as_ref(), dst)); if let Some(ref uri) = self.config.websocket_path { - let mut stream = self - .with_websocket(stream, uri.as_str()) - .await - .map_err(|e| io_err(e.to_string().as_str()))?; + let mut stream = self.with_websocket(stream, uri.as_str()).await?; self.first_packet(first_packet, TrojanCmd::Associate, &mut stream) .await?; let udp_socket = TrojanUdpSocket::bind(stream); @@ -195,7 +189,7 @@ impl Outbound for TrojanOutbound { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { let self_clone = self.clone(); tokio::spawn(async move { let server_addr = @@ -213,10 +207,10 @@ impl Outbound for TrojanOutbound { tcp_outbound: Option>, udp_outbound: Option>, abort_handle: ConnAbortHandle, - ) -> io::Result { + ) -> Result { if tcp_outbound.is_none() || udp_outbound.is_some() { tracing::error!("Invalid Trojan UDP outbound ancestor"); - return Err(io::ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let self_clone = self.clone(); tokio::spawn(async move { @@ -232,7 +226,7 @@ impl Outbound for TrojanOutbound { inbound: AddrConnector, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> JoinHandle> { + ) -> JoinHandle> { let self_clone = self.clone(); tokio::spawn(async move { let server_addr = @@ -253,10 +247,10 @@ impl Outbound for TrojanOutbound { udp_outbound: Option>, abort_handle: ConnAbortHandle, tunnel_only: bool, - ) -> io::Result { + ) -> Result { if tcp_outbound.is_none() || udp_outbound.is_some() { tracing::error!("Invalid Trojan UDP outbound ancestor"); - return Err(io::ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let tcp_outbound = tcp_outbound.unwrap(); let self_clone = self.clone(); diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 4436912..7df0988 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -3,10 +3,10 @@ use std::collections::HashMap; use crate::adapter; use crate::adapter::udp_over_tcp::UdpOverTcpAdapter; -use crate::common::{io_err, local_async_run, AbortCanary, StreamOutboundTrait, MAX_PKT_SIZE}; +use crate::common::{local_async_run, AbortCanary, StreamOutboundTrait, MAX_PKT_SIZE}; use crate::network::dns::{Dns, GenericDns}; use crate::network::egress::Egress; -use crate::proxy::error::TransportError; +use crate::proxy::error::{DnsError, TransportError}; use crate::proxy::{ConnAbortHandle, NetworkAddr}; use crate::transport::smol::{SmolDnsProvider, SmolStack, VirtualIpDevice}; use crate::transport::wireguard::{WireguardConfig, WireguardTunnel}; @@ -400,7 +400,7 @@ impl WireguardHandle { abort_handle: ConnAbortHandle, adapter: Option, ret_tx: tokio::sync::oneshot::Sender, - ) -> io::Result<()> { + ) -> Result<(), TransportError> { let endpoint = self.get_endpoint(adapter, ret_tx).await?; let notify = endpoint.clone_notify(); let smol_dns = endpoint.stack.lock().await.get_dns(); @@ -409,24 +409,23 @@ impl WireguardHandle { NetworkAddr::DomainName { domain_name, port } => SocketAddr::new( match smol_dns.genuine_lookup(domain_name.as_str()).await { Ok(Some(addr)) => addr, - _ => return Err(ErrorKind::AddrNotAvailable.into()), + _ => return Err(TransportError::Dns(DnsError::ResolveDomain(domain_name))), }, port, ), }; let mut x = endpoint.stack.lock().await; - x.open_tcp(self.src, dst, inbound, abort_handle, notify) + Ok(x.open_tcp(self.src, dst, inbound, abort_handle, notify)?) } async fn get_endpoint( &self, adapter: Option, ret_tx: tokio::sync::oneshot::Sender, - ) -> io::Result> { + ) -> Result, TransportError> { self.manager .get_wg_conn(&self.name, &self.config, adapter, ret_tx) .await - .map_err(|e| io_err(format!("{}", e).as_str())) } async fn attach_udp( @@ -435,11 +434,11 @@ impl WireguardHandle { abort_handle: ConnAbortHandle, adapter: Option, ret_tx: tokio::sync::oneshot::Sender, - ) -> io::Result<()> { + ) -> Result<(), TransportError> { let endpoint = self.get_endpoint(adapter, ret_tx).await?; let notify = endpoint.clone_notify(); let mut x = endpoint.stack.lock().await; - x.open_udp(self.src, inbound, abort_handle, notify) + Ok(x.open_udp(self.src, inbound, abort_handle, notify)?) } } @@ -457,7 +456,7 @@ impl Outbound for WireguardHandle { &self, inbound: Connector, abort_handle: ConnAbortHandle, - ) -> JoinHandle> { + ) -> JoinHandle> { let (tx, _) = tokio::sync::oneshot::channel(); tokio::spawn(adapter::connect_timeout( self.clone().attach_tcp(inbound, abort_handle, None, tx), @@ -471,10 +470,10 @@ impl Outbound for WireguardHandle { tcp_outbound: Option>, udp_outbound: Option>, abort_handle: ConnAbortHandle, - ) -> io::Result { + ) -> Result { if tcp_outbound.is_some() || udp_outbound.is_none() { tracing::error!("Invalid Wireguard UDP outbound ancestor"); - return Err(ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let udp_outbound = udp_outbound.unwrap(); let (ret_tx, ret_rx) = tokio::sync::oneshot::channel(); @@ -489,7 +488,7 @@ impl Outbound for WireguardHandle { )); ret_rx .await - .map_err(|_| ErrorKind::ConnectionAborted.into()) + .map_err(|_| TransportError::Internal("Return rx closed")) } fn spawn_udp( @@ -497,7 +496,7 @@ impl Outbound for WireguardHandle { inbound: AddrConnector, abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> JoinHandle> { + ) -> JoinHandle> { let (ret_tx, _) = tokio::sync::oneshot::channel(); tokio::spawn(adapter::connect_timeout( self.clone().attach_udp(inbound, abort_handle, None, ret_tx), @@ -512,10 +511,10 @@ impl Outbound for WireguardHandle { udp_outbound: Option>, abort_handle: ConnAbortHandle, _tunnel_only: bool, - ) -> io::Result { + ) -> Result { if tcp_outbound.is_some() || udp_outbound.is_none() { tracing::error!("Invalid Wireguard UDP outbound ancestor"); - return Err(ErrorKind::InvalidData.into()); + return Err(TransportError::Internal("Invalid outbound")); } let udp_outbound = udp_outbound.unwrap(); let (ret_tx, ret_rx) = tokio::sync::oneshot::channel(); @@ -530,7 +529,7 @@ impl Outbound for WireguardHandle { )); ret_rx .await - .map_err(|_| ErrorKind::ConnectionAborted.into()) + .map_err(|_| TransportError::Internal("Return rx closed")) } } diff --git a/boltconn/src/proxy/error.rs b/boltconn/src/proxy/error.rs index e3925a6..d6a00bf 100644 --- a/boltconn/src/proxy/error.rs +++ b/boltconn/src/proxy/error.rs @@ -44,6 +44,8 @@ pub enum TransportError { WireGuard(&'static str), #[error("SSH error: {0}")] Ssh(#[from] russh::Error), + #[error("Timeout: {0}")] + Timeout(&'static str), } #[derive(Error, Debug)] @@ -52,6 +54,8 @@ pub enum DnsError { MissingBootstrap(String), #[error("Failed to resolve dns server: {0}")] ResolveServer(String), + #[error("Failed to resolve domain name: {0}")] + ResolveDomain(String), } #[derive(Error, Debug)] diff --git a/boltconn/src/transport/smol.rs b/boltconn/src/transport/smol.rs index 434e5a5..f74835e 100644 --- a/boltconn/src/transport/smol.rs +++ b/boltconn/src/transport/smol.rs @@ -464,30 +464,20 @@ impl SmolStack { abort_handle: ConnAbortHandle, notify: Arc, ) -> io::Result<()> { - if local_addr.port() == 0 { - for _ in 0..10 { - let port = rand::thread_rng().gen_range(32768..65534); - match self.tcp_conn.entry(port) { - Entry::Occupied(_) => continue, - Entry::Vacant(e) => { - let handle = Self::open_tcp_inner( - &mut self.iface, - &mut self.socket_set, - self.ip_addr - .matched_if_addr(remote_addr.ip()) - .ok_or::(ErrorKind::AddrNotAvailable.into())?, - port, - remote_addr, - )?; - e.insert(TcpConnTask::new(connector, handle, abort_handle, notify)); - return Ok(()); + let choose_a_local_port = local_addr.port() == 0; + for _ in 0..10 { + let local_port = if choose_a_local_port { + rand::thread_rng().gen_range(32768..65534) + } else { + local_addr.port() + }; + return match self.tcp_conn.entry(local_port) { + Entry::Occupied(_) => { + if choose_a_local_port { + continue; } + Err(ErrorKind::AddrInUse.into()) } - } - Err(ErrorKind::AddrNotAvailable.into()) - } else { - match self.tcp_conn.entry(local_addr.port()) { - Entry::Occupied(_) => Err(ErrorKind::AddrInUse.into()), Entry::Vacant(e) => { let handle = Self::open_tcp_inner( &mut self.iface, @@ -501,8 +491,9 @@ impl SmolStack { e.insert(TcpConnTask::new(connector, handle, abort_handle, notify)); Ok(()) } - } + }; } + Err(ErrorKind::AddrNotAvailable.into()) } fn open_tcp_inner( @@ -555,36 +546,21 @@ impl SmolStack { buffer_packet_cnt: usize, ) -> io::Result<()> { // todo: IPv6 support when local_addr is a V4 address - if local_addr.port() == 0 { - for _ in 0..10 { - let port = rand::thread_rng().gen_range(32768..65534); - match self.udp_conn.entry(port) { - Entry::Occupied(_) => continue, - Entry::Vacant(e) => { - let handle = Self::open_udp_inner( - &mut self.socket_set, - self.ip_addr - .matched_if_addr(local_addr.ip()) - .ok_or::(ErrorKind::AddrNotAvailable.into())?, - port, - buffer_packet_cnt, - )?; - e.insert(UdpConnTask::new( - connector, - handle, - abort_handle, - self.dns.clone(), - notify, - IPVersion::from_addr(&local_addr.ip()), - )); - return Ok(()); + let choose_a_local_port = local_addr.port() == 0; + + for _ in 0..10 { + let port = if choose_a_local_port { + rand::thread_rng().gen_range(32768..65534) + } else { + local_addr.port() + }; + return match self.udp_conn.entry(port) { + Entry::Occupied(_) => { + if choose_a_local_port { + continue; } + Err(ErrorKind::AddrInUse.into()) } - } - Err(ErrorKind::AddrNotAvailable.into()) - } else { - match self.udp_conn.entry(local_addr.port()) { - Entry::Occupied(_) => Err(ErrorKind::AddrInUse.into()), Entry::Vacant(e) => { let handle = Self::open_udp_inner( &mut self.socket_set, @@ -604,8 +580,9 @@ impl SmolStack { )); Ok(()) } - } + }; } + Err(ErrorKind::AddrNotAvailable.into()) } fn open_udp_inner( @@ -860,7 +837,7 @@ impl RuntimeProvider for SmolDnsProvider { fn connect_tcp( &self, server_addr: SocketAddr, - ) -> Pin>>> { + ) -> Pin>>> { let smol = self.smol.upgrade(); let handle = self.abort_handle.clone(); let notify = self.notify.clone(); @@ -889,7 +866,7 @@ impl RuntimeProvider for SmolDnsProvider { &self, local_addr: SocketAddr, _server_addr: SocketAddr, - ) -> Pin>>> { + ) -> Pin>>> { let smol = self.smol.upgrade(); let notify = self.notify.clone(); let handle = self.abort_handle.clone(); From 1069407d69589bfb875b28abea27f9e2e0bcf58f Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Thu, 3 Oct 2024 17:05:33 -0400 Subject: [PATCH 12/14] fix(wireguard): contended get_wg_conn --- boltconn/src/adapter/wireguard.rs | 122 ++++++++++++++++------------ boltconn/src/transport/wireguard.rs | 5 +- 2 files changed, 73 insertions(+), 54 deletions(-) diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 7df0988..872ee69 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -1,5 +1,4 @@ use crate::adapter::{AddrConnector, AddrConnectorWrapper, Connector, Outbound, OutboundType}; -use std::collections::HashMap; use crate::adapter; use crate::adapter::udp_over_tcp::UdpOverTcpAdapter; @@ -13,6 +12,8 @@ use crate::transport::wireguard::{WireguardConfig, WireguardTunnel}; use crate::transport::{AdapterOrSocket, InterfaceAddress, UdpSocketAdapter}; use async_trait::async_trait; use bytes::Bytes; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; use hickory_resolver::config::{ResolverConfig, ResolverOpts}; use hickory_resolver::name_server::GenericConnector; use hickory_resolver::proto::udp::DnsUdpSocket; @@ -274,8 +275,7 @@ impl Endpoint { pub struct WireguardManager { iface: String, - // We use an async wrapper to avoid deadlock in DashMap - active_conn: Mutex>>, + active_conn: DashMap>, endpoint_resolver: Arc, timeout: Duration, } @@ -297,56 +297,33 @@ impl WireguardManager { adapter: Option, ret_tx: tokio::sync::oneshot::Sender, ) -> Result, TransportError> { + // optimistic trial to avoid extra config.clone() + if let Some(ep) = self.active_conn.get(config) { + if ep.is_active.alive() { + let _ = ret_tx.send(false); + return Ok(ep.clone()); + } + } + // loop is only used for reconnecting a removed connection for _ in 0..10 { // get an existing conn, or create - let mut guard = self.active_conn.lock().await; - if let Some(endpoint) = guard.get(config) { - if endpoint.is_active.alive() { - let _ = ret_tx.send(false); - return Ok(endpoint.clone()); - } else { - guard.remove(config); - continue; - } - } else { - let _ = ret_tx.send(true); - let server_addr = - adapter::get_dst(&self.endpoint_resolver, &config.endpoint).await?; - let outbound = match adapter { - Some(a) => a, - None => { - if config.over_tcp { - let stream = Egress::new(&self.iface).tcp_stream(server_addr).await?; - AdapterOrSocket::Adapter(Arc::new(UdpOverTcpAdapter::new( - stream, - server_addr, - )?)) - } else { - AdapterOrSocket::Socket(match server_addr { - SocketAddr::V4(_) => { - let socket = Egress::new(&self.iface).udpv4_socket().await?; - socket.connect(server_addr).await?; - socket - } - SocketAddr::V6(_) => { - let socket = Egress::new(&self.iface).udpv6_socket().await?; - socket.connect(server_addr).await?; - socket - } - }) - } + // warning: if two keys fall into the same shard, the reconnecting may block this shard + match self.active_conn.entry(config.clone()) { + Entry::Occupied(entry) => { + if entry.get().is_active.alive() { + let _ = ret_tx.send(false); + return Ok(entry.get().clone()); + } else { + entry.remove(); + continue; } - }; - let ep = Endpoint::new( - name, - outbound, - config, - self.endpoint_resolver.clone(), - self.timeout, - ) - .await?; - guard.insert(config.clone(), ep.clone()); - return Ok(ep); + } + Entry::Vacant(e) => { + let _ = ret_tx.send(true); + let ep = self.create_endpoint(name, config, adapter).await?; + e.insert(ep.clone()); + return Ok(ep); + } } } Err(TransportError::WireGuard( @@ -354,11 +331,50 @@ impl WireguardManager { )) } + async fn create_endpoint( + &self, + name: &str, + config: &WireguardConfig, + adapter: Option, + ) -> Result, TransportError> { + let outbound = match adapter { + Some(a) => a, + None => { + let server_addr = + adapter::get_dst(&self.endpoint_resolver, &config.endpoint).await?; + if config.over_tcp { + let stream = Egress::new(&self.iface).tcp_stream(server_addr).await?; + AdapterOrSocket::Adapter(Arc::new(UdpOverTcpAdapter::new(stream, server_addr)?)) + } else { + AdapterOrSocket::Socket(match server_addr { + SocketAddr::V4(_) => { + let socket = Egress::new(&self.iface).udpv4_socket().await?; + socket.connect(server_addr).await?; + socket + } + SocketAddr::V6(_) => { + let socket = Egress::new(&self.iface).udpv6_socket().await?; + socket.connect(server_addr).await?; + socket + } + }) + } + } + }; + Endpoint::new( + name, + outbound, + config, + self.endpoint_resolver.clone(), + self.timeout, + ) + .await + } + pub async fn debug_internal_state(&self) -> Vec { - let conns = self.active_conn.lock().await; let mut ret = Vec::new(); - for (_, v) in conns.iter() { - let r = v.debug_internal_state().await; + for entry in self.active_conn.iter() { + let r = entry.debug_internal_state().await; ret.push(r); } ret diff --git a/boltconn/src/transport/wireguard.rs b/boltconn/src/transport/wireguard.rs index 4be8821..7f94371 100644 --- a/boltconn/src/transport/wireguard.rs +++ b/boltconn/src/transport/wireguard.rs @@ -41,6 +41,7 @@ pub struct WireguardConfig { impl Debug for WireguardConfig { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_tuple("") + .field(&self.name) .field(&self.ip_addr) .field(&self.ip_addr6) .field(&self.endpoint) @@ -51,7 +52,8 @@ impl Debug for WireguardConfig { impl PartialEq for WireguardConfig { fn eq(&self, other: &Self) -> bool { - self.public_key == other.public_key + self.name == other.name + && self.public_key == other.public_key && self.ip_addr == other.ip_addr && self.ip_addr6 == other.ip_addr6 && self.endpoint == other.endpoint @@ -62,6 +64,7 @@ impl Eq for WireguardConfig {} impl Hash for WireguardConfig { fn hash(&self, state: &mut H) { + self.name.hash(state); self.ip_addr.hash(state); self.ip_addr6.hash(state); self.public_key.hash(state); From e18ce17736006e5eae82ed8b31d83cf62dfbb029 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Sun, 13 Oct 2024 19:07:58 -0400 Subject: [PATCH 13/14] chore(wireguard): add interface to stop a master connection --- boltapi/src/rpc.rs | 2 ++ boltconn/src/adapter/wireguard.rs | 21 +++++++++++++++++++++ boltconn/src/cli/mod.rs | 14 +++++++++++--- boltconn/src/cli/request.rs | 12 ++++++++++++ boltconn/src/cli/request_uds.rs | 4 ++++ boltconn/src/external/controller.rs | 4 ++++ boltconn/src/external/uds_controller.rs | 4 ++++ 7 files changed, 58 insertions(+), 3 deletions(-) diff --git a/boltapi/src/rpc.rs b/boltapi/src/rpc.rs index a7dfa80..8ad35fc 100644 --- a/boltapi/src/rpc.rs +++ b/boltapi/src/rpc.rs @@ -57,6 +57,8 @@ pub trait ControlService { async fn get_master_conn_stat() -> Vec; + async fn stop_master_conn(id: String); + async fn reload(); // Streaming diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 872ee69..826ba78 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -261,6 +261,10 @@ impl Endpoint { self.notify.clone() } + pub fn abort_connection(&self) { + let _ = self.stop_sender.send(()); + } + pub async fn debug_internal_state(&self) -> boltapi::MasterConnectionStatus { let tunn_state = self.wg.stats().await; boltapi::MasterConnectionStatus { @@ -371,6 +375,23 @@ impl WireguardManager { .await } + pub async fn stop_master_conn(&self, name: &str) { + let mut stopped = false; + for entry in self.active_conn.iter() { + if entry.value().name == name { + stopped = true; + entry.value().abort_connection(); + tracing::info!("Stop WireGuard master connection #{}", name); + } + } + if !stopped { + tracing::warn!( + "Stop WireGuard master connection #{} failed: no such connection", + name + ); + } + } + pub async fn debug_internal_state(&self) -> Vec { let mut ret = Vec::new(); for entry in self.active_conn.iter() { diff --git a/boltconn/src/cli/mod.rs b/boltconn/src/cli/mod.rs index afcf6b0..3fdcee4 100644 --- a/boltconn/src/cli/mod.rs +++ b/boltconn/src/cli/mod.rs @@ -174,10 +174,17 @@ pub(crate) enum LogsLimitOptions { Get, } -#[derive(Debug, Clone, Copy, Subcommand)] +#[derive(Clone, Debug, Args)] +pub(crate) struct WgOptions { + pub name: String, +} + +#[derive(Debug, Clone, Subcommand)] pub(crate) enum MasterConnOptions { /// Show the WireGuard master connections - Wg, + ListWg, + /// Stop a WireGuard connection + StopWg(WgOptions), } #[derive(Debug, Subcommand)] @@ -387,7 +394,8 @@ pub(crate) async fn controller_main(args: ProgramArgs) -> ! { }, #[cfg(feature = "internal-test")] SubCommand::MasterConn(opt) => match opt { - MasterConnOptions::Wg => requester.master_conn_stats().await, + MasterConnOptions::ListWg => requester.master_conn_stats().await, + MasterConnOptions::StopWg(opt) => requester.stop_master_conn(opt.name).await, }, SubCommand::Start(_) | SubCommand::Generate(_) diff --git a/boltconn/src/cli/request.rs b/boltconn/src/cli/request.rs index 50532e4..c975925 100644 --- a/boltconn/src/cli/request.rs +++ b/boltconn/src/cli/request.rs @@ -327,6 +327,18 @@ impl Requester { } } } + + pub async fn stop_master_conn(&self, id: String) -> Result<()> { + match &self.inner { + Inner::Web(_) => Err(anyhow::anyhow!( + "stop master conn: Not supported by RESTful API" + )), + Inner::Uds(c) => { + c.stop_master_conn(id).await?; + Ok(()) + } + } + } } fn pretty_size(data: u64) -> String { diff --git a/boltconn/src/cli/request_uds.rs b/boltconn/src/cli/request_uds.rs index fb06bd2..aa5ba29 100644 --- a/boltconn/src/cli/request_uds.rs +++ b/boltconn/src/cli/request_uds.rs @@ -175,4 +175,8 @@ impl UdsConnector { pub async fn get_master_conn_stat(&self) -> Result> { Ok(self.client.get_master_conn_stat(Context::current()).await?) } + + pub async fn stop_master_conn(&self, id: String) -> Result<()> { + Ok(self.client.stop_master_conn(Context::current(), id).await?) + } } diff --git a/boltconn/src/external/controller.rs b/boltconn/src/external/controller.rs index 2acf19b..d42ad79 100644 --- a/boltconn/src/external/controller.rs +++ b/boltconn/src/external/controller.rs @@ -394,6 +394,10 @@ impl Controller { self.dispatcher.get_wg_mgr().debug_internal_state().await } + pub async fn stop_master_conn(&self, id: String) { + self.dispatcher.get_wg_mgr().stop_master_conn(&id).await; + } + pub async fn real_lookup(&self, domain_name: String) -> Option { match self.dns.genuine_lookup(domain_name.as_str()).await { Ok(Some(ip)) => Some(ip.to_string()), diff --git a/boltconn/src/external/uds_controller.rs b/boltconn/src/external/uds_controller.rs index 5ac0603..7950d53 100644 --- a/boltconn/src/external/uds_controller.rs +++ b/boltconn/src/external/uds_controller.rs @@ -347,6 +347,10 @@ impl ControlService for UdsRpcServer { self.controller.get_master_conn_stat().await } + async fn stop_master_conn(self, _ctx: Context, id: String) { + self.controller.stop_master_conn(id).await + } + async fn reload(self, _ctx: Context) { self.controller.reload().await } From 1098b4b8c78611e3d3194bc6f86d0aef91f3cbb7 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Sun, 13 Oct 2024 21:01:12 -0400 Subject: [PATCH 14/14] fix(wireguard): remove all tasks in smol side when wireguard side exits --- boltconn/src/adapter/wireguard.rs | 26 ++++++++++++++++---------- boltconn/src/transport/smol.rs | 6 ++++++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 826ba78..cf32db9 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -233,16 +233,22 @@ impl Endpoint { }); } - let name_clone = name.to_string(); - tokio::spawn(async move { - // kill all coroutine - let _ = stop_recv.recv().await; - indi_write.abort(); - wg_out.abort(); - wg_in.abort(); - wg_tick.abort(); - tracing::trace!("[WireGuard] connection #{} killed", name_clone); - }); + { + let name = name.to_string(); + let stack = smol_stack.clone(); + tokio::spawn(async move { + // kill all coroutine + let _ = stop_recv.recv().await; + indi_write.abort(); + wg_out.abort(); + wg_in.abort(); + wg_tick.abort(); + // reset smol stack to drop all channel sender, + // so the receiver can report errors correctly + stack.lock().await.terminate_all(); + tracing::trace!("[WireGuard] connection #{} killed", name); + }); + } tracing::info!("[WireGuard] Established master connection #{}", name); diff --git a/boltconn/src/transport/smol.rs b/boltconn/src/transport/smol.rs index f74835e..09dd5fd 100644 --- a/boltconn/src/transport/smol.rs +++ b/boltconn/src/transport/smol.rs @@ -700,6 +700,12 @@ impl SmolStack { } }); } + + /// Terminate all connections with fatal errors + pub fn terminate_all(&mut self) { + self.tcp_conn.clear(); + self.udp_conn.clear(); + } } // -----------------------------------------------------------------------------------