Skip to content

Commit

Permalink
SNOW-726742: Implement GetQueryId for statements
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Aug 28, 2023
1 parent 18daec5 commit 9561417
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 6 deletions.
30 changes: 30 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,36 @@ in place of the default randomized request ID. For example:
ctxWithID := WithRequestID(ctx, requestID)
rows, err := db.QueryContext(ctxWithID, query)
# Last query ID
If you need query ID for your query you have to use raw connection.
For queries:
```
err := conn.Raw(func(x any) error {
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1")
rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
rows.(SnowflakeRows).GetQueryID()
stmt.(SnowflakeStmt).GetQueryID()
return nil
}
```
For execs:
```
err := conn.Raw(func(x any) error {
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)")
rows, err := stmt.(driver.StmtQueryContext).ExecContext(ctx, nil)
rows.(SnowflakeResult).GetQueryID()
stmt.(SnowflakeStmt).GetQueryID()
return nil
}
```
# Canceling Query by CtrlC
From 0.5.0, a signal handling responsibility has moved to the applications. If you want to cancel a
Expand Down
30 changes: 24 additions & 6 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ import (
"database/sql/driver"
)

// SnowflakeStmt represents the prepared statement in driver.
type SnowflakeStmt interface {
GetQueryID() string
}

type snowflakeStmt struct {
sc *snowflakeConn
query string
sc *snowflakeConn
query string
lastQueryID string
}

func (stmt *snowflakeStmt) Close() error {
Expand All @@ -26,20 +32,32 @@ func (stmt *snowflakeStmt) NumInput() int {

func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext")
return stmt.sc.ExecContext(ctx, stmt.query, args)
result, err := stmt.sc.ExecContext(ctx, stmt.query, args)
stmt.lastQueryID = result.(SnowflakeResult).GetQueryID()
return result, err
}

func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext")
return stmt.sc.QueryContext(ctx, stmt.query, args)
rows, err := stmt.sc.QueryContext(ctx, stmt.query, args)
stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID()
return rows, err
}

func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec")
return stmt.sc.Exec(stmt.query, args)
result, err := stmt.sc.Exec(stmt.query, args)
stmt.lastQueryID = result.(SnowflakeResult).GetQueryID()
return result, err
}

func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query")
return stmt.sc.Query(stmt.query, args)
rows, err := stmt.sc.Query(stmt.query, args)
stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID()
return rows, err
}

func (stmt *snowflakeStmt) GetQueryID() string {
return stmt.lastQueryID
}
127 changes: 127 additions & 0 deletions statement_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved.
//lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility.

package gosnowflake

Expand Down Expand Up @@ -287,3 +288,129 @@ func TestUnitCheckQueryStatus(t *testing.T) {
t.Fatalf("unexpected error code. expected: %v, got: %v", ErrQueryStatus, driverErr.Number)
}
}

func TestStatementQueryIdForQueries(t *testing.T) {
ctx := context.Background()
conn := openConn(t)
defer conn.Close()

testcases := []struct {
name string
f func(stmt driver.Stmt) (driver.Rows, error)
}{
{
"query",
func(stmt driver.Stmt) (driver.Rows, error) {
return stmt.Query(nil)
},
},
{
"queryContext",
func(stmt driver.Stmt) (driver.Rows, error) {
return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
},
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := conn.Raw(func(x any) error {
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1")
if err != nil {
t.Fatal(err)
}
if stmt.(SnowflakeStmt).GetQueryID() != "" {
t.Error("queryId should be empty before executing any query")
}
firstQuery, err := tc.f(stmt)
if err != nil {
t.Fatal(err)
}
if stmt.(SnowflakeStmt).GetQueryID() == "" {
t.Error("queryId should not be empty after executing query")
}
if stmt.(SnowflakeStmt).GetQueryID() != firstQuery.(SnowflakeRows).GetQueryID() {
t.Error("queryId should be equal among query result and prepared statement")
}
secondQuery, err := tc.f(stmt)
if err != nil {
t.Fatal(err)
}
if stmt.(SnowflakeStmt).GetQueryID() == "" {
t.Error("queryId should not be empty after executing query")
}
if stmt.(SnowflakeStmt).GetQueryID() != secondQuery.(SnowflakeRows).GetQueryID() {
t.Error("queryId should be equal among query result and prepared statement")
}
return nil
})
if err != nil {
t.Fatal(err)
}
})
}
}

func TestStatementQueryIdForExecs(t *testing.T) {
ctx := context.Background()
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)")
defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs")

testcases := []struct {
name string
f func(stmt driver.Stmt) (driver.Result, error)
}{
{
"exec",
func(stmt driver.Stmt) (driver.Result, error) {
return stmt.Exec(nil)
},
},
{
"execContext",
func(stmt driver.Stmt) (driver.Result, error) {
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
},
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := dbt.conn.Raw(func(x any) error {
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)")
if err != nil {
t.Fatal(err)
}
if stmt.(SnowflakeStmt).GetQueryID() != "" {
t.Error("queryId should be empty before executing any query")
}
firstExec, err := tc.f(stmt)
if err != nil {
t.Fatal(err)
}
if stmt.(SnowflakeStmt).GetQueryID() == "" {
t.Error("queryId should not be empty after executing query")
}
if stmt.(SnowflakeStmt).GetQueryID() != firstExec.(SnowflakeResult).GetQueryID() {
t.Error("queryId should be equal among query result and prepared statement")
}
secondExec, err := tc.f(stmt)
if err != nil {
t.Fatal(err)
}
if stmt.(SnowflakeStmt).GetQueryID() == "" {
t.Error("queryId should not be empty after executing query")
}
if stmt.(SnowflakeStmt).GetQueryID() != secondExec.(SnowflakeResult).GetQueryID() {
t.Error("queryId should be equal among query result and prepared statement")
}
return nil
})
if err != nil {
t.Fatal(err)
}
})
}
})
}

0 comments on commit 9561417

Please sign in to comment.