From 9475ed0ecf2051f19849590f0f042827acb197ab Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Wed, 18 Sep 2024 00:18:58 -0400 Subject: [PATCH] impl(http-inbound): legacy proxy --- boltconn/src/adapter/tcp_adapter.rs | 11 +- boltconn/src/common/mod.rs | 2 + boltconn/src/proxy/dispatcher.rs | 7 +- boltconn/src/proxy/http_inbound.rs | 225 ++++++++++++++++++++++++---- boltconn/src/proxy/mixed_inbound.rs | 3 +- 5 files changed, 205 insertions(+), 43 deletions(-) diff --git a/boltconn/src/adapter/tcp_adapter.rs b/boltconn/src/adapter/tcp_adapter.rs index 0a931ac..a75c18a 100644 --- a/boltconn/src/adapter/tcp_adapter.rs +++ b/boltconn/src/adapter/tcp_adapter.rs @@ -1,5 +1,5 @@ use crate::adapter::{Connector, DuplexCloseGuard, TcpIndicatorGuard, TcpStatus}; -use crate::common::{read_to_bytes_mut, MAX_PKT_SIZE}; +use crate::common::{read_to_bytes_mut, StreamOutboundTrait, MAX_PKT_SIZE}; use crate::proxy::{ConnAbortHandle, ConnContext, NetworkAddr}; use bytes::BytesMut; use io::Result; @@ -8,17 +8,16 @@ use std::net::SocketAddr; use std::sync::atomic::AtomicU8; use std::sync::Arc; use tokio::io::AsyncWriteExt; -use tokio::net::TcpStream; -pub struct TcpAdapter { +pub struct TcpAdapter { stat: TcpStatus, info: Arc, - inbound: TcpStream, + inbound: S, connector: Connector, abort_handle: ConnAbortHandle, } -impl TcpAdapter { +impl TcpAdapter { const BUF_SIZE: usize = 65536; #[allow(clippy::too_many_arguments)] @@ -26,7 +25,7 @@ impl TcpAdapter { src_addr: SocketAddr, dst_addr: NetworkAddr, info: Arc, - inbound: TcpStream, + inbound: S, available: Arc, connector: Connector, abort_handle: ConnAbortHandle, diff --git a/boltconn/src/common/mod.rs b/boltconn/src/common/mod.rs index c0caa5a..1a19576 100644 --- a/boltconn/src/common/mod.rs +++ b/boltconn/src/common/mod.rs @@ -37,6 +37,8 @@ where pub trait StreamOutboundTrait: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static {} +impl StreamOutboundTrait for tokio::net::TcpStream {} + #[cfg(target_os = "windows")] impl StreamOutboundTrait for tokio::net::windows::named_pipe::NamedPipeServer {} #[cfg(not(target_os = "windows"))] diff --git a/boltconn/src/proxy/dispatcher.rs b/boltconn/src/proxy/dispatcher.rs index 41f4029..085e31a 100644 --- a/boltconn/src/proxy/dispatcher.rs +++ b/boltconn/src/proxy/dispatcher.rs @@ -4,6 +4,7 @@ use crate::adapter::{ TrojanOutbound, TunUdpAdapter, WireguardHandle, WireguardManager, }; use crate::common::duplex_chan::DuplexChan; +use crate::common::StreamOutboundTrait; use crate::dispatch::{ ConnInfo, Dispatching, GeneralProxy, InboundIdentity, InboundInfo, ProxyImpl, }; @@ -19,7 +20,7 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, AtomicU8}; use std::sync::Arc; use std::time::Duration; -use tokio::net::{TcpStream, UdpSocket}; +use tokio::net::UdpSocket; use tokio::sync::mpsc; pub(crate) enum DispatchError { @@ -208,13 +209,13 @@ impl Dispatcher { Ok(ChainOutbound::new(res)) } - pub async fn submit_tcp( + pub async fn submit_tcp( &self, inbound: InboundInfo, src_addr: SocketAddr, dst_addr: NetworkAddr, indicator: Arc, - stream: TcpStream, + stream: S, ) -> Result<(), DispatchError> { let process_info = process::get_pid(src_addr, process::NetworkType::Tcp) .map_or(None, process::get_process_info); diff --git a/boltconn/src/proxy/http_inbound.rs b/boltconn/src/proxy/http_inbound.rs index 105c32a..2872513 100644 --- a/boltconn/src/proxy/http_inbound.rs +++ b/boltconn/src/proxy/http_inbound.rs @@ -1,8 +1,17 @@ +use crate::adapter::Connector; +use crate::common::duplex_chan::DuplexChan; use crate::dispatch::{InboundIdentity, InboundInfo}; +use crate::intercept::HyperBody; use crate::proxy::error::TransportError; -use crate::proxy::Dispatcher; +use crate::proxy::{Dispatcher, NetworkAddr}; use base64::Engine; -use httparse::Request; +use bytes::Bytes; +use http::{HeaderMap, HeaderValue, Request, Response}; +use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; +use hyper::body::Incoming; +use hyper::service::service_fn; +use hyper_util::rt::TokioIo; use std::collections::HashMap; use std::io; use std::net::SocketAddr; @@ -65,7 +74,7 @@ impl HttpInbound { ) -> Result<(), TransportError> { // get response let mut buf_reader = BufReader::new(socket); - let mut req = first_byte.unwrap_or_default(); + let mut req = String::new(); while !req.ends_with("\r\n\r\n") { if buf_reader.read_line(&mut req).await? == 0 { return Err(TransportError::Http("Connecting: EOF")); @@ -76,7 +85,7 @@ impl HttpInbound { } let mut socket = buf_reader.into_inner(); let mut buf = [httparse::EMPTY_HEADER; 16]; - let mut req_struct = Request::new(buf.as_mut()); + let mut req_struct = httparse::Request::new(buf.as_mut()); req_struct .parse(req.as_bytes()) .map_err(|_| TransportError::Http("Failed to parse request header"))?; @@ -85,8 +94,6 @@ impl HttpInbound { && req_struct.version.map_or(false, |v| v == 1) { if let Some(Ok(dest)) = req_struct.path.map(|p| p.parse()) { - // None: invalid - // Some(None): valid but empty auth let authorized = if auth.is_empty() { Some(None) } else { @@ -97,32 +104,8 @@ impl HttpInbound { let Ok(value) = std::str::from_utf8(hdr.value) else { break; }; - // manually split - if value.is_ascii() && value.len() > 6 { - let (left, right) = value.split_at(6); - if left.eq_ignore_ascii_case("basic ") { - let b64decoder = base64::engine::general_purpose::STANDARD; - let code = b64decoder.decode(right).map_err(|_| { - TransportError::Http("bad authorization encoding") - })?; - let text = - std::str::from_utf8(code.as_slice()).map_err(|_| { - TransportError::Http("bad authorization utf-8") - })?; - let v: Vec = - text.split(':').map(|s| s.to_string()).collect(); - if v.len() == 2 - && auth - .get(v.first().unwrap()) - .is_some_and(|pwd| pwd == v.get(1).unwrap()) - { - r = Some(Some(v.first().unwrap().clone())); - } else { - r = None; - } - break; - } - } + r = validate_auth(Some(value), &auth); + break; } } r @@ -153,6 +136,32 @@ impl HttpInbound { Err(TransportError::Http("Invalid CONNECT request")) } + pub(super) async fn serve_legacy_connection( + self_port: u16, + socket: TcpStream, + auth: Arc>, + src: SocketAddr, + dispatcher: Arc, + ) -> Result<(), TransportError> { + let legacy_proxy = LegacyProxy { + client: Arc::new(tokio::sync::Mutex::new(None)), + auth, + port: self_port, + src, + dispatcher, + }; + + let service = service_fn(move |req| legacy_proxy.clone().serve_connection(req)); + + tokio::spawn( + hyper::server::conn::http1::Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .serve_connection(TokioIo::new(socket), service), + ); + Ok(()) + } + const fn response403() -> &'static str { "HTTP/1.1 403 Forbidden\r\n\r\n" } @@ -161,3 +170,155 @@ impl HttpInbound { "HTTP/1.1 200 OK\r\n\r\n" } } + +#[derive(Clone)] +struct LegacyProxy { + // Since we only support http/1.1 in legacy proxy, there is no concurrent request. + client: Arc>>>, + auth: Arc>, + port: u16, + src: SocketAddr, + dispatcher: Arc, +} + +impl LegacyProxy { + pub async fn serve_connection( + self, + mut req: Request, + ) -> hyper::Result> { + let conn_keep_alive = check_keep_alive(req.headers()); + let dest = match req.uri().authority() { + Some(auth) => { + let host = auth.host(); + let port = auth.port_u16().unwrap_or(80); + NetworkAddr::DomainName { + domain_name: host.to_string(), + port, + } + } + None => { + return Ok(Response::builder() + .status(400) + .body(HyperBody::new( + http_body_util::Full::new(Bytes::new()).map_err(|e| match e {}), + )) + .unwrap()); + } + }; + let Some(http_auth) = validate_auth( + if let Some(value) = req.headers().get("Proxy-Authorization") { + value.to_str().ok() + } else { + None + }, + &self.auth, + ) else { + // Unauthorized + return Ok(Response::builder() + .status(403) + .body(HyperBody::new( + http_body_util::Full::new(Bytes::new()).map_err(|e| match e {}), + )) + .unwrap()); + }; + clean_headers(req.headers_mut()); + set_keep_alive(req.headers_mut(), conn_keep_alive); + let mut client_holder = self.client.lock().await; + if client_holder.is_none() { + let (left, right) = Connector::new_pair(10); + let _ = self + .dispatcher + .submit_tcp( + InboundInfo::Http(InboundIdentity { + user: http_auth, + port: Some(self.port), + }), + self.src, + dest, + Arc::new(AtomicU8::new(2)), + DuplexChan::new(right), + ) + .await; + let (send_req, conn) = hyper::client::conn::http1::Builder::new() + .handshake(TokioIo::new(DuplexChan::new(left))) + .await?; + tokio::spawn(conn); + *client_holder = Some(send_req); + } + let client = client_holder.as_mut().unwrap(); + let mut res = client.send_request(req).await?; + drop(client_holder); + let resp_keep_alive = conn_keep_alive && check_keep_alive(res.headers()); + clean_headers(res.headers_mut()); + set_keep_alive(res.headers_mut(), resp_keep_alive); + Ok(res.map(BoxBody::new)) + } +} + +// Return value: +// - None: invalid +// - Some(None): valid but empty auth +// - Some(Some(user)): valid auth +fn validate_auth( + auth: Option<&str>, + server_auth: &HashMap, +) -> Option> { + if server_auth.is_empty() { + return Some(None); + } else if let Some(value) = auth { + // manually split + if value.is_ascii() && value.len() > 6 { + let (left, right) = value.split_at(6); + if left.eq_ignore_ascii_case("basic ") { + let b64decoder = base64::engine::general_purpose::STANDARD; + let code = b64decoder.decode(right).ok()?; + let text = std::str::from_utf8(code.as_slice()).ok()?; + let v: Vec = text.split(':').map(|s| s.to_string()).collect(); + if v.len() == 2 + && server_auth + .get(v.first().unwrap()) + .is_some_and(|pwd| pwd == v.get(1).unwrap()) + { + return Some(Some(v.first().unwrap().clone())); + } + } + } + } + None +} + +fn check_keep_alive(headers: &HeaderMap) -> bool { + headers.get("Connection").map_or(false, |v| { + v.to_str() + .unwrap_or_default() + .eq_ignore_ascii_case("keep-alive") + }) || headers.get("Proxy-Connection").map_or(false, |v| { + v.to_str() + .unwrap_or_default() + .eq_ignore_ascii_case("keep-alive") + }) +} + +fn clean_headers(headers: &mut HeaderMap) { + const HOP_BY_HOP_HEADERS: [&str; 10] = [ + "Keep-Alive", + "Transfer-Encoding", + "TE", + "Connection", + "Trailer", + "Upgrade", + "Proxy-Authorization", + "Proxy-Authenticate", + "Proxy-Connection", // Not standard, but many implementations do send this header + "Connection", + ]; + for key in HOP_BY_HOP_HEADERS.iter() { + while headers.remove(*key).is_some() {} + } +} + +fn set_keep_alive(headers: &mut HeaderMap, keep_alive: bool) { + if !keep_alive { + headers.insert("Connection", HeaderValue::from_static("close")); + } +} diff --git a/boltconn/src/proxy/mixed_inbound.rs b/boltconn/src/proxy/mixed_inbound.rs index a6a23b6..e89e4b7 100644 --- a/boltconn/src/proxy/mixed_inbound.rs +++ b/boltconn/src/proxy/mixed_inbound.rs @@ -4,7 +4,6 @@ use std::collections::HashMap; use std::io; use std::net::SocketAddr; use std::sync::Arc; -use tokio::io::AsyncReadExt; use tokio::net::{TcpListener, TcpStream}; pub struct MixedInbound { @@ -59,7 +58,7 @@ impl MixedInbound { async fn serve_connection( self_port: u16, - mut socks_stream: TcpStream, + socks_stream: TcpStream, http_auth: Arc>, socks_auth: Arc>, src_addr: SocketAddr,