From 72e55189e6725817910e42ef8345cfdaa0858678 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 18 Jun 2024 12:06:42 +0200 Subject: [PATCH] ZDM-71: Introduce protocol negotiation --- integration-tests/connect_test.go | 5 ++--- integration-tests/setup/testcluster.go | 2 +- proxy/pkg/config/config.go | 12 ++++++------ proxy/pkg/zdmproxy/controlconn.go | 8 +++----- proxy/pkg/zdmproxy/cqlconn.go | 19 ++++++++++++------- 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index 77fb11b..b3dbff4 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -53,7 +53,7 @@ func TestProtocolVersionNegotiation(t *testing.T) { env.CassandraVersion = testCassandraVersion }() c := setup.NewTestConfig("", "") - c.ProtocolVersion = 4 // configure unsupported protocol version + c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) require.Nil(t, err) defer testSetup.Cleanup() @@ -62,8 +62,7 @@ func TestProtocolVersionNegotiation(t *testing.T) { proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) if err != nil { - t.Log("Unable to connect to proxy session.") - t.Fatal(err) + t.Fatal("Unable to connect to proxy session.") } defer proxy.Close() diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index dac16ac..929850c 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -452,7 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config { conf.ReadMode = config.ReadModePrimaryOnly conf.SystemQueriesMode = config.SystemQueriesModeOrigin conf.AsyncHandshakeTimeoutMs = 4000 - conf.ProtocolVersion = 3 + conf.ControlConnMaxProtocolVersion = 3 conf.ProxyRequestTimeoutMs = 10000 diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index 6e6c402..c48b5e6 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -16,12 +16,12 @@ type Config struct { // Global bucket - PrimaryCluster string `default:"ORIGIN" split_words:"true"` - ReadMode string `default:"PRIMARY_ONLY" split_words:"true"` - ReplaceCqlFunctions bool `default:"false" split_words:"true"` - AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` - LogLevel string `default:"INFO" split_words:"true"` - ProtocolVersion uint `default:"3" split_words:"true"` + PrimaryCluster string `default:"ORIGIN" split_words:"true"` + ReadMode string `default:"PRIMARY_ONLY" split_words:"true"` + ReplaceCqlFunctions bool `default:"false" split_words:"true"` + AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` + LogLevel string `default:"INFO" split_words:"true"` + ControlConnMaxProtocolVersion uint `default:"3" split_words:"true"` // Proxy Topology (also known as system.peers "virtualization") bucket diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index e99f683..cef8bb3 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -54,7 +54,6 @@ type ControlConn struct { protocolEventSubscribers map[ProtocolEventObserver]interface{} authEnabled *atomic.Value metricsHandler *metrics.MetricHandler - protocolVersion primitive.ProtocolVersion } const ProxyVirtualRack = "rack0" @@ -321,7 +320,7 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( currentIndex := (firstEndpointIndex + i) % len(endpoints) endpoint = endpoints[currentIndex] - newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ProtocolVersion, ctx) + newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ControlConnMaxProtocolVersion, ctx) if err == nil { newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) { @@ -394,7 +393,6 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV } continue // retry lower protocol version } else { - cc.protocolVersion = protoVer return newConn, err // we may have successfully established connection or faced other error } } @@ -429,7 +427,7 @@ func (cc *ControlConn) Close() { } func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]*Host, error) { - localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx) + localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.local table: %w", err) } @@ -452,7 +450,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] } } - peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx) + peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.peers table: %w", err) } diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index c8a6e43..d7bb7a6 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -14,6 +14,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" ) @@ -32,7 +33,7 @@ type CqlConnection interface { SendAndReceive(request *frame.Frame, ctx context.Context) (*frame.Frame, error) Close() error Execute(msg message.Message, ctx context.Context) (message.Message, error) - Query(cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error) + Query(cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error) SendHeartbeat(ctx context.Context) error SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error @@ -59,7 +60,7 @@ type cqlConn struct { eventHandlerLock *sync.Mutex authEnabled bool frameProcessor FrameProcessor - protocolVersion primitive.ProtocolVersion + protocolVersion *atomic.Value } var ( @@ -238,7 +239,8 @@ func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx conte return fmt.Errorf("failed to perform handshake: %w", err) } - c.protocolVersion = version + c.protocolVersion = &atomic.Value{} + c.protocolVersion.Store(version) c.initialized = true c.authEnabled = authEnabled return nil @@ -369,7 +371,7 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex } func (c *cqlConn) Query( - cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error) { + cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error) { queryMsg := &message.Query{ Query: cql, Options: &message.QueryOptions{ @@ -377,7 +379,8 @@ func (c *cqlConn) Query( }, } - queryFrame := frame.NewFrame(c.protocolVersion, -1, queryMsg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + queryFrame := frame.NewFrame(version, -1, queryMsg) var rowSet *ParsedRowSet for { localResponse, err := c.SendAndReceive(queryFrame, ctx) @@ -431,7 +434,8 @@ func (c *cqlConn) Query( } func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Message, error) { - queryFrame := frame.NewFrame(c.protocolVersion, -1, msg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + queryFrame := frame.NewFrame(version, -1, msg) localResponse, err := c.SendAndReceive(queryFrame, ctx) if err != nil { return nil, err @@ -442,7 +446,8 @@ func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Mes func (c *cqlConn) SendHeartbeat(ctx context.Context) error { optionsMsg := &message.Options{} - heartBeatFrame := frame.NewFrame(c.protocolVersion, -1, optionsMsg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + heartBeatFrame := frame.NewFrame(version, -1, optionsMsg) response, err := c.SendAndReceive(heartBeatFrame, ctx) if err != nil {