Skip to content

Commit

Permalink
Merge branch 'main' into fix-platform
Browse files Browse the repository at this point in the history
  • Loading branch information
nojaf authored Dec 5, 2024
2 parents b92e128 + a2bd841 commit a4c8565
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 12 deletions.
4 changes: 2 additions & 2 deletions magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
// "createRegisteredModel",
"renameRegisteredModel",
"updateRegisteredModel",
// "deleteRegisteredModel",
// "getRegisteredModel",
"deleteRegisteredModel",
"getRegisteredModel",
// "searchRegisteredModels",
"getLatestVersions",
// "createModelVersion",
Expand Down
24 changes: 24 additions & 0 deletions mlflow_go/store/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from mlflow.entities.model_registry import ModelVersion, RegisteredModel
from mlflow.protos.model_registry_pb2 import (
DeleteRegisteredModel,
GetLatestVersions,
GetRegisteredModel,
RenameRegisteredModel,
UpdateRegisteredModel,
)
Expand Down Expand Up @@ -55,6 +57,28 @@ def rename_registered_model(self, name, new_name):
)
return RegisteredModel.from_proto(response.registered_model)

def delete_registered_model(self, name):
request = DeleteRegisteredModel(name=name)
self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteRegisteredModel, request)

def get_registered_model(self, name):
request = GetRegisteredModel(name=name)
response = self.service.call_endpoint(
get_lib().ModelRegistryServiceGetRegisteredModel, request
)

entity = RegisteredModel.from_proto(response.registered_model)
if entity.description == "":
entity.description = None

# during convertion to proto, `version` value became a `string` value.
# convert it back to `int` value again to satisfy all the Python tests and related logic.
for key in entity.aliases:
if entity.aliases[key].isnumeric():
entity.aliases[key] = int(entity.aliases[key])

return entity


def ModelRegistryStore(cls):
return type(cls.__name__, (_ModelRegistryStore, cls), {})
Expand Down
2 changes: 2 additions & 0 deletions pkg/contract/service/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions pkg/entities/model_version.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package entities

import (
"strconv"

"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
)

type ModelVersion struct {
Name string
Version int32
CreationTime int64
LastUpdatedTime int64
Description string
UserID string
CurrentStage string
Source string
RunID string
Status string
StatusMessage string
RunLink string
StorageLocation string
}

func (mv ModelVersion) ToProto() *protos.ModelVersion {
return &protos.ModelVersion{
Version: utils.PtrTo(strconv.Itoa(int(mv.Version))),
CurrentStage: utils.PtrTo(mv.CurrentStage),
}
}
10 changes: 10 additions & 0 deletions pkg/entities/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
type RegisteredModel struct {
Name string
Tags []*RegisteredModelTag
Aliases []*RegisteredModelAlias
Versions []*ModelVersion
Description *string
CreationTime int64
LastUpdatedTime int64
Expand All @@ -26,5 +28,13 @@ func (m RegisteredModel) ToProto() *protos.RegisteredModel {
registeredModel.Tags = append(registeredModel.Tags, tag.ToProto())
}

for _, alias := range m.Aliases {
registeredModel.Aliases = append(registeredModel.Aliases, alias.ToProto())
}

for _, version := range m.Versions {
registeredModel.LatestVersions = append(registeredModel.LatestVersions, version.ToProto())
}

return &registeredModel
}
18 changes: 18 additions & 0 deletions pkg/entities/registered_model_alias.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package entities

import (
"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
)

type RegisteredModelAlias struct {
Alias string
Version string
}

func (t RegisteredModelAlias) ToProto() *protos.RegisteredModelAlias {
return &protos.RegisteredModelAlias{
Alias: utils.PtrTo(t.Alias),
Version: utils.PtrTo(t.Version),
}
}
16 changes: 16 additions & 0 deletions pkg/lib/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions pkg/model_registry/service/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,26 @@ func (m *ModelRegistryService) RenameRegisteredModel(
RegisteredModel: registeredModel.ToProto(),
}, nil
}

func (m *ModelRegistryService) DeleteRegisteredModel(
ctx context.Context, input *protos.DeleteRegisteredModel,
) (*protos.DeleteRegisteredModel_Response, *contract.Error) {
if err := m.store.DeleteRegisteredModel(ctx, input.GetName()); err != nil {
return nil, err
}

return &protos.DeleteRegisteredModel_Response{}, nil
}

func (m *ModelRegistryService) GetRegisteredModel(
ctx context.Context, input *protos.GetRegisteredModel,
) (*protos.GetRegisteredModel_Response, *contract.Error) {
registeredModel, err := m.store.GetRegisteredModel(ctx, input.GetName())
if err != nil {
return nil, err
}

return &protos.GetRegisteredModel_Response{
RegisteredModel: registeredModel.ToProto(),
}, nil
}
71 changes: 67 additions & 4 deletions pkg/model_registry/store/sql/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,20 @@ func (m *ModelRegistrySQLStore) GetLatestVersions(
return results, nil
}

func (m *ModelRegistrySQLStore) GetRegisteredModelByName(
func (m *ModelRegistrySQLStore) GetRegisteredModel(
ctx context.Context, name string,
) (*entities.RegisteredModel, *contract.Error) {
var registeredModel models.RegisteredModel
if err := m.db.WithContext(
ctx,
).Where(
"name = ?", name,
).Preload(
"Tags",
).Preload(
"Aliases",
).Preload(
"Versions",
).First(
&registeredModel,
).Error; err != nil {
Expand All @@ -128,7 +134,7 @@ func (m *ModelRegistrySQLStore) GetRegisteredModelByName(
func (m *ModelRegistrySQLStore) UpdateRegisteredModel(
ctx context.Context, name, description string,
) (*entities.RegisteredModel, *contract.Error) {
registeredModel, err := m.GetRegisteredModelByName(ctx, name)
registeredModel, err := m.GetRegisteredModel(ctx, name)
if err != nil {
return nil, err
}
Expand All @@ -151,7 +157,7 @@ func (m *ModelRegistrySQLStore) UpdateRegisteredModel(
func (m *ModelRegistrySQLStore) RenameRegisteredModel(
ctx context.Context, name, newName string,
) (*entities.RegisteredModel, *contract.Error) {
registeredModel, err := m.GetRegisteredModelByName(ctx, name)
registeredModel, err := m.GetRegisteredModel(ctx, name)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -192,10 +198,67 @@ func (m *ModelRegistrySQLStore) RenameRegisteredModel(
return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to rename registered model", err)
}

registeredModel, err = m.GetRegisteredModelByName(ctx, newName)
registeredModel, err = m.GetRegisteredModel(ctx, newName)
if err != nil {
return nil, err
}

return registeredModel, nil
}

func (m *ModelRegistrySQLStore) DeleteRegisteredModel(ctx context.Context, name string) *contract.Error {
registeredModel, err := m.GetRegisteredModel(ctx, name)
if err != nil {
return err
}

if err := m.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error {
if err := transaction.Where(
"name = ?", registeredModel.Name,
).Delete(
models.ModelVersionTag{},
).Error; err != nil {
return err
}

if err := transaction.Where(
"name = ?", registeredModel.Name,
).Delete(
models.ModelVersion{},
).Error; err != nil {
return err
}

if err := transaction.Where(
"name = ?", registeredModel.Name,
).Delete(
models.RegisteredModelTag{},
).Error; err != nil {
return err
}

if err := transaction.Where(
"name = ?", registeredModel.Name,
).Delete(
models.RegisteredModelAlias{},
).Error; err != nil {
return err
}

if err := transaction.Where(
"name = ?", registeredModel.Name,
).Delete(
models.RegisteredModel{},
).Error; err != nil {
return err
}

return nil
}); err != nil {
return contract.NewError(
protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("error deleting registered model: %v", err),
)
}

return nil
}
19 changes: 19 additions & 0 deletions pkg/model_registry/store/sql/models/model_versions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package models

import (
"github.com/mlflow/mlflow-go/pkg/entities"
"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
)
Expand Down Expand Up @@ -47,3 +48,21 @@ func (mv ModelVersion) ToProto() *protos.ModelVersion {
RunLink: &mv.RunLink,
}
}

func (mv ModelVersion) ToEntity() *entities.ModelVersion {
return &entities.ModelVersion{
Name: mv.Name,
Version: mv.Version,
CreationTime: mv.CreationTime,
LastUpdatedTime: mv.LastUpdatedTime,
Description: mv.Description,
UserID: mv.UserID,
CurrentStage: mv.CurrentStage.String(),
Source: mv.Source,
RunID: mv.RunID,
Status: mv.Status,
StatusMessage: mv.StatusMessage,
RunLink: mv.RunLink,
StorageLocation: mv.StorageLocation,
}
}
15 changes: 14 additions & 1 deletion pkg/model_registry/store/sql/models/registered_model_aliases.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
package models

import (
"strconv"

"github.com/mlflow/mlflow-go/pkg/entities"
)

// RegisteredModelAlias mapped from table <registered_model_aliases>.
type RegisteredModelAlias struct {
Name string `db:"name" gorm:"column:name;primaryKey"`
Alias string `db:"alias" gorm:"column:alias;primaryKey"`
Version int32 `db:"version" gorm:"column:version;not null"`
Name string `db:"name" gorm:"column:name;primaryKey"`
}

func (a RegisteredModelAlias) ToEntity() *entities.RegisteredModelAlias {
return &entities.RegisteredModelAlias{
Alias: a.Alias,
Version: strconv.Itoa(int(a.Version)),
}
}
33 changes: 28 additions & 5 deletions pkg/model_registry/store/sql/models/registered_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,21 @@ import (

// RegisteredModel mapped from table <registered_models>.
type RegisteredModel struct {
Name string `gorm:"column:name;primaryKey"`
Tags []RegisteredModelTag `gorm:"foreignKey:Name;references:Name"`
Description sql.NullString `gorm:"column:description"`
CreationTime int64 `gorm:"column:creation_time"`
LastUpdatedTime int64 `gorm:"column:last_updated_time"`
Name string `gorm:"column:name;primaryKey"`
Tags []RegisteredModelTag `gorm:"foreignKey:Name;references:Name"`
Aliases []RegisteredModelAlias `gorm:"foreignKey:Name;references:Name"`
Versions []ModelVersion `gorm:"foreignKey:Name;references:Name"`
Description sql.NullString `gorm:"column:description"`
CreationTime int64 `gorm:"column:creation_time"`
LastUpdatedTime int64 `gorm:"column:last_updated_time"`
}

func (m *RegisteredModel) ToEntity() *entities.RegisteredModel {
model := entities.RegisteredModel{
Name: m.Name,
Tags: make([]*entities.RegisteredModelTag, 0, len(m.Tags)),
Aliases: make([]*entities.RegisteredModelAlias, 0, len(m.Aliases)),
Versions: make([]*entities.ModelVersion, 0),
CreationTime: m.CreationTime,
LastUpdatedTime: m.LastUpdatedTime,
}
Expand All @@ -31,5 +35,24 @@ func (m *RegisteredModel) ToEntity() *entities.RegisteredModel {
model.Tags = append(model.Tags, tag.ToEntity())
}

for _, alias := range m.Aliases {
model.Aliases = append(model.Aliases, alias.ToEntity())
}

latestVersionsByStage := map[string]*ModelVersion{}

for _, currentVersion := range m.Versions {
stage := currentVersion.CurrentStage.String()
if stage != StageDeletedInternal {
if latestVersion, ok := latestVersionsByStage[stage]; !ok || latestVersion.Version < currentVersion.Version {
latestVersionsByStage[stage] = &currentVersion
}
}
}

for _, latestVersion := range latestVersionsByStage {
model.Versions = append(model.Versions, latestVersion.ToEntity())
}

return &model
}
Loading

0 comments on commit a4c8565

Please sign in to comment.