diff --git a/pkg/ping/cqlping/cqlping_integration_test.go b/pkg/ping/cqlping/cqlping_integration_test.go index 7cc007bf7..52b5a15c1 100644 --- a/pkg/ping/cqlping/cqlping_integration_test.go +++ b/pkg/ping/cqlping/cqlping_integration_test.go @@ -8,10 +8,11 @@ package cqlping import ( "context" "crypto/tls" - "github.com/scylladb/scylla-manager/v3/pkg/testutils/testconfig" "testing" "time" + "github.com/scylladb/scylla-manager/v3/pkg/testutils/testconfig" + "github.com/scylladb/go-log" "github.com/scylladb/scylla-manager/v3/pkg/ping" "github.com/scylladb/scylla-manager/v3/pkg/scyllaclient" @@ -24,7 +25,7 @@ func TestPingIntegration(t *testing.T) { client := newTestClient(t, log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("client"), nil) defer client.Close() - sessionHosts, err := cluster.GetRPCAddresses(context.Background(), client, []string{testconfig.ManagedClusterHost()}) + sessionHosts, err := cluster.GetRPCAddresses(context.Background(), client, []string{testconfig.ManagedClusterHost()}, false) if err != nil { t.Fatal(err) } diff --git a/pkg/service/cluster/service.go b/pkg/service/cluster/service.go index 0a9bb7077..bb7d92cc8 100644 --- a/pkg/service/cluster/service.go +++ b/pkg/service/cluster/service.go @@ -617,9 +617,12 @@ func SingleHostSessionConfigOption(host string) SessionConfigOption { return errors.Wrapf(err, "fetch node (%s) info", host) } cqlAddr := ni.CQLAddr(host) + if ni.ClientEncryptionEnabled { + cqlAddr = ni.CQLSSLAddr(host) + } cfg.Hosts = []string{cqlAddr} - cfg.HostFilter = gocql.WhiteListHostFilter(cqlAddr) cfg.DisableInitialHostLookup = true + cfg.HostFilter = gocql.WhiteListHostFilter(cqlAddr) return nil } } @@ -643,9 +646,14 @@ func (s *Service) GetSession(ctx context.Context, clusterID uuid.UUID, opts ...S return session, err } } - // Fill hosts if they weren't specified by the options + + clusterInfo, err := s.GetClusterByID(ctx, clusterID) + if err != nil { + return session, errors.Wrap(err, "cluster by id") + } + // Fill hosts if they weren't specified by the options or make sure that they use correct rpc address. if len(cfg.Hosts) == 0 { - sessionHosts, err := GetRPCAddresses(ctx, client, client.Config().Hosts) + sessionHosts, err := GetRPCAddresses(ctx, client, client.Config().Hosts, clusterInfo.ForceTLSDisabled || clusterInfo.ForceNonSSLSessionPort) if err != nil { s.logger.Info(ctx, "Gets session", "err", err) if errors.Is(err, ErrNoRPCAddressesFound) { @@ -662,7 +670,7 @@ func (s *Service) GetSession(ctx context.Context, clusterID uuid.UUID, opts ...S if err := s.extendClusterConfigWithAuthentication(clusterID, ni, cfg); err != nil { return session, err } - if err := s.extendClusterConfigWithTLS(ctx, clusterID, ni, cfg); err != nil { + if err := s.extendClusterConfigWithTLS(clusterInfo, ni, cfg); err != nil { return session, err } @@ -695,24 +703,15 @@ func (s *Service) extendClusterConfigWithAuthentication(clusterID uuid.UUID, ni return nil } -func (s *Service) extendClusterConfigWithTLS(ctx context.Context, clusterID uuid.UUID, ni *scyllaclient.NodeInfo, cfg *gocql.ClusterConfig) error { - cluster, err := s.GetClusterByID(ctx, clusterID) - if err != nil { - return errors.Wrap(err, "get cluster by id") - } - - cqlPort := ni.CQLPort() +func (s *Service) extendClusterConfigWithTLS(cluster *Cluster, ni *scyllaclient.NodeInfo, cfg *gocql.ClusterConfig) error { if ni.ClientEncryptionEnabled && !cluster.ForceTLSDisabled { - if !cluster.ForceNonSSLSessionPort { - cqlPort = ni.CQLSSLPort() - } cfg.SslOpts = &gocql.SslOptions{ Config: &tls.Config{ InsecureSkipVerify: true, }, } if ni.ClientEncryptionRequireAuth { - keyPair, err := s.loadTLSIdentity(clusterID) + keyPair, err := s.loadTLSIdentity(cluster.ID) if err != nil { return err } @@ -720,11 +719,6 @@ func (s *Service) extendClusterConfigWithTLS(ctx context.Context, clusterID uuid } } - p, err := strconv.Atoi(cqlPort) - if err != nil { - return errors.Wrap(err, "parse cql port") - } - cfg.Port = p return nil } @@ -770,7 +764,7 @@ var ErrNoRPCAddressesFound = errors.New("no RPC addresses found") // GetRPCAddresses accepts client and hosts parameters that are used later on to query client.NodeInfo endpoint // returning RPC addresses for given hosts. // RPC addresses are the ones that scylla uses to accept CQL connections. -func GetRPCAddresses(ctx context.Context, client *scyllaclient.Client, hosts []string) ([]string, error) { +func GetRPCAddresses(ctx context.Context, client *scyllaclient.Client, hosts []string, clusterSSLDisabled bool) ([]string, error) { var sessionHosts []string var combinedError error for _, h := range hosts { @@ -779,7 +773,11 @@ func GetRPCAddresses(ctx context.Context, client *scyllaclient.Client, hosts []s combinedError = multierr.Append(combinedError, err) continue } - sessionHosts = append(sessionHosts, ni.CQLAddr(h)) + addr := ni.CQLAddr(h) + if ni.ClientEncryptionEnabled && !clusterSSLDisabled { + addr = ni.CQLSSLAddr(h) + } + sessionHosts = append(sessionHosts, addr) } if len(sessionHosts) == 0 { diff --git a/pkg/service/restore/service_restore_integration_test.go b/pkg/service/restore/service_restore_integration_test.go index d3a79ff39..1892128ee 100644 --- a/pkg/service/restore/service_restore_integration_test.go +++ b/pkg/service/restore/service_restore_integration_test.go @@ -1737,7 +1737,7 @@ func (h *restoreTestHelper) restartScylla() { b := backoff.WithContext(backoff.WithMaxRetries( backoff.NewConstantBackOff(500*time.Millisecond), 10), ctx) if err := backoff.Retry(func() error { - sessionHosts, err = cluster.GetRPCAddresses(ctx, h.Client, []string{host}) + sessionHosts, err = cluster.GetRPCAddresses(ctx, h.Client, []string{host}, false) return err }, b); err != nil { h.T.Fatal(err) diff --git a/pkg/testutils/db/db.go b/pkg/testutils/db/db.go index 5d59b6ac5..05fff3bce 100644 --- a/pkg/testutils/db/db.go +++ b/pkg/testutils/db/db.go @@ -83,7 +83,7 @@ func CreateManagedClusterSession(tb testing.TB, empty bool, client *scyllaclient tb.Helper() ctx := context.Background() - sessionHosts, err := cluster.GetRPCAddresses(ctx, client, client.Config().Hosts) + sessionHosts, err := cluster.GetRPCAddresses(ctx, client, client.Config().Hosts, false) if err != nil { tb.Log(err) if errors.Is(err, cluster.ErrNoRPCAddressesFound) {