Skip to content

Commit

Permalink
ZDM-71: Introduce protocol negotiation
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-antoniak committed Jun 17, 2024
1 parent a6a5fc1 commit 9257cbd
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 19 deletions.
35 changes: 35 additions & 0 deletions integration-tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/datastax/go-cassandra-native-protocol/message"
"github.com/datastax/go-cassandra-native-protocol/primitive"
"github.com/datastax/zdm-proxy/integration-tests/client"
"github.com/datastax/zdm-proxy/integration-tests/env"
"github.com/datastax/zdm-proxy/integration-tests/setup"
"github.com/datastax/zdm-proxy/integration-tests/utils"
"github.com/datastax/zdm-proxy/proxy/pkg/config"
Expand Down Expand Up @@ -45,6 +46,40 @@ func TestGoCqlConnect(t *testing.T) {
require.Equal(t, "fake", iter.Columns()[0].Name)
}

func TestProtocolVersionNegotiation(t *testing.T) {
testCassandraVersion := env.CassandraVersion
env.CassandraVersion = "2.1" // downgrade C* version for protocol negotiation test
defer func() {
env.CassandraVersion = testCassandraVersion
}()
c := setup.NewTestConfig("", "")
c.ProtocolVersion = 4 // configure unsupported protocol version
testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c)
require.Nil(t, err)
defer testSetup.Cleanup()

// Connect to proxy as a "client"
proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3)

if err != nil {
t.Log("Unable to connect to proxy session.")
t.Fatal(err)
}
defer proxy.Close()

iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter()
result, err := iter.SliceMap()

if err != nil {
t.Fatal("query failed:", err)
}

require.Equal(t, 0, len(result))

// simulacron generates fake response metadata when queries aren't primed
require.Equal(t, "fake", iter.Columns()[0].Name)
}

func TestMaxClientsThreshold(t *testing.T) {
maxClients := 10
goCqlConnectionsPerHost := 1
Expand Down
1 change: 1 addition & 0 deletions integration-tests/setup/testcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config {
conf.ReadMode = config.ReadModePrimaryOnly
conf.SystemQueriesMode = config.SystemQueriesModeOrigin
conf.AsyncHandshakeTimeoutMs = 4000
conf.ProtocolVersion = 3

conf.ProxyRequestTimeoutMs = 10000

Expand Down
9 changes: 7 additions & 2 deletions integration-tests/utils/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ func CheckMetricsEndpointResult(httpAddr string, success bool) error {
return nil
}

// ConnectToCluster is used to connect to source and destination clusters
func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) {
func ConnectToClusterUsingVersion(hostname string, username string, password string, port int, protoVersion int) (*gocql.Session, error) {
cluster := NewCluster(hostname, username, password, port)
cluster.ProtoVersion = protoVersion
session, err := cluster.CreateSession()
log.Debugf("Connection established with Cluster: %s:%d", cluster.Hosts[0], cluster.Port)
if err != nil {
Expand All @@ -127,6 +127,11 @@ func ConnectToCluster(hostname string, username string, password string, port in
return session, nil
}

// ConnectToCluster is used to connect to source and destination clusters
func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) {
return ConnectToClusterUsingVersion(hostname, username, password, port, 4)
}

// NewCluster initializes a ClusterConfig object with common settings
func NewCluster(hostname string, username string, password string, port int) *gocql.ClusterConfig {
cluster := gocql.NewCluster(hostname)
Expand Down
1 change: 1 addition & 0 deletions proxy/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Config struct {
ReplaceCqlFunctions bool `default:"false" split_words:"true"`
AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"`
LogLevel string `default:"INFO" split_words:"true"`
ProtocolVersion uint `default:"3" split_words:"true"`

// Proxy Topology (also known as system.peers "virtualization") bucket

Expand Down
70 changes: 56 additions & 14 deletions proxy/pkg/zdmproxy/controlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ type ControlConn struct {
protocolEventSubscribers map[ProtocolEventObserver]interface{}
authEnabled *atomic.Value
metricsHandler *metrics.MetricHandler
protocolVersion primitive.ProtocolVersion
}

const ProxyVirtualRack = "rack0"
const ProxyVirtualPartitioner = "org.apache.cassandra.dht.Murmur3Partitioner"
const ccProtocolVersion = primitive.ProtocolVersion3
const ccWriteTimeout = 5 * time.Second
const ccReadTimeout = 10 * time.Second

Expand Down Expand Up @@ -320,15 +320,9 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (

currentIndex := (firstEndpointIndex + i) % len(endpoints)
endpoint = endpoints[currentIndex]
tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false)
if err != nil {
log.Warnf("Failed to open control connection to %v using endpoint %v: %v",
cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err)
continue
}

newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf)
err = newConn.InitializeContext(ccProtocolVersion, ctx)
newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ProtocolVersion, ctx)

if err == nil {
newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) {
switch f.Body.Message.(type) {
Expand All @@ -355,9 +349,11 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (
log.Warnf("Error while initializing a new cql connection for the control connection of %v: %v",
cc.connConfig.GetClusterType(), err)
}
err2 := newConn.Close()
if err2 != nil {
log.Errorf("Failed to close cql connection: %v", err2)
if newConn != nil {
err2 := newConn.Close()
if err2 != nil {
log.Errorf("Failed to close cql connection: %v", err2)
}
}

continue
Expand All @@ -372,6 +368,52 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (
return conn, endpoint
}

func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer uint, ctx context.Context) (CqlConnection, error) {
protoVer := primitive.ProtocolVersion(initialProtoVer)
for {
tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false)
if err != nil {
log.Warnf("Failed to open control connection to %v using endpoint %v: %v",
cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err)
return nil, err
}
newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf)
err = newConn.InitializeContext(protoVer, ctx)
if err != nil && strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
// unsupported protocol version
// protocol renegotiation requires opening a new TCP connection
err2 := newConn.Close()
if err2 != nil {
log.Errorf("Failed to close cql connection: %v", err2)
}
protoVer = downgradeProtocol(protoVer)
log.Infof("Downgrading protocol version: %v", protoVer)
if protoVer == 0 {
// we cannot downgrade anymore
return nil, err
}
continue // retry lower protocol version
} else {
cc.protocolVersion = protoVer
return newConn, err // we may have successfully established connection or faced other error
}
}
}

func downgradeProtocol(version primitive.ProtocolVersion) primitive.ProtocolVersion {
switch version {
case primitive.ProtocolVersionDse2:
return primitive.ProtocolVersionDse1
case primitive.ProtocolVersionDse1:
return primitive.ProtocolVersion4
case primitive.ProtocolVersion4:
return primitive.ProtocolVersion3
case primitive.ProtocolVersion3:
return primitive.ProtocolVersion2
}
return 0
}

func (cc *ControlConn) Close() {
cc.cqlConnLock.Lock()
conn := cc.cqlConn
Expand All @@ -387,7 +429,7 @@ func (cc *ControlConn) Close() {
}

func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]*Host, error) {
localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx)
localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch information from system.local table: %w", err)
}
Expand All @@ -410,7 +452,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]
}
}

peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx)
peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch information from system.peers table: %w", err)
}
Expand Down
8 changes: 5 additions & 3 deletions proxy/pkg/zdmproxy/cqlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type cqlConn struct {
eventHandlerLock *sync.Mutex
authEnabled bool
frameProcessor FrameProcessor
protocolVersion primitive.ProtocolVersion
}

var (
Expand Down Expand Up @@ -237,6 +238,7 @@ func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx conte
return fmt.Errorf("failed to perform handshake: %w", err)
}

c.protocolVersion = version
c.initialized = true
c.authEnabled = authEnabled
return nil
Expand Down Expand Up @@ -375,7 +377,7 @@ func (c *cqlConn) Query(
},
}

queryFrame := frame.NewFrame(ccProtocolVersion, -1, queryMsg)
queryFrame := frame.NewFrame(c.protocolVersion, -1, queryMsg)
var rowSet *ParsedRowSet
for {
localResponse, err := c.SendAndReceive(queryFrame, ctx)
Expand Down Expand Up @@ -429,7 +431,7 @@ func (c *cqlConn) Query(
}

func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Message, error) {
queryFrame := frame.NewFrame(ccProtocolVersion, -1, msg)
queryFrame := frame.NewFrame(c.protocolVersion, -1, msg)
localResponse, err := c.SendAndReceive(queryFrame, ctx)
if err != nil {
return nil, err
Expand All @@ -440,7 +442,7 @@ func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Mes

func (c *cqlConn) SendHeartbeat(ctx context.Context) error {
optionsMsg := &message.Options{}
heartBeatFrame := frame.NewFrame(ccProtocolVersion, -1, optionsMsg)
heartBeatFrame := frame.NewFrame(c.protocolVersion, -1, optionsMsg)

response, err := c.SendAndReceive(heartBeatFrame, ctx)
if err != nil {
Expand Down

0 comments on commit 9257cbd

Please sign in to comment.