Skip to content

Commit

Permalink
Validation of protocol version
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-antoniak committed Jun 20, 2024
1 parent a244c8b commit 3002a92
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 27 deletions.
71 changes: 52 additions & 19 deletions integration-tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,64 @@ func TestGoCqlConnect(t *testing.T) {
}

func TestProtocolVersionNegotiation(t *testing.T) {
c := setup.NewTestConfig("", "")
c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version
testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, &simulacron.ClusterVersion{"2.1", "2.1"})
require.Nil(t, err)
defer testSetup.Cleanup()
tests := []struct {
name string
clusterVersion string
controlConnMaxProtocolVersion string
negotiatedProtocolVersion primitive.ProtocolVersion
}{
{
name: "Cluster2.1_MaxCCProtoVer4_NegotiatedProtoVer3",
clusterVersion: "2.1",
controlConnMaxProtocolVersion: "4",
negotiatedProtocolVersion: primitive.ProtocolVersion3, // protocol downgraded to V3, V4 is not supported
},
{
name: "Cluster3.0_MaxCCProtoVer4_NegotiatedProtoVer4",
clusterVersion: "3.0",
controlConnMaxProtocolVersion: "4",
negotiatedProtocolVersion: primitive.ProtocolVersion4,
},
{
name: "Cluster4.0_MaxCCProtoVer4_NegotiatedProtoVer4",
clusterVersion: "4.0",
controlConnMaxProtocolVersion: "4",
negotiatedProtocolVersion: primitive.ProtocolVersion4,
},
}

// Connect to proxy as a "client"
proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := setup.NewTestConfig("", "")
c.ControlConnMaxProtocolVersion = tt.controlConnMaxProtocolVersion
testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c,
&simulacron.ClusterVersion{tt.clusterVersion, tt.clusterVersion})
require.Nil(t, err)
defer testSetup.Cleanup()

if err != nil {
t.Fatal("Unable to connect to proxy session.")
}
defer proxy.Close()
// Connect to proxy as a "client"
proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3)

iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter()
result, err := iter.SliceMap()
if err != nil {
t.Fatal("Unable to connect to proxy session.")
}
defer proxy.Close()

if err != nil {
t.Fatal("query failed:", err)
}
cqlConn, _ := testSetup.Proxy.GetOriginControlConn().GetConnAndContactPoint()
negotiatedProto := cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion)

require.Equal(t, 0, len(result))
require.Equal(t, tt.negotiatedProtocolVersion, negotiatedProto)

// simulacron generates fake response metadata when queries aren't primed
require.Equal(t, "fake", iter.Columns()[0].Name)
iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter()
result, err := iter.SliceMap()

if err != nil {
t.Fatal("query failed:", err)
}

require.Equal(t, 0, len(result))
})
}
}

func TestMaxClientsThreshold(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion integration-tests/setup/testcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config {
conf.ReadMode = config.ReadModePrimaryOnly
conf.SystemQueriesMode = config.SystemQueriesModeOrigin
conf.AsyncHandshakeTimeoutMs = 4000
conf.ControlConnMaxProtocolVersion = 3
conf.ControlConnMaxProtocolVersion = "3"

conf.ProxyRequestTimeoutMs = 10000

Expand Down
25 changes: 24 additions & 1 deletion proxy/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type Config struct {
ReplaceCqlFunctions bool `default:"false" split_words:"true"`
AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"`
LogLevel string `default:"INFO" split_words:"true"`
ControlConnMaxProtocolVersion uint `default:"3" split_words:"true"`
ControlConnMaxProtocolVersion string `default:"3" split_words:"true"` // Numeric Cassandra OSS protocol version or Dse1 / Dse2

// Proxy Topology (also known as system.peers "virtualization") bucket

Expand Down Expand Up @@ -283,6 +283,11 @@ func (c *Config) Validate() error {
return err
}

_, err = c.ParseControlConnMaxProtocolVersion()
if err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -337,6 +342,24 @@ func (c *Config) ParseReadMode() (common.ReadMode, error) {
}
}

func (c *Config) ParseControlConnMaxProtocolVersion() (uint, error) {
switch c.ControlConnMaxProtocolVersion {
case "Dse2":
return 0b_1_000010, nil
case "Dse1":
return 0b_1_000001, nil
}
ver, err := strconv.ParseUint(c.ControlConnMaxProtocolVersion, 10, 32)
if err != nil {
return 0, fmt.Errorf("could not parse control connection max protocol version, valid values are "+
"2, 3, 4, Dse1, Dse2; original err: %w", err)
}
if ver < 2 || ver > 4 {
return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2")
}
return uint(ver), nil
}

func (c *Config) ParseLogLevel() (log.Level, error) {
level, err := log.ParseLevel(strings.TrimSpace(c.LogLevel))
if err != nil {
Expand Down
84 changes: 84 additions & 0 deletions proxy/pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,87 @@ func TestTargetConfig_WithHostnameButWithoutPort(t *testing.T) {
require.Nil(t, err)
require.Equal(t, 9042, c.TargetPort)
}

func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) {
defer clearAllEnvVars()

// general setup
clearAllEnvVars()
setOriginCredentialsEnvVars()
setTargetCredentialsEnvVars()
setOriginContactPointsAndPortEnvVars()

// test-specific setup
setTargetContactPointsAndPortEnvVars()

conf, _ := New().ParseEnvVars()

tests := []struct {
name string
controlConnMaxProtocolVersion string
parsedProtocolVersion uint
errorMessage string
}{
{
name: "ParsedV2",
controlConnMaxProtocolVersion: "2",
parsedProtocolVersion: 2,
errorMessage: "",
},
{
name: "ParsedV3",
controlConnMaxProtocolVersion: "3",
parsedProtocolVersion: 3,
errorMessage: "",
},
{
name: "ParsedV4",
controlConnMaxProtocolVersion: "4",
parsedProtocolVersion: 4,
errorMessage: "",
},
{
name: "ParsedDse1",
controlConnMaxProtocolVersion: "Dse1",
parsedProtocolVersion: 65,
errorMessage: "",
},
{
name: "ParsedDse2",
controlConnMaxProtocolVersion: "Dse2",
parsedProtocolVersion: 66,
errorMessage: "",
},
{
name: "UnsupportedCassandraV5",
controlConnMaxProtocolVersion: "5",
parsedProtocolVersion: 0,
errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2",
},
{
name: "UnsupportedCassandraV1",
controlConnMaxProtocolVersion: "1",
parsedProtocolVersion: 0,
errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2",
},
{
name: "InvalidValue",
controlConnMaxProtocolVersion: "Dsev123",
parsedProtocolVersion: 0,
errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf.ControlConnMaxProtocolVersion = tt.controlConnMaxProtocolVersion
ver, err := conf.ParseControlConnMaxProtocolVersion()
if ver == 0 {
require.NotNil(t, err)
require.Contains(t, err.Error(), tt.errorMessage)
} else {
require.Equal(t, tt.parsedProtocolVersion, ver)
}
})
}
}
11 changes: 6 additions & 5 deletions proxy/pkg/zdmproxy/controlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (cc *ControlConn) Start(wg *sync.WaitGroup, ctx context.Context) error {

log.Infof("Received topology event from %v, refreshing topology.", cc.connConfig.GetClusterType())

conn, _ := cc.getConnAndContactPoint()
conn, _ := cc.GetConnAndContactPoint()
if conn == nil {
log.Debugf("Topology refresh scheduled but the control connection isn't open. " +
"Falling back to the connection where the event was received.")
Expand Down Expand Up @@ -162,7 +162,7 @@ func (cc *ControlConn) Start(wg *sync.WaitGroup, ctx context.Context) error {
cc.Close()
}

conn, _ := cc.getConnAndContactPoint()
conn, _ := cc.GetConnAndContactPoint()
if conn == nil {
useContactPointsOnly := false
if !lastOpenSuccessful {
Expand Down Expand Up @@ -251,7 +251,7 @@ func (cc *ControlConn) ReadFailureCounter() int {
}

func (cc *ControlConn) Open(contactPointsOnly bool, ctx context.Context) (CqlConnection, error) {
oldConn, _ := cc.getConnAndContactPoint()
oldConn, _ := cc.GetConnAndContactPoint()
if oldConn != nil {
cc.Close()
oldConn = nil
Expand Down Expand Up @@ -321,7 +321,8 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (
currentIndex := (firstEndpointIndex + i) % len(endpoints)
endpoint = endpoints[currentIndex]

newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ControlConnMaxProtocolVersion, ctx)
maxProtoVer, _ := cc.conf.ParseControlConnMaxProtocolVersion()
newConn, err := cc.connAndNegotiateProtoVer(endpoint, maxProtoVer, ctx)

if err == nil {
newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) {
Expand Down Expand Up @@ -678,7 +679,7 @@ func (cc *ControlConn) setConn(oldConn CqlConnection, newConn CqlConnection, new
return cc.cqlConn, cc.currentContactPoint
}

func (cc *ControlConn) getConnAndContactPoint() (CqlConnection, Endpoint) {
func (cc *ControlConn) GetConnAndContactPoint() (CqlConnection, Endpoint) {
cc.cqlConnLock.Lock()
conn := cc.cqlConn
contactPoint := cc.currentContactPoint
Expand Down
7 changes: 6 additions & 1 deletion proxy/pkg/zdmproxy/cqlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type CqlConnection interface {
SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection))
SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error
IsAuthEnabled() (bool, error)
GetProtocolVersion() *atomic.Value
}

// Not thread safe
Expand Down Expand Up @@ -98,6 +99,7 @@ func NewCqlConnection(
eventHandlerLock: &sync.Mutex{},
authEnabled: true,
frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(conf.ProxyMaxStreamIds, nil)),
protocolVersion: &atomic.Value{},
}
cqlConn.StartRequestLoop()
cqlConn.StartResponseLoop()
Expand Down Expand Up @@ -233,13 +235,16 @@ func (c *cqlConn) IsAuthEnabled() (bool, error) {
return c.authEnabled, nil
}

func (c *cqlConn) GetProtocolVersion() *atomic.Value {
return c.protocolVersion
}

func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx context.Context) error {
authEnabled, err := c.PerformHandshake(version, ctx)
if err != nil {
return fmt.Errorf("failed to perform handshake: %w", err)
}

c.protocolVersion = &atomic.Value{}
c.protocolVersion.Store(version)
c.initialized = true
c.authEnabled = authEnabled
Expand Down

0 comments on commit 3002a92

Please sign in to comment.