diff --git a/integration-tests/customhandler_test_utils.go b/integration-tests/customhandler_test_utils.go index 37de98e..7101ca2 100644 --- a/integration-tests/customhandler_test_utils.go +++ b/integration-tests/customhandler_test_utils.go @@ -78,7 +78,10 @@ var systemLocalColumns = []*message.ColumnMetadata{ tokensColumn, } -var systemLocalColumnsV2 = []*message.ColumnMetadata{ +// These columns are a subset of the total columns returned by OSS C* 2.0.0, and contain all the information that +// drivers need in order to establish the cluster topology and determine its characteristics. Please note that RPC address +// column is not present. +var systemLocalColumnsProtocolV2 = []*message.ColumnMetadata{ keyColumn, clusterNameColumn, cqlVersionColumn, @@ -125,10 +128,10 @@ var ( schemaVersionValue = message.Column{0xC0, 0xD1, 0xD2, 0x1E, 0xBB, 0x01, 0x41, 0x96, 0x86, 0xDB, 0xBC, 0x31, 0x7B, 0xC1, 0x79, 0x6A} ) -func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr *net.Addr, version primitive.ProtocolVersion) message.Row { +func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr net.Addr, version primitive.ProtocolVersion) message.Row { addrBuf := &bytes.Buffer{} if addr != nil { - inetAddr := (*addr).(*net.TCPAddr).IP + inetAddr := addr.(*net.TCPAddr).IP if inetAddr.To4() != nil { addrBuf.Write(inetAddr.To4()) } else { @@ -150,7 +153,7 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string, if customPartitioner != "" { partitionerValue = message.Column(customPartitioner) } - if addrBuf.Len() > 0 { + if version >= primitive.ProtocolVersion3 { return message.Row{ keyValue, addrBuf.Bytes(), @@ -182,8 +185,7 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string, } func fullSystemLocal(cluster string, datacenter string, customPartitioner string, request *frame.Frame, conn *client.CqlServerConnection) *frame.Frame { - localAddress := conn.LocalAddr() - systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, &localAddress, request.Header.Version) + systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, conn.LocalAddr(), request.Header.Version) msg := &message.RowsResult{ Metadata: &message.RowsMetadata{ ColumnCount: int32(len(systemLocalColumns)), diff --git a/integration-tests/protocolv2_test.go b/integration-tests/protocolv2_test.go index 65d1f46..5b0f57f 100644 --- a/integration-tests/protocolv2_test.go +++ b/integration-tests/protocolv2_test.go @@ -128,8 +128,8 @@ func (recv *ProtocolV2RequestHandler) HandleRequest( sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) sysLocMsg := &message.RowsResult{ Metadata: &message.RowsMetadata{ - ColumnCount: int32(len(systemLocalColumnsV2)), - Columns: systemLocalColumnsV2, + ColumnCount: int32(len(systemLocalColumnsProtocolV2)), + Columns: systemLocalColumnsProtocolV2, }, Data: message.RowSet{sysLocRow}, } diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index a249cbf..9d60270 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "fmt" + "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/datastax/zdm-proxy/proxy/pkg/common" "github.com/kelseyhightower/envconfig" log "github.com/sirupsen/logrus" @@ -342,22 +343,22 @@ func (c *Config) ParseReadMode() (common.ReadMode, error) { } } -func (c *Config) ParseControlConnMaxProtocolVersion() (uint, error) { - switch c.ControlConnMaxProtocolVersion { - case "Dse2": - return 0b_1_000010, nil - case "Dse1": - return 0b_1_000001, nil +func (c *Config) ParseControlConnMaxProtocolVersion() (primitive.ProtocolVersion, error) { + if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV2") { + return primitive.ProtocolVersionDse2, nil + } + if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV1") { + return primitive.ProtocolVersionDse1, nil } ver, err := strconv.ParseUint(c.ControlConnMaxProtocolVersion, 10, 32) if err != nil { return 0, fmt.Errorf("could not parse control connection max protocol version, valid values are "+ - "2, 3, 4, Dse1, Dse2; original err: %w", err) + "2, 3, 4, DseV1, DseV2; original err: %w", err) } if ver < 2 || ver > 4 { - return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2") + return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2") } - return uint(ver), nil + return primitive.ProtocolVersion(ver), nil } func (c *Config) ParseLogLevel() (log.Level, error) { diff --git a/proxy/pkg/config/config_test.go b/proxy/pkg/config/config_test.go index 6da7c43..cea6ce4 100644 --- a/proxy/pkg/config/config_test.go +++ b/proxy/pkg/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/stretchr/testify/require" "testing" ) @@ -111,56 +112,62 @@ func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { tests := []struct { name string controlConnMaxProtocolVersion string - parsedProtocolVersion uint + parsedProtocolVersion primitive.ProtocolVersion errorMessage string }{ { name: "ParsedV2", controlConnMaxProtocolVersion: "2", - parsedProtocolVersion: 2, + parsedProtocolVersion: primitive.ProtocolVersion2, errorMessage: "", }, { name: "ParsedV3", controlConnMaxProtocolVersion: "3", - parsedProtocolVersion: 3, + parsedProtocolVersion: primitive.ProtocolVersion3, errorMessage: "", }, { name: "ParsedV4", controlConnMaxProtocolVersion: "4", - parsedProtocolVersion: 4, + parsedProtocolVersion: primitive.ProtocolVersion4, errorMessage: "", }, { name: "ParsedDse1", - controlConnMaxProtocolVersion: "Dse1", - parsedProtocolVersion: 65, + controlConnMaxProtocolVersion: "DseV1", + parsedProtocolVersion: primitive.ProtocolVersionDse1, errorMessage: "", }, { name: "ParsedDse2", - controlConnMaxProtocolVersion: "Dse2", - parsedProtocolVersion: 66, + controlConnMaxProtocolVersion: "DseV2", + parsedProtocolVersion: primitive.ProtocolVersionDse2, + errorMessage: "", + }, + { + name: "ParsedDse2CaseInsensitive", + controlConnMaxProtocolVersion: "dsev2", + parsedProtocolVersion: primitive.ProtocolVersionDse2, errorMessage: "", }, { name: "UnsupportedCassandraV5", controlConnMaxProtocolVersion: "5", parsedProtocolVersion: 0, - errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", }, { name: "UnsupportedCassandraV1", controlConnMaxProtocolVersion: "1", parsedProtocolVersion: 0, - errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", }, { name: "InvalidValue", controlConnMaxProtocolVersion: "Dsev123", parsedProtocolVersion: 0, - errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", }, } diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 1c3e962..4a2bc45 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -361,16 +361,16 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( } conn = newConn - log.Infof("Successfully opened control connection to %v using endpoint %v.", - cc.connConfig.GetClusterType(), endpoint.String()) + log.Infof("Successfully opened control connection to %v using endpoint %v with %v.", + cc.connConfig.GetClusterType(), endpoint.String(), newConn.GetProtocolVersion().Load().(primitive.ProtocolVersion)) break } return conn, endpoint } -func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer uint, ctx context.Context) (CqlConnection, error) { - protoVer := primitive.ProtocolVersion(initialProtoVer) +func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer primitive.ProtocolVersion, ctx context.Context) (CqlConnection, error) { + protoVer := initialProtoVer for { tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false) if err != nil { @@ -389,7 +389,7 @@ func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoV log.Errorf("Failed to close cql connection: %v", err2) } protoVer = downgradeProtocol(protoVer) - log.Infof("Downgrading protocol version: %v", protoVer) + log.Debugf("Downgrading protocol version: %v", protoVer) if protoVer == 0 { // we cannot downgrade anymore return nil, err diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 929ec46..1c26651 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -19,8 +19,7 @@ import ( ) const ( - eventQueueLength = 2048 - eventQueueLengthV2 = 128 + eventQueueLength = 2048 maxIncomingPending = 2048 maxIncomingPendingV2 = 128 @@ -102,7 +101,7 @@ func NewCqlConnection( wg: &sync.WaitGroup{}, // protoVer is the proposed protocol version using which we will try to establish connectivity outgoingCh: make(chan *frame.Frame, maxOutgoingPendingRequests(protoVer)), - eventsQueue: make(chan *frame.Frame, maxEventsQueue(protoVer)), + eventsQueue: make(chan *frame.Frame, eventQueueLength), pendingOperations: make(map[int16]chan *frame.Frame), pendingOperationsLock: &sync.RWMutex{}, timedOutOperations: 0, @@ -126,14 +125,6 @@ func maxOutgoingPendingRequests(protocolVersion primitive.ProtocolVersion) int { return maxOutgoingPending } -func maxEventsQueue(protocolVersion primitive.ProtocolVersion) int { - switch protocolVersion { - case primitive.ProtocolVersion2: - return eventQueueLengthV2 - } - return eventQueueLength -} - func (c *cqlConn) SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) { c.eventHandlerLock.Lock() defer c.eventHandlerLock.Unlock()