diff --git a/pkg/cmd/scylla-manager/server.go b/pkg/cmd/scylla-manager/server.go index 8c9a40c22..580683132 100644 --- a/pkg/cmd/scylla-manager/server.go +++ b/pkg/cmd/scylla-manager/server.go @@ -164,7 +164,7 @@ func (s *server) makeServices(ctx context.Context) error { func (s *server) onClusterChange(ctx context.Context, c cluster.Change) error { switch c.Type { case cluster.Update: - s.configCacheSvc.ForceUpdateCluster(ctx, c.ID) + go s.configCacheSvc.ForceUpdateCluster(ctx, c.ID) case cluster.Create: s.configCacheSvc.ForceUpdateCluster(ctx, c.ID) for _, t := range makeAutoHealthCheckTasks(c.ID) { diff --git a/pkg/scyllaclient/client_scylla.go b/pkg/scyllaclient/client_scylla.go index 193fb9e15..03210b83f 100644 --- a/pkg/scyllaclient/client_scylla.go +++ b/pkg/scyllaclient/client_scylla.go @@ -178,6 +178,15 @@ func (c *Client) Datacenters(ctx context.Context) (map[string][]string, error) { return res, errs } +// GossiperEndpointLiveGet finds live nodes (according to gossiper). +func (c *Client) GossiperEndpointLiveGet(ctx context.Context) ([]string, error) { + live, err := c.scyllaOps.GossiperEndpointLiveGet(&operations.GossiperEndpointLiveGetParams{Context: ctx}) + if err != nil { + return nil, err + } + return live.GetPayload(), nil +} + // HostDatacenter looks up the datacenter that the given host belongs to. func (c *Client) HostDatacenter(ctx context.Context, host string) (dc string, err error) { // Try reading from cache diff --git a/pkg/service/cluster/service.go b/pkg/service/cluster/service.go index 353d67fea..160393009 100644 --- a/pkg/service/cluster/service.go +++ b/pkg/service/cluster/service.go @@ -9,6 +9,7 @@ import ( "fmt" "sort" "strconv" + "sync" "time" "github.com/gocql/gocql" @@ -34,7 +35,10 @@ type ProviderFunc func(ctx context.Context, id uuid.UUID) (*Cluster, error) type ChangeType int8 // ErrNoValidKnownHost is thrown when it was not possible to connect to any of the currently known hosts of the cluster. -var ErrNoValidKnownHost = errors.New("unable to connect to any of cluster's known hosts") +var ( + ErrNoValidKnownHost = errors.New("unable to connect to any of cluster's known hosts") + ErrNoLiveHostAvailable = errors.New("no single live host available") +) // ChangeType enumeration. const ( @@ -156,7 +160,7 @@ func (s *Service) clientConfig(c *Cluster) scyllaclient.Config { } func (s *Service) discoverAndSetClusterHosts(ctx context.Context, c *Cluster) error { - knownHosts, err := s.discoverClusterHosts(ctx, c) + knownHosts, _, err := s.discoverClusterHosts(ctx, c) if err != nil { if errors.Is(err, ErrNoValidKnownHost) { s.logger.Error(ctx, "There is no single valid known host for the cluster. "+ @@ -171,55 +175,124 @@ func (s *Service) discoverAndSetClusterHosts(ctx context.Context, c *Cluster) er return errors.Wrap(s.setKnownHosts(c, knownHosts), "update known_hosts in SM DB") } -func (s *Service) discoverClusterHosts(ctx context.Context, c *Cluster) ([]string, error) { - var contactPoints []string +const ( + discoverClusterHostsTimeout = 5 * time.Second +) + +func (s *Service) discoverClusterHosts(ctx context.Context, c *Cluster) (knownHosts, liveHosts []string, err error) { if c.Host != "" { - contactPoints = append(contactPoints, c.Host) // Go with the designated contact point first + knownHosts, liveHosts, err := s.discoverClusterHostUsingCoordinator(ctx, c, discoverClusterHostsTimeout, c.Host) + if err != nil { + s.logger.Error(ctx, "Couldn't discover hosts using stored coordinator host, proceeding with other known ones", + "coordinator-host", c.Host, "error", err) + } else { + return knownHosts, liveHosts, nil + } } else { s.logger.Error(ctx, "Missing --host flag. Using only previously discovered hosts instead", "cluster ID", c.ID) } - contactPoints = append(contactPoints, c.KnownHosts...) // In case it failed, try to contact previously discovered hosts + if len(c.KnownHosts) < 1 { + return nil, nil, ErrNoValidKnownHost + } - for _, cp := range contactPoints { - if cp == "" { - s.logger.Error(ctx, "Empty contact point", "cluster ID", c.ID, "contact points", contactPoints) - continue - } + wg := sync.WaitGroup{} + type hostsTuple struct { + live, known []string + } + result := make(chan hostsTuple, len(c.KnownHosts)) + discoverContext, discoverCancel := context.WithCancel(ctx) + defer discoverCancel() - config := scyllaclient.DefaultConfigWithTimeout(s.timeoutConfig) - if c.Port != 0 { - config.Port = strconv.Itoa(c.Port) - } - config.AuthToken = c.AuthToken - config.Hosts = []string{cp} + for _, cp := range c.KnownHosts { + wg.Add(1) - client, err := scyllaclient.NewClient(config, s.logger.Named("client")) - if err != nil { - s.logger.Error(ctx, "Couldn't connect to contact point", "contact point", cp, "error", err) - continue - } + go func(host string) { + defer wg.Done() - knownHosts, err := s.discoverHosts(ctx, client) - logutil.LogOnError(ctx, s.logger, client.Close, "Couldn't close scylla client") - if err != nil { - s.logger.Error(ctx, "Couldn't discover hosts", "host", cp, "error", err) - continue - } - return knownHosts, nil + knownHosts, liveHosts, err := s.discoverClusterHostUsingCoordinator(discoverContext, c, discoverClusterHostsTimeout, host) + if err != nil { + // Only log if the context hasn't been canceled + if !errors.Is(discoverContext.Err(), context.Canceled) { + s.logger.Error(ctx, "Couldn't discover hosts", "host", host, "error", err) + } + return + } + result <- hostsTuple{ + live: liveHosts, + known: knownHosts, + } + }(cp) + } + + go func() { + wg.Wait() + close(result) + }() + + // Read results until the channel is closed + for hosts := range result { + return hosts.known, hosts.live, nil + } + + // If no valid results, return error< + return nil, nil, ErrNoValidKnownHost +} + +func (s *Service) discoverClusterHostUsingCoordinator(ctx context.Context, c *Cluster, apiCallTimeout time.Duration, + host string, +) (knownHosts, liveHosts []string, err error) { + config := scyllaclient.DefaultConfigWithTimeout(s.timeoutConfig) + if c.Port != 0 { + config.Port = strconv.Itoa(c.Port) } + config.Timeout = apiCallTimeout + config.AuthToken = c.AuthToken + config.Hosts = []string{host} - return nil, ErrNoValidKnownHost + client, err := scyllaclient.NewClient(config, s.logger.Named("client")) + if err != nil { + return nil, nil, err + } + defer logutil.LogOnError(ctx, s.logger, client.Close, "Couldn't close scylla client") + + liveHosts, err = client.GossiperEndpointLiveGet(ctx) + if err != nil { + return nil, nil, err + } + knownHosts, err = s.discoverHosts(ctx, client, liveHosts) + if err != nil { + return nil, nil, err + } + return knownHosts, liveHosts, nil } // discoverHosts returns a list of all hosts sorted by DC speed. This is // an optimisation for Epsilon-Greedy host pool used internally by // scyllaclient.Client that makes it use supposedly faster hosts first. -func (s *Service) discoverHosts(ctx context.Context, client *scyllaclient.Client) (hosts []string, err error) { +func (s *Service) discoverHosts(ctx context.Context, client *scyllaclient.Client, liveHosts []string) (hosts []string, err error) { + if len(liveHosts) == 0 { + return nil, ErrNoLiveHostAvailable + } + dcs, err := client.Datacenters(ctx) if err != nil { return nil, err } - closest, err := client.ClosestDC(ctx, dcs) + // remove dead nodes from the map + liveSet := make(map[string]struct{}) + for _, host := range liveHosts { + liveSet[host] = struct{}{} + } + filteredDCs := make(map[string][]string) + for dc, hosts := range dcs { + for _, host := range hosts { + if _, isLive := liveSet[host]; isLive { + filteredDCs[dc] = append(filteredDCs[dc], host) + } + } + } + + closest, err := client.ClosestDC(ctx, filteredDCs) if err != nil { return nil, err } @@ -397,7 +470,7 @@ func (s *Service) PutCluster(ctx context.Context, c *Cluster) (err error) { } // Check hosts connectivity. - if err := s.validateHostsConnectivity(ctx, c); err != nil { + if err := s.ValidateHostsConnectivity(ctx, c); err != nil { var tip string switch scyllaclient.StatusCodeOf(err) { case 0: @@ -487,36 +560,34 @@ func (s *Service) PutCluster(ctx context.Context, c *Cluster) (err error) { return s.notifyChangeListener(ctx, changeEvent) } -func (s *Service) validateHostsConnectivity(ctx context.Context, c *Cluster) error { +// ValidateHostsConnectivity validates that scylla manager agent API is available and responding on all live hosts. +// Hosts are discovered using cluster.host + cluster.knownHosts saved to the manager's database. +func (s *Service) ValidateHostsConnectivity(ctx context.Context, c *Cluster) error { if err := s.loadKnownHosts(c); err != nil && !errors.Is(err, gocql.ErrNotFound) { return errors.Wrap(err, "load known hosts") } - knownHosts, err := s.discoverClusterHosts(ctx, c) + knownHosts, liveHosts, err := s.discoverClusterHosts(ctx, c) if err != nil { return errors.Wrap(err, "discover cluster hosts") } c.KnownHosts = knownHosts + if len(liveHosts) == 0 { + return util.ErrValidate(errors.New("no live nodes")) + } + config := s.clientConfig(c) + config.Hosts = liveHosts client, err := scyllaclient.NewClient(config, s.logger.Named("client")) if err != nil { return err } defer logutil.LogOnError(ctx, s.logger, client.Close, "Couldn't close scylla client") - status, err := client.Status(ctx) - if err != nil { - return errors.Wrap(err, "cluster status") - } - live := status.Live().Hosts() - if len(live) == 0 { - return util.ErrValidate(errors.New("no live nodes")) - } - var errs error - for i, err := range client.CheckHostsConnectivity(ctx, live) { - errs = multierr.Append(errs, errors.Wrap(err, live[i])) + for i, err := range client.CheckHostsConnectivity(ctx, liveHosts) { + errs = multierr.Append(errs, errors.Wrap(err, liveHosts[i])) } if errs != nil { return util.ErrValidate(errors.Wrap(errs, "connectivity check")) diff --git a/pkg/service/cluster/service_integration_test.go b/pkg/service/cluster/service_integration_test.go index e70c6f18e..551ef6965 100644 --- a/pkg/service/cluster/service_integration_test.go +++ b/pkg/service/cluster/service_integration_test.go @@ -33,6 +33,121 @@ import ( "github.com/scylladb/scylla-manager/v3/pkg/util/uuid" ) +func TestValidateHostConnectivityIntegration(t *testing.T) { + Print("given: the fresh cluster") + var ( + ctx = context.Background() + session = CreateScyllaManagerDBSession(t) + secretsStore = store.NewTableStore(session, table.Secrets) + c = &cluster.Cluster{ + AuthToken: "token", + Host: ManagedClusterHost(), + } + ) + s, err := cluster.NewService(session, metrics.NewClusterMetrics(), secretsStore, scyllaclient.DefaultTimeoutConfig(), + server.DefaultConfig().ClientCacheTimeout, log.NewDevelopment()) + if err != nil { + t.Fatal(err) + } + + err = s.PutCluster(context.Background(), c) + if err != nil { + t.Fatal(err) + } + + allHosts := ManagedClusterHosts() + for _, tc := range []struct { + name string + hostsDown []string + result error + timeout time.Duration + }{ + { + name: "coordinator host is DOWN", + hostsDown: []string{ManagedClusterHost()}, + result: nil, + timeout: 6 * time.Second, + }, + { + name: "only one is UP", + hostsDown: allHosts[:len(allHosts)-1], + result: nil, + timeout: 6 * time.Second, + }, + { + name: "all hosts are DOWN", + hostsDown: allHosts, + result: cluster.ErrNoValidKnownHost, + timeout: 11 * time.Second, // the 5 seconds calls will timeout twice + }, + { + name: "all hosts are UP", + hostsDown: nil, + result: nil, + timeout: 6 * time.Second, + }, + } { + t.Run(tc.name, func(t *testing.T) { + defer func() { + for _, host := range tc.hostsDown { + if err := StartService(host, "scylla"); err != nil { + t.Logf("error on starting stopped scylla service on host={%s}, err={%s}", host, err) + } + if err := RunIptablesCommand(host, CmdUnblockScyllaREST); err != nil { + t.Logf("error trying to unblock REST API on host = {%s}, err={%s}", host, err) + } + } + }() + + Printf("then: validate that call to validate host connectivity takes less than %v seconds", tc.timeout.Seconds()) + testCluster, err := s.GetClusterByID(context.Background(), c.ID) + if err != nil { + t.Fatal(err) + } + if err := callValidateHostConnectivityWithTimeout(ctx, s, tc.timeout, testCluster); err != nil { + t.Fatal(err) + } + Printf("when: the scylla service is stopped and the scylla API is timing out on some hosts") + // It's needed to block Scylla REST API, so that the clients are just hanging when they call the API. + // Scylla service must be stopped to make the node to report DOWN status. Blocking REST API is not + // enough. + for _, host := range tc.hostsDown { + if err := StopService(host, "scylla"); err != nil { + t.Fatal(err) + } + if err := RunIptablesCommand(host, CmdBlockScyllaREST); err != nil { + t.Error(err) + } + } + + Printf("then: validate that call still takes less than %v seconds", tc.timeout.Seconds()) + if err := callValidateHostConnectivityWithTimeout(ctx, s, tc.timeout, testCluster); !errors.Is(err, tc.result) { + t.Fatal(err) + } + }) + } +} + +func callValidateHostConnectivityWithTimeout(ctx context.Context, s *cluster.Service, timeout time.Duration, + c *cluster.Cluster) error { + + callCtx, cancel := context.WithCancel(ctx) + defer cancel() + + done := make(chan error) + go func() { + done <- s.ValidateHostsConnectivity(callCtx, c) + }() + + select { + case <-time.After(timeout): + cancel() + return fmt.Errorf("expected s.ValidateHostsConnectivity to complete in less than %v seconds, time exceeded", timeout.Seconds()) + case err := <-done: + return err + } +} + func TestClientIntegration(t *testing.T) { expectedHosts := ManagedClusterHosts() diff --git a/pkg/service/healthcheck/service_integration_test.go b/pkg/service/healthcheck/service_integration_test.go index 4744e2414..2cd1b9e32 100644 --- a/pkg/service/healthcheck/service_integration_test.go +++ b/pkg/service/healthcheck/service_integration_test.go @@ -13,6 +13,7 @@ import ( "net/http" "os" "strings" + "sync" "testing" "time" @@ -22,6 +23,7 @@ import ( "github.com/scylladb/scylla-manager/v3/pkg/metrics" "github.com/scylladb/scylla-manager/v3/pkg/service/cluster" "github.com/scylladb/scylla-manager/v3/pkg/service/configcache" + "go.uber.org/multierr" "go.uber.org/zap/zapcore" "github.com/scylladb/scylla-manager/v3/pkg/schema/table" @@ -53,6 +55,9 @@ func TestStatus_Ping_Independent_From_REST_Integration(t *testing.T) { tryUnblockREST(t, ManagedClusterHosts()) tryUnblockAlternator(t, ManagedClusterHosts()) tryStartAgent(t, ManagedClusterHosts()) + if err := ensureNodesAreUP(t, ManagedClusterHosts(), time.Minute); err != nil { + t.Fatalf("not all nodes are UP, err = {%v}", err) + } logger := log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("healthcheck") @@ -562,6 +567,33 @@ func tryStartAgent(t *testing.T, hosts []string) { } } +func ensureNodesAreUP(t *testing.T, hosts []string, timeout time.Duration) error { + t.Helper() + + var ( + allErrors error + mu sync.Mutex + ) + + wg := sync.WaitGroup{} + for _, host := range hosts { + wg.Add(1) + + go func(h string) { + defer wg.Done() + + if err := WaitForNodeUPOrTimeout(h, timeout); err != nil { + mu.Lock() + allErrors = multierr.Combine(allErrors, err) + mu.Unlock() + } + }(host) + } + wg.Wait() + + return allErrors +} + const pingPath = "/storage_service/scylla_release_version" func fakeHealthCheckStatus(host string, code int) http.RoundTripper { diff --git a/pkg/testutils/exec.go b/pkg/testutils/exec.go index 624dea384..745516106 100644 --- a/pkg/testutils/exec.go +++ b/pkg/testutils/exec.go @@ -4,8 +4,10 @@ package testutils import ( "bytes" + "fmt" "net" "strings" + "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" @@ -107,3 +109,40 @@ func StartService(h, service string) error { } return nil } + +// WaitForNodeUPOrTimeout waits until nodetool status report UN status for the given node. +// The nodetool status CLI is executed on the same node. +func WaitForNodeUPOrTimeout(h string, timeout time.Duration) error { + nodeIsReady := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(nodeIsReady) + for { + select { + case <-done: + return + default: + stdout, _, err := ExecOnHost(h, "nodetool status | grep "+h) + if err != nil { + continue + } + if strings.HasPrefix(stdout, "UN") { + return + } + select { + case <-done: + return + case <-time.After(time.Second): + } + } + } + }() + + select { + case <-nodeIsReady: + return nil + case <-time.After(timeout): + close(done) + return fmt.Errorf("node %s haven't reach UP status", h) + } +}