From e5035e81fe9f02bac75a420c7ce05f0e5fffba13 Mon Sep 17 00:00:00 2001 From: Nikos Date: Tue, 12 Nov 2024 13:43:18 +0200 Subject: [PATCH] refactor: merge user and device code storage --- handler/rfc8628/auth_handler.go | 40 +++------- handler/rfc8628/auth_handler_test.go | 25 ++---- handler/rfc8628/storage.go | 31 ++------ handler/rfc8628/token_handler_test.go | 32 ++++++-- integration/helper_setup_test.go | 5 +- internal/device_code_storage.go | 71 +++++++++-------- internal/rfc8628_core_storage.go | 40 ++-------- storage/memory.go | 105 ++++++++------------------ 8 files changed, 130 insertions(+), 219 deletions(-) diff --git a/handler/rfc8628/auth_handler.go b/handler/rfc8628/auth_handler.go index 70a6af8a..b5ebec12 100644 --- a/handler/rfc8628/auth_handler.go +++ b/handler/rfc8628/auth_handler.go @@ -30,14 +30,7 @@ type DeviceAuthHandler struct { func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar fosite.DeviceRequester, resp fosite.DeviceResponder) error { var err error - var deviceCode string - deviceCode, err = d.handleDeviceCode(ctx, dar) - if err != nil { - return err - } - - var userCode string - userCode, err = d.handleUserCode(ctx, dar) + deviceCode, userCode, err := d.handleDeviceAuthSession(ctx, dar) if err != nil { return err } @@ -52,23 +45,15 @@ func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar return nil } -func (d *DeviceAuthHandler) handleDeviceCode(ctx context.Context, dar fosite.DeviceRequester) (string, error) { - code, signature, err := d.Strategy.GenerateDeviceCode(ctx) - if err != nil { - return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } +func (d *DeviceAuthHandler) handleDeviceAuthSession(ctx context.Context, dar fosite.DeviceRequester) (string, string, error) { + var userCode, userCodeSignature string - dar.GetSession().SetExpiresAt(fosite.DeviceCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx))) - if err = d.Storage.CreateDeviceCodeSession(ctx, signature, dar.Sanitize(nil)); err != nil { - return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + deviceCode, deviceCodeSignature, err := d.Strategy.GenerateDeviceCode(ctx) + if err != nil { + return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } - return code, nil -} - -func (d *DeviceAuthHandler) handleUserCode(ctx context.Context, dar fosite.DeviceRequester) (string, error) { - var err error - var userCode, signature string + dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second)) // Note: the retries are added here because we need to ensure uniqueness of user codes. // The chances of duplicates should however be diminishing, because they are the same // chance an attacker will be able to hit a valid code with few guesses. However, as @@ -76,17 +61,16 @@ func (d *DeviceAuthHandler) handleUserCode(ctx context.Context, dar fosite.Devic // the chances of hitting a duplicate here can be higher. // Three retries should be plenty, as otherwise the entropy is definitely off. for i := 0; i < MaxAttempts; i++ { - userCode, signature, err = d.Strategy.GenerateUserCode(ctx) + userCode, userCodeSignature, err = d.Strategy.GenerateUserCode(ctx) if err != nil { - return "", err + return "", "", err } - dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second)) - if err = d.Storage.CreateUserCodeSession(ctx, signature, dar.Sanitize(nil)); err == nil { - return userCode, nil + if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil)); err == nil { + return deviceCode, userCode, nil } } errMsg := fmt.Sprintf("Exceeded user-code generation max attempts %v: %s", MaxAttempts, err.Error()) - return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(errMsg)) + return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(errMsg)) } diff --git a/handler/rfc8628/auth_handler_test.go b/handler/rfc8628/auth_handler_test.go index bbe2d525..e09d8cb7 100644 --- a/handler/rfc8628/auth_handler_test.go +++ b/handler/rfc8628/auth_handler_test.go @@ -85,20 +85,15 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) { EXPECT(). GenerateDeviceCode(ctx). Return("deviceCode", "signature", nil) - mockRFC8628CoreStorage. - EXPECT(). - CreateDeviceCodeSession(ctx, "signature", gomock.Any()). - Return(nil) mockRFC8628CodeStrategy. EXPECT(). GenerateUserCode(ctx). - Return("userCode", "signature", nil). + Return("userCode", "signature2", nil). Times(1) mockRFC8628CoreStorage. EXPECT(). - CreateUserCodeSession(ctx, "signature", gomock.Any()). - Return(nil). - Times(1) + CreateDeviceAuthSession(ctx, "signature", "signature2", gomock.Any()). + Return(nil) }, check: func(t *testing.T, resp *fosite.DeviceResponse) { assert.Equal(t, "userCode", resp.GetUserCode()) @@ -111,10 +106,6 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) { EXPECT(). GenerateDeviceCode(ctx). Return("deviceCode", "signature", nil) - mockRFC8628CoreStorage. - EXPECT(). - CreateDeviceCodeSession(ctx, "signature", gomock.Any()). - Return(nil) gomock.InOrder( mockRFC8628CodeStrategy. EXPECT(). @@ -122,7 +113,7 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) { Return("duplicatedUserCode", "duplicatedSignature", nil), mockRFC8628CoreStorage. EXPECT(). - CreateUserCodeSession(ctx, "duplicatedSignature", gomock.Any()). + CreateDeviceAuthSession(ctx, "signature", "duplicatedSignature", gomock.Any()). Return(errors.New("unique constraint violation")), mockRFC8628CodeStrategy. EXPECT(). @@ -130,7 +121,7 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) { Return("uniqueUserCode", "uniqueSignature", nil), mockRFC8628CoreStorage. EXPECT(). - CreateUserCodeSession(ctx, "uniqueSignature", gomock.Any()). + CreateDeviceAuthSession(ctx, "signature", "uniqueSignature", gomock.Any()). Return(nil), ) }, @@ -145,10 +136,6 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) { EXPECT(). GenerateDeviceCode(ctx). Return("deviceCode", "signature", nil) - mockRFC8628CoreStorage. - EXPECT(). - CreateDeviceCodeSession(ctx, "signature", gomock.Any()). - Return(nil) mockRFC8628CodeStrategy. EXPECT(). GenerateUserCode(ctx). @@ -156,7 +143,7 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) { Times(rfc8628.MaxAttempts) mockRFC8628CoreStorage. EXPECT(). - CreateUserCodeSession(ctx, "duplicatedSignature", gomock.Any()). + CreateDeviceAuthSession(ctx, "signature", "duplicatedSignature", gomock.Any()). Return(errors.New("unique constraint violation")). Times(rfc8628.MaxAttempts) }, diff --git a/handler/rfc8628/storage.go b/handler/rfc8628/storage.go index 8ae1b35e..e15dc97d 100644 --- a/handler/rfc8628/storage.go +++ b/handler/rfc8628/storage.go @@ -12,16 +12,15 @@ import ( // RFC8628CoreStorage is the storage needed for the DeviceAuthHandler type RFC8628CoreStorage interface { - DeviceCodeStorage - UserCodeStorage + DeviceAuthStorage oauth2.AccessTokenStorage oauth2.RefreshTokenStorage } -// DeviceCodeStorage handles the device_code storage -type DeviceCodeStorage interface { - // CreateDeviceCodeSession stores the device request for a given device code. - CreateDeviceCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error) +// DeviceAuthStorage handles the device auth session storage +type DeviceAuthStorage interface { + // CreateDeviceAuthSession stores the device auth request session. + CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.Requester) (err error) // GetDeviceCodeSession hydrates the session based on the given device code and returns the device request. // If the device code has been invalidated with `InvalidateDeviceCodeSession`, this @@ -30,26 +29,8 @@ type DeviceCodeStorage interface { // Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedDeviceCode error! GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) - // InvalidateDeviceCodeSession is called when a device code is being used. The state of the user + // InvalidateDeviceCodeSession is called when a device code is being used. The state of the device // code should be set to invalid and consecutive requests to GetDeviceCodeSession should return the // ErrInvalidatedDeviceCode error. InvalidateDeviceCodeSession(ctx context.Context, signature string) (err error) } - -// UserCodeStorage handles the user_code storage -type UserCodeStorage interface { - // CreateUserCodeSession stores the device request for a given user code. - CreateUserCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error) - - // GetUserCodeSession hydrates the session based on the given user code and returns the device request. - // If the user code has been invalidated with `InvalidateUserCodeSession`, this - // method should return the ErrInvalidatedUserCode error. - // - // Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedUserCode error! - GetUserCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) - - // InvalidateUserCodeSession is called when a user code is being used. The state of the user - // code should be set to invalid and consecutive requests to GetUserCodeSession should return the - // ErrInvalidatedUserCode error. - InvalidateUserCodeSession(ctx context.Context, signature string) (err error) -} diff --git a/handler/rfc8628/token_handler_test.go b/handler/rfc8628/token_handler_test.go index c8a71d55..7a4a4d36 100644 --- a/handler/rfc8628/token_handler_test.go +++ b/handler/rfc8628/token_handler_test.go @@ -154,9 +154,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) { code, signature, err := strategy.GenerateDeviceCode(context.TODO()) require.NoError(t, err) + _, userCodeSignature, err := strategy.GenerateUserCode(context.TODO()) + require.NoError(t, err) areq.Form.Add("device_code", code) - require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) }, expectErr: fosite.ErrAuthorizationPending, }, @@ -192,9 +194,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) { code, signature, err := strategy.GenerateDeviceCode(context.TODO()) require.NoError(t, err) + _, userCodeSignature, err := strategy.GenerateUserCode(context.TODO()) + require.NoError(t, err) areq.Form.Add("device_code", code) - require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) }, expectErr: fosite.ErrDeviceExpiredToken, }, @@ -227,9 +231,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) { token, signature, err := strategy.GenerateDeviceCode(context.TODO()) require.NoError(t, err) + _, userCodeSignature, err := strategy.GenerateUserCode(context.TODO()) + require.NoError(t, err) areq.Form = url.Values{"device_code": {token}} - require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) }, expectErr: fosite.ErrInvalidGrant, }, @@ -263,9 +269,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) { token, signature, err := strategy.GenerateDeviceCode(context.TODO()) require.NoError(t, err) + _, userCodeSignature, err := strategy.GenerateUserCode(context.TODO()) + require.NoError(t, err) areq.Form = url.Values{"device_code": {token}} - require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) }, }, } @@ -342,9 +350,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest_RateLimiting(t *testing.T) { token, signature, err := strategy.GenerateDeviceCode(context.TODO()) require.NoError(t, err) + _, userCodeSignature, err := strategy.GenerateUserCode(context.TODO()) + require.NoError(t, err) areq.Form = url.Values{"device_code": {token}} - require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) err = h.HandleTokenEndpointRequest(context.Background(), areq) require.NoError(t, err, "%+v", err) err = h.HandleTokenEndpointRequest(context.Background(), areq) @@ -441,9 +451,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, _ *fosite.Config) { code, signature, err := strategy.GenerateDeviceCode(context.TODO()) require.NoError(t, err) + _, userCodeSignature, err := strategy.GenerateUserCode(context.TODO()) + require.NoError(t, err) areq.Form.Add("device_code", code) - require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) }, check: func(t *testing.T, aresp *fosite.AccessResponse) { assert.NotEmpty(t, aresp.AccessToken) @@ -483,9 +495,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { config.RefreshTokenScopes = []string{} code, signature, err := strategy.GenerateDeviceCode(context.TODO()) require.NoError(t, err) + _, userCodeSignature, err := strategy.GenerateUserCode(context.TODO()) + require.NoError(t, err) areq.Form.Add("device_code", code) - require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) }, check: func(t *testing.T, aresp *fosite.AccessResponse) { assert.NotEmpty(t, aresp.AccessToken) @@ -524,9 +538,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, config *fosite.Config) { code, signature, err := strategy.GenerateDeviceCode(context.TODO()) require.NoError(t, err) + _, userCodeSignature, err := strategy.GenerateUserCode(context.TODO()) + require.NoError(t, err) areq.Form.Add("device_code", code) - require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) }, check: func(t *testing.T, aresp *fosite.AccessResponse) { assert.NotEmpty(t, aresp.AccessToken) diff --git a/integration/helper_setup_test.go b/integration/helper_setup_test.go index 61be3b4b..f314e885 100644 --- a/integration/helper_setup_test.go +++ b/integration/helper_setup_test.go @@ -123,9 +123,8 @@ var fositeStore = &storage.MemoryStore{ AccessTokenRequestIDs: map[string]string{}, RefreshTokenRequestIDs: map[string]string{}, PARSessions: map[string]fosite.AuthorizeRequester{}, - DeviceCodes: map[string]fosite.Requester{}, - UserCodes: map[string]fosite.Requester{}, - DeviceCodesRequestIDs: map[string]string{}, + DeviceAuths: map[string]fosite.Requester{}, + DeviceCodesRequestIDs: map[string]storage.DeviceAuthPair{}, UserCodesRequestIDs: map[string]string{}, } diff --git a/internal/device_code_storage.go b/internal/device_code_storage.go index d4363044..cca61326 100644 --- a/internal/device_code_storage.go +++ b/internal/device_code_storage.go @@ -2,12 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ory/fosite/handler/rfc8628 (interfaces: DeviceCodeStorage) -// -// Generated by this command: -// -// mockgen -package internal -destination internal/device_code_storage.go github.com/ory/fosite/handler/rfc8628 DeviceCodeStorage -// +// Source: github.com/ory/fosite/handler/rfc8628 (interfaces: DeviceAuthStorage) + // Package internal is a generated GoMock package. package internal @@ -15,49 +11,49 @@ import ( context "context" reflect "reflect" + gomock "github.com/golang/mock/gomock" fosite "github.com/ory/fosite" - gomock "go.uber.org/mock/gomock" ) -// MockDeviceCodeStorage is a mock of DeviceCodeStorage interface. -type MockDeviceCodeStorage struct { +// MockDeviceAuthStorage is a mock of DeviceAuthStorage interface. +type MockDeviceAuthStorage struct { ctrl *gomock.Controller - recorder *MockDeviceCodeStorageMockRecorder + recorder *MockDeviceAuthStorageMockRecorder } -// MockDeviceCodeStorageMockRecorder is the mock recorder for MockDeviceCodeStorage. -type MockDeviceCodeStorageMockRecorder struct { - mock *MockDeviceCodeStorage +// MockDeviceAuthStorageMockRecorder is the mock recorder for MockDeviceAuthStorage. +type MockDeviceAuthStorageMockRecorder struct { + mock *MockDeviceAuthStorage } -// NewMockDeviceCodeStorage creates a new mock instance. -func NewMockDeviceCodeStorage(ctrl *gomock.Controller) *MockDeviceCodeStorage { - mock := &MockDeviceCodeStorage{ctrl: ctrl} - mock.recorder = &MockDeviceCodeStorageMockRecorder{mock} +// NewMockDeviceAuthStorage creates a new mock instance. +func NewMockDeviceAuthStorage(ctrl *gomock.Controller) *MockDeviceAuthStorage { + mock := &MockDeviceAuthStorage{ctrl: ctrl} + mock.recorder = &MockDeviceAuthStorageMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDeviceCodeStorage) EXPECT() *MockDeviceCodeStorageMockRecorder { +func (m *MockDeviceAuthStorage) EXPECT() *MockDeviceAuthStorageMockRecorder { return m.recorder } -// CreateDeviceCodeSession mocks base method. -func (m *MockDeviceCodeStorage) CreateDeviceCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { +// CreateDeviceAuthSession mocks base method. +func (m *MockDeviceAuthStorage) CreateDeviceAuthSession(arg0 context.Context, arg1, arg2 string, arg3 fosite.Requester) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateDeviceCodeSession", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CreateDeviceAuthSession", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } -// CreateDeviceCodeSession indicates an expected call of CreateDeviceCodeSession. -func (mr *MockDeviceCodeStorageMockRecorder) CreateDeviceCodeSession(arg0, arg1, arg2 any) *gomock.Call { +// CreateDeviceAuthSession indicates an expected call of CreateDeviceAuthSession. +func (mr *MockDeviceAuthStorageMockRecorder) CreateDeviceAuthSession(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDeviceCodeSession", reflect.TypeOf((*MockDeviceCodeStorage)(nil).CreateDeviceCodeSession), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDeviceAuthSession", reflect.TypeOf((*MockDeviceAuthStorage)(nil).CreateDeviceAuthSession), arg0, arg1, arg2, arg3) } // GetDeviceCodeSession mocks base method. -func (m *MockDeviceCodeStorage) GetDeviceCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Session) (fosite.Requester, error) { +func (m *MockDeviceAuthStorage) GetDeviceCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Session) (fosite.Requester, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetDeviceCodeSession", arg0, arg1, arg2) ret0, _ := ret[0].(fosite.Requester) @@ -66,13 +62,28 @@ func (m *MockDeviceCodeStorage) GetDeviceCodeSession(arg0 context.Context, arg1 } // GetDeviceCodeSession indicates an expected call of GetDeviceCodeSession. -func (mr *MockDeviceCodeStorageMockRecorder) GetDeviceCodeSession(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockDeviceAuthStorageMockRecorder) GetDeviceCodeSession(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeviceCodeSession", reflect.TypeOf((*MockDeviceAuthStorage)(nil).GetDeviceCodeSession), arg0, arg1, arg2) +} + +// GetUserCodeSession mocks base method. +func (m *MockDeviceAuthStorage) GetUserCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Session) (fosite.Requester, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserCodeSession", arg0, arg1, arg2) + ret0, _ := ret[0].(fosite.Requester) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserCodeSession indicates an expected call of GetUserCodeSession. +func (mr *MockDeviceAuthStorageMockRecorder) GetUserCodeSession(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeviceCodeSession", reflect.TypeOf((*MockDeviceCodeStorage)(nil).GetDeviceCodeSession), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserCodeSession", reflect.TypeOf((*MockDeviceAuthStorage)(nil).GetUserCodeSession), arg0, arg1, arg2) } // InvalidateDeviceCodeSession mocks base method. -func (m *MockDeviceCodeStorage) InvalidateDeviceCodeSession(arg0 context.Context, arg1 string) error { +func (m *MockDeviceAuthStorage) InvalidateDeviceCodeSession(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "InvalidateDeviceCodeSession", arg0, arg1) ret0, _ := ret[0].(error) @@ -80,7 +91,7 @@ func (m *MockDeviceCodeStorage) InvalidateDeviceCodeSession(arg0 context.Context } // InvalidateDeviceCodeSession indicates an expected call of InvalidateDeviceCodeSession. -func (mr *MockDeviceCodeStorageMockRecorder) InvalidateDeviceCodeSession(arg0, arg1 any) *gomock.Call { +func (mr *MockDeviceAuthStorageMockRecorder) InvalidateDeviceCodeSession(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InvalidateDeviceCodeSession", reflect.TypeOf((*MockDeviceCodeStorage)(nil).InvalidateDeviceCodeSession), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InvalidateDeviceCodeSession", reflect.TypeOf((*MockDeviceAuthStorage)(nil).InvalidateDeviceCodeSession), arg0, arg1) } diff --git a/internal/rfc8628_core_storage.go b/internal/rfc8628_core_storage.go index fa9d1ab5..89c43548 100644 --- a/internal/rfc8628_core_storage.go +++ b/internal/rfc8628_core_storage.go @@ -56,18 +56,18 @@ func (mr *MockRFC8628CoreStorageMockRecorder) CreateAccessTokenSession(arg0, arg return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAccessTokenSession", reflect.TypeOf((*MockRFC8628CoreStorage)(nil).CreateAccessTokenSession), arg0, arg1, arg2) } -// CreateDeviceCodeSession mocks base method. -func (m *MockRFC8628CoreStorage) CreateDeviceCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { +// CreateDeviceAuthSession mocks base method. +func (m *MockRFC8628CoreStorage) CreateDeviceAuthSession(arg0 context.Context, arg1, arg2 string, arg3 fosite.Requester) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateDeviceCodeSession", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CreateDeviceAuthSession", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } -// CreateDeviceCodeSession indicates an expected call of CreateDeviceCodeSession. -func (mr *MockRFC8628CoreStorageMockRecorder) CreateDeviceCodeSession(arg0, arg1, arg2 any) *gomock.Call { +// CreateDeviceAuthSession indicates an expected call of CreateDeviceAuthSession. +func (mr *MockRFC8628CoreStorageMockRecorder) CreateDeviceAuthSession(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDeviceCodeSession", reflect.TypeOf((*MockRFC8628CoreStorage)(nil).CreateDeviceCodeSession), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDeviceAuthSession", reflect.TypeOf((*MockRFC8628CoreStorage)(nil).CreateDeviceAuthSession), arg0, arg1, arg2, arg3) } // CreateRefreshTokenSession mocks base method. @@ -84,20 +84,6 @@ func (mr *MockRFC8628CoreStorageMockRecorder) CreateRefreshTokenSession(arg0, ar return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshTokenSession", reflect.TypeOf((*MockRFC8628CoreStorage)(nil).CreateRefreshTokenSession), arg0, arg1, arg2) } -// CreateUserCodeSession mocks base method. -func (m *MockRFC8628CoreStorage) CreateUserCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateUserCodeSession", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// CreateUserCodeSession indicates an expected call of CreateUserCodeSession. -func (mr *MockRFC8628CoreStorageMockRecorder) CreateUserCodeSession(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUserCodeSession", reflect.TypeOf((*MockRFC8628CoreStorage)(nil).CreateUserCodeSession), arg0, arg1, arg2) -} - // DeleteAccessTokenSession mocks base method. func (m *MockRFC8628CoreStorage) DeleteAccessTokenSession(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() @@ -199,17 +185,3 @@ func (mr *MockRFC8628CoreStorageMockRecorder) InvalidateDeviceCodeSession(arg0, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InvalidateDeviceCodeSession", reflect.TypeOf((*MockRFC8628CoreStorage)(nil).InvalidateDeviceCodeSession), arg0, arg1) } - -// InvalidateUserCodeSession mocks base method. -func (m *MockRFC8628CoreStorage) InvalidateUserCodeSession(arg0 context.Context, arg1 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InvalidateUserCodeSession", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// InvalidateUserCodeSession indicates an expected call of InvalidateUserCodeSession. -func (mr *MockRFC8628CoreStorageMockRecorder) InvalidateUserCodeSession(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InvalidateUserCodeSession", reflect.TypeOf((*MockRFC8628CoreStorage)(nil).InvalidateUserCodeSession), arg0, arg1) -} diff --git a/storage/memory.go b/storage/memory.go index a4e6dde7..7887938e 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -36,21 +36,25 @@ type PublicKeyScopes struct { Scopes []string } +type DeviceAuthPair struct { + d string + u string +} + type MemoryStore struct { Clients map[string]fosite.Client AuthorizeCodes map[string]StoreAuthorizeCode IDSessions map[string]fosite.Requester AccessTokens map[string]fosite.Requester RefreshTokens map[string]StoreRefreshToken - DeviceCodes map[string]fosite.Requester - UserCodes map[string]fosite.Requester + DeviceAuths map[string]fosite.Requester PKCES map[string]fosite.Requester Users map[string]MemoryUserRelation BlacklistedJTIs map[string]time.Time // In-memory request ID to token signatures AccessTokenRequestIDs map[string]string RefreshTokenRequestIDs map[string]string - DeviceCodesRequestIDs map[string]string + DeviceCodesRequestIDs map[string]DeviceAuthPair UserCodesRequestIDs map[string]string // Public keys to check signature in auth grant jwt assertion. IssuerPublicKeys map[string]IssuerPublicKeys @@ -61,15 +65,13 @@ type MemoryStore struct { idSessionsMutex sync.RWMutex accessTokensMutex sync.RWMutex refreshTokensMutex sync.RWMutex - userCodesMutex sync.RWMutex - deviceCodesMutex sync.RWMutex + deviceAuthsMutex sync.RWMutex pkcesMutex sync.RWMutex usersMutex sync.RWMutex blacklistedJTIsMutex sync.RWMutex accessTokenRequestIDsMutex sync.RWMutex refreshTokenRequestIDsMutex sync.RWMutex - userCodesRequestIDsMutex sync.RWMutex - deviceCodesRequestIDsMutex sync.RWMutex + deviceAuthsRequestIDsMutex sync.RWMutex issuerPublicKeysMutex sync.RWMutex parSessionsMutex sync.RWMutex } @@ -81,13 +83,12 @@ func NewMemoryStore() *MemoryStore { IDSessions: make(map[string]fosite.Requester), AccessTokens: make(map[string]fosite.Requester), RefreshTokens: make(map[string]StoreRefreshToken), - DeviceCodes: make(map[string]fosite.Requester), - UserCodes: make(map[string]fosite.Requester), + DeviceAuths: make(map[string]fosite.Requester), PKCES: make(map[string]fosite.Requester), Users: make(map[string]MemoryUserRelation), AccessTokenRequestIDs: make(map[string]string), RefreshTokenRequestIDs: make(map[string]string), - DeviceCodesRequestIDs: make(map[string]string), + DeviceCodesRequestIDs: make(map[string]DeviceAuthPair), UserCodesRequestIDs: make(map[string]string), BlacklistedJTIs: make(map[string]time.Time), IssuerPublicKeys: make(map[string]IssuerPublicKeys), @@ -152,11 +153,10 @@ func NewExampleStore() *MemoryStore { AccessTokens: map[string]fosite.Requester{}, RefreshTokens: map[string]StoreRefreshToken{}, PKCES: map[string]fosite.Requester{}, - DeviceCodes: make(map[string]fosite.Requester), - UserCodes: make(map[string]fosite.Requester), + DeviceAuths: make(map[string]fosite.Requester), AccessTokenRequestIDs: map[string]string{}, RefreshTokenRequestIDs: map[string]string{}, - DeviceCodesRequestIDs: make(map[string]string), + DeviceCodesRequestIDs: make(map[string]DeviceAuthPair), UserCodesRequestIDs: make(map[string]string), IssuerPublicKeys: map[string]IssuerPublicKeys{}, PARSessions: map[string]fosite.AuthorizeRequester{}, @@ -515,41 +515,25 @@ func (s *MemoryStore) DeletePARSession(ctx context.Context, requestURI string) ( return nil } -// CreateDeviceCodeSession stores the device code session -func (s *MemoryStore) CreateDeviceCodeSession(_ context.Context, signature string, req fosite.Requester) error { - // We first lock accessTokenRequestIDsMutex and then accessTokensMutex because this is the same order - // locking happens in RevokeAccessToken and using the same order prevents deadlocks. - s.deviceCodesRequestIDsMutex.Lock() - defer s.deviceCodesRequestIDsMutex.Unlock() - s.deviceCodesMutex.Lock() - defer s.deviceCodesMutex.Unlock() +// CreateDeviceAuthSession stores the device auth session +func (s *MemoryStore) CreateDeviceAuthSession(_ context.Context, deviceCodeSignature, userCodeSignature string, req fosite.Requester) error { + s.deviceAuthsRequestIDsMutex.Lock() + defer s.deviceAuthsRequestIDsMutex.Unlock() + s.deviceAuthsMutex.Lock() + defer s.deviceAuthsMutex.Unlock() - s.DeviceCodes[signature] = req - s.DeviceCodesRequestIDs[req.GetID()] = signature - return nil -} - -// UpdateDeviceCodeSession updates the device code session -func (s *MemoryStore) UpdateDeviceCodeSession(_ context.Context, signature string, req fosite.Requester) error { - s.deviceCodesRequestIDsMutex.Lock() - defer s.deviceCodesRequestIDsMutex.Unlock() - s.deviceCodesMutex.Lock() - defer s.deviceCodesMutex.Unlock() - - // Only update if exist - if _, exists := s.DeviceCodes[signature]; exists { - s.DeviceCodes[signature] = req - s.DeviceCodesRequestIDs[req.GetID()] = signature - } + s.DeviceAuths[deviceCodeSignature] = req + s.DeviceAuths[userCodeSignature] = req + s.DeviceCodesRequestIDs[req.GetID()] = DeviceAuthPair{d: deviceCodeSignature, u: userCodeSignature} return nil } // GetDeviceCodeSession gets the device code session func (s *MemoryStore) GetDeviceCodeSession(_ context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { - s.deviceCodesMutex.RLock() - defer s.deviceCodesMutex.RUnlock() + s.deviceAuthsMutex.RLock() + defer s.deviceAuthsMutex.RUnlock() - rel, ok := s.DeviceCodes[signature] + rel, ok := s.DeviceAuths[signature] if !ok { return nil, fosite.ErrNotFound } @@ -558,46 +542,23 @@ func (s *MemoryStore) GetDeviceCodeSession(_ context.Context, signature string, // InvalidateDeviceCodeSession invalidates the device code session func (s *MemoryStore) InvalidateDeviceCodeSession(_ context.Context, code string) error { - s.deviceCodesRequestIDsMutex.Lock() - defer s.deviceCodesRequestIDsMutex.Unlock() - s.deviceCodesMutex.Lock() - defer s.deviceCodesMutex.Unlock() + s.deviceAuthsRequestIDsMutex.Lock() + defer s.deviceAuthsRequestIDsMutex.Unlock() + s.deviceAuthsMutex.Lock() + defer s.deviceAuthsMutex.Unlock() - delete(s.DeviceCodes, code) - return nil -} - -// CreateUserCodeSession stores the user code session -func (s *MemoryStore) CreateUserCodeSession(_ context.Context, signature string, req fosite.Requester) error { - s.userCodesRequestIDsMutex.Lock() - defer s.userCodesRequestIDsMutex.Unlock() - s.userCodesMutex.Lock() - defer s.userCodesMutex.Unlock() - - s.UserCodes[signature] = req - s.UserCodesRequestIDs[req.GetID()] = signature + delete(s.DeviceAuths, code) return nil } // GetUserCodeSession gets the user code session func (s *MemoryStore) GetUserCodeSession(_ context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { - s.userCodesMutex.RLock() - defer s.userCodesMutex.RUnlock() + s.deviceAuthsMutex.RLock() + defer s.deviceAuthsMutex.RUnlock() - rel, ok := s.UserCodes[signature] + rel, ok := s.DeviceAuths[signature] if !ok { return nil, fosite.ErrNotFound } return rel, nil } - -// GetUserCodeSession invalidates the user code session -func (s *MemoryStore) InvalidateUserCodeSession(_ context.Context, code string) error { - s.userCodesRequestIDsMutex.Lock() - defer s.userCodesRequestIDsMutex.Unlock() - s.userCodesMutex.Lock() - defer s.userCodesMutex.Unlock() - - delete(s.UserCodes, code) - return nil -}