From 1f2fc5b28f0c5034da6ca684d6a7b8f144e00224 Mon Sep 17 00:00:00 2001 From: Yujia Qiao Date: Sat, 11 Jun 2022 11:39:11 +0800 Subject: [PATCH] feat: cache dns result for one session (#166) --- src/client.rs | 10 ++++++---- src/helper.rs | 15 ++++++++++++--- src/transport/mod.rs | 40 +++++++++++++++++++++++++++++++++++++--- src/transport/noise.rs | 4 ++-- src/transport/tcp.rs | 4 ++-- src/transport/tls.rs | 11 +++++++---- 6 files changed, 66 insertions(+), 18 deletions(-) diff --git a/src/client.rs b/src/client.rs index 8df83894..1149bec0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,7 +6,7 @@ use crate::protocol::{ self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd, DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES, }; -use crate::transport::{SocketOpts, TcpTransport, Transport}; +use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport}; use anyhow::{anyhow, bail, Context, Result}; use backoff::ExponentialBackoff; use backoff::{backoff::Backoff, future::retry_notify}; @@ -150,7 +150,7 @@ impl Client { struct RunDataChannelArgs { session_key: Nonce, - remote_addr: String, + remote_addr: AddrMaybeCached, connector: Arc, socket_opts: SocketOpts, service: ClientServiceConfig, @@ -385,9 +385,12 @@ struct ControlChannelHandle { impl ControlChannel { #[instrument(skip_all)] async fn run(&mut self) -> Result<()> { + let mut remote_addr = AddrMaybeCached::new(&self.remote_addr); + remote_addr.resolve().await?; + let mut conn = self .transport - .connect(&self.remote_addr) + .connect(&remote_addr) .await .with_context(|| format!("Failed to connect to {}", &self.remote_addr))?; T::hint(&conn, SocketOpts::for_control_channel()); @@ -432,7 +435,6 @@ impl ControlChannel { // Channel ready info!("Control channel established"); - let remote_addr = self.remote_addr.clone(); // Socket options for the data channel let socket_opts = SocketOpts::from_client_cfg(&self.service); let data_ch_args = Arc::new(RunDataChannelArgs { diff --git a/src/helper.rs b/src/helper.rs index 68db8807..b795932f 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -10,6 +10,8 @@ use tokio::{ use tracing::trace; use url::Url; +use crate::transport::AddrMaybeCached; + // Tokio hesitates to expose this option...So we have to do it on our own :( // The good news is that using socket2 it can be easily done, without losing portability. // See https://github.com/tokio-rs/tokio/issues/3082 @@ -40,7 +42,7 @@ pub fn feature_not_compile(feature: &str) -> ! { ) } -async fn to_socket_addr(addr: A) -> Result { +pub async fn to_socket_addr(addr: A) -> Result { lookup_host(addr) .await? .next() @@ -68,8 +70,12 @@ pub async fn udp_connect(addr: A) -> Result { /// Create a TcpStream using a proxy /// e.g. socks5://user:pass@127.0.0.1:1080 http://127.0.0.1:8080 -pub async fn tcp_connect_with_proxy(addr: &str, proxy: Option<&Url>) -> Result { +pub async fn tcp_connect_with_proxy( + addr: &AddrMaybeCached, + proxy: Option<&Url>, +) -> Result { if let Some(url) = proxy { + let addr = &addr.addr; let mut s = TcpStream::connect(( url.host_str().expect("proxy url should have host field"), url.port().expect("proxy url should have port field"), @@ -108,7 +114,10 @@ pub async fn tcp_connect_with_proxy(addr: &str, proxy: Option<&Url>) -> Result TcpStream::connect(s).await?, + None => TcpStream::connect(&addr.addr).await?, + }) } } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index cc0e139f..d98afb80 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -1,8 +1,8 @@ use crate::config::{ClientServiceConfig, ServerServiceConfig, TcpConfig, TransportConfig}; -use crate::helper::try_set_tcp_keepalive; +use crate::helper::{to_socket_addr, try_set_tcp_keepalive}; use anyhow::{Context, Result}; use async_trait::async_trait; -use std::fmt::Debug; +use std::fmt::{Debug, Display}; use std::net::SocketAddr; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; @@ -14,6 +14,40 @@ pub const DEFAULT_NODELAY: bool = false; pub const DEFAULT_KEEPALIVE_SECS: u64 = 20; pub const DEFAULT_KEEPALIVE_INTERVAL: u64 = 8; +#[derive(Clone)] +pub struct AddrMaybeCached { + pub addr: String, + pub socket_addr: Option, +} + +impl AddrMaybeCached { + pub fn new(addr: &str) -> AddrMaybeCached { + AddrMaybeCached { + addr: addr.to_string(), + socket_addr: None, + } + } + + pub async fn resolve(&mut self) -> Result<()> { + match to_socket_addr(&self.addr).await { + Ok(s) => { + self.socket_addr = Some(s); + Ok(()) + } + Err(e) => Err(e), + } + } +} + +impl Display for AddrMaybeCached { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.socket_addr { + Some(s) => f.write_fmt(format_args!("{}", s)), + None => f.write_str(&self.addr), + } + } +} + /// Specify a transport layer, like TCP, TLS #[async_trait] pub trait Transport: Debug + Send + Sync { @@ -30,7 +64,7 @@ pub trait Transport: Debug + Send + Sync { /// accept must be cancel safe async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)>; async fn handshake(&self, conn: Self::RawStream) -> Result; - async fn connect(&self, addr: &str) -> Result; + async fn connect(&self, addr: &AddrMaybeCached) -> Result; } mod tcp; diff --git a/src/transport/noise.rs b/src/transport/noise.rs index b3391d63..8ffc37ce 100644 --- a/src/transport/noise.rs +++ b/src/transport/noise.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -use super::{SocketOpts, TcpTransport, Transport}; +use super::{AddrMaybeCached, SocketOpts, TcpTransport, Transport}; use crate::config::{NoiseConfig, TransportConfig}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; @@ -92,7 +92,7 @@ impl Transport for NoiseTransport { Ok(conn) } - async fn connect(&self, addr: &str) -> Result { + async fn connect(&self, addr: &AddrMaybeCached) -> Result { let conn = self .tcp .connect(addr) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index f2d360a3..3c5e242e 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -3,7 +3,7 @@ use crate::{ helper::tcp_connect_with_proxy, }; -use super::{SocketOpts, Transport}; +use super::{AddrMaybeCached, SocketOpts, Transport}; use anyhow::Result; use async_trait::async_trait; use std::net::SocketAddr; @@ -46,7 +46,7 @@ impl Transport for TcpTransport { Ok(conn) } - async fn connect(&self, addr: &str) -> Result { + async fn connect(&self, addr: &AddrMaybeCached) -> Result { let s = tcp_connect_with_proxy(addr, self.cfg.proxy.as_ref()).await?; self.socket_opts.apply(&s); Ok(s) diff --git a/src/transport/tls.rs b/src/transport/tls.rs index c4f6579b..80433608 100644 --- a/src/transport/tls.rs +++ b/src/transport/tls.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -use super::{SocketOpts, TcpTransport, Transport}; +use super::{AddrMaybeCached, SocketOpts, TcpTransport, Transport}; use crate::config::{TlsConfig, TransportConfig}; use crate::helper::host_port_pair; use anyhow::{anyhow, Context, Result}; @@ -26,7 +26,10 @@ impl Transport for TlsTransport { fn new(config: &TransportConfig) -> Result { let tcp = TcpTransport::new(config)?; - let config = config.tls.as_ref().ok_or_else(|| anyhow!("Missing tls config"))?; + let config = config + .tls + .as_ref() + .ok_or_else(|| anyhow!("Missing tls config"))?; let connector = match config.trusted_root.as_ref() { Some(path) => { @@ -87,7 +90,7 @@ impl Transport for TlsTransport { Ok(conn) } - async fn connect(&self, addr: &str) -> Result { + async fn connect(&self, addr: &AddrMaybeCached) -> Result { let conn = self.tcp.connect(addr).await?; let connector = self.connector.as_ref().unwrap(); @@ -96,7 +99,7 @@ impl Transport for TlsTransport { self.config .hostname .as_deref() - .unwrap_or(host_port_pair(addr)?.0), + .unwrap_or(host_port_pair(&addr.addr)?.0), conn, ) .await?)