diff --git a/pkg/cmd/scylla-manager/server.go b/pkg/cmd/scylla-manager/server.go index 8c9a40c22..580683132 100644 --- a/pkg/cmd/scylla-manager/server.go +++ b/pkg/cmd/scylla-manager/server.go @@ -164,7 +164,7 @@ func (s *server) makeServices(ctx context.Context) error { func (s *server) onClusterChange(ctx context.Context, c cluster.Change) error { switch c.Type { case cluster.Update: - s.configCacheSvc.ForceUpdateCluster(ctx, c.ID) + go s.configCacheSvc.ForceUpdateCluster(ctx, c.ID) case cluster.Create: s.configCacheSvc.ForceUpdateCluster(ctx, c.ID) for _, t := range makeAutoHealthCheckTasks(c.ID) { diff --git a/pkg/scyllaclient/client_scylla.go b/pkg/scyllaclient/client_scylla.go index 193fb9e15..03210b83f 100644 --- a/pkg/scyllaclient/client_scylla.go +++ b/pkg/scyllaclient/client_scylla.go @@ -178,6 +178,15 @@ func (c *Client) Datacenters(ctx context.Context) (map[string][]string, error) { return res, errs } +// GossiperEndpointLiveGet finds live nodes (according to gossiper). +func (c *Client) GossiperEndpointLiveGet(ctx context.Context) ([]string, error) { + live, err := c.scyllaOps.GossiperEndpointLiveGet(&operations.GossiperEndpointLiveGetParams{Context: ctx}) + if err != nil { + return nil, err + } + return live.GetPayload(), nil +} + // HostDatacenter looks up the datacenter that the given host belongs to. func (c *Client) HostDatacenter(ctx context.Context, host string) (dc string, err error) { // Try reading from cache diff --git a/pkg/scyllaclient/retry_integration_test.go b/pkg/scyllaclient/retry_integration_test.go index 5e37b9996..db81d62c5 100644 --- a/pkg/scyllaclient/retry_integration_test.go +++ b/pkg/scyllaclient/retry_integration_test.go @@ -57,7 +57,7 @@ func TestRetryWithTimeoutIntegration(t *testing.T) { test := table[i] t.Run(fmt.Sprintf("block %d nodes", test.block), func(t *testing.T) { - if err := testRetry(hosts, test.block, test.timeout); err != nil { + if err := testRetry(t, hosts, test.block, test.timeout); err != nil { t.Fatal(err) } }) @@ -72,12 +72,12 @@ func allHosts() ([]string, error) { return client.Hosts(context.Background()) } -func testRetry(hosts []string, n int, shouldTimeout bool) error { +func testRetry(t *testing.T, hosts []string, n int, shouldTimeout bool) error { blockedHosts := make([]string, 0, len(hosts)) block := func(ctx context.Context, hosts []string) error { for _, h := range hosts { - err := RunIptablesCommand(h, CmdBlockScyllaREST) + err := RunIptablesCommand(t, h, CmdBlockScyllaREST) if err != nil { return err } @@ -88,7 +88,7 @@ func testRetry(hosts []string, n int, shouldTimeout bool) error { unblock := func(ctx context.Context) error { for _, h := range blockedHosts { - err := RunIptablesCommand(h, CmdUnblockScyllaREST) + err := RunIptablesCommand(t, h, CmdUnblockScyllaREST) if err != nil { return err } diff --git a/pkg/service/backup/service_backup_integration_test.go b/pkg/service/backup/service_backup_integration_test.go index 610ab0f63..0b25c39aa 100644 --- a/pkg/service/backup/service_backup_integration_test.go +++ b/pkg/service/backup/service_backup_integration_test.go @@ -859,10 +859,10 @@ func TestBackupWithNodesDownIntegration(t *testing.T) { WriteData(t, clusterSession, testKeyspace, 1) Print("Given: downed node") - if err := RunIptablesCommand(IPFromTestNet("11"), CmdBlockScyllaREST); err != nil { + if err := RunIptablesCommand(t, IPFromTestNet("11"), CmdBlockScyllaREST); err != nil { t.Fatal(err) } - defer RunIptablesCommand(IPFromTestNet("11"), CmdUnblockScyllaREST) + defer RunIptablesCommand(t, IPFromTestNet("11"), CmdUnblockScyllaREST) Print("When: get target") target := backup.Target{ diff --git a/pkg/service/cluster/service.go b/pkg/service/cluster/service.go index 353d67fea..e06225c84 100644 --- a/pkg/service/cluster/service.go +++ b/pkg/service/cluster/service.go @@ -9,6 +9,7 @@ import ( "fmt" "sort" "strconv" + "sync" "time" "github.com/gocql/gocql" @@ -34,7 +35,10 @@ type ProviderFunc func(ctx context.Context, id uuid.UUID) (*Cluster, error) type ChangeType int8 // ErrNoValidKnownHost is thrown when it was not possible to connect to any of the currently known hosts of the cluster. -var ErrNoValidKnownHost = errors.New("unable to connect to any of cluster's known hosts") +var ( + ErrNoValidKnownHost = errors.New("unable to connect to any of cluster's known hosts") + ErrNoLiveHostAvailable = errors.New("no single live host available") +) // ChangeType enumeration. const ( @@ -156,7 +160,7 @@ func (s *Service) clientConfig(c *Cluster) scyllaclient.Config { } func (s *Service) discoverAndSetClusterHosts(ctx context.Context, c *Cluster) error { - knownHosts, err := s.discoverClusterHosts(ctx, c) + knownHosts, _, err := s.discoverClusterHosts(ctx, c) if err != nil { if errors.Is(err, ErrNoValidKnownHost) { s.logger.Error(ctx, "There is no single valid known host for the cluster. "+ @@ -171,55 +175,125 @@ func (s *Service) discoverAndSetClusterHosts(ctx context.Context, c *Cluster) er return errors.Wrap(s.setKnownHosts(c, knownHosts), "update known_hosts in SM DB") } -func (s *Service) discoverClusterHosts(ctx context.Context, c *Cluster) ([]string, error) { - var contactPoints []string +const ( + discoverClusterHostsTimeout = 5 * time.Second +) + +func (s *Service) discoverClusterHosts(ctx context.Context, c *Cluster) (knownHosts, liveHosts []string, err error) { if c.Host != "" { - contactPoints = append(contactPoints, c.Host) // Go with the designated contact point first + knownHosts, liveHosts, err := s.discoverClusterHostUsingCoordinator(ctx, c, discoverClusterHostsTimeout, c.Host) + if err != nil { + s.logger.Error(ctx, "Couldn't discover hosts using stored coordinator host, proceeding with other known ones", + "coordinator-host", c.Host, "error", err) + } else { + return knownHosts, liveHosts, nil + } } else { s.logger.Error(ctx, "Missing --host flag. Using only previously discovered hosts instead", "cluster ID", c.ID) } - contactPoints = append(contactPoints, c.KnownHosts...) // In case it failed, try to contact previously discovered hosts + if len(c.KnownHosts) < 1 { + return nil, nil, ErrNoValidKnownHost + } - for _, cp := range contactPoints { - if cp == "" { - s.logger.Error(ctx, "Empty contact point", "cluster ID", c.ID, "contact points", contactPoints) - continue - } + wg := sync.WaitGroup{} + type hostsTuple struct { + live, known []string + } + result := make(chan hostsTuple, len(c.KnownHosts)) + discoverContext, discoverCancel := context.WithCancel(ctx) + defer discoverCancel() - config := scyllaclient.DefaultConfigWithTimeout(s.timeoutConfig) - if c.Port != 0 { - config.Port = strconv.Itoa(c.Port) - } - config.AuthToken = c.AuthToken - config.Hosts = []string{cp} + for _, cp := range c.KnownHosts { + wg.Add(1) - client, err := scyllaclient.NewClient(config, s.logger.Named("client")) - if err != nil { - s.logger.Error(ctx, "Couldn't connect to contact point", "contact point", cp, "error", err) - continue - } + go func(host string) { + defer wg.Done() - knownHosts, err := s.discoverHosts(ctx, client) - logutil.LogOnError(ctx, s.logger, client.Close, "Couldn't close scylla client") - if err != nil { - s.logger.Error(ctx, "Couldn't discover hosts", "host", cp, "error", err) - continue - } - return knownHosts, nil + knownHosts, liveHosts, err := s.discoverClusterHostUsingCoordinator(discoverContext, c, discoverClusterHostsTimeout, host) + if err != nil { + // Only log if the context hasn't been canceled + if !errors.Is(discoverContext.Err(), context.Canceled) { + s.logger.Error(ctx, "Couldn't discover hosts", "host", host, "error", err) + } + return + } + result <- hostsTuple{ + live: liveHosts, + known: knownHosts, + } + }(cp) + } + + go func() { + wg.Wait() + close(result) + }() + + // Read results until the channel is closed + hosts, ok := <-result + if ok { + return hosts.known, hosts.live, nil + } + + // If no valid results, return error + return nil, nil, ErrNoValidKnownHost +} + +func (s *Service) discoverClusterHostUsingCoordinator(ctx context.Context, c *Cluster, apiCallTimeout time.Duration, + host string, +) (knownHosts, liveHosts []string, err error) { + config := scyllaclient.DefaultConfigWithTimeout(s.timeoutConfig) + if c.Port != 0 { + config.Port = strconv.Itoa(c.Port) } + config.Timeout = apiCallTimeout + config.AuthToken = c.AuthToken + config.Hosts = []string{host} - return nil, ErrNoValidKnownHost + client, err := scyllaclient.NewClient(config, s.logger.Named("client")) + if err != nil { + return nil, nil, err + } + defer logutil.LogOnError(ctx, s.logger, client.Close, "Couldn't close scylla client") + + liveHosts, err = client.GossiperEndpointLiveGet(ctx) + if err != nil { + return nil, nil, err + } + knownHosts, err = s.discoverHosts(ctx, client, liveHosts) + if err != nil { + return nil, nil, err + } + return knownHosts, liveHosts, nil } // discoverHosts returns a list of all hosts sorted by DC speed. This is // an optimisation for Epsilon-Greedy host pool used internally by // scyllaclient.Client that makes it use supposedly faster hosts first. -func (s *Service) discoverHosts(ctx context.Context, client *scyllaclient.Client) (hosts []string, err error) { +func (s *Service) discoverHosts(ctx context.Context, client *scyllaclient.Client, liveHosts []string) (hosts []string, err error) { + if len(liveHosts) == 0 { + return nil, ErrNoLiveHostAvailable + } + dcs, err := client.Datacenters(ctx) if err != nil { return nil, err } - closest, err := client.ClosestDC(ctx, dcs) + // remove dead nodes from the map + liveSet := make(map[string]struct{}) + for _, host := range liveHosts { + liveSet[host] = struct{}{} + } + filteredDCs := make(map[string][]string) + for dc, hosts := range dcs { + for _, host := range hosts { + if _, isLive := liveSet[host]; isLive { + filteredDCs[dc] = append(filteredDCs[dc], host) + } + } + } + + closest, err := client.ClosestDC(ctx, filteredDCs) if err != nil { return nil, err } @@ -397,7 +471,7 @@ func (s *Service) PutCluster(ctx context.Context, c *Cluster) (err error) { } // Check hosts connectivity. - if err := s.validateHostsConnectivity(ctx, c); err != nil { + if err := s.ValidateHostsConnectivity(ctx, c); err != nil { var tip string switch scyllaclient.StatusCodeOf(err) { case 0: @@ -487,36 +561,34 @@ func (s *Service) PutCluster(ctx context.Context, c *Cluster) (err error) { return s.notifyChangeListener(ctx, changeEvent) } -func (s *Service) validateHostsConnectivity(ctx context.Context, c *Cluster) error { +// ValidateHostsConnectivity validates that scylla manager agent API is available and responding on all live hosts. +// Hosts are discovered using cluster.host + cluster.knownHosts saved to the manager's database. +func (s *Service) ValidateHostsConnectivity(ctx context.Context, c *Cluster) error { if err := s.loadKnownHosts(c); err != nil && !errors.Is(err, gocql.ErrNotFound) { return errors.Wrap(err, "load known hosts") } - knownHosts, err := s.discoverClusterHosts(ctx, c) + knownHosts, liveHosts, err := s.discoverClusterHosts(ctx, c) if err != nil { return errors.Wrap(err, "discover cluster hosts") } c.KnownHosts = knownHosts + if len(liveHosts) == 0 { + return util.ErrValidate(errors.New("no live nodes")) + } + config := s.clientConfig(c) + config.Hosts = liveHosts client, err := scyllaclient.NewClient(config, s.logger.Named("client")) if err != nil { return err } defer logutil.LogOnError(ctx, s.logger, client.Close, "Couldn't close scylla client") - status, err := client.Status(ctx) - if err != nil { - return errors.Wrap(err, "cluster status") - } - live := status.Live().Hosts() - if len(live) == 0 { - return util.ErrValidate(errors.New("no live nodes")) - } - var errs error - for i, err := range client.CheckHostsConnectivity(ctx, live) { - errs = multierr.Append(errs, errors.Wrap(err, live[i])) + for i, err := range client.CheckHostsConnectivity(ctx, liveHosts) { + errs = multierr.Append(errs, errors.Wrap(err, liveHosts[i])) } if errs != nil { return util.ErrValidate(errors.Wrap(errs, "connectivity check")) diff --git a/pkg/service/cluster/service_integration_test.go b/pkg/service/cluster/service_integration_test.go index e70c6f18e..71a93deef 100644 --- a/pkg/service/cluster/service_integration_test.go +++ b/pkg/service/cluster/service_integration_test.go @@ -33,6 +33,132 @@ import ( "github.com/scylladb/scylla-manager/v3/pkg/util/uuid" ) +func TestValidateHostConnectivityIntegration(t *testing.T) { + if IsIPV6Network() { + t.Skip("DB node do not have ip6tables and related modules to make it work properly") + } + + Print("given: the fresh cluster") + var ( + ctx = context.Background() + session = CreateScyllaManagerDBSession(t) + secretsStore = store.NewTableStore(session, table.Secrets) + c = &cluster.Cluster{ + AuthToken: "token", + Host: ManagedClusterHost(), + } + ) + s, err := cluster.NewService(session, metrics.NewClusterMetrics(), secretsStore, scyllaclient.DefaultTimeoutConfig(), + server.DefaultConfig().ClientCacheTimeout, log.NewDevelopment()) + if err != nil { + t.Fatal(err) + } + + err = s.PutCluster(context.Background(), c) + if err != nil { + t.Fatal(err) + } + + allHosts := ManagedClusterHosts() + for _, tc := range []struct { + name string + hostsDown []string + result error + timeout time.Duration + }{ + { + name: "coordinator host is DOWN", + hostsDown: []string{ManagedClusterHost()}, + result: nil, + timeout: 6 * time.Second, + }, + { + name: "only one is UP", + hostsDown: allHosts[:len(allHosts)-1], + result: nil, + timeout: 6 * time.Second, + }, + { + name: "all hosts are DOWN", + hostsDown: allHosts, + result: cluster.ErrNoValidKnownHost, + timeout: 11 * time.Second, // the 5 seconds calls will timeout twice + }, + { + name: "all hosts are UP", + hostsDown: nil, + result: nil, + timeout: 6 * time.Second, + }, + } { + t.Run(tc.name, func(t *testing.T) { + defer func() { + for _, host := range tc.hostsDown { + if err := StartService(host, "scylla"); err != nil { + t.Logf("error on starting stopped scylla service on host={%s}, err={%s}", host, err) + } + if err := RunIptablesCommand(t, host, CmdUnblockScyllaREST); err != nil { + t.Logf("error trying to unblock REST API on host = {%s}, err={%s}", host, err) + } + } + }() + TryUnblockCQL(t, ManagedClusterHosts()) + TryUnblockREST(t, ManagedClusterHosts()) + TryUnblockAlternator(t, ManagedClusterHosts()) + TryStartAgent(t, ManagedClusterHosts()) + if err := EnsureNodesAreUP(t, ManagedClusterHosts(), time.Minute); err != nil { + t.Fatalf("not all nodes are UP, err = {%v}", err) + } + + Printf("then: validate that call to validate host connectivity takes less than %v seconds", tc.timeout.Seconds()) + testCluster, err := s.GetClusterByID(context.Background(), c.ID) + if err != nil { + t.Fatal(err) + } + if err := callValidateHostConnectivityWithTimeout(ctx, s, tc.timeout, testCluster); err != nil { + t.Fatal(err) + } + Printf("when: the scylla service is stopped and the scylla API is timing out on some hosts") + // It's needed to block Scylla REST API, so that the clients are just hanging when they call the API. + // Scylla service must be stopped to make the node to report DOWN status. Blocking REST API is not + // enough. + for _, host := range tc.hostsDown { + if err := StopService(host, "scylla"); err != nil { + t.Fatal(err) + } + if err := RunIptablesCommand(t, host, CmdBlockScyllaREST); err != nil { + t.Error(err) + } + } + + Printf("then: validate that call still takes less than %v seconds", tc.timeout.Seconds()) + if err := callValidateHostConnectivityWithTimeout(ctx, s, tc.timeout, testCluster); !errors.Is(err, tc.result) { + t.Fatal(err) + } + }) + } +} + +func callValidateHostConnectivityWithTimeout(ctx context.Context, s *cluster.Service, timeout time.Duration, + c *cluster.Cluster) error { + + callCtx, cancel := context.WithCancel(ctx) + defer cancel() + + done := make(chan error) + go func() { + done <- s.ValidateHostsConnectivity(callCtx, c) + }() + + select { + case <-time.After(timeout): + cancel() + return fmt.Errorf("expected s.ValidateHostsConnectivity to complete in less than %v seconds, time exceeded", timeout.Seconds()) + case err := <-done: + return err + } +} + func TestClientIntegration(t *testing.T) { expectedHosts := ManagedClusterHosts() @@ -553,10 +679,10 @@ func TestServiceStorageIntegration(t *testing.T) { c := validCluster() c.Host = h1 - if err := RunIptablesCommand(h2, CmdBlockScyllaREST); err != nil { + if err := RunIptablesCommand(t, h2, CmdBlockScyllaREST); err != nil { t.Fatal(err) } - defer RunIptablesCommand(h2, CmdUnblockScyllaREST) + defer RunIptablesCommand(t, h2, CmdUnblockScyllaREST) if err := s.PutCluster(ctx, c); err == nil { t.Fatal("expected put cluster to fail because of connectivity issues") diff --git a/pkg/service/healthcheck/service_integration_test.go b/pkg/service/healthcheck/service_integration_test.go index 4744e2414..894807bbc 100644 --- a/pkg/service/healthcheck/service_integration_test.go +++ b/pkg/service/healthcheck/service_integration_test.go @@ -49,10 +49,13 @@ func TestStatus_Ping_Independent_From_REST_Integration(t *testing.T) { } // Given - tryUnblockCQL(t, ManagedClusterHosts()) - tryUnblockREST(t, ManagedClusterHosts()) - tryUnblockAlternator(t, ManagedClusterHosts()) - tryStartAgent(t, ManagedClusterHosts()) + TryUnblockCQL(t, ManagedClusterHosts()) + TryUnblockREST(t, ManagedClusterHosts()) + TryUnblockAlternator(t, ManagedClusterHosts()) + TryStartAgent(t, ManagedClusterHosts()) + if err := EnsureNodesAreUP(t, ManagedClusterHosts(), time.Minute); err != nil { + t.Fatalf("not all nodes are UP, err = {%v}", err) + } logger := log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("healthcheck") @@ -117,8 +120,8 @@ func TestStatus_Ping_Independent_From_REST_Integration(t *testing.T) { } // When #2 -> one of the hosts has unresponsive REST API - defer unblockREST(t, hostWithUnresponsiveREST) - blockREST(t, hostWithUnresponsiveREST) + defer UnblockREST(t, hostWithUnresponsiveREST) + BlockREST(t, hostWithUnresponsiveREST) // Then #2 -> only REST ping fails, CQL and Alternator are fine status, err = healthSvc.Status(context.Background(), testCluster.ID) @@ -211,16 +214,16 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterSvc cluster // Tests here do not test the dynamic t/o functionality c := DefaultConfig() - tryUnblockCQL(t, ManagedClusterHosts()) - tryUnblockREST(t, ManagedClusterHosts()) - tryUnblockAlternator(t, ManagedClusterHosts()) - tryStartAgent(t, ManagedClusterHosts()) + TryUnblockCQL(t, ManagedClusterHosts()) + TryUnblockREST(t, ManagedClusterHosts()) + TryUnblockAlternator(t, ManagedClusterHosts()) + TryStartAgent(t, ManagedClusterHosts()) defer func() { - tryUnblockCQL(t, ManagedClusterHosts()) - tryUnblockREST(t, ManagedClusterHosts()) - tryUnblockAlternator(t, ManagedClusterHosts()) - tryStartAgent(t, ManagedClusterHosts()) + TryUnblockCQL(t, ManagedClusterHosts()) + TryUnblockREST(t, ManagedClusterHosts()) + TryUnblockAlternator(t, ManagedClusterHosts()) + TryStartAgent(t, ManagedClusterHosts()) }() hrt := NewHackableRoundTripper(scyllaclient.DefaultTransport()) @@ -283,8 +286,8 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterSvc cluster t.Run("node REST TIMEOUT", func(t *testing.T) { host := IPFromTestNet("12") - blockREST(t, host) - defer unblockREST(t, host) + BlockREST(t, host) + defer UnblockREST(t, host) status, err := s.Status(context.Background(), clusterID) if err != nil { @@ -309,8 +312,8 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterSvc cluster t.Run("node CQL TIMEOUT", func(t *testing.T) { host := IPFromTestNet("12") - blockCQL(t, host, sslEnabled) - defer unblockCQL(t, host, sslEnabled) + BlockCQL(t, host, sslEnabled) + defer UnblockCQL(t, host, sslEnabled) status, err := s.Status(context.Background(), clusterID) if err != nil { @@ -335,8 +338,8 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterSvc cluster t.Run("node Alternator TIMEOUT", func(t *testing.T) { host := IPFromTestNet("12") - blockAlternator(t, host) - defer unblockAlternator(t, host) + BlockAlternator(t, host) + defer UnblockAlternator(t, host) status, err := s.Status(context.Background(), clusterID) if err != nil { @@ -361,8 +364,8 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterSvc cluster t.Run("node REST DOWN", func(t *testing.T) { host := IPFromTestNet("12") - stopAgent(t, host) - defer startAgent(t, host) + StopAgent(t, host) + defer StartAgent(t, host) status, err := s.Status(context.Background(), clusterID) if err != nil { @@ -440,11 +443,11 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterSvc cluster defer cancel() for _, h := range ManagedClusterHosts() { - blockREST(t, h) + BlockREST(t, h) } defer func() { for _, h := range ManagedClusterHosts() { - unblockREST(t, h) + UnblockREST(t, h) } }() @@ -468,100 +471,6 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterSvc cluster }) } -func blockREST(t *testing.T, h string) { - t.Helper() - if err := RunIptablesCommand(h, CmdBlockScyllaREST); err != nil { - t.Error(err) - } -} - -func unblockREST(t *testing.T, h string) { - t.Helper() - if err := RunIptablesCommand(h, CmdUnblockScyllaREST); err != nil { - t.Error(err) - } -} - -func tryUnblockREST(t *testing.T, hosts []string) { - t.Helper() - for _, host := range hosts { - _ = RunIptablesCommand(host, CmdUnblockScyllaREST) - } -} - -func blockCQL(t *testing.T, h string, sslEnabled bool) { - t.Helper() - cmd := CmdBlockScyllaCQL - if sslEnabled { - cmd = CmdBlockScyllaCQLSSL - } - if err := RunIptablesCommand(h, cmd); err != nil { - t.Error(err) - } -} - -func unblockCQL(t *testing.T, h string, sslEnabled bool) { - t.Helper() - cmd := CmdUnblockScyllaCQL - if sslEnabled { - cmd = CmdUnblockScyllaCQLSSL - } - if err := RunIptablesCommand(h, cmd); err != nil { - t.Error(err) - } -} - -func tryUnblockCQL(t *testing.T, hosts []string) { - t.Helper() - for _, host := range hosts { - _ = RunIptablesCommand(host, CmdUnblockScyllaCQL) - } -} - -func blockAlternator(t *testing.T, h string) { - t.Helper() - if err := RunIptablesCommand(h, CmdBlockScyllaAlternator); err != nil { - t.Error(err) - } -} - -func unblockAlternator(t *testing.T, h string) { - t.Helper() - if err := RunIptablesCommand(h, CmdUnblockScyllaAlternator); err != nil { - t.Error(err) - } -} - -func tryUnblockAlternator(t *testing.T, hosts []string) { - t.Helper() - for _, host := range hosts { - _ = RunIptablesCommand(host, CmdUnblockScyllaAlternator) - } -} - -const agentService = "scylla-manager-agent" - -func stopAgent(t *testing.T, h string) { - t.Helper() - if err := StopService(h, agentService); err != nil { - t.Error(err) - } -} - -func startAgent(t *testing.T, h string) { - t.Helper() - if err := StartService(h, agentService); err != nil { - t.Error(err) - } -} - -func tryStartAgent(t *testing.T, hosts []string) { - t.Helper() - for _, host := range hosts { - _ = StartService(host, agentService) - } -} - const pingPath = "/storage_service/scylla_release_version" func fakeHealthCheckStatus(host string, code int) http.RoundTripper { diff --git a/pkg/testutils/exec.go b/pkg/testutils/exec.go index 624dea384..556ca8b1c 100644 --- a/pkg/testutils/exec.go +++ b/pkg/testutils/exec.go @@ -4,10 +4,15 @@ package testutils import ( "bytes" + "fmt" "net" "strings" + "sync" + "testing" + "time" "github.com/pkg/errors" + "go.uber.org/multierr" "golang.org/x/crypto/ssh" ) @@ -35,6 +40,9 @@ const ( // CmdUnblockScyllaAlternator defines the command used for unblocking the Scylla Alternator access. CmdUnblockScyllaAlternator = "iptables -D INPUT -p tcp --destination-port 8000 -j DROP" + + // CmdOrTrueAppend let to accept shell command failure and proceed. + CmdOrTrueAppend = " || true" ) func makeIPV6Rule(rule string) string { @@ -42,16 +50,21 @@ func makeIPV6Rule(rule string) string { } // RunIptablesCommand executes iptables command, repeats same command for IPV6 iptables rule. -func RunIptablesCommand(host, cmd string) error { +func RunIptablesCommand(t *testing.T, host, cmd string) error { + t.Helper() if IsIPV6Network() { - return ExecOnHostStatus(host, makeIPV6Rule(cmd)) + return ExecOnHostStatus(t, host, makeIPV6Rule(cmd)) } - return ExecOnHostStatus(host, cmd) + return ExecOnHostStatus(t, host, cmd) } // ExecOnHostStatus executes the given command on the given host and returns on error. -func ExecOnHostStatus(host, cmd string) error { - _, _, err := ExecOnHost(host, cmd) +func ExecOnHostStatus(t *testing.T, host, cmd string) error { + t.Helper() + stdOut, stdErr, err := ExecOnHost(host, cmd) + if err != nil { + t.Logf("cnd: {%s}, stdout: {%s}, stderr: {%s}", cmd, stdOut, stdErr) + } return errors.Wrapf(err, "run command %s", cmd) } @@ -107,3 +120,188 @@ func StartService(h, service string) error { } return nil } + +// WaitForNodeUPOrTimeout waits until nodetool status report UN status for the given node. +// The nodetool status CLI is executed on the same node. +func WaitForNodeUPOrTimeout(h string, timeout time.Duration) error { + nodeIsReady := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(nodeIsReady) + for { + select { + case <-done: + return + default: + stdout, _, err := ExecOnHost(h, "nodetool status | grep "+h) + if err != nil { + continue + } + if strings.HasPrefix(stdout, "UN") { + return + } + select { + case <-done: + return + case <-time.After(time.Second): + } + } + } + }() + + select { + case <-nodeIsReady: + return nil + case <-time.After(timeout): + close(done) + return fmt.Errorf("node %s haven't reach UP status", h) + } +} + +// BlockREST blocks the Scylla API ports on h machine by dropping TCP packets. +func BlockREST(t *testing.T, h string) { + t.Helper() + if err := RunIptablesCommand(t, h, CmdBlockScyllaREST); err != nil { + t.Error(err) + } +} + +// UnblockREST unblocks the Scylla API ports on []hosts machines. +func UnblockREST(t *testing.T, h string) { + t.Helper() + if err := RunIptablesCommand(t, h, CmdUnblockScyllaREST); err != nil { + t.Error(err) + } +} + +// TryUnblockREST tries to unblock the Scylla API ports on []hosts machines. +// Logs an error if the execution failed, but doesn't return it. +func TryUnblockREST(t *testing.T, hosts []string) { + t.Helper() + for _, host := range hosts { + if err := RunIptablesCommand(t, host, CmdUnblockScyllaREST+CmdOrTrueAppend); err != nil { + t.Log(err) + } + } +} + +// BlockCQL blocks the CQL ports on h machine by dropping TCP packets. +func BlockCQL(t *testing.T, h string, sslEnabled bool) { + t.Helper() + cmd := CmdBlockScyllaCQL + if sslEnabled { + cmd = CmdBlockScyllaCQLSSL + } + if err := RunIptablesCommand(t, h, cmd); err != nil { + t.Error(err) + } +} + +// UnblockCQL unblocks the CQL ports on []hosts machines. +func UnblockCQL(t *testing.T, h string, sslEnabled bool) { + t.Helper() + cmd := CmdUnblockScyllaCQL + if sslEnabled { + cmd = CmdUnblockScyllaCQLSSL + } + if err := RunIptablesCommand(t, h, cmd); err != nil { + t.Error(err) + } +} + +// TryUnblockCQL tries to unblock the CQL ports on []hosts machines. +// Logs an error if the execution failed, but doesn't return it. +func TryUnblockCQL(t *testing.T, hosts []string) { + t.Helper() + for _, host := range hosts { + if err := RunIptablesCommand(t, host, CmdUnblockScyllaCQL+CmdOrTrueAppend); err != nil { + t.Log(err) + } + } +} + +// BlockAlternator blocks the Scylla Alternator ports on h machine by dropping TCP packets. +func BlockAlternator(t *testing.T, h string) { + t.Helper() + if err := RunIptablesCommand(t, h, CmdBlockScyllaAlternator); err != nil { + t.Error(err) + } +} + +// UnblockAlternator unblocks the Alternator ports on []hosts machines. +func UnblockAlternator(t *testing.T, h string) { + t.Helper() + if err := RunIptablesCommand(t, h, CmdUnblockScyllaAlternator); err != nil { + t.Error(err) + } +} + +// TryUnblockAlternator tries to unblock the Alternator API ports on []hosts machines. +// Logs an error if the execution failed, but doesn't return it. +func TryUnblockAlternator(t *testing.T, hosts []string) { + t.Helper() + for _, host := range hosts { + if err := RunIptablesCommand(t, host, CmdUnblockScyllaAlternator+CmdOrTrueAppend); err != nil { + t.Log(err) + } + } +} + +const agentService = "scylla-manager-agent" + +// StopAgent stops scylla-manager-agent service on the h machine. +func StopAgent(t *testing.T, h string) { + t.Helper() + if err := StopService(h, agentService); err != nil { + t.Error(err) + } +} + +// StartAgent starts scylla-manager-agent service on the h machine. +func StartAgent(t *testing.T, h string) { + t.Helper() + if err := StartService(h, agentService); err != nil { + t.Error(err) + } +} + +// TryStartAgent tries to start scylla-manager-agent service on the []hosts machines. +// It logs an error on failures, but doesn't return it. +func TryStartAgent(t *testing.T, hosts []string) { + t.Helper() + for _, host := range hosts { + if err := StartService(host, agentService+CmdOrTrueAppend); err != nil { + t.Log(err) + } + } +} + +// EnsureNodesAreUP validates if scylla-service is up and running on every []hosts and nodes are reporting their status +// correctly via `nodetool status` command. +// It waits for each node to report UN status for the duration specified in timeout parameter. +func EnsureNodesAreUP(t *testing.T, hosts []string, timeout time.Duration) error { + t.Helper() + + var ( + allErrors error + mu sync.Mutex + ) + + wg := sync.WaitGroup{} + for _, host := range hosts { + wg.Add(1) + + go func(h string) { + defer wg.Done() + + if err := WaitForNodeUPOrTimeout(h, timeout); err != nil { + mu.Lock() + allErrors = multierr.Combine(allErrors, err) + mu.Unlock() + } + }(host) + } + wg.Wait() + + return allErrors +} diff --git a/pkg/testutils/netwait_integration_test.go b/pkg/testutils/netwait_integration_test.go index 0ce873d46..5f95d6818 100644 --- a/pkg/testutils/netwait_integration_test.go +++ b/pkg/testutils/netwait_integration_test.go @@ -22,11 +22,11 @@ func TestWaiterTimeoutIntegration(t *testing.T) { } host := ManagedClusterHost() - err := RunIptablesCommand(host, CmdBlockScyllaCQL) + err := RunIptablesCommand(t, host, CmdBlockScyllaCQL) if err != nil { t.Fatal(err) } - defer RunIptablesCommand(host, CmdUnblockScyllaCQL) + defer RunIptablesCommand(t, host, CmdUnblockScyllaCQL) w := &netwait.Waiter{ DialTimeout: 5 * time.Millisecond,