From 7e7b1f2d5ab99c181d891db8a4f661f330f28378 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 24 Jul 2024 12:03:43 +0200 Subject: [PATCH] Support protocol v2 and protocol negotiation --- README.md | 17 +- go.mod | 4 +- go.sum | 6 +- integration-tests/asyncreads_test.go | 2 +- integration-tests/connect_test.go | 117 +++++++- integration-tests/customhandler_test_utils.go | 52 +++- integration-tests/functioncalls_test.go | 23 +- integration-tests/prepared_statements_test.go | 98 ++++-- integration-tests/protocolversions_test.go | 282 ++++++++++++++++++ integration-tests/setup/testcluster.go | 11 +- integration-tests/simulacron/cluster.go | 4 +- integration-tests/simulacron/http.go | 12 +- integration-tests/streamids_test.go | 257 ++++++++++++++++ proxy/pkg/config/config.go | 35 ++- proxy/pkg/config/config_test.go | 93 ++++++ proxy/pkg/zdmproxy/clientconn.go | 20 +- proxy/pkg/zdmproxy/clienthandler.go | 90 ++++-- proxy/pkg/zdmproxy/clienthandler_test.go | 34 +++ proxy/pkg/zdmproxy/clusterconn.go | 13 +- proxy/pkg/zdmproxy/controlconn.go | 85 ++++-- proxy/pkg/zdmproxy/cqlconn.go | 43 ++- proxy/pkg/zdmproxy/cqlparser.go | 11 +- .../cqlparser_adv_workloads_utils_test.go | 6 +- proxy/pkg/zdmproxy/cqlparser_test.go | 10 +- proxy/pkg/zdmproxy/frameprocessor.go | 4 +- proxy/pkg/zdmproxy/host.go | 25 +- proxy/pkg/zdmproxy/nativeprotocol.go | 12 +- proxy/pkg/zdmproxy/parametermodifier_test.go | 6 +- proxy/pkg/zdmproxy/querymodifier.go | 8 +- proxy/pkg/zdmproxy/querymodifier_test.go | 10 +- proxy/pkg/zdmproxy/response.go | 23 +- proxy/pkg/zdmproxy/streamidmapper.go | 68 ++++- proxy/pkg/zdmproxy/streamidmapper_test.go | 8 +- 33 files changed, 1282 insertions(+), 207 deletions(-) create mode 100644 integration-tests/protocolversions_test.go create mode 100644 integration-tests/streamids_test.go create mode 100644 proxy/pkg/zdmproxy/clienthandler_test.go diff --git a/README.md b/README.md index 917a4c83..70fa60fd 100644 --- a/README.md +++ b/README.md @@ -95,12 +95,24 @@ containerized sandbox environment. ## Supported Protocol Versions -**ZDM Proxy supports protocol versions v3, v4, DSE_V1 and DSE_V2.** +**ZDM Proxy supports protocol versions v2, v3, v4, DSE_V1 and DSE_V2.** It technically doesn't support v5, but handles protocol negotiation so that the client application properly downgrades the protocol version to v4 if v5 is requested. This means that any client application using a recent driver that supports protocol version v5 can be migrated using the ZDM Proxy (as long as it does not use v5-specific functionality). +ZDM Proxy requires origin and target clusters to have at least one protocol version in common. It is therefore not feasible +to configure Apache Cassandra 2.0 as origin and 3.x / 4.x as target. Below table displays protocol versions supported by +various C* versions: + +| Apache Cassandra | Protocol Version | +|------------------|------------------| +| 2.0 | V2 | +| 2.1 | V2, V3 | +| 2.2 | V2, V3, V4 | +| 3.x | V3, V4 | +| 4.x | V3, V4, V5 | + --- :warning: **Thrift is not supported by ZDM Proxy.** If you are using a very old driver or cluster version that only supports Thrift then you need to change your client application to use CQL and potentially upgrade your cluster before starting the @@ -110,8 +122,7 @@ migration process. In practice this means that ZDM Proxy supports the following cluster versions (as Origin and / or Target): -- Apache Cassandra from 2.1+ up to (and including) Apache Cassandra 4.x. Apache Cassandra 2.0 support will be introduced -when protocol version v2 is supported. +- Apache Cassandra from 2.0+ up to (and including) Apache Cassandra 4.x. (although both clusters have to support a common protocol version as mentioned above). - DataStax Enterprise 4.8+. DataStax Enterprise 4.6 and 4.7 support will be introduced when protocol version v2 is supported. - DataStax Astra DB (both Serverless and Classic) diff --git a/go.mod b/go.mod index e5cd8e7b..28211350 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211106181442-e4c1a74c66bd - github.com/datastax/go-cassandra-native-protocol v0.0.0-20220525125956-6158d9e218b8 + github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e github.com/google/uuid v1.1.1 github.com/jpillora/backoff v1.0.0 @@ -15,7 +15,6 @@ require ( github.com/rs/zerolog v1.20.0 github.com/sirupsen/logrus v1.6.0 github.com/stretchr/testify v1.8.0 - gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -35,4 +34,5 @@ require ( github.com/prometheus/procfs v0.0.8 // indirect golang.org/x/sys v0.3.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 5a3c50fa..387403a2 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,8 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dR github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/datastax/go-cassandra-native-protocol v0.0.0-20220525125956-6158d9e218b8 h1:NKLtNzC76ssf68VOenDAzMyQGg+QkxuD2QCubX+GvLk= -github.com/datastax/go-cassandra-native-protocol v0.0.0-20220525125956-6158d9e218b8/go.mod h1:yFD0OKoVV9d1QW7Es58c1Gv6ijrqTGPcxgHv27wdC4Q= +github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d h1:UnPtAA8Ux3GvHLazSSUydERFuoQRyxHrB8puzXyjXIE= +github.com/datastax/go-cassandra-native-protocol v0.0.0-20240626123646-2abea740da8d/go.mod h1:6FzirJfdffakAVqmHjwVfFkpru/gNbIazUOK5rIhndc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -105,7 +105,6 @@ github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -131,7 +130,6 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= diff --git a/integration-tests/asyncreads_test.go b/integration-tests/asyncreads_test.go index 3510d0b2..d487f2aa 100644 --- a/integration-tests/asyncreads_test.go +++ b/integration-tests/asyncreads_test.go @@ -287,7 +287,7 @@ func TestAsyncReadsRequestTypes(t *testing.T) { } testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig( - t, false, false, 1, nil) + t, false, false, 1, nil, nil) require.Nil(t, err) defer testSetup.Cleanup() diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index b7df5f71..b687ae44 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -3,12 +3,13 @@ package integration_tests import ( "bytes" "context" - client2 "github.com/datastax/go-cassandra-native-protocol/client" + cqlClient "github.com/datastax/go-cassandra-native-protocol/client" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/datastax/zdm-proxy/integration-tests/client" "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/simulacron" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/rs/zerolog" @@ -45,6 +46,82 @@ func TestGoCqlConnect(t *testing.T) { require.Equal(t, "fake", iter.Columns()[0].Name) } +// Simulacron-based test to make sure that we can handle invalid protocol error and downgrade +// used protocol on control connection. ORIGIN and TARGET are using the same C* version +func TestControlConnectionProtocolVersionNegotiation(t *testing.T) { + tests := []struct { + name string + clusterVersion string + controlConnMaxProtocolVersion string + negotiatedProtocolVersion primitive.ProtocolVersion + }{ + { + name: "Cluster2.1_MaxCCProtoVer4_NegotiatedProtoVer3", + clusterVersion: "2.1", + controlConnMaxProtocolVersion: "4", + negotiatedProtocolVersion: primitive.ProtocolVersion3, // protocol downgraded to V3, V4 is not supported + }, + { + name: "Cluster3.0_MaxCCProtoVer4_NegotiatedProtoVer4", + clusterVersion: "3.0", + controlConnMaxProtocolVersion: "4", + negotiatedProtocolVersion: primitive.ProtocolVersion4, // make sure that protocol negotiation does not fail if it is not actually needed + }, + { + name: "Cluster3.0_MaxCCProtoVer3_NegotiatedProtoVer3", + clusterVersion: "3.0", + controlConnMaxProtocolVersion: "3", + negotiatedProtocolVersion: primitive.ProtocolVersion3, // protocol V3 applied as it is the maximum configured + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := setup.NewTestConfig("", "") + c.ControlConnMaxProtocolVersion = tt.controlConnMaxProtocolVersion + testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, + &simulacron.ClusterVersion{tt.clusterVersion, tt.clusterVersion}) + require.Nil(t, err) + defer testSetup.Cleanup() + + query := "SELECT * FROM test" + expectedRows := simulacron.NewRowsResult( + map[string]simulacron.DataType{ + "company": simulacron.DataTypeText, + }).WithRow(map[string]interface{}{ + "company": "TBD", + }) + + err = testSetup.Origin.Prime(simulacron.WhenQuery( + query, + simulacron.NewWhenQueryOptions()). + ThenRowsSuccess(expectedRows)) + require.Nil(t, err) + + // Connect to proxy as a "client" + client := cqlClient.NewCqlClient("127.0.0.1:14002", nil) + cqlClientConn, err := client.ConnectAndInit(context.Background(), tt.negotiatedProtocolVersion, 0) + require.Nil(t, err) + defer cqlClientConn.Close() + + cqlConn, _ := testSetup.Proxy.GetOriginControlConn().GetConnAndContactPoint() + negotiatedProto := cqlConn.GetProtocolVersion() + require.Equal(t, tt.negotiatedProtocolVersion, negotiatedProto) + + queryMsg := &message.Query{ + Query: "SELECT * FROM test", + Options: &message.QueryOptions{Consistency: primitive.ConsistencyLevelOne}, + } + rsp, err := cqlClientConn.SendAndReceive(frame.NewFrame(primitive.ProtocolVersion3, 0, queryMsg)) + if err != nil { + t.Fatal("query failed:", err) + } + + require.Equal(t, 1, len(rsp.Body.Message.(*message.RowsResult).Data)) + }) + } +} + func TestMaxClientsThreshold(t *testing.T) { maxClients := 10 goCqlConnectionsPerHost := 1 @@ -82,20 +159,23 @@ func TestMaxClientsThreshold(t *testing.T) { func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { tests := []struct { - name string - requestVersion primitive.ProtocolVersion - expectedVersion primitive.ProtocolVersion - errExpected string + name string + requestVersion primitive.ProtocolVersion + negotiatedVersion string + expectedVersion primitive.ProtocolVersion + errExpected string }{ { "request v5, response v4", primitive.ProtocolVersion5, + "4", primitive.ProtocolVersion4, "Invalid or unsupported protocol version (5)", }, { "request v1, response v4", primitive.ProtocolVersion(0x1), + "4", primitive.ProtocolVersion4, "Invalid or unsupported protocol version (1)", }, @@ -112,13 +192,14 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { defer zerolog.SetGlobalLevel(oldZeroLogLevel) cfg := setup.NewTestConfig("127.0.1.1", "127.0.1.2") + cfg.ControlConnMaxProtocolVersion = test.negotiatedVersion cfg.LogLevel = "TRACE" // saw 1 test failure here once but logs didn't show enough info testSetup, err := setup.NewCqlServerTestSetup(t, cfg, false, false, false) require.Nil(t, err) defer testSetup.Cleanup() - testSetup.Origin.CqlServer.RequestHandlers = []client2.RequestHandler{client2.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {})} - testSetup.Target.CqlServer.RequestHandlers = []client2.RequestHandler{client2.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {})} + testSetup.Origin.CqlServer.RequestHandlers = []cqlClient.RequestHandler{cqlClient.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {})} + testSetup.Target.CqlServer.RequestHandlers = []cqlClient.RequestHandler{cqlClient.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {})} err = testSetup.Start(cfg, false, primitive.ProtocolVersion3) require.Nil(t, err) @@ -141,16 +222,18 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { type test struct { - name string - requestVersion primitive.ProtocolVersion - returnedVersion primitive.ProtocolVersion - expectedVersion primitive.ProtocolVersion - errExpected string + name string + requestVersion primitive.ProtocolVersion + negotiatedVersion string + returnedVersion primitive.ProtocolVersion + expectedVersion primitive.ProtocolVersion + errExpected string } tests := []*test{ { "DSE_V2 request, v5 returned, v4 expected", primitive.ProtocolVersionDse2, + "4", primitive.ProtocolVersion5, primitive.ProtocolVersion4, "Invalid or unsupported protocol version (5)", @@ -158,6 +241,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { { "DSE_V2 request, v1 returned, v4 expected", primitive.ProtocolVersionDse2, + "4", primitive.ProtocolVersion(0x01), primitive.ProtocolVersion4, "Invalid or unsupported protocol version (1)", @@ -165,6 +249,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { } runTestFunc := func(t *testing.T, test *test, cfg *config.Config) { + cfg.ControlConnMaxProtocolVersion = test.negotiatedVersion // simulate what version was negotiated on control connection testSetup, err := setup.NewCqlServerTestSetup(t, cfg, false, false, false) require.Nil(t, err) defer testSetup.Cleanup() @@ -172,7 +257,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { enableHandlers := atomic.Value{} enableHandlers.Store(false) - rawHandler := func(request *frame.Frame, conn *client2.CqlServerConnection, ctx client2.RequestHandlerContext) (response []byte) { + rawHandler := func(request *frame.Frame, conn *cqlClient.CqlServerConnection, ctx cqlClient.RequestHandlerContext) (response []byte) { if enableHandlers.Load().(bool) && request.Header.Version == test.requestVersion { encodedFrame, err := createFrameWithUnsupportedVersion(test.returnedVersion, request.Header.StreamId, true) if err != nil { @@ -184,8 +269,8 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { return nil } - testSetup.Origin.CqlServer.RequestRawHandlers = []client2.RawRequestHandler{rawHandler} - testSetup.Target.CqlServer.RequestRawHandlers = []client2.RawRequestHandler{rawHandler} + testSetup.Origin.CqlServer.RequestRawHandlers = []cqlClient.RawRequestHandler{rawHandler} + testSetup.Target.CqlServer.RequestRawHandlers = []cqlClient.RawRequestHandler{rawHandler} err = testSetup.Start(cfg, false, primitive.ProtocolVersion4) require.Nil(t, err) @@ -222,7 +307,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) { } func createFrameWithUnsupportedVersion(version primitive.ProtocolVersion, streamId int16, isResponse bool) ([]byte, error) { - mostSimilarVersion := primitive.ProtocolVersion4 + mostSimilarVersion := version if version > primitive.ProtocolVersionDse2 { mostSimilarVersion = primitive.ProtocolVersionDse2 } else if version < primitive.ProtocolVersion2 { diff --git a/integration-tests/customhandler_test_utils.go b/integration-tests/customhandler_test_utils.go index 7e440d6f..7101ca27 100644 --- a/integration-tests/customhandler_test_utils.go +++ b/integration-tests/customhandler_test_utils.go @@ -57,7 +57,7 @@ var ( releaseVersionColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "release_version", Type: datatype.Varchar} rpcAddressColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "rpc_address", Type: datatype.Inet} schemaVersionColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "schema_version", Type: datatype.Uuid} - tokensColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "tokens", Type: datatype.NewSetType(datatype.Varchar)} + tokensColumn = &message.ColumnMetadata{Keyspace: "system", Table: "local", Name: "tokens", Type: datatype.NewSet(datatype.Varchar)} ) // These columns are a subset of the total columns returned by OSS C* 3.11.2, and contain all the information that @@ -78,6 +78,22 @@ var systemLocalColumns = []*message.ColumnMetadata{ tokensColumn, } +// These columns are a subset of the total columns returned by OSS C* 2.0.0, and contain all the information that +// drivers need in order to establish the cluster topology and determine its characteristics. Please note that RPC address +// column is not present. +var systemLocalColumnsProtocolV2 = []*message.ColumnMetadata{ + keyColumn, + clusterNameColumn, + cqlVersionColumn, + datacenterColumn, + hostIdColumn, + partitionerColumn, + rackColumn, + releaseVersionColumn, + schemaVersionColumn, + tokensColumn, +} + var ( peerColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "peer", Type: datatype.Inet} datacenterPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "data_center", Type: datatype.Varchar} @@ -86,7 +102,7 @@ var ( releaseVersionPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "release_version", Type: datatype.Varchar} rpcAddressPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "rpc_address", Type: datatype.Inet} schemaVersionPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "schema_version", Type: datatype.Uuid} - tokensPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "tokens", Type: datatype.NewSetType(datatype.Varchar)} + tokensPeersColumn = &message.ColumnMetadata{Keyspace: "system", Table: "peers", Name: "tokens", Type: datatype.NewSet(datatype.Varchar)} ) // These columns are a subset of the total columns returned by OSS C* 3.11.2, and contain all the information that @@ -114,11 +130,13 @@ var ( func systemLocalRow(cluster string, datacenter string, customPartitioner string, addr net.Addr, version primitive.ProtocolVersion) message.Row { addrBuf := &bytes.Buffer{} - inetAddr := addr.(*net.TCPAddr).IP - if inetAddr.To4() != nil { - addrBuf.Write(inetAddr.To4()) - } else { - addrBuf.Write(inetAddr) + if addr != nil { + inetAddr := addr.(*net.TCPAddr).IP + if inetAddr.To4() != nil { + addrBuf.Write(inetAddr.To4()) + } else { + addrBuf.Write(inetAddr) + } } // emulate {'-9223372036854775808'} (entire ring) tokensBuf := &bytes.Buffer{} @@ -135,18 +153,32 @@ func systemLocalRow(cluster string, datacenter string, customPartitioner string, if customPartitioner != "" { partitionerValue = message.Column(customPartitioner) } + if version >= primitive.ProtocolVersion3 { + return message.Row{ + keyValue, + addrBuf.Bytes(), + message.Column(cluster), + cqlVersionValue, + message.Column(datacenter), + hostIdValue, + addrBuf.Bytes(), + partitionerValue, + rackValue, + releaseVersionValue, + addrBuf.Bytes(), + schemaVersionValue, + tokensBuf.Bytes(), + } + } return message.Row{ keyValue, - addrBuf.Bytes(), message.Column(cluster), cqlVersionValue, message.Column(datacenter), hostIdValue, - addrBuf.Bytes(), partitionerValue, rackValue, releaseVersionValue, - addrBuf.Bytes(), schemaVersionValue, tokensBuf.Bytes(), } diff --git a/integration-tests/functioncalls_test.go b/integration-tests/functioncalls_test.go index 7cf04d76..96fbbfeb 100644 --- a/integration-tests/functioncalls_test.go +++ b/integration-tests/functioncalls_test.go @@ -854,7 +854,7 @@ func TestNowFunctionReplacementPreparedStatement(t *testing.T) { isReplacedNow: false, value: []int{11, 22, 33}, valueSimulacron: []int{11, 22, 33}, - dataType: datatype.NewListType(datatype.Int), + dataType: datatype.NewList(datatype.Int), simulacronType: "list", }, { @@ -880,7 +880,7 @@ func TestNowFunctionReplacementPreparedStatement(t *testing.T) { {1, 2, 3}, {2, 3, 4}, }, - dataType: datatype.NewListType(datatype.NewTupleType(datatype.Int, datatype.Int, datatype.Int)), + dataType: datatype.NewList(datatype.NewTuple(datatype.Int, datatype.Int, datatype.Int)), simulacronType: "list>", }, { @@ -2261,7 +2261,7 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { } expectedBatchChildQueries = append(expectedBatchChildQueries, expectedBatchChildQuery) - var queryOrId interface{} + var batchChild *message.BatchChild if childStatement.prepared { when := simulacron.NewWhenQueryOptions() for _, p := range expectedChildQueryParams { @@ -2285,18 +2285,21 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { require.Nil(t, err) prepared, ok := resp.Body.Message.(*message.PreparedResult) require.True(t, ok) - queryOrId = prepared.PreparedQueryId + batchChild = &message.BatchChild{ + Id: prepared.PreparedQueryId, + Values: positionalValues, + } validateForwardedPrepare(simulacronSetup.Origin, childStatement) validateForwardedPrepare(simulacronSetup.Target, childStatement) } else { - queryOrId = childStatement.originalQuery + batchChild = &message.BatchChild{ + Query: childStatement.originalQuery, + Values: positionalValues, + } } - batchChildStatements = append(batchChildStatements, &message.BatchChild{ - QueryOrId: queryOrId, - Values: positionalValues, - }) + batchChildStatements = append(batchChildStatements, batchChild) } batchMsg := &message.Batch{ @@ -2325,7 +2328,7 @@ func TestNowFunctionReplacementBatchStatement(t *testing.T) { actualStmt := matching[0].QueriesOrIds[idx] actualParams := matching[0].Values[idx] if childStatement.prepared { - b64ExpectedValue := base64.StdEncoding.EncodeToString(batchChildStatements[idx].QueryOrId.([]byte)) + b64ExpectedValue := base64.StdEncoding.EncodeToString(batchChildStatements[idx].Id) require.Equal(t, b64ExpectedValue, actualStmt, idx) } else { if enableNowReplacement { diff --git a/integration-tests/prepared_statements_test.go b/integration-tests/prepared_statements_test.go index cf8edc11..53a50294 100644 --- a/integration-tests/prepared_statements_test.go +++ b/integration-tests/prepared_statements_test.go @@ -353,9 +353,9 @@ func TestPreparedIdReplacement(t *testing.T) { var batchPrepareMsg *message.Prepare var expectedBatchPrepareMsg *message.Prepare if test.batchQuery != "" { - batchPrepareMsg = prepareMsg.Clone().(*message.Prepare) + batchPrepareMsg = prepareMsg.DeepCopy() batchPrepareMsg.Query = test.batchQuery - expectedBatchPrepareMsg = batchPrepareMsg.Clone().(*message.Prepare) + expectedBatchPrepareMsg = batchPrepareMsg.DeepCopy() expectedBatchPrepareMsg.Query = test.expectedBatchQuery prepareResp, err = testSetup.Client.CqlConnection.SendAndReceive( frame.NewFrame(primitive.ProtocolVersion4, 10, batchPrepareMsg)) @@ -391,15 +391,15 @@ func TestPreparedIdReplacement(t *testing.T) { Type: primitive.BatchTypeLogged, Children: []*message.BatchChild{ { - QueryOrId: test.query, + Query: test.query, // the decoder uses empty slices instead of nil so this has to be initialized this way // so that the equality assertions work later in this test Values: make([]*primitive.Value, 0), }, { - QueryOrId: originBatchPreparedId, - Values: make([]*primitive.Value, 0), + Id: originBatchPreparedId, + Values: make([]*primitive.Value, 0), }, }, Consistency: primitive.ConsistencyLevelLocalQuorum, @@ -482,7 +482,7 @@ func TestPreparedIdReplacement(t *testing.T) { require.Equal(t, originPreparedId, originExecuteMessages[0].QueryId) if expectedOriginBatches > 0 { require.Equal(t, 2, len(originBatchMessages[0].Children)) - require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].QueryOrId) + require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].Id) } for _, targetExecute := range targetExecuteMessages { @@ -491,7 +491,7 @@ func TestPreparedIdReplacement(t *testing.T) { } if expectedTargetBatches > 0 { require.Equal(t, 2, len(targetBatchMessages[0].Children)) - require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].QueryOrId) + require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].Id) require.NotEqual(t, batchMsg, targetBatchMessages[0]) } @@ -508,8 +508,8 @@ func TestPreparedIdReplacement(t *testing.T) { require.NotEqual(t, len(executeMsg.Options.PositionalValues), len(originExecuteMessages[0].Options.PositionalValues)) // check if only the positional values are different, we test the parameter replacement in depth on other tests - modifiedOriginExecuteMsg := originExecuteMessages[0].Clone() - modifiedOriginExecuteMsg.(*message.Execute).Options.PositionalValues = executeMsg.Options.PositionalValues + modifiedOriginExecuteMsg := originExecuteMessages[0].DeepCopy() + modifiedOriginExecuteMsg.Options.PositionalValues = executeMsg.Options.PositionalValues require.Equal(t, executeMsg, modifiedOriginExecuteMsg) require.Equal(t, originExecuteMessages[0].Options, targetExecuteMessages[0].Options) } else { @@ -524,19 +524,19 @@ func TestPreparedIdReplacement(t *testing.T) { require.Equal(t, expectedBatchPrepareMsg, originPrepareMessages[1]) if test.expectedBatchPreparedStmtVariables != nil { - require.NotEqual(t, batchMsg.Children[0].QueryOrId, originBatchMessages[0].Children[0].QueryOrId) - require.NotEqual(t, batchMsg.Children[0].QueryOrId, targetBatchMessages[0].Children[0].QueryOrId) - require.Equal(t, originBatchMessages[0].Children[0].QueryOrId, targetBatchMessages[0].Children[0].QueryOrId) + batchChildNotEqual(t, batchMsg.Children[0], originBatchMessages[0].Children[0]) + batchChildNotEqual(t, batchMsg.Children[0], targetBatchMessages[0].Children[0]) + batchChildEqual(t, originBatchMessages[0].Children[0], targetBatchMessages[0].Children[0]) require.Equal(t, 0, len(targetBatchMessages[0].Children[0].Values)) require.Equal(t, 0, len(originBatchMessages[0].Children[0].Values)) require.Equal(t, 0, len(batchMsg.Children[0].Values)) - require.Equal(t, batchMsg.Children[1].QueryOrId, originBatchMessages[0].Children[1].QueryOrId) - require.NotEqual(t, batchMsg.Children[1].QueryOrId, targetBatchMessages[0].Children[1].QueryOrId) - require.NotEqual(t, originBatchMessages[0].Children[1].QueryOrId, targetBatchMessages[0].Children[1].QueryOrId) - require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].QueryOrId) - require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].QueryOrId) - require.Equal(t, originBatchPreparedId, batchMsg.Children[1].QueryOrId) + batchChildEqual(t, batchMsg.Children[1], originBatchMessages[0].Children[1]) + batchChildNotEqual(t, batchMsg.Children[1], targetBatchMessages[0].Children[1]) + batchChildNotEqual(t, originBatchMessages[0].Children[1], targetBatchMessages[0].Children[1]) + require.Equal(t, targetBatchPreparedId, targetBatchMessages[0].Children[1].Id) + require.Equal(t, originBatchPreparedId, originBatchMessages[0].Children[1].Id) + require.Equal(t, originBatchPreparedId, batchMsg.Children[1].Id) require.Equal(t, len(test.expectedBatchPreparedStmtVariables.Columns), len(targetBatchMessages[0].Children[1].Values)) require.Equal(t, len(test.expectedBatchPreparedStmtVariables.Columns), len(originBatchMessages[0].Children[1].Values)) require.Equal(t, 0, len(batchMsg.Children[1].Values)) @@ -546,8 +546,8 @@ func TestPreparedIdReplacement(t *testing.T) { } else { require.Equal(t, batchMsg, originBatchMessages[0]) require.NotEqual(t, batchMsg, targetBatchMessages[0]) - clonedBatchMsg := targetBatchMessages[0].Clone().(*message.Batch) - clonedBatchMsg.Children[1].QueryOrId = originBatchPreparedId + clonedBatchMsg := targetBatchMessages[0].DeepCopy() + clonedBatchMsg.Children[1].Id = originBatchPreparedId require.Equal(t, batchMsg, clonedBatchMsg) } } @@ -555,6 +555,46 @@ func TestPreparedIdReplacement(t *testing.T) { } } +func batchChildEqual(t *testing.T, child1 *message.BatchChild, child2 *message.BatchChild) { + id := false + if child1.Id != nil && child2.Id != nil { + id = true + require.Equal(t, child1.Id, child2.Id) + } else if child1.Id != nil || child2.Id != nil { + require.Fail(t, "unexpected id field presence: [%v], [%v]", child1.Id, child2.Id) + } + + query := false + if len(child1.Query) > 0 && len(child2.Query) > 0 { + query = true + require.Equal(t, child1.Query, child2.Query) + } else if len(child1.Query) > 0 || len(child2.Query) > 0 { + require.Fail(t, "unexpected query field presence: [%v], [%v]", child1.Query, child2.Query) + } + + require.True(t, id || query, "id or query fields should be present") +} + +func batchChildNotEqual(t *testing.T, child1 *message.BatchChild, child2 *message.BatchChild) { + id := false + if child1.Id != nil && child2.Id != nil { + id = true + require.NotEqual(t, child1.Id, child2.Id) + } else if child1.Id != nil || child2.Id != nil { + require.Fail(t, "unexpected query field presence: [%v], [%v]", child1.Id, child2.Id) + } + + query := false + if len(child1.Query) > 0 && len(child2.Query) > 0 { + query = true + require.NotEqual(t, child1.Query, child2.Query) + } else if len(child1.Query) > 0 || len(child2.Query) > 0 { + require.Fail(t, "unexpected query field presence: [%v], [%v]", child1.Query, child2.Query) + } + + require.True(t, id || query, "id or query fields should be present") +} + func TestUnpreparedIdReplacement(t *testing.T) { type test struct { name string @@ -706,7 +746,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { var batchMsg *message.Batch var batchPrepareMsg *message.Prepare if test.batchQuery != "" { - batchPrepareMsg = prepareMsg.Clone().(*message.Prepare) + batchPrepareMsg = prepareMsg.DeepCopy() batchPrepareMsg.Query = test.batchQuery prepareResp, err = testSetup.Client.CqlConnection.SendAndReceive( frame.NewFrame(primitive.ProtocolVersion4, 10, batchPrepareMsg)) @@ -721,14 +761,14 @@ func TestUnpreparedIdReplacement(t *testing.T) { Type: primitive.BatchTypeLogged, Children: []*message.BatchChild{ { - QueryOrId: test.query, + Query: test.query, // the decoder uses empty slices instead of nil so this has to be initialized this way // so that the equality assertions work later in this test Values: make([]*primitive.Value, 0), }, { - QueryOrId: originBatchPreparedId, - Values: make([]*primitive.Value, 0), + Id: originBatchPreparedId, + Values: make([]*primitive.Value, 0), }, }, Consistency: primitive.ConsistencyLevelLocalQuorum, @@ -843,7 +883,7 @@ func TestUnpreparedIdReplacement(t *testing.T) { if expectedTargetBatches > 0 { for _, batch := range targetBatchMessages { require.Equal(t, 2, len(batch.Children)) - require.Equal(t, targetBatchPreparedId, batch.Children[1].QueryOrId) + require.Equal(t, targetBatchPreparedId, batch.Children[1].Id) require.NotEqual(t, batchMsg, batch) } } @@ -1117,13 +1157,11 @@ func NewPreparedTestHandler( func checkIfPreparedIdMatches(batchMsg *message.Batch, preparedId []byte) (bool, []byte) { var batchPreparedId []byte for _, child := range batchMsg.Children { - switch queryOrId := child.QueryOrId.(type) { - case []byte: - batchPreparedId = queryOrId - if !bytes.Equal(queryOrId, preparedId) { + if child.Id != nil { + batchPreparedId = child.Id + if !bytes.Equal(child.Id, preparedId) { return false, batchPreparedId } - default: } } diff --git a/integration-tests/protocolversions_test.go b/integration-tests/protocolversions_test.go new file mode 100644 index 00000000..4da2980d --- /dev/null +++ b/integration-tests/protocolversions_test.go @@ -0,0 +1,282 @@ +package integration_tests + +import ( + "context" + "fmt" + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/datatype" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/stretchr/testify/require" + "net" + "slices" + "testing" +) + +// Test that proxy can establish connectivity with ORIGIN and TARGET +// clusters that support different set of protocol versions. Verify also that +// client driver can connect and successfully insert or query data. +func TestProtocolNegotiationDifferentClusters(t *testing.T) { + tests := []struct { + name string + proxyMaxProtoVer string + proxyOriginContConnVer primitive.ProtocolVersion + proxyTargetContConnVer primitive.ProtocolVersion + originProtoVer []primitive.ProtocolVersion + targetProtoVer []primitive.ProtocolVersion + clientProtoVer primitive.ProtocolVersion + failClientConnect bool + failProxyStartup bool + }{ + { + name: "OriginV2_TargetV2_ClientV2", + proxyMaxProtoVer: "2", + proxyOriginContConnVer: primitive.ProtocolVersion2, + proxyTargetContConnVer: primitive.ProtocolVersion2, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV2_TargetV2_ClientV2_ProxyControlConnNegotiation", + proxyMaxProtoVer: "4", + proxyOriginContConnVer: primitive.ProtocolVersion2, + proxyTargetContConnVer: primitive.ProtocolVersion2, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV2_TargetV23_ClientV2", + proxyMaxProtoVer: "3", + proxyOriginContConnVer: primitive.ProtocolVersion2, + proxyTargetContConnVer: primitive.ProtocolVersion3, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + name: "OriginV23_TargetV2_ClientV2", + proxyMaxProtoVer: "3", + proxyOriginContConnVer: primitive.ProtocolVersion3, + proxyTargetContConnVer: primitive.ProtocolVersion2, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + clientProtoVer: primitive.ProtocolVersion2, + }, + { + // most common setup with OSS Cassandra + name: "OriginV345_TargetV345_ClientV4", + proxyMaxProtoVer: "DseV2", + proxyOriginContConnVer: primitive.ProtocolVersion4, + proxyTargetContConnVer: primitive.ProtocolVersion4, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + clientProtoVer: primitive.ProtocolVersion4, + }, + { + // most common setup with DSE + name: "OriginV345_TargetV34Dse1Dse2_ClientV4", + proxyMaxProtoVer: "DseV2", + proxyOriginContConnVer: primitive.ProtocolVersion4, + proxyTargetContConnVer: primitive.ProtocolVersionDse2, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersion5}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1, primitive.ProtocolVersionDse2}, + clientProtoVer: primitive.ProtocolVersion4, + }, + { + name: "OriginV2_TargetV3_ClientV2", + proxyMaxProtoVer: "3", + proxyOriginContConnVer: primitive.ProtocolVersion2, + proxyTargetContConnVer: primitive.ProtocolVersion3, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion2}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, + // client connection should fail as there is no common protocol version between origin and target + failClientConnect: true, + }, { + name: "OriginV3_TargetV3_ClientV3_Too_Low_Proto_Configured", + proxyMaxProtoVer: "2", + proxyOriginContConnVer: primitive.ProtocolVersion3, + proxyTargetContConnVer: primitive.ProtocolVersion3, + originProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + targetProtoVer: []primitive.ProtocolVersion{primitive.ProtocolVersion3}, + clientProtoVer: primitive.ProtocolVersion2, + // client proxy startup, because configured protocol version is too low + failProxyStartup: true, + }, + } + + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + querySelect := &message.Query{ + Query: "SELECT * FROM test_ks.test", + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxyConf.ControlConnMaxProtocolVersion = test.proxyMaxProtoVer + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewProtocolNegotiationRequestHandler("origin", "dc1", originAddress, test.originProtoVer) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, test.targetProtoVer) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, test.clientProtoVer) + require.Nil(t, err) + + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + if test.failProxyStartup { + require.NotNil(t, err) + return + } else { + require.Nil(t, err) + } + + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), test.clientProtoVer, 0) + if test.failClientConnect { + require.NotNil(t, err) + return + } + require.Nil(t, err) + defer cqlConn.Close() + + response, err := cqlConn.SendAndReceive(frame.NewFrame(test.clientProtoVer, 0, queryInsert)) + require.Nil(t, err) + require.IsType(t, &message.VoidResult{}, response.Body.Message) + + response, err = cqlConn.SendAndReceive(frame.NewFrame(test.clientProtoVer, 0, querySelect)) + require.Nil(t, err) + resultSet := response.Body.Message.(*message.RowsResult).Data + require.Equal(t, 1, len(resultSet)) + + proxyCqlConn, _ := proxy.GetOriginControlConn().GetConnAndContactPoint() + require.Equal(t, test.proxyOriginContConnVer, proxyCqlConn.GetProtocolVersion()) + proxyCqlConn, _ = proxy.GetTargetControlConn().GetConnAndContactPoint() + require.Equal(t, test.proxyTargetContConnVer, proxyCqlConn.GetProtocolVersion()) + }) + } +} + +type ProtocolNegotiationRequestHandler struct { + cluster string + datacenter string + peerIP string + protocolVersions []primitive.ProtocolVersion // accepted protocol versions + // store negotiated protocol versions by socket port number + // protocol version negotiated by proxy on control connections can be different from the one + // used by client driver with ORIGIN and TARGET nodes. In the scenario 'OriginV2_TargetV23_ClientV2', proxy + // will establish control connection with ORIGIN using version 2, and TARGET with version 3. + // Protocol version applied on client connections with TARGET will be different - V2. + negotiatedProtoVer map[int]primitive.ProtocolVersion // negotiated protocol version on different sockets +} + +func NewProtocolNegotiationRequestHandler(cluster string, datacenter string, peerIP string, + protocolVersion []primitive.ProtocolVersion) *ProtocolNegotiationRequestHandler { + return &ProtocolNegotiationRequestHandler{ + cluster: cluster, + datacenter: datacenter, + peerIP: peerIP, + protocolVersions: protocolVersion, + negotiatedProtoVer: make(map[int]primitive.ProtocolVersion), + } +} + +func (recv *ProtocolNegotiationRequestHandler) HandleRequest( + request *frame.Frame, + conn *client.CqlServerConnection, + ctx client.RequestHandlerContext) (response *frame.Frame) { + port := conn.RemoteAddr().(*net.TCPAddr).Port + negotiatedProtoVer := recv.negotiatedProtoVer[port] + if !slices.Contains(recv.protocolVersions, request.Header.Version) || (negotiatedProtoVer != 0 && negotiatedProtoVer != request.Header.Version) { + // server does not support given protocol version, or it was not the one negotiated + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ + ErrorMessage: fmt.Sprintf("Invalid or unsupported protocol version (%d)", request.Header.Version), + }) + } + switch request.Body.Message.GetOpCode() { + case primitive.OpCodeStartup: + recv.negotiatedProtoVer[port] = request.Header.Version + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeRegister: + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeQuery: + query := request.Body.Message.(*message.Query) + switch query.Query { + case "SELECT * FROM system.local": + // C* 2.0.0 does not store local endpoint details in system.local table + sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) + metadata := &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumns)), + Columns: systemLocalColumns, + } + if negotiatedProtoVer == primitive.ProtocolVersion2 { + metadata = &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumnsProtocolV2)), + Columns: systemLocalColumnsProtocolV2, + } + } + sysLocMsg := &message.RowsResult{ + Metadata: metadata, + Data: message.RowSet{sysLocRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) + case "SELECT * FROM system.peers": + var sysPeerRows message.RowSet + if len(recv.peerIP) > 0 { + sysPeerRows = append(sysPeerRows, systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + negotiatedProtoVer, + )) + } + sysPeeMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemPeersColumns)), + Columns: systemPeersColumns, + }, + Data: sysPeerRows, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + case "SELECT * FROM test_ks.test": + qryMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 2, + Columns: []*message.ColumnMetadata{ + {Keyspace: "test_ks", Table: "test", Name: "key", Type: datatype.Varchar}, + {Keyspace: "test_ks", Table: "test", Name: "value", Type: datatype.Uuid}, + }, + }, + Data: message.RowSet{ + message.Row{keyValue, hostIdValue}, + }, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, qryMsg) + case "INSERT INTO test_ks.test(key, value) VALUES(1, '1')": + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.VoidResult{}) + } + } + return nil +} diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index 1eb60144..55ee74f7 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -127,22 +127,22 @@ func NewSimulacronTestSetupWithSession(t *testing.T, createProxy bool, createSes } func NewSimulacronTestSetupWithSessionAndConfig(t *testing.T, createProxy bool, createSession bool, config *config.Config) (*SimulacronTestSetup, error) { - return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, 1, config) + return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, 1, config, nil) } func NewSimulacronTestSetupWithSessionAndNodes(t *testing.T, createProxy bool, createSession bool, nodes int) (*SimulacronTestSetup, error) { - return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, nodes, nil) + return NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, createProxy, createSession, nodes, nil, nil) } -func NewSimulacronTestSetupWithSessionAndNodesAndConfig(t *testing.T, createProxy bool, createSession bool, nodes int, config *config.Config) (*SimulacronTestSetup, error) { +func NewSimulacronTestSetupWithSessionAndNodesAndConfig(t *testing.T, createProxy bool, createSession bool, nodes int, config *config.Config, version *simulacron.ClusterVersion) (*SimulacronTestSetup, error) { if !env.RunMockTests { t.Skip("Skipping Simulacron tests, RUN_MOCKTESTS is set false") } - origin, err := simulacron.GetNewCluster(createSession, nodes) + origin, err := simulacron.GetNewCluster(createSession, nodes, version) if err != nil { log.Panic("simulacron origin startup failed: ", err) } - target, err := simulacron.GetNewCluster(createSession, nodes) + target, err := simulacron.GetNewCluster(createSession, nodes, version) if err != nil { log.Panic("simulacron target startup failed: ", err) } @@ -452,6 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config { conf.ReadMode = config.ReadModePrimaryOnly conf.SystemQueriesMode = config.SystemQueriesModeOrigin conf.AsyncHandshakeTimeoutMs = 4000 + conf.ControlConnMaxProtocolVersion = "DseV2" conf.ProxyRequestTimeoutMs = 10000 diff --git a/integration-tests/simulacron/cluster.go b/integration-tests/simulacron/cluster.go index 0c98387c..6423c833 100644 --- a/integration-tests/simulacron/cluster.go +++ b/integration-tests/simulacron/cluster.go @@ -83,14 +83,14 @@ func (baseSimulacron *baseSimulacron) GetId() string { return baseSimulacron.id } -func GetNewCluster(startSession bool, numberOfNodes int) (*Cluster, error) { +func GetNewCluster(startSession bool, numberOfNodes int, version *ClusterVersion) (*Cluster, error) { process, err := GetOrCreateGlobalSimulacronProcess() if err != nil { return nil, err } - cluster, createErr := process.Create(startSession, numberOfNodes) + cluster, createErr := process.Create(startSession, numberOfNodes, version) if createErr != nil { return nil, createErr diff --git a/integration-tests/simulacron/http.go b/integration-tests/simulacron/http.go index cfb18241..ff18967b 100644 --- a/integration-tests/simulacron/http.go +++ b/integration-tests/simulacron/http.go @@ -18,6 +18,11 @@ type ClusterData struct { Datacenters []*DatacenterData `json:"data_centers"` } +type ClusterVersion struct { + Cassandra string + Dse string +} + type DatacenterData struct { Id int `json:"id"` Nodes []*NodeData `json:"nodes"` @@ -31,11 +36,14 @@ type NodeData struct { const createUrl = "/cluster?data_centers=%s&cassandra_version=%s&dse_version=%s&name=%s&activity_log=%s&num_tokens=%d" -func (process *Process) Create(startSession bool, numberOfNodes int) (*Cluster, error) { +func (process *Process) Create(startSession bool, numberOfNodes int, version *ClusterVersion) (*Cluster, error) { + if version == nil { + version = &ClusterVersion{env.CassandraVersion, env.DseVersion} + } name := "test_" + uuid.New().String() resp, err := process.execHttp( "POST", - fmt.Sprintf(createUrl, strconv.FormatInt(int64(numberOfNodes), 10), env.CassandraVersion, env.DseVersion, name, "true", 1), + fmt.Sprintf(createUrl, strconv.FormatInt(int64(numberOfNodes), 10), version.Cassandra, version.Dse, name, "true", 1), nil) if err != nil { diff --git a/integration-tests/streamids_test.go b/integration-tests/streamids_test.go new file mode 100644 index 00000000..4796cc86 --- /dev/null +++ b/integration-tests/streamids_test.go @@ -0,0 +1,257 @@ +package integration_tests + +import ( + "context" + "fmt" + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/utils" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "net" + "strings" + "sync" + "testing" + "time" +) + +// Test sending more concurrent, async request than allowed stream IDs. +// Origin and target clusters are stubbed and will return protocol error +// if we notice greater stream ID value than expected. We cannot easily test +// exceeding 127 stream IDs allowed in protocol V2, because clients will +// fail serializing the frame +func TestLimitStreamIdsGeneration(t *testing.T) { + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + originProtoVer := primitive.ProtocolVersion2 + targetProtoVer := primitive.ProtocolVersion2 + requestCount := 20 + maxStreamIdsConf := 10 + maxStreamIdsExpected := 10 + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + + buffer := utils.CreateLogHooks(log.WarnLevel, log.ErrorLevel) + defer log.StandardLogger().ReplaceHooks(make(log.LevelHooks)) + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewMaxStreamIdsRequestHandler("origin", "dc1", originAddress, maxStreamIdsExpected) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, []primitive.ProtocolVersion{targetProtoVer}) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, originProtoVer) + require.Nil(t, err) + + proxyConf.ProxyMaxStreamIds = maxStreamIdsConf + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + testSetup.Client.CqlClient.MaxInFlight = 127 // set to 127, otherwise we fail to serialize in protocol + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), originProtoVer, 0) + require.Nil(t, err) + defer cqlConn.Close() + + remainingRequests := requestCount + + for j := 0; j < 10; j++ { + var responses []client.InFlightRequest + for i := 0; i < remainingRequests; i++ { + inFlightReq, err := cqlConn.Send(frame.NewFrame(originProtoVer, 0, queryInsert)) + require.Nil(t, err) + responses = append(responses, inFlightReq) + } + + for _, response := range responses { + select { + case msg := <-response.Incoming(): + if response.Err() != nil { + t.Fatalf(response.Err().Error()) + } + switch msg.Body.Message.(type) { + case *message.VoidResult: + // expected, we have received successful response + remainingRequests-- + case *message.Overloaded: + // client received overloaded message due to insufficient stream ID pool, retry the request + default: + t.Fatalf(response.Err().Error()) + } + } + } + + if remainingRequests == 0 { + break + } + } + + require.True(t, strings.Contains(buffer.String(), "no stream id available")) + + require.True(t, len(originRequestHandler.usedStreamIdsPerConn) >= 1) + for _, idMap := range originRequestHandler.usedStreamIdsPerConn { + maxId := int16(0) + for streamId := range idMap { + if streamId > maxId { + maxId = streamId + } + } + require.True(t, maxId < int16(maxStreamIdsExpected)) + } +} + +func TestFailOnNegativeStreamIDsFromClient(t *testing.T) { + originAddress := "127.0.1.1" + targetAddress := "127.0.1.2" + originProtoVer := primitive.ProtocolVersion2 + targetProtoVer := primitive.ProtocolVersion2 + serverConf := setup.NewTestConfig(originAddress, targetAddress) + proxyConf := setup.NewTestConfig(originAddress, targetAddress) + + queryInsert := &message.Query{ + Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters + } + + testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false) + require.Nil(t, err) + defer testSetup.Cleanup() + + originRequestHandler := NewMaxStreamIdsRequestHandler("origin", "dc1", originAddress, 100) + targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, []primitive.ProtocolVersion{targetProtoVer}) + + testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{ + originRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}), + } + testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{ + targetRequestHandler.HandleRequest, + client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}), + } + + err = testSetup.Start(nil, false, originProtoVer) + require.Nil(t, err) + + proxyConf.ProxyMaxStreamIds = 100 + proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy + if proxy != nil { + defer proxy.Shutdown() + } + require.Nil(t, err) + + cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), originProtoVer, 0) + require.Nil(t, err) + defer cqlConn.Close() + + response, _ := cqlConn.SendAndReceive(frame.NewFrame(originProtoVer, -1, queryInsert)) + require.IsType(t, response.Body.Message, &message.ProtocolError{}) + require.Equal(t, "negative stream id: -1", response.Body.Message.(*message.ProtocolError).ErrorMessage) +} + +type MaxStreamIdsRequestHandler struct { + lock sync.Mutex + cluster string + datacenter string + peerIP string + maxStreamIds int + usedStreamIdsPerConn map[int]map[int16]bool +} + +func NewMaxStreamIdsRequestHandler(cluster string, datacenter string, peerIP string, maxStreamIds int) *MaxStreamIdsRequestHandler { + return &MaxStreamIdsRequestHandler{ + cluster: cluster, + datacenter: datacenter, + peerIP: peerIP, + maxStreamIds: maxStreamIds, + usedStreamIdsPerConn: make(map[int]map[int16]bool), + } +} + +func (recv *MaxStreamIdsRequestHandler) HandleRequest( + request *frame.Frame, + conn *client.CqlServerConnection, + ctx client.RequestHandlerContext) (response *frame.Frame) { + port := conn.RemoteAddr().(*net.TCPAddr).Port + + switch request.Body.Message.GetOpCode() { + case primitive.OpCodeStartup: + case primitive.OpCodeRegister: + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) + case primitive.OpCodeQuery: + query := request.Body.Message.(*message.Query) + switch query.Query { + case "SELECT * FROM system.local": + // C* 2.0.0 does not store local endpoint details in system.local table + sysLocRow := systemLocalRow(recv.cluster, recv.datacenter, "Murmur3Partitioner", nil, request.Header.Version) + metadata := &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumns)), + Columns: systemLocalColumns, + } + if request.Header.Version == primitive.ProtocolVersion2 { + metadata = &message.RowsMetadata{ + ColumnCount: int32(len(systemLocalColumnsProtocolV2)), + Columns: systemLocalColumnsProtocolV2, + } + } + sysLocMsg := &message.RowsResult{ + Metadata: metadata, + Data: message.RowSet{sysLocRow}, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysLocMsg) + case "SELECT * FROM system.peers": + var sysPeerRows message.RowSet + if len(recv.peerIP) > 0 { + sysPeerRows = append(sysPeerRows, systemPeersRow( + recv.datacenter, + &net.TCPAddr{IP: net.ParseIP(recv.peerIP), Port: 9042}, + request.Header.Version, + )) + } + sysPeeMsg := &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: int32(len(systemPeersColumns)), + Columns: systemPeersColumns, + }, + Data: sysPeerRows, + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, sysPeeMsg) + case "INSERT INTO test_ks.test(key, value) VALUES(1, '1')": + recv.lock.Lock() + usedStreamIdsMap := recv.usedStreamIdsPerConn[port] + if usedStreamIdsMap == nil { + usedStreamIdsMap = make(map[int16]bool) + recv.usedStreamIdsPerConn[port] = usedStreamIdsMap + } + usedStreamIdsMap[request.Header.StreamId] = true + recv.lock.Unlock() + + time.Sleep(5 * time.Millisecond) // introduce some delay so that stream IDs are not released immediately + + if len(usedStreamIdsMap) > recv.maxStreamIds { + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.ProtocolError{ + ErrorMessage: fmt.Sprintf("Too many stream IDs used (%d)", len(usedStreamIdsMap)), + }) + } + return frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.VoidResult{}) + } + } + return nil +} diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index afd2f220..cc62842b 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "fmt" + "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/datastax/zdm-proxy/proxy/pkg/common" "github.com/kelseyhightower/envconfig" def "github.com/mcuadros/go-defaults" @@ -19,11 +20,12 @@ type Config struct { // Global bucket - PrimaryCluster string `default:"ORIGIN" split_words:"true" yaml:"primary_cluster"` - ReadMode string `default:"PRIMARY_ONLY" split_words:"true" yaml:"read_mode"` - ReplaceCqlFunctions bool `default:"false" split_words:"true" yaml:"replace_cql_functions"` - AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true" yaml:"async_handshake_timeout_ms"` - LogLevel string `default:"INFO" split_words:"true" yaml:"log_level"` + PrimaryCluster string `default:"ORIGIN" split_words:"true" yaml:"primary_cluster"` + ReadMode string `default:"PRIMARY_ONLY" split_words:"true" yaml:"read_mode"` + ReplaceCqlFunctions bool `default:"false" split_words:"true" yaml:"replace_cql_functions"` + AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true" yaml:"async_handshake_timeout_ms"` + LogLevel string `default:"INFO" split_words:"true" yaml:"log_level"` + ControlConnMaxProtocolVersion string `default:"DseV2" split_words:"true" yaml:"control_conn_max_protocol_version"` // Numeric Cassandra OSS protocol version or DseV1 / DseV2 // Proxy Topology (also known as system.peers "virtualization") bucket @@ -315,6 +317,11 @@ func (c *Config) Validate() error { return err } + _, err = c.ParseControlConnMaxProtocolVersion() + if err != nil { + return err + } + return nil } @@ -369,6 +376,24 @@ func (c *Config) ParseReadMode() (common.ReadMode, error) { } } +func (c *Config) ParseControlConnMaxProtocolVersion() (primitive.ProtocolVersion, error) { + if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV2") { + return primitive.ProtocolVersionDse2, nil + } + if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV1") { + return primitive.ProtocolVersionDse1, nil + } + ver, err := strconv.ParseUint(c.ControlConnMaxProtocolVersion, 10, 32) + if err != nil { + return 0, fmt.Errorf("could not parse control connection max protocol version, valid values are "+ + "2, 3, 4, DseV1, DseV2; original err: %w", err) + } + if ver < 2 || ver > 4 { + return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2") + } + return primitive.ProtocolVersion(ver), nil +} + func (c *Config) ParseLogLevel() (log.Level, error) { level, err := log.ParseLevel(strings.TrimSpace(c.LogLevel)) if err != nil { diff --git a/proxy/pkg/config/config_test.go b/proxy/pkg/config/config_test.go index e211fc5d..74eaa557 100644 --- a/proxy/pkg/config/config_test.go +++ b/proxy/pkg/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/stretchr/testify/require" "testing" ) @@ -94,6 +95,98 @@ func TestTargetConfig_WithHostnameButWithoutPort(t *testing.T) { require.Equal(t, 9042, c.TargetPort) } +func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { + defer clearAllEnvVars() + + // general setup + clearAllEnvVars() + setOriginCredentialsEnvVars() + setTargetCredentialsEnvVars() + setOriginContactPointsAndPortEnvVars() + + // test-specific setup + setTargetContactPointsAndPortEnvVars() + + conf := New() + err := conf.parseEnvVars() + require.Nil(t, err) + + tests := []struct { + name string + controlConnMaxProtocolVersion string + parsedProtocolVersion primitive.ProtocolVersion + errorMessage string + }{ + { + name: "ParsedV2", + controlConnMaxProtocolVersion: "2", + parsedProtocolVersion: primitive.ProtocolVersion2, + errorMessage: "", + }, + { + name: "ParsedV3", + controlConnMaxProtocolVersion: "3", + parsedProtocolVersion: primitive.ProtocolVersion3, + errorMessage: "", + }, + { + name: "ParsedV4", + controlConnMaxProtocolVersion: "4", + parsedProtocolVersion: primitive.ProtocolVersion4, + errorMessage: "", + }, + { + name: "ParsedDse1", + controlConnMaxProtocolVersion: "DseV1", + parsedProtocolVersion: primitive.ProtocolVersionDse1, + errorMessage: "", + }, + { + name: "ParsedDse2", + controlConnMaxProtocolVersion: "DseV2", + parsedProtocolVersion: primitive.ProtocolVersionDse2, + errorMessage: "", + }, + { + name: "ParsedDse2CaseInsensitive", + controlConnMaxProtocolVersion: "dsev2", + parsedProtocolVersion: primitive.ProtocolVersionDse2, + errorMessage: "", + }, + { + name: "UnsupportedCassandraV5", + controlConnMaxProtocolVersion: "5", + parsedProtocolVersion: 0, + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", + }, + { + name: "UnsupportedCassandraV1", + controlConnMaxProtocolVersion: "1", + parsedProtocolVersion: 0, + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", + }, + { + name: "InvalidValue", + controlConnMaxProtocolVersion: "Dsev123", + parsedProtocolVersion: 0, + errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, DseV1, DseV2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf.ControlConnMaxProtocolVersion = tt.controlConnMaxProtocolVersion + ver, err := conf.ParseControlConnMaxProtocolVersion() + if ver == 0 { + require.NotNil(t, err) + require.Contains(t, err.Error(), tt.errorMessage) + } else { + require.Equal(t, tt.parsedProtocolVersion, ver) + } + }) + } +} + func TestConfig_LoadNotExistingFile(t *testing.T) { defer clearAllEnvVars() clearAllEnvVars() diff --git a/proxy/pkg/zdmproxy/clientconn.go b/proxy/pkg/zdmproxy/clientconn.go index b6bafe65..33e8b66d 100644 --- a/proxy/pkg/zdmproxy/clientconn.go +++ b/proxy/pkg/zdmproxy/clientconn.go @@ -56,6 +56,8 @@ type ClientConnector struct { readScheduler *Scheduler shutdownRequestCtx context.Context + + minProtoVer primitive.ProtocolVersion } func NewClientConnector( @@ -71,7 +73,8 @@ func NewClientConnector( readScheduler *Scheduler, writeScheduler *Scheduler, shutdownRequestCtx context.Context, - clientHandlerShutdownRequestCancelFn context.CancelFunc) *ClientConnector { + clientHandlerShutdownRequestCancelFn context.CancelFunc, + minProtoVer primitive.ProtocolVersion) *ClientConnector { return &ClientConnector{ connection: connection, @@ -97,6 +100,7 @@ func NewClientConnector( readScheduler: readScheduler, shutdownRequestCtx: shutdownRequestCtx, clientHandlerShutdownRequestCancelFn: clientHandlerShutdownRequestCancelFn, + minProtoVer: minProtoVer, } } @@ -176,7 +180,7 @@ func (cc *ClientConnector) listenForRequests() { for cc.clientHandlerContext.Err() == nil { f, err := readRawFrame(bufferedReader, connectionAddr, cc.clientHandlerContext) - protocolErrResponseFrame, err, _ := checkProtocolError(f, err, protocolErrOccurred, ClientConnectorLogPrefix) + protocolErrResponseFrame, err, _ := checkProtocolError(f, cc.minProtoVer, err, protocolErrOccurred, ClientConnectorLogPrefix) if err != nil { handleConnectionError( err, cc.clientHandlerContext, cc.clientHandlerCancelFunc, ClientConnectorLogPrefix, "reading", connectionAddr) @@ -187,7 +191,7 @@ func (cc *ClientConnector) listenForRequests() { cc.sendResponseToClient(protocolErrResponseFrame) continue } else if alreadySentProtocolErr != nil { - clonedProtocolErr := alreadySentProtocolErr.Clone() + clonedProtocolErr := alreadySentProtocolErr.DeepCopy() clonedProtocolErr.Header.StreamId = f.Header.StreamId cc.sendResponseToClient(clonedProtocolErr) continue @@ -224,7 +228,7 @@ func (cc *ClientConnector) sendOverloadedToClient(request *frame.RawFrame) { } } -func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred bool, prefix string) (protocolErrResponse *frame.RawFrame, fatalErr error, errorCode int8) { +func checkProtocolError(f *frame.RawFrame, protoVer primitive.ProtocolVersion, connErr error, protocolErrorOccurred bool, prefix string) (protocolErrResponse *frame.RawFrame, fatalErr error, errorCode int8) { var protocolErrMsg *message.ProtocolError var streamId int16 var logMsg string @@ -244,7 +248,7 @@ func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred if !protocolErrorOccurred { log.Debugf("[%v] %v Returning a protocol error to the client to force a downgrade: %v.", prefix, logMsg, protocolErrMsg) } - rawProtocolErrResponse, err := generateProtocolErrorResponseFrame(streamId, protocolErrMsg) + rawProtocolErrResponse, err := generateProtocolErrorResponseFrame(streamId, protoVer, protocolErrMsg) if err != nil { return nil, fmt.Errorf("could not generate protocol error response raw frame (%v): %v", protocolErrMsg, err), -1 } else { @@ -255,10 +259,8 @@ func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred } } -func generateProtocolErrorResponseFrame(streamId int16, protocolErrMsg *message.ProtocolError) (*frame.RawFrame, error) { - // ideally we would use the maximum version between the versions used by both control connections if - // control connections implemented protocol version negotiation - response := frame.NewFrame(primitive.ProtocolVersion4, streamId, protocolErrMsg) +func generateProtocolErrorResponseFrame(streamId int16, protoVer primitive.ProtocolVersion, protocolErrMsg *message.ProtocolError) (*frame.RawFrame, error) { + response := frame.NewFrame(protoVer, streamId, protocolErrMsg) rawResponse, err := defaultCodec.ConvertToRawFrame(response) if err != nil { return nil, err diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index 1eb5e1dd..066acf0a 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -16,6 +16,7 @@ import ( log "github.com/sirupsen/logrus" "net" "sort" + "strings" "sync" "sync/atomic" "time" @@ -168,9 +169,18 @@ func NewClientHandler( requestsDoneCtx, requestsDoneCancelFn := context.WithCancel(context.Background()) // Initialize stream id processors to manage the ids sent to the clusters - originFrameProcessor := newFrameProcessor(conf, nodeMetrics, ClusterConnectorTypeOrigin) - targetFrameProcessor := newFrameProcessor(conf, nodeMetrics, ClusterConnectorTypeTarget) - asyncFrameProcessor := newFrameProcessor(conf, nodeMetrics, ClusterConnectorTypeAsync) + originCCProtoVer := originControlConn.cqlConn.GetProtocolVersion() + targetCCProtoVer := targetControlConn.cqlConn.GetProtocolVersion() + // Calculate maximum number of stream IDs. Take the oldest protocol version negotiated between two clusters + // and apply limit defined in proxy configuration. If origin or target cluster are still running protocol V2, + // we will limit maximum number of stream IDs to 127 on both clusters. Logic is based on Java driver version 3.x. + // Java driver 3.x was the last one supporting protocol V2. It establishes control connection first, and then + // uses negotiated protocol version to configure maximum number of stream IDs on node connections. Driver does NOT + // change the number of stream IDs on per node basis. Maximum stream ID is calculated while creating stream ID mapper. + minimalProtoVer := minProtoVer(originCCProtoVer, targetCCProtoVer) + originFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeOrigin) + targetFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeTarget) + asyncFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeAsync) closeFrameProcessors := func() { originFrameProcessor.Close() @@ -197,7 +207,7 @@ func NewClientHandler( originConnector, err := NewClusterConnector( originCassandraConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg, clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx, - false, nil, handshakeDone, originFrameProcessor) + false, nil, handshakeDone, originFrameProcessor, originCCProtoVer) if err != nil { clientHandlerCancelFunc() return nil, err @@ -206,7 +216,7 @@ func NewClientHandler( targetConnector, err := NewClusterConnector( targetCassandraConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg, clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx, - false, nil, handshakeDone, targetFrameProcessor) + false, nil, handshakeDone, targetFrameProcessor, targetCCProtoVer) if err != nil { clientHandlerCancelFunc() return nil, err @@ -224,7 +234,7 @@ func NewClientHandler( asyncConnector, err = NewClusterConnector( asyncConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg, clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx, - true, asyncPendingRequests, handshakeDone, asyncFrameProcessor) + true, asyncPendingRequests, handshakeDone, asyncFrameProcessor, originCCProtoVer) if err != nil { log.Errorf("Could not create async cluster connector to %s, async requests will not be forwarded: %s", asyncConnInfo.connConfig.GetClusterType(), err.Error()) asyncConnector = nil @@ -260,7 +270,8 @@ func NewClientHandler( readScheduler, writeScheduler, clientHandlerShutdownRequestContext, - clientHandlerShutdownRequestCancelFn), + clientHandlerShutdownRequestCancelFn, + minProtoVer(originCCProtoVer, targetCCProtoVer)), asyncConnector: asyncConnector, originCassandraConnector: originConnector, @@ -891,7 +902,7 @@ func (ch *ClientHandler) processClientResponse( return nil, fmt.Errorf("invalid cluster type: %v", responseClusterType) } - newFrame = decodedFrame.Clone() + newFrame = decodedFrame.DeepCopy() newUnprepared := &message.Unprepared{ ErrorMessage: fmt.Sprintf("Prepared query with ID %s not found (either the query was not prepared "+ "on this host (maybe the host has been restarted?) or you have prepared too many queries and it has "+ @@ -945,7 +956,7 @@ func (ch *ClientHandler) processPreparedResponse( return nil, fmt.Errorf("replaced terms in the prepared statement but prepared result doesn't have variables metadata: %v", bodyMsg) } - newResponse = response.Clone() + newResponse = response.DeepCopy() newPreparedBody, ok := newResponse.Body.Message.(*message.PreparedResult) if !ok { return nil, fmt.Errorf("could not modify prepared result to remove generated parameters because "+ @@ -1482,17 +1493,27 @@ func (ch *ClientHandler) executeRequest( case forwardToBoth: log.Tracef("Forwarding request with opcode %v for stream %v to %v and %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin, common.ClusterTypeTarget) - ch.originCassandraConnector.sendRequestToCluster(originRequest) - ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest) + if sendErr != nil { + ch.handleRequestSendFailure(sendErr, frameContext) + } else { + ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + } case forwardToOrigin: log.Tracef("Forwarding request with opcode %v for stream %v to %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin) - ch.originCassandraConnector.sendRequestToCluster(originRequest) + sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest) + if sendErr != nil { + ch.handleRequestSendFailure(sendErr, frameContext) + } ch.targetCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs) case forwardToTarget: log.Tracef("Forwarding request with opcode %v for stream %v to %v", f.Header.OpCode, f.Header.StreamId, common.ClusterTypeTarget) - ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + sendErr := ch.targetCassandraConnector.sendRequestToCluster(targetRequest) + if sendErr != nil { + ch.handleRequestSendFailure(sendErr, frameContext) + } ch.originCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs) case forwardToAsyncOnly: default: @@ -1513,6 +1534,21 @@ func (ch *ClientHandler) executeRequest( overallRequestStartTime, requestTimeout) } +func (ch *ClientHandler) handleRequestSendFailure(err error, frameContext *frameDecodeContext) { + if strings.Contains(err.Error(), "no stream id available") { + ch.clientConnector.sendOverloadedToClient(frameContext.frame) + } else if strings.Contains(err.Error(), "negative stream id") { + responseMessage := &message.ProtocolError{ErrorMessage: err.Error()} + responseFrame, err := generateProtocolErrorResponseFrame( + frameContext.frame.Header.StreamId, frameContext.frame.Header.Version, responseMessage) + if err != nil { + log.Errorf("could not generate protocol error response raw frame (%v): %v", responseMessage, err) + } else { + ch.clientConnector.sendResponseToClient(responseFrame) + } + } +} + func (ch *ClientHandler) handleInterceptedRequest( requestInfo RequestInfo, frameContext *frameDecodeContext, currentKeyspace string) (*frame.RawFrame, error) { @@ -1655,7 +1691,7 @@ func (ch *ClientHandler) handleExecuteRequest( } replacementTimeUuids = ch.parameterModifier.generateTimeUuids(prepareRequestInfo) - newOriginRequest := clientRequest.Clone() + newOriginRequest := clientRequest.DeepCopy() _, err = ch.parameterModifier.AddValuesToExecuteFrame( newOriginRequest, prepareRequestInfo, preparedData.GetOriginVariablesMetadata(), replacementTimeUuids) if err != nil { @@ -1677,7 +1713,7 @@ func (ch *ClientHandler) handleExecuteRequest( return nil, nil, nil, fmt.Errorf("could not decode execute raw frame: %w", err) } - newTargetRequest := clientRequest.Clone() + newTargetRequest := clientRequest.DeepCopy() var newTargetExecuteMsg *message.Execute if len(replacedTerms) > 0 { if replacementTimeUuids == nil { @@ -1726,7 +1762,7 @@ func (ch *ClientHandler) handleBatchRequest( var newOriginRequest *frame.Frame var newOriginBatchMsg *message.Batch - newTargetRequest := decodedFrame.Clone() + newTargetRequest := decodedFrame.DeepCopy() newTargetBatchMsg, ok := newTargetRequest.Body.Message.(*message.Batch) if !ok { return nil, nil, fmt.Errorf("expected Batch but got %v instead", newTargetRequest.Body.Message.GetOpCode()) @@ -1736,7 +1772,7 @@ func (ch *ClientHandler) handleBatchRequest( prepareRequestInfo := preparedData.GetPrepareRequestInfo() if len(prepareRequestInfo.GetReplacedTerms()) > 0 { if newOriginRequest == nil { - newOriginRequest = decodedFrame.Clone() + newOriginRequest = decodedFrame.DeepCopy() newOriginBatchMsg, ok = newOriginRequest.Body.Message.(*message.Batch) if !ok { return nil, nil, fmt.Errorf("expected Batch but got %v instead", newOriginRequest.Body.Message.GetOpCode()) @@ -1754,8 +1790,8 @@ func (ch *ClientHandler) handleBatchRequest( } } - originalQueryId := newTargetBatchMsg.Children[stmtIdx].QueryOrId.([]byte) - newTargetBatchMsg.Children[stmtIdx].QueryOrId = preparedData.GetTargetPreparedId() + originalQueryId := newTargetBatchMsg.Children[stmtIdx].Id + newTargetBatchMsg.Children[stmtIdx].Id = preparedData.GetTargetPreparedId() log.Tracef("Replacing prepared ID %s within a BATCH with %s for target cluster.", hex.EncodeToString(originalQueryId), hex.EncodeToString(preparedData.GetTargetPreparedId())) } @@ -1803,7 +1839,7 @@ func (ch *ClientHandler) sendToAsyncConnector( } if sendAlsoToAsync { - asyncRequest = asyncRequest.Clone() // forwardToAsyncOnly requests don't need to be cloned because they are only sent to 1 connector + asyncRequest = asyncRequest.DeepCopy() // forwardToAsyncOnly requests don't need to be cloned because they are only sent to 1 connector } if isFireAndForget { @@ -2230,7 +2266,8 @@ func GetNodeMetricsByClusterConnector(nodeMetrics *metrics.NodeMetrics, connecto } } -func newFrameProcessor(conf *config.Config, nodeMetrics *metrics.NodeMetrics, connectorType ClusterConnectorType) FrameProcessor { +func newFrameProcessor(protoVer primitive.ProtocolVersion, config *config.Config, nodeMetrics *metrics.NodeMetrics, + connectorType ClusterConnectorType) FrameProcessor { var streamIdsMetric metrics.Gauge connectorMetrics, err := GetNodeMetricsByClusterConnector(nodeMetrics, connectorType) if err != nil { @@ -2241,9 +2278,16 @@ func newFrameProcessor(conf *config.Config, nodeMetrics *metrics.NodeMetrics, co } var mapper StreamIdMapper if connectorType == ClusterConnectorTypeAsync { - mapper = NewInternalStreamIdMapper(conf.ProxyMaxStreamIds, streamIdsMetric) + mapper = NewInternalStreamIdMapper(protoVer, config, streamIdsMetric) } else { - mapper = NewStreamIdMapper(conf.ProxyMaxStreamIds, streamIdsMetric) + mapper = NewStreamIdMapper(protoVer, config, streamIdsMetric) } return NewStreamIdProcessor(mapper) } + +func minProtoVer(version1 primitive.ProtocolVersion, version2 primitive.ProtocolVersion) primitive.ProtocolVersion { + if version1 < version2 { + return version1 + } + return version2 +} diff --git a/proxy/pkg/zdmproxy/clienthandler_test.go b/proxy/pkg/zdmproxy/clienthandler_test.go new file mode 100644 index 00000000..b5452557 --- /dev/null +++ b/proxy/pkg/zdmproxy/clienthandler_test.go @@ -0,0 +1,34 @@ +package zdmproxy + +import ( + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/proxy/pkg/config" + "github.com/stretchr/testify/require" + "testing" +) + +func TestMaxStreamIds(t *testing.T) { + type args struct { + originProtoVer primitive.ProtocolVersion + targetProtoVer primitive.ProtocolVersion + config *config.Config + expectedMaxStreamIds int + } + tests := []struct { + name string + args args + expectedMaxStreamIds int + }{ + {"OriginV3_TargetV3_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion3, config: &config.Config{ProxyMaxStreamIds: 2048}}, 2048}, + {"OriginV3_TargetV4_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion4, config: &config.Config{ProxyMaxStreamIds: 2048}}, 2048}, + {"OriginV3_TargetV4_LowerConfig", args{originProtoVer: primitive.ProtocolVersion3, targetProtoVer: primitive.ProtocolVersion4, config: &config.Config{ProxyMaxStreamIds: 1024}}, 1024}, + {"OriginV2_TargetV3_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion2, targetProtoVer: primitive.ProtocolVersion3, config: &config.Config{ProxyMaxStreamIds: 2048}}, 127}, + {"OriginV2_TargetV2_DefaultConfig", args{originProtoVer: primitive.ProtocolVersion2, targetProtoVer: primitive.ProtocolVersion2, config: &config.Config{ProxyMaxStreamIds: 2048}}, 127}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ids := maxStreamIds(minProtoVer(tt.args.originProtoVer, tt.args.targetProtoVer), tt.args.config) + require.Equal(t, tt.expectedMaxStreamIds, ids) + }) + } +} diff --git a/proxy/pkg/zdmproxy/clusterconn.go b/proxy/pkg/zdmproxy/clusterconn.go index a1a35e93..deeeaa45 100644 --- a/proxy/pkg/zdmproxy/clusterconn.go +++ b/proxy/pkg/zdmproxy/clusterconn.go @@ -75,6 +75,8 @@ type ClusterConnector struct { lastHeartbeatTime *atomic.Value lastHeartbeatLock sync.Mutex + + ccProtoVer primitive.ProtocolVersion } func NewClusterConnectionInfo(connConfig ConnectionConfig, endpointConfig Endpoint, isOriginCassandra bool) *ClusterConnectionInfo { @@ -101,7 +103,8 @@ func NewClusterConnector( asyncConnector bool, asyncPendingRequests *pendingRequests, handshakeDone *atomic.Value, - frameProcessor FrameProcessor) (*ClusterConnector, error) { + frameProcessor FrameProcessor, + ccProtoVer primitive.ProtocolVersion) (*ClusterConnector, error) { var connectorType ClusterConnectorType var clusterType common.ClusterType @@ -181,6 +184,7 @@ func NewClusterConnector( asyncPendingRequests: asyncPendingRequests, handshakeDone: handshakeDone, lastHeartbeatTime: lastHeartbeatTime, + ccProtoVer: ccProtoVer, }, nil } @@ -247,7 +251,7 @@ func (cc *ClusterConnector) runResponseListeningLoop() { protocolErrOccurred := false for { response, err := readRawFrame(bufferedReader, connectionAddr, cc.clusterConnContext) - protocolErrResponseFrame, err, errCode := checkProtocolError(response, err, protocolErrOccurred, string(cc.connectorType)) + protocolErrResponseFrame, err, errCode := checkProtocolError(response, cc.ccProtoVer, err, protocolErrOccurred, string(cc.connectorType)) if err != nil { handleConnectionError( @@ -394,17 +398,18 @@ func (cc *ClusterConnector) handleAsyncResponse(response *frame.RawFrame) *frame return nil } -func (cc *ClusterConnector) sendRequestToCluster(frame *frame.RawFrame) { +func (cc *ClusterConnector) sendRequestToCluster(frame *frame.RawFrame) error { var err error if cc.frameProcessor != nil { frame, err = cc.frameProcessor.AssignUniqueId(frame) } if err != nil { log.Errorf("[%v] Couldn't assign stream id to frame %v: %v", string(cc.connectorType), frame.Header.OpCode, err) - return + return err } else { cc.writeCoalescer.Enqueue(frame) } + return nil } func (cc *ClusterConnector) validateAsyncStateForRequest(frame *frame.RawFrame) bool { diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index e32bc967..17fd8048 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -2,6 +2,7 @@ package zdmproxy import ( "context" + "errors" "fmt" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" @@ -58,7 +59,6 @@ type ControlConn struct { const ProxyVirtualRack = "rack0" const ProxyVirtualPartitioner = "org.apache.cassandra.dht.Murmur3Partitioner" -const ccProtocolVersion = primitive.ProtocolVersion3 const ccWriteTimeout = 5 * time.Second const ccReadTimeout = 10 * time.Second @@ -125,7 +125,7 @@ func (cc *ControlConn) Start(wg *sync.WaitGroup, ctx context.Context) error { log.Infof("Received topology event from %v, refreshing topology.", cc.connConfig.GetClusterType()) - conn, _ := cc.getConnAndContactPoint() + conn, _ := cc.GetConnAndContactPoint() if conn == nil { log.Debugf("Topology refresh scheduled but the control connection isn't open. " + "Falling back to the connection where the event was received.") @@ -162,7 +162,7 @@ func (cc *ControlConn) Start(wg *sync.WaitGroup, ctx context.Context) error { cc.Close() } - conn, _ := cc.getConnAndContactPoint() + conn, _ := cc.GetConnAndContactPoint() if conn == nil { useContactPointsOnly := false if !lastOpenSuccessful { @@ -251,7 +251,7 @@ func (cc *ControlConn) ReadFailureCounter() int { } func (cc *ControlConn) Open(contactPointsOnly bool, ctx context.Context) (CqlConnection, error) { - oldConn, _ := cc.getConnAndContactPoint() + oldConn, _ := cc.GetConnAndContactPoint() if oldConn != nil { cc.Close() oldConn = nil @@ -320,15 +320,10 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( currentIndex := (firstEndpointIndex + i) % len(endpoints) endpoint = endpoints[currentIndex] - tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false) - if err != nil { - log.Warnf("Failed to open control connection to %v using endpoint %v: %v", - cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err) - continue - } - newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf) - err = newConn.InitializeContext(ccProtocolVersion, ctx) + maxProtoVer, _ := cc.conf.ParseControlConnMaxProtocolVersion() + newConn, err := cc.connAndNegotiateProtoVer(endpoint, maxProtoVer, ctx) + if err == nil { newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) { switch f.Body.Message.(type) { @@ -355,23 +350,71 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( log.Warnf("Error while initializing a new cql connection for the control connection of %v: %v", cc.connConfig.GetClusterType(), err) } - err2 := newConn.Close() - if err2 != nil { - log.Errorf("Failed to close cql connection: %v", err2) + if newConn != nil { + err2 := newConn.Close() + if err2 != nil { + log.Errorf("Failed to close cql connection: %v", err2) + } } continue } conn = newConn - log.Infof("Successfully opened control connection to %v using endpoint %v.", - cc.connConfig.GetClusterType(), endpoint.String()) + log.Infof("Successfully opened control connection to %v using endpoint %v with %v.", + cc.connConfig.GetClusterType(), endpoint.String(), newConn.GetProtocolVersion()) break } return conn, endpoint } +func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer primitive.ProtocolVersion, ctx context.Context) (CqlConnection, error) { + protoVer := initialProtoVer + for { + tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false) + if err != nil { + log.Warnf("Failed to open control connection to %v using endpoint %v: %v", + cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err) + return nil, err + } + newConn := NewCqlConnection(endpoint, tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf, protoVer) + err = newConn.InitializeContext(protoVer, ctx) + var respErr *ResponseError + if err != nil && errors.As(err, &respErr) && respErr.IsProtocolError() && strings.Contains(err.Error(), "Invalid or unsupported protocol version") { + // unsupported protocol version + // protocol renegotiation requires opening a new TCP connection + err2 := newConn.Close() + if err2 != nil { + log.Errorf("Failed to close cql connection: %v", err2) + } + protoVer = downgradeProtocol(protoVer) + log.Debugf("Downgrading protocol version: %v", protoVer) + if protoVer == 0 { + // we cannot downgrade anymore + return nil, err + } + continue // retry lower protocol version + } else { + return newConn, err // we may have successfully established connection or faced other error + } + } +} + +func downgradeProtocol(version primitive.ProtocolVersion) primitive.ProtocolVersion { + switch version { + case primitive.ProtocolVersionDse2: + return primitive.ProtocolVersionDse1 + case primitive.ProtocolVersionDse1: + return primitive.ProtocolVersion4 + case primitive.ProtocolVersion4: + return primitive.ProtocolVersion3 + case primitive.ProtocolVersion3: + return primitive.ProtocolVersion2 + } + return 0 +} + func (cc *ControlConn) Close() { cc.cqlConnLock.Lock() conn := cc.cqlConn @@ -387,12 +430,12 @@ func (cc *ControlConn) Close() { } func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]*Host, error) { - localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx) + localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.local table: %w", err) } - localInfo, localHost, err := ParseSystemLocalResult(localQueryResult, cc.defaultPort) + localInfo, localHost, err := ParseSystemLocalResult(localQueryResult, conn.GetEndpoint(), cc.defaultPort) if err != nil { return nil, err } @@ -410,7 +453,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] } } - peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx) + peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.peers table: %w", err) } @@ -636,7 +679,7 @@ func (cc *ControlConn) setConn(oldConn CqlConnection, newConn CqlConnection, new return cc.cqlConn, cc.currentContactPoint } -func (cc *ControlConn) getConnAndContactPoint() (CqlConnection, Endpoint) { +func (cc *ControlConn) GetConnAndContactPoint() (CqlConnection, Endpoint) { cc.cqlConnLock.Lock() conn := cc.cqlConn contactPoint := cc.currentContactPoint diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 041894fe..16d9ad4e 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -14,6 +14,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" ) @@ -23,26 +24,32 @@ const ( maxIncomingPending = 2048 maxOutgoingPending = 2048 + maxStreamIdsV3 = 2048 + maxStreamIdsV2 = 127 + timeOutsThreshold = 1024 ) type CqlConnection interface { + GetEndpoint() Endpoint IsInitialized() bool InitializeContext(version primitive.ProtocolVersion, ctx context.Context) error SendAndReceive(request *frame.Frame, ctx context.Context) (*frame.Frame, error) Close() error Execute(msg message.Message, ctx context.Context) (message.Message, error) - Query(cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error) + Query(cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error) SendHeartbeat(ctx context.Context) error SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error IsAuthEnabled() (bool, error) + GetProtocolVersion() primitive.ProtocolVersion } // Not thread safe type cqlConn struct { readTimeout time.Duration writeTimeout time.Duration + endpoint Endpoint conn net.Conn credentials *AuthCredentials initialized bool @@ -59,25 +66,31 @@ type cqlConn struct { eventHandlerLock *sync.Mutex authEnabled bool frameProcessor FrameProcessor + protocolVersion *atomic.Value } var ( StreamIdMismatchErr = errors.New("stream id of the response is different from the stream id of the request") ) +func (c *cqlConn) GetEndpoint() Endpoint { + return c.endpoint +} + func (c *cqlConn) String() string { return fmt.Sprintf("cqlConn{conn: %v}", c.conn.RemoteAddr().String()) } func NewCqlConnection( - conn net.Conn, + endpoint Endpoint, conn net.Conn, username string, password string, readTimeout time.Duration, writeTimeout time.Duration, - conf *config.Config) CqlConnection { + conf *config.Config, protoVer primitive.ProtocolVersion) CqlConnection { ctx, cFn := context.WithCancel(context.Background()) cqlConn := &cqlConn{ readTimeout: readTimeout, writeTimeout: writeTimeout, + endpoint: endpoint, conn: conn, credentials: &AuthCredentials{ Username: username, @@ -95,7 +108,9 @@ func NewCqlConnection( closed: false, eventHandlerLock: &sync.Mutex{}, authEnabled: true, - frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(conf.ProxyMaxStreamIds, nil)), + // protoVer is the proposed protocol version using which we will try to establish connectivity + frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(protoVer, conf, nil)), + protocolVersion: &atomic.Value{}, } cqlConn.StartRequestLoop() cqlConn.StartResponseLoop() @@ -231,12 +246,17 @@ func (c *cqlConn) IsAuthEnabled() (bool, error) { return c.authEnabled, nil } +func (c *cqlConn) GetProtocolVersion() primitive.ProtocolVersion { + return c.protocolVersion.Load().(primitive.ProtocolVersion) +} + func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx context.Context) error { authEnabled, err := c.PerformHandshake(version, ctx) if err != nil { return fmt.Errorf("failed to perform handshake: %w", err) } + c.protocolVersion.Store(version) c.initialized = true c.authEnabled = authEnabled return nil @@ -353,6 +373,8 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex } } } + case *message.ProtocolError: + err = &ResponseError{Response: response} default: err = fmt.Errorf("expected AUTHENTICATE or READY, got %v", response.Body.Message) } @@ -367,15 +389,16 @@ func (c *cqlConn) PerformHandshake(version primitive.ProtocolVersion, ctx contex } func (c *cqlConn) Query( - cql string, genericTypeCodec *GenericTypeCodec, version primitive.ProtocolVersion, ctx context.Context) (*ParsedRowSet, error) { + cql string, genericTypeCodec *GenericTypeCodec, ctx context.Context) (*ParsedRowSet, error) { queryMsg := &message.Query{ Query: cql, Options: &message.QueryOptions{ - Consistency: primitive.ConsistencyLevelLocalQuorum, + Consistency: primitive.ConsistencyLevelOne, }, } - queryFrame := frame.NewFrame(ccProtocolVersion, -1, queryMsg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + queryFrame := frame.NewFrame(version, -1, queryMsg) var rowSet *ParsedRowSet for { localResponse, err := c.SendAndReceive(queryFrame, ctx) @@ -429,7 +452,8 @@ func (c *cqlConn) Query( } func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Message, error) { - queryFrame := frame.NewFrame(ccProtocolVersion, -1, msg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + queryFrame := frame.NewFrame(version, -1, msg) localResponse, err := c.SendAndReceive(queryFrame, ctx) if err != nil { return nil, err @@ -440,7 +464,8 @@ func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Mes func (c *cqlConn) SendHeartbeat(ctx context.Context) error { optionsMsg := &message.Options{} - heartBeatFrame := frame.NewFrame(ccProtocolVersion, -1, optionsMsg) + version := c.protocolVersion.Load().(primitive.ProtocolVersion) + heartBeatFrame := frame.NewFrame(version, -1, optionsMsg) response, err := c.SendAndReceive(heartBeatFrame, ctx) if err != nil { diff --git a/proxy/pkg/zdmproxy/cqlparser.go b/proxy/pkg/zdmproxy/cqlparser.go index 56446ba9..937afe73 100644 --- a/proxy/pkg/zdmproxy/cqlparser.go +++ b/proxy/pkg/zdmproxy/cqlparser.go @@ -115,15 +115,13 @@ func buildRequestInfo( } preparedDataByStmtIdxMap := make(map[int]PreparedData) for childIdx, child := range batchMsg.Children { - switch queryOrId := child.QueryOrId.(type) { - case []byte: - preparedData, err := getPreparedData(psCache, mh, queryOrId, primitive.OpCodeBatch, decodedFrame) + if child.Id != nil { + preparedData, err := getPreparedData(psCache, mh, child.Id, primitive.OpCodeBatch, decodedFrame) if err != nil { return nil, err } else { preparedDataByStmtIdxMap[childIdx] = preparedData } - default: } } return NewBatchRequestInfo(preparedDataByStmtIdxMap), nil @@ -352,11 +350,10 @@ func (recv *frameDecodeContext) inspectStatements(currentKeyspace string, timeUu currentKeyspace = typedMsg.Keyspace } for idx, childStmt := range typedMsg.Children { - switch typedQueryOrId := childStmt.QueryOrId.(type) { - case string: + if len(childStmt.Query) > 0 { statementsQueryData = append( statementsQueryData, &statementQueryData{ - statementIndex: idx, queryData: inspectCqlQuery(typedQueryOrId, currentKeyspace, timeUuidGenerator)}) + statementIndex: idx, queryData: inspectCqlQuery(childStmt.Query, currentKeyspace, timeUuidGenerator)}) } } default: diff --git a/proxy/pkg/zdmproxy/cqlparser_adv_workloads_utils_test.go b/proxy/pkg/zdmproxy/cqlparser_adv_workloads_utils_test.go index 4ae50b02..46cefec9 100644 --- a/proxy/pkg/zdmproxy/cqlparser_adv_workloads_utils_test.go +++ b/proxy/pkg/zdmproxy/cqlparser_adv_workloads_utils_test.go @@ -39,6 +39,8 @@ func getGeneralParamsForTests(t *testing.T) params { } func buildQueryMessageForTests(queryString string) *message.Query { + var defaultTimestamp int64 = 1647023221311969 + var serialConsistency = primitive.ConsistencyLevelLocalSerial return &message.Query{ Query: queryString, Options: &message.QueryOptions{ @@ -49,8 +51,8 @@ func buildQueryMessageForTests(queryString string) *message.Query { PageSize: 5000, PageSizeInBytes: false, PagingState: nil, - SerialConsistency: &primitive.NillableConsistencyLevel{Value: primitive.ConsistencyLevelLocalSerial}, - DefaultTimestamp: &primitive.NillableInt64{Value: 1647023221311969}, + SerialConsistency: &serialConsistency, + DefaultTimestamp: &defaultTimestamp, Keyspace: "", NowInSeconds: nil, ContinuousPagingOptions: &message.ContinuousPagingOptions{ diff --git a/proxy/pkg/zdmproxy/cqlparser_test.go b/proxy/pkg/zdmproxy/cqlparser_test.go index a211abba..198cf822 100644 --- a/proxy/pkg/zdmproxy/cqlparser_test.go +++ b/proxy/pkg/zdmproxy/cqlparser_test.go @@ -182,7 +182,15 @@ func mockExecuteFrame(t *testing.T, preparedId string) *frame.RawFrame { } func mockBatch(t *testing.T, query interface{}) *frame.RawFrame { - batchMsg := &message.Batch{Children: []*message.BatchChild{{QueryOrId: query}}} + var child message.BatchChild + switch query.(type) { + case []byte: + child = message.BatchChild{Id: query.([]byte)} + default: + child = message.BatchChild{Query: query.(string)} + + } + batchMsg := &message.Batch{Children: []*message.BatchChild{&child}} return mockFrame(t, batchMsg, primitive.ProtocolVersion4) } diff --git a/proxy/pkg/zdmproxy/frameprocessor.go b/proxy/pkg/zdmproxy/frameprocessor.go index ee392a61..2beb62b1 100644 --- a/proxy/pkg/zdmproxy/frameprocessor.go +++ b/proxy/pkg/zdmproxy/frameprocessor.go @@ -84,7 +84,7 @@ func setRawFrameStreamId(f *frame.RawFrame, id int16) *frame.RawFrame { if f.Header.StreamId == id { return f } - newHeader := f.Header.Clone() + newHeader := f.Header.DeepCopy() newHeader.StreamId = id return &frame.RawFrame{ Header: newHeader, @@ -98,7 +98,7 @@ func setFrameStreamId(f *frame.Frame, id int16) *frame.Frame { if f.Header.StreamId == id { return f } - newHeader := f.Header.Clone() + newHeader := f.Header.DeepCopy() newHeader.StreamId = id return &frame.Frame{ Header: newHeader, diff --git a/proxy/pkg/zdmproxy/host.go b/proxy/pkg/zdmproxy/host.go index 4033b1c0..bc2168ca 100644 --- a/proxy/pkg/zdmproxy/host.go +++ b/proxy/pkg/zdmproxy/host.go @@ -7,6 +7,8 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" "net" + "strconv" + "strings" ) type Host struct { @@ -48,7 +50,7 @@ func (recv *Host) String() string { hex.EncodeToString(recv.HostId[:])) } -func ParseSystemLocalResult(rs *ParsedRowSet, defaultPort int) (map[string]*optionalColumn, *Host, error) { +func ParseSystemLocalResult(rs *ParsedRowSet, ccEndpoint Endpoint, defaultPort int) (map[string]*optionalColumn, *Host, error) { if len(rs.Rows) < 1 { return nil, nil, fmt.Errorf("could not parse system local query result: query returned %d rows", len(rs.Rows)) } @@ -60,6 +62,10 @@ func ParseSystemLocalResult(rs *ParsedRowSet, defaultPort int) (map[string]*opti row := rs.Rows[0] addr, port, err := ParseRpcAddress(false, row, defaultPort) + if addr == nil { + // could not resolve address from system.local table (e.g. not present in C* 2.0.0) + addr, port, err = ParseEndpoint(ccEndpoint) + } if err != nil { return nil, nil, err } @@ -179,6 +185,9 @@ func ParseRpcAddress(isPeersV2 bool, row *ParsedRow, defaultPort int) (net.IP, i } else { addr = parseRpcAddressLocalOrPeersV1(row) } + if addr == nil { + return nil, -1, nil + } if addr.IsUnspecified() { peer, peerExists := row.GetByColumn("peer") @@ -215,6 +224,20 @@ func ParseRpcAddress(isPeersV2 bool, row *ParsedRow, defaultPort int) (net.IP, i return addr, rpcPort, nil } +func ParseEndpoint(endpoint Endpoint) (net.IP, int, error) { + socketEndpoint := endpoint.GetSocketEndpoint() + parts := strings.Split(socketEndpoint, ":") + if len(parts) != 2 { + return nil, -1, fmt.Errorf("invalid endpoint: %s", socketEndpoint) + } + addr := parts[0] + port, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, -1, fmt.Errorf("invalid endpoint: %s", socketEndpoint) + } + return net.ParseIP(addr), port, nil +} + func parseRpcPortPeersV2(row *ParsedRow) (int, bool) { val, ok := row.GetByColumn("native_port") if ok && val != nil { diff --git a/proxy/pkg/zdmproxy/nativeprotocol.go b/proxy/pkg/zdmproxy/nativeprotocol.go index 8f3515ed..98b0dfe1 100644 --- a/proxy/pkg/zdmproxy/nativeprotocol.go +++ b/proxy/pkg/zdmproxy/nativeprotocol.go @@ -262,10 +262,10 @@ var ( storagePortColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "storage_port", Type: datatype.Int} storagePortSslColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "storage_port_ssl", Type: datatype.Int} thriftVersionColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "thrift_version", Type: datatype.Varchar} - tokensColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "tokens", Type: datatype.NewSetType(datatype.Varchar)} - truncatedAtColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "truncated_at", Type: datatype.NewMapType(datatype.Uuid, datatype.Blob)} + tokensColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "tokens", Type: datatype.NewSet(datatype.Varchar)} + truncatedAtColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "truncated_at", Type: datatype.NewMap(datatype.Uuid, datatype.Blob)} workloadColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "workload", Type: datatype.Varchar} - workloadsColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "workloads", Type: datatype.NewSetType(datatype.Varchar)} + workloadsColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemLocalTableName, Name: "workloads", Type: datatype.NewSet(datatype.Varchar)} ) var systemLocalColumns = []*message.ColumnMetadata{ @@ -367,7 +367,7 @@ func columnFromSelector( } // we are assuming here that resultColumn always refers to an unaliased column because the cql grammar doesn't support alias recursion - aliasedColumn := resultColumn.Clone() + aliasedColumn := resultColumn.DeepCopy() aliasedColumn.Name = s.alias return aliasedColumn, isCountSelector, nil default: @@ -605,9 +605,9 @@ var ( serverIdPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "server_id", Type: datatype.Varchar} storagePortPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "storage_port", Type: datatype.Int} storagePortSslPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "storage_port_ssl", Type: datatype.Int} - tokensPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "tokens", Type: datatype.NewSetType(datatype.Varchar)} + tokensPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "tokens", Type: datatype.NewSet(datatype.Varchar)} workloadPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "workload", Type: datatype.Varchar} - workloadsPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "workloads", Type: datatype.NewSetType(datatype.Varchar)} + workloadsPeersColumn = &message.ColumnMetadata{Keyspace: systemKeyspaceName, Table: systemPeersTableName, Name: "workloads", Type: datatype.NewSet(datatype.Varchar)} ) var systemPeersColumns = []*message.ColumnMetadata{ diff --git a/proxy/pkg/zdmproxy/parametermodifier_test.go b/proxy/pkg/zdmproxy/parametermodifier_test.go index dc0f99cf..0cb85ade 100644 --- a/proxy/pkg/zdmproxy/parametermodifier_test.go +++ b/proxy/pkg/zdmproxy/parametermodifier_test.go @@ -25,7 +25,7 @@ func TestAddValuesToExecuteFrame_NoReplacedTerms(t *testing.T) { PkIndices: nil, Columns: nil, } - fClone := f.Clone() + fClone := f.DeepCopy() replacementTimeUuids := parameterModifier.generateTimeUuids(prepareRequestInfo) newMsg, err := parameterModifier.AddValuesToExecuteFrame(fClone, prepareRequestInfo, variablesMetadata, replacementTimeUuids) require.Same(t, fClone.Body.Message, newMsg) @@ -198,7 +198,7 @@ func TestAddValuesToExecuteFrame_PositionalValues(t *testing.T) { require.Nil(t, err) parameterModifier := NewParameterModifier(generator) queryOpts := &message.QueryOptions{PositionalValues: requestPosVals} - clonedQueryOpts := queryOpts.Clone() // we use this so that we keep the "original" request options + clonedQueryOpts := queryOpts.DeepCopy() // we use this so that we keep the "original" request options f := frame.NewFrame(primitive.ProtocolVersion4, 1, &message.Execute{ QueryId: nil, ResultMetadataId: nil, @@ -344,7 +344,7 @@ func TestAddValuesToExecuteFrame_NamedValues(t *testing.T) { require.Nil(t, err) parameterModifier := NewParameterModifier(generator) queryOpts := &message.QueryOptions{NamedValues: requestNamedVals} - clonedQueryOpts := queryOpts.Clone() // we use this so that we keep the "original" request options + clonedQueryOpts := queryOpts.DeepCopy() // we use this so that we keep the "original" request options f := frame.NewFrame(primitive.ProtocolVersion4, 1, &message.Execute{ QueryId: nil, ResultMetadataId: nil, diff --git a/proxy/pkg/zdmproxy/querymodifier.go b/proxy/pkg/zdmproxy/querymodifier.go index 91eb9c05..d2a0dd25 100644 --- a/proxy/pkg/zdmproxy/querymodifier.go +++ b/proxy/pkg/zdmproxy/querymodifier.go @@ -89,7 +89,7 @@ func (recv *QueryModifier) replaceQueryInBatchMessage( return decodedFrame, []*statementReplacedTerms{}, statementsQueryData, nil } - newFrame := decodedFrame.Clone() + newFrame := decodedFrame.DeepCopy() newBatchMsg, ok := newFrame.Body.Message.(*message.Batch) if !ok { return nil, nil, nil, fmt.Errorf("expected Batch in cloned frame but got %v instead", newFrame.Body.Message.GetOpCode()) @@ -100,7 +100,7 @@ func (recv *QueryModifier) replaceQueryInBatchMessage( return nil, nil, nil, fmt.Errorf("new query data statement index (%v) is greater or equal than "+ "number of batch child statements (%v)", newStmtQueryData.statementIndex, len(newBatchMsg.Children)) } - newBatchMsg.Children[newStmtQueryData.statementIndex].QueryOrId = newStmtQueryData.queryData.getQuery() + newBatchMsg.Children[newStmtQueryData.statementIndex].Query = newStmtQueryData.queryData.getQuery() } return newFrame, statementsReplacedTerms, newStatementsQueryData, nil @@ -117,7 +117,7 @@ func (recv *QueryModifier) replaceQueryInQueryMessage( return decodedFrame, []*statementReplacedTerms{}, statementsQueryData, nil } newQueryData, replacedTerms := stmtQueryData.queryData.replaceNowFunctionCallsWithLiteral() - newFrame := decodedFrame.Clone() + newFrame := decodedFrame.DeepCopy() newQueryMsg, ok := newFrame.Body.Message.(*message.Query) if !ok { return nil, nil, nil, fmt.Errorf("expected Query in cloned frame but got %v instead", newFrame.Body.Message.GetOpCode()) @@ -143,7 +143,7 @@ func (recv *QueryModifier) replaceQueryInPrepareMessage( } else { newQueryData, replacedTerms = stmtQueryData.queryData.replaceNowFunctionCallsWithPositionalBindMarkers() } - newFrame := decodedFrame.Clone() + newFrame := decodedFrame.DeepCopy() newPrepareMsg, ok := newFrame.Body.Message.(*message.Prepare) if !ok { return nil, nil, nil, fmt.Errorf("expected Prepare in cloned frame but got %v instead", newFrame.Body.Message.GetOpCode()) diff --git a/proxy/pkg/zdmproxy/querymodifier_test.go b/proxy/pkg/zdmproxy/querymodifier_test.go index d7634f84..da315367 100644 --- a/proxy/pkg/zdmproxy/querymodifier_test.go +++ b/proxy/pkg/zdmproxy/querymodifier_test.go @@ -123,7 +123,7 @@ func TestReplaceQueryString(t *testing.T) { {"OpCodeBatch Mixed Prepared and Simple", mockBatchWithChildren(t, []*message.BatchChild{ { - QueryOrId: "UPDATE blah SET a = ?, b = 123 " + + Query: "UPDATE blah SET a = ?, b = 123 " + "WHERE f[now()] = ? IF " + "g[123] IN (2, 3, ?, now(), ?, now()) AND " + "d IN ? AND " + @@ -132,12 +132,12 @@ func TestReplaceQueryString(t *testing.T) { Values: []*primitive.Value{}, // not used by the SUT (system under test) }, { - QueryOrId: []byte{0}, - Values: []*primitive.Value{}, // not used by the SUT + Id: []byte{0}, + Values: []*primitive.Value{}, // not used by the SUT }, { - QueryOrId: "DELETE FROM blah WHERE b = 123 AND a = now()", - Values: []*primitive.Value{}, // not used by the SUT + Query: "DELETE FROM blah WHERE b = 123 AND a = now()", + Values: []*primitive.Value{}, // not used by the SUT }}), []*statementReplacedTerms{ {statementIndex: 0, replacedTerms: []*term{ diff --git a/proxy/pkg/zdmproxy/response.go b/proxy/pkg/zdmproxy/response.go index 9c531a4c..c328c3e2 100644 --- a/proxy/pkg/zdmproxy/response.go +++ b/proxy/pkg/zdmproxy/response.go @@ -1,6 +1,10 @@ package zdmproxy -import "github.com/datastax/go-cassandra-native-protocol/frame" +import ( + "fmt" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" +) type Response struct { responseFrame *frame.RawFrame @@ -37,3 +41,20 @@ func (r *Response) GetStreamId() int16 { return r.requestFrame.Header.StreamId } } + +type ResponseError struct { + Response *frame.Frame +} + +func (pre *ResponseError) Error() string { + return fmt.Sprintf("%v", pre.Response.Body.Message) +} + +func (pre *ResponseError) IsProtocolError() bool { + switch pre.Response.Body.Message.(type) { + case *message.ProtocolError: + return true + default: + return false + } +} diff --git a/proxy/pkg/zdmproxy/streamidmapper.go b/proxy/pkg/zdmproxy/streamidmapper.go index 02a78d21..652e9ff6 100644 --- a/proxy/pkg/zdmproxy/streamidmapper.go +++ b/proxy/pkg/zdmproxy/streamidmapper.go @@ -2,7 +2,10 @@ package zdmproxy import ( "fmt" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/datastax/zdm-proxy/proxy/pkg/metrics" + "math" "sync" ) @@ -17,30 +20,35 @@ type StreamIdMapper interface { type streamIdMapper struct { sync.Mutex - idMapper map[int16]int16 - clusterIds chan int16 - metrics metrics.Gauge + idMapper map[int16]int16 + clusterIds chan int16 + metrics metrics.Gauge + protocolVersion primitive.ProtocolVersion } type internalStreamIdMapper struct { - clusterIds chan int16 - metrics metrics.Gauge + clusterIds chan int16 + metrics metrics.Gauge + protocolVersion primitive.ProtocolVersion } // NewInternalStreamIdMapper is used to assign unique ids to frames that have no initial stream id defined, such as // CQL queries initiated by the proxy or ASYNC requests. -func NewInternalStreamIdMapper(maxStreamIds int, metrics metrics.Gauge) StreamIdMapper { - streamIdsQueue := make(chan int16, maxStreamIds) - for i := int16(0); i < int16(maxStreamIds); i++ { +func NewInternalStreamIdMapper(protocolVersion primitive.ProtocolVersion, config *config.Config, metrics metrics.Gauge) StreamIdMapper { + maximumStreamIds := maxStreamIds(protocolVersion, config) + streamIdsQueue := make(chan int16, maximumStreamIds) + for i := int16(0); i < int16(maximumStreamIds); i++ { streamIdsQueue <- i } return &internalStreamIdMapper{ - clusterIds: streamIdsQueue, - metrics: metrics, + protocolVersion: protocolVersion, + clusterIds: streamIdsQueue, + metrics: metrics, } } func (csid *internalStreamIdMapper) GetNewIdFor(_ int16) (int16, error) { + // do not validate provided stream ID select { case id := <-csid.clusterIds: if csid.metrics != nil { @@ -73,20 +81,25 @@ func (csid *internalStreamIdMapper) Close() { } } -func NewStreamIdMapper(maxStreamIds int, metrics metrics.Gauge) StreamIdMapper { +func NewStreamIdMapper(protocolVersion primitive.ProtocolVersion, config *config.Config, metrics metrics.Gauge) StreamIdMapper { + maximumStreamIds := maxStreamIds(protocolVersion, config) idMapper := make(map[int16]int16) - streamIdsQueue := make(chan int16, maxStreamIds) - for i := int16(0); i < int16(maxStreamIds); i++ { + streamIdsQueue := make(chan int16, maximumStreamIds) + for i := int16(0); i < int16(maximumStreamIds); i++ { streamIdsQueue <- i } return &streamIdMapper{ - idMapper: idMapper, - clusterIds: streamIdsQueue, - metrics: metrics, + protocolVersion: protocolVersion, + idMapper: idMapper, + clusterIds: streamIdsQueue, + metrics: metrics, } } func (sim *streamIdMapper) GetNewIdFor(streamId int16) (int16, error) { + if err := validateStreamId(sim.protocolVersion, streamId); err != nil { + return -1, err + } select { case id := <-sim.clusterIds: if sim.metrics != nil { @@ -134,3 +147,26 @@ func (sim *streamIdMapper) Close() { sim.metrics.Subtract(cap(sim.clusterIds) - len(sim.clusterIds)) } } + +func maxStreamIds(protoVer primitive.ProtocolVersion, conf *config.Config) int { + maxSupported := maxStreamIdsV3 + if protoVer == primitive.ProtocolVersion2 { + maxSupported = maxStreamIdsV2 + } + if maxSupported < conf.ProxyMaxStreamIds { + return maxSupported + } + return conf.ProxyMaxStreamIds +} + +func validateStreamId(version primitive.ProtocolVersion, streamId int16) error { + if version < primitive.ProtocolVersion3 { + if streamId > math.MaxInt8 || streamId < math.MinInt8 { + return fmt.Errorf("stream id out of range for %v: %v", version, streamId) + } + } + if streamId < 0 { + return fmt.Errorf("negative stream id: %v", streamId) + } + return nil +} diff --git a/proxy/pkg/zdmproxy/streamidmapper_test.go b/proxy/pkg/zdmproxy/streamidmapper_test.go index 92d04792..d72e5f59 100644 --- a/proxy/pkg/zdmproxy/streamidmapper_test.go +++ b/proxy/pkg/zdmproxy/streamidmapper_test.go @@ -1,20 +1,22 @@ package zdmproxy import ( + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/proxy/pkg/config" "github.com/stretchr/testify/require" "sync" "testing" ) func TestStreamIdMapper(t *testing.T) { - var mapper = NewStreamIdMapper(2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, &config.Config{ProxyMaxStreamIds: 2048}, nil) var syntheticId, _ = mapper.GetNewIdFor(1000) var originalId, _ = mapper.ReleaseId(syntheticId) require.Equal(t, int16(1000), originalId) } func BenchmarkStreamIdMapper(b *testing.B) { - var mapper = NewStreamIdMapper(2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, &config.Config{ProxyMaxStreamIds: 2048}, nil) for i := 0; i < b.N; i++ { var originalId = int16(i) var syntheticId, _ = mapper.GetNewIdFor(originalId) @@ -28,7 +30,7 @@ func TestConcurrentStreamIdMapper(t *testing.T) { var wg = sync.WaitGroup{} wg.Add(concurrency) for i := 0; i < concurrency; i++ { - var mapper = NewStreamIdMapper(2048, nil) + var mapper = NewStreamIdMapper(primitive.ProtocolVersion3, &config.Config{ProxyMaxStreamIds: 2048}, nil) getAndReleaseIds(t, mapper, int16(i), requestCount, &wg) } wg.Wait()