diff --git a/reaper/client.go b/reaper/client.go index 73234be..fb597c8 100644 --- a/reaper/client.go +++ b/reaper/client.go @@ -28,7 +28,7 @@ type Client interface { DeleteCluster(ctx context.Context, cluster string) error // GetRepairRuns returns a list of repair runs, optionally filtering according to the provided search options. - GetRepairRuns(ctx context.Context, searchOptions *RepairRunSearchOptions) ([]*RepairRun, error) + GetRepairRuns(ctx context.Context, searchOptions *RepairRunSearchOptions) (map[uuid.UUID]*RepairRun, error) // GetRepairRun returns a repair run object identified by its id. GetRepairRun(ctx context.Context, repairRunId uuid.UUID) (*RepairRun, error) @@ -58,7 +58,7 @@ type Client interface { ResumeRepairRun(ctx context.Context, repairRunId uuid.UUID) error // GetRepairRunSegments returns the list of segments of a repair run identified by its id. - GetRepairRunSegments(ctx context.Context, repairRunId uuid.UUID) ([]*RepairSegment, error) + GetRepairRunSegments(ctx context.Context, repairRunId uuid.UUID) (map[uuid.UUID]*RepairSegment, error) // AbortRepairRunSegment aborts a running segment and puts it back in NOT_STARTED state. The segment will be // processed again later during the lifetime of the repair run. diff --git a/reaper/repair_run.go b/reaper/repair_run.go index 6cf5582..cbf7ae5 100644 --- a/reaper/repair_run.go +++ b/reaper/repair_run.go @@ -192,13 +192,17 @@ type RepairRunCreateOptions struct { RepairThreadCount int `url:"repairThreadCount,omitempty"` } -func (c *client) GetRepairRuns(ctx context.Context, searchOptions *RepairRunSearchOptions) ([]*RepairRun, error) { +func (c *client) GetRepairRuns(ctx context.Context, searchOptions *RepairRunSearchOptions) (map[uuid.UUID]*RepairRun, error) { res, err := c.doGet(ctx, "/repair_run", searchOptions, http.StatusOK) if err == nil { repairRuns := make([]*RepairRun, 0) err = c.readBodyAsJson(res, &repairRuns) if err == nil { - return repairRuns, nil + repairRunsMap := make(map[uuid.UUID]*RepairRun, len(repairRuns)) + for _, repairRun := range repairRuns { + repairRunsMap[repairRun.Id] = repairRun + } + return repairRunsMap, nil } } return nil, fmt.Errorf("failed to get repair runs: %w", err) @@ -281,14 +285,18 @@ func (c *client) ResumeRepairRun(ctx context.Context, repairRunId uuid.UUID) err return fmt.Errorf("failed to resume repair run %v: %w", repairRunId, err) } -func (c *client) GetRepairRunSegments(ctx context.Context, repairRunId uuid.UUID) ([]*RepairSegment, error) { +func (c *client) GetRepairRunSegments(ctx context.Context, repairRunId uuid.UUID) (map[uuid.UUID]*RepairSegment, error) { path := fmt.Sprint("/repair_run/", repairRunId, "/segments") res, err := c.doGet(ctx, path, nil, http.StatusOK) if err == nil { repairRunSegments := make([]*RepairSegment, 0) err = c.readBodyAsJson(res, &repairRunSegments) if err == nil { - return repairRunSegments, nil + repairRunSegmentsMap := make(map[uuid.UUID]*RepairSegment, len(repairRunSegments)) + for _, segment := range repairRunSegments { + repairRunSegmentsMap[segment.Id] = segment + } + return repairRunSegmentsMap, nil } } return nil, fmt.Errorf("failed to get segments of repair run %v: %w", repairRunId, err) diff --git a/reaper/repair_run_test.go b/reaper/repair_run_test.go index 8b882a1..090fd65 100644 --- a/reaper/repair_run_test.go +++ b/reaper/repair_run_test.go @@ -58,8 +58,8 @@ func testGetRepairRuns(t *testing.T, client Client) { repairRuns, err := client.GetRepairRuns(context.Background(), nil) require.Nil(t, err) assert.Len(t, repairRuns, 2) - assert.Contains(t, repairRuns, run1) - assert.Contains(t, repairRuns, run2) + assert.Contains(t, repairRuns, run1.Id) + assert.Contains(t, repairRuns, run2.Id) } func testGetRepairRunsFilteredByCluster(t *testing.T, client Client) { @@ -75,8 +75,8 @@ func testGetRepairRunsFilteredByCluster(t *testing.T, client Client) { ) require.Nil(t, err) assert.Len(t, repairRuns, 1) - assert.Contains(t, repairRuns, run1) - assert.NotContains(t, repairRuns, run2) + assert.Contains(t, repairRuns, run1.Id) + assert.NotContains(t, repairRuns, run2.Id) } func testGetRepairRunsFilteredByKeyspace(t *testing.T, client Client) { @@ -92,8 +92,8 @@ func testGetRepairRunsFilteredByKeyspace(t *testing.T, client Client) { ) require.Nil(t, err) assert.Len(t, repairRuns, 2) - assert.Contains(t, repairRuns, run1) - assert.Contains(t, repairRuns, run2) + assert.Contains(t, repairRuns, run1.Id) + assert.Contains(t, repairRuns, run2.Id) repairRuns, err = client.GetRepairRuns( context.Background(), &RepairRunSearchOptions{ @@ -117,8 +117,8 @@ func testGetRepairRunsFilteredByState(t *testing.T, client Client) { ) require.Nil(t, err) assert.Len(t, repairRuns, 2) - assert.Contains(t, repairRuns, run1) - assert.Contains(t, repairRuns, run2) + assert.Contains(t, repairRuns, run1.Id) + assert.Contains(t, repairRuns, run2.Id) repairRuns, err = client.GetRepairRuns( context.Background(), &RepairRunSearchOptions{ @@ -147,12 +147,11 @@ func testCreateStartFinishRepairRun(t *testing.T, client Client) { segments, err := client.GetRepairRunSegments(context.Background(), done.Id) require.Nil(t, err) assert.Len(t, segments, 2) - assert.NotNil(t, segments[0].StartTime) - assert.NotNil(t, segments[1].StartTime) - assert.NotNil(t, segments[0].EndTime) - assert.NotNil(t, segments[1].EndTime) - assert.Equal(t, RepairSegmentStateDone, segments[0].State) - assert.Equal(t, RepairSegmentStateDone, segments[1].State) + for _, segment := range segments { + assert.Equal(t, RepairSegmentStateDone, segment.State) + assert.NotNil(t, segment.StartTime) + assert.NotNil(t, segment.EndTime) + } } func testCreateStartPauseUpdateResumeRepairRun(t *testing.T, client Client) { @@ -192,26 +191,22 @@ func testGetRepairRunSegments(t *testing.T, client Client) { segments, err := client.GetRepairRunSegments(context.Background(), run.Id) require.Nil(t, err) assert.Len(t, segments, 2) - assert.Nil(t, segments[0].StartTime) - assert.Nil(t, segments[1].StartTime) - assert.Nil(t, segments[0].EndTime) - assert.Nil(t, segments[1].EndTime) - assert.Equal(t, RepairSegmentStateNotStarted, segments[0].State) - assert.Equal(t, RepairSegmentStateNotStarted, segments[1].State) - checkRepairRunSegment(t, run, segments[0]) - checkRepairRunSegment(t, run, segments[1]) + for _, segment := range segments { + assert.Equal(t, RepairSegmentStateNotStarted, segment.State) + assert.Nil(t, segment.StartTime) + assert.Nil(t, segment.EndTime) + checkRepairRunSegment(t, run, segment) + } err = client.StartRepairRun(context.Background(), run.Id) require.Nil(t, err) segments = waitForSegmentsStarted(t, client, run) - // could be STARTED, RUNNING or DONE - assert.NotEqual(t, RepairSegmentStateNotStarted, segments[0].State) - assert.NotEqual(t, RepairSegmentStateNotStarted, segments[1].State) - assert.NotNil(t, segments[0].StartTime) - assert.NotNil(t, segments[1].StartTime) - assert.True(t, segments[0].EndTime == nil || segments[0].State == RepairSegmentStateDone) - assert.True(t, segments[1].EndTime == nil || segments[1].State == RepairSegmentStateDone) - checkRepairRunSegment(t, run, segments[0]) - checkRepairRunSegment(t, run, segments[1]) + for _, segment := range segments { + // could be STARTED, RUNNING or DONE + assert.NotEqual(t, RepairSegmentStateNotStarted, segment.State) + assert.NotNil(t, segment.StartTime) + assert.True(t, segment.EndTime == nil || segment.State == RepairSegmentStateDone) + checkRepairRunSegment(t, run, segment) + } err = client.PauseRepairRun(context.Background(), run.Id) if err != nil { // pause not possible because repair is DONE @@ -220,13 +215,12 @@ func testGetRepairRunSegments(t *testing.T, client Client) { segments, err = client.GetRepairRunSegments(context.Background(), run.Id) require.Nil(t, err) assert.Len(t, segments, 2) - // some segments may be DONE or even RUNNING: cannot assert state here - assert.NotNil(t, segments[0].StartTime) - assert.NotNil(t, segments[1].StartTime) - assert.True(t, segments[0].EndTime == nil || segments[0].State == RepairSegmentStateDone) - assert.True(t, segments[1].EndTime == nil || segments[1].State == RepairSegmentStateDone) - checkRepairRunSegment(t, run, segments[0]) - checkRepairRunSegment(t, run, segments[1]) + for _, segment := range segments { + // some segments may be DONE or even RUNNING: cannot assert state here + assert.NotNil(t, segment.StartTime) + assert.True(t, segment.EndTime == nil || segment.State == RepairSegmentStateDone) + checkRepairRunSegment(t, run, segment) + } err = client.StartRepairRun(context.Background(), run.Id) require.Nil(t, err) done := waitForRepairRun(t, client, run, RepairRunStateDone) @@ -235,14 +229,12 @@ func testGetRepairRunSegments(t *testing.T, client Client) { segments, err = client.GetRepairRunSegments(context.Background(), run.Id) require.Nil(t, err) assert.Len(t, segments, 2) - assert.Equal(t, RepairSegmentStateDone, segments[0].State) - assert.Equal(t, RepairSegmentStateDone, segments[1].State) - assert.NotNil(t, segments[0].StartTime) - assert.NotNil(t, segments[1].StartTime) - assert.NotNil(t, segments[0].EndTime) - assert.NotNil(t, segments[1].EndTime) - checkRepairRunSegment(t, run, segments[0]) - checkRepairRunSegment(t, run, segments[1]) + for _, segment := range segments { + assert.Equal(t, RepairSegmentStateDone, segment.State) + assert.NotNil(t, segment.StartTime) + assert.NotNil(t, segment.EndTime) + checkRepairRunSegment(t, run, segment) + } } func testAbortRepairRunSegments(t *testing.T, client Client) { @@ -251,28 +243,22 @@ func testAbortRepairRunSegments(t *testing.T, client Client) { err := client.StartRepairRun(context.Background(), run.Id) require.Nil(t, err) segments := waitForSegmentsStarted(t, client, run) - // could be STARTED, RUNNING or DONE - assert.NotEqual(t, RepairSegmentStateNotStarted, segments[0].State) - assert.NotEqual(t, RepairSegmentStateNotStarted, segments[1].State) - expectedStates := map[uuid.UUID]RepairSegmentState{ - segments[0].Id: RepairSegmentStateNotStarted, - segments[1].Id: RepairSegmentStateNotStarted, - } - err = client.AbortRepairRunSegment(context.Background(), run.Id, segments[0].Id) - if err != nil { - require.Contains(t, err.Error(), "Cannot abort segment on repair run with status DONE") - expectedStates[segments[0].Id] = RepairSegmentStateDone - } - err = client.AbortRepairRunSegment(context.Background(), run.Id, segments[1].Id) - if err != nil { - require.Contains(t, err.Error(), "Cannot abort segment on repair run with status DONE") - expectedStates[segments[1].Id] = RepairSegmentStateDone + for _, segment := range segments { + // could be STARTED, RUNNING or DONE + assert.NotEqual(t, RepairSegmentStateNotStarted, segment.State) + err = client.AbortRepairRunSegment(context.Background(), run.Id, segment.Id) + if err != nil { + require.Contains(t, err.Error(), "Cannot abort segment on repair run with status DONE") + } } segments, err = client.GetRepairRunSegments(context.Background(), run.Id) require.Nil(t, err) require.Len(t, segments, 2) - assert.Equal(t, expectedStates[segments[0].Id], segments[0].State) - assert.Equal(t, expectedStates[segments[1].Id], segments[1].State) + for _, segment := range segments { + assert.True(t, + segment.State == RepairSegmentStateNotStarted || + segment.State == RepairSegmentStateDone) + } } func testDeleteRepairRunNotFound(t *testing.T, client Client) { @@ -378,14 +364,20 @@ func waitForRepairRun(t *testing.T, client Client, run *RepairRun, state RepairR return nil } -func waitForSegmentsStarted(t *testing.T, client Client, run *RepairRun) []*RepairSegment { +func waitForSegmentsStarted(t *testing.T, client Client, run *RepairRun) map[uuid.UUID]*RepairSegment { success := assert.Eventually( t, func() bool { segments, err := client.GetRepairRunSegments(context.Background(), run.Id) - return err == nil && - segments[0].State != RepairSegmentStateNotStarted && - segments[1].State != RepairSegmentStateNotStarted + if err != nil { + return false + } + for _, segment := range segments { + if segment.State == RepairSegmentStateNotStarted { + return false + } + } + return true }, 5*time.Minute, 5*time.Second, @@ -395,10 +387,6 @@ func waitForSegmentsStarted(t *testing.T, client Client, run *RepairRun) []*Repa if success { return segments } - t.Fatalf( - "timed out waiting for repair to start, last states were: %s, %s", - segments[0].State, - segments[1].State, - ) + t.Fatal("timed out waiting for repair to start") return nil }