Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-895537: Send query context with request #904

Merged
merged 1 commit into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,18 @@ func (sc *snowflakeConn) exec(
var err error
counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter

queryContext, err := buildQueryContext(sc.queryContextCache)
if err != nil {
logger.Errorf("error while building query context: %v", err)
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved
}
req := execRequest{
SQLText: query,
AsyncExec: noResult,
Parameters: map[string]interface{}{},
IsInternal: isInternal,
DescribeOnly: describeOnly,
SequenceID: counter,
QueryContext: queryContext,
}
if key := ctx.Value(multiStatementCount); key != nil {
req.Parameters[string(multiStatementCount)] = key
Expand Down Expand Up @@ -173,6 +178,27 @@ func extractQueryContext(data *execResponse) (queryContext, error) {
return queryContext, err
}

func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) {
sfc-gh-dheyman marked this conversation as resolved.
Show resolved Hide resolved
rqc := requestQueryContext{}
if qcc == nil || len(qcc.entries) == 0 {
logger.Debugf("empty qcc")
return rqc, nil
}
for _, qce := range qcc.entries {
contextData := contextData{}
if qce.Context == "" {
contextData.Base64Data = qce.Context
}
rqc.Entries = append(rqc.Entries, requestQueryContextEntry{
ID: qce.ID,
Priority: qce.Priority,
Timestamp: qce.Timestamp,
Context: contextData,
})
}
return rqc, nil
}

func (sc *snowflakeConn) Begin() (driver.Tx, error) {
return sc.BeginTx(sc.ctx, driver.TxOptions{})
}
Expand Down
1 change: 1 addition & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) {
}

sct := &SCTest{t, sc}

test(sct)
}

Expand Down
8 changes: 4 additions & 4 deletions htap.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ type queryContext struct {
}

type queryContextEntry struct {
ID int `json:"id"`
Timestamp int64 `json:"timestamp"`
Priority int `json:"priority"`
Context any `json:"context,omitempty"`
ID int `json:"id"`
Timestamp int64 `json:"timestamp"`
Priority int `json:"priority"`
Context string `json:"context,omitempty"`
}

type queryContextCache struct {
Expand Down
195 changes: 89 additions & 106 deletions htap_test.go
Original file line number Diff line number Diff line change
@@ -1,106 +1,13 @@
package gosnowflake

import (
"encoding/json"
"database/sql/driver"
"fmt"
"reflect"
"strings"
"testing"
"time"
)

func TestMarshallAndDecodeOpaqueContext(t *testing.T) {
testcases := []struct {
json string
qc queryContextEntry
}{
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3
}`,
qc: queryContextEntry{1, 2, 3, nil},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": "abc"
}`,
qc: queryContextEntry{1, 2, 3, "abc"},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": {
"val": "abc"
}
}`,
qc: queryContextEntry{1, 2, 3, map[string]interface{}{"val": "abc"}},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": [
"abc"
]
}`,
qc: queryContextEntry{1, 2, 3, []any{"abc"}},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": [
{
"val": "abc"
}
]
}`,
qc: queryContextEntry{1, 2, 3, []any{map[string]interface{}{"val": "abc"}}},
},
}

for _, tc := range testcases {
t.Run(trimWhitespaces(tc.json), func(t *testing.T) {
var qc queryContextEntry

err := json.NewDecoder(strings.NewReader(tc.json)).Decode(&qc)
if err != nil {
t.Fatalf("failed to decode json. %v", err)
}

if !reflect.DeepEqual(tc.qc, qc) {
t.Errorf("failed to decode json. expected: %v, got: %v", tc.qc, qc)
}

bytes, err := json.Marshal(qc)
if err != nil {
t.Fatalf("failed to encode json. %v", err)
}

resultJSON := string(bytes)
if resultJSON != trimWhitespaces(tc.json) {
t.Errorf("failed to encode json. epxected: %v, got: %v", trimWhitespaces(tc.json), resultJSON)
}
})
}
}

func trimWhitespaces(s string) string {
return strings.ReplaceAll(
strings.ReplaceAll(
strings.ReplaceAll(s, "\t", ""),
" ", ""),
"\n", "",
)
}

func TestSortingByPriority(t *testing.T) {
qcc := (&queryContextCache{}).init()
sc := htapTestSnowflakeConn()
Expand Down Expand Up @@ -302,9 +209,9 @@ func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryC
}

func TestPruneBySessionValue(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce2 := queryContextEntry{2, 2, 2, nil}
qce3 := queryContextEntry{3, 3, 3, nil}
qce1 := queryContextEntry{1, 1, 1, ""}
qce2 := queryContextEntry{2, 2, 2, ""}
qce3 := queryContextEntry{3, 3, 3, ""}

testcases := []struct {
size string
Expand Down Expand Up @@ -352,12 +259,12 @@ func TestPruneBySessionValue(t *testing.T) {
}

func TestPruneByDefaultValue(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce2 := queryContextEntry{2, 2, 2, nil}
qce3 := queryContextEntry{3, 3, 3, nil}
qce4 := queryContextEntry{4, 4, 4, nil}
qce5 := queryContextEntry{5, 5, 5, nil}
qce6 := queryContextEntry{6, 6, 6, nil}
qce1 := queryContextEntry{1, 1, 1, ""}
qce2 := queryContextEntry{2, 2, 2, ""}
qce3 := queryContextEntry{3, 3, 3, ""}
qce4 := queryContextEntry{4, 4, 4, ""}
qce5 := queryContextEntry{5, 5, 5, ""}
qce6 := queryContextEntry{6, 6, 6, ""}

sc := &snowflakeConn{
cfg: &Config{
Expand All @@ -383,7 +290,7 @@ func TestPruneByDefaultValue(t *testing.T) {
}

func TestNoQcesClearsCache(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce1 := queryContextEntry{1, 1, 1, ""}

sc := &snowflakeConn{
cfg: &Config{
Expand Down Expand Up @@ -426,3 +333,79 @@ func TestQueryContextCacheDisabled(t *testing.T) {
}
})
}

func TestHybridTablesE2E(t *testing.T) {
if runningOnGithubAction() && !runningOnAWS() {
t.Skip("HTAP is enabled only on AWS")
}
runID := time.Now().UnixMilli()
testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID)
testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID)
runSnowflakeConnTest(t, func(sct *SCTest) {
dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil)
defer dbQuery.Close()
currentDb := make([]driver.Value, 1)
dbQuery.Next(currentDb)
defer func() {
sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil)
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil)
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil)
}()

t.Run("Run tests on first database", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil)
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil)

sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil)
rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows.Close()
row := make([]driver.Value, 2)
rows.Next(row)
if row[0] != "1" || row[1] != "a" {
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
}

sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil)
rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows2.Close()
rows2.Next(row)
if row[0] != "1" || row[1] != "a" {
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
}
rows2.Next(row)
if row[0] != "2" || row[1] != "b" {
t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1])
}
if len(sct.sc.queryContextCache.entries) != 2 {
t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
t.Run("Run tests on second database", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil)
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil)
sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil)

rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil)
defer rows.Close()
row := make([]driver.Value, 2)
rows.Next(row)
if row[0] != "3" || row[1] != "c" {
t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1])
}
if len(sct.sc.queryContextCache.entries) != 3 {
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
t.Run("Run tests on first database again", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil)

sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil)

rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows.Close()
if len(sct.sc.queryContextCache.entries) != 3 {
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
})
}
16 changes: 16 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ type execRequest struct {
Parameters map[string]interface{} `json:"parameters,omitempty"`
Bindings map[string]execBindParameter `json:"bindings,omitempty"`
BindStage string `json:"bindStage,omitempty"`
QueryContext requestQueryContext `json:"queryContextDTO,omitempty"`
}

type requestQueryContext struct {
Entries []requestQueryContextEntry `json:"entries,omitempty"`
}

type requestQueryContextEntry struct {
Context contextData `json:"context,omitempty"`
ID int `json:"id"`
Priority int `json:"priority"`
Timestamp int64 `json:"timestamp,omitempty"`
}

type contextData struct {
Base64Data string `json:"base64Data,omitempty"`
}

type execResponseRowType struct {
Expand Down
Loading