From 0cf60de88be1183f9da57342565d0da4440b4d38 Mon Sep 17 00:00:00 2001 From: Software Developer <7852635+dsuhinin@users.noreply.github.com> Date: Wed, 20 Nov 2024 13:06:07 +0100 Subject: [PATCH] Move `DELETE /mlflow/registered-models/delete` endpoint. (#89) Signed-off-by: Software Developer <7852635+dsuhinin@users.noreply.github.com> --- magefiles/generate/endpoints.go | 2 +- mlflow_go/store/model_registry.py | 5 ++ pkg/contract/service/model_registry.g.go | 1 + pkg/lib/model_registry.g.go | 8 +++ pkg/model_registry/service/model_versions.go | 10 ++++ .../store/sql/model_versions.go | 57 +++++++++++++++++++ pkg/model_registry/store/store.go | 1 + pkg/server/routes/model_registry.g.go | 11 ++++ 8 files changed, 94 insertions(+), 1 deletion(-) diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index eca4319..cc761a6 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -51,7 +51,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ // "createRegisteredModel", "renameRegisteredModel", "updateRegisteredModel", - // "deleteRegisteredModel", + "deleteRegisteredModel", // "getRegisteredModel", // "searchRegisteredModels", "getLatestVersions", diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index 981fdac..ead2ec6 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -3,6 +3,7 @@ from mlflow.entities.model_registry import ModelVersion, RegisteredModel from mlflow.protos.model_registry_pb2 import ( + DeleteRegisteredModel, GetLatestVersions, RenameRegisteredModel, UpdateRegisteredModel, @@ -55,6 +56,10 @@ def rename_registered_model(self, name, new_name): ) return RegisteredModel.from_proto(response.registered_model) + def delete_registered_model(self, name): + request = DeleteRegisteredModel(name=name) + self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteRegisteredModel, request) + def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index ce187cb..42659ef 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -12,5 +12,6 @@ type ModelRegistryService interface { contract.Destroyer RenameRegisteredModel(ctx context.Context, input *protos.RenameRegisteredModel) (*protos.RenameRegisteredModel_Response, *contract.Error) UpdateRegisteredModel(ctx context.Context, input *protos.UpdateRegisteredModel) (*protos.UpdateRegisteredModel_Response, *contract.Error) + DeleteRegisteredModel(ctx context.Context, input *protos.DeleteRegisteredModel) (*protos.DeleteRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) } diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index f6c33d3..878f204 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -23,6 +23,14 @@ func ModelRegistryServiceUpdateRegisteredModel(serviceID int64, requestData unsa } return invokeServiceMethod(service.UpdateRegisteredModel, new(protos.UpdateRegisteredModel), requestData, requestSize, responseSize) } +//export ModelRegistryServiceDeleteRegisteredModel +func ModelRegistryServiceDeleteRegisteredModel(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.DeleteRegisteredModel, new(protos.DeleteRegisteredModel), requestData, requestSize, responseSize) +} //export ModelRegistryServiceGetLatestVersions func ModelRegistryServiceGetLatestVersions(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index ae97f4d..b7e22b7 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -53,3 +53,13 @@ func (m *ModelRegistryService) RenameRegisteredModel( RegisteredModel: registeredModel.ToProto(), }, nil } + +func (m *ModelRegistryService) DeleteRegisteredModel( + ctx context.Context, input *protos.DeleteRegisteredModel, +) (*protos.DeleteRegisteredModel_Response, *contract.Error) { + if err := m.store.DeleteRegisteredModel(ctx, input.GetName()); err != nil { + return nil, err + } + + return &protos.DeleteRegisteredModel_Response{}, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 3e35fbf..2863909 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -199,3 +199,60 @@ func (m *ModelRegistrySQLStore) RenameRegisteredModel( return registeredModel, nil } + +func (m *ModelRegistrySQLStore) DeleteRegisteredModel(ctx context.Context, name string) *contract.Error { + registeredModel, err := m.GetRegisteredModelByName(ctx, name) + if err != nil { + return err + } + + if err := m.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + if err := transaction.Where( + "name = ?", registeredModel.Name, + ).Delete( + models.ModelVersionTag{}, + ).Error; err != nil { + return err + } + + if err := transaction.Where( + "name = ?", registeredModel.Name, + ).Delete( + models.ModelVersion{}, + ).Error; err != nil { + return err + } + + if err := transaction.Where( + "name = ?", registeredModel.Name, + ).Delete( + models.RegisteredModelTag{}, + ).Error; err != nil { + return err + } + + if err := transaction.Where( + "name = ?", registeredModel.Name, + ).Delete( + models.RegisteredModelAlias{}, + ).Error; err != nil { + return err + } + + if err := transaction.Where( + "name = ?", registeredModel.Name, + ).Delete( + models.RegisteredModel{}, + ).Error; err != nil { + return err + } + + return nil + }); err != nil { + return contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("error deleting registered model: %v", err), + ) + } + + return nil +} diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index 0fa667a..fe19f37 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -13,4 +13,5 @@ type ModelRegistryStore interface { GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) UpdateRegisteredModel(ctx context.Context, name, description string) (*entities.RegisteredModel, *contract.Error) RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error) + DeleteRegisteredModel(ctx context.Context, name string) *contract.Error } diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index 6e349ff..f4e6e37 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -33,6 +33,17 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Delete("/mlflow/registered-models/delete", func(ctx *fiber.Ctx) error { + input := &protos.DeleteRegisteredModel{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.DeleteRegisteredModel(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Post("/mlflow/registered-models/get-latest-versions", func(ctx *fiber.Ctx) error { input := &protos.GetLatestVersions{} if err := parser.ParseBody(ctx, input); err != nil {