diff --git a/Makefile b/Makefile index a03b62b..0dc7471 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ wait-for-reaper: .PHONY: test test: @echo Running tests: - go test -v -race -cover ./reaper/... + go test -v -race -cover -timeout 30m ./reaper/... .PHONY: test-cleanup test-cleanup: diff --git a/go.mod b/go.mod index 8dec78d..c556706 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,9 @@ module github.com/k8ssandra/reaper-client-go go 1.15 -require github.com/stretchr/testify v1.5.1 +require ( + github.com/google/go-querystring v1.1.0 + github.com/google/uuid v1.2.0 + github.com/stretchr/testify v1.5.1 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c +) diff --git a/go.sum b/go.sum index c0565d7..2a6273e 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,20 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/reaper/client.go b/reaper/client.go index 386b64c..7e5252d 100644 --- a/reaper/client.go +++ b/reaper/client.go @@ -2,19 +2,13 @@ package reaper import ( "context" - "encoding/json" - "fmt" - "io/ioutil" - "log" - "math" + "github.com/google/uuid" "net/http" "net/url" - "runtime" - "sync" "time" ) -type ReaperClient interface { +type Client interface { IsReaperUp(ctx context.Context) (bool, error) GetClusterNames(ctx context.Context) ([]string, error) @@ -33,303 +27,85 @@ type ReaperClient interface { DeleteCluster(ctx context.Context, cluster string) error - RepairSchedules(ctx context.Context) ([]RepairSchedule, error) + // RepairRuns returns a list of repair runs, optionally filtering according to the provided search options. + RepairRuns(ctx context.Context, searchOptions *RepairRunSearchOptions) (map[uuid.UUID]*RepairRun, error) - RepairSchedulesForCluster(ctx context.Context, clusterName string) ([]RepairSchedule, error) -} + // RepairRun returns a repair run object identified by its id. + RepairRun(ctx context.Context, repairRunId uuid.UUID) (*RepairRun, error) -type Client struct { - BaseURL *url.URL - UserAgent string - httpClient *http.Client -} + // CreateRepairRun creates a new repair run for the given cluster and keyspace. Does not actually trigger the run: + // creating a repair run includes generating the repair segments. Returns the id of the newly-created repair run if + // successful. The owner name can be any string identifying the owner. + CreateRepairRun( + ctx context.Context, + cluster string, + keyspace string, + owner string, + options *RepairRunCreateOptions, + ) (uuid.UUID, error) -func newClient(reaperBaseURL string) (*Client, error) { - if baseURL, err := url.Parse(reaperBaseURL); err != nil { - return nil, err - } else { - return &Client{BaseURL: baseURL, UserAgent: "", httpClient: &http.Client{Timeout: 3 * time.Second}}, nil - } + // UpdateRepairRun modifies the intensity of a PAUSED repair run identified by its id. + UpdateRepairRun(ctx context.Context, repairRunId uuid.UUID, newIntensity Intensity) error -} + // StartRepairRun starts (or resumes) a repair run identified by its id. Can also be used to reattempt a repair run + // in state “ERROR”, picking up where it left off. + StartRepairRun(ctx context.Context, repairRunId uuid.UUID) error -func NewReaperClient(baseURL string) (ReaperClient, error) { - return newClient(baseURL) -} + // PauseRepairRun pauses a repair run identified by its id. The repair run must be in RUNNING state. + PauseRepairRun(ctx context.Context, repairRunId uuid.UUID) error -func (c *Client) IsReaperUp(ctx context.Context) (bool, error) { - rel := &url.URL{Path: "/ping"} - u := c.BaseURL.ResolveReference(rel) - req, err := http.NewRequest(http.MethodGet, u.String(), nil) - if err != nil { - return false, err - } + // ResumeRepairRun is an alias to StartRepairRun. + ResumeRepairRun(ctx context.Context, repairRunId uuid.UUID) error - if resp, err := c.doRequest(ctx, req, nil); err == nil { - return resp.StatusCode == http.StatusNoContent, nil - } else { - return false, err - } -} - -func (c *Client) GetClusterNames(ctx context.Context) ([]string, error) { - rel := &url.URL{Path: "/cluster"} - u := c.BaseURL.ResolveReference(rel) - req, err := http.NewRequest(http.MethodGet, u.String(), nil) - if err != nil { - return nil, err - } + // AbortRepairRun aborts a repair run identified bu its id. The repair run must not be in ERROR state. + AbortRepairRun(ctx context.Context, repairRunId uuid.UUID) error - //req.Header.Set("User-Agent", c.UserAgent) + // RepairRunSegments returns the list of segments of a repair run identified by its id. + RepairRunSegments(ctx context.Context, repairRunId uuid.UUID) (map[uuid.UUID]*RepairSegment, error) - clusterNames := []string{} - _, err = c.doJsonRequest(ctx, req, &clusterNames) - - if err != nil { - return nil, fmt.Errorf("failed to get cluster names: %w", err) - } - - return clusterNames, nil -} - -func (c *Client) GetCluster(ctx context.Context, name string) (*Cluster, error) { - rel := &url.URL{Path: fmt.Sprintf("/cluster/%s", name)} - u := c.BaseURL.ResolveReference(rel) - req, err := http.NewRequest(http.MethodGet, u.String(), nil) - if err != nil { - return nil, err - } + // AbortRepairRunSegment aborts a running segment and puts it back in NOT_STARTED state. The segment will be + // processed again later during the lifetime of the repair run. + AbortRepairRunSegment(ctx context.Context, repairRunId uuid.UUID, segmentId uuid.UUID) error - clusterState := &clusterStatus{} - resp, err := c.doJsonRequest(ctx, req, clusterState) + // DeleteRepairRun deletes a repair run object identified by its id. Repair run and all the related repair segments + // will be deleted from the database. If the given owner does not match the stored owner, the delete request will + // fail. + DeleteRepairRun(ctx context.Context, repairRunId uuid.UUID, owner string) error - if err != nil { - fmt.Printf("response: %+v", resp) - return nil, fmt.Errorf("failed to get cluster (%s): %w", name, err) - } - - if resp.StatusCode == http.StatusNotFound { - return nil, CassandraClusterNotFound - } - - cluster := newCluster(clusterState) - - return cluster, nil -} - -// Fetches all clusters. This function is async and may return before any or all results are -// available. The concurrency is currently determined by min(5, NUM_CPUS). -func (c *Client) GetClusters(ctx context.Context) <-chan GetClusterResult { - // TODO Make the concurrency configurable - concurrency := int(math.Min(5, float64(runtime.NumCPU()))) - results := make(chan GetClusterResult, concurrency) - - clusterNames, err := c.GetClusterNames(ctx) - if err != nil { - close(results) - return results - } - - var wg sync.WaitGroup - - go func() { - defer close(results) - for _, clusterName := range clusterNames { - wg.Add(1) - go func(name string) { - defer wg.Done() - cluster, err := c.GetCluster(ctx, name) - result := GetClusterResult{Cluster: cluster, Error: err} - results <- result - }(clusterName) - } - wg.Wait() - }() - - return results -} - -// Fetches all clusters in a synchronous or blocking manner. Note that this function fails -// fast if there is an error and no clusters will be returned. -func (c *Client) GetClustersSync(ctx context.Context) ([]*Cluster, error) { - clusters := make([]*Cluster, 0) - - for result := range c.GetClusters(ctx) { - if result.Error != nil { - return nil, result.Error - } - clusters = append(clusters, result.Cluster) - } - - return clusters, nil -} + // PurgeRepairRuns purges repairs and returns the number of repair runs purged. + PurgeRepairRuns(ctx context.Context) (int, error) -func (c *Client) AddCluster(ctx context.Context, cluster string, seed string) error { - rel := &url.URL{Path: fmt.Sprintf("/cluster/%s", cluster)} - u := c.BaseURL.ResolveReference(rel) - - req, err := http.NewRequest(http.MethodPut, u.String(), nil) - if err != nil { - return err - } - req.Header.Set("Accept", "application/json") - q := req.URL.Query() - q.Add("seedHost", seed) - req.URL.RawQuery = q.Encode() - req.WithContext(ctx) - - resp, err := c.httpClient.Do(req) - if err != nil { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - return err - } - defer resp.Body.Close() - - switch { - case resp.StatusCode < 300: - return nil - case resp.StatusCode >= 300 && resp.StatusCode < 400: - return ErrRedirectsNotSupported - default: - if body, err := getBodyAsString(resp); err == nil { - return fmt.Errorf("request failed: msg (%s), status code (%d)", body, resp.StatusCode) - } - log.Printf("failed to get response body: %s", err) - return fmt.Errorf("request failed: status code (%d)", resp.StatusCode) - } -} - -func (c *Client) DeleteCluster(ctx context.Context, cluster string) error { - rel := &url.URL{Path: fmt.Sprintf("/cluster/%s", cluster)} - u := c.BaseURL.ResolveReference(rel) - req, err := http.NewRequest(http.MethodDelete, u.String(), nil) - if err != nil { - return err - } - - _, err = c.doJsonRequest(ctx, req, nil) - - // TODO check response status code - - if err != nil { - return fmt.Errorf("failed to delete cluster (%s): %w", cluster, err) - } - - return nil -} - -func (c *Client) RepairSchedules(ctx context.Context) ([]RepairSchedule, error) { - rel := &url.URL{Path: "/repair_schedule"} - return c.fetchRepairSchedules(ctx, rel) -} - -func (c *Client) RepairSchedulesForCluster(ctx context.Context, clusterName string) ([]RepairSchedule, error) { - rel := &url.URL{Path: fmt.Sprintf("/repair_schedule/cluster/%s", clusterName)} - return c.fetchRepairSchedules(ctx, rel) -} - -func (c *Client) fetchRepairSchedules(ctx context.Context, rel *url.URL) ([]RepairSchedule, error) { - u := c.BaseURL.ResolveReference(rel) - req, err := http.NewRequest(http.MethodGet, u.String(), nil) - - if err != nil { - return nil, err - } - - schedules := make([]RepairSchedule, 0) - resp, err := c.doJsonRequest(ctx, req, &schedules) - if err != nil { - return nil, err - } - if resp.StatusCode != 200 { - return nil, fmt.Errorf("Failed to fetch repair schedules: %v\n", resp.StatusCode) - } + RepairSchedules(ctx context.Context) ([]RepairSchedule, error) - return schedules, nil + RepairSchedulesForCluster(ctx context.Context, clusterName string) ([]RepairSchedule, error) } -func (c *Client) doJsonRequest(ctx context.Context, req *http.Request, v interface{}) (*http.Response, error) { - req.Header.Set("Accept", "application/json") - return c.doRequest(ctx, req, v) +type client struct { + baseURL *url.URL + userAgent string + httpClient *http.Client } -func (c *Client) doRequest(ctx context.Context, req *http.Request, v interface{}) (*http.Response, error) { - req.WithContext(ctx) - - resp, err := c.httpClient.Do(req) - if err != nil { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - return nil, err +func NewClient(reaperBaseURL *url.URL, options ...ClientCreateOption) Client { + client := &client{baseURL: reaperBaseURL, httpClient: &http.Client{ + Timeout: 10 * time.Second, + }} + for _, option := range options { + option(client) } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - return resp, nil - } - - if v != nil { - err = json.NewDecoder(resp.Body).Decode(v) - } - - return resp, err + return client } -func newCluster(state *clusterStatus) *Cluster { - cluster := Cluster{ - Name: state.Name, - JmxUsername: state.JmxUsername, - JmxPasswordSet: state.JmxPasswordSet, - Seeds: state.Seeds, - NodeState: NodeState{}, - } +type ClientCreateOption func(client *client) - for _, gs := range state.NodeStatus.EndpointStates { - gossipState := GossipState{ - SourceNode: gs.SourceNode, - EndpointNames: gs.EndpointNames, - TotalLoad: gs.TotalLoad, - DataCenters: map[string]DataCenterState{}, - } - for dc, dcStateInternal := range gs.Endpoints { - dcState := DataCenterState{Name: dc, Racks: map[string]RackState{}} - for rack, endpoints := range dcStateInternal { - rackState := RackState{Name: rack} - for _, ep := range endpoints { - endpoint := EndpointState{ - Endpoint: ep.Endpoint, - DataCenter: ep.DataCenter, - Rack: ep.Rack, - HostId: ep.HostId, - Status: ep.Status, - Severity: ep.Severity, - ReleaseVersion: ep.ReleaseVersion, - Tokens: ep.Tokens, - Load: ep.Load, - } - rackState.Endpoints = append(rackState.Endpoints, endpoint) - } - dcState.Racks[rack] = rackState - } - gossipState.DataCenters[dc] = dcState - } - cluster.NodeState.GossipStates = append(cluster.NodeState.GossipStates, gossipState) +func WithUserAgent(userAgent string) ClientCreateOption { + return func(client *client) { + client.userAgent = userAgent } - - return &cluster } -func getBodyAsString(resp *http.Response) (string, error) { - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err +func WithHttpClient(httpClient *http.Client) ClientCreateOption { + return func(client *client) { + client.httpClient = httpClient } - return string(body), nil } diff --git a/reaper/client_test.go b/reaper/client_test.go index e97b1b0..a74d57a 100644 --- a/reaper/client_test.go +++ b/reaper/client_test.go @@ -2,23 +2,25 @@ package reaper import ( "context" - "testing" - "time" - + "fmt" "github.com/k8ssandra/reaper-client-go/testenv" - "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" + "net/url" + "os" + "testing" ) const ( reaperURL = "http://localhost:8080" + keyspace = "reaper_client_test" ) -type clientTest func(*testing.T, *Client) +type clientTest func(*testing.T, Client) -func run(client *Client, test clientTest) func (*testing.T) { +func run(client Client, test clientTest) func(*testing.T) { return func(t *testing.T) { - //name := runtime.FuncForPC(reflect.ValueOf(test).Pointer()).Name() - //t.Logf("running %s\n", name) + // name := runtime.FuncForPC(reflect.ValueOf(test).Pointer()).Name() + // t.Logf("running %s\n", name) test(t, client) } } @@ -26,199 +28,162 @@ func run(client *Client, test clientTest) func (*testing.T) { func TestClient(t *testing.T) { t.Log("starting test") - client, err := newClient(reaperURL) - if err != nil { - t.Fatalf("failed to create reaper client: (%s)", err) - } + u, _ := url.Parse(reaperURL) + client := NewClient(u) - if err = testenv.ResetServices(t); err != nil { - t.Fatalf("failed to reset docker services: %s", err) - } + ctx := context.Background() - if err = testenv.WaitForClusterReady(t,"cluster-1-node-0", 2); err != nil { - t.Fatalf("cluster-1 readiness check failed: %s", err) - } - if err = testenv.WaitForClusterReady(t,"cluster-2-node-0", 2); err != nil { - t.Fatalf("cluster-2 readiness check failed: %s", err) - } - if err = testenv.WaitForClusterReady(t,"cluster-3-node-0", 1); err != nil { - t.Fatalf("cluster-1 readiness check failed: %s", err) - } + prepareEnvironment(t, ctx) - isUp := false - for i := 0; i < 10; i++ { - t.Log("checking if reaper is ready") - if isUp, err = client.IsReaperUp(context.Background()); err == nil { - if isUp { - t.Log("reaper is ready!") - break - } - } else { - t.Logf("reaper readiness check failed: %s", err) - } - time.Sleep(6 * time.Second) - } - if !isUp { - t.Fatalf("reaper readiness check timed out") - } + t.Run("Ping", run(client, testIsReaperUp)) - if err = testenv.AddCluster(t,"cluster-1", "cluster-1-node-0"); err != nil { - t.Fatalf("failed to add cluster-1: %s", err) - } - if err = testenv.AddCluster(t,"cluster-2", "cluster-2-node-0"); err != nil { - t.Fatalf("failed to add cluster-2: %s", err) - } + registerClusters(t, ctx) + + t.Log("running Cluster resource tests...") t.Run("GetClusterNames", run(client, testGetClusterNames)) t.Run("GetCluster", run(client, testGetCluster)) t.Run("GetClusterNotFound", run(client, testGetClusterNotFound)) t.Run("GetClusters", run(client, testGetClusters)) - t.Run("GetClustersSync", run(client, testGetClustersSyc)) + t.Run("GetClustersSync", run(client, testGetClustersSync)) t.Run("AddDeleteCluster", run(client, testAddDeleteCluster)) -} - -func testGetClusterNames(t *testing.T, client *Client) { - expected := []string{"cluster-1", "cluster-2"} - - actual, err := client.GetClusterNames(context.TODO()) - if err != nil { - t.Fatalf("failed to get cluster names: (%s)", err) - } - - assert.ElementsMatch(t, expected, actual) -} - -func testGetCluster(t *testing.T, client *Client) { - name := "cluster-1" - cluster, err := client.GetCluster(context.TODO(), name) - if err != nil { - t.Fatalf("failed to get cluster (%s): %s", name, err) - } - assert.Equal(t, cluster.Name, name) - assert.Equal(t, cluster.JmxUsername, "reaperUser") - assert.True(t, cluster.JmxPasswordSet) - assert.Equal(t, len(cluster.Seeds), 2) - assert.Equal(t, 1, len(cluster.NodeState.GossipStates)) - - gossipState := cluster.NodeState.GossipStates[0] - assert.NotEmpty(t, gossipState.SourceNode) - assert.True(t, gossipState.TotalLoad > 0.0) - assert.Equal(t, 2, len(gossipState.EndpointNames), "EndpointNames (%s)", gossipState.EndpointNames) - - assert.Equal(t, 1, len(gossipState.DataCenters), "DataCenters (%+v)", gossipState.DataCenters) - dcName := "datacenter1" - dc, found := gossipState.DataCenters[dcName] - if !found { - t.Fatalf("failed to find DataCenter (%s)", dcName) - } - assert.Equal(t, dcName, dc.Name) - - assert.Equal(t, 1, len(dc.Racks)) - rackName := "rack1" - rack, found := dc.Racks[rackName] - if !found { - t.Fatalf("failed to find Rack (%s)", rackName) - } + createFixtures(t, ctx) + + t.Log("running RepairRun resource tests...") + + t.Run("GetRepairRun", run(client, testGetRepairRun)) + t.Run("GetRepairRunNotFound", run(client, testGetRepairRunNotFound)) + t.Run("GetRepairRunIgnoredTables", run(client, testGetRepairRunIgnoredTables)) + t.Run("GetRepairRuns", run(client, testGetRepairRuns)) + t.Run("GetRepairRunsFilteredByCluster", run(client, testGetRepairRunsFilteredByCluster)) + t.Run("GetRepairRunsFilteredByKeyspace", run(client, testGetRepairRunsFilteredByKeyspace)) + t.Run("GetRepairRunsFilteredByState", run(client, testGetRepairRunsFilteredByState)) + t.Run("CreateDeleteRepairRun", run(client, testCreateDeleteRepairRun)) + t.Run("DeleteRepairRunNotFound", run(client, testDeleteRepairRunNotFound)) + t.Run("CreateStartFinishRepairRun", run(client, testCreateStartFinishRepairRun)) + t.Run("CreateStartPauseUpdateResumeRepairRun", run(client, testCreateStartPauseUpdateResumeRepairRun)) + t.Run("CreateAbortRepairRun", run(client, testCreateAbortRepairRun)) + t.Run("GetRepairRunSegments", run(client, testGetRepairRunSegments)) + t.Run("AbortRepairRunSegments", run(client, testAbortRepairRunSegments)) + t.Run("PurgeRepairRun", run(client, testPurgeRepairRun)) - assert.Equal(t, 2, len(rack.Endpoints)) - for _, ep := range rack.Endpoints { - assert.True(t, ep.Endpoint == gossipState.EndpointNames[0] || ep.Endpoint == gossipState.EndpointNames[1]) - assert.NotEmpty(t, ep.HostId) - assert.Equal(t, dcName, ep.DataCenter) - assert.Equal(t, rackName, ep.Rack) - assert.NotEmpty(t, ep.Status) - assert.Equal(t, "3.11.8", ep.ReleaseVersion) - assert.NotEmpty(t, ep.Tokens) - } } -func testGetClusterNotFound(t *testing.T, client *Client) { - name := "cluster-notfound" - cluster, err := client.GetCluster(context.TODO(), name) - - if err != CassandraClusterNotFound { - t.Errorf("expected (%s) but got (%s)", CassandraClusterNotFound, err) - } - - assert.Nil(t, cluster, "expected non-existent cluster to be nil") -} - -func testGetClusters(t *testing.T, client *Client) { - results := make([]GetClusterResult, 0) - - for result := range client.GetClusters(context.TODO()) { - results = append(results, result) - } - - // Verify that we got the expected number of results - assert.Equal(t, 2, len(results)) - - // Verify that there were no errors - for _, result := range results { - assert.Nil(t, result.Error) +func prepareEnvironment(t *testing.T, parent context.Context) { + if err := testenv.ResetServices(t); err != nil { + t.Fatalf("failed to reset docker services: %s", err) } - - assertGetClusterResultsContains(t, results, "cluster-1") - assertGetClusterResultsContains(t, results, "cluster-2") -} - -func assertGetClusterResultsContains(t *testing.T, results []GetClusterResult, clusterName string) { - var cluster *Cluster - for _, result := range results { - if result.Cluster.Name == clusterName { - cluster = result.Cluster - break + clusterReadinessGroup, ctx := errgroup.WithContext(parent) + t.Log("checking cassandra cluster-1 status...") + clusterReadinessGroup.Go(func() error { + if err := testenv.WaitForClusterReady(ctx, "cluster-1-node-0", 2); err != nil { + return fmt.Errorf("cluster-1 readiness check failed: %w", err) + } + return nil + }) + t.Log("checking cassandra cluster-2 status...") + clusterReadinessGroup.Go(func() error { + if err := testenv.WaitForClusterReady(ctx, "cluster-2-node-0", 2); err != nil { + return fmt.Errorf("cluster-2 readiness check failed: %w", err) } + return nil + }) + t.Log("checking cassandra cluster-3 status...") + clusterReadinessGroup.Go(func() error { + if err := testenv.WaitForClusterReady(ctx, "cluster-3-node-0", 1); err != nil { + return fmt.Errorf("cluster-3 readiness check failed: %w", err) + } + return nil + }) + if err := clusterReadinessGroup.Wait(); err != nil { + t.Fatal(err) } - assert.NotNil(t, cluster, "failed to find %s", clusterName) } -func testGetClustersSyc(t *testing.T, client *Client) { - clusters, err := client.GetClustersSync(context.TODO()) - - if err != nil { - t.Fatalf("failed to get clusters synchronously: %s", err) - } - - // Verify that we got the expected number of results - assert.Equal(t, 2, len(clusters)) - - assertClustersContains(t, clusters, "cluster-1") - assertClustersContains(t, clusters, "cluster-2") -} - -func assertClustersContains(t *testing.T, clusters []*Cluster, clusterName string) { - for _, cluster := range clusters { - if cluster.Name == clusterName { - return +func registerClusters(t *testing.T, parent context.Context) { + addClusterGroup, ctx := errgroup.WithContext(parent) + t.Log("adding cluster-1 in Reaper...") + addClusterGroup.Go(func() error { + if err := testenv.AddCluster(ctx, "cluster-1", "cluster-1-node-0"); err != nil { + return fmt.Errorf("failed to add cluster-1: %w", err) + } + return nil + }) + t.Log("adding cluster-2 in Reaper...") + addClusterGroup.Go(func() error { + if err := testenv.AddCluster(ctx, "cluster-2", "cluster-2-node-0"); err != nil { + return fmt.Errorf("failed to add cluster-2: %w", err) } + return nil + }) + // cluster-3 will be added by a test + if err := addClusterGroup.Wait(); err != nil { + t.Fatal(err) } - t.Errorf("failed to find cluster (%s)", clusterName) } -func testAddDeleteCluster(t *testing.T, client *Client) { - cluster := "cluster-3" - seed := "cluster-3-node-0" - - if err := client.AddCluster(context.TODO(), cluster, seed); err != nil { - t.Fatalf("failed to add cluster (%s): %s", cluster, err) - } - - if clusterNames, err := client.GetClusterNames(context.TODO()); err != nil { - t.Fatalf("failed to get cluster names: %s", err) - } else { - assert.Equal(t, 3, len(clusterNames)) - } - - if err := client.DeleteCluster(context.TODO(), cluster); err != nil { - t.Fatalf("failed to delete cluster (%s): %s", cluster, err) - } - - if clusterNames, err := client.GetClusterNames(context.TODO()); err != nil { - t.Fatalf("failed to get cluster names: %s", err) - } else { - assert.Equal(t, 2, len(clusterNames)) - assert.NotContains(t, clusterNames, cluster) +func createFixtures(t *testing.T, parent context.Context) { + scriptsGroup, ctx := errgroup.WithContext(parent) + scripts := make(chan *os.File, 2) + t.Log("generating CQL scripts...") + scriptsGroup.Go(func() error { + if script, err := testenv.CreateCqlInsertScript(keyspace, "table1"); err != nil { + return fmt.Errorf("failed to create table1 CQL script: %w", err) + } else { + scripts <- script + return nil + } + }) + scriptsGroup.Go(func() error { + if script, err := testenv.CreateCqlInsertScript(keyspace, "table2"); err != nil { + return fmt.Errorf("failed to create table2 CQL script: %w", err) + } else { + scripts <- script + return nil + } + }) + if err := scriptsGroup.Wait(); err != nil { + t.Fatal(err) + } + script1 := <-scripts + script2 := <-scripts + cqlFixturesGroup, ctx := errgroup.WithContext(parent) + t.Log("populating test keyspace in cluster-1...") + cqlFixturesGroup.Go(func() error { + if err := testenv.WaitForCqlReady(ctx, "cluster-1-node-0"); err != nil { + return fmt.Errorf("CQL cluster-1 readiness check failed: %w", err) + } else if err = testenv.CreateKeyspace(ctx, "cluster-1-node-0", keyspace, 2); err != nil { + return fmt.Errorf("failed to create keyspace on cluster-1: %w", err) + } else if err = testenv.CreateTable(ctx, "cluster-1-node-0", keyspace, "table1"); err != nil { + return fmt.Errorf("failed to create keyspace on cluster-1: %w", err) + } else if err = testenv.CreateTable(ctx, "cluster-1-node-0", keyspace, "table2"); err != nil { + return fmt.Errorf("failed to create keyspace on cluster-1: %w", err) + } else if err := testenv.ExecuteCqlScript(ctx, "cluster-1-node-0", script1); err != nil { + return fmt.Errorf("failed to execute CQL script 1 on cluster-1: %w", err) + } else if err := testenv.ExecuteCqlScript(ctx, "cluster-1-node-0", script2); err != nil { + return fmt.Errorf("failed to execute CQL script 2 on cluster-1: %w", err) + } + return nil + }) + t.Log("populating test keyspace in cluster-2...") + cqlFixturesGroup.Go(func() error { + if err := testenv.WaitForCqlReady(ctx, "cluster-2-node-0"); err != nil { + return fmt.Errorf("CQL cluster-2 readiness check failed: %s", err) + } else if err = testenv.CreateKeyspace(ctx, "cluster-2-node-0", keyspace, 2); err != nil { + return fmt.Errorf("failed to create keyspace on cluster-2: %s", err) + } else if err = testenv.CreateTable(ctx, "cluster-2-node-0", keyspace, "table1"); err != nil { + return fmt.Errorf("failed to create keyspace on cluster-2: %s", err) + } else if err = testenv.CreateTable(ctx, "cluster-2-node-0", keyspace, "table2"); err != nil { + return fmt.Errorf("failed to create keyspace on cluster-2: %s", err) + } else if err := testenv.ExecuteCqlScript(ctx, "cluster-2-node-0", script1); err != nil { + return fmt.Errorf("failed to execute CQL script 1 on cluster-2: %s", err) + } else if err := testenv.ExecuteCqlScript(ctx, "cluster-2-node-0", script2); err != nil { + return fmt.Errorf("failed to execute CQL script 2 on cluster-2: %s", err) + } + return nil + }) + if err := cqlFixturesGroup.Wait(); err != nil { + t.Fatal(err) } } diff --git a/reaper/cluster.go b/reaper/cluster.go new file mode 100644 index 0000000..789060a --- /dev/null +++ b/reaper/cluster.go @@ -0,0 +1,243 @@ +package reaper + +import ( + "context" + "fmt" + "math" + "net/http" + "net/url" + "runtime" + "sync" + "time" +) + +type Cluster struct { + Name string + JmxUsername string + JmxPasswordSet bool + Seeds []string + NodeState NodeState +} + +type NodeState struct { + GossipStates []GossipState +} + +type GossipState struct { + SourceNode string + EndpointNames []string + TotalLoad float64 + DataCenters map[string]DataCenterState +} + +type DataCenterState struct { + Name string + Racks map[string]RackState +} + +type RackState struct { + Name string + Endpoints []EndpointState +} + +type EndpointState struct { + Endpoint string + DataCenter string + Rack string + HostId string + Status string + Severity float64 + ReleaseVersion string + Tokens string + Load float64 +} + +type GetClusterResult struct { + Cluster *Cluster + Error error +} + +type RepairSchedule struct { + Id string `json:"id"` + Owner string `json:"owner,omitempty"` + State string `json:"state,omitempty"` + Intensity float64 `json:"intensity,omitempty"` + ClusterName string `json:"cluster_name,omitempty"` + KeyspaceName string `json:"keyspace_name,omitempty"` + RepairParallelism string `json:"repair_parallelism,omitempty"` + IncrementalRepair bool `json:"incremental_repair,omitempty"` + RepairThreadCount int `json:"repair_thread_count,omitempty"` + RepairUnitId string `json:"repair_unit_id,omitempty"` + DaysBetween int `json:"scheduled_days_between,omitempty"` + Created time.Time `json:"creation_time,omitempty"` + Paused time.Time `json:"pause_time,omitempty"` + NextActivation time.Time `json:"next_activation,omitempty"` +} + +// All the following types are used internally by the client and not part of the public API + +type clusterStatus struct { + Name string `json:"name"` + JmxUsername string `json:"jmx_username,omitempty"` + JmxPasswordSet bool `json:"jmx_password_is_set,omitempty"` + Seeds []string `json:"seed_hosts,omitempty"` + NodeStatus nodeStatus `json:"nodes_status"` +} + +type nodeStatus struct { + EndpointStates []gossipStatus `json:"endpointStates,omitempty"` +} + +type gossipStatus struct { + SourceNode string `json:"sourceNode"` + EndpointNames []string `json:"endpointNames,omitempty"` + TotalLoad float64 `json:"totalLoad,omitempty"` + Endpoints map[string]map[string][]endpointStatus +} + +type endpointStatus struct { + Endpoint string `json:"endpoint"` + DataCenter string `json:"dc"` + Rack string `json:"rack"` + HostId string `json:"hostId"` + Status string `json:"status"` + Severity float64 `json:"severity"` + ReleaseVersion string `json:"releaseVersion"` + Tokens string `json:"tokens"` + Load float64 `json:"load"` +} + +func (c *client) GetClusterNames(ctx context.Context) ([]string, error) { + res, err := c.doGet(ctx, "/cluster", nil, http.StatusOK) + if err == nil { + clusterNames := make([]string, 0) + err = c.readBodyAsJson(res, &clusterNames) + if err == nil { + return clusterNames, nil + } + } + return nil, fmt.Errorf("failed to get cluster names: %w", err) +} + +func (c *client) GetCluster(ctx context.Context, name string) (*Cluster, error) { + path := "/cluster/" + url.PathEscape(name) + res, err := c.doGet(ctx, path, nil, http.StatusOK) + if err == nil { + clusterStatus := &clusterStatus{} + err = c.readBodyAsJson(res, clusterStatus) + if err == nil { + return newCluster(clusterStatus), nil + } + } + return nil, fmt.Errorf("failed to get cluster %s: %w", name, err) +} + +// GetClusters fetches all clusters. This function is async and may return before any or all results are +// available. The concurrency is currently determined by min(5, NUM_CPUS). +func (c *client) GetClusters(ctx context.Context) <-chan GetClusterResult { + // TODO Make the concurrency configurable + concurrency := int(math.Min(5, float64(runtime.NumCPU()))) + results := make(chan GetClusterResult, concurrency) + + clusterNames, err := c.GetClusterNames(ctx) + if err != nil { + close(results) + return results + } + + var wg sync.WaitGroup + + go func() { + defer close(results) + for _, clusterName := range clusterNames { + wg.Add(1) + go func(name string) { + defer wg.Done() + cluster, err := c.GetCluster(ctx, name) + result := GetClusterResult{Cluster: cluster, Error: err} + results <- result + }(clusterName) + } + wg.Wait() + }() + + return results +} + +// GetClustersSync fetches all clusters in a synchronous or blocking manner. Note that this function fails +// fast if there is an error and no clusters will be returned. +func (c *client) GetClustersSync(ctx context.Context) ([]*Cluster, error) { + clusters := make([]*Cluster, 0) + + for result := range c.GetClusters(ctx) { + if result.Error != nil { + return nil, result.Error + } + clusters = append(clusters, result.Cluster) + } + + return clusters, nil +} + +func (c *client) AddCluster(ctx context.Context, cluster string, seed string) error { + queryParams := &url.Values{"seedHost": {seed}} + path := "/cluster/" + url.PathEscape(cluster) + _, err := c.doPut(ctx, path, queryParams, nil, http.StatusCreated, http.StatusNoContent, http.StatusOK) + if err == nil { + return nil + } + return fmt.Errorf("failed to create cluster %s: %w", cluster, err) +} + +func (c *client) DeleteCluster(ctx context.Context, cluster string) error { + path := "/cluster/" + url.PathEscape(cluster) + _, err := c.doDelete(ctx, path, nil, http.StatusAccepted) + if err == nil { + return nil + } + return fmt.Errorf("failed to delete cluster %s: %w", cluster, err) +} + +func newCluster(state *clusterStatus) *Cluster { + cluster := Cluster{ + Name: state.Name, + JmxUsername: state.JmxUsername, + JmxPasswordSet: state.JmxPasswordSet, + Seeds: state.Seeds, + NodeState: NodeState{}, + } + + for _, gs := range state.NodeStatus.EndpointStates { + gossipState := GossipState{ + SourceNode: gs.SourceNode, + EndpointNames: gs.EndpointNames, + TotalLoad: gs.TotalLoad, + DataCenters: map[string]DataCenterState{}, + } + for dc, dcStateInternal := range gs.Endpoints { + dcState := DataCenterState{Name: dc, Racks: map[string]RackState{}} + for rack, endpoints := range dcStateInternal { + rackState := RackState{Name: rack} + for _, ep := range endpoints { + endpoint := EndpointState{ + Endpoint: ep.Endpoint, + DataCenter: ep.DataCenter, + Rack: ep.Rack, + HostId: ep.HostId, + Status: ep.Status, + Severity: ep.Severity, + ReleaseVersion: ep.ReleaseVersion, + Tokens: ep.Tokens, + Load: ep.Load, + } + rackState.Endpoints = append(rackState.Endpoints, endpoint) + } + dcState.Racks[rack] = rackState + } + gossipState.DataCenters[dc] = dcState + } + cluster.NodeState.GossipStates = append(cluster.NodeState.GossipStates, gossipState) + } + + return &cluster +} diff --git a/reaper/cluster_test.go b/reaper/cluster_test.go new file mode 100644 index 0000000..c4de32d --- /dev/null +++ b/reaper/cluster_test.go @@ -0,0 +1,150 @@ +package reaper + +import ( + "context" + "github.com/stretchr/testify/assert" + "testing" +) + +func testGetClusterNames(t *testing.T, client Client) { + expected := []string{"cluster-1", "cluster-2"} + + actual, err := client.GetClusterNames(context.TODO()) + if err != nil { + t.Fatalf("failed to get cluster names: (%s)", err) + } + + assert.ElementsMatch(t, expected, actual) +} + +func testGetCluster(t *testing.T, client Client) { + name := "cluster-1" + cluster, err := client.GetCluster(context.TODO(), name) + if err != nil { + t.Fatalf("failed to get cluster (%s): %s", name, err) + } + + assert.Equal(t, cluster.Name, name) + assert.Equal(t, cluster.JmxUsername, "reaperUser") + assert.True(t, cluster.JmxPasswordSet) + assert.Equal(t, len(cluster.Seeds), 2) + assert.Equal(t, 1, len(cluster.NodeState.GossipStates)) + + gossipState := cluster.NodeState.GossipStates[0] + assert.NotEmpty(t, gossipState.SourceNode) + assert.True(t, gossipState.TotalLoad > 0.0) + assert.Equal(t, 2, len(gossipState.EndpointNames), "EndpointNames (%s)", gossipState.EndpointNames) + + assert.Equal(t, 1, len(gossipState.DataCenters), "DataCenters (%+v)", gossipState.DataCenters) + dcName := "datacenter1" + dc, found := gossipState.DataCenters[dcName] + if !found { + t.Fatalf("failed to find DataCenter (%s)", dcName) + } + assert.Equal(t, dcName, dc.Name) + + assert.Equal(t, 1, len(dc.Racks)) + rackName := "rack1" + rack, found := dc.Racks[rackName] + if !found { + t.Fatalf("failed to find Rack (%s)", rackName) + } + + assert.Equal(t, 2, len(rack.Endpoints)) + for _, ep := range rack.Endpoints { + assert.True(t, ep.Endpoint == gossipState.EndpointNames[0] || ep.Endpoint == gossipState.EndpointNames[1]) + assert.NotEmpty(t, ep.HostId) + assert.Equal(t, dcName, ep.DataCenter) + assert.Equal(t, rackName, ep.Rack) + assert.NotEmpty(t, ep.Status) + assert.Equal(t, "3.11.8", ep.ReleaseVersion) + assert.NotEmpty(t, ep.Tokens) + } +} + +func testGetClusterNotFound(t *testing.T, client Client) { + name := "cluster-notfound" + cluster, err := client.GetCluster(context.TODO(), name) + assert.NotNil(t, err) + assert.Containsf(t, err.Error(), "cluster with name \"cluster-notfound\" not found", name) + assert.Nil(t, cluster, "expected non-existent cluster to be nil") +} + +func testGetClusters(t *testing.T, client Client) { + results := make([]GetClusterResult, 0) + + for result := range client.GetClusters(context.TODO()) { + results = append(results, result) + } + + // Verify that we got the expected number of results + assert.Equal(t, 2, len(results)) + + // Verify that there were no errors + for _, result := range results { + assert.Nil(t, result.Error) + } + + assertGetClusterResultsContains(t, results, "cluster-1") + assertGetClusterResultsContains(t, results, "cluster-2") +} + +func assertGetClusterResultsContains(t *testing.T, results []GetClusterResult, clusterName string) { + var cluster *Cluster + for _, result := range results { + if result.Cluster.Name == clusterName { + cluster = result.Cluster + break + } + } + assert.NotNil(t, cluster, "failed to find %s", clusterName) +} + +func testGetClustersSync(t *testing.T, client Client) { + clusters, err := client.GetClustersSync(context.TODO()) + + if err != nil { + t.Fatalf("failed to get clusters synchronously: %s", err) + } + + // Verify that we got the expected number of results + assert.Equal(t, 2, len(clusters)) + + assertClustersContains(t, clusters, "cluster-1") + assertClustersContains(t, clusters, "cluster-2") +} + +func assertClustersContains(t *testing.T, clusters []*Cluster, clusterName string) { + for _, cluster := range clusters { + if cluster.Name == clusterName { + return + } + } + t.Errorf("failed to find cluster (%s)", clusterName) +} + +func testAddDeleteCluster(t *testing.T, client Client) { + cluster := "cluster-3" + seed := "cluster-3-node-0" + + if err := client.AddCluster(context.TODO(), cluster, seed); err != nil { + t.Fatalf("failed to add cluster (%s): %s", cluster, err) + } + + if clusterNames, err := client.GetClusterNames(context.TODO()); err != nil { + t.Fatalf("failed to get cluster names: %s", err) + } else { + assert.Equal(t, 3, len(clusterNames)) + } + + if err := client.DeleteCluster(context.TODO(), cluster); err != nil { + t.Fatalf("failed to delete cluster (%s): %s", cluster, err) + } + + if clusterNames, err := client.GetClusterNames(context.TODO()); err != nil { + t.Fatalf("failed to get cluster names: %s", err) + } else { + assert.Equal(t, 2, len(clusterNames)) + assert.NotContains(t, clusterNames, cluster) + } +} diff --git a/reaper/errors.go b/reaper/errors.go deleted file mode 100644 index fb42c88..0000000 --- a/reaper/errors.go +++ /dev/null @@ -1,10 +0,0 @@ -package reaper - -import "errors" - -var ( - CassandraClusterNotFound = errors.New("cassandra cluster not found") - - ErrRedirectsNotSupported = errors.New("http redirects are not supported") -) - diff --git a/reaper/http.go b/reaper/http.go new file mode 100644 index 0000000..907e2fa --- /dev/null +++ b/reaper/http.go @@ -0,0 +1,206 @@ +package reaper + +import ( + "context" + "encoding/json" + "fmt" + "github.com/google/go-querystring/query" + "io" + "io/ioutil" + "net/http" + "net/url" + "strconv" + "strings" +) + +func (c *client) doGet( + ctx context.Context, + path string, + queryParams interface{}, + expectedStatuses ...int, +) (*http.Response, error) { + return c.doRequest(ctx, http.MethodGet, path, queryParams, nil, expectedStatuses...) +} + +func (c *client) doPost( + ctx context.Context, + path string, + queryParams interface{}, + formData interface{}, + expectedStatuses ...int, +) (*http.Response, error) { + return c.doRequest(ctx, http.MethodPost, path, queryParams, formData, expectedStatuses...) +} + +func (c *client) doPut( + ctx context.Context, + path string, + queryParams interface{}, + formData interface{}, + expectedStatuses ...int, +) (*http.Response, error) { + return c.doRequest(ctx, http.MethodPut, path, queryParams, formData, expectedStatuses...) +} + +func (c *client) doDelete( + ctx context.Context, + path string, + queryParams interface{}, + expectedStatuses ...int, +) (*http.Response, error) { + return c.doRequest(ctx, http.MethodDelete, path, queryParams, nil, expectedStatuses...) +} + +func (c *client) doHead( + ctx context.Context, + path string, + queryParams interface{}, + expectedStatuses ...int, +) (*http.Response, error) { + return c.doRequest(ctx, http.MethodHead, path, queryParams, nil, expectedStatuses...) +} + +func (c *client) doRequest( + ctx context.Context, + method string, + path string, + queryParams interface{}, + formData interface{}, + expectedStatuses ...int, +) (*http.Response, error) { + u := c.resolveURL(path) + if queryParams != nil { + queryValues, err := c.paramSourceToValues(queryParams) + if err != nil { + return nil, err + } + u.RawQuery = queryValues.Encode() + } + var body string + var bodyReader io.Reader + if formData != nil { + formValues, err := c.paramSourceToValues(formData) + if err != nil { + return nil, err + } + body = formValues.Encode() + bodyReader = strings.NewReader(body) + } + req, err := http.NewRequestWithContext(ctx, method, u.String(), bodyReader) + if err != nil { + return nil, err + } + // TODO authentication headers + c.addCommonHeaders(req) + res, err := c.httpClient.Do(req) + if err == nil { + err = c.checkResponseStatus(res, expectedStatuses...) + } + return res, err +} + +func (c *client) mergeParamSources(paramSources ...interface{}) (*url.Values, error) { + mergedValues := url.Values{} + for _, paramSource := range paramSources { + paramSourceValues, err := c.paramSourceToValues(paramSource) + if err != nil { + return nil, err + } + for key, values := range *paramSourceValues { + mergedValues[key] = append(mergedValues[key], values...) + } + } + return &mergedValues, nil +} + +func (c *client) paramSourceToValues(paramSource interface{}) (*url.Values, error) { + if paramSource == nil { + return nil, nil + } + if values, ok := paramSource.(*url.Values); ok { + return values, nil + } + if m, ok := paramSource.(map[string]string); ok { + values := make(url.Values) + for key, val := range m { + values.Add(key, val) + } + return &values, nil + } + if m, ok := paramSource.(map[string][]string); ok { + values := url.Values(m) + return &values, nil + } + values, err := query.Values(paramSource) + if err != nil { + return nil, err + } + return &values, nil +} + +func (c *client) resolveURL(path string) *url.URL { + rel := &url.URL{Path: path} + u := c.baseURL.ResolveReference(rel) + return u +} + +func (c *client) addFormHeaders(req *http.Request, requestBody string) { + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(requestBody))) +} + +func (c *client) addCommonHeaders(req *http.Request) { + req.Header.Set("Accept", "application/json;q=0.9,text/plain") + if c.userAgent != "" { + req.Header.Set("User-Agent", c.userAgent) + } +} + +func (c *client) readBodyAsString(res *http.Response) (string, error) { + b, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", err + } + return string(b), nil +} + +func (c *client) readBodyAsJson(res *http.Response, v interface{}) error { + err := json.NewDecoder(res.Body).Decode(v) + _ = res.Body.Close() + return err +} + +func (c *client) checkResponseStatus(res *http.Response, expectedStatuses ...int) error { + if len(expectedStatuses) == 0 { + // the caller didn't specify any status: assume they will deal with statuses themselves + return nil + } + for _, status := range expectedStatuses { + if res.StatusCode == status { + return nil + } + } + return c.bodyToError(res) +} + +func (c *client) bodyToError(res *http.Response) error { + message, err := c.readBodyAsString(res) + if message != "" && err == nil { + contentType := res.Header.Get("Content-Type") + if contentType == "application/json" { + payload := &errorPayload{} + err = json.NewDecoder(strings.NewReader(message)).Decode(payload) + if err == nil && payload.Message != "" { + message = payload.Message + } + } + } else { + message = http.StatusText(res.StatusCode) + } + return fmt.Errorf("%s (HTTP status %d)", message, res.StatusCode) +} + +type errorPayload struct { + Code int `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} diff --git a/reaper/ping.go b/reaper/ping.go new file mode 100644 index 0000000..14633e3 --- /dev/null +++ b/reaper/ping.go @@ -0,0 +1,14 @@ +package reaper + +import ( + "context" + "net/http" +) + +func (c *client) IsReaperUp(ctx context.Context) (bool, error) { + if resp, err := c.doHead(ctx, "/ping", nil); err == nil { + return resp.StatusCode == http.StatusNoContent, nil + } else { + return false, err + } +} diff --git a/reaper/ping_test.go b/reaper/ping_test.go new file mode 100644 index 0000000..c10a2d6 --- /dev/null +++ b/reaper/ping_test.go @@ -0,0 +1,25 @@ +package reaper + +import ( + "context" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func testIsReaperUp(t *testing.T, client Client) { + success := assert.Eventually( + t, + func() bool { + isUp, err := client.IsReaperUp(context.Background()) + return isUp && err == nil + }, + 5*time.Minute, + 1*time.Second, + ) + if success { + t.Log("reaper is ready!") + } else { + t.Fatalf("reaper ping timeout") + } +} diff --git a/reaper/repair_run.go b/reaper/repair_run.go new file mode 100644 index 0000000..c2f4d69 --- /dev/null +++ b/reaper/repair_run.go @@ -0,0 +1,355 @@ +package reaper + +import ( + "context" + "encoding/json" + "fmt" + "github.com/google/uuid" + "math/big" + "net/http" + "net/url" + "strconv" + "time" +) + +type RepairRun struct { + Id uuid.UUID `json:"id"` + Cluster string `json:"cluster_name"` + Owner string `json:"owner"` + Keyspace string `json:"keyspace_name"` + Tables []string `json:"column_families"` + Cause string `json:"cause"` + State RepairRunState `json:"state"` + Intensity Intensity `json:"intensity"` + IncrementalRepair bool `json:"incremental_repair"` + TotalSegments int `json:"total_segments"` + RepairParallelism RepairParallelism `json:"repair_parallelism"` + SegmentsRepaired int `json:"segments_repaired"` + LastEvent string `json:"last_event"` + Duration string `json:"duration"` + Nodes []string `json:"nodes"` + Datacenters []string `json:"datacenters"` + IgnoredTables []string `json:"blacklisted_tables"` + RepairThreadCount int `json:"repair_thread_count"` + RepairUnitId uuid.UUID `json:"repair_unit_id"` +} + +func (r RepairRun) String() string { + return fmt.Sprintf("Repair run %v on %v/%v (%v)", r.Id, r.Cluster, r.Keyspace, r.State) +} + +type RepairSegment struct { + Id uuid.UUID + RunId uuid.UUID + RepairUnitId uuid.UUID + TokenRange *Segment + FailCount int + State RepairSegmentState + Coordinator string + Replicas map[string]string + StartTime *time.Time + EndTime *time.Time +} + +func (r RepairSegment) String() string { + return fmt.Sprintf("Repair run segment %v (%v)", r.Id, r.State) +} + +func (r *RepairSegment) UnmarshalJSON(data []byte) error { + temp := struct { + Id uuid.UUID `json:"id"` + RunId uuid.UUID `json:"runId"` + RepairUnitId uuid.UUID `json:"repairUnitId"` + TokenRange *Segment `json:"tokenRange"` + FailCount int `json:"failCount"` + State RepairSegmentState `json:"state"` + Coordinator string `json:"coordinatorHost"` + Replicas map[string]string `json:"replicas"` + StartTimeMillis int64 `json:"startTime,omitempty"` + EndTimeMillis int64 `json:"endTime,omitempty"` + }{} + err := json.Unmarshal(data, &temp) + if err != nil { + return err + } + r.Id = temp.Id + r.RunId = temp.RunId + r.RepairUnitId = temp.RepairUnitId + r.TokenRange = temp.TokenRange + r.FailCount = temp.FailCount + r.State = temp.State + r.Coordinator = temp.Coordinator + r.Replicas = temp.Replicas + if temp.StartTimeMillis != 0 { + unix := time.Unix(0, temp.StartTimeMillis*int64(time.Millisecond)) + r.StartTime = &unix + } + if temp.EndTimeMillis != 0 { + unix := time.Unix(0, temp.EndTimeMillis*int64(time.Millisecond)) + r.EndTime = &unix + } + return nil +} + +type Segment struct { + BaseRange *TokenRange `json:"baseRange"` + TokenRanges []*TokenRange `json:"tokenRanges"` + Replicas map[string]string `json:"replicas"` +} + +type TokenRange struct { + Start *big.Int `json:"start"` + End *big.Int `json:"end"` +} + +// Intensity controls the eagerness by which Reaper triggers repair segments. Must be in range (0.0, 1.0]. Reaper will +// use the duration of the previous repair segment to compute how much time to wait before triggering the next one. The +// idea behind this is that long segments mean a lot of data mismatch, and thus a lot of streaming and compaction. +// Intensity allows Reaper to adequately back off and give the cluster time to handle the load caused by the repair. +type Intensity = float64 + +type RepairRunState string + +const ( + RepairRunStateNotStarted = RepairRunState("NOT_STARTED") + RepairRunStateRunning = RepairRunState("RUNNING") + RepairRunStateError = RepairRunState("ERROR") + RepairRunStateDone = RepairRunState("DONE") + RepairRunStatePaused = RepairRunState("PAUSED") + RepairRunStateAborted = RepairRunState("ABORTED") + RepairRunStateDeleted = RepairRunState("DELETED") +) + +func (s RepairRunState) isActive() bool { + return s == RepairRunStateRunning || s == RepairRunStatePaused +} + +func (s RepairRunState) isTerminated() bool { + return s == RepairRunStateDone || s == RepairRunStateError || s == RepairRunStateAborted || s == RepairRunStateDeleted +} + +type RepairSegmentState string + +const ( + RepairSegmentStateNotStarted = RepairSegmentState("NOT_STARTED") + RepairSegmentStateRunning = RepairSegmentState("RUNNING") + RepairSegmentStateDone = RepairSegmentState("DONE") + RepairSegmentStateStarted = RepairSegmentState("STARTED") +) + +type RepairParallelism string + +const ( + RepairParallelismSequential = RepairParallelism("SEQUENTIAL") + RepairParallelismParallel = RepairParallelism("PARALLEL") + RepairParallelismDatacenterAware = RepairParallelism("DATACENTER_AWARE") +) + +type RepairRunSearchOptions struct { + + // Only return repair runs belonging to this cluster. + Cluster string `url:"cluster_name,omitempty"` + + // Only return repair runs belonging to this keyspace. + Keyspace string `url:"keyspace_name,omitempty"` + + // Restrict the search to repair runs whose states are in this list. + States []RepairRunState `url:"state,comma,omitempty"` +} + +type RepairRunCreateOptions struct { + + // Allows to specify which tables are targeted by a repair run. When this parameter is omitted, then the + // repair run will target all the tables in its target keyspace. + Tables []string `url:"tables,comma,omitempty"` + + // Allows to specify a list of tables that should not be repaired. Cannot be used in conjunction with Tables. + IgnoredTables []string `url:"blacklistedTables,comma,omitempty"` + + // Identifies the process, or cause that caused the repair to run. + Cause string `url:"cause,omitempty"` + + // Defines the amount of segments per node to create for the repair run. The value must be >0 and <=1000. + SegmentCountPerNode int `url:"segmentCountPerNode,omitempty"` + + // Defines the used repair parallelism for repair run. + RepairParallelism RepairParallelism `url:"repairParallelism,omitempty"` + + // Defines the used repair parallelism for repair run. + Intensity Intensity `url:"intensity,omitempty"` + + // Defines if incremental repair should be done. + IncrementalRepair bool `url:"incrementalRepair,omitempty"` + + // Allows to specify a list of nodes whose tokens should be repaired. + Nodes []string `url:"nodes,comma,omitempty"` + + // Allows to specify a list of datacenters to repair. + Datacenters []string `url:"datacenters,comma,omitempty"` + + // Defines the thread count to use for repair. Since Cassandra 2.2, repairs can be performed with + // up to 4 threads in order to parallelize the work on different token ranges. + RepairThreadCount int `url:"repairThreadCount,omitempty"` +} + +func (c *client) RepairRuns(ctx context.Context, searchOptions *RepairRunSearchOptions) (map[uuid.UUID]*RepairRun, error) { + res, err := c.doGet(ctx, "/repair_run", searchOptions, http.StatusOK) + if err == nil { + repairRuns := make([]*RepairRun, 0) + err = c.readBodyAsJson(res, &repairRuns) + if err == nil { + repairRunsMap := make(map[uuid.UUID]*RepairRun, len(repairRuns)) + for _, repairRun := range repairRuns { + repairRunsMap[repairRun.Id] = repairRun + } + return repairRunsMap, nil + } + } + return nil, fmt.Errorf("failed to get repair runs: %w", err) +} + +func (c *client) RepairRun(ctx context.Context, repairRunId uuid.UUID) (*RepairRun, error) { + path := fmt.Sprint("/repair_run/", repairRunId) + res, err := c.doGet(ctx, path, nil, http.StatusOK) + if err == nil { + repairRun := &RepairRun{} + err = c.readBodyAsJson(res, repairRun) + if err == nil { + return repairRun, nil + } + } + return nil, fmt.Errorf("failed to get repair run %v: %w", repairRunId, err) +} + +func (c *client) CreateRepairRun(ctx context.Context, cluster string, keyspace string, owner string, options *RepairRunCreateOptions) (uuid.UUID, error) { + queryParams, err := c.mergeParamSources( + map[string]string{ + "clusterName": cluster, + "keyspace": keyspace, + "owner": owner, + }, + options, + ) + if err == nil { + if options != nil && options.SegmentCountPerNode > 0 { + // Some Reaper versions accept "segmentCount", others "segmentCountPerNode"; + // make sure we include both in the query string. + queryParams.Set("segmentCount", strconv.Itoa(options.SegmentCountPerNode)) + } + var res *http.Response + res, err = c.doPost(ctx, "/repair_run", queryParams, nil, http.StatusCreated) + if err == nil { + repairRun := &RepairRun{} + err = c.readBodyAsJson(res, repairRun) + if err == nil { + return repairRun.Id, nil + } + } + } + return uuid.Nil, fmt.Errorf("failed to create repair run: %w", err) +} + +func (c *client) UpdateRepairRun(ctx context.Context, repairRunId uuid.UUID, newIntensity Intensity) error { + path := fmt.Sprint("/repair_run/", repairRunId, "/intensity/", newIntensity) + _, err := c.doPut(ctx, path, nil, nil, http.StatusOK) + if err == nil { + return nil + } + return fmt.Errorf("failed to update intensity of repair run %v: %w", repairRunId, err) +} + +func (c *client) StartRepairRun(ctx context.Context, repairRunId uuid.UUID) error { + path := fmt.Sprint("/repair_run/", repairRunId, "/state/", RepairRunStateRunning) + _, err := c.doPut(ctx, path, nil, nil, http.StatusOK, http.StatusNoContent) + if err == nil { + return nil + } + return fmt.Errorf("failed to start repair run %v: %w", repairRunId, err) +} + +func (c *client) PauseRepairRun(ctx context.Context, repairRunId uuid.UUID) error { + path := fmt.Sprint("/repair_run/", repairRunId, "/state/", RepairRunStatePaused) + _, err := c.doPut(ctx, path, nil, nil, http.StatusOK, http.StatusNoContent) + if err == nil { + return nil + } + return fmt.Errorf("failed to pause repair run %v: %w", repairRunId, err) +} + +func (c *client) ResumeRepairRun(ctx context.Context, repairRunId uuid.UUID) error { + path := fmt.Sprint("/repair_run/", repairRunId, "/state/", RepairRunStateRunning) + _, err := c.doPut(ctx, path, nil, nil, http.StatusOK, http.StatusNoContent) + if err == nil { + return nil + } + return fmt.Errorf("failed to resume repair run %v: %w", repairRunId, err) +} + +func (c *client) AbortRepairRun(ctx context.Context, repairRunId uuid.UUID) error { + path := fmt.Sprint("/repair_run/", repairRunId, "/state/", RepairRunStateAborted) + _, err := c.doPut(ctx, path, nil, nil, http.StatusOK, http.StatusNoContent) + if err == nil { + return nil + } + return fmt.Errorf("failed to abort repair run %v: %w", repairRunId, err) +} + +func (c *client) RepairRunSegments(ctx context.Context, repairRunId uuid.UUID) (map[uuid.UUID]*RepairSegment, error) { + path := fmt.Sprint("/repair_run/", repairRunId, "/segments") + res, err := c.doGet(ctx, path, nil, http.StatusOK) + if err == nil { + repairRunSegments := make([]*RepairSegment, 0) + err = c.readBodyAsJson(res, &repairRunSegments) + if err == nil { + repairRunSegmentsMap := make(map[uuid.UUID]*RepairSegment, len(repairRunSegments)) + for _, segment := range repairRunSegments { + repairRunSegmentsMap[segment.Id] = segment + } + return repairRunSegmentsMap, nil + } + } + return nil, fmt.Errorf("failed to get segments of repair run %v: %w", repairRunId, err) +} + +func (c *client) AbortRepairRunSegment(ctx context.Context, repairRunId uuid.UUID, segmentId uuid.UUID) error { + path := fmt.Sprint("/repair_run/", repairRunId, "/segments/abort/", segmentId) + _, err := c.doPost(ctx, path, nil, nil, http.StatusOK) + if err == nil { + return nil + } + return fmt.Errorf("failed to abort segment %v of repair run %v: %w", segmentId, repairRunId, err) +} + +func (c *client) DeleteRepairRun(ctx context.Context, repairRunId uuid.UUID, owner string) error { + path := fmt.Sprint("/repair_run/", repairRunId) + queryParams := &url.Values{"owner": {owner}} + res, err := c.doDelete(ctx, path, queryParams, http.StatusAccepted) + if err == nil { + return nil + } else { + // FIXME this REST resource currently returns 500 for succeeded deletes + if res != nil && res.StatusCode == http.StatusInternalServerError { + _, err2 := c.doGet(ctx, path, nil, http.StatusNotFound) + if err2 == nil { + return nil + } + } + } + return fmt.Errorf("failed to delete repair run %v: %w", repairRunId, err) +} + +func (c *client) PurgeRepairRuns(ctx context.Context) (int, error) { + res, err := c.doPost(ctx, "/repair_run/purge", nil, nil, http.StatusOK) + if err == nil { + var purgedStr string + purgedStr, err = c.readBodyAsString(res) + if err == nil { + var purged int + purged, err = strconv.Atoi(purgedStr) + if err == nil { + return purged, nil + } + } + } + return 0, fmt.Errorf("failed to purge repair runs: %w", err) +} diff --git a/reaper/repair_run_test.go b/reaper/repair_run_test.go new file mode 100644 index 0000000..778f3f5 --- /dev/null +++ b/reaper/repair_run_test.go @@ -0,0 +1,405 @@ +package reaper + +import ( + "context" + "fmt" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" +) + +func testGetRepairRun(t *testing.T, client Client) { + expected := createRepairRun(t, client, "cluster-1") + defer deleteRepairRun(t, client, expected) + actual, err := client.RepairRun( + context.Background(), + expected.Id, + ) + require.Nil(t, err) + assert.Equal(t, expected, actual) +} + +func testGetRepairRunNotFound(t *testing.T, client Client) { + nonExistentRepairRun, _ := uuid.NewUUID() + actual, err := client.RepairRun( + context.Background(), + nonExistentRepairRun, + ) + assert.Nil(t, actual) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("repair run %v doesn't exist", nonExistentRepairRun)) +} + +func testGetRepairRunIgnoredTables(t *testing.T, client Client) { + runId, err := client.CreateRepairRun( + context.Background(), + "cluster-2", + keyspace, + "Bob", + &RepairRunCreateOptions{IgnoredTables: []string{"table2"}}, + ) + require.Nil(t, err) + repairRun, err := client.RepairRun(context.Background(), runId) + assert.Nil(t, err) + assert.Equal(t, repairRun.Tables, []string{"table1"}) + assert.Equal(t, repairRun.IgnoredTables, []string{"table2"}) + err = client.DeleteRepairRun(context.Background(), runId, "Bob") + assert.Nil(t, err) +} + +func testGetRepairRuns(t *testing.T, client Client) { + run1 := createRepairRun(t, client, "cluster-1") + run2 := createRepairRun(t, client, "cluster-2") + defer deleteRepairRun(t, client, run1) + defer deleteRepairRun(t, client, run2) + repairRuns, err := client.RepairRuns(context.Background(), nil) + require.Nil(t, err) + assert.Len(t, repairRuns, 2) + assert.Contains(t, repairRuns, run1.Id) + assert.Contains(t, repairRuns, run2.Id) +} + +func testGetRepairRunsFilteredByCluster(t *testing.T, client Client) { + run1 := createRepairRun(t, client, "cluster-1") + run2 := createRepairRun(t, client, "cluster-2") + defer deleteRepairRun(t, client, run1) + defer deleteRepairRun(t, client, run2) + repairRuns, err := client.RepairRuns( + context.Background(), + &RepairRunSearchOptions{ + Cluster: "cluster-1", + }, + ) + require.Nil(t, err) + assert.Len(t, repairRuns, 1) + assert.Contains(t, repairRuns, run1.Id) + assert.NotContains(t, repairRuns, run2.Id) +} + +func testGetRepairRunsFilteredByKeyspace(t *testing.T, client Client) { + run1 := createRepairRun(t, client, "cluster-1") + run2 := createRepairRun(t, client, "cluster-2") + defer deleteRepairRun(t, client, run1) + defer deleteRepairRun(t, client, run2) + repairRuns, err := client.RepairRuns( + context.Background(), + &RepairRunSearchOptions{ + Keyspace: keyspace, + }, + ) + require.Nil(t, err) + assert.Len(t, repairRuns, 2) + assert.Contains(t, repairRuns, run1.Id) + assert.Contains(t, repairRuns, run2.Id) + repairRuns, err = client.RepairRuns( + context.Background(), + &RepairRunSearchOptions{ + Keyspace: "nonexistent_keyspace", + }, + ) + require.Nil(t, err) + assert.Len(t, repairRuns, 0) +} + +func testGetRepairRunsFilteredByState(t *testing.T, client Client) { + run1 := createRepairRun(t, client, "cluster-1") + run2 := createRepairRun(t, client, "cluster-2") + defer deleteRepairRun(t, client, run1) + defer deleteRepairRun(t, client, run2) + repairRuns, err := client.RepairRuns( + context.Background(), + &RepairRunSearchOptions{ + States: []RepairRunState{RepairRunStateNotStarted}, + }, + ) + require.Nil(t, err) + assert.Len(t, repairRuns, 2) + assert.Contains(t, repairRuns, run1.Id) + assert.Contains(t, repairRuns, run2.Id) + repairRuns, err = client.RepairRuns( + context.Background(), + &RepairRunSearchOptions{ + States: []RepairRunState{RepairRunStateRunning, RepairRunStateDone}, + }, + ) + require.Nil(t, err) + assert.Len(t, repairRuns, 0) +} + +func testCreateDeleteRepairRun(t *testing.T, client Client) { + run := createRepairRun(t, client, "cluster-1") + deleteRepairRun(t, client, run) + repairRuns, err := client.RepairRuns(context.Background(), nil) + require.Nil(t, err) + assert.Len(t, repairRuns, 0) +} + +func testCreateStartFinishRepairRun(t *testing.T, client Client) { + run := createRepairRun(t, client, "cluster-1") + defer deleteRepairRun(t, client, run) + err := client.StartRepairRun(context.Background(), run.Id) + require.Nil(t, err) + done := waitForRepairRun(t, client, run, RepairRunStateDone) + assert.Equal(t, RepairRunStateDone, done.State) + segments, err := client.RepairRunSegments(context.Background(), done.Id) + require.Nil(t, err) + for _, segment := range segments { + assert.Equal(t, RepairSegmentStateDone, segment.State) + assert.NotNil(t, segment.StartTime) + assert.NotNil(t, segment.EndTime) + } +} + +func testCreateStartPauseUpdateResumeRepairRun(t *testing.T, client Client) { + run := createRepairRun(t, client, "cluster-1") + defer deleteRepairRun(t, client, run) + err := client.StartRepairRun(context.Background(), run.Id) + require.Nil(t, err) + started, err := client.RepairRun(context.Background(), run.Id) + require.Nil(t, err) + assert.Equal(t, RepairRunStateRunning, started.State) + err = client.PauseRepairRun(context.Background(), run.Id) + if err != nil { + // pause not possible because repair is DONE + require.Contains(t, err.Error(), "Transition DONE->PAUSED not supported") + done, err := client.RepairRun(context.Background(), run.Id) + require.Nil(t, err) + assert.Equal(t, RepairRunStateDone, done.State) + } else { + paused, err := client.RepairRun(context.Background(), run.Id) + require.Nil(t, err) + assert.Equal(t, RepairRunStatePaused, paused.State) + err = client.UpdateRepairRun(context.Background(), run.Id, 0.5) + require.Nil(t, err) + updated, err := client.RepairRun(context.Background(), run.Id) + require.Nil(t, err) + assert.InDelta(t, 0.5, updated.Intensity, 0.001) + err = client.ResumeRepairRun(context.Background(), run.Id) + require.Nil(t, err) + done := waitForRepairRun(t, client, run, RepairRunStateDone) + assert.Equal(t, RepairRunStateDone, done.State) + } +} + +func testCreateAbortRepairRun(t *testing.T, client Client) { + run := createRepairRun(t, client, "cluster-1") + defer deleteRepairRun(t, client, run) + err := client.StartRepairRun(context.Background(), run.Id) + require.Nil(t, err) + started, err := client.RepairRun(context.Background(), run.Id) + require.Nil(t, err) + assert.Equal(t, RepairRunStateRunning, started.State) + err = client.AbortRepairRun(context.Background(), run.Id) + require.Nil(t, err) + aborted, err := client.RepairRun(context.Background(), run.Id) + require.Nil(t, err) + assert.Equal(t, RepairRunStateAborted, aborted.State) +} + +func testGetRepairRunSegments(t *testing.T, client Client) { + run := createRepairRun(t, client, "cluster-1") + defer deleteRepairRun(t, client, run) + segments, err := client.RepairRunSegments(context.Background(), run.Id) + require.Nil(t, err) + for _, segment := range segments { + assert.Equal(t, RepairSegmentStateNotStarted, segment.State) + assert.Nil(t, segment.StartTime) + assert.Nil(t, segment.EndTime) + checkRepairRunSegment(t, run, segment) + } + err = client.StartRepairRun(context.Background(), run.Id) + require.Nil(t, err) + segments = waitForSegmentsStarted(t, client, run) + for _, segment := range segments { + // could be STARTED, RUNNING or DONE + assert.NotEqual(t, RepairSegmentStateNotStarted, segment.State) + assert.NotNil(t, segment.StartTime) + assert.True(t, segment.EndTime == nil || segment.State == RepairSegmentStateDone) + checkRepairRunSegment(t, run, segment) + } + err = client.PauseRepairRun(context.Background(), run.Id) + if err != nil { + // pause not possible because repair is DONE + require.Contains(t, err.Error(), "Transition DONE->PAUSED not supported") + } else { + segments, err = client.RepairRunSegments(context.Background(), run.Id) + require.Nil(t, err) + for _, segment := range segments { + // some segments may be DONE or even RUNNING: cannot assert state here + assert.NotNil(t, segment.StartTime) + assert.True(t, segment.EndTime == nil || segment.State == RepairSegmentStateDone) + checkRepairRunSegment(t, run, segment) + } + err = client.StartRepairRun(context.Background(), run.Id) + require.Nil(t, err) + done := waitForRepairRun(t, client, run, RepairRunStateDone) + assert.Equal(t, RepairRunStateDone, done.State) + } + segments, err = client.RepairRunSegments(context.Background(), run.Id) + require.Nil(t, err) + for _, segment := range segments { + assert.Equal(t, RepairSegmentStateDone, segment.State) + assert.NotNil(t, segment.StartTime) + assert.NotNil(t, segment.EndTime) + checkRepairRunSegment(t, run, segment) + } +} + +func testAbortRepairRunSegments(t *testing.T, client Client) { + run := createRepairRun(t, client, "cluster-1") + defer deleteRepairRun(t, client, run) + err := client.StartRepairRun(context.Background(), run.Id) + require.Nil(t, err) + segments := waitForSegmentsStarted(t, client, run) + for _, segment := range segments { + // could be STARTED, RUNNING or DONE + assert.NotEqual(t, RepairSegmentStateNotStarted, segment.State) + err = client.AbortRepairRunSegment(context.Background(), run.Id, segment.Id) + if err != nil { + require.Contains(t, err.Error(), "Cannot abort segment on repair run with status DONE") + } + } + segments, err = client.RepairRunSegments(context.Background(), run.Id) + require.Nil(t, err) + for _, segment := range segments { + assert.True(t, + segment.State == RepairSegmentStateNotStarted || + segment.State == RepairSegmentStateDone) + } +} + +func testDeleteRepairRunNotFound(t *testing.T, client Client) { + nonExistentRepairRun, _ := uuid.NewUUID() + err := client.DeleteRepairRun(context.Background(), nonExistentRepairRun, "Alice") + assert.NotNil(t, err) + // Reaper returns a spurious '%s' in the error message + assert.Contains(t, err.Error(), fmt.Sprintf("Repair run %%s%v not found", nonExistentRepairRun)) +} + +func testPurgeRepairRun(t *testing.T, client Client) { + run := createRepairRun(t, client, "cluster-1") + defer deleteRepairRun(t, client, run) + purged, err := client.PurgeRepairRuns(context.Background()) + require.Nil(t, err) + assert.Equal(t, 0, purged) +} + +func createRepairRun(t *testing.T, client Client, cluster string) *RepairRun { + runId, err := client.CreateRepairRun( + context.Background(), + cluster, + keyspace, + "Alice", + &RepairRunCreateOptions{ + Tables: []string{"table1", "table2"}, + SegmentCountPerNode: 3, + RepairParallelism: RepairParallelismParallel, + Intensity: 0.1, + IncrementalRepair: false, + RepairThreadCount: 4, + Cause: "testing repair runs", + }, + ) + require.Nil(t, err) + repairRun, err := client.RepairRun(context.Background(), runId) + require.Nil(t, err) + return checkRepairRun(t, cluster, repairRun) +} + +func checkRepairRun(t *testing.T, cluster string, actual *RepairRun) *RepairRun { + assert.NotNil(t, actual.Id) + assert.Equal(t, cluster, actual.Cluster) + assert.Equal(t, keyspace, actual.Keyspace) + assert.Equal(t, "Alice", actual.Owner) + assert.Equal(t, "testing repair runs", actual.Cause) + assert.ElementsMatch(t, []string{"table1", "table2"}, actual.Tables) + assert.Equal(t, RepairRunStateNotStarted, actual.State) + assert.InDelta(t, 0.1, actual.Intensity, 0.001) + assert.False(t, actual.IncrementalRepair) + // Can't really guess the total + assert.NotZero(t, actual.TotalSegments) + assert.Equal(t, RepairParallelismParallel, actual.RepairParallelism) + assert.Equal(t, 0, actual.SegmentsRepaired) + assert.Equal(t, "no events", actual.LastEvent) + assert.Empty(t, actual.Duration) + assert.Empty(t, actual.Nodes) + assert.Empty(t, actual.Datacenters) + assert.Empty(t, actual.IgnoredTables) + assert.Equal(t, 4, actual.RepairThreadCount) + assert.NotNil(t, actual.RepairUnitId) + return actual +} + +func checkRepairRunSegment(t *testing.T, run *RepairRun, actual *RepairSegment) { + assert.NotNil(t, actual.Id) + assert.Equal(t, run.Id, actual.RunId) + assert.NotNil(t, actual.RepairUnitId) + assert.NotNil(t, actual.TokenRange) + assert.NotNil(t, actual.Coordinator) +} + +func deleteRepairRun(t *testing.T, client Client, run *RepairRun) { + _ = client.PauseRepairRun(context.Background(), run.Id) + err := client.DeleteRepairRun(context.Background(), run.Id, "Alice") + if err != nil { + assert.True(t, + strings.Contains(err.Error(), "is currently running") || + strings.Contains(err.Error(), "has running segments")) + waitForRepairRun(t, client, run, RepairRunStateNotStarted) + err = client.DeleteRepairRun(context.Background(), run.Id, "Alice") + } + require.Nil(t, err) +} + +func waitForRepairRun(t *testing.T, client Client, run *RepairRun, state RepairRunState) *RepairRun { + success := assert.Eventually( + t, + func() bool { + actual, err := client.RepairRun(context.Background(), run.Id) + return err == nil && actual.State == state + }, + 15*time.Minute, + 5*time.Second, + ) + actual, err := client.RepairRun(context.Background(), run.Id) + require.Nil(t, err) + if success { + return actual + } + t.Fatalf( + "timed out waiting for repair to reach state %s, last state was: %s", + state, + actual.State, + ) + return nil +} + +func waitForSegmentsStarted(t *testing.T, client Client, run *RepairRun) map[uuid.UUID]*RepairSegment { + success := assert.Eventually( + t, + func() bool { + segments, err := client.RepairRunSegments(context.Background(), run.Id) + if err != nil { + return false + } + for _, segment := range segments { + if segment.State == RepairSegmentStateNotStarted { + return false + } + } + return true + }, + 15*time.Minute, + 5*time.Second, + ) + segments, err := client.RepairRunSegments(context.Background(), run.Id) + require.Nil(t, err) + if success { + return segments + } + t.Fatal("timed out waiting for repair to start") + return nil +} diff --git a/reaper/repair_schedule.go b/reaper/repair_schedule.go new file mode 100644 index 0000000..e5d64ce --- /dev/null +++ b/reaper/repair_schedule.go @@ -0,0 +1,29 @@ +package reaper + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +func (c *client) RepairSchedules(ctx context.Context) ([]RepairSchedule, error) { + return c.fetchRepairSchedules(ctx, "/repair_schedule") +} + +func (c *client) RepairSchedulesForCluster(ctx context.Context, clusterName string) ([]RepairSchedule, error) { + path := "/repair_schedule/cluster/" + url.PathEscape(clusterName) + return c.fetchRepairSchedules(ctx, path) +} + +func (c *client) fetchRepairSchedules(ctx context.Context, path string) ([]RepairSchedule, error) { + res, err := c.doGet(ctx, path, nil, http.StatusOK) + if err == nil { + repairSchedules := make([]RepairSchedule, 0) + err = c.readBodyAsJson(res, &repairSchedules) + if err == nil { + return repairSchedules, nil + } + } + return nil, fmt.Errorf("failed to fetch repair schedules: %w", err) +} diff --git a/reaper/types.go b/reaper/types.go deleted file mode 100644 index 39b2655..0000000 --- a/reaper/types.go +++ /dev/null @@ -1,99 +0,0 @@ -package reaper - -import "time" - -type Cluster struct { - Name string - JmxUsername string - JmxPasswordSet bool - Seeds []string - NodeState NodeState -} - -type NodeState struct { - GossipStates []GossipState -} - -type GossipState struct { - SourceNode string - EndpointNames []string - TotalLoad float64 - DataCenters map[string]DataCenterState -} - -type DataCenterState struct { - Name string - Racks map[string]RackState -} - -type RackState struct { - Name string - Endpoints []EndpointState -} - -type EndpointState struct { - Endpoint string - DataCenter string - Rack string - HostId string - Status string - Severity float64 - ReleaseVersion string - Tokens string - Load float64 -} - -type GetClusterResult struct { - Cluster *Cluster - Error error -} - -type RepairSchedule struct { - Id string `json:"id"` - Owner string `json:"owner,omitempty"` - State string `json:"state,omitempty"` - Intensity float64 `json:"intensity,omitempty"` - ClusterName string `json:"cluster_name,omitempty"` - KeyspaceName string `json:"keyspace_name,omitempty"` - RepairParallism string `json:"repair_parallelism,omitempty"` - IncrementalRepair bool `json:"incremental_repair,omitempty"` - ThreadCount int `json:"repair_thread_count,omitempty"` - UnitId string `json:"repair_unit_id,omitempty"` - DaysBetween int `json:"scheduled_days_between,omitempty"` - Created time.Time `json:"creation_time,omitempty"` - Paused time.Time `json:"pause_time,omitempty"` - NextActivation time.Time `json:"next_activation,omitempty"` -} - -// All the following types are used internally by the client and not part of the public API - -type clusterStatus struct { - Name string `json:"name"` - JmxUsername string `json:"jmx_username,omitempty"` - JmxPasswordSet bool `json:"jmx_password_is_set,omitempty"` - Seeds []string `json:"seed_hosts,omitempty"` - NodeStatus nodeStatus `json:"nodes_status"` -} - -type nodeStatus struct { - EndpointStates []gossipStatus `json:"endpointStates,omitempty"` -} - -type gossipStatus struct { - SourceNode string `json:"sourceNode"` - EndpointNames []string `json:"endpointNames,omitempty"` - TotalLoad float64 `json:"totalLoad,omitempty"` - Endpoints map[string]map[string][]endpointStatus -} - -type endpointStatus struct { - Endpoint string `json:"endpoint"` - DataCenter string `json:"dc"` - Rack string `json:"rack"` - HostId string `json:"hostId"` - Status string `json:"status"` - Severity float64 `json:"severity"` - ReleaseVersion string `json:"releaseVersion"` - Tokens string `json:"tokens"` - Load float64 `json:"load"` -} diff --git a/testenv/util.go b/testenv/util.go index 083dc7b..8f5db59 100644 --- a/testenv/util.go +++ b/testenv/util.go @@ -1,10 +1,13 @@ package testenv import ( + "context" "fmt" "io/ioutil" + "math/rand" "os" "os/exec" + "path" "path/filepath" "regexp" "testing" @@ -13,7 +16,7 @@ import ( var cassandraReadyStatusRegex = regexp.MustCompile(`\nUN `) -// Stops all services declared in docker-compose.yaml. This function blocks until the +// StopServices stops all services declared in docker-compose.yaml. This function blocks until the // operation completes. func StopServices(t *testing.T) error { t.Log("stopping services") @@ -21,29 +24,29 @@ func StopServices(t *testing.T) error { return stopServices.Run() } -// Starts all services declared in docker-compose.yaml in detached mode. +// StartServices starts all services declared in docker-compose.yaml in detached mode. func StartServices(t *testing.T) error { t.Log("starting services") startServices := exec.Command("docker-compose", "up", "-d") return startServices.Run() } -// Deletes all contents under PROJECT_ROOT/data/cassandra +// PurgeCassandraDataDir deletes all contents under PROJECT_ROOT/data/cassandra func PurgeCassandraDataDir(t *testing.T) error { - t.Log("puring cassandra data dir") - cassandrDataDir, err := filepath.Abs("../data/cassandra") + t.Log("purging cassandra data dir") + cassandraDataDir, err := filepath.Abs("../data/cassandra") if err != nil { return fmt.Errorf("failed to get path of cassandra data dir: %w", err) } - if err := os.RemoveAll(cassandrDataDir); err != nil { - return fmt.Errorf("failed to purge %s: %w", cassandrDataDir, err) + if err := os.RemoveAll(cassandraDataDir); err != nil { + return fmt.Errorf("failed to purge %s: %w", cassandraDataDir, err) } return nil } -// A convenience function that does the following: +// ResetServices is a convenience function that does the following: // // * stop all services // * purge cassandra data directory @@ -100,38 +103,188 @@ func checkCassandraStatus(seed string) ([]byte, error) { return bytes, nil } -// Runs nodetool status against the seed node. Blocks until numNodes nodes report a status of UN. -// This function will perform a max of 10 checks with a delay of one second between retries. -func WaitForClusterReady(t *testing.T, seed string, numNodes int) error { - // TODO make the number of checks configurable - for i := 0; i < 10; i++ { - t.Logf("checking cassandra cluster status with seed %s", seed) - bytes, err := checkCassandraStatus(seed) - if err == nil { - matches := cassandraReadyStatusRegex.FindAll(bytes, -1) - if matches != nil && len(matches) == numNodes { - return nil +// WaitForClusterReady runs nodetool status against the seed node. Blocks until numNodes nodes report a status of UN. +func WaitForClusterReady(ctx context.Context, seed string, numNodes int) error { + // TODO make the timeout configurable + ctx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + for { + select { + default: + b, err := checkCassandraStatus(seed) + if err == nil { + matches := cassandraReadyStatusRegex.FindAll(b, -1) + if matches != nil && len(matches) == numNodes { + return nil + } } + // TODO make the duration configurable + time.Sleep(time.Second) + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for nodetool status with seed (%s)", seed) } - // TODO make the duration configurable - time.Sleep(1 * time.Second) } - - return fmt.Errorf("timed out waiting for nodetool status with seed (%s)", seed) } -// Adds a cluster to Reaper without using the client. -func AddCluster(t *testing.T, cluster string, seed string) error { - t.Logf("registering cluster %s", cluster) +// AddCluster adds a cluster to Reaper without using the client. +func AddCluster(ctx context.Context, cluster string, seed string) error { relPath := "../scripts/add-cluster.sh" - path, err := filepath.Abs(relPath) + p, err := filepath.Abs(relPath) if err != nil { return fmt.Errorf("failed to get absolute path of (%s): %w", relPath, err) } - script := exec.Command(path, cluster, seed) - if err = script.Run(); err != nil { - return fmt.Errorf("add cluster script (%s) failed with seed (%s): %w", path, seed, err) + return exec.CommandContext(ctx, p, cluster, seed).Run() +} + +func WaitForCqlReady(ctx context.Context, seed string) error { + ctx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + for { + select { + default: + err := checkCqlStatus(ctx, seed) + if err == nil { + return nil + } + time.Sleep(time.Second) + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for CQL readiness with seed (%s)", seed) + } } +} - return nil -} \ No newline at end of file +func checkCqlStatus(ctx context.Context, node string) error { + s := "SELECT release_version FROM system.local" + checkStatus := exec.CommandContext( + ctx, + "docker-compose", + "exec", + "-T", + node, + "cqlsh", + "-u", + "reaperUser", + "-p", + "reaperPass", + "-e", + s, + node, + "9042", + ) + return checkStatus.Run() +} + +func CreateKeyspace(ctx context.Context, node string, keyspace string, rf int) error { + stmt := fmt.Sprintf( + "CREATE KEYSPACE IF NOT EXISTS \"%s\" "+ + "WITH replication = {'class':'NetworkTopologyStrategy', 'datacenter1':%d} "+ + "AND durable_writes = true", + keyspace, + rf, + ) + createKeyspace := exec.CommandContext( + ctx, + "docker-compose", + "exec", + "-T", + node, + "cqlsh", + "-u", + "reaperUser", + "-p", + "reaperPass", + "-e", + stmt, + node, + "9042", + ) + return createKeyspace.Run() +} + +func CreateTable(ctx context.Context, node string, keyspace string, table string) error { + stmt := fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS \"%s\".\"%s\" "+ + "(pk int, cc timeuuid, v text, "+ + "PRIMARY KEY (pk, cc))", + keyspace, + table, + ) + createTable := exec.CommandContext( + ctx, + "docker-compose", + "exec", + "-T", + node, + "cqlsh", + "-u", + "reaperUser", + "-p", + "reaperPass", + "-e", + stmt, + node, + "9042", + ) + return createTable.Run() +} + +func CreateCqlInsertScript(keyspace string, table string) (*os.File, error) { + script, err := ioutil.TempFile(os.TempDir(), "insert-*.cql") + if err != nil { + return nil, err + } + defer func(script *os.File) { + _ = script.Close() + }(script) + for pk := 0; pk < 1000; pk++ { + insert := fmt.Sprintf( + "INSERT INTO \"%s\".\"%s\" (pk, cc, v) VALUES (%d, now(), '%s');\n", + keyspace, + table, + pk, + randomString(10, 100), + ) + if _, err = script.WriteString(insert); err != nil { + return nil, err + } + } + return script, nil +} + +func ExecuteCqlScript(ctx context.Context, node string, script *os.File) error { + remotePath := "/tmp/" + path.Base(script.Name()) + copyScript := exec.Command("docker", "cp", script.Name(), "reaper-client-go_"+node+"_1:"+remotePath) + err := copyScript.Run() + if err != nil { + return err + } + execScript := exec.CommandContext( + ctx, + "docker-compose", + "exec", + "-T", + node, + "cqlsh", + "-u", + "reaperUser", + "-p", + "reaperPass", + "-f", + remotePath, + node, + "9042", + ) + return execScript.Run() +} + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + +func randomString(min int, max int) string { + rand.Seed(time.Now().UnixNano()) + length := rand.Intn(max-min+1) + min + s := make([]rune, length) + for i := range s { + s[i] = letters[rand.Intn(len(letters))] + } + return string(s) +}