diff --git a/grype/db/curator_test.go b/grype/db/curator_test.go index 5efa9e0c07e..c652b1ef78a 100644 --- a/grype/db/curator_test.go +++ b/grype/db/curator_test.go @@ -2,15 +2,18 @@ package db import ( "bufio" + "errors" "fmt" "io" "net/http" + "net/http/httptest" "net/url" "os" "os/exec" "path" "path/filepath" "strconv" + "strings" "syscall" "testing" "time" @@ -370,3 +373,77 @@ func TestCurator_validateStaleness(t *testing.T) { }) } } + +func TestCuratorTimeoutBehavior(t *testing.T) { + failAfter := 10 * time.Second + success := make(chan struct{}) + errs := make(chan error) + timeout := time.After(failAfter) + + hangForeverHandler := func(w http.ResponseWriter, r *http.Request) { + select {} // hang forever + } + + ts := httptest.NewServer(http.HandlerFunc(hangForeverHandler)) + + cfg := Config{ + DBRootDir: "", + ListingURL: fmt.Sprintf("%s/listing.json", ts.URL), + CACert: "", + ValidateByHashOnGet: false, + ValidateAge: false, + MaxAllowedBuiltAge: 0, + ListingFileTimeout: 400 * time.Millisecond, + UpdateTimeout: 400 * time.Millisecond, + } + + curator, err := NewCurator(cfg) + require.NoError(t, err) + + u, err := url.Parse(fmt.Sprintf("%s/some-db.tar.gz", ts.URL)) + require.NoError(t, err) + + entry := ListingEntry{ + Built: time.Now(), + Version: 5, + URL: u, + Checksum: "83b52a2aa6aff35d208520f40dd36144", + } + + downloadProgress := progress.NewManual(10) + importProgress := progress.NewManual(10) + stage := progress.NewAtomicStage("some-stage") + + runTheTest := func(success chan struct{}, errs chan error) { + _, _, _, err = curator.IsUpdateAvailable() + if err == nil { + errs <- errors.New("expected timeout error but got nil") + return + } + if !strings.Contains(err.Error(), "Timeout exceeded") { + errs <- fmt.Errorf("expected %q but got %q", "Timeout exceeded", err.Error()) + return + } + + err = curator.UpdateTo(&entry, downloadProgress, importProgress, stage) + if err == nil { + errs <- errors.New("expected timeout error but got nil") + return + } + if !strings.Contains(err.Error(), "Timeout exceeded") { + errs <- fmt.Errorf("expected %q but got %q", "Timeout exceeded", err.Error()) + return + } + success <- struct{}{} + } + go runTheTest(success, errs) + + select { + case <-success: + return + case err := <-errs: + t.Error(err) + case <-timeout: + t.Fatalf("timeout exceeded (%v)", failAfter) + } +}