From 574bab0dc1162f656073505c9f5f5d6e67dbc543 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 17 Jun 2024 16:01:25 -0700 Subject: [PATCH 01/29] initial commit --- internal/evaluation/flag.go | 106 ++++++++++++++ pkg/experiment/local/cohort.go | 11 ++ pkg/experiment/local/cohort_download_api.go | 148 ++++++++++++++++++++ pkg/experiment/local/cohort_loader.go | 109 ++++++++++++++ pkg/experiment/local/cohort_storage.go | 95 +++++++++++++ pkg/experiment/local/config.go | 23 +++ pkg/experiment/local/flag_config_api.go | 73 ++++++++++ pkg/experiment/local/flag_config_storage.go | 62 ++++++++ pkg/experiment/local/flag_config_util.go | 113 +++++++++++++++ 9 files changed, 740 insertions(+) create mode 100644 internal/evaluation/flag.go create mode 100644 pkg/experiment/local/cohort.go create mode 100644 pkg/experiment/local/cohort_download_api.go create mode 100644 pkg/experiment/local/cohort_loader.go create mode 100644 pkg/experiment/local/cohort_storage.go create mode 100644 pkg/experiment/local/flag_config_api.go create mode 100644 pkg/experiment/local/flag_config_storage.go create mode 100644 pkg/experiment/local/flag_config_util.go diff --git a/internal/evaluation/flag.go b/internal/evaluation/flag.go new file mode 100644 index 0000000..8a717e8 --- /dev/null +++ b/internal/evaluation/flag.go @@ -0,0 +1,106 @@ +package evaluation + +// IsCohortFilter checks if the condition is a cohort filter. +func (f Flag) IsCohortFilter(condition map[string]interface{}) bool { + op, opExists := condition["op"].(string) + selector, selectorExists := condition["selector"].([]interface{}) + if opExists && selectorExists && len(selector) > 0 && selector[len(selector)-1] == "cohort_ids" { + return op == "set contains any" || op == "set does not contain any" + } + return false +} + +// GetGroupedCohortConditionIDs extracts grouped cohort condition IDs from a segment. +func (f Flag) GetGroupedCohortConditionIDs(segment map[string]interface{}) map[string]map[string]bool { + cohortIDs := make(map[string]map[string]bool) + conditions, ok := segment["conditions"].([]interface{}) + if !ok { + return cohortIDs + } + for _, outer := range conditions { + outerCondition, ok := outer.(map[string]interface{}) + if !ok { + continue + } + for _, condition := range outerCondition { + conditionMap, ok := condition.(map[string]interface{}) + if !ok { + continue + } + if f.IsCohortFilter(conditionMap) { + selector, _ := conditionMap["selector"].([]interface{}) + if len(selector) > 2 { + contextSubtype := selector[1].(string) + var groupType string + if contextSubtype == "user" { + groupType = cohort.USER_GROUP_TYPE + } else if selectorContainsGroups(selector) { + groupType = selector[2].(string) + } else { + continue + } + values, _ := conditionMap["values"].([]interface{}) + cohortIDs[groupType] = map[string]bool{} + for _, value := range values { + cohortIDs[groupType][value.(string)] = true + } + } + } + } + } + return cohortIDs +} + +// helper function to check if selector contains groups +func selectorContainsGroups(selector []interface{}) bool { + for _, s := range selector { + if s == "groups" { + return true + } + } + return false +} + +// GetGroupedCohortIDsFromFlag extracts grouped cohort IDs from a flag. +func (f Flag) GetGroupedCohortIDsFromFlag(flag map[string]interface{}) map[string]map[string]bool { + cohortIDs := make(map[string]map[string]bool) + segments, ok := flag["segments"].([]interface{}) + if !ok { + return cohortIDs + } + for _, seg := range segments { + segment, _ := seg.(map[string]interface{}) + for key, values := range f.GetGroupedCohortConditionIDs(segment) { + if _, exists := cohortIDs[key]; !exists { + cohortIDs[key] = make(map[string]bool) + } + for id := range values { + cohortIDs[key][id] = true + } + } + } + return cohortIDs +} + +// GetAllCohortIDsFromFlag extracts all cohort IDs from a flag. +func (f Flag) GetAllCohortIDsFromFlag(flag map[string]interface{}) map[string]bool { + cohortIDs := make(map[string]bool) + groupedIDs := f.GetGroupedCohortIDsFromFlag(flag) + for _, values := range groupedIDs { + for id := range values { + cohortIDs[id] = true + } + } + return cohortIDs +} + +// GetAllCohortIDsFromFlags extracts all cohort IDs from multiple flags. +func (f Flag) GetAllCohortIDsFromFlags(flags []map[string]interface{}) map[string]bool { + cohortIDs := make(map[string]bool) + for _, flag := range flags { + for id := range f.GetAllCohortIDsFromFlag(flag) { + cohortIDs[id] = true + } + } + return cohortIDs +} diff --git a/pkg/experiment/local/cohort.go b/pkg/experiment/local/cohort.go new file mode 100644 index 0000000..cc51451 --- /dev/null +++ b/pkg/experiment/local/cohort.go @@ -0,0 +1,11 @@ +package local + +const userGroupType = "user" + +type Cohort struct { + ID string + LastModified int64 + Size int + MemberIDs map[string]struct{} + GroupType string +} diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go new file mode 100644 index 0000000..081efe7 --- /dev/null +++ b/pkg/experiment/local/cohort_download_api.go @@ -0,0 +1,148 @@ +package local + +import ( + "encoding/base64" + "encoding/json" + "log" + "net/http" + "strconv" + "time" +) + +const ( + CdnCohortSyncUrl = "https://cohort-v2.lab.amplitude.com" +) + +type HTTPErrorResponseException struct { + StatusCode int + Message string +} + +func (e *HTTPErrorResponseException) Error() string { + return e.Message +} + +type CohortTooLargeException struct { + Message string +} + +func (e *CohortTooLargeException) Error() string { + return e.Message +} + +type CohortNotModifiedException struct { + Message string +} + +func (e *CohortNotModifiedException) Error() string { + return e.Message +} + +type CohortDownloadApi interface { + GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) +} + +type DirectCohortDownloadApi struct { + ApiKey string + SecretKey string + MaxCohortSize int + CohortRequestDelayMillis int + Debug bool + Logger *log.Logger +} + +func NewDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortRequestDelayMillis int, debug bool) *DirectCohortDownloadApi { + api := &DirectCohortDownloadApi{ + ApiKey: apiKey, + SecretKey: secretKey, + MaxCohortSize: maxCohortSize, + CohortRequestDelayMillis: cohortRequestDelayMillis, + Debug: debug, + Logger: log.New(log.Writer(), "Amplitude: ", log.LstdFlags), + } + if debug { + api.Logger.SetFlags(log.LstdFlags | log.Lshortfile) + } + return api +} + +func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { + api.Logger.Printf("getCohortMembers(%s): start", cohortID) + errors := 0 + client := &http.Client{} + + for { + response, err := api.getCohortMembersRequest(client, cohortID, cohort) + if err != nil { + api.Logger.Printf("getCohortMembers(%s): request-status error %d - %v", cohortID, errors, err) + errors++ + if errors >= 3 || isSpecificError(err) { + return nil, err + } + time.Sleep(time.Duration(api.CohortRequestDelayMillis) * time.Millisecond) + continue + } + + if response.StatusCode == http.StatusOK { + var cohortInfo struct { + CohortId string `json:"cohortId"` + LastModified int64 `json:"lastModified"` + Size int `json:"size"` + MemberIds []string `json:"memberIds"` + GroupType string `json:"groupType"` + } + if err := json.NewDecoder(response.Body).Decode(&cohortInfo); err != nil { + return nil, err + } + memberIDs := make(map[string]struct{}, len(cohortInfo.MemberIds)) + for _, id := range cohortInfo.MemberIds { + memberIDs[id] = struct{}{} + } + api.Logger.Printf("getCohortMembers(%s): end - resultSize=%d", cohortID, cohortInfo.Size) + return &Cohort{ + ID: cohortInfo.CohortId, + LastModified: cohortInfo.LastModified, + Size: cohortInfo.Size, + MemberIDs: memberIDs, + GroupType: cohortInfo.GroupType, + }, nil + } else if response.StatusCode == http.StatusNoContent { + return nil, &CohortNotModifiedException{Message: "Cohort not modified"} + } else if response.StatusCode == http.StatusRequestEntityTooLarge { + return nil, &CohortTooLargeException{Message: "Cohort exceeds max cohort size"} + } else { + return nil, &HTTPErrorResponseException{StatusCode: response.StatusCode, Message: "Unexpected response code"} + } + } +} + +func isSpecificError(err error) bool { + switch err.(type) { + case *CohortNotModifiedException, *CohortTooLargeException: + return true + default: + return false + } +} + +func (api *DirectCohortDownloadApi) getCohortMembersRequest(client *http.Client, cohortID string, cohort *Cohort) (*http.Response, error) { + req, err := http.NewRequest("GET", api.buildCohortURL(cohortID, cohort), nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Basic "+api.getBasicAuth()) + return client.Do(req) +} + +func (api *DirectCohortDownloadApi) getBasicAuth() string { + auth := api.ApiKey + ":" + api.SecretKey + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func (api *DirectCohortDownloadApi) buildCohortURL(cohortID string, cohort *Cohort) string { + url := CdnCohortSyncUrl + "/sdk/v1/cohort/" + cohortID + "?maxCohortSize=" + strconv.Itoa(api.MaxCohortSize) + if cohort != nil { + url += "&lastModified=" + strconv.FormatInt(cohort.LastModified, 10) + } + return url +} diff --git a/pkg/experiment/local/cohort_loader.go b/pkg/experiment/local/cohort_loader.go new file mode 100644 index 0000000..43e3b01 --- /dev/null +++ b/pkg/experiment/local/cohort_loader.go @@ -0,0 +1,109 @@ +package local + +import ( + "sync" + "sync/atomic" +) + +// CohortLoader handles the loading of cohorts using CohortDownloadApi and CohortStorage. +type CohortLoader struct { + cohortDownloadApi CohortDownloadApi + cohortStorage CohortStorage + jobs sync.Map + executor *sync.Pool + lockJobs sync.Mutex +} + +// NewCohortLoader creates a new instance of CohortLoader. +func NewCohortLoader(cohortDownloadApi CohortDownloadApi, cohortStorage CohortStorage) *CohortLoader { + return &CohortLoader{ + cohortDownloadApi: cohortDownloadApi, + cohortStorage: cohortStorage, + executor: &sync.Pool{ + New: func() interface{} { + return &CohortLoaderTask{} + }, + }, + } +} + +// LoadCohort initiates the loading of a cohort. +func (cl *CohortLoader) LoadCohort(cohortID string) *CohortLoaderTask { + cl.lockJobs.Lock() + defer cl.lockJobs.Unlock() + + task, ok := cl.jobs.Load(cohortID) + if !ok { + task = cl.executor.Get().(*CohortLoaderTask) + task.(*CohortLoaderTask).init(cl, cohortID) + cl.jobs.Store(cohortID, task) + go task.(*CohortLoaderTask).run() + } + + return task.(*CohortLoaderTask) +} + +// removeJob removes a job from the jobs map. +func (cl *CohortLoader) removeJob(cohortID string) { + cl.jobs.Delete(cohortID) +} + +// CohortLoaderTask represents a task for loading a cohort. +type CohortLoaderTask struct { + loader *CohortLoader + cohortID string + done int32 + doneChan chan struct{} + err error +} + +// init initializes a CohortLoaderTask. +func (task *CohortLoaderTask) init(loader *CohortLoader, cohortID string) { + task.loader = loader + task.cohortID = cohortID + task.done = 0 + task.doneChan = make(chan struct{}) + task.err = nil +} + +// run executes the task of loading a cohort. +func (task *CohortLoaderTask) run() { + defer task.loader.executor.Put(task) + + cohort, err := task.loader.downloadCohort(task.cohortID) + if err != nil { + task.err = err + } else { + task.loader.cohortStorage.PutCohort(cohort) + } + + task.loader.removeJob(task.cohortID) + atomic.StoreInt32(&task.done, 1) + close(task.doneChan) +} + +// Wait waits for the task to complete. +func (task *CohortLoaderTask) Wait() error { + <-task.doneChan + return task.err +} + +// downloadCohort downloads a cohort. +func (cl *CohortLoader) downloadCohort(cohortID string) (*Cohort, error) { + cohort := cl.cohortStorage.GetCohort(cohortID) + return cl.cohortDownloadApi.GetCohort(cohortID, cohort) +} + +type CohortDownloadApiImpl struct{} + +// GetCohort gets a cohort. +func (api *CohortDownloadApiImpl) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { + // Placeholder implementation + return &Cohort{ + ID: cohortID, + LastModified: 0, + Size: 0, + MemberIDs: make(map[string]struct{}), + GroupType: "example", + }, nil +} diff --git a/pkg/experiment/local/cohort_storage.go b/pkg/experiment/local/cohort_storage.go new file mode 100644 index 0000000..35b9645 --- /dev/null +++ b/pkg/experiment/local/cohort_storage.go @@ -0,0 +1,95 @@ +package local + +import ( + "sync" +) + +// CohortStorage defines the interface for cohort storage operations +type CohortStorage interface { + GetCohort(cohortID string) *Cohort + GetCohorts() map[string]*Cohort + GetCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} + GetCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} + PutCohort(cohort *Cohort) + DeleteCohort(groupType, cohortID string) +} + +// InMemoryCohortStorage is an in-memory implementation of CohortStorage +type InMemoryCohortStorage struct { + lock sync.RWMutex + groupToCohortStore map[string]map[string]struct{} + cohortStore map[string]*Cohort +} + +// NewInMemoryCohortStorage creates a new InMemoryCohortStorage instance +func NewInMemoryCohortStorage() *InMemoryCohortStorage { + return &InMemoryCohortStorage{ + groupToCohortStore: make(map[string]map[string]struct{}), + cohortStore: make(map[string]*Cohort), + } +} + +// GetCohort retrieves a cohort by its ID +func (s *InMemoryCohortStorage) GetCohort(cohortID string) *Cohort { + s.lock.RLock() + defer s.lock.RUnlock() + return s.cohortStore[cohortID] +} + +// GetCohorts retrieves all cohorts +func (s *InMemoryCohortStorage) GetCohorts() map[string]*Cohort { + s.lock.RLock() + defer s.lock.RUnlock() + cohorts := make(map[string]*Cohort) + for id, cohort := range s.cohortStore { + cohorts[id] = cohort + } + return cohorts +} + +// GetCohortsForUser retrieves cohorts for a user based on cohort IDs +func (s *InMemoryCohortStorage) GetCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} { + return s.GetCohortsForGroup(userGroupType, userID, cohortIDs) +} + +// GetCohortsForGroup retrieves cohorts for a group based on cohort IDs +func (s *InMemoryCohortStorage) GetCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} { + result := make(map[string]struct{}) + s.lock.RLock() + defer s.lock.RUnlock() + groupTypeCohorts := s.groupToCohortStore[groupType] + for cohortID := range groupTypeCohorts { + if _, exists := cohortIDs[cohortID]; exists { + if cohort, found := s.cohortStore[cohortID]; found { + if _, memberExists := cohort.MemberIDs[groupName]; memberExists { + result[cohortID] = struct{}{} + } + } + } + } + return result +} + +// PutCohort stores a cohort +func (s *InMemoryCohortStorage) PutCohort(cohort *Cohort) { + s.lock.Lock() + defer s.lock.Unlock() + if _, exists := s.groupToCohortStore[cohort.GroupType]; !exists { + s.groupToCohortStore[cohort.GroupType] = make(map[string]struct{}) + } + s.groupToCohortStore[cohort.GroupType][cohort.ID] = struct{}{} + s.cohortStore[cohort.ID] = cohort +} + +// DeleteCohort deletes a cohort by its ID and group type +func (s *InMemoryCohortStorage) DeleteCohort(groupType, cohortID string) { + s.lock.Lock() + defer s.lock.Unlock() + if groupCohorts, exists := s.groupToCohortStore[groupType]; exists { + delete(groupCohorts, cohortID) + if len(groupCohorts) == 0 { + delete(s.groupToCohortStore, groupType) + } + } + delete(s.cohortStore, cohortID) +} diff --git a/pkg/experiment/local/config.go b/pkg/experiment/local/config.go index 9c9cb9d..fd746ab 100644 --- a/pkg/experiment/local/config.go +++ b/pkg/experiment/local/config.go @@ -11,6 +11,7 @@ type Config struct { FlagConfigPollerInterval time.Duration FlagConfigPollerRequestTimeout time.Duration AssignmentConfig *AssignmentConfig + CohortSyncConfig *CohortSyncConfig } type AssignmentConfig struct { @@ -18,6 +19,14 @@ type AssignmentConfig struct { CacheCapacity int } +// CohortSyncConfig holds configuration for cohort synchronization. +type CohortSyncConfig struct { + ApiKey string + SecretKey string + MaxCohortSize int + CohortRequestDelayMillis int +} + var DefaultConfig = &Config{ Debug: false, ServerUrl: "https://api.lab.amplitude.com/", @@ -29,6 +38,11 @@ var DefaultAssignmentConfig = &AssignmentConfig{ CacheCapacity: 524288, } +var DefaultCohortSyncConfig = &CohortSyncConfig{ + MaxCohortSize: 15000, + CohortRequestDelayMillis: 5000, +} + func fillConfigDefaults(c *Config) *Config { if c == nil { return DefaultConfig @@ -45,5 +59,14 @@ func fillConfigDefaults(c *Config) *Config { if c.AssignmentConfig != nil && c.AssignmentConfig.CacheCapacity == 0 { c.AssignmentConfig.CacheCapacity = DefaultAssignmentConfig.CacheCapacity } + + if c.CohortSyncConfig != nil && c.CohortSyncConfig.MaxCohortSize == 0 { + c.CohortSyncConfig.MaxCohortSize = DefaultCohortSyncConfig.MaxCohortSize + } + + if c.CohortSyncConfig != nil && c.CohortSyncConfig.CohortRequestDelayMillis == 0 { + c.CohortSyncConfig.CohortRequestDelayMillis = DefaultCohortSyncConfig.CohortRequestDelayMillis + } + return c } diff --git a/pkg/experiment/local/flag_config_api.go b/pkg/experiment/local/flag_config_api.go new file mode 100644 index 0000000..dffbeba --- /dev/null +++ b/pkg/experiment/local/flag_config_api.go @@ -0,0 +1,73 @@ +package local + +import ( + "context" + "encoding/json" + "fmt" + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/amplitude/experiment-go-server/pkg/experiment" + "io/ioutil" + "net/http" + "net/url" + "time" +) + +// FlagConfigApi defines an interface for retrieving flag configurations. +type FlagConfigApi interface { + GetFlagConfigs() []interface{} +} + +// FlagConfigApiV2 is an implementation of the FlagConfigApi interface for version 2 of the API. +type FlagConfigApiV2 struct { + DeploymentKey string + ServerURL string + FlagConfigPollerRequestTimeoutMillis time.Duration +} + +// NewFlagConfigApiV2 creates a new instance of FlagConfigApiV2. +func NewFlagConfigApiV2(deploymentKey, serverURL string, flagConfigPollerRequestTimeoutMillis time.Duration) *FlagConfigApiV2 { + return &FlagConfigApiV2{ + DeploymentKey: deploymentKey, + ServerURL: serverURL, + FlagConfigPollerRequestTimeoutMillis: flagConfigPollerRequestTimeoutMillis, + } +} + +func (a *FlagConfigApiV2) GetFlagConfigs() (map[string]*evaluation.Flag, error) { + client := &http.Client{} + endpoint, err := url.Parse("https://api.lab.amplitude.com/") + if err != nil { + return nil, err + } + endpoint.Path = "sdk/v2/flags" + endpoint.RawQuery = "v=0" + ctx, cancel := context.WithTimeout(context.Background(), a.FlagConfigPollerRequestTimeoutMillis) + defer cancel() + req, err := http.NewRequest("GET", endpoint.String(), nil) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + req.Header.Set("Authorization", fmt.Sprintf("Api-Key %s", a.DeploymentKey)) + req.Header.Set("Content-Type", "application/json; charset=UTF-8") + req.Header.Set("X-Amp-Exp-Library", fmt.Sprintf("experiment-go-server/%v", experiment.VERSION)) + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var flagsArray []*evaluation.Flag + err = json.Unmarshal(body, &flagsArray) + if err != nil { + return nil, err + } + flags := make(map[string]*evaluation.Flag) + for _, flag := range flagsArray { + flags[flag.Key] = flag + } + return flags, nil +} diff --git a/pkg/experiment/local/flag_config_storage.go b/pkg/experiment/local/flag_config_storage.go new file mode 100644 index 0000000..6c4d76a --- /dev/null +++ b/pkg/experiment/local/flag_config_storage.go @@ -0,0 +1,62 @@ +package local + +import ( + "sync" +) + +// FlagConfigStorage defines an interface for managing flag configurations. +type FlagConfigStorage interface { + GetFlagConfig(key string) map[string]interface{} + GetFlagConfigs() map[string]map[string]interface{} + PutFlagConfig(flagConfig map[string]interface{}) + RemoveIf(condition func(map[string]interface{}) bool) +} + +// InMemoryFlagConfigStorage is an in-memory implementation of FlagConfigStorage. +type InMemoryFlagConfigStorage struct { + flagConfigs map[string]map[string]interface{} + flagConfigsLock sync.Mutex +} + +// NewInMemoryFlagConfigStorage creates a new instance of InMemoryFlagConfigStorage. +func NewInMemoryFlagConfigStorage() *InMemoryFlagConfigStorage { + return &InMemoryFlagConfigStorage{ + flagConfigs: make(map[string]map[string]interface{}), + } +} + +// GetFlagConfig retrieves a flag configuration by key. +func (storage *InMemoryFlagConfigStorage) GetFlagConfig(key string) map[string]interface{} { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + return storage.flagConfigs[key] +} + +// GetFlagConfigs retrieves all flag configurations. +func (storage *InMemoryFlagConfigStorage) GetFlagConfigs() map[string]map[string]interface{} { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + copyFlagConfigs := make(map[string]map[string]interface{}) + for key, value := range storage.flagConfigs { + copyFlagConfigs[key] = value + } + return copyFlagConfigs +} + +// PutFlagConfig stores a flag configuration. +func (storage *InMemoryFlagConfigStorage) PutFlagConfig(flagConfig map[string]interface{}) { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + storage.flagConfigs[flagConfig["key"].(string)] = flagConfig +} + +// RemoveIf removes flag configurations based on a condition. +func (storage *InMemoryFlagConfigStorage) RemoveIf(condition func(map[string]interface{}) bool) { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + for key, value := range storage.flagConfigs { + if condition(value) { + delete(storage.flagConfigs, key) + } + } +} diff --git a/pkg/experiment/local/flag_config_util.go b/pkg/experiment/local/flag_config_util.go new file mode 100644 index 0000000..768b89b --- /dev/null +++ b/pkg/experiment/local/flag_config_util.go @@ -0,0 +1,113 @@ +package local + +import ( + "github.com/amplitude/experiment-go-server/internal/evaluation" +) + +// IsCohortFilter checks if the condition is a cohort filter. +func IsCohortFilter(condition *evaluation.Condition) bool { + op := condition.Op + selector := condition.Selector + if len(selector) > 0 && selector[len(selector)-1] == "cohort_ids" { + return op == "set contains any" || op == "set does not contain any" + } + return false +} + +// GetGroupedCohortConditionIDs extracts grouped cohort condition IDs from a segment. +func GetGroupedCohortConditionIDs(segment *evaluation.Segment) map[string]map[string]bool { + cohortIDs := make(map[string]map[string]bool) + if segment == nil { + return cohortIDs + } + + for _, outer := range segment.Conditions { + for _, condition := range outer { + if IsCohortFilter(condition) { + selector := condition.Selector + if len(selector) > 2 { + contextSubtype := selector[1] + var groupType string + if contextSubtype == "user" { + groupType = userGroupType + } else if selectorContainsGroups(selector) { + groupType = selector[2] + } else { + continue + } + values := condition.Values + cohortIDs[groupType] = make(map[string]bool) + for _, value := range values { + cohortIDs[groupType][value] = true + } + } + } + } + } + return cohortIDs +} + +// GetGroupedCohortIDsFromFlag extracts grouped cohort IDs from a flag. +func GetGroupedCohortIDsFromFlag(flag *evaluation.Flag) map[string]map[string]bool { + cohortIDs := make(map[string]map[string]bool) + for _, segment := range flag.Segments { + for key, values := range GetGroupedCohortConditionIDs(segment) { + if _, exists := cohortIDs[key]; !exists { + cohortIDs[key] = make(map[string]bool) + } + for id := range values { + cohortIDs[key][id] = true + } + } + } + return cohortIDs +} + +// GetAllCohortIDsFromFlag extracts all cohort IDs from a flag. +func GetAllCohortIDsFromFlag(flag *evaluation.Flag) map[string]bool { + cohortIDs := make(map[string]bool) + groupedIDs := GetGroupedCohortIDsFromFlag(flag) + for _, values := range groupedIDs { + for id := range values { + cohortIDs[id] = true + } + } + return cohortIDs +} + +// GetGroupedCohortIDsFromFlags extracts grouped cohort IDs from multiple flags. +func GetGroupedCohortIDsFromFlags(flags []*evaluation.Flag) map[string]map[string]bool { + cohortIDs := make(map[string]map[string]bool) + for _, flag := range flags { + for key, values := range GetGroupedCohortIDsFromFlag(flag) { + if _, exists := cohortIDs[key]; !exists { + cohortIDs[key] = make(map[string]bool) + } + for id := range values { + cohortIDs[key][id] = true + } + } + } + return cohortIDs +} + +// GetAllCohortIDsFromFlags extracts all cohort IDs from multiple flags. +func GetAllCohortIDsFromFlags(flags []*evaluation.Flag) map[string]bool { + cohortIDs := make(map[string]bool) + for _, flag := range flags { + for id := range GetAllCohortIDsFromFlag(flag) { + cohortIDs[id] = true + } + } + return cohortIDs +} + +// helper function to check if selector contains groups +func selectorContainsGroups(selector []string) bool { + for _, s := range selector { + if s == "groups" { + return true + } + } + return false +} From b527601e15ae9c3125977a529013453048d34917 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 27 Jun 2024 09:29:02 -0700 Subject: [PATCH 02/29] add deployment runner --- go.mod | 6 +- go.sum | 11 +- internal/evaluation/flag.go | 106 --------- pkg/experiment/local/client.go | 68 +++++- pkg/experiment/local/cohort.go | 25 +- pkg/experiment/local/cohort_download_api.go | 21 +- .../local/cohort_download_api_test.go | 190 +++++++++++++++ pkg/experiment/local/cohort_loader.go | 23 -- pkg/experiment/local/cohort_loader_test.go | 105 +++++++++ pkg/experiment/local/cohort_storage.go | 28 +-- pkg/experiment/local/config.go | 6 + pkg/experiment/local/deployment_runner.go | 153 ++++++++++++ .../local/deployment_runner_test.go | 107 +++++++++ pkg/experiment/local/flag_config_api.go | 5 +- pkg/experiment/local/flag_config_storage.go | 32 ++- pkg/experiment/local/flag_config_test.go | 221 ++++++++++++++++++ pkg/experiment/local/flag_config_util.go | 60 ++--- pkg/experiment/types.go | 57 +++-- 18 files changed, 989 insertions(+), 235 deletions(-) delete mode 100644 internal/evaluation/flag.go create mode 100644 pkg/experiment/local/cohort_download_api_test.go create mode 100644 pkg/experiment/local/cohort_loader_test.go create mode 100644 pkg/experiment/local/deployment_runner.go create mode 100644 pkg/experiment/local/deployment_runner_test.go create mode 100644 pkg/experiment/local/flag_config_test.go diff --git a/go.mod b/go.mod index d282c4d..81525f6 100644 --- a/go.mod +++ b/go.mod @@ -4,4 +4,8 @@ go 1.12 require github.com/spaolacci/murmur3 v1.1.0 -require github.com/amplitude/analytics-go v1.0.1 +require ( + github.com/amplitude/analytics-go v1.0.1 + github.com/jarcoal/httpmock v1.3.1 + github.com/stretchr/testify v1.9.0 +) diff --git a/go.sum b/go.sum index 40e774e..46aa182 100644 --- a/go.sum +++ b/go.sum @@ -5,18 +5,25 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= +github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= +github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/evaluation/flag.go b/internal/evaluation/flag.go deleted file mode 100644 index 8a717e8..0000000 --- a/internal/evaluation/flag.go +++ /dev/null @@ -1,106 +0,0 @@ -package evaluation - -// IsCohortFilter checks if the condition is a cohort filter. -func (f Flag) IsCohortFilter(condition map[string]interface{}) bool { - op, opExists := condition["op"].(string) - selector, selectorExists := condition["selector"].([]interface{}) - if opExists && selectorExists && len(selector) > 0 && selector[len(selector)-1] == "cohort_ids" { - return op == "set contains any" || op == "set does not contain any" - } - return false -} - -// GetGroupedCohortConditionIDs extracts grouped cohort condition IDs from a segment. -func (f Flag) GetGroupedCohortConditionIDs(segment map[string]interface{}) map[string]map[string]bool { - cohortIDs := make(map[string]map[string]bool) - conditions, ok := segment["conditions"].([]interface{}) - if !ok { - return cohortIDs - } - for _, outer := range conditions { - outerCondition, ok := outer.(map[string]interface{}) - if !ok { - continue - } - for _, condition := range outerCondition { - conditionMap, ok := condition.(map[string]interface{}) - if !ok { - continue - } - if f.IsCohortFilter(conditionMap) { - selector, _ := conditionMap["selector"].([]interface{}) - if len(selector) > 2 { - contextSubtype := selector[1].(string) - var groupType string - if contextSubtype == "user" { - groupType = cohort.USER_GROUP_TYPE - } else if selectorContainsGroups(selector) { - groupType = selector[2].(string) - } else { - continue - } - values, _ := conditionMap["values"].([]interface{}) - cohortIDs[groupType] = map[string]bool{} - for _, value := range values { - cohortIDs[groupType][value.(string)] = true - } - } - } - } - } - return cohortIDs -} - -// helper function to check if selector contains groups -func selectorContainsGroups(selector []interface{}) bool { - for _, s := range selector { - if s == "groups" { - return true - } - } - return false -} - -// GetGroupedCohortIDsFromFlag extracts grouped cohort IDs from a flag. -func (f Flag) GetGroupedCohortIDsFromFlag(flag map[string]interface{}) map[string]map[string]bool { - cohortIDs := make(map[string]map[string]bool) - segments, ok := flag["segments"].([]interface{}) - if !ok { - return cohortIDs - } - for _, seg := range segments { - segment, _ := seg.(map[string]interface{}) - for key, values := range f.GetGroupedCohortConditionIDs(segment) { - if _, exists := cohortIDs[key]; !exists { - cohortIDs[key] = make(map[string]bool) - } - for id := range values { - cohortIDs[key][id] = true - } - } - } - return cohortIDs -} - -// GetAllCohortIDsFromFlag extracts all cohort IDs from a flag. -func (f Flag) GetAllCohortIDsFromFlag(flag map[string]interface{}) map[string]bool { - cohortIDs := make(map[string]bool) - groupedIDs := f.GetGroupedCohortIDsFromFlag(flag) - for _, values := range groupedIDs { - for id := range values { - cohortIDs[id] = true - } - } - return cohortIDs -} - -// GetAllCohortIDsFromFlags extracts all cohort IDs from multiple flags. -func (f Flag) GetAllCohortIDsFromFlags(flags []map[string]interface{}) map[string]bool { - cohortIDs := make(map[string]bool) - for _, flag := range flags { - for id := range f.GetAllCohortIDsFromFlag(flag) { - cohortIDs[id] = true - } - } - return cohortIDs -} diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 7ba17e1..7697d8e 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -31,6 +31,10 @@ type Client struct { flagsMutex *sync.RWMutex engine *evaluation.Engine assignmentService *assignmentService + cohortStorage CohortStorage + flagConfigStorage FlagConfigStorage + cohortLoader *CohortLoader + deploymentRunner *DeploymentRunner } func Initialize(apiKey string, config *Config) *Client { @@ -43,23 +47,36 @@ func Initialize(apiKey string, config *Config) *Client { config = fillConfigDefaults(config) log := logger.New(config.Debug) var as *assignmentService - if config.AssignmentConfig != nil && config.AssignmentConfig.APIKey != "" { + if config.AssignmentConfig != nil && config.AssignmentConfig.APIKey != "" { amplitudeClient := amplitude.NewClient(config.AssignmentConfig.Config) as = &assignmentService{ amplitude: &litudeClient, - filter: newAssignmentFilter(config.AssignmentConfig.CacheCapacity), + filter: newAssignmentFilter(config.AssignmentConfig.CacheCapacity), } } + cohortStorage := NewInMemoryCohortStorage() + flagConfigStorage := NewInMemoryFlagConfigStorage() + var cohortLoader *CohortLoader + var deploymentRunner *DeploymentRunner + if config.CohortSyncConfig != nil { + cohortDownloadApi := NewDirectCohortDownloadApi(config.CohortSyncConfig.ApiKey, config.CohortSyncConfig.SecretKey, config.CohortSyncConfig.MaxCohortSize, config.CohortSyncConfig.CohortRequestDelayMillis, config.CohortSyncConfig.CohortServerUrl, config.Debug) + cohortLoader = NewCohortLoader(cohortDownloadApi, cohortStorage) + deploymentRunner = NewDeploymentRunner(config, NewFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), flagConfigStorage, cohortStorage, cohortLoader) + } client = &Client{ - log: log, - apiKey: apiKey, - config: config, - client: &http.Client{}, - poller: newPoller(), - flags: make(map[string]*evaluation.Flag), - flagsMutex: &sync.RWMutex{}, - engine: evaluation.NewEngine(log), + log: log, + apiKey: apiKey, + config: config, + client: &http.Client{}, + poller: newPoller(), + flags: make(map[string]*evaluation.Flag), + flagsMutex: &sync.RWMutex{}, + engine: evaluation.NewEngine(log), assignmentService: as, + cohortStorage: cohortStorage, + flagConfigStorage: flagConfigStorage, + cohortLoader: cohortLoader, + deploymentRunner: deploymentRunner, } client.log.Debug("config: %v", *config) clients[apiKey] = client @@ -329,3 +346,34 @@ func coerceString(value interface{}) string { } return fmt.Sprintf("%v", value) } + +func (c *Client) enrichUser(user *experiment.User, flagConfigs map[string]evaluation.Flag) (*experiment.User, error) { + flagConfigSlice := make([]*evaluation.Flag, 0, len(flagConfigs)) + + for _, value := range flagConfigs { + flagConfigSlice = append(flagConfigSlice, &value) + } + groupedCohortIDs := getGroupedCohortIDsFromFlags(flagConfigSlice) + + if cohortIDs, ok := groupedCohortIDs[userGroupType]; ok { + if len(cohortIDs) > 0 && user.UserId != "" { + user.CohortIDs = c.cohortStorage.GetCohortsForUser(user.UserId, cohortIDs) + } + } + + if user.Groups != nil { + for groupType, groupNames := range user.Groups { + groupName := "" + if len(groupNames) > 0 { + groupName = groupNames[0] + } + if groupName == "" { + continue + } + if cohortIDs, ok := groupedCohortIDs[groupType]; ok { + user.AddGroupCohortIDs(groupType, groupName, c.cohortStorage.GetCohortsForGroup(groupType, groupName, cohortIDs)) + } + } + } + return user, nil +} diff --git a/pkg/experiment/local/cohort.go b/pkg/experiment/local/cohort.go index cc51451..584bcb9 100644 --- a/pkg/experiment/local/cohort.go +++ b/pkg/experiment/local/cohort.go @@ -1,11 +1,34 @@ package local +import "sort" + const userGroupType = "user" type Cohort struct { ID string LastModified int64 Size int - MemberIDs map[string]struct{} + MemberIDs []string GroupType string } + +func CohortEquals(c1, c2 *Cohort) bool { + if c1.ID != c2.ID || c1.LastModified != c2.LastModified || c1.Size != c2.Size || c1.GroupType != c2.GroupType { + return false + } + if len(c1.MemberIDs) != len(c2.MemberIDs) { + return false + } + + // Sort MemberIDs before comparing + sort.Strings(c1.MemberIDs) + sort.Strings(c2.MemberIDs) + + for i := range c1.MemberIDs { + if c1.MemberIDs[i] != c2.MemberIDs[i] { + return false + } + } + + return true +} diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 081efe7..837886f 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -47,16 +47,18 @@ type DirectCohortDownloadApi struct { SecretKey string MaxCohortSize int CohortRequestDelayMillis int + ServerUrl string Debug bool Logger *log.Logger } -func NewDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortRequestDelayMillis int, debug bool) *DirectCohortDownloadApi { +func NewDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortRequestDelayMillis int, serverUrl string, debug bool) *DirectCohortDownloadApi { api := &DirectCohortDownloadApi{ ApiKey: apiKey, SecretKey: secretKey, MaxCohortSize: maxCohortSize, CohortRequestDelayMillis: cohortRequestDelayMillis, + ServerUrl: serverUrl, Debug: debug, Logger: log.New(log.Writer(), "Amplitude: ", log.LstdFlags), } @@ -85,7 +87,7 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( if response.StatusCode == http.StatusOK { var cohortInfo struct { - CohortId string `json:"cohortId"` + CohortId string `json:"Id"` LastModified int64 `json:"lastModified"` Size int `json:"size"` MemberIds []string `json:"memberIds"` @@ -94,17 +96,18 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( if err := json.NewDecoder(response.Body).Decode(&cohortInfo); err != nil { return nil, err } - memberIDs := make(map[string]struct{}, len(cohortInfo.MemberIds)) - for _, id := range cohortInfo.MemberIds { - memberIDs[id] = struct{}{} - } api.Logger.Printf("getCohortMembers(%s): end - resultSize=%d", cohortID, cohortInfo.Size) return &Cohort{ ID: cohortInfo.CohortId, LastModified: cohortInfo.LastModified, Size: cohortInfo.Size, - MemberIDs: memberIDs, - GroupType: cohortInfo.GroupType, + MemberIDs: cohortInfo.MemberIds, + GroupType: func() string { + if cohortInfo.GroupType == "" { + return userGroupType + } + return cohortInfo.GroupType + }(), }, nil } else if response.StatusCode == http.StatusNoContent { return nil, &CohortNotModifiedException{Message: "Cohort not modified"} @@ -140,7 +143,7 @@ func (api *DirectCohortDownloadApi) getBasicAuth() string { } func (api *DirectCohortDownloadApi) buildCohortURL(cohortID string, cohort *Cohort) string { - url := CdnCohortSyncUrl + "/sdk/v1/cohort/" + cohortID + "?maxCohortSize=" + strconv.Itoa(api.MaxCohortSize) + url := api.ServerUrl + "/sdk/v1/cohort/" + cohortID + "?maxCohortSize=" + strconv.Itoa(api.MaxCohortSize) if cohort != nil { url += "&lastModified=" + strconv.FormatInt(cohort.LastModified, 10) } diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go new file mode 100644 index 0000000..428f2f6 --- /dev/null +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -0,0 +1,190 @@ +package local + +import ( + "net/http" + "testing" + + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockCohortDownloadApi struct { + mock.Mock +} + +func (m *MockCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { + args := m.Called(cohortID, cohort) + if args.Get(0) != nil { + return args.Get(0).(*Cohort), args.Error(1) + } + return nil, args.Error(1) +} + +func TestCohortDownloadApi(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + api := NewDirectCohortDownloadApi("api", "secret", 15000, 100, "https://server.amplitude.com", false) + + t.Run("test_cohort_download_success", func(t *testing.T) { + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.GetCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort, resultCohort) + }) + + t.Run("test_cohort_download_many_202s_success", func(t *testing.T) { + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + + for i := 0; i < 9; i++ { + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(202, ""), + ) + } + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.GetCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort, resultCohort) + }) + + t.Run("test_cohort_request_status_with_two_failures_succeeds", func(t *testing.T) { + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(503, ""), + ) + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(503, ""), + ) + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.GetCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort, resultCohort) + }) + + t.Run("test_cohort_request_status_429s_keep_retrying", func(t *testing.T) { + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + + for i := 0; i < 9; i++ { + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(429, ""), + ) + } + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.GetCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort, resultCohort) + }) + + t.Run("test_group_cohort_download_success", func(t *testing.T) { + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"group"}, GroupType: "org name"} + response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"group"}, GroupType: "org name"} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.GetCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort, resultCohort) + }) + + t.Run("test_group_cohort_request_status_429s_keep_retrying", func(t *testing.T) { + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"group"}, GroupType: "org name"} + response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"group"}, GroupType: "org name"} + + for i := 0; i < 9; i++ { + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(429, ""), + ) + } + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.GetCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort, resultCohort) + }) + + t.Run("test_cohort_size_too_large", func(t *testing.T) { + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 16000, MemberIDs: []string{}} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(413, ""), + ) + + _, err := api.GetCohort("1234", cohort) + assert.Error(t, err) + _, isCohortTooLargeException := err.(*CohortTooLargeException) + assert.True(t, isCohortTooLargeException) + }) + + t.Run("test_cohort_not_modified_exception", func(t *testing.T) { + cohort := &Cohort{ID: "1234", LastModified: 1000, Size: 1, MemberIDs: []string{}} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(204, ""), + ) + + _, err := api.GetCohort("1234", cohort) + assert.Error(t, err) + _, isCohortNotModifiedException := err.(*CohortNotModifiedException) + assert.True(t, isCohortNotModifiedException) + }) +} diff --git a/pkg/experiment/local/cohort_loader.go b/pkg/experiment/local/cohort_loader.go index 43e3b01..eefddcd 100644 --- a/pkg/experiment/local/cohort_loader.go +++ b/pkg/experiment/local/cohort_loader.go @@ -5,7 +5,6 @@ import ( "sync/atomic" ) -// CohortLoader handles the loading of cohorts using CohortDownloadApi and CohortStorage. type CohortLoader struct { cohortDownloadApi CohortDownloadApi cohortStorage CohortStorage @@ -14,7 +13,6 @@ type CohortLoader struct { lockJobs sync.Mutex } -// NewCohortLoader creates a new instance of CohortLoader. func NewCohortLoader(cohortDownloadApi CohortDownloadApi, cohortStorage CohortStorage) *CohortLoader { return &CohortLoader{ cohortDownloadApi: cohortDownloadApi, @@ -27,7 +25,6 @@ func NewCohortLoader(cohortDownloadApi CohortDownloadApi, cohortStorage CohortSt } } -// LoadCohort initiates the loading of a cohort. func (cl *CohortLoader) LoadCohort(cohortID string) *CohortLoaderTask { cl.lockJobs.Lock() defer cl.lockJobs.Unlock() @@ -43,12 +40,10 @@ func (cl *CohortLoader) LoadCohort(cohortID string) *CohortLoaderTask { return task.(*CohortLoaderTask) } -// removeJob removes a job from the jobs map. func (cl *CohortLoader) removeJob(cohortID string) { cl.jobs.Delete(cohortID) } -// CohortLoaderTask represents a task for loading a cohort. type CohortLoaderTask struct { loader *CohortLoader cohortID string @@ -57,7 +52,6 @@ type CohortLoaderTask struct { err error } -// init initializes a CohortLoaderTask. func (task *CohortLoaderTask) init(loader *CohortLoader, cohortID string) { task.loader = loader task.cohortID = cohortID @@ -66,7 +60,6 @@ func (task *CohortLoaderTask) init(loader *CohortLoader, cohortID string) { task.err = nil } -// run executes the task of loading a cohort. func (task *CohortLoaderTask) run() { defer task.loader.executor.Put(task) @@ -82,28 +75,12 @@ func (task *CohortLoaderTask) run() { close(task.doneChan) } -// Wait waits for the task to complete. func (task *CohortLoaderTask) Wait() error { <-task.doneChan return task.err } -// downloadCohort downloads a cohort. func (cl *CohortLoader) downloadCohort(cohortID string) (*Cohort, error) { cohort := cl.cohortStorage.GetCohort(cohortID) return cl.cohortDownloadApi.GetCohort(cohortID, cohort) } - -type CohortDownloadApiImpl struct{} - -// GetCohort gets a cohort. -func (api *CohortDownloadApiImpl) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { - // Placeholder implementation - return &Cohort{ - ID: cohortID, - LastModified: 0, - Size: 0, - MemberIDs: make(map[string]struct{}), - GroupType: "example", - }, nil -} diff --git a/pkg/experiment/local/cohort_loader_test.go b/pkg/experiment/local/cohort_loader_test.go new file mode 100644 index 0000000..25506e7 --- /dev/null +++ b/pkg/experiment/local/cohort_loader_test.go @@ -0,0 +1,105 @@ +package local + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/mock" +) + +func TestLoadSuccess(t *testing.T) { + api := &MockCohortDownloadApi{} + storage := NewInMemoryCohortStorage() + loader := NewCohortLoader(api, storage) + + // Define mock behavior + api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "a", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType}, nil) + api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "b", LastModified: 0, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType}, nil) + + futureA := loader.LoadCohort("a") + futureB := loader.LoadCohort("b") + + if err := futureA.Wait(); err != nil { + t.Errorf("futureA.Wait() returned error: %v", err) + } + if err := futureB.Wait(); err != nil { + t.Errorf("futureB.Wait() returned error: %v", err) + } + + storageDescriptionA := storage.GetCohort("a") + storageDescriptionB := storage.GetCohort("b") + expectedA := &Cohort{ID: "a", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType} + expectedB := &Cohort{ID: "b", LastModified: 0, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType} + + if !CohortEquals(storageDescriptionA, expectedA) { + t.Errorf("Unexpected cohort A stored: %+v", storageDescriptionA) + } + if !CohortEquals(storageDescriptionB, expectedB) { + t.Errorf("Unexpected cohort B stored: %+v", storageDescriptionB) + } + + storageUser1Cohorts := storage.GetCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}}) + storageUser2Cohorts := storage.GetCohortsForUser("2", map[string]struct{}{"a": {}, "b": {}}) + if len(storageUser1Cohorts) != 2 || len(storageUser2Cohorts) != 1 { + t.Errorf("Unexpected user cohorts: User1: %+v, User2: %+v", storageUser1Cohorts, storageUser2Cohorts) + } +} + +func TestFilterCohortsAlreadyComputed(t *testing.T) { + api := &MockCohortDownloadApi{} + storage := NewInMemoryCohortStorage() + loader := NewCohortLoader(api, storage) + + storage.PutCohort(&Cohort{ID: "a", LastModified: 0, Size: 0, MemberIDs: []string{}}) + storage.PutCohort(&Cohort{ID: "b", LastModified: 0, Size: 0, MemberIDs: []string{}}) + + // Define mock behavior + api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "a", LastModified: 0, Size: 0, MemberIDs: []string{}, GroupType: userGroupType}, nil) + api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "b", LastModified: 1, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType}, nil) + + loader.LoadCohort("a").Wait() + loader.LoadCohort("b").Wait() + + storageDescriptionA := storage.GetCohort("a") + storageDescriptionB := storage.GetCohort("b") + expectedA := &Cohort{ID: "a", LastModified: 0, Size: 0, MemberIDs: []string{}, GroupType: userGroupType} + expectedB := &Cohort{ID: "b", LastModified: 1, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType} + + if !CohortEquals(storageDescriptionA, expectedA) { + t.Errorf("Unexpected cohort A stored: %+v", storageDescriptionA) + } + if !CohortEquals(storageDescriptionB, expectedB) { + t.Errorf("Unexpected cohort B stored: %+v", storageDescriptionB) + } + + storageUser1Cohorts := storage.GetCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}}) + storageUser2Cohorts := storage.GetCohortsForUser("2", map[string]struct{}{"a": {}, "b": {}}) + if len(storageUser1Cohorts) != 1 || len(storageUser2Cohorts) != 1 { + t.Errorf("Unexpected user cohorts: User1: %+v, User2: %+v", storageUser1Cohorts, storageUser2Cohorts) + } +} + +func TestLoadDownloadFailureThrows(t *testing.T) { + api := &MockCohortDownloadApi{} + storage := NewInMemoryCohortStorage() + loader := NewCohortLoader(api, storage) + + // Define mock behavior + api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "a", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType}, nil) + api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(nil, errors.New("connection timed out")) + api.On("GetCohort", "c", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "c", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType}, nil) + + loader.LoadCohort("a").Wait() + errB := loader.LoadCohort("b").Wait() + loader.LoadCohort("c").Wait() + + if errB == nil || errB.Error() != "connection timed out" { + t.Errorf("futureB.Wait() expected 'Connection timed out' error, got: %v", errB) + } + + expectedCohorts := map[string]struct{}{"a": {}, "c": {}} + actualCohorts := storage.GetCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}, "c": {}}) + if len(actualCohorts) != len(expectedCohorts) { + t.Errorf("Expected cohorts for user '1': %+v, but got: %+v", expectedCohorts, actualCohorts) + } +} diff --git a/pkg/experiment/local/cohort_storage.go b/pkg/experiment/local/cohort_storage.go index 35b9645..737e4ed 100644 --- a/pkg/experiment/local/cohort_storage.go +++ b/pkg/experiment/local/cohort_storage.go @@ -4,7 +4,6 @@ import ( "sync" ) -// CohortStorage defines the interface for cohort storage operations type CohortStorage interface { GetCohort(cohortID string) *Cohort GetCohorts() map[string]*Cohort @@ -14,14 +13,12 @@ type CohortStorage interface { DeleteCohort(groupType, cohortID string) } -// InMemoryCohortStorage is an in-memory implementation of CohortStorage type InMemoryCohortStorage struct { lock sync.RWMutex groupToCohortStore map[string]map[string]struct{} cohortStore map[string]*Cohort } -// NewInMemoryCohortStorage creates a new InMemoryCohortStorage instance func NewInMemoryCohortStorage() *InMemoryCohortStorage { return &InMemoryCohortStorage{ groupToCohortStore: make(map[string]map[string]struct{}), @@ -29,14 +26,12 @@ func NewInMemoryCohortStorage() *InMemoryCohortStorage { } } -// GetCohort retrieves a cohort by its ID func (s *InMemoryCohortStorage) GetCohort(cohortID string) *Cohort { s.lock.RLock() defer s.lock.RUnlock() return s.cohortStore[cohortID] } -// GetCohorts retrieves all cohorts func (s *InMemoryCohortStorage) GetCohorts() map[string]*Cohort { s.lock.RLock() defer s.lock.RUnlock() @@ -47,30 +42,36 @@ func (s *InMemoryCohortStorage) GetCohorts() map[string]*Cohort { return cohorts } -// GetCohortsForUser retrieves cohorts for a user based on cohort IDs func (s *InMemoryCohortStorage) GetCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} { return s.GetCohortsForGroup(userGroupType, userID, cohortIDs) } -// GetCohortsForGroup retrieves cohorts for a group based on cohort IDs func (s *InMemoryCohortStorage) GetCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} { result := make(map[string]struct{}) s.lock.RLock() defer s.lock.RUnlock() - groupTypeCohorts := s.groupToCohortStore[groupType] - for cohortID := range groupTypeCohorts { - if _, exists := cohortIDs[cohortID]; exists { + + groupTypeCohorts, groupExists := s.groupToCohortStore[groupType] + if !groupExists { + return result + } + + for cohortID := range cohortIDs { + if _, exists := groupTypeCohorts[cohortID]; exists { if cohort, found := s.cohortStore[cohortID]; found { - if _, memberExists := cohort.MemberIDs[groupName]; memberExists { - result[cohortID] = struct{}{} + for _, memberID := range cohort.MemberIDs { + if memberID == groupName { + result[cohortID] = struct{}{} + break + } } } } } + return result } -// PutCohort stores a cohort func (s *InMemoryCohortStorage) PutCohort(cohort *Cohort) { s.lock.Lock() defer s.lock.Unlock() @@ -81,7 +82,6 @@ func (s *InMemoryCohortStorage) PutCohort(cohort *Cohort) { s.cohortStore[cohort.ID] = cohort } -// DeleteCohort deletes a cohort by its ID and group type func (s *InMemoryCohortStorage) DeleteCohort(groupType, cohortID string) { s.lock.Lock() defer s.lock.Unlock() diff --git a/pkg/experiment/local/config.go b/pkg/experiment/local/config.go index fd746ab..79101c6 100644 --- a/pkg/experiment/local/config.go +++ b/pkg/experiment/local/config.go @@ -25,6 +25,7 @@ type CohortSyncConfig struct { SecretKey string MaxCohortSize int CohortRequestDelayMillis int + CohortServerUrl string } var DefaultConfig = &Config{ @@ -41,6 +42,7 @@ var DefaultAssignmentConfig = &AssignmentConfig{ var DefaultCohortSyncConfig = &CohortSyncConfig{ MaxCohortSize: 15000, CohortRequestDelayMillis: 5000, + CohortServerUrl: "https://cohort-v2.lab.amplitude.com", } func fillConfigDefaults(c *Config) *Config { @@ -68,5 +70,9 @@ func fillConfigDefaults(c *Config) *Config { c.CohortSyncConfig.CohortRequestDelayMillis = DefaultCohortSyncConfig.CohortRequestDelayMillis } + if c.CohortSyncConfig != nil && c.CohortSyncConfig.CohortServerUrl == "" { + c.CohortSyncConfig.CohortServerUrl = DefaultCohortSyncConfig.CohortServerUrl + } + return c } diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go new file mode 100644 index 0000000..0de3358 --- /dev/null +++ b/pkg/experiment/local/deployment_runner.go @@ -0,0 +1,153 @@ +package local + +import ( + "github.com/amplitude/experiment-go-server/internal/evaluation" + "log" + "sync" +) + +type DeploymentRunner struct { + config *Config + flagConfigApi FlagConfigApi + flagConfigStorage FlagConfigStorage + cohortStorage CohortStorage + cohortLoader *CohortLoader + lock sync.Mutex + poller *poller + logger *log.Logger +} + +func NewDeploymentRunner( + config *Config, + flagConfigApi FlagConfigApi, + flagConfigStorage FlagConfigStorage, + cohortStorage CohortStorage, + cohortLoader *CohortLoader, +) *DeploymentRunner { + dr := &DeploymentRunner{ + config: config, + flagConfigApi: flagConfigApi, + flagConfigStorage: flagConfigStorage, + cohortStorage: cohortStorage, + cohortLoader: cohortLoader, + logger: log.New(log.Writer(), "Amplitude: ", log.LstdFlags), + } + if config.Debug { + dr.logger.SetFlags(log.LstdFlags | log.Lshortfile) + } + dr.poller = newPoller() + return dr +} + +// Start begins the deployment runner's periodic refresh. +func (dr *DeploymentRunner) Start() error { + dr.lock.Lock() + defer dr.lock.Unlock() + + if err := dr.refresh(); err != nil { + dr.logger.Printf("Initial refresh failed: %v", err) + return err + } + + dr.poller.Poll(dr.config.FlagConfigPollerInterval, func() { + if err := dr.periodicRefresh(); err != nil { + dr.logger.Printf("Periodic refresh failed: %v", err) + } + }) + return nil +} + +func (dr *DeploymentRunner) periodicRefresh() error { + defer func() { + if r := recover(); r != nil { + dr.logger.Printf("Recovered in periodicRefresh: %v", r) + } + }() + return dr.refresh() +} + +func (dr *DeploymentRunner) refresh() error { + dr.logger.Println("Refreshing flag configs.") + flagConfigs, err := dr.flagConfigApi.GetFlagConfigs() + if err != nil { + dr.logger.Printf("Failed to fetch flag configs: %v", err) + return err + } + + flagKeys := make(map[string]struct{}) + for _, flag := range flagConfigs { + flagKeys[flag.Key] = struct{}{} + } + + dr.flagConfigStorage.RemoveIf(func(f evaluation.Flag) bool { + _, exists := flagKeys[f.Key] + return !exists + }) + + for _, flagConfig := range flagConfigs { + cohortIDs := getAllCohortIDsFromFlag(flagConfig) + if dr.cohortLoader == nil || len(cohortIDs) == 0 { + dr.logger.Printf("Putting non-cohort flag %s", flagConfig.Key) + dr.flagConfigStorage.PutFlagConfig(*flagConfig) + continue + } + + oldFlagConfig := dr.flagConfigStorage.GetFlagConfig(flagConfig.Key) + + err := dr.loadCohorts(*flagConfig, cohortIDs) + if err != nil { + dr.logger.Printf("Failed to load all cohorts for flag %s. Using the old flag config.", flagConfig.Key) + dr.flagConfigStorage.PutFlagConfig(oldFlagConfig) + continue + } + + dr.flagConfigStorage.PutFlagConfig(*flagConfig) + dr.logger.Printf("Stored flag config %s", flagConfig.Key) + } + + dr.deleteUnusedCohorts() + dr.logger.Printf("Refreshed %d flag configs.", len(flagConfigs)) + return nil +} + +func (dr *DeploymentRunner) loadCohorts(flagConfig evaluation.Flag, cohortIDs map[string]struct{}) error { + task := func() error { + for cohortID := range cohortIDs { + task := dr.cohortLoader.LoadCohort(cohortID) + err := task.Wait() + if err != nil { + dr.logger.Printf("Failed to load cohort %s for flag %s: %v", cohortID, flagConfig.Key, err) + return err + } + dr.logger.Printf("Cohort %s loaded for flag %s", cohortID, flagConfig.Key) + } + return nil + } + + // Using a goroutine to simulate async task execution + errCh := make(chan error) + go func() { + errCh <- task() + }() + err := <-errCh + return err +} + +func (dr *DeploymentRunner) deleteUnusedCohorts() { + flagCohortIDs := make(map[string]struct{}) + for _, flag := range dr.flagConfigStorage.GetFlagConfigs() { + for cohortID := range getAllCohortIDsFromFlag(&flag) { + flagCohortIDs[cohortID] = struct{}{} + } + } + + storageCohorts := dr.cohortStorage.GetCohorts() + for cohortID := range storageCohorts { + if _, exists := flagCohortIDs[cohortID]; !exists { + cohort := storageCohorts[cohortID] + if cohort != nil { + dr.cohortStorage.DeleteCohort(cohort.GroupType, cohortID) + } + } + } +} diff --git a/pkg/experiment/local/deployment_runner_test.go b/pkg/experiment/local/deployment_runner_test.go new file mode 100644 index 0000000..aed7a1c --- /dev/null +++ b/pkg/experiment/local/deployment_runner_test.go @@ -0,0 +1,107 @@ +package local + +import ( + "errors" + "fmt" + "testing" + + "github.com/amplitude/experiment-go-server/internal/evaluation" +) + +const ( + CohortId = "1234" +) + +func TestStartThrowsIfFirstFlagConfigLoadFails(t *testing.T) { + flagAPI := &mockFlagConfigApi{getFlagConfigsFunc: func() (map[string]*evaluation.Flag, error) { + return nil, errors.New("test") + }} + cohortDownloadAPI := &mockCohortDownloadApi{} + flagConfigStorage := NewInMemoryFlagConfigStorage() + cohortStorage := NewInMemoryCohortStorage() + cohortLoader := NewCohortLoader(cohortDownloadAPI, cohortStorage) + + runner := NewDeploymentRunner( + &Config{}, + flagAPI, + flagConfigStorage, + cohortStorage, + cohortLoader, + ) + + err := runner.Start() + + if err == nil { + t.Error("Expected error but got nil") + } +} + +func TestStartThrowsIfFirstCohortLoadFails(t *testing.T) { + flagAPI := &mockFlagConfigApi{getFlagConfigsFunc: func() (map[string]*evaluation.Flag, error) { + return map[string]*evaluation.Flag{"flag": createTestFlag()}, nil + }} + cohortDownloadAPI := &mockCohortDownloadApi{getCohortFunc: func(cohortID string, cohort *Cohort) (*Cohort, error) { + return nil, errors.New("test") + }} + flagConfigStorage := NewInMemoryFlagConfigStorage() + cohortStorage := NewInMemoryCohortStorage() + cohortLoader := NewCohortLoader(cohortDownloadAPI, cohortStorage) + + runner := NewDeploymentRunner( + &Config{}, + flagAPI, + flagConfigStorage, + cohortStorage, + cohortLoader, + ) + + err := runner.Start() + + if err == nil { + t.Error("Expected error but got nil") + } +} + +// Mock implementations for interfaces used in tests + +type mockFlagConfigApi struct { + getFlagConfigsFunc func() (map[string]*evaluation.Flag, error) +} + +func (m *mockFlagConfigApi) GetFlagConfigs() (map[string]*evaluation.Flag, error) { + if m.getFlagConfigsFunc != nil { + return m.getFlagConfigsFunc() + } + return nil, fmt.Errorf("mock not implemented") +} + +type mockCohortDownloadApi struct { + getCohortFunc func(cohortID string, cohort *Cohort) (*Cohort, error) +} + +func (m *mockCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { + if m.getCohortFunc != nil { + return m.getCohortFunc(cohortID, cohort) + } + return nil, fmt.Errorf("mock not implemented") +} + +func createTestFlag() *evaluation.Flag { + return &evaluation.Flag{ + Key: "flag", + Variants: map[string]*evaluation.Variant{}, + Segments: []*evaluation.Segment{ + { + Conditions: [][]*evaluation.Condition{ + { + { + Selector: []string{"context", "user", "cohort_ids"}, + Op: "set contains any", + Values: []string{CohortId}, + }, + }, + }, + }, + }, + } +} diff --git a/pkg/experiment/local/flag_config_api.go b/pkg/experiment/local/flag_config_api.go index dffbeba..f9d8649 100644 --- a/pkg/experiment/local/flag_config_api.go +++ b/pkg/experiment/local/flag_config_api.go @@ -12,19 +12,16 @@ import ( "time" ) -// FlagConfigApi defines an interface for retrieving flag configurations. type FlagConfigApi interface { - GetFlagConfigs() []interface{} + GetFlagConfigs() (map[string]*evaluation.Flag, error) } -// FlagConfigApiV2 is an implementation of the FlagConfigApi interface for version 2 of the API. type FlagConfigApiV2 struct { DeploymentKey string ServerURL string FlagConfigPollerRequestTimeoutMillis time.Duration } -// NewFlagConfigApiV2 creates a new instance of FlagConfigApiV2. func NewFlagConfigApiV2(deploymentKey, serverURL string, flagConfigPollerRequestTimeoutMillis time.Duration) *FlagConfigApiV2 { return &FlagConfigApiV2{ DeploymentKey: deploymentKey, diff --git a/pkg/experiment/local/flag_config_storage.go b/pkg/experiment/local/flag_config_storage.go index 6c4d76a..accb2f2 100644 --- a/pkg/experiment/local/flag_config_storage.go +++ b/pkg/experiment/local/flag_config_storage.go @@ -1,57 +1,51 @@ package local import ( + "github.com/amplitude/experiment-go-server/internal/evaluation" "sync" ) -// FlagConfigStorage defines an interface for managing flag configurations. type FlagConfigStorage interface { - GetFlagConfig(key string) map[string]interface{} - GetFlagConfigs() map[string]map[string]interface{} - PutFlagConfig(flagConfig map[string]interface{}) - RemoveIf(condition func(map[string]interface{}) bool) + GetFlagConfig(key string) evaluation.Flag + GetFlagConfigs() map[string]evaluation.Flag + PutFlagConfig(flagConfig evaluation.Flag) + RemoveIf(condition func(evaluation.Flag) bool) } -// InMemoryFlagConfigStorage is an in-memory implementation of FlagConfigStorage. type InMemoryFlagConfigStorage struct { - flagConfigs map[string]map[string]interface{} + flagConfigs map[string]evaluation.Flag flagConfigsLock sync.Mutex } -// NewInMemoryFlagConfigStorage creates a new instance of InMemoryFlagConfigStorage. func NewInMemoryFlagConfigStorage() *InMemoryFlagConfigStorage { return &InMemoryFlagConfigStorage{ - flagConfigs: make(map[string]map[string]interface{}), + flagConfigs: make(map[string]evaluation.Flag), } } -// GetFlagConfig retrieves a flag configuration by key. -func (storage *InMemoryFlagConfigStorage) GetFlagConfig(key string) map[string]interface{} { +func (storage *InMemoryFlagConfigStorage) GetFlagConfig(key string) evaluation.Flag { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() return storage.flagConfigs[key] } -// GetFlagConfigs retrieves all flag configurations. -func (storage *InMemoryFlagConfigStorage) GetFlagConfigs() map[string]map[string]interface{} { +func (storage *InMemoryFlagConfigStorage) GetFlagConfigs() map[string]evaluation.Flag { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() - copyFlagConfigs := make(map[string]map[string]interface{}) + copyFlagConfigs := make(map[string]evaluation.Flag) for key, value := range storage.flagConfigs { copyFlagConfigs[key] = value } return copyFlagConfigs } -// PutFlagConfig stores a flag configuration. -func (storage *InMemoryFlagConfigStorage) PutFlagConfig(flagConfig map[string]interface{}) { +func (storage *InMemoryFlagConfigStorage) PutFlagConfig(flagConfig evaluation.Flag) { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() - storage.flagConfigs[flagConfig["key"].(string)] = flagConfig + storage.flagConfigs[flagConfig.Key] = flagConfig } -// RemoveIf removes flag configurations based on a condition. -func (storage *InMemoryFlagConfigStorage) RemoveIf(condition func(map[string]interface{}) bool) { +func (storage *InMemoryFlagConfigStorage) RemoveIf(condition func(evaluation.Flag) bool) { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() for key, value := range storage.flagConfigs { diff --git a/pkg/experiment/local/flag_config_test.go b/pkg/experiment/local/flag_config_test.go new file mode 100644 index 0000000..37e3d6f --- /dev/null +++ b/pkg/experiment/local/flag_config_test.go @@ -0,0 +1,221 @@ +package local + +import ( + "testing" + + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/stretchr/testify/assert" +) + +func TestGetAllCohortIDsFromFlag(t *testing.T) { + flags := getTestFlags() + expectedCohortIDs := []string{ + "cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6", "cohort7", "cohort8", + } + expectedCohortIDSet := make(map[string]bool) + for _, id := range expectedCohortIDs { + expectedCohortIDSet[id] = true + } + + for _, flag := range flags { + cohortIDs := getAllCohortIDsFromFlag(flag) + for id := range cohortIDs { + assert.True(t, expectedCohortIDSet[id]) + } + } +} + +func TestGetGroupedCohortIDsFromFlag(t *testing.T) { + flags := getTestFlags() + expectedGroupedCohortIDs := map[string][]string{ + "user": {"cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6"}, + "group_name": {"cohort7", "cohort8"}, + } + + for _, flag := range flags { + groupedCohortIDs := getGroupedCohortIDsFromFlag(flag) + for key, values := range groupedCohortIDs { + assert.Contains(t, expectedGroupedCohortIDs, key) + expectedSet := make(map[string]bool) + for _, id := range expectedGroupedCohortIDs[key] { + expectedSet[id] = true + } + for id := range values { + assert.True(t, expectedSet[id]) + } + } + } +} + +func TestGetAllCohortIDsFromFlags(t *testing.T) { + flags := getTestFlags() + expectedCohortIDs := []string{ + "cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6", "cohort7", "cohort8", + } + expectedCohortIDSet := make(map[string]bool) + for _, id := range expectedCohortIDs { + expectedCohortIDSet[id] = true + } + + cohortIDs := getAllCohortIDsFromFlags(flags) + for id := range cohortIDs { + assert.True(t, expectedCohortIDSet[id]) + } +} + +func TestGetGroupedCohortIDsFromFlags(t *testing.T) { + flags := getTestFlags() + expectedGroupedCohortIDs := map[string][]string{ + "user": {"cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6"}, + "group_name": {"cohort7", "cohort8"}, + } + + groupedCohortIDs := getGroupedCohortIDsFromFlags(flags) + for key, values := range groupedCohortIDs { + assert.Contains(t, expectedGroupedCohortIDs, key) + expectedSet := make(map[string]bool) + for _, id := range expectedGroupedCohortIDs[key] { + expectedSet[id] = true + } + for id := range values { + assert.True(t, expectedSet[id]) + } + } +} + +func getTestFlags() []*evaluation.Flag { + return []*evaluation.Flag{ + { + Key: "flag-1", + Metadata: map[string]interface{}{ + "deployed": true, + "evaluationMode": "local", + "flagType": "release", + "flagVersion": 1, + }, + Segments: []*evaluation.Segment{ + { + Conditions: [][]*evaluation.Condition{ + { + { + Op: "set contains any", + Selector: []string{"context", "user", "cohort_ids"}, + Values: []string{"cohort1", "cohort2"}, + }, + }, + }, + Metadata: map[string]interface{}{ + "segmentName": "Segment A", + }, + Variant: "on", + }, + { + Metadata: map[string]interface{}{ + "segmentName": "All Other Users", + }, + Variant: "off", + }, + }, + Variants: map[string]*evaluation.Variant{ + "off": { + Key: "off", + Metadata: map[string]interface{}{ + "default": true, + }, + }, + "on": { + Key: "on", + Value: "on", + }, + }, + }, + { + Key: "flag-2", + Metadata: map[string]interface{}{ + "deployed": true, + "evaluationMode": "local", + "flagType": "release", + "flagVersion": 2, + }, + Segments: []*evaluation.Segment{ + { + Conditions: [][]*evaluation.Condition{ + { + { + Op: "set contains any", + Selector: []string{"context", "user", "cohort_ids"}, + Values: []string{"cohort3", "cohort4", "cohort5", "cohort6"}, + }, + }, + }, + Metadata: map[string]interface{}{ + "segmentName": "Segment B", + }, + Variant: "on", + }, + { + Metadata: map[string]interface{}{ + "segmentName": "All Other Users", + }, + Variant: "off", + }, + }, + Variants: map[string]*evaluation.Variant{ + "off": { + Key: "off", + Metadata: map[string]interface{}{ + "default": true, + }, + }, + "on": { + Key: "on", + Value: "on", + }, + }, + }, + { + Key: "flag-3", + Metadata: map[string]interface{}{ + "deployed": true, + "evaluationMode": "local", + "flagType": "release", + "flagVersion": 3, + }, + Segments: []*evaluation.Segment{ + { + Conditions: [][]*evaluation.Condition{ + { + { + Op: "set contains any", + Selector: []string{"context", "groups", "group_name", "cohort_ids"}, + Values: []string{"cohort7", "cohort8"}, + }, + }, + }, + Metadata: map[string]interface{}{ + "segmentName": "Segment C", + }, + Variant: "on", + }, + { + Metadata: map[string]interface{}{ + "segmentName": "All Other Groups", + }, + Variant: "off", + }, + }, + Variants: map[string]*evaluation.Variant{ + "off": { + Key: "off", + Metadata: map[string]interface{}{ + "default": true, + }, + }, + "on": { + Key: "on", + Value: "on", + }, + }, + }, + } +} diff --git a/pkg/experiment/local/flag_config_util.go b/pkg/experiment/local/flag_config_util.go index 768b89b..2c4231e 100644 --- a/pkg/experiment/local/flag_config_util.go +++ b/pkg/experiment/local/flag_config_util.go @@ -4,8 +4,8 @@ import ( "github.com/amplitude/experiment-go-server/internal/evaluation" ) -// IsCohortFilter checks if the condition is a cohort filter. -func IsCohortFilter(condition *evaluation.Condition) bool { +// isCohortFilter checks if the condition is a cohort filter. +func isCohortFilter(condition *evaluation.Condition) bool { op := condition.Op selector := condition.Selector if len(selector) > 0 && selector[len(selector)-1] == "cohort_ids" { @@ -14,16 +14,16 @@ func IsCohortFilter(condition *evaluation.Condition) bool { return false } -// GetGroupedCohortConditionIDs extracts grouped cohort condition IDs from a segment. -func GetGroupedCohortConditionIDs(segment *evaluation.Segment) map[string]map[string]bool { - cohortIDs := make(map[string]map[string]bool) +// getGroupedCohortConditionIDs extracts grouped cohort condition IDs from a segment. +func getGroupedCohortConditionIDs(segment *evaluation.Segment) map[string]map[string]struct{} { + cohortIDs := make(map[string]map[string]struct{}) if segment == nil { return cohortIDs } for _, outer := range segment.Conditions { for _, condition := range outer { - if IsCohortFilter(condition) { + if isCohortFilter(condition) { selector := condition.Selector if len(selector) > 2 { contextSubtype := selector[1] @@ -36,9 +36,9 @@ func GetGroupedCohortConditionIDs(segment *evaluation.Segment) map[string]map[st continue } values := condition.Values - cohortIDs[groupType] = make(map[string]bool) + cohortIDs[groupType] = make(map[string]struct{}) for _, value := range values { - cohortIDs[groupType][value] = true + cohortIDs[groupType][value] = struct{}{} } } } @@ -47,56 +47,56 @@ func GetGroupedCohortConditionIDs(segment *evaluation.Segment) map[string]map[st return cohortIDs } -// GetGroupedCohortIDsFromFlag extracts grouped cohort IDs from a flag. -func GetGroupedCohortIDsFromFlag(flag *evaluation.Flag) map[string]map[string]bool { - cohortIDs := make(map[string]map[string]bool) +// getGroupedCohortIDsFromFlag extracts grouped cohort IDs from a flag. +func getGroupedCohortIDsFromFlag(flag *evaluation.Flag) map[string]map[string]struct{} { + cohortIDs := make(map[string]map[string]struct{}) for _, segment := range flag.Segments { - for key, values := range GetGroupedCohortConditionIDs(segment) { + for key, values := range getGroupedCohortConditionIDs(segment) { if _, exists := cohortIDs[key]; !exists { - cohortIDs[key] = make(map[string]bool) + cohortIDs[key] = make(map[string]struct{}) } for id := range values { - cohortIDs[key][id] = true + cohortIDs[key][id] = struct{}{} } } } return cohortIDs } -// GetAllCohortIDsFromFlag extracts all cohort IDs from a flag. -func GetAllCohortIDsFromFlag(flag *evaluation.Flag) map[string]bool { - cohortIDs := make(map[string]bool) - groupedIDs := GetGroupedCohortIDsFromFlag(flag) +// getAllCohortIDsFromFlag extracts all cohort IDs from a flag. +func getAllCohortIDsFromFlag(flag *evaluation.Flag) map[string]struct{} { + cohortIDs := make(map[string]struct{}) + groupedIDs := getGroupedCohortIDsFromFlag(flag) for _, values := range groupedIDs { for id := range values { - cohortIDs[id] = true + cohortIDs[id] = struct{}{} } } return cohortIDs } -// GetGroupedCohortIDsFromFlags extracts grouped cohort IDs from multiple flags. -func GetGroupedCohortIDsFromFlags(flags []*evaluation.Flag) map[string]map[string]bool { - cohortIDs := make(map[string]map[string]bool) +// getGroupedCohortIDsFromFlags extracts grouped cohort IDs from multiple flags. +func getGroupedCohortIDsFromFlags(flags []*evaluation.Flag) map[string]map[string]struct{} { + cohortIDs := make(map[string]map[string]struct{}) for _, flag := range flags { - for key, values := range GetGroupedCohortIDsFromFlag(flag) { + for key, values := range getGroupedCohortIDsFromFlag(flag) { if _, exists := cohortIDs[key]; !exists { - cohortIDs[key] = make(map[string]bool) + cohortIDs[key] = make(map[string]struct{}) } for id := range values { - cohortIDs[key][id] = true + cohortIDs[key][id] = struct{}{} } } } return cohortIDs } -// GetAllCohortIDsFromFlags extracts all cohort IDs from multiple flags. -func GetAllCohortIDsFromFlags(flags []*evaluation.Flag) map[string]bool { - cohortIDs := make(map[string]bool) +// getAllCohortIDsFromFlags extracts all cohort IDs from multiple flags. +func getAllCohortIDsFromFlags(flags []*evaluation.Flag) map[string]struct{} { + cohortIDs := make(map[string]struct{}) for _, flag := range flags { - for id := range GetAllCohortIDsFromFlag(flag) { - cohortIDs[id] = true + for id := range getAllCohortIDsFromFlag(flag) { + cohortIDs[id] = struct{}{} } } return cohortIDs diff --git a/pkg/experiment/types.go b/pkg/experiment/types.go index 82910b5..d4e6ba5 100644 --- a/pkg/experiment/types.go +++ b/pkg/experiment/types.go @@ -1,24 +1,49 @@ package experiment +import "sync" + const VERSION = "1.5.0" type User struct { - UserId string `json:"user_id,omitempty"` - DeviceId string `json:"device_id,omitempty"` - Country string `json:"country,omitempty"` - Region string `json:"region,omitempty"` - Dma string `json:"dma,omitempty"` - City string `json:"city,omitempty"` - Language string `json:"language,omitempty"` - Platform string `json:"platform,omitempty"` - Version string `json:"version,omitempty"` - Os string `json:"os,omitempty"` - DeviceManufacturer string `json:"device_manufacturer,omitempty"` - DeviceBrand string `json:"device_brand,omitempty"` - DeviceModel string `json:"device_model,omitempty"` - Carrier string `json:"carrier,omitempty"` - Library string `json:"library,omitempty"` - UserProperties map[string]interface{} `json:"user_properties,omitempty"` + UserId string `json:"user_id,omitempty"` + DeviceId string `json:"device_id,omitempty"` + Country string `json:"country,omitempty"` + Region string `json:"region,omitempty"` + Dma string `json:"dma,omitempty"` + City string `json:"city,omitempty"` + Language string `json:"language,omitempty"` + Platform string `json:"platform,omitempty"` + Version string `json:"version,omitempty"` + Os string `json:"os,omitempty"` + DeviceManufacturer string `json:"device_manufacturer,omitempty"` + DeviceBrand string `json:"device_brand,omitempty"` + DeviceModel string `json:"device_model,omitempty"` + Carrier string `json:"carrier,omitempty"` + Library string `json:"library,omitempty"` + UserProperties map[string]interface{} `json:"user_properties,omitempty"` + Groups map[string][]string `json:"groups,omitempty"` + CohortIDs map[string]struct{} `json:"cohort_ids,omitempty"` + GroupCohortIDs map[string]map[string]struct{} `json:"group_cohort_ids,omitempty"` + lock sync.Mutex +} + +func (u *User) AddGroupCohortIDs(groupType, groupName string, cohortIDs map[string]struct{}) { + u.lock.Lock() + defer u.lock.Unlock() + + if u.GroupCohortIDs == nil { + u.GroupCohortIDs = make(map[string]map[string]struct{}) + } + + groupNames := u.GroupCohortIDs[groupType] + if groupNames == nil { + groupNames = make(map[string]struct{}) + u.GroupCohortIDs[groupType] = groupNames + } + + for id := range cohortIDs { + groupNames[id] = struct{}{} + } } type Variant struct { From 287d71fc7002a4e28c99bbe88e1c24cf5a83f608 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 27 Jun 2024 10:44:12 -0700 Subject: [PATCH 03/29] fix tests --- pkg/experiment/local/client.go | 4 +++- pkg/experiment/local/cohort_download_api_test.go | 8 ++++---- pkg/experiment/local/deployment_runner.go | 3 +-- pkg/experiment/local/deployment_runner_test.go | 4 +--- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 7697d8e..2506697 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -127,7 +127,9 @@ func (c *Client) Evaluate(user *experiment.User, flagKeys []string) (map[string] } func (c *Client) EvaluateV2(user *experiment.User, flagKeys []string) (map[string]experiment.Variant, error) { - userContext := evaluation.UserToContext(user) + flagConfigs := c.flagConfigStorage.GetFlagConfigs() + enrichedUser, err := c.enrichUser(user, flagConfigs) + userContext := evaluation.UserToContext(enrichedUser) c.flagsMutex.RLock() sortedFlags, err := topologicalSort(c.flags, flagKeys) c.flagsMutex.RUnlock() diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go index 428f2f6..0d7dcff 100644 --- a/pkg/experiment/local/cohort_download_api_test.go +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -28,7 +28,7 @@ func TestCohortDownloadApi(t *testing.T) { api := NewDirectCohortDownloadApi("api", "secret", 15000, 100, "https://server.amplitude.com", false) t.Run("test_cohort_download_success", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}, GroupType: userGroupType} response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), @@ -47,7 +47,7 @@ func TestCohortDownloadApi(t *testing.T) { }) t.Run("test_cohort_download_many_202s_success", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}, GroupType: userGroupType} response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} for i := 0; i < 9; i++ { @@ -71,7 +71,7 @@ func TestCohortDownloadApi(t *testing.T) { }) t.Run("test_cohort_request_status_with_two_failures_succeeds", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}, GroupType: userGroupType} response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), @@ -96,7 +96,7 @@ func TestCohortDownloadApi(t *testing.T) { }) t.Run("test_cohort_request_status_429s_keep_retrying", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}, GroupType: userGroupType} response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} for i := 0; i < 9; i++ { diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 0de3358..c6b10c1 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -39,7 +39,6 @@ func NewDeploymentRunner( return dr } -// Start begins the deployment runner's periodic refresh. func (dr *DeploymentRunner) Start() error { dr.lock.Lock() defer dr.lock.Unlock() @@ -98,7 +97,7 @@ func (dr *DeploymentRunner) refresh() error { if err != nil { dr.logger.Printf("Failed to load all cohorts for flag %s. Using the old flag config.", flagConfig.Key) dr.flagConfigStorage.PutFlagConfig(oldFlagConfig) - continue + return err } dr.flagConfigStorage.PutFlagConfig(*flagConfig) diff --git a/pkg/experiment/local/deployment_runner_test.go b/pkg/experiment/local/deployment_runner_test.go index aed7a1c..2f1dc85 100644 --- a/pkg/experiment/local/deployment_runner_test.go +++ b/pkg/experiment/local/deployment_runner_test.go @@ -48,7 +48,7 @@ func TestStartThrowsIfFirstCohortLoadFails(t *testing.T) { cohortLoader := NewCohortLoader(cohortDownloadAPI, cohortStorage) runner := NewDeploymentRunner( - &Config{}, + DefaultConfig, flagAPI, flagConfigStorage, cohortStorage, @@ -62,8 +62,6 @@ func TestStartThrowsIfFirstCohortLoadFails(t *testing.T) { } } -// Mock implementations for interfaces used in tests - type mockFlagConfigApi struct { getFlagConfigsFunc func() (map[string]*evaluation.Flag, error) } From f186f00fc33400ef73e6fa925dbbf1a4ff391c23 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 27 Jun 2024 11:28:57 -0700 Subject: [PATCH 04/29] fix flag storage to use pointer --- internal/evaluation/context.go | 39 +++++++++++++++ pkg/experiment/local/client.go | 22 +++------ pkg/experiment/local/deployment_runner.go | 8 ++-- pkg/experiment/local/flag_config_storage.go | 22 ++++----- pkg/experiment/types.go | 53 ++++++++++----------- 5 files changed, 86 insertions(+), 58 deletions(-) diff --git a/internal/evaluation/context.go b/internal/evaluation/context.go index 62d4147..b73958d 100644 --- a/internal/evaluation/context.go +++ b/internal/evaluation/context.go @@ -8,6 +8,7 @@ func UserToContext(user *experiment.User) map[string]interface{} { } context := make(map[string]interface{}) userMap := make(map[string]interface{}) + if len(user.UserId) != 0 { userMap["user_id"] = user.UserId } @@ -56,6 +57,44 @@ func UserToContext(user *experiment.User) map[string]interface{} { if len(user.UserProperties) != 0 { userMap["user_properties"] = user.UserProperties } + context["user"] = userMap + + if user.Groups == nil { + return context + } + + groups := make(map[string]interface{}) + for groupType, groupNames := range user.Groups { + if len(groupNames) > 0 { + groupName := groupNames[0] + groupNameMap := map[string]interface{}{ + "group_name": groupName, + } + + if user.GroupProperties != nil { + if groupPropertiesType, ok := user.GroupProperties[groupType]; ok { + if groupPropertiesName, ok := groupPropertiesType[groupName]; ok { + groupNameMap["group_properties"] = groupPropertiesName + } + } + } + + if user.GroupCohortIds != nil { + if groupCohortIdsType, ok := user.GroupCohortIds[groupType]; ok { + if groupCohortIdsName, ok := groupCohortIdsType[groupName]; ok { + groupNameMap["cohort_ids"] = groupCohortIdsName + } + } + } + + groups[groupType] = groupNameMap + } + } + + if len(groups) > 0 { + context["groups"] = groups + } + return context } diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 2506697..ebd5a36 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -61,8 +61,8 @@ func Initialize(apiKey string, config *Config) *Client { if config.CohortSyncConfig != nil { cohortDownloadApi := NewDirectCohortDownloadApi(config.CohortSyncConfig.ApiKey, config.CohortSyncConfig.SecretKey, config.CohortSyncConfig.MaxCohortSize, config.CohortSyncConfig.CohortRequestDelayMillis, config.CohortSyncConfig.CohortServerUrl, config.Debug) cohortLoader = NewCohortLoader(cohortDownloadApi, cohortStorage) - deploymentRunner = NewDeploymentRunner(config, NewFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), flagConfigStorage, cohortStorage, cohortLoader) } + deploymentRunner = NewDeploymentRunner(config, NewFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), flagConfigStorage, cohortStorage, cohortLoader) client = &Client{ log: log, apiKey: apiKey, @@ -86,20 +86,10 @@ func Initialize(apiKey string, config *Config) *Client { } func (c *Client) Start() error { - result, err := c.doFlagsV2() + err := c.deploymentRunner.Start() if err != nil { return err } - c.flags = result - c.poller.Poll(c.config.FlagConfigPollerInterval, func() { - result, err := c.doFlagsV2() - if err != nil { - return - } - c.flagsMutex.Lock() - c.flags = result - c.flagsMutex.Unlock() - }) return nil } @@ -131,7 +121,7 @@ func (c *Client) EvaluateV2(user *experiment.User, flagKeys []string) (map[strin enrichedUser, err := c.enrichUser(user, flagConfigs) userContext := evaluation.UserToContext(enrichedUser) c.flagsMutex.RLock() - sortedFlags, err := topologicalSort(c.flags, flagKeys) + sortedFlags, err := topologicalSort(flagConfigs, flagKeys) c.flagsMutex.RUnlock() if err != nil { return nil, err @@ -349,17 +339,17 @@ func coerceString(value interface{}) string { return fmt.Sprintf("%v", value) } -func (c *Client) enrichUser(user *experiment.User, flagConfigs map[string]evaluation.Flag) (*experiment.User, error) { +func (c *Client) enrichUser(user *experiment.User, flagConfigs map[string]*evaluation.Flag) (*experiment.User, error) { flagConfigSlice := make([]*evaluation.Flag, 0, len(flagConfigs)) for _, value := range flagConfigs { - flagConfigSlice = append(flagConfigSlice, &value) + flagConfigSlice = append(flagConfigSlice, value) } groupedCohortIDs := getGroupedCohortIDsFromFlags(flagConfigSlice) if cohortIDs, ok := groupedCohortIDs[userGroupType]; ok { if len(cohortIDs) > 0 && user.UserId != "" { - user.CohortIDs = c.cohortStorage.GetCohortsForUser(user.UserId, cohortIDs) + user.CohortIds = c.cohortStorage.GetCohortsForUser(user.UserId, cohortIDs) } } diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index c6b10c1..a972aa2 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -78,7 +78,7 @@ func (dr *DeploymentRunner) refresh() error { flagKeys[flag.Key] = struct{}{} } - dr.flagConfigStorage.RemoveIf(func(f evaluation.Flag) bool { + dr.flagConfigStorage.RemoveIf(func(f *evaluation.Flag) bool { _, exists := flagKeys[f.Key] return !exists }) @@ -87,7 +87,7 @@ func (dr *DeploymentRunner) refresh() error { cohortIDs := getAllCohortIDsFromFlag(flagConfig) if dr.cohortLoader == nil || len(cohortIDs) == 0 { dr.logger.Printf("Putting non-cohort flag %s", flagConfig.Key) - dr.flagConfigStorage.PutFlagConfig(*flagConfig) + dr.flagConfigStorage.PutFlagConfig(flagConfig) continue } @@ -100,7 +100,7 @@ func (dr *DeploymentRunner) refresh() error { return err } - dr.flagConfigStorage.PutFlagConfig(*flagConfig) + dr.flagConfigStorage.PutFlagConfig(flagConfig) dr.logger.Printf("Stored flag config %s", flagConfig.Key) } @@ -135,7 +135,7 @@ func (dr *DeploymentRunner) loadCohorts(flagConfig evaluation.Flag, cohortIDs ma func (dr *DeploymentRunner) deleteUnusedCohorts() { flagCohortIDs := make(map[string]struct{}) for _, flag := range dr.flagConfigStorage.GetFlagConfigs() { - for cohortID := range getAllCohortIDsFromFlag(&flag) { + for cohortID := range getAllCohortIDsFromFlag(flag) { flagCohortIDs[cohortID] = struct{}{} } } diff --git a/pkg/experiment/local/flag_config_storage.go b/pkg/experiment/local/flag_config_storage.go index accb2f2..12229ba 100644 --- a/pkg/experiment/local/flag_config_storage.go +++ b/pkg/experiment/local/flag_config_storage.go @@ -6,46 +6,46 @@ import ( ) type FlagConfigStorage interface { - GetFlagConfig(key string) evaluation.Flag - GetFlagConfigs() map[string]evaluation.Flag - PutFlagConfig(flagConfig evaluation.Flag) - RemoveIf(condition func(evaluation.Flag) bool) + GetFlagConfig(key string) *evaluation.Flag + GetFlagConfigs() map[string]*evaluation.Flag + PutFlagConfig(flagConfig *evaluation.Flag) + RemoveIf(condition func(*evaluation.Flag) bool) } type InMemoryFlagConfigStorage struct { - flagConfigs map[string]evaluation.Flag + flagConfigs map[string]*evaluation.Flag flagConfigsLock sync.Mutex } func NewInMemoryFlagConfigStorage() *InMemoryFlagConfigStorage { return &InMemoryFlagConfigStorage{ - flagConfigs: make(map[string]evaluation.Flag), + flagConfigs: make(map[string]*evaluation.Flag), } } -func (storage *InMemoryFlagConfigStorage) GetFlagConfig(key string) evaluation.Flag { +func (storage *InMemoryFlagConfigStorage) GetFlagConfig(key string) *evaluation.Flag { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() return storage.flagConfigs[key] } -func (storage *InMemoryFlagConfigStorage) GetFlagConfigs() map[string]evaluation.Flag { +func (storage *InMemoryFlagConfigStorage) GetFlagConfigs() map[string]*evaluation.Flag { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() - copyFlagConfigs := make(map[string]evaluation.Flag) + copyFlagConfigs := make(map[string]*evaluation.Flag) for key, value := range storage.flagConfigs { copyFlagConfigs[key] = value } return copyFlagConfigs } -func (storage *InMemoryFlagConfigStorage) PutFlagConfig(flagConfig evaluation.Flag) { +func (storage *InMemoryFlagConfigStorage) PutFlagConfig(flagConfig *evaluation.Flag) { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() storage.flagConfigs[flagConfig.Key] = flagConfig } -func (storage *InMemoryFlagConfigStorage) RemoveIf(condition func(evaluation.Flag) bool) { +func (storage *InMemoryFlagConfigStorage) RemoveIf(condition func(*evaluation.Flag) bool) { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() for key, value := range storage.flagConfigs { diff --git a/pkg/experiment/types.go b/pkg/experiment/types.go index d4e6ba5..19c465a 100644 --- a/pkg/experiment/types.go +++ b/pkg/experiment/types.go @@ -5,25 +5,26 @@ import "sync" const VERSION = "1.5.0" type User struct { - UserId string `json:"user_id,omitempty"` - DeviceId string `json:"device_id,omitempty"` - Country string `json:"country,omitempty"` - Region string `json:"region,omitempty"` - Dma string `json:"dma,omitempty"` - City string `json:"city,omitempty"` - Language string `json:"language,omitempty"` - Platform string `json:"platform,omitempty"` - Version string `json:"version,omitempty"` - Os string `json:"os,omitempty"` - DeviceManufacturer string `json:"device_manufacturer,omitempty"` - DeviceBrand string `json:"device_brand,omitempty"` - DeviceModel string `json:"device_model,omitempty"` - Carrier string `json:"carrier,omitempty"` - Library string `json:"library,omitempty"` - UserProperties map[string]interface{} `json:"user_properties,omitempty"` - Groups map[string][]string `json:"groups,omitempty"` - CohortIDs map[string]struct{} `json:"cohort_ids,omitempty"` - GroupCohortIDs map[string]map[string]struct{} `json:"group_cohort_ids,omitempty"` + UserId string `json:"user_id,omitempty"` + DeviceId string `json:"device_id,omitempty"` + Country string `json:"country,omitempty"` + Region string `json:"region,omitempty"` + Dma string `json:"dma,omitempty"` + City string `json:"city,omitempty"` + Language string `json:"language,omitempty"` + Platform string `json:"platform,omitempty"` + Version string `json:"version,omitempty"` + Os string `json:"os,omitempty"` + DeviceManufacturer string `json:"device_manufacturer,omitempty"` + DeviceBrand string `json:"device_brand,omitempty"` + DeviceModel string `json:"device_model,omitempty"` + Carrier string `json:"carrier,omitempty"` + Library string `json:"library,omitempty"` + UserProperties map[string]interface{} `json:"user_properties,omitempty"` + GroupProperties map[string]map[string]string `json:"group_properties,omitempty"` + Groups map[string][]string `json:"groups,omitempty"` + CohortIds map[string]struct{} `json:"cohort_ids,omitempty"` + GroupCohortIds map[string]map[string]map[string]struct{} `json:"group_cohort_ids,omitempty"` lock sync.Mutex } @@ -31,19 +32,17 @@ func (u *User) AddGroupCohortIDs(groupType, groupName string, cohortIDs map[stri u.lock.Lock() defer u.lock.Unlock() - if u.GroupCohortIDs == nil { - u.GroupCohortIDs = make(map[string]map[string]struct{}) + if u.GroupCohortIds == nil { + u.GroupCohortIds = make(map[string]map[string]map[string]struct{}) } - groupNames := u.GroupCohortIDs[groupType] + groupNames := u.GroupCohortIds[groupType] if groupNames == nil { - groupNames = make(map[string]struct{}) - u.GroupCohortIDs[groupType] = groupNames + groupNames = make(map[string]map[string]struct{}) + u.GroupCohortIds[groupType] = groupNames } - for id := range cohortIDs { - groupNames[id] = struct{}{} - } + groupNames[groupName] = cohortIDs } type Variant struct { From d16abcff2b54afdb50dbf35dcabd4e81452f3ba4 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 1 Jul 2024 09:13:55 -0700 Subject: [PATCH 05/29] fix evaluation context --- internal/evaluation/context.go | 20 ++++++++++++++++++-- pkg/experiment/local/client.go | 4 +--- pkg/experiment/local/cohort.go | 2 +- pkg/experiment/local/cohort_download_api.go | 8 ++------ 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/internal/evaluation/context.go b/internal/evaluation/context.go index b73958d..002c2c1 100644 --- a/internal/evaluation/context.go +++ b/internal/evaluation/context.go @@ -1,6 +1,8 @@ package evaluation -import "github.com/amplitude/experiment-go-server/pkg/experiment" +import ( + "github.com/amplitude/experiment-go-server/pkg/experiment" +) func UserToContext(user *experiment.User) map[string]interface{} { if user == nil { @@ -57,6 +59,12 @@ func UserToContext(user *experiment.User) map[string]interface{} { if len(user.UserProperties) != 0 { userMap["user_properties"] = user.UserProperties } + if len(user.Groups) != 0 { + userMap["groups"] = user.Groups + } + if len(user.CohortIds) != 0 { + userMap["cohort_ids"] = extractKeys(user.CohortIds) + } context["user"] = userMap @@ -83,7 +91,7 @@ func UserToContext(user *experiment.User) map[string]interface{} { if user.GroupCohortIds != nil { if groupCohortIdsType, ok := user.GroupCohortIds[groupType]; ok { if groupCohortIdsName, ok := groupCohortIdsType[groupName]; ok { - groupNameMap["cohort_ids"] = groupCohortIdsName + groupNameMap["cohort_ids"] = extractKeys(groupCohortIdsName) } } } @@ -98,3 +106,11 @@ func UserToContext(user *experiment.User) map[string]interface{} { return context } + +func extractKeys(m map[string]struct{}) []string { + keys := make([]string, 0, len(m)) + for key := range m { + keys = append(keys, key) + } + return keys +} diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index ebd5a36..e993c4b 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -158,9 +158,7 @@ func (c *Client) FlagsV2() (string, error) { // FlagMetadata returns a copy of the flag's metadata. If the flag is not found then nil is returned. func (c *Client) FlagMetadata(flagKey string) map[string]interface{} { - c.flagsMutex.RLock() - f := c.flags[flagKey] - c.flagsMutex.RUnlock() + f := c.flagConfigStorage.GetFlagConfig(flagKey) if f == nil { return nil } diff --git a/pkg/experiment/local/cohort.go b/pkg/experiment/local/cohort.go index 584bcb9..54d2f4d 100644 --- a/pkg/experiment/local/cohort.go +++ b/pkg/experiment/local/cohort.go @@ -2,7 +2,7 @@ package local import "sort" -const userGroupType = "user" +const userGroupType = "User" type Cohort struct { ID string diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 837886f..14bae26 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -9,10 +9,6 @@ import ( "time" ) -const ( - CdnCohortSyncUrl = "https://cohort-v2.lab.amplitude.com" -) - type HTTPErrorResponseException struct { StatusCode int Message string @@ -87,7 +83,7 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( if response.StatusCode == http.StatusOK { var cohortInfo struct { - CohortId string `json:"Id"` + Id string `json:"cohortId"` LastModified int64 `json:"lastModified"` Size int `json:"size"` MemberIds []string `json:"memberIds"` @@ -98,7 +94,7 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( } api.Logger.Printf("getCohortMembers(%s): end - resultSize=%d", cohortID, cohortInfo.Size) return &Cohort{ - ID: cohortInfo.CohortId, + ID: cohortInfo.Id, LastModified: cohortInfo.LastModified, Size: cohortInfo.Size, MemberIDs: cohortInfo.MemberIds, From d6d6b9a12943ee1d7494747be0e4bb258ab83c77 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 1 Jul 2024 14:14:10 -0700 Subject: [PATCH 06/29] add config tests, refactor exceptions, ignore CohortNotModifiedException --- pkg/experiment/local/cohort_download_api.go | 47 +----- pkg/experiment/local/config.go | 23 ++- pkg/experiment/local/config_test.go | 175 ++++++++++++++++++++ pkg/experiment/local/deployment_runner.go | 7 +- pkg/experiment/local/exception.go | 30 ++++ 5 files changed, 238 insertions(+), 44 deletions(-) create mode 100644 pkg/experiment/local/config_test.go create mode 100644 pkg/experiment/local/exception.go diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 14bae26..00eaeab 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -9,35 +9,6 @@ import ( "time" ) -type HTTPErrorResponseException struct { - StatusCode int - Message string -} - -func (e *HTTPErrorResponseException) Error() string { - return e.Message -} - -type CohortTooLargeException struct { - Message string -} - -func (e *CohortTooLargeException) Error() string { - return e.Message -} - -type CohortNotModifiedException struct { - Message string -} - -func (e *CohortNotModifiedException) Error() string { - return e.Message -} - -type CohortDownloadApi interface { - GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) -} - type DirectCohortDownloadApi struct { ApiKey string SecretKey string @@ -74,7 +45,14 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( if err != nil { api.Logger.Printf("getCohortMembers(%s): request-status error %d - %v", cohortID, errors, err) errors++ - if errors >= 3 || isSpecificError(err) { + if errors >= 3 || func(err error) bool { + switch err.(type) { + case *CohortNotModifiedException, *CohortTooLargeException: + return true + default: + return false + } + }(err) { return nil, err } time.Sleep(time.Duration(api.CohortRequestDelayMillis) * time.Millisecond) @@ -115,15 +93,6 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( } } -func isSpecificError(err error) bool { - switch err.(type) { - case *CohortNotModifiedException, *CohortTooLargeException: - return true - default: - return false - } -} - func (api *DirectCohortDownloadApi) getCohortMembersRequest(client *http.Client, cohortID string, cohort *Cohort) (*http.Response, error) { req, err := http.NewRequest("GET", api.buildCohortURL(cohortID, cohort), nil) if err != nil { diff --git a/pkg/experiment/local/config.go b/pkg/experiment/local/config.go index 79101c6..443bc5e 100644 --- a/pkg/experiment/local/config.go +++ b/pkg/experiment/local/config.go @@ -2,12 +2,17 @@ package local import ( "github.com/amplitude/analytics-go/amplitude" + "strings" "time" ) +const EUFlagServerUrl = "https://flag.lab.eu.amplitude.com" +const EUCohortSyncUrl = "https://cohort-v2.lab.eu.amplitude.com" + type Config struct { Debug bool ServerUrl string + ServerZone string FlagConfigPollerInterval time.Duration FlagConfigPollerRequestTimeout time.Duration AssignmentConfig *AssignmentConfig @@ -19,7 +24,6 @@ type AssignmentConfig struct { CacheCapacity int } -// CohortSyncConfig holds configuration for cohort synchronization. type CohortSyncConfig struct { ApiKey string SecretKey string @@ -31,6 +35,7 @@ type CohortSyncConfig struct { var DefaultConfig = &Config{ Debug: false, ServerUrl: "https://api.lab.amplitude.com/", + ServerZone: "us", FlagConfigPollerInterval: 30 * time.Second, FlagConfigPollerRequestTimeout: 10 * time.Second, } @@ -49,9 +54,17 @@ func fillConfigDefaults(c *Config) *Config { if c == nil { return DefaultConfig } + if c.ServerZone == "" { + c.ServerZone = DefaultConfig.ServerZone + } if c.ServerUrl == "" { - c.ServerUrl = DefaultConfig.ServerUrl + if strings.ToLower(c.ServerZone) == DefaultConfig.ServerZone { + c.ServerUrl = DefaultConfig.ServerUrl + } else if strings.ToLower(c.ServerZone) == "eu" { + c.ServerUrl = EUFlagServerUrl + } } + if c.FlagConfigPollerInterval == 0 { c.FlagConfigPollerInterval = DefaultConfig.FlagConfigPollerInterval } @@ -71,7 +84,11 @@ func fillConfigDefaults(c *Config) *Config { } if c.CohortSyncConfig != nil && c.CohortSyncConfig.CohortServerUrl == "" { - c.CohortSyncConfig.CohortServerUrl = DefaultCohortSyncConfig.CohortServerUrl + if strings.ToLower(c.ServerZone) == DefaultConfig.ServerZone { + c.CohortSyncConfig.CohortServerUrl = DefaultCohortSyncConfig.CohortServerUrl + } else if strings.ToLower(c.ServerZone) == "eu" { + c.CohortSyncConfig.CohortServerUrl = EUCohortSyncUrl + } } return c diff --git a/pkg/experiment/local/config_test.go b/pkg/experiment/local/config_test.go new file mode 100644 index 0000000..4932281 --- /dev/null +++ b/pkg/experiment/local/config_test.go @@ -0,0 +1,175 @@ +package local + +import ( + "strings" + "testing" + "time" +) + +func TestFillConfigDefaults_ServerZoneAndServerUrl(t *testing.T) { + tests := []struct { + name string + input *Config + expectedZone string + expectedUrl string + }{ + { + name: "Nil config", + input: nil, + expectedZone: DefaultConfig.ServerZone, + expectedUrl: DefaultConfig.ServerUrl, + }, + { + name: "Empty ServerZone", + input: &Config{}, + expectedZone: DefaultConfig.ServerZone, + expectedUrl: DefaultConfig.ServerUrl, + }, + { + name: "ServerZone US", + input: &Config{ServerZone: "us"}, + expectedZone: "us", + expectedUrl: DefaultConfig.ServerUrl, + }, + { + name: "ServerZone EU", + input: &Config{ServerZone: "eu"}, + expectedZone: "eu", + expectedUrl: EUFlagServerUrl, + }, + { + name: "Uppercase ServerZone EU", + input: &Config{ServerZone: "EU"}, + expectedZone: "EU", + expectedUrl: EUFlagServerUrl, + }, + { + name: "Custom ServerUrl", + input: &Config{ServerZone: "us", ServerUrl: "https://custom.url/"}, + expectedZone: "us", + expectedUrl: "https://custom.url/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fillConfigDefaults(tt.input) + if strings.ToLower(result.ServerZone) != strings.ToLower(tt.expectedZone) { + t.Errorf("expected ServerZone %s, got %s", tt.expectedZone, result.ServerZone) + } + if result.ServerUrl != tt.expectedUrl { + t.Errorf("expected ServerUrl %s, got %s", tt.expectedUrl, result.ServerUrl) + } + }) + } +} + +func TestFillConfigDefaults_CohortSyncConfig(t *testing.T) { + tests := []struct { + name string + input *Config + expectedUrl string + }{ + { + name: "Nil CohortSyncConfig", + input: &Config{ + ServerZone: "eu", + CohortSyncConfig: nil, + }, + expectedUrl: "", + }, + { + name: "CohortSyncConfig with empty CohortServerUrl", + input: &Config{ + ServerZone: "eu", + CohortSyncConfig: &CohortSyncConfig{}, + }, + expectedUrl: EUCohortSyncUrl, + }, + { + name: "CohortSyncConfig with custom CohortServerUrl", + input: &Config{ + ServerZone: "us", + CohortSyncConfig: &CohortSyncConfig{ + CohortServerUrl: "https://custom-cohort.url/", + }, + }, + expectedUrl: "https://custom-cohort.url/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fillConfigDefaults(tt.input) + if tt.input.CohortSyncConfig == nil { + if result.CohortSyncConfig == nil { + return + } + if result.CohortSyncConfig.CohortServerUrl != tt.expectedUrl { + t.Errorf("expected CohortServerUrl %s, got %s", tt.expectedUrl, result.CohortSyncConfig.CohortServerUrl) + } + } else { + if result.CohortSyncConfig.CohortServerUrl != tt.expectedUrl { + t.Errorf("expected CohortServerUrl %s, got %s", tt.expectedUrl, result.CohortSyncConfig.CohortServerUrl) + } + } + }) + } +} + +func TestFillConfigDefaults_DefaultValues(t *testing.T) { + tests := []struct { + name string + input *Config + expected *Config + }{ + { + name: "Nil config", + input: nil, + expected: DefaultConfig, + }, + { + name: "Empty config", + input: &Config{}, + expected: &Config{ + ServerZone: DefaultConfig.ServerZone, + ServerUrl: DefaultConfig.ServerUrl, + FlagConfigPollerInterval: DefaultConfig.FlagConfigPollerInterval, + FlagConfigPollerRequestTimeout: DefaultConfig.FlagConfigPollerRequestTimeout, + }, + }, + { + name: "Custom values", + input: &Config{ + ServerZone: "eu", + ServerUrl: "https://custom.url/", + FlagConfigPollerInterval: 60 * time.Second, + FlagConfigPollerRequestTimeout: 20 * time.Second, + }, + expected: &Config{ + ServerZone: "eu", + ServerUrl: "https://custom.url/", + FlagConfigPollerInterval: 60 * time.Second, + FlagConfigPollerRequestTimeout: 20 * time.Second, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fillConfigDefaults(tt.input) + if result.ServerZone != tt.expected.ServerZone { + t.Errorf("expected ServerZone %s, got %s", tt.expected.ServerZone, result.ServerZone) + } + if result.ServerUrl != tt.expected.ServerUrl { + t.Errorf("expected ServerUrl %s, got %s", tt.expected.ServerUrl, result.ServerUrl) + } + if result.FlagConfigPollerInterval != tt.expected.FlagConfigPollerInterval { + t.Errorf("expected FlagConfigPollerInterval %v, got %v", tt.expected.FlagConfigPollerInterval, result.FlagConfigPollerInterval) + } + if result.FlagConfigPollerRequestTimeout != tt.expected.FlagConfigPollerRequestTimeout { + t.Errorf("expected FlagConfigPollerRequestTimeout %v, got %v", tt.expected.FlagConfigPollerRequestTimeout, result.FlagConfigPollerRequestTimeout) + } + }) + } +} diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index a972aa2..bf1e29f 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -115,8 +115,11 @@ func (dr *DeploymentRunner) loadCohorts(flagConfig evaluation.Flag, cohortIDs ma task := dr.cohortLoader.LoadCohort(cohortID) err := task.Wait() if err != nil { - dr.logger.Printf("Failed to load cohort %s for flag %s: %v", cohortID, flagConfig.Key, err) - return err + if _, ok := err.(*CohortNotModifiedException); !ok { + dr.logger.Printf("Failed to load cohort %s for flag %s: %v", cohortID, flagConfig.Key, err) + return err + } + continue } dr.logger.Printf("Cohort %s loaded for flag %s", cohortID, flagConfig.Key) } diff --git a/pkg/experiment/local/exception.go b/pkg/experiment/local/exception.go new file mode 100644 index 0000000..402f0e1 --- /dev/null +++ b/pkg/experiment/local/exception.go @@ -0,0 +1,30 @@ +package local + +type HTTPErrorResponseException struct { + StatusCode int + Message string +} + +func (e *HTTPErrorResponseException) Error() string { + return e.Message +} + +type CohortTooLargeException struct { + Message string +} + +func (e *CohortTooLargeException) Error() string { + return e.Message +} + +type CohortNotModifiedException struct { + Message string +} + +func (e *CohortNotModifiedException) Error() string { + return e.Message +} + +type CohortDownloadApi interface { + GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) +} From 8ba6b45c57915aab2c0c5a7c7cc1455a258bbe40 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 1 Jul 2024 14:28:07 -0700 Subject: [PATCH 07/29] fix logging --- pkg/experiment/local/cohort_download_api.go | 15 ++++------ pkg/experiment/local/deployment_runner.go | 31 ++++++++++----------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 00eaeab..b4881ad 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -3,7 +3,7 @@ package local import ( "encoding/base64" "encoding/json" - "log" + "github.com/amplitude/experiment-go-server/internal/logger" "net/http" "strconv" "time" @@ -16,7 +16,7 @@ type DirectCohortDownloadApi struct { CohortRequestDelayMillis int ServerUrl string Debug bool - Logger *log.Logger + log *logger.Log } func NewDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortRequestDelayMillis int, serverUrl string, debug bool) *DirectCohortDownloadApi { @@ -27,23 +27,20 @@ func NewDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortR CohortRequestDelayMillis: cohortRequestDelayMillis, ServerUrl: serverUrl, Debug: debug, - Logger: log.New(log.Writer(), "Amplitude: ", log.LstdFlags), - } - if debug { - api.Logger.SetFlags(log.LstdFlags | log.Lshortfile) + log: logger.New(debug), } return api } func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { - api.Logger.Printf("getCohortMembers(%s): start", cohortID) + api.log.Debug("getCohortMembers(%s): start", cohortID) errors := 0 client := &http.Client{} for { response, err := api.getCohortMembersRequest(client, cohortID, cohort) if err != nil { - api.Logger.Printf("getCohortMembers(%s): request-status error %d - %v", cohortID, errors, err) + api.log.Error("getCohortMembers(%s): request-status error %d - %v", cohortID, errors, err) errors++ if errors >= 3 || func(err error) bool { switch err.(type) { @@ -70,7 +67,7 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( if err := json.NewDecoder(response.Body).Decode(&cohortInfo); err != nil { return nil, err } - api.Logger.Printf("getCohortMembers(%s): end - resultSize=%d", cohortID, cohortInfo.Size) + api.log.Debug("getCohortMembers(%s): end - resultSize=%d", cohortID, cohortInfo.Size) return &Cohort{ ID: cohortInfo.Id, LastModified: cohortInfo.LastModified, diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index bf1e29f..63b43ca 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -2,7 +2,7 @@ package local import ( "github.com/amplitude/experiment-go-server/internal/evaluation" - "log" + "github.com/amplitude/experiment-go-server/internal/logger" "sync" ) @@ -14,7 +14,7 @@ type DeploymentRunner struct { cohortLoader *CohortLoader lock sync.Mutex poller *poller - logger *log.Logger + log *logger.Log } func NewDeploymentRunner( @@ -30,10 +30,7 @@ func NewDeploymentRunner( flagConfigStorage: flagConfigStorage, cohortStorage: cohortStorage, cohortLoader: cohortLoader, - logger: log.New(log.Writer(), "Amplitude: ", log.LstdFlags), - } - if config.Debug { - dr.logger.SetFlags(log.LstdFlags | log.Lshortfile) + log: logger.New(config.Debug), } dr.poller = newPoller() return dr @@ -44,13 +41,13 @@ func (dr *DeploymentRunner) Start() error { defer dr.lock.Unlock() if err := dr.refresh(); err != nil { - dr.logger.Printf("Initial refresh failed: %v", err) + dr.log.Error("Initial refresh failed: %v", err) return err } dr.poller.Poll(dr.config.FlagConfigPollerInterval, func() { if err := dr.periodicRefresh(); err != nil { - dr.logger.Printf("Periodic refresh failed: %v", err) + dr.log.Error("Periodic refresh failed: %v", err) } }) return nil @@ -59,17 +56,17 @@ func (dr *DeploymentRunner) Start() error { func (dr *DeploymentRunner) periodicRefresh() error { defer func() { if r := recover(); r != nil { - dr.logger.Printf("Recovered in periodicRefresh: %v", r) + dr.log.Error("Recovered in periodicRefresh: %v", r) } }() return dr.refresh() } func (dr *DeploymentRunner) refresh() error { - dr.logger.Println("Refreshing flag configs.") + dr.log.Debug("Refreshing flag configs.") flagConfigs, err := dr.flagConfigApi.GetFlagConfigs() if err != nil { - dr.logger.Printf("Failed to fetch flag configs: %v", err) + dr.log.Error("Failed to fetch flag configs: %v", err) return err } @@ -86,7 +83,7 @@ func (dr *DeploymentRunner) refresh() error { for _, flagConfig := range flagConfigs { cohortIDs := getAllCohortIDsFromFlag(flagConfig) if dr.cohortLoader == nil || len(cohortIDs) == 0 { - dr.logger.Printf("Putting non-cohort flag %s", flagConfig.Key) + dr.log.Debug("Putting non-cohort flag %s", flagConfig.Key) dr.flagConfigStorage.PutFlagConfig(flagConfig) continue } @@ -95,17 +92,17 @@ func (dr *DeploymentRunner) refresh() error { err := dr.loadCohorts(*flagConfig, cohortIDs) if err != nil { - dr.logger.Printf("Failed to load all cohorts for flag %s. Using the old flag config.", flagConfig.Key) + dr.log.Error("Failed to load all cohorts for flag %s. Using the old flag config.", flagConfig.Key) dr.flagConfigStorage.PutFlagConfig(oldFlagConfig) return err } dr.flagConfigStorage.PutFlagConfig(flagConfig) - dr.logger.Printf("Stored flag config %s", flagConfig.Key) + dr.log.Debug("Stored flag config %s", flagConfig.Key) } dr.deleteUnusedCohorts() - dr.logger.Printf("Refreshed %d flag configs.", len(flagConfigs)) + dr.log.Debug("Refreshed %d flag configs.", len(flagConfigs)) return nil } @@ -116,12 +113,12 @@ func (dr *DeploymentRunner) loadCohorts(flagConfig evaluation.Flag, cohortIDs ma err := task.Wait() if err != nil { if _, ok := err.(*CohortNotModifiedException); !ok { - dr.logger.Printf("Failed to load cohort %s for flag %s: %v", cohortID, flagConfig.Key, err) + dr.log.Error("Failed to load cohort %s for flag %s: %v", cohortID, flagConfig.Key, err) return err } continue } - dr.logger.Printf("Cohort %s loaded for flag %s", cohortID, flagConfig.Key) + dr.log.Debug("Cohort %s loaded for flag %s", cohortID, flagConfig.Key) } return nil } From dcc2a0702e2c0ff8f886fa386157051abaa6e640 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 1 Jul 2024 14:39:10 -0700 Subject: [PATCH 08/29] fix lint --- pkg/experiment/local/client.go | 3 +++ pkg/experiment/local/cohort_loader_test.go | 23 ++++++++++++++++++---- pkg/experiment/local/config.go | 8 ++++---- pkg/experiment/local/config_test.go | 2 +- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index e993c4b..6205308 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -119,6 +119,9 @@ func (c *Client) Evaluate(user *experiment.User, flagKeys []string) (map[string] func (c *Client) EvaluateV2(user *experiment.User, flagKeys []string) (map[string]experiment.Variant, error) { flagConfigs := c.flagConfigStorage.GetFlagConfigs() enrichedUser, err := c.enrichUser(user, flagConfigs) + if err != nil { + return nil, err + } userContext := evaluation.UserToContext(enrichedUser) c.flagsMutex.RLock() sortedFlags, err := topologicalSort(flagConfigs, flagKeys) diff --git a/pkg/experiment/local/cohort_loader_test.go b/pkg/experiment/local/cohort_loader_test.go index 25506e7..fe3ab80 100644 --- a/pkg/experiment/local/cohort_loader_test.go +++ b/pkg/experiment/local/cohort_loader_test.go @@ -57,8 +57,15 @@ func TestFilterCohortsAlreadyComputed(t *testing.T) { api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "a", LastModified: 0, Size: 0, MemberIDs: []string{}, GroupType: userGroupType}, nil) api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "b", LastModified: 1, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType}, nil) - loader.LoadCohort("a").Wait() - loader.LoadCohort("b").Wait() + futureA := loader.LoadCohort("a") + futureB := loader.LoadCohort("b") + + if err := futureA.Wait(); err != nil { + t.Errorf("futureA.Wait() returned error: %v", err) + } + if err := futureB.Wait(); err != nil { + t.Errorf("futureB.Wait() returned error: %v", err) + } storageDescriptionA := storage.GetCohort("a") storageDescriptionB := storage.GetCohort("b") @@ -89,14 +96,22 @@ func TestLoadDownloadFailureThrows(t *testing.T) { api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(nil, errors.New("connection timed out")) api.On("GetCohort", "c", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "c", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType}, nil) - loader.LoadCohort("a").Wait() + futureA := loader.LoadCohort("a") errB := loader.LoadCohort("b").Wait() - loader.LoadCohort("c").Wait() + futureC := loader.LoadCohort("c") + + if err := futureA.Wait(); err != nil { + t.Errorf("futureA.Wait() returned error: %v", err) + } if errB == nil || errB.Error() != "connection timed out" { t.Errorf("futureB.Wait() expected 'Connection timed out' error, got: %v", errB) } + if err := futureC.Wait(); err != nil { + t.Errorf("futureC.Wait() returned error: %v", err) + } + expectedCohorts := map[string]struct{}{"a": {}, "c": {}} actualCohorts := storage.GetCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}, "c": {}}) if len(actualCohorts) != len(expectedCohorts) { diff --git a/pkg/experiment/local/config.go b/pkg/experiment/local/config.go index 443bc5e..75c0daf 100644 --- a/pkg/experiment/local/config.go +++ b/pkg/experiment/local/config.go @@ -58,9 +58,9 @@ func fillConfigDefaults(c *Config) *Config { c.ServerZone = DefaultConfig.ServerZone } if c.ServerUrl == "" { - if strings.ToLower(c.ServerZone) == DefaultConfig.ServerZone { + if strings.EqualFold(strings.ToLower(c.ServerZone), strings.ToLower(DefaultConfig.ServerZone)) { c.ServerUrl = DefaultConfig.ServerUrl - } else if strings.ToLower(c.ServerZone) == "eu" { + } else if strings.EqualFold(strings.ToLower(c.ServerZone), "eu") { c.ServerUrl = EUFlagServerUrl } } @@ -84,9 +84,9 @@ func fillConfigDefaults(c *Config) *Config { } if c.CohortSyncConfig != nil && c.CohortSyncConfig.CohortServerUrl == "" { - if strings.ToLower(c.ServerZone) == DefaultConfig.ServerZone { + if strings.EqualFold(strings.ToLower(c.ServerZone), strings.ToLower(DefaultConfig.ServerZone)) { c.CohortSyncConfig.CohortServerUrl = DefaultCohortSyncConfig.CohortServerUrl - } else if strings.ToLower(c.ServerZone) == "eu" { + } else if strings.EqualFold(strings.ToLower(c.ServerZone), "eu") { c.CohortSyncConfig.CohortServerUrl = EUCohortSyncUrl } } diff --git a/pkg/experiment/local/config_test.go b/pkg/experiment/local/config_test.go index 4932281..8622b0c 100644 --- a/pkg/experiment/local/config_test.go +++ b/pkg/experiment/local/config_test.go @@ -54,7 +54,7 @@ func TestFillConfigDefaults_ServerZoneAndServerUrl(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := fillConfigDefaults(tt.input) - if strings.ToLower(result.ServerZone) != strings.ToLower(tt.expectedZone) { + if !strings.EqualFold(result.ServerZone, tt.expectedZone) { t.Errorf("expected ServerZone %s, got %s", tt.expectedZone, result.ServerZone) } if result.ServerUrl != tt.expectedUrl { From 96f5f3299ecca30ba8b1c92a0409029bedbdcf2f Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 1 Jul 2024 14:55:10 -0700 Subject: [PATCH 09/29] fix cohort_download_api_test.go, rename cohort vars --- pkg/experiment/local/cohort.go | 18 ++--- pkg/experiment/local/cohort_download_api.go | 4 +- .../local/cohort_download_api_test.go | 72 +++++++++++++------ pkg/experiment/local/cohort_loader_test.go | 24 +++---- pkg/experiment/local/cohort_storage.go | 6 +- 5 files changed, 78 insertions(+), 46 deletions(-) diff --git a/pkg/experiment/local/cohort.go b/pkg/experiment/local/cohort.go index 54d2f4d..c94660b 100644 --- a/pkg/experiment/local/cohort.go +++ b/pkg/experiment/local/cohort.go @@ -5,27 +5,27 @@ import "sort" const userGroupType = "User" type Cohort struct { - ID string + Id string LastModified int64 Size int - MemberIDs []string + MemberIds []string GroupType string } func CohortEquals(c1, c2 *Cohort) bool { - if c1.ID != c2.ID || c1.LastModified != c2.LastModified || c1.Size != c2.Size || c1.GroupType != c2.GroupType { + if c1.Id != c2.Id || c1.LastModified != c2.LastModified || c1.Size != c2.Size || c1.GroupType != c2.GroupType { return false } - if len(c1.MemberIDs) != len(c2.MemberIDs) { + if len(c1.MemberIds) != len(c2.MemberIds) { return false } - // Sort MemberIDs before comparing - sort.Strings(c1.MemberIDs) - sort.Strings(c2.MemberIDs) + // Sort MemberIds before comparing + sort.Strings(c1.MemberIds) + sort.Strings(c2.MemberIds) - for i := range c1.MemberIDs { - if c1.MemberIDs[i] != c2.MemberIDs[i] { + for i := range c1.MemberIds { + if c1.MemberIds[i] != c2.MemberIds[i] { return false } } diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index b4881ad..ae523dd 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -69,10 +69,10 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( } api.log.Debug("getCohortMembers(%s): end - resultSize=%d", cohortID, cohortInfo.Size) return &Cohort{ - ID: cohortInfo.Id, + Id: cohortInfo.Id, LastModified: cohortInfo.LastModified, Size: cohortInfo.Size, - MemberIDs: cohortInfo.MemberIds, + MemberIds: cohortInfo.MemberIds, GroupType: func() string { if cohortInfo.GroupType == "" { return userGroupType diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go index 0d7dcff..4351b67 100644 --- a/pkg/experiment/local/cohort_download_api_test.go +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -13,6 +13,14 @@ type MockCohortDownloadApi struct { mock.Mock } +type cohortInfo struct { + Id string `json:"cohortId"` + LastModified int64 `json:"lastModified"` + Size int `json:"size"` + MemberIds []string `json:"memberIds"` + GroupType string `json:"groupType"` +} + func (m *MockCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { args := m.Called(cohortID, cohort) if args.Get(0) != nil { @@ -28,8 +36,8 @@ func TestCohortDownloadApi(t *testing.T) { api := NewDirectCohortDownloadApi("api", "secret", 15000, 100, "https://server.amplitude.com", false) t.Run("test_cohort_download_success", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}, GroupType: userGroupType} - response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + response := cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), func(req *http.Request) (*http.Response, error) { @@ -43,12 +51,16 @@ func TestCohortDownloadApi(t *testing.T) { resultCohort, err := api.GetCohort("1234", cohort) assert.NoError(t, err) - assert.Equal(t, cohort, resultCohort) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) }) t.Run("test_cohort_download_many_202s_success", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}, GroupType: userGroupType} - response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} for i := 0; i < 9; i++ { httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), @@ -67,12 +79,16 @@ func TestCohortDownloadApi(t *testing.T) { resultCohort, err := api.GetCohort("1234", cohort) assert.NoError(t, err) - assert.Equal(t, cohort, resultCohort) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) }) t.Run("test_cohort_request_status_with_two_failures_succeeds", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}, GroupType: userGroupType} - response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), httpmock.NewStringResponder(503, ""), @@ -92,12 +108,16 @@ func TestCohortDownloadApi(t *testing.T) { resultCohort, err := api.GetCohort("1234", cohort) assert.NoError(t, err) - assert.Equal(t, cohort, resultCohort) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) }) t.Run("test_cohort_request_status_429s_keep_retrying", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}, GroupType: userGroupType} - response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"user"}} + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} for i := 0; i < 9; i++ { httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), @@ -116,12 +136,16 @@ func TestCohortDownloadApi(t *testing.T) { resultCohort, err := api.GetCohort("1234", cohort) assert.NoError(t, err) - assert.Equal(t, cohort, resultCohort) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) }) t.Run("test_group_cohort_download_success", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"group"}, GroupType: "org name"} - response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"group"}, GroupType: "org name"} + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"group"}, GroupType: "org name"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"group"}, GroupType: "org name"} httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), func(req *http.Request) (*http.Response, error) { @@ -135,12 +159,16 @@ func TestCohortDownloadApi(t *testing.T) { resultCohort, err := api.GetCohort("1234", cohort) assert.NoError(t, err) - assert.Equal(t, cohort, resultCohort) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) }) t.Run("test_group_cohort_request_status_429s_keep_retrying", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"group"}, GroupType: "org name"} - response := &Cohort{ID: "1234", LastModified: 0, Size: 1, MemberIDs: []string{"group"}, GroupType: "org name"} + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"group"}, GroupType: "org name"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"group"}, GroupType: "org name"} for i := 0; i < 9; i++ { httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), @@ -159,11 +187,15 @@ func TestCohortDownloadApi(t *testing.T) { resultCohort, err := api.GetCohort("1234", cohort) assert.NoError(t, err) - assert.Equal(t, cohort, resultCohort) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) }) t.Run("test_cohort_size_too_large", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 0, Size: 16000, MemberIDs: []string{}} + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 16000, MemberIds: []string{}} httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), httpmock.NewStringResponder(413, ""), @@ -176,7 +208,7 @@ func TestCohortDownloadApi(t *testing.T) { }) t.Run("test_cohort_not_modified_exception", func(t *testing.T) { - cohort := &Cohort{ID: "1234", LastModified: 1000, Size: 1, MemberIDs: []string{}} + cohort := &Cohort{Id: "1234", LastModified: 1000, Size: 1, MemberIds: []string{}} httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), httpmock.NewStringResponder(204, ""), diff --git a/pkg/experiment/local/cohort_loader_test.go b/pkg/experiment/local/cohort_loader_test.go index fe3ab80..d2dd799 100644 --- a/pkg/experiment/local/cohort_loader_test.go +++ b/pkg/experiment/local/cohort_loader_test.go @@ -13,8 +13,8 @@ func TestLoadSuccess(t *testing.T) { loader := NewCohortLoader(api, storage) // Define mock behavior - api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "a", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType}, nil) - api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "b", LastModified: 0, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType}, nil) + api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) + api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "b", LastModified: 0, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType}, nil) futureA := loader.LoadCohort("a") futureB := loader.LoadCohort("b") @@ -28,8 +28,8 @@ func TestLoadSuccess(t *testing.T) { storageDescriptionA := storage.GetCohort("a") storageDescriptionB := storage.GetCohort("b") - expectedA := &Cohort{ID: "a", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType} - expectedB := &Cohort{ID: "b", LastModified: 0, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType} + expectedA := &Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType} + expectedB := &Cohort{Id: "b", LastModified: 0, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType} if !CohortEquals(storageDescriptionA, expectedA) { t.Errorf("Unexpected cohort A stored: %+v", storageDescriptionA) @@ -50,12 +50,12 @@ func TestFilterCohortsAlreadyComputed(t *testing.T) { storage := NewInMemoryCohortStorage() loader := NewCohortLoader(api, storage) - storage.PutCohort(&Cohort{ID: "a", LastModified: 0, Size: 0, MemberIDs: []string{}}) - storage.PutCohort(&Cohort{ID: "b", LastModified: 0, Size: 0, MemberIDs: []string{}}) + storage.PutCohort(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}}) + storage.PutCohort(&Cohort{Id: "b", LastModified: 0, Size: 0, MemberIds: []string{}}) // Define mock behavior - api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "a", LastModified: 0, Size: 0, MemberIDs: []string{}, GroupType: userGroupType}, nil) - api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "b", LastModified: 1, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType}, nil) + api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}, GroupType: userGroupType}, nil) + api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "b", LastModified: 1, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType}, nil) futureA := loader.LoadCohort("a") futureB := loader.LoadCohort("b") @@ -69,8 +69,8 @@ func TestFilterCohortsAlreadyComputed(t *testing.T) { storageDescriptionA := storage.GetCohort("a") storageDescriptionB := storage.GetCohort("b") - expectedA := &Cohort{ID: "a", LastModified: 0, Size: 0, MemberIDs: []string{}, GroupType: userGroupType} - expectedB := &Cohort{ID: "b", LastModified: 1, Size: 2, MemberIDs: []string{"1", "2"}, GroupType: userGroupType} + expectedA := &Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}, GroupType: userGroupType} + expectedB := &Cohort{Id: "b", LastModified: 1, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType} if !CohortEquals(storageDescriptionA, expectedA) { t.Errorf("Unexpected cohort A stored: %+v", storageDescriptionA) @@ -92,9 +92,9 @@ func TestLoadDownloadFailureThrows(t *testing.T) { loader := NewCohortLoader(api, storage) // Define mock behavior - api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "a", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType}, nil) + api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(nil, errors.New("connection timed out")) - api.On("GetCohort", "c", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{ID: "c", LastModified: 0, Size: 1, MemberIDs: []string{"1"}, GroupType: userGroupType}, nil) + api.On("GetCohort", "c", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "c", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) futureA := loader.LoadCohort("a") errB := loader.LoadCohort("b").Wait() diff --git a/pkg/experiment/local/cohort_storage.go b/pkg/experiment/local/cohort_storage.go index 737e4ed..ed879ce 100644 --- a/pkg/experiment/local/cohort_storage.go +++ b/pkg/experiment/local/cohort_storage.go @@ -59,7 +59,7 @@ func (s *InMemoryCohortStorage) GetCohortsForGroup(groupType, groupName string, for cohortID := range cohortIDs { if _, exists := groupTypeCohorts[cohortID]; exists { if cohort, found := s.cohortStore[cohortID]; found { - for _, memberID := range cohort.MemberIDs { + for _, memberID := range cohort.MemberIds { if memberID == groupName { result[cohortID] = struct{}{} break @@ -78,8 +78,8 @@ func (s *InMemoryCohortStorage) PutCohort(cohort *Cohort) { if _, exists := s.groupToCohortStore[cohort.GroupType]; !exists { s.groupToCohortStore[cohort.GroupType] = make(map[string]struct{}) } - s.groupToCohortStore[cohort.GroupType][cohort.ID] = struct{}{} - s.cohortStore[cohort.ID] = cohort + s.groupToCohortStore[cohort.GroupType][cohort.Id] = struct{}{} + s.cohortStore[cohort.Id] = cohort } func (s *InMemoryCohortStorage) DeleteCohort(groupType, cohortID string) { From 9754ca20c8cd43e3afe011fb8b05c997947e33df Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 1 Jul 2024 15:06:44 -0700 Subject: [PATCH 10/29] fix flag_config_test and handle old flag config not existing in deployment_runner --- pkg/experiment/local/deployment_runner.go | 6 ++++-- pkg/experiment/local/flag_config_test.go | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 63b43ca..95e88c6 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -92,8 +92,10 @@ func (dr *DeploymentRunner) refresh() error { err := dr.loadCohorts(*flagConfig, cohortIDs) if err != nil { - dr.log.Error("Failed to load all cohorts for flag %s. Using the old flag config.", flagConfig.Key) - dr.flagConfigStorage.PutFlagConfig(oldFlagConfig) + if oldFlagConfig != nil { + dr.log.Error("Failed to load all cohorts for flag %s. Using the old flag config.", flagConfig.Key) + dr.flagConfigStorage.PutFlagConfig(oldFlagConfig) + } return err } diff --git a/pkg/experiment/local/flag_config_test.go b/pkg/experiment/local/flag_config_test.go index 37e3d6f..686e672 100644 --- a/pkg/experiment/local/flag_config_test.go +++ b/pkg/experiment/local/flag_config_test.go @@ -28,7 +28,7 @@ func TestGetAllCohortIDsFromFlag(t *testing.T) { func TestGetGroupedCohortIDsFromFlag(t *testing.T) { flags := getTestFlags() expectedGroupedCohortIDs := map[string][]string{ - "user": {"cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6"}, + "User": {"cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6"}, "group_name": {"cohort7", "cohort8"}, } @@ -66,7 +66,7 @@ func TestGetAllCohortIDsFromFlags(t *testing.T) { func TestGetGroupedCohortIDsFromFlags(t *testing.T) { flags := getTestFlags() expectedGroupedCohortIDs := map[string][]string{ - "user": {"cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6"}, + "User": {"cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6"}, "group_name": {"cohort7", "cohort8"}, } From b50319df1f7bc7571b2df973bb160fc3a3afde9f Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 1 Jul 2024 15:17:27 -0700 Subject: [PATCH 11/29] Update max cohort size, remove unused local flags var --- pkg/experiment/local/client.go | 2 -- pkg/experiment/local/config.go | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 6205308..bce584e 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -27,7 +27,6 @@ type Client struct { config *Config client *http.Client poller *poller - flags map[string]*evaluation.Flag flagsMutex *sync.RWMutex engine *evaluation.Engine assignmentService *assignmentService @@ -69,7 +68,6 @@ func Initialize(apiKey string, config *Config) *Client { config: config, client: &http.Client{}, poller: newPoller(), - flags: make(map[string]*evaluation.Flag), flagsMutex: &sync.RWMutex{}, engine: evaluation.NewEngine(log), assignmentService: as, diff --git a/pkg/experiment/local/config.go b/pkg/experiment/local/config.go index 75c0daf..94850ff 100644 --- a/pkg/experiment/local/config.go +++ b/pkg/experiment/local/config.go @@ -2,6 +2,7 @@ package local import ( "github.com/amplitude/analytics-go/amplitude" + "math" "strings" "time" ) @@ -45,7 +46,7 @@ var DefaultAssignmentConfig = &AssignmentConfig{ } var DefaultCohortSyncConfig = &CohortSyncConfig{ - MaxCohortSize: 15000, + MaxCohortSize: math.MaxInt32, CohortRequestDelayMillis: 5000, CohortServerUrl: "https://cohort-v2.lab.amplitude.com", } From 1afeeb92d93e049774c731f3a5cee7df9c562e05 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 3 Jul 2024 13:10:59 -0700 Subject: [PATCH 12/29] Refactor deployment_runner, rename classes and methods --- pkg/experiment/local/client.go | 28 +-- pkg/experiment/local/cohort_download_api.go | 18 +- .../local/cohort_download_api_test.go | 20 +-- pkg/experiment/local/cohort_loader.go | 79 ++++++--- pkg/experiment/local/cohort_loader_test.go | 88 +++++----- pkg/experiment/local/cohort_storage.go | 43 +++-- pkg/experiment/local/deployment_runner.go | 164 +++++++++++------- .../local/deployment_runner_test.go | 18 +- pkg/experiment/local/exception.go | 4 - pkg/experiment/local/flag_config_api.go | 10 +- pkg/experiment/local/flag_config_storage.go | 24 +-- 11 files changed, 292 insertions(+), 204 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index bce584e..5552adc 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -31,9 +31,9 @@ type Client struct { engine *evaluation.Engine assignmentService *assignmentService cohortStorage CohortStorage - flagConfigStorage FlagConfigStorage - cohortLoader *CohortLoader - deploymentRunner *DeploymentRunner + flagConfigStorage flagConfigStorage + cohortLoader *cohortLoader + deploymentRunner *deploymentRunner } func Initialize(apiKey string, config *Config) *Client { @@ -53,15 +53,15 @@ func Initialize(apiKey string, config *Config) *Client { filter: newAssignmentFilter(config.AssignmentConfig.CacheCapacity), } } - cohortStorage := NewInMemoryCohortStorage() + cohortStorage := newInMemoryCohortStorage() flagConfigStorage := NewInMemoryFlagConfigStorage() - var cohortLoader *CohortLoader - var deploymentRunner *DeploymentRunner + var cohortLoader *cohortLoader + var deploymentRunner *deploymentRunner if config.CohortSyncConfig != nil { - cohortDownloadApi := NewDirectCohortDownloadApi(config.CohortSyncConfig.ApiKey, config.CohortSyncConfig.SecretKey, config.CohortSyncConfig.MaxCohortSize, config.CohortSyncConfig.CohortRequestDelayMillis, config.CohortSyncConfig.CohortServerUrl, config.Debug) - cohortLoader = NewCohortLoader(cohortDownloadApi, cohortStorage) + cohortDownloadApi := newDirectCohortDownloadApi(config.CohortSyncConfig.ApiKey, config.CohortSyncConfig.SecretKey, config.CohortSyncConfig.MaxCohortSize, config.CohortSyncConfig.CohortRequestDelayMillis, config.CohortSyncConfig.CohortServerUrl, config.Debug) + cohortLoader = newCohortLoader(cohortDownloadApi, cohortStorage) } - deploymentRunner = NewDeploymentRunner(config, NewFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), flagConfigStorage, cohortStorage, cohortLoader) + deploymentRunner = newDeploymentRunner(config, newFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), flagConfigStorage, cohortStorage, cohortLoader) client = &Client{ log: log, apiKey: apiKey, @@ -84,7 +84,7 @@ func Initialize(apiKey string, config *Config) *Client { } func (c *Client) Start() error { - err := c.deploymentRunner.Start() + err := c.deploymentRunner.start() if err != nil { return err } @@ -115,7 +115,7 @@ func (c *Client) Evaluate(user *experiment.User, flagKeys []string) (map[string] } func (c *Client) EvaluateV2(user *experiment.User, flagKeys []string) (map[string]experiment.Variant, error) { - flagConfigs := c.flagConfigStorage.GetFlagConfigs() + flagConfigs := c.flagConfigStorage.getFlagConfigs() enrichedUser, err := c.enrichUser(user, flagConfigs) if err != nil { return nil, err @@ -159,7 +159,7 @@ func (c *Client) FlagsV2() (string, error) { // FlagMetadata returns a copy of the flag's metadata. If the flag is not found then nil is returned. func (c *Client) FlagMetadata(flagKey string) map[string]interface{} { - f := c.flagConfigStorage.GetFlagConfig(flagKey) + f := c.flagConfigStorage.getFlagConfig(flagKey) if f == nil { return nil } @@ -348,7 +348,7 @@ func (c *Client) enrichUser(user *experiment.User, flagConfigs map[string]*evalu if cohortIDs, ok := groupedCohortIDs[userGroupType]; ok { if len(cohortIDs) > 0 && user.UserId != "" { - user.CohortIds = c.cohortStorage.GetCohortsForUser(user.UserId, cohortIDs) + user.CohortIds = c.cohortStorage.getCohortsForUser(user.UserId, cohortIDs) } } @@ -362,7 +362,7 @@ func (c *Client) enrichUser(user *experiment.User, flagConfigs map[string]*evalu continue } if cohortIDs, ok := groupedCohortIDs[groupType]; ok { - user.AddGroupCohortIDs(groupType, groupName, c.cohortStorage.GetCohortsForGroup(groupType, groupName, cohortIDs)) + user.AddGroupCohortIDs(groupType, groupName, c.cohortStorage.getCohortsForGroup(groupType, groupName, cohortIDs)) } } } diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index ae523dd..45e5320 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -9,7 +9,11 @@ import ( "time" ) -type DirectCohortDownloadApi struct { +type cohortDownloadApi interface { + getCohort(cohortID string, cohort *Cohort) (*Cohort, error) +} + +type directCohortDownloadApi struct { ApiKey string SecretKey string MaxCohortSize int @@ -19,8 +23,8 @@ type DirectCohortDownloadApi struct { log *logger.Log } -func NewDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortRequestDelayMillis int, serverUrl string, debug bool) *DirectCohortDownloadApi { - api := &DirectCohortDownloadApi{ +func newDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortRequestDelayMillis int, serverUrl string, debug bool) *directCohortDownloadApi { + api := &directCohortDownloadApi{ ApiKey: apiKey, SecretKey: secretKey, MaxCohortSize: maxCohortSize, @@ -32,7 +36,7 @@ func NewDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortR return api } -func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { +func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) (*Cohort, error) { api.log.Debug("getCohortMembers(%s): start", cohortID) errors := 0 client := &http.Client{} @@ -90,7 +94,7 @@ func (api *DirectCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) ( } } -func (api *DirectCohortDownloadApi) getCohortMembersRequest(client *http.Client, cohortID string, cohort *Cohort) (*http.Response, error) { +func (api *directCohortDownloadApi) getCohortMembersRequest(client *http.Client, cohortID string, cohort *Cohort) (*http.Response, error) { req, err := http.NewRequest("GET", api.buildCohortURL(cohortID, cohort), nil) if err != nil { return nil, err @@ -99,12 +103,12 @@ func (api *DirectCohortDownloadApi) getCohortMembersRequest(client *http.Client, return client.Do(req) } -func (api *DirectCohortDownloadApi) getBasicAuth() string { +func (api *directCohortDownloadApi) getBasicAuth() string { auth := api.ApiKey + ":" + api.SecretKey return base64.StdEncoding.EncodeToString([]byte(auth)) } -func (api *DirectCohortDownloadApi) buildCohortURL(cohortID string, cohort *Cohort) string { +func (api *directCohortDownloadApi) buildCohortURL(cohortID string, cohort *Cohort) string { url := api.ServerUrl + "/sdk/v1/cohort/" + cohortID + "?maxCohortSize=" + strconv.Itoa(api.MaxCohortSize) if cohort != nil { url += "&lastModified=" + strconv.FormatInt(cohort.LastModified, 10) diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go index 4351b67..8029283 100644 --- a/pkg/experiment/local/cohort_download_api_test.go +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -21,7 +21,7 @@ type cohortInfo struct { GroupType string `json:"groupType"` } -func (m *MockCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { +func (m *MockCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) (*Cohort, error) { args := m.Called(cohortID, cohort) if args.Get(0) != nil { return args.Get(0).(*Cohort), args.Error(1) @@ -33,7 +33,7 @@ func TestCohortDownloadApi(t *testing.T) { httpmock.Activate() defer httpmock.DeactivateAndReset() - api := NewDirectCohortDownloadApi("api", "secret", 15000, 100, "https://server.amplitude.com", false) + api := newDirectCohortDownloadApi("api", "secret", 15000, 100, "https://server.amplitude.com", false) t.Run("test_cohort_download_success", func(t *testing.T) { cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} @@ -49,7 +49,7 @@ func TestCohortDownloadApi(t *testing.T) { }, ) - resultCohort, err := api.GetCohort("1234", cohort) + resultCohort, err := api.getCohort("1234", cohort) assert.NoError(t, err) assert.Equal(t, cohort.Id, resultCohort.Id) assert.Equal(t, cohort.LastModified, resultCohort.LastModified) @@ -77,7 +77,7 @@ func TestCohortDownloadApi(t *testing.T) { }, ) - resultCohort, err := api.GetCohort("1234", cohort) + resultCohort, err := api.getCohort("1234", cohort) assert.NoError(t, err) assert.Equal(t, cohort.Id, resultCohort.Id) assert.Equal(t, cohort.LastModified, resultCohort.LastModified) @@ -106,7 +106,7 @@ func TestCohortDownloadApi(t *testing.T) { }, ) - resultCohort, err := api.GetCohort("1234", cohort) + resultCohort, err := api.getCohort("1234", cohort) assert.NoError(t, err) assert.Equal(t, cohort.Id, resultCohort.Id) assert.Equal(t, cohort.LastModified, resultCohort.LastModified) @@ -134,7 +134,7 @@ func TestCohortDownloadApi(t *testing.T) { }, ) - resultCohort, err := api.GetCohort("1234", cohort) + resultCohort, err := api.getCohort("1234", cohort) assert.NoError(t, err) assert.Equal(t, cohort.Id, resultCohort.Id) assert.Equal(t, cohort.LastModified, resultCohort.LastModified) @@ -157,7 +157,7 @@ func TestCohortDownloadApi(t *testing.T) { }, ) - resultCohort, err := api.GetCohort("1234", cohort) + resultCohort, err := api.getCohort("1234", cohort) assert.NoError(t, err) assert.Equal(t, cohort.Id, resultCohort.Id) assert.Equal(t, cohort.LastModified, resultCohort.LastModified) @@ -185,7 +185,7 @@ func TestCohortDownloadApi(t *testing.T) { }, ) - resultCohort, err := api.GetCohort("1234", cohort) + resultCohort, err := api.getCohort("1234", cohort) assert.NoError(t, err) assert.Equal(t, cohort.Id, resultCohort.Id) assert.Equal(t, cohort.LastModified, resultCohort.LastModified) @@ -201,7 +201,7 @@ func TestCohortDownloadApi(t *testing.T) { httpmock.NewStringResponder(413, ""), ) - _, err := api.GetCohort("1234", cohort) + _, err := api.getCohort("1234", cohort) assert.Error(t, err) _, isCohortTooLargeException := err.(*CohortTooLargeException) assert.True(t, isCohortTooLargeException) @@ -214,7 +214,7 @@ func TestCohortDownloadApi(t *testing.T) { httpmock.NewStringResponder(204, ""), ) - _, err := api.GetCohort("1234", cohort) + _, err := api.getCohort("1234", cohort) assert.Error(t, err) _, isCohortNotModifiedException := err.(*CohortNotModifiedException) assert.True(t, isCohortNotModifiedException) diff --git a/pkg/experiment/local/cohort_loader.go b/pkg/experiment/local/cohort_loader.go index eefddcd..7023306 100644 --- a/pkg/experiment/local/cohort_loader.go +++ b/pkg/experiment/local/cohort_loader.go @@ -1,20 +1,22 @@ package local import ( + "fmt" + "strings" "sync" "sync/atomic" ) -type CohortLoader struct { - cohortDownloadApi CohortDownloadApi +type cohortLoader struct { + cohortDownloadApi cohortDownloadApi cohortStorage CohortStorage jobs sync.Map executor *sync.Pool lockJobs sync.Mutex } -func NewCohortLoader(cohortDownloadApi CohortDownloadApi, cohortStorage CohortStorage) *CohortLoader { - return &CohortLoader{ +func newCohortLoader(cohortDownloadApi cohortDownloadApi, cohortStorage CohortStorage) *cohortLoader { + return &cohortLoader{ cohortDownloadApi: cohortDownloadApi, cohortStorage: cohortStorage, executor: &sync.Pool{ @@ -25,36 +27,36 @@ func NewCohortLoader(cohortDownloadApi CohortDownloadApi, cohortStorage CohortSt } } -func (cl *CohortLoader) LoadCohort(cohortID string) *CohortLoaderTask { +func (cl *cohortLoader) loadCohort(cohortId string) *CohortLoaderTask { cl.lockJobs.Lock() defer cl.lockJobs.Unlock() - task, ok := cl.jobs.Load(cohortID) + task, ok := cl.jobs.Load(cohortId) if !ok { task = cl.executor.Get().(*CohortLoaderTask) - task.(*CohortLoaderTask).init(cl, cohortID) - cl.jobs.Store(cohortID, task) + task.(*CohortLoaderTask).init(cl, cohortId) + cl.jobs.Store(cohortId, task) go task.(*CohortLoaderTask).run() } return task.(*CohortLoaderTask) } -func (cl *CohortLoader) removeJob(cohortID string) { - cl.jobs.Delete(cohortID) +func (cl *cohortLoader) removeJob(cohortId string) { + cl.jobs.Delete(cohortId) } type CohortLoaderTask struct { - loader *CohortLoader - cohortID string + loader *cohortLoader + cohortId string done int32 doneChan chan struct{} err error } -func (task *CohortLoaderTask) init(loader *CohortLoader, cohortID string) { +func (task *CohortLoaderTask) init(loader *cohortLoader, cohortId string) { task.loader = loader - task.cohortID = cohortID + task.cohortId = cohortId task.done = 0 task.doneChan = make(chan struct{}) task.err = nil @@ -63,24 +65,59 @@ func (task *CohortLoaderTask) init(loader *CohortLoader, cohortID string) { func (task *CohortLoaderTask) run() { defer task.loader.executor.Put(task) - cohort, err := task.loader.downloadCohort(task.cohortID) + cohort, err := task.loader.downloadCohort(task.cohortId) if err != nil { task.err = err } else { - task.loader.cohortStorage.PutCohort(cohort) + task.loader.cohortStorage.putCohort(cohort) } - task.loader.removeJob(task.cohortID) + task.loader.removeJob(task.cohortId) atomic.StoreInt32(&task.done, 1) close(task.doneChan) } -func (task *CohortLoaderTask) Wait() error { +func (task *CohortLoaderTask) wait() error { <-task.doneChan return task.err } -func (cl *CohortLoader) downloadCohort(cohortID string) (*Cohort, error) { - cohort := cl.cohortStorage.GetCohort(cohortID) - return cl.cohortDownloadApi.GetCohort(cohortID, cohort) +func (cl *cohortLoader) downloadCohort(cohortID string) (*Cohort, error) { + cohort := cl.cohortStorage.getCohort(cohortID) + return cl.cohortDownloadApi.getCohort(cohortID, cohort) +} + +func (cl *cohortLoader) updateStoredCohorts() error { + var wg sync.WaitGroup + errorChan := make(chan error, len(cl.cohortStorage.getCohortIds())) + + cohortIds := make([]string, 0, len(cl.cohortStorage.getCohortIds())) + for id := range cl.cohortStorage.getCohortIds() { + cohortIds = append(cohortIds, id) + } + + for _, cohortID := range cohortIds { + wg.Add(1) + go func(id string) { + defer wg.Done() + task := cl.loadCohort(id) + if err := task.wait(); err != nil { + errorChan <- fmt.Errorf("cohort %s: %v", id, err) + } + }(cohortID) + } + + wg.Wait() + close(errorChan) + + var errorMessages []string + for err := range errorChan { + errorMessages = append(errorMessages, err.Error()) + } + + if len(errorMessages) > 0 { + return fmt.Errorf("One or more cohorts failed to download:\n%s", + strings.Join(errorMessages, "\n")) + } + return nil } diff --git a/pkg/experiment/local/cohort_loader_test.go b/pkg/experiment/local/cohort_loader_test.go index d2dd799..4f483e1 100644 --- a/pkg/experiment/local/cohort_loader_test.go +++ b/pkg/experiment/local/cohort_loader_test.go @@ -9,25 +9,25 @@ import ( func TestLoadSuccess(t *testing.T) { api := &MockCohortDownloadApi{} - storage := NewInMemoryCohortStorage() - loader := NewCohortLoader(api, storage) + storage := newInMemoryCohortStorage() + loader := newCohortLoader(api, storage) // Define mock behavior - api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) - api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "b", LastModified: 0, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType}, nil) + api.On("getCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) + api.On("getCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "b", LastModified: 0, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType}, nil) - futureA := loader.LoadCohort("a") - futureB := loader.LoadCohort("b") + futureA := loader.loadCohort("a") + futureB := loader.loadCohort("b") - if err := futureA.Wait(); err != nil { - t.Errorf("futureA.Wait() returned error: %v", err) + if err := futureA.wait(); err != nil { + t.Errorf("futureA.wait() returned error: %v", err) } - if err := futureB.Wait(); err != nil { - t.Errorf("futureB.Wait() returned error: %v", err) + if err := futureB.wait(); err != nil { + t.Errorf("futureB.wait() returned error: %v", err) } - storageDescriptionA := storage.GetCohort("a") - storageDescriptionB := storage.GetCohort("b") + storageDescriptionA := storage.getCohort("a") + storageDescriptionB := storage.getCohort("b") expectedA := &Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType} expectedB := &Cohort{Id: "b", LastModified: 0, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType} @@ -38,8 +38,8 @@ func TestLoadSuccess(t *testing.T) { t.Errorf("Unexpected cohort B stored: %+v", storageDescriptionB) } - storageUser1Cohorts := storage.GetCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}}) - storageUser2Cohorts := storage.GetCohortsForUser("2", map[string]struct{}{"a": {}, "b": {}}) + storageUser1Cohorts := storage.getCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}}) + storageUser2Cohorts := storage.getCohortsForUser("2", map[string]struct{}{"a": {}, "b": {}}) if len(storageUser1Cohorts) != 2 || len(storageUser2Cohorts) != 1 { t.Errorf("Unexpected user cohorts: User1: %+v, User2: %+v", storageUser1Cohorts, storageUser2Cohorts) } @@ -47,28 +47,28 @@ func TestLoadSuccess(t *testing.T) { func TestFilterCohortsAlreadyComputed(t *testing.T) { api := &MockCohortDownloadApi{} - storage := NewInMemoryCohortStorage() - loader := NewCohortLoader(api, storage) + storage := newInMemoryCohortStorage() + loader := newCohortLoader(api, storage) - storage.PutCohort(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}}) - storage.PutCohort(&Cohort{Id: "b", LastModified: 0, Size: 0, MemberIds: []string{}}) + storage.putCohort(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}}) + storage.putCohort(&Cohort{Id: "b", LastModified: 0, Size: 0, MemberIds: []string{}}) // Define mock behavior - api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}, GroupType: userGroupType}, nil) - api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "b", LastModified: 1, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType}, nil) + api.On("getCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}, GroupType: userGroupType}, nil) + api.On("getCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "b", LastModified: 1, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType}, nil) - futureA := loader.LoadCohort("a") - futureB := loader.LoadCohort("b") + futureA := loader.loadCohort("a") + futureB := loader.loadCohort("b") - if err := futureA.Wait(); err != nil { - t.Errorf("futureA.Wait() returned error: %v", err) + if err := futureA.wait(); err != nil { + t.Errorf("futureA.wait() returned error: %v", err) } - if err := futureB.Wait(); err != nil { - t.Errorf("futureB.Wait() returned error: %v", err) + if err := futureB.wait(); err != nil { + t.Errorf("futureB.wait() returned error: %v", err) } - storageDescriptionA := storage.GetCohort("a") - storageDescriptionB := storage.GetCohort("b") + storageDescriptionA := storage.getCohort("a") + storageDescriptionB := storage.getCohort("b") expectedA := &Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}, GroupType: userGroupType} expectedB := &Cohort{Id: "b", LastModified: 1, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType} @@ -79,8 +79,8 @@ func TestFilterCohortsAlreadyComputed(t *testing.T) { t.Errorf("Unexpected cohort B stored: %+v", storageDescriptionB) } - storageUser1Cohorts := storage.GetCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}}) - storageUser2Cohorts := storage.GetCohortsForUser("2", map[string]struct{}{"a": {}, "b": {}}) + storageUser1Cohorts := storage.getCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}}) + storageUser2Cohorts := storage.getCohortsForUser("2", map[string]struct{}{"a": {}, "b": {}}) if len(storageUser1Cohorts) != 1 || len(storageUser2Cohorts) != 1 { t.Errorf("Unexpected user cohorts: User1: %+v, User2: %+v", storageUser1Cohorts, storageUser2Cohorts) } @@ -88,32 +88,32 @@ func TestFilterCohortsAlreadyComputed(t *testing.T) { func TestLoadDownloadFailureThrows(t *testing.T) { api := &MockCohortDownloadApi{} - storage := NewInMemoryCohortStorage() - loader := NewCohortLoader(api, storage) + storage := newInMemoryCohortStorage() + loader := newCohortLoader(api, storage) // Define mock behavior - api.On("GetCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) - api.On("GetCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(nil, errors.New("connection timed out")) - api.On("GetCohort", "c", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "c", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) + api.On("getCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) + api.On("getCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(nil, errors.New("connection timed out")) + api.On("getCohort", "c", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "c", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) - futureA := loader.LoadCohort("a") - errB := loader.LoadCohort("b").Wait() - futureC := loader.LoadCohort("c") + futureA := loader.loadCohort("a") + errB := loader.loadCohort("b").wait() + futureC := loader.loadCohort("c") - if err := futureA.Wait(); err != nil { - t.Errorf("futureA.Wait() returned error: %v", err) + if err := futureA.wait(); err != nil { + t.Errorf("futureA.wait() returned error: %v", err) } if errB == nil || errB.Error() != "connection timed out" { - t.Errorf("futureB.Wait() expected 'Connection timed out' error, got: %v", errB) + t.Errorf("futureB.wait() expected 'Connection timed out' error, got: %v", errB) } - if err := futureC.Wait(); err != nil { - t.Errorf("futureC.Wait() returned error: %v", err) + if err := futureC.wait(); err != nil { + t.Errorf("futureC.wait() returned error: %v", err) } expectedCohorts := map[string]struct{}{"a": {}, "c": {}} - actualCohorts := storage.GetCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}, "c": {}}) + actualCohorts := storage.getCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}, "c": {}}) if len(actualCohorts) != len(expectedCohorts) { t.Errorf("Expected cohorts for user '1': %+v, but got: %+v", expectedCohorts, actualCohorts) } diff --git a/pkg/experiment/local/cohort_storage.go b/pkg/experiment/local/cohort_storage.go index ed879ce..66981e5 100644 --- a/pkg/experiment/local/cohort_storage.go +++ b/pkg/experiment/local/cohort_storage.go @@ -5,34 +5,35 @@ import ( ) type CohortStorage interface { - GetCohort(cohortID string) *Cohort - GetCohorts() map[string]*Cohort - GetCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} - GetCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} - PutCohort(cohort *Cohort) - DeleteCohort(groupType, cohortID string) + getCohort(cohortID string) *Cohort + getCohorts() map[string]*Cohort + getCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} + getCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} + putCohort(cohort *Cohort) + deleteCohort(groupType, cohortID string) + getCohortIds() map[string]struct{} } -type InMemoryCohortStorage struct { +type inMemoryCohortStorage struct { lock sync.RWMutex groupToCohortStore map[string]map[string]struct{} cohortStore map[string]*Cohort } -func NewInMemoryCohortStorage() *InMemoryCohortStorage { - return &InMemoryCohortStorage{ +func newInMemoryCohortStorage() *inMemoryCohortStorage { + return &inMemoryCohortStorage{ groupToCohortStore: make(map[string]map[string]struct{}), cohortStore: make(map[string]*Cohort), } } -func (s *InMemoryCohortStorage) GetCohort(cohortID string) *Cohort { +func (s *inMemoryCohortStorage) getCohort(cohortID string) *Cohort { s.lock.RLock() defer s.lock.RUnlock() return s.cohortStore[cohortID] } -func (s *InMemoryCohortStorage) GetCohorts() map[string]*Cohort { +func (s *inMemoryCohortStorage) getCohorts() map[string]*Cohort { s.lock.RLock() defer s.lock.RUnlock() cohorts := make(map[string]*Cohort) @@ -42,11 +43,11 @@ func (s *InMemoryCohortStorage) GetCohorts() map[string]*Cohort { return cohorts } -func (s *InMemoryCohortStorage) GetCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} { - return s.GetCohortsForGroup(userGroupType, userID, cohortIDs) +func (s *inMemoryCohortStorage) getCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} { + return s.getCohortsForGroup(userGroupType, userID, cohortIDs) } -func (s *InMemoryCohortStorage) GetCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} { +func (s *inMemoryCohortStorage) getCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} { result := make(map[string]struct{}) s.lock.RLock() defer s.lock.RUnlock() @@ -72,7 +73,7 @@ func (s *InMemoryCohortStorage) GetCohortsForGroup(groupType, groupName string, return result } -func (s *InMemoryCohortStorage) PutCohort(cohort *Cohort) { +func (s *inMemoryCohortStorage) putCohort(cohort *Cohort) { s.lock.Lock() defer s.lock.Unlock() if _, exists := s.groupToCohortStore[cohort.GroupType]; !exists { @@ -82,7 +83,7 @@ func (s *InMemoryCohortStorage) PutCohort(cohort *Cohort) { s.cohortStore[cohort.Id] = cohort } -func (s *InMemoryCohortStorage) DeleteCohort(groupType, cohortID string) { +func (s *inMemoryCohortStorage) deleteCohort(groupType, cohortID string) { s.lock.Lock() defer s.lock.Unlock() if groupCohorts, exists := s.groupToCohortStore[groupType]; exists { @@ -93,3 +94,13 @@ func (s *InMemoryCohortStorage) DeleteCohort(groupType, cohortID string) { } delete(s.cohortStore, cohortID) } + +func (s *inMemoryCohortStorage) getCohortIds() map[string]struct{} { + s.lock.RLock() + defer s.lock.RUnlock() + cohortIds := make(map[string]struct{}) + for id := range s.cohortStore { + cohortIds[id] = struct{}{} + } + return cohortIds +} diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 95e88c6..cf82a2f 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -1,30 +1,33 @@ package local import ( + "fmt" + "strings" + "sync" + "github.com/amplitude/experiment-go-server/internal/evaluation" "github.com/amplitude/experiment-go-server/internal/logger" - "sync" ) -type DeploymentRunner struct { +type deploymentRunner struct { config *Config - flagConfigApi FlagConfigApi - flagConfigStorage FlagConfigStorage + flagConfigApi flagConfigApi + flagConfigStorage flagConfigStorage cohortStorage CohortStorage - cohortLoader *CohortLoader + cohortLoader *cohortLoader lock sync.Mutex poller *poller log *logger.Log } -func NewDeploymentRunner( +func newDeploymentRunner( config *Config, - flagConfigApi FlagConfigApi, - flagConfigStorage FlagConfigStorage, + flagConfigApi flagConfigApi, + flagConfigStorage flagConfigStorage, cohortStorage CohortStorage, - cohortLoader *CohortLoader, -) *DeploymentRunner { - dr := &DeploymentRunner{ + cohortLoader *cohortLoader, +) *deploymentRunner { + dr := &deploymentRunner{ config: config, flagConfigApi: flagConfigApi, flagConfigStorage: flagConfigStorage, @@ -36,33 +39,39 @@ func NewDeploymentRunner( return dr } -func (dr *DeploymentRunner) Start() error { +func (dr *deploymentRunner) start() error { dr.lock.Lock() defer dr.lock.Unlock() - if err := dr.refresh(); err != nil { - dr.log.Error("Initial refresh failed: %v", err) + if err := dr.updateFlagConfigs(); err != nil { + dr.log.Error("Initial updateFlagConfigs failed: %v", err) return err } dr.poller.Poll(dr.config.FlagConfigPollerInterval, func() { if err := dr.periodicRefresh(); err != nil { - dr.log.Error("Periodic refresh failed: %v", err) + dr.log.Error("Periodic updateFlagConfigs failed: %v", err) } }) + + if dr.cohortLoader != nil { + dr.poller.Poll(dr.config.FlagConfigPollerInterval, func() { + dr.updateStoredCohorts() + }) + } return nil } -func (dr *DeploymentRunner) periodicRefresh() error { +func (dr *deploymentRunner) periodicRefresh() error { defer func() { if r := recover(); r != nil { dr.log.Error("Recovered in periodicRefresh: %v", r) } }() - return dr.refresh() + return dr.updateFlagConfigs() } -func (dr *DeploymentRunner) refresh() error { +func (dr *deploymentRunner) updateFlagConfigs() error { dr.log.Debug("Refreshing flag configs.") flagConfigs, err := dr.flagConfigApi.GetFlagConfigs() if err != nil { @@ -75,80 +84,111 @@ func (dr *DeploymentRunner) refresh() error { flagKeys[flag.Key] = struct{}{} } - dr.flagConfigStorage.RemoveIf(func(f *evaluation.Flag) bool { + dr.flagConfigStorage.removeIf(func(f *evaluation.Flag) bool { _, exists := flagKeys[f.Key] return !exists }) - for _, flagConfig := range flagConfigs { - cohortIDs := getAllCohortIDsFromFlag(flagConfig) - if dr.cohortLoader == nil || len(cohortIDs) == 0 { + if dr.cohortLoader == nil { + for _, flagConfig := range flagConfigs { dr.log.Debug("Putting non-cohort flag %s", flagConfig.Key) - dr.flagConfigStorage.PutFlagConfig(flagConfig) - continue + dr.flagConfigStorage.putFlagConfig(flagConfig) + } + return nil + } + + newCohortIDs := make(map[string]struct{}) + for _, flagConfig := range flagConfigs { + for cohortID := range getAllCohortIDsFromFlag(flagConfig) { + newCohortIDs[cohortID] = struct{}{} } + } - oldFlagConfig := dr.flagConfigStorage.GetFlagConfig(flagConfig.Key) + existingCohortIDs := dr.cohortStorage.getCohortIds() + cohortIDsToDownload := difference(newCohortIDs, existingCohortIDs) + var cohortDownloadErrors []string - err := dr.loadCohorts(*flagConfig, cohortIDs) - if err != nil { - if oldFlagConfig != nil { - dr.log.Error("Failed to load all cohorts for flag %s. Using the old flag config.", flagConfig.Key) - dr.flagConfigStorage.PutFlagConfig(oldFlagConfig) - } - return err + // Download all new cohorts + for cohortID := range cohortIDsToDownload { + if err := dr.cohortLoader.loadCohort(cohortID).wait(); err != nil { + cohortDownloadErrors = append(cohortDownloadErrors, fmt.Sprintf("Cohort %s: %v", cohortID, err)) + dr.log.Error("Download cohort %s failed: %v", cohortID, err) } + } - dr.flagConfigStorage.PutFlagConfig(flagConfig) - dr.log.Debug("Stored flag config %s", flagConfig.Key) + // Get updated set of cohort ids + updatedCohortIDs := dr.cohortStorage.getCohortIds() + // Iterate through new flag configs and check if their required cohorts exist + failedFlagCount := 0 + for _, flagConfig := range flagConfigs { + cohortIDs := getAllCohortIDsFromFlag(flagConfig) + if len(cohortIDs) == 0 || dr.cohortLoader == nil { + dr.flagConfigStorage.putFlagConfig(flagConfig) + dr.log.Debug("Putting non-cohort flag %s", flagConfig.Key) + } else if subset(cohortIDs, updatedCohortIDs) { + dr.flagConfigStorage.putFlagConfig(flagConfig) + dr.log.Debug("Putting flag %s", flagConfig.Key) + } else { + dr.log.Error("Flag %s not updated because not all required cohorts could be loaded", flagConfig.Key) + failedFlagCount++ + } } + // Delete unused cohorts dr.deleteUnusedCohorts() - dr.log.Debug("Refreshed %d flag configs.", len(flagConfigs)) + dr.log.Debug("Refreshed %d flag configs.", len(flagConfigs)-failedFlagCount) + + // If there are any download errors, raise an aggregated exception + if len(cohortDownloadErrors) > 0 { + errorCount := len(cohortDownloadErrors) + errorMessages := strings.Join(cohortDownloadErrors, "\n") + return fmt.Errorf("%d cohort(s) failed to download:\n%s", errorCount, errorMessages) + } + return nil } -func (dr *DeploymentRunner) loadCohorts(flagConfig evaluation.Flag, cohortIDs map[string]struct{}) error { - task := func() error { - for cohortID := range cohortIDs { - task := dr.cohortLoader.LoadCohort(cohortID) - err := task.Wait() - if err != nil { - if _, ok := err.(*CohortNotModifiedException); !ok { - dr.log.Error("Failed to load cohort %s for flag %s: %v", cohortID, flagConfig.Key, err) - return err - } - continue - } - dr.log.Debug("Cohort %s loaded for flag %s", cohortID, flagConfig.Key) - } - return nil +func (dr *deploymentRunner) updateStoredCohorts() { + err := dr.cohortLoader.updateStoredCohorts() + if err != nil { + dr.log.Error("Error updating stored cohorts: %v", err) } - - // Using a goroutine to simulate async task execution - errCh := make(chan error) - go func() { - errCh <- task() - }() - err := <-errCh - return err } -func (dr *DeploymentRunner) deleteUnusedCohorts() { +func (dr *deploymentRunner) deleteUnusedCohorts() { flagCohortIDs := make(map[string]struct{}) - for _, flag := range dr.flagConfigStorage.GetFlagConfigs() { + for _, flag := range dr.flagConfigStorage.getFlagConfigs() { for cohortID := range getAllCohortIDsFromFlag(flag) { flagCohortIDs[cohortID] = struct{}{} } } - storageCohorts := dr.cohortStorage.GetCohorts() + storageCohorts := dr.cohortStorage.getCohorts() for cohortID := range storageCohorts { if _, exists := flagCohortIDs[cohortID]; !exists { cohort := storageCohorts[cohortID] if cohort != nil { - dr.cohortStorage.DeleteCohort(cohort.GroupType, cohortID) + dr.cohortStorage.deleteCohort(cohort.GroupType, cohortID) } } } } + +func difference(set1, set2 map[string]struct{}) map[string]struct{} { + diff := make(map[string]struct{}) + for k := range set1 { + if _, exists := set2[k]; !exists { + diff[k] = struct{}{} + } + } + return diff +} + +func subset(subset, set map[string]struct{}) bool { + for k := range subset { + if _, exists := set[k]; !exists { + return false + } + } + return true +} diff --git a/pkg/experiment/local/deployment_runner_test.go b/pkg/experiment/local/deployment_runner_test.go index 2f1dc85..6eb59da 100644 --- a/pkg/experiment/local/deployment_runner_test.go +++ b/pkg/experiment/local/deployment_runner_test.go @@ -18,10 +18,10 @@ func TestStartThrowsIfFirstFlagConfigLoadFails(t *testing.T) { }} cohortDownloadAPI := &mockCohortDownloadApi{} flagConfigStorage := NewInMemoryFlagConfigStorage() - cohortStorage := NewInMemoryCohortStorage() - cohortLoader := NewCohortLoader(cohortDownloadAPI, cohortStorage) + cohortStorage := newInMemoryCohortStorage() + cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage) - runner := NewDeploymentRunner( + runner := newDeploymentRunner( &Config{}, flagAPI, flagConfigStorage, @@ -29,7 +29,7 @@ func TestStartThrowsIfFirstFlagConfigLoadFails(t *testing.T) { cohortLoader, ) - err := runner.Start() + err := runner.start() if err == nil { t.Error("Expected error but got nil") @@ -44,10 +44,10 @@ func TestStartThrowsIfFirstCohortLoadFails(t *testing.T) { return nil, errors.New("test") }} flagConfigStorage := NewInMemoryFlagConfigStorage() - cohortStorage := NewInMemoryCohortStorage() - cohortLoader := NewCohortLoader(cohortDownloadAPI, cohortStorage) + cohortStorage := newInMemoryCohortStorage() + cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage) - runner := NewDeploymentRunner( + runner := newDeploymentRunner( DefaultConfig, flagAPI, flagConfigStorage, @@ -55,7 +55,7 @@ func TestStartThrowsIfFirstCohortLoadFails(t *testing.T) { cohortLoader, ) - err := runner.Start() + err := runner.start() if err == nil { t.Error("Expected error but got nil") @@ -77,7 +77,7 @@ type mockCohortDownloadApi struct { getCohortFunc func(cohortID string, cohort *Cohort) (*Cohort, error) } -func (m *mockCohortDownloadApi) GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) { +func (m *mockCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) (*Cohort, error) { if m.getCohortFunc != nil { return m.getCohortFunc(cohortID, cohort) } diff --git a/pkg/experiment/local/exception.go b/pkg/experiment/local/exception.go index 402f0e1..c151916 100644 --- a/pkg/experiment/local/exception.go +++ b/pkg/experiment/local/exception.go @@ -24,7 +24,3 @@ type CohortNotModifiedException struct { func (e *CohortNotModifiedException) Error() string { return e.Message } - -type CohortDownloadApi interface { - GetCohort(cohortID string, cohort *Cohort) (*Cohort, error) -} diff --git a/pkg/experiment/local/flag_config_api.go b/pkg/experiment/local/flag_config_api.go index f9d8649..d6ff58a 100644 --- a/pkg/experiment/local/flag_config_api.go +++ b/pkg/experiment/local/flag_config_api.go @@ -12,25 +12,25 @@ import ( "time" ) -type FlagConfigApi interface { +type flagConfigApi interface { GetFlagConfigs() (map[string]*evaluation.Flag, error) } -type FlagConfigApiV2 struct { +type flagConfigApiV2 struct { DeploymentKey string ServerURL string FlagConfigPollerRequestTimeoutMillis time.Duration } -func NewFlagConfigApiV2(deploymentKey, serverURL string, flagConfigPollerRequestTimeoutMillis time.Duration) *FlagConfigApiV2 { - return &FlagConfigApiV2{ +func newFlagConfigApiV2(deploymentKey, serverURL string, flagConfigPollerRequestTimeoutMillis time.Duration) *flagConfigApiV2 { + return &flagConfigApiV2{ DeploymentKey: deploymentKey, ServerURL: serverURL, FlagConfigPollerRequestTimeoutMillis: flagConfigPollerRequestTimeoutMillis, } } -func (a *FlagConfigApiV2) GetFlagConfigs() (map[string]*evaluation.Flag, error) { +func (a *flagConfigApiV2) GetFlagConfigs() (map[string]*evaluation.Flag, error) { client := &http.Client{} endpoint, err := url.Parse("https://api.lab.amplitude.com/") if err != nil { diff --git a/pkg/experiment/local/flag_config_storage.go b/pkg/experiment/local/flag_config_storage.go index 12229ba..69ad9d3 100644 --- a/pkg/experiment/local/flag_config_storage.go +++ b/pkg/experiment/local/flag_config_storage.go @@ -5,31 +5,31 @@ import ( "sync" ) -type FlagConfigStorage interface { - GetFlagConfig(key string) *evaluation.Flag - GetFlagConfigs() map[string]*evaluation.Flag - PutFlagConfig(flagConfig *evaluation.Flag) - RemoveIf(condition func(*evaluation.Flag) bool) +type flagConfigStorage interface { + getFlagConfig(key string) *evaluation.Flag + getFlagConfigs() map[string]*evaluation.Flag + putFlagConfig(flagConfig *evaluation.Flag) + removeIf(condition func(*evaluation.Flag) bool) } -type InMemoryFlagConfigStorage struct { +type inMemoryFlagConfigStorage struct { flagConfigs map[string]*evaluation.Flag flagConfigsLock sync.Mutex } -func NewInMemoryFlagConfigStorage() *InMemoryFlagConfigStorage { - return &InMemoryFlagConfigStorage{ +func NewInMemoryFlagConfigStorage() *inMemoryFlagConfigStorage { + return &inMemoryFlagConfigStorage{ flagConfigs: make(map[string]*evaluation.Flag), } } -func (storage *InMemoryFlagConfigStorage) GetFlagConfig(key string) *evaluation.Flag { +func (storage *inMemoryFlagConfigStorage) getFlagConfig(key string) *evaluation.Flag { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() return storage.flagConfigs[key] } -func (storage *InMemoryFlagConfigStorage) GetFlagConfigs() map[string]*evaluation.Flag { +func (storage *inMemoryFlagConfigStorage) getFlagConfigs() map[string]*evaluation.Flag { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() copyFlagConfigs := make(map[string]*evaluation.Flag) @@ -39,13 +39,13 @@ func (storage *InMemoryFlagConfigStorage) GetFlagConfigs() map[string]*evaluatio return copyFlagConfigs } -func (storage *InMemoryFlagConfigStorage) PutFlagConfig(flagConfig *evaluation.Flag) { +func (storage *inMemoryFlagConfigStorage) putFlagConfig(flagConfig *evaluation.Flag) { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() storage.flagConfigs[flagConfig.Key] = flagConfig } -func (storage *InMemoryFlagConfigStorage) RemoveIf(condition func(*evaluation.Flag) bool) { +func (storage *inMemoryFlagConfigStorage) removeIf(condition func(*evaluation.Flag) bool) { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() for key, value := range storage.flagConfigs { From 1df8194a2b2f7c9e342b1fd782a5e07b29e18084 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 3 Jul 2024 13:22:20 -0700 Subject: [PATCH 13/29] Fix class/method names --- pkg/experiment/local/client.go | 4 ++-- pkg/experiment/local/deployment_runner.go | 2 +- pkg/experiment/local/deployment_runner_test.go | 6 +++--- pkg/experiment/local/flag_config_api.go | 4 ++-- pkg/experiment/local/flag_config_storage.go | 2 +- pkg/experiment/local/flag_config_util.go | 7 ------- pkg/experiment/types.go | 4 ++-- 7 files changed, 11 insertions(+), 18 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 5552adc..73a1d69 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -54,7 +54,7 @@ func Initialize(apiKey string, config *Config) *Client { } } cohortStorage := newInMemoryCohortStorage() - flagConfigStorage := NewInMemoryFlagConfigStorage() + flagConfigStorage := newInMemoryFlagConfigStorage() var cohortLoader *cohortLoader var deploymentRunner *deploymentRunner if config.CohortSyncConfig != nil { @@ -362,7 +362,7 @@ func (c *Client) enrichUser(user *experiment.User, flagConfigs map[string]*evalu continue } if cohortIDs, ok := groupedCohortIDs[groupType]; ok { - user.AddGroupCohortIDs(groupType, groupName, c.cohortStorage.getCohortsForGroup(groupType, groupName, cohortIDs)) + user.AddGroupCohortIds(groupType, groupName, c.cohortStorage.getCohortsForGroup(groupType, groupName, cohortIDs)) } } } diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index cf82a2f..0f7b133 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -73,7 +73,7 @@ func (dr *deploymentRunner) periodicRefresh() error { func (dr *deploymentRunner) updateFlagConfigs() error { dr.log.Debug("Refreshing flag configs.") - flagConfigs, err := dr.flagConfigApi.GetFlagConfigs() + flagConfigs, err := dr.flagConfigApi.getFlagConfigs() if err != nil { dr.log.Error("Failed to fetch flag configs: %v", err) return err diff --git a/pkg/experiment/local/deployment_runner_test.go b/pkg/experiment/local/deployment_runner_test.go index 6eb59da..5c22064 100644 --- a/pkg/experiment/local/deployment_runner_test.go +++ b/pkg/experiment/local/deployment_runner_test.go @@ -17,7 +17,7 @@ func TestStartThrowsIfFirstFlagConfigLoadFails(t *testing.T) { return nil, errors.New("test") }} cohortDownloadAPI := &mockCohortDownloadApi{} - flagConfigStorage := NewInMemoryFlagConfigStorage() + flagConfigStorage := newInMemoryFlagConfigStorage() cohortStorage := newInMemoryCohortStorage() cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage) @@ -43,7 +43,7 @@ func TestStartThrowsIfFirstCohortLoadFails(t *testing.T) { cohortDownloadAPI := &mockCohortDownloadApi{getCohortFunc: func(cohortID string, cohort *Cohort) (*Cohort, error) { return nil, errors.New("test") }} - flagConfigStorage := NewInMemoryFlagConfigStorage() + flagConfigStorage := newInMemoryFlagConfigStorage() cohortStorage := newInMemoryCohortStorage() cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage) @@ -66,7 +66,7 @@ type mockFlagConfigApi struct { getFlagConfigsFunc func() (map[string]*evaluation.Flag, error) } -func (m *mockFlagConfigApi) GetFlagConfigs() (map[string]*evaluation.Flag, error) { +func (m *mockFlagConfigApi) getFlagConfigs() (map[string]*evaluation.Flag, error) { if m.getFlagConfigsFunc != nil { return m.getFlagConfigsFunc() } diff --git a/pkg/experiment/local/flag_config_api.go b/pkg/experiment/local/flag_config_api.go index d6ff58a..2964891 100644 --- a/pkg/experiment/local/flag_config_api.go +++ b/pkg/experiment/local/flag_config_api.go @@ -13,7 +13,7 @@ import ( ) type flagConfigApi interface { - GetFlagConfigs() (map[string]*evaluation.Flag, error) + getFlagConfigs() (map[string]*evaluation.Flag, error) } type flagConfigApiV2 struct { @@ -30,7 +30,7 @@ func newFlagConfigApiV2(deploymentKey, serverURL string, flagConfigPollerRequest } } -func (a *flagConfigApiV2) GetFlagConfigs() (map[string]*evaluation.Flag, error) { +func (a *flagConfigApiV2) getFlagConfigs() (map[string]*evaluation.Flag, error) { client := &http.Client{} endpoint, err := url.Parse("https://api.lab.amplitude.com/") if err != nil { diff --git a/pkg/experiment/local/flag_config_storage.go b/pkg/experiment/local/flag_config_storage.go index 69ad9d3..4635354 100644 --- a/pkg/experiment/local/flag_config_storage.go +++ b/pkg/experiment/local/flag_config_storage.go @@ -17,7 +17,7 @@ type inMemoryFlagConfigStorage struct { flagConfigsLock sync.Mutex } -func NewInMemoryFlagConfigStorage() *inMemoryFlagConfigStorage { +func newInMemoryFlagConfigStorage() *inMemoryFlagConfigStorage { return &inMemoryFlagConfigStorage{ flagConfigs: make(map[string]*evaluation.Flag), } diff --git a/pkg/experiment/local/flag_config_util.go b/pkg/experiment/local/flag_config_util.go index 2c4231e..d3fc46b 100644 --- a/pkg/experiment/local/flag_config_util.go +++ b/pkg/experiment/local/flag_config_util.go @@ -4,7 +4,6 @@ import ( "github.com/amplitude/experiment-go-server/internal/evaluation" ) -// isCohortFilter checks if the condition is a cohort filter. func isCohortFilter(condition *evaluation.Condition) bool { op := condition.Op selector := condition.Selector @@ -14,7 +13,6 @@ func isCohortFilter(condition *evaluation.Condition) bool { return false } -// getGroupedCohortConditionIDs extracts grouped cohort condition IDs from a segment. func getGroupedCohortConditionIDs(segment *evaluation.Segment) map[string]map[string]struct{} { cohortIDs := make(map[string]map[string]struct{}) if segment == nil { @@ -47,7 +45,6 @@ func getGroupedCohortConditionIDs(segment *evaluation.Segment) map[string]map[st return cohortIDs } -// getGroupedCohortIDsFromFlag extracts grouped cohort IDs from a flag. func getGroupedCohortIDsFromFlag(flag *evaluation.Flag) map[string]map[string]struct{} { cohortIDs := make(map[string]map[string]struct{}) for _, segment := range flag.Segments { @@ -63,7 +60,6 @@ func getGroupedCohortIDsFromFlag(flag *evaluation.Flag) map[string]map[string]st return cohortIDs } -// getAllCohortIDsFromFlag extracts all cohort IDs from a flag. func getAllCohortIDsFromFlag(flag *evaluation.Flag) map[string]struct{} { cohortIDs := make(map[string]struct{}) groupedIDs := getGroupedCohortIDsFromFlag(flag) @@ -75,7 +71,6 @@ func getAllCohortIDsFromFlag(flag *evaluation.Flag) map[string]struct{} { return cohortIDs } -// getGroupedCohortIDsFromFlags extracts grouped cohort IDs from multiple flags. func getGroupedCohortIDsFromFlags(flags []*evaluation.Flag) map[string]map[string]struct{} { cohortIDs := make(map[string]map[string]struct{}) for _, flag := range flags { @@ -91,7 +86,6 @@ func getGroupedCohortIDsFromFlags(flags []*evaluation.Flag) map[string]map[strin return cohortIDs } -// getAllCohortIDsFromFlags extracts all cohort IDs from multiple flags. func getAllCohortIDsFromFlags(flags []*evaluation.Flag) map[string]struct{} { cohortIDs := make(map[string]struct{}) for _, flag := range flags { @@ -102,7 +96,6 @@ func getAllCohortIDsFromFlags(flags []*evaluation.Flag) map[string]struct{} { return cohortIDs } -// helper function to check if selector contains groups func selectorContainsGroups(selector []string) bool { for _, s := range selector { if s == "groups" { diff --git a/pkg/experiment/types.go b/pkg/experiment/types.go index 19c465a..378b319 100644 --- a/pkg/experiment/types.go +++ b/pkg/experiment/types.go @@ -28,7 +28,7 @@ type User struct { lock sync.Mutex } -func (u *User) AddGroupCohortIDs(groupType, groupName string, cohortIDs map[string]struct{}) { +func (u *User) AddGroupCohortIds(groupType, groupName string, cohortIds map[string]struct{}) { u.lock.Lock() defer u.lock.Unlock() @@ -42,7 +42,7 @@ func (u *User) AddGroupCohortIDs(groupType, groupName string, cohortIDs map[stri u.GroupCohortIds[groupType] = groupNames } - groupNames[groupName] = cohortIDs + groupNames[groupName] = cohortIds } type Variant struct { From b76f0918bc4ccaea3e44522950f5d7b34daf91d9 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 15 Jul 2024 10:38:00 -0700 Subject: [PATCH 14/29] remove lock for user type --- pkg/experiment/types.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pkg/experiment/types.go b/pkg/experiment/types.go index 378b319..08ea1a1 100644 --- a/pkg/experiment/types.go +++ b/pkg/experiment/types.go @@ -1,7 +1,5 @@ package experiment -import "sync" - const VERSION = "1.5.0" type User struct { @@ -25,13 +23,9 @@ type User struct { Groups map[string][]string `json:"groups,omitempty"` CohortIds map[string]struct{} `json:"cohort_ids,omitempty"` GroupCohortIds map[string]map[string]map[string]struct{} `json:"group_cohort_ids,omitempty"` - lock sync.Mutex } func (u *User) AddGroupCohortIds(groupType, groupName string, cohortIds map[string]struct{}) { - u.lock.Lock() - defer u.lock.Unlock() - if u.GroupCohortIds == nil { u.GroupCohortIds = make(map[string]map[string]map[string]struct{}) } From 2c69e4fad2349bea2106c383c9a513189c9024e4 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 25 Jul 2024 11:52:24 -0700 Subject: [PATCH 15/29] add CI tests and update build.yml --- .github/workflows/build.yml | 6 +++ .gitignore | 1 + go.mod | 1 + go.sum | 2 + pkg/experiment/local/client_eu_test.go | 46 ++++++++++++++++++++ pkg/experiment/local/client_test.go | 56 +++++++++++++++++++++++-- pkg/experiment/local/flag_config_api.go | 2 +- 7 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 pkg/experiment/local/client_eu_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a1fcf5e..2c361ef 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,6 +25,7 @@ jobs: with: version: latest test: + environment: Unit Test runs-on: 'ubuntu-latest' steps: - name: Checkout @@ -35,5 +36,10 @@ jobs: go-version: '1.17' check-latest: true - name: Test + env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} run: | go test ./... diff --git a/.gitignore b/.gitignore index ff53c67..9e9437e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ xpmt .DS_Store cmd/xpmt/bin/ +pkg/experiment/local/.env diff --git a/go.mod b/go.mod index 81525f6..bad5092 100644 --- a/go.mod +++ b/go.mod @@ -7,5 +7,6 @@ require github.com/spaolacci/murmur3 v1.1.0 require ( github.com/amplitude/analytics-go v1.0.1 github.com/jarcoal/httpmock v1.3.1 + github.com/joho/godotenv v1.5.1 github.com/stretchr/testify v1.9.0 ) diff --git a/go.sum b/go.sum index 46aa182..ec3c7e2 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pkg/experiment/local/client_eu_test.go b/pkg/experiment/local/client_eu_test.go new file mode 100644 index 0000000..b4a01ff --- /dev/null +++ b/pkg/experiment/local/client_eu_test.go @@ -0,0 +1,46 @@ +package local + +import ( + "github.com/amplitude/experiment-go-server/pkg/experiment" + "github.com/joho/godotenv" + "log" + "os" + "testing" +) + +var clientEU *Client + +func init() { + err := godotenv.Load() + if err != nil { + log.Fatalf("Error loading .env file: %v", err) + } + projectApiKey := os.Getenv("EU_API_KEY") + secretKey := os.Getenv("EU_SECRET_KEY") + cohortSyncConfig := CohortSyncConfig{ + ApiKey: projectApiKey, + SecretKey: secretKey, + } + clientEU = Initialize("server-Qlp7XiSu6JtP2S3JzA95PnP27duZgQCF", + &Config{CohortSyncConfig: &cohortSyncConfig, ServerZone: "eu"}) + err = clientEU.Start() + if err != nil { + panic(err) + } +} + +func TestEvaluateV2CohortEU(t *testing.T) { + user := &experiment.User{UserId: "1", DeviceId: "0"} + flagKeys := []string{"sdk-local-evaluation-user-cohort"} + result, err := clientEU.EvaluateV2(user, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-user-cohort"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } +} diff --git a/pkg/experiment/local/client_test.go b/pkg/experiment/local/client_test.go index be7cc67..2cb346b 100644 --- a/pkg/experiment/local/client_test.go +++ b/pkg/experiment/local/client_test.go @@ -2,14 +2,28 @@ package local import ( "github.com/amplitude/experiment-go-server/pkg/experiment" + "github.com/joho/godotenv" + "log" + "os" "testing" ) var client *Client func init() { - client = Initialize("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", nil) - err := client.Start() + err := godotenv.Load() + if err != nil { + log.Fatalf("Error loading .env file: %v", err) + } + projectApiKey := os.Getenv("API_KEY") + secretKey := os.Getenv("SECRET_KEY") + cohortSyncConfig := CohortSyncConfig{ + ApiKey: projectApiKey, + SecretKey: secretKey, + } + client = Initialize("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", + &Config{CohortSyncConfig: &cohortSyncConfig}) + err = client.Start() if err != nil { panic(err) } @@ -52,7 +66,6 @@ func TestEvaluate(t *testing.T) { } } - func TestEvaluateV2AllFlags(t *testing.T) { user := &experiment.User{UserId: "test_user"} result, err := client.EvaluateV2(user, nil) @@ -157,3 +170,40 @@ func TestFlagMetadataLocalFlagKey(t *testing.T) { t.Fatalf("Unexpected metadata %v", md) } } + +func TestEvaluateV2Cohort(t *testing.T) { + user := &experiment.User{UserId: "12345"} + flagKeys := []string{"sdk-local-evaluation-user-cohort-ci-test"} + result, err := client.EvaluateV2(user, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-user-cohort-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } +} + +func TestEvaluateV2GroupCohort(t *testing.T) { + user := &experiment.User{ + UserId: "12345", + DeviceId: "device_id", + Groups: map[string][]string{ + "org id": {"1"}, + }} + flagKeys := []string{"sdk-local-evaluation-group-cohort-ci-test"} + result, err := client.EvaluateV2(user, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-group-cohort-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } +} diff --git a/pkg/experiment/local/flag_config_api.go b/pkg/experiment/local/flag_config_api.go index 2964891..b8dc324 100644 --- a/pkg/experiment/local/flag_config_api.go +++ b/pkg/experiment/local/flag_config_api.go @@ -32,7 +32,7 @@ func newFlagConfigApiV2(deploymentKey, serverURL string, flagConfigPollerRequest func (a *flagConfigApiV2) getFlagConfigs() (map[string]*evaluation.Flag, error) { client := &http.Client{} - endpoint, err := url.Parse("https://api.lab.amplitude.com/") + endpoint, err := url.Parse(a.ServerURL) if err != nil { return nil, err } From 27e08762500682f6f941d8e338edfe95eb9d116b Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 25 Jul 2024 13:14:52 -0700 Subject: [PATCH 16/29] handle godotenv error with print --- pkg/experiment/local/client_eu_test.go | 2 +- pkg/experiment/local/client_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/experiment/local/client_eu_test.go b/pkg/experiment/local/client_eu_test.go index b4a01ff..ea21bc9 100644 --- a/pkg/experiment/local/client_eu_test.go +++ b/pkg/experiment/local/client_eu_test.go @@ -13,7 +13,7 @@ var clientEU *Client func init() { err := godotenv.Load() if err != nil { - log.Fatalf("Error loading .env file: %v", err) + log.Printf("Error loading .env file: %v", err) } projectApiKey := os.Getenv("EU_API_KEY") secretKey := os.Getenv("EU_SECRET_KEY") diff --git a/pkg/experiment/local/client_test.go b/pkg/experiment/local/client_test.go index 2cb346b..ee32125 100644 --- a/pkg/experiment/local/client_test.go +++ b/pkg/experiment/local/client_test.go @@ -13,7 +13,7 @@ var client *Client func init() { err := godotenv.Load() if err != nil { - log.Fatalf("Error loading .env file: %v", err) + log.Printf("Error loading .env file: %v", err) } projectApiKey := os.Getenv("API_KEY") secretKey := os.Getenv("SECRET_KEY") From 40df4cb9a23dd6fde627bcaeeae7a95b01a3f109 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 25 Jul 2024 14:20:07 -0700 Subject: [PATCH 17/29] update client tests --- pkg/experiment/local/client_eu_test.go | 13 +++++++++-- pkg/experiment/local/client_test.go | 31 ++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/pkg/experiment/local/client_eu_test.go b/pkg/experiment/local/client_eu_test.go index ea21bc9..c0805e0 100644 --- a/pkg/experiment/local/client_eu_test.go +++ b/pkg/experiment/local/client_eu_test.go @@ -30,9 +30,10 @@ func init() { } func TestEvaluateV2CohortEU(t *testing.T) { - user := &experiment.User{UserId: "1", DeviceId: "0"} + targetedUser := &experiment.User{UserId: "1", DeviceId: "0"} + nonTargetedUser := &experiment.User{UserId: "not_targeted", DeviceId: "0"} flagKeys := []string{"sdk-local-evaluation-user-cohort"} - result, err := clientEU.EvaluateV2(user, flagKeys) + result, err := clientEU.EvaluateV2(targetedUser, flagKeys) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -43,4 +44,12 @@ func TestEvaluateV2CohortEU(t *testing.T) { if variant.Value != "on" { t.Fatalf("Unexpected variant %v", variant) } + result, err = clientEU.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-user-cohort"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } } diff --git a/pkg/experiment/local/client_test.go b/pkg/experiment/local/client_test.go index ee32125..bada72a 100644 --- a/pkg/experiment/local/client_test.go +++ b/pkg/experiment/local/client_test.go @@ -172,9 +172,10 @@ func TestFlagMetadataLocalFlagKey(t *testing.T) { } func TestEvaluateV2Cohort(t *testing.T) { - user := &experiment.User{UserId: "12345"} + targetedUser := &experiment.User{UserId: "12345"} + nonTargetedUser := &experiment.User{UserId: "not_targeted"} flagKeys := []string{"sdk-local-evaluation-user-cohort-ci-test"} - result, err := client.EvaluateV2(user, flagKeys) + result, err := client.EvaluateV2(targetedUser, flagKeys) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -185,17 +186,31 @@ func TestEvaluateV2Cohort(t *testing.T) { if variant.Value != "on" { t.Fatalf("Unexpected variant %v", variant) } + result, err = client.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-user-cohort-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } } func TestEvaluateV2GroupCohort(t *testing.T) { - user := &experiment.User{ + targetedUser := &experiment.User{ UserId: "12345", DeviceId: "device_id", Groups: map[string][]string{ "org id": {"1"}, }} + nonTargetedUser := &experiment.User{ + UserId: "12345", + DeviceId: "device_id", + Groups: map[string][]string{ + "org id": {"not_targeted"}, + }} flagKeys := []string{"sdk-local-evaluation-group-cohort-ci-test"} - result, err := client.EvaluateV2(user, flagKeys) + result, err := client.EvaluateV2(targetedUser, flagKeys) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -206,4 +221,12 @@ func TestEvaluateV2GroupCohort(t *testing.T) { if variant.Value != "on" { t.Fatalf("Unexpected variant %v", variant) } + result, err = client.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-group-cohort-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } } From c4c426085d39a0a332c9c1cac29c481653baddb7 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 25 Jul 2024 14:27:06 -0700 Subject: [PATCH 18/29] update README --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 300c716..ccf1fcf 100644 --- a/README.md +++ b/README.md @@ -110,3 +110,12 @@ Fetch variants for a user given an experiment user JSON object ``` > Note: must use single quotes around JSON object string + +### Running unit tests suite +To set up for running test on local, create a `.env` file in `pkg/experiment/local` with following +contents, and replace `{API_KEY}` and `{SECRET_KEY}` (or `{EU_API_KEY}` and `{EU_SECRET_KEY}` for EU data center) for the project in test: + +``` +API_KEY={API_KEY} +SECRET_KEY={SECRET_KEY} +``` From 73672b2d516875312561b4b91e0cec8c210089f4 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 30 Jul 2024 17:02:57 -0700 Subject: [PATCH 19/29] Cohort not modified should not throw exception --- pkg/experiment/local/cohort_download_api.go | 4 ++-- pkg/experiment/local/cohort_download_api_test.go | 7 +++---- pkg/experiment/local/cohort_loader.go | 4 +++- pkg/experiment/local/exception.go | 8 -------- 4 files changed, 8 insertions(+), 15 deletions(-) diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 45e5320..6d09fcd 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -48,7 +48,7 @@ func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) ( errors++ if errors >= 3 || func(err error) bool { switch err.(type) { - case *CohortNotModifiedException, *CohortTooLargeException: + case *CohortTooLargeException: return true default: return false @@ -85,7 +85,7 @@ func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) ( }(), }, nil } else if response.StatusCode == http.StatusNoContent { - return nil, &CohortNotModifiedException{Message: "Cohort not modified"} + return nil, nil } else if response.StatusCode == http.StatusRequestEntityTooLarge { return nil, &CohortTooLargeException{Message: "Cohort exceeds max cohort size"} } else { diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go index 8029283..fbd6422 100644 --- a/pkg/experiment/local/cohort_download_api_test.go +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -214,9 +214,8 @@ func TestCohortDownloadApi(t *testing.T) { httpmock.NewStringResponder(204, ""), ) - _, err := api.getCohort("1234", cohort) - assert.Error(t, err) - _, isCohortNotModifiedException := err.(*CohortNotModifiedException) - assert.True(t, isCohortNotModifiedException) + result, err := api.getCohort("1234", cohort) + assert.Nil(t, result) + assert.NoError(t, err) }) } diff --git a/pkg/experiment/local/cohort_loader.go b/pkg/experiment/local/cohort_loader.go index 7023306..94d30c0 100644 --- a/pkg/experiment/local/cohort_loader.go +++ b/pkg/experiment/local/cohort_loader.go @@ -69,7 +69,9 @@ func (task *CohortLoaderTask) run() { if err != nil { task.err = err } else { - task.loader.cohortStorage.putCohort(cohort) + if cohort != nil { + task.loader.cohortStorage.putCohort(cohort) + } } task.loader.removeJob(task.cohortId) diff --git a/pkg/experiment/local/exception.go b/pkg/experiment/local/exception.go index c151916..498d3ab 100644 --- a/pkg/experiment/local/exception.go +++ b/pkg/experiment/local/exception.go @@ -16,11 +16,3 @@ type CohortTooLargeException struct { func (e *CohortTooLargeException) Error() string { return e.Message } - -type CohortNotModifiedException struct { - Message string -} - -func (e *CohortNotModifiedException) Error() string { - return e.Message -} From aa97fc4c85453c5fc7a4db46dff60408d0a31a95 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 30 Jul 2024 17:04:48 -0700 Subject: [PATCH 20/29] debug message for cohort not modified --- pkg/experiment/local/cohort_download_api.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 6d09fcd..2b022ed 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -85,6 +85,7 @@ func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) ( }(), }, nil } else if response.StatusCode == http.StatusNoContent { + api.log.Debug("getCohortMembers(%s): Cohort not modified", cohortID) return nil, nil } else if response.StatusCode == http.StatusRequestEntityTooLarge { return nil, &CohortTooLargeException{Message: "Cohort exceeds max cohort size"} From a2d74f80604db2743dc4cff9f2078c88e87b32a5 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 30 Jul 2024 17:08:55 -0700 Subject: [PATCH 21/29] nit: update test name --- pkg/experiment/local/cohort_download_api_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go index fbd6422..f000ae0 100644 --- a/pkg/experiment/local/cohort_download_api_test.go +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -207,7 +207,7 @@ func TestCohortDownloadApi(t *testing.T) { assert.True(t, isCohortTooLargeException) }) - t.Run("test_cohort_not_modified_exception", func(t *testing.T) { + t.Run("test_cohort_not_modified", func(t *testing.T) { cohort := &Cohort{Id: "1234", LastModified: 1000, Size: 1, MemberIds: []string{}} httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), From 683d0c2d2274c130a79d71c6aba1102c1eb9457d Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 1 Aug 2024 15:43:52 -0700 Subject: [PATCH 22/29] do not throw error on cohort download failure, log error on evaluate --- pkg/experiment/local/client.go | 29 +++++++++++++-- pkg/experiment/local/cohort_download_api.go | 2 +- pkg/experiment/local/deployment_runner.go | 35 ++++--------------- .../local/deployment_runner_test.go | 6 ++-- 4 files changed, 37 insertions(+), 35 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 73a1d69..489032c 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -116,13 +116,14 @@ func (c *Client) Evaluate(user *experiment.User, flagKeys []string) (map[string] func (c *Client) EvaluateV2(user *experiment.User, flagKeys []string) (map[string]experiment.Variant, error) { flagConfigs := c.flagConfigStorage.getFlagConfigs() - enrichedUser, err := c.enrichUser(user, flagConfigs) + sortedFlags, err := topologicalSort(flagConfigs, flagKeys) + c.requiredCohortsInStorage(sortedFlags) + enrichedUser, err := c.enrichUserWithCohorts(user, flagConfigs) if err != nil { return nil, err } userContext := evaluation.UserToContext(enrichedUser) c.flagsMutex.RLock() - sortedFlags, err := topologicalSort(flagConfigs, flagKeys) c.flagsMutex.RUnlock() if err != nil { return nil, err @@ -338,7 +339,29 @@ func coerceString(value interface{}) string { return fmt.Sprintf("%v", value) } -func (c *Client) enrichUser(user *experiment.User, flagConfigs map[string]*evaluation.Flag) (*experiment.User, error) { +func (c *Client) requiredCohortsInStorage(flagConfigs []*evaluation.Flag) { + storedCohortIDs := c.cohortStorage.getCohortIds() + for _, flag := range flagConfigs { + flagCohortIDs := getAllCohortIDsFromFlag(flag) + missingCohorts := difference(flagCohortIDs, storedCohortIDs) + + if len(missingCohorts) > 0 { + if c.config.CohortSyncConfig != nil { + c.log.Debug( + "Evaluating flag %s dependent on cohorts %v without %v in storage", + flag.Key, flagCohortIDs, missingCohorts, + ) + } else { + c.log.Debug( + "Evaluating flag %s dependent on cohorts %v without cohort syncing configured", + flag.Key, flagCohortIDs, + ) + } + } + } +} + +func (c *Client) enrichUserWithCohorts(user *experiment.User, flagConfigs map[string]*evaluation.Flag) (*experiment.User, error) { flagConfigSlice := make([]*evaluation.Flag, 0, len(flagConfigs)) for _, value := range flagConfigs { diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 2b022ed..3c10d2d 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -88,7 +88,7 @@ func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) ( api.log.Debug("getCohortMembers(%s): Cohort not modified", cohortID) return nil, nil } else if response.StatusCode == http.StatusRequestEntityTooLarge { - return nil, &CohortTooLargeException{Message: "Cohort exceeds max cohort size"} + return nil, &CohortTooLargeException{Message: "Cohort exceeds max cohort size of " + strconv.Itoa(api.MaxCohortSize)} } else { return nil, &HTTPErrorResponseException{StatusCode: response.StatusCode, Message: "Unexpected response code"} } diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 0f7b133..6080d9c 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -2,7 +2,6 @@ package local import ( "fmt" - "strings" "sync" "github.com/amplitude/experiment-go-server/internal/evaluation" @@ -119,31 +118,20 @@ func (dr *deploymentRunner) updateFlagConfigs() error { // Get updated set of cohort ids updatedCohortIDs := dr.cohortStorage.getCohortIds() // Iterate through new flag configs and check if their required cohorts exist - failedFlagCount := 0 for _, flagConfig := range flagConfigs { cohortIDs := getAllCohortIDsFromFlag(flagConfig) - if len(cohortIDs) == 0 || dr.cohortLoader == nil { - dr.flagConfigStorage.putFlagConfig(flagConfig) - dr.log.Debug("Putting non-cohort flag %s", flagConfig.Key) - } else if subset(cohortIDs, updatedCohortIDs) { - dr.flagConfigStorage.putFlagConfig(flagConfig) - dr.log.Debug("Putting flag %s", flagConfig.Key) - } else { - dr.log.Error("Flag %s not updated because not all required cohorts could be loaded", flagConfig.Key) - failedFlagCount++ + missingCohorts := difference(cohortIDs, updatedCohortIDs) + + dr.flagConfigStorage.putFlagConfig(flagConfig) + dr.log.Debug("Putting flag %s", flagConfig.Key) + if len(missingCohorts) != 0 { + dr.log.Error("Flag %s - failed to load cohorts: %v", flagConfig.Key, missingCohorts) } } // Delete unused cohorts dr.deleteUnusedCohorts() - dr.log.Debug("Refreshed %d flag configs.", len(flagConfigs)-failedFlagCount) - - // If there are any download errors, raise an aggregated exception - if len(cohortDownloadErrors) > 0 { - errorCount := len(cohortDownloadErrors) - errorMessages := strings.Join(cohortDownloadErrors, "\n") - return fmt.Errorf("%d cohort(s) failed to download:\n%s", errorCount, errorMessages) - } + dr.log.Debug("Refreshed %d flag configs.", len(flagConfigs)) return nil } @@ -183,12 +171,3 @@ func difference(set1, set2 map[string]struct{}) map[string]struct{} { } return diff } - -func subset(subset, set map[string]struct{}) bool { - for k := range subset { - if _, exists := set[k]; !exists { - return false - } - } - return true -} diff --git a/pkg/experiment/local/deployment_runner_test.go b/pkg/experiment/local/deployment_runner_test.go index 5c22064..691dae7 100644 --- a/pkg/experiment/local/deployment_runner_test.go +++ b/pkg/experiment/local/deployment_runner_test.go @@ -36,7 +36,7 @@ func TestStartThrowsIfFirstFlagConfigLoadFails(t *testing.T) { } } -func TestStartThrowsIfFirstCohortLoadFails(t *testing.T) { +func TestStartSucceedsEvenIfFirstCohortLoadFails(t *testing.T) { flagAPI := &mockFlagConfigApi{getFlagConfigsFunc: func() (map[string]*evaluation.Flag, error) { return map[string]*evaluation.Flag{"flag": createTestFlag()}, nil }} @@ -57,8 +57,8 @@ func TestStartThrowsIfFirstCohortLoadFails(t *testing.T) { err := runner.start() - if err == nil { - t.Error("Expected error but got nil") + if err != nil { + t.Errorf("Expected no error but got %v", err) } } From eb77b8b60f0502f3f9ed7f80867dab2375bb9619 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 1 Aug 2024 15:50:04 -0700 Subject: [PATCH 23/29] fix lint --- pkg/experiment/local/client.go | 5 +++-- pkg/experiment/local/deployment_runner.go | 3 --- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 489032c..876b2e5 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -117,14 +117,15 @@ func (c *Client) Evaluate(user *experiment.User, flagKeys []string) (map[string] func (c *Client) EvaluateV2(user *experiment.User, flagKeys []string) (map[string]experiment.Variant, error) { flagConfigs := c.flagConfigStorage.getFlagConfigs() sortedFlags, err := topologicalSort(flagConfigs, flagKeys) + if err != nil { + return nil, err + } c.requiredCohortsInStorage(sortedFlags) enrichedUser, err := c.enrichUserWithCohorts(user, flagConfigs) if err != nil { return nil, err } userContext := evaluation.UserToContext(enrichedUser) - c.flagsMutex.RLock() - c.flagsMutex.RUnlock() if err != nil { return nil, err } diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 6080d9c..e557b4d 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -1,7 +1,6 @@ package local import ( - "fmt" "sync" "github.com/amplitude/experiment-go-server/internal/evaluation" @@ -105,12 +104,10 @@ func (dr *deploymentRunner) updateFlagConfigs() error { existingCohortIDs := dr.cohortStorage.getCohortIds() cohortIDsToDownload := difference(newCohortIDs, existingCohortIDs) - var cohortDownloadErrors []string // Download all new cohorts for cohortID := range cohortIDsToDownload { if err := dr.cohortLoader.loadCohort(cohortID).wait(); err != nil { - cohortDownloadErrors = append(cohortDownloadErrors, fmt.Sprintf("Cohort %s: %v", cohortID, err)) dr.log.Error("Download cohort %s failed: %v", cohortID, err) } } From c21c81c6f9bda468830a3fdee88e8e1caa1b6235 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 5 Aug 2024 16:17:41 -0700 Subject: [PATCH 24/29] make exceptions private to package --- pkg/experiment/local/cohort_download_api.go | 6 +++--- pkg/experiment/local/cohort_download_api_test.go | 2 +- pkg/experiment/local/exception.go | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 3c10d2d..5706943 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -48,7 +48,7 @@ func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) ( errors++ if errors >= 3 || func(err error) bool { switch err.(type) { - case *CohortTooLargeException: + case *cohortTooLargeException: return true default: return false @@ -88,9 +88,9 @@ func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) ( api.log.Debug("getCohortMembers(%s): Cohort not modified", cohortID) return nil, nil } else if response.StatusCode == http.StatusRequestEntityTooLarge { - return nil, &CohortTooLargeException{Message: "Cohort exceeds max cohort size of " + strconv.Itoa(api.MaxCohortSize)} + return nil, &cohortTooLargeException{Message: "Cohort exceeds max cohort size of " + strconv.Itoa(api.MaxCohortSize)} } else { - return nil, &HTTPErrorResponseException{StatusCode: response.StatusCode, Message: "Unexpected response code"} + return nil, &httpErrorResponseException{StatusCode: response.StatusCode, Message: "Unexpected response code"} } } } diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go index f000ae0..7e2b319 100644 --- a/pkg/experiment/local/cohort_download_api_test.go +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -203,7 +203,7 @@ func TestCohortDownloadApi(t *testing.T) { _, err := api.getCohort("1234", cohort) assert.Error(t, err) - _, isCohortTooLargeException := err.(*CohortTooLargeException) + _, isCohortTooLargeException := err.(*cohortTooLargeException) assert.True(t, isCohortTooLargeException) }) diff --git a/pkg/experiment/local/exception.go b/pkg/experiment/local/exception.go index 498d3ab..8e8dba6 100644 --- a/pkg/experiment/local/exception.go +++ b/pkg/experiment/local/exception.go @@ -1,18 +1,18 @@ package local -type HTTPErrorResponseException struct { +type httpErrorResponseException struct { StatusCode int Message string } -func (e *HTTPErrorResponseException) Error() string { +func (e *httpErrorResponseException) Error() string { return e.Message } -type CohortTooLargeException struct { +type cohortTooLargeException struct { Message string } -func (e *CohortTooLargeException) Error() string { +func (e *cohortTooLargeException) Error() string { return e.Message } From 5d6055f35e5542e6463f84860aa3d30b55a2c1fb Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 5 Aug 2024 22:53:32 -0700 Subject: [PATCH 25/29] download cohorts gets cohort ids from flags --- pkg/experiment/local/cohort_loader.go | 37 ----------------- pkg/experiment/local/deployment_runner.go | 45 ++++++++++++++++----- pkg/experiment/local/flag_config_storage.go | 12 ++++++ 3 files changed, 48 insertions(+), 46 deletions(-) diff --git a/pkg/experiment/local/cohort_loader.go b/pkg/experiment/local/cohort_loader.go index 94d30c0..3fa597b 100644 --- a/pkg/experiment/local/cohort_loader.go +++ b/pkg/experiment/local/cohort_loader.go @@ -1,8 +1,6 @@ package local import ( - "fmt" - "strings" "sync" "sync/atomic" ) @@ -88,38 +86,3 @@ func (cl *cohortLoader) downloadCohort(cohortID string) (*Cohort, error) { cohort := cl.cohortStorage.getCohort(cohortID) return cl.cohortDownloadApi.getCohort(cohortID, cohort) } - -func (cl *cohortLoader) updateStoredCohorts() error { - var wg sync.WaitGroup - errorChan := make(chan error, len(cl.cohortStorage.getCohortIds())) - - cohortIds := make([]string, 0, len(cl.cohortStorage.getCohortIds())) - for id := range cl.cohortStorage.getCohortIds() { - cohortIds = append(cohortIds, id) - } - - for _, cohortID := range cohortIds { - wg.Add(1) - go func(id string) { - defer wg.Done() - task := cl.loadCohort(id) - if err := task.wait(); err != nil { - errorChan <- fmt.Errorf("cohort %s: %v", id, err) - } - }(cohortID) - } - - wg.Wait() - close(errorChan) - - var errorMessages []string - for err := range errorChan { - errorMessages = append(errorMessages, err.Error()) - } - - if len(errorMessages) > 0 { - return fmt.Errorf("One or more cohorts failed to download:\n%s", - strings.Join(errorMessages, "\n")) - } - return nil -} diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index e557b4d..ff80d49 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -1,6 +1,8 @@ package local import ( + "fmt" + "strings" "sync" "github.com/amplitude/experiment-go-server/internal/evaluation" @@ -106,11 +108,7 @@ func (dr *deploymentRunner) updateFlagConfigs() error { cohortIDsToDownload := difference(newCohortIDs, existingCohortIDs) // Download all new cohorts - for cohortID := range cohortIDsToDownload { - if err := dr.cohortLoader.loadCohort(cohortID).wait(); err != nil { - dr.log.Error("Download cohort %s failed: %v", cohortID, err) - } - } + dr.downloadCohorts(cohortIDsToDownload) // Get updated set of cohort ids updatedCohortIDs := dr.cohortStorage.getCohortIds() @@ -134,10 +132,8 @@ func (dr *deploymentRunner) updateFlagConfigs() error { } func (dr *deploymentRunner) updateStoredCohorts() { - err := dr.cohortLoader.updateStoredCohorts() - if err != nil { - dr.log.Error("Error updating stored cohorts: %v", err) - } + cohortIDs := getAllCohortIDsFromFlags(dr.flagConfigStorage.getFlagConfigsArray()) + dr.downloadCohorts(cohortIDs) } func (dr *deploymentRunner) deleteUnusedCohorts() { @@ -168,3 +164,34 @@ func difference(set1, set2 map[string]struct{}) map[string]struct{} { } return diff } + +func (dr *deploymentRunner) downloadCohorts(cohortIDs map[string]struct{}) { + var wg sync.WaitGroup + errorChan := make(chan error, len(cohortIDs)) + + for cohortID := range cohortIDs { + wg.Add(1) + go func(id string) { + defer wg.Done() + task := dr.cohortLoader.loadCohort(id) + if err := task.wait(); err != nil { + errorChan <- fmt.Errorf("cohort %s: %v", id, err) + } + }(cohortID) + } + + go func() { + wg.Wait() + close(errorChan) + }() + + var errorMessages []string + for err := range errorChan { + errorMessages = append(errorMessages, err.Error()) + dr.log.Error("Error downloading cohort: %v", err) + } + + if len(errorMessages) > 0 { + dr.log.Error("One or more cohorts failed to download:\n%s", strings.Join(errorMessages, "\n")) + } +} diff --git a/pkg/experiment/local/flag_config_storage.go b/pkg/experiment/local/flag_config_storage.go index 4635354..02daea6 100644 --- a/pkg/experiment/local/flag_config_storage.go +++ b/pkg/experiment/local/flag_config_storage.go @@ -8,6 +8,7 @@ import ( type flagConfigStorage interface { getFlagConfig(key string) *evaluation.Flag getFlagConfigs() map[string]*evaluation.Flag + getFlagConfigsArray() []*evaluation.Flag putFlagConfig(flagConfig *evaluation.Flag) removeIf(condition func(*evaluation.Flag) bool) } @@ -39,6 +40,17 @@ func (storage *inMemoryFlagConfigStorage) getFlagConfigs() map[string]*evaluatio return copyFlagConfigs } +func (storage *inMemoryFlagConfigStorage) getFlagConfigsArray() []*evaluation.Flag { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + + var copyFlagConfigs []*evaluation.Flag + for _, value := range storage.flagConfigs { + copyFlagConfigs = append(copyFlagConfigs, value) + } + return copyFlagConfigs +} + func (storage *inMemoryFlagConfigStorage) putFlagConfig(flagConfig *evaluation.Flag) { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() From 28c1793260b8cb7e6fb2f5a56cfc7b874dd2b614 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 6 Aug 2024 11:12:44 -0700 Subject: [PATCH 26/29] use const for cohort poller interval --- pkg/experiment/local/deployment_runner.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index ff80d49..332c633 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -4,11 +4,14 @@ import ( "fmt" "strings" "sync" + "time" "github.com/amplitude/experiment-go-server/internal/evaluation" "github.com/amplitude/experiment-go-server/internal/logger" ) +const CohortPollerInterval = 60 * time.Second + type deploymentRunner struct { config *Config flagConfigApi flagConfigApi @@ -55,7 +58,7 @@ func (dr *deploymentRunner) start() error { }) if dr.cohortLoader != nil { - dr.poller.Poll(dr.config.FlagConfigPollerInterval, func() { + dr.poller.Poll(CohortPollerInterval, func() { dr.updateStoredCohorts() }) } From 82cc98a84d6994b0f853085de574a496598029c6 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 6 Aug 2024 14:26:07 -0700 Subject: [PATCH 27/29] update cohort_sync_config fields: include polling and remove request delay, use enum for serverzone, update tests accordingly --- pkg/experiment/local/client.go | 2 +- pkg/experiment/local/client_eu_test.go | 2 +- pkg/experiment/local/cohort_download_api.go | 32 +++++++------- .../local/cohort_download_api_test.go | 2 +- pkg/experiment/local/config.go | 44 +++++++++++-------- pkg/experiment/local/config_test.go | 35 +++++++-------- pkg/experiment/local/deployment_runner.go | 12 ++--- 7 files changed, 66 insertions(+), 63 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 876b2e5..989b710 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -58,7 +58,7 @@ func Initialize(apiKey string, config *Config) *Client { var cohortLoader *cohortLoader var deploymentRunner *deploymentRunner if config.CohortSyncConfig != nil { - cohortDownloadApi := newDirectCohortDownloadApi(config.CohortSyncConfig.ApiKey, config.CohortSyncConfig.SecretKey, config.CohortSyncConfig.MaxCohortSize, config.CohortSyncConfig.CohortRequestDelayMillis, config.CohortSyncConfig.CohortServerUrl, config.Debug) + cohortDownloadApi := newDirectCohortDownloadApi(config.CohortSyncConfig.ApiKey, config.CohortSyncConfig.SecretKey, config.CohortSyncConfig.MaxCohortSize, config.CohortSyncConfig.CohortServerUrl, config.Debug) cohortLoader = newCohortLoader(cohortDownloadApi, cohortStorage) } deploymentRunner = newDeploymentRunner(config, newFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), flagConfigStorage, cohortStorage, cohortLoader) diff --git a/pkg/experiment/local/client_eu_test.go b/pkg/experiment/local/client_eu_test.go index c0805e0..0bca7c2 100644 --- a/pkg/experiment/local/client_eu_test.go +++ b/pkg/experiment/local/client_eu_test.go @@ -22,7 +22,7 @@ func init() { SecretKey: secretKey, } clientEU = Initialize("server-Qlp7XiSu6JtP2S3JzA95PnP27duZgQCF", - &Config{CohortSyncConfig: &cohortSyncConfig, ServerZone: "eu"}) + &Config{CohortSyncConfig: &cohortSyncConfig, ServerZone: EUServerZone}) err = clientEU.Start() if err != nil { panic(err) diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 5706943..8fabac0 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -9,29 +9,29 @@ import ( "time" ) +const cohortRequestDelay = 100 * time.Millisecond + type cohortDownloadApi interface { getCohort(cohortID string, cohort *Cohort) (*Cohort, error) } type directCohortDownloadApi struct { - ApiKey string - SecretKey string - MaxCohortSize int - CohortRequestDelayMillis int - ServerUrl string - Debug bool - log *logger.Log + ApiKey string + SecretKey string + MaxCohortSize int + ServerUrl string + Debug bool + log *logger.Log } -func newDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize, cohortRequestDelayMillis int, serverUrl string, debug bool) *directCohortDownloadApi { +func newDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize int, serverUrl string, debug bool) *directCohortDownloadApi { api := &directCohortDownloadApi{ - ApiKey: apiKey, - SecretKey: secretKey, - MaxCohortSize: maxCohortSize, - CohortRequestDelayMillis: cohortRequestDelayMillis, - ServerUrl: serverUrl, - Debug: debug, - log: logger.New(debug), + ApiKey: apiKey, + SecretKey: secretKey, + MaxCohortSize: maxCohortSize, + ServerUrl: serverUrl, + Debug: debug, + log: logger.New(debug), } return api } @@ -56,7 +56,7 @@ func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) ( }(err) { return nil, err } - time.Sleep(time.Duration(api.CohortRequestDelayMillis) * time.Millisecond) + time.Sleep(cohortRequestDelay) continue } diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go index 7e2b319..7ddc870 100644 --- a/pkg/experiment/local/cohort_download_api_test.go +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -33,7 +33,7 @@ func TestCohortDownloadApi(t *testing.T) { httpmock.Activate() defer httpmock.DeactivateAndReset() - api := newDirectCohortDownloadApi("api", "secret", 15000, 100, "https://server.amplitude.com", false) + api := newDirectCohortDownloadApi("api", "secret", 15000, "https://server.amplitude.com", false) t.Run("test_cohort_download_success", func(t *testing.T) { cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} diff --git a/pkg/experiment/local/config.go b/pkg/experiment/local/config.go index 94850ff..c896fd7 100644 --- a/pkg/experiment/local/config.go +++ b/pkg/experiment/local/config.go @@ -3,17 +3,23 @@ package local import ( "github.com/amplitude/analytics-go/amplitude" "math" - "strings" "time" ) const EUFlagServerUrl = "https://flag.lab.eu.amplitude.com" const EUCohortSyncUrl = "https://cohort-v2.lab.eu.amplitude.com" +type ServerZone int + +const ( + USServerZone ServerZone = iota + EUServerZone +) + type Config struct { Debug bool ServerUrl string - ServerZone string + ServerZone ServerZone FlagConfigPollerInterval time.Duration FlagConfigPollerRequestTimeout time.Duration AssignmentConfig *AssignmentConfig @@ -26,17 +32,17 @@ type AssignmentConfig struct { } type CohortSyncConfig struct { - ApiKey string - SecretKey string - MaxCohortSize int - CohortRequestDelayMillis int - CohortServerUrl string + ApiKey string + SecretKey string + MaxCohortSize int + CohortPollingInterval time.Duration + CohortServerUrl string } var DefaultConfig = &Config{ Debug: false, ServerUrl: "https://api.lab.amplitude.com/", - ServerZone: "us", + ServerZone: USServerZone, FlagConfigPollerInterval: 30 * time.Second, FlagConfigPollerRequestTimeout: 10 * time.Second, } @@ -46,22 +52,23 @@ var DefaultAssignmentConfig = &AssignmentConfig{ } var DefaultCohortSyncConfig = &CohortSyncConfig{ - MaxCohortSize: math.MaxInt32, - CohortRequestDelayMillis: 5000, - CohortServerUrl: "https://cohort-v2.lab.amplitude.com", + MaxCohortSize: math.MaxInt32, + CohortPollingInterval: 60 * time.Second, + CohortServerUrl: "https://cohort-v2.lab.amplitude.com", } func fillConfigDefaults(c *Config) *Config { if c == nil { return DefaultConfig } - if c.ServerZone == "" { + if c.ServerZone == 0 { c.ServerZone = DefaultConfig.ServerZone } if c.ServerUrl == "" { - if strings.EqualFold(strings.ToLower(c.ServerZone), strings.ToLower(DefaultConfig.ServerZone)) { + switch c.ServerZone { + case USServerZone: c.ServerUrl = DefaultConfig.ServerUrl - } else if strings.EqualFold(strings.ToLower(c.ServerZone), "eu") { + case EUServerZone: c.ServerUrl = EUFlagServerUrl } } @@ -80,14 +87,15 @@ func fillConfigDefaults(c *Config) *Config { c.CohortSyncConfig.MaxCohortSize = DefaultCohortSyncConfig.MaxCohortSize } - if c.CohortSyncConfig != nil && c.CohortSyncConfig.CohortRequestDelayMillis == 0 { - c.CohortSyncConfig.CohortRequestDelayMillis = DefaultCohortSyncConfig.CohortRequestDelayMillis + if c.CohortSyncConfig != nil && (c.CohortSyncConfig.CohortPollingInterval < DefaultCohortSyncConfig.CohortPollingInterval) { + c.CohortSyncConfig.CohortPollingInterval = DefaultCohortSyncConfig.CohortPollingInterval } if c.CohortSyncConfig != nil && c.CohortSyncConfig.CohortServerUrl == "" { - if strings.EqualFold(strings.ToLower(c.ServerZone), strings.ToLower(DefaultConfig.ServerZone)) { + switch c.ServerZone { + case USServerZone: c.CohortSyncConfig.CohortServerUrl = DefaultCohortSyncConfig.CohortServerUrl - } else if strings.EqualFold(strings.ToLower(c.ServerZone), "eu") { + case EUServerZone: c.CohortSyncConfig.CohortServerUrl = EUCohortSyncUrl } } diff --git a/pkg/experiment/local/config_test.go b/pkg/experiment/local/config_test.go index 8622b0c..6c790e7 100644 --- a/pkg/experiment/local/config_test.go +++ b/pkg/experiment/local/config_test.go @@ -1,7 +1,6 @@ package local import ( - "strings" "testing" "time" ) @@ -10,7 +9,7 @@ func TestFillConfigDefaults_ServerZoneAndServerUrl(t *testing.T) { tests := []struct { name string input *Config - expectedZone string + expectedZone ServerZone expectedUrl string }{ { @@ -27,26 +26,26 @@ func TestFillConfigDefaults_ServerZoneAndServerUrl(t *testing.T) { }, { name: "ServerZone US", - input: &Config{ServerZone: "us"}, - expectedZone: "us", + input: &Config{ServerZone: USServerZone}, + expectedZone: USServerZone, expectedUrl: DefaultConfig.ServerUrl, }, { name: "ServerZone EU", - input: &Config{ServerZone: "eu"}, - expectedZone: "eu", + input: &Config{ServerZone: EUServerZone}, + expectedZone: EUServerZone, expectedUrl: EUFlagServerUrl, }, { name: "Uppercase ServerZone EU", - input: &Config{ServerZone: "EU"}, - expectedZone: "EU", + input: &Config{ServerZone: EUServerZone}, + expectedZone: EUServerZone, expectedUrl: EUFlagServerUrl, }, { name: "Custom ServerUrl", - input: &Config{ServerZone: "us", ServerUrl: "https://custom.url/"}, - expectedZone: "us", + input: &Config{ServerZone: USServerZone, ServerUrl: "https://custom.url/"}, + expectedZone: USServerZone, expectedUrl: "https://custom.url/", }, } @@ -54,8 +53,8 @@ func TestFillConfigDefaults_ServerZoneAndServerUrl(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := fillConfigDefaults(tt.input) - if !strings.EqualFold(result.ServerZone, tt.expectedZone) { - t.Errorf("expected ServerZone %s, got %s", tt.expectedZone, result.ServerZone) + if result.ServerZone != tt.expectedZone { + t.Errorf("expected ServerZone %d, got %d", tt.expectedZone, result.ServerZone) } if result.ServerUrl != tt.expectedUrl { t.Errorf("expected ServerUrl %s, got %s", tt.expectedUrl, result.ServerUrl) @@ -73,7 +72,7 @@ func TestFillConfigDefaults_CohortSyncConfig(t *testing.T) { { name: "Nil CohortSyncConfig", input: &Config{ - ServerZone: "eu", + ServerZone: EUServerZone, CohortSyncConfig: nil, }, expectedUrl: "", @@ -81,7 +80,7 @@ func TestFillConfigDefaults_CohortSyncConfig(t *testing.T) { { name: "CohortSyncConfig with empty CohortServerUrl", input: &Config{ - ServerZone: "eu", + ServerZone: EUServerZone, CohortSyncConfig: &CohortSyncConfig{}, }, expectedUrl: EUCohortSyncUrl, @@ -89,7 +88,7 @@ func TestFillConfigDefaults_CohortSyncConfig(t *testing.T) { { name: "CohortSyncConfig with custom CohortServerUrl", input: &Config{ - ServerZone: "us", + ServerZone: USServerZone, CohortSyncConfig: &CohortSyncConfig{ CohortServerUrl: "https://custom-cohort.url/", }, @@ -141,13 +140,13 @@ func TestFillConfigDefaults_DefaultValues(t *testing.T) { { name: "Custom values", input: &Config{ - ServerZone: "eu", + ServerZone: EUServerZone, ServerUrl: "https://custom.url/", FlagConfigPollerInterval: 60 * time.Second, FlagConfigPollerRequestTimeout: 20 * time.Second, }, expected: &Config{ - ServerZone: "eu", + ServerZone: EUServerZone, ServerUrl: "https://custom.url/", FlagConfigPollerInterval: 60 * time.Second, FlagConfigPollerRequestTimeout: 20 * time.Second, @@ -159,7 +158,7 @@ func TestFillConfigDefaults_DefaultValues(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := fillConfigDefaults(tt.input) if result.ServerZone != tt.expected.ServerZone { - t.Errorf("expected ServerZone %s, got %s", tt.expected.ServerZone, result.ServerZone) + t.Errorf("expected ServerZone %d, got %d", tt.expected.ServerZone, result.ServerZone) } if result.ServerUrl != tt.expected.ServerUrl { t.Errorf("expected ServerUrl %s, got %s", tt.expected.ServerUrl, result.ServerUrl) diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 332c633..040ea0f 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -2,16 +2,12 @@ package local import ( "fmt" - "strings" - "sync" - "time" - "github.com/amplitude/experiment-go-server/internal/evaluation" "github.com/amplitude/experiment-go-server/internal/logger" + "strings" + "sync" ) -const CohortPollerInterval = 60 * time.Second - type deploymentRunner struct { config *Config flagConfigApi flagConfigApi @@ -57,8 +53,8 @@ func (dr *deploymentRunner) start() error { } }) - if dr.cohortLoader != nil { - dr.poller.Poll(CohortPollerInterval, func() { + if dr.config.CohortSyncConfig != nil { + dr.poller.Poll(dr.config.CohortSyncConfig.CohortPollingInterval, func() { dr.updateStoredCohorts() }) } From 1fc1c3d0dabca009da359b005722a7452129d875 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 6 Aug 2024 14:30:49 -0700 Subject: [PATCH 28/29] make cohort storage private --- pkg/experiment/local/client.go | 2 +- pkg/experiment/local/cohort_loader.go | 4 ++-- pkg/experiment/local/cohort_storage.go | 2 +- pkg/experiment/local/deployment_runner.go | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 989b710..07286a7 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -30,7 +30,7 @@ type Client struct { flagsMutex *sync.RWMutex engine *evaluation.Engine assignmentService *assignmentService - cohortStorage CohortStorage + cohortStorage cohortStorage flagConfigStorage flagConfigStorage cohortLoader *cohortLoader deploymentRunner *deploymentRunner diff --git a/pkg/experiment/local/cohort_loader.go b/pkg/experiment/local/cohort_loader.go index 3fa597b..d325315 100644 --- a/pkg/experiment/local/cohort_loader.go +++ b/pkg/experiment/local/cohort_loader.go @@ -7,13 +7,13 @@ import ( type cohortLoader struct { cohortDownloadApi cohortDownloadApi - cohortStorage CohortStorage + cohortStorage cohortStorage jobs sync.Map executor *sync.Pool lockJobs sync.Mutex } -func newCohortLoader(cohortDownloadApi cohortDownloadApi, cohortStorage CohortStorage) *cohortLoader { +func newCohortLoader(cohortDownloadApi cohortDownloadApi, cohortStorage cohortStorage) *cohortLoader { return &cohortLoader{ cohortDownloadApi: cohortDownloadApi, cohortStorage: cohortStorage, diff --git a/pkg/experiment/local/cohort_storage.go b/pkg/experiment/local/cohort_storage.go index 66981e5..7a81f84 100644 --- a/pkg/experiment/local/cohort_storage.go +++ b/pkg/experiment/local/cohort_storage.go @@ -4,7 +4,7 @@ import ( "sync" ) -type CohortStorage interface { +type cohortStorage interface { getCohort(cohortID string) *Cohort getCohorts() map[string]*Cohort getCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 040ea0f..2b85437 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -12,7 +12,7 @@ type deploymentRunner struct { config *Config flagConfigApi flagConfigApi flagConfigStorage flagConfigStorage - cohortStorage CohortStorage + cohortStorage cohortStorage cohortLoader *cohortLoader lock sync.Mutex poller *poller @@ -23,7 +23,7 @@ func newDeploymentRunner( config *Config, flagConfigApi flagConfigApi, flagConfigStorage flagConfigStorage, - cohortStorage CohortStorage, + cohortStorage cohortStorage, cohortLoader *cohortLoader, ) *deploymentRunner { dr := &deploymentRunner{ From 6d78f4956f10e26dfdd811334c7cc2cac82a4187 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 7 Aug 2024 14:01:55 -0700 Subject: [PATCH 29/29] add SDK+version to cohort request headers --- pkg/experiment/local/cohort_download_api.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go index 8fabac0..2d5ab8c 100644 --- a/pkg/experiment/local/cohort_download_api.go +++ b/pkg/experiment/local/cohort_download_api.go @@ -3,7 +3,9 @@ package local import ( "encoding/base64" "encoding/json" + "fmt" "github.com/amplitude/experiment-go-server/internal/logger" + "github.com/amplitude/experiment-go-server/pkg/experiment" "net/http" "strconv" "time" @@ -101,6 +103,7 @@ func (api *directCohortDownloadApi) getCohortMembersRequest(client *http.Client, return nil, err } req.Header.Set("Authorization", "Basic "+api.getBasicAuth()) + req.Header.Set("X-Amp-Exp-Library", fmt.Sprintf("experiment-go-server/%v", experiment.VERSION)) return client.Do(req) }