Skip to content

Commit

Permalink
WIP: refactor Transactional interface
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr authored and aeneasr committed May 7, 2024
1 parent 5e039ca commit 199f98c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 127 deletions.
34 changes: 12 additions & 22 deletions handler/oauth2/flow_authorize_code_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,26 +152,21 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex
}
}

ctx, err = storage.MaybeBeginTx(ctx, c.CoreStorage)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
defer func() {
if err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil {
err = errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr))
}
if err := storage.MaybeTransaction(ctx, c.CoreStorage, func(ctx context.Context) error {
if err := c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
}()

if err = c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
} else if err = c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
} else if refreshSignature != "" {
if err = c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil {
if err := c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
if refreshSignature != "" {
if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
}
return nil
}); err != nil {
return err // error already wrapped inside tx callback
}

responder.SetAccessToken(access)
Expand All @@ -182,11 +177,6 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex
if refresh != "" {
responder.SetExtra("refresh_token", refresh)
}

if err = storage.MaybeCommitTx(ctx, c.CoreStorage); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

return nil
}

Expand Down
93 changes: 33 additions & 60 deletions handler/oauth2/flow_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,29 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con

signature := c.RefreshTokenStrategy.RefreshTokenSignature(ctx, requester.GetRequestForm().Get("refresh_token"))

ctx, err = storage.MaybeBeginTx(ctx, c.TokenRevocationStorage)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
defer func() {
err = c.handleRefreshTokenEndpointStorageError(ctx, err)
}()

ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
if err != nil {
return err
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil {
return err
}

if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil {
return err
}

storeReq := requester.Sanitize([]string{})
storeReq.SetID(ts.GetID())

if err = c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
return err
}

if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil {
return err
if err := storage.MaybeTransaction(ctx, c.TokenRevocationStorage, func(ctx context.Context) error {
ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
if err != nil {
return err
}
if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil {
return err
}
if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil {
return err
}
storeReq := requester.Sanitize([]string{})
storeReq.SetID(ts.GetID())
if err := c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
return err
}
return c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq)
}); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

responder.SetAccessToken(accessToken)
Expand All @@ -163,10 +158,6 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
responder.SetScopes(requester.GetGrantedScopes())
responder.SetExtra("refresh_token", refreshToken)

if err = storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return err
}

return nil
}

Expand All @@ -179,45 +170,27 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
// legitimate client is trying to access, in case of such an access
// attempt the valid refresh token and the access authorization
// associated with it are both revoked.
func (c *RefreshTokenGrantHandler) handleRefreshTokenReuse(ctx context.Context, signature string, req fosite.Requester) (err error) {
ctx, err = storage.MaybeBeginTx(ctx, c.TokenRevocationStorage)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
defer func() {
err = c.handleRefreshTokenEndpointStorageError(ctx, err)
}()

if err = c.TokenRevocationStorage.DeleteRefreshTokenSession(ctx, signature); err != nil {
return err
} else if err = c.TokenRevocationStorage.RevokeRefreshToken(
ctx, req.GetID(),
); err != nil && !errors.Is(err, fosite.ErrNotFound) {
return err
} else if err = c.TokenRevocationStorage.RevokeAccessToken(
ctx, req.GetID(),
); err != nil && !errors.Is(err, fosite.ErrNotFound) {
return err
}

if err = storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return err
}

return nil
func (c *RefreshTokenGrantHandler) handleRefreshTokenReuse(ctx context.Context, signature string, req fosite.Requester) error {
err := storage.MaybeTransaction(ctx, c.TokenRevocationStorage, func(ctx context.Context) error {
if err := c.TokenRevocationStorage.DeleteRefreshTokenSession(ctx, signature); err != nil {
return err
}
if err := c.TokenRevocationStorage.RevokeRefreshToken(ctx, req.GetID()); err != nil && !errors.Is(err, fosite.ErrNotFound) {
return err
}
if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, req.GetID()); err != nil && !errors.Is(err, fosite.ErrNotFound) {
return err
}
return nil
})
return c.handleRefreshTokenEndpointStorageError(ctx, err)
}

func (c *RefreshTokenGrantHandler) handleRefreshTokenEndpointStorageError(ctx context.Context, storageErr error) (err error) {
if storageErr == nil {
return nil
}

defer func() {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rollBackTxnErr != nil {
err = errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr))
}
}()

if errors.Is(storageErr, fosite.ErrSerializationFailure) {
return errorsx.WithStack(fosite.ErrInvalidRequest.
WithDebugf(storageErr.Error()).
Expand Down
58 changes: 13 additions & 45 deletions storage/transactional.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,22 @@ package storage

import "context"

// A storage provider that has support for transactions should implement this interface to ensure atomicity for certain flows
// that require transactional semantics. Fosite will call these methods (when atomicity is required) if and only if the storage
// provider has implemented `Transactional`. It is expected that the storage provider will examine context for an existing transaction
// each time a database operation is to be performed.
// A storage provider that has support for transactions should implement this
// interface to ensure atomicity for certain flows that require transactional
// semantics. When atomicity is required, Fosite will group calls to the storage
// provider in a function and passes that to Transaction. Implementations are
// expected to execute these calls in a transactional manner. Typically, a
// handle to the transaction will be stored in the context.
//
// An implementation of `BeginTX` should attempt to initiate a new transaction and store that under a unique key
// in the context that can be accessible by `Commit` and `Rollback`. The "transactional aware" context will then be
// returned for further propagation, eventually to be consumed by `Commit` or `Rollback` to finish the transaction.
//
// Implementations for `Commit` & `Rollback` should look for the transaction object inside the supplied context using the same
// key used by `BeginTX`. If these methods have been called, it is expected that a txn object should be available in the provided
// context.
// Implementations should rollback (or retry) the transaction if the callback
// returns an error.
type Transactional interface {
BeginTX(ctx context.Context) (context.Context, error)
Commit(ctx context.Context) error
Rollback(ctx context.Context) error
}

// MaybeBeginTx is a helper function that can be used to initiate a transaction if the supplied storage
// implements the `Transactional` interface.
func MaybeBeginTx(ctx context.Context, storage interface{}) (context.Context, error) {
// the type assertion checks whether the dynamic type of `storage` implements `Transactional`
txnStorage, transactional := storage.(Transactional)
if transactional {
return txnStorage.BeginTX(ctx)
} else {
return ctx, nil
}
}

// MaybeCommitTx is a helper function that can be used to commit a transaction if the supplied storage
// implements the `Transactional` interface.
func MaybeCommitTx(ctx context.Context, storage interface{}) error {
txnStorage, transactional := storage.(Transactional)
if transactional {
return txnStorage.Commit(ctx)
} else {
return nil
}
Transaction(context.Context, func(context.Context) error) error
}

// MaybeRollbackTx is a helper function that can be used to rollback a transaction if the supplied storage
// implements the `Transactional` interface.
func MaybeRollbackTx(ctx context.Context, storage interface{}) error {
txnStorage, transactional := storage.(Transactional)
if transactional {
return txnStorage.Rollback(ctx)
} else {
return nil
func MaybeTransaction(ctx context.Context, storage any, f func(context.Context) error) error {
if tx, ok := storage.(Transactional); ok {
return tx.Transaction(ctx, f)
}
return f(ctx)
}

0 comments on commit 199f98c

Please sign in to comment.