Skip to content

Commit

Permalink
ZDM-71: Introduce protocol negotiation
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-antoniak committed Jun 18, 2024
1 parent 9257cbd commit 72e5518
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 22 deletions.
5 changes: 2 additions & 3 deletions integration-tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestProtocolVersionNegotiation(t *testing.T) {
env.CassandraVersion = testCassandraVersion
}()
c := setup.NewTestConfig("", "")
c.ProtocolVersion = 4 // configure unsupported protocol version
c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version
testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c)
require.Nil(t, err)
defer testSetup.Cleanup()
Expand All @@ -62,8 +62,7 @@ func TestProtocolVersionNegotiation(t *testing.T) {
proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3)

if err != nil {
t.Log("Unable to connect to proxy session.")
t.Fatal(err)
t.Fatal("Unable to connect to proxy session.")
}
defer proxy.Close()

Expand Down
2 changes: 1 addition & 1 deletion integration-tests/setup/testcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config {
conf.ReadMode = config.ReadModePrimaryOnly
conf.SystemQueriesMode = config.SystemQueriesModeOrigin
conf.AsyncHandshakeTimeoutMs = 4000
conf.ProtocolVersion = 3
conf.ControlConnMaxProtocolVersion = 3

conf.ProxyRequestTimeoutMs = 10000

Expand Down
12 changes: 6 additions & 6 deletions proxy/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ type Config struct {

// Global bucket

PrimaryCluster string `default:"ORIGIN" split_words:"true"`
ReadMode string `default:"PRIMARY_ONLY" split_words:"true"`
ReplaceCqlFunctions bool `default:"false" split_words:"true"`
AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"`
LogLevel string `default:"INFO" split_words:"true"`
ProtocolVersion uint `default:"3" split_words:"true"`
PrimaryCluster string `default:"ORIGIN" split_words:"true"`
ReadMode string `default:"PRIMARY_ONLY" split_words:"true"`
ReplaceCqlFunctions bool `default:"false" split_words:"true"`
AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"`
LogLevel string `default:"INFO" split_words:"true"`
ControlConnMaxProtocolVersion uint `default:"3" split_words:"true"`

// Proxy Topology (also known as system.peers "virtualization") bucket

Expand Down
8 changes: 3 additions & 5 deletions proxy/pkg/zdmproxy/controlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ type ControlConn struct {
protocolEventSubscribers map[ProtocolEventObserver]interface{}
authEnabled *atomic.Value
metricsHandler *metrics.MetricHandler
protocolVersion primitive.ProtocolVersion
}

const ProxyVirtualRack = "rack0"
Expand Down Expand Up @@ -321,7 +320,7 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (
currentIndex := (firstEndpointIndex + i) % len(endpoints)
endpoint = endpoints[currentIndex]

newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ProtocolVersion, ctx)
newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ControlConnMaxProtocolVersion, ctx)

if err == nil {
newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) {
Expand Down Expand Up @@ -394,7 +393,6 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV
}
continue // retry lower protocol version
} else {
cc.protocolVersion = protoVer
return newConn, err // we may have successfully established connection or faced other error
}
}
Expand Down Expand Up @@ -429,7 +427,7 @@ func (cc *ControlConn) Close() {
}

func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]*Host, error) {
localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx)
localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch information from system.local table: %w", err)
}
Expand All @@ -452,7 +450,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]
}
}

peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx)
peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch information from system.peers table: %w", err)
}
Expand Down
19 changes: 12 additions & 7 deletions proxy/pkg/zdmproxy/cqlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -32,7 +33,7 @@ type CqlConnection interface {
SendAndReceive(request *frame.Frame, ctx context.Context) (*frame.Frame, error)
Close() error
Execute(msg message.Message, ctx context.Context) (message.Message, error)
Query(cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error)
Query(cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error)
SendHeartbeat(ctx context.Context) error
SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection))
SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error
Expand All @@ -59,7 +60,7 @@ type cqlConn struct {
eventHandlerLock *sync.Mutex
authEnabled bool
frameProcessor FrameProcessor
protocolVersion primitive.ProtocolVersion
protocolVersion *atomic.Value
}

var (
Expand Down Expand Up @@ -238,7 +239,8 @@ func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx conte
return fmt.Errorf("failed to perform handshake: %w", err)
}

c.protocolVersion = version
c.protocolVersion = &atomic.Value{}
c.protocolVersion.Store(version)
c.initialized = true
c.authEnabled = authEnabled
return nil
Expand Down Expand Up @@ -369,15 +371,16 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex
}

func (c *cqlConn) Query(
cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error) {
cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error) {
queryMsg := &message.Query{
Query: cql,
Options: &message.QueryOptions{
Consistency: primitive.ConsistencyLevelLocalQuorum,
},
}

queryFrame := frame.NewFrame(c.protocolVersion, -1, queryMsg)
version := c.protocolVersion.Load().(primitive.ProtocolVersion)
queryFrame := frame.NewFrame(version, -1, queryMsg)
var rowSet *ParsedRowSet
for {
localResponse, err := c.SendAndReceive(queryFrame, ctx)
Expand Down Expand Up @@ -431,7 +434,8 @@ func (c *cqlConn) Query(
}

func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Message, error) {
queryFrame := frame.NewFrame(c.protocolVersion, -1, msg)
version := c.protocolVersion.Load().(primitive.ProtocolVersion)
queryFrame := frame.NewFrame(version, -1, msg)
localResponse, err := c.SendAndReceive(queryFrame, ctx)
if err != nil {
return nil, err
Expand All @@ -442,7 +446,8 @@ func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Mes

func (c *cqlConn) SendHeartbeat(ctx context.Context) error {
optionsMsg := &message.Options{}
heartBeatFrame := frame.NewFrame(c.protocolVersion, -1, optionsMsg)
version := c.protocolVersion.Load().(primitive.ProtocolVersion)
heartBeatFrame := frame.NewFrame(version, -1, optionsMsg)

response, err := c.SendAndReceive(heartBeatFrame, ctx)
if err != nil {
Expand Down

0 comments on commit 72e5518

Please sign in to comment.