From ce40f85efa502f0779eddec23e424c753a7c5e02 Mon Sep 17 00:00:00 2001 From: Aram Peres <6775216+aramperes@users.noreply.github.com> Date: Sun, 24 Dec 2023 15:05:33 -0500 Subject: [PATCH] Cleanup usage of anyhow with_context --- src/config.rs | 70 +++++++++++++++++----------------------- src/lib.rs | 2 +- src/main.rs | 6 ++-- src/pcap.rs | 20 ++++++------ src/tunnel/tcp.rs | 6 ++-- src/tunnel/udp.rs | 6 ++-- src/virtual_iface/tcp.rs | 4 +-- src/virtual_iface/udp.rs | 4 +-- src/wg.rs | 6 ++-- 9 files changed, 55 insertions(+), 69 deletions(-) diff --git a/src/config.rs b/src/config.rs index 4c030dd..f3fdbc4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -161,14 +161,14 @@ impl Config { .map(|s| PortForwardConfig::from_notation(&s, DEFAULT_PORT_FORWARD_SOURCE)) .collect(); let port_forwards: Vec = port_forwards - .with_context(|| "Failed to parse port forward config")? + .context("Failed to parse port forward config")? .into_iter() .flatten() .collect(); // Read source-peer-ip let source_peer_ip = parse_ip(matches.get_one::("source-peer-ip")) - .with_context(|| "Invalid source peer IP")?; + .context("Invalid source peer IP")?; // Combined `remote` arg and `ONETUN_REMOTE_PORT_FORWARD_#` envs let mut port_forward_strings = HashSet::new(); @@ -196,7 +196,7 @@ impl Config { }) .collect(); let mut remote_port_forwards: Vec = remote_port_forwards - .with_context(|| "Failed to parse remote port forward config")? + .context("Failed to parse remote port forward config")? .into_iter() .flatten() .collect(); @@ -229,7 +229,7 @@ impl Config { { read_to_string(private_key_file) .map(|s| s.trim().to_string()) - .with_context(|| "Failed to read private key file") + .context("Failed to read private key file") } else { if std::env::var("ONETUN_PRIVATE_KEY").is_err() { warnings.push("Private key was passed using CLI. This is insecure. \ @@ -238,20 +238,18 @@ impl Config { matches .get_one::("private-key") .cloned() - .with_context(|| "Missing private key") + .context("Missing private key") }?; let endpoint_addr = parse_addr(matches.get_one::("endpoint-addr")) - .with_context(|| "Invalid endpoint address")?; + .context("Invalid endpoint address")?; let endpoint_bind_addr = if let Some(addr) = matches.get_one::("endpoint-bind-addr") { - let addr = parse_addr(Some(addr)).with_context(|| "Invalid bind address")?; + let addr = parse_addr(Some(addr)).context("Invalid bind address")?; // Make sure the bind address and endpoint address are the same IP version if addr.ip().is_ipv4() != endpoint_addr.ip().is_ipv4() { - return Err(anyhow::anyhow!( - "Endpoint and bind addresses must be the same IP version" - )); + bail!("Endpoint and bind addresses must be the same IP version"); } addr } else { @@ -265,21 +263,19 @@ impl Config { Ok(Self { port_forwards, remote_port_forwards, - private_key: Arc::new( - parse_private_key(&private_key).with_context(|| "Invalid private key")?, - ), + private_key: Arc::new(parse_private_key(&private_key).context("Invalid private key")?), endpoint_public_key: Arc::new( parse_public_key(matches.get_one::("endpoint-public-key")) - .with_context(|| "Invalid endpoint public key")?, + .context("Invalid endpoint public key")?, ), preshared_key: parse_preshared_key(matches.get_one::("preshared-key"))?, endpoint_addr, endpoint_bind_addr, source_peer_ip, keepalive_seconds: parse_keep_alive(matches.get_one::("keep-alive")) - .with_context(|| "Invalid keep-alive value")?, + .context("Invalid keep-alive value")?, max_transmission_unit: parse_mtu(matches.get_one::("max-transmission-unit")) - .with_context(|| "Invalid max-transmission-unit value")?, + .context("Invalid max-transmission-unit value")?, log: matches .get_one::("log") .cloned() @@ -291,22 +287,22 @@ impl Config { } fn parse_addr>(s: Option) -> anyhow::Result { - s.with_context(|| "Missing address")? + s.context("Missing address")? .as_ref() .to_socket_addrs() - .with_context(|| "Invalid address")? + .context("Invalid address")? .next() - .with_context(|| "Could not lookup address") + .context("Could not lookup address") } fn parse_ip(s: Option<&String>) -> anyhow::Result { - s.with_context(|| "Missing IP")? + s.context("Missing IP address")? .parse::() - .with_context(|| "Invalid IP address") + .context("Invalid IP address") } fn parse_private_key(s: &str) -> anyhow::Result { - let decoded = base64::decode(s).with_context(|| "Failed to decode private key")?; + let decoded = base64::decode(s).context("Failed to decode private key")?; if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() { Ok(StaticSecret::from(bytes)) } else { @@ -315,8 +311,8 @@ fn parse_private_key(s: &str) -> anyhow::Result { } fn parse_public_key(s: Option<&String>) -> anyhow::Result { - let encoded = s.with_context(|| "Missing public key")?; - let decoded = base64::decode(encoded).with_context(|| "Failed to decode public key")?; + let encoded = s.context("Missing public key")?; + let decoded = base64::decode(encoded).context("Failed to decode public key")?; if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() { Ok(PublicKey::from(bytes)) } else { @@ -326,7 +322,7 @@ fn parse_public_key(s: Option<&String>) -> anyhow::Result { fn parse_preshared_key(s: Option<&String>) -> anyhow::Result> { if let Some(s) = s { - let decoded = base64::decode(s).with_context(|| "Failed to decode preshared key")?; + let decoded = base64::decode(s).context("Failed to decode preshared key")?; if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() { Ok(Some(bytes)) } else { @@ -352,9 +348,7 @@ fn parse_keep_alive(s: Option<&String>) -> anyhow::Result> { } fn parse_mtu(s: Option<&String>) -> anyhow::Result { - s.with_context(|| "Missing MTU")? - .parse() - .with_context(|| "Invalid MTU") + s.context("Missing MTU")?.parse().context("Invalid MTU") } #[cfg(unix)] @@ -483,27 +477,21 @@ impl PortForwardConfig { let source = ( src_addr.0.unwrap_or(default_source), - src_addr - .1 - .parse::() - .with_context(|| "Invalid source port")?, + src_addr.1.parse::().context("Invalid source port")?, ) .to_socket_addrs() - .with_context(|| "Invalid source address")? + .context("Invalid source address")? .next() - .with_context(|| "Could not resolve source address")?; + .context("Could not resolve source address")?; let destination = ( dst_addr.0, - dst_addr - .1 - .parse::() - .with_context(|| "Invalid source port")?, + dst_addr.1.parse::().context("Invalid source port")?, ) .to_socket_addrs() // TODO: Pass this as given and use DNS config instead (issue #15) - .with_context(|| "Invalid destination address")? + .context("Invalid destination address")? .next() - .with_context(|| "Could not resolve destination address")?; + .context("Could not resolve destination address")?; // Parse protocols let protocols = if let Some(protocols) = protocols { @@ -513,7 +501,7 @@ impl PortForwardConfig { } else { Ok(vec![PortProtocol::Tcp]) } - .with_context(|| "Failed to parse protocols")?; + .context("Failed to parse protocols")?; // Returns an config for each protocol Ok(protocols diff --git a/src/lib.rs b/src/lib.rs index a43d657..a76fa18 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ pub async fn start_tunnels(config: Config, bus: Bus) -> anyhow::Result<()> { let wg = WireGuardTunnel::new(&config, bus.clone()) .await - .with_context(|| "Failed to initialize WireGuard tunnel")?; + .context("Failed to initialize WireGuard tunnel")?; let wg = Arc::new(wg); { diff --git a/src/main.rs b/src/main.rs index 78c6419..4ff2954 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ async fn main() -> anyhow::Result<()> { use anyhow::Context; use onetun::{config::Config, events::Bus}; - let config = Config::from_args().with_context(|| "Failed to read config")?; + let config = Config::from_args().context("Configuration has errors")?; init_logger(&config)?; for warning in &config.warnings { @@ -32,7 +32,5 @@ fn init_logger(config: &onetun::config::Config) -> anyhow::Result<()> { let mut builder = pretty_env_logger::formatted_timed_builder(); builder.parse_filters(&config.log); - builder - .try_init() - .with_context(|| "Failed to initialize logger") + builder.try_init().context("Failed to initialize logger") } diff --git a/src/pcap.rs b/src/pcap.rs index 1771a33..487fbe0 100644 --- a/src/pcap.rs +++ b/src/pcap.rs @@ -16,7 +16,7 @@ impl Pcap { self.writer .flush() .await - .with_context(|| "Failed to flush pcap writer") + .context("Failed to flush pcap writer") } async fn write(&mut self, data: &[u8]) -> anyhow::Result { @@ -30,14 +30,14 @@ impl Pcap { self.writer .write_u16(value) .await - .with_context(|| "Failed to write u16 to pcap writer") + .context("Failed to write u16 to pcap writer") } async fn write_u32(&mut self, value: u32) -> anyhow::Result<()> { self.writer .write_u32(value) .await - .with_context(|| "Failed to write u32 to pcap writer") + .context("Failed to write u32 to pcap writer") } async fn global_header(&mut self) -> anyhow::Result<()> { @@ -64,14 +64,14 @@ impl Pcap { async fn packet(&mut self, timestamp: Instant, packet: &[u8]) -> anyhow::Result<()> { self.packet_header(timestamp, packet.len()) .await - .with_context(|| "Failed to write packet header to pcap writer")?; + .context("Failed to write packet header to pcap writer")?; self.write(packet) .await - .with_context(|| "Failed to write packet to pcap writer")?; + .context("Failed to write packet to pcap writer")?; self.writer .flush() .await - .with_context(|| "Failed to flush pcap writer")?; + .context("Failed to flush pcap writer")?; self.flush().await } } @@ -81,14 +81,14 @@ pub async fn capture(pcap_file: String, bus: Bus) -> anyhow::Result<()> { let mut endpoint = bus.new_endpoint(); let file = File::create(&pcap_file) .await - .with_context(|| "Failed to create pcap file")?; + .context("Failed to create pcap file")?; let writer = BufWriter::new(file); let mut writer = Pcap { writer }; writer .global_header() .await - .with_context(|| "Failed to write global header to pcap writer")?; + .context("Failed to write global header to pcap writer")?; info!("Capturing WireGuard IP packets to {}", &pcap_file); loop { @@ -98,14 +98,14 @@ pub async fn capture(pcap_file: String, bus: Bus) -> anyhow::Result<()> { writer .packet(instant, &ip) .await - .with_context(|| "Failed to write inbound IP packet to pcap writer")?; + .context("Failed to write inbound IP packet to pcap writer")?; } Event::OutboundInternetPacket(ip) => { let instant = Instant::now(); writer .packet(instant, &ip) .await - .with_context(|| "Failed to write output IP packet to pcap writer")?; + .context("Failed to write output IP packet to pcap writer")?; } _ => {} } diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index b5e1ec5..47b0197 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -27,14 +27,14 @@ pub async fn tcp_proxy_server( ) -> anyhow::Result<()> { let listener = TcpListener::bind(port_forward.source) .await - .with_context(|| "Failed to listen on TCP proxy server")?; + .context("Failed to listen on TCP proxy server")?; loop { let port_pool = port_pool.clone(); let (socket, peer_addr) = listener .accept() .await - .with_context(|| "Failed to accept connection on TCP proxy server")?; + .context("Failed to accept connection on TCP proxy server")?; // Assign a 'virtual port': this is a unique port number used to route IP packets // received from the WireGuard tunnel. It is the port number that the virtual client will @@ -192,7 +192,7 @@ impl TcpPortPool { let port = inner .queue .pop_front() - .with_context(|| "TCP virtual port pool is exhausted")?; + .context("TCP virtual port pool is exhausted")?; Ok(VirtualPort::new(port, PortProtocol::Tcp)) } diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 32fef15..ab52dc7 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -37,7 +37,7 @@ pub async fn udp_proxy_server( let mut endpoint = bus.new_endpoint(); let socket = UdpSocket::bind(port_forward.source) .await - .with_context(|| "Failed to bind on UDP proxy address")?; + .context("Failed to bind on UDP proxy address")?; let mut buffer = [0u8; MAX_PACKET]; loop { @@ -103,7 +103,7 @@ async fn next_udp_datagram( let (size, peer_addr) = socket .recv_from(buffer) .await - .with_context(|| "Failed to accept incoming UDP datagram")?; + .context("Failed to accept incoming UDP datagram")?; // Assign a 'virtual port': this is a unique port number used to route IP packets // received from the WireGuard tunnel. It is the port number that the virtual client will @@ -212,7 +212,7 @@ impl UdpPortPool { None } }) - .with_context(|| "virtual port pool is exhausted")?; + .context("Virtual port pool is exhausted")?; inner.port_by_peer_addr.insert(peer_addr, port); inner.peer_addr_by_port.insert(port, peer_addr); diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index ab9ca08..7522d7e 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -56,7 +56,7 @@ impl TcpVirtualInterface { IpAddress::from(port_forward.destination.ip()), port_forward.destination.port(), )) - .with_context(|| "Virtual server socket failed to listen")?; + .context("Virtual server socket failed to listen")?; Ok(socket) } @@ -218,7 +218,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { ), (IpAddress::from(self.source_peer_ip), virtual_port.num()), ) - .with_context(|| "Virtual server socket failed to listen")?; + .context("Virtual server socket failed to listen")?; next_poll = None; } diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 214c01e..3ca4c2d 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -61,7 +61,7 @@ impl UdpVirtualInterface { IpAddress::from(port_forward.destination.ip()), port_forward.destination.port(), )) - .with_context(|| "UDP virtual server socket failed to bind")?; + .context("UDP virtual server socket failed to bind")?; Ok(socket) } @@ -78,7 +78,7 @@ impl UdpVirtualInterface { let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); socket .bind((IpAddress::from(source_peer_ip), client_port.num())) - .with_context(|| "UDP virtual client failed to bind")?; + .context("UDP virtual client failed to bind")?; Ok(socket) } diff --git a/src/wg.rs b/src/wg.rs index 5f5b735..fc346a2 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -41,7 +41,7 @@ impl WireGuardTunnel { let endpoint = config.endpoint_addr; let udp = UdpSocket::bind(config.endpoint_bind_addr) .await - .with_context(|| "Failed to create UDP socket for WireGuard connection")?; + .context("Failed to create UDP socket for WireGuard connection")?; Ok(Self { source_peer_ip, @@ -65,7 +65,7 @@ impl WireGuardTunnel { self.udp .send_to(packet, self.endpoint) .await - .with_context(|| "Failed to send encrypted IP packet to WireGuard endpoint.")?; + .context("Failed to send encrypted IP packet to WireGuard endpoint.")?; debug!( "Sent {} bytes to WireGuard endpoint (encrypted IP packet)", packet.len() @@ -244,7 +244,7 @@ impl WireGuardTunnel { None, ) .map_err(|s| anyhow::anyhow!("{}", s)) - .with_context(|| "Failed to initialize boringtun Tunn") + .context("Failed to initialize boringtun Tunn") } /// Determine the inner protocol of the incoming IP packet (TCP/UDP).