From 2e1cd94a9a222c021c48955fcef4929b5c138b22 Mon Sep 17 00:00:00 2001 From: pyneda <11313340+pyneda@users.noreply.github.com> Date: Thu, 2 Jan 2025 02:32:46 +0100 Subject: [PATCH] ensure browser replay returns request body and task jobs finish before parent completed --- db/tasks.go | 9 +++++++++ pkg/browser/hijack.go | 7 ++++++- pkg/browser/utils.go | 26 ++++++++++++++++++++++---- pkg/scan/engine/engine.go | 18 ++++++++++++++++++ 4 files changed, 55 insertions(+), 5 deletions(-) diff --git a/db/tasks.go b/db/tasks.go index 32fb1cd..c428269 100644 --- a/db/tasks.go +++ b/db/tasks.go @@ -336,3 +336,12 @@ func (d *DatabaseConnection) GetOrCreateDefaultWorkspaceTask(workspaceID uint) ( } return task, result.Error } + +func (d *DatabaseConnection) TaskHasPendingJobs(taskID uint) (bool, error) { + var count int64 + err := d.db.Model(&TaskJob{}). + Where("task_id = ? AND status IN ?", taskID, []TaskJobStatus{TaskJobScheduled, TaskJobRunning}). + Count(&count).Error + + return count > 0, err +} diff --git a/pkg/browser/hijack.go b/pkg/browser/hijack.go index 6b80be7..d651b47 100644 --- a/pkg/browser/hijack.go +++ b/pkg/browser/hijack.go @@ -188,12 +188,17 @@ func DumpHijackRequest(req *rod.HijackRequest) (raw string, body string) { } else { reader := req.Req().Body if reader != nil { - bodyBytes, _ := io.ReadAll(reader) + bodyBytes, err := io.ReadAll(reader) + if err != nil { + log.Error().Err(err).Msg("Error reading request body in DumpHijackRequest") + } body = string(bodyBytes) if len(bodyBytes) > 0 { dump.WriteString("\n") dump.WriteString(body) } + } else { + log.Warn().Msg("DumpHijackRequest request body is empty") } } raw = dump.String() diff --git a/pkg/browser/utils.go b/pkg/browser/utils.go index 1771c81..0a58c79 100644 --- a/pkg/browser/utils.go +++ b/pkg/browser/utils.go @@ -92,6 +92,8 @@ func ReplayRequestInBrowserAndCreateHistory(opts ReplayAndCreateHistoryOptions) opts.Note = "Create history from replay in browser" } + var reqBody []byte + router.MustAdd("*", func(ctx *rod.Hijack) { // https://github.com/go-rod/rod/blob/4c4ccbecdd8110a434de73de08bdbb72e8c47cb0/examples_test.go#L473-L477 if requestHandled { @@ -122,6 +124,7 @@ func ReplayRequestInBrowserAndCreateHistory(opts ReplayAndCreateHistoryOptions) opts.Request.Body = io.NopCloser(newBodyReader) ctx.Request.Req().Body = io.NopCloser(bytes.NewReader(bodyBytes)) ctx.Request.SetBody(bodyBytes) + reqBody = bodyBytes // Set the Content-Length header to the length of the new body contentLength := len(bodyBytes) @@ -134,10 +137,25 @@ func ReplayRequestInBrowserAndCreateHistory(opts ReplayAndCreateHistoryOptions) } history = CreateHistoryFromHijack(ctx.Request, ctx.Response, opts.Source, opts.Note, opts.WorkspaceID, opts.TaskID, opts.PlaygroundSessionID) // NOTE: This shouldn't be necessary, but it seems that the body is not being set on the history object when replaying the request - // if history.RequestBody == nil && len(reqBody) > 0 { - // history.RequestBody = reqBody - // history, _ = db.Connection.UpdateHistory(history) - // } + if len(history.RequestBody) == 0 && len(reqBody) > 0 { + history.RequestBody = reqBody + raw := string(history.RawRequest) + parts := strings.Split(raw, "\n\n") + if len(parts) == 1 { + // No body section yet, add it + history.RawRequest = []byte(raw + "\n\n" + string(reqBody)) + } else { + // Replace existing body section + history.RawRequest = []byte(parts[0] + "\n\n" + string(reqBody)) + } + + history, err = db.Connection.UpdateHistory(history) + if err != nil { + log.Error().Err(err).Msg("Failed to update history with request body") + } else { + log.Debug().Uint("history", history.ID).Msg("Updated history with fixed request body") + } + } }) diff --git a/pkg/scan/engine/engine.go b/pkg/scan/engine/engine.go index af1a2c6..a4611b4 100644 --- a/pkg/scan/engine/engine.go +++ b/pkg/scan/engine/engine.go @@ -280,11 +280,13 @@ func (s *ScanEngine) FullScan(options scan_options.FullScanOptions, waitCompleti if waitCompletion { time.Sleep(2 * time.Second) s.wg.Wait() + waitForTaskCompletion(task.ID) scanLog.Info().Msg("Active scans finished") db.Connection.SetTaskStatus(task.ID, db.TaskStatusFinished) } else { go func() { s.wg.Wait() + waitForTaskCompletion(task.ID) scanLog.Info().Msg("Active scans finished") db.Connection.SetTaskStatus(task.ID, db.TaskStatusFinished) }() @@ -292,3 +294,19 @@ func (s *ScanEngine) FullScan(options scan_options.FullScanOptions, waitCompleti return task, nil } + +func waitForTaskCompletion(taskID uint) { + scanLog := log.With().Uint("task", taskID).Logger() + for { + hasPending, err := db.Connection.TaskHasPendingJobs(taskID) + if err != nil { + scanLog.Error().Err(err).Msg("Error checking pending task jobs") + return + } + if !hasPending { + break + } + time.Sleep(2 * time.Second) + } + db.Connection.SetTaskStatus(taskID, db.TaskStatusFinished) +}