diff --git a/pkg/core/artifact.go b/pkg/core/artifact.go new file mode 100644 index 00000000..f35050dd --- /dev/null +++ b/pkg/core/artifact.go @@ -0,0 +1,332 @@ +package core + +import ( + "context" + "fmt" + + "github.com/golang/glog" + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// ARTIFACTS + +// UpsertArtifact creates a new artifact if the provided artifact's ID is nil, or updates an existing artifact if the +// ID is provided. +// A model version ID must be provided to disambiguate between artifacts. +// Upon creation, new artifacts will be associated with their corresponding model version. +func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, modelVersionId *string) (*openapi.Artifact, error) { + if artifact == nil { + return nil, fmt.Errorf("invalid artifact pointer, can't upsert nil") + } + creating := false + if ma := artifact.ModelArtifact; ma != nil { + if ma.Id == nil { + creating = true + glog.Info("Creating model artifact") + if modelVersionId == nil { + return nil, fmt.Errorf("missing model version id, cannot create artifact without model version: %w", api.ErrBadRequest) + } + _, err := serv.GetModelVersionById(*modelVersionId) + if err != nil { + return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound) + } + } else { + glog.Info("Updating model artifact") + existing, err := serv.GetModelArtifactById(*ma.Id) + if err != nil { + return nil, err + } + + withNotEditable, err := serv.openapiConv.OverrideNotEditableForModelArtifact(converter.NewOpenapiUpdateWrapper(existing, ma)) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + ma = &withNotEditable + + _, err = serv.getModelVersionByArtifactId(*ma.Id) + if err != nil { + return nil, err + } + } + } else if da := artifact.DocArtifact; da != nil { + if da.Id == nil { + creating = true + glog.Info("Creating doc artifact") + if modelVersionId == nil { + return nil, fmt.Errorf("missing model version id, cannot create artifact without model version: %w", api.ErrBadRequest) + } + _, err := serv.GetModelVersionById(*modelVersionId) + if err != nil { + return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound) + } + } else { + glog.Info("Updating doc artifact") + existing, err := serv.GetArtifactById(*da.Id) + if err != nil { + return nil, err + } + if existing.DocArtifact == nil { + return nil, fmt.Errorf("mismatched types, artifact with id %s is not a doc artifact: %w", *da.Id, api.ErrBadRequest) + } + + withNotEditable, err := serv.openapiConv.OverrideNotEditableForDocArtifact(converter.NewOpenapiUpdateWrapper(existing.DocArtifact, da)) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + da = &withNotEditable + + _, err = serv.getModelVersionByArtifactId(*da.Id) + if err != nil { + return nil, err + } + } + } else { + return nil, fmt.Errorf("invalid artifact type, must be either ModelArtifact or DocArtifact: %w", api.ErrBadRequest) + } + pa, err := serv.mapper.MapFromArtifact(artifact, modelVersionId) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{ + Artifacts: []*proto.Artifact{pa}, + }) + if err != nil { + return nil, err + } + + if creating { + // add explicit Attribution between Artifact and ModelVersion + modelVersionId, err := converter.StringToInt64(modelVersionId) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + attributions := []*proto.Attribution{} + for _, a := range artifactsResp.ArtifactIds { + attributions = append(attributions, &proto.Attribution{ + ContextId: modelVersionId, + ArtifactId: &a, + }) + } + _, err = serv.mlmdClient.PutAttributionsAndAssociations(context.Background(), &proto.PutAttributionsAndAssociationsRequest{ + Attributions: attributions, + Associations: make([]*proto.Association, 0), + }) + if err != nil { + return nil, err + } + } + + idAsString := converter.Int64ToString(&artifactsResp.ArtifactIds[0]) + return serv.GetArtifactById(*idAsString) +} + +func (serv *ModelRegistryService) GetArtifactById(id string) (*openapi.Artifact, error) { + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{ + ArtifactIds: []int64{int64(*idAsInt)}, + }) + if err != nil { + return nil, err + } + if len(artifactsResp.Artifacts) > 1 { + return nil, fmt.Errorf("multiple artifacts found for id %s: %w", id, api.ErrNotFound) + } + if len(artifactsResp.Artifacts) == 0 { + return nil, fmt.Errorf("no artifact found for id %s: %w", id, api.ErrNotFound) + } + return serv.mapper.MapToArtifact(artifactsResp.Artifacts[0]) +} + +func (serv *ModelRegistryService) GetArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ArtifactList, error) { + listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + var artifacts []*proto.Artifact + var nextPageToken *string + if modelVersionId == nil { + return nil, fmt.Errorf("missing model version id, cannot get artifacts without model version: %w", api.ErrBadRequest) + } + ctxId, err := converter.StringToInt64(modelVersionId) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{ + ContextId: ctxId, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + artifacts = artifactsResp.Artifacts + nextPageToken = artifactsResp.NextPageToken + + results := []openapi.Artifact{} + for _, a := range artifacts { + mapped, err := serv.mapper.MapToArtifact(a) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + results = append(results, *mapped) + } + + toReturn := openapi.ArtifactList{ + NextPageToken: apiutils.ZeroIfNil(nextPageToken), + PageSize: apiutils.ZeroIfNil(listOptions.PageSize), + Size: int32(len(results)), + Items: results, + } + return &toReturn, nil +} + +// MODEL ARTIFACTS + +// UpsertModelArtifact creates a new model artifact if the provided model artifact's ID is nil, +// or updates an existing model artifact if the ID is provided. +// If a model version ID is provided and the model artifact is newly created, establishes an +// explicit attribution between the model version and the created model artifact. +func (serv *ModelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, modelVersionId *string) (*openapi.ModelArtifact, error) { + art, err := serv.UpsertArtifact(&openapi.Artifact{ + ModelArtifact: modelArtifact, + }, modelVersionId) + if err != nil { + return nil, err + } + return art.ModelArtifact, err +} + +// GetModelArtifactById retrieves a model artifact by its unique identifier (ID). +func (serv *ModelRegistryService) GetModelArtifactById(id string) (*openapi.ModelArtifact, error) { + art, err := serv.GetArtifactById(id) + if err != nil { + return nil, err + } + ma := art.ModelArtifact + if ma == nil { + return nil, fmt.Errorf("artifact with id %s is not a model artifact: %w", id, api.ErrNotFound) + } + return ma, err +} + +// GetModelArtifactByInferenceService retrieves the model artifact associated with the specified inference service ID. +func (serv *ModelRegistryService) GetModelArtifactByInferenceService(inferenceServiceId string) (*openapi.ModelArtifact, error) { + mv, err := serv.GetModelVersionByInferenceService(inferenceServiceId) + if err != nil { + return nil, err + } + + artifactList, err := serv.GetModelArtifacts(api.ListOptions{}, mv.Id) + if err != nil { + return nil, err + } + + if artifactList.Size == 0 { + return nil, fmt.Errorf("no artifacts found for model version %s: %w", *mv.Id, api.ErrNotFound) + } + + return &artifactList.Items[0], nil +} + +// GetModelArtifactByParams retrieves a model artifact based on specified parameters, such as (artifact name and model version ID), or external ID. +// If multiple or no model artifacts are found, an error is returned. +func (serv *ModelRegistryService) GetModelArtifactByParams(artifactName *string, modelVersionId *string, externalId *string) (*openapi.ModelArtifact, error) { + var artifact0 *proto.Artifact + + filterQuery := "" + if externalId != nil { + filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) + } else if artifactName != nil && modelVersionId != nil { + filterQuery = fmt.Sprintf("name = \"%s\"", converter.PrefixWhenOwned(modelVersionId, *artifactName)) + } else { + return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and modelVersionId), or externalId: %w", api.ErrBadRequest) + } + glog.Info("filterQuery ", filterQuery) + + artifactsResponse, err := serv.mlmdClient.GetArtifactsByType(context.Background(), &proto.GetArtifactsByTypeRequest{ + TypeName: &serv.nameConfig.ModelArtifactTypeName, + Options: &proto.ListOperationOptions{ + FilterQuery: &filterQuery, + }, + }) + if err != nil { + return nil, err + } + + if len(artifactsResponse.Artifacts) > 1 { + return nil, fmt.Errorf("multiple model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v: %w", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + if len(artifactsResponse.Artifacts) == 0 { + return nil, fmt.Errorf("no model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v: %w", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + artifact0 = artifactsResponse.Artifacts[0] + + result, err := serv.mapper.MapToModelArtifact(artifact0) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return result, nil +} + +// GetModelArtifacts retrieves a list of model artifacts based on the provided list options and optional model version ID. +func (serv *ModelRegistryService) GetModelArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ModelArtifactList, error) { + listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + var artifacts []*proto.Artifact + var nextPageToken *string + if modelVersionId != nil { + ctxId, err := converter.StringToInt64(modelVersionId) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{ + ContextId: ctxId, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + artifacts = artifactsResp.Artifacts + nextPageToken = artifactsResp.NextPageToken + } else { + artifactsResp, err := serv.mlmdClient.GetArtifactsByType(context.Background(), &proto.GetArtifactsByTypeRequest{ + TypeName: &serv.nameConfig.ModelArtifactTypeName, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + artifacts = artifactsResp.Artifacts + nextPageToken = artifactsResp.NextPageToken + } + + results := []openapi.ModelArtifact{} + for _, a := range artifacts { + mapped, err := serv.mapper.MapToModelArtifact(a) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + results = append(results, *mapped) + } + + toReturn := openapi.ModelArtifactList{ + NextPageToken: apiutils.ZeroIfNil(nextPageToken), + PageSize: apiutils.ZeroIfNil(listOptions.PageSize), + Size: int32(len(results)), + Items: results, + } + return &toReturn, nil +} diff --git a/pkg/core/artifact_test.go b/pkg/core/artifact_test.go new file mode 100644 index 00000000..e68e1962 --- /dev/null +++ b/pkg/core/artifact_test.go @@ -0,0 +1,572 @@ +package core + +import ( + "context" + "fmt" + + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// ARTIFACTS + +func (suite *CoreTestSuite) TestCreateArtifact() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + createdArt, err := service.UpsertArtifact(&openapi.Artifact{ + DocArtifact: &openapi.DocArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + Description: &artifactDescription, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + }, + }, &modelVersionId) + suite.Nilf(err, "error creating new artifact for %d: %v", modelVersionId, err) + + docArtifact := createdArt.DocArtifact + suite.NotNilf(docArtifact, "error creating new artifact for %d", modelVersionId) + state, _ := openapi.NewArtifactStateFromValue(artifactState) + suite.NotNil(docArtifact.Id, "created artifact id should not be nil") + suite.Equal(artifactName, *docArtifact.Name) + suite.Equal(*state, *docArtifact.State) + suite.Equal(artifactUri, *docArtifact.Uri) + suite.Equal(artifactDescription, *docArtifact.Description) + suite.Equal(customString, (*docArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) +} + +func (suite *CoreTestSuite) TestCreateArtifactFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := "9998" + + var artifact openapi.Artifact + artifact.DocArtifact = &openapi.DocArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + _, err := service.UpsertArtifact(&artifact, nil) + suite.NotNil(err) + suite.Equal("missing model version id, cannot create artifact without model version: bad request", err.Error()) + + _, err = service.UpsertArtifact(&artifact, &modelVersionId) + suite.NotNil(err) + suite.Equal("no model version found for id 9998: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestUpdateArtifact() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + createdArtifact, err := service.UpsertArtifact(&openapi.Artifact{ + DocArtifact: &openapi.DocArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + }, + }, &modelVersionId) + suite.Nilf(err, "error creating new artifact for %d", modelVersionId) + + newState := "MARKED_FOR_DELETION" + createdArtifact.DocArtifact.State = (*openapi.ArtifactState)(&newState) + updatedArtifact, err := service.UpsertArtifact(createdArtifact, &modelVersionId) + suite.Nilf(err, "error updating artifact for %d: %v", modelVersionId, err) + + createdArtifactId, _ := converter.StringToInt64(createdArtifact.DocArtifact.Id) + updatedArtifactId, _ := converter.StringToInt64(updatedArtifact.DocArtifact.Id) + suite.Equal(createdArtifactId, updatedArtifactId) + + getById, err := suite.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{ + ArtifactIds: []int64{*createdArtifactId}, + }) + suite.Nilf(err, "error getting artifact by id %d", createdArtifactId) + + suite.Equal(*createdArtifactId, *getById.Artifacts[0].Id) + suite.Equal(fmt.Sprintf("%s:%s", modelVersionId, *createdArtifact.DocArtifact.Name), *getById.Artifacts[0].Name) + suite.Equal(string(newState), getById.Artifacts[0].State.String()) + suite.Equal(*createdArtifact.DocArtifact.Uri, *getById.Artifacts[0].Uri) + suite.Equal((*createdArtifact.DocArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, getById.Artifacts[0].CustomProperties["custom_string_prop"].GetStringValue()) +} + +func (suite *CoreTestSuite) TestUpdateArtifactFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + createdArtifact, err := service.UpsertArtifact(&openapi.Artifact{ + DocArtifact: &openapi.DocArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + }, + }, &modelVersionId) + suite.Nilf(err, "error creating new artifact for model version %s", modelVersionId) + suite.NotNilf(createdArtifact.DocArtifact.Id, "created model artifact should not have nil Id") + + newState := "MARKED_FOR_DELETION" + createdArtifact.DocArtifact.State = (*openapi.ArtifactState)(&newState) + updatedArtifact, err := service.UpsertArtifact(createdArtifact, &modelVersionId) + suite.Nilf(err, "error updating artifact for %d: %v", modelVersionId, err) + + wrongId := "5555" + updatedArtifact.DocArtifact.Id = &wrongId + _, err = service.UpsertArtifact(updatedArtifact, &modelVersionId) + suite.NotNil(err) + suite.Equal(fmt.Sprintf("no artifact found for id %s: not found", wrongId), err.Error()) +} + +func (suite *CoreTestSuite) TestGetArtifactById() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + createdArtifact, err := service.UpsertArtifact(&openapi.Artifact{ + DocArtifact: &openapi.DocArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + }, + }, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + + createdArtifactId, _ := converter.StringToInt64(createdArtifact.DocArtifact.Id) + + getById, err := service.GetArtifactById(*createdArtifact.DocArtifact.Id) + suite.Nilf(err, "error getting artifact by id %d", createdArtifactId) + + state, _ := openapi.NewArtifactStateFromValue(artifactState) + suite.NotNil(createdArtifact.DocArtifact.Id, "created artifact id should not be nil") + suite.Equal(artifactName, *getById.DocArtifact.Name) + suite.Equal(*state, *getById.DocArtifact.State) + suite.Equal(artifactUri, *getById.DocArtifact.Uri) + suite.Equal(customString, (*getById.DocArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) + + suite.Equal(*createdArtifact, *getById, "artifacts returned during creation and on get by id should be equal") +} + +func (suite *CoreTestSuite) TestGetArtifacts() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + secondArtifactName := "second-name" + secondArtifactExtId := "second-ext-id" + secondArtifactUri := "second-uri" + + createdArtifact1, err := service.UpsertArtifact(&openapi.Artifact{ + ModelArtifact: &openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + ExternalId: &artifactExtId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + }, + }, &modelVersionId) + suite.Nilf(err, "error creating new artifact for %d", modelVersionId) + createdArtifact2, err := service.UpsertArtifact(&openapi.Artifact{ + DocArtifact: &openapi.DocArtifact{ + Name: &secondArtifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &secondArtifactUri, + ExternalId: &secondArtifactExtId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + }, + }, &modelVersionId) + suite.Nilf(err, "error creating new artifact for %d", modelVersionId) + + createdArtifactId1, _ := converter.StringToInt64(createdArtifact1.ModelArtifact.Id) + createdArtifactId2, _ := converter.StringToInt64(createdArtifact2.DocArtifact.Id) + + getAll, err := service.GetArtifacts(api.ListOptions{}, &modelVersionId) + suite.Nilf(err, "error getting all model artifacts") + suite.Equalf(int32(2), getAll.Size, "expected two artifacts") + + suite.Equal(*converter.Int64ToString(createdArtifactId1), *getAll.Items[0].ModelArtifact.Id) + suite.Equal(*converter.Int64ToString(createdArtifactId2), *getAll.Items[1].DocArtifact.Id) + + orderByLastUpdate := "LAST_UPDATE_TIME" + getAllByModelVersion, err := service.GetArtifacts(api.ListOptions{ + OrderBy: &orderByLastUpdate, + SortOrder: &descOrderDirection, + }, &modelVersionId) + suite.Nilf(err, "error getting all model artifacts for %d", modelVersionId) + suite.Equalf(int32(2), getAllByModelVersion.Size, "expected 2 artifacts for model version %d", modelVersionId) + + suite.Equal(*converter.Int64ToString(createdArtifactId1), *getAllByModelVersion.Items[1].ModelArtifact.Id) + suite.Equal(*converter.Int64ToString(createdArtifactId2), *getAllByModelVersion.Items[0].DocArtifact.Id) +} + +// MODEL ARTIFACTS + +func (suite *CoreTestSuite) TestCreateModelArtifact() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + modelArtifact, err := service.UpsertModelArtifact(&openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + Description: &artifactDescription, + ModelFormatName: apiutils.Of("onnx"), + ModelFormatVersion: apiutils.Of("1"), + StorageKey: apiutils.Of("aws-connection-models"), + StoragePath: apiutils.Of("bucket"), + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + }, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + + state, _ := openapi.NewArtifactStateFromValue(artifactState) + suite.NotNil(modelArtifact.Id, "created artifact id should not be nil") + suite.Equal(artifactName, *modelArtifact.Name) + suite.Equal(*state, *modelArtifact.State) + suite.Equal(artifactUri, *modelArtifact.Uri) + suite.Equal(artifactDescription, *modelArtifact.Description) + suite.Equal("onnx", *modelArtifact.ModelFormatName) + suite.Equal("1", *modelArtifact.ModelFormatVersion) + suite.Equal("aws-connection-models", *modelArtifact.StorageKey) + suite.Equal("bucket", *modelArtifact.StoragePath) + suite.Equal(customString, (*modelArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) +} + +func (suite *CoreTestSuite) TestCreateModelArtifactFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := "9998" + + modelArtifact := &openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + _, err := service.UpsertModelArtifact(modelArtifact, nil) + suite.NotNil(err) + suite.Equal("missing model version id, cannot create artifact without model version: bad request", err.Error()) + + _, err = service.UpsertModelArtifact(modelArtifact, &modelVersionId) + suite.NotNil(err) + suite.Equal("no model version found for id 9998: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestUpdateModelArtifact() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + modelArtifact := &openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdArtifact, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + + newState := "MARKED_FOR_DELETION" + createdArtifact.State = (*openapi.ArtifactState)(&newState) + updatedArtifact, err := service.UpsertModelArtifact(createdArtifact, &modelVersionId) + suite.Nilf(err, "error updating model artifact for %d: %v", modelVersionId, err) + + createdArtifactId, _ := converter.StringToInt64(createdArtifact.Id) + updatedArtifactId, _ := converter.StringToInt64(updatedArtifact.Id) + suite.Equal(createdArtifactId, updatedArtifactId) + + getById, err := suite.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{ + ArtifactIds: []int64{*createdArtifactId}, + }) + suite.Nilf(err, "error getting model artifact by id %d", createdArtifactId) + + suite.Equal(*createdArtifactId, *getById.Artifacts[0].Id) + suite.Equal(fmt.Sprintf("%s:%s", modelVersionId, *createdArtifact.Name), *getById.Artifacts[0].Name) + suite.Equal(string(newState), getById.Artifacts[0].State.String()) + suite.Equal(*createdArtifact.Uri, *getById.Artifacts[0].Uri) + suite.Equal((*createdArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, getById.Artifacts[0].CustomProperties["custom_string_prop"].GetStringValue()) +} + +func (suite *CoreTestSuite) TestUpdateModelArtifactFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + modelArtifact := &openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdArtifact, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for model version %s", modelVersionId) + suite.NotNilf(createdArtifact.Id, "created model artifact should not have nil Id") +} + +func (suite *CoreTestSuite) TestGetModelArtifactById() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + modelArtifact := &openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdArtifact, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + + createdArtifactId, _ := converter.StringToInt64(createdArtifact.Id) + + getById, err := service.GetModelArtifactById(*createdArtifact.Id) + suite.Nilf(err, "error getting model artifact by id %d", createdArtifactId) + + state, _ := openapi.NewArtifactStateFromValue(artifactState) + suite.NotNil(createdArtifact.Id, "created artifact id should not be nil") + suite.Equal(artifactName, *getById.Name) + suite.Equal(*state, *getById.State) + suite.Equal(artifactUri, *getById.Uri) + suite.Equal(customString, (*getById.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) + + suite.Equal(*createdArtifact, *getById, "artifacts returned during creation and on get by id should be equal") +} + +func (suite *CoreTestSuite) TestGetModelArtifactByParams() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + modelArtifact := &openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + ExternalId: &artifactExtId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdArtifact, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + + createdArtifactId, _ := converter.StringToInt64(createdArtifact.Id) + + state, _ := openapi.NewArtifactStateFromValue(artifactState) + + getByName, err := service.GetModelArtifactByParams(&artifactName, &modelVersionId, nil) + suite.Nilf(err, "error getting model artifact by id %d", createdArtifactId) + + suite.NotNil(createdArtifact.Id, "created artifact id should not be nil") + suite.Equal(artifactName, *getByName.Name) + suite.Equal(artifactExtId, *getByName.ExternalId) + suite.Equal(*state, *getByName.State) + suite.Equal(artifactUri, *getByName.Uri) + suite.Equal(customString, (*getByName.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) + + suite.Equal(*createdArtifact, *getByName, "artifacts returned during creation and on get by name should be equal") + + getByExtId, err := service.GetModelArtifactByParams(nil, nil, &artifactExtId) + suite.Nilf(err, "error getting model artifact by id %d", createdArtifactId) + + suite.NotNil(createdArtifact.Id, "created artifact id should not be nil") + suite.Equal(artifactName, *getByExtId.Name) + suite.Equal(artifactExtId, *getByExtId.ExternalId) + suite.Equal(*state, *getByExtId.State) + suite.Equal(artifactUri, *getByExtId.Uri) + suite.Equal(customString, (*getByExtId.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) + + suite.Equal(*createdArtifact, *getByExtId, "artifacts returned during creation and on get by ext id should be equal") +} + +func (suite *CoreTestSuite) TestGetModelArtifactByEmptyParams() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + modelArtifact := &openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + ExternalId: &artifactExtId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + _, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + + _, err = service.GetModelArtifactByParams(nil, nil, nil) + suite.NotNil(err) + suite.Equal("invalid parameters call, supply either (artifactName and modelVersionId), or externalId: bad request", err.Error()) +} + +func (suite *CoreTestSuite) TestGetModelArtifactByParamsWithNoResults() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + _, err := service.GetModelArtifactByParams(apiutils.Of("not-present"), &modelVersionId, nil) + suite.NotNil(err) + suite.Equal("no model artifacts found for artifactName=not-present, modelVersionId=2, externalId=: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestGetModelArtifacts() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + modelArtifact1 := &openapi.ModelArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + ExternalId: &artifactExtId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + secondArtifactName := "second-name" + secondArtifactExtId := "second-ext-id" + secondArtifactUri := "second-uri" + modelArtifact2 := &openapi.ModelArtifact{ + Name: &secondArtifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &secondArtifactUri, + ExternalId: &secondArtifactExtId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + thirdArtifactName := "third-name" + thirdArtifactExtId := "third-ext-id" + thirdArtifactUri := "third-uri" + modelArtifact3 := &openapi.ModelArtifact{ + Name: &thirdArtifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &thirdArtifactUri, + ExternalId: &thirdArtifactExtId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdArtifact1, err := service.UpsertModelArtifact(modelArtifact1, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + createdArtifact2, err := service.UpsertModelArtifact(modelArtifact2, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + createdArtifact3, err := service.UpsertModelArtifact(modelArtifact3, &modelVersionId) + suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) + + createdArtifactId1, _ := converter.StringToInt64(createdArtifact1.Id) + createdArtifactId2, _ := converter.StringToInt64(createdArtifact2.Id) + createdArtifactId3, _ := converter.StringToInt64(createdArtifact3.Id) + + getAll, err := service.GetModelArtifacts(api.ListOptions{}, nil) + suite.Nilf(err, "error getting all model artifacts") + suite.Equalf(int32(3), getAll.Size, "expected three model artifacts") + + suite.Equal(*converter.Int64ToString(createdArtifactId1), *getAll.Items[0].Id) + suite.Equal(*converter.Int64ToString(createdArtifactId2), *getAll.Items[1].Id) + suite.Equal(*converter.Int64ToString(createdArtifactId3), *getAll.Items[2].Id) + + orderByLastUpdate := "LAST_UPDATE_TIME" + getAllByModelVersion, err := service.GetModelArtifacts(api.ListOptions{ + OrderBy: &orderByLastUpdate, + SortOrder: &descOrderDirection, + }, &modelVersionId) + suite.Nilf(err, "error getting all model artifacts for %d", modelVersionId) + suite.Equalf(int32(3), getAllByModelVersion.Size, "expected three model artifacts for model version %d", modelVersionId) + + suite.Equal(*converter.Int64ToString(createdArtifactId1), *getAllByModelVersion.Items[2].Id) + suite.Equal(*converter.Int64ToString(createdArtifactId2), *getAllByModelVersion.Items[1].Id) + suite.Equal(*converter.Int64ToString(createdArtifactId3), *getAllByModelVersion.Items[0].Id) +} diff --git a/pkg/core/core.go b/pkg/core/core.go index fae55169..8d13051e 100644 --- a/pkg/core/core.go +++ b/pkg/core/core.go @@ -3,17 +3,12 @@ package core import ( "context" "fmt" - "strings" - "github.com/golang/glog" - "github.com/kubeflow/model-registry/internal/apiutils" - "github.com/kubeflow/model-registry/internal/converter" "github.com/kubeflow/model-registry/internal/converter/generated" "github.com/kubeflow/model-registry/internal/mapper" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" "github.com/kubeflow/model-registry/internal/mlmdtypes" "github.com/kubeflow/model-registry/pkg/api" - "github.com/kubeflow/model-registry/pkg/openapi" "google.golang.org/grpc" ) @@ -111,1350 +106,3 @@ func BuildTypesMap(cc grpc.ClientConnInterface, nameConfig mlmdtypes.MLMDTypeNam } return typesMap, nil } - -// REGISTERED MODELS - -// UpsertRegisteredModel creates a new registered model if the given registered model's ID is nil, -// or updates an existing registered model if the ID is provided. -func (serv *ModelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) { - var err error - var existing *openapi.RegisteredModel - - if registeredModel.Id == nil { - glog.Info("Creating new registered model") - } else { - glog.Infof("Updating registered model %s", *registeredModel.Id) - existing, err = serv.GetRegisteredModelById(*registeredModel.Id) - if err != nil { - return nil, err - } - - withNotEditable, err := serv.openapiConv.OverrideNotEditableForRegisteredModel(converter.NewOpenapiUpdateWrapper(existing, registeredModel)) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - registeredModel = &withNotEditable - } - - modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel) - if err != nil { - return nil, err - } - - modelCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ - Contexts: []*proto.Context{ - modelCtx, - }, - }) - if err != nil { - return nil, err - } - - idAsString := converter.Int64ToString(&modelCtxResp.ContextIds[0]) - model, err := serv.GetRegisteredModelById(*idAsString) - if err != nil { - return nil, err - } - - return model, nil -} - -// GetRegisteredModelById retrieves a registered model by its unique identifier (ID). -func (serv *ModelRegistryService) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) { - glog.Infof("Getting registered model %s", id) - - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{int64(*idAsInt)}, - }) - if err != nil { - return nil, err - } - - if len(getByIdResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple registered models found for id %s: %w", id, api.ErrNotFound) - } - - if len(getByIdResp.Contexts) == 0 { - return nil, fmt.Errorf("no registered model found for id %s: %w", id, api.ErrNotFound) - } - - regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return regModel, nil -} - -// GetRegisteredModelByInferenceService retrieves a registered model associated with the specified inference service ID. -func (serv *ModelRegistryService) GetRegisteredModelByInferenceService(inferenceServiceId string) (*openapi.RegisteredModel, error) { - is, err := serv.GetInferenceServiceById(inferenceServiceId) - if err != nil { - return nil, err - } - return serv.GetRegisteredModelById(is.RegisteredModelId) -} - -// getRegisteredModelByVersionId retrieves a registered model associated with the specified model version ID. -func (serv *ModelRegistryService) getRegisteredModelByVersionId(id string) (*openapi.RegisteredModel, error) { - glog.Infof("Getting registered model for model version %s", id) - - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - getParentResp, err := serv.mlmdClient.GetParentContextsByContext(context.Background(), &proto.GetParentContextsByContextRequest{ - ContextId: idAsInt, - }) - if err != nil { - return nil, err - } - - if len(getParentResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple registered models found for model version %s: %w", id, api.ErrNotFound) - } - - if len(getParentResp.Contexts) == 0 { - return nil, fmt.Errorf("no registered model found for model version %s: %w", id, api.ErrNotFound) - } - - regModel, err := serv.mapper.MapToRegisteredModel(getParentResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return regModel, nil -} - -// GetRegisteredModelByParams retrieves a registered model based on specified parameters, such as name or external ID. -// If multiple or no registered models are found, an error is returned accordingly. -func (serv *ModelRegistryService) GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error) { - glog.Infof("Getting registered model by params name=%v, externalId=%v", name, externalId) - - filterQuery := "" - if name != nil { - filterQuery = fmt.Sprintf("name = \"%s\"", *name) - } else if externalId != nil { - filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) - } else { - return nil, fmt.Errorf("invalid parameters call, supply either name or externalId: %w", api.ErrBadRequest) - } - glog.Info("filterQuery ", filterQuery) - - getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: &serv.nameConfig.RegisteredModelTypeName, - Options: &proto.ListOperationOptions{ - FilterQuery: &filterQuery, - }, - }) - if err != nil { - return nil, err - } - - if len(getByParamsResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple registered models found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - if len(getByParamsResp.Contexts) == 0 { - return nil, fmt.Errorf("no registered models found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - regModel, err := serv.mapper.MapToRegisteredModel(getByParamsResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - return regModel, nil -} - -// GetRegisteredModels retrieves a list of registered models based on the provided list options. -func (serv *ModelRegistryService) GetRegisteredModels(listOptions api.ListOptions) (*openapi.RegisteredModelList, error) { - listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) - if err != nil { - return nil, err - } - contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: &serv.nameConfig.RegisteredModelTypeName, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - - results := []openapi.RegisteredModel{} - for _, c := range contextsResp.Contexts { - mapped, err := serv.mapper.MapToRegisteredModel(c) - if err != nil { - return nil, err - } - results = append(results, *mapped) - } - - toReturn := openapi.RegisteredModelList{ - NextPageToken: apiutils.ZeroIfNil(contextsResp.NextPageToken), - PageSize: apiutils.ZeroIfNil(listOptions.PageSize), - Size: int32(len(results)), - Items: results, - } - return &toReturn, nil -} - -// MODEL VERSIONS - -// UpsertModelVersion creates a new model version if the provided model version's ID is nil, -// or updates an existing model version if the ID is provided. -func (serv *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, registeredModelId *string) (*openapi.ModelVersion, error) { - var err error - var existing *openapi.ModelVersion - var registeredModel *openapi.RegisteredModel - - if modelVersion.Id == nil { - // create - glog.Info("Creating new model version") - if registeredModelId == nil { - return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model: %w", api.ErrBadRequest) - } - registeredModel, err = serv.GetRegisteredModelById(*registeredModelId) - if err != nil { - return nil, err - } - } else { - // update - glog.Infof("Updating model version %s", *modelVersion.Id) - existing, err = serv.GetModelVersionById(*modelVersion.Id) - if err != nil { - return nil, err - } - - withNotEditable, err := serv.openapiConv.OverrideNotEditableForModelVersion(converter.NewOpenapiUpdateWrapper(existing, modelVersion)) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - modelVersion = &withNotEditable - - registeredModel, err = serv.getRegisteredModelByVersionId(*modelVersion.Id) - if err != nil { - return nil, err - } - } - - modelCtx, err := serv.mapper.MapFromModelVersion(modelVersion, *registeredModel.Id, registeredModel.Name) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - modelCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ - Contexts: []*proto.Context{ - modelCtx, - }, - }) - if err != nil { - return nil, err - } - - modelId := &modelCtxResp.ContextIds[0] - if modelVersion.Id == nil { - registeredModelId, err := converter.StringToInt64(registeredModel.Id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - _, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{ - ParentContexts: []*proto.ParentContext{{ - ChildId: modelId, - ParentId: registeredModelId, - }}, - TransactionOptions: &proto.TransactionOptions{}, - }) - if err != nil { - return nil, err - } - } - - idAsString := converter.Int64ToString(modelId) - model, err := serv.GetModelVersionById(*idAsString) - if err != nil { - return nil, err - } - - return model, nil -} - -// GetModelVersionById retrieves a model version by its unique identifier (ID). -func (serv *ModelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) { - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{int64(*idAsInt)}, - }) - if err != nil { - return nil, err - } - - if len(getByIdResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple model versions found for id %s: %w", id, api.ErrNotFound) - } - - if len(getByIdResp.Contexts) == 0 { - return nil, fmt.Errorf("no model version found for id %s: %w", id, api.ErrNotFound) - } - - modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return modelVer, nil -} - -// GetModelVersionByInferenceService retrieves the model version associated with the specified inference service ID. -func (serv *ModelRegistryService) GetModelVersionByInferenceService(inferenceServiceId string) (*openapi.ModelVersion, error) { - is, err := serv.GetInferenceServiceById(inferenceServiceId) - if err != nil { - return nil, err - } - if is.ModelVersionId != nil { - return serv.GetModelVersionById(*is.ModelVersionId) - } - // modelVersionId: ID of the ModelVersion to serve. If it's unspecified, then the latest ModelVersion by creation order will be served. - orderByCreateTime := "CREATE_TIME" - sortOrderDesc := "DESC" - versions, err := serv.GetModelVersions(api.ListOptions{OrderBy: &orderByCreateTime, SortOrder: &sortOrderDesc}, &is.RegisteredModelId) - if err != nil { - return nil, err - } - if len(versions.Items) == 0 { - return nil, fmt.Errorf("no model versions found for id %s: %w", is.RegisteredModelId, api.ErrNotFound) - } - return &versions.Items[0], nil -} - -// getModelVersionByArtifactId retrieves the model version associated with the specified model artifact ID. -func (serv *ModelRegistryService) getModelVersionByArtifactId(id string) (*openapi.ModelVersion, error) { - glog.Infof("Getting model version for model artifact %s", id) - - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - getParentResp, err := serv.mlmdClient.GetContextsByArtifact(context.Background(), &proto.GetContextsByArtifactRequest{ - ArtifactId: idAsInt, - }) - if err != nil { - return nil, err - } - - if len(getParentResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple model versions found for artifact %s: %w", id, api.ErrNotFound) - } - - if len(getParentResp.Contexts) == 0 { - return nil, fmt.Errorf("no model version found for artifact %s: %w", id, api.ErrNotFound) - } - - modelVersion, err := serv.mapper.MapToModelVersion(getParentResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return modelVersion, nil -} - -// GetModelVersionByParams retrieves a model version based on specified parameters, such as (version name and registered model ID), or external ID. -// If multiple or no model versions are found, an error is returned. -func (serv *ModelRegistryService) GetModelVersionByParams(versionName *string, registeredModelId *string, externalId *string) (*openapi.ModelVersion, error) { - filterQuery := "" - if versionName != nil && registeredModelId != nil { - filterQuery = fmt.Sprintf("name = \"%s\"", converter.PrefixWhenOwned(registeredModelId, *versionName)) - } else if externalId != nil { - filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) - } else { - return nil, fmt.Errorf("invalid parameters call, supply either (versionName and registeredModelId), or externalId: %w", api.ErrBadRequest) - } - - getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: &serv.nameConfig.ModelVersionTypeName, - Options: &proto.ListOperationOptions{ - FilterQuery: &filterQuery, - }, - }) - if err != nil { - return nil, err - } - - if len(getByParamsResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple model versions found for versionName=%v, registeredModelId=%v, externalId=%v: %w", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - if len(getByParamsResp.Contexts) == 0 { - return nil, fmt.Errorf("no model versions found for versionName=%v, registeredModelId=%v, externalId=%v: %w", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - return modelVer, nil -} - -// GetModelVersions retrieves a list of model versions based on the provided list options and optional registered model ID. -func (serv *ModelRegistryService) GetModelVersions(listOptions api.ListOptions, registeredModelId *string) (*openapi.ModelVersionList, error) { - listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - if registeredModelId != nil { - queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *registeredModelId) - listOperationOptions.FilterQuery = &queryParentCtxId - } - - contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: &serv.nameConfig.ModelVersionTypeName, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - - results := []openapi.ModelVersion{} - for _, c := range contextsResp.Contexts { - mapped, err := serv.mapper.MapToModelVersion(c) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - results = append(results, *mapped) - } - - toReturn := openapi.ModelVersionList{ - NextPageToken: apiutils.ZeroIfNil(contextsResp.NextPageToken), - PageSize: apiutils.ZeroIfNil(listOptions.PageSize), - Size: int32(len(results)), - Items: results, - } - return &toReturn, nil -} - -// ARTIFACTS - -// UpsertArtifact creates a new artifact if the provided artifact's ID is nil, or updates an existing artifact if the -// ID is provided. -// A model version ID must be provided to disambiguate between artifacts. -// Upon creation, new artifacts will be associated with their corresponding model version. -func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, modelVersionId *string) (*openapi.Artifact, error) { - if artifact == nil { - return nil, fmt.Errorf("invalid artifact pointer, can't upsert nil") - } - creating := false - if ma := artifact.ModelArtifact; ma != nil { - if ma.Id == nil { - creating = true - glog.Info("Creating model artifact") - if modelVersionId == nil { - return nil, fmt.Errorf("missing model version id, cannot create artifact without model version: %w", api.ErrBadRequest) - } - _, err := serv.GetModelVersionById(*modelVersionId) - if err != nil { - return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound) - } - } else { - glog.Info("Updating model artifact") - existing, err := serv.GetModelArtifactById(*ma.Id) - if err != nil { - return nil, err - } - - withNotEditable, err := serv.openapiConv.OverrideNotEditableForModelArtifact(converter.NewOpenapiUpdateWrapper(existing, ma)) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - ma = &withNotEditable - - _, err = serv.getModelVersionByArtifactId(*ma.Id) - if err != nil { - return nil, err - } - } - } else if da := artifact.DocArtifact; da != nil { - if da.Id == nil { - creating = true - glog.Info("Creating doc artifact") - if modelVersionId == nil { - return nil, fmt.Errorf("missing model version id, cannot create artifact without model version: %w", api.ErrBadRequest) - } - _, err := serv.GetModelVersionById(*modelVersionId) - if err != nil { - return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound) - } - } else { - glog.Info("Updating doc artifact") - existing, err := serv.GetArtifactById(*da.Id) - if err != nil { - return nil, err - } - if existing.DocArtifact == nil { - return nil, fmt.Errorf("mismatched types, artifact with id %s is not a doc artifact: %w", *da.Id, api.ErrBadRequest) - } - - withNotEditable, err := serv.openapiConv.OverrideNotEditableForDocArtifact(converter.NewOpenapiUpdateWrapper(existing.DocArtifact, da)) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - da = &withNotEditable - - _, err = serv.getModelVersionByArtifactId(*da.Id) - if err != nil { - return nil, err - } - } - } else { - return nil, fmt.Errorf("invalid artifact type, must be either ModelArtifact or DocArtifact: %w", api.ErrBadRequest) - } - pa, err := serv.mapper.MapFromArtifact(artifact, modelVersionId) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{ - Artifacts: []*proto.Artifact{pa}, - }) - if err != nil { - return nil, err - } - - if creating { - // add explicit Attribution between Artifact and ModelVersion - modelVersionId, err := converter.StringToInt64(modelVersionId) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - attributions := []*proto.Attribution{} - for _, a := range artifactsResp.ArtifactIds { - attributions = append(attributions, &proto.Attribution{ - ContextId: modelVersionId, - ArtifactId: &a, - }) - } - _, err = serv.mlmdClient.PutAttributionsAndAssociations(context.Background(), &proto.PutAttributionsAndAssociationsRequest{ - Attributions: attributions, - Associations: make([]*proto.Association, 0), - }) - if err != nil { - return nil, err - } - } - - idAsString := converter.Int64ToString(&artifactsResp.ArtifactIds[0]) - return serv.GetArtifactById(*idAsString) -} - -func (serv *ModelRegistryService) GetArtifactById(id string) (*openapi.Artifact, error) { - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{ - ArtifactIds: []int64{int64(*idAsInt)}, - }) - if err != nil { - return nil, err - } - if len(artifactsResp.Artifacts) > 1 { - return nil, fmt.Errorf("multiple artifacts found for id %s: %w", id, api.ErrNotFound) - } - if len(artifactsResp.Artifacts) == 0 { - return nil, fmt.Errorf("no artifact found for id %s: %w", id, api.ErrNotFound) - } - return serv.mapper.MapToArtifact(artifactsResp.Artifacts[0]) -} - -func (serv *ModelRegistryService) GetArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ArtifactList, error) { - listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - var artifacts []*proto.Artifact - var nextPageToken *string - if modelVersionId == nil { - return nil, fmt.Errorf("missing model version id, cannot get artifacts without model version: %w", api.ErrBadRequest) - } - ctxId, err := converter.StringToInt64(modelVersionId) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{ - ContextId: ctxId, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - artifacts = artifactsResp.Artifacts - nextPageToken = artifactsResp.NextPageToken - - results := []openapi.Artifact{} - for _, a := range artifacts { - mapped, err := serv.mapper.MapToArtifact(a) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - results = append(results, *mapped) - } - - toReturn := openapi.ArtifactList{ - NextPageToken: apiutils.ZeroIfNil(nextPageToken), - PageSize: apiutils.ZeroIfNil(listOptions.PageSize), - Size: int32(len(results)), - Items: results, - } - return &toReturn, nil -} - -// MODEL ARTIFACTS - -// UpsertModelArtifact creates a new model artifact if the provided model artifact's ID is nil, -// or updates an existing model artifact if the ID is provided. -// If a model version ID is provided and the model artifact is newly created, establishes an -// explicit attribution between the model version and the created model artifact. -func (serv *ModelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, modelVersionId *string) (*openapi.ModelArtifact, error) { - art, err := serv.UpsertArtifact(&openapi.Artifact{ - ModelArtifact: modelArtifact, - }, modelVersionId) - if err != nil { - return nil, err - } - return art.ModelArtifact, err -} - -// GetModelArtifactById retrieves a model artifact by its unique identifier (ID). -func (serv *ModelRegistryService) GetModelArtifactById(id string) (*openapi.ModelArtifact, error) { - art, err := serv.GetArtifactById(id) - if err != nil { - return nil, err - } - ma := art.ModelArtifact - if ma == nil { - return nil, fmt.Errorf("artifact with id %s is not a model artifact: %w", id, api.ErrNotFound) - } - return ma, err -} - -// GetModelArtifactByInferenceService retrieves the model artifact associated with the specified inference service ID. -func (serv *ModelRegistryService) GetModelArtifactByInferenceService(inferenceServiceId string) (*openapi.ModelArtifact, error) { - mv, err := serv.GetModelVersionByInferenceService(inferenceServiceId) - if err != nil { - return nil, err - } - - artifactList, err := serv.GetModelArtifacts(api.ListOptions{}, mv.Id) - if err != nil { - return nil, err - } - - if artifactList.Size == 0 { - return nil, fmt.Errorf("no artifacts found for model version %s: %w", *mv.Id, api.ErrNotFound) - } - - return &artifactList.Items[0], nil -} - -// GetModelArtifactByParams retrieves a model artifact based on specified parameters, such as (artifact name and model version ID), or external ID. -// If multiple or no model artifacts are found, an error is returned. -func (serv *ModelRegistryService) GetModelArtifactByParams(artifactName *string, modelVersionId *string, externalId *string) (*openapi.ModelArtifact, error) { - var artifact0 *proto.Artifact - - filterQuery := "" - if externalId != nil { - filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) - } else if artifactName != nil && modelVersionId != nil { - filterQuery = fmt.Sprintf("name = \"%s\"", converter.PrefixWhenOwned(modelVersionId, *artifactName)) - } else { - return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and modelVersionId), or externalId: %w", api.ErrBadRequest) - } - glog.Info("filterQuery ", filterQuery) - - artifactsResponse, err := serv.mlmdClient.GetArtifactsByType(context.Background(), &proto.GetArtifactsByTypeRequest{ - TypeName: &serv.nameConfig.ModelArtifactTypeName, - Options: &proto.ListOperationOptions{ - FilterQuery: &filterQuery, - }, - }) - if err != nil { - return nil, err - } - - if len(artifactsResponse.Artifacts) > 1 { - return nil, fmt.Errorf("multiple model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v: %w", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - if len(artifactsResponse.Artifacts) == 0 { - return nil, fmt.Errorf("no model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v: %w", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - artifact0 = artifactsResponse.Artifacts[0] - - result, err := serv.mapper.MapToModelArtifact(artifact0) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return result, nil -} - -// GetModelArtifacts retrieves a list of model artifacts based on the provided list options and optional model version ID. -func (serv *ModelRegistryService) GetModelArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ModelArtifactList, error) { - listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - var artifacts []*proto.Artifact - var nextPageToken *string - if modelVersionId != nil { - ctxId, err := converter.StringToInt64(modelVersionId) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{ - ContextId: ctxId, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - artifacts = artifactsResp.Artifacts - nextPageToken = artifactsResp.NextPageToken - } else { - artifactsResp, err := serv.mlmdClient.GetArtifactsByType(context.Background(), &proto.GetArtifactsByTypeRequest{ - TypeName: &serv.nameConfig.ModelArtifactTypeName, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - artifacts = artifactsResp.Artifacts - nextPageToken = artifactsResp.NextPageToken - } - - results := []openapi.ModelArtifact{} - for _, a := range artifacts { - mapped, err := serv.mapper.MapToModelArtifact(a) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - results = append(results, *mapped) - } - - toReturn := openapi.ModelArtifactList{ - NextPageToken: apiutils.ZeroIfNil(nextPageToken), - PageSize: apiutils.ZeroIfNil(listOptions.PageSize), - Size: int32(len(results)), - Items: results, - } - return &toReturn, nil -} - -// SERVING ENVIRONMENT - -// UpsertServingEnvironment creates a new serving environment if the provided serving environment's ID is nil, -// or updates an existing serving environment if the ID is provided. -func (serv *ModelRegistryService) UpsertServingEnvironment(servingEnvironment *openapi.ServingEnvironment) (*openapi.ServingEnvironment, error) { - var err error - var existing *openapi.ServingEnvironment - - if servingEnvironment.Id == nil { - glog.Info("Creating new serving environment") - } else { - glog.Infof("Updating serving environment %s", *servingEnvironment.Id) - existing, err = serv.GetServingEnvironmentById(*servingEnvironment.Id) - if err != nil { - return nil, err - } - - withNotEditable, err := serv.openapiConv.OverrideNotEditableForServingEnvironment(converter.NewOpenapiUpdateWrapper(existing, servingEnvironment)) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - servingEnvironment = &withNotEditable - } - - protoCtx, err := serv.mapper.MapFromServingEnvironment(servingEnvironment) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - protoCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ - Contexts: []*proto.Context{ - protoCtx, - }, - }) - if err != nil { - return nil, err - } - - idAsString := converter.Int64ToString(&protoCtxResp.ContextIds[0]) - openapiModel, err := serv.GetServingEnvironmentById(*idAsString) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return openapiModel, nil -} - -// GetServingEnvironmentById retrieves a serving environment by its unique identifier (ID). -func (serv *ModelRegistryService) GetServingEnvironmentById(id string) (*openapi.ServingEnvironment, error) { - glog.Infof("Getting serving environment %s", id) - - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{*idAsInt}, - }) - if err != nil { - return nil, err - } - - if len(getByIdResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple serving environments found for id %s: %w", id, api.ErrNotFound) - } - - if len(getByIdResp.Contexts) == 0 { - return nil, fmt.Errorf("no serving environment found for id %s: %w", id, api.ErrNotFound) - } - - openapiModel, err := serv.mapper.MapToServingEnvironment(getByIdResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return openapiModel, nil -} - -// GetServingEnvironmentByParams retrieves a serving environment based on specified parameters, such as name or external ID. -// If multiple or no serving environments are found, an error is returned accordingly. -func (serv *ModelRegistryService) GetServingEnvironmentByParams(name *string, externalId *string) (*openapi.ServingEnvironment, error) { - glog.Infof("Getting serving environment by params name=%v, externalId=%v", name, externalId) - - filterQuery := "" - if name != nil { - filterQuery = fmt.Sprintf("name = \"%s\"", *name) - } else if externalId != nil { - filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) - } else { - return nil, fmt.Errorf("invalid parameters call, supply either name or externalId: %w", api.ErrBadRequest) - } - - getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: &serv.nameConfig.ServingEnvironmentTypeName, - Options: &proto.ListOperationOptions{ - FilterQuery: &filterQuery, - }, - }) - if err != nil { - return nil, err - } - - if len(getByParamsResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple serving environments found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - if len(getByParamsResp.Contexts) == 0 { - return nil, fmt.Errorf("no serving environments found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - openapiModel, err := serv.mapper.MapToServingEnvironment(getByParamsResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - return openapiModel, nil -} - -// GetServingEnvironments retrieves a list of serving environments based on the provided list options. -func (serv *ModelRegistryService) GetServingEnvironments(listOptions api.ListOptions) (*openapi.ServingEnvironmentList, error) { - listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: &serv.nameConfig.ServingEnvironmentTypeName, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - - results := []openapi.ServingEnvironment{} - for _, c := range contextsResp.Contexts { - mapped, err := serv.mapper.MapToServingEnvironment(c) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - results = append(results, *mapped) - } - - toReturn := openapi.ServingEnvironmentList{ - NextPageToken: apiutils.ZeroIfNil(contextsResp.NextPageToken), - PageSize: apiutils.ZeroIfNil(listOptions.PageSize), - Size: int32(len(results)), - Items: results, - } - return &toReturn, nil -} - -// INFERENCE SERVICE - -// UpsertInferenceService creates a new inference service if the provided inference service's ID is nil, -// or updates an existing inference service if the ID is provided. -func (serv *ModelRegistryService) UpsertInferenceService(inferenceService *openapi.InferenceService) (*openapi.InferenceService, error) { - var err error - var existing *openapi.InferenceService - var servingEnvironment *openapi.ServingEnvironment - - if inferenceService.Id == nil { - // create - glog.Info("Creating new InferenceService") - servingEnvironment, err = serv.GetServingEnvironmentById(inferenceService.ServingEnvironmentId) - if err != nil { - return nil, err - } - } else { - // update - glog.Infof("Updating InferenceService %s", *inferenceService.Id) - - existing, err = serv.GetInferenceServiceById(*inferenceService.Id) - if err != nil { - return nil, err - } - - withNotEditable, err := serv.openapiConv.OverrideNotEditableForInferenceService(converter.NewOpenapiUpdateWrapper(existing, inferenceService)) - if err != nil { - return nil, err - } - inferenceService = &withNotEditable - - servingEnvironment, err = serv.getServingEnvironmentByInferenceServiceId(*inferenceService.Id) - if err != nil { - return nil, err - } - } - - // validate RegisteredModelId is also valid - if _, err := serv.GetRegisteredModelById(inferenceService.RegisteredModelId); err != nil { - return nil, err - } - - // if already existing assure the name is the same - if existing != nil && inferenceService.Name == nil { - // user did not provide it - // need to set it to avoid mlmd error "context name should not be empty" - inferenceService.Name = existing.Name - } - - protoCtx, err := serv.mapper.MapFromInferenceService(inferenceService, *servingEnvironment.Id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - protoCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ - Contexts: []*proto.Context{ - protoCtx, - }, - }) - if err != nil { - return nil, err - } - - inferenceServiceId := &protoCtxResp.ContextIds[0] - if inferenceService.Id == nil { - servingEnvironmentId, err := converter.StringToInt64(servingEnvironment.Id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - _, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{ - ParentContexts: []*proto.ParentContext{{ - ChildId: inferenceServiceId, - ParentId: servingEnvironmentId, - }}, - TransactionOptions: &proto.TransactionOptions{}, - }) - if err != nil { - return nil, err - } - } - - idAsString := converter.Int64ToString(inferenceServiceId) - toReturn, err := serv.GetInferenceServiceById(*idAsString) - if err != nil { - return nil, err - } - - return toReturn, nil -} - -// getServingEnvironmentByInferenceServiceId retrieves the serving environment associated with the specified inference service ID. -func (serv *ModelRegistryService) getServingEnvironmentByInferenceServiceId(id string) (*openapi.ServingEnvironment, error) { - glog.Infof("Getting ServingEnvironment for InferenceService %s", id) - - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - getParentResp, err := serv.mlmdClient.GetParentContextsByContext(context.Background(), &proto.GetParentContextsByContextRequest{ - ContextId: idAsInt, - }) - if err != nil { - return nil, err - } - - if len(getParentResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple ServingEnvironments found for InferenceService %s: %w", id, api.ErrNotFound) - } - - if len(getParentResp.Contexts) == 0 { - return nil, fmt.Errorf("no ServingEnvironments found for InferenceService %s: %w", id, api.ErrNotFound) - } - - toReturn, err := serv.mapper.MapToServingEnvironment(getParentResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return toReturn, nil -} - -// GetInferenceServiceById retrieves an inference service by its unique identifier (ID). -func (serv *ModelRegistryService) GetInferenceServiceById(id string) (*openapi.InferenceService, error) { - glog.Infof("Getting InferenceService by id %s", id) - - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{*idAsInt}, - }) - if err != nil { - return nil, err - } - - if len(getByIdResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple InferenceServices found for id %s: %w", id, api.ErrNotFound) - } - - if len(getByIdResp.Contexts) == 0 { - return nil, fmt.Errorf("no InferenceService found for id %s: %w", id, api.ErrNotFound) - } - - toReturn, err := serv.mapper.MapToInferenceService(getByIdResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return toReturn, nil -} - -// GetInferenceServiceByParams retrieves an inference service based on specified parameters, such as (name and serving environment ID), or external ID. -// If multiple or no serving environments are found, an error is returned accordingly. -func (serv *ModelRegistryService) GetInferenceServiceByParams(name *string, servingEnvironmentId *string, externalId *string) (*openapi.InferenceService, error) { - filterQuery := "" - if name != nil && servingEnvironmentId != nil { - filterQuery = fmt.Sprintf("name = \"%s\"", converter.PrefixWhenOwned(servingEnvironmentId, *name)) - } else if externalId != nil { - filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) - } else { - return nil, fmt.Errorf("invalid parameters call, supply either (name and servingEnvironmentId), or externalId: %w", api.ErrBadRequest) - } - - getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: &serv.nameConfig.InferenceServiceTypeName, - Options: &proto.ListOperationOptions{ - FilterQuery: &filterQuery, - }, - }) - if err != nil { - return nil, err - } - - if len(getByParamsResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple inference services found for name=%v, servingEnvironmentId=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(servingEnvironmentId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - if len(getByParamsResp.Contexts) == 0 { - return nil, fmt.Errorf("no inference services found for name=%v, servingEnvironmentId=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(servingEnvironmentId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) - } - - toReturn, err := serv.mapper.MapToInferenceService(getByParamsResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - return toReturn, nil -} - -// GetInferenceServices retrieves a list of inference services based on the provided list options and optional serving environment ID and runtime. -func (serv *ModelRegistryService) GetInferenceServices(listOptions api.ListOptions, servingEnvironmentId *string, runtime *string) (*openapi.InferenceServiceList, error) { - listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - queries := []string{} - if servingEnvironmentId != nil { - queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *servingEnvironmentId) - queries = append(queries, queryParentCtxId) - } - - if runtime != nil { - queryRuntimeProp := fmt.Sprintf("properties.runtime.string_value = \"%s\"", *runtime) - queries = append(queries, queryRuntimeProp) - } - - query := strings.Join(queries, " and ") - listOperationOptions.FilterQuery = &query - - contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: &serv.nameConfig.InferenceServiceTypeName, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - - results := []openapi.InferenceService{} - for _, c := range contextsResp.Contexts { - mapped, err := serv.mapper.MapToInferenceService(c) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - results = append(results, *mapped) - } - - toReturn := openapi.InferenceServiceList{ - NextPageToken: apiutils.ZeroIfNil(contextsResp.NextPageToken), - PageSize: apiutils.ZeroIfNil(listOptions.PageSize), - Size: int32(len(results)), - Items: results, - } - return &toReturn, nil -} - -// SERVE MODEL - -// UpsertServeModel creates a new serve model if the provided serve model's ID is nil, -// or updates an existing serve model if the ID is provided. -func (serv *ModelRegistryService) UpsertServeModel(serveModel *openapi.ServeModel, inferenceServiceId *string) (*openapi.ServeModel, error) { - var err error - var existing *openapi.ServeModel - - if serveModel.Id == nil { - // create - glog.Info("Creating new ServeModel") - if inferenceServiceId == nil { - return nil, fmt.Errorf("missing inferenceServiceId, cannot create ServeModel without parent resource InferenceService: %w", api.ErrBadRequest) - } - _, err = serv.GetInferenceServiceById(*inferenceServiceId) - if err != nil { - return nil, err - } - } else { - // update - glog.Infof("Updating ServeModel %s", *serveModel.Id) - - existing, err = serv.GetServeModelById(*serveModel.Id) - if err != nil { - return nil, err - } - - withNotEditable, err := serv.openapiConv.OverrideNotEditableForServeModel(converter.NewOpenapiUpdateWrapper(existing, serveModel)) - if err != nil { - return nil, err - } - serveModel = &withNotEditable - - _, err = serv.getInferenceServiceByServeModel(*serveModel.Id) - if err != nil { - return nil, err - } - } - _, err = serv.GetModelVersionById(serveModel.ModelVersionId) - if err != nil { - return nil, err - } - - // if already existing assure the name is the same - if existing != nil && serveModel.Name == nil { - // user did not provide it - // need to set it to avoid mlmd error "artifact name should not be empty" - serveModel.Name = existing.Name - } - - execution, err := serv.mapper.MapFromServeModel(serveModel, *inferenceServiceId) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - executionsResp, err := serv.mlmdClient.PutExecutions(context.Background(), &proto.PutExecutionsRequest{ - Executions: []*proto.Execution{execution}, - }) - if err != nil { - return nil, err - } - - // add explicit Association between ServeModel and InferenceService - if inferenceServiceId != nil && serveModel.Id == nil { - inferenceServiceId, err := converter.StringToInt64(inferenceServiceId) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - associations := []*proto.Association{} - for _, a := range executionsResp.ExecutionIds { - associations = append(associations, &proto.Association{ - ContextId: inferenceServiceId, - ExecutionId: &a, - }) - } - _, err = serv.mlmdClient.PutAttributionsAndAssociations(context.Background(), &proto.PutAttributionsAndAssociationsRequest{ - Attributions: make([]*proto.Attribution, 0), - Associations: associations, - }) - if err != nil { - return nil, err - } - } - - idAsString := converter.Int64ToString(&executionsResp.ExecutionIds[0]) - mapped, err := serv.GetServeModelById(*idAsString) - if err != nil { - return nil, err - } - return mapped, nil -} - -// getInferenceServiceByServeModel retrieves the inference service associated with the specified serve model ID. -func (serv *ModelRegistryService) getInferenceServiceByServeModel(id string) (*openapi.InferenceService, error) { - glog.Infof("Getting InferenceService for ServeModel %s", id) - - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - getParentResp, err := serv.mlmdClient.GetContextsByExecution(context.Background(), &proto.GetContextsByExecutionRequest{ - ExecutionId: idAsInt, - }) - if err != nil { - return nil, err - } - - if len(getParentResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple InferenceService found for ServeModel %s: %w", id, api.ErrNotFound) - } - - if len(getParentResp.Contexts) == 0 { - return nil, fmt.Errorf("no InferenceService found for ServeModel %s: %w", id, api.ErrNotFound) - } - - toReturn, err := serv.mapper.MapToInferenceService(getParentResp.Contexts[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return toReturn, nil -} - -// GetServeModelById retrieves a serve model by its unique identifier (ID). -func (serv *ModelRegistryService) GetServeModelById(id string) (*openapi.ServeModel, error) { - idAsInt, err := converter.StringToInt64(&id) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - executionsResp, err := serv.mlmdClient.GetExecutionsByID(context.Background(), &proto.GetExecutionsByIDRequest{ - ExecutionIds: []int64{int64(*idAsInt)}, - }) - if err != nil { - return nil, err - } - - if len(executionsResp.Executions) > 1 { - return nil, fmt.Errorf("multiple ServeModels found for id %s: %w", id, api.ErrNotFound) - } - - if len(executionsResp.Executions) == 0 { - return nil, fmt.Errorf("no ServeModel found for id %s: %w", id, api.ErrNotFound) - } - - result, err := serv.mapper.MapToServeModel(executionsResp.Executions[0]) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - return result, nil -} - -// GetServeModels retrieves a list of serve models based on the provided list options and optional inference service ID. -func (serv *ModelRegistryService) GetServeModels(listOptions api.ListOptions, inferenceServiceId *string) (*openapi.ServeModelList, error) { - listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - - var executions []*proto.Execution - var nextPageToken *string - if inferenceServiceId != nil { - ctxId, err := converter.StringToInt64(inferenceServiceId) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - executionsResp, err := serv.mlmdClient.GetExecutionsByContext(context.Background(), &proto.GetExecutionsByContextRequest{ - ContextId: ctxId, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - executions = executionsResp.Executions - nextPageToken = executionsResp.NextPageToken - } else { - executionsResp, err := serv.mlmdClient.GetExecutionsByType(context.Background(), &proto.GetExecutionsByTypeRequest{ - TypeName: &serv.nameConfig.ServeModelTypeName, - Options: listOperationOptions, - }) - if err != nil { - return nil, err - } - executions = executionsResp.Executions - nextPageToken = executionsResp.NextPageToken - } - - results := []openapi.ServeModel{} - for _, a := range executions { - mapped, err := serv.mapper.MapToServeModel(a) - if err != nil { - return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) - } - results = append(results, *mapped) - } - - toReturn := openapi.ServeModelList{ - NextPageToken: apiutils.ZeroIfNil(nextPageToken), - PageSize: apiutils.ZeroIfNil(listOptions.PageSize), - Size: int32(len(results)), - Items: results, - } - return &toReturn, nil -} diff --git a/pkg/core/core_test.go b/pkg/core/core_test.go index ec729514..e625e39a 100644 --- a/pkg/core/core_test.go +++ b/pkg/core/core_test.go @@ -2,7 +2,6 @@ package core import ( "context" - "fmt" "strings" "testing" "time" @@ -580,2749 +579,3 @@ func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInServeModel( suite.NotNil(err) suite.Regexp("error setting up execution type "+*serveModelTypeName+": rpc error: code = AlreadyExists.*", err.Error()) } - -// REGISTERED MODELS - -func (suite *CoreTestSuite) TestCreateRegisteredModel() { - // create mode registry service - service := suite.setupModelRegistryService() - - state := openapi.REGISTEREDMODELSTATE_ARCHIVED - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - ExternalId: &modelExternalId, - Description: &modelDescription, - Owner: &modelOwner, - State: &state, - CustomProperties: &map[string]openapi.MetadataValue{ - "myCustomProp": { - MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), - }, - }, - } - - // test - createdModel, err := service.UpsertRegisteredModel(registeredModel) - - // checks - suite.Nilf(err, "error creating registered model: %v", err) - suite.NotNilf(createdModel.Id, "created registered model should not have nil Id") - - createdModelId, _ := converter.StringToInt64(createdModel.Id) - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{*createdModelId}, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - ctxId := converter.Int64ToString(ctx.Id) - suite.Equal(*createdModel.Id, *ctxId, "returned model id should match the mlmd one") - suite.Equal(modelName, *ctx.Name, "saved model name should match the provided one") - suite.Equal(modelExternalId, *ctx.ExternalId, "saved external id should match the provided one") - suite.Equal(modelDescription, ctx.Properties["description"].GetStringValue(), "saved description should match the provided one") - suite.Equal(modelOwner, ctx.Properties["owner"].GetStringValue(), "saved owner should match the provided one") - suite.Equal(string(state), ctx.Properties["state"].GetStringValue(), "saved state should match the provided one") - suite.Equal(myCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") - - getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") -} - -func (suite *CoreTestSuite) TestUpdateRegisteredModel() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - Owner: &modelOwner, - ExternalId: &modelExternalId, - CustomProperties: &map[string]openapi.MetadataValue{ - "myCustomProp": { - MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), - }, - }, - } - - // test - createdModel, err := service.UpsertRegisteredModel(registeredModel) - - // checks - suite.Nilf(err, "error creating registered model: %v", err) - suite.NotNilf(createdModel.Id, "created registered model should not have nil Id") - createdModelId, _ := converter.StringToInt64(createdModel.Id) - - // checks created model matches original one except for Id - suite.Equal(registeredModel.Name, createdModel.Name, "returned model name should match the original one") - suite.Equal(*registeredModel.ExternalId, *createdModel.ExternalId, "returned model external id should match the original one") - suite.Equal(*registeredModel.CustomProperties, *createdModel.CustomProperties, "returned model custom props should match the original one") - - // update existing model - newModelExternalId := "newExternalId" - newOwner := "newOwner" - newCustomProp := "updated myCustomProp" - - createdModel.ExternalId = &newModelExternalId - createdModel.Owner = &newOwner - (*createdModel.CustomProperties)["myCustomProp"] = openapi.MetadataValue{ - MetadataStringValue: converter.NewMetadataStringValue(newCustomProp), - } - // check can also define customProperty of name "owner", in addition to built-in property "owner" - (*createdModel.CustomProperties)["owner"] = openapi.MetadataValue{ - MetadataStringValue: converter.NewMetadataStringValue(newCustomProp), - } - - // update the model - createdModel, err = service.UpsertRegisteredModel(createdModel) - suite.Nilf(err, "error creating registered model: %v", err) - - // still one registered model - getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") - - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{*createdModelId}, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - ctxId := converter.Int64ToString(ctx.Id) - suite.Equal(*createdModel.Id, *ctxId, "returned model id should match the mlmd one") - suite.Equal(modelName, *ctx.Name, "saved model name should match the provided one") - suite.Equal(newModelExternalId, *ctx.ExternalId, "saved external id should match the provided one") - suite.Equal(newOwner, ctx.Properties["owner"].GetStringValue(), "saved owner custom property should match the provided one") - suite.Equal(newCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") - suite.Equal(newCustomProp, ctx.CustomProperties["owner"].GetStringValue(), "check can define custom property 'onwer' and should match the provided one") - - // update the model keeping nil name - newModelExternalId = "newNewExternalId" - createdModel.ExternalId = &newModelExternalId - createdModel.Name = "" - createdModel, err = service.UpsertRegisteredModel(createdModel) - suite.Nilf(err, "error creating registered model: %v", err) - - // still one registered model - getAllResp, err = suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") - - ctxById, err = suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{*createdModelId}, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx = ctxById.Contexts[0] - ctxId = converter.Int64ToString(ctx.Id) - suite.Equal(*createdModel.Id, *ctxId, "returned model id should match the mlmd one") - suite.Equal(modelName, *ctx.Name, "saved model name should match the provided one") - suite.Equal(newModelExternalId, *ctx.ExternalId, "saved external id should match the provided one") - suite.Equal(newOwner, ctx.Properties["owner"].GetStringValue(), "saved owner custom property should match the provided one") - suite.Equal(newCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") - suite.Equal(newCustomProp, ctx.CustomProperties["owner"].GetStringValue(), "check can define custom property 'onwer' and should match the provided one") -} - -func (suite *CoreTestSuite) TestGetRegisteredModelById() { - // create mode registry service - service := suite.setupModelRegistryService() - - state := openapi.REGISTEREDMODELSTATE_LIVE - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - ExternalId: &modelExternalId, - State: &state, - CustomProperties: &map[string]openapi.MetadataValue{ - "myCustomProp": { - MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), - }, - }, - } - - // test - createdModel, err := service.UpsertRegisteredModel(registeredModel) - - // checks - suite.Nilf(err, "error creating registered model: %v", err) - - getModelById, err := service.GetRegisteredModelById(*createdModel.Id) - suite.Nilf(err, "error getting registered model by id %s: %v", *createdModel.Id, err) - - // checks created model matches original one except for Id - suite.Equal(registeredModel.Name, getModelById.Name, "saved model name should match the original one") - suite.Equal(*registeredModel.ExternalId, *getModelById.ExternalId, "saved model external id should match the original one") - suite.Equal(*registeredModel.State, *getModelById.State, "saved model state should match the original one") - suite.Equal(*registeredModel.CustomProperties, *getModelById.CustomProperties, "saved model custom props should match the original one") -} - -func (suite *CoreTestSuite) TestGetRegisteredModelByParamsWithNoResults() { - // create mode registry service - service := suite.setupModelRegistryService() - - _, err := service.GetRegisteredModelByParams(apiutils.Of("not-present"), nil) - suite.NotNil(err) - suite.Equal("no registered models found for name=not-present, externalId=: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestGetRegisteredModelByParamsName() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - ExternalId: &modelExternalId, - } - - createdModel, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - byName, err := service.GetRegisteredModelByParams(&modelName, nil) - suite.Nilf(err, "error getting registered model by name: %v", err) - - suite.Equalf(*createdModel.Id, *byName.Id, "the returned model id should match the retrieved by name") -} - -func (suite *CoreTestSuite) TestGetRegisteredModelByParamsExternalId() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - ExternalId: &modelExternalId, - } - - createdModel, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - byName, err := service.GetRegisteredModelByParams(nil, &modelExternalId) - suite.Nilf(err, "error getting registered model by external id: %v", err) - - suite.Equalf(*createdModel.Id, *byName.Id, "the returned model id should match the retrieved by name") -} - -func (suite *CoreTestSuite) TestGetRegisteredModelByEmptyParams() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - ExternalId: &modelExternalId, - } - - _, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - _, err = service.GetRegisteredModelByParams(nil, nil) - suite.NotNil(err) - suite.Equal("invalid parameters call, supply either name or externalId: bad request", err.Error()) -} - -func (suite *CoreTestSuite) TestGetRegisteredModelsOrderedById() { - // create mode registry service - service := suite.setupModelRegistryService() - - orderBy := "ID" - - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - ExternalId: &modelExternalId, - } - - _, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - newModelName := "PricingModel2" - newModelExternalId := "myExternalId2" - registeredModel.Name = newModelName - registeredModel.ExternalId = &newModelExternalId - _, err = service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - newModelName = "PricingModel3" - newModelExternalId = "myExternalId3" - registeredModel.Name = newModelName - registeredModel.ExternalId = &newModelExternalId - _, err = service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - orderedById, err := service.GetRegisteredModels(api.ListOptions{ - OrderBy: &orderBy, - SortOrder: &ascOrderDirection, - }) - suite.Nilf(err, "error getting registered models: %v", err) - - suite.Equal(3, int(orderedById.Size)) - for i := 0; i < int(orderedById.Size)-1; i++ { - suite.Less(*orderedById.Items[i].Id, *orderedById.Items[i+1].Id) - } - - orderedById, err = service.GetRegisteredModels(api.ListOptions{ - OrderBy: &orderBy, - SortOrder: &descOrderDirection, - }) - suite.Nilf(err, "error getting registered models: %v", err) - - suite.Equal(3, int(orderedById.Size)) - for i := 0; i < int(orderedById.Size)-1; i++ { - suite.Greater(*orderedById.Items[i].Id, *orderedById.Items[i+1].Id) - } -} - -func (suite *CoreTestSuite) TestGetRegisteredModelsOrderedByLastUpdate() { - // create mode registry service - service := suite.setupModelRegistryService() - - orderBy := "LAST_UPDATE_TIME" - - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - ExternalId: &modelExternalId, - } - - firstModel, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - newModelName := "PricingModel2" - newModelExternalId := "myExternalId2" - registeredModel.Name = newModelName - registeredModel.ExternalId = &newModelExternalId - secondModel, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - newModelName = "PricingModel3" - newModelExternalId = "myExternalId3" - registeredModel.Name = newModelName - registeredModel.ExternalId = &newModelExternalId - thirdModel, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - // update second model - secondModel.ExternalId = nil - _, err = service.UpsertRegisteredModel(secondModel) - suite.Nilf(err, "error creating registered model: %v", err) - - orderedById, err := service.GetRegisteredModels(api.ListOptions{ - OrderBy: &orderBy, - SortOrder: &ascOrderDirection, - }) - suite.Nilf(err, "error getting registered models: %v", err) - - suite.Equal(3, int(orderedById.Size)) - suite.Equal(*firstModel.Id, *orderedById.Items[0].Id) - suite.Equal(*thirdModel.Id, *orderedById.Items[1].Id) - suite.Equal(*secondModel.Id, *orderedById.Items[2].Id) - - orderedById, err = service.GetRegisteredModels(api.ListOptions{ - OrderBy: &orderBy, - SortOrder: &descOrderDirection, - }) - suite.Nilf(err, "error getting registered models: %v", err) - - suite.Equal(3, int(orderedById.Size)) - suite.Equal(*secondModel.Id, *orderedById.Items[0].Id) - suite.Equal(*thirdModel.Id, *orderedById.Items[1].Id) - suite.Equal(*firstModel.Id, *orderedById.Items[2].Id) -} - -func (suite *CoreTestSuite) TestGetRegisteredModelsWithPageSize() { - // create mode registry service - service := suite.setupModelRegistryService() - - pageSize := int32(1) - pageSize2 := int32(2) - modelName := "PricingModel1" - modelExternalId := "myExternalId1" - - // register a new model - registeredModel := &openapi.RegisteredModel{ - Name: modelName, - ExternalId: &modelExternalId, - } - - firstModel, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - newModelName := "PricingModel2" - newModelExternalId := "myExternalId2" - registeredModel.Name = newModelName - registeredModel.ExternalId = &newModelExternalId - secondModel, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - newModelName = "PricingModel3" - newModelExternalId = "myExternalId3" - registeredModel.Name = newModelName - registeredModel.ExternalId = &newModelExternalId - thirdModel, err := service.UpsertRegisteredModel(registeredModel) - suite.Nilf(err, "error creating registered model: %v", err) - - truncatedList, err := service.GetRegisteredModels(api.ListOptions{ - PageSize: &pageSize, - }) - suite.Nilf(err, "error getting registered models: %v", err) - - suite.Equal(1, int(truncatedList.Size)) - suite.NotEqual("", truncatedList.NextPageToken, "next page token should not be empty") - suite.Equal(*firstModel.Id, *truncatedList.Items[0].Id) - - truncatedList, err = service.GetRegisteredModels(api.ListOptions{ - PageSize: &pageSize2, - NextPageToken: &truncatedList.NextPageToken, - }) - suite.Nilf(err, "error getting registered models: %v", err) - - suite.Equal(2, int(truncatedList.Size)) - suite.Equal("", truncatedList.NextPageToken, "next page token should be empty as list item returned") - suite.Equal(*secondModel.Id, *truncatedList.Items[0].Id) - suite.Equal(*thirdModel.Id, *truncatedList.Items[1].Id) -} - -// MODEL VERSIONS - -func (suite *CoreTestSuite) TestCreateModelVersion() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - state := openapi.MODELVERSIONSTATE_LIVE - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Description: &modelVersionDescription, - State: &state, - Author: &author, - } - - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - suite.Equal((*createdVersion).RegisteredModelId, registeredModelId, "RegisteredModelId should match the actual owner-entity") - - suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") - - createdVersionId, _ := converter.StringToInt64(createdVersion.Id) - - byId, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *createdVersionId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - suite.Equal(1, len(byId.Contexts), "there should be just one context saved in mlmd") - - suite.Equal(*createdVersionId, *byId.Contexts[0].Id, "returned model id should match the mlmd one") - suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, modelVersionName), *byId.Contexts[0].Name, "saved model name should match the provided one") - suite.Equal(versionExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") - suite.Equal(author, byId.Contexts[0].Properties["author"].GetStringValue(), "saved author property should match the provided one") - suite.Equal(modelVersionDescription, byId.Contexts[0].Properties["description"].GetStringValue(), "saved description should match the provided one") - suite.Equal(string(state), byId.Contexts[0].Properties["state"].GetStringValue(), "saved state should match the provided one") - suite.Equalf(*modelVersionTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *modelVersionTypeName) - - getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(2, len(getAllResp.Contexts), "there should be two contexts saved in mlmd") -} - -func (suite *CoreTestSuite) TestCreateModelVersionFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := "9999" - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Author: &author, - } - - _, err := service.UpsertModelVersion(modelVersion, nil) - suite.NotNil(err) - suite.Equal("missing registered model id, cannot create model version without registered model: bad request", err.Error()) - - _, err = service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.NotNil(err) - suite.Equal("no registered model found for id 9999: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestUpdateModelVersion() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Author: &author, - } - - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - - suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") - createdVersionId, _ := converter.StringToInt64(createdVersion.Id) - - newExternalId := "org.my_awesome_model@v1" - newScore := 0.95 - - createdVersion.ExternalId = &newExternalId - (*createdVersion.CustomProperties)["score"] = openapi.MetadataValue{ - MetadataDoubleValue: converter.NewMetadataDoubleValue(newScore), - } - - updatedVersion, err := service.UpsertModelVersion(createdVersion, ®isteredModelId) - suite.Nilf(err, "error updating new model version for %s: %v", registeredModelId, err) - suite.Equal((*updatedVersion).RegisteredModelId, registeredModelId, "RegisteredModelId should match the actual owner-entity") - - updateVersionId, _ := converter.StringToInt64(updatedVersion.Id) - suite.Equal(*createdVersionId, *updateVersionId, "created and updated model version should have same id") - - byId, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *updateVersionId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - suite.Equal(1, len(byId.Contexts), "there should be just one context saved in mlmd") - - suite.Equal(*updateVersionId, *byId.Contexts[0].Id, "returned model id should match the mlmd one") - suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, modelVersionName), *byId.Contexts[0].Name, "saved model name should match the provided one") - suite.Equal(newExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") - suite.Equal(author, byId.Contexts[0].Properties["author"].GetStringValue(), "saved author property should match the provided one") - suite.Equal(newScore, byId.Contexts[0].CustomProperties["score"].GetDoubleValue(), "saved score custom property should match the provided one") - suite.Equalf(*modelVersionTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *modelVersionTypeName) - - getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(2, len(getAllResp.Contexts), "there should be two contexts saved in mlmd") - - // update with nil name - newExternalId = "org.my_awesome_model_@v1" - updatedVersion.ExternalId = &newExternalId - updatedVersion.Name = "" - updatedVersion, err = service.UpsertModelVersion(updatedVersion, ®isteredModelId) - suite.Nilf(err, "error updating new model version for %s: %v", registeredModelId, err) - - updateVersionId, _ = converter.StringToInt64(updatedVersion.Id) - suite.Equal(*createdVersionId, *updateVersionId, "created and updated model version should have same id") - - byId, err = suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *updateVersionId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - suite.Equal(1, len(byId.Contexts), "there should be just one context saved in mlmd") - - suite.Equal(*updateVersionId, *byId.Contexts[0].Id, "returned model id should match the mlmd one") - suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, modelVersionName), *byId.Contexts[0].Name, "saved model name should match the provided one") - suite.Equal(newExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") - suite.Equal(author, byId.Contexts[0].Properties["author"].GetStringValue(), "saved author property should match the provided one") - suite.Equal(newScore, byId.Contexts[0].CustomProperties["score"].GetDoubleValue(), "saved score custom property should match the provided one") - suite.Equalf(*modelVersionTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *modelVersionTypeName) -} - -func (suite *CoreTestSuite) TestUpdateModelVersionFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Author: &author, - } - - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %s", registeredModelId) - suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") - - newExternalId := "org.my_awesome_model@v1" - newScore := 0.95 - - createdVersion.ExternalId = &newExternalId - (*createdVersion.CustomProperties)["score"] = openapi.MetadataValue{ - MetadataDoubleValue: converter.NewMetadataDoubleValue(newScore), - } - - wrongId := "9999" - createdVersion.Id = &wrongId - _, err = service.UpsertModelVersion(createdVersion, ®isteredModelId) - suite.NotNil(err) - suite.Equal(fmt.Sprintf("no model version found for id %s: not found", wrongId), err.Error()) -} - -func (suite *CoreTestSuite) TestGetModelVersionById() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - state := openapi.MODELVERSIONSTATE_ARCHIVED - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - State: &state, - Author: &author, - } - - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - - suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") - createdVersionId, _ := converter.StringToInt64(createdVersion.Id) - - getById, err := service.GetModelVersionById(*createdVersion.Id) - suite.Nilf(err, "error getting model version with id %d", *createdVersionId) - - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *createdVersionId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - suite.Equal(*converter.Int64ToString(ctx.Id), *getById.Id, "returned model version id should match the mlmd context one") - suite.Equal(modelVersion.Name, getById.Name, "saved model name should match the provided one") - suite.Equal(*modelVersion.ExternalId, *getById.ExternalId, "saved external id should match the provided one") - suite.Equal(*modelVersion.State, *getById.State, "saved model state should match the original one") - suite.Equal(*getById.Author, author, "saved author property should match the provided one") -} - -func (suite *CoreTestSuite) TestGetModelVersionByParamsWithNoResults() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - _, err := service.GetModelVersionByParams(apiutils.Of("not-present"), ®isteredModelId, nil) - suite.NotNil(err) - suite.Equal("no model versions found for versionName=not-present, registeredModelId=1, externalId=: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestGetModelVersionByParamsName() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Author: &author, - } - - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - - suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") - createdVersionId, _ := converter.StringToInt64(createdVersion.Id) - - getByName, err := service.GetModelVersionByParams(&modelVersionName, ®isteredModelId, nil) - suite.Nilf(err, "error getting model version by name %d", *createdVersionId) - - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *createdVersionId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - suite.Equal(*converter.Int64ToString(ctx.Id), *getByName.Id, "returned model version id should match the mlmd context one") - suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, getByName.Name), *ctx.Name, "saved model name should match the provided one") - suite.Equal(*ctx.ExternalId, *getByName.ExternalId, "saved external id should match the provided one") - suite.Equal(ctx.Properties["author"].GetStringValue(), *getByName.Author, "saved author property should match the provided one") -} - -func (suite *CoreTestSuite) TestGetModelVersionByParamsExternalId() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Author: &author, - } - - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - - suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") - createdVersionId, _ := converter.StringToInt64(createdVersion.Id) - - getByExternalId, err := service.GetModelVersionByParams(nil, nil, modelVersion.ExternalId) - suite.Nilf(err, "error getting model version by external id %d", *modelVersion.ExternalId) - - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *createdVersionId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - suite.Equal(*converter.Int64ToString(ctx.Id), *getByExternalId.Id, "returned model version id should match the mlmd context one") - suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, getByExternalId.Name), *ctx.Name, "saved model name should match the provided one") - suite.Equal(*ctx.ExternalId, *getByExternalId.ExternalId, "saved external id should match the provided one") - suite.Equal(ctx.Properties["author"].GetStringValue(), *getByExternalId.Author, "saved author property should match the provided one") -} - -func (suite *CoreTestSuite) TestGetModelVersionByEmptyParams() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Author: &author, - } - - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") - - _, err = service.GetModelVersionByParams(nil, nil, nil) - suite.NotNil(err) - suite.Equal("invalid parameters call, supply either (versionName and registeredModelId), or externalId: bad request", err.Error()) -} - -func (suite *CoreTestSuite) TestGetModelVersions() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - - modelVersion1 := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - } - - secondModelVersionName := "v2" - secondModelVersionExtId := "org.myawesomemodel@v2" - modelVersion2 := &openapi.ModelVersion{ - Name: secondModelVersionName, - ExternalId: &secondModelVersionExtId, - } - - thirdModelVersionName := "v3" - thirdModelVersionExtId := "org.myawesomemodel@v3" - modelVersion3 := &openapi.ModelVersion{ - Name: thirdModelVersionName, - ExternalId: &thirdModelVersionExtId, - } - - createdVersion1, err := service.UpsertModelVersion(modelVersion1, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - - createdVersion2, err := service.UpsertModelVersion(modelVersion2, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - - createdVersion3, err := service.UpsertModelVersion(modelVersion3, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - - anotherRegModelName := "AnotherModel" - anotherRegModelExtId := "org.another" - anotherRegisteredModelId := suite.registerModel(service, &anotherRegModelName, &anotherRegModelExtId) - - anotherModelVersionName := "v1.0" - anotherModelVersionExtId := "org.another@v1.0" - modelVersionAnother := &openapi.ModelVersion{ - Name: anotherModelVersionName, - ExternalId: &anotherModelVersionExtId, - } - - _, err = service.UpsertModelVersion(modelVersionAnother, &anotherRegisteredModelId) - suite.Nilf(err, "error creating new model version for %d", anotherRegisteredModelId) - - createdVersionId1, _ := converter.StringToInt64(createdVersion1.Id) - createdVersionId2, _ := converter.StringToInt64(createdVersion2.Id) - createdVersionId3, _ := converter.StringToInt64(createdVersion3.Id) - - getAll, err := service.GetModelVersions(api.ListOptions{}, nil) - suite.Nilf(err, "error getting all model versions") - suite.Equal(int32(4), getAll.Size, "expected four model versions across all registered models") - - getAllByRegModel, err := service.GetModelVersions(api.ListOptions{}, ®isteredModelId) - suite.Nilf(err, "error getting all model versions") - suite.Equalf(int32(3), getAllByRegModel.Size, "expected three model versions for registered model %d", registeredModelId) - - suite.Equal(*converter.Int64ToString(createdVersionId1), *getAllByRegModel.Items[0].Id) - suite.Equal(*converter.Int64ToString(createdVersionId2), *getAllByRegModel.Items[1].Id) - suite.Equal(*converter.Int64ToString(createdVersionId3), *getAllByRegModel.Items[2].Id) - - // order by last update time, expecting last created as first - orderByLastUpdate := "LAST_UPDATE_TIME" - getAllByRegModel, err = service.GetModelVersions(api.ListOptions{ - OrderBy: &orderByLastUpdate, - SortOrder: &descOrderDirection, - }, ®isteredModelId) - suite.Nilf(err, "error getting all model versions") - suite.Equalf(int32(3), getAllByRegModel.Size, "expected three model versions for registered model %d", registeredModelId) - - suite.Equal(*converter.Int64ToString(createdVersionId1), *getAllByRegModel.Items[2].Id) - suite.Equal(*converter.Int64ToString(createdVersionId2), *getAllByRegModel.Items[1].Id) - suite.Equal(*converter.Int64ToString(createdVersionId3), *getAllByRegModel.Items[0].Id) - - // update the second version - newVersionExternalId := "updated.org:v2" - createdVersion2.ExternalId = &newVersionExternalId - createdVersion2, err = service.UpsertModelVersion(createdVersion2, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - - suite.Equal(newVersionExternalId, *createdVersion2.ExternalId) - - getAllByRegModel, err = service.GetModelVersions(api.ListOptions{ - OrderBy: &orderByLastUpdate, - SortOrder: &descOrderDirection, - }, ®isteredModelId) - suite.Nilf(err, "error getting all model versions") - suite.Equalf(int32(3), getAllByRegModel.Size, "expected three model versions for registered model %d", registeredModelId) - - suite.Equal(*converter.Int64ToString(createdVersionId1), *getAllByRegModel.Items[2].Id) - suite.Equal(*converter.Int64ToString(createdVersionId2), *getAllByRegModel.Items[0].Id) - suite.Equal(*converter.Int64ToString(createdVersionId3), *getAllByRegModel.Items[1].Id) -} - -// ARTIFACTS - -func (suite *CoreTestSuite) TestCreateArtifact() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - createdArt, err := service.UpsertArtifact(&openapi.Artifact{ - DocArtifact: &openapi.DocArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - Description: &artifactDescription, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - }, - }, &modelVersionId) - suite.Nilf(err, "error creating new artifact for %d: %v", modelVersionId, err) - - docArtifact := createdArt.DocArtifact - suite.NotNilf(docArtifact, "error creating new artifact for %d", modelVersionId) - state, _ := openapi.NewArtifactStateFromValue(artifactState) - suite.NotNil(docArtifact.Id, "created artifact id should not be nil") - suite.Equal(artifactName, *docArtifact.Name) - suite.Equal(*state, *docArtifact.State) - suite.Equal(artifactUri, *docArtifact.Uri) - suite.Equal(artifactDescription, *docArtifact.Description) - suite.Equal(customString, (*docArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) -} - -func (suite *CoreTestSuite) TestCreateArtifactFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := "9998" - - var artifact openapi.Artifact - artifact.DocArtifact = &openapi.DocArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - _, err := service.UpsertArtifact(&artifact, nil) - suite.NotNil(err) - suite.Equal("missing model version id, cannot create artifact without model version: bad request", err.Error()) - - _, err = service.UpsertArtifact(&artifact, &modelVersionId) - suite.NotNil(err) - suite.Equal("no model version found for id 9998: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestUpdateArtifact() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - createdArtifact, err := service.UpsertArtifact(&openapi.Artifact{ - DocArtifact: &openapi.DocArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - }, - }, &modelVersionId) - suite.Nilf(err, "error creating new artifact for %d", modelVersionId) - - newState := "MARKED_FOR_DELETION" - createdArtifact.DocArtifact.State = (*openapi.ArtifactState)(&newState) - updatedArtifact, err := service.UpsertArtifact(createdArtifact, &modelVersionId) - suite.Nilf(err, "error updating artifact for %d: %v", modelVersionId, err) - - createdArtifactId, _ := converter.StringToInt64(createdArtifact.DocArtifact.Id) - updatedArtifactId, _ := converter.StringToInt64(updatedArtifact.DocArtifact.Id) - suite.Equal(createdArtifactId, updatedArtifactId) - - getById, err := suite.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{ - ArtifactIds: []int64{*createdArtifactId}, - }) - suite.Nilf(err, "error getting artifact by id %d", createdArtifactId) - - suite.Equal(*createdArtifactId, *getById.Artifacts[0].Id) - suite.Equal(fmt.Sprintf("%s:%s", modelVersionId, *createdArtifact.DocArtifact.Name), *getById.Artifacts[0].Name) - suite.Equal(string(newState), getById.Artifacts[0].State.String()) - suite.Equal(*createdArtifact.DocArtifact.Uri, *getById.Artifacts[0].Uri) - suite.Equal((*createdArtifact.DocArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, getById.Artifacts[0].CustomProperties["custom_string_prop"].GetStringValue()) -} - -func (suite *CoreTestSuite) TestUpdateArtifactFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - createdArtifact, err := service.UpsertArtifact(&openapi.Artifact{ - DocArtifact: &openapi.DocArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - }, - }, &modelVersionId) - suite.Nilf(err, "error creating new artifact for model version %s", modelVersionId) - suite.NotNilf(createdArtifact.DocArtifact.Id, "created model artifact should not have nil Id") - - newState := "MARKED_FOR_DELETION" - createdArtifact.DocArtifact.State = (*openapi.ArtifactState)(&newState) - updatedArtifact, err := service.UpsertArtifact(createdArtifact, &modelVersionId) - suite.Nilf(err, "error updating artifact for %d: %v", modelVersionId, err) - - wrongId := "5555" - updatedArtifact.DocArtifact.Id = &wrongId - _, err = service.UpsertArtifact(updatedArtifact, &modelVersionId) - suite.NotNil(err) - suite.Equal(fmt.Sprintf("no artifact found for id %s: not found", wrongId), err.Error()) -} - -func (suite *CoreTestSuite) TestGetArtifactById() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - createdArtifact, err := service.UpsertArtifact(&openapi.Artifact{ - DocArtifact: &openapi.DocArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - }, - }, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - - createdArtifactId, _ := converter.StringToInt64(createdArtifact.DocArtifact.Id) - - getById, err := service.GetArtifactById(*createdArtifact.DocArtifact.Id) - suite.Nilf(err, "error getting artifact by id %d", createdArtifactId) - - state, _ := openapi.NewArtifactStateFromValue(artifactState) - suite.NotNil(createdArtifact.DocArtifact.Id, "created artifact id should not be nil") - suite.Equal(artifactName, *getById.DocArtifact.Name) - suite.Equal(*state, *getById.DocArtifact.State) - suite.Equal(artifactUri, *getById.DocArtifact.Uri) - suite.Equal(customString, (*getById.DocArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) - - suite.Equal(*createdArtifact, *getById, "artifacts returned during creation and on get by id should be equal") -} - -func (suite *CoreTestSuite) TestGetArtifacts() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - secondArtifactName := "second-name" - secondArtifactExtId := "second-ext-id" - secondArtifactUri := "second-uri" - - createdArtifact1, err := service.UpsertArtifact(&openapi.Artifact{ - ModelArtifact: &openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - ExternalId: &artifactExtId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - }, - }, &modelVersionId) - suite.Nilf(err, "error creating new artifact for %d", modelVersionId) - createdArtifact2, err := service.UpsertArtifact(&openapi.Artifact{ - DocArtifact: &openapi.DocArtifact{ - Name: &secondArtifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &secondArtifactUri, - ExternalId: &secondArtifactExtId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - }, - }, &modelVersionId) - suite.Nilf(err, "error creating new artifact for %d", modelVersionId) - - createdArtifactId1, _ := converter.StringToInt64(createdArtifact1.ModelArtifact.Id) - createdArtifactId2, _ := converter.StringToInt64(createdArtifact2.DocArtifact.Id) - - getAll, err := service.GetArtifacts(api.ListOptions{}, &modelVersionId) - suite.Nilf(err, "error getting all model artifacts") - suite.Equalf(int32(2), getAll.Size, "expected two artifacts") - - suite.Equal(*converter.Int64ToString(createdArtifactId1), *getAll.Items[0].ModelArtifact.Id) - suite.Equal(*converter.Int64ToString(createdArtifactId2), *getAll.Items[1].DocArtifact.Id) - - orderByLastUpdate := "LAST_UPDATE_TIME" - getAllByModelVersion, err := service.GetArtifacts(api.ListOptions{ - OrderBy: &orderByLastUpdate, - SortOrder: &descOrderDirection, - }, &modelVersionId) - suite.Nilf(err, "error getting all model artifacts for %d", modelVersionId) - suite.Equalf(int32(2), getAllByModelVersion.Size, "expected 2 artifacts for model version %d", modelVersionId) - - suite.Equal(*converter.Int64ToString(createdArtifactId1), *getAllByModelVersion.Items[1].ModelArtifact.Id) - suite.Equal(*converter.Int64ToString(createdArtifactId2), *getAllByModelVersion.Items[0].DocArtifact.Id) -} - -// MODEL ARTIFACTS - -func (suite *CoreTestSuite) TestCreateModelArtifact() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - modelArtifact, err := service.UpsertModelArtifact(&openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - Description: &artifactDescription, - ModelFormatName: apiutils.Of("onnx"), - ModelFormatVersion: apiutils.Of("1"), - StorageKey: apiutils.Of("aws-connection-models"), - StoragePath: apiutils.Of("bucket"), - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - }, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - - state, _ := openapi.NewArtifactStateFromValue(artifactState) - suite.NotNil(modelArtifact.Id, "created artifact id should not be nil") - suite.Equal(artifactName, *modelArtifact.Name) - suite.Equal(*state, *modelArtifact.State) - suite.Equal(artifactUri, *modelArtifact.Uri) - suite.Equal(artifactDescription, *modelArtifact.Description) - suite.Equal("onnx", *modelArtifact.ModelFormatName) - suite.Equal("1", *modelArtifact.ModelFormatVersion) - suite.Equal("aws-connection-models", *modelArtifact.StorageKey) - suite.Equal("bucket", *modelArtifact.StoragePath) - suite.Equal(customString, (*modelArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) -} - -func (suite *CoreTestSuite) TestCreateModelArtifactFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := "9998" - - modelArtifact := &openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - _, err := service.UpsertModelArtifact(modelArtifact, nil) - suite.NotNil(err) - suite.Equal("missing model version id, cannot create artifact without model version: bad request", err.Error()) - - _, err = service.UpsertModelArtifact(modelArtifact, &modelVersionId) - suite.NotNil(err) - suite.Equal("no model version found for id 9998: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestUpdateModelArtifact() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - modelArtifact := &openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdArtifact, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - - newState := "MARKED_FOR_DELETION" - createdArtifact.State = (*openapi.ArtifactState)(&newState) - updatedArtifact, err := service.UpsertModelArtifact(createdArtifact, &modelVersionId) - suite.Nilf(err, "error updating model artifact for %d: %v", modelVersionId, err) - - createdArtifactId, _ := converter.StringToInt64(createdArtifact.Id) - updatedArtifactId, _ := converter.StringToInt64(updatedArtifact.Id) - suite.Equal(createdArtifactId, updatedArtifactId) - - getById, err := suite.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{ - ArtifactIds: []int64{*createdArtifactId}, - }) - suite.Nilf(err, "error getting model artifact by id %d", createdArtifactId) - - suite.Equal(*createdArtifactId, *getById.Artifacts[0].Id) - suite.Equal(fmt.Sprintf("%s:%s", modelVersionId, *createdArtifact.Name), *getById.Artifacts[0].Name) - suite.Equal(string(newState), getById.Artifacts[0].State.String()) - suite.Equal(*createdArtifact.Uri, *getById.Artifacts[0].Uri) - suite.Equal((*createdArtifact.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, getById.Artifacts[0].CustomProperties["custom_string_prop"].GetStringValue()) -} - -func (suite *CoreTestSuite) TestUpdateModelArtifactFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - modelArtifact := &openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdArtifact, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for model version %s", modelVersionId) - suite.NotNilf(createdArtifact.Id, "created model artifact should not have nil Id") -} - -func (suite *CoreTestSuite) TestGetModelArtifactById() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - modelArtifact := &openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdArtifact, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - - createdArtifactId, _ := converter.StringToInt64(createdArtifact.Id) - - getById, err := service.GetModelArtifactById(*createdArtifact.Id) - suite.Nilf(err, "error getting model artifact by id %d", createdArtifactId) - - state, _ := openapi.NewArtifactStateFromValue(artifactState) - suite.NotNil(createdArtifact.Id, "created artifact id should not be nil") - suite.Equal(artifactName, *getById.Name) - suite.Equal(*state, *getById.State) - suite.Equal(artifactUri, *getById.Uri) - suite.Equal(customString, (*getById.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) - - suite.Equal(*createdArtifact, *getById, "artifacts returned during creation and on get by id should be equal") -} - -func (suite *CoreTestSuite) TestGetModelArtifactByParams() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - modelArtifact := &openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - ExternalId: &artifactExtId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdArtifact, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - - createdArtifactId, _ := converter.StringToInt64(createdArtifact.Id) - - state, _ := openapi.NewArtifactStateFromValue(artifactState) - - getByName, err := service.GetModelArtifactByParams(&artifactName, &modelVersionId, nil) - suite.Nilf(err, "error getting model artifact by id %d", createdArtifactId) - - suite.NotNil(createdArtifact.Id, "created artifact id should not be nil") - suite.Equal(artifactName, *getByName.Name) - suite.Equal(artifactExtId, *getByName.ExternalId) - suite.Equal(*state, *getByName.State) - suite.Equal(artifactUri, *getByName.Uri) - suite.Equal(customString, (*getByName.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) - - suite.Equal(*createdArtifact, *getByName, "artifacts returned during creation and on get by name should be equal") - - getByExtId, err := service.GetModelArtifactByParams(nil, nil, &artifactExtId) - suite.Nilf(err, "error getting model artifact by id %d", createdArtifactId) - - suite.NotNil(createdArtifact.Id, "created artifact id should not be nil") - suite.Equal(artifactName, *getByExtId.Name) - suite.Equal(artifactExtId, *getByExtId.ExternalId) - suite.Equal(*state, *getByExtId.State) - suite.Equal(artifactUri, *getByExtId.Uri) - suite.Equal(customString, (*getByExtId.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) - - suite.Equal(*createdArtifact, *getByExtId, "artifacts returned during creation and on get by ext id should be equal") -} - -func (suite *CoreTestSuite) TestGetModelArtifactByEmptyParams() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - modelArtifact := &openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - ExternalId: &artifactExtId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - _, err := service.UpsertModelArtifact(modelArtifact, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - - _, err = service.GetModelArtifactByParams(nil, nil, nil) - suite.NotNil(err) - suite.Equal("invalid parameters call, supply either (artifactName and modelVersionId), or externalId: bad request", err.Error()) -} - -func (suite *CoreTestSuite) TestGetModelArtifactByParamsWithNoResults() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - _, err := service.GetModelArtifactByParams(apiutils.Of("not-present"), &modelVersionId, nil) - suite.NotNil(err) - suite.Equal("no model artifacts found for artifactName=not-present, modelVersionId=2, externalId=: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestGetModelArtifacts() { - // create mode registry service - service := suite.setupModelRegistryService() - - modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) - - modelArtifact1 := &openapi.ModelArtifact{ - Name: &artifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &artifactUri, - ExternalId: &artifactExtId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - secondArtifactName := "second-name" - secondArtifactExtId := "second-ext-id" - secondArtifactUri := "second-uri" - modelArtifact2 := &openapi.ModelArtifact{ - Name: &secondArtifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &secondArtifactUri, - ExternalId: &secondArtifactExtId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - thirdArtifactName := "third-name" - thirdArtifactExtId := "third-ext-id" - thirdArtifactUri := "third-uri" - modelArtifact3 := &openapi.ModelArtifact{ - Name: &thirdArtifactName, - State: (*openapi.ArtifactState)(&artifactState), - Uri: &thirdArtifactUri, - ExternalId: &thirdArtifactExtId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdArtifact1, err := service.UpsertModelArtifact(modelArtifact1, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - createdArtifact2, err := service.UpsertModelArtifact(modelArtifact2, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - createdArtifact3, err := service.UpsertModelArtifact(modelArtifact3, &modelVersionId) - suite.Nilf(err, "error creating new model artifact for %d", modelVersionId) - - createdArtifactId1, _ := converter.StringToInt64(createdArtifact1.Id) - createdArtifactId2, _ := converter.StringToInt64(createdArtifact2.Id) - createdArtifactId3, _ := converter.StringToInt64(createdArtifact3.Id) - - getAll, err := service.GetModelArtifacts(api.ListOptions{}, nil) - suite.Nilf(err, "error getting all model artifacts") - suite.Equalf(int32(3), getAll.Size, "expected three model artifacts") - - suite.Equal(*converter.Int64ToString(createdArtifactId1), *getAll.Items[0].Id) - suite.Equal(*converter.Int64ToString(createdArtifactId2), *getAll.Items[1].Id) - suite.Equal(*converter.Int64ToString(createdArtifactId3), *getAll.Items[2].Id) - - orderByLastUpdate := "LAST_UPDATE_TIME" - getAllByModelVersion, err := service.GetModelArtifacts(api.ListOptions{ - OrderBy: &orderByLastUpdate, - SortOrder: &descOrderDirection, - }, &modelVersionId) - suite.Nilf(err, "error getting all model artifacts for %d", modelVersionId) - suite.Equalf(int32(3), getAllByModelVersion.Size, "expected three model artifacts for model version %d", modelVersionId) - - suite.Equal(*converter.Int64ToString(createdArtifactId1), *getAllByModelVersion.Items[2].Id) - suite.Equal(*converter.Int64ToString(createdArtifactId2), *getAllByModelVersion.Items[1].Id) - suite.Equal(*converter.Int64ToString(createdArtifactId3), *getAllByModelVersion.Items[0].Id) -} - -// SERVING ENVIRONMENT - -func (suite *CoreTestSuite) TestCreateServingEnvironment() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new ServingEnvironment - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - Description: &entityDescription, - CustomProperties: &map[string]openapi.MetadataValue{ - "myCustomProp": { - MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), - }, - }, - } - - // test - createdEntity, err := service.UpsertServingEnvironment(eut) - - // checks - suite.Nilf(err, "error creating uut: %v", err) - suite.NotNilf(createdEntity.Id, "created uut should not have nil Id") - - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{*createdEntityId}, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - ctxId := converter.Int64ToString(ctx.Id) - suite.Equal(*createdEntity.Id, *ctxId, "returned id should match the mlmd one") - suite.Equal(entityName, *ctx.Name, "saved name should match the provided one") - suite.Equal(entityExternalId, *ctx.ExternalId, "saved external id should match the provided one") - suite.Equal(entityDescription, ctx.Properties["description"].GetStringValue(), "saved description should match the provided one") - suite.Equal(myCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") - - getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") -} - -func (suite *CoreTestSuite) TestUpdateServingEnvironment() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new ServingEnvironment - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - CustomProperties: &map[string]openapi.MetadataValue{ - "myCustomProp": { - MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), - }, - }, - } - - // test - createdEntity, err := service.UpsertServingEnvironment(eut) - - // checks - suite.Nilf(err, "error creating uut: %v", err) - suite.NotNilf(createdEntity.Id, "created uut should not have nil Id") - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - - // checks created entity matches original one except for Id - suite.Equal(*eut.Name, *createdEntity.Name, "returned entity should match the original one") - suite.Equal(*eut.ExternalId, *createdEntity.ExternalId, "returned entity external id should match the original one") - suite.Equal(*eut.CustomProperties, *createdEntity.CustomProperties, "returned entity custom props should match the original one") - - // update existing entity - newExternalId := "newExternalId" - newCustomProp := "newCustomProp" - - createdEntity.ExternalId = &newExternalId - (*createdEntity.CustomProperties)["myCustomProp"] = openapi.MetadataValue{ - MetadataStringValue: converter.NewMetadataStringValue(newCustomProp), - } - - // update the entity - createdEntity, err = service.UpsertServingEnvironment(createdEntity) - suite.Nilf(err, "error creating uut: %v", err) - - // still one expected MLMD type - getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") - - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{*createdEntityId}, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - ctxId := converter.Int64ToString(ctx.Id) - suite.Equal(*createdEntity.Id, *ctxId, "returned entity id should match the mlmd one") - suite.Equal(entityName, *ctx.Name, "saved entity name should match the provided one") - suite.Equal(newExternalId, *ctx.ExternalId, "saved external id should match the provided one") - suite.Equal(newCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") - - // update the entity under test, keeping nil name - newExternalId = "newNewExternalId" - createdEntity.ExternalId = &newExternalId - createdEntity.Name = nil - createdEntity, err = service.UpsertServingEnvironment(createdEntity) - suite.Nilf(err, "error creating entity: %v", err) - - // still one registered entity - getAllResp, err = suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") - - ctxById, err = suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{*createdEntityId}, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - - ctx = ctxById.Contexts[0] - ctxId = converter.Int64ToString(ctx.Id) - suite.Equal(*createdEntity.Id, *ctxId, "returned entity id should match the mlmd one") - suite.Equal(entityName, *ctx.Name, "saved entity name should match the provided one") - suite.Equal(newExternalId, *ctx.ExternalId, "saved external id should match the provided one") - suite.Equal(newCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") -} - -func (suite *CoreTestSuite) TestGetServingEnvironmentById() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new entity - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - CustomProperties: &map[string]openapi.MetadataValue{ - "myCustomProp": { - MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), - }, - }, - } - - // test - createdEntity, err := service.UpsertServingEnvironment(eut) - - // checks - suite.Nilf(err, "error creating eut: %v", err) - - getEntityById, err := service.GetServingEnvironmentById(*createdEntity.Id) - suite.Nilf(err, "error getting eut by id %s: %v", *createdEntity.Id, err) - - // checks created entity matches original one except for Id - suite.Equal(*eut.Name, *getEntityById.Name, "saved name should match the original one") - suite.Equal(*eut.ExternalId, *getEntityById.ExternalId, "saved external id should match the original one") - suite.Equal(*eut.CustomProperties, *getEntityById.CustomProperties, "saved custom props should match the original one") -} - -func (suite *CoreTestSuite) TestGetServingEnvironmentByParamsWithNoResults() { - // create mode registry service - service := suite.setupModelRegistryService() - - _, err := service.GetServingEnvironmentByParams(apiutils.Of("not-present"), nil) - suite.NotNil(err) - suite.Equal("no serving environments found for name=not-present, externalId=: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestGetServingEnvironmentByParamsName() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new ServingEnvironment - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - } - - createdEntity, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - byName, err := service.GetServingEnvironmentByParams(&entityName, nil) - suite.Nilf(err, "error getting ServingEnvironment by name: %v", err) - - suite.Equalf(*createdEntity.Id, *byName.Id, "the returned entity id should match the retrieved by name") -} - -func (suite *CoreTestSuite) TestGetServingEnvironmentByParamsExternalId() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new ServingEnvironment - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - } - - createdEntity, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - byName, err := service.GetServingEnvironmentByParams(nil, &entityExternalId) - suite.Nilf(err, "error getting ServingEnvironment by external id: %v", err) - - suite.Equalf(*createdEntity.Id, *byName.Id, "the returned entity id should match the retrieved by name") -} - -func (suite *CoreTestSuite) TestGetServingEnvironmentByEmptyParams() { - // create mode registry service - service := suite.setupModelRegistryService() - - // register a new ServingEnvironment - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - } - - _, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - _, err = service.GetServingEnvironmentByParams(nil, nil) - suite.NotNil(err) - suite.Equal("invalid parameters call, supply either name or externalId: bad request", err.Error()) -} - -func (suite *CoreTestSuite) TestGetServingEnvironmentsOrderedById() { - // create mode registry service - service := suite.setupModelRegistryService() - - orderBy := "ID" - - // register a new ServingEnvironment - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - } - - _, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - newName := "Pricingentity2" - newExternalId := "myExternalId2" - eut.Name = &newName - eut.ExternalId = &newExternalId - _, err = service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - newName = "Pricingentity3" - newExternalId = "myExternalId3" - eut.Name = &newName - eut.ExternalId = &newExternalId - _, err = service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - orderedById, err := service.GetServingEnvironments(api.ListOptions{ - OrderBy: &orderBy, - SortOrder: &ascOrderDirection, - }) - suite.Nilf(err, "error getting ServingEnvironment: %v", err) - - suite.Equal(3, int(orderedById.Size)) - for i := 0; i < int(orderedById.Size)-1; i++ { - suite.Less(*orderedById.Items[i].Id, *orderedById.Items[i+1].Id) - } - - orderedById, err = service.GetServingEnvironments(api.ListOptions{ - OrderBy: &orderBy, - SortOrder: &descOrderDirection, - }) - suite.Nilf(err, "error getting ServingEnvironments: %v", err) - - suite.Equal(3, int(orderedById.Size)) - for i := 0; i < int(orderedById.Size)-1; i++ { - suite.Greater(*orderedById.Items[i].Id, *orderedById.Items[i+1].Id) - } -} - -func (suite *CoreTestSuite) TestGetServingEnvironmentsOrderedByLastUpdate() { - // create mode registry service - service := suite.setupModelRegistryService() - - orderBy := "LAST_UPDATE_TIME" - - // register a new ServingEnvironment - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - } - - firstEntity, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - newName := "Pricingentity2" - newExternalId := "myExternalId2" - eut.Name = &newName - eut.ExternalId = &newExternalId - secondEntity, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - newName = "Pricingentity3" - newExternalId = "myExternalId3" - eut.Name = &newName - eut.ExternalId = &newExternalId - thirdEntity, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - // update second entity - secondEntity.ExternalId = nil - _, err = service.UpsertServingEnvironment(secondEntity) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - orderedById, err := service.GetServingEnvironments(api.ListOptions{ - OrderBy: &orderBy, - SortOrder: &ascOrderDirection, - }) - suite.Nilf(err, "error getting ServingEnvironments: %v", err) - - suite.Equal(3, int(orderedById.Size)) - suite.Equal(*firstEntity.Id, *orderedById.Items[0].Id) - suite.Equal(*thirdEntity.Id, *orderedById.Items[1].Id) - suite.Equal(*secondEntity.Id, *orderedById.Items[2].Id) - - orderedById, err = service.GetServingEnvironments(api.ListOptions{ - OrderBy: &orderBy, - SortOrder: &descOrderDirection, - }) - suite.Nilf(err, "error getting ServingEnvironments: %v", err) - - suite.Equal(3, int(orderedById.Size)) - suite.Equal(*secondEntity.Id, *orderedById.Items[0].Id) - suite.Equal(*thirdEntity.Id, *orderedById.Items[1].Id) - suite.Equal(*firstEntity.Id, *orderedById.Items[2].Id) -} - -func (suite *CoreTestSuite) TestGetServingEnvironmentsWithPageSize() { - // create mode registry service - service := suite.setupModelRegistryService() - - pageSize := int32(1) - pageSize2 := int32(2) - entityName := "Pricingentity1" - entityExternalId := "myExternalId1" - - // register a new ServingEnvironment - eut := &openapi.ServingEnvironment{ - Name: &entityName, - ExternalId: &entityExternalId, - } - - firstEntity, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating registered entity: %v", err) - - newName := "Pricingentity2" - newExternalId := "myExternalId2" - eut.Name = &newName - eut.ExternalId = &newExternalId - secondEntity, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - newName = "Pricingentity3" - newExternalId = "myExternalId3" - eut.Name = &newName - eut.ExternalId = &newExternalId - thirdEntity, err := service.UpsertServingEnvironment(eut) - suite.Nilf(err, "error creating ServingEnvironment: %v", err) - - truncatedList, err := service.GetServingEnvironments(api.ListOptions{ - PageSize: &pageSize, - }) - suite.Nilf(err, "error getting ServingEnvironments: %v", err) - - suite.Equal(1, int(truncatedList.Size)) - suite.NotEqual("", truncatedList.NextPageToken, "next page token should not be empty") - suite.Equal(*firstEntity.Id, *truncatedList.Items[0].Id) - - truncatedList, err = service.GetServingEnvironments(api.ListOptions{ - PageSize: &pageSize2, - NextPageToken: &truncatedList.NextPageToken, - }) - suite.Nilf(err, "error getting ServingEnvironments: %v", err) - - suite.Equal(2, int(truncatedList.Size)) - suite.Equal("", truncatedList.NextPageToken, "next page token should be empty as list item returned") - suite.Equal(*secondEntity.Id, *truncatedList.Items[0].Id) - suite.Equal(*thirdEntity.Id, *truncatedList.Items[1].Id) -} - -// INFERENCE SERVICE - -func (suite *CoreTestSuite) TestCreateInferenceService() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - runtime := "model-server" - desiredState := openapi.INFERENCESERVICESTATE_DEPLOYED - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - Runtime: &runtime, - DesiredState: &desiredState, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %s: %v", parentResourceId, err) - - suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") - - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - - byId, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *createdEntityId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - suite.Equal(1, len(byId.Contexts), "there should be just one context saved in mlmd") - - suite.Equal(*createdEntityId, *byId.Contexts[0].Id, "returned id should match the mlmd one") - suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, entityName), *byId.Contexts[0].Name, "saved name should match the provided one") - suite.Equal(entityExternalId2, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") - suite.Equal(customString, byId.Contexts[0].CustomProperties["custom_string_prop"].GetStringValue(), "saved custom_string_prop custom property should match the provided one") - suite.Equal(entityDescription, byId.Contexts[0].Properties["description"].GetStringValue(), "saved description should match the provided one") - suite.Equal(runtime, byId.Contexts[0].Properties["runtime"].GetStringValue(), "saved runtime should match the provided one") - suite.Equal(string(desiredState), byId.Contexts[0].Properties["desired_state"].GetStringValue(), "saved state should match the provided one") - suite.Equalf(*inferenceServiceTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *inferenceServiceTypeName) - - getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(3, len(getAllResp.Contexts), "there should be 3 contexts (RegisteredModel, ServingEnvironment, InferenceService) saved in mlmd") -} - -func (suite *CoreTestSuite) TestCreateInferenceServiceFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - ServingEnvironmentId: "9999", - RegisteredModelId: "9998", - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - _, err := service.UpsertInferenceService(eut) - suite.NotNil(err) - suite.Equal("no serving environment found for id 9999: not found", err.Error()) - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - eut.ServingEnvironmentId = parentResourceId - - _, err = service.UpsertInferenceService(eut) - suite.NotNil(err) - suite.Equal("no registered model found for id 9998: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestUpdateInferenceService() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") - - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - - newExternalId := "org.my_awesome_entity@v1" - newScore := 0.95 - - createdEntity.ExternalId = &newExternalId - (*createdEntity.CustomProperties)["score"] = openapi.MetadataValue{ - MetadataDoubleValue: converter.NewMetadataDoubleValue(newScore), - } - - updatedEntity, err := service.UpsertInferenceService(createdEntity) - suite.Nilf(err, "error updating new entity for %s: %v", registeredModelId, err) - - updateEntityId, _ := converter.StringToInt64(updatedEntity.Id) - suite.Equal(*createdEntityId, *updateEntityId, "created and updated should have same id") - - byId, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *updateEntityId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - suite.Equal(1, len(byId.Contexts), "there should be 1 context saved in mlmd by id") - - suite.Equal(*updateEntityId, *byId.Contexts[0].Id, "returned id should match the mlmd one") - suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, *eut.Name), *byId.Contexts[0].Name, "saved name should match the provided one") - suite.Equal(newExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") - suite.Equal(customString, byId.Contexts[0].CustomProperties["custom_string_prop"].GetStringValue(), "saved custom_string_prop custom property should match the provided one") - suite.Equal(newScore, byId.Contexts[0].CustomProperties["score"].GetDoubleValue(), "saved score custom property should match the provided one") - suite.Equalf(*inferenceServiceTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *inferenceServiceTypeName) - - getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) - suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) - suite.Equal(3, len(getAllResp.Contexts), "there should be 3 contexts saved in mlmd") - - // update with nil name - newExternalId = "org.my_awesome_entity_@v1" - updatedEntity.ExternalId = &newExternalId - updatedEntity.Name = nil - updatedEntity, err = service.UpsertInferenceService(updatedEntity) - suite.Nilf(err, "error updating new model version for %s: %v", updateEntityId, err) - - updateEntityId, _ = converter.StringToInt64(updatedEntity.Id) - suite.Equal(*createdEntityId, *updateEntityId, "created and updated should have same id") - - byId, err = suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *updateEntityId, - }, - }) - suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) - suite.Equal(1, len(byId.Contexts), "there should be 1 context saved in mlmd by id") - - suite.Equal(*updateEntityId, *byId.Contexts[0].Id, "returned id should match the mlmd one") - suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, *eut.Name), *byId.Contexts[0].Name, "saved name should match the provided one") - suite.Equal(newExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") - suite.Equal(customString, byId.Contexts[0].CustomProperties["custom_string_prop"].GetStringValue(), "saved custom_string_prop custom property should match the provided one") - suite.Equal(newScore, byId.Contexts[0].CustomProperties["score"].GetDoubleValue(), "saved score custom property should match the provided one") - suite.Equalf(*inferenceServiceTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *inferenceServiceTypeName) - - // update with empty registeredModelId - newExternalId = "org.my_awesome_entity_@v1" - prevRegModelId := updatedEntity.RegisteredModelId - updatedEntity.RegisteredModelId = "" - updatedEntity, err = service.UpsertInferenceService(updatedEntity) - suite.Nil(err) - suite.Equal(prevRegModelId, updatedEntity.RegisteredModelId) -} - -func (suite *CoreTestSuite) TestUpdateInferenceServiceFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") - - newExternalId := "org.my_awesome_entity@v1" - newScore := 0.95 - - createdEntity.ExternalId = &newExternalId - (*createdEntity.CustomProperties)["score"] = openapi.MetadataValue{ - MetadataDoubleValue: converter.NewMetadataDoubleValue(newScore), - } - - wrongId := "9999" - createdEntity.Id = &wrongId - _, err = service.UpsertInferenceService(createdEntity) - suite.NotNil(err) - suite.Equal(fmt.Sprintf("no InferenceService found for id %s: not found", wrongId), err.Error()) -} - -func (suite *CoreTestSuite) TestGetInferenceServiceById() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - state := openapi.INFERENCESERVICESTATE_UNDEPLOYED - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - DesiredState: &state, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - - getById, err := service.GetInferenceServiceById(*createdEntity.Id) - suite.Nilf(err, "error getting model version with id %d", *createdEntityId) - - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *createdEntityId, - }, - }) - suite.Nilf(err, "error retrieving context, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - suite.Equal(*getById.Id, *converter.Int64ToString(ctx.Id), "returned id should match the mlmd context one") - suite.Equal(*eut.Name, *getById.Name, "saved name should match the provided one") - suite.Equal(*eut.ExternalId, *getById.ExternalId, "saved external id should match the provided one") - suite.Equal(*eut.DesiredState, *getById.DesiredState, "saved state should match the provided one") - suite.Equal((*getById.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, customString, "saved custom_string_prop custom property should match the provided one") -} - -func (suite *CoreTestSuite) TestGetRegisteredModelByInferenceServiceId() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") - - getRM, err := service.GetRegisteredModelByInferenceService(*createdEntity.Id) - suite.Nilf(err, "error getting using id %s", *createdEntity.Id) - - suite.Equal(registeredModelId, *getRM.Id, "returned id should match the original registeredModelId") -} - -func (suite *CoreTestSuite) TestGetModelVersionByInferenceServiceId() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - modelVersion1Name := "v1" - modelVersion1 := &openapi.ModelVersion{Name: modelVersion1Name, Description: &modelVersionDescription} - createdVersion1, err := service.UpsertModelVersion(modelVersion1, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersion1Id := *createdVersion1.Id - - modelVersion2Name := "v2" - modelVersion2 := &openapi.ModelVersion{Name: modelVersion2Name, Description: &modelVersionDescription} - createdVersion2, err := service.UpsertModelVersion(modelVersion2, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersion2Id := *createdVersion2.Id - // end of data preparation - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - ModelVersionId: nil, // first we test by unspecified - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - getVModel, err := service.GetModelVersionByInferenceService(*createdEntity.Id) - suite.Nilf(err, "error getting using id %s", *createdEntity.Id) - suite.Equal(createdVersion2Id, *getVModel.Id, "returned id shall be the latest ModelVersion by creation order") - - // here we used the returned entity (so ID is populated), and we update to specify the "ID of the ModelVersion to serve" - createdEntity.ModelVersionId = &createdVersion1Id - _, err = service.UpsertInferenceService(createdEntity) - suite.Nilf(err, "error updating eut for %v", parentResourceId) - - getVModel, err = service.GetModelVersionByInferenceService(*createdEntity.Id) - suite.Nilf(err, "error getting using id %s", *createdEntity.Id) - suite.Equal(createdVersion1Id, *getVModel.Id, "returned id shall be the specified one") -} - -func (suite *CoreTestSuite) TestGetModelArtifactByInferenceServiceId() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - modelVersion1Name := "v1" - modelVersion1 := &openapi.ModelVersion{Name: modelVersion1Name, Description: &modelVersionDescription} - createdVersion1, err := service.UpsertModelVersion(modelVersion1, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %s", registeredModelId) - modelArtifact1Name := "v1-artifact" - modelArtifact1 := &openapi.ModelArtifact{Name: &modelArtifact1Name} - createdArtifact1, err := service.UpsertModelArtifact(modelArtifact1, createdVersion1.Id) - suite.Nilf(err, "error creating new model artifact for %s", *createdVersion1.Id) - - modelVersion2Name := "v2" - modelVersion2 := &openapi.ModelVersion{Name: modelVersion2Name, Description: &modelVersionDescription} - createdVersion2, err := service.UpsertModelVersion(modelVersion2, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %s", registeredModelId) - modelArtifact2Name := "v2-artifact" - modelArtifact2 := &openapi.ModelArtifact{Name: &modelArtifact2Name} - createdArtifact2, err := service.UpsertModelArtifact(modelArtifact2, createdVersion2.Id) - suite.Nilf(err, "error creating new model artifact for %s", *createdVersion2.Id) - // end of data preparation - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - ModelVersionId: nil, // first we test by unspecified - } - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - getModelArt, err := service.GetModelArtifactByInferenceService(*createdEntity.Id) - suite.Nilf(err, "error getting using id %s", *createdEntity.Id) - suite.Equal(*createdArtifact2.Id, *getModelArt.Id, "returned id shall be the latest ModelVersion by creation order") - - // here we used the returned entity (so ID is populated), and we update to specify the "ID of the ModelVersion to serve" - createdEntity.ModelVersionId = createdVersion1.Id - _, err = service.UpsertInferenceService(createdEntity) - suite.Nilf(err, "error updating eut for %v", parentResourceId) - - getModelArt, err = service.GetModelArtifactByInferenceService(*createdEntity.Id) - suite.Nilf(err, "error getting using id %s", *createdEntity.Id) - suite.Equal(*createdArtifact1.Id, *getModelArt.Id, "returned id shall be the specified one") -} - -func (suite *CoreTestSuite) TestGetInferenceServiceByParamsWithNoResults() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - - _, err := service.GetInferenceServiceByParams(apiutils.Of("not-present"), &parentResourceId, nil) - suite.NotNil(err) - suite.Equal("no inference services found for name=not-present, servingEnvironmentId=1, externalId=: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestGetInferenceServiceByParamsName() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - - getByName, err := service.GetInferenceServiceByParams(&entityName, &parentResourceId, nil) - suite.Nilf(err, "error getting model version by name %d", *createdEntityId) - - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *createdEntityId, - }, - }) - suite.Nilf(err, "error retrieving context, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - suite.Equal(*converter.Int64ToString(ctx.Id), *getByName.Id, "returned id should match the mlmd context one") - suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, *getByName.Name), *ctx.Name, "saved name should match the provided one") - suite.Equal(*ctx.ExternalId, *getByName.ExternalId, "saved external id should match the provided one") - suite.Equal(ctx.CustomProperties["custom_string_prop"].GetStringValue(), (*getByName.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, "saved custom_string_prop custom property should match the provided one") -} - -func (suite *CoreTestSuite) TestGetInfernenceServiceByParamsExternalId() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - - getByExternalId, err := service.GetInferenceServiceByParams(nil, nil, eut.ExternalId) - suite.Nilf(err, "error getting by external id %d", *eut.ExternalId) - - ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ - ContextIds: []int64{ - *createdEntityId, - }, - }) - suite.Nilf(err, "error retrieving context, not related to the test itself: %v", err) - - ctx := ctxById.Contexts[0] - suite.Equal(*converter.Int64ToString(ctx.Id), *getByExternalId.Id, "returned id should match the mlmd context one") - suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, *getByExternalId.Name), *ctx.Name, "saved name should match the provided one") - suite.Equal(*ctx.ExternalId, *getByExternalId.ExternalId, "saved external id should match the provided one") - suite.Equal(ctx.CustomProperties["custom_string_prop"].GetStringValue(), (*getByExternalId.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, "saved custom_string_prop custom property should match the provided one") -} - -func (suite *CoreTestSuite) TestGetInferenceServiceByEmptyParams() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - eut := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - Description: &entityDescription, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertInferenceService(eut) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") - - _, err = service.GetInferenceServiceByParams(nil, nil, nil) - suite.NotNil(err) - suite.Equal("invalid parameters call, supply either (name and servingEnvironmentId), or externalId: bad request", err.Error()) -} - -func (suite *CoreTestSuite) TestGetInferenceServices() { - // create mode registry service - service := suite.setupModelRegistryService() - - parentResourceId := suite.registerServingEnvironment(service, nil, nil) - registeredModelId := suite.registerModel(service, nil, nil) - - eut1 := &openapi.InferenceService{ - Name: &entityName, - ExternalId: &entityExternalId2, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - Runtime: apiutils.Of("model-server0"), - } - - secondName := "v2" - secondExtId := "org.myawesomeentity@v2" - eut2 := &openapi.InferenceService{ - Name: &secondName, - ExternalId: &secondExtId, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - Runtime: apiutils.Of("model-server1"), - } - - thirdName := "v3" - thirdExtId := "org.myawesomeentity@v3" - eut3 := &openapi.InferenceService{ - Name: &thirdName, - ExternalId: &thirdExtId, - ServingEnvironmentId: parentResourceId, - RegisteredModelId: registeredModelId, - Runtime: apiutils.Of("model-server2"), - } - - createdEntity1, err := service.UpsertInferenceService(eut1) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - createdEntity2, err := service.UpsertInferenceService(eut2) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - createdEntity3, err := service.UpsertInferenceService(eut3) - suite.Nilf(err, "error creating new eut for %v", parentResourceId) - - anotherParentResourceName := "AnotherModel" - anotherParentResourceExtId := "org.another" - anotherParentResourceId := suite.registerServingEnvironment(service, &anotherParentResourceName, &anotherParentResourceExtId) - - anotherName := "v1.0" - anotherExtId := "org.another@v1.0" - eutAnother := &openapi.InferenceService{ - Name: &anotherName, - ExternalId: &anotherExtId, - ServingEnvironmentId: anotherParentResourceId, - RegisteredModelId: registeredModelId, - Runtime: apiutils.Of("model-server3"), - } - - _, err = service.UpsertInferenceService(eutAnother) - suite.Nilf(err, "error creating new model version for %d", anotherParentResourceId) - - createdId1, _ := converter.StringToInt64(createdEntity1.Id) - createdId2, _ := converter.StringToInt64(createdEntity2.Id) - createdId3, _ := converter.StringToInt64(createdEntity3.Id) - - getAll, err := service.GetInferenceServices(api.ListOptions{}, nil, nil) - suite.Nilf(err, "error getting all") - suite.Equal(int32(4), getAll.Size, "expected 4 across all parent resources") - - getAllByParentResource, err := service.GetInferenceServices(api.ListOptions{}, &parentResourceId, nil) - suite.Nilf(err, "error getting all") - suite.Equalf(int32(3), getAllByParentResource.Size, "expected 3 for parent resource %d", parentResourceId) - - suite.Equal(*converter.Int64ToString(createdId1), *getAllByParentResource.Items[0].Id) - suite.Equal(*converter.Int64ToString(createdId2), *getAllByParentResource.Items[1].Id) - suite.Equal(*converter.Int64ToString(createdId3), *getAllByParentResource.Items[2].Id) - - modelServer := "model-server1" - getAllByParentResourceAndRuntime, err := service.GetInferenceServices(api.ListOptions{}, &parentResourceId, &modelServer) - suite.Nilf(err, "error getting all") - suite.Equalf(int32(1), getAllByParentResourceAndRuntime.Size, "expected 1 for parent resource %s and runtime %s", parentResourceId, modelServer) - - suite.Equal(*converter.Int64ToString(createdId1), *getAllByParentResource.Items[0].Id) - - // order by last update time, expecting last created as first - orderByLastUpdate := "LAST_UPDATE_TIME" - getAllByParentResource, err = service.GetInferenceServices(api.ListOptions{ - OrderBy: &orderByLastUpdate, - SortOrder: &descOrderDirection, - }, &parentResourceId, nil) - suite.Nilf(err, "error getting all") - suite.Equalf(int32(3), getAllByParentResource.Size, "expected 3 for parent resource %d", parentResourceId) - - suite.Equal(*converter.Int64ToString(createdId1), *getAllByParentResource.Items[2].Id) - suite.Equal(*converter.Int64ToString(createdId2), *getAllByParentResource.Items[1].Id) - suite.Equal(*converter.Int64ToString(createdId3), *getAllByParentResource.Items[0].Id) - - // update the second entity - newExternalId := "updated.org:v2" - createdEntity2.ExternalId = &newExternalId - createdEntity2, err = service.UpsertInferenceService(createdEntity2) - suite.Nilf(err, "error creating new eut2 for %d", parentResourceId) - - suite.Equal(newExternalId, *createdEntity2.ExternalId) - - getAllByParentResource, err = service.GetInferenceServices(api.ListOptions{ - OrderBy: &orderByLastUpdate, - SortOrder: &descOrderDirection, - }, &parentResourceId, nil) - suite.Nilf(err, "error getting all") - suite.Equalf(int32(3), getAllByParentResource.Size, "expected 3 for parent resource %d", parentResourceId) - - suite.Equal(*converter.Int64ToString(createdId1), *getAllByParentResource.Items[2].Id) - suite.Equal(*converter.Int64ToString(createdId2), *getAllByParentResource.Items[0].Id) - suite.Equal(*converter.Int64ToString(createdId3), *getAllByParentResource.Items[1].Id) -} - -// SERVE MODEL - -func (suite *CoreTestSuite) TestCreateServeModel() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Description: &modelVersionDescription, - Author: &author, - } - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersionId := *createdVersion.Id - createdVersionIdAsInt, _ := converter.StringToInt64(&createdVersionId) - // end of data preparation - - eut := &openapi.ServeModel{ - LastKnownState: (*openapi.ExecutionState)(&executionState), - ExternalId: &entityExternalId2, - Description: &entityDescription, - Name: &entityName, - ModelVersionId: createdVersionId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertServeModel(eut, &inferenceServiceId) - suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) - suite.NotNil(createdEntity.Id, "created id should not be nil") - - state, _ := openapi.NewExecutionStateFromValue(executionState) - suite.Equal(entityName, *createdEntity.Name) - suite.Equal(*state, *createdEntity.LastKnownState) - suite.Equal(createdVersionId, createdEntity.ModelVersionId) - suite.Equal(entityDescription, *createdEntity.Description) - suite.Equal(customString, (*createdEntity.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) - - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - getById, err := suite.mlmdClient.GetExecutionsByID(context.Background(), &proto.GetExecutionsByIDRequest{ - ExecutionIds: []int64{*createdEntityId}, - }) - suite.Nilf(err, "error getting Execution by id %d", createdEntityId) - - suite.Equal(*createdEntityId, *getById.Executions[0].Id) - suite.Equal(fmt.Sprintf("%s:%s", inferenceServiceId, *createdEntity.Name), *getById.Executions[0].Name) - suite.Equal(string(*createdEntity.LastKnownState), getById.Executions[0].LastKnownState.String()) - suite.Equal(*createdVersionIdAsInt, getById.Executions[0].Properties["model_version_id"].GetIntValue()) - suite.Equal(*createdEntity.Description, getById.Executions[0].Properties["description"].GetStringValue()) - suite.Equal((*createdEntity.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, getById.Executions[0].CustomProperties["custom_string_prop"].GetStringValue()) - - inferenceServiceIdAsInt, _ := converter.StringToInt64(&inferenceServiceId) - byCtx, _ := suite.mlmdClient.GetExecutionsByContext(context.Background(), &proto.GetExecutionsByContextRequest{ - ContextId: (*int64)(inferenceServiceIdAsInt), - }) - suite.Equal(1, len(byCtx.Executions)) - suite.Equal(*createdEntityId, *byCtx.Executions[0].Id) -} - -func (suite *CoreTestSuite) TestCreateServeModelFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) - // end of data preparation - - eut := &openapi.ServeModel{ - LastKnownState: (*openapi.ExecutionState)(&executionState), - ExternalId: &entityExternalId2, - Description: &entityDescription, - Name: &entityName, - ModelVersionId: "9998", - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - _, err := service.UpsertServeModel(eut, nil) - suite.NotNil(err) - suite.Equal("missing inferenceServiceId, cannot create ServeModel without parent resource InferenceService: bad request", err.Error()) - - _, err = service.UpsertServeModel(eut, &inferenceServiceId) - suite.NotNil(err) - suite.Equal("no model version found for id 9998: not found", err.Error()) -} - -func (suite *CoreTestSuite) TestUpdateServeModel() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Description: &modelVersionDescription, - Author: &author, - } - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersionId := *createdVersion.Id - createdVersionIdAsInt, _ := converter.StringToInt64(&createdVersionId) - // end of data preparation - - eut := &openapi.ServeModel{ - LastKnownState: (*openapi.ExecutionState)(&executionState), - ExternalId: &entityExternalId2, - Description: &entityDescription, - Name: &entityName, - ModelVersionId: createdVersionId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertServeModel(eut, &inferenceServiceId) - suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) - - newState := "UNKNOWN" - createdEntity.LastKnownState = (*openapi.ExecutionState)(&newState) - updatedEntity, err := service.UpsertServeModel(createdEntity, &inferenceServiceId) - suite.Nilf(err, "error updating entity for %d: %v", inferenceServiceId, err) - - createdEntityId, _ := converter.StringToInt64(createdEntity.Id) - updatedEntityId, _ := converter.StringToInt64(updatedEntity.Id) - suite.Equal(createdEntityId, updatedEntityId) - - getById, err := suite.mlmdClient.GetExecutionsByID(context.Background(), &proto.GetExecutionsByIDRequest{ - ExecutionIds: []int64{*createdEntityId}, - }) - suite.Nilf(err, "error getting by id %d", createdEntityId) - - suite.Equal(*createdEntityId, *getById.Executions[0].Id) - suite.Equal(fmt.Sprintf("%s:%s", inferenceServiceId, *createdEntity.Name), *getById.Executions[0].Name) - suite.Equal(string(newState), getById.Executions[0].LastKnownState.String()) - suite.Equal(*createdVersionIdAsInt, getById.Executions[0].Properties["model_version_id"].GetIntValue()) - suite.Equal((*createdEntity.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, getById.Executions[0].CustomProperties["custom_string_prop"].GetStringValue()) - - prevModelVersionId := updatedEntity.ModelVersionId - updatedEntity.ModelVersionId = "" - updatedEntity, err = service.UpsertServeModel(updatedEntity, &inferenceServiceId) - suite.Nilf(err, "error updating entity for %d: %v", inferenceServiceId, err) - suite.Equal(prevModelVersionId, updatedEntity.ModelVersionId) -} - -func (suite *CoreTestSuite) TestUpdateServeModelFailure() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Description: &modelVersionDescription, - Author: &author, - } - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersionId := *createdVersion.Id - // end of data preparation - - eut := &openapi.ServeModel{ - LastKnownState: (*openapi.ExecutionState)(&executionState), - ExternalId: &entityExternalId2, - Description: &entityDescription, - Name: &entityName, - ModelVersionId: createdVersionId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertServeModel(eut, &inferenceServiceId) - suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) - suite.NotNil(createdEntity.Id, "created id should not be nil") - - newState := "UNKNOWN" - createdEntity.LastKnownState = (*openapi.ExecutionState)(&newState) - updatedEntity, err := service.UpsertServeModel(createdEntity, &inferenceServiceId) - suite.Nilf(err, "error updating entity for %d: %v", inferenceServiceId, err) - - wrongId := "9998" - updatedEntity.Id = &wrongId - _, err = service.UpsertServeModel(updatedEntity, &inferenceServiceId) - suite.NotNil(err) - suite.Equal(fmt.Sprintf("no ServeModel found for id %s: not found", wrongId), err.Error()) -} - -func (suite *CoreTestSuite) TestGetServeModelById() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) - - modelVersion := &openapi.ModelVersion{ - Name: modelVersionName, - ExternalId: &versionExternalId, - Description: &modelVersionDescription, - Author: &author, - } - createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersionId := *createdVersion.Id - // end of data preparation - - eut := &openapi.ServeModel{ - LastKnownState: (*openapi.ExecutionState)(&executionState), - ExternalId: &entityExternalId2, - Description: &entityDescription, - Name: &entityName, - ModelVersionId: createdVersionId, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity, err := service.UpsertServeModel(eut, &inferenceServiceId) - suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) - - getById, err := service.GetServeModelById(*createdEntity.Id) - suite.Nilf(err, "error getting entity by id %d", *createdEntity.Id) - - state, _ := openapi.NewExecutionStateFromValue(executionState) - suite.NotNil(createdEntity.Id, "created artifact id should not be nil") - suite.Equal(entityName, *getById.Name) - suite.Equal(*state, *getById.LastKnownState) - suite.Equal(createdVersionId, getById.ModelVersionId) - suite.Equal(customString, (*getById.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) - - suite.Equal(*createdEntity, *getById, "artifacts returned during creation and on get by id should be equal") -} - -func (suite *CoreTestSuite) TestGetServeModels() { - // create mode registry service - service := suite.setupModelRegistryService() - - registeredModelId := suite.registerModel(service, nil, nil) - inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) - - modelVersion1Name := "v1" - modelVersion1 := &openapi.ModelVersion{Name: modelVersion1Name, Description: &modelVersionDescription} - createdVersion1, err := service.UpsertModelVersion(modelVersion1, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersion1Id := *createdVersion1.Id - - modelVersion2Name := "v2" - modelVersion2 := &openapi.ModelVersion{Name: modelVersion2Name, Description: &modelVersionDescription} - createdVersion2, err := service.UpsertModelVersion(modelVersion2, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersion2Id := *createdVersion2.Id - - modelVersion3Name := "v3" - modelVersion3 := &openapi.ModelVersion{Name: modelVersion3Name, Description: &modelVersionDescription} - createdVersion3, err := service.UpsertModelVersion(modelVersion3, ®isteredModelId) - suite.Nilf(err, "error creating new model version for %d", registeredModelId) - createdVersion3Id := *createdVersion3.Id - // end of data preparation - - eut1Name := "sm1" - eut1 := &openapi.ServeModel{ - LastKnownState: (*openapi.ExecutionState)(&executionState), - Description: &entityDescription, - Name: &eut1Name, - ModelVersionId: createdVersion1Id, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - eut2Name := "sm2" - eut2 := &openapi.ServeModel{ - LastKnownState: (*openapi.ExecutionState)(&executionState), - Description: &entityDescription, - Name: &eut2Name, - ModelVersionId: createdVersion2Id, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - eut3Name := "sm3" - eut3 := &openapi.ServeModel{ - LastKnownState: (*openapi.ExecutionState)(&executionState), - Description: &entityDescription, - Name: &eut3Name, - ModelVersionId: createdVersion3Id, - CustomProperties: &map[string]openapi.MetadataValue{ - "custom_string_prop": { - MetadataStringValue: converter.NewMetadataStringValue(customString), - }, - }, - } - - createdEntity1, err := service.UpsertServeModel(eut1, &inferenceServiceId) - suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) - createdEntity2, err := service.UpsertServeModel(eut2, &inferenceServiceId) - suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) - createdEntity3, err := service.UpsertServeModel(eut3, &inferenceServiceId) - suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) - - createdEntityId1, _ := converter.StringToInt64(createdEntity1.Id) - createdEntityId2, _ := converter.StringToInt64(createdEntity2.Id) - createdEntityId3, _ := converter.StringToInt64(createdEntity3.Id) - - getAll, err := service.GetServeModels(api.ListOptions{}, nil) - suite.Nilf(err, "error getting all ServeModel") - suite.Equalf(int32(3), getAll.Size, "expected three ServeModel") - - suite.Equal(*converter.Int64ToString(createdEntityId1), *getAll.Items[0].Id) - suite.Equal(*converter.Int64ToString(createdEntityId2), *getAll.Items[1].Id) - suite.Equal(*converter.Int64ToString(createdEntityId3), *getAll.Items[2].Id) - - orderByLastUpdate := "LAST_UPDATE_TIME" - getAllByInferenceService, err := service.GetServeModels(api.ListOptions{ - OrderBy: &orderByLastUpdate, - SortOrder: &descOrderDirection, - }, &inferenceServiceId) - suite.Nilf(err, "error getting all ServeModels for %d", inferenceServiceId) - suite.Equalf(int32(3), getAllByInferenceService.Size, "expected three ServeModels for InferenceServiceId %d", inferenceServiceId) - - suite.Equal(*converter.Int64ToString(createdEntityId1), *getAllByInferenceService.Items[2].Id) - suite.Equal(*converter.Int64ToString(createdEntityId2), *getAllByInferenceService.Items[1].Id) - suite.Equal(*converter.Int64ToString(createdEntityId3), *getAllByInferenceService.Items[0].Id) -} diff --git a/pkg/core/inference_service.go b/pkg/core/inference_service.go new file mode 100644 index 00000000..d7584bbc --- /dev/null +++ b/pkg/core/inference_service.go @@ -0,0 +1,253 @@ +package core + +import ( + "context" + "fmt" + "strings" + + "github.com/golang/glog" + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// INFERENCE SERVICE + +// UpsertInferenceService creates a new inference service if the provided inference service's ID is nil, +// or updates an existing inference service if the ID is provided. +func (serv *ModelRegistryService) UpsertInferenceService(inferenceService *openapi.InferenceService) (*openapi.InferenceService, error) { + var err error + var existing *openapi.InferenceService + var servingEnvironment *openapi.ServingEnvironment + + if inferenceService.Id == nil { + // create + glog.Info("Creating new InferenceService") + servingEnvironment, err = serv.GetServingEnvironmentById(inferenceService.ServingEnvironmentId) + if err != nil { + return nil, err + } + } else { + // update + glog.Infof("Updating InferenceService %s", *inferenceService.Id) + + existing, err = serv.GetInferenceServiceById(*inferenceService.Id) + if err != nil { + return nil, err + } + + withNotEditable, err := serv.openapiConv.OverrideNotEditableForInferenceService(converter.NewOpenapiUpdateWrapper(existing, inferenceService)) + if err != nil { + return nil, err + } + inferenceService = &withNotEditable + + servingEnvironment, err = serv.getServingEnvironmentByInferenceServiceId(*inferenceService.Id) + if err != nil { + return nil, err + } + } + + // validate RegisteredModelId is also valid + if _, err := serv.GetRegisteredModelById(inferenceService.RegisteredModelId); err != nil { + return nil, err + } + + // if already existing assure the name is the same + if existing != nil && inferenceService.Name == nil { + // user did not provide it + // need to set it to avoid mlmd error "context name should not be empty" + inferenceService.Name = existing.Name + } + + protoCtx, err := serv.mapper.MapFromInferenceService(inferenceService, *servingEnvironment.Id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + protoCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ + Contexts: []*proto.Context{ + protoCtx, + }, + }) + if err != nil { + return nil, err + } + + inferenceServiceId := &protoCtxResp.ContextIds[0] + if inferenceService.Id == nil { + servingEnvironmentId, err := converter.StringToInt64(servingEnvironment.Id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + _, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{ + ParentContexts: []*proto.ParentContext{{ + ChildId: inferenceServiceId, + ParentId: servingEnvironmentId, + }}, + TransactionOptions: &proto.TransactionOptions{}, + }) + if err != nil { + return nil, err + } + } + + idAsString := converter.Int64ToString(inferenceServiceId) + toReturn, err := serv.GetInferenceServiceById(*idAsString) + if err != nil { + return nil, err + } + + return toReturn, nil +} + +// getServingEnvironmentByInferenceServiceId retrieves the serving environment associated with the specified inference service ID. +func (serv *ModelRegistryService) getServingEnvironmentByInferenceServiceId(id string) (*openapi.ServingEnvironment, error) { + glog.Infof("Getting ServingEnvironment for InferenceService %s", id) + + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + getParentResp, err := serv.mlmdClient.GetParentContextsByContext(context.Background(), &proto.GetParentContextsByContextRequest{ + ContextId: idAsInt, + }) + if err != nil { + return nil, err + } + + if len(getParentResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple ServingEnvironments found for InferenceService %s: %w", id, api.ErrNotFound) + } + + if len(getParentResp.Contexts) == 0 { + return nil, fmt.Errorf("no ServingEnvironments found for InferenceService %s: %w", id, api.ErrNotFound) + } + + toReturn, err := serv.mapper.MapToServingEnvironment(getParentResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return toReturn, nil +} + +// GetInferenceServiceById retrieves an inference service by its unique identifier (ID). +func (serv *ModelRegistryService) GetInferenceServiceById(id string) (*openapi.InferenceService, error) { + glog.Infof("Getting InferenceService by id %s", id) + + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{*idAsInt}, + }) + if err != nil { + return nil, err + } + + if len(getByIdResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple InferenceServices found for id %s: %w", id, api.ErrNotFound) + } + + if len(getByIdResp.Contexts) == 0 { + return nil, fmt.Errorf("no InferenceService found for id %s: %w", id, api.ErrNotFound) + } + + toReturn, err := serv.mapper.MapToInferenceService(getByIdResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return toReturn, nil +} + +// GetInferenceServiceByParams retrieves an inference service based on specified parameters, such as (name and serving environment ID), or external ID. +// If multiple or no serving environments are found, an error is returned accordingly. +func (serv *ModelRegistryService) GetInferenceServiceByParams(name *string, servingEnvironmentId *string, externalId *string) (*openapi.InferenceService, error) { + filterQuery := "" + if name != nil && servingEnvironmentId != nil { + filterQuery = fmt.Sprintf("name = \"%s\"", converter.PrefixWhenOwned(servingEnvironmentId, *name)) + } else if externalId != nil { + filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) + } else { + return nil, fmt.Errorf("invalid parameters call, supply either (name and servingEnvironmentId), or externalId: %w", api.ErrBadRequest) + } + + getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ + TypeName: &serv.nameConfig.InferenceServiceTypeName, + Options: &proto.ListOperationOptions{ + FilterQuery: &filterQuery, + }, + }) + if err != nil { + return nil, err + } + + if len(getByParamsResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple inference services found for name=%v, servingEnvironmentId=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(servingEnvironmentId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + if len(getByParamsResp.Contexts) == 0 { + return nil, fmt.Errorf("no inference services found for name=%v, servingEnvironmentId=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(servingEnvironmentId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + toReturn, err := serv.mapper.MapToInferenceService(getByParamsResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + return toReturn, nil +} + +// GetInferenceServices retrieves a list of inference services based on the provided list options and optional serving environment ID and runtime. +func (serv *ModelRegistryService) GetInferenceServices(listOptions api.ListOptions, servingEnvironmentId *string, runtime *string) (*openapi.InferenceServiceList, error) { + listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + queries := []string{} + if servingEnvironmentId != nil { + queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *servingEnvironmentId) + queries = append(queries, queryParentCtxId) + } + + if runtime != nil { + queryRuntimeProp := fmt.Sprintf("properties.runtime.string_value = \"%s\"", *runtime) + queries = append(queries, queryRuntimeProp) + } + + query := strings.Join(queries, " and ") + listOperationOptions.FilterQuery = &query + + contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ + TypeName: &serv.nameConfig.InferenceServiceTypeName, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + + results := []openapi.InferenceService{} + for _, c := range contextsResp.Contexts { + mapped, err := serv.mapper.MapToInferenceService(c) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + results = append(results, *mapped) + } + + toReturn := openapi.InferenceServiceList{ + NextPageToken: apiutils.ZeroIfNil(contextsResp.NextPageToken), + PageSize: apiutils.ZeroIfNil(listOptions.PageSize), + Size: int32(len(results)), + Items: results, + } + return &toReturn, nil +} diff --git a/pkg/core/inference_service_test.go b/pkg/core/inference_service_test.go new file mode 100644 index 00000000..4a76560e --- /dev/null +++ b/pkg/core/inference_service_test.go @@ -0,0 +1,649 @@ +package core + +import ( + "context" + "fmt" + + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// INFERENCE SERVICE + +func (suite *CoreTestSuite) TestCreateInferenceService() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + runtime := "model-server" + desiredState := openapi.INFERENCESERVICESTATE_DEPLOYED + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + Runtime: &runtime, + DesiredState: &desiredState, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %s: %v", parentResourceId, err) + + suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") + + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + + byId, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *createdEntityId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + suite.Equal(1, len(byId.Contexts), "there should be just one context saved in mlmd") + + suite.Equal(*createdEntityId, *byId.Contexts[0].Id, "returned id should match the mlmd one") + suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, entityName), *byId.Contexts[0].Name, "saved name should match the provided one") + suite.Equal(entityExternalId2, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") + suite.Equal(customString, byId.Contexts[0].CustomProperties["custom_string_prop"].GetStringValue(), "saved custom_string_prop custom property should match the provided one") + suite.Equal(entityDescription, byId.Contexts[0].Properties["description"].GetStringValue(), "saved description should match the provided one") + suite.Equal(runtime, byId.Contexts[0].Properties["runtime"].GetStringValue(), "saved runtime should match the provided one") + suite.Equal(string(desiredState), byId.Contexts[0].Properties["desired_state"].GetStringValue(), "saved state should match the provided one") + suite.Equalf(*inferenceServiceTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *inferenceServiceTypeName) + + getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(3, len(getAllResp.Contexts), "there should be 3 contexts (RegisteredModel, ServingEnvironment, InferenceService) saved in mlmd") +} + +func (suite *CoreTestSuite) TestCreateInferenceServiceFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + ServingEnvironmentId: "9999", + RegisteredModelId: "9998", + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + _, err := service.UpsertInferenceService(eut) + suite.NotNil(err) + suite.Equal("no serving environment found for id 9999: not found", err.Error()) + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + eut.ServingEnvironmentId = parentResourceId + + _, err = service.UpsertInferenceService(eut) + suite.NotNil(err) + suite.Equal("no registered model found for id 9998: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestUpdateInferenceService() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") + + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + + newExternalId := "org.my_awesome_entity@v1" + newScore := 0.95 + + createdEntity.ExternalId = &newExternalId + (*createdEntity.CustomProperties)["score"] = openapi.MetadataValue{ + MetadataDoubleValue: converter.NewMetadataDoubleValue(newScore), + } + + updatedEntity, err := service.UpsertInferenceService(createdEntity) + suite.Nilf(err, "error updating new entity for %s: %v", registeredModelId, err) + + updateEntityId, _ := converter.StringToInt64(updatedEntity.Id) + suite.Equal(*createdEntityId, *updateEntityId, "created and updated should have same id") + + byId, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *updateEntityId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + suite.Equal(1, len(byId.Contexts), "there should be 1 context saved in mlmd by id") + + suite.Equal(*updateEntityId, *byId.Contexts[0].Id, "returned id should match the mlmd one") + suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, *eut.Name), *byId.Contexts[0].Name, "saved name should match the provided one") + suite.Equal(newExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") + suite.Equal(customString, byId.Contexts[0].CustomProperties["custom_string_prop"].GetStringValue(), "saved custom_string_prop custom property should match the provided one") + suite.Equal(newScore, byId.Contexts[0].CustomProperties["score"].GetDoubleValue(), "saved score custom property should match the provided one") + suite.Equalf(*inferenceServiceTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *inferenceServiceTypeName) + + getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(3, len(getAllResp.Contexts), "there should be 3 contexts saved in mlmd") + + // update with nil name + newExternalId = "org.my_awesome_entity_@v1" + updatedEntity.ExternalId = &newExternalId + updatedEntity.Name = nil + updatedEntity, err = service.UpsertInferenceService(updatedEntity) + suite.Nilf(err, "error updating new model version for %s: %v", updateEntityId, err) + + updateEntityId, _ = converter.StringToInt64(updatedEntity.Id) + suite.Equal(*createdEntityId, *updateEntityId, "created and updated should have same id") + + byId, err = suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *updateEntityId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + suite.Equal(1, len(byId.Contexts), "there should be 1 context saved in mlmd by id") + + suite.Equal(*updateEntityId, *byId.Contexts[0].Id, "returned id should match the mlmd one") + suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, *eut.Name), *byId.Contexts[0].Name, "saved name should match the provided one") + suite.Equal(newExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") + suite.Equal(customString, byId.Contexts[0].CustomProperties["custom_string_prop"].GetStringValue(), "saved custom_string_prop custom property should match the provided one") + suite.Equal(newScore, byId.Contexts[0].CustomProperties["score"].GetDoubleValue(), "saved score custom property should match the provided one") + suite.Equalf(*inferenceServiceTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *inferenceServiceTypeName) + + // update with empty registeredModelId + newExternalId = "org.my_awesome_entity_@v1" + prevRegModelId := updatedEntity.RegisteredModelId + updatedEntity.RegisteredModelId = "" + updatedEntity, err = service.UpsertInferenceService(updatedEntity) + suite.Nil(err) + suite.Equal(prevRegModelId, updatedEntity.RegisteredModelId) +} + +func (suite *CoreTestSuite) TestUpdateInferenceServiceFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") + + newExternalId := "org.my_awesome_entity@v1" + newScore := 0.95 + + createdEntity.ExternalId = &newExternalId + (*createdEntity.CustomProperties)["score"] = openapi.MetadataValue{ + MetadataDoubleValue: converter.NewMetadataDoubleValue(newScore), + } + + wrongId := "9999" + createdEntity.Id = &wrongId + _, err = service.UpsertInferenceService(createdEntity) + suite.NotNil(err) + suite.Equal(fmt.Sprintf("no InferenceService found for id %s: not found", wrongId), err.Error()) +} + +func (suite *CoreTestSuite) TestGetInferenceServiceById() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + state := openapi.INFERENCESERVICESTATE_UNDEPLOYED + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + DesiredState: &state, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + + getById, err := service.GetInferenceServiceById(*createdEntity.Id) + suite.Nilf(err, "error getting model version with id %d", *createdEntityId) + + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *createdEntityId, + }, + }) + suite.Nilf(err, "error retrieving context, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + suite.Equal(*getById.Id, *converter.Int64ToString(ctx.Id), "returned id should match the mlmd context one") + suite.Equal(*eut.Name, *getById.Name, "saved name should match the provided one") + suite.Equal(*eut.ExternalId, *getById.ExternalId, "saved external id should match the provided one") + suite.Equal(*eut.DesiredState, *getById.DesiredState, "saved state should match the provided one") + suite.Equal((*getById.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, customString, "saved custom_string_prop custom property should match the provided one") +} + +func (suite *CoreTestSuite) TestGetRegisteredModelByInferenceServiceId() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") + + getRM, err := service.GetRegisteredModelByInferenceService(*createdEntity.Id) + suite.Nilf(err, "error getting using id %s", *createdEntity.Id) + + suite.Equal(registeredModelId, *getRM.Id, "returned id should match the original registeredModelId") +} + +func (suite *CoreTestSuite) TestGetModelVersionByInferenceServiceId() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + modelVersion1Name := "v1" + modelVersion1 := &openapi.ModelVersion{Name: modelVersion1Name, Description: &modelVersionDescription} + createdVersion1, err := service.UpsertModelVersion(modelVersion1, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersion1Id := *createdVersion1.Id + + modelVersion2Name := "v2" + modelVersion2 := &openapi.ModelVersion{Name: modelVersion2Name, Description: &modelVersionDescription} + createdVersion2, err := service.UpsertModelVersion(modelVersion2, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersion2Id := *createdVersion2.Id + // end of data preparation + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + ModelVersionId: nil, // first we test by unspecified + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + getVModel, err := service.GetModelVersionByInferenceService(*createdEntity.Id) + suite.Nilf(err, "error getting using id %s", *createdEntity.Id) + suite.Equal(createdVersion2Id, *getVModel.Id, "returned id shall be the latest ModelVersion by creation order") + + // here we used the returned entity (so ID is populated), and we update to specify the "ID of the ModelVersion to serve" + createdEntity.ModelVersionId = &createdVersion1Id + _, err = service.UpsertInferenceService(createdEntity) + suite.Nilf(err, "error updating eut for %v", parentResourceId) + + getVModel, err = service.GetModelVersionByInferenceService(*createdEntity.Id) + suite.Nilf(err, "error getting using id %s", *createdEntity.Id) + suite.Equal(createdVersion1Id, *getVModel.Id, "returned id shall be the specified one") +} + +func (suite *CoreTestSuite) TestGetModelArtifactByInferenceServiceId() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + modelVersion1Name := "v1" + modelVersion1 := &openapi.ModelVersion{Name: modelVersion1Name, Description: &modelVersionDescription} + createdVersion1, err := service.UpsertModelVersion(modelVersion1, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %s", registeredModelId) + modelArtifact1Name := "v1-artifact" + modelArtifact1 := &openapi.ModelArtifact{Name: &modelArtifact1Name} + createdArtifact1, err := service.UpsertModelArtifact(modelArtifact1, createdVersion1.Id) + suite.Nilf(err, "error creating new model artifact for %s", *createdVersion1.Id) + + modelVersion2Name := "v2" + modelVersion2 := &openapi.ModelVersion{Name: modelVersion2Name, Description: &modelVersionDescription} + createdVersion2, err := service.UpsertModelVersion(modelVersion2, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %s", registeredModelId) + modelArtifact2Name := "v2-artifact" + modelArtifact2 := &openapi.ModelArtifact{Name: &modelArtifact2Name} + createdArtifact2, err := service.UpsertModelArtifact(modelArtifact2, createdVersion2.Id) + suite.Nilf(err, "error creating new model artifact for %s", *createdVersion2.Id) + // end of data preparation + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + ModelVersionId: nil, // first we test by unspecified + } + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + getModelArt, err := service.GetModelArtifactByInferenceService(*createdEntity.Id) + suite.Nilf(err, "error getting using id %s", *createdEntity.Id) + suite.Equal(*createdArtifact2.Id, *getModelArt.Id, "returned id shall be the latest ModelVersion by creation order") + + // here we used the returned entity (so ID is populated), and we update to specify the "ID of the ModelVersion to serve" + createdEntity.ModelVersionId = createdVersion1.Id + _, err = service.UpsertInferenceService(createdEntity) + suite.Nilf(err, "error updating eut for %v", parentResourceId) + + getModelArt, err = service.GetModelArtifactByInferenceService(*createdEntity.Id) + suite.Nilf(err, "error getting using id %s", *createdEntity.Id) + suite.Equal(*createdArtifact1.Id, *getModelArt.Id, "returned id shall be the specified one") +} + +func (suite *CoreTestSuite) TestGetInferenceServiceByParamsWithNoResults() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + + _, err := service.GetInferenceServiceByParams(apiutils.Of("not-present"), &parentResourceId, nil) + suite.NotNil(err) + suite.Equal("no inference services found for name=not-present, servingEnvironmentId=1, externalId=: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestGetInferenceServiceByParamsName() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + + getByName, err := service.GetInferenceServiceByParams(&entityName, &parentResourceId, nil) + suite.Nilf(err, "error getting model version by name %d", *createdEntityId) + + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *createdEntityId, + }, + }) + suite.Nilf(err, "error retrieving context, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + suite.Equal(*converter.Int64ToString(ctx.Id), *getByName.Id, "returned id should match the mlmd context one") + suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, *getByName.Name), *ctx.Name, "saved name should match the provided one") + suite.Equal(*ctx.ExternalId, *getByName.ExternalId, "saved external id should match the provided one") + suite.Equal(ctx.CustomProperties["custom_string_prop"].GetStringValue(), (*getByName.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, "saved custom_string_prop custom property should match the provided one") +} + +func (suite *CoreTestSuite) TestGetInfernenceServiceByParamsExternalId() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + + getByExternalId, err := service.GetInferenceServiceByParams(nil, nil, eut.ExternalId) + suite.Nilf(err, "error getting by external id %d", *eut.ExternalId) + + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *createdEntityId, + }, + }) + suite.Nilf(err, "error retrieving context, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + suite.Equal(*converter.Int64ToString(ctx.Id), *getByExternalId.Id, "returned id should match the mlmd context one") + suite.Equal(fmt.Sprintf("%s:%s", parentResourceId, *getByExternalId.Name), *ctx.Name, "saved name should match the provided one") + suite.Equal(*ctx.ExternalId, *getByExternalId.ExternalId, "saved external id should match the provided one") + suite.Equal(ctx.CustomProperties["custom_string_prop"].GetStringValue(), (*getByExternalId.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, "saved custom_string_prop custom property should match the provided one") +} + +func (suite *CoreTestSuite) TestGetInferenceServiceByEmptyParams() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + eut := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + Description: &entityDescription, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertInferenceService(eut) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + suite.NotNilf(createdEntity.Id, "created eut should not have nil Id") + + _, err = service.GetInferenceServiceByParams(nil, nil, nil) + suite.NotNil(err) + suite.Equal("invalid parameters call, supply either (name and servingEnvironmentId), or externalId: bad request", err.Error()) +} + +func (suite *CoreTestSuite) TestGetInferenceServices() { + // create mode registry service + service := suite.setupModelRegistryService() + + parentResourceId := suite.registerServingEnvironment(service, nil, nil) + registeredModelId := suite.registerModel(service, nil, nil) + + eut1 := &openapi.InferenceService{ + Name: &entityName, + ExternalId: &entityExternalId2, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + Runtime: apiutils.Of("model-server0"), + } + + secondName := "v2" + secondExtId := "org.myawesomeentity@v2" + eut2 := &openapi.InferenceService{ + Name: &secondName, + ExternalId: &secondExtId, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + Runtime: apiutils.Of("model-server1"), + } + + thirdName := "v3" + thirdExtId := "org.myawesomeentity@v3" + eut3 := &openapi.InferenceService{ + Name: &thirdName, + ExternalId: &thirdExtId, + ServingEnvironmentId: parentResourceId, + RegisteredModelId: registeredModelId, + Runtime: apiutils.Of("model-server2"), + } + + createdEntity1, err := service.UpsertInferenceService(eut1) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + createdEntity2, err := service.UpsertInferenceService(eut2) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + createdEntity3, err := service.UpsertInferenceService(eut3) + suite.Nilf(err, "error creating new eut for %v", parentResourceId) + + anotherParentResourceName := "AnotherModel" + anotherParentResourceExtId := "org.another" + anotherParentResourceId := suite.registerServingEnvironment(service, &anotherParentResourceName, &anotherParentResourceExtId) + + anotherName := "v1.0" + anotherExtId := "org.another@v1.0" + eutAnother := &openapi.InferenceService{ + Name: &anotherName, + ExternalId: &anotherExtId, + ServingEnvironmentId: anotherParentResourceId, + RegisteredModelId: registeredModelId, + Runtime: apiutils.Of("model-server3"), + } + + _, err = service.UpsertInferenceService(eutAnother) + suite.Nilf(err, "error creating new model version for %d", anotherParentResourceId) + + createdId1, _ := converter.StringToInt64(createdEntity1.Id) + createdId2, _ := converter.StringToInt64(createdEntity2.Id) + createdId3, _ := converter.StringToInt64(createdEntity3.Id) + + getAll, err := service.GetInferenceServices(api.ListOptions{}, nil, nil) + suite.Nilf(err, "error getting all") + suite.Equal(int32(4), getAll.Size, "expected 4 across all parent resources") + + getAllByParentResource, err := service.GetInferenceServices(api.ListOptions{}, &parentResourceId, nil) + suite.Nilf(err, "error getting all") + suite.Equalf(int32(3), getAllByParentResource.Size, "expected 3 for parent resource %d", parentResourceId) + + suite.Equal(*converter.Int64ToString(createdId1), *getAllByParentResource.Items[0].Id) + suite.Equal(*converter.Int64ToString(createdId2), *getAllByParentResource.Items[1].Id) + suite.Equal(*converter.Int64ToString(createdId3), *getAllByParentResource.Items[2].Id) + + modelServer := "model-server1" + getAllByParentResourceAndRuntime, err := service.GetInferenceServices(api.ListOptions{}, &parentResourceId, &modelServer) + suite.Nilf(err, "error getting all") + suite.Equalf(int32(1), getAllByParentResourceAndRuntime.Size, "expected 1 for parent resource %s and runtime %s", parentResourceId, modelServer) + + suite.Equal(*converter.Int64ToString(createdId1), *getAllByParentResource.Items[0].Id) + + // order by last update time, expecting last created as first + orderByLastUpdate := "LAST_UPDATE_TIME" + getAllByParentResource, err = service.GetInferenceServices(api.ListOptions{ + OrderBy: &orderByLastUpdate, + SortOrder: &descOrderDirection, + }, &parentResourceId, nil) + suite.Nilf(err, "error getting all") + suite.Equalf(int32(3), getAllByParentResource.Size, "expected 3 for parent resource %d", parentResourceId) + + suite.Equal(*converter.Int64ToString(createdId1), *getAllByParentResource.Items[2].Id) + suite.Equal(*converter.Int64ToString(createdId2), *getAllByParentResource.Items[1].Id) + suite.Equal(*converter.Int64ToString(createdId3), *getAllByParentResource.Items[0].Id) + + // update the second entity + newExternalId := "updated.org:v2" + createdEntity2.ExternalId = &newExternalId + createdEntity2, err = service.UpsertInferenceService(createdEntity2) + suite.Nilf(err, "error creating new eut2 for %d", parentResourceId) + + suite.Equal(newExternalId, *createdEntity2.ExternalId) + + getAllByParentResource, err = service.GetInferenceServices(api.ListOptions{ + OrderBy: &orderByLastUpdate, + SortOrder: &descOrderDirection, + }, &parentResourceId, nil) + suite.Nilf(err, "error getting all") + suite.Equalf(int32(3), getAllByParentResource.Size, "expected 3 for parent resource %d", parentResourceId) + + suite.Equal(*converter.Int64ToString(createdId1), *getAllByParentResource.Items[2].Id) + suite.Equal(*converter.Int64ToString(createdId2), *getAllByParentResource.Items[0].Id) + suite.Equal(*converter.Int64ToString(createdId3), *getAllByParentResource.Items[1].Id) +} diff --git a/pkg/core/model_version.go b/pkg/core/model_version.go new file mode 100644 index 00000000..feec215d --- /dev/null +++ b/pkg/core/model_version.go @@ -0,0 +1,253 @@ +package core + +import ( + "context" + "fmt" + + "github.com/golang/glog" + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// MODEL VERSIONS + +// UpsertModelVersion creates a new model version if the provided model version's ID is nil, +// or updates an existing model version if the ID is provided. +func (serv *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, registeredModelId *string) (*openapi.ModelVersion, error) { + var err error + var existing *openapi.ModelVersion + var registeredModel *openapi.RegisteredModel + + if modelVersion.Id == nil { + // create + glog.Info("Creating new model version") + if registeredModelId == nil { + return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model: %w", api.ErrBadRequest) + } + registeredModel, err = serv.GetRegisteredModelById(*registeredModelId) + if err != nil { + return nil, err + } + } else { + // update + glog.Infof("Updating model version %s", *modelVersion.Id) + existing, err = serv.GetModelVersionById(*modelVersion.Id) + if err != nil { + return nil, err + } + + withNotEditable, err := serv.openapiConv.OverrideNotEditableForModelVersion(converter.NewOpenapiUpdateWrapper(existing, modelVersion)) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + modelVersion = &withNotEditable + + registeredModel, err = serv.getRegisteredModelByVersionId(*modelVersion.Id) + if err != nil { + return nil, err + } + } + + modelCtx, err := serv.mapper.MapFromModelVersion(modelVersion, *registeredModel.Id, registeredModel.Name) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + modelCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ + Contexts: []*proto.Context{ + modelCtx, + }, + }) + if err != nil { + return nil, err + } + + modelId := &modelCtxResp.ContextIds[0] + if modelVersion.Id == nil { + registeredModelId, err := converter.StringToInt64(registeredModel.Id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + _, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{ + ParentContexts: []*proto.ParentContext{{ + ChildId: modelId, + ParentId: registeredModelId, + }}, + TransactionOptions: &proto.TransactionOptions{}, + }) + if err != nil { + return nil, err + } + } + + idAsString := converter.Int64ToString(modelId) + model, err := serv.GetModelVersionById(*idAsString) + if err != nil { + return nil, err + } + + return model, nil +} + +// GetModelVersionById retrieves a model version by its unique identifier (ID). +func (serv *ModelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) { + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{int64(*idAsInt)}, + }) + if err != nil { + return nil, err + } + + if len(getByIdResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple model versions found for id %s: %w", id, api.ErrNotFound) + } + + if len(getByIdResp.Contexts) == 0 { + return nil, fmt.Errorf("no model version found for id %s: %w", id, api.ErrNotFound) + } + + modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return modelVer, nil +} + +// GetModelVersionByInferenceService retrieves the model version associated with the specified inference service ID. +func (serv *ModelRegistryService) GetModelVersionByInferenceService(inferenceServiceId string) (*openapi.ModelVersion, error) { + is, err := serv.GetInferenceServiceById(inferenceServiceId) + if err != nil { + return nil, err + } + if is.ModelVersionId != nil { + return serv.GetModelVersionById(*is.ModelVersionId) + } + // modelVersionId: ID of the ModelVersion to serve. If it's unspecified, then the latest ModelVersion by creation order will be served. + orderByCreateTime := "CREATE_TIME" + sortOrderDesc := "DESC" + versions, err := serv.GetModelVersions(api.ListOptions{OrderBy: &orderByCreateTime, SortOrder: &sortOrderDesc}, &is.RegisteredModelId) + if err != nil { + return nil, err + } + if len(versions.Items) == 0 { + return nil, fmt.Errorf("no model versions found for id %s: %w", is.RegisteredModelId, api.ErrNotFound) + } + return &versions.Items[0], nil +} + +// getModelVersionByArtifactId retrieves the model version associated with the specified model artifact ID. +func (serv *ModelRegistryService) getModelVersionByArtifactId(id string) (*openapi.ModelVersion, error) { + glog.Infof("Getting model version for model artifact %s", id) + + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + getParentResp, err := serv.mlmdClient.GetContextsByArtifact(context.Background(), &proto.GetContextsByArtifactRequest{ + ArtifactId: idAsInt, + }) + if err != nil { + return nil, err + } + + if len(getParentResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple model versions found for artifact %s: %w", id, api.ErrNotFound) + } + + if len(getParentResp.Contexts) == 0 { + return nil, fmt.Errorf("no model version found for artifact %s: %w", id, api.ErrNotFound) + } + + modelVersion, err := serv.mapper.MapToModelVersion(getParentResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return modelVersion, nil +} + +// GetModelVersionByParams retrieves a model version based on specified parameters, such as (version name and registered model ID), or external ID. +// If multiple or no model versions are found, an error is returned. +func (serv *ModelRegistryService) GetModelVersionByParams(versionName *string, registeredModelId *string, externalId *string) (*openapi.ModelVersion, error) { + filterQuery := "" + if versionName != nil && registeredModelId != nil { + filterQuery = fmt.Sprintf("name = \"%s\"", converter.PrefixWhenOwned(registeredModelId, *versionName)) + } else if externalId != nil { + filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) + } else { + return nil, fmt.Errorf("invalid parameters call, supply either (versionName and registeredModelId), or externalId: %w", api.ErrBadRequest) + } + + getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ + TypeName: &serv.nameConfig.ModelVersionTypeName, + Options: &proto.ListOperationOptions{ + FilterQuery: &filterQuery, + }, + }) + if err != nil { + return nil, err + } + + if len(getByParamsResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple model versions found for versionName=%v, registeredModelId=%v, externalId=%v: %w", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + if len(getByParamsResp.Contexts) == 0 { + return nil, fmt.Errorf("no model versions found for versionName=%v, registeredModelId=%v, externalId=%v: %w", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + return modelVer, nil +} + +// GetModelVersions retrieves a list of model versions based on the provided list options and optional registered model ID. +func (serv *ModelRegistryService) GetModelVersions(listOptions api.ListOptions, registeredModelId *string) (*openapi.ModelVersionList, error) { + listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + if registeredModelId != nil { + queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *registeredModelId) + listOperationOptions.FilterQuery = &queryParentCtxId + } + + contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ + TypeName: &serv.nameConfig.ModelVersionTypeName, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + + results := []openapi.ModelVersion{} + for _, c := range contextsResp.Contexts { + mapped, err := serv.mapper.MapToModelVersion(c) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + results = append(results, *mapped) + } + + toReturn := openapi.ModelVersionList{ + NextPageToken: apiutils.ZeroIfNil(contextsResp.NextPageToken), + PageSize: apiutils.ZeroIfNil(listOptions.PageSize), + Size: int32(len(results)), + Items: results, + } + return &toReturn, nil +} diff --git a/pkg/core/model_version_test.go b/pkg/core/model_version_test.go new file mode 100644 index 00000000..ccbcde4b --- /dev/null +++ b/pkg/core/model_version_test.go @@ -0,0 +1,425 @@ +package core + +import ( + "context" + "fmt" + + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// MODEL VERSIONS + +func (suite *CoreTestSuite) TestCreateModelVersion() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + state := openapi.MODELVERSIONSTATE_LIVE + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Description: &modelVersionDescription, + State: &state, + Author: &author, + } + + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + suite.Equal((*createdVersion).RegisteredModelId, registeredModelId, "RegisteredModelId should match the actual owner-entity") + + suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") + + createdVersionId, _ := converter.StringToInt64(createdVersion.Id) + + byId, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *createdVersionId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + suite.Equal(1, len(byId.Contexts), "there should be just one context saved in mlmd") + + suite.Equal(*createdVersionId, *byId.Contexts[0].Id, "returned model id should match the mlmd one") + suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, modelVersionName), *byId.Contexts[0].Name, "saved model name should match the provided one") + suite.Equal(versionExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") + suite.Equal(author, byId.Contexts[0].Properties["author"].GetStringValue(), "saved author property should match the provided one") + suite.Equal(modelVersionDescription, byId.Contexts[0].Properties["description"].GetStringValue(), "saved description should match the provided one") + suite.Equal(string(state), byId.Contexts[0].Properties["state"].GetStringValue(), "saved state should match the provided one") + suite.Equalf(*modelVersionTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *modelVersionTypeName) + + getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(2, len(getAllResp.Contexts), "there should be two contexts saved in mlmd") +} + +func (suite *CoreTestSuite) TestCreateModelVersionFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := "9999" + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Author: &author, + } + + _, err := service.UpsertModelVersion(modelVersion, nil) + suite.NotNil(err) + suite.Equal("missing registered model id, cannot create model version without registered model: bad request", err.Error()) + + _, err = service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.NotNil(err) + suite.Equal("no registered model found for id 9999: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestUpdateModelVersion() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Author: &author, + } + + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + + suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") + createdVersionId, _ := converter.StringToInt64(createdVersion.Id) + + newExternalId := "org.my_awesome_model@v1" + newScore := 0.95 + + createdVersion.ExternalId = &newExternalId + (*createdVersion.CustomProperties)["score"] = openapi.MetadataValue{ + MetadataDoubleValue: converter.NewMetadataDoubleValue(newScore), + } + + updatedVersion, err := service.UpsertModelVersion(createdVersion, ®isteredModelId) + suite.Nilf(err, "error updating new model version for %s: %v", registeredModelId, err) + suite.Equal((*updatedVersion).RegisteredModelId, registeredModelId, "RegisteredModelId should match the actual owner-entity") + + updateVersionId, _ := converter.StringToInt64(updatedVersion.Id) + suite.Equal(*createdVersionId, *updateVersionId, "created and updated model version should have same id") + + byId, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *updateVersionId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + suite.Equal(1, len(byId.Contexts), "there should be just one context saved in mlmd") + + suite.Equal(*updateVersionId, *byId.Contexts[0].Id, "returned model id should match the mlmd one") + suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, modelVersionName), *byId.Contexts[0].Name, "saved model name should match the provided one") + suite.Equal(newExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") + suite.Equal(author, byId.Contexts[0].Properties["author"].GetStringValue(), "saved author property should match the provided one") + suite.Equal(newScore, byId.Contexts[0].CustomProperties["score"].GetDoubleValue(), "saved score custom property should match the provided one") + suite.Equalf(*modelVersionTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *modelVersionTypeName) + + getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(2, len(getAllResp.Contexts), "there should be two contexts saved in mlmd") + + // update with nil name + newExternalId = "org.my_awesome_model_@v1" + updatedVersion.ExternalId = &newExternalId + updatedVersion.Name = "" + updatedVersion, err = service.UpsertModelVersion(updatedVersion, ®isteredModelId) + suite.Nilf(err, "error updating new model version for %s: %v", registeredModelId, err) + + updateVersionId, _ = converter.StringToInt64(updatedVersion.Id) + suite.Equal(*createdVersionId, *updateVersionId, "created and updated model version should have same id") + + byId, err = suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *updateVersionId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + suite.Equal(1, len(byId.Contexts), "there should be just one context saved in mlmd") + + suite.Equal(*updateVersionId, *byId.Contexts[0].Id, "returned model id should match the mlmd one") + suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, modelVersionName), *byId.Contexts[0].Name, "saved model name should match the provided one") + suite.Equal(newExternalId, *byId.Contexts[0].ExternalId, "saved external id should match the provided one") + suite.Equal(author, byId.Contexts[0].Properties["author"].GetStringValue(), "saved author property should match the provided one") + suite.Equal(newScore, byId.Contexts[0].CustomProperties["score"].GetDoubleValue(), "saved score custom property should match the provided one") + suite.Equalf(*modelVersionTypeName, *byId.Contexts[0].Type, "saved context should be of type of %s", *modelVersionTypeName) +} + +func (suite *CoreTestSuite) TestUpdateModelVersionFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Author: &author, + } + + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %s", registeredModelId) + suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") + + newExternalId := "org.my_awesome_model@v1" + newScore := 0.95 + + createdVersion.ExternalId = &newExternalId + (*createdVersion.CustomProperties)["score"] = openapi.MetadataValue{ + MetadataDoubleValue: converter.NewMetadataDoubleValue(newScore), + } + + wrongId := "9999" + createdVersion.Id = &wrongId + _, err = service.UpsertModelVersion(createdVersion, ®isteredModelId) + suite.NotNil(err) + suite.Equal(fmt.Sprintf("no model version found for id %s: not found", wrongId), err.Error()) +} + +func (suite *CoreTestSuite) TestGetModelVersionById() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + state := openapi.MODELVERSIONSTATE_ARCHIVED + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + State: &state, + Author: &author, + } + + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + + suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") + createdVersionId, _ := converter.StringToInt64(createdVersion.Id) + + getById, err := service.GetModelVersionById(*createdVersion.Id) + suite.Nilf(err, "error getting model version with id %d", *createdVersionId) + + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *createdVersionId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + suite.Equal(*converter.Int64ToString(ctx.Id), *getById.Id, "returned model version id should match the mlmd context one") + suite.Equal(modelVersion.Name, getById.Name, "saved model name should match the provided one") + suite.Equal(*modelVersion.ExternalId, *getById.ExternalId, "saved external id should match the provided one") + suite.Equal(*modelVersion.State, *getById.State, "saved model state should match the original one") + suite.Equal(*getById.Author, author, "saved author property should match the provided one") +} + +func (suite *CoreTestSuite) TestGetModelVersionByParamsWithNoResults() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + _, err := service.GetModelVersionByParams(apiutils.Of("not-present"), ®isteredModelId, nil) + suite.NotNil(err) + suite.Equal("no model versions found for versionName=not-present, registeredModelId=1, externalId=: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestGetModelVersionByParamsName() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Author: &author, + } + + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + + suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") + createdVersionId, _ := converter.StringToInt64(createdVersion.Id) + + getByName, err := service.GetModelVersionByParams(&modelVersionName, ®isteredModelId, nil) + suite.Nilf(err, "error getting model version by name %d", *createdVersionId) + + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *createdVersionId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + suite.Equal(*converter.Int64ToString(ctx.Id), *getByName.Id, "returned model version id should match the mlmd context one") + suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, getByName.Name), *ctx.Name, "saved model name should match the provided one") + suite.Equal(*ctx.ExternalId, *getByName.ExternalId, "saved external id should match the provided one") + suite.Equal(ctx.Properties["author"].GetStringValue(), *getByName.Author, "saved author property should match the provided one") +} + +func (suite *CoreTestSuite) TestGetModelVersionByParamsExternalId() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Author: &author, + } + + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + + suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") + createdVersionId, _ := converter.StringToInt64(createdVersion.Id) + + getByExternalId, err := service.GetModelVersionByParams(nil, nil, modelVersion.ExternalId) + suite.Nilf(err, "error getting model version by external id %d", *modelVersion.ExternalId) + + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{ + *createdVersionId, + }, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + suite.Equal(*converter.Int64ToString(ctx.Id), *getByExternalId.Id, "returned model version id should match the mlmd context one") + suite.Equal(fmt.Sprintf("%s:%s", registeredModelId, getByExternalId.Name), *ctx.Name, "saved model name should match the provided one") + suite.Equal(*ctx.ExternalId, *getByExternalId.ExternalId, "saved external id should match the provided one") + suite.Equal(ctx.Properties["author"].GetStringValue(), *getByExternalId.Author, "saved author property should match the provided one") +} + +func (suite *CoreTestSuite) TestGetModelVersionByEmptyParams() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Author: &author, + } + + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") + + _, err = service.GetModelVersionByParams(nil, nil, nil) + suite.NotNil(err) + suite.Equal("invalid parameters call, supply either (versionName and registeredModelId), or externalId: bad request", err.Error()) +} + +func (suite *CoreTestSuite) TestGetModelVersions() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + + modelVersion1 := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + } + + secondModelVersionName := "v2" + secondModelVersionExtId := "org.myawesomemodel@v2" + modelVersion2 := &openapi.ModelVersion{ + Name: secondModelVersionName, + ExternalId: &secondModelVersionExtId, + } + + thirdModelVersionName := "v3" + thirdModelVersionExtId := "org.myawesomemodel@v3" + modelVersion3 := &openapi.ModelVersion{ + Name: thirdModelVersionName, + ExternalId: &thirdModelVersionExtId, + } + + createdVersion1, err := service.UpsertModelVersion(modelVersion1, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + + createdVersion2, err := service.UpsertModelVersion(modelVersion2, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + + createdVersion3, err := service.UpsertModelVersion(modelVersion3, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + + anotherRegModelName := "AnotherModel" + anotherRegModelExtId := "org.another" + anotherRegisteredModelId := suite.registerModel(service, &anotherRegModelName, &anotherRegModelExtId) + + anotherModelVersionName := "v1.0" + anotherModelVersionExtId := "org.another@v1.0" + modelVersionAnother := &openapi.ModelVersion{ + Name: anotherModelVersionName, + ExternalId: &anotherModelVersionExtId, + } + + _, err = service.UpsertModelVersion(modelVersionAnother, &anotherRegisteredModelId) + suite.Nilf(err, "error creating new model version for %d", anotherRegisteredModelId) + + createdVersionId1, _ := converter.StringToInt64(createdVersion1.Id) + createdVersionId2, _ := converter.StringToInt64(createdVersion2.Id) + createdVersionId3, _ := converter.StringToInt64(createdVersion3.Id) + + getAll, err := service.GetModelVersions(api.ListOptions{}, nil) + suite.Nilf(err, "error getting all model versions") + suite.Equal(int32(4), getAll.Size, "expected four model versions across all registered models") + + getAllByRegModel, err := service.GetModelVersions(api.ListOptions{}, ®isteredModelId) + suite.Nilf(err, "error getting all model versions") + suite.Equalf(int32(3), getAllByRegModel.Size, "expected three model versions for registered model %d", registeredModelId) + + suite.Equal(*converter.Int64ToString(createdVersionId1), *getAllByRegModel.Items[0].Id) + suite.Equal(*converter.Int64ToString(createdVersionId2), *getAllByRegModel.Items[1].Id) + suite.Equal(*converter.Int64ToString(createdVersionId3), *getAllByRegModel.Items[2].Id) + + // order by last update time, expecting last created as first + orderByLastUpdate := "LAST_UPDATE_TIME" + getAllByRegModel, err = service.GetModelVersions(api.ListOptions{ + OrderBy: &orderByLastUpdate, + SortOrder: &descOrderDirection, + }, ®isteredModelId) + suite.Nilf(err, "error getting all model versions") + suite.Equalf(int32(3), getAllByRegModel.Size, "expected three model versions for registered model %d", registeredModelId) + + suite.Equal(*converter.Int64ToString(createdVersionId1), *getAllByRegModel.Items[2].Id) + suite.Equal(*converter.Int64ToString(createdVersionId2), *getAllByRegModel.Items[1].Id) + suite.Equal(*converter.Int64ToString(createdVersionId3), *getAllByRegModel.Items[0].Id) + + // update the second version + newVersionExternalId := "updated.org:v2" + createdVersion2.ExternalId = &newVersionExternalId + createdVersion2, err = service.UpsertModelVersion(createdVersion2, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + + suite.Equal(newVersionExternalId, *createdVersion2.ExternalId) + + getAllByRegModel, err = service.GetModelVersions(api.ListOptions{ + OrderBy: &orderByLastUpdate, + SortOrder: &descOrderDirection, + }, ®isteredModelId) + suite.Nilf(err, "error getting all model versions") + suite.Equalf(int32(3), getAllByRegModel.Size, "expected three model versions for registered model %d", registeredModelId) + + suite.Equal(*converter.Int64ToString(createdVersionId1), *getAllByRegModel.Items[2].Id) + suite.Equal(*converter.Int64ToString(createdVersionId2), *getAllByRegModel.Items[0].Id) + suite.Equal(*converter.Int64ToString(createdVersionId3), *getAllByRegModel.Items[1].Id) +} diff --git a/pkg/core/registered_model.go b/pkg/core/registered_model.go new file mode 100644 index 00000000..829a03e0 --- /dev/null +++ b/pkg/core/registered_model.go @@ -0,0 +1,205 @@ +package core + +import ( + "context" + "fmt" + + "github.com/golang/glog" + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// REGISTERED MODELS + +// UpsertRegisteredModel creates a new registered model if the given registered model's ID is nil, +// or updates an existing registered model if the ID is provided. +func (serv *ModelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) { + var err error + var existing *openapi.RegisteredModel + + if registeredModel.Id == nil { + glog.Info("Creating new registered model") + } else { + glog.Infof("Updating registered model %s", *registeredModel.Id) + existing, err = serv.GetRegisteredModelById(*registeredModel.Id) + if err != nil { + return nil, err + } + + withNotEditable, err := serv.openapiConv.OverrideNotEditableForRegisteredModel(converter.NewOpenapiUpdateWrapper(existing, registeredModel)) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + registeredModel = &withNotEditable + } + + modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel) + if err != nil { + return nil, err + } + + modelCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ + Contexts: []*proto.Context{ + modelCtx, + }, + }) + if err != nil { + return nil, err + } + + idAsString := converter.Int64ToString(&modelCtxResp.ContextIds[0]) + model, err := serv.GetRegisteredModelById(*idAsString) + if err != nil { + return nil, err + } + + return model, nil +} + +// GetRegisteredModelById retrieves a registered model by its unique identifier (ID). +func (serv *ModelRegistryService) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) { + glog.Infof("Getting registered model %s", id) + + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{int64(*idAsInt)}, + }) + if err != nil { + return nil, err + } + + if len(getByIdResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple registered models found for id %s: %w", id, api.ErrNotFound) + } + + if len(getByIdResp.Contexts) == 0 { + return nil, fmt.Errorf("no registered model found for id %s: %w", id, api.ErrNotFound) + } + + regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return regModel, nil +} + +// GetRegisteredModelByInferenceService retrieves a registered model associated with the specified inference service ID. +func (serv *ModelRegistryService) GetRegisteredModelByInferenceService(inferenceServiceId string) (*openapi.RegisteredModel, error) { + is, err := serv.GetInferenceServiceById(inferenceServiceId) + if err != nil { + return nil, err + } + return serv.GetRegisteredModelById(is.RegisteredModelId) +} + +// getRegisteredModelByVersionId retrieves a registered model associated with the specified model version ID. +func (serv *ModelRegistryService) getRegisteredModelByVersionId(id string) (*openapi.RegisteredModel, error) { + glog.Infof("Getting registered model for model version %s", id) + + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + getParentResp, err := serv.mlmdClient.GetParentContextsByContext(context.Background(), &proto.GetParentContextsByContextRequest{ + ContextId: idAsInt, + }) + if err != nil { + return nil, err + } + + if len(getParentResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple registered models found for model version %s: %w", id, api.ErrNotFound) + } + + if len(getParentResp.Contexts) == 0 { + return nil, fmt.Errorf("no registered model found for model version %s: %w", id, api.ErrNotFound) + } + + regModel, err := serv.mapper.MapToRegisteredModel(getParentResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return regModel, nil +} + +// GetRegisteredModelByParams retrieves a registered model based on specified parameters, such as name or external ID. +// If multiple or no registered models are found, an error is returned accordingly. +func (serv *ModelRegistryService) GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error) { + glog.Infof("Getting registered model by params name=%v, externalId=%v", name, externalId) + + filterQuery := "" + if name != nil { + filterQuery = fmt.Sprintf("name = \"%s\"", *name) + } else if externalId != nil { + filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) + } else { + return nil, fmt.Errorf("invalid parameters call, supply either name or externalId: %w", api.ErrBadRequest) + } + glog.Info("filterQuery ", filterQuery) + + getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ + TypeName: &serv.nameConfig.RegisteredModelTypeName, + Options: &proto.ListOperationOptions{ + FilterQuery: &filterQuery, + }, + }) + if err != nil { + return nil, err + } + + if len(getByParamsResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple registered models found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + if len(getByParamsResp.Contexts) == 0 { + return nil, fmt.Errorf("no registered models found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + regModel, err := serv.mapper.MapToRegisteredModel(getByParamsResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + return regModel, nil +} + +// GetRegisteredModels retrieves a list of registered models based on the provided list options. +func (serv *ModelRegistryService) GetRegisteredModels(listOptions api.ListOptions) (*openapi.RegisteredModelList, error) { + listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) + if err != nil { + return nil, err + } + contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ + TypeName: &serv.nameConfig.RegisteredModelTypeName, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + + results := []openapi.RegisteredModel{} + for _, c := range contextsResp.Contexts { + mapped, err := serv.mapper.MapToRegisteredModel(c) + if err != nil { + return nil, err + } + results = append(results, *mapped) + } + + toReturn := openapi.RegisteredModelList{ + NextPageToken: apiutils.ZeroIfNil(contextsResp.NextPageToken), + PageSize: apiutils.ZeroIfNil(listOptions.PageSize), + Size: int32(len(results)), + Items: results, + } + return &toReturn, nil +} diff --git a/pkg/core/registered_model_test.go b/pkg/core/registered_model_test.go new file mode 100644 index 00000000..75d84a3f --- /dev/null +++ b/pkg/core/registered_model_test.go @@ -0,0 +1,414 @@ +package core + +import ( + "context" + + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// REGISTERED MODELS + +func (suite *CoreTestSuite) TestCreateRegisteredModel() { + // create mode registry service + service := suite.setupModelRegistryService() + + state := openapi.REGISTEREDMODELSTATE_ARCHIVED + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + ExternalId: &modelExternalId, + Description: &modelDescription, + Owner: &modelOwner, + State: &state, + CustomProperties: &map[string]openapi.MetadataValue{ + "myCustomProp": { + MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), + }, + }, + } + + // test + createdModel, err := service.UpsertRegisteredModel(registeredModel) + + // checks + suite.Nilf(err, "error creating registered model: %v", err) + suite.NotNilf(createdModel.Id, "created registered model should not have nil Id") + + createdModelId, _ := converter.StringToInt64(createdModel.Id) + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{*createdModelId}, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + ctxId := converter.Int64ToString(ctx.Id) + suite.Equal(*createdModel.Id, *ctxId, "returned model id should match the mlmd one") + suite.Equal(modelName, *ctx.Name, "saved model name should match the provided one") + suite.Equal(modelExternalId, *ctx.ExternalId, "saved external id should match the provided one") + suite.Equal(modelDescription, ctx.Properties["description"].GetStringValue(), "saved description should match the provided one") + suite.Equal(modelOwner, ctx.Properties["owner"].GetStringValue(), "saved owner should match the provided one") + suite.Equal(string(state), ctx.Properties["state"].GetStringValue(), "saved state should match the provided one") + suite.Equal(myCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") + + getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") +} + +func (suite *CoreTestSuite) TestUpdateRegisteredModel() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + Owner: &modelOwner, + ExternalId: &modelExternalId, + CustomProperties: &map[string]openapi.MetadataValue{ + "myCustomProp": { + MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), + }, + }, + } + + // test + createdModel, err := service.UpsertRegisteredModel(registeredModel) + + // checks + suite.Nilf(err, "error creating registered model: %v", err) + suite.NotNilf(createdModel.Id, "created registered model should not have nil Id") + createdModelId, _ := converter.StringToInt64(createdModel.Id) + + // checks created model matches original one except for Id + suite.Equal(registeredModel.Name, createdModel.Name, "returned model name should match the original one") + suite.Equal(*registeredModel.ExternalId, *createdModel.ExternalId, "returned model external id should match the original one") + suite.Equal(*registeredModel.CustomProperties, *createdModel.CustomProperties, "returned model custom props should match the original one") + + // update existing model + newModelExternalId := "newExternalId" + newOwner := "newOwner" + newCustomProp := "updated myCustomProp" + + createdModel.ExternalId = &newModelExternalId + createdModel.Owner = &newOwner + (*createdModel.CustomProperties)["myCustomProp"] = openapi.MetadataValue{ + MetadataStringValue: converter.NewMetadataStringValue(newCustomProp), + } + // check can also define customProperty of name "owner", in addition to built-in property "owner" + (*createdModel.CustomProperties)["owner"] = openapi.MetadataValue{ + MetadataStringValue: converter.NewMetadataStringValue(newCustomProp), + } + + // update the model + createdModel, err = service.UpsertRegisteredModel(createdModel) + suite.Nilf(err, "error creating registered model: %v", err) + + // still one registered model + getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") + + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{*createdModelId}, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + ctxId := converter.Int64ToString(ctx.Id) + suite.Equal(*createdModel.Id, *ctxId, "returned model id should match the mlmd one") + suite.Equal(modelName, *ctx.Name, "saved model name should match the provided one") + suite.Equal(newModelExternalId, *ctx.ExternalId, "saved external id should match the provided one") + suite.Equal(newOwner, ctx.Properties["owner"].GetStringValue(), "saved owner custom property should match the provided one") + suite.Equal(newCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") + suite.Equal(newCustomProp, ctx.CustomProperties["owner"].GetStringValue(), "check can define custom property 'onwer' and should match the provided one") + + // update the model keeping nil name + newModelExternalId = "newNewExternalId" + createdModel.ExternalId = &newModelExternalId + createdModel.Name = "" + createdModel, err = service.UpsertRegisteredModel(createdModel) + suite.Nilf(err, "error creating registered model: %v", err) + + // still one registered model + getAllResp, err = suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") + + ctxById, err = suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{*createdModelId}, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx = ctxById.Contexts[0] + ctxId = converter.Int64ToString(ctx.Id) + suite.Equal(*createdModel.Id, *ctxId, "returned model id should match the mlmd one") + suite.Equal(modelName, *ctx.Name, "saved model name should match the provided one") + suite.Equal(newModelExternalId, *ctx.ExternalId, "saved external id should match the provided one") + suite.Equal(newOwner, ctx.Properties["owner"].GetStringValue(), "saved owner custom property should match the provided one") + suite.Equal(newCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") + suite.Equal(newCustomProp, ctx.CustomProperties["owner"].GetStringValue(), "check can define custom property 'onwer' and should match the provided one") +} + +func (suite *CoreTestSuite) TestGetRegisteredModelById() { + // create mode registry service + service := suite.setupModelRegistryService() + + state := openapi.REGISTEREDMODELSTATE_LIVE + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + ExternalId: &modelExternalId, + State: &state, + CustomProperties: &map[string]openapi.MetadataValue{ + "myCustomProp": { + MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), + }, + }, + } + + // test + createdModel, err := service.UpsertRegisteredModel(registeredModel) + + // checks + suite.Nilf(err, "error creating registered model: %v", err) + + getModelById, err := service.GetRegisteredModelById(*createdModel.Id) + suite.Nilf(err, "error getting registered model by id %s: %v", *createdModel.Id, err) + + // checks created model matches original one except for Id + suite.Equal(registeredModel.Name, getModelById.Name, "saved model name should match the original one") + suite.Equal(*registeredModel.ExternalId, *getModelById.ExternalId, "saved model external id should match the original one") + suite.Equal(*registeredModel.State, *getModelById.State, "saved model state should match the original one") + suite.Equal(*registeredModel.CustomProperties, *getModelById.CustomProperties, "saved model custom props should match the original one") +} + +func (suite *CoreTestSuite) TestGetRegisteredModelByParamsWithNoResults() { + // create mode registry service + service := suite.setupModelRegistryService() + + _, err := service.GetRegisteredModelByParams(apiutils.Of("not-present"), nil) + suite.NotNil(err) + suite.Equal("no registered models found for name=not-present, externalId=: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestGetRegisteredModelByParamsName() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + ExternalId: &modelExternalId, + } + + createdModel, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + byName, err := service.GetRegisteredModelByParams(&modelName, nil) + suite.Nilf(err, "error getting registered model by name: %v", err) + + suite.Equalf(*createdModel.Id, *byName.Id, "the returned model id should match the retrieved by name") +} + +func (suite *CoreTestSuite) TestGetRegisteredModelByParamsExternalId() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + ExternalId: &modelExternalId, + } + + createdModel, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + byName, err := service.GetRegisteredModelByParams(nil, &modelExternalId) + suite.Nilf(err, "error getting registered model by external id: %v", err) + + suite.Equalf(*createdModel.Id, *byName.Id, "the returned model id should match the retrieved by name") +} + +func (suite *CoreTestSuite) TestGetRegisteredModelByEmptyParams() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + ExternalId: &modelExternalId, + } + + _, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + _, err = service.GetRegisteredModelByParams(nil, nil) + suite.NotNil(err) + suite.Equal("invalid parameters call, supply either name or externalId: bad request", err.Error()) +} + +func (suite *CoreTestSuite) TestGetRegisteredModelsOrderedById() { + // create mode registry service + service := suite.setupModelRegistryService() + + orderBy := "ID" + + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + ExternalId: &modelExternalId, + } + + _, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + newModelName := "PricingModel2" + newModelExternalId := "myExternalId2" + registeredModel.Name = newModelName + registeredModel.ExternalId = &newModelExternalId + _, err = service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + newModelName = "PricingModel3" + newModelExternalId = "myExternalId3" + registeredModel.Name = newModelName + registeredModel.ExternalId = &newModelExternalId + _, err = service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + orderedById, err := service.GetRegisteredModels(api.ListOptions{ + OrderBy: &orderBy, + SortOrder: &ascOrderDirection, + }) + suite.Nilf(err, "error getting registered models: %v", err) + + suite.Equal(3, int(orderedById.Size)) + for i := 0; i < int(orderedById.Size)-1; i++ { + suite.Less(*orderedById.Items[i].Id, *orderedById.Items[i+1].Id) + } + + orderedById, err = service.GetRegisteredModels(api.ListOptions{ + OrderBy: &orderBy, + SortOrder: &descOrderDirection, + }) + suite.Nilf(err, "error getting registered models: %v", err) + + suite.Equal(3, int(orderedById.Size)) + for i := 0; i < int(orderedById.Size)-1; i++ { + suite.Greater(*orderedById.Items[i].Id, *orderedById.Items[i+1].Id) + } +} + +func (suite *CoreTestSuite) TestGetRegisteredModelsOrderedByLastUpdate() { + // create mode registry service + service := suite.setupModelRegistryService() + + orderBy := "LAST_UPDATE_TIME" + + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + ExternalId: &modelExternalId, + } + + firstModel, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + newModelName := "PricingModel2" + newModelExternalId := "myExternalId2" + registeredModel.Name = newModelName + registeredModel.ExternalId = &newModelExternalId + secondModel, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + newModelName = "PricingModel3" + newModelExternalId = "myExternalId3" + registeredModel.Name = newModelName + registeredModel.ExternalId = &newModelExternalId + thirdModel, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + // update second model + secondModel.ExternalId = nil + _, err = service.UpsertRegisteredModel(secondModel) + suite.Nilf(err, "error creating registered model: %v", err) + + orderedById, err := service.GetRegisteredModels(api.ListOptions{ + OrderBy: &orderBy, + SortOrder: &ascOrderDirection, + }) + suite.Nilf(err, "error getting registered models: %v", err) + + suite.Equal(3, int(orderedById.Size)) + suite.Equal(*firstModel.Id, *orderedById.Items[0].Id) + suite.Equal(*thirdModel.Id, *orderedById.Items[1].Id) + suite.Equal(*secondModel.Id, *orderedById.Items[2].Id) + + orderedById, err = service.GetRegisteredModels(api.ListOptions{ + OrderBy: &orderBy, + SortOrder: &descOrderDirection, + }) + suite.Nilf(err, "error getting registered models: %v", err) + + suite.Equal(3, int(orderedById.Size)) + suite.Equal(*secondModel.Id, *orderedById.Items[0].Id) + suite.Equal(*thirdModel.Id, *orderedById.Items[1].Id) + suite.Equal(*firstModel.Id, *orderedById.Items[2].Id) +} + +func (suite *CoreTestSuite) TestGetRegisteredModelsWithPageSize() { + // create mode registry service + service := suite.setupModelRegistryService() + + pageSize := int32(1) + pageSize2 := int32(2) + modelName := "PricingModel1" + modelExternalId := "myExternalId1" + + // register a new model + registeredModel := &openapi.RegisteredModel{ + Name: modelName, + ExternalId: &modelExternalId, + } + + firstModel, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + newModelName := "PricingModel2" + newModelExternalId := "myExternalId2" + registeredModel.Name = newModelName + registeredModel.ExternalId = &newModelExternalId + secondModel, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + newModelName = "PricingModel3" + newModelExternalId = "myExternalId3" + registeredModel.Name = newModelName + registeredModel.ExternalId = &newModelExternalId + thirdModel, err := service.UpsertRegisteredModel(registeredModel) + suite.Nilf(err, "error creating registered model: %v", err) + + truncatedList, err := service.GetRegisteredModels(api.ListOptions{ + PageSize: &pageSize, + }) + suite.Nilf(err, "error getting registered models: %v", err) + + suite.Equal(1, int(truncatedList.Size)) + suite.NotEqual("", truncatedList.NextPageToken, "next page token should not be empty") + suite.Equal(*firstModel.Id, *truncatedList.Items[0].Id) + + truncatedList, err = service.GetRegisteredModels(api.ListOptions{ + PageSize: &pageSize2, + NextPageToken: &truncatedList.NextPageToken, + }) + suite.Nilf(err, "error getting registered models: %v", err) + + suite.Equal(2, int(truncatedList.Size)) + suite.Equal("", truncatedList.NextPageToken, "next page token should be empty as list item returned") + suite.Equal(*secondModel.Id, *truncatedList.Items[0].Id) + suite.Equal(*thirdModel.Id, *truncatedList.Items[1].Id) +} diff --git a/pkg/core/serve_model.go b/pkg/core/serve_model.go new file mode 100644 index 00000000..6226ea7e --- /dev/null +++ b/pkg/core/serve_model.go @@ -0,0 +1,220 @@ +package core + +import ( + "context" + "fmt" + + "github.com/golang/glog" + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// SERVE MODEL + +// UpsertServeModel creates a new serve model if the provided serve model's ID is nil, +// or updates an existing serve model if the ID is provided. +func (serv *ModelRegistryService) UpsertServeModel(serveModel *openapi.ServeModel, inferenceServiceId *string) (*openapi.ServeModel, error) { + var err error + var existing *openapi.ServeModel + + if serveModel.Id == nil { + // create + glog.Info("Creating new ServeModel") + if inferenceServiceId == nil { + return nil, fmt.Errorf("missing inferenceServiceId, cannot create ServeModel without parent resource InferenceService: %w", api.ErrBadRequest) + } + _, err = serv.GetInferenceServiceById(*inferenceServiceId) + if err != nil { + return nil, err + } + } else { + // update + glog.Infof("Updating ServeModel %s", *serveModel.Id) + + existing, err = serv.GetServeModelById(*serveModel.Id) + if err != nil { + return nil, err + } + + withNotEditable, err := serv.openapiConv.OverrideNotEditableForServeModel(converter.NewOpenapiUpdateWrapper(existing, serveModel)) + if err != nil { + return nil, err + } + serveModel = &withNotEditable + + _, err = serv.getInferenceServiceByServeModel(*serveModel.Id) + if err != nil { + return nil, err + } + } + _, err = serv.GetModelVersionById(serveModel.ModelVersionId) + if err != nil { + return nil, err + } + + // if already existing assure the name is the same + if existing != nil && serveModel.Name == nil { + // user did not provide it + // need to set it to avoid mlmd error "artifact name should not be empty" + serveModel.Name = existing.Name + } + + execution, err := serv.mapper.MapFromServeModel(serveModel, *inferenceServiceId) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + executionsResp, err := serv.mlmdClient.PutExecutions(context.Background(), &proto.PutExecutionsRequest{ + Executions: []*proto.Execution{execution}, + }) + if err != nil { + return nil, err + } + + // add explicit Association between ServeModel and InferenceService + if inferenceServiceId != nil && serveModel.Id == nil { + inferenceServiceId, err := converter.StringToInt64(inferenceServiceId) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + associations := []*proto.Association{} + for _, a := range executionsResp.ExecutionIds { + associations = append(associations, &proto.Association{ + ContextId: inferenceServiceId, + ExecutionId: &a, + }) + } + _, err = serv.mlmdClient.PutAttributionsAndAssociations(context.Background(), &proto.PutAttributionsAndAssociationsRequest{ + Attributions: make([]*proto.Attribution, 0), + Associations: associations, + }) + if err != nil { + return nil, err + } + } + + idAsString := converter.Int64ToString(&executionsResp.ExecutionIds[0]) + mapped, err := serv.GetServeModelById(*idAsString) + if err != nil { + return nil, err + } + return mapped, nil +} + +// getInferenceServiceByServeModel retrieves the inference service associated with the specified serve model ID. +func (serv *ModelRegistryService) getInferenceServiceByServeModel(id string) (*openapi.InferenceService, error) { + glog.Infof("Getting InferenceService for ServeModel %s", id) + + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + getParentResp, err := serv.mlmdClient.GetContextsByExecution(context.Background(), &proto.GetContextsByExecutionRequest{ + ExecutionId: idAsInt, + }) + if err != nil { + return nil, err + } + + if len(getParentResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple InferenceService found for ServeModel %s: %w", id, api.ErrNotFound) + } + + if len(getParentResp.Contexts) == 0 { + return nil, fmt.Errorf("no InferenceService found for ServeModel %s: %w", id, api.ErrNotFound) + } + + toReturn, err := serv.mapper.MapToInferenceService(getParentResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return toReturn, nil +} + +// GetServeModelById retrieves a serve model by its unique identifier (ID). +func (serv *ModelRegistryService) GetServeModelById(id string) (*openapi.ServeModel, error) { + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + executionsResp, err := serv.mlmdClient.GetExecutionsByID(context.Background(), &proto.GetExecutionsByIDRequest{ + ExecutionIds: []int64{int64(*idAsInt)}, + }) + if err != nil { + return nil, err + } + + if len(executionsResp.Executions) > 1 { + return nil, fmt.Errorf("multiple ServeModels found for id %s: %w", id, api.ErrNotFound) + } + + if len(executionsResp.Executions) == 0 { + return nil, fmt.Errorf("no ServeModel found for id %s: %w", id, api.ErrNotFound) + } + + result, err := serv.mapper.MapToServeModel(executionsResp.Executions[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return result, nil +} + +// GetServeModels retrieves a list of serve models based on the provided list options and optional inference service ID. +func (serv *ModelRegistryService) GetServeModels(listOptions api.ListOptions, inferenceServiceId *string) (*openapi.ServeModelList, error) { + listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + var executions []*proto.Execution + var nextPageToken *string + if inferenceServiceId != nil { + ctxId, err := converter.StringToInt64(inferenceServiceId) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + executionsResp, err := serv.mlmdClient.GetExecutionsByContext(context.Background(), &proto.GetExecutionsByContextRequest{ + ContextId: ctxId, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + executions = executionsResp.Executions + nextPageToken = executionsResp.NextPageToken + } else { + executionsResp, err := serv.mlmdClient.GetExecutionsByType(context.Background(), &proto.GetExecutionsByTypeRequest{ + TypeName: &serv.nameConfig.ServeModelTypeName, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + executions = executionsResp.Executions + nextPageToken = executionsResp.NextPageToken + } + + results := []openapi.ServeModel{} + for _, a := range executions { + mapped, err := serv.mapper.MapToServeModel(a) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + results = append(results, *mapped) + } + + toReturn := openapi.ServeModelList{ + NextPageToken: apiutils.ZeroIfNil(nextPageToken), + PageSize: apiutils.ZeroIfNil(listOptions.PageSize), + Size: int32(len(results)), + Items: results, + } + return &toReturn, nil +} diff --git a/pkg/core/serve_model_test.go b/pkg/core/serve_model_test.go new file mode 100644 index 00000000..e0bd9d8f --- /dev/null +++ b/pkg/core/serve_model_test.go @@ -0,0 +1,360 @@ +package core + +import ( + "context" + "fmt" + + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// SERVE MODEL + +func (suite *CoreTestSuite) TestCreateServeModel() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Description: &modelVersionDescription, + Author: &author, + } + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersionId := *createdVersion.Id + createdVersionIdAsInt, _ := converter.StringToInt64(&createdVersionId) + // end of data preparation + + eut := &openapi.ServeModel{ + LastKnownState: (*openapi.ExecutionState)(&executionState), + ExternalId: &entityExternalId2, + Description: &entityDescription, + Name: &entityName, + ModelVersionId: createdVersionId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertServeModel(eut, &inferenceServiceId) + suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) + suite.NotNil(createdEntity.Id, "created id should not be nil") + + state, _ := openapi.NewExecutionStateFromValue(executionState) + suite.Equal(entityName, *createdEntity.Name) + suite.Equal(*state, *createdEntity.LastKnownState) + suite.Equal(createdVersionId, createdEntity.ModelVersionId) + suite.Equal(entityDescription, *createdEntity.Description) + suite.Equal(customString, (*createdEntity.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) + + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + getById, err := suite.mlmdClient.GetExecutionsByID(context.Background(), &proto.GetExecutionsByIDRequest{ + ExecutionIds: []int64{*createdEntityId}, + }) + suite.Nilf(err, "error getting Execution by id %d", createdEntityId) + + suite.Equal(*createdEntityId, *getById.Executions[0].Id) + suite.Equal(fmt.Sprintf("%s:%s", inferenceServiceId, *createdEntity.Name), *getById.Executions[0].Name) + suite.Equal(string(*createdEntity.LastKnownState), getById.Executions[0].LastKnownState.String()) + suite.Equal(*createdVersionIdAsInt, getById.Executions[0].Properties["model_version_id"].GetIntValue()) + suite.Equal(*createdEntity.Description, getById.Executions[0].Properties["description"].GetStringValue()) + suite.Equal((*createdEntity.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, getById.Executions[0].CustomProperties["custom_string_prop"].GetStringValue()) + + inferenceServiceIdAsInt, _ := converter.StringToInt64(&inferenceServiceId) + byCtx, _ := suite.mlmdClient.GetExecutionsByContext(context.Background(), &proto.GetExecutionsByContextRequest{ + ContextId: (*int64)(inferenceServiceIdAsInt), + }) + suite.Equal(1, len(byCtx.Executions)) + suite.Equal(*createdEntityId, *byCtx.Executions[0].Id) +} + +func (suite *CoreTestSuite) TestCreateServeModelFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) + // end of data preparation + + eut := &openapi.ServeModel{ + LastKnownState: (*openapi.ExecutionState)(&executionState), + ExternalId: &entityExternalId2, + Description: &entityDescription, + Name: &entityName, + ModelVersionId: "9998", + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + _, err := service.UpsertServeModel(eut, nil) + suite.NotNil(err) + suite.Equal("missing inferenceServiceId, cannot create ServeModel without parent resource InferenceService: bad request", err.Error()) + + _, err = service.UpsertServeModel(eut, &inferenceServiceId) + suite.NotNil(err) + suite.Equal("no model version found for id 9998: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestUpdateServeModel() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Description: &modelVersionDescription, + Author: &author, + } + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersionId := *createdVersion.Id + createdVersionIdAsInt, _ := converter.StringToInt64(&createdVersionId) + // end of data preparation + + eut := &openapi.ServeModel{ + LastKnownState: (*openapi.ExecutionState)(&executionState), + ExternalId: &entityExternalId2, + Description: &entityDescription, + Name: &entityName, + ModelVersionId: createdVersionId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertServeModel(eut, &inferenceServiceId) + suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) + + newState := "UNKNOWN" + createdEntity.LastKnownState = (*openapi.ExecutionState)(&newState) + updatedEntity, err := service.UpsertServeModel(createdEntity, &inferenceServiceId) + suite.Nilf(err, "error updating entity for %d: %v", inferenceServiceId, err) + + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + updatedEntityId, _ := converter.StringToInt64(updatedEntity.Id) + suite.Equal(createdEntityId, updatedEntityId) + + getById, err := suite.mlmdClient.GetExecutionsByID(context.Background(), &proto.GetExecutionsByIDRequest{ + ExecutionIds: []int64{*createdEntityId}, + }) + suite.Nilf(err, "error getting by id %d", createdEntityId) + + suite.Equal(*createdEntityId, *getById.Executions[0].Id) + suite.Equal(fmt.Sprintf("%s:%s", inferenceServiceId, *createdEntity.Name), *getById.Executions[0].Name) + suite.Equal(string(newState), getById.Executions[0].LastKnownState.String()) + suite.Equal(*createdVersionIdAsInt, getById.Executions[0].Properties["model_version_id"].GetIntValue()) + suite.Equal((*createdEntity.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue, getById.Executions[0].CustomProperties["custom_string_prop"].GetStringValue()) + + prevModelVersionId := updatedEntity.ModelVersionId + updatedEntity.ModelVersionId = "" + updatedEntity, err = service.UpsertServeModel(updatedEntity, &inferenceServiceId) + suite.Nilf(err, "error updating entity for %d: %v", inferenceServiceId, err) + suite.Equal(prevModelVersionId, updatedEntity.ModelVersionId) +} + +func (suite *CoreTestSuite) TestUpdateServeModelFailure() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Description: &modelVersionDescription, + Author: &author, + } + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersionId := *createdVersion.Id + // end of data preparation + + eut := &openapi.ServeModel{ + LastKnownState: (*openapi.ExecutionState)(&executionState), + ExternalId: &entityExternalId2, + Description: &entityDescription, + Name: &entityName, + ModelVersionId: createdVersionId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertServeModel(eut, &inferenceServiceId) + suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) + suite.NotNil(createdEntity.Id, "created id should not be nil") + + newState := "UNKNOWN" + createdEntity.LastKnownState = (*openapi.ExecutionState)(&newState) + updatedEntity, err := service.UpsertServeModel(createdEntity, &inferenceServiceId) + suite.Nilf(err, "error updating entity for %d: %v", inferenceServiceId, err) + + wrongId := "9998" + updatedEntity.Id = &wrongId + _, err = service.UpsertServeModel(updatedEntity, &inferenceServiceId) + suite.NotNil(err) + suite.Equal(fmt.Sprintf("no ServeModel found for id %s: not found", wrongId), err.Error()) +} + +func (suite *CoreTestSuite) TestGetServeModelById() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) + + modelVersion := &openapi.ModelVersion{ + Name: modelVersionName, + ExternalId: &versionExternalId, + Description: &modelVersionDescription, + Author: &author, + } + createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersionId := *createdVersion.Id + // end of data preparation + + eut := &openapi.ServeModel{ + LastKnownState: (*openapi.ExecutionState)(&executionState), + ExternalId: &entityExternalId2, + Description: &entityDescription, + Name: &entityName, + ModelVersionId: createdVersionId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity, err := service.UpsertServeModel(eut, &inferenceServiceId) + suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) + + getById, err := service.GetServeModelById(*createdEntity.Id) + suite.Nilf(err, "error getting entity by id %d", *createdEntity.Id) + + state, _ := openapi.NewExecutionStateFromValue(executionState) + suite.NotNil(createdEntity.Id, "created artifact id should not be nil") + suite.Equal(entityName, *getById.Name) + suite.Equal(*state, *getById.LastKnownState) + suite.Equal(createdVersionId, getById.ModelVersionId) + suite.Equal(customString, (*getById.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) + + suite.Equal(*createdEntity, *getById, "artifacts returned during creation and on get by id should be equal") +} + +func (suite *CoreTestSuite) TestGetServeModels() { + // create mode registry service + service := suite.setupModelRegistryService() + + registeredModelId := suite.registerModel(service, nil, nil) + inferenceServiceId := suite.registerInferenceService(service, registeredModelId, nil, nil, nil, nil) + + modelVersion1Name := "v1" + modelVersion1 := &openapi.ModelVersion{Name: modelVersion1Name, Description: &modelVersionDescription} + createdVersion1, err := service.UpsertModelVersion(modelVersion1, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersion1Id := *createdVersion1.Id + + modelVersion2Name := "v2" + modelVersion2 := &openapi.ModelVersion{Name: modelVersion2Name, Description: &modelVersionDescription} + createdVersion2, err := service.UpsertModelVersion(modelVersion2, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersion2Id := *createdVersion2.Id + + modelVersion3Name := "v3" + modelVersion3 := &openapi.ModelVersion{Name: modelVersion3Name, Description: &modelVersionDescription} + createdVersion3, err := service.UpsertModelVersion(modelVersion3, ®isteredModelId) + suite.Nilf(err, "error creating new model version for %d", registeredModelId) + createdVersion3Id := *createdVersion3.Id + // end of data preparation + + eut1Name := "sm1" + eut1 := &openapi.ServeModel{ + LastKnownState: (*openapi.ExecutionState)(&executionState), + Description: &entityDescription, + Name: &eut1Name, + ModelVersionId: createdVersion1Id, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + eut2Name := "sm2" + eut2 := &openapi.ServeModel{ + LastKnownState: (*openapi.ExecutionState)(&executionState), + Description: &entityDescription, + Name: &eut2Name, + ModelVersionId: createdVersion2Id, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + eut3Name := "sm3" + eut3 := &openapi.ServeModel{ + LastKnownState: (*openapi.ExecutionState)(&executionState), + Description: &entityDescription, + Name: &eut3Name, + ModelVersionId: createdVersion3Id, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + createdEntity1, err := service.UpsertServeModel(eut1, &inferenceServiceId) + suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) + createdEntity2, err := service.UpsertServeModel(eut2, &inferenceServiceId) + suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) + createdEntity3, err := service.UpsertServeModel(eut3, &inferenceServiceId) + suite.Nilf(err, "error creating new ServeModel for %d", inferenceServiceId) + + createdEntityId1, _ := converter.StringToInt64(createdEntity1.Id) + createdEntityId2, _ := converter.StringToInt64(createdEntity2.Id) + createdEntityId3, _ := converter.StringToInt64(createdEntity3.Id) + + getAll, err := service.GetServeModels(api.ListOptions{}, nil) + suite.Nilf(err, "error getting all ServeModel") + suite.Equalf(int32(3), getAll.Size, "expected three ServeModel") + + suite.Equal(*converter.Int64ToString(createdEntityId1), *getAll.Items[0].Id) + suite.Equal(*converter.Int64ToString(createdEntityId2), *getAll.Items[1].Id) + suite.Equal(*converter.Int64ToString(createdEntityId3), *getAll.Items[2].Id) + + orderByLastUpdate := "LAST_UPDATE_TIME" + getAllByInferenceService, err := service.GetServeModels(api.ListOptions{ + OrderBy: &orderByLastUpdate, + SortOrder: &descOrderDirection, + }, &inferenceServiceId) + suite.Nilf(err, "error getting all ServeModels for %d", inferenceServiceId) + suite.Equalf(int32(3), getAllByInferenceService.Size, "expected three ServeModels for InferenceServiceId %d", inferenceServiceId) + + suite.Equal(*converter.Int64ToString(createdEntityId1), *getAllByInferenceService.Items[2].Id) + suite.Equal(*converter.Int64ToString(createdEntityId2), *getAllByInferenceService.Items[1].Id) + suite.Equal(*converter.Int64ToString(createdEntityId3), *getAllByInferenceService.Items[0].Id) +} diff --git a/pkg/core/serving_environment.go b/pkg/core/serving_environment.go new file mode 100644 index 00000000..9d4f306c --- /dev/null +++ b/pkg/core/serving_environment.go @@ -0,0 +1,163 @@ +package core + +import ( + "context" + "fmt" + + "github.com/golang/glog" + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// SERVING ENVIRONMENT + +// UpsertServingEnvironment creates a new serving environment if the provided serving environment's ID is nil, +// or updates an existing serving environment if the ID is provided. +func (serv *ModelRegistryService) UpsertServingEnvironment(servingEnvironment *openapi.ServingEnvironment) (*openapi.ServingEnvironment, error) { + var err error + var existing *openapi.ServingEnvironment + + if servingEnvironment.Id == nil { + glog.Info("Creating new serving environment") + } else { + glog.Infof("Updating serving environment %s", *servingEnvironment.Id) + existing, err = serv.GetServingEnvironmentById(*servingEnvironment.Id) + if err != nil { + return nil, err + } + + withNotEditable, err := serv.openapiConv.OverrideNotEditableForServingEnvironment(converter.NewOpenapiUpdateWrapper(existing, servingEnvironment)) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + servingEnvironment = &withNotEditable + } + + protoCtx, err := serv.mapper.MapFromServingEnvironment(servingEnvironment) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + protoCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ + Contexts: []*proto.Context{ + protoCtx, + }, + }) + if err != nil { + return nil, err + } + + idAsString := converter.Int64ToString(&protoCtxResp.ContextIds[0]) + openapiModel, err := serv.GetServingEnvironmentById(*idAsString) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return openapiModel, nil +} + +// GetServingEnvironmentById retrieves a serving environment by its unique identifier (ID). +func (serv *ModelRegistryService) GetServingEnvironmentById(id string) (*openapi.ServingEnvironment, error) { + glog.Infof("Getting serving environment %s", id) + + idAsInt, err := converter.StringToInt64(&id) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{*idAsInt}, + }) + if err != nil { + return nil, err + } + + if len(getByIdResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple serving environments found for id %s: %w", id, api.ErrNotFound) + } + + if len(getByIdResp.Contexts) == 0 { + return nil, fmt.Errorf("no serving environment found for id %s: %w", id, api.ErrNotFound) + } + + openapiModel, err := serv.mapper.MapToServingEnvironment(getByIdResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return openapiModel, nil +} + +// GetServingEnvironmentByParams retrieves a serving environment based on specified parameters, such as name or external ID. +// If multiple or no serving environments are found, an error is returned accordingly. +func (serv *ModelRegistryService) GetServingEnvironmentByParams(name *string, externalId *string) (*openapi.ServingEnvironment, error) { + glog.Infof("Getting serving environment by params name=%v, externalId=%v", name, externalId) + + filterQuery := "" + if name != nil { + filterQuery = fmt.Sprintf("name = \"%s\"", *name) + } else if externalId != nil { + filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) + } else { + return nil, fmt.Errorf("invalid parameters call, supply either name or externalId: %w", api.ErrBadRequest) + } + + getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ + TypeName: &serv.nameConfig.ServingEnvironmentTypeName, + Options: &proto.ListOperationOptions{ + FilterQuery: &filterQuery, + }, + }) + if err != nil { + return nil, err + } + + if len(getByParamsResp.Contexts) > 1 { + return nil, fmt.Errorf("multiple serving environments found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + if len(getByParamsResp.Contexts) == 0 { + return nil, fmt.Errorf("no serving environments found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + openapiModel, err := serv.mapper.MapToServingEnvironment(getByParamsResp.Contexts[0]) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + return openapiModel, nil +} + +// GetServingEnvironments retrieves a list of serving environments based on the provided list options. +func (serv *ModelRegistryService) GetServingEnvironments(listOptions api.ListOptions) (*openapi.ServingEnvironmentList, error) { + listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ + TypeName: &serv.nameConfig.ServingEnvironmentTypeName, + Options: listOperationOptions, + }) + if err != nil { + return nil, err + } + + results := []openapi.ServingEnvironment{} + for _, c := range contextsResp.Contexts { + mapped, err := serv.mapper.MapToServingEnvironment(c) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + results = append(results, *mapped) + } + + toReturn := openapi.ServingEnvironmentList{ + NextPageToken: apiutils.ZeroIfNil(contextsResp.NextPageToken), + PageSize: apiutils.ZeroIfNil(listOptions.PageSize), + Size: int32(len(results)), + Items: results, + } + return &toReturn, nil +} diff --git a/pkg/core/serving_environment_test.go b/pkg/core/serving_environment_test.go new file mode 100644 index 00000000..335113f8 --- /dev/null +++ b/pkg/core/serving_environment_test.go @@ -0,0 +1,395 @@ +package core + +import ( + "context" + + "github.com/kubeflow/model-registry/internal/apiutils" + "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +// SERVING ENVIRONMENT + +func (suite *CoreTestSuite) TestCreateServingEnvironment() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new ServingEnvironment + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + Description: &entityDescription, + CustomProperties: &map[string]openapi.MetadataValue{ + "myCustomProp": { + MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), + }, + }, + } + + // test + createdEntity, err := service.UpsertServingEnvironment(eut) + + // checks + suite.Nilf(err, "error creating uut: %v", err) + suite.NotNilf(createdEntity.Id, "created uut should not have nil Id") + + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{*createdEntityId}, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + ctxId := converter.Int64ToString(ctx.Id) + suite.Equal(*createdEntity.Id, *ctxId, "returned id should match the mlmd one") + suite.Equal(entityName, *ctx.Name, "saved name should match the provided one") + suite.Equal(entityExternalId, *ctx.ExternalId, "saved external id should match the provided one") + suite.Equal(entityDescription, ctx.Properties["description"].GetStringValue(), "saved description should match the provided one") + suite.Equal(myCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") + + getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") +} + +func (suite *CoreTestSuite) TestUpdateServingEnvironment() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new ServingEnvironment + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + CustomProperties: &map[string]openapi.MetadataValue{ + "myCustomProp": { + MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), + }, + }, + } + + // test + createdEntity, err := service.UpsertServingEnvironment(eut) + + // checks + suite.Nilf(err, "error creating uut: %v", err) + suite.NotNilf(createdEntity.Id, "created uut should not have nil Id") + createdEntityId, _ := converter.StringToInt64(createdEntity.Id) + + // checks created entity matches original one except for Id + suite.Equal(*eut.Name, *createdEntity.Name, "returned entity should match the original one") + suite.Equal(*eut.ExternalId, *createdEntity.ExternalId, "returned entity external id should match the original one") + suite.Equal(*eut.CustomProperties, *createdEntity.CustomProperties, "returned entity custom props should match the original one") + + // update existing entity + newExternalId := "newExternalId" + newCustomProp := "newCustomProp" + + createdEntity.ExternalId = &newExternalId + (*createdEntity.CustomProperties)["myCustomProp"] = openapi.MetadataValue{ + MetadataStringValue: converter.NewMetadataStringValue(newCustomProp), + } + + // update the entity + createdEntity, err = service.UpsertServingEnvironment(createdEntity) + suite.Nilf(err, "error creating uut: %v", err) + + // still one expected MLMD type + getAllResp, err := suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") + + ctxById, err := suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{*createdEntityId}, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx := ctxById.Contexts[0] + ctxId := converter.Int64ToString(ctx.Id) + suite.Equal(*createdEntity.Id, *ctxId, "returned entity id should match the mlmd one") + suite.Equal(entityName, *ctx.Name, "saved entity name should match the provided one") + suite.Equal(newExternalId, *ctx.ExternalId, "saved external id should match the provided one") + suite.Equal(newCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") + + // update the entity under test, keeping nil name + newExternalId = "newNewExternalId" + createdEntity.ExternalId = &newExternalId + createdEntity.Name = nil + createdEntity, err = service.UpsertServingEnvironment(createdEntity) + suite.Nilf(err, "error creating entity: %v", err) + + // still one registered entity + getAllResp, err = suite.mlmdClient.GetContexts(context.Background(), &proto.GetContextsRequest{}) + suite.Nilf(err, "error retrieving all contexts, not related to the test itself: %v", err) + suite.Equal(1, len(getAllResp.Contexts), "there should be just one context saved in mlmd") + + ctxById, err = suite.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ + ContextIds: []int64{*createdEntityId}, + }) + suite.Nilf(err, "error retrieving context by type and name, not related to the test itself: %v", err) + + ctx = ctxById.Contexts[0] + ctxId = converter.Int64ToString(ctx.Id) + suite.Equal(*createdEntity.Id, *ctxId, "returned entity id should match the mlmd one") + suite.Equal(entityName, *ctx.Name, "saved entity name should match the provided one") + suite.Equal(newExternalId, *ctx.ExternalId, "saved external id should match the provided one") + suite.Equal(newCustomProp, ctx.CustomProperties["myCustomProp"].GetStringValue(), "saved myCustomProp custom property should match the provided one") +} + +func (suite *CoreTestSuite) TestGetServingEnvironmentById() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new entity + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + CustomProperties: &map[string]openapi.MetadataValue{ + "myCustomProp": { + MetadataStringValue: converter.NewMetadataStringValue(myCustomProp), + }, + }, + } + + // test + createdEntity, err := service.UpsertServingEnvironment(eut) + + // checks + suite.Nilf(err, "error creating eut: %v", err) + + getEntityById, err := service.GetServingEnvironmentById(*createdEntity.Id) + suite.Nilf(err, "error getting eut by id %s: %v", *createdEntity.Id, err) + + // checks created entity matches original one except for Id + suite.Equal(*eut.Name, *getEntityById.Name, "saved name should match the original one") + suite.Equal(*eut.ExternalId, *getEntityById.ExternalId, "saved external id should match the original one") + suite.Equal(*eut.CustomProperties, *getEntityById.CustomProperties, "saved custom props should match the original one") +} + +func (suite *CoreTestSuite) TestGetServingEnvironmentByParamsWithNoResults() { + // create mode registry service + service := suite.setupModelRegistryService() + + _, err := service.GetServingEnvironmentByParams(apiutils.Of("not-present"), nil) + suite.NotNil(err) + suite.Equal("no serving environments found for name=not-present, externalId=: not found", err.Error()) +} + +func (suite *CoreTestSuite) TestGetServingEnvironmentByParamsName() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new ServingEnvironment + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + } + + createdEntity, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + byName, err := service.GetServingEnvironmentByParams(&entityName, nil) + suite.Nilf(err, "error getting ServingEnvironment by name: %v", err) + + suite.Equalf(*createdEntity.Id, *byName.Id, "the returned entity id should match the retrieved by name") +} + +func (suite *CoreTestSuite) TestGetServingEnvironmentByParamsExternalId() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new ServingEnvironment + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + } + + createdEntity, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + byName, err := service.GetServingEnvironmentByParams(nil, &entityExternalId) + suite.Nilf(err, "error getting ServingEnvironment by external id: %v", err) + + suite.Equalf(*createdEntity.Id, *byName.Id, "the returned entity id should match the retrieved by name") +} + +func (suite *CoreTestSuite) TestGetServingEnvironmentByEmptyParams() { + // create mode registry service + service := suite.setupModelRegistryService() + + // register a new ServingEnvironment + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + } + + _, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + _, err = service.GetServingEnvironmentByParams(nil, nil) + suite.NotNil(err) + suite.Equal("invalid parameters call, supply either name or externalId: bad request", err.Error()) +} + +func (suite *CoreTestSuite) TestGetServingEnvironmentsOrderedById() { + // create mode registry service + service := suite.setupModelRegistryService() + + orderBy := "ID" + + // register a new ServingEnvironment + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + } + + _, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + newName := "Pricingentity2" + newExternalId := "myExternalId2" + eut.Name = &newName + eut.ExternalId = &newExternalId + _, err = service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + newName = "Pricingentity3" + newExternalId = "myExternalId3" + eut.Name = &newName + eut.ExternalId = &newExternalId + _, err = service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + orderedById, err := service.GetServingEnvironments(api.ListOptions{ + OrderBy: &orderBy, + SortOrder: &ascOrderDirection, + }) + suite.Nilf(err, "error getting ServingEnvironment: %v", err) + + suite.Equal(3, int(orderedById.Size)) + for i := 0; i < int(orderedById.Size)-1; i++ { + suite.Less(*orderedById.Items[i].Id, *orderedById.Items[i+1].Id) + } + + orderedById, err = service.GetServingEnvironments(api.ListOptions{ + OrderBy: &orderBy, + SortOrder: &descOrderDirection, + }) + suite.Nilf(err, "error getting ServingEnvironments: %v", err) + + suite.Equal(3, int(orderedById.Size)) + for i := 0; i < int(orderedById.Size)-1; i++ { + suite.Greater(*orderedById.Items[i].Id, *orderedById.Items[i+1].Id) + } +} + +func (suite *CoreTestSuite) TestGetServingEnvironmentsOrderedByLastUpdate() { + // create mode registry service + service := suite.setupModelRegistryService() + + orderBy := "LAST_UPDATE_TIME" + + // register a new ServingEnvironment + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + } + + firstEntity, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + newName := "Pricingentity2" + newExternalId := "myExternalId2" + eut.Name = &newName + eut.ExternalId = &newExternalId + secondEntity, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + newName = "Pricingentity3" + newExternalId = "myExternalId3" + eut.Name = &newName + eut.ExternalId = &newExternalId + thirdEntity, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + // update second entity + secondEntity.ExternalId = nil + _, err = service.UpsertServingEnvironment(secondEntity) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + orderedById, err := service.GetServingEnvironments(api.ListOptions{ + OrderBy: &orderBy, + SortOrder: &ascOrderDirection, + }) + suite.Nilf(err, "error getting ServingEnvironments: %v", err) + + suite.Equal(3, int(orderedById.Size)) + suite.Equal(*firstEntity.Id, *orderedById.Items[0].Id) + suite.Equal(*thirdEntity.Id, *orderedById.Items[1].Id) + suite.Equal(*secondEntity.Id, *orderedById.Items[2].Id) + + orderedById, err = service.GetServingEnvironments(api.ListOptions{ + OrderBy: &orderBy, + SortOrder: &descOrderDirection, + }) + suite.Nilf(err, "error getting ServingEnvironments: %v", err) + + suite.Equal(3, int(orderedById.Size)) + suite.Equal(*secondEntity.Id, *orderedById.Items[0].Id) + suite.Equal(*thirdEntity.Id, *orderedById.Items[1].Id) + suite.Equal(*firstEntity.Id, *orderedById.Items[2].Id) +} + +func (suite *CoreTestSuite) TestGetServingEnvironmentsWithPageSize() { + // create mode registry service + service := suite.setupModelRegistryService() + + pageSize := int32(1) + pageSize2 := int32(2) + entityName := "Pricingentity1" + entityExternalId := "myExternalId1" + + // register a new ServingEnvironment + eut := &openapi.ServingEnvironment{ + Name: &entityName, + ExternalId: &entityExternalId, + } + + firstEntity, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating registered entity: %v", err) + + newName := "Pricingentity2" + newExternalId := "myExternalId2" + eut.Name = &newName + eut.ExternalId = &newExternalId + secondEntity, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + newName = "Pricingentity3" + newExternalId = "myExternalId3" + eut.Name = &newName + eut.ExternalId = &newExternalId + thirdEntity, err := service.UpsertServingEnvironment(eut) + suite.Nilf(err, "error creating ServingEnvironment: %v", err) + + truncatedList, err := service.GetServingEnvironments(api.ListOptions{ + PageSize: &pageSize, + }) + suite.Nilf(err, "error getting ServingEnvironments: %v", err) + + suite.Equal(1, int(truncatedList.Size)) + suite.NotEqual("", truncatedList.NextPageToken, "next page token should not be empty") + suite.Equal(*firstEntity.Id, *truncatedList.Items[0].Id) + + truncatedList, err = service.GetServingEnvironments(api.ListOptions{ + PageSize: &pageSize2, + NextPageToken: &truncatedList.NextPageToken, + }) + suite.Nilf(err, "error getting ServingEnvironments: %v", err) + + suite.Equal(2, int(truncatedList.Size)) + suite.Equal("", truncatedList.NextPageToken, "next page token should be empty as list item returned") + suite.Equal(*secondEntity.Id, *truncatedList.Items[0].Id) + suite.Equal(*thirdEntity.Id, *truncatedList.Items[1].Id) +}