From 4bd37851aab796b44ddaebac3b00304b14cca26e Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 17 Jul 2024 15:50:54 +0200 Subject: [PATCH] New maximum stream IDs test --- README.md | 15 +- integration-tests/connect_test.go | 16 +- integration-tests/protocolv2_test.go | 171 -------------- integration-tests/protocolversions_test.go | 245 +++++++++++++++++++++ integration-tests/streamids_test.go | 210 ++++++++++++++++++ proxy/pkg/zdmproxy/clienthandler.go | 26 ++- proxy/pkg/zdmproxy/clusterconn.go | 5 +- proxy/pkg/zdmproxy/controlconn.go | 2 +- proxy/pkg/zdmproxy/cqlconn.go | 15 +- 9 files changed, 508 insertions(+), 197 deletions(-) delete mode 100644 integration-tests/protocolv2_test.go create mode 100644 integration-tests/protocolversions_test.go create mode 100644 integration-tests/streamids_test.go diff --git a/README.md b/README.md index a73ea1b..723c005 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,18 @@ It technically doesn't support v5, but handles protocol negotiation so that the the protocol version to v4 if v5 is requested. This means that any client application using a recent driver that supports protocol version v5 can be migrated using the ZDM Proxy (as long as it does not use v5-specific functionality). +ZDM Proxy requires origin and target clusters to have at least one protocol version in common. It is therefore not feasible +to configure Apache Cassandra 2.0 as origin and 3.x / 4.x as target. Below table displays protocol versions supported by +various C* versions: + +| Apache Cassandra | Protocol Version | +|------------------|------------------| +| 2.0 | V2 | +| 2.1 | V2, V3 | +| 2.2 | V3, V4 | +| 3.x | V3, V4 | +| 4.x | V3, V4, V5 | + --- :warning: **Thrift is not supported by ZDM Proxy.** If you are using a very old driver or cluster version that only supports Thrift then you need to change your client application to use CQL and potentially upgrade your cluster before starting the @@ -110,7 +122,8 @@ migration process. In practice this means that ZDM Proxy supports the following cluster versions (as Origin and / or Target): -- Apache Cassandra from 2.0+ up to (and including) Apache Cassandra 4.x. +- Apache Cassandra from 2.1+ up to (and including) Apache Cassandra 4.x. +- Apache Cassandra 2.0 up to 2.1. - DataStax Enterprise 4.8+. DataStax Enterprise 4.6 and 4.7 support will be introduced when protocol version v2 is supported. - DataStax Astra DB (both Serverless and Classic) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index b5d6e21..8b31103 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -46,7 +46,9 @@ func TestGoCqlConnect(t *testing.T) { require.Equal(t, "fake", iter.Columns()[0].Name) } -func TestProtocolVersionNegotiation(t *testing.T) { +// Simulacron-based test to make sure that we can handle invalid protocol error and downgrade +// used protocol on control connection. ORIGIN and TARGET are using the same C* version +func TestControlConnectionProtocolVersionNegotiation(t *testing.T) { tests := []struct { name string clusterVersion string @@ -63,13 +65,13 @@ func TestProtocolVersionNegotiation(t *testing.T) { name: "Cluster3.0_MaxCCProtoVer4_NegotiatedProtoVer4", clusterVersion: "3.0", controlConnMaxProtocolVersion: "4", - negotiatedProtocolVersion: primitive.ProtocolVersion4, + negotiatedProtocolVersion: primitive.ProtocolVersion4, // make sure that protocol negotiation does not fail if it is not actually needed }, { - name: "Cluster4.0_MaxCCProtoVer4_NegotiatedProtoVer4", - clusterVersion: "4.0", - controlConnMaxProtocolVersion: "4", - negotiatedProtocolVersion: primitive.ProtocolVersion4, + name: "Cluster3.0_MaxCCProtoVer3_NegotiatedProtoVer3", + clusterVersion: "3.0", + controlConnMaxProtocolVersion: "3", + negotiatedProtocolVersion: primitive.ProtocolVersion3, // protocol V3 applied as it is the maximum configured }, } @@ -103,7 +105,7 @@ func TestProtocolVersionNegotiation(t *testing.T) { defer cqlClientConn.Close() cqlConn, _ := testSetup.Proxy.GetOriginControlConn().GetConnAndContactPoint() - negotiatedProto := cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) + negotiatedProto := cqlConn.GetProtocolVersion() require.Equal(t, tt.negotiatedProtocolVersion, negotiatedProto) queryMsg := &message.Query{ diff --git a/integration-tests/protocolv2_test.go b/integration-tests/protocolv2_test.go deleted file mode 100644 index 5b0f57f..0000000 --- a/integration-tests/protocolv2_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package integration_tests - -import ( - "context" - "fmt" - "github.com/datastax/go-cassandra-native-protocol/client" - "github.com/datastax/go-cassandra-native-protocol/datatype" - "github.com/datastax/go-cassandra-native-protocol/frame" - "github.com/datastax/go-cassandra-native-protocol/message" - "github.com/datastax/go-cassandra-native-protocol/primitive" - "github.com/datastax/zdm-proxy/integration-tests/setup" - "github.com/stretchr/testify/require" - "net" - "testing" -) - -func TestProtocolV2Connect(t *testing.T) { - originAddress := "127.0.0.2" - targetAddress := "127.0.0.3" - - serverConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf.ControlConnMaxProtocolVersion = "3" // simulate protocol downgrade to V2 - - testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) - require.Nil(t, err) - defer testSetup.Cleanup() - - originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "127.0.0.4") - targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "127.0.0.5") - - testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ - originRequestHandler.HandleRequest, - client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), - } - testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ - targetRequestHandler.HandleRequest, - client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), - } - - err = testSetup.Start(nil, false, primitive.ProtocolVersion2) - require.Nil(t, err) - - proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy - if proxy != nil { - defer proxy.Shutdown() - } - require.Nil(t, err) -} - -func TestProtocolV2Query(t *testing.T) { - originAddress := "127.0.0.2" - targetAddress := "127.0.0.3" - - serverConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf := setup.NewTestConfig(originAddress, targetAddress) - proxyConf.ControlConnMaxProtocolVersion = "2" - - testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) - require.Nil(t, err) - defer testSetup.Cleanup() - - originRequestHandler := NewProtocolV2RequestHandler("origin", "dc1", "") - targetRequestHandler := NewProtocolV2RequestHandler("target", "dc1", "") - - testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ - originRequestHandler.HandleRequest, - client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), - } - testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ - targetRequestHandler.HandleRequest, - client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), - } - - err = testSetup.Start(nil, false, primitive.ProtocolVersion2) - require.Nil(t, err) - - proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy - if proxy != nil { - defer proxy.Shutdown() - } - require.Nil(t, err) - - cqlConn, err := testSetup.Client.CqlClient.Connect(context.Background()) - query := &message.Query{ - Query: "SELECT * FROM fakeks.faketb", - Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne}, - } - - response, err := cqlConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion2, 0, query)) - resultSet := response.Body.Message.(*message.RowsResult).Data - require.Equal(t, 1, len(resultSet)) -} - -type ProtocolV2RequestHandler struct { - cluster string - datacenter string - peerIP string -} - -func NewProtocolV2RequestHandler(cluster string, datacenter string, peerIP string) *ProtocolV2RequestHandler { - return &ProtocolV2RequestHandler{ - cluster: cluster, - datacenter: datacenter, - peerIP: peerIP, - } -} - -func (recv *ProtocolV2RequestHandler) HandleRequest( - request *frame.Frame, - conn *client.CqlServerConnection, - ctx client.RequestHandlerContext) (response *frame.Frame) { - switch request.Body.Message.GetOpCode() { - case primitive.OpCodeStartup: - if request.Header.Version != primitive.ProtocolVersion2 { - return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ - ErrorMessage: fmt.Sprintf("Invalid or unsupported protocol version (%d)", request.Header.Version), - }) - } - return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) - case primitive.OpCodeRegister: - return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) - case primitive.OpCodeQuery: - query := request.Body.Message.(*message.Query) - switch query.Query { - case "SELECT * FROM system.local": - // C* 2.0.0 does not store local endpoint details in system.local table - sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) - sysLocMsg := &message.RowsResult{ - Metadata: &message.RowsMetadata{ - ColumnCount: int32(len(systemLocalColumnsProtocolV2)), - Columns: systemLocalColumnsProtocolV2, - }, - Data: message.RowSet{sysLocRow}, - } - return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) - case "SELECT * FROM system.peers": - var sysPeerRows message.RowSet - if len(recv.peerIP) > 0 { - sysPeerRows = append(sysPeerRows, systemPeersRow( - recv.datacenter, - &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, - primitive.ProtocolVersion2, - )) - } - sysPeeMsg := &message.RowsResult{ - Metadata: &message.RowsMetadata{ - ColumnCount: int32(len(systemPeersColumns)), - Columns: systemPeersColumns, - }, - Data: sysPeerRows, - } - return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) - case "SELECT * FROM fakeks.faketb": - sysLocMsg := &message.RowsResult{ - Metadata: &message.RowsMetadata{ - ColumnCount: 2, - Columns: []*message.ColumnMetadata{ - {Keyspace: "fakeks", Table: "faketb", Name: "key", Type: datatype.Varchar}, - {Keyspace: "fakeks", Table: "faketb", Name: "value", Type: datatype.Uuid}, - }, - }, - Data: message.RowSet{ - message.Row{keyValue, hostIdValue}, - }, - } - return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) - } - } - return nil -} diff --git a/integration-tests/protocolversions_test.go b/integration-tests/protocolversions_test.go new file mode 100644 index 0000000..cd9bddd --- /dev/null +++ b/integration-tests/protocolversions_test.go @@ -0,0 +1,245 @@ +package integration_tests + +import ( + "context" + "fmt" + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/datatype" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/stretchr/testify/require" + "net" + "slices" + "testing" +) + +// Test that proxy can establish connectivity with ORIGIN and TARGET +// clusters that support different set of protocol versions. Verify also that +// client driver can connect and successfully insert or query data. +func TestProtocolNegotiationDifferentClusters(t *testing.T) { + tests := []struct { + name string + proxyMaxProtoVer string + originProtoVer []primitive.ProtocolVersion + targetProtoVer []primitive.ProtocolVersion + clientProtoVer primitive.ProtocolVersion + failClientConnect bool + }{ + { + name: "OriginV2_TargetV2_ClientV2", + proxyMaxProtoVer: "2", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV2_TargetV2_ClientV2_ProxyControlConnNegotiation", + proxyMaxProtoVer: "4", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV2_TargetV23_ClientV2", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV23_TargetV2_ClientV2", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + // most common setup with OSS Cassandra + name: "OriginV345_TargetV345_ClientV4", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + clientProtoVer: primitive.ProtocolVersion4, + }, + { + // most common setup with DSE + name: "OriginV345_TargetV34Dse1Dse2_ClientV4", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2}, + clientProtoVer: primitive.ProtocolVersion4, + }, + { + name: "OriginV2_TargetV3_ClientV2", + proxyMaxProtoVer: "3", + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, + // client connection should fail as there is no common protocol version between origin and target + failClientConnect: true, + }, + } + + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + querySelect := &message.Query{ + Query: "SELECT * FROM test_ks.test", + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxyConf.ControlConnMaxProtocolVersion = test.proxyMaxProtoVer + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewProtocolNegotiationRequestHandler("origin", "dc1", originAddress, test.originProtoVer) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, test.targetProtoVer) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, test.clientProtoVer) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), test.clientProtoVer, 0) + if test.failClientConnect { + require.NotNil(t, err) + return + } + require.Nil(t, err) + defer cqlConn.Close() + + response, err := cqlConn.SendAndReceive(frame.NewFrame(test.clientProtoVer, 0, queryInsert)) + require.Nil(t, err) + require.IsType(t, &message.VoidResult{}, response.Body.Message) + + response, err = cqlConn.SendAndReceive(frame.NewFrame(test.clientProtoVer, 0, querySelect)) + require.Nil(t, err) + resultSet := response.Body.Message.(*message.RowsResult).Data + require.Equal(t, 1, len(resultSet)) + }) + } +} + +type ProtocolNegotiationRequestHandler struct { + cluster string + datacenter string + peerIP string + protocolVersions []primitive.ProtocolVersion // accepted protocol versions + // store negotiated protocol versions by socket port number + // protocol version negotiated by proxy on control connections can be different from the one + // used by client driver with ORIGIN and TARGET nodes. In the scenario 'OriginV2_TargetV23_ClientV2', proxy + // will establish control connection with ORIGIN using version 2, and TARGET with version 3. + // Protocol version applied on client connections with TARGET will be different - V2. + negotiatedProtoVer map[int]primitive.ProtocolVersion // negotiated protocol version on different sockets +} + +func NewProtocolNegotiationRequestHandler(cluster string, datacenter string, peerIP string, + protocolVersion []primitive.ProtocolVersion) *ProtocolNegotiationRequestHandler { + return &ProtocolNegotiationRequestHandler{ + cluster: cluster, + datacenter: datacenter, + peerIP: peerIP, + protocolVersions: protocolVersion, + negotiatedProtoVer: make(map[int]primitive.ProtocolVersion), + } +} + +func (recv *ProtocolNegotiationRequestHandler) HandleRequest( + request *frame.Frame, + conn *client.CqlServerConnection, + ctx client.RequestHandlerContext) (response *frame.Frame) { + port := conn.RemoteAddr().(*net.TCPAddr).Port + negotiatedProtoVer := recv.negotiatedProtoVer[port] + if !slices.Contains(recv.protocolVersions, request.Header.Version) || (negotiatedProtoVer != 0 && negotiatedProtoVer != request.Header.Version) { + // server does not support given protocol version, or it was not the one negotiated + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ + ErrorMessage: fmt.Sprintf("Invalid or unsupported protocol version (%d)", request.Header.Version), + }) + } + switch request.Body.Message.GetOpCode() { + case primitive.OpCodeStartup: + recv.negotiatedProtoVer[port] = request.Header.Version + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeRegister: + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeQuery: + query := request.Body.Message.(*message.Query) + switch query.Query { + case "SELECT * FROM system.local": + // C* 2.0.0 does not store local endpoint details in system.local table + sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) + metadata := &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumns)), + Columns: systemLocalColumns, + } + if negotiatedProtoVer == primitive.ProtocolVersion2 { + metadata = &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumnsProtocolV2)), + Columns: systemLocalColumnsProtocolV2, + } + } + sysLocMsg := &message.RowsResult{ + Metadata: metadata, + Data: message.RowSet{sysLocRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) + case "SELECT * FROM system.peers": + var sysPeerRows message.RowSet + if len(recv.peerIP) > 0 { + sysPeerRows = append(sysPeerRows, systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + negotiatedProtoVer, + )) + } + sysPeeMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemPeersColumns)), + Columns: systemPeersColumns, + }, + Data: sysPeerRows, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + case "SELECT * FROM test_ks.test": + qryMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 2, + Columns: []*message.ColumnMetadata{ + {Keyspace: "test_ks", Table: "test", Name: "key", Type: datatype.Varchar}, + {Keyspace: "test_ks", Table: "test", Name: "value", Type: datatype.Uuid}, + }, + }, + Data: message.RowSet{ + message.Row{keyValue, hostIdValue}, + }, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, qryMsg) + case "INSERT INTO test_ks.test(key, value) VALUES(1, '1')": + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.VoidResult{}) + } + } + return nil +} diff --git a/integration-tests/streamids_test.go b/integration-tests/streamids_test.go new file mode 100644 index 0000000..3a781fc --- /dev/null +++ b/integration-tests/streamids_test.go @@ -0,0 +1,210 @@ +package integration_tests + +import ( + "context" + "fmt" + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/utils" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "net" + "strings" + "sync" + "testing" + "time" +) + +// Test sending more concurrent, async request than allowed stream IDs. +// Origin and target clusters are stubbed and will return protocol error +// if we notice greater stream ID value than expected. We cannot easily test +// exceeding 127 stream IDs allowed in protocol V2, because clients will +// fail serializing the frame +func TestMaxStreamIds(t *testing.T) { + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + originProtoVer := primitive.ProtocolVersion2 + targetProtoVer := primitive.ProtocolVersion2 + requestCount := 20 + maxStreamIdsConf := 10 + maxStreamIdsExpected := 10 + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + + buffer := utils.CreateLogHooks(log.WarnLevel, log.ErrorLevel) + defer log.StandardLogger().ReplaceHooks(make(log.LevelHooks)) + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewMaxStreamIdsRequestHandler("origin", "dc1", originAddress, maxStreamIdsExpected) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, []primitive.ProtocolVersion{targetProtoVer}) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, originProtoVer) + require.Nil(t, err) + + proxyConf.ProxyMaxStreamIds = maxStreamIdsConf + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + testSetup.Client.CqlClient.MaxInFlight = 127 // set to 127, otherwise we fail to serialize in protocol + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), originProtoVer, 0) + require.Nil(t, err) + defer cqlConn.Close() + + remainingRequests := requestCount + + for j := 0; j < 10; j++ { + var responses []client.InFlightRequest + for i := 0; i < remainingRequests; i++ { + inFlightReq, err := cqlConn.Send(frame.NewFrame(originProtoVer, 0, queryInsert)) + require.Nil(t, err) + responses = append(responses, inFlightReq) + } + + for _, response := range responses { + select { + case msg := <-response.Incoming(): + if response.Err() != nil { + t.Fatalf(response.Err().Error()) + } + switch msg.Body.Message.(type) { + case *message.VoidResult: + // expected, we have received successful response + remainingRequests-- + case *message.Overloaded: + // client received overloaded message due to insufficient stream ID pool, retry the request + default: + t.Fatalf(response.Err().Error()) + } + } + } + + if remainingRequests == 0 { + break + } + } + + require.True(t, strings.Contains(buffer.String(), "no stream id available")) + + require.True(t, len(originRequestHandler.usedStreamIdsPerConn) >= 1) + for _, idMap := range originRequestHandler.usedStreamIdsPerConn { + maxId := int16(0) + for streamId, _ := range idMap { + if streamId > maxId { + maxId = streamId + } + } + require.True(t, maxId < int16(maxStreamIdsExpected)) + } +} + +type MaxStreamIdsRequestHandler struct { + lock sync.Mutex + cluster string + datacenter string + peerIP string + maxStreamIds int + usedStreamIdsPerConn map[int]map[int16]bool +} + +func NewMaxStreamIdsRequestHandler(cluster string, datacenter string, peerIP string, maxStreamIds int) *MaxStreamIdsRequestHandler { + return &MaxStreamIdsRequestHandler{ + cluster: cluster, + datacenter: datacenter, + peerIP: peerIP, + maxStreamIds: maxStreamIds, + usedStreamIdsPerConn: make(map[int]map[int16]bool), + } +} + +func (recv *MaxStreamIdsRequestHandler) HandleRequest( + request *frame.Frame, + conn *client.CqlServerConnection, + ctx client.RequestHandlerContext) (response *frame.Frame) { + port := conn.RemoteAddr().(*net.TCPAddr).Port + + switch request.Body.Message.GetOpCode() { + case primitive.OpCodeStartup: + case primitive.OpCodeRegister: + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeQuery: + query := request.Body.Message.(*message.Query) + switch query.Query { + case "SELECT * FROM system.local": + // C* 2.0.0 does not store local endpoint details in system.local table + sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) + metadata := &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumns)), + Columns: systemLocalColumns, + } + if request.Header.Version == primitive.ProtocolVersion2 { + metadata = &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumnsProtocolV2)), + Columns: systemLocalColumnsProtocolV2, + } + } + sysLocMsg := &message.RowsResult{ + Metadata: metadata, + Data: message.RowSet{sysLocRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) + case "SELECT * FROM system.peers": + var sysPeerRows message.RowSet + if len(recv.peerIP) > 0 { + sysPeerRows = append(sysPeerRows, systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + request.Header.Version, + )) + } + sysPeeMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemPeersColumns)), + Columns: systemPeersColumns, + }, + Data: sysPeerRows, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + case "INSERT INTO test_ks.test(key, value) VALUES(1, '1')": + recv.lock.Lock() + usedStreamIdsMap := recv.usedStreamIdsPerConn[port] + if usedStreamIdsMap == nil { + usedStreamIdsMap = make(map[int16]bool) + recv.usedStreamIdsPerConn[port] = usedStreamIdsMap + } + usedStreamIdsMap[request.Header.StreamId] = true + recv.lock.Unlock() + + time.Sleep(5 * time.Millisecond) // introduce some delay so that stream IDs are not released immediately + + if len(usedStreamIdsMap) > recv.maxStreamIds { + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ + ErrorMessage: fmt.Sprintf("Too many stream IDs used (%d)", len(usedStreamIdsMap)), + }) + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.VoidResult{}) + } + } + return nil +} diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index 30b59bc..bfd91ba 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -168,8 +168,8 @@ func NewClientHandler( requestsDoneCtx, requestsDoneCancelFn := context.WithCancel(context.Background()) // Initialize stream id processors to manage the ids sent to the clusters - originCCProtoVer := originControlConn.cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) - targetCCProtoVer := targetControlConn.cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) + originCCProtoVer := originControlConn.cqlConn.GetProtocolVersion() + targetCCProtoVer := targetControlConn.cqlConn.GetProtocolVersion() streamIds := maxStreamIds(originCCProtoVer, targetCCProtoVer, conf) originFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeOrigin) targetFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeTarget) @@ -1485,17 +1485,27 @@ func (ch *ClientHandler) executeRequest( case forwardToBoth: log.Tracef("Forwarding request with opcode %v for stream %v to %v and %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin, common.ClusterTypeTarget) - ch.originCassandraConnector.sendRequestToCluster(originRequest) - ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest) + if sendErr != nil { + ch.clientConnector.sendOverloadedToClient(frameContext.frame) + } else { + ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + } case forwardToOrigin: log.Tracef("Forwarding request with opcode %v for stream %v to %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin) - ch.originCassandraConnector.sendRequestToCluster(originRequest) + sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest) + if sendErr != nil { + ch.clientConnector.sendOverloadedToClient(frameContext.frame) + } ch.targetCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs) case forwardToTarget: log.Tracef("Forwarding request with opcode %v for stream %v to %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeTarget) - ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + sendErr := ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + if sendErr != nil { + ch.clientConnector.sendOverloadedToClient(frameContext.frame) + } ch.originCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs) case forwardToAsyncOnly: default: @@ -2258,13 +2268,13 @@ func newFrameProcessor(maxStreamIds int, nodeMetrics *metrics.NodeMetrics, conne // uses negotiated protocol version to configure maximum number of stream IDs on node connections. Driver does NOT // change the number of stream IDs on per node basis. func maxStreamIds(originProtoVer primitive.ProtocolVersion, targetProtoVer primitive.ProtocolVersion, conf *config.Config) int { - maxSupported := maxOutgoingPending + maxSupported := maxStreamIdsV3 protoVer := originProtoVer if targetProtoVer < originProtoVer { protoVer = targetProtoVer } if protoVer == primitive.ProtocolVersion2 { - maxSupported = maxOutgoingPendingV2 + maxSupported = maxStreamIdsV2 } if maxSupported < conf.ProxyMaxStreamIds { return maxSupported diff --git a/proxy/pkg/zdmproxy/clusterconn.go b/proxy/pkg/zdmproxy/clusterconn.go index a1a35e9..21bea2c 100644 --- a/proxy/pkg/zdmproxy/clusterconn.go +++ b/proxy/pkg/zdmproxy/clusterconn.go @@ -394,17 +394,18 @@ func (cc *ClusterConnector) handleAsyncResponse(response *frame.RawFrame) *frame return nil } -func (cc *ClusterConnector) sendRequestToCluster(frame *frame.RawFrame) { +func (cc *ClusterConnector) sendRequestToCluster(frame *frame.RawFrame) error { var err error if cc.frameProcessor != nil { frame, err = cc.frameProcessor.AssignUniqueId(frame) } if err != nil { log.Errorf("[%v] Couldn't assign stream id to frame %v: %v", string(cc.connectorType), frame.Header.OpCode, err) - return + return err } else { cc.writeCoalescer.Enqueue(frame) } + return nil } func (cc *ClusterConnector) validateAsyncStateForRequest(frame *frame.RawFrame) bool { diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 4a2bc45..17fd804 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -362,7 +362,7 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( conn = newConn log.Infof("Successfully opened control connection to %v using endpoint %v with %v.", - cc.connConfig.GetClusterType(), endpoint.String(), newConn.GetProtocolVersion().Load().(primitive.ProtocolVersion)) + cc.connConfig.GetClusterType(), endpoint.String(), newConn.GetProtocolVersion()) break } diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index a9fdabb..92a047f 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -21,10 +21,11 @@ import ( const ( eventQueueLength = 2048 - maxIncomingPending = 2048 - maxIncomingPendingV2 = 128 - maxOutgoingPending = 2048 - maxOutgoingPendingV2 = 128 + maxIncomingPending = 2048 + maxOutgoingPending = 2048 + + maxStreamIdsV3 = 2048 + maxStreamIdsV2 = 127 timeOutsThreshold = 1024 ) @@ -41,7 +42,7 @@ type CqlConnection interface { SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error IsAuthEnabled() (bool, error) - GetProtocolVersion() *atomic.Value + GetProtocolVersion() primitive.ProtocolVersion } // Not thread safe @@ -245,8 +246,8 @@ func (c *cqlConn) IsAuthEnabled() (bool, error) { return c.authEnabled, nil } -func (c *cqlConn) GetProtocolVersion() *atomic.Value { - return c.protocolVersion +func (c *cqlConn) GetProtocolVersion() primitive.ProtocolVersion { + return c.protocolVersion.Load().(primitive.ProtocolVersion) } func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx context.Context) error {