From 1245f308f0d2a89e193c0bcd857ec2fa09497b77 Mon Sep 17 00:00:00 2001 From: Tommy Situ Date: Wed, 28 Aug 2024 23:08:44 +0100 Subject: [PATCH] Encapsulate templating data source and optimize locking --- core/hoverfly_service_test.go | 33 +++++++------ core/templating/datasource_sql_over_csv.go | 13 ++++-- .../datasource_sql_over_csv_test.go | 2 +- core/templating/template_datasource.go | 46 +++++++++---------- core/templating/template_helpers.go | 23 ++++------ 5 files changed, 58 insertions(+), 59 deletions(-) diff --git a/core/hoverfly_service_test.go b/core/hoverfly_service_test.go index e17fd9310..2991d3f08 100644 --- a/core/hoverfly_service_test.go +++ b/core/hoverfly_service_test.go @@ -1523,25 +1523,30 @@ func TestHoverfly_SetMultipleTemplateDataSource(t *testing.T) { Expect(err1).To(BeNil()) Expect(err2).To(BeNil()) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources).ToNot(BeNil()) + Expect(unit.templator.TemplateHelper.TemplateDataSource.GetAllDataSources()).ToNot(BeNil()) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv1"]).NotTo(BeNil()) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv2"]).NotTo(BeNil()) + csv1, exists1:= unit.templator.TemplateHelper.TemplateDataSource.GetDataSource("test-csv1") + Expect(csv1).NotTo(BeNil()) + Expect(exists1).To(BeTrue()) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv1"].Name).To(Equal("test-csv1")) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv2"].Name).To(Equal("test-csv2")) + csv2, exists2:= unit.templator.TemplateHelper.TemplateDataSource.GetDataSource("test-csv2") + Expect(csv2).NotTo(BeNil()) + Expect(exists2).To(BeTrue()) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv1"].SourceType).To(Equal("csv")) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv2"].SourceType).To(Equal("csv")) + Expect(csv1.Name).To(Equal("test-csv1")) + Expect(csv2.Name).To(Equal("test-csv2")) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv1"].Data).To(HaveLen(3)) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv2"].Data).To(HaveLen(4)) + Expect(csv1.SourceType).To(Equal("csv")) + Expect(csv2.SourceType).To(Equal("csv")) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv1"].Data[1][2]).To(Equal("55")) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv2"].Data[2][2]).To(Equal("New York")) + Expect(csv1.Data).To(HaveLen(3)) + Expect(csv2.Data).To(HaveLen(4)) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv1"].Data[2][1]).To(Equal("Test2")) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources["test-csv2"].Data[3][0]).To(Equal("31")) + Expect(csv1.Data[1][2]).To(Equal("55")) + Expect(csv2.Data[2][2]).To(Equal("New York")) + + Expect(csv1.Data[2][1]).To(Equal("Test2")) + Expect(csv2.Data[3][0]).To(Equal("31")) } func TestHoverfly_DeleteTemplateDataSource(t *testing.T) { @@ -1557,7 +1562,7 @@ func TestHoverfly_DeleteTemplateDataSource(t *testing.T) { unit.DeleteDataSource("test-csv1") Expect(err).To(BeNil()) - Expect(unit.templator.TemplateHelper.TemplateDataSource.DataSources).To(HaveLen(0)) + Expect(unit.templator.TemplateHelper.TemplateDataSource.GetAllDataSources()).To(HaveLen(0)) } func TestHoverfly_GetTemplateDataSources(t *testing.T) { diff --git a/core/templating/datasource_sql_over_csv.go b/core/templating/datasource_sql_over_csv.go index d1468f88f..6d7fc1501 100644 --- a/core/templating/datasource_sql_over_csv.go +++ b/core/templating/datasource_sql_over_csv.go @@ -55,14 +55,16 @@ func parseSqlCommand(query string, datasource *TemplateDataSource) (SQLStatement } columnsPart = matches[1] dataSourceName = matches[2] - if !datasource.DataSourceExists(dataSourceName) { + source, exits := datasource.GetDataSource(dataSourceName) + if !exits { return SQLStatement{}, errors.New("data source does not exist") } if len(matches) == 4 { wherePart = matches[3] } - headers := datasource.DataSources[dataSourceName].Data[0] + + headers := source.Data[0] columns, err := parseColumns(columnsPart, headers) if err != nil { return SQLStatement{}, err @@ -86,14 +88,15 @@ func parseSqlCommand(query string, datasource *TemplateDataSource) (SQLStatement return SQLStatement{}, errors.New("invalid UPDATE query format") } dataSourceName = matches[1] - if !datasource.DataSourceExists(dataSourceName) { + source, exits := datasource.GetDataSource(dataSourceName) + if !exits { return SQLStatement{}, errors.New("data source does not exist") } setPart := matches[2] if len(matches) == 4 { wherePart = matches[3] } - headers := datasource.DataSources[dataSourceName].Data[0] + headers := source.Data[0] setClauses, err := parseSetClauses(setPart, headers) if err != nil { return SQLStatement{}, err @@ -116,7 +119,7 @@ func parseSqlCommand(query string, datasource *TemplateDataSource) (SQLStatement return SQLStatement{}, errors.New("invalid DELETE query format") } dataSourceName = matches[1] - if !datasource.DataSourceExists(dataSourceName) { + if _, exits := datasource.GetDataSource(dataSourceName); !exits { return SQLStatement{}, errors.New("data source does not exist") } if len(matches) == 3 { diff --git a/core/templating/datasource_sql_over_csv_test.go b/core/templating/datasource_sql_over_csv_test.go index 6e2b7d4d2..e8f93b25c 100644 --- a/core/templating/datasource_sql_over_csv_test.go +++ b/core/templating/datasource_sql_over_csv_test.go @@ -21,7 +21,7 @@ func TestParseCommand(t *testing.T) { }, } templateDataSource := NewTemplateDataSource() - templateDataSource.DataSources = dataSources + templateDataSource.dataSources = dataSources tests := []struct { query string diff --git a/core/templating/template_datasource.go b/core/templating/template_datasource.go index 7165fb10e..4152dd366 100644 --- a/core/templating/template_datasource.go +++ b/core/templating/template_datasource.go @@ -5,47 +5,45 @@ import ( ) type TemplateDataSource struct { - DataSources map[string]*DataSource - RWMutex sync.RWMutex + dataSources map[string]*DataSource + rwMutex sync.RWMutex } func NewTemplateDataSource() *TemplateDataSource { return &TemplateDataSource{ - DataSources: make(map[string]*DataSource), + dataSources: make(map[string]*DataSource), } } -func (templateDataSource *TemplateDataSource) SetDataSource(dataSourceName string, dataSource *DataSource) { +func (t *TemplateDataSource) SetDataSource(dataSourceName string, dataSource *DataSource) { - templateDataSource.RWMutex.Lock() - templateDataSource.DataSources[dataSourceName] = dataSource - templateDataSource.RWMutex.Unlock() + t.rwMutex.Lock() + defer t.rwMutex.Unlock() + + t.dataSources[dataSourceName] = dataSource } -func (templateDataSource *TemplateDataSource) DeleteDataSource(dataSourceName string) { +func (t *TemplateDataSource) DeleteDataSource(dataSourceName string) { - templateDataSource.RWMutex.Lock() + t.rwMutex.Lock() + defer t.rwMutex.Unlock() - if _, ok := templateDataSource.DataSources[dataSourceName]; ok { - delete(templateDataSource.DataSources, dataSourceName) - } - templateDataSource.RWMutex.Unlock() + delete(t.dataSources, dataSourceName) } -func (templateDataSource *TemplateDataSource) GetAllDataSources() map[string]*DataSource { +func (t *TemplateDataSource) GetAllDataSources() map[string]*DataSource { - return templateDataSource.DataSources + t.rwMutex.RLock() + defer t.rwMutex.RUnlock() + + return t.dataSources } -func (templateDataSource *TemplateDataSource) DataSourceExists(name string) bool { - templateDataSource.RWMutex.Lock() - defer templateDataSource.RWMutex.Unlock() +func (t *TemplateDataSource) GetDataSource(name string) (*DataSource, bool) { + t.rwMutex.RLock() + defer t.rwMutex.RUnlock() - for _, dataSource := range templateDataSource.DataSources { - if dataSource.Name == name { - return true - } - } - return false + source, exits := t.dataSources[name] + return source, exits } diff --git a/core/templating/template_helpers.go b/core/templating/template_helpers.go index c545e068a..2edd98729 100644 --- a/core/templating/template_helpers.go +++ b/core/templating/template_helpers.go @@ -239,8 +239,7 @@ func (t templateHelpers) faker(fakerType string) []reflect.Value { } func (t templateHelpers) fetchSingleFieldCsv(dataSourceName, searchFieldName, searchFieldValue, returnFieldName string, options *raymond.Options) string { - templateDataSources := t.TemplateDataSource.DataSources - source, exists := templateDataSources[dataSourceName] + source, exists := t.TemplateDataSource.GetDataSource(dataSourceName) if !exists { log.Error("could not find datasource " + dataSourceName) return getEvaluationString("csv", options) @@ -274,8 +273,7 @@ func (t templateHelpers) fetchSingleFieldCsv(dataSourceName, searchFieldName, se } func (t templateHelpers) fetchMatchingRowsCsv(dataSourceName string, searchFieldName string, searchFieldValue string) []RowMap { - templateDataSources := t.TemplateDataSource.DataSources - source, exists := templateDataSources[dataSourceName] + source, exists := t.TemplateDataSource.GetDataSource(dataSourceName) if !exists { log.Error("could not find datasource " + dataSourceName) return []RowMap{} @@ -316,8 +314,7 @@ func (t templateHelpers) fetchMatchingRowsCsv(dataSourceName string, searchField } func (t templateHelpers) csvAsArray(dataSourceName string) [][]string { - templateDataSources := t.TemplateDataSource.DataSources - source, exists := templateDataSources[dataSourceName] + source, exists := t.TemplateDataSource.GetDataSource(dataSourceName) if exists { source.mu.Lock() defer source.mu.Unlock() @@ -330,8 +327,7 @@ func (t templateHelpers) csvAsArray(dataSourceName string) [][]string { func (t templateHelpers) csvAsMap(dataSourceName string) []RowMap { - templateDataSources := t.TemplateDataSource.DataSources - source, exists := templateDataSources[dataSourceName] + source, exists := t.TemplateDataSource.GetDataSource(dataSourceName) if !exists { log.Error("could not find datasource " + dataSourceName) return []RowMap{} @@ -357,8 +353,7 @@ func (t templateHelpers) csvAsMap(dataSourceName string) []RowMap { } func (t templateHelpers) csvAddRow(dataSourceName string, newRow []string) string { - templateDataSources := t.TemplateDataSource.DataSources - source, exists := templateDataSources[dataSourceName] + source, exists := t.TemplateDataSource.GetDataSource(dataSourceName) if exists { source.mu.Lock() defer source.mu.Unlock() @@ -370,8 +365,7 @@ func (t templateHelpers) csvAddRow(dataSourceName string, newRow []string) strin } func (t templateHelpers) csvDeleteRows(dataSourceName, searchFieldName, searchFieldValue string, output bool) string { - templateDataSources := t.TemplateDataSource.DataSources - source, exists := templateDataSources[dataSourceName] + source, exists := t.TemplateDataSource.GetDataSource(dataSourceName) if !exists { log.Error("could not find datasource " + dataSourceName) return "" @@ -411,8 +405,7 @@ func (t templateHelpers) csvDeleteRows(dataSourceName, searchFieldName, searchFi } func (t templateHelpers) csvCountRows(dataSourceName string) string { - templateDataSources := t.TemplateDataSource.DataSources - source, exists := templateDataSources[dataSourceName] + source, exists := t.TemplateDataSource.GetDataSource(dataSourceName) if !exists { log.Error("could not find datasource " + dataSourceName) return "" @@ -436,7 +429,7 @@ func (t templateHelpers) csvSqlCommand(commandString string) []RowMap { } // Find the data source by name - source, exists := t.TemplateDataSource.DataSources[command.DataSourceName] + source, exists := t.TemplateDataSource.GetDataSource(command.DataSourceName) if !exists { log.Error("Could not find datasource " + command.DataSourceName) return []RowMap{}