diff --git a/cmd/grype/cli/options/database.go b/cmd/grype/cli/options/database.go index 83fd1770b3c..69882885a81 100644 --- a/cmd/grype/cli/options/database.go +++ b/cmd/grype/cli/options/database.go @@ -12,15 +12,23 @@ import ( ) type Database struct { - Dir string `yaml:"cache-dir" json:"cache-dir" mapstructure:"cache-dir"` - UpdateURL string `yaml:"update-url" json:"update-url" mapstructure:"update-url"` - CACert string `yaml:"ca-cert" json:"ca-cert" mapstructure:"ca-cert"` - AutoUpdate bool `yaml:"auto-update" json:"auto-update" mapstructure:"auto-update"` - ValidateByHashOnStart bool `yaml:"validate-by-hash-on-start" json:"validate-by-hash-on-start" mapstructure:"validate-by-hash-on-start"` - ValidateAge bool `yaml:"validate-age" json:"validate-age" mapstructure:"validate-age"` - MaxAllowedBuiltAge time.Duration `yaml:"max-allowed-built-age" json:"max-allowed-built-age" mapstructure:"max-allowed-built-age"` + Dir string `yaml:"cache-dir" json:"cache-dir" mapstructure:"cache-dir"` + UpdateURL string `yaml:"update-url" json:"update-url" mapstructure:"update-url"` + CACert string `yaml:"ca-cert" json:"ca-cert" mapstructure:"ca-cert"` + AutoUpdate bool `yaml:"auto-update" json:"auto-update" mapstructure:"auto-update"` + ValidateByHashOnStart bool `yaml:"validate-by-hash-on-start" json:"validate-by-hash-on-start" mapstructure:"validate-by-hash-on-start"` + ValidateAge bool `yaml:"validate-age" json:"validate-age" mapstructure:"validate-age"` + MaxAllowedBuiltAge time.Duration `yaml:"max-allowed-built-age" json:"max-allowed-built-age" mapstructure:"max-allowed-built-age"` + UpdateAvailableTimeout time.Duration `yaml:"update-available-timeout" json:"update-available-timeout" mapstructure:"update-available-timeout"` + UpdateDownloadTimeout time.Duration `yaml:"update-download-timeout" json:"update-download-timeout" mapstructure:"update-download-timeout"` } +const ( + defaultMaxDBAge time.Duration = time.Hour * 24 * 5 + defaultUpdateAvailableTimeout = time.Second * 30 + defaultUpdateDownloadTimeout = time.Second * 120 +) + func DefaultDatabase(id clio.Identification) Database { return Database{ Dir: path.Join(xdg.CacheHome, id.Name, "db"), @@ -28,7 +36,9 @@ func DefaultDatabase(id clio.Identification) Database { AutoUpdate: true, ValidateAge: true, // After this period (5 days) the db data is considered stale - MaxAllowedBuiltAge: time.Hour * 24 * 5, + MaxAllowedBuiltAge: defaultMaxDBAge, + UpdateAvailableTimeout: defaultUpdateAvailableTimeout, + UpdateDownloadTimeout: defaultUpdateDownloadTimeout, } } @@ -40,5 +50,7 @@ func (cfg Database) ToCuratorConfig() db.Config { ValidateByHashOnGet: cfg.ValidateByHashOnStart, ValidateAge: cfg.ValidateAge, MaxAllowedBuiltAge: cfg.MaxAllowedBuiltAge, + ListingFileTimeout: cfg.UpdateAvailableTimeout, + UpdateTimeout: cfg.UpdateDownloadTimeout, } } diff --git a/grype/db/curator.go b/grype/db/curator.go index da0b21da1dd..34d2cd92c40 100644 --- a/grype/db/curator.go +++ b/grype/db/curator.go @@ -37,11 +37,14 @@ type Config struct { ValidateByHashOnGet bool ValidateAge bool MaxAllowedBuiltAge time.Duration + ListingFileTimeout time.Duration + UpdateTimeout time.Duration } type Curator struct { fs afero.Fs - downloader file.Getter + listingDownloader file.Getter + updateDownloader file.Getter targetSchema int dbDir string dbPath string @@ -55,15 +58,23 @@ func NewCurator(cfg Config) (Curator, error) { dbDir := path.Join(cfg.DBRootDir, strconv.Itoa(vulnerability.SchemaVersion)) fs := afero.NewOsFs() - httpClient, err := defaultHTTPClient(fs, cfg.CACert) + listingClient, err := defaultHTTPClient(fs, cfg.CACert) if err != nil { return Curator{}, err } + listingClient.Timeout = cfg.ListingFileTimeout + + dbClient, err := defaultHTTPClient(fs, cfg.CACert) + if err != nil { + return Curator{}, err + } + dbClient.Timeout = cfg.UpdateTimeout return Curator{ fs: fs, targetSchema: vulnerability.SchemaVersion, - downloader: file.NewGetter(httpClient), + listingDownloader: file.NewGetter(listingClient), + updateDownloader: file.NewGetter(dbClient), dbDir: dbDir, dbPath: path.Join(dbDir, FileName), listingURL: cfg.ListingURL, @@ -283,7 +294,7 @@ func (c *Curator) download(listing *ListingEntry, downloadProgress *progress.Man url.RawQuery = query.Encode() // go-getter will automatically extract all files within the archive to the temp dir - err = c.downloader.GetToDir(tempDir, listing.URL.String(), downloadProgress) + err = c.updateDownloader.GetToDir(tempDir, listing.URL.String(), downloadProgress) if err != nil { return "", fmt.Errorf("unable to download db: %w", err) } @@ -375,7 +386,7 @@ func (c Curator) ListingFromURL() (Listing, error) { }() // download the listing file - err = c.downloader.GetFile(tempFile.Name(), c.listingURL) + err = c.listingDownloader.GetFile(tempFile.Name(), c.listingURL) if err != nil { return Listing{}, fmt.Errorf("unable to download listing: %w", err) } @@ -390,6 +401,7 @@ func (c Curator) ListingFromURL() (Listing, error) { func defaultHTTPClient(fs afero.Fs, caCertPath string) (*http.Client, error) { httpClient := cleanhttp.DefaultClient() + httpClient.Timeout = 30 * time.Second if caCertPath != "" { rootCAs := x509.NewCertPool() diff --git a/grype/db/curator_test.go b/grype/db/curator_test.go index bb3ee22ac09..c652b1ef78a 100644 --- a/grype/db/curator_test.go +++ b/grype/db/curator_test.go @@ -2,15 +2,18 @@ package db import ( "bufio" + "errors" "fmt" "io" "net/http" + "net/http/httptest" "net/url" "os" "os/exec" "path" "path/filepath" "strconv" + "strings" "syscall" "testing" "time" @@ -68,13 +71,14 @@ func newTestCurator(tb testing.TB, fs afero.Fs, getter file.Getter, dbDir, metad require.NoError(tb, err) - c.downloader = getter + c.listingDownloader = getter + c.updateDownloader = getter c.fs = fs return c } -func Test_defaultHTTPClient(t *testing.T) { +func Test_defaultHTTPClientHasCert(t *testing.T) { tests := []struct { name string hasCert bool @@ -105,11 +109,16 @@ func Test_defaultHTTPClient(t *testing.T) { } else { assert.Nil(t, httpClient.Transport.(*http.Transport).TLSClientConfig) } - }) } } +func Test_defaultHTTPClientTimeout(t *testing.T) { + c, err := defaultHTTPClient(afero.NewMemMapFs(), "") + require.NoError(t, err) + assert.Equal(t, 30*time.Second, c.Timeout) +} + func generateCertFixture(t *testing.T) string { path := "test-fixtures/tls/server.crt" if _, err := os.Stat(path); !os.IsNotExist(err) { @@ -364,3 +373,77 @@ func TestCurator_validateStaleness(t *testing.T) { }) } } + +func TestCuratorTimeoutBehavior(t *testing.T) { + failAfter := 10 * time.Second + success := make(chan struct{}) + errs := make(chan error) + timeout := time.After(failAfter) + + hangForeverHandler := func(w http.ResponseWriter, r *http.Request) { + select {} // hang forever + } + + ts := httptest.NewServer(http.HandlerFunc(hangForeverHandler)) + + cfg := Config{ + DBRootDir: "", + ListingURL: fmt.Sprintf("%s/listing.json", ts.URL), + CACert: "", + ValidateByHashOnGet: false, + ValidateAge: false, + MaxAllowedBuiltAge: 0, + ListingFileTimeout: 400 * time.Millisecond, + UpdateTimeout: 400 * time.Millisecond, + } + + curator, err := NewCurator(cfg) + require.NoError(t, err) + + u, err := url.Parse(fmt.Sprintf("%s/some-db.tar.gz", ts.URL)) + require.NoError(t, err) + + entry := ListingEntry{ + Built: time.Now(), + Version: 5, + URL: u, + Checksum: "83b52a2aa6aff35d208520f40dd36144", + } + + downloadProgress := progress.NewManual(10) + importProgress := progress.NewManual(10) + stage := progress.NewAtomicStage("some-stage") + + runTheTest := func(success chan struct{}, errs chan error) { + _, _, _, err = curator.IsUpdateAvailable() + if err == nil { + errs <- errors.New("expected timeout error but got nil") + return + } + if !strings.Contains(err.Error(), "Timeout exceeded") { + errs <- fmt.Errorf("expected %q but got %q", "Timeout exceeded", err.Error()) + return + } + + err = curator.UpdateTo(&entry, downloadProgress, importProgress, stage) + if err == nil { + errs <- errors.New("expected timeout error but got nil") + return + } + if !strings.Contains(err.Error(), "Timeout exceeded") { + errs <- fmt.Errorf("expected %q but got %q", "Timeout exceeded", err.Error()) + return + } + success <- struct{}{} + } + go runTheTest(success, errs) + + select { + case <-success: + return + case err := <-errs: + t.Error(err) + case <-timeout: + t.Fatalf("timeout exceeded (%v)", failAfter) + } +}