From de058ef79c85e226378c24a5131c86df4f06e219 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Mon, 28 Aug 2023 08:37:52 +0200 Subject: [PATCH] SNOW-726742: Implement GetQueryId for statements --- doc.go | 30 +++++++++++ statement.go | 30 ++++++++--- statement_test.go | 127 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 6 deletions(-) diff --git a/doc.go b/doc.go index fd496f67c..84b859803 100644 --- a/doc.go +++ b/doc.go @@ -168,6 +168,36 @@ in place of the default randomized request ID. For example: ctxWithID := WithRequestID(ctx, requestID) rows, err := db.QueryContext(ctxWithID, query) +# Last query ID + +If you need query ID for your query you have to use raw connection. + +For queries: +``` + + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") + rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + rows.(SnowflakeRows).GetQueryID() + stmt.(SnowflakeStmt).GetQueryID() + return nil + } + +``` + +For execs: +``` + + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") + result, err := stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + result.(SnowflakeResult).GetQueryID() + stmt.(SnowflakeStmt).GetQueryID() + return nil + } + +``` + # Canceling Query by CtrlC From 0.5.0, a signal handling responsibility has moved to the applications. If you want to cancel a diff --git a/statement.go b/statement.go index 203009986..70d4479a7 100644 --- a/statement.go +++ b/statement.go @@ -7,9 +7,15 @@ import ( "database/sql/driver" ) +// SnowflakeStmt represents the prepared statement in driver. +type SnowflakeStmt interface { + GetQueryID() string +} + type snowflakeStmt struct { - sc *snowflakeConn - query string + sc *snowflakeConn + query string + lastQueryID string } func (stmt *snowflakeStmt) Close() error { @@ -26,20 +32,32 @@ func (stmt *snowflakeStmt) NumInput() int { func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext") - return stmt.sc.ExecContext(ctx, stmt.query, args) + result, err := stmt.sc.ExecContext(ctx, stmt.query, args) + stmt.lastQueryID = result.(SnowflakeResult).GetQueryID() + return result, err } func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext") - return stmt.sc.QueryContext(ctx, stmt.query, args) + rows, err := stmt.sc.QueryContext(ctx, stmt.query, args) + stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID() + return rows, err } func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") - return stmt.sc.Exec(stmt.query, args) + result, err := stmt.sc.Exec(stmt.query, args) + stmt.lastQueryID = result.(SnowflakeResult).GetQueryID() + return result, err } func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query") - return stmt.sc.Query(stmt.query, args) + rows, err := stmt.sc.Query(stmt.query, args) + stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID() + return rows, err +} + +func (stmt *snowflakeStmt) GetQueryID() string { + return stmt.lastQueryID } diff --git a/statement_test.go b/statement_test.go index de496c8aa..68f1e7a7b 100644 --- a/statement_test.go +++ b/statement_test.go @@ -1,4 +1,5 @@ // Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved. +//lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility. package gosnowflake @@ -287,3 +288,129 @@ func TestUnitCheckQueryStatus(t *testing.T) { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrQueryStatus, driverErr.Number) } } + +func TestStatementQueryIdForQueries(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + testcases := []struct { + name string + f func(stmt driver.Stmt) (driver.Rows, error) + }{ + { + "query", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.Query(nil) + }, + }, + { + "queryContext", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + firstQuery, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != firstQuery.(SnowflakeRows).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + secondQuery, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != secondQuery.(SnowflakeRows).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestStatementQueryIdForExecs(t *testing.T) { + ctx := context.Background() + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)") + defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs") + + testcases := []struct { + name string + f func(stmt driver.Stmt) (driver.Result, error) + }{ + { + "exec", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.Exec(nil) + }, + }, + { + "execContext", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + firstExec, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != firstExec.(SnowflakeResult).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + secondExec, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != secondExec.(SnowflakeResult).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } + }) +}