Skip to content

Commit

Permalink
ensure browser replay returns request body and task jobs finish befor…
Browse files Browse the repository at this point in the history
…e parent completed
  • Loading branch information
pyneda committed Jan 2, 2025
1 parent 1ef925a commit 2e1cd94
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 5 deletions.
9 changes: 9 additions & 0 deletions db/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,12 @@ func (d *DatabaseConnection) GetOrCreateDefaultWorkspaceTask(workspaceID uint) (
}
return task, result.Error
}

func (d *DatabaseConnection) TaskHasPendingJobs(taskID uint) (bool, error) {
var count int64
err := d.db.Model(&TaskJob{}).
Where("task_id = ? AND status IN ?", taskID, []TaskJobStatus{TaskJobScheduled, TaskJobRunning}).
Count(&count).Error

return count > 0, err
}
7 changes: 6 additions & 1 deletion pkg/browser/hijack.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,17 @@ func DumpHijackRequest(req *rod.HijackRequest) (raw string, body string) {
} else {
reader := req.Req().Body
if reader != nil {
bodyBytes, _ := io.ReadAll(reader)
bodyBytes, err := io.ReadAll(reader)
if err != nil {
log.Error().Err(err).Msg("Error reading request body in DumpHijackRequest")
}
body = string(bodyBytes)
if len(bodyBytes) > 0 {
dump.WriteString("\n")
dump.WriteString(body)
}
} else {
log.Warn().Msg("DumpHijackRequest request body is empty")
}
}
raw = dump.String()
Expand Down
26 changes: 22 additions & 4 deletions pkg/browser/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ func ReplayRequestInBrowserAndCreateHistory(opts ReplayAndCreateHistoryOptions)
opts.Note = "Create history from replay in browser"
}

var reqBody []byte

router.MustAdd("*", func(ctx *rod.Hijack) {
// https://github.com/go-rod/rod/blob/4c4ccbecdd8110a434de73de08bdbb72e8c47cb0/examples_test.go#L473-L477
if requestHandled {
Expand Down Expand Up @@ -122,6 +124,7 @@ func ReplayRequestInBrowserAndCreateHistory(opts ReplayAndCreateHistoryOptions)
opts.Request.Body = io.NopCloser(newBodyReader)
ctx.Request.Req().Body = io.NopCloser(bytes.NewReader(bodyBytes))
ctx.Request.SetBody(bodyBytes)
reqBody = bodyBytes

// Set the Content-Length header to the length of the new body
contentLength := len(bodyBytes)
Expand All @@ -134,10 +137,25 @@ func ReplayRequestInBrowserAndCreateHistory(opts ReplayAndCreateHistoryOptions)
}
history = CreateHistoryFromHijack(ctx.Request, ctx.Response, opts.Source, opts.Note, opts.WorkspaceID, opts.TaskID, opts.PlaygroundSessionID)
// NOTE: This shouldn't be necessary, but it seems that the body is not being set on the history object when replaying the request
// if history.RequestBody == nil && len(reqBody) > 0 {
// history.RequestBody = reqBody
// history, _ = db.Connection.UpdateHistory(history)
// }
if len(history.RequestBody) == 0 && len(reqBody) > 0 {
history.RequestBody = reqBody
raw := string(history.RawRequest)
parts := strings.Split(raw, "\n\n")
if len(parts) == 1 {
// No body section yet, add it
history.RawRequest = []byte(raw + "\n\n" + string(reqBody))
} else {
// Replace existing body section
history.RawRequest = []byte(parts[0] + "\n\n" + string(reqBody))
}

history, err = db.Connection.UpdateHistory(history)
if err != nil {
log.Error().Err(err).Msg("Failed to update history with request body")
} else {
log.Debug().Uint("history", history.ID).Msg("Updated history with fixed request body")
}
}

})

Expand Down
18 changes: 18 additions & 0 deletions pkg/scan/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,33 @@ func (s *ScanEngine) FullScan(options scan_options.FullScanOptions, waitCompleti
if waitCompletion {
time.Sleep(2 * time.Second)
s.wg.Wait()
waitForTaskCompletion(task.ID)
scanLog.Info().Msg("Active scans finished")
db.Connection.SetTaskStatus(task.ID, db.TaskStatusFinished)
} else {
go func() {
s.wg.Wait()
waitForTaskCompletion(task.ID)
scanLog.Info().Msg("Active scans finished")
db.Connection.SetTaskStatus(task.ID, db.TaskStatusFinished)
}()
}

return task, nil
}

func waitForTaskCompletion(taskID uint) {
scanLog := log.With().Uint("task", taskID).Logger()
for {
hasPending, err := db.Connection.TaskHasPendingJobs(taskID)
if err != nil {
scanLog.Error().Err(err).Msg("Error checking pending task jobs")
return
}
if !hasPending {
break
}
time.Sleep(2 * time.Second)
}
db.Connection.SetTaskStatus(taskID, db.TaskStatusFinished)
}

0 comments on commit 2e1cd94

Please sign in to comment.