diff --git a/parser/lexer_test.go b/parser/lexer_test.go index 85810bf..4b5e1ef 100644 --- a/parser/lexer_test.go +++ b/parser/lexer_test.go @@ -86,11 +86,15 @@ func TestLexerIdentifiers(t *testing.T) { {`"system"`, tkIdentifier, "system"}, {`"system"`, tkIdentifier, "system"}, {`"System"`, tkIdentifier, "System"}, - {`""""`, tkIdentifier, "\""}, - {`""""""`, tkIdentifier, "\"\""}, - {`"A"""""`, tkIdentifier, "A\"\""}, - {`"""A"""`, tkIdentifier, "\"A\""}, - {`"""""A"`, tkIdentifier, "\"\"A"}, + // below test verify correct escaping double quote character as per CQL definition: + // identifier ::= unquoted_identifier | quoted_identifier + // unquoted_identifier ::= re('[a-zA-Z][link:[a-zA-Z0-9]]*') + // quoted_identifier ::= '"' (any character where " can appear if doubled)+ '"' + {`""""`, tkIdentifier, "\""}, // outermost quotes indicate quoted string, inner two double quotes shall be treated as single quote + {`""""""`, tkIdentifier, "\"\""}, // same as above, but 4 inner quotes result in 2 quotes + {`"A"""""`, tkIdentifier, "A\"\""}, // outermost quotes indicate quoted string, 4 quotes after A result in 2 quotes + {`"""A"""`, tkIdentifier, "\"A\""}, // outermost quotes indicate quoted string, 2 quotes before and after A result in single quotes + {`"""""A"`, tkIdentifier, "\"\"A"}, // analogical to previous tests {`";`, tkInvalid, ""}, {`"""`, tkIdentifier, ""}, } diff --git a/proxy/proxy.go b/proxy/proxy.go index 57443cb..fc585f4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -667,11 +667,9 @@ func (c *client) handleExecute(raw *frame.RawFrame, msg *partialExecute, customP } func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery, customPayload map[string][]byte) { - c.proxy.logger.Debug("handling query", zap.String("query", msg.query), zap.Int16("stream", raw.Header.StreamId)) - handled, stmt, err := parser.IsQueryHandled(parser.IdentifierFromString(c.keyspace), msg.query) - if handled { + c.proxy.logger.Debug("Query handled by proxy", zap.String("query", msg.query), zap.Int16("stream", raw.Header.StreamId)) if err != nil { c.proxy.logger.Error("error parsing query to see if it's handled", zap.Error(err)) c.send(raw.Header, &message.Invalid{ErrorMessage: err.Error()}) @@ -679,6 +677,7 @@ func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery, customPaylo c.interceptSystemQuery(raw.Header, stmt) } } else { + c.proxy.logger.Debug("Query not handled by proxy, forwarding", zap.String("query", msg.query), zap.Int16("stream", raw.Header.StreamId)) c.execute(raw, c.getDefaultIdempotency(customPayload), c.keyspace, msg) } } @@ -813,9 +812,20 @@ func (c *client) interceptSystemQuery(hdr *frame.Header, stmt interface{}) { } case *parser.UseStatement: if _, err := c.proxy.maybeCreateSession(hdr.Version, s.Keyspace); err != nil { - c.send(hdr, &message.ServerError{ErrorMessage: "Proxy unable to create new session for keyspace"}) + errMsg := "Proxy unable to create new session for keyspace" + var cqlError *proxycore.CqlError + if errors.As(err, &cqlError) { + // copy detailed error reason from downstream message + errMsg = cqlError.Message.GetErrorMessage() + } + c.send(hdr, &message.ServerError{ErrorMessage: errMsg}) } else { c.keyspace = s.Keyspace + // We might have received a quoted keyspace name in the UseStatement so remove any + // quotes before sending back this result message. This keeps us consistent with + // how Cassandra implements the same functionality and avoids any issues with + // drivers sending follow-on "USE" requests after wrapping the keyspace name in + // quotes. ks := parser.IdentifierFromString(s.Keyspace) c.send(hdr, &message.SetKeyspaceResult{Keyspace: ks.ID()}) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 46e1c44..995b597 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -206,13 +206,48 @@ func TestProxy_UseKeyspace(t *testing.T) { cl := connectTestClient(t, ctx, proxyContactPoint) - resp, err := cl.SendAndReceive(ctx, frame.NewFrame(primitive.ProtocolVersion4, 0, &message.Query{Query: "USE system"})) + testKeyspaces := []string{"system", "\"system\""} + for _, testKeyspace := range testKeyspaces { + + resp, err := cl.SendAndReceive(ctx, frame.NewFrame(primitive.ProtocolVersion4, 0, &message.Query{Query: "USE " + testKeyspace})) + require.NoError(t, err) + + assert.Equal(t, primitive.OpCodeResult, resp.Header.OpCode) + res, ok := resp.Body.Message.(*message.SetKeyspaceResult) + require.True(t, ok, "expected set keyspace result") + assert.Equal(t, "system", res.Keyspace) + } +} + +func TestProxy_UseKeyspace_Error(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + tester, proxyContactPoint, err := setupProxyTest(ctx, 1, proxycore.MockRequestHandlers{ + primitive.OpCodeQuery: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message { + qry := frm.Body.Message.(*message.Query) + if qry.Query == "USE non_existing" { + return &message.ServerError{ + ErrorMessage: "Keyspace 'non_existing' does not exist", + } + } + return cl.InterceptQuery(frm.Header, frm.Body.Message.(*message.Query)) + }}) + defer func() { + cancel() + tester.shutdown() + }() require.NoError(t, err) - assert.Equal(t, primitive.OpCodeResult, resp.Header.OpCode) - res, ok := resp.Body.Message.(*message.SetKeyspaceResult) - require.True(t, ok, "expected set keyspace result") - assert.Equal(t, "system", res.Keyspace) + cl := connectTestClient(t, ctx, proxyContactPoint) + + resp, err := cl.SendAndReceive(ctx, frame.NewFrame(primitive.ProtocolVersion4, 0, &message.Query{Query: "USE non_existing"})) + require.NoError(t, err) + + assert.Equal(t, primitive.OpCodeError, resp.Header.OpCode) + res, ok := resp.Body.Message.(*message.ServerError) + require.True(t, ok) + // make sure that CQL Proxy returns the same error of 'USE keyspace' command + // as backend C* cluster has and does not wrap it inside a custom one + assert.Equal(t, "Keyspace 'non_existing' does not exist", res.ErrorMessage) } func TestProxy_NegotiateProtocolV5(t *testing.T) {