From 3a31488af8fc5240f2cd01a4d677564f8b2c3173 Mon Sep 17 00:00:00 2001 From: Eder Ignatowicz Date: Thu, 19 Dec 2024 09:53:27 -0500 Subject: [PATCH] feat(bff): authorize endpoints based on kubeflow-userid and kubeflow-groups header Signed-off-by: Eder Ignatowicz --- clients/ui/bff/README.md | 85 ++++++++--- clients/ui/bff/internal/api/app.go | 47 +++--- clients/ui/bff/internal/api/middleware.go | 135 ++++++++++++++---- .../internal/api/model_registry_handler.go | 8 +- .../api/model_registry_handler_test.go | 12 +- .../internal/api/model_versions_handler.go | 10 +- .../api/model_versions_handler_test.go | 22 +-- .../ui/bff/internal/api/namespaces_handler.go | 9 +- .../internal/api/registered_models_handler.go | 12 +- .../api/registered_models_handler_test.go | 26 ++-- clients/ui/bff/internal/api/test_utils.go | 11 +- clients/ui/bff/internal/integrations/http.go | 15 +- clients/ui/bff/internal/integrations/k8s.go | 83 ++++++++--- clients/ui/bff/internal/mocks/http_mock.go | 4 + clients/ui/bff/internal/mocks/k8s_mock.go | 133 ++++++++++++++++- .../ui/bff/internal/mocks/k8s_mock_test.go | 80 ++++++++++- .../internal/repositories/model_registry.go | 4 +- .../repositories/model_registry_test.go | 30 +++- .../ui/bff/internal/repositories/namespace.go | 4 +- 19 files changed, 574 insertions(+), 156 deletions(-) diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index f700a52a..0befdc45 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -58,7 +58,7 @@ make docker-build |----------------------------------------------------------------------------------------------|----------------------------------------------|-------------------------------------------------------------| | GET /v1/healthcheck | HealthcheckHandler | Show application information. | | GET /v1/user | UserHandler | Show "kubeflow-user-id" from header information. | -| GET /v1/namespaces | NamespacesHandler | Get all user namespaces. | +| GET /v1/namespaces | NamespacesHandler | Get all user namespaces. (only enabled in devmode) | | GET /v1/model_registry | ModelRegistryHandler | Get all model registries, | | GET /v1/model_registry/{model_registry_id}/registered_models | GetAllRegisteredModelsHandler | Gets a list of all RegisteredModel entities. | | POST /v1/model_registry/{model_registry_id}/registered_models | CreateRegisteredModelHandler | Create a RegisteredModel entity. | @@ -72,32 +72,51 @@ make docker-build | GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | GetAllModelArtifactsByModelVersionHandler | Get all ModelArtifact entities by ModelVersion ID | | POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | CreateModelArtifactByModelVersion | Create a ModelArtifact entity for a specific ModelVersion | +Note: Most API paths require the namespace parameter to be passed as a query parameter. +The only exceptions are the health check (/v1/healthcheck) and user (/v1/user) paths, which do not require the namespace parameter. + ### Sample local calls -You will need to inject your requests with a kubeflow-userid header for authorization purposes. When running the service with the mocked Kubernetes client (MOCK_K8S_CLIENT=true), the user user@example.com is preconfigured with the necessary RBAC permissions to perform these actions. +You will need to inject your requests with a `kubeflow-userid` header and namespace for authorization purposes. + +When running the service with the mocked Kubernetes client (MOCK_K8S_CLIENT=true), the user `user@example.com` is preconfigured with the necessary RBAC permissions to perform these actions. ``` # GET /v1/healthcheck -curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/healthcheck +curl -i -H "kubeflow-userid: user@example.com" "localhost:4000/api/v1/healthcheck" ``` ``` # GET /v1/user -curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/user +curl -i -H "kubeflow-userid: user@example.com" "localhost:4000/api/v1/user" ``` ``` -# GET /v1/namespaces -curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/namespaces +# GET /v1/namespaces (only works when DEV_MODE=true) +curl -i -H "kubeflow-userid: user@example.com" "localhost:4000/api/v1/namespaces" ``` ``` # GET /v1/model_registry -curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/model_registry +curl -i -H "kubeflow-userid: user@example.com" "localhost:4000/api/v1/model_registry?namespace=kubeflow" +``` +``` +# GET /v1/model_registry using groups permissions +curl -i \ + -H "kubeflow-userid: non-user@example.com" \ + -H "kubeflow-groups: dora-namespace-group ,group2,group3" \ + "http://localhost:4000/api/v1/model_registry?namespace=dora-namespace" ``` ``` # GET /v1/model_registry/{model_registry_id}/registered_models -curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/model_registry/model-registry/registered_models +curl -i -H "kubeflow-userid: user@example.com" "localhost:4000/api/v1/model_registry/model-registry/registered_models?namespace=kubeflow" +``` +``` +# GET /v1/model_registry/{model_registry_id}/registered_models using group permissions +curl -i \ + -H "kubeflow-userid: non-user@example.com" \ + -H "kubeflow-groups: dora-namespace-group ,dora-service-group,group3" \ + "http://localhost:4000/api/v1/model_registry/model-registry-dora/registered_models?namespace=dora-namespace" ``` ``` #POST /v1/model_registry/{model_registry_id}/registered_models -curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \ +curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?namespace=kubeflow" \ -H "Content-Type: application/json" \ -d '{ "data": { "customProperties": { @@ -115,11 +134,11 @@ curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/ap ``` ``` # GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} -curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/model_registry/model-registry/registered_models/1 +curl -i -H "kubeflow-userid: user@example.com" "localhost:4000/api/v1/model_registry/model-registry/registered_models/1?namespace=kubeflow" ``` ``` # PATCH /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} -curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1" \ +curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1?namespace=kubeflow" \ -H "Content-Type: application/json" \ -d '{ "data": { "description": "New description" @@ -127,11 +146,11 @@ curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/a ``` ``` # GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} -curl -i -H "kubeflow-userid: user@example.com" http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1 +curl -i -H "kubeflow-userid: user@example.com" "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow" ``` ``` # POST /api/v1/model_registry/{model_registry_id}/model_versions -curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions" \ +curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow" \ -H "Content-Type: application/json" \ -d '{ "data": { "customProperties": { @@ -150,7 +169,7 @@ curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/ap ``` ``` # PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} -curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1" \ +curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow" \ -H "Content-Type: application/json" \ -d '{ "data": { "description": "New description 2" @@ -158,11 +177,11 @@ curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/a ``` ``` # GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions -curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions +curl -i -H "kubeflow-userid: user@example.com" "localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions?namespace=kubeflow" ``` ``` # POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions -curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions" \ +curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions?namespace=kubeflow" \ -H "Content-Type: application/json" \ -d '{ "data": { "customProperties": { @@ -176,16 +195,16 @@ curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/ap "name": "ModelVersion One", "state": "LIVE", "author": "alex", - "registeredModelId: "1" + "registeredModelId": "1" }}' ``` ``` -# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts -curl -i -H "kubeflow-userid: user@example.com" http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts +# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts +curl -i -H "kubeflow-userid: user@example.com" "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts?namespace=kubeflow" ``` ``` # POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts -curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts" \ +curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts?namespace=kubeflow" \ -H "Content-Type: application/json" \ -d '{ "data": { "customProperties": { @@ -246,13 +265,35 @@ The mock Kubernetes environment is activated when the environment variable `MOCK - **Namespaces**: - `kubeflow` - `dora-namespace` + - `bella-namespace` - **Users**: - `user@example.com` (has `cluster-admin` privileges) - `doraNonAdmin@example.com` (restricted to the `dora-namespace`) + - `bellaNonAdmin@example.com` (restricted to the `bella-namespace`) +- **Groups**: + - `dora-service-group` (has access to `model-registry-dora` inside `dora-namespace`) + - `dora-namespace-group` (has access to the `dora-namespace`) + - **Services (Model Registries)**: - `model-registry`: resides in the `kubeflow` namespace with the label `component: model-registry`. - - `model-registry-dora`: resides in the `dora-namespace` namespace with the label `component: model-registry`. - - `model-registry-bella`: resides in the `kubeflow` namespace with the label `component: model-registry`. + - `model-registry-one`: resides in the `kubeflow` namespace with the label `component: model-registry`. - `non-model-registry`: resides in the `kubeflow` namespace *without* the label `component: model-registry`. + - `model-registry-dora`: resides in the `dora-namespace` namespace with the label `component: model-registry`. + +#### 3. How BFF authorization works for kubeflow-userid and kubeflow-groups? + +Authorization is performed using Kubernetes SubjectAccessReview (SAR), which validates user access to resources. + +- `kubeflow-userid`: Required header that specifies the user’s email. Access is checked directly for the user via SAR. +- `kubeflow-groups`: Optional header with a comma-separated list of groups. If the user does not have access, SAR checks group permissions using OR logic. If any group has access, the request is authorized. + + +Access to Model Registry List: +- To list all model registries (/v1/model_registry), we perform a SAR check for get and list verbs on services within the specified namespace. +- If the user or any group has permission to get and list services in the namespace, the request is authorized. + +Access to Specific Model Registry Endpoints: +- For other endpoints (e.g., /v1/model_registry/{model_registry_id}/...), we perform a SAR check for get and list verbs on the specific service (identified by model_registry_id) within the namespace. +- If the user or any group has permission to get or list the specific service, the request is authorized. diff --git a/clients/ui/bff/internal/api/app.go b/clients/ui/bff/internal/api/app.go index 6f2d985f..52425e67 100644 --- a/clients/ui/bff/internal/api/app.go +++ b/clients/ui/bff/internal/api/app.go @@ -14,7 +14,8 @@ import ( ) const ( - Version = "1.0.0" + Version = "1.0.0" + PathPrefix = "/api/v1" ModelRegistryId = "model_registry_id" RegisteredModelId = "registered_model_id" @@ -89,33 +90,29 @@ func (app *App) Routes() http.Handler { router.NotFound = http.HandlerFunc(app.notFoundResponse) router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse) - // HTTP client routes + // HTTP client routes (requests that we forward to Model Registry API) + // on those, we perform SAR on Specific Service on a given namespace router.GET(HealthCheckPath, app.HealthcheckHandler) - router.GET(RegisteredModelListPath, app.AttachRESTClient(app.GetAllRegisteredModelsHandler)) - router.GET(RegisteredModelPath, app.AttachRESTClient(app.GetRegisteredModelHandler)) - router.POST(RegisteredModelListPath, app.AttachRESTClient(app.CreateRegisteredModelHandler)) - router.PATCH(RegisteredModelPath, app.AttachRESTClient(app.UpdateRegisteredModelHandler)) - router.GET(RegisteredModelVersionsPath, app.AttachRESTClient(app.GetAllModelVersionsForRegisteredModelHandler)) - router.POST(RegisteredModelVersionsPath, app.AttachRESTClient(app.CreateModelVersionForRegisteredModelHandler)) - router.GET(ModelVersionPath, app.AttachRESTClient(app.GetModelVersionHandler)) - router.POST(ModelVersionListPath, app.AttachRESTClient(app.CreateModelVersionHandler)) - router.PATCH(ModelVersionPath, app.AttachRESTClient(app.UpdateModelVersionHandler)) - router.GET(ModelVersionArtifactListPath, app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler)) - router.POST(ModelVersionArtifactListPath, app.AttachRESTClient(app.CreateModelArtifactByModelVersionHandler)) - router.PATCH(ModelRegistryPath, app.AttachRESTClient(app.UpdateModelVersionHandler)) - - // Kubernetes client routes + router.GET(RegisteredModelListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllRegisteredModelsHandler)))) + router.GET(RegisteredModelPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetRegisteredModelHandler)))) + router.POST(RegisteredModelListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateRegisteredModelHandler)))) + router.PATCH(RegisteredModelPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateRegisteredModelHandler)))) + router.GET(RegisteredModelVersionsPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelVersionsForRegisteredModelHandler)))) + router.POST(RegisteredModelVersionsPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelVersionForRegisteredModelHandler)))) + router.GET(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient((app.GetModelVersionHandler))))) + router.POST(ModelVersionListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelVersionHandler)))) + router.PATCH(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateModelVersionHandler)))) + router.GET(ModelVersionArtifactListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler)))) + router.POST(ModelVersionArtifactListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelArtifactByModelVersionHandler)))) + router.PATCH(ModelRegistryPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateModelVersionHandler)))) + + // Kubernetes routes router.GET(UserPath, app.UserHandler) - router.GET(ModelRegistryListPath, app.ModelRegistryHandler) + // Perform SAR to Get List Services by Namspace + router.GET(ModelRegistryListPath, app.AttachNamespace(app.PerformSARonGetListServicesByNamespace(app.ModelRegistryHandler))) if app.config.DevMode { - router.GET(NamespaceListPath, app.GetNamespacesHandler) - } - - accessControlExemptPaths := map[string]struct{}{ - HealthCheckPath: {}, - UserPath: {}, - NamespaceListPath: {}, + router.GET(NamespaceListPath, app.AttachNamespace(app.GetNamespacesHandler)) } - return app.RecoverPanic(app.enableCORS(app.RequireAccessControl(app.InjectUserHeaders(router), accessControlExemptPaths))) + return app.RecoverPanic(app.enableCORS(app.InjectUserHeaders(router))) } diff --git a/clients/ui/bff/internal/api/middleware.go b/clients/ui/bff/internal/api/middleware.go index 2b826736..00a41b66 100644 --- a/clients/ui/bff/internal/api/middleware.go +++ b/clients/ui/bff/internal/api/middleware.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "strings" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/ui/bff/internal/config" @@ -14,11 +15,12 @@ import ( type contextKey string const ( - httpClientKey contextKey = "httpClientKey" + ModelRegistryHttpClientKey contextKey = "ModelRegistryHttpClientKey" + NamespaceHeaderParameterKey contextKey = "namespace" //Kubeflow authorization operates using custom authentication headers: // Note: The functionality for `kubeflow-groups` is not fully operational at Kubeflow platform at this time - // But it will be soon implemented on Model Registry BFF + // but it's supported on Model Registry BFF KubeflowUserIdKey contextKey = "kubeflowUserId" // kubeflow-userid :contains the user's email address KubeflowUserIDHeader = "kubeflow-userid" KubeflowUserGroupsKey contextKey = "kubeflowUserGroups" // kubeflow-groups : Holds a comma-separated list of user groups @@ -41,17 +43,28 @@ func (app *App) RecoverPanic(next http.Handler) http.Handler { func (app *App) InjectUserHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - userId := r.Header.Get(KubeflowUserIDHeader) - userGroups := r.Header.Get(KubeflowUserGroupsIdHeader) - - //Note: The functionality for `kubeflow-groups` is not fully operational at Kubeflow platform at this time - if userId == "" { + userIdHeader := r.Header.Get(KubeflowUserIDHeader) + userGroupsHeader := r.Header.Get(KubeflowUserGroupsIdHeader) + //`kubeflow-userid`: Contains the user's email address. + if userIdHeader == "" { app.badRequestResponse(w, r, errors.New("missing required header: kubeflow-userid")) return } + // Note: The functionality for `kubeflow-groups` is not fully operational at Kubeflow platform at this time + // but it's supported on Model Registry BFF + //`kubeflow-groups`: Holds a comma-separated list of user groups. + var userGroups []string + if userGroupsHeader != "" { + userGroups = strings.Split(userGroupsHeader, ",") + // Trim spaces from each group name + for i, group := range userGroups { + userGroups[i] = strings.TrimSpace(group) + } + } + ctx := r.Context() - ctx = context.WithValue(ctx, KubeflowUserIdKey, userId) + ctx = context.WithValue(ctx, KubeflowUserIdKey, userIdHeader) ctx = context.WithValue(ctx, KubeflowUserGroupsKey, userGroups) next.ServeHTTP(w, r.WithContext(ctx)) @@ -68,29 +81,35 @@ func (app *App) enableCORS(next http.Handler) http.Handler { }) } -func (app *App) AttachRESTClient(handler func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { +func (app *App) AttachRESTClient(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { modelRegistryID := ps.ByName(ModelRegistryId) - modelRegistryBaseURL, err := resolveModelRegistryURL(modelRegistryID, app.kubernetesClient, app.config) + namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + if !ok || namespace == "" { + app.badRequestResponse(w, r, fmt.Errorf("missing namespace in the context")) + } + + modelRegistryBaseURL, err := resolveModelRegistryURL(namespace, modelRegistryID, app.kubernetesClient, app.config) if err != nil { - app.serverErrorResponse(w, r, fmt.Errorf("failed to resolve model registry base URL): %v", err)) + app.notFoundResponse(w, r) return } - client, err := integrations.NewHTTPClient(modelRegistryBaseURL) + client, err := integrations.NewHTTPClient(modelRegistryID, modelRegistryBaseURL) if err != nil { app.serverErrorResponse(w, r, fmt.Errorf("failed to create Kubernetes client: %v", err)) return } - ctx := context.WithValue(r.Context(), httpClientKey, client) - handler(w, r.WithContext(ctx), ps) + ctx := context.WithValue(r.Context(), ModelRegistryHttpClientKey, client) + next(w, r.WithContext(ctx), ps) } } -func resolveModelRegistryURL(id string, client integrations.KubernetesClientInterface, config config.EnvConfig) (string, error) { - serviceDetails, err := client.GetServiceDetailsByName(id) +func resolveModelRegistryURL(namespace string, serviceName string, client integrations.KubernetesClientInterface, config config.EnvConfig) (string, error) { + + serviceDetails, err := client.GetServiceDetailsByName(namespace, serviceName) if err != nil { return "", err } @@ -104,22 +123,84 @@ func resolveModelRegistryURL(id string, client integrations.KubernetesClientInte return url, nil } -func (app *App) RequireAccessControl(next http.Handler, exemptPaths map[string]struct{}) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func (app *App) AttachNamespace(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + namespace := r.URL.Query().Get(string(NamespaceHeaderParameterKey)) + if namespace == "" { + app.badRequestResponse(w, r, fmt.Errorf("missing required query parameter: %s", NamespaceHeaderParameterKey)) + return + } + + ctx := context.WithValue(r.Context(), NamespaceHeaderParameterKey, namespace) + r = r.WithContext(ctx) - // Skip SAR for exempt paths - if _, exempt := exemptPaths[r.URL.Path]; exempt { - next.ServeHTTP(w, r) + next(w, r, ps) + } +} + +func (app *App) PerformSARonGetListServicesByNamespace(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + user, ok := r.Context().Value(KubeflowUserIdKey).(string) + if !ok || user == "" { + app.badRequestResponse(w, r, fmt.Errorf("missing user in context")) + return + } + namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + if !ok || namespace == "" { + app.badRequestResponse(w, r, fmt.Errorf("missing namespace in context")) return } - user := r.Header.Get(KubeflowUserIDHeader) - if user == "" { - app.forbiddenResponse(w, r, "missing kubeflow-userid header") + var userGroups []string + if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + userGroups = groups + } else { + userGroups = []string{} + } + + allowed, err := app.kubernetesClient.PerformSARonGetListServicesByNamespace(user, userGroups, namespace) + if err != nil { + app.forbiddenResponse(w, r, fmt.Sprintf("failed to perform SAR: %v", err)) + return + } + if !allowed { + app.forbiddenResponse(w, r, "access denied") + return + } + + next(w, r, ps) + } +} + +func (app *App) PerformSARonSpecificService(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + + user, ok := r.Context().Value(KubeflowUserIdKey).(string) + if !ok || user == "" { + app.badRequestResponse(w, r, fmt.Errorf("missing user in context")) + return + } + + namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + if !ok || namespace == "" { + app.badRequestResponse(w, r, fmt.Errorf("missing namespace in context")) + return + } + + modelRegistryID := ps.ByName(ModelRegistryId) + if !ok || modelRegistryID == "" { + app.badRequestResponse(w, r, fmt.Errorf("missing namespace in context")) return } - allowed, err := app.kubernetesClient.PerformSAR(user) + var userGroups []string + if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + userGroups = groups + } else { + userGroups = []string{} + } + + allowed, err := app.kubernetesClient.PerformSARonSpecificService(user, userGroups, namespace, modelRegistryID) if err != nil { app.forbiddenResponse(w, r, "failed to perform SAR: %v") return @@ -129,6 +210,6 @@ func (app *App) RequireAccessControl(next http.Handler, exemptPaths map[string]s return } - next.ServeHTTP(w, r) - }) + next(w, r, ps) + } } diff --git a/clients/ui/bff/internal/api/model_registry_handler.go b/clients/ui/bff/internal/api/model_registry_handler.go index 8412d8f8..1600d004 100644 --- a/clients/ui/bff/internal/api/model_registry_handler.go +++ b/clients/ui/bff/internal/api/model_registry_handler.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/ui/bff/internal/models" "net/http" @@ -10,7 +11,12 @@ type ModelRegistryListEnvelope Envelope[[]models.ModelRegistryModel, None] func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - registries, err := app.repositories.ModelRegistry.FetchAllModelRegistries(app.kubernetesClient) + namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + if !ok || namespace == "" { + app.badRequestResponse(w, r, fmt.Errorf("missing namespace in the context")) + } + + registries, err := app.repositories.ModelRegistry.GetAllModelRegistries(app.kubernetesClient, namespace) if err != nil { app.serverErrorResponse(w, r, err) return diff --git a/clients/ui/bff/internal/api/model_registry_handler_test.go b/clients/ui/bff/internal/api/model_registry_handler_test.go index 13c460fe..872121ce 100644 --- a/clients/ui/bff/internal/api/model_registry_handler_test.go +++ b/clients/ui/bff/internal/api/model_registry_handler_test.go @@ -1,7 +1,9 @@ package api import ( + "context" "encoding/json" + "fmt" "github.com/kubeflow/model-registry/ui/bff/internal/models" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" . "github.com/onsi/ginkgo/v2" @@ -23,7 +25,12 @@ var _ = Describe("TestModelRegistryHandler", func() { } By("creating the http test infrastructure") - req, err := http.NewRequest(http.MethodGet, ModelRegistryListPath, nil) + requestPath := fmt.Sprintf(" %s?namespace=kubeflow", ModelRegistryListPath) + req, err := http.NewRequest(http.MethodGet, requestPath, nil) + + ctx := context.WithValue(req.Context(), NamespaceHeaderParameterKey, "kubeflow") + req = req.WithContext(ctx) + Expect(err).NotTo(HaveOccurred()) rr := httptest.NewRecorder() @@ -43,8 +50,7 @@ var _ = Describe("TestModelRegistryHandler", func() { By("should match the expected model registries") var expected = []models.ModelRegistryModel{ {Name: "model-registry", Description: "Model Registry Description", DisplayName: "Model Registry"}, - {Name: "model-registry-bella", Description: "Model Registry Bella description", DisplayName: "Model Registry Bella"}, - {Name: "model-registry-dora", Description: "Model Registry Dora description", DisplayName: "Model Registry Dora"}, + {Name: "model-registry-one", Description: "Model Registry One description", DisplayName: "Model Registry One"}, } Expect(actual.Data).To(ConsistOf(expected)) }) diff --git a/clients/ui/bff/internal/api/model_versions_handler.go b/clients/ui/bff/internal/api/model_versions_handler.go index 425ca369..a945d049 100644 --- a/clients/ui/bff/internal/api/model_versions_handler.go +++ b/clients/ui/bff/internal/api/model_versions_handler.go @@ -19,7 +19,7 @@ type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None] type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None] func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -47,7 +47,7 @@ func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, p } func (app *App) CreateModelVersionHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -101,7 +101,7 @@ func (app *App) CreateModelVersionHandler(w http.ResponseWriter, r *http.Request } func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -151,7 +151,7 @@ func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request } func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -174,7 +174,7 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, } func (app *App) CreateModelArtifactByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return diff --git a/clients/ui/bff/internal/api/model_versions_handler_test.go b/clients/ui/bff/internal/api/model_versions_handler_test.go index 1a9ef040..cce4cf02 100644 --- a/clients/ui/bff/internal/api/model_versions_handler_test.go +++ b/clients/ui/bff/internal/api/model_versions_handler_test.go @@ -15,7 +15,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { By("fetching a model version") data := mocks.GetModelVersionMocks()[0] expected := ModelVersionEnvelope{Data: &data} - actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow", nil, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should match the expected model version") Expect(rs.StatusCode).To(Equal(http.StatusOK)) @@ -27,13 +27,13 @@ var _ = Describe("TestGetModelVersionHandler", func() { data := mocks.GetModelVersionMocks()[0] expected := ModelVersionEnvelope{Data: &data} body := ModelVersionEnvelope{Data: openapi.NewModelVersion("Model One", "1")} - actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow", body, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should match the expected model version created") Expect(rs.StatusCode).To(Equal(http.StatusCreated)) Expect(actual.Data.Name).To(Equal(expected.Data.Name)) - Expect(rs.Header.Get("Location")).To(Equal("/api/v1/model_registry/model-registry/model_versions/1")) + Expect(rs.Header.Get("Location")).To(Equal("/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow")) }) It("should updated a model version", func() { @@ -46,7 +46,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { } body := ModelVersionUpdateEnvelope{Data: &reqData} - actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow", body, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should match the expected model version updated") @@ -58,7 +58,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { By("getting a model artifacts by model version") data := mocks.GetModelArtifactListMock() expected := ModelArtifactListEnvelope{Data: &data} - actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts?namespace=kubeflow", nil, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should get all expected model version artifacts") @@ -79,7 +79,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { ArtifactType: "ARTIFACT_TYPE_ONE", } body := ModelArtifactEnvelope{Data: &artifact} - actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts?namespace=kubeflow", body, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should get all expected model artifacts") @@ -94,7 +94,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { wrongUserIDHeader := "bella@dora.com" // Incorrect username header value // Test: GET /model_versions/1 - _, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil, k8sClient, wrongUserIDHeader) + _, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow", nil, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response") @@ -106,14 +106,14 @@ var _ = Describe("TestGetModelVersionHandler", func() { ArtifactType: "ARTIFACT_TYPE_ONE", } body := ModelArtifactEnvelope{Data: &artifact} - _, rs, err = setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts?namespace=kubeflow", body, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response") Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) // Test: GET /model_versions/1/artifacts - _, rs, err = setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts?namespace=kubeflow", nil, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response") @@ -124,7 +124,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { Description: openapi.PtrString("New description"), } body1 := ModelVersionUpdateEnvelope{Data: &reqData} - _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body1, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow", body1, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response") @@ -132,7 +132,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { // Test: POST /model_versions body2 := ModelVersionEnvelope{Data: openapi.NewModelVersion("Model One", "1")} - _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body2, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow", body2, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response") Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) diff --git a/clients/ui/bff/internal/api/namespaces_handler.go b/clients/ui/bff/internal/api/namespaces_handler.go index fe60b190..80fb4701 100644 --- a/clients/ui/bff/internal/api/namespaces_handler.go +++ b/clients/ui/bff/internal/api/namespaces_handler.go @@ -18,7 +18,14 @@ func (app *App) GetNamespacesHandler(w http.ResponseWriter, r *http.Request, _ h return } - namespaces, err := app.repositories.Namespace.GetNamespaces(app.kubernetesClient, userId) + var userGroups []string + if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + userGroups = groups + } else { + userGroups = []string{} + } + + namespaces, err := app.repositories.Namespace.GetNamespaces(app.kubernetesClient, userId, userGroups) if err != nil { app.serverErrorResponse(w, r, err) return diff --git a/clients/ui/bff/internal/api/registered_models_handler.go b/clients/ui/bff/internal/api/registered_models_handler.go index f1dedde5..98daef1c 100644 --- a/clients/ui/bff/internal/api/registered_models_handler.go +++ b/clients/ui/bff/internal/api/registered_models_handler.go @@ -16,7 +16,7 @@ type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None] type RegisteredModelUpdateEnvelope Envelope[*openapi.RegisteredModelUpdate, None] func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -39,7 +39,7 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req } func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -93,7 +93,7 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ } func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -121,7 +121,7 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request } func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -171,7 +171,7 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ } func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -195,7 +195,7 @@ func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWrit } func (app *App) CreateModelVersionForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return diff --git a/clients/ui/bff/internal/api/registered_models_handler_test.go b/clients/ui/bff/internal/api/registered_models_handler_test.go index 93aa5cfc..e114a8f5 100644 --- a/clients/ui/bff/internal/api/registered_models_handler_test.go +++ b/clients/ui/bff/internal/api/registered_models_handler_test.go @@ -15,7 +15,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { By("fetching all model registries") data := mocks.GetRegisteredModelMocks()[0] expected := RegisteredModelEnvelope{Data: &data} - actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1?namespace=kubeflow", nil, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should match the expected model registry") //TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values @@ -28,7 +28,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { By("fetching all registered models") data := mocks.GetRegisteredModelListMock() expected := RegisteredModelListEnvelope{Data: &data} - actual, rs, err := setupApiTest[RegisteredModelListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[RegisteredModelListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models?namespace=kubeflow", nil, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should match the expected model registry") Expect(rs.StatusCode).To(Equal(http.StatusOK)) @@ -43,13 +43,13 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { data := mocks.GetRegisteredModelMocks()[0] expected := RegisteredModelEnvelope{Data: &data} body := RegisteredModelEnvelope{Data: openapi.NewRegisteredModel("Model One")} - actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models", body, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models?namespace=kubeflow", body, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should do a successful post") Expect(rs.StatusCode).To(Equal(http.StatusCreated)) Expect(actual.Data.Name).To(Equal(expected.Data.Name)) - Expect(rs.Header.Get("location")).To(Equal("/api/v1/model_registry/model-registry/registered_models/1")) + Expect(rs.Header.Get("location")).To(Equal("/api/v1/model_registry/model-registry/registered_models/1?namespace=kubeflow")) }) It("updating registered models", func() { @@ -60,7 +60,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { Description: openapi.PtrString("This is a new description"), } body := RegisteredModelUpdateEnvelope{Data: &reqData} - actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/registered_models/1", body, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/registered_models/1?namespace=kubeflow", body, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should do a successful patch") @@ -73,7 +73,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { data := mocks.GetModelVersionListMock() expected := ModelVersionListEnvelope{Data: &data} - actual, rs, err := setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1/versions", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1/versions?namespace=kubeflow", nil, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should get all items") @@ -90,7 +90,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { expected := ModelVersionEnvelope{Data: &data} body := ModelVersionEnvelope{Data: openapi.NewModelVersion("Version Fifty", "")} - actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions", body, k8sClient, mocks.KubeflowUserIDHeaderValue) + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions?namespace=kubeflow", body, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should successfully create it") @@ -105,20 +105,20 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { wrongUserIDHeader := "bella@dora.com" // Incorrect username header value // Test: GET /registered_models/1 - _, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1", nil, k8sClient, wrongUserIDHeader) + _, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1?namespace=kubeflow", nil, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response for GET registered model by ID") Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) // Test: GET /registered_models - _, rs, err = setupApiTest[RegisteredModelListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models", nil, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[RegisteredModelListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models?namespace=kubeflow", nil, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response for GET all registered models") Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) // Test: POST /registered_models body := RegisteredModelEnvelope{Data: openapi.NewRegisteredModel("Model One")} - _, rs, err = setupApiTest[RegisteredModelEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models", body, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[RegisteredModelEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models?namespace=kubeflow", body, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response for POST create registered model") Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) @@ -128,20 +128,20 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { Description: openapi.PtrString("This is a new description"), } body2 := RegisteredModelUpdateEnvelope{Data: &reqData} - _, rs, err = setupApiTest[RegisteredModelEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/registered_models/1", body2, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[RegisteredModelEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/registered_models/1?namespace=kubeflow", body2, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response for PATCH update registered model") Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) // Test: GET /registered_models/1/versions - _, rs, err = setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1/versions", nil, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1/versions?namespace=kubeflow", nil, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response for GET model versions of registered model") Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) // Test: POST /registered_models/1/versions body3 := ModelVersionEnvelope{Data: openapi.NewModelVersion("Version Fifty", "")} - _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions", body3, k8sClient, wrongUserIDHeader) + _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions?namespace=kubeflow", body3, k8sClient, wrongUserIDHeader, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should return a 403 Forbidden response for POST create model version for registered model") Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) diff --git a/clients/ui/bff/internal/api/test_utils.go b/clients/ui/bff/internal/api/test_utils.go index 77f56373..945e7d0e 100644 --- a/clients/ui/bff/internal/api/test_utils.go +++ b/clients/ui/bff/internal/api/test_utils.go @@ -12,7 +12,7 @@ import ( "net/http/httptest" ) -func setupApiTest[T any](method string, url string, body interface{}, k8sClient k8s.KubernetesClientInterface, kubeflowUserIDHeaderValue string) (T, *http.Response, error) { +func setupApiTest[T any](method string, url string, body interface{}, k8sClient k8s.KubernetesClientInterface, kubeflowUserIDHeaderValue string, namespace string) (T, *http.Response, error) { mockMRClient, err := mocks.NewModelRegistryClient(nil) if err != nil { return *new(T), nil, err @@ -46,7 +46,14 @@ func setupApiTest[T any](method string, url string, body interface{}, k8sClient // Set the kubeflow-userid header req.Header.Set(KubeflowUserIDHeader, kubeflowUserIDHeaderValue) - ctx := context.WithValue(req.Context(), httpClientKey, mockClient) + ctx := req.Context() + ctx = context.WithValue(ctx, ModelRegistryHttpClientKey, mockClient) + ctx = context.WithValue(ctx, KubeflowUserIdKey, kubeflowUserIDHeaderValue) + ctx = context.WithValue(ctx, NamespaceHeaderParameterKey, namespace) + mrHttpClient := k8s.HTTPClient{ + ModelRegistryID: "model-registry", + } + ctx = context.WithValue(ctx, ModelRegistryHttpClientKey, mrHttpClient) req = req.WithContext(ctx) rr := httptest.NewRecorder() diff --git a/clients/ui/bff/internal/integrations/http.go b/clients/ui/bff/internal/integrations/http.go index c20a859b..712b8556 100644 --- a/clients/ui/bff/internal/integrations/http.go +++ b/clients/ui/bff/internal/integrations/http.go @@ -10,14 +10,16 @@ import ( ) type HTTPClientInterface interface { + GetModelRegistryID() (modelRegistryService string) GET(url string) ([]byte, error) POST(url string, body io.Reader) ([]byte, error) PATCH(url string, body io.Reader) ([]byte, error) } type HTTPClient struct { - client *http.Client - baseURL string + client *http.Client + baseURL string + ModelRegistryID string } type ErrorResponse struct { @@ -34,16 +36,21 @@ func (e *HTTPError) Error() string { return fmt.Sprintf("HTTP %d: %s - %s", e.StatusCode, e.Code, e.Message) } -func NewHTTPClient(baseURL string) (HTTPClientInterface, error) { +func NewHTTPClient(modelRegistryID string, baseURL string) (HTTPClientInterface, error) { return &HTTPClient{ client: &http.Client{Transport: &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }}, - baseURL: baseURL, + baseURL: baseURL, + ModelRegistryID: modelRegistryID, }, nil } +func (c *HTTPClient) GetModelRegistryID() string { + return c.ModelRegistryID +} + func (c *HTTPClient) GET(url string) ([]byte, error) { fullURL := c.baseURL + url req, err := http.NewRequest("GET", fullURL, nil) diff --git a/clients/ui/bff/internal/integrations/k8s.go b/clients/ui/bff/internal/integrations/k8s.go index 6c3d8c37..9b89bb02 100644 --- a/clients/ui/bff/internal/integrations/k8s.go +++ b/clients/ui/bff/internal/integrations/k8s.go @@ -22,15 +22,16 @@ import ( const ComponentLabelValue = "model-registry" type KubernetesClientInterface interface { - GetServiceNames() ([]string, error) - GetServiceDetailsByName(serviceName string) (ServiceDetails, error) - GetServiceDetails() ([]ServiceDetails, error) + GetServiceNames(namespace string) ([]string, error) + GetServiceDetailsByName(namespace string, serviceName string) (ServiceDetails, error) + GetServiceDetails(namespace string) ([]ServiceDetails, error) BearerToken() (string, error) Shutdown(ctx context.Context, logger *slog.Logger) error IsInCluster() bool - PerformSAR(user string) (bool, error) + PerformSARonGetListServicesByNamespace(user string, groups []string, namespace string) (bool, error) + PerformSARonSpecificService(user string, groups []string, namespace string, serviceName string) (bool, error) IsClusterAdmin(user string) (bool, error) - GetNamespaces(user string) ([]corev1.Namespace, error) + GetNamespaces(user string, groups []string) ([]corev1.Namespace, error) } type ServiceDetails struct { @@ -151,8 +152,8 @@ func (kc *KubernetesClient) BearerToken() (string, error) { return kc.Token, nil } -func (kc *KubernetesClient) GetServiceNames() ([]string, error) { - services, err := kc.GetServiceDetails() +func (kc *KubernetesClient) GetServiceNames(namespace string) ([]string, error) { + services, err := kc.GetServiceDetails(namespace) if err != nil { return nil, err } @@ -165,7 +166,12 @@ func (kc *KubernetesClient) GetServiceNames() ([]string, error) { return names, nil } -func (kc *KubernetesClient) GetServiceDetails() ([]ServiceDetails, error) { +func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetails, error) { + + if namespace == "" { + return nil, fmt.Errorf("namespace cannot be empty") + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -177,6 +183,7 @@ func (kc *KubernetesClient) GetServiceDetails() ([]ServiceDetails, error) { err := kc.ControllerRuntimeClient.List(ctx, serviceList, &client.ListOptions{ LabelSelector: labelSelector, + Namespace: namespace, }) if err != nil { return nil, fmt.Errorf("failed to list services: %w", err) @@ -235,8 +242,8 @@ func (kc *KubernetesClient) GetServiceDetails() ([]ServiceDetails, error) { return services, nil } -func (kc *KubernetesClient) GetServiceDetailsByName(serviceName string) (ServiceDetails, error) { - services, err := kc.GetServiceDetails() +func (kc *KubernetesClient) GetServiceDetailsByName(namespace string, serviceName string) (ServiceDetails, error) { + services, err := kc.GetServiceDetails(namespace) if err != nil { return ServiceDetails{}, fmt.Errorf("failed to get service details: %w", err) } @@ -250,19 +257,22 @@ func (kc *KubernetesClient) GetServiceDetailsByName(serviceName string) (Service return ServiceDetails{}, fmt.Errorf("service %s not found", serviceName) } -func (kc *KubernetesClient) PerformSAR(user string) (bool, error) { +func (kc *KubernetesClient) PerformSARonGetListServicesByNamespace(user string, groups []string, namespace string) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + verbs := []string{"get", "list"} resource := "services" for _, verb := range verbs { sar := &authv1.SubjectAccessReview{ Spec: authv1.SubjectAccessReviewSpec{ - User: user, + User: user, + Groups: groups, ResourceAttributes: &authv1.ResourceAttributes{ - Verb: verb, - Resource: resource, + Verb: verb, + Resource: resource, + Namespace: namespace, }, }, } @@ -308,7 +318,7 @@ func (kc *KubernetesClient) IsClusterAdmin(user string) (bool, error) { return false, nil } -func (kc *KubernetesClient) GetNamespaces(user string) ([]corev1.Namespace, error) { +func (kc *KubernetesClient) GetNamespaces(user string, groups []string) ([]corev1.Namespace, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -324,7 +334,8 @@ func (kc *KubernetesClient) GetNamespaces(user string) ([]corev1.Namespace, erro for _, ns := range namespaceList.Items { sar := &authv1.SubjectAccessReview{ Spec: authv1.SubjectAccessReviewSpec{ - User: user, + User: user, + Groups: groups, ResourceAttributes: &authv1.ResourceAttributes{ Namespace: ns.Name, Verb: "get", @@ -347,3 +358,43 @@ func (kc *KubernetesClient) GetNamespaces(user string) ([]corev1.Namespace, erro return namespaces, nil } + +func (kc *KubernetesClient) PerformSARonSpecificService(user string, groups []string, namespace string, serviceName string) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + resource := "services" + verb := "get" + + sar := &authv1.SubjectAccessReview{ + Spec: authv1.SubjectAccessReviewSpec{ + User: user, + Groups: groups, + ResourceAttributes: &authv1.ResourceAttributes{ + Verb: verb, + Resource: resource, + Namespace: namespace, + Name: serviceName, + }, + }, + } + + // Perform the SAR using the native KubernetesNativeClient client + response, err := kc.KubernetesNativeClient.AuthorizationV1().SubjectAccessReviews().Create(ctx, sar, metav1.CreateOptions{}) + if err != nil { + return false, fmt.Errorf( + "failed to create SubjectAccessReview for verb %q on resource %q (service: %q) in namespace %q: %w", + verb, resource, serviceName, namespace, err, + ) + } + + if !response.Status.Allowed { + kc.Logger.Warn( + "access denied", "user", user, "verb", verb, "resource", resource, + "namespace", namespace, "service", serviceName, + ) + return false, nil + } + + return true, nil +} diff --git a/clients/ui/bff/internal/mocks/http_mock.go b/clients/ui/bff/internal/mocks/http_mock.go index c530592e..fd93aa21 100644 --- a/clients/ui/bff/internal/mocks/http_mock.go +++ b/clients/ui/bff/internal/mocks/http_mock.go @@ -9,6 +9,10 @@ type MockHTTPClient struct { mock.Mock } +func (c *MockHTTPClient) GetModelRegistryID() string { + return "model-registry" +} + func (m *MockHTTPClient) GET(url string) ([]byte, error) { args := m.Called(url) return args.Get(0).([]byte), args.Error(1) diff --git a/clients/ui/bff/internal/mocks/k8s_mock.go b/clients/ui/bff/internal/mocks/k8s_mock.go index d07e06ad..ce2b3e61 100644 --- a/clients/ui/bff/internal/mocks/k8s_mock.go +++ b/clients/ui/bff/internal/mocks/k8s_mock.go @@ -20,6 +20,9 @@ import ( const ( KubeflowUserIDHeaderValue = "user@example.com" DoraNonAdminUser = "doraNonAdmin@example.com" + BellaNonAdminUser = "bellaNonAdmin@example.com" + DoraServiceGroup = "dora-service-group" + DoraNamespaceGroup = "dora-namespace-group" ) type KubernetesClientMock struct { @@ -128,19 +131,28 @@ func setupMock(mockK8sClient client.Client, ctx context.Context) error { return err } + err = createNamespace(mockK8sClient, ctx, "bella-namespace") + if err != nil { + return err + } + err = createService(mockK8sClient, ctx, "model-registry", "kubeflow", "Model Registry", "Model Registry Description", "10.0.0.10", "model-registry") if err != nil { return err } - err = createService(mockK8sClient, ctx, "model-registry-dora", "dora-namespace", "Model Registry Dora", "Model Registry Dora description", "10.0.0.11", "model-registry") + err = createService(mockK8sClient, ctx, "model-registry-one", "kubeflow", "Model Registry One", "Model Registry One description", "10.0.0.11", "model-registry") if err != nil { return err } - err = createService(mockK8sClient, ctx, "model-registry-bella", "kubeflow", "Model Registry Bella", "Model Registry Bella description", "10.0.0.12", "model-registry") + err = createService(mockK8sClient, ctx, "model-registry-dora", "dora-namespace", "Model Registry Dora", "Model Registry Dora description", "10.0.0.12", "model-registry") if err != nil { return err } - err = createService(mockK8sClient, ctx, "non-model-registry", "kubeflow", "Not a Model Registry", "Not a Model Registry Bella description", "10.0.0.13", "") + err = createService(mockK8sClient, ctx, "model-registry-bella", "bella-namespace", "Model Registry Bella", "Model Registry Bella description", "10.0.0.13", "model-registry") + if err != nil { + return err + } + err = createService(mockK8sClient, ctx, "non-model-registry", "kubeflow", "Not a Model Registry", "Not a Model Registry Bella description", "10.0.0.14", "") if err != nil { return err } @@ -155,11 +167,26 @@ func setupMock(mockK8sClient client.Client, ctx context.Context) error { return fmt.Errorf("failed to create namespace-restricted RBAC: %w", err) } + err = createNamespaceRestrictedRBAC(mockK8sClient, ctx, BellaNonAdminUser, "bella-namespace") + if err != nil { + return fmt.Errorf("failed to create namespace-restricted RBAC: %w", err) + } + + err = createGroupAccessRBAC(mockK8sClient, ctx, DoraServiceGroup, "dora-namespace", "model-registry-dora") + if err != nil { + return fmt.Errorf("failed to create group-based RBAC: %w", err) + } + + err = createGroupNamespaceAccessRBAC(mockK8sClient, ctx, DoraNamespaceGroup, "dora-namespace") + if err != nil { + return fmt.Errorf("failed to set up group access to namespace: %w", err) + } + return nil } -func (m *KubernetesClientMock) GetServiceDetails() ([]k8s.ServiceDetails, error) { - originalServices, err := m.KubernetesClient.GetServiceDetails() +func (m *KubernetesClientMock) GetServiceDetails(namespace string) ([]k8s.ServiceDetails, error) { + originalServices, err := m.KubernetesClient.GetServiceDetails(namespace) if err != nil { return nil, fmt.Errorf("failed to get service details: %w", err) } @@ -172,8 +199,8 @@ func (m *KubernetesClientMock) GetServiceDetails() ([]k8s.ServiceDetails, error) return originalServices, nil } -func (m *KubernetesClientMock) GetServiceDetailsByName(serviceName string) (k8s.ServiceDetails, error) { - originalService, err := m.KubernetesClient.GetServiceDetailsByName(serviceName) +func (m *KubernetesClientMock) GetServiceDetailsByName(namespace string, serviceName string) (k8s.ServiceDetails, error) { + originalService, err := m.KubernetesClient.GetServiceDetailsByName(namespace, serviceName) if err != nil { return k8s.ServiceDetails{}, fmt.Errorf("failed to get service details: %w", err) } @@ -340,6 +367,98 @@ func createNamespaceRestrictedRBAC(k8sClient client.Client, ctx context.Context, return nil } +func createGroupAccessRBAC(k8sClient client.Client, ctx context.Context, groupName, namespace, serviceName string) error { + role := &rbacv1.Role{ + ObjectMeta: metav1.ObjectMeta{ + Name: "group-model-registry-access", + Namespace: namespace, + }, + Rules: []rbacv1.PolicyRule{ + { + APIGroups: []string{""}, + Resources: []string{"services"}, + Verbs: []string{"get", "list"}, + ResourceNames: []string{ + serviceName, + }, + }, + }, + } + + if err := k8sClient.Create(ctx, role); err != nil { + return fmt.Errorf("failed to create Role for group: %w", err) + } + + roleBinding := &rbacv1.RoleBinding{ + ObjectMeta: metav1.ObjectMeta{ + Name: "group-access-binding", + Namespace: namespace, + }, + Subjects: []rbacv1.Subject{ + { + Kind: "Group", + Name: groupName, + }, + }, + RoleRef: rbacv1.RoleRef{ + Kind: "Role", + Name: "group-model-registry-access", + APIGroup: "rbac.authorization.k8s.io", + }, + } + + if err := k8sClient.Create(ctx, roleBinding); err != nil { + return fmt.Errorf("failed to create RoleBinding for group: %w", err) + } + + return nil +} + +func createGroupNamespaceAccessRBAC(k8sClient client.Client, ctx context.Context, groupName, namespace string) error { + + role := &rbacv1.Role{ + ObjectMeta: metav1.ObjectMeta{ + Name: "group-namespace-access-role", + Namespace: namespace, + }, + Rules: []rbacv1.PolicyRule{ + { + APIGroups: []string{""}, + Resources: []string{"namespaces", "services"}, + Verbs: []string{"get", "list"}, + }, + }, + } + + if err := k8sClient.Create(ctx, role); err != nil { + return fmt.Errorf("failed to create Role for group namespace access: %w", err) + } + + roleBinding := &rbacv1.RoleBinding{ + ObjectMeta: metav1.ObjectMeta{ + Name: "group-namespace-access-binding", + Namespace: namespace, + }, + Subjects: []rbacv1.Subject{ + { + Kind: "Group", + Name: groupName, + }, + }, + RoleRef: rbacv1.RoleRef{ + Kind: "Role", + Name: "group-namespace-access-role", + APIGroup: "rbac.authorization.k8s.io", + }, + } + + if err := k8sClient.Create(ctx, roleBinding); err != nil { + return fmt.Errorf("failed to create RoleBinding for group namespace access: %w", err) + } + + return nil +} + func strPtr(s string) *string { return &s } diff --git a/clients/ui/bff/internal/mocks/k8s_mock_test.go b/clients/ui/bff/internal/mocks/k8s_mock_test.go index d77c658b..e236326a 100644 --- a/clients/ui/bff/internal/mocks/k8s_mock_test.go +++ b/clients/ui/bff/internal/mocks/k8s_mock_test.go @@ -11,7 +11,7 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the get all service successfully", func() { By("getting service details") - services, err := k8sClient.GetServiceDetails() + services, err := k8sClient.GetServiceDetails("kubeflow") Expect(err).NotTo(HaveOccurred(), "Failed to create HTTP request") By("checking that all services have the modified ClusterIP and HTTPPort") @@ -37,7 +37,7 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the service details by name", func() { By("getting service by name") - service, err := k8sClient.GetServiceDetailsByName("model-registry-dora") + service, err := k8sClient.GetServiceDetailsByName("dora-namespace", "model-registry-dora") Expect(err).NotTo(HaveOccurred(), "Failed to create k8s request") By("checking that service details are correct") @@ -49,11 +49,11 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the services names", func() { By("getting service by name") - services, err := k8sClient.GetServiceNames() + services, err := k8sClient.GetServiceNames("kubeflow") Expect(err).NotTo(HaveOccurred(), "Failed to create HTTP request") By("checking that service details are correct") - Expect(services).To(ConsistOf("model-registry", "model-registry-bella", "model-registry-dora")) + Expect(services).To(ConsistOf("model-registry", "model-registry-one")) }) }) @@ -65,14 +65,25 @@ var _ = Describe("KubernetesNativeClient SAR Test", func() { It("should allow allowed user to access services", func() { By("performing SAR for Kubeflow User ID") - allowed, err := k8sClient.PerformSAR(KubeflowUserIDHeaderValue) + allowed, err := k8sClient.PerformSARonGetListServicesByNamespace(KubeflowUserIDHeaderValue, []string{}, "kubeflow") Expect(err).NotTo(HaveOccurred(), "Failed to perform SAR for Kubeflow User ID\"") Expect(allowed).To(BeTrue(), "Expected Kubeflow User ID to have access") }) + It("check dora access to namespaces", func() { + By("performing SAR for dora user") + allowed, err := k8sClient.PerformSARonGetListServicesByNamespace(DoraNonAdminUser, []string{}, "kubeflow") + Expect(err).NotTo(HaveOccurred(), "Failed to perform SAR for unauthorized-dora@example.com") + Expect(allowed).To(BeFalse(), "Expected doraNonAdmin@example.com to be denied access") + + allowed, err = k8sClient.PerformSARonGetListServicesByNamespace(DoraNonAdminUser, []string{}, "dora-namespace") + Expect(err).NotTo(HaveOccurred(), "Failed to perform SAR for unauthorized-dora@example.com") + Expect(allowed).To(BeTrue(), "Expected doraNonAdmin@example.com ID to have access") + }) + It("should deny access for another user", func() { By("performing SAR for another user") - allowed, err := k8sClient.PerformSAR("unauthorized-dora@example.com") + allowed, err := k8sClient.PerformSARonGetListServicesByNamespace("unauthorized-dora@example.com", []string{}, "kubeflow") Expect(err).NotTo(HaveOccurred(), "Failed to perform SAR for unauthorized-dora@example.com") Expect(allowed).To(BeFalse(), "Expected unauthorized-dora@example.com to be denied access") }) @@ -80,6 +91,63 @@ var _ = Describe("KubernetesNativeClient SAR Test", func() { }) }) +var _ = Describe("KubernetesClient PerformSARonSpecificService Group Tests", func() { + Context("checking access using group memberships", func() { + const ( + namespace = "dora-namespace" + serviceName = "model-registry-dora" + existingUser = "bentoOnlyGroupAccess@example.com" + ) + + It("should deny access for a group that does not exist", func() { + groups := []string{"non-existent-group"} + + allowed, err := k8sClient.PerformSARonSpecificService(existingUser, groups, namespace, serviceName) + Expect(err).NotTo(HaveOccurred()) + Expect(allowed).To(BeFalse(), "Access should be denied for a non-existent group") + }) + + It("should allow service access for the DoraServiceGroup", func() { + groups := []string{DoraServiceGroup} + + allowed, err := k8sClient.PerformSARonSpecificService(existingUser, groups, namespace, serviceName) + Expect(err).NotTo(HaveOccurred()) + Expect(allowed).To(BeTrue(), "Access should be allowed for the DoraServiceGroup group") + }) + + It("should allow access when one group exists and the other does not", func() { + groups := []string{DoraServiceGroup, "non-existent-group"} + + allowed, err := k8sClient.PerformSARonSpecificService(existingUser, groups, namespace, serviceName) + Expect(err).NotTo(HaveOccurred()) + Expect(allowed).To(BeTrue(), "Access should be allowed if any group in the list has access") + }) + + It("should allow access only when I've service access and namespace access", func() { + groups := []string{DoraServiceGroup} + + allowed, err := k8sClient.PerformSARonSpecificService(existingUser, groups, namespace, serviceName) + Expect(err).NotTo(HaveOccurred()) + Expect(allowed).To(BeTrue(), "Access should be allowed for the DoraServiceGroup group") + + allowed, err = k8sClient.PerformSARonGetListServicesByNamespace(existingUser, groups, namespace) + Expect(err).NotTo(HaveOccurred()) + Expect(allowed).To(BeFalse(), "Access should not be allowed for only DoraServiceGroup group") + + allGroups := []string{DoraServiceGroup, DoraNamespaceGroup} + + allowed, err = k8sClient.PerformSARonGetListServicesByNamespace(existingUser, allGroups, namespace) + Expect(err).NotTo(HaveOccurred()) + Expect(allowed).To(BeTrue(), "Access should be allowed for both groups") + + allowed, err = k8sClient.PerformSARonSpecificService(existingUser, allGroups, namespace, serviceName) + Expect(err).NotTo(HaveOccurred()) + Expect(allowed).To(BeTrue(), "Access should be allowed for for both groups") + + }) + }) +}) + var _ = Describe("KubernetesClient isClusterAdmin Test", func() { Context("checking cluster admin status", func() { It("should confirm that user@example.com(KubeflowUserIDHeaderValue) is a cluster-admin", func() { diff --git a/clients/ui/bff/internal/repositories/model_registry.go b/clients/ui/bff/internal/repositories/model_registry.go index a60b2279..db417595 100644 --- a/clients/ui/bff/internal/repositories/model_registry.go +++ b/clients/ui/bff/internal/repositories/model_registry.go @@ -13,9 +13,9 @@ func NewModelRegistryRepository() *ModelRegistryRepository { return &ModelRegistryRepository{} } -func (m *ModelRegistryRepository) FetchAllModelRegistries(client k8s.KubernetesClientInterface) ([]models.ModelRegistryModel, error) { +func (m *ModelRegistryRepository) GetAllModelRegistries(client k8s.KubernetesClientInterface, namespace string) ([]models.ModelRegistryModel, error) { - resources, err := client.GetServiceDetails() + resources, err := client.GetServiceDetails(namespace) if err != nil { return nil, fmt.Errorf("error fetching model registries: %w", err) } diff --git a/clients/ui/bff/internal/repositories/model_registry_test.go b/clients/ui/bff/internal/repositories/model_registry_test.go index e430011c..a5a0d903 100644 --- a/clients/ui/bff/internal/repositories/model_registry_test.go +++ b/clients/ui/bff/internal/repositories/model_registry_test.go @@ -9,20 +9,44 @@ import ( var _ = Describe("TestFetchAllModelRegistry", func() { Context("with existing model registries", Ordered, func() { - It("should retrieve the get all service successfully", func() { + It("should retrieve the get all kubeflow service successfully", func() { By("fetching all model registries in the repository") modelRegistryRepository := NewModelRegistryRepository() - registries, err := modelRegistryRepository.FetchAllModelRegistries(k8sClient) + registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should match the expected model registries") expectedRegistries := []models.ModelRegistryModel{ {Name: "model-registry", Description: "Model Registry Description", DisplayName: "Model Registry"}, - {Name: "model-registry-bella", Description: "Model Registry Bella description", DisplayName: "Model Registry Bella"}, + {Name: "model-registry-one", Description: "Model Registry One description", DisplayName: "Model Registry One"}, + } + Expect(registries).To(ConsistOf(expectedRegistries)) + }) + + It("should retrieve the get all dora-namespace service successfully", func() { + + By("fetching all model registries in the repository") + modelRegistryRepository := NewModelRegistryRepository() + registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "dora-namespace") + Expect(err).NotTo(HaveOccurred()) + + By("should match the expected model registries") + expectedRegistries := []models.ModelRegistryModel{ {Name: "model-registry-dora", Description: "Model Registry Dora description", DisplayName: "Model Registry Dora"}, } Expect(registries).To(ConsistOf(expectedRegistries)) }) + + It("should not retrieve namespaces", func() { + + By("fetching all model registries in the repository") + modelRegistryRepository := NewModelRegistryRepository() + registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "no-namespace") + Expect(err).NotTo(HaveOccurred()) + + By("should be empty") + Expect(registries).To(BeEmpty()) + }) }) }) diff --git a/clients/ui/bff/internal/repositories/namespace.go b/clients/ui/bff/internal/repositories/namespace.go index ae547a88..6cfda951 100644 --- a/clients/ui/bff/internal/repositories/namespace.go +++ b/clients/ui/bff/internal/repositories/namespace.go @@ -12,9 +12,9 @@ func NewNamespaceRepository() *NamespaceRepository { return &NamespaceRepository{} } -func (r *NamespaceRepository) GetNamespaces(client k8s.KubernetesClientInterface, user string) ([]models.NamespaceModel, error) { +func (r *NamespaceRepository) GetNamespaces(client k8s.KubernetesClientInterface, user string, groups []string) ([]models.NamespaceModel, error) { - namespaces, err := client.GetNamespaces(user) + namespaces, err := client.GetNamespaces(user, groups) if err != nil { return nil, fmt.Errorf("error fetching namespaces: %w", err) }