Skip to content

Commit

Permalink
Add a flag to signal that failed query could have been executed and i…
Browse files Browse the repository at this point in the history
…t might be not safe to retry it
  • Loading branch information
sylwiaszunejko committed Dec 18, 2024
1 parent e12494d commit 11bc473
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 21 deletions.
46 changes: 31 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1101,13 +1101,13 @@ func (c *Conn) addCall(call *callReq) error {

func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
return nil, &QueryError{err: ctxErr, potentiallyExecuted: false}
}

// TODO: move tracer onto conn
stream, ok := c.streams.GetStream()
if !ok {
return nil, ErrNoStreams
return nil, &QueryError{err: ErrNoStreams, potentiallyExecuted: false}
}

// resp is basically a waiting semaphore protecting the framer
Expand All @@ -1125,7 +1125,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
}

if err := c.addCall(call); err != nil {
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: false}
}

// After this point, we need to either read from call.resp or close(call.timeout)
Expand Down Expand Up @@ -1157,7 +1157,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil
// check above could fail.
c.releaseStream(call)
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: false}
}

n, err := c.w.writeContext(ctx, framer.buf)
Expand Down Expand Up @@ -1185,7 +1185,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// send a frame on, with all the streams used up and not returned.
c.closeWithError(err)
}
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: true}
}

var timeoutCh <-chan time.Time
Expand Down Expand Up @@ -1222,7 +1222,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// connection to close.
c.releaseStream(call)
}
return nil, resp.err
return nil, &QueryError{err: resp.err, potentiallyExecuted: true}
}
// dont release the stream if detect a timeout as another request can reuse
// that stream and get a response for the old request, which we have no
Expand All @@ -1233,20 +1233,20 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
defer c.releaseStream(call)

if v := resp.framer.header.version.version(); v != c.version {
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
return nil, &QueryError{err: NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version), potentiallyExecuted: true}
}

return resp.framer, nil
case <-timeoutCh:
close(call.timeout)
c.handleTimeout()
return nil, ErrTimeoutNoResponse
return nil, &QueryError{err: ErrTimeoutNoResponse, potentiallyExecuted: true}
case <-ctxDone:
close(call.timeout)
return nil, ctx.Err()
return nil, &QueryError{err: ctx.Err(), potentiallyExecuted: true}
case <-c.ctx.Done():
close(call.timeout)
return nil, ErrConnectionClosed
return nil, &QueryError{err: ErrConnectionClosed, potentiallyExecuted: true}
}
}

Expand Down Expand Up @@ -1906,11 +1906,14 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
}

var (
ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
ErrNoStreams = errors.New("gocql: no streams available on connection")
ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
ErrNoStreams = errors.New("gocql: no streams available on connection")
ErrHostDown = errors.New("gocql: host is nil or down")
ErrNoPool = errors.New("gocql: host does not have a pool")
ErrNoConnectionsInPool = errors.New("gocql: host pool does not have connections")
)

type ErrSchemaMismatch struct {
Expand All @@ -1920,3 +1923,16 @@ type ErrSchemaMismatch struct {
func (e *ErrSchemaMismatch) Error() string {
return fmt.Sprintf("gocql: cluster schema versions not consistent: %+v", e.schemas)
}

type QueryError struct {
err error
potentiallyExecuted bool
}

func (e *QueryError) Error() string {
return fmt.Sprintf("%s (potentially executed: %v)", e.err.Error(), e.potentiallyExecuted)
}

func (e *QueryError) Unwrap() error {
return e.err
}
12 changes: 6 additions & 6 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func TestCancel(t *testing.T) {
wg.Add(1)

go func() {
if err := qry.Exec(); err != context.Canceled {
if err := qry.Exec(); !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err)
}
wg.Done()
Expand Down Expand Up @@ -573,7 +573,7 @@ func TestQueryTimeout(t *testing.T) {

select {
case err := <-ch:
if err != ErrTimeoutNoResponse {
if !errors.Is(err, ErrTimeoutNoResponse) {
t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err)
}
case <-time.After(40*time.Millisecond + db.cfg.Timeout):
Expand Down Expand Up @@ -667,8 +667,8 @@ func TestQueryTimeoutClose(t *testing.T) {
t.Fatal("timedout waiting to get a response once cluster is closed")
}

if err != ErrConnectionClosed {
t.Fatalf("expected to get %v got %v", ErrConnectionClosed, err)
if !errors.Is(err, ErrConnectionClosed) {
t.Fatalf("expected to get %v or an error wrapping it, got %v", ErrConnectionClosed, err)
}
}

Expand Down Expand Up @@ -721,7 +721,7 @@ func TestContext_Timeout(t *testing.T) {
cancel()

err = db.Query("timeout").WithContext(ctx).Exec()
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
}
}
Expand Down Expand Up @@ -838,7 +838,7 @@ func TestContext_CanceledBeforeExec(t *testing.T) {
cancel()

err = db.Query("timeout").WithContext(ctx).Exec()
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
}

Expand Down
4 changes: 4 additions & 0 deletions query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne

lastErr = iter.err

if customErr, ok := iter.err.(*QueryError); ok && customErr.potentiallyExecuted && !qry.IsIdempotent() {
return iter
}

var retry_type RetryType
if use_lwt_rt {
retry_type = lwt_rt.GetRetryTypeLWT(iter.err)
Expand Down

0 comments on commit 11bc473

Please sign in to comment.