Skip to content

Commit

Permalink
Protocol V2 stubbed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-antoniak committed Jun 27, 2024
1 parent 081bec0 commit 65ce9a0
Showing 1 changed file with 77 additions and 8 deletions.
85 changes: 77 additions & 8 deletions integration-tests/protocolv2_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package integration_tests

import (
"context"
"fmt"
"github.com/datastax/go-cassandra-native-protocol/client"
"github.com/datastax/go-cassandra-native-protocol/datatype"
"github.com/datastax/go-cassandra-native-protocol/frame"
"github.com/datastax/go-cassandra-native-protocol/message"
"github.com/datastax/go-cassandra-native-protocol/primitive"
Expand All @@ -11,13 +14,13 @@ import (
"testing"
)

func TestProtocolV2Basic(t *testing.T) {
func TestProtocolV2Connect(t *testing.T) {
originAddress := "127.0.0.2"
targetAddress := "127.0.0.3"

serverConf := setup.NewTestConfig(originAddress, targetAddress)
proxyConf := setup.NewTestConfig(originAddress, targetAddress)
proxyConf.ControlConnMaxProtocolVersion = "2"
proxyConf.ControlConnMaxProtocolVersion = "3" // simulate protocol downgrade to V2

testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false)
require.Nil(t, err)
Expand All @@ -42,8 +45,51 @@ func TestProtocolV2Basic(t *testing.T) {
if proxy != nil {
defer proxy.Shutdown()
}
require.Nil(t, err)
}

func TestProtocolV2Query(t *testing.T) {
originAddress := "127.0.0.2"
targetAddress := "127.0.0.3"

serverConf := setup.NewTestConfig(originAddress, targetAddress)
proxyConf := setup.NewTestConfig(originAddress, targetAddress)
proxyConf.ControlConnMaxProtocolVersion = "2"

testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false)
require.Nil(t, err)
defer testSetup.Cleanup()

originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "")
targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "")

testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{
originRequestHandler.HandleRequest,
client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}),
}
testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{
targetRequestHandler.HandleRequest,
client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}),
}

err = testSetup.Start(nil, false, primitive.ProtocolVersion2)
require.Nil(t, err)

proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy
if proxy != nil {
defer proxy.Shutdown()
}
require.Nil(t, err)

cqlConn, err := testSetup.Client.CqlClient.Connect(context.Background())
query := &message.Query{
Query: "SELECT * FROM fakeks.faketb",
Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne},
}

response, err := cqlConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion2, 0, query))
resultSet := response.Body.Message.(*message.RowsResult).Data
require.Equal(t, 1, len(resultSet))
}

type ProtocolV2RequestHandler struct {
Expand All @@ -66,6 +112,12 @@ func (recv *ProtocolV2RequestHandler) HandleRequest(
ctx client.RequestHandlerContext) (response *frame.Frame) {
switch request.Body.Message.GetOpCode() {
case primitive.OpCodeStartup:
if request.Header.Version != primitive.ProtocolVersion2 {
return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{
ErrorMessage: fmt.Sprintf("Invalid or unsupported protocol version (%d)", request.Header.Version),
})
}
return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{})
case primitive.OpCodeRegister:
return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{})
case primitive.OpCodeQuery:
Expand All @@ -83,19 +135,36 @@ func (recv *ProtocolV2RequestHandler) HandleRequest(
}
return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg)
case "SELECT * FROM system.peers":
sysPeerRow := systemPeersRow(
recv.datacenter,
&net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042},
primitive.ProtocolVersion2,
)
var sysPeerRows message.RowSet
if len(recv.peerIP) > 0 {
sysPeerRows = append(sysPeerRows, systemPeersRow(
recv.datacenter,
&net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042},
primitive.ProtocolVersion2,
))
}
sysPeeMsg := &message.RowsResult{
Metadata: &message.RowsMetadata{
ColumnCount: int32(len(systemPeersColumns)),
Columns: systemPeersColumns,
},
Data: message.RowSet{sysPeerRow},
Data: sysPeerRows,
}
return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg)
case "SELECT * FROM fakeks.faketb":
sysLocMsg := &message.RowsResult{
Metadata: &message.RowsMetadata{
ColumnCount: 2,
Columns: []*message.ColumnMetadata{
{Keyspace: "fakeks", Table: "faketb", Name: "key", Type: datatype.Varchar},
{Keyspace: "fakeks", Table: "faketb", Name: "value", Type: datatype.Uuid},
},
},
Data: message.RowSet{
message.Row{keyValue, hostIdValue},
},
}
return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg)
}
}
return nil
Expand Down

0 comments on commit 65ce9a0

Please sign in to comment.