Skip to content

Commit

Permalink
Add the validation if session is operational
Browse files Browse the repository at this point in the history
Fail to connect if DC or rack name is different in the topology
than the one the user entered.
  • Loading branch information
sylwiaszunejko committed Jul 5, 2024
1 parent 0bd6283 commit 3b17292
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 4 deletions.
48 changes: 48 additions & 0 deletions policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ type HostSelectionPolicy interface {
// so it's safe to have internal state without additional synchronization as long as every call to Pick returns
// a different instance of NextHost.
Pick(ExecutableQuery) NextHost
// IsOperational checks if host policy can properly work with given Session/Cluster/ClusterConfig
IsOperational(*Session) error
}

// SelectedHost is an interface returned when picking a host from a host
Expand Down Expand Up @@ -363,6 +365,7 @@ func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {}
func (r *roundRobinHostPolicy) Init(*Session) {}
func (r *roundRobinHostPolicy) Reset() {}
func (r *roundRobinHostPolicy) IsOperational(*Session) error { return nil }

// Experimental, this interface and use may change
func (r *roundRobinHostPolicy) SetTablets(tablets []*TabletInfo) {}
Expand Down Expand Up @@ -489,6 +492,10 @@ func (t *tokenAwareHostPolicy) Reset() {
t.logger = nil
}

func (t *tokenAwareHostPolicy) IsOperational(session *Session) error {
return t.fallback.IsOperational(session)
}

func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
return t.fallback.IsLocal(host)
}
Expand Down Expand Up @@ -823,6 +830,7 @@ type hostPoolHostPolicy struct {

func (r *hostPoolHostPolicy) Init(*Session) {}
func (r *hostPoolHostPolicy) Reset() {}
func (r *hostPoolHostPolicy) IsOperational(*Session) error { return nil }
func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *hostPoolHostPolicy) SetPartitioner(string) {}
func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true }
Expand Down Expand Up @@ -984,6 +992,27 @@ func (d *dcAwareRR) Reset() {}
func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *dcAwareRR) SetPartitioner(p string) {}

func (d *dcAwareRR) IsOperational(session *Session) error {
if session.cfg.disableInit || session.cfg.disableControlConn {
return nil
}

hosts, _, err := session.hostSource.GetHosts()
if err != nil {
return fmt.Errorf("gocql: unable to check if session is operational: %v", err)
}
for _, host := range hosts {
if !session.cfg.filterHost(host) && host.DataCenter() == d.local {
// Policy can work properly only if there is at least one host from target DC
// No need to check host status, since it could be down due to the outage
// We only need to make sure that policy is not misconfigured with wrong DC
return nil
}
}

return fmt.Errorf("gocql: datacenter %s in the policy was not found in the topology", d.local)
}

func (d *dcAwareRR) IsLocal(host *HostInfo) bool {
return host.DataCenter() == d.local
}
Expand Down Expand Up @@ -1088,6 +1117,25 @@ func (d *rackAwareRR) Reset() {}
func (d *rackAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *rackAwareRR) SetPartitioner(p string) {}

func (d *rackAwareRR) IsOperational(session *Session) error {
if session.cfg.disableInit || session.cfg.disableControlConn {
return nil
}
hosts, _, err := session.hostSource.GetHosts()
if err != nil {
return fmt.Errorf("gocql: unable to check if session is operational: %v", err)
}
for _, host := range hosts {
if !session.cfg.filterHost(host) && host.DataCenter() == d.localDC && host.Rack() == d.localRack {
// Policy can work properly only if there is at least one host from target DC+Rack
// No need to check host status, since it could be down due to the outage
// We only need to make sure that policy is not misconfigured with wrong DC+Rack
return nil
}
}
return fmt.Errorf("gocql: rack %s/%s was not found in the topology", d.localDC, d.localRack)
}

func (d *rackAwareRR) MaxHostTier() uint {
return 2
}
Expand Down
41 changes: 41 additions & 0 deletions policies_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build integration && scylla
// +build integration,scylla

package gocql

import (
"testing"
)

// Check if session fail to start if DC name provided in the policy is wrong
func TestDCValidationTokenAware(t *testing.T) {
cluster := createCluster()

fallback := DCAwareRoundRobinPolicy("WRONG_DC")
cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback)

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}

func TestDCValidationDCAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = DCAwareRoundRobinPolicy("WRONG_DC")

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}

func TestDCValidationRackAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = RackAwareRoundRobinPolicy("WRONG_DC", "RACK")

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}
17 changes: 13 additions & 4 deletions policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ func TestHostPolicy_DCAwareRR(t *testing.T) {

}

func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) {
p := DCAwareRoundRobinPolicy("wrong_dc", HostPolicyOptionDisableDCFailover)
func TestHostPolicy_DCAwareRR_disableDCFailover(t *testing.T) {
p := DCAwareRoundRobinPolicy("local", HostPolicyOptionDisableDCFailover)

hosts := [...]*HostInfo{
{hostId: "0", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "local"},
Expand All @@ -616,19 +616,28 @@ func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) {
}

got := make(map[string]bool, len(hosts))
var dcs []string

it := p.Pick(nil)
for h := it(); h != nil; h = it() {
id := h.Info().hostId
dc := h.Info().dataCenter

if got[id] {
t.Fatalf("got duplicate host %s", id)
}
got[id] = true
dcs = append(dcs, dc)
}

if len(got) != 0 {
t.Fatalf("expected %d hosts got %d", 0, len(got))
if len(got) != 2 {
t.Fatalf("expected %d hosts got %d", 2, len(got))
}

for _, dc := range dcs {
if dc == "remote" {
t.Fatalf("got remote dc but failover was diabled")
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
}
}

if s.policy.IsOperational(s) != nil {
return nil, fmt.Errorf("gocql: unable to create session: %v", err)
}

return s, nil
}

Expand Down

0 comments on commit 3b17292

Please sign in to comment.