From 7182b55dcb1498853d711304140cde3ef069228b Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 4 Jul 2024 13:06:55 +0200 Subject: [PATCH] Add the datacenter name validation if provided Fail to connect if DC name is different in the topology than the one the user entered. --- policies.go | 34 ++++++++++++++++++++++++++++++ policies_integration_test.go | 41 ++++++++++++++++++++++++++++++++++++ policies_test.go | 17 +++++++++++---- session.go | 14 ++++++++++++ 4 files changed, 102 insertions(+), 4 deletions(-) create mode 100644 policies_integration_test.go diff --git a/policies.go b/policies.go index 853ca48b3..a47f51372 100644 --- a/policies.go +++ b/policies.go @@ -319,6 +319,8 @@ type HostSelectionPolicy interface { // so it's safe to have internal state without additional synchronization as long as every call to Pick returns // a different instance of NextHost. Pick(ExecutableQuery) NextHost + // Checks if datacenter is valid if local aware policy is used (explicitly or as a fallback) + IsDatacenterValid(datacenters []string) bool } // SelectedHost is an interface returned when picking a host from a host @@ -363,6 +365,7 @@ func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {} func (r *roundRobinHostPolicy) Init(*Session) {} func (r *roundRobinHostPolicy) Reset() {} +func (r *roundRobinHostPolicy) IsDatacenterValid([]string) bool { return true } // Experimental, this interface and use may change func (r *roundRobinHostPolicy) SetTablets(tablets []*TabletInfo) {} @@ -489,6 +492,10 @@ func (t *tokenAwareHostPolicy) Reset() { t.logger = nil } +func (t *tokenAwareHostPolicy) IsDatacenterValid(datacenters []string) bool { + return t.fallback.IsDatacenterValid(datacenters) +} + func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool { return t.fallback.IsLocal(host) } @@ -823,6 +830,7 @@ type hostPoolHostPolicy struct { func (r *hostPoolHostPolicy) Init(*Session) {} func (r *hostPoolHostPolicy) Reset() {} +func (r *hostPoolHostPolicy) IsDatacenterValid([]string) bool { return true } func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} func (r *hostPoolHostPolicy) SetPartitioner(string) {} func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true } @@ -984,6 +992,19 @@ func (d *dcAwareRR) Reset() {} func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {} func (d *dcAwareRR) SetPartitioner(p string) {} +func (d *dcAwareRR) IsDatacenterValid(datacenters []string) bool { + found := false + + for _, dc := range datacenters { + if dc == d.local { + found = true + break + } + } + + return found +} + func (d *dcAwareRR) IsLocal(host *HostInfo) bool { return host.DataCenter() == d.local } @@ -1088,6 +1109,19 @@ func (d *rackAwareRR) Reset() {} func (d *rackAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {} func (d *rackAwareRR) SetPartitioner(p string) {} +func (d *rackAwareRR) IsDatacenterValid(datacenters []string) bool { + found := false + + for _, dc := range datacenters { + if dc == d.localDC { + found = true + break + } + } + + return found +} + func (d *rackAwareRR) MaxHostTier() uint { return 2 } diff --git a/policies_integration_test.go b/policies_integration_test.go new file mode 100644 index 000000000..608ac3299 --- /dev/null +++ b/policies_integration_test.go @@ -0,0 +1,41 @@ +//go:build integration && scylla +// +build integration,scylla + +package gocql + +import ( + "testing" +) + +// Check if session fail to start if DC name provided in the policy is wrong +func TestDCValidationTokenAware(t *testing.T) { + cluster := createCluster() + + fallback := DCAwareRoundRobinPolicy("WRONG_DC") + cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback) + + _, err := cluster.CreateSession() + if err == nil { + t.Fatal("createSession was expected to fail with wrong DC name provided.") + } +} + +func TestDCValidationDCAware(t *testing.T) { + cluster := createCluster() + cluster.PoolConfig.HostSelectionPolicy = DCAwareRoundRobinPolicy("WRONG_DC") + + _, err := cluster.CreateSession() + if err == nil { + t.Fatal("createSession was expected to fail with wrong DC name provided.") + } +} + +func TestDCValidationRackAware(t *testing.T) { + cluster := createCluster() + cluster.PoolConfig.HostSelectionPolicy = RackAwareRoundRobinPolicy("WRONG_DC", "RACK") + + _, err := cluster.CreateSession() + if err == nil { + t.Fatal("createSession was expected to fail with wrong DC name provided.") + } +} diff --git a/policies_test.go b/policies_test.go index d8e83f341..dd4969a96 100644 --- a/policies_test.go +++ b/policies_test.go @@ -601,8 +601,8 @@ func TestHostPolicy_DCAwareRR(t *testing.T) { } -func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) { - p := DCAwareRoundRobinPolicy("wrong_dc", HostPolicyOptionDisableDCFailover) +func TestHostPolicy_DCAwareRR_disableDCFailover(t *testing.T) { + p := DCAwareRoundRobinPolicy("local", HostPolicyOptionDisableDCFailover) hosts := [...]*HostInfo{ {hostId: "0", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "local"}, @@ -616,19 +616,28 @@ func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) { } got := make(map[string]bool, len(hosts)) + var dcs []string it := p.Pick(nil) for h := it(); h != nil; h = it() { id := h.Info().hostId + dc := h.Info().dataCenter if got[id] { t.Fatalf("got duplicate host %s", id) } got[id] = true + dcs = append(dcs, dc) } - if len(got) != 0 { - t.Fatalf("expected %d hosts got %d", 0, len(got)) + if len(got) != 2 { + t.Fatalf("expected %d hosts got %d", 2, len(got)) + } + + for _, dc := range dcs { + if dc == "remote" { + t.Fatalf("got remote dc but failover was diabled") + } } } diff --git a/session.go b/session.go index 247accd7c..3bff4b0f7 100644 --- a/session.go +++ b/session.go @@ -197,6 +197,20 @@ func NewSession(cfg ClusterConfig) (*Session, error) { } } + if !s.cfg.disableInit && !s.cfg.disableControlConn { + var datacenters []string + hosts, _, err := s.hostSource.GetHosts() + if err != nil { + return nil, fmt.Errorf("gocql: unable to create session: %v", err) + } + for _, host := range hosts { + datacenters = append(datacenters, host.DataCenter()) + } + if !s.policy.IsDatacenterValid(datacenters) { + return nil, fmt.Errorf("gocql: unable to create session: datacenter provided in the policy was not found in the topology") + } + } + return s, nil }