diff --git a/core/mocks/activity_service.go b/core/mocks/activity_service.go new file mode 100644 index 000000000..ddb6d4ba0 --- /dev/null +++ b/core/mocks/activity_service.go @@ -0,0 +1,81 @@ +// Code generated by mockery v2.32.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// ActivityService is an autogenerated mock type for the ActivityService type +type ActivityService struct { + mock.Mock +} + +type ActivityService_Expecter struct { + mock *mock.Mock +} + +func (_m *ActivityService) EXPECT() *ActivityService_Expecter { + return &ActivityService_Expecter{mock: &_m.Mock} +} + +// Log provides a mock function with given fields: ctx, action, actor, data +func (_m *ActivityService) Log(ctx context.Context, action string, actor string, data interface{}) error { + ret := _m.Called(ctx, action, actor, data) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, interface{}) error); ok { + r0 = rf(ctx, action, actor, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ActivityService_Log_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Log' +type ActivityService_Log_Call struct { + *mock.Call +} + +// Log is a helper method to define mock.On call +// - ctx context.Context +// - action string +// - actor string +// - data interface{} +func (_e *ActivityService_Expecter) Log(ctx interface{}, action interface{}, actor interface{}, data interface{}) *ActivityService_Log_Call { + return &ActivityService_Log_Call{Call: _e.mock.On("Log", ctx, action, actor, data)} +} + +func (_c *ActivityService_Log_Call) Run(run func(ctx context.Context, action string, actor string, data interface{})) *ActivityService_Log_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(interface{})) + }) + return _c +} + +func (_c *ActivityService_Log_Call) Return(_a0 error) *ActivityService_Log_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ActivityService_Log_Call) RunAndReturn(run func(context.Context, string, string, interface{}) error) *ActivityService_Log_Call { + _c.Call.Return(run) + return _c +} + +// NewActivityService creates a new instance of ActivityService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewActivityService(t interface { + mock.TestingT + Cleanup(func()) +}) *ActivityService { + mock := &ActivityService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/mocks/user_repository.go b/core/mocks/user_repository.go new file mode 100644 index 000000000..8411e6cab --- /dev/null +++ b/core/mocks/user_repository.go @@ -0,0 +1,465 @@ +// Code generated by mockery v2.32.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + user "github.com/goto/shield/core/user" + mock "github.com/stretchr/testify/mock" +) + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, _a1 +func (_m *Repository) Create(ctx context.Context, _a1 user.User) (user.User, error) { + ret := _m.Called(ctx, _a1) + + var r0 user.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, user.User) (user.User, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, user.User) user.User); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(user.User) + } + + if rf, ok := ret.Get(1).(func(context.Context, user.User) error); ok { + r1 = rf(ctx, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type Repository_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - _a1 user.User +func (_e *Repository_Expecter) Create(ctx interface{}, _a1 interface{}) *Repository_Create_Call { + return &Repository_Create_Call{Call: _e.mock.On("Create", ctx, _a1)} +} + +func (_c *Repository_Create_Call) Run(run func(ctx context.Context, _a1 user.User)) *Repository_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(user.User)) + }) + return _c +} + +func (_c *Repository_Create_Call) Return(_a0 user.User, _a1 error) *Repository_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, user.User) (user.User, error)) *Repository_Create_Call { + _c.Call.Return(run) + return _c +} + +// CreateMetadataKey provides a mock function with given fields: ctx, key +func (_m *Repository) CreateMetadataKey(ctx context.Context, key user.UserMetadataKey) (user.UserMetadataKey, error) { + ret := _m.Called(ctx, key) + + var r0 user.UserMetadataKey + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, user.UserMetadataKey) (user.UserMetadataKey, error)); ok { + return rf(ctx, key) + } + if rf, ok := ret.Get(0).(func(context.Context, user.UserMetadataKey) user.UserMetadataKey); ok { + r0 = rf(ctx, key) + } else { + r0 = ret.Get(0).(user.UserMetadataKey) + } + + if rf, ok := ret.Get(1).(func(context.Context, user.UserMetadataKey) error); ok { + r1 = rf(ctx, key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_CreateMetadataKey_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateMetadataKey' +type Repository_CreateMetadataKey_Call struct { + *mock.Call +} + +// CreateMetadataKey is a helper method to define mock.On call +// - ctx context.Context +// - key user.UserMetadataKey +func (_e *Repository_Expecter) CreateMetadataKey(ctx interface{}, key interface{}) *Repository_CreateMetadataKey_Call { + return &Repository_CreateMetadataKey_Call{Call: _e.mock.On("CreateMetadataKey", ctx, key)} +} + +func (_c *Repository_CreateMetadataKey_Call) Run(run func(ctx context.Context, key user.UserMetadataKey)) *Repository_CreateMetadataKey_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(user.UserMetadataKey)) + }) + return _c +} + +func (_c *Repository_CreateMetadataKey_Call) Return(_a0 user.UserMetadataKey, _a1 error) *Repository_CreateMetadataKey_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_CreateMetadataKey_Call) RunAndReturn(run func(context.Context, user.UserMetadataKey) (user.UserMetadataKey, error)) *Repository_CreateMetadataKey_Call { + _c.Call.Return(run) + return _c +} + +// GetByEmail provides a mock function with given fields: ctx, email +func (_m *Repository) GetByEmail(ctx context.Context, email string) (user.User, error) { + ret := _m.Called(ctx, email) + + var r0 user.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (user.User, error)); ok { + return rf(ctx, email) + } + if rf, ok := ret.Get(0).(func(context.Context, string) user.User); ok { + r0 = rf(ctx, email) + } else { + r0 = ret.Get(0).(user.User) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, email) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByEmail' +type Repository_GetByEmail_Call struct { + *mock.Call +} + +// GetByEmail is a helper method to define mock.On call +// - ctx context.Context +// - email string +func (_e *Repository_Expecter) GetByEmail(ctx interface{}, email interface{}) *Repository_GetByEmail_Call { + return &Repository_GetByEmail_Call{Call: _e.mock.On("GetByEmail", ctx, email)} +} + +func (_c *Repository_GetByEmail_Call) Run(run func(ctx context.Context, email string)) *Repository_GetByEmail_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetByEmail_Call) Return(_a0 user.User, _a1 error) *Repository_GetByEmail_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByEmail_Call) RunAndReturn(run func(context.Context, string) (user.User, error)) *Repository_GetByEmail_Call { + _c.Call.Return(run) + return _c +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *Repository) GetByID(ctx context.Context, id string) (user.User, error) { + ret := _m.Called(ctx, id) + + var r0 user.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (user.User, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) user.User); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(user.User) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type Repository_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) GetByID(ctx interface{}, id interface{}) *Repository_GetByID_Call { + return &Repository_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *Repository_GetByID_Call) Run(run func(ctx context.Context, id string)) *Repository_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetByID_Call) Return(_a0 user.User, _a1 error) *Repository_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByID_Call) RunAndReturn(run func(context.Context, string) (user.User, error)) *Repository_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// GetByIDs provides a mock function with given fields: ctx, userIds +func (_m *Repository) GetByIDs(ctx context.Context, userIds []string) ([]user.User, error) { + ret := _m.Called(ctx, userIds) + + var r0 []user.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []string) ([]user.User, error)); ok { + return rf(ctx, userIds) + } + if rf, ok := ret.Get(0).(func(context.Context, []string) []user.User); ok { + r0 = rf(ctx, userIds) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]user.User) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { + r1 = rf(ctx, userIds) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByIDs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByIDs' +type Repository_GetByIDs_Call struct { + *mock.Call +} + +// GetByIDs is a helper method to define mock.On call +// - ctx context.Context +// - userIds []string +func (_e *Repository_Expecter) GetByIDs(ctx interface{}, userIds interface{}) *Repository_GetByIDs_Call { + return &Repository_GetByIDs_Call{Call: _e.mock.On("GetByIDs", ctx, userIds)} +} + +func (_c *Repository_GetByIDs_Call) Run(run func(ctx context.Context, userIds []string)) *Repository_GetByIDs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]string)) + }) + return _c +} + +func (_c *Repository_GetByIDs_Call) Return(_a0 []user.User, _a1 error) *Repository_GetByIDs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByIDs_Call) RunAndReturn(run func(context.Context, []string) ([]user.User, error)) *Repository_GetByIDs_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function with given fields: ctx, flt +func (_m *Repository) List(ctx context.Context, flt user.Filter) ([]user.User, error) { + ret := _m.Called(ctx, flt) + + var r0 []user.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, user.Filter) ([]user.User, error)); ok { + return rf(ctx, flt) + } + if rf, ok := ret.Get(0).(func(context.Context, user.Filter) []user.User); ok { + r0 = rf(ctx, flt) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]user.User) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, user.Filter) error); ok { + r1 = rf(ctx, flt) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type Repository_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - ctx context.Context +// - flt user.Filter +func (_e *Repository_Expecter) List(ctx interface{}, flt interface{}) *Repository_List_Call { + return &Repository_List_Call{Call: _e.mock.On("List", ctx, flt)} +} + +func (_c *Repository_List_Call) Run(run func(ctx context.Context, flt user.Filter)) *Repository_List_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(user.Filter)) + }) + return _c +} + +func (_c *Repository_List_Call) Return(_a0 []user.User, _a1 error) *Repository_List_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_List_Call) RunAndReturn(run func(context.Context, user.Filter) ([]user.User, error)) *Repository_List_Call { + _c.Call.Return(run) + return _c +} + +// UpdateByEmail provides a mock function with given fields: ctx, toUpdate +func (_m *Repository) UpdateByEmail(ctx context.Context, toUpdate user.User) (user.User, error) { + ret := _m.Called(ctx, toUpdate) + + var r0 user.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, user.User) (user.User, error)); ok { + return rf(ctx, toUpdate) + } + if rf, ok := ret.Get(0).(func(context.Context, user.User) user.User); ok { + r0 = rf(ctx, toUpdate) + } else { + r0 = ret.Get(0).(user.User) + } + + if rf, ok := ret.Get(1).(func(context.Context, user.User) error); ok { + r1 = rf(ctx, toUpdate) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_UpdateByEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateByEmail' +type Repository_UpdateByEmail_Call struct { + *mock.Call +} + +// UpdateByEmail is a helper method to define mock.On call +// - ctx context.Context +// - toUpdate user.User +func (_e *Repository_Expecter) UpdateByEmail(ctx interface{}, toUpdate interface{}) *Repository_UpdateByEmail_Call { + return &Repository_UpdateByEmail_Call{Call: _e.mock.On("UpdateByEmail", ctx, toUpdate)} +} + +func (_c *Repository_UpdateByEmail_Call) Run(run func(ctx context.Context, toUpdate user.User)) *Repository_UpdateByEmail_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(user.User)) + }) + return _c +} + +func (_c *Repository_UpdateByEmail_Call) Return(_a0 user.User, _a1 error) *Repository_UpdateByEmail_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_UpdateByEmail_Call) RunAndReturn(run func(context.Context, user.User) (user.User, error)) *Repository_UpdateByEmail_Call { + _c.Call.Return(run) + return _c +} + +// UpdateByID provides a mock function with given fields: ctx, toUpdate +func (_m *Repository) UpdateByID(ctx context.Context, toUpdate user.User) (user.User, error) { + ret := _m.Called(ctx, toUpdate) + + var r0 user.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, user.User) (user.User, error)); ok { + return rf(ctx, toUpdate) + } + if rf, ok := ret.Get(0).(func(context.Context, user.User) user.User); ok { + r0 = rf(ctx, toUpdate) + } else { + r0 = ret.Get(0).(user.User) + } + + if rf, ok := ret.Get(1).(func(context.Context, user.User) error); ok { + r1 = rf(ctx, toUpdate) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_UpdateByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateByID' +type Repository_UpdateByID_Call struct { + *mock.Call +} + +// UpdateByID is a helper method to define mock.On call +// - ctx context.Context +// - toUpdate user.User +func (_e *Repository_Expecter) UpdateByID(ctx interface{}, toUpdate interface{}) *Repository_UpdateByID_Call { + return &Repository_UpdateByID_Call{Call: _e.mock.On("UpdateByID", ctx, toUpdate)} +} + +func (_c *Repository_UpdateByID_Call) Run(run func(ctx context.Context, toUpdate user.User)) *Repository_UpdateByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(user.User)) + }) + return _c +} + +func (_c *Repository_UpdateByID_Call) Return(_a0 user.User, _a1 error) *Repository_UpdateByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_UpdateByID_Call) RunAndReturn(run func(context.Context, user.User) (user.User, error)) *Repository_UpdateByID_Call { + _c.Call.Return(run) + return _c +} + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/user/service.go b/core/user/service.go index 03a50853c..0f38e8af9 100644 --- a/core/user/service.go +++ b/core/user/service.go @@ -6,13 +6,14 @@ import ( "strings" "github.com/goto/salt/log" + pkgctx "github.com/goto/shield/pkg/context" "github.com/goto/shield/pkg/uuid" ) const ( - auditKeyUserCreate = "user.create" - auditKeyUserUpdate = "user.update" + AuditKeyUserCreate = "user.create" + AuditKeyUserUpdate = "user.update" auditKeyUserMetadataKeyCreate = "user_metadata_key.create" ) @@ -61,7 +62,7 @@ func (s Service) Create(ctx context.Context, user User) (User, error) { newUser, err := s.repository.Create(ctx, User{ Name: user.Name, - Email: user.Email, + Email: strings.ToLower(user.Email), Metadata: user.Metadata, }) if err != nil { @@ -71,7 +72,7 @@ func (s Service) Create(ctx context.Context, user User) (User, error) { go func() { ctx := pkgctx.WithoutCancel(ctx) userLogData := newUser.ToUserLogData() - if err := s.activityService.Log(ctx, auditKeyUserCreate, currentUser.ID, userLogData); err != nil { + if err := s.activityService.Log(ctx, AuditKeyUserCreate, currentUser.ID, userLogData); err != nil { s.logger.Error(fmt.Sprintf("%s: %s", ErrLogActivity.Error(), err.Error())) } }() @@ -122,7 +123,12 @@ func (s Service) UpdateByID(ctx context.Context, toUpdate User) (User, error) { s.logger.Error(fmt.Sprintf("%s: %s", ErrInvalidEmail.Error(), err.Error())) } - updatedUser, err := s.repository.UpdateByID(ctx, toUpdate) + updatedUser, err := s.repository.UpdateByID(ctx, User{ + ID: toUpdate.ID, + Name: toUpdate.Name, + Email: strings.ToLower(toUpdate.Email), + Metadata: toUpdate.Metadata, + }) if err != nil { return User{}, err } @@ -130,7 +136,7 @@ func (s Service) UpdateByID(ctx context.Context, toUpdate User) (User, error) { go func() { ctx := pkgctx.WithoutCancel(ctx) userLogData := updatedUser.ToUserLogData() - if err := s.activityService.Log(ctx, auditKeyUserUpdate, currentUser.ID, userLogData); err != nil { + if err := s.activityService.Log(ctx, AuditKeyUserUpdate, currentUser.ID, userLogData); err != nil { s.logger.Error(fmt.Sprintf("%s: %s", ErrLogActivity.Error(), err.Error())) } }() @@ -144,7 +150,11 @@ func (s Service) UpdateByEmail(ctx context.Context, toUpdate User) (User, error) s.logger.Error(fmt.Sprintf("%s: %s", ErrInvalidEmail.Error(), err.Error())) } - updatedUser, err := s.repository.UpdateByEmail(ctx, toUpdate) + updatedUser, err := s.repository.UpdateByEmail(ctx, User{ + Name: toUpdate.Name, + Email: strings.ToLower(toUpdate.Email), + Metadata: toUpdate.Metadata, + }) if err != nil { return User{}, err } @@ -152,7 +162,7 @@ func (s Service) UpdateByEmail(ctx context.Context, toUpdate User) (User, error) go func() { ctx := pkgctx.WithoutCancel(ctx) userLogData := updatedUser.ToUserLogData() - if err := s.activityService.Log(ctx, auditKeyUserUpdate, currentUser.ID, userLogData); err != nil { + if err := s.activityService.Log(ctx, AuditKeyUserUpdate, currentUser.ID, userLogData); err != nil { s.logger.Error(fmt.Sprintf("%s: %s", ErrLogActivity.Error(), err.Error())) } }() diff --git a/core/user/service_test.go b/core/user/service_test.go new file mode 100644 index 000000000..30587f2c2 --- /dev/null +++ b/core/user/service_test.go @@ -0,0 +1,385 @@ +package user_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/goto/shield/core/user" + shieldlogger "github.com/goto/shield/pkg/logger" + + "github.com/goto/shield/core/mocks" + "github.com/goto/shield/pkg/logger" +) + +func TestService_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + user user.User + setup func(t *testing.T) *user.Service + urn string + want user.User + wantErr error + }{ + { + name: "CreateUserWithUpperCase", + user: user.User{ + Name: "John Doe", + Email: "John.Doe@gotocompany.com", + }, + setup: func(t *testing.T) *user.Service { + t.Helper() + repository := &mocks.Repository{} + activityService := &mocks.ActivityService{} + logger := shieldlogger.InitLogger(logger.Config{}) + repository.EXPECT(). + Create(mock.Anything, user.User{ + Name: "John Doe", + Email: "john.doe@gotocompany.com"}). + Return(user.User{ + Name: "John Doe", + Email: "john.doe@gotocompany.com"}, nil).Once() + + activityService.EXPECT(). + Log(mock.Anything, user.AuditKeyUserCreate, "", user.UserLogData{Entity: "user", Name: "John Doe", Email: "john.doe@gotocompany.com"}).Return(nil).Once() + return user.NewService(logger, repository, activityService) + }, + want: user.User{ + Name: "John Doe", + Email: "john.doe@gotocompany.com", + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := tt.setup(t) + + assert.NotNil(t, svc) + + got, err := svc.Create(context.Background(), tt.user) + if tt.wantErr != nil { + assert.Error(t, err) + assert.True(t, errors.Is(err, tt.wantErr)) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_UpdateByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + user user.User + setup func(t *testing.T) *user.Service + urn string + want user.User + wantErr error + }{ + { + name: "UpdateUserWithUpperCase", + user: user.User{ + ID: "1", + Name: "John Doe", + Email: "John.Doe2@gotocompany.com", + }, + setup: func(t *testing.T) *user.Service { + t.Helper() + repository := &mocks.Repository{} + activityService := &mocks.ActivityService{} + logger := shieldlogger.InitLogger(logger.Config{}) + repository.EXPECT(). + UpdateByID(mock.Anything, user.User{ + ID: "1", + Name: "John Doe", + Email: "john.doe2@gotocompany.com"}). + Return(user.User{ + ID: "1", + Name: "John Doe", + Email: "john.doe2@gotocompany.com"}, nil).Once() + + activityService.EXPECT(). + Log(mock.Anything, user.AuditKeyUserUpdate, "", user.UserLogData{Entity: "user", Name: "John Doe", Email: "john.doe2@gotocompany.com"}).Return(nil).Once() + return user.NewService(logger, repository, activityService) + }, + want: user.User{ + ID: "1", + Name: "John Doe", + Email: "john.doe2@gotocompany.com", + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := tt.setup(t) + + assert.NotNil(t, svc) + + got, err := svc.UpdateByID(context.Background(), tt.user) + if tt.wantErr != nil { + assert.Error(t, err) + assert.True(t, errors.Is(err, tt.wantErr)) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_UpdateByEmail(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + user user.User + setup func(t *testing.T) *user.Service + urn string + want user.User + wantErr error + }{ + { + name: "UpdateUserWithUpperCase", + user: user.User{ + Name: "John Doe", + Email: "John.Doe2@gotocompany.com", + }, + setup: func(t *testing.T) *user.Service { + t.Helper() + repository := &mocks.Repository{} + activityService := &mocks.ActivityService{} + logger := shieldlogger.InitLogger(logger.Config{}) + repository.EXPECT(). + UpdateByEmail(mock.Anything, user.User{ + Name: "John Doe", + Email: "john.doe2@gotocompany.com"}). + Return(user.User{ + Name: "John Doe", + Email: "john.doe2@gotocompany.com"}, nil).Once() + + activityService.EXPECT(). + Log(mock.Anything, user.AuditKeyUserUpdate, "", user.UserLogData{Entity: "user", Name: "John Doe", Email: "john.doe2@gotocompany.com"}).Return(nil).Once() + return user.NewService(logger, repository, activityService) + }, + want: user.User{ + Name: "John Doe", + Email: "john.doe2@gotocompany.com", + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := tt.setup(t) + + assert.NotNil(t, svc) + + got, err := svc.UpdateByEmail(context.Background(), tt.user) + if tt.wantErr != nil { + assert.Error(t, err) + assert.True(t, errors.Is(err, tt.wantErr)) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_GetByEmail(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + email string + setup func(t *testing.T) *user.Service + urn string + want user.User + wantErr error + }{ + { + name: "GetUserByEmail", + email: "john.doe@gotocompany.com", + setup: func(t *testing.T) *user.Service { + t.Helper() + repository := &mocks.Repository{} + activityService := &mocks.ActivityService{} + logger := shieldlogger.InitLogger(logger.Config{}) + repository.EXPECT(). + GetByEmail(mock.Anything, "john.doe@gotocompany.com"). + Return(user.User{ + ID: "1", + Name: "John Doe", + Email: "john.doe@gotocompany.com"}, nil).Once() + + activityService.EXPECT(). + Log(mock.Anything, user.AuditKeyUserUpdate, "", user.UserLogData{Entity: "user", Name: "John Doe", Email: "john.doe2@gotocompany.com"}).Return(nil).Once() + return user.NewService(logger, repository, activityService) + }, + want: user.User{ + ID: "1", + Name: "John Doe", + Email: "john.doe@gotocompany.com", + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := tt.setup(t) + + assert.NotNil(t, svc) + + got, err := svc.GetByEmail(context.Background(), tt.email) + if tt.wantErr != nil { + assert.Error(t, err) + assert.True(t, errors.Is(err, tt.wantErr)) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_GetByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + uuid string + setup func(t *testing.T) *user.Service + urn string + want user.User + wantErr error + }{ + { + name: "GetUserByEmail", + uuid: "qwer-1234-tyui-5678-opas-90", + setup: func(t *testing.T) *user.Service { + t.Helper() + repository := &mocks.Repository{} + activityService := &mocks.ActivityService{} + logger := shieldlogger.InitLogger(logger.Config{}) + repository.EXPECT(). + GetByID(mock.Anything, "qwer-1234-tyui-5678-opas-90"). + Return(user.User{ + ID: "qwer-1234-tyui-5678-opas-90", + Name: "John Doe", + Email: "john.doe@gotocompany.com"}, nil).Once() + + activityService.EXPECT(). + Log(mock.Anything, user.AuditKeyUserUpdate, "", user.UserLogData{Entity: "user", Name: "John Doe", Email: "john.doe2@gotocompany.com"}).Return(nil).Once() + return user.NewService(logger, repository, activityService) + }, + want: user.User{ + ID: "qwer-1234-tyui-5678-opas-90", + Name: "John Doe", + Email: "john.doe@gotocompany.com", + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := tt.setup(t) + + assert.NotNil(t, svc) + + got, err := svc.GetByID(context.Background(), tt.uuid) + if tt.wantErr != nil { + assert.Error(t, err) + assert.True(t, errors.Is(err, tt.wantErr)) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_FetchCurrentUser(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + email string + setup func(t *testing.T) *user.Service + urn string + want user.User + wantErr error + }{ + { + name: "FetchCurrentUser", + email: "john.doe@gotocompany.com", + setup: func(t *testing.T) *user.Service { + t.Helper() + repository := &mocks.Repository{} + activityService := &mocks.ActivityService{} + logger := shieldlogger.InitLogger(logger.Config{}) + repository.EXPECT(). + GetByEmail(mock.Anything, "john.doe@gotocompany.com"). + Return(user.User{ + ID: "qwer-1234-tyui-5678-opas-90", + Name: "John Doe", + Email: "john.doe@gotocompany.com"}, nil).Once() + + activityService.EXPECT(). + Log(mock.Anything, user.AuditKeyUserUpdate, "", user.UserLogData{Entity: "user", Name: "John Doe", Email: "john.doe2@gotocompany.com"}).Return(nil).Once() + return user.NewService(logger, repository, activityService) + }, + want: user.User{ + ID: "qwer-1234-tyui-5678-opas-90", + Name: "John Doe", + Email: "john.doe@gotocompany.com", + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := tt.setup(t) + + assert.NotNil(t, svc) + + ctx := user.SetContextWithEmail(context.Background(), tt.email) + + got, err := svc.FetchCurrentUser(ctx) + if tt.wantErr != nil { + assert.Error(t, err) + assert.True(t, errors.Is(err, tt.wantErr)) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/api/v1beta1/errors.go b/internal/api/v1beta1/errors.go index a2d17e2ec..42ce431a3 100644 --- a/internal/api/v1beta1/errors.go +++ b/internal/api/v1beta1/errors.go @@ -1,9 +1,10 @@ package v1beta1 import ( - "github.com/goto/shield/pkg/errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "github.com/goto/shield/pkg/errors" ) // HTTP Codes defined here: diff --git a/internal/api/v1beta1/user.go b/internal/api/v1beta1/user.go index 86c51613b..8dff7d230 100644 --- a/internal/api/v1beta1/user.go +++ b/internal/api/v1beta1/user.go @@ -3,6 +3,7 @@ package v1beta1 import ( "context" "errors" + "net/mail" "strings" grpczap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" @@ -89,6 +90,10 @@ func (h Handler) CreateUser(ctx context.Context, request *shieldv1beta1.CreateUs email = currentUserEmail } + if !isValidEmail(email) { + return nil, user.ErrInvalidEmail + } + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { logger.Error(err.Error()) @@ -240,6 +245,10 @@ func (h Handler) UpdateUser(ctx context.Context, request *shieldv1beta1.UpdateUs return nil, grpcBadBodyError } + if !isValidEmail(email) { + return nil, user.ErrInvalidEmail + } + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { return nil, grpcBadBodyError @@ -248,9 +257,9 @@ func (h Handler) UpdateUser(ctx context.Context, request *shieldv1beta1.UpdateUs id := request.GetId() if uuid.IsValid(id) { updatedUser, err = h.userService.UpdateByID(ctx, user.User{ - ID: request.GetId(), + ID: id, Name: request.GetBody().GetName(), - Email: request.GetBody().GetEmail(), + Email: email, Metadata: metaDataMap, }) if err != nil { @@ -282,7 +291,7 @@ func (h Handler) UpdateUser(ctx context.Context, request *shieldv1beta1.UpdateUs updatedUser, err = h.userService.UpdateByEmail(ctx, user.User{ Name: request.GetBody().GetName(), - Email: request.GetBody().GetEmail(), + Email: email, Metadata: metaDataMap, }) if err != nil { @@ -392,3 +401,8 @@ func transformUserToPB(usr user.User) (shieldv1beta1.User, error) { UpdatedAt: timestamppb.New(usr.UpdatedAt), }, nil } + +func isValidEmail(email string) bool { + _, err := mail.ParseAddress(email) + return err == nil +} diff --git a/internal/api/v1beta1/user_test.go b/internal/api/v1beta1/user_test.go index dd04750b6..2d940fb77 100644 --- a/internal/api/v1beta1/user_test.go +++ b/internal/api/v1beta1/user_test.go @@ -173,7 +173,23 @@ func TestCreateUser(t *testing.T) { want: nil, err: grpcBadBodyError, }, - + { + title: "should return invalid email error if email is invalid", + setup: func(ctx context.Context, us *mocks.UserService) context.Context { + return user.SetContextWithEmail(ctx, email) + }, + req: &shieldv1beta1.CreateUserRequest{Body: &shieldv1beta1.UserRequestBody{ + Name: "some user", + Email: "invalid email", + Metadata: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "foo": structpb.NewNullValue(), + }, + }, + }}, + want: nil, + err: user.ErrInvalidEmail, + }, { title: "should return already exist error if user service return error conflict", setup: func(ctx context.Context, us *mocks.UserService) context.Context { @@ -568,6 +584,25 @@ func TestUpdateUser(t *testing.T) { want: nil, err: grpcBadBodyError, }, + { + title: "should return invalid email error if email is invalid", + setup: func(us *mocks.UserService) { + }, + req: &shieldv1beta1.UpdateUserRequest{ + Id: someID, + Body: &shieldv1beta1.UserRequestBody{ + Name: "abc user", + Email: "invalid email", + Metadata: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "foo": structpb.NewStringValue("bar"), + }, + }, + }, + }, + want: nil, + err: user.ErrInvalidEmail, + }, { title: "should return bad request error if empty request body", req: &shieldv1beta1.UpdateUserRequest{Id: someID, Body: nil},