Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TLS connection to HTTP Proxy #950

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"net/url"
"strings"
"time"

"golang.org/x/net/proxy"
)

// ErrBadHandshake is returned when the server response to opening handshake is
Expand Down Expand Up @@ -244,18 +246,39 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer cancel()
}

var netDial netDialerFunc
switch {
case u.Scheme == "https" && d.NetDialTLSContext != nil:
netDial = d.NetDialTLSContext
case d.NetDialContext != nil:
netDial = d.NetDialContext
case d.NetDial != nil:
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return d.NetDial(net, addr)
netDial := newNetDialerFunc(u.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)

// If needed, wrap the dial function to connect through a proxy.
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
if tlsClientConfig.ServerName == "" {
_, hostNoPort := hostPortNoPort(proxyURL)
tlsClientConfig.ServerName = hostNoPort
}
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, nil)
} else {
dialer, err := proxy.FromURL(proxyURL, forwardDial)
if err != nil {
return nil, nil, err
}
if d, ok := dialer.(proxy.ContextDialer); ok {
netDial = d.DialContext
} else {
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return dialer.Dial(net, addr)
}
}
}
}
default:
netDial = (&net.Dialer{}).DialContext
}

// If needed, wrap the dial function to set the connection deadline.
Expand All @@ -275,20 +298,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}

// If needed, wrap the dial function to connect through a proxy.
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
netDial, err = proxyFromURL(proxyURL, netDial)
if err != nil {
return nil, nil, err
}
}
}
Comment on lines -278 to -290
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to set a deadline for the dialing to proxy, so moved the line up.


hostPort, hostNoPort := hostPortNoPort(u)
trace := httptrace.ContextClientTrace(ctx)
if trace != nil && trace.GetConn != nil {
Expand Down
154 changes: 130 additions & 24 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,50 @@ func newTLSServer(t *testing.T) *cstServer {
return &s
}

type cstProxyServer struct{}

func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}

conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()

upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host)
if err != nil {
_, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
return
}
defer upstream.Close()

_, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n")

done := make(chan struct{}, 2)
go func() {
_, _ = io.Copy(upstream, conn)
done <- struct{}{}
}()
go func() {
_, _ = io.Copy(conn, upstream)
done <- struct{}{}
}()
<-done
}

func newProxyServer() *httptest.Server {
return httptest.NewServer(&cstProxyServer{})
}

func newTLSProxyServer() *httptest.Server {
return httptest.NewTLSServer(&cstProxyServer{})
}

func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
Expand Down Expand Up @@ -165,41 +209,103 @@ func sendRecv(t *testing.T, ws *Conn) {
}

func TestProxyDial(t *testing.T) {
testcases := []struct {
name string
isTLS bool
tlsServerName string
insecureSkipVerify bool
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
}{{
name: "http",
isTLS: false,
}, {
name: "https",
isTLS: true,
}, {
name: "https with ServerName",
isTLS: true,
tlsServerName: "example.com",
}, {
name: "https with insecureSkipVerify",
isTLS: true,
insecureSkipVerify: true,
}, {
name: "https with netDialTLSContext",
isTLS: true,
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &tls.Dialer{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
return dialer.DialContext(ctx, network, addr)
},
}}

for _, tc := range testcases {
t.Run(tc.name, func(tt *testing.T) {
s := newServer(tt)
defer s.Close()

var ps *httptest.Server
if tc.isTLS {
ps = newTLSProxyServer()
} else {
ps = newProxyServer()
}

s := newServer(t)
defer s.Close()
psurl, _ := url.Parse(ps.URL)

surl, _ := url.Parse(s.Server.URL)
netDialCalled := false

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(surl)
cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(psurl)
if tc.isTLS {
cstDialer.TLSClientConfig = &tls.Config{
RootCAs: rootCAs(tt, ps),
ServerName: tc.tlsServerName,
InsecureSkipVerify: tc.insecureSkipVerify,
}
if tc.netDialTLSContext != nil {
cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialCalled = true
return tc.netDialTLSContext(ctx, network, addr)
}
} else {
netDialCalled = true
}
} else {
netDialCalled = true
}

connect := false
origHandler := s.Server.Config.Handler
connect := false
origHandler := ps.Config.Handler

// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
connect = true
w.WriteHeader(http.StatusOK)
return
// Capture the request Host header.
ps.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
connect = true
}

origHandler.ServeHTTP(w, r)
})

ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
tt.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(tt, ws)

if !connect {
t.Log("connect not received")
http.Error(w, "connect not received", http.StatusMethodNotAllowed)
return
tt.Error("connect not received")
}
if !netDialCalled {
tt.Error("netDialTLSContext not called")
}
origHandler.ServeHTTP(w, r)
})

ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}

func TestProxyAuthorizationDial(t *testing.T) {
Expand Down
Loading