diff --git a/client/client.go b/client/client.go index 3677a69f..d6d2fa44 100644 --- a/client/client.go +++ b/client/client.go @@ -6,6 +6,8 @@ import ( "crypto/tls" "io" "net" + "strings" + "sync" "time" "github.com/valyala/fasthttp" @@ -27,8 +29,9 @@ const ( // Client is the implementation of HTTP1 client which sets body as a // stream. type Client struct { - dialer Dialer - tlsConfig *tls.Config + dialer Dialer + tlsConfigMap map[string]*tls.Config + tlsConfigMapLock sync.Mutex } // DoTimeout does HTTP request with the given timeout. @@ -71,7 +74,7 @@ func (c *Client) do(req *fasthttp.Request, resp *fasthttp.Response, readTimeout, } if !isHTTP { - conn = tls.Client(conn, c.tlsConfig) + conn = tls.Client(conn, c.cachedTLSConfig(addr)) } if err = conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { @@ -134,10 +137,54 @@ func (c *Client) do(req *fasthttp.Request, resp *fasthttp.Response, readTimeout, return nil } +func (c *Client) cachedTLSConfig(addr string) *tls.Config { + c.tlsConfigMapLock.Lock() + + cfg := c.tlsConfigMap[addr] + if cfg == nil { + cfg = newClientTLSConfig(addr) + c.tlsConfigMap[addr] = cfg + } + c.tlsConfigMapLock.Unlock() + + return cfg +} + +func newClientTLSConfig(addr string) *tls.Config { + c := &tls.Config{} + if c.ClientSessionCache == nil { + c.ClientSessionCache = tls.NewLRUClientSessionCache(0) + } + + if len(c.ServerName) == 0 { + serverName := tlsServerName(addr) + if serverName == "*" { + c.InsecureSkipVerify = true + } else { + c.ServerName = serverName + } + } + + return c +} + +func tlsServerName(addr string) string { + if !strings.Contains(addr, ":") { + return addr + } + + host, _, err := net.SplitHostPort(addr) + if err != nil { + return "*" + } + + return host +} + // NewClient creates a new instance of HTTP1 client. -func NewClient(dialer Dialer, tlsConfig *tls.Config) *Client { +func NewClient(dialer Dialer) *Client { return &Client{ - dialer: dialer, - tlsConfig: tlsConfig, + dialer: dialer, + tlsConfigMap: make(map[string]*tls.Config), } } diff --git a/client/client_test.go b/client/client_test.go index 8cf4de11..befbf5a3 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,7 +1,6 @@ package client import ( - "crypto/tls" "net/http" "net/http/httptest" "testing" @@ -51,7 +50,7 @@ func (suite *ClientTestSuite) TearDownSuite() { func (suite *ClientTestSuite) TestStreamContent() { dialer, _ := NewSimpleDialer(FastHTTPBaseDialer, time.Second) - client := NewClient(dialer, &tls.Config{InsecureSkipVerify: true}) // nolint: gosec + client := NewClient(dialer) // nolint: gosec req := &fasthttp.Request{} resp := &fasthttp.Response{} @@ -66,7 +65,7 @@ func (suite *ClientTestSuite) TestPooledDialer() { dialer, _ := NewPooledDialer(FastHTTPBaseDialer, time.Second, 5) go dialer.Run() - client := NewClient(dialer, &tls.Config{InsecureSkipVerify: true}) // nolint: gosec + client := NewClient(dialer) // nolint: gosec req := &fasthttp.Request{} resp := &fasthttp.Response{} @@ -85,7 +84,7 @@ func (suite *ClientTestSuite) TestPooledDialerSameSocket() { dialer, _ := NewPooledDialer(FastHTTPBaseDialer, time.Second, 5) go dialer.Run() - client := NewClient(dialer, &tls.Config{InsecureSkipVerify: true}) // nolint: gosec + client := NewClient(dialer) // nolint: gosec req := &fasthttp.Request{} resp := &fasthttp.Response{} @@ -113,7 +112,7 @@ func (suite *ClientTestSuite) TestBrokenResponseHeader() { go dialer.Run() - client := NewClient(dialer, &tls.Config{InsecureSkipVerify: true}) // nolint: gosec + client := NewClient(dialer) // nolint: gosec req := &fasthttp.Request{} resp := &fasthttp.Response{} diff --git a/http_clients.go b/http_clients.go index 8d7c2335..efa114cd 100644 --- a/http_clients.go +++ b/http_clients.go @@ -3,7 +3,6 @@ package httransform import ( "bufio" "bytes" - "crypto/tls" "net" "net/url" "time" @@ -30,10 +29,6 @@ const ( DefaultHTTPTImeout = 3 * time.Minute ) -var ( - defaultTLSConfig = &tls.Config{InsecureSkipVerify: true} // nolint: gosec -) - // HTTPRequestExecutor is an interface to be used for ExecuteRequest and // ExecuteRequestTimeout functions. type HTTPRequestExecutor interface { @@ -165,12 +160,12 @@ func MakeDefaultCONNECTProxyClient(proxyURL *url.URL) HTTPRequestExecutor { func makeStreamingClosingHTTPClient(dialer client.BaseDialer) HTTPRequestExecutor { newDialer, _ := client.NewSimpleDialer(dialer, ConnectDialTimeout) - return client.NewClient(newDialer, defaultTLSConfig) + return client.NewClient(newDialer) } func makeStreamingPooledHTTPClient(dialer client.BaseDialer) HTTPRequestExecutor { newDialer, _ := client.NewPooledDialer(dialer, ConnectDialTimeout, MaxConnsPerHost) - return client.NewClient(newDialer, defaultTLSConfig) + return client.NewClient(newDialer) } func makeDefaultHTTPClient(dialFunc fasthttp.DialFunc) *fasthttp.Client { @@ -179,7 +174,6 @@ func makeDefaultHTTPClient(dialFunc fasthttp.DialFunc) *fasthttp.Client { DialDualStack: true, DisableHeaderNamesNormalizing: true, MaxConnsPerHost: MaxConnsPerHost, - TLSConfig: defaultTLSConfig, } }