Skip to content

Commit

Permalink
fix: handle ssl only scylla clusters
Browse files Browse the repository at this point in the history
This fixes how SM decides which port to use when connecting to Scylla
nodes.
  • Loading branch information
VAveryanov8 committed Nov 15, 2024
1 parent c2e2bb1 commit 378d7b8
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 26 deletions.
5 changes: 3 additions & 2 deletions pkg/ping/cqlping/cqlping_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ package cqlping
import (
"context"
"crypto/tls"
"github.com/scylladb/scylla-manager/v3/pkg/testutils/testconfig"
"testing"
"time"

"github.com/scylladb/scylla-manager/v3/pkg/testutils/testconfig"

"github.com/scylladb/go-log"
"github.com/scylladb/scylla-manager/v3/pkg/ping"
"github.com/scylladb/scylla-manager/v3/pkg/scyllaclient"
Expand All @@ -24,7 +25,7 @@ func TestPingIntegration(t *testing.T) {
client := newTestClient(t, log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("client"), nil)
defer client.Close()

sessionHosts, err := cluster.GetRPCAddresses(context.Background(), client, []string{testconfig.ManagedClusterHost()})
sessionHosts, err := cluster.GetRPCAddresses(context.Background(), client, []string{testconfig.ManagedClusterHost()}, false)
if err != nil {
t.Fatal(err)
}
Expand Down
42 changes: 20 additions & 22 deletions pkg/service/cluster/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,12 @@ func SingleHostSessionConfigOption(host string) SessionConfigOption {
return errors.Wrapf(err, "fetch node (%s) info", host)
}
cqlAddr := ni.CQLAddr(host)
if ni.ClientEncryptionEnabled {
cqlAddr = ni.CQLSSLAddr(host)
}
cfg.Hosts = []string{cqlAddr}
cfg.HostFilter = gocql.WhiteListHostFilter(cqlAddr)
cfg.DisableInitialHostLookup = true
cfg.HostFilter = gocql.WhiteListHostFilter(cqlAddr)
return nil
}
}
Expand All @@ -643,9 +646,14 @@ func (s *Service) GetSession(ctx context.Context, clusterID uuid.UUID, opts ...S
return session, err
}
}
// Fill hosts if they weren't specified by the options

clusterInfo, err := s.GetClusterByID(ctx, clusterID)
if err != nil {
return session, errors.Wrap(err, "cluster by id")
}
// Fill hosts if they weren't specified by the options or make sure that they use correct rpc address.
if len(cfg.Hosts) == 0 {
sessionHosts, err := GetRPCAddresses(ctx, client, client.Config().Hosts)
sessionHosts, err := GetRPCAddresses(ctx, client, client.Config().Hosts, clusterInfo.ForceTLSDisabled || clusterInfo.ForceNonSSLSessionPort)
if err != nil {
s.logger.Info(ctx, "Gets session", "err", err)
if errors.Is(err, ErrNoRPCAddressesFound) {
Expand All @@ -662,7 +670,7 @@ func (s *Service) GetSession(ctx context.Context, clusterID uuid.UUID, opts ...S
if err := s.extendClusterConfigWithAuthentication(clusterID, ni, cfg); err != nil {
return session, err
}
if err := s.extendClusterConfigWithTLS(ctx, clusterID, ni, cfg); err != nil {
if err := s.extendClusterConfigWithTLS(clusterInfo, ni, cfg); err != nil {
return session, err
}

Expand Down Expand Up @@ -695,36 +703,22 @@ func (s *Service) extendClusterConfigWithAuthentication(clusterID uuid.UUID, ni
return nil
}

func (s *Service) extendClusterConfigWithTLS(ctx context.Context, clusterID uuid.UUID, ni *scyllaclient.NodeInfo, cfg *gocql.ClusterConfig) error {
cluster, err := s.GetClusterByID(ctx, clusterID)
if err != nil {
return errors.Wrap(err, "get cluster by id")
}

cqlPort := ni.CQLPort()
func (s *Service) extendClusterConfigWithTLS(cluster *Cluster, ni *scyllaclient.NodeInfo, cfg *gocql.ClusterConfig) error {
if ni.ClientEncryptionEnabled && !cluster.ForceTLSDisabled {
if !cluster.ForceNonSSLSessionPort {
cqlPort = ni.CQLSSLPort()
}
cfg.SslOpts = &gocql.SslOptions{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
if ni.ClientEncryptionRequireAuth {
keyPair, err := s.loadTLSIdentity(clusterID)
keyPair, err := s.loadTLSIdentity(cluster.ID)
if err != nil {
return err
}
cfg.SslOpts.Config.Certificates = []tls.Certificate{keyPair}
}
}

p, err := strconv.Atoi(cqlPort)
if err != nil {
return errors.Wrap(err, "parse cql port")
}
cfg.Port = p
return nil
}

Expand Down Expand Up @@ -770,7 +764,7 @@ var ErrNoRPCAddressesFound = errors.New("no RPC addresses found")
// GetRPCAddresses accepts client and hosts parameters that are used later on to query client.NodeInfo endpoint
// returning RPC addresses for given hosts.
// RPC addresses are the ones that scylla uses to accept CQL connections.
func GetRPCAddresses(ctx context.Context, client *scyllaclient.Client, hosts []string) ([]string, error) {
func GetRPCAddresses(ctx context.Context, client *scyllaclient.Client, hosts []string, clusterSSLDisabled bool) ([]string, error) {
var sessionHosts []string
var combinedError error
for _, h := range hosts {
Expand All @@ -779,7 +773,11 @@ func GetRPCAddresses(ctx context.Context, client *scyllaclient.Client, hosts []s
combinedError = multierr.Append(combinedError, err)
continue
}
sessionHosts = append(sessionHosts, ni.CQLAddr(h))
addr := ni.CQLAddr(h)
if ni.ClientEncryptionEnabled && !clusterSSLDisabled {
addr = ni.CQLSSLAddr(h)
}
sessionHosts = append(sessionHosts, addr)
}

if len(sessionHosts) == 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/service/restore/service_restore_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,7 @@ func (h *restoreTestHelper) restartScylla() {
b := backoff.WithContext(backoff.WithMaxRetries(
backoff.NewConstantBackOff(500*time.Millisecond), 10), ctx)
if err := backoff.Retry(func() error {
sessionHosts, err = cluster.GetRPCAddresses(ctx, h.Client, []string{host})
sessionHosts, err = cluster.GetRPCAddresses(ctx, h.Client, []string{host}, false)
return err
}, b); err != nil {
h.T.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/testutils/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func CreateManagedClusterSession(tb testing.TB, empty bool, client *scyllaclient
tb.Helper()
ctx := context.Background()

sessionHosts, err := cluster.GetRPCAddresses(ctx, client, client.Config().Hosts)
sessionHosts, err := cluster.GetRPCAddresses(ctx, client, client.Config().Hosts, false)
if err != nil {
tb.Log(err)
if errors.Is(err, cluster.ErrNoRPCAddressesFound) {
Expand Down

0 comments on commit 378d7b8

Please sign in to comment.