Skip to content

Commit

Permalink
Add the datacenter name validation if provided
Browse files Browse the repository at this point in the history
Fail to connect if DC name is different in the topology
than the one the user entered.
  • Loading branch information
sylwiaszunejko committed Jul 4, 2024
1 parent 0bd6283 commit 7182b55
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 4 deletions.
34 changes: 34 additions & 0 deletions policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ type HostSelectionPolicy interface {
// so it's safe to have internal state without additional synchronization as long as every call to Pick returns
// a different instance of NextHost.
Pick(ExecutableQuery) NextHost
// Checks if datacenter is valid if local aware policy is used (explicitly or as a fallback)
IsDatacenterValid(datacenters []string) bool
}

// SelectedHost is an interface returned when picking a host from a host
Expand Down Expand Up @@ -363,6 +365,7 @@ func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {}
func (r *roundRobinHostPolicy) Init(*Session) {}
func (r *roundRobinHostPolicy) Reset() {}
func (r *roundRobinHostPolicy) IsDatacenterValid([]string) bool { return true }

// Experimental, this interface and use may change
func (r *roundRobinHostPolicy) SetTablets(tablets []*TabletInfo) {}
Expand Down Expand Up @@ -489,6 +492,10 @@ func (t *tokenAwareHostPolicy) Reset() {
t.logger = nil
}

func (t *tokenAwareHostPolicy) IsDatacenterValid(datacenters []string) bool {
return t.fallback.IsDatacenterValid(datacenters)
}

func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
return t.fallback.IsLocal(host)
}
Expand Down Expand Up @@ -823,6 +830,7 @@ type hostPoolHostPolicy struct {

func (r *hostPoolHostPolicy) Init(*Session) {}
func (r *hostPoolHostPolicy) Reset() {}
func (r *hostPoolHostPolicy) IsDatacenterValid([]string) bool { return true }
func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *hostPoolHostPolicy) SetPartitioner(string) {}
func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true }
Expand Down Expand Up @@ -984,6 +992,19 @@ func (d *dcAwareRR) Reset() {}
func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *dcAwareRR) SetPartitioner(p string) {}

func (d *dcAwareRR) IsDatacenterValid(datacenters []string) bool {
found := false

for _, dc := range datacenters {
if dc == d.local {
found = true
break
}
}

return found
}

func (d *dcAwareRR) IsLocal(host *HostInfo) bool {
return host.DataCenter() == d.local
}
Expand Down Expand Up @@ -1088,6 +1109,19 @@ func (d *rackAwareRR) Reset() {}
func (d *rackAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *rackAwareRR) SetPartitioner(p string) {}

func (d *rackAwareRR) IsDatacenterValid(datacenters []string) bool {
found := false

for _, dc := range datacenters {
if dc == d.localDC {
found = true
break
}
}

return found
}

func (d *rackAwareRR) MaxHostTier() uint {
return 2
}
Expand Down
41 changes: 41 additions & 0 deletions policies_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build integration && scylla
// +build integration,scylla

package gocql

import (
"testing"
)

// Check if session fail to start if DC name provided in the policy is wrong
func TestDCValidationTokenAware(t *testing.T) {
cluster := createCluster()

fallback := DCAwareRoundRobinPolicy("WRONG_DC")
cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback)

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}

func TestDCValidationDCAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = DCAwareRoundRobinPolicy("WRONG_DC")

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}

func TestDCValidationRackAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = RackAwareRoundRobinPolicy("WRONG_DC", "RACK")

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}
17 changes: 13 additions & 4 deletions policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ func TestHostPolicy_DCAwareRR(t *testing.T) {

}

func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) {
p := DCAwareRoundRobinPolicy("wrong_dc", HostPolicyOptionDisableDCFailover)
func TestHostPolicy_DCAwareRR_disableDCFailover(t *testing.T) {
p := DCAwareRoundRobinPolicy("local", HostPolicyOptionDisableDCFailover)

hosts := [...]*HostInfo{
{hostId: "0", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "local"},
Expand All @@ -616,19 +616,28 @@ func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) {
}

got := make(map[string]bool, len(hosts))
var dcs []string

it := p.Pick(nil)
for h := it(); h != nil; h = it() {
id := h.Info().hostId
dc := h.Info().dataCenter

if got[id] {
t.Fatalf("got duplicate host %s", id)
}
got[id] = true
dcs = append(dcs, dc)
}

if len(got) != 0 {
t.Fatalf("expected %d hosts got %d", 0, len(got))
if len(got) != 2 {
t.Fatalf("expected %d hosts got %d", 2, len(got))
}

for _, dc := range dcs {
if dc == "remote" {
t.Fatalf("got remote dc but failover was diabled")
}
}
}

Expand Down
14 changes: 14 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,20 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
}
}

if !s.cfg.disableInit && !s.cfg.disableControlConn {
var datacenters []string
hosts, _, err := s.hostSource.GetHosts()
if err != nil {
return nil, fmt.Errorf("gocql: unable to create session: %v", err)
}
for _, host := range hosts {
datacenters = append(datacenters, host.DataCenter())
}
if !s.policy.IsDatacenterValid(datacenters) {
return nil, fmt.Errorf("gocql: unable to create session: datacenter provided in the policy was not found in the topology")
}
}

return s, nil
}

Expand Down

0 comments on commit 7182b55

Please sign in to comment.