Skip to content

Commit

Permalink
SNOW-968719: Remember queryId for failed queries (#967)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pbulawa authored Nov 15, 2023
1 parent 46d412c commit 25d3956
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 2 deletions.
14 changes: 13 additions & 1 deletion statement.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
// Copyright (c) 2017-2023 Snowflake Computing Inc. All rights reserved.

package gosnowflake

import (
"context"
"database/sql/driver"
"errors"
"fmt"
)

Expand Down Expand Up @@ -35,6 +36,7 @@ func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedV
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext")
result, err := stmt.sc.ExecContext(ctx, stmt.query, args)
if err != nil {
stmt.setQueryIDFromError(err)
return nil, err
}
r, ok := result.(SnowflakeResult)
Expand All @@ -49,6 +51,7 @@ func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.Named
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext")
rows, err := stmt.sc.QueryContext(ctx, stmt.query, args)
if err != nil {
stmt.setQueryIDFromError(err)
return nil, err
}
r, ok := rows.(SnowflakeRows)
Expand All @@ -63,6 +66,7 @@ func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec")
result, err := stmt.sc.Exec(stmt.query, args)
if err != nil {
stmt.setQueryIDFromError(err)
return nil, err
}
r, ok := result.(SnowflakeResult)
Expand All @@ -77,6 +81,7 @@ func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query")
rows, err := stmt.sc.Query(stmt.query, args)
if err != nil {
stmt.setQueryIDFromError(err)
return nil, err
}
r, ok := rows.(SnowflakeRows)
Expand All @@ -90,3 +95,10 @@ func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) {
func (stmt *snowflakeStmt) GetQueryID() string {
return stmt.lastQueryID
}

func (stmt *snowflakeStmt) setQueryIDFromError(err error) {
var snowflakeError *SnowflakeError
if errors.As(err, &snowflakeError) {
stmt.lastQueryID = snowflakeError.QueryID
}
}
146 changes: 145 additions & 1 deletion statement_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved.
// Copyright (c) 2020-2023 Snowflake Computing Inc. All rights reserved.
//lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility.

package gosnowflake
Expand All @@ -7,6 +7,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -39,6 +40,149 @@ func openConn(t *testing.T) *sql.Conn {
return conn
}

func TestFailedQueryIdInSnowflakeError(t *testing.T) {
failingQuery := "SELECTT 1"
failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"

runDBTest(t, func(dbt *DBTest) {
testcases := []struct {
name string
query string
f func(dbt *DBTest) (any, error)
}{
{
name: "query",
f: func(dbt *DBTest) (any, error) {
return dbt.query(failingQuery)
},
},
{
name: "exec",
f: func(dbt *DBTest) (any, error) {
return dbt.exec(failingExec)
},
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.f(dbt)
if err == nil {
t.Error("should have failed")
}
var snowflakeError *SnowflakeError
if !errors.As(err, &snowflakeError) {
t.Error("should be a SnowflakeError")
}
if snowflakeError.QueryID == "" {
t.Error("QueryID should be set")
}
})
}
})
}

func TestSetFailedQueryId(t *testing.T) {
ctx := context.Background()
failingQuery := "SELECTT 1"
failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"

runDBTest(t, func(dbt *DBTest) {
testcases := []struct {
name string
query string
f func(stmt driver.Stmt) (any, error)
}{
{
name: "query",
query: failingQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.Query(nil)
},
},
{
name: "exec",
query: failingExec,
f: func(stmt driver.Stmt) (any, error) {
return stmt.Exec(nil)
},
},
{
name: "queryContext",
query: failingQuery,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
},
},
{
name: "execContext",
query: failingExec,
f: func(stmt driver.Stmt) (any, error) {
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
},
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := dbt.conn.Raw(func(x any) error {
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query)
if err != nil {
t.Error(err)
}
if stmt.(SnowflakeStmt).GetQueryID() != "" {
t.Error("queryId should be empty before executing any query")
}
if _, err := tc.f(stmt); err == nil {
t.Error("should have failed to execute the query")
}
if stmt.(SnowflakeStmt).GetQueryID() == "" {
t.Error("should have set the query id")
}
return nil
})
if err != nil {
t.Fatal(err)
}
})
}
})
}

func TestAsyncFailQueryId(t *testing.T) {
ctx := WithAsyncMode(context.Background())
runDBTest(t, func(dbt *DBTest) {
err := dbt.conn.Raw(func(x any) error {
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECTT 1")
if err != nil {
t.Error(err)
}
if stmt.(SnowflakeStmt).GetQueryID() != "" {
t.Error("queryId should be empty before executing any query")
}
rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
if err != nil {
t.Error("should not fail the initial request")
}
if rows.(SnowflakeRows).GetStatus() != QueryStatusInProgress {
t.Error("should be in progress")
}
// Wait for the query to complete
rows.Next(nil)
if rows.(SnowflakeRows).GetStatus() != QueryFailed {
t.Error("should have failed")
}
if rows.(SnowflakeRows).GetQueryID() != stmt.(SnowflakeStmt).GetQueryID() {
t.Error("last query id should be the same as rows query id")
}
return nil
})
if err != nil {
t.Fatal(err)
}
})
}

func TestGetQueryID(t *testing.T) {
ctx := context.Background()
conn := openConn(t)
Expand Down

0 comments on commit 25d3956

Please sign in to comment.