diff --git a/.mockery.yaml b/.mockery.yaml index 410e9b031..9913a542d 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -143,4 +143,13 @@ packages: filename: "activity_service.go" Repository: config: - filename: "servicedata_repository.go" \ No newline at end of file + filename: "servicedata_repository.go" + github.com/goto/shield/internal/store/inmemory: + config: + dir: "internal/store/inmemory/mocks" + outpkg: "mocks" + mockname: "{{.InterfaceName}}" + interfaces: + GroupRepository: + config: + filename: "group_repository.go" \ No newline at end of file diff --git a/cmd/serve.go b/cmd/serve.go index fd96f090b..52a837b23 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -34,6 +34,7 @@ import ( "github.com/goto/shield/internal/schema" "github.com/goto/shield/internal/server" "github.com/goto/shield/internal/store/blob" + "github.com/goto/shield/internal/store/inmemory" "github.com/goto/shield/internal/store/postgres" "github.com/goto/shield/internal/store/spicedb" "github.com/goto/shield/pkg/db" @@ -149,7 +150,7 @@ func StartServer(logger *log.Zap, cfg *config.Shield) error { return err } - deps, err := BuildAPIDependencies(ctx, logger, activityRepository, resourceBlobRepository, dbClient, spiceDBClient) + deps, err := BuildAPIDependencies(ctx, logger, activityRepository, resourceBlobRepository, dbClient, spiceDBClient, cfg) if err != nil { return err } @@ -191,7 +192,13 @@ func BuildAPIDependencies( resourceBlobRepository *blob.ResourcesRepository, dbc *db.Client, sdb *spicedb.SpiceDB, + cfg *config.Shield, ) (api.Deps, error) { + cache, err := inmemory.NewCache(cfg.App.CacheConfig) + if err != nil { + return api.Deps{}, err + } + appConfig := activity.AppConfig{Version: config.Version} activityService := activity.NewService(appConfig, activityRepository) @@ -212,7 +219,8 @@ func BuildAPIDependencies( relationService := relation.NewService(logger, relationPGRepository, relationSpiceRepository, userService, activityService) groupRepository := postgres.NewGroupRepository(dbc) - groupService := group.NewService(logger, groupRepository, relationService, userService, activityService) + cachedGroupRepository := inmemory.NewCachedGroupRepository(cache, groupRepository) + groupService := group.NewService(logger, groupRepository, cachedGroupRepository, relationService, userService, activityService) organizationRepository := postgres.NewOrganizationRepository(dbc) organizationService := organization.NewService(logger, organizationRepository, relationService, userService, activityService) diff --git a/core/group/group.go b/core/group/group.go index c924709cc..33c27edf4 100644 --- a/core/group/group.go +++ b/core/group/group.go @@ -22,6 +22,10 @@ type Repository interface { ListGroupRelations(ctx context.Context, objectId, subjectType, role string) ([]relation.RelationV2, error) } +type CachedRepository interface { + GetBySlug(ctx context.Context, slug string) (Group, error) +} + type Group struct { ID string Name string diff --git a/core/group/service.go b/core/group/service.go index 57b2cd3a7..632cc41c0 100644 --- a/core/group/service.go +++ b/core/group/service.go @@ -40,15 +40,17 @@ type ActivityService interface { type Service struct { logger log.Logger repository Repository + cacheRepository CachedRepository relationService RelationService userService UserService activityService ActivityService } -func NewService(logger log.Logger, repository Repository, relationService RelationService, userService UserService, activityService ActivityService) *Service { +func NewService(logger log.Logger, repository Repository, cacheRepository CachedRepository, relationService RelationService, userService UserService, activityService ActivityService) *Service { return &Service{ logger: logger, repository: repository, + cacheRepository: cacheRepository, relationService: relationService, userService: userService, activityService: activityService, @@ -90,7 +92,7 @@ func (s Service) Get(ctx context.Context, idOrSlug string) (Group, error) { } func (s Service) GetBySlug(ctx context.Context, slug string) (Group, error) { - return s.repository.GetBySlug(ctx, slug) + return s.cacheRepository.GetBySlug(ctx, slug) } func (s Service) GetByIDs(ctx context.Context, groupIDs []string) ([]Group, error) { diff --git a/go.mod b/go.mod index c0eec2a53..bc838217c 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/authzed/authzed-go v0.7.1-0.20221109204547-1aa903788b3b github.com/authzed/grpcutil v0.0.0-20230908193239-4286bb1d6403 github.com/authzed/spicedb v1.15.0 + github.com/dgraph-io/ristretto v0.1.1 github.com/doug-martin/goqu/v9 v9.18.0 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/ghodss/yaml v1.0.0 @@ -76,6 +77,7 @@ require ( github.com/cenkalti/backoff v2.2.1+incompatible // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/charmbracelet/glamour v0.6.0 // indirect github.com/cli/safeexec v1.0.1 // indirect github.com/containerd/continuity v0.3.0 // indirect @@ -84,6 +86,7 @@ require ( github.com/dlclark/regexp2 v1.9.0 // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dustin/go-humanize v1.0.0 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/fatih/color v1.14.1 // indirect github.com/felixge/fgprof v0.9.3 // indirect @@ -93,6 +96,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/golang/glog v1.2.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/cel-go v0.13.0 // indirect github.com/google/pprof v0.0.0-20221219190121-3cb0bae90811 // indirect diff --git a/go.sum b/go.sum index fb18f969a..5e8604064 100644 --- a/go.sum +++ b/go.sum @@ -632,6 +632,8 @@ github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d/go.mod h1:sGbDF6 github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/glamour v0.6.0 h1:wi8fse3Y7nfcabbbDuwolqTqMQPMnVPeZhDM273bISc= github.com/charmbracelet/glamour v0.6.0/go.mod h1:taqWV4swIMMbWALc0m7AfE9JkPSU8om2538k9ITBxOc= github.com/checkpoint-restore/go-criu/v4 v4.1.0/go.mod h1:xUQBLp4RLc5zJtWY++yjOoMoB5lihDt7fai+75m+rGw= @@ -811,8 +813,12 @@ github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27 github.com/dennwc/varint v1.0.0/go.mod h1:hnItb35rvZvJrbTALZtY/iQfDs48JKRG1RPpgziApxA= github.com/denverdino/aliyungo v0.0.0-20190125010748-a747050bb1ba/go.mod h1:dV8lFg6daOBZbT6/BDGIz6Y3WFGn8juu6G+CQ6LHtl0= github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mzjeJY= +github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= +github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgrijalva/jwt-go v0.0.0-20170104182250-a601269ab70c/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dgryski/go-sip13 v0.0.0-20200911182023-62edffca9245/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dhui/dktest v0.3.16 h1:i6gq2YQEtcrjKbeJpBkWjE8MmLZPYllcjOFbTZuPDnw= @@ -852,6 +858,7 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3 github.com/doug-martin/goqu/v9 v9.18.0 h1:/6bcuEtAe6nsSMVK/M+fOiXUNfyFF3yYtE07DBPFMYY= github.com/doug-martin/goqu/v9 v9.18.0/go.mod h1:nf0Wc2/hV3gYK9LiyqIrzBEVGlI8qW3GuDCEobC4wBQ= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= @@ -2509,6 +2516,7 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220908150016-7ac13a9a928d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/server/config.go b/internal/server/config.go index a93116ec3..a45ee2dfd 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -2,6 +2,8 @@ package server import ( "fmt" + + "github.com/goto/shield/internal/store/inmemory" ) type GRPCConfig struct { @@ -63,4 +65,6 @@ type Config struct { ServiceData ServiceDataConfig `yaml:"service_data" mapstructure:"service_data"` PublicAPIPrefix string `yaml:"public_api_prefix" mapstructure:"public_api_prefix" default:"/shield"` + + CacheConfig inmemory.Config `yaml:"cache" mapstructure:"cache"` } diff --git a/internal/store/inmemory/cache.go b/internal/store/inmemory/cache.go new file mode 100644 index 000000000..53d3a65e8 --- /dev/null +++ b/internal/store/inmemory/cache.go @@ -0,0 +1,39 @@ +package inmemory + +import ( + "errors" + + "github.com/dgraph-io/ristretto" +) + +var ErrParsing = errors.New("parsing error") + +type Config struct { + NumCounters int64 `yaml:"num_counters" mapstructure:"num_counters" default:"10000000"` + MaxCost int64 `yaml:"max_cost" mapstructure:"max_cost" default:"1073741824"` + BufferItems int64 `yaml:"buffer_items" mapstructure:"buffer_items" default:"64"` + Metrics bool `yaml:"metrics" mapstructure:"metrics" default:"true"` + TTLInSeconds int `yaml:"ttl_in_seconds" mapstructure:"ttl_in_seconds" default:"3600"` +} + +type Cache struct { + *ristretto.Cache + config Config +} + +func NewCache(cfg Config) (Cache, error) { + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: cfg.NumCounters, + MaxCost: cfg.MaxCost, + BufferItems: cfg.BufferItems, + Metrics: cfg.Metrics, + }) + if err != nil { + return Cache{}, err + } + + return Cache{ + Cache: cache, + config: cfg, + }, nil +} diff --git a/internal/store/inmemory/group_repository.go b/internal/store/inmemory/group_repository.go new file mode 100644 index 000000000..c30e5e19f --- /dev/null +++ b/internal/store/inmemory/group_repository.go @@ -0,0 +1,51 @@ +package inmemory + +import ( + "context" + "fmt" + + "github.com/goto/shield/core/group" +) + +var keyPrefix = "group" + +type GroupRepository interface { + GetBySlug(ctx context.Context, slug string) (group.Group, error) +} + +type CachedGroupRepository struct { + cache Cache + repository GroupRepository +} + +func NewCachedGroupRepository(cache Cache, repository GroupRepository) *CachedGroupRepository { + return &CachedGroupRepository{ + cache: cache, + repository: repository, + } +} + +func getKey(identifier string) string { + return fmt.Sprintf("%s:%s", keyPrefix, identifier) +} + +func (r CachedGroupRepository) GetBySlug(ctx context.Context, slug string) (group.Group, error) { + key := getKey(slug) + grp, found := r.cache.Get(key) + if !found { + grp, err := r.repository.GetBySlug(ctx, slug) + if err != nil { + return group.Group{}, err + } + + r.cache.Set(key, grp, 0) + return grp, nil + } + + grpParsed, ok := grp.(group.Group) + if !ok { + return group.Group{}, ErrParsing + } + + return grpParsed, nil +} diff --git a/internal/store/inmemory/group_repository_test.go b/internal/store/inmemory/group_repository_test.go new file mode 100644 index 000000000..9a1e9d13c --- /dev/null +++ b/internal/store/inmemory/group_repository_test.go @@ -0,0 +1,124 @@ +package inmemory + +import ( + "context" + "errors" + "testing" + + "github.com/goto/shield/core/group" + "github.com/goto/shield/internal/store/inmemory/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + testCacheConfig = Config{ + NumCounters: 10000000, + MaxCost: 1073741824, + BufferItems: 64, + Metrics: true, + TTLInSeconds: 3600, + } + testGroupSlug = "test-group-slug" + testGroup = group.Group{ + ID: "test-group-id", + Slug: testGroupSlug, + Name: "test group", + } +) + +func TestGetBySlug(t *testing.T) { + t.Parallel() + + testCases := []struct { + description string + slug string + setup func(t *testing.T) *CachedGroupRepository + want group.Group + wantErr error + }{ + { + description: "should retrieve group from cache", + slug: testGroupSlug, + setup: func(t *testing.T) *CachedGroupRepository { + t.Helper() + groupRepository := &mocks.GroupRepository{} + c, err := NewCache(testCacheConfig) + if err != nil { + return nil + } + c.Set(getKey(testGroupSlug), testGroup, 0) + c.Wait() + return NewCachedGroupRepository(c, groupRepository) + }, + want: testGroup, + }, + { + description: "should retrieve group from repository", + slug: testGroupSlug, + setup: func(t *testing.T) *CachedGroupRepository { + t.Helper() + groupRepository := &mocks.GroupRepository{} + c, err := NewCache(testCacheConfig) + if err != nil { + return nil + } + groupRepository.EXPECT().GetBySlug(mock.Anything, testGroupSlug). + Return(testGroup, nil) + return NewCachedGroupRepository(c, groupRepository) + }, + want: testGroup, + }, + { + description: "should return parse error if cache data invalid", + slug: testGroupSlug, + setup: func(t *testing.T) *CachedGroupRepository { + t.Helper() + groupRepository := &mocks.GroupRepository{} + c, err := NewCache(testCacheConfig) + if err != nil { + return nil + } + c.Set(getKey(testGroupSlug), "invalid-group-data", 0) + c.Wait() + return NewCachedGroupRepository(c, groupRepository) + }, + wantErr: ErrParsing, + }, + { + description: "should return error from repository", + slug: testGroupSlug, + setup: func(t *testing.T) *CachedGroupRepository { + t.Helper() + groupRepository := &mocks.GroupRepository{} + c, err := NewCache(testCacheConfig) + if err != nil { + return nil + } + groupRepository.EXPECT().GetBySlug(mock.Anything, testGroupSlug). + Return(group.Group{}, group.ErrInvalidDetail) + return NewCachedGroupRepository(c, groupRepository) + }, + wantErr: group.ErrInvalidDetail, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.description, func(t *testing.T) { + t.Parallel() + cacheRepo := tc.setup(t) + assert.NotNil(t, cacheRepo) + + ctx := context.TODO() + got, err := cacheRepo.GetBySlug(ctx, tc.slug) + if tc.wantErr != nil { + assert.Error(t, err) + assert.True(t, errors.Is(err, tc.wantErr)) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/internal/store/inmemory/mocks/group_repository.go b/internal/store/inmemory/mocks/group_repository.go new file mode 100644 index 000000000..0e2c718be --- /dev/null +++ b/internal/store/inmemory/mocks/group_repository.go @@ -0,0 +1,94 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + group "github.com/goto/shield/core/group" + mock "github.com/stretchr/testify/mock" +) + +// GroupRepository is an autogenerated mock type for the GroupRepository type +type GroupRepository struct { + mock.Mock +} + +type GroupRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *GroupRepository) EXPECT() *GroupRepository_Expecter { + return &GroupRepository_Expecter{mock: &_m.Mock} +} + +// GetBySlug provides a mock function with given fields: ctx, slug +func (_m *GroupRepository) GetBySlug(ctx context.Context, slug string) (group.Group, error) { + ret := _m.Called(ctx, slug) + + if len(ret) == 0 { + panic("no return value specified for GetBySlug") + } + + var r0 group.Group + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (group.Group, error)); ok { + return rf(ctx, slug) + } + if rf, ok := ret.Get(0).(func(context.Context, string) group.Group); ok { + r0 = rf(ctx, slug) + } else { + r0 = ret.Get(0).(group.Group) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, slug) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GroupRepository_GetBySlug_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBySlug' +type GroupRepository_GetBySlug_Call struct { + *mock.Call +} + +// GetBySlug is a helper method to define mock.On call +// - ctx context.Context +// - slug string +func (_e *GroupRepository_Expecter) GetBySlug(ctx interface{}, slug interface{}) *GroupRepository_GetBySlug_Call { + return &GroupRepository_GetBySlug_Call{Call: _e.mock.On("GetBySlug", ctx, slug)} +} + +func (_c *GroupRepository_GetBySlug_Call) Run(run func(ctx context.Context, slug string)) *GroupRepository_GetBySlug_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *GroupRepository_GetBySlug_Call) Return(_a0 group.Group, _a1 error) *GroupRepository_GetBySlug_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *GroupRepository_GetBySlug_Call) RunAndReturn(run func(context.Context, string) (group.Group, error)) *GroupRepository_GetBySlug_Call { + _c.Call.Return(run) + return _c +} + +// NewGroupRepository creates a new instance of GroupRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewGroupRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *GroupRepository { + mock := &GroupRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/test/e2e_test/testbench/testbench.go b/test/e2e_test/testbench/testbench.go index 24bf46fff..84326bc1d 100644 --- a/test/e2e_test/testbench/testbench.go +++ b/test/e2e_test/testbench/testbench.go @@ -16,6 +16,7 @@ import ( "github.com/goto/shield/config" "github.com/goto/shield/internal/proxy" "github.com/goto/shield/internal/server" + "github.com/goto/shield/internal/store/inmemory" "github.com/goto/shield/internal/store/postgres/migrations" "github.com/goto/shield/internal/store/spicedb" "github.com/goto/shield/pkg/db" @@ -206,6 +207,13 @@ func SetupTests(t *testing.T) (shieldv1beta1.ShieldServiceClient, shieldv1beta1. MaxNumUpsertData: 1, }, PublicAPIPrefix: "/shield", + CacheConfig: inmemory.Config{ + NumCounters: 10000000, + MaxCost: 1073741824, + BufferItems: 64, + Metrics: true, + TTLInSeconds: 3600, + }, }, Proxy: proxy.ServicesConfig{ Services: []proxy.Config{