Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
134130 committed Jul 18, 2024
1 parent bad5b0a commit 3dddd1c
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 8 deletions.
9 changes: 4 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
if d.TLSClientConfig == nil {
tlsClientConfig = &tls.Config{
ServerName: proxyURL.Hostname(),
}
if tlsClientConfig.ServerName == "" {
_, hostNoPort := hostPortNoPort(proxyURL)
tlsClientConfig.ServerName = hostNoPort
}
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
Expand Down Expand Up @@ -369,7 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
"sharing tlsServerName.Config with net/http Transport can cause this error: %w",
proto, err,
)
}
Expand Down
150 changes: 147 additions & 3 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer {
return &s
}

type cstProxyServer struct{}

func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}

conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()

upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host)
if err != nil {
_, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
return
}
defer upstream.Close()

_, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n")

wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(upstream, conn)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(conn, upstream)
}()
wg.Wait()
}

func newProxyServer() *httptest.Server {
return httptest.NewServer(&cstProxyServer{})
}

func newTLSProxyServer() *httptest.Server {
return httptest.NewTLSServer(&cstProxyServer{})
}

func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
Expand Down Expand Up @@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) {
}

func TestProxyDial(t *testing.T) {

s := newServer(t)
defer s.Close()

Expand Down Expand Up @@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) {
sendRecv(t, ws)
}

func TestProxyDialer(t *testing.T) {
testcases := []struct {
name string
isTLS bool
tlsServerName string // optional host for tls ServerName
insecureSkipVerify bool
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
}{{
name: "http",
isTLS: false,
}, {
name: "https",
isTLS: true,
}, {
name: "https with ServerName",
isTLS: true,
tlsServerName: "example.com",
}, {
name: "https with insecureSkipVerify",
isTLS: true,
insecureSkipVerify: true,
}, {
name: "https with netDialTLSContext",
isTLS: true,
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &tls.Dialer{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
return dialer.DialContext(ctx, network, addr)
},
}}

for _, tc := range testcases {
t.Run(tc.name, func(tt *testing.T) {
s := newServer(tt)
defer s.Close()

var ps *httptest.Server
if tc.isTLS {
ps = newTLSProxyServer()
} else {
ps = newProxyServer()
}

psurl, _ := url.Parse(ps.URL)

netDialCalled := false

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(psurl)
if tc.isTLS {
cstDialer.TLSClientConfig = &tls.Config{
RootCAs: rootCAs(tt, ps),
ServerName: tc.tlsServerName,
InsecureSkipVerify: tc.insecureSkipVerify,
}
if tc.netDialTLSContext != nil {
cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialCalled = true
return tc.netDialTLSContext(ctx, network, addr)
}
} else {
netDialCalled = true
}
} else {
netDialCalled = true
}

connect := false
origHandler := ps.Config.Handler

// Capture the request Host header.
ps.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
connect = true
}

origHandler.ServeHTTP(w, r)
})

ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
tt.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(tt, ws)

if !connect {
tt.Error("connect not received")
}
if !netDialCalled {
tt.Error("netDialTLSContext not called")
}
})
}
}

func TestProxyAuthorizationDial(t *testing.T) {
s := newServer(t)
defer s.Close()
Expand Down Expand Up @@ -652,7 +796,7 @@ func TestHost(t *testing.T) {
server *httptest.Server // server to use
url string // host for request URI
header string // optional request host header
tls string // optional host for tls ServerName
tls string // optional host for tlsServerName ServerName
wantAddr string // expected host for dial
wantHeader string // expected request header on server
insecureSkipVerify bool
Expand Down Expand Up @@ -759,7 +903,7 @@ func TestHost(t *testing.T) {
}

check := func(protos map[*httptest.Server]string) {
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tlsServerName.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
if gotAddr != tt.wantAddr {
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
}
Expand Down

0 comments on commit 3dddd1c

Please sign in to comment.