From d119df6c199ba432c2c7869258f27506383dbf7d Mon Sep 17 00:00:00 2001 From: Rumen Vasilev Date: Mon, 6 Nov 2023 03:12:58 +0100 Subject: [PATCH] add context to AnalyzeRepositories() Signed-off-by: Rumen Vasilev --- internal/core/analysis.go | 55 +++++++++++-------- .../{generic_worker.go => retrieve_worker.go} | 1 + internal/pkg/scan/github/github.go | 2 +- internal/pkg/scan/gitlab/gitlab.go | 2 +- internal/pkg/scan/localgit/localgit.go | 4 +- 5 files changed, 37 insertions(+), 27 deletions(-) rename internal/core/{generic_worker.go => retrieve_worker.go} (97%) diff --git a/internal/core/analysis.go b/internal/core/analysis.go index 8fa6e16..cfcc740 100644 --- a/internal/core/analysis.go +++ b/internal/core/analysis.go @@ -2,6 +2,7 @@ package core import ( + "context" "os" "strconv" "strings" @@ -26,7 +27,7 @@ import ( // are controlled by flags. If a directory, file, or the content pass through all of the filters then // it is scanned once per each signature which may lead to a specific secret matching multiple rules // and then generating multiple findings. -func AnalyzeRepositories(sess *session.Session, st *stats.Stats) { +func AnalyzeRepositories(ctx context.Context, sess *session.Session, st *stats.Stats) { log := log.Log st.UpdateStatus(stats.StatusAnalyzing) repoCnt := len(sess.State.Repositories) @@ -54,7 +55,7 @@ func AnalyzeRepositories(sess *session.Session, st *stats.Stats) { // Start analyzer workers for i := 0; i < threadNum; i++ { - go analyzeWorker(i, ch, &wg, sess, st) + go analyzeWorker(ctx, i, &wg, ch, sess, st) } // Feed repos to the analyzer workers @@ -68,36 +69,42 @@ func AnalyzeRepositories(sess *session.Session, st *stats.Stats) { wg.Wait() } -func analyzeWorker(tid int, ch chan coreapi.Repository, wg *sync.WaitGroup, sess *session.Session, st *stats.Stats) { +func analyzeWorker(ctx context.Context, workerID int, wg *sync.WaitGroup, ch chan coreapi.Repository, sess *session.Session, st *stats.Stats) { log := log.Log for { - log.Debug("[THREAD #%d] Requesting new repository to analyze...", tid) - repo, ok := <-ch - if !ok { - log.Debug("[THREAD #%d] No more tasks, marking WaitGroup done", tid) + select { + case <-ctx.Done(): + log.Info("Job cancellation requested.") wg.Done() return - } + case repo, ok := <-ch: + log.Debug("[THREAD #%d] Requesting new repository to analyze...", workerID) + if !ok { + log.Info("[THREAD #%d] No more repositories to analyze", workerID) + wg.Done() + return + } - // Clone the repository from the remote source or if a local repo from the path - // The path variable is returning the path that the clone was done to. The repo is cloned directly - // there. - log.Debug("[THREAD #%d][%s] Cloning repository...", tid, repo.CloneURL) - clone, path, err := cloneRepository(sess.Config, st.IncrementRepositoriesCloned, repo) - if err != nil { - log.Error("%v", err) - cleanUpPath(path) - continue - } - log.Debug("[THREAD #%d][%s] Cloned repository to: %s", tid, repo.CloneURL, path) + // Clone the repository from the remote source or if a local repo from the path + // The path variable is returning the path that the clone was done to. The repo is cloned directly + // there. + log.Debug("[THREAD #%d][%s] Cloning repository...", workerID, repo.CloneURL) + clone, path, err := cloneRepository(sess.Config, st.IncrementRepositoriesCloned, repo) + if err != nil { + log.Error("%v", err) + cleanUpPath(path) + continue + } + log.Debug("[THREAD #%d][%s] Cloned repository to: %s", workerID, repo.CloneURL, path) - analyzeHistory(sess, clone, tid, path, repo) + analyzeHistory(sess, clone, workerID, path, repo) - log.Debug("[THREAD #%d][%s] Done analyzing commits", tid, repo.CloneURL) - log.Debug("[THREAD #%d][%s] Deleted %s", tid, repo.CloneURL, path) + log.Debug("[THREAD #%d][%s] Done analyzing commits", workerID, repo.CloneURL) + log.Debug("[THREAD #%d][%s] Deleted %s", workerID, repo.CloneURL, path) - cleanUpPath(path) - st.IncrementRepositoriesScanned() + cleanUpPath(path) + st.IncrementRepositoriesScanned() + } } } diff --git a/internal/core/generic_worker.go b/internal/core/retrieve_worker.go similarity index 97% rename from internal/core/generic_worker.go rename to internal/core/retrieve_worker.go index 6b6a968..175c6f3 100644 --- a/internal/core/generic_worker.go +++ b/internal/core/retrieve_worker.go @@ -46,6 +46,7 @@ func retrieveReposWorker(ctx context.Context, workerID int, wg *sync.WaitGroup, return case target, ok := <-ch: if !ok { + log.Debug("[THREAD #%d]: No more targets to retrieve", workerID) wg.Done() return } diff --git a/internal/pkg/scan/github/github.go b/internal/pkg/scan/github/github.go index 860a94e..8446177 100644 --- a/internal/pkg/scan/github/github.go +++ b/internal/pkg/scan/github/github.go @@ -86,7 +86,7 @@ func (g Github) Run() error { } core.GatherRepositories(ctx, sess) - core.AnalyzeRepositories(sess, sess.State.Stats) + core.AnalyzeRepositories(ctx, sess, sess.State.Stats) sess.Finish() err = output.Summary(sess.State, sess.Config.Global, sess.SignatureVersion) diff --git a/internal/pkg/scan/gitlab/gitlab.go b/internal/pkg/scan/gitlab/gitlab.go index 34e8664..563e11a 100644 --- a/internal/pkg/scan/gitlab/gitlab.go +++ b/internal/pkg/scan/gitlab/gitlab.go @@ -46,7 +46,7 @@ func (g Gitlab) Run() error { core.GatherTargets(sess) core.GatherRepositories(ctx, sess) - core.AnalyzeRepositories(sess, sess.State.Stats) + core.AnalyzeRepositories(ctx, sess, sess.State.Stats) sess.Finish() err = output.Summary(sess.State, sess.Config.Global, sess.SignatureVersion) diff --git a/internal/pkg/scan/localgit/localgit.go b/internal/pkg/scan/localgit/localgit.go index 309f3cc..c18829b 100644 --- a/internal/pkg/scan/localgit/localgit.go +++ b/internal/pkg/scan/localgit/localgit.go @@ -1,6 +1,7 @@ package localgit import ( + "context" "time" "github.com/rumenvasilev/rvsecret/internal/config" @@ -20,6 +21,7 @@ type LocalGit struct { func (l LocalGit) Run() error { cfg := l.Cfg log := log.Log + ctx := context.Background() // create session sess, err := session.NewWithConfig(cfg) if err != nil { @@ -46,7 +48,7 @@ func (l LocalGit) Run() error { if err != nil { return err } - core.AnalyzeRepositories(sess, sess.State.Stats) + core.AnalyzeRepositories(ctx, sess, sess.State.Stats) sess.Finish() err = output.Summary(sess.State, sess.Config.Global, sess.SignatureVersion)