Skip to content

Commit

Permalink
Add the datacenter name validation if provided
Browse files Browse the repository at this point in the history
Fail to connect if DC name is different in the topology
than the one the user entered.
  • Loading branch information
sylwiaszunejko committed Jul 4, 2024
1 parent e6d9bec commit ad02224
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 4 deletions.
34 changes: 34 additions & 0 deletions policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ type HostSelectionPolicy interface {
// so it's safe to have internal state without additional synchronization as long as every call to Pick returns
// a different instance of NextHost.
Pick(ExecutableQuery) NextHost
// Checks if datacenter is valid if local aware policy is used (explicitly or as a fallback)
IsDatacenterValid(datacenters []string) bool
}

// SelectedHost is an interface returned when picking a host from a host
Expand Down Expand Up @@ -363,6 +365,7 @@ func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {}
func (r *roundRobinHostPolicy) Init(*Session) {}
func (r *roundRobinHostPolicy) Reset() {}
func (r *roundRobinHostPolicy) IsDatacenterValid([]string) bool { return true }

// Experimental, this interface and use may change
func (r *roundRobinHostPolicy) SetTablets(tablets []*TabletInfo) {}
Expand Down Expand Up @@ -489,6 +492,10 @@ func (t *tokenAwareHostPolicy) Reset() {
t.logger = nil
}

func (t *tokenAwareHostPolicy) IsDatacenterValid(datacenters []string) bool {
return t.fallback.IsDatacenterValid(datacenters)
}

func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
return t.fallback.IsLocal(host)
}
Expand Down Expand Up @@ -823,6 +830,7 @@ type hostPoolHostPolicy struct {

func (r *hostPoolHostPolicy) Init(*Session) {}
func (r *hostPoolHostPolicy) Reset() {}
func (r *hostPoolHostPolicy) IsDatacenterValid([]string) bool { return true }
func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *hostPoolHostPolicy) SetPartitioner(string) {}
func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true }
Expand Down Expand Up @@ -984,6 +992,19 @@ func (d *dcAwareRR) Reset() {}
func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *dcAwareRR) SetPartitioner(p string) {}

func (d *dcAwareRR) IsDatacenterValid(datacenters []string) bool {
found := false

for _, dc := range datacenters {
if dc == d.local {
found = true
break
}
}

return found
}

func (d *dcAwareRR) IsLocal(host *HostInfo) bool {
return host.DataCenter() == d.local
}
Expand Down Expand Up @@ -1090,6 +1111,19 @@ func (d *rackAwareRR) Reset() {}
func (d *rackAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *rackAwareRR) SetPartitioner(p string) {}

func (d *rackAwareRR) IsDatacenterValid(datacenters []string) bool {
found := false

for _, dc := range datacenters {
if dc == d.localDC {
found = true
break
}
}

return found
}

func (d *rackAwareRR) MaxHostTier() uint {
return 2
}
Expand Down
41 changes: 41 additions & 0 deletions policies_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build integration && scylla
// +build integration,scylla

package gocql

import (
"testing"
)

// Check if session fail to start if DC name provided in the policy is wrong
func TestDCValidationTokenAware(t *testing.T) {
cluster := createCluster()

fallback := DCAwareRoundRobinPolicy("WRONG_DC")
cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback)

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}

func TestDCValidationDCAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = DCAwareRoundRobinPolicy("WRONG_DC")

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}

func TestDCValidationRackAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = RackAwareRoundRobinPolicy("WRONG_DC", "RACK")

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}
17 changes: 13 additions & 4 deletions policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ func TestHostPolicy_DCAwareRR(t *testing.T) {

}

func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) {
p := DCAwareRoundRobinPolicy("wrong_dc", HostPolicyOptionDisableDCFailover)
func TestHostPolicy_DCAwareRR_disableDCFailover(t *testing.T) {
p := DCAwareRoundRobinPolicy("local", HostPolicyOptionDisableDCFailover)

hosts := [...]*HostInfo{
{hostId: "0", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "local"},
Expand All @@ -616,19 +616,28 @@ func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) {
}

got := make(map[string]bool, len(hosts))
var dcs []string

it := p.Pick(nil)
for h := it(); h != nil; h = it() {
id := h.Info().hostId
dc := h.Info().dataCenter

if got[id] {
t.Fatalf("got duplicate host %s", id)
}
got[id] = true
dcs = append(dcs, dc)
}

if len(got) != 0 {
t.Fatalf("expected %d hosts got %d", 0, len(got))
if len(got) != 2 {
t.Fatalf("expected %d hosts got %d", 2, len(got))
}

for _, dc := range dcs {
if dc == "remote" {
t.Fatalf("got remote dc but failover was diabled")
}
}
}

Expand Down
8 changes: 8 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
}
}

var datacenters []string
for _, host := range s.ring.allHosts() {
datacenters = append(datacenters, host.DataCenter())
}
if !s.policy.IsDatacenterValid(datacenters) {
return nil, fmt.Errorf("gocql: unable to create session: datacenter provided in the policy is not found in the topology")
}

return s, nil
}

Expand Down

0 comments on commit ad02224

Please sign in to comment.