diff --git a/integration-tests/customhandler_test_utils.go b/integration-tests/customhandler_test_utils.go index 97fdc7a..37de98e 100644 --- a/integration-tests/customhandler_test_utils.go +++ b/integration-tests/customhandler_test_utils.go @@ -78,6 +78,19 @@ var systemLocalColumns = []*message.ColumnMetadata{ tokensColumn, } +var systemLocalColumnsV2 = []*message.ColumnMetadata{ + keyColumn, + clusterNameColumn, + cqlVersionColumn, + datacenterColumn, + hostIdColumn, + partitionerColumn, + rackColumn, + releaseVersionColumn, + schemaVersionColumn, + tokensColumn, +} + var ( peerColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "peer", Type: datatype.Inet} datacenterPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "data_center", Type: datatype.Varchar} @@ -112,13 +125,15 @@ var ( schemaVersionValue = message.Column{0xC0, 0xD1, 0xD2, 0x1E, 0xBB, 0x01, 0x41, 0x96, 0x86, 0xDB, 0xBC, 0x31, 0x7B, 0xC1, 0x79, 0x6A} ) -func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr net.Addr, version primitive.ProtocolVersion) message.Row { +func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr *net.Addr, version primitive.ProtocolVersion) message.Row { addrBuf := &bytes.Buffer{} - inetAddr := addr.(*net.TCPAddr).IP - if inetAddr.To4() != nil { - addrBuf.Write(inetAddr.To4()) - } else { - addrBuf.Write(inetAddr) + if addr != nil { + inetAddr := (*addr).(*net.TCPAddr).IP + if inetAddr.To4() != nil { + addrBuf.Write(inetAddr.To4()) + } else { + addrBuf.Write(inetAddr) + } } // emulate {'-9223372036854775808'} (entire ring) tokensBuf := &bytes.Buffer{} @@ -135,25 +150,40 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string, if customPartitioner != "" { partitionerValue = message.Column(customPartitioner) } + if addrBuf.Len() > 0 { + return message.Row{ + keyValue, + addrBuf.Bytes(), + message.Column(cluster), + cqlVersionValue, + message.Column(datacenter), + hostIdValue, + addrBuf.Bytes(), + partitionerValue, + rackValue, + releaseVersionValue, + addrBuf.Bytes(), + schemaVersionValue, + tokensBuf.Bytes(), + } + } return message.Row{ keyValue, - addrBuf.Bytes(), message.Column(cluster), cqlVersionValue, message.Column(datacenter), hostIdValue, - addrBuf.Bytes(), partitionerValue, rackValue, releaseVersionValue, - addrBuf.Bytes(), schemaVersionValue, tokensBuf.Bytes(), } } func fullSystemLocal(cluster string, datacenter string, customPartitioner string, request *frame.Frame, conn *client.CqlServerConnection) *frame.Frame { - systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, conn.LocalAddr(), request.Header.Version) + localAddress := conn.LocalAddr() + systemLocalRow := systemLocalRow(cluster, datacenter, customPartitioner, &localAddress, request.Header.Version) msg := &message.RowsResult{ Metadata: &message.RowsMetadata{ ColumnCount: int32(len(systemLocalColumns)), diff --git a/integration-tests/protocolv2_test.go b/integration-tests/protocolv2_test.go new file mode 100644 index 0000000..8cba750 --- /dev/null +++ b/integration-tests/protocolv2_test.go @@ -0,0 +1,102 @@ +package integration_tests + +import ( + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/stretchr/testify/require" + "net" + "testing" +) + +func TestProtocolV2Basic(t *testing.T) { + originAddress := "127.0.0.2" + targetAddress := "127.0.0.3" + + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf.ControlConnMaxProtocolVersion = "2" + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "127.0.0.4") + targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "127.0.0.5") + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, primitive.ProtocolVersion2) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + + require.Nil(t, err) +} + +type ProtocolV2RequestHandler struct { + cluster string + datacenter string + peerIP string +} + +func NewProtocolV2RequestHandler(cluster string, datacenter string, peerIP string) *ProtocolV2RequestHandler { + return &ProtocolV2RequestHandler{ + cluster: cluster, + datacenter: datacenter, + peerIP: peerIP, + } +} + +func (recv *ProtocolV2RequestHandler) HandleRequest( + request *frame.Frame, + conn *client.CqlServerConnection, + ctx client.RequestHandlerContext) (response *frame.Frame) { + switch request.Body.Message.GetOpCode() { + case primitive.OpCodeStartup: + case primitive.OpCodeRegister: + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeQuery: + query := request.Body.Message.(*message.Query) + switch query.Query { + case "SELECT * FROM system.local": + // C* 2.0.0 does not store local endpoint details in system.local table + sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) + sysLocMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumnsV2)), + Columns: systemLocalColumnsV2, + }, + Data: message.RowSet{sysLocRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) + case "SELECT * FROM system.peers": + sysPeerRow := systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + primitive.ProtocolVersion2, + ) + sysPeeMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemPeersColumns)), + Columns: systemPeersColumns, + }, + Data: message.RowSet{sysPeerRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + } + } + return nil +}