Skip to content

Commit

Permalink
Merge pull request #7 from azlotnikov/fix-tls-config
Browse files Browse the repository at this point in the history
Add tls config cache, set ServerName value
  • Loading branch information
9seconds authored Dec 19, 2019
2 parents e78ffda + 398fbcc commit 64554e6
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
59 changes: 53 additions & 6 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"crypto/tls"
"io"
"net"
"strings"
"sync"
"time"

"github.com/valyala/fasthttp"
Expand All @@ -27,8 +29,9 @@ const (
// Client is the implementation of HTTP1 client which sets body as a
// stream.
type Client struct {
dialer Dialer
tlsConfig *tls.Config
dialer Dialer
tlsConfigMap map[string]*tls.Config
tlsConfigMapLock sync.Mutex
}

// DoTimeout does HTTP request with the given timeout.
Expand Down Expand Up @@ -71,7 +74,7 @@ func (c *Client) do(req *fasthttp.Request, resp *fasthttp.Response, readTimeout,
}

if !isHTTP {
conn = tls.Client(conn, c.tlsConfig)
conn = tls.Client(conn, c.cachedTLSConfig(addr))
}

if err = conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
Expand Down Expand Up @@ -134,10 +137,54 @@ func (c *Client) do(req *fasthttp.Request, resp *fasthttp.Response, readTimeout,
return nil
}

func (c *Client) cachedTLSConfig(addr string) *tls.Config {
c.tlsConfigMapLock.Lock()

cfg := c.tlsConfigMap[addr]
if cfg == nil {
cfg = newClientTLSConfig(addr)
c.tlsConfigMap[addr] = cfg
}
c.tlsConfigMapLock.Unlock()

return cfg
}

func newClientTLSConfig(addr string) *tls.Config {
c := &tls.Config{}
if c.ClientSessionCache == nil {
c.ClientSessionCache = tls.NewLRUClientSessionCache(0)
}

if len(c.ServerName) == 0 {
serverName := tlsServerName(addr)
if serverName == "*" {
c.InsecureSkipVerify = true
} else {
c.ServerName = serverName
}
}

return c
}

func tlsServerName(addr string) string {
if !strings.Contains(addr, ":") {
return addr
}

host, _, err := net.SplitHostPort(addr)
if err != nil {
return "*"
}

return host
}

// NewClient creates a new instance of HTTP1 client.
func NewClient(dialer Dialer, tlsConfig *tls.Config) *Client {
func NewClient(dialer Dialer) *Client {
return &Client{
dialer: dialer,
tlsConfig: tlsConfig,
dialer: dialer,
tlsConfigMap: make(map[string]*tls.Config),
}
}
9 changes: 4 additions & 5 deletions client/client_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package client

import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -51,7 +50,7 @@ func (suite *ClientTestSuite) TearDownSuite() {

func (suite *ClientTestSuite) TestStreamContent() {
dialer, _ := NewSimpleDialer(FastHTTPBaseDialer, time.Second)
client := NewClient(dialer, &tls.Config{InsecureSkipVerify: true}) // nolint: gosec
client := NewClient(dialer) // nolint: gosec
req := &fasthttp.Request{}
resp := &fasthttp.Response{}

Expand All @@ -66,7 +65,7 @@ func (suite *ClientTestSuite) TestPooledDialer() {
dialer, _ := NewPooledDialer(FastHTTPBaseDialer, time.Second, 5)
go dialer.Run()

client := NewClient(dialer, &tls.Config{InsecureSkipVerify: true}) // nolint: gosec
client := NewClient(dialer) // nolint: gosec
req := &fasthttp.Request{}
resp := &fasthttp.Response{}

Expand All @@ -85,7 +84,7 @@ func (suite *ClientTestSuite) TestPooledDialerSameSocket() {
dialer, _ := NewPooledDialer(FastHTTPBaseDialer, time.Second, 5)
go dialer.Run()

client := NewClient(dialer, &tls.Config{InsecureSkipVerify: true}) // nolint: gosec
client := NewClient(dialer) // nolint: gosec
req := &fasthttp.Request{}
resp := &fasthttp.Response{}

Expand Down Expand Up @@ -113,7 +112,7 @@ func (suite *ClientTestSuite) TestBrokenResponseHeader() {

go dialer.Run()

client := NewClient(dialer, &tls.Config{InsecureSkipVerify: true}) // nolint: gosec
client := NewClient(dialer) // nolint: gosec
req := &fasthttp.Request{}
resp := &fasthttp.Response{}

Expand Down
10 changes: 2 additions & 8 deletions http_clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package httransform
import (
"bufio"
"bytes"
"crypto/tls"
"net"
"net/url"
"time"
Expand All @@ -30,10 +29,6 @@ const (
DefaultHTTPTImeout = 3 * time.Minute
)

var (
defaultTLSConfig = &tls.Config{InsecureSkipVerify: true} // nolint: gosec
)

// HTTPRequestExecutor is an interface to be used for ExecuteRequest and
// ExecuteRequestTimeout functions.
type HTTPRequestExecutor interface {
Expand Down Expand Up @@ -165,12 +160,12 @@ func MakeDefaultCONNECTProxyClient(proxyURL *url.URL) HTTPRequestExecutor {

func makeStreamingClosingHTTPClient(dialer client.BaseDialer) HTTPRequestExecutor {
newDialer, _ := client.NewSimpleDialer(dialer, ConnectDialTimeout)
return client.NewClient(newDialer, defaultTLSConfig)
return client.NewClient(newDialer)
}

func makeStreamingPooledHTTPClient(dialer client.BaseDialer) HTTPRequestExecutor {
newDialer, _ := client.NewPooledDialer(dialer, ConnectDialTimeout, MaxConnsPerHost)
return client.NewClient(newDialer, defaultTLSConfig)
return client.NewClient(newDialer)
}

func makeDefaultHTTPClient(dialFunc fasthttp.DialFunc) *fasthttp.Client {
Expand All @@ -179,7 +174,6 @@ func makeDefaultHTTPClient(dialFunc fasthttp.DialFunc) *fasthttp.Client {
DialDualStack: true,
DisableHeaderNamesNormalizing: true,
MaxConnsPerHost: MaxConnsPerHost,
TLSConfig: defaultTLSConfig,
}
}

Expand Down

0 comments on commit 64554e6

Please sign in to comment.