From 65ce9a0243a63c8a674396ef7fb1375245470514 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 27 Jun 2024 16:13:58 +0200 Subject: [PATCH] Protocol V2 stubbed tests --- integration-tests/protocolv2_test.go | 85 +++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 8 deletions(-) diff --git a/integration-tests/protocolv2_test.go b/integration-tests/protocolv2_test.go index 8cba750..65d1f46 100644 --- a/integration-tests/protocolv2_test.go +++ b/integration-tests/protocolv2_test.go @@ -1,7 +1,10 @@ package integration_tests import ( + "context" + "fmt" "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" @@ -11,13 +14,13 @@ import ( "testing" ) -func TestProtocolV2Basic(t *testing.T) { +func TestProtocolV2Connect(t *testing.T) { originAddress := "127.0.0.2" targetAddress := "127.0.0.3" serverConf := setup.NewTestConfig(originAddress, targetAddress) proxyConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf.ControlConnMaxProtocolVersion = "2" + proxyConf.ControlConnMaxProtocolVersion = "3" // simulate protocol downgrade to V2 testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) require.Nil(t, err) @@ -42,8 +45,51 @@ func TestProtocolV2Basic(t *testing.T) { if proxy != nil { defer proxy.Shutdown() } + require.Nil(t, err) +} + +func TestProtocolV2Query(t *testing.T) { + originAddress := "127.0.0.2" + targetAddress := "127.0.0.3" + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf.ControlConnMaxProtocolVersion = "2" + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "") + targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "") + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, primitive.ProtocolVersion2) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + cqlConn, err := testSetup.Client.CqlClient.Connect(context.Background()) + query := &message.Query{ + Query: "SELECT * FROM fakeks.faketb", + Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne}, + } + + response, err := cqlConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion2, 0, query)) + resultSet := response.Body.Message.(*message.RowsResult).Data + require.Equal(t, 1, len(resultSet)) } type ProtocolV2RequestHandler struct { @@ -66,6 +112,12 @@ func (recv *ProtocolV2RequestHandler) HandleRequest( ctx client.RequestHandlerContext) (response *frame.Frame) { switch request.Body.Message.GetOpCode() { case primitive.OpCodeStartup: + if request.Header.Version != primitive.ProtocolVersion2 { + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ + ErrorMessage: fmt.Sprintf("Invalid or unsupported protocol version (%d)", request.Header.Version), + }) + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) case primitive.OpCodeRegister: return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) case primitive.OpCodeQuery: @@ -83,19 +135,36 @@ func (recv *ProtocolV2RequestHandler) HandleRequest( } return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) case "SELECT * FROM system.peers": - sysPeerRow := systemPeersRow( - recv.datacenter, - &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, - primitive.ProtocolVersion2, - ) + var sysPeerRows message.RowSet + if len(recv.peerIP) > 0 { + sysPeerRows = append(sysPeerRows, systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + primitive.ProtocolVersion2, + )) + } sysPeeMsg := &message.RowsResult{ Metadata: &message.RowsMetadata{ ColumnCount: int32(len(systemPeersColumns)), Columns: systemPeersColumns, }, - Data: message.RowSet{sysPeerRow}, + Data: sysPeerRows, } return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + case "SELECT * FROM fakeks.faketb": + sysLocMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 2, + Columns: []*message.ColumnMetadata{ + {Keyspace: "fakeks", Table: "faketb", Name: "key", Type: datatype.Varchar}, + {Keyspace: "fakeks", Table: "faketb", Name: "value", Type: datatype.Uuid}, + }, + }, + Data: message.RowSet{ + message.Row{keyValue, hostIdValue}, + }, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) } } return nil