From d45538740d7179429bbd1dc0498f43fdf38272d2 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Tue, 29 Aug 2023 15:21:53 +0200 Subject: [PATCH] SNOW-895536: Limit query context cache --- connection.go | 2 +- htap.go | 37 ++++++++++++++++-- htap_test.go | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 4 deletions(-) diff --git a/connection.go b/connection.go index 9b4b1f808..a48af3600 100644 --- a/connection.go +++ b/connection.go @@ -141,7 +141,7 @@ func (sc *snowflakeConn) exec( return nil, err } - sc.queryContextCache.add(data.Data.QueryContext.Entries...) + sc.queryContextCache.add(sc, data.Data.QueryContext.Entries...) // handle PUT/GET commands if isFileTransfer(query) { diff --git a/htap.go b/htap.go index 587f830aa..0d15a1030 100644 --- a/htap.go +++ b/htap.go @@ -1,6 +1,14 @@ package gosnowflake -import "sync" +import ( + "strconv" + "sync" +) + +const ( + queryContextCacheSizeParamName = "QUERY_CONTEXT_CACHE_SIZE" + defaultQueryContextCacheSize = 5 +) type queryContextEntry struct { ID int `json:"id"` @@ -19,8 +27,31 @@ func (qcc *queryContextCache) init() *queryContextCache { return qcc } -func (qcc *queryContextCache) add(qces ...queryContextEntry) { +func (qcc *queryContextCache) add(sc *snowflakeConn, qces ...queryContextEntry) { qcc.mutex.Lock() defer qcc.mutex.Unlock() - qcc.entries = append(qcc.entries, qces...) + if len(qces) == 0 { + qcc.prune(0) + } else { + qcc.entries = append(qcc.entries, qces...) + qcc.prune(qcc.getQueryContextCacheSize(sc)) + } +} + +func (qcc *queryContextCache) prune(size int) { + if len(qcc.entries) > size { + qcc.entries = qcc.entries[0:size] + } +} + +func (qcc *queryContextCache) getQueryContextCacheSize(sc *snowflakeConn) int { + if sizeStr, ok := sc.cfg.Params[queryContextCacheSizeParamName]; ok { + size, err := strconv.Atoi(*sizeStr) + if err != nil { + logger.Warnf("cannot parse %v as int as query context cache size: %v", sizeStr, err) + } else { + return size + } + } + return defaultQueryContextCacheSize } diff --git a/htap_test.go b/htap_test.go index 39fae90f0..e196589a6 100644 --- a/htap_test.go +++ b/htap_test.go @@ -147,3 +147,107 @@ func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryC return false } + +func TestPruneBySessionValue(t *testing.T) { + qce1 := queryContextEntry{1, 1, 1, nil} + qce2 := queryContextEntry{2, 2, 2, nil} + qce3 := queryContextEntry{3, 3, 3, nil} + + testcases := []struct { + size string + expected []queryContextEntry + }{ + { + size: "1", + expected: []queryContextEntry{qce1}, + }, + { + size: "2", + expected: []queryContextEntry{qce1, qce2}, + }, + { + size: "3", + expected: []queryContextEntry{qce1, qce2, qce3}, + }, + { + size: "4", + expected: []queryContextEntry{qce1, qce2, qce3}, + }, + } + + for _, tc := range testcases { + t.Run(tc.size, func(t *testing.T) { + sc := &snowflakeConn{ + cfg: &Config{ + Params: map[string]*string{ + queryContextCacheSizeParamName: &tc.size, + }, + }, + } + + qcc := (&queryContextCache{}).init() + + qcc.add(sc, qce1) + qcc.add(sc, qce2) + qcc.add(sc, qce3) + + if !reflect.DeepEqual(qcc.entries, tc.expected) { + t.Errorf("unexpected cache entries. expected: %v, got: %v", tc.expected, qcc.entries) + } + }) + } +} + +func TestPruneByDefaultValue(t *testing.T) { + qce1 := queryContextEntry{1, 1, 1, nil} + qce2 := queryContextEntry{2, 2, 2, nil} + qce3 := queryContextEntry{3, 3, 3, nil} + qce4 := queryContextEntry{4, 4, 4, nil} + qce5 := queryContextEntry{5, 5, 5, nil} + qce6 := queryContextEntry{6, 6, 6, nil} + + sc := &snowflakeConn{ + cfg: &Config{ + Params: map[string]*string{}, + }, + } + + qcc := (&queryContextCache{}).init() + qcc.add(sc, qce1) + qcc.add(sc, qce2) + qcc.add(sc, qce3) + qcc.add(sc, qce4) + qcc.add(sc, qce5) + + if len(qcc.entries) != 5 { + t.Fatalf("Expected 5 elements, got: %v", len(qcc.entries)) + } + + qcc.add(sc, qce6) + if len(qcc.entries) != 5 { + t.Fatalf("Expected 5 elements, got: %v", len(qcc.entries)) + } +} + +func TestNoQcesClearsCache(t *testing.T) { + qce1 := queryContextEntry{1, 1, 1, nil} + + sc := &snowflakeConn{ + cfg: &Config{ + Params: map[string]*string{}, + }, + } + + qcc := (&queryContextCache{}).init() + qcc.add(sc, qce1) + + if len(qcc.entries) != 1 { + t.Fatalf("improperly inited cache") + } + + qcc.add(sc) + + if len(qcc.entries) != 0 { + t.Errorf("after adding empty context list cache should be cleared") + } +}