diff --git a/auth/cookie.go b/auth/cookie.go index f93679470..113206a54 100644 --- a/auth/cookie.go +++ b/auth/cookie.go @@ -163,10 +163,11 @@ func NewRedirectCookie(ctx context.Context, redirectURL string) *http.Cookie { } } +// GetAuthFlowEndRedirect returns the redirect URI according to data in request. // At the end of the OAuth flow, the server needs to send the user somewhere. This should have been stored as a cookie // during the initial /login call. If that cookie is missing from the request, it will default to the one configured // in this package's Config object. -func getAuthFlowEndRedirect(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request) string { +func GetAuthFlowEndRedirect(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request) string { queryParams := request.URL.Query() // Use the redirect URL specified in the request if one is available. if redirectURL := queryParams.Get(RedirectURLParameter); len(redirectURL) > 0 { diff --git a/auth/cookie_test.go b/auth/cookie_test.go index b525ba8b4..575dffe21 100644 --- a/auth/cookie_test.go +++ b/auth/cookie_test.go @@ -9,11 +9,12 @@ import ( "net/url" "testing" - "github.com/flyteorg/flyteadmin/auth/config" - "github.com/flyteorg/flyteadmin/auth/interfaces/mocks" stdConfig "github.com/flyteorg/flytestdlib/config" "github.com/gorilla/securecookie" "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyteadmin/auth/config" + "github.com/flyteorg/flyteadmin/auth/interfaces/mocks" ) func mustParseURL(t testing.TB, u string) url.URL { @@ -131,7 +132,7 @@ func TestGetAuthFlowEndRedirect(t *testing.T) { assert.NotNil(t, cookie) request.AddCookie(cookie) mockAuthCtx := &mocks.AuthenticationContext{} - redirect := getAuthFlowEndRedirect(ctx, mockAuthCtx, request) + redirect := GetAuthFlowEndRedirect(ctx, mockAuthCtx, request) assert.Equal(t, "/console", redirect) }) @@ -145,7 +146,7 @@ func TestGetAuthFlowEndRedirect(t *testing.T) { RedirectURL: stdConfig.URL{URL: mustParseURL(t, "/api/v1/projects")}, }, }) - redirect := getAuthFlowEndRedirect(ctx, mockAuthCtx, request) + redirect := GetAuthFlowEndRedirect(ctx, mockAuthCtx, request) assert.Equal(t, "/api/v1/projects", redirect) }) } diff --git a/auth/handlers.go b/auth/handlers.go index 0ee2cc776..94a28e341 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -48,6 +48,7 @@ func (e *PreRedirectHookError) Error() string { // PreRedirectHookError is the error interface which allows the user to set correct http status code and Message to be set in case the function returns an error // without which the current usage in GetCallbackHandler will set this to InternalServerError type PreRedirectHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) *PreRedirectHookError +type LogoutHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) error type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD type UserInfoForwardResponseHandler func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error @@ -68,7 +69,7 @@ func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer, handler.HandleFunc(fmt.Sprintf("/%s", OIdCMetadataEndpoint), GetOIdCMetadataEndpointRedirectHandler(ctx, authCtx)) // These endpoints require authentication - handler.HandleFunc("/logout", GetLogoutEndpointHandler(ctx, authCtx)) + handler.HandleFunc("/logout", GetLogoutEndpointHandler(ctx, authCtx, pluginRegistry)) } // Look for access token and refresh token, if both are present and the access token is expired, then attempt to @@ -123,7 +124,7 @@ func RefreshTokensIfExists(ctx context.Context, authCtx interfaces.Authenticatio return } - redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) + redirectURL := GetAuthFlowEndRedirect(ctx, authCtx, request) http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) } } @@ -210,7 +211,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo } logger.Info(ctx, "Successfully called the preRedirect hook") } - redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) + redirectURL := GetAuthFlowEndRedirect(ctx, authCtx, request) http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) } } @@ -466,9 +467,19 @@ func GetOIdCMetadataEndpointRedirectHandler(ctx context.Context, authCtx interfa } } -func GetLogoutEndpointHandler(ctx context.Context, authCtx interfaces.AuthenticationContext) http.HandlerFunc { +func GetLogoutEndpointHandler(ctx context.Context, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) http.HandlerFunc { return func(writer http.ResponseWriter, request *http.Request) { - logger.Debugf(ctx, "Deleting auth cookies") + hook := plugins.Get[LogoutHookFunc](pluginRegistry, plugins.PluginIDLogoutHook) + if hook != nil { + if err := hook(ctx, authCtx, request, writer); err != nil { + logger.Errorf(ctx, "logout hook failed: %v", err) + writer.WriteHeader(http.StatusInternalServerError) + return + } + logger.Debugf(ctx, "logout hook called") + } + + logger.Debugf(ctx, "deleting auth cookies") authCtx.CookieManager().DeleteCookies(ctx, writer) // Redirect if one was given diff --git a/auth/handlers_test.go b/auth/handlers_test.go index 449b13c4a..d6cfac743 100644 --- a/auth/handlers_test.go +++ b/auth/handlers_test.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "fmt" "io" "net/http" @@ -11,8 +12,11 @@ import ( "testing" "github.com/coreos/go-oidc" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + stdConfig "github.com/flyteorg/flytestdlib/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "golang.org/x/oauth2" "google.golang.org/protobuf/types/known/structpb" @@ -21,8 +25,6 @@ import ( "github.com/flyteorg/flyteadmin/auth/interfaces/mocks" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/plugins" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - stdConfig "github.com/flyteorg/flytestdlib/config" ) const ( @@ -50,8 +52,8 @@ func setupMockedAuthContextAtEndpoint(endpoint string) *mocks.AuthenticationCont Timeout: IdpConnectionTimeout, } mockAuthCtx.OnCookieManagerMatch().Return(mockCookieHandler) - mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() mockAuthCtx.OnOAuth2ClientConfigMatch(mock.Anything).Return(&dummyOAuth2Config) mockAuthCtx.OnGetHTTPClient().Return(dummyHTTPClient) return mockAuthCtx @@ -255,6 +257,97 @@ func TestGetLoginHandler(t *testing.T) { assert.True(t, strings.Contains(w.Header().Get("Set-Cookie"), "flyte_csrf_state=")) } +func TestGetLogoutHandler(t *testing.T) { + ctx := context.Background() + + t.Run("no_hook_no_redirect", func(t *testing.T) { + cookieHandler := &CookieManager{} + authCtx := mocks.AuthenticationContext{} + authCtx.OnCookieManager().Return(cookieHandler).Once() + w := httptest.NewRecorder() + r := plugins.NewRegistry() + req, err := http.NewRequest(http.MethodGet, "/logout", nil) + require.NoError(t, err) + + GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + require.Len(t, w.Result().Cookies(), 3) + authCtx.AssertExpectations(t) + }) + + t.Run("no_hook_with_redirect", func(t *testing.T) { + ctx := context.Background() + cookieHandler := &CookieManager{} + authCtx := mocks.AuthenticationContext{} + authCtx.OnCookieManager().Return(cookieHandler).Once() + w := httptest.NewRecorder() + r := plugins.NewRegistry() + req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil) + require.NoError(t, err) + + GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + authCtx.AssertExpectations(t) + require.Len(t, w.Result().Cookies(), 3) + }) + + t.Run("with_hook_with_redirect", func(t *testing.T) { + ctx := context.Background() + cookieHandler := &CookieManager{} + authCtx := mocks.AuthenticationContext{} + authCtx.OnCookieManager().Return(cookieHandler).Once() + w := httptest.NewRecorder() + r := plugins.NewRegistry() + hook := new(mock.Mock) + err := r.Register(plugins.PluginIDLogoutHook, LogoutHookFunc(func( + ctx context.Context, + authCtx interfaces.AuthenticationContext, + request *http.Request, + w http.ResponseWriter) error { + return hook.MethodCalled("hook").Error(0) + })) + hook.On("hook").Return(nil).Once() + require.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil) + require.NoError(t, err) + + GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + require.Len(t, w.Result().Cookies(), 3) + authCtx.AssertExpectations(t) + hook.AssertExpectations(t) + }) + + t.Run("hook_error", func(t *testing.T) { + ctx := context.Background() + authCtx := mocks.AuthenticationContext{} + w := httptest.NewRecorder() + r := plugins.NewRegistry() + hook := new(mock.Mock) + err := r.Register(plugins.PluginIDLogoutHook, LogoutHookFunc(func( + ctx context.Context, + authCtx interfaces.AuthenticationContext, + request *http.Request, + w http.ResponseWriter) error { + return hook.MethodCalled("hook").Error(0) + })) + hook.On("hook").Return(errors.New("fail")).Once() + require.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil) + require.NoError(t, err) + + GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Empty(t, w.Result().Cookies()) + authCtx.AssertExpectations(t) + hook.AssertExpectations(t) + }) +} + func TestGetHTTPRequestCookieToMetadataHandler(t *testing.T) { ctx := context.Background() // These were generated for unit testing only. diff --git a/plugins/registry.go b/plugins/registry.go index 14682f7e8..90389008b 100644 --- a/plugins/registry.go +++ b/plugins/registry.go @@ -13,6 +13,7 @@ const ( PluginIDDataProxy PluginID = "DataProxy" PluginIDUnaryServiceMiddleware PluginID = "UnaryServiceMiddleware" PluginIDPreRedirectHook PluginID = "PreRedirectHook" + PluginIDLogoutHook PluginID = "LogoutHook" ) type AtomicRegistry struct { diff --git a/plugins/registry_test.go b/plugins/registry_test.go index 0737c1281..15b0cb93b 100644 --- a/plugins/registry_test.go +++ b/plugins/registry_test.go @@ -41,6 +41,26 @@ func TestRedirectHook(t *testing.T) { assert.Equal(t, fmt.Errorf("redirect hook error"), err) } +type LogoutHook func(context.Context) error + +func TestLogoutHook(t *testing.T) { + ar := NewAtomicRegistry(nil) + r := NewRegistry() + + hook := LogoutHook(func(ctx context.Context) error { + return fmt.Errorf("redirect hook error") + }) + err := r.Register(PluginIDLogoutHook, hook) + assert.NoError(t, err) + + ar.Store(r) + r = ar.Load() + fn := Get[LogoutHook](r, PluginIDLogoutHook) + err = fn(context.Background()) + + assert.Equal(t, fmt.Errorf("redirect hook error"), err) +} + func TestRegistry_RegisterDefault(t *testing.T) { r := NewRegistry() r.RegisterDefault("hello", 5)