Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ring describer refactor #368

Merged
merged 8 commits into from
Dec 11, 2024
90 changes: 43 additions & 47 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
// Deprecated.
var TimeoutLimit int64 = 0

type ConnInterface interface {
Close()
exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error)
awaitSchemaAgreement(ctx context.Context) error
executeQuery(ctx context.Context, qry *Query) *Iter
querySystem(ctx context.Context, query string) *Iter
getIsSchemaV2() bool
setSchemaV2(s bool)
query(ctx context.Context, statement string, values ...interface{}) (iter *Iter)
getScyllaSupported() scyllaSupported
}

// Conn is a single connection to a Cassandra node. It can be used to execute
// queries, but users are usually advised to use a more reliable, higher
// level API.
Expand Down Expand Up @@ -212,6 +224,18 @@ type Conn struct {
tabletsRoutingV1 int32
}

func (c *Conn) getIsSchemaV2() bool {
return c.isSchemaV2
}

func (c *Conn) setSchemaV2(s bool) {
c.isSchemaV2 = s
}

func (c *Conn) getScyllaSupported() scyllaSupported {
return c.scyllaSupported
}

// connect establishes a connection to a Cassandra node using session's connection config.
func (s *Session) connect(ctx context.Context, host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
return s.dial(ctx, host, s.connCfg, errorHandler)
Expand Down Expand Up @@ -350,6 +374,10 @@ func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error {
c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
}

if c.isScyllaConn() { // ScyllaDB does not support system.peers_v2
c.setSchemaV2(false)
}

go c.serve(ctx)
go c.heartBeat(ctx)

Expand Down Expand Up @@ -1768,52 +1796,19 @@ func (c *Conn) query(ctx context.Context, statement string, values ...interface{
return c.executeQuery(ctx, q)
}

func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter {
func (c *Conn) querySystem(ctx context.Context, query string) *Iter {
usingClause := ""
if c.session.control != nil {
usingClause = c.session.usingTimeoutClause
}
var (
peerSchema = "SELECT * FROM system.peers" + usingClause
peerV2Schemas = "SELECT * FROM system.peers_v2" + usingClause
)

c.mu.Lock()
if isScyllaConn((c)) { // ScyllaDB does not support system.peers_v2
c.isSchemaV2 = false
}

isSchemaV2 := c.isSchemaV2
c.mu.Unlock()

if version.AtLeast(4, 0, 0) && isSchemaV2 {
// Try "system.peers_v2" and fallback to "system.peers" if it's not found
iter := c.query(ctx, peerV2Schemas)

err := iter.checkErrAndNotFound()
if err != nil {
if errFrame, ok := err.(errorFrame); ok && errFrame.code == ErrCodeInvalid { // system.peers_v2 not found, try system.peers
c.mu.Lock()
c.isSchemaV2 = false
c.mu.Unlock()
return c.query(ctx, peerSchema)
} else {
return iter
}
}
return iter
} else {
return c.query(ctx, peerSchema)
}
queryStmt := query + usingClause
return c.query(ctx, queryStmt)
}

func (c *Conn) querySystemLocal(ctx context.Context) *Iter {
usingClause := ""
if c.session.control != nil {
usingClause = c.session.usingTimeoutClause
}
return c.query(ctx, "SELECT * FROM system.local WHERE key='local'"+usingClause)
}
const qrySystemPeers = "SELECT * FROM system.peers"
const qrySystemPeersV2 = "SELECT * FROM system.peers_2"

const qrySystemLocal = "SELECT * FROM system.local WHERE key='local'"

func getSchemaAgreement(queryLocalSchemasRows []string, querySystemPeersRows []map[string]interface{}, connectAddress net.IP, port int, translateAddressPort func(addr net.IP, port int) (net.IP, int), logger StdLogger) (err error) {
versions := make(map[string]struct{})
Expand Down Expand Up @@ -1850,11 +1845,7 @@ func getSchemaAgreement(queryLocalSchemasRows []string, querySystemPeersRows []m
}

func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
usingClause := ""
if c.session.control != nil {
usingClause = c.session.usingTimeoutClause
}
var localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" + usingClause
var localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"

var schemaVersion string

Expand All @@ -1874,7 +1865,12 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
}

for time.Now().Before(endDeadline) {
iter := c.querySystemPeers(ctx, c.host.version)
var iter *Iter
if c.getIsSchemaV2() {
iter = c.querySystem(ctx, qrySystemPeersV2)
} else {
iter = c.querySystem(ctx, qrySystemPeers)
}
var systemPeersRows []map[string]interface{}
systemPeersRows, err = iter.SliceMap()
if err != nil {
Expand All @@ -1886,7 +1882,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {

schemaVersions := []string{}

iter = c.query(ctx, localSchemas)
iter = c.querySystem(ctx, localSchemas)
for iter.Scan(&schemaVersion) {
schemaVersions = append(schemaVersions, schemaVersion)
schemaVersion = ""
Expand Down
2 changes: 1 addition & 1 deletion connectionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func (pool *hostConnPool) initConnPicker(conn *Conn) {
return
}

if isScyllaConn(conn) {
if conn.isScyllaConn() {
pool.connPicker = newScyllaConnPicker(conn)
return
}
Expand Down
64 changes: 23 additions & 41 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ const (
controlConnClosing = -1
)

type controlConnection interface {
getConn() *connHost
awaitSchemaAgreement() error
query(statement string, values ...interface{}) (iter *Iter)
discoverProtocol(hosts []*HostInfo) (int, error)
connect(hosts []*HostInfo) error
close()
getSession() *Session
}

// Ensure that the atomic variable is aligned to a 64bit boundary
// so that atomic operations can be applied on 32bit architectures.
type controlConn struct {
Expand All @@ -49,6 +59,10 @@ type controlConn struct {
quit chan struct{}
}

func (c *controlConn) getSession() *Session {
return c.session
}

func createControlConn(session *Session) *controlConn {

control := &controlConn{
Expand Down Expand Up @@ -264,18 +278,18 @@ func (c *controlConn) connect(hosts []*HostInfo) error {
}

type connHost struct {
conn *Conn
conn ConnInterface
host *HostInfo
}

func (c *controlConn) setupConn(conn *Conn) error {
// we need up-to-date host info for the filterHost call below
iter := conn.querySystemLocal(context.TODO())
iter := conn.querySystem(context.TODO(), qrySystemLocal)
defaultPort := 9042
if tcpAddr, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok {
defaultPort = tcpAddr.Port
}
host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, defaultPort)
host, err := hostInfoFromIter(iter, conn.host.connectAddress, defaultPort, c.session.cfg.translateAddressPort)
if err != nil {
return err
}
Expand Down Expand Up @@ -359,7 +373,7 @@ func (c *controlConn) reconnect() {
return
}

err = c.session.refreshRing()
err = c.session.refreshRingNow()
if err != nil {
c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err)
}
Expand Down Expand Up @@ -462,45 +476,14 @@ func (c *controlConn) writeFrame(w frameBuilder) (frame, error) {
return framer.parseFrame()
}

func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
const maxConnectAttempts = 5
connectAttempts := 0

for i := 0; i < maxConnectAttempts; i++ {
ch := c.getConn()
if ch == nil {
if connectAttempts > maxConnectAttempts {
break
}

connectAttempts++

c.reconnect()
continue
}

return fn(ch)
}

return &Iter{err: errNoControl}
}

func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
return c.withConnHost(func(ch *connHost) *Iter {
return fn(ch.conn)
})
}

// query will return nil if the connection is closed or nil
func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)

for {
iter = c.withConn(func(conn *Conn) *Iter {
// we want to keep the query on the control connection
q.conn = conn
return conn.executeQuery(context.TODO(), q)
})
ch := c.getConn()
q.conn = ch.conn.(*Conn)
iter = ch.conn.executeQuery(context.TODO(), q)

if gocqlDebug && iter.err != nil {
c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err)
Expand All @@ -516,9 +499,8 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
}

func (c *controlConn) awaitSchemaAgreement() error {
return c.withConn(func(conn *Conn) *Iter {
return &Iter{err: conn.awaitSchemaAgreement(context.TODO())}
}).err
ch := c.getConn()
return (&Iter{err: ch.conn.awaitSchemaAgreement(context.TODO())}).err
}

func (c *controlConn) close() {
Expand Down
18 changes: 16 additions & 2 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ type FrameHeaderObserver interface {
ObserveFrameHeader(context.Context, ObservedFrameHeader)
}

type framerInterface interface {
ReadBytesInternal() ([]byte, error)
GetCustomPayload() map[string][]byte
GetHeaderWarnings() []string
}

// a framer is responsible for reading, writing and parsing frames on a single stream
type framer struct {
proto byte
Expand Down Expand Up @@ -1866,7 +1872,7 @@ func (f *framer) readStringList() []string {
return l
}

func (f *framer) readBytesInternal() ([]byte, error) {
func (f *framer) ReadBytesInternal() ([]byte, error) {
size := f.readInt()
if size < 0 {
return nil, nil
Expand All @@ -1883,7 +1889,7 @@ func (f *framer) readBytesInternal() ([]byte, error) {
}

func (f *framer) readBytes() []byte {
l, err := f.readBytesInternal()
l, err := f.ReadBytesInternal()
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -2015,6 +2021,14 @@ func (f *framer) writeCustomPayload(customPayload *map[string][]byte) {
}
}

func (f *framer) GetCustomPayload() map[string][]byte {
return f.customPayload
}

func (f *framer) GetHeaderWarnings() []string {
return f.header.warnings
}

// these are protocol level binary types
func (f *framer) writeInt(n int32) {
f.buf = appendInt(f.buf, n)
Expand Down
Loading
Loading