diff --git a/async.go b/async.go index 65434e4a9..d29b24b12 100644 --- a/async.go +++ b/async.go @@ -105,7 +105,7 @@ func (sr *snowflakeRestful) getAsync( } - sc := &snowflakeConn{rest: sr, cfg: cfg} + sc := &snowflakeConn{rest: sr, cfg: cfg, queryContextCache: (&queryContextCache{}).init()} if respd.Success { if resType == execResultType { res.insertID = -1 diff --git a/connection.go b/connection.go index cf129a174..cf888abed 100644 --- a/connection.go +++ b/connection.go @@ -60,12 +60,13 @@ const ( const privateLinkSuffix = "privatelink.snowflakecomputing.com" type snowflakeConn struct { - ctx context.Context - cfg *Config - rest *snowflakeRestful - SequenceCounter uint64 - telemetry *snowflakeTelemetry - internal InternalClient + ctx context.Context + cfg *Config + rest *snowflakeRestful + SequenceCounter uint64 + telemetry *snowflakeTelemetry + internal InternalClient + queryContextCache *queryContextCache } var ( @@ -143,6 +144,8 @@ func (sc *snowflakeConn) exec( return nil, err } + sc.queryContextCache.addAll(data.Data.QueryContext.Entries) + // handle PUT/GET commands if isFileTransfer(query) { data, err = sc.processFileTransfer(ctx, data, query, isInternal) @@ -679,9 +682,10 @@ func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamB func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) { sc := &snowflakeConn{ - SequenceCounter: 0, - ctx: ctx, - cfg: &config, + SequenceCounter: 0, + ctx: ctx, + cfg: &config, + queryContextCache: (&queryContextCache{}).init(), } var st http.RoundTripper = SnowflakeTransport if sc.cfg.Transporter == nil { diff --git a/connection_test.go b/connection_test.go index 3f77c66e7..90d0bc41a 100644 --- a/connection_test.go +++ b/connection_test.go @@ -90,8 +90,9 @@ func TestExecWithEmptyRequestID(t *testing.T) { } sc := &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + queryContextCache: (&queryContextCache{}).init(), } if _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ false /* describeOnly */, nil); err != nil { @@ -161,8 +162,9 @@ func TestExecWithSpecificRequestID(t *testing.T) { } sc := &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + queryContextCache: (&queryContextCache{}).init(), } if _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ false /* describeOnly */, nil); err != nil { @@ -181,8 +183,9 @@ func TestServiceName(t *testing.T) { } sc := &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + queryContextCache: (&queryContextCache{}).init(), } expectServiceName := serviceNameStub @@ -219,9 +222,10 @@ func TestCloseIgnoreSessionGone(t *testing.T) { FuncCloseSession: closeSessionMock, } sc := &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, - telemetry: testTelemetry, + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + telemetry: testTelemetry, + queryContextCache: (&queryContextCache{}).init(), } if sc.Close() != nil { diff --git a/connection_util.go b/connection_util.go index 2a45e363c..c3e4cb60a 100644 --- a/connection_util.go +++ b/connection_util.go @@ -39,7 +39,7 @@ func (sc *snowflakeConn) stopHeartBeat() { if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return } - if sc.rest != nil { + if sc.rest != nil && sc.rest.HeartBeat != nil { sc.rest.HeartBeat.stop() } } diff --git a/driver_test.go b/driver_test.go index 634a53145..653c62777 100644 --- a/driver_test.go +++ b/driver_test.go @@ -328,6 +328,23 @@ func runDBTest(t *testing.T, test func(dbt *DBTest)) { test(dbt) } +func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + defer sc.Close() + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + + test(sc) +} + func runningOnAWS() bool { return os.Getenv("CLOUD_PROVIDER") == "AWS" } diff --git a/htap.go b/htap.go index c0e7665da..efc623e23 100644 --- a/htap.go +++ b/htap.go @@ -1,8 +1,26 @@ package gosnowflake +import "sync" + type queryContextEntry struct { ID int `json:"id"` Timestamp int64 `json:"timestamp"` Priority int `json:"priority"` Context any `json:"context,omitempty"` } + +type queryContextCache struct { + mutex *sync.Mutex + entries []queryContextEntry +} + +func (qcc *queryContextCache) init() *queryContextCache { + qcc.mutex = &sync.Mutex{} + return qcc +} + +func (qcc *queryContextCache) addAll(qces []queryContextEntry) { + qcc.mutex.Lock() + defer qcc.mutex.Unlock() + qcc.entries = append(qcc.entries, qces...) +} diff --git a/htap_test.go b/htap_test.go index 9fff10a27..39fae90f0 100644 --- a/htap_test.go +++ b/htap_test.go @@ -100,3 +100,50 @@ func trimWhitespaces(s string) string { "\n", "", ) } + +func TestAddingQueryContextCacheEntry(t *testing.T) { + runSnowflakeConnTest(t, func(sc *snowflakeConn) { + t.Run("First query (may be on empty cache)", func(t *testing.T) { + entriesBefore := sc.queryContextCache.entries + if _, err := sc.Query("SELECT 1", nil); err != nil { + t.Fatalf("cannot query. %v", err) + } + entriesAfter := sc.queryContextCache.entries + + if !containsNewEntries(entriesAfter, entriesBefore) { + t.Error("no new entries added to the query context cache") + } + }) + + t.Run("Second query (cache should not be empty)", func(t *testing.T) { + entriesBefore := sc.queryContextCache.entries + if len(entriesBefore) == 0 { + t.Fatalf("cache should not be empty after first query") + } + if _, err := sc.Query("SELECT 1", nil); err != nil { + t.Fatalf("cannot query. %v", err) + } + entriesAfter := sc.queryContextCache.entries + + if !containsNewEntries(entriesAfter, entriesBefore) { + t.Error("no new entries added to the query context cache") + } + }) + }) +} + +func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryContextEntry) bool { + if len(entriesAfter) > len(entriesBefore) { + return true + } + + for _, entryAfter := range entriesAfter { + for _, entryBefore := range entriesBefore { + if !reflect.DeepEqual(entryBefore, entryAfter) { + return true + } + } + } + + return false +} diff --git a/query.go b/query.go index db76d1624..5d7dff053 100644 --- a/query.go +++ b/query.go @@ -118,6 +118,11 @@ type execResponseData struct { Command string `json:"command,omitempty"` Kind string `json:"kind,omitempty"` Operation string `json:"operation,omitempty"` + + // HTAP + QueryContext struct { + Entries []queryContextEntry `json:"entries,omitempty"` + } `json:"queryContext,omitempty"` } type execResponse struct {