Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature issue #479 #602

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ type Dialer struct {
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar

// custom proxy connect header
ProxyConnectHeader http.Header
}

// Dial creates a new client connection by calling DialContext with a background context.
Expand Down Expand Up @@ -274,7 +277,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
dialer, err := proxy_FromURL(proxyURL, &netDialer{d.ProxyConnectHeader, netDial})
if err != nil {
return nil, nil, err
}
Expand Down
7 changes: 7 additions & 0 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ func TestProxyDial(t *testing.T) {

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(surl)
cstDialer.ProxyConnectHeader = map[string][]string{
"User-Agents": {"xxx"},
}

connect := false
origHandler := s.Server.Config.Handler
Expand All @@ -166,6 +169,10 @@ func TestProxyDial(t *testing.T) {
if r.Method == "CONNECT" {
connect = true
w.WriteHeader(http.StatusOK)
if r.Header.Get("User-Agents") != "xxx" {
t.Log("xxx not found in the request header")
http.Error(w, "header xxx not found", http.StatusMethodNotAllowed)
}
return
}

Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/gorilla/websocket

go 1.12

require golang.org/x/tools v0.0.0-20200619210111-0f592d2728bb
18 changes: 18 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619210111-0f592d2728bb h1:/7SQoPdMxZ0c/Zu9tBJgMbRE/BmK6i9QXflNJXKAmw0=
golang.org/x/tools v0.0.0-20200619210111-0f592d2728bb/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
23 changes: 19 additions & 4 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,32 @@ import (
"strings"
)

type netDialerFunc func(network, addr string) (net.Conn, error)
// type netDialerFunc func(network, addr string) (net.Conn, error)
//
// func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
// return fn(network, addr)
// }
type netDialer struct {
proxyHeader http.Header
f func(network, addr string) (net.Conn, error)
}

func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
func (n netDialer) Dial(network, addr string) (net.Conn, error) {
return n.f(network, addr)
}

func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
p, _ := forwardDialer.(*netDialer)
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, proxyHeader: p.proxyHeader}, nil
// return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
})
}

type httpProxyDialer struct {
proxyURL *url.URL
forwardDial func(network, addr string) (net.Conn, error)
proxyHeader http.Header
}

func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
Expand All @@ -47,6 +58,10 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
}
}

for k, v := range hpd.proxyHeader {
connectHeader[k] = v
}

connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: addr},
Expand Down
31 changes: 25 additions & 6 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package websocket
import (
"bufio"
"errors"
"fmt"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -44,6 +45,7 @@ type Upgrader struct {
// WriteBufferSize.
WriteBufferPool BufferPool

// Subprotocols have lower priority than NegotiateSuprotocol.
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
Expand All @@ -70,6 +72,13 @@ type Upgrader struct {
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
// NegotiateSubprotocol has higher priority than Subprotocols.
// NegotiateSubprotocol returns the negotiated subprotocol for the handshake
// request. If the returned string is "", then the the Sec-Websocket-Protocol header
// is not included in the handshake response. If the function returns an error, then
// Upgrade responds to the client with http.StatusBadRequest.
// If this function is not nil, then the Upgrader.Subportocols field is ignored.
NegotiateSubprotocol func(r *http.Request) (string, error)
}

func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
Expand All @@ -96,7 +105,7 @@ func checkSameOrigin(r *http.Request) bool {
return equalASCIIFold(u.Host, r.Host)
}

func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
func (u *Upgrader) selectSubprotocol(r *http.Request) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
Expand All @@ -106,20 +115,21 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
}
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
}

// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-WebSocket-Protocol).
// request. Use the responseHeader to specify cookies (Set-Cookie).
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
//
// The responseHeader does not support negotiated subprotocol(Sec-Websocket-Protocol)
// IF necessary,please use Upgrader.NegotiateSubprotocol and Upgrader.Subprotocols
// Use the method to view the Upgrader struct.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
const badHandshake = "websocket: the client is not using the websocket protocol: "

Expand Down Expand Up @@ -156,7 +166,16 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
}

subprotocol := u.selectSubprotocol(r, responseHeader)
subprotocol := ""
if u.NegotiateSubprotocol != nil {
str, err := u.NegotiateSubprotocol(r)
if err != nil {
return u.returnError(w, r, http.StatusBadRequest, fmt.Sprintf("websocket:handshake negotiation protocol error:%s", err))
}
subprotocol = str
} else {
subprotocol = u.selectSubprotocol(r)
}

// Negotiate PMCE
var compress bool
Expand Down
73 changes: 73 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package websocket
import (
"bufio"
"bytes"
"errors"
"net"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -117,3 +119,74 @@ func TestBufioReuse(t *testing.T) {
}
}
}

var negotiateSubprotocolTests = []struct {
*Upgrader
match bool
shouldErr bool
}{
{
&Upgrader{
NegotiateSubprotocol: func(r *http.Request) (s string, err error) { return "json", nil },
}, true, false,
},
{
&Upgrader{
Subprotocols: []string{"json"},
}, true, false,
},
{
&Upgrader{
Subprotocols: []string{"not-match"},
}, false, false,
},
{
&Upgrader{
NegotiateSubprotocol: func(r *http.Request) (s string, err error) { return "", errors.New("not-match") },
}, false, true,
},
}

func TestNegotiateSubprotocol(t *testing.T) {
for i := range negotiateSubprotocolTests {
upgrade := negotiateSubprotocolTests[i].Upgrader

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgrade.Upgrade(w, r, nil)
}))

req, err := http.NewRequest("GET", s.URL, strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest retuened error %v", err)
}

req.Header.Set("Connection", "upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-Websocket-Version", "13")
req.Header.Set("Sec-Websocket-Protocol", "json")
req.Header.Set("Sec-Websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Do returned error %v", err)
}

if negotiateSubprotocolTests[i].shouldErr && resp.StatusCode != http.StatusBadRequest {
t.Errorf("The expecred status code is %d,actual status code is %d", http.StatusBadRequest, resp.StatusCode)
} else {
if negotiateSubprotocolTests[i].match {
protocol := resp.Header.Get("Sec-Websocket-Protocol")
if protocol != "json" {
t.Errorf("Negotiation protocol failed,request protocol is json,reponese protocol is %s", protocol)
}
} else {
if _, ok := resp.Header["Sec-Websocket-Protocol"]; ok {
t.Errorf("Negotiation protocol failed,Sec-Websocket-Protocol field should be empty")
}
}
}
s.Close()
resp.Body.Close()
}

}