Skip to content

Commit

Permalink
SNOW-857631 Handle multistatement query type
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Aug 17, 2023
1 parent 4bc6cd4 commit 66d08d6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 15 deletions.
23 changes: 23 additions & 0 deletions arrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,29 @@ import (
"time"
)

func TestCheckVersion(t *testing.T) {
conn := openConn(t)
defer conn.Close()

queries := []string{"SELECT current_version()", "show parameters like 'SHOW%'"}
for _, query := range queries {
rows, err := conn.QueryContext(context.Background(), query)
if err != nil {
t.Error(err)
}
defer rows.Close()

if !rows.Next() {
t.Fatalf("failed to find any row")
}
var s string
if err = rows.Scan(&s); err != nil {
t.Fatal(err)
}
println(s)
}
}

func TestArrowBigInt(t *testing.T) {
conn := openConn(t)
defer conn.Close()
Expand Down
3 changes: 2 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ const (
)

const (
statementTypeIDMulti = int64(0x1000)
statementTypeIDSelect = int64(0x1000)
statementTypeIDDml = int64(0x3000)
statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500)
statementTypeIDMultistatement = int64(0xA000)
)

const (
Expand Down
4 changes: 2 additions & 2 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ func updateRows(data execResponseData) (int64, error) {
// Note that the statement type code is also equivalent to type INSERT, so an
// additional check of the name is required
func isMultiStmt(data *execResponseData) bool {
return data.StatementTypeID == statementTypeIDMulti &&
data.RowType[0].Name == "multiple statement execution"
var isMultistatementByReturningSelect = data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name == "multiple statement execution"
return isMultistatementByReturningSelect || data.StatementTypeID == statementTypeIDMultistatement
}

func getResumeQueryID(ctx context.Context) (string, error) {
Expand Down
34 changes: 22 additions & 12 deletions multistatement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
Expand All @@ -22,11 +23,8 @@ func TestMultiStatementExecuteNoResultSet(t *testing.T) {
"insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" +
"commit;"

runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn")
dbt.mustExec(`create or replace table test_multi_statement_txn(
c1 number, c2 string) as select 10, 'z'`)
defer dbt.mustExec("drop table if exists test_multi_statement_txn")
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec(`create or replace table test_multi_statement_txn(c1 number, c2 string) as select 10, 'z'`)

res := dbt.mustExecContext(ctx, multiStmtQuery)
count, err := res.RowsAffected()
Expand All @@ -48,7 +46,8 @@ func TestMultiStatementQueryResultSet(t *testing.T) {

var v1, v2, v3 int64
var v4 string
runDBTest(t, func(dbt *DBTest) {

testForAllMultistatementTypes(t, func(dbt *DBTest) {
rows := dbt.mustQueryContext(ctx, multiStmtQuery)
defer rows.Close()

Expand Down Expand Up @@ -120,7 +119,7 @@ func TestMultiStatementExecuteResultSet(t *testing.T) {
"select 2;\n" +
"rollback;"

runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn_rb")
dbt.mustExec(`create or replace table test_multi_statement_txn_rb(
c1 number, c2 string) as select 10, 'z'`)
Expand All @@ -144,7 +143,7 @@ func TestMultiStatementQueryNoResultSet(t *testing.T) {
"insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" +
"commit;"

runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn")
dbt.mustExec(`create or replace table test_multi_statement_txn(
c1 number, c2 string) as select 10, 'z'`)
Expand All @@ -161,7 +160,7 @@ func TestMultiStatementExecuteMix(t *testing.T) {
"insert into test_multi values (1), (2);\n" +
"select cola from test_multi order by cola asc;"

runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn")
dbt.mustExec(`create or replace table test_multi_statement_txn(
c1 number, c2 string) as select 10, 'z'`)
Expand All @@ -185,7 +184,7 @@ func TestMultiStatementQueryMix(t *testing.T) {
"select cola from test_multi order by cola asc;"

var count, v int
runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn")
dbt.mustExec(`create or replace table test_multi_statement_txn(
c1 number, c2 string) as select 10, 'z'`)
Expand Down Expand Up @@ -232,7 +231,7 @@ func TestMultiStatementCountZero(t *testing.T) {
var v3 float64
var v4 bool

runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
// first query
multiStmtQuery1 := "select 123;\n" +
"select '456';"
Expand Down Expand Up @@ -352,7 +351,7 @@ func TestMultiStatementVaryingColumnCount(t *testing.T) {
ctx, _ := WithMultiStatement(context.Background(), 0)

var v1, v2 int
runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("create or replace table test_tbl(c1 int, c2 int)")
dbt.mustExec("insert into test_tbl values(1, 0)")
defer dbt.mustExec("drop table if exists test_tbl")
Expand Down Expand Up @@ -593,3 +592,14 @@ func TestUnitHandleMultiQuery(t *testing.T) {
t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number)
}
}

func testForAllMultistatementTypes(t *testing.T, test func(dbt *DBTest)) {
for _, enableMultistatementType := range []bool{false, true} {
t.Run(fmt.Sprintf("enableMultistatementType=%v", enableMultistatementType), func(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec(fmt.Sprintf("ALTER SESSION SET ENABLE_MULTI_STMT_QUERY_TYPE = %v", enableMultistatementType))
test(dbt)
})
})
}
}

0 comments on commit 66d08d6

Please sign in to comment.