From 7e0adabfda5d45c7ae92392c3c1789db76d2035b Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 19 Nov 2024 18:43:25 +0100 Subject: [PATCH] Extract getSchemaAgreement to the separate function and unit test it against zero-token node --- conn.go | 83 +++++++++++++++++++++++---------------------- conn_test.go | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 41 deletions(-) diff --git a/conn.go b/conn.go index df7682d2d..e0a9c7ada 100644 --- a/conn.go +++ b/conn.go @@ -1815,6 +1815,42 @@ func (c *Conn) querySystemLocal(ctx context.Context) *Iter { return c.query(ctx, "SELECT * FROM system.local WHERE key='local'"+usingClause) } +func getSchemaAgreement(queryLocalSchemasRows []string, querySystemPeersRows []map[string]interface{}, connectAddress net.IP, port int, translateAddressPort func(addr net.IP, port int) (net.IP, int), logger StdLogger) (err error) { + versions := make(map[string]struct{}) + + for _, row := range querySystemPeersRows { + var host *HostInfo + host, err = hostInfoFromMap(row, &HostInfo{connectAddress: connectAddress, port: port}, translateAddressPort) + if err != nil { + return err + } + if !isValidPeer(host) || host.schemaVersion == "" { + logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host) + continue + } else if isZeroToken(host) { + continue + } + + versions[host.schemaVersion] = struct{}{} + } + + for _, schemaVersion := range queryLocalSchemasRows { + versions[schemaVersion] = struct{}{} + schemaVersion = "" + } + + if len(versions) > 1 { + schemas := make([]string, 0, len(versions)) + for schema := range versions { + schemas = append(schemas, schema) + } + + return &ErrSchemaMismatch{schemas: schemas} + } + + return nil +} + func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { usingClause := "" if c.session.control != nil { @@ -1822,7 +1858,6 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } var localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" + usingClause - var versions map[string]struct{} var schemaVersion string endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement) @@ -1840,61 +1875,29 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } } - getSchemaAgreement := func() error { + for time.Now().Before(endDeadline) { iter := c.querySystemPeers(ctx, c.host.version) - - versions = make(map[string]struct{}) - - var rows []map[string]interface{} - rows, err = iter.SliceMap() + var systemPeersRows []map[string]interface{} + systemPeersRows, err = iter.SliceMap() if err != nil { return err } - - for _, row := range rows { - var host *HostInfo - host, err = hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}, c.session.cfg.translateAddressPort) - if err != nil { - return err - } - if !isValidPeer(host) || host.schemaVersion == "" { - c.logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host) - continue - } else if isZeroToken(host) { - continue - } - - versions[host.schemaVersion] = struct{}{} - } - if err = iter.Close(); err != nil { return err } + schemaVersions := []string{} + iter = c.query(ctx, localSchemas) for iter.Scan(&schemaVersion) { - versions[schemaVersion] = struct{}{} + schemaVersions = append(schemaVersions, schemaVersion) schemaVersion = "" } if err = iter.Close(); err != nil { return err } - - if len(versions) > 1 { - schemas := make([]string, 0, len(versions)) - for schema := range versions { - schemas = append(schemas, schema) - } - - return &ErrSchemaMismatch{schemas: schemas} - } - - return nil - } - - for time.Now().Before(endDeadline) { - err = getSchemaAgreement() + err = getSchemaAgreement(schemaVersions, systemPeersRows, c.host.ConnectAddress(), c.session.cfg.Port, c.session.cfg.translateAddressPort, c.logger) if err == ErrConnectionClosed || err == nil { return err diff --git a/conn_test.go b/conn_test.go index 479ebe6f8..36e2536e2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -14,7 +14,6 @@ import ( "crypto/x509" "errors" "fmt" - "github.com/google/go-cmp/cmp" "io" "io/ioutil" "math/rand" @@ -26,6 +25,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/gocql/gocql/internal/streams" ) @@ -1394,3 +1396,94 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { return framer, nil } + +func TestGetSchemaAgreement(t *testing.T) { + host_id1, _ := ParseUUID("b2035fd9-e0ca-4857-8c45-e63c00fb7c43") + host_id2, _ := ParseUUID("4b21ee4c-acea-4267-8e20-aaed5361a0dd") + host_id3, _ := ParseUUID("dfef4a22-b8d8-47e9-aee5-8c19d4b7a9e3") + + schema_version1, _ := ParseUUID("af810386-a694-11ef-81fa-3aea73156247") + schema_version2, _ := ParseUUID("875a938a-a695-11ef-4314-85c8ef0ebaa2") + + peersRows := []map[string]interface{}{ + { + "data_center": "datacenter1", + "host_id": host_id1, + "peer": "127.0.0.3", + "preferred_ip": "127.0.0.3", + "rack": "rack1", + "release_version": "3.0.8", + "rpc_address": "127.0.0.3", + "schema_version": schema_version1, + "tokens": []string{"-1296227678594315580994457470329811265"}, + }, + { + "data_center": "datacenter1", + "host_id": host_id2, + "peer": "127.0.0.2", + "preferred_ip": "127.0.0.2", + "rack": "rack1", + "release_version": "3.0.8", + "rpc_address": "127.0.0.2", + "schema_version": schema_version1, + "tokens": []string{"-1129762924682054333"}, + }, + { + "data_center": "datacenter2", + "host_id": host_id3, + "peer": "127.0.0.5", + "preferred_ip": "127.0.0.5", + "rack": "rack1", + "release_version": "3.0.8", + "rpc_address": "127.0.0.5", + "schema_version": schema_version2, + "tokens": []string{}, + }, + } + + translateAddressPort := func(addr net.IP, port int) (net.IP, int) { + return addr, port + } + + var logger StdLogger + + t.Run("SchemaNotConsistent", func(t *testing.T) { + err := getSchemaAgreement( + []string{"875a938a-a695-11ef-4314-85c8ef0ebaa2"}, + peersRows, + net.ParseIP("127.0.0.1"), + 9042, + translateAddressPort, + logger, + ) + + assert.Error(t, err, "error expected when local schema is different then others") + }) + + t.Run("ZeroTokenNodeSchemaNotConsistent", func(t *testing.T) { + err := getSchemaAgreement( + []string{"af810386-a694-11ef-81fa-3aea73156247"}, + peersRows, + net.ParseIP("127.0.0.1"), + 9042, + translateAddressPort, + logger, + ) + + assert.NoError(t, err, "expected no error when zero-token node has different schema because it is ommitted") + }) + + t.Run("SchemaConsistent", func(t *testing.T) { + peersRows[2]["schema_version"] = schema_version1 + err := getSchemaAgreement( + []string{"af810386-a694-11ef-81fa-3aea73156247"}, + peersRows, + net.ParseIP("127.0.0.1"), + 9042, + translateAddressPort, + logger, + ) + + assert.NoError(t, err, "expected no error when all nodes have the same schema") + }) +}