diff --git a/client.go b/client.go index 24bd7ff2..9f50882d 100644 --- a/client.go +++ b/client.go @@ -17,6 +17,8 @@ import ( "net/url" "strings" "time" + + "golang.org/x/net/proxy" ) // ErrBadHandshake is returned when the server response to opening handshake is @@ -244,18 +246,39 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h defer cancel() } - var netDial netDialerFunc - switch { - case u.Scheme == "https" && d.NetDialTLSContext != nil: - netDial = d.NetDialTLSContext - case d.NetDialContext != nil: - netDial = d.NetDialContext - case d.NetDial != nil: - netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { - return d.NetDial(net, addr) + netDial := newNetDialerFunc(u.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext) + + // If needed, wrap the dial function to connect through a proxy. + if d.Proxy != nil { + proxyURL, err := d.Proxy(req) + if err != nil { + return nil, nil, err + } + if proxyURL != nil { + forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext) + if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil { + tlsClientConfig := cloneTLSConfig(d.TLSClientConfig) + if tlsClientConfig.ServerName == "" { + _, hostNoPort := hostPortNoPort(proxyURL) + tlsClientConfig.ServerName = hostNoPort + } + netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig) + } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, nil) + } else { + dialer, err := proxy.FromURL(proxyURL, forwardDial) + if err != nil { + return nil, nil, err + } + if d, ok := dialer.(proxy.ContextDialer); ok { + netDial = d.DialContext + } else { + netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { + return dialer.Dial(net, addr) + } + } + } } - default: - netDial = (&net.Dialer{}).DialContext } // If needed, wrap the dial function to set the connection deadline. @@ -275,20 +298,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } } - // If needed, wrap the dial function to connect through a proxy. - if d.Proxy != nil { - proxyURL, err := d.Proxy(req) - if err != nil { - return nil, nil, err - } - if proxyURL != nil { - netDial, err = proxyFromURL(proxyURL, netDial) - if err != nil { - return nil, nil, err - } - } - } - hostPort, hostNoPort := hostPortNoPort(u) trace := httptrace.ContextClientTrace(ctx) if trace != nil && trace.GetConn != nil { diff --git a/client_server_test.go b/client_server_test.go index e4546aea..149db2bf 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -85,6 +85,50 @@ func newTLSServer(t *testing.T) *cstServer { return &s } +type cstProxyServer struct{} + +func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodConnect { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + + upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host) + if err != nil { + _, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n") + return + } + defer upstream.Close() + + _, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n") + + done := make(chan struct{}, 2) + go func() { + _, _ = io.Copy(upstream, conn) + done <- struct{}{} + }() + go func() { + _, _ = io.Copy(conn, upstream) + done <- struct{}{} + }() + <-done +} + +func newProxyServer() *httptest.Server { + return httptest.NewServer(&cstProxyServer{}) +} + +func newTLSProxyServer() *httptest.Server { + return httptest.NewTLSServer(&cstProxyServer{}) +} + func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Because tests wait for a response from a server, we are guaranteed that // the wait group count is incremented before the test waits on the group @@ -165,41 +209,103 @@ func sendRecv(t *testing.T, ws *Conn) { } func TestProxyDial(t *testing.T) { + testcases := []struct { + name string + isTLS bool + tlsServerName string + insecureSkipVerify bool + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + }{{ + name: "http", + isTLS: false, + }, { + name: "https", + isTLS: true, + }, { + name: "https with ServerName", + isTLS: true, + tlsServerName: "example.com", + }, { + name: "https with insecureSkipVerify", + isTLS: true, + insecureSkipVerify: true, + }, { + name: "https with netDialTLSContext", + isTLS: true, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &tls.Dialer{ + Config: &tls.Config{ + InsecureSkipVerify: true, + }, + } + return dialer.DialContext(ctx, network, addr) + }, + }} + + for _, tc := range testcases { + t.Run(tc.name, func(tt *testing.T) { + s := newServer(tt) + defer s.Close() + + var ps *httptest.Server + if tc.isTLS { + ps = newTLSProxyServer() + } else { + ps = newProxyServer() + } - s := newServer(t) - defer s.Close() + psurl, _ := url.Parse(ps.URL) - surl, _ := url.Parse(s.Server.URL) + netDialCalled := false - cstDialer := cstDialer // make local copy for modification on next line. - cstDialer.Proxy = http.ProxyURL(surl) + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(psurl) + if tc.isTLS { + cstDialer.TLSClientConfig = &tls.Config{ + RootCAs: rootCAs(tt, ps), + ServerName: tc.tlsServerName, + InsecureSkipVerify: tc.insecureSkipVerify, + } + if tc.netDialTLSContext != nil { + cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialCalled = true + return tc.netDialTLSContext(ctx, network, addr) + } + } else { + netDialCalled = true + } + } else { + netDialCalled = true + } - connect := false - origHandler := s.Server.Config.Handler + connect := false + origHandler := ps.Config.Handler - // Capture the request Host header. - s.Server.Config.Handler = http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodConnect { - connect = true - w.WriteHeader(http.StatusOK) - return + // Capture the request Host header. + ps.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + connect = true + } + + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + tt.Fatalf("Dial: %v", err) } + defer ws.Close() + sendRecv(tt, ws) if !connect { - t.Log("connect not received") - http.Error(w, "connect not received", http.StatusMethodNotAllowed) - return + tt.Error("connect not received") + } + if !netDialCalled { + tt.Error("netDialTLSContext not called") } - origHandler.ServeHTTP(w, r) }) - - ws, _, err := cstDialer.Dial(s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) } - defer ws.Close() - sendRecv(t, ws) } func TestProxyAuthorizationDial(t *testing.T) { diff --git a/proxy.go b/proxy.go index b4683b9f..c3c0e45f 100644 --- a/proxy.go +++ b/proxy.go @@ -8,16 +8,35 @@ import ( "bufio" "bytes" "context" + "crypto/tls" "encoding/base64" "errors" "net" "net/http" "net/url" "strings" - - "golang.org/x/net/proxy" ) +func newNetDialerFunc( + scheme string, + netDial func(network, addr string) (net.Conn, error), + netDialContext func(ctx context.Context, network, addr string) (net.Conn, error), + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error), +) netDialerFunc { + switch { + case scheme == "https" && netDialTLSContext != nil: + return netDialTLSContext + case netDialContext != nil: + return netDialContext + case netDial != nil: + return func(ctx context.Context, net, addr string) (net.Conn, error) { + return netDial(net, addr) + } + default: + return (&net.Dialer{}).DialContext + } +} + type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { @@ -28,78 +47,71 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) ( return fn(ctx, network, addr) } -func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { - if proxyURL.Scheme == "http" { - return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil - } - dialer, err := proxy.FromURL(proxyURL, forwardDial) - if err != nil { - return nil, err - } - if d, ok := dialer.(proxy.ContextDialer); ok { - return d.DialContext, nil - } - return func(ctx context.Context, net, addr string) (net.Conn, error) { - return dialer.Dial(net, addr) - }, nil -} - -type httpProxyDialer struct { - proxyURL *url.URL - forwardDial netDialerFunc -} +// newHTTPProxyDialerFunc returns a netDialerFunc that dials using the provided +// proxyURL. The forwardDial function is used to establish the connection to the +// proxy server. If tlsClientConfig is not nil, the connection to the proxy is +// upgraded to a TLS connection with tls.Client. +func newHTTPProxyDialerFunc(proxyURL *url.URL, forwardDial netDialerFunc, tlsClientConfig *tls.Config) netDialerFunc { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + hostPort, _ := hostPortNoPort(proxyURL) + conn, err := forwardDial(ctx, network, hostPort) + if err != nil { + return nil, err + } -func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { - hostPort, _ := hostPortNoPort(hpd.proxyURL) - conn, err := hpd.forwardDial(ctx, network, hostPort) - if err != nil { - return nil, err - } + if tlsClientConfig != nil { + tlsConn := tls.Client(conn, tlsClientConfig) + if err = tlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + conn = tlsConn + } - connectHeader := make(http.Header) - if user := hpd.proxyURL.User; user != nil { - proxyUser := user.Username() - if proxyPassword, passwordSet := user.Password(); passwordSet { - credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) - connectHeader.Set("Proxy-Authorization", "Basic "+credential) + connectHeader := make(http.Header) + if user := proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } } - } - connectReq := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Opaque: addr}, - Host: addr, - Header: connectHeader, - } + connectReq := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: connectHeader, + } - if err := connectReq.Write(conn); err != nil { - conn.Close() - return nil, err - } + if err := connectReq.Write(conn); err != nil { + conn.Close() + return nil, err + } - // Read response. It's OK to use and discard buffered reader here because - // the remote server does not speak until spoken to. - br := bufio.NewReader(conn) - resp, err := http.ReadResponse(br, connectReq) - if err != nil { - conn.Close() - return nil, err - } + // Read response. It's OK to use and discard buffered reader here because + // the remote server does not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + conn.Close() + return nil, err + } - // Close the response body to silence false positives from linters. Reset - // the buffered reader first to ensure that Close() does not read from - // conn. - // Note: Applications must call resp.Body.Close() on a response returned - // http.ReadResponse to inspect trailers or read another response from the - // buffered reader. The call to resp.Body.Close() does not release - // resources. - br.Reset(bytes.NewReader(nil)) - _ = resp.Body.Close() + // Close the response body to silence false positives from linters. Reset + // the buffered reader first to ensure that Close() does not read from + // conn. + // Note: Applications must call resp.Body.Close() on a response returned + // http.ReadResponse to inspect trailers or read another response from the + // buffered reader. The call to resp.Body.Close() does not release + // resources. + br.Reset(bytes.NewReader(nil)) + _ = resp.Body.Close() - if resp.StatusCode != http.StatusOK { - _ = conn.Close() - f := strings.SplitN(resp.Status, " ", 2) - return nil, errors.New(f[1]) + if resp.StatusCode != http.StatusOK { + _ = conn.Close() + f := strings.SplitN(resp.Status, " ", 2) + return nil, errors.New(f[1]) + } + return conn, nil } - return conn, nil }