diff --git a/lib/backend/registrybackend/blobclient.go b/lib/backend/registrybackend/blobclient.go index 276636643..98117a51e 100644 --- a/lib/backend/registrybackend/blobclient.go +++ b/lib/backend/registrybackend/blobclient.go @@ -63,7 +63,7 @@ type BlobClient struct { // NewBlobClient creates a new BlobClient. func NewBlobClient(config Config) (*BlobClient, error) { config = config.applyDefaults() - authenticator, err := security.NewAuthenticator(config.Address, config.Security) + authenticator, err := config.Authenticator() if err != nil { return nil, fmt.Errorf("cannot create tag client authenticator: %s", err) } diff --git a/lib/backend/registrybackend/blobclient_test.go b/lib/backend/registrybackend/blobclient_test.go index 65142bac5..570db8da4 100644 --- a/lib/backend/registrybackend/blobclient_test.go +++ b/lib/backend/registrybackend/blobclient_test.go @@ -19,8 +19,10 @@ import ( "io" "net/http" "testing" + "time" "github.com/pressly/chi" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/kraken/core" "github.com/uber/kraken/lib/backend/backenderrors" @@ -125,3 +127,61 @@ func TestBlobDownloadFileNotFound(t *testing.T) { var b bytes.Buffer require.Equal(backenderrors.ErrBlobNotFound, client.Download(namespace, "data", &b)) } + +func TestBlobDownloadHeaderTimeout(t *testing.T) { + require := require.New(t) + + blob := randutil.Blob(32 * memsize.KB) + namespace := core.NamespaceFixture() + + r := chi.NewRouter() + r.Get(fmt.Sprintf("/v2/%s/blobs/{blob}", namespace), func(w http.ResponseWriter, req *http.Request) { + time.Sleep(time.Second) + // ignoring errors here, as this will fail after we timeout below + _, _ = io.Copy(w, bytes.NewReader(blob)) + }) + r.Head(fmt.Sprintf("/v2/%s/blobs/{blob}", namespace), func(w http.ResponseWriter, req *http.Request) { + time.Sleep(time.Second) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(blob))) + }) + addr, stop := testutil.StartServer(r) + defer stop() + + config := newTestConfig(addr) + config.ResponseHeaderTimeout = 100 * time.Millisecond + client, err := NewBlobClient(config) + require.NoError(err) + + _, err = client.Stat(namespace, "data") + if assert.NotNil(t, err) { + assert.Contains(t, err.Error(), "timeout awaiting response headers") + } + + var b bytes.Buffer + err = client.Download(namespace, "data", &b) + if assert.NotNil(t, err) { + assert.Contains(t, err.Error(), "timeout awaiting response headers") + } +} + +// FIXME: debugging failing tests on Travis +//func TestBlobDownloadConnectTimeout(t *testing.T) { +// require := require.New(t) +// +// // unroutable address, courtesy of https://stackoverflow.com/a/904609/4867444 +// config := newTestConfig("10.255.255.1") +// config.ConnectTimeout = 100 * time.Millisecond +// client, err := NewBlobClient(config) +// require.NoError(err) +// +// _, err = client.Stat("dummynamespace", "data") +// if assert.NotNil(t, err) { +// assert.Contains(t, err.Error(), "i/o timeout") +// } +// +// var b bytes.Buffer +// err = client.Download("dummynamespace", "data", &b) +// if assert.NotNil(t, err) { +// assert.Contains(t, err.Error(), "i/o timeout") +// } +//} diff --git a/lib/backend/registrybackend/config.go b/lib/backend/registrybackend/config.go index 16cf40dc2..fd26f83e1 100644 --- a/lib/backend/registrybackend/config.go +++ b/lib/backend/registrybackend/config.go @@ -14,6 +14,8 @@ package registrybackend import ( + "net" + "net/http" "time" "github.com/uber/kraken/lib/backend/registrybackend/security" @@ -21,9 +23,13 @@ import ( // Config defines the registry address, timeout and security options. type Config struct { - Address string `yaml:"address"` - Timeout time.Duration `yaml:"timeout"` - Security security.Config `yaml:"security"` + Address string `yaml:"address"` + Timeout time.Duration `yaml:"timeout"` + // ConnectTimeout limits the time spent establishing the TCP connection (if a new one is needed). + ConnectTimeout time.Duration `yaml:"connect_timeout"` + // ResponseHeaderTimeout limits the time spent reading the headers of the response. + ResponseHeaderTimeout time.Duration `yaml:"response_header_timeout"` + Security security.Config `yaml:"security"` } // Set default configuration @@ -33,3 +39,21 @@ func (c Config) applyDefaults() Config { } return c } + +func (c Config) Authenticator() (security.Authenticator, error) { + transport := http.DefaultTransport.(*http.Transport).Clone() + + if c.ConnectTimeout != 0 { + dialer := &net.Dialer{ + Timeout: c.ConnectTimeout, + KeepAlive: 30 * time.Second, + } + transport.DialContext = dialer.DialContext + } + + if c.ResponseHeaderTimeout != 0 { + transport.ResponseHeaderTimeout = c.ResponseHeaderTimeout + } + + return security.NewAuthenticator(c.Address, c.Security, transport) +} diff --git a/lib/backend/registrybackend/security/security.go b/lib/backend/registrybackend/security/security.go index bbdd79d98..118f0f930 100644 --- a/lib/backend/registrybackend/security/security.go +++ b/lib/backend/registrybackend/security/security.go @@ -72,17 +72,16 @@ type authenticator struct { // address, TLS, and credentials configuration. It supports both basic auth and // token based authentication challenges. If TLS is disabled, no authentication // is attempted. -func NewAuthenticator(address string, config Config) (Authenticator, error) { - rt := http.DefaultTransport.(*http.Transport).Clone() +func NewAuthenticator(address string, config Config, transport *http.Transport) (Authenticator, error) { tlsClientConfig, err := config.TLS.BuildClient() if err != nil { return nil, fmt.Errorf("build tls config for %q: %s", address, err) } - rt.TLSClientConfig = tlsClientConfig + transport.TLSClientConfig = tlsClientConfig return &authenticator{ address: address, config: config, - roundTripper: rt, + roundTripper: transport, credentialStore: newCredentialStore(address, config), challengeManager: challenge.NewSimpleManager(), }, nil diff --git a/lib/backend/registrybackend/tagclient.go b/lib/backend/registrybackend/tagclient.go index de899db73..58b78b629 100644 --- a/lib/backend/registrybackend/tagclient.go +++ b/lib/backend/registrybackend/tagclient.go @@ -64,7 +64,7 @@ type TagClient struct { // NewTagClient creates a new TagClient. func NewTagClient(config Config) (*TagClient, error) { config = config.applyDefaults() - authenticator, err := security.NewAuthenticator(config.Address, config.Security) + authenticator, err := config.Authenticator() if err != nil { return nil, fmt.Errorf("cannot create tag client authenticator: %s", err) } diff --git a/lib/backend/registrybackend/tagclient_test.go b/lib/backend/registrybackend/tagclient_test.go index 4140d133b..88e6aee6c 100644 --- a/lib/backend/registrybackend/tagclient_test.go +++ b/lib/backend/registrybackend/tagclient_test.go @@ -20,8 +20,10 @@ import ( "net/http" "strings" "testing" + "time" "github.com/pressly/chi" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/kraken/core" "github.com/uber/kraken/lib/backend/backenderrors" @@ -97,3 +99,72 @@ func TestTagDownloadFileNotFound(t *testing.T) { var b bytes.Buffer require.Equal(backenderrors.ErrBlobNotFound, client.Download(tag, tag, &b)) } + +func TestTagDownloadHeaderTimeout(t *testing.T) { + require := require.New(t) + + imageConfig := core.NewBlobFixture() + layer1 := core.NewBlobFixture() + layer2 := core.NewBlobFixture() + digest, manifest := dockerutil.ManifestFixture( + imageConfig.Digest, layer1.Digest, layer2.Digest) + + tag := core.TagFixture() + namespace := strings.Split(tag, ":")[0] + + r := chi.NewRouter() + r.Get(fmt.Sprintf("/v2/%s/manifests/{tag}", namespace), func(w http.ResponseWriter, req *http.Request) { + time.Sleep(time.Second) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(manifest))) + w.Header().Set("Docker-Content-Digest", digest.String()) + _, err := io.Copy(w, bytes.NewReader(manifest)) + require.NoError(err) + }) + r.Head(fmt.Sprintf("/v2/%s/manifests/{tag}", namespace), func(w http.ResponseWriter, req *http.Request) { + time.Sleep(time.Second) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(manifest))) + w.Header().Set("Docker-Content-Digest", digest.String()) + _, err := io.Copy(w, bytes.NewReader(manifest)) + require.NoError(err) + }) + addr, stop := testutil.StartServer(r) + defer stop() + + config := newTestConfig(addr) + config.ResponseHeaderTimeout = 100 * time.Millisecond + client, err := NewTagClient(config) + require.NoError(err) + + _, err = client.Stat(tag, tag) + if assert.NotNil(t, err) { + assert.Contains(t, err.Error(), "timeout awaiting response headers") + } + + var b bytes.Buffer + err = client.Download(tag, tag, &b) + if assert.NotNil(t, err) { + assert.Contains(t, err.Error(), "timeout awaiting response headers") + } +} + +// FIXME: debugging failing tests on Travis +//func TestTagDownloadConnectTimeout(t *testing.T) { +// require := require.New(t) +// +// // unroutable address, courtesy of https://stackoverflow.com/a/904609/4867444 +// config := newTestConfig("10.255.255.1") +// config.ConnectTimeout = 100 * time.Millisecond +// client, err := NewTagClient(config) +// require.NoError(err) +// +// _, err = client.Stat("dummynamespace", "image:tag") +// if assert.NotNil(t, err) { +// assert.Contains(t, err.Error(), "i/o timeout") +// } +// +// var b bytes.Buffer +// err = client.Download("dummynamespace", "image:tag", &b) +// if assert.NotNil(t, err) { +// assert.Contains(t, err.Error(), "i/o timeout") +// } +//}