Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-antoniak committed Jun 28, 2024
1 parent 58680ad commit 7247e64
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 44 deletions.
14 changes: 8 additions & 6 deletions integration-tests/customhandler_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ var systemLocalColumns = []*message.ColumnMetadata{
tokensColumn,
}

var systemLocalColumnsV2 = []*message.ColumnMetadata{
// These columns are a subset of the total columns returned by OSS C* 2.0.0, and contain all the information that
// drivers need in order to establish the cluster topology and determine its characteristics. Please note that RPC address
// column is not present.
var systemLocalColumnsProtocolV2 = []*message.ColumnMetadata{
keyColumn,
clusterNameColumn,
cqlVersionColumn,
Expand Down Expand Up @@ -125,10 +128,10 @@ var (
schemaVersionValue = message.Column{0xC0, 0xD1, 0xD2, 0x1E, 0xBB, 0x01, 0x41, 0x96, 0x86, 0xDB, 0xBC, 0x31, 0x7B, 0xC1, 0x79, 0x6A}
)

func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr *net.Addr, version primitive.ProtocolVersion) message.Row {
func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr net.Addr, version primitive.ProtocolVersion) message.Row {
addrBuf := &bytes.Buffer{}
if addr != nil {
inetAddr := (*addr).(*net.TCPAddr).IP
inetAddr := addr.(*net.TCPAddr).IP
if inetAddr.To4() != nil {
addrBuf.Write(inetAddr.To4())
} else {
Expand All @@ -150,7 +153,7 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string,
if customPartitioner != "" {
partitionerValue = message.Column(customPartitioner)
}
if addrBuf.Len() > 0 {
if version >= primitive.ProtocolVersion3 {
return message.Row{
keyValue,
addrBuf.Bytes(),
Expand Down Expand Up @@ -182,8 +185,7 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string,
}

func fullSystemLocal(cluster string, datacenter string, customPartitioner string, request *frame.Frame, conn *client.CqlServerConnection) *frame.Frame {
localAddress := conn.LocalAddr()
systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, &localAddress, request.Header.Version)
systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, conn.LocalAddr(), request.Header.Version)
msg := &message.RowsResult{
Metadata: &message.RowsMetadata{
ColumnCount: int32(len(systemLocalColumns)),
Expand Down
4 changes: 2 additions & 2 deletions integration-tests/protocolv2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ func (recv *ProtocolV2RequestHandler) HandleRequest(
sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version)
sysLocMsg := &message.RowsResult{
Metadata: &message.RowsMetadata{
ColumnCount: int32(len(systemLocalColumnsV2)),
Columns: systemLocalColumnsV2,
ColumnCount: int32(len(systemLocalColumnsProtocolV2)),
Columns: systemLocalColumnsProtocolV2,
},
Data: message.RowSet{sysLocRow},
}
Expand Down
19 changes: 10 additions & 9 deletions proxy/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"fmt"
"github.com/datastax/go-cassandra-native-protocol/primitive"
"github.com/datastax/zdm-proxy/proxy/pkg/common"
"github.com/kelseyhightower/envconfig"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -342,22 +343,22 @@ func (c *Config) ParseReadMode() (common.ReadMode, error) {
}
}

func (c *Config) ParseControlConnMaxProtocolVersion() (uint, error) {
switch c.ControlConnMaxProtocolVersion {
case "Dse2":
return 0b_1_000010, nil
case "Dse1":
return 0b_1_000001, nil
func (c *Config) ParseControlConnMaxProtocolVersion() (primitive.ProtocolVersion, error) {
if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV2") {
return primitive.ProtocolVersionDse2, nil
}
if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV1") {
return primitive.ProtocolVersionDse1, nil
}
ver, err := strconv.ParseUint(c.ControlConnMaxProtocolVersion, 10, 32)
if err != nil {
return 0, fmt.Errorf("could not parse control connection max protocol version, valid values are "+
"2, 3, 4, Dse1, Dse2; original err: %w", err)
"2, 3, 4, DseV1, DseV2; original err: %w", err)
}
if ver < 2 || ver > 4 {
return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2")
return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2")
}
return uint(ver), nil
return primitive.ProtocolVersion(ver), nil
}

func (c *Config) ParseLogLevel() (log.Level, error) {
Expand Down
29 changes: 18 additions & 11 deletions proxy/pkg/config/config_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"github.com/datastax/go-cassandra-native-protocol/primitive"
"github.com/stretchr/testify/require"
"testing"
)
Expand Down Expand Up @@ -111,56 +112,62 @@ func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) {
tests := []struct {
name string
controlConnMaxProtocolVersion string
parsedProtocolVersion uint
parsedProtocolVersion primitive.ProtocolVersion
errorMessage string
}{
{
name: "ParsedV2",
controlConnMaxProtocolVersion: "2",
parsedProtocolVersion: 2,
parsedProtocolVersion: primitive.ProtocolVersion2,
errorMessage: "",
},
{
name: "ParsedV3",
controlConnMaxProtocolVersion: "3",
parsedProtocolVersion: 3,
parsedProtocolVersion: primitive.ProtocolVersion3,
errorMessage: "",
},
{
name: "ParsedV4",
controlConnMaxProtocolVersion: "4",
parsedProtocolVersion: 4,
parsedProtocolVersion: primitive.ProtocolVersion4,
errorMessage: "",
},
{
name: "ParsedDse1",
controlConnMaxProtocolVersion: "Dse1",
parsedProtocolVersion: 65,
controlConnMaxProtocolVersion: "DseV1",
parsedProtocolVersion: primitive.ProtocolVersionDse1,
errorMessage: "",
},
{
name: "ParsedDse2",
controlConnMaxProtocolVersion: "Dse2",
parsedProtocolVersion: 66,
controlConnMaxProtocolVersion: "DseV2",
parsedProtocolVersion: primitive.ProtocolVersionDse2,
errorMessage: "",
},
{
name: "ParsedDse2CaseInsensitive",
controlConnMaxProtocolVersion: "dsev2",
parsedProtocolVersion: primitive.ProtocolVersionDse2,
errorMessage: "",
},
{
name: "UnsupportedCassandraV5",
controlConnMaxProtocolVersion: "5",
parsedProtocolVersion: 0,
errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2",
errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2",
},
{
name: "UnsupportedCassandraV1",
controlConnMaxProtocolVersion: "1",
parsedProtocolVersion: 0,
errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2",
errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2",
},
{
name: "InvalidValue",
controlConnMaxProtocolVersion: "Dsev123",
parsedProtocolVersion: 0,
errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2",
errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2",
},
}

Expand Down
10 changes: 5 additions & 5 deletions proxy/pkg/zdmproxy/controlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,16 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) (
}

conn = newConn
log.Infof("Successfully opened control connection to %v using endpoint %v.",
cc.connConfig.GetClusterType(), endpoint.String())
log.Infof("Successfully opened control connection to %v using endpoint %v with %v.",
cc.connConfig.GetClusterType(), endpoint.String(), newConn.GetProtocolVersion().Load().(primitive.ProtocolVersion))
break
}

return conn, endpoint
}

func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer uint, ctx context.Context) (CqlConnection, error) {
protoVer := primitive.ProtocolVersion(initialProtoVer)
func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer primitive.ProtocolVersion, ctx context.Context) (CqlConnection, error) {
protoVer := initialProtoVer
for {
tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false)
if err != nil {
Expand All @@ -389,7 +389,7 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV
log.Errorf("Failed to close cql connection: %v", err2)
}
protoVer = downgradeProtocol(protoVer)
log.Infof("Downgrading protocol version: %v", protoVer)
log.Debugf("Downgrading protocol version: %v", protoVer)
if protoVer == 0 {
// we cannot downgrade anymore
return nil, err
Expand Down
13 changes: 2 additions & 11 deletions proxy/pkg/zdmproxy/cqlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ import (
)

const (
eventQueueLength = 2048
eventQueueLengthV2 = 128
eventQueueLength = 2048

maxIncomingPending = 2048
maxIncomingPendingV2 = 128
Expand Down Expand Up @@ -102,7 +101,7 @@ func NewCqlConnection(
wg: &sync.WaitGroup{},
// protoVer is the proposed protocol version using which we will try to establish connectivity
outgoingCh: make(chan *frame.Frame, maxOutgoingPendingRequests(protoVer)),
eventsQueue: make(chan *frame.Frame, maxEventsQueue(protoVer)),
eventsQueue: make(chan *frame.Frame, eventQueueLength),
pendingOperations: make(map[int16]chan *frame.Frame),
pendingOperationsLock: &sync.RWMutex{},
timedOutOperations: 0,
Expand All @@ -126,14 +125,6 @@ func maxOutgoingPendingRequests(protocolVersion primitive.ProtocolVersion) int {
return maxOutgoingPending
}

func maxEventsQueue(protocolVersion primitive.ProtocolVersion) int {
switch protocolVersion {
case primitive.ProtocolVersion2:
return eventQueueLengthV2
}
return eventQueueLength
}

func (c *cqlConn) SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) {
c.eventHandlerLock.Lock()
defer c.eventHandlerLock.Unlock()
Expand Down

0 comments on commit 7247e64

Please sign in to comment.