From c544791e10d90fe71ece85af21ca7bc37bca0bec Mon Sep 17 00:00:00 2001 From: Matteo Mortari Date: Mon, 26 Feb 2024 10:24:16 +0100 Subject: [PATCH] kubeflow: make MLMD type names (and prefix) pluggable Signed-off-by: Matteo Mortari --- .../python/src/model_registry/types/base.py | 4 +- cmd/proxy.go | 5 +- internal/constants/constants.go | 12 --- .../converter/mlmd_converter_util_test.go | 20 ++-- .../converter/mlmd_openapi_converter_util.go | 6 +- .../converter/openapi_mlmd_converter_util.go | 16 ++-- internal/defaults/defaults.go | 12 +++ internal/mapper/mapper.go | 34 +++---- internal/mapper/mapper_test.go | 54 +++++------ internal/mlmdtypes/mlmdtypes.go | 94 +++++++++++-------- pkg/core/core.go | 84 ++++++++--------- pkg/core/core_test.go | 42 ++++++--- test/robot/MRandLogicalModel.robot | 6 +- 13 files changed, 203 insertions(+), 186 deletions(-) delete mode 100644 internal/constants/constants.go create mode 100644 internal/defaults/defaults.go diff --git a/clients/python/src/model_registry/types/base.py b/clients/python/src/model_registry/types/base.py index 2e3834a5..1d67bcce 100644 --- a/clients/python/src/model_registry/types/base.py +++ b/clients/python/src/model_registry/types/base.py @@ -20,9 +20,9 @@ def get_proto_type_name(cls) -> str: """Name of the proto type. Returns: - Name of the class prefixed with `kfmr.` + Name of the class prefixed with `kf.` """ - return f"kfmr.{cls.__name__}" + return f"kf.{cls.__name__}" @property @abstractmethod diff --git a/cmd/proxy.go b/cmd/proxy.go index f69d16b9..f940ce56 100644 --- a/cmd/proxy.go +++ b/cmd/proxy.go @@ -49,11 +49,12 @@ func runProxyServer(cmd *cobra.Command, args []string) error { defer conn.Close() glog.Infof("connected to MLMD server") - _, err = mlmdtypes.CreateMLMDTypes(conn) + mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults() + _, err = mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig) if err != nil { return fmt.Errorf("error creating MLMD types: %v", err) } - service, err := core.NewModelRegistryService(conn) + service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig) if err != nil { return fmt.Errorf("error creating core service: %v", err) } diff --git a/internal/constants/constants.go b/internal/constants/constants.go deleted file mode 100644 index 70427926..00000000 --- a/internal/constants/constants.go +++ /dev/null @@ -1,12 +0,0 @@ -package constants - -// MLMD type names -const ( - RegisteredModelTypeName = "kfmr.RegisteredModel" - ModelVersionTypeName = "kfmr.ModelVersion" - ModelArtifactTypeName = "kfmr.ModelArtifact" - DocArtifactTypeName = "kfmr.DocArtifact" - ServingEnvironmentTypeName = "kfmr.ServingEnvironment" - InferenceServiceTypeName = "kfmr.InferenceService" - ServeModelTypeName = "kfmr.ServeModel" -) diff --git a/internal/converter/mlmd_converter_util_test.go b/internal/converter/mlmd_converter_util_test.go index 5eddc08a..72a0adfb 100644 --- a/internal/converter/mlmd_converter_util_test.go +++ b/internal/converter/mlmd_converter_util_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/kubeflow/model-registry/internal/constants" + "github.com/kubeflow/model-registry/internal/defaults" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" "github.com/kubeflow/model-registry/pkg/openapi" "github.com/stretchr/testify/assert" @@ -208,7 +208,7 @@ func TestMapRegisteredModelType(t *testing.T) { typeName := MapRegisteredModelType(&openapi.RegisteredModel{}) assertion.NotNil(typeName) - assertion.Equal(constants.RegisteredModelTypeName, *typeName) + assertion.Equal(defaults.RegisteredModelTypeName, *typeName) } func TestMapModelVersionProperties(t *testing.T) { @@ -236,7 +236,7 @@ func TestMapModelVersionType(t *testing.T) { typeName := MapModelVersionType(&openapi.ModelVersion{}) assertion.NotNil(typeName) - assertion.Equal(constants.ModelVersionTypeName, *typeName) + assertion.Equal(defaults.ModelVersionTypeName, *typeName) } func TestMapModelVersionName(t *testing.T) { @@ -287,7 +287,7 @@ func TestMapModelArtifactType(t *testing.T) { typeName := MapModelArtifactType(&openapi.ModelArtifact{}) assertion.NotNil(typeName) - assertion.Equal(constants.ModelArtifactTypeName, *typeName) + assertion.Equal(defaults.ModelArtifactTypeName, *typeName) } func TestMapModelArtifactName(t *testing.T) { @@ -346,7 +346,7 @@ func TestMapDocArtifactType(t *testing.T) { typeName := MapModelArtifactType(&openapi.ModelArtifact{}) assertion.NotNil(typeName) - assertion.Equal(constants.ModelArtifactTypeName, *typeName) + assertion.Equal(defaults.ModelArtifactTypeName, *typeName) } func TestMapDocArtifactName(t *testing.T) { @@ -577,13 +577,13 @@ func TestMapArtifactType(t *testing.T) { assertion := setup(t) artifactType, err := MapArtifactType(&proto.Artifact{ - Type: of(constants.ModelArtifactTypeName), + Type: of(defaults.ModelArtifactTypeName), }) assertion.Nil(err) assertion.Equal("model-artifact", artifactType) artifactType, err = MapArtifactType(&proto.Artifact{ - Type: of(constants.DocArtifactTypeName), + Type: of(defaults.DocArtifactTypeName), }) assertion.Nil(err) assertion.Equal("doc-artifact", artifactType) @@ -659,7 +659,7 @@ func TestMapServingEnvironmentType(t *testing.T) { typeName := MapServingEnvironmentType(&openapi.ServingEnvironment{}) assertion.NotNil(typeName) - assertion.Equal(constants.ServingEnvironmentTypeName, *typeName) + assertion.Equal(defaults.ServingEnvironmentTypeName, *typeName) } func TestMapInferenceServiceType(t *testing.T) { @@ -667,7 +667,7 @@ func TestMapInferenceServiceType(t *testing.T) { typeName := MapInferenceServiceType(&openapi.InferenceService{}) assertion.NotNil(typeName) - assertion.Equal(constants.InferenceServiceTypeName, *typeName) + assertion.Equal(defaults.InferenceServiceTypeName, *typeName) } func TestMapInferenceServiceProperties(t *testing.T) { @@ -710,7 +710,7 @@ func TestMapServeModelType(t *testing.T) { typeName := MapServeModelType(&openapi.ServeModel{}) assertion.NotNil(typeName) - assertion.Equal(constants.ServeModelTypeName, *typeName) + assertion.Equal(defaults.ServeModelTypeName, *typeName) } func TestMapServeModelProperties(t *testing.T) { diff --git a/internal/converter/mlmd_openapi_converter_util.go b/internal/converter/mlmd_openapi_converter_util.go index 4dec4f50..d7b48590 100644 --- a/internal/converter/mlmd_openapi_converter_util.go +++ b/internal/converter/mlmd_openapi_converter_util.go @@ -6,7 +6,7 @@ import ( "fmt" "strings" - "github.com/kubeflow/model-registry/internal/constants" + "github.com/kubeflow/model-registry/internal/defaults" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" "github.com/kubeflow/model-registry/pkg/openapi" ) @@ -87,9 +87,9 @@ func MapArtifactType(source *proto.Artifact) (string, error) { return "", fmt.Errorf("artifact type is nil") } switch *source.Type { - case constants.ModelArtifactTypeName: + case defaults.ModelArtifactTypeName: return "model-artifact", nil - case constants.DocArtifactTypeName: + case defaults.DocArtifactTypeName: return "doc-artifact", nil default: return "", fmt.Errorf("invalid artifact type found: %v", source.Type) diff --git a/internal/converter/openapi_mlmd_converter_util.go b/internal/converter/openapi_mlmd_converter_util.go index 5f730ea9..8bea7d64 100644 --- a/internal/converter/openapi_mlmd_converter_util.go +++ b/internal/converter/openapi_mlmd_converter_util.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/google/uuid" - "github.com/kubeflow/model-registry/internal/constants" + "github.com/kubeflow/model-registry/internal/defaults" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" "github.com/kubeflow/model-registry/pkg/openapi" "google.golang.org/protobuf/types/known/structpb" @@ -144,7 +144,7 @@ func MapRegisteredModelProperties(source *openapi.RegisteredModel) (map[string]* // MapRegisteredModelType return RegisteredModel corresponding MLMD context type func MapRegisteredModelType(_ *openapi.RegisteredModel) *string { - return of(constants.RegisteredModelTypeName) + return of(defaults.RegisteredModelTypeName) } // MODEL VERSION @@ -194,7 +194,7 @@ func MapModelVersionProperties(source *OpenAPIModelWrapper[openapi.ModelVersion] // MapModelVersionType return ModelVersion corresponding MLMD context type func MapModelVersionType(_ *openapi.ModelVersion) *string { - return of(constants.ModelVersionTypeName) + return of(defaults.ModelVersionTypeName) } // MapModelVersionName maps the user-provided name into MLMD one, i.e., prefixing it with @@ -222,7 +222,7 @@ func MapOpenAPIArtifactState(source *openapi.ArtifactState) (*proto.Artifact_Sta // get DocArtifact MLMD type name func MapDocArtifactType(_ *openapi.DocArtifact) *string { - return of(constants.DocArtifactTypeName) + return of(defaults.DocArtifactTypeName) } func MapDocArtifactProperties(source *openapi.DocArtifact) (map[string]*proto.Value, error) { @@ -307,7 +307,7 @@ func MapModelArtifactProperties(source *openapi.ModelArtifact) (map[string]*prot // MapModelArtifactType return ModelArtifact corresponding MLMD context type func MapModelArtifactType(_ *openapi.ModelArtifact) *string { - return of(constants.ModelArtifactTypeName) + return of(defaults.ModelArtifactTypeName) } // MapModelArtifactName maps the user-provided name into MLMD one, i.e., prefixing it with @@ -328,7 +328,7 @@ func MapModelArtifactName(source *OpenAPIModelWrapper[openapi.ModelArtifact]) *s // MapServingEnvironmentType return ServingEnvironment corresponding MLMD context type func MapServingEnvironmentType(_ *openapi.ServingEnvironment) *string { - return of(constants.ServingEnvironmentTypeName) + return of(defaults.ServingEnvironmentTypeName) } // MapServingEnvironmentProperties maps ServingEnvironment fields to specific MLMD properties @@ -350,7 +350,7 @@ func MapServingEnvironmentProperties(source *openapi.ServingEnvironment) (map[st // MapInferenceServiceType return InferenceService corresponding MLMD context type func MapInferenceServiceType(_ *openapi.InferenceService) *string { - return of(constants.InferenceServiceTypeName) + return of(defaults.InferenceServiceTypeName) } // MapInferenceServiceProperties maps InferenceService fields to specific MLMD properties @@ -436,7 +436,7 @@ func MapInferenceServiceName(source *OpenAPIModelWrapper[openapi.InferenceServic // MapServeModelType return ServeModel corresponding MLMD context type func MapServeModelType(_ *openapi.ServeModel) *string { - return of(constants.ServeModelTypeName) + return of(defaults.ServeModelTypeName) } // MapServeModelProperties maps ServeModel fields to specific MLMD properties diff --git a/internal/defaults/defaults.go b/internal/defaults/defaults.go new file mode 100644 index 00000000..a14cce33 --- /dev/null +++ b/internal/defaults/defaults.go @@ -0,0 +1,12 @@ +package defaults + +// MLMD type names +const ( + RegisteredModelTypeName = "kf.RegisteredModel" + ModelVersionTypeName = "kf.ModelVersion" + ModelArtifactTypeName = "kf.ModelArtifact" + DocArtifactTypeName = "kf.DocArtifact" + ServingEnvironmentTypeName = "kf.ServingEnvironment" + InferenceServiceTypeName = "kf.InferenceService" + ServeModelTypeName = "kf.ServeModel" +) diff --git a/internal/mapper/mapper.go b/internal/mapper/mapper.go index 201f4819..9f28601a 100644 --- a/internal/mapper/mapper.go +++ b/internal/mapper/mapper.go @@ -3,9 +3,9 @@ package mapper import ( "fmt" - "github.com/kubeflow/model-registry/internal/constants" "github.com/kubeflow/model-registry/internal/converter" "github.com/kubeflow/model-registry/internal/converter/generated" + "github.com/kubeflow/model-registry/internal/defaults" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" "github.com/kubeflow/model-registry/pkg/openapi" ) @@ -28,14 +28,14 @@ func NewMapper(mlmdTypes map[string]int64) *Mapper { func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) { return m.OpenAPIConverter.ConvertRegisteredModel(&converter.OpenAPIModelWrapper[openapi.RegisteredModel]{ - TypeId: m.MLMDTypes[constants.RegisteredModelTypeName], + TypeId: m.MLMDTypes[defaults.RegisteredModelTypeName], Model: registeredModel, }) } func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registeredModelId string, registeredModelName *string) (*proto.Context, error) { return m.OpenAPIConverter.ConvertModelVersion(&converter.OpenAPIModelWrapper[openapi.ModelVersion]{ - TypeId: m.MLMDTypes[constants.ModelVersionTypeName], + TypeId: m.MLMDTypes[defaults.ModelVersionTypeName], Model: modelVersion, ParentResourceId: ®isteredModelId, ModelName: registeredModelName, @@ -44,7 +44,7 @@ func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registe func (m *Mapper) MapFromModelArtifact(modelArtifact *openapi.ModelArtifact, modelVersionId *string) (*proto.Artifact, error) { return m.OpenAPIConverter.ConvertModelArtifact(&converter.OpenAPIModelWrapper[openapi.ModelArtifact]{ - TypeId: m.MLMDTypes[constants.ModelArtifactTypeName], + TypeId: m.MLMDTypes[defaults.ModelArtifactTypeName], Model: modelArtifact, ParentResourceId: modelVersionId, }) @@ -52,7 +52,7 @@ func (m *Mapper) MapFromModelArtifact(modelArtifact *openapi.ModelArtifact, mode func (m *Mapper) MapFromDocArtifact(docArtifact *openapi.DocArtifact, modelVersionId *string) (*proto.Artifact, error) { return m.OpenAPIConverter.ConvertDocArtifact(&converter.OpenAPIModelWrapper[openapi.DocArtifact]{ - TypeId: m.MLMDTypes[constants.DocArtifactTypeName], + TypeId: m.MLMDTypes[defaults.DocArtifactTypeName], Model: docArtifact, ParentResourceId: modelVersionId, }) @@ -89,14 +89,14 @@ func (m *Mapper) MapFromModelArtifacts(modelArtifacts []openapi.ModelArtifact, m func (m *Mapper) MapFromServingEnvironment(servingEnvironment *openapi.ServingEnvironment) (*proto.Context, error) { return m.OpenAPIConverter.ConvertServingEnvironment(&converter.OpenAPIModelWrapper[openapi.ServingEnvironment]{ - TypeId: m.MLMDTypes[constants.ServingEnvironmentTypeName], + TypeId: m.MLMDTypes[defaults.ServingEnvironmentTypeName], Model: servingEnvironment, }) } func (m *Mapper) MapFromInferenceService(inferenceService *openapi.InferenceService, servingEnvironmentId string) (*proto.Context, error) { return m.OpenAPIConverter.ConvertInferenceService(&converter.OpenAPIModelWrapper[openapi.InferenceService]{ - TypeId: m.MLMDTypes[constants.InferenceServiceTypeName], + TypeId: m.MLMDTypes[defaults.InferenceServiceTypeName], Model: inferenceService, ParentResourceId: &servingEnvironmentId, }) @@ -104,7 +104,7 @@ func (m *Mapper) MapFromInferenceService(inferenceService *openapi.InferenceServ func (m *Mapper) MapFromServeModel(serveModel *openapi.ServeModel, inferenceServiceId string) (*proto.Execution, error) { return m.OpenAPIConverter.ConvertServeModel(&converter.OpenAPIModelWrapper[openapi.ServeModel]{ - TypeId: m.MLMDTypes[constants.ServeModelTypeName], + TypeId: m.MLMDTypes[defaults.ServeModelTypeName], Model: serveModel, ParentResourceId: &inferenceServiceId, }) @@ -113,19 +113,19 @@ func (m *Mapper) MapFromServeModel(serveModel *openapi.ServeModel, inferenceServ // Utilities for MLMD --> OpenAPI mapping, make use of generated Converters func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) { - return mapTo(ctx, m.MLMDTypes, constants.RegisteredModelTypeName, m.MLMDConverter.ConvertRegisteredModel) + return mapTo(ctx, m.MLMDTypes, defaults.RegisteredModelTypeName, m.MLMDConverter.ConvertRegisteredModel) } func (m *Mapper) MapToModelVersion(ctx *proto.Context) (*openapi.ModelVersion, error) { - return mapTo(ctx, m.MLMDTypes, constants.ModelVersionTypeName, m.MLMDConverter.ConvertModelVersion) + return mapTo(ctx, m.MLMDTypes, defaults.ModelVersionTypeName, m.MLMDConverter.ConvertModelVersion) } func (m *Mapper) MapToModelArtifact(art *proto.Artifact) (*openapi.ModelArtifact, error) { - return mapTo(art, m.MLMDTypes, constants.ModelArtifactTypeName, m.MLMDConverter.ConvertModelArtifact) + return mapTo(art, m.MLMDTypes, defaults.ModelArtifactTypeName, m.MLMDConverter.ConvertModelArtifact) } func (m *Mapper) MapToDocArtifact(art *proto.Artifact) (*openapi.DocArtifact, error) { - return mapTo(art, m.MLMDTypes, constants.DocArtifactTypeName, m.MLMDConverter.ConvertDocArtifact) + return mapTo(art, m.MLMDTypes, defaults.DocArtifactTypeName, m.MLMDConverter.ConvertDocArtifact) } func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) { @@ -136,12 +136,12 @@ func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) { return nil, fmt.Errorf("invalid artifact type, can't map from nil") } switch art.GetType() { - case constants.ModelArtifactTypeName: + case defaults.ModelArtifactTypeName: ma, err := m.MapToModelArtifact(art) return &openapi.Artifact{ ModelArtifact: ma, }, err - case constants.DocArtifactTypeName: + case defaults.DocArtifactTypeName: da, err := m.MapToDocArtifact(art) return &openapi.Artifact{ DocArtifact: da, @@ -152,15 +152,15 @@ func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) { } func (m *Mapper) MapToServingEnvironment(ctx *proto.Context) (*openapi.ServingEnvironment, error) { - return mapTo(ctx, m.MLMDTypes, constants.ServingEnvironmentTypeName, m.MLMDConverter.ConvertServingEnvironment) + return mapTo(ctx, m.MLMDTypes, defaults.ServingEnvironmentTypeName, m.MLMDConverter.ConvertServingEnvironment) } func (m *Mapper) MapToInferenceService(ctx *proto.Context) (*openapi.InferenceService, error) { - return mapTo(ctx, m.MLMDTypes, constants.InferenceServiceTypeName, m.MLMDConverter.ConvertInferenceService) + return mapTo(ctx, m.MLMDTypes, defaults.InferenceServiceTypeName, m.MLMDConverter.ConvertInferenceService) } func (m *Mapper) MapToServeModel(ex *proto.Execution) (*openapi.ServeModel, error) { - return mapTo(ex, m.MLMDTypes, constants.ServeModelTypeName, m.MLMDConverter.ConvertServeModel) + return mapTo(ex, m.MLMDTypes, defaults.ServeModelTypeName, m.MLMDConverter.ConvertServeModel) } type getTypeIder interface { diff --git a/internal/mapper/mapper_test.go b/internal/mapper/mapper_test.go index 82ce173f..04d55bbc 100644 --- a/internal/mapper/mapper_test.go +++ b/internal/mapper/mapper_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/kubeflow/model-registry/internal/constants" + "github.com/kubeflow/model-registry/internal/defaults" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" "github.com/kubeflow/model-registry/pkg/openapi" "github.com/stretchr/testify/assert" @@ -22,13 +22,13 @@ const ( ) var typesMap = map[string]int64{ - constants.RegisteredModelTypeName: registeredModelTypeId, - constants.ModelVersionTypeName: modelVersionTypeId, - constants.DocArtifactTypeName: docArtifactTypeId, - constants.ModelArtifactTypeName: modelArtifactTypeId, - constants.ServingEnvironmentTypeName: servingEnvironmentTypeId, - constants.InferenceServiceTypeName: inferenceServiceTypeId, - constants.ServeModelTypeName: serveModelTypeId, + defaults.RegisteredModelTypeName: registeredModelTypeId, + defaults.ModelVersionTypeName: modelVersionTypeId, + defaults.DocArtifactTypeName: docArtifactTypeId, + defaults.ModelArtifactTypeName: modelArtifactTypeId, + defaults.ServingEnvironmentTypeName: servingEnvironmentTypeId, + defaults.InferenceServiceTypeName: inferenceServiceTypeId, + defaults.ServeModelTypeName: serveModelTypeId, } func setup(t *testing.T) (*assert.Assertions, *Mapper) { @@ -148,7 +148,7 @@ func TestMapToRegisteredModel(t *testing.T) { assertion, m := setup(t) _, err := m.MapToRegisteredModel(&proto.Context{ TypeId: of(registeredModelTypeId), - Type: of(constants.RegisteredModelTypeName), + Type: of(defaults.RegisteredModelTypeName), }) assertion.Nil(err) } @@ -157,17 +157,17 @@ func TestMapToRegisteredModelInvalid(t *testing.T) { assertion, m := setup(t) _, err := m.MapToRegisteredModel(&proto.Context{ TypeId: of(invalidTypeId), - Type: of("kfmr.OtherEntity"), + Type: of("kf.OtherEntity"), }) assertion.NotNil(err) - assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kfmr.OtherEntity, please check the provided id", constants.RegisteredModelTypeName), err.Error()) + assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kf.OtherEntity, please check the provided id", defaults.RegisteredModelTypeName), err.Error()) } func TestMapToModelVersion(t *testing.T) { assertion, m := setup(t) _, err := m.MapToModelVersion(&proto.Context{ TypeId: of(modelVersionTypeId), - Type: of(constants.ModelVersionTypeName), + Type: of(defaults.ModelVersionTypeName), }) assertion.Nil(err) } @@ -176,17 +176,17 @@ func TestMapToModelVersionInvalid(t *testing.T) { assertion, m := setup(t) _, err := m.MapToModelVersion(&proto.Context{ TypeId: of(invalidTypeId), - Type: of("kfmr.OtherEntity"), + Type: of("kf.OtherEntity"), }) assertion.NotNil(err) - assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kfmr.OtherEntity, please check the provided id", constants.ModelVersionTypeName), err.Error()) + assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kf.OtherEntity, please check the provided id", defaults.ModelVersionTypeName), err.Error()) } func TestMapToDocArtifact(t *testing.T) { assertion, m := setup(t) _, err := m.MapToArtifact(&proto.Artifact{ TypeId: of(docArtifactTypeId), - Type: of(constants.DocArtifactTypeName), + Type: of(defaults.DocArtifactTypeName), }) assertion.Nil(err) } @@ -195,7 +195,7 @@ func TestMapToModelArtifact(t *testing.T) { assertion, m := setup(t) _, err := m.MapToArtifact(&proto.Artifact{ TypeId: of(modelArtifactTypeId), - Type: of(constants.ModelArtifactTypeName), + Type: of(defaults.ModelArtifactTypeName), }) assertion.Nil(err) } @@ -213,17 +213,17 @@ func TestMapToArtifactInvalid(t *testing.T) { assertion, m := setup(t) _, err := m.MapToArtifact(&proto.Artifact{ TypeId: of(invalidTypeId), - Type: of("kfmr.OtherEntity"), + Type: of("kf.OtherEntity"), }) assertion.NotNil(err) - assertion.Equal("unknown artifact type: kfmr.OtherEntity", err.Error()) + assertion.Equal("unknown artifact type: kf.OtherEntity", err.Error()) } func TestMapToServingEnvironment(t *testing.T) { assertion, m := setup(t) _, err := m.MapToServingEnvironment(&proto.Context{ TypeId: of(servingEnvironmentTypeId), - Type: of(constants.ServingEnvironmentTypeName), + Type: of(defaults.ServingEnvironmentTypeName), }) assertion.Nil(err) } @@ -232,17 +232,17 @@ func TestMapToServingEnvironmentInvalid(t *testing.T) { assertion, m := setup(t) _, err := m.MapToServingEnvironment(&proto.Context{ TypeId: of(invalidTypeId), - Type: of("kfmr.OtherEntity"), + Type: of("kf.OtherEntity"), }) assertion.NotNil(err) - assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kfmr.OtherEntity, please check the provided id", constants.ServingEnvironmentTypeName), err.Error()) + assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kf.OtherEntity, please check the provided id", defaults.ServingEnvironmentTypeName), err.Error()) } func TestMapToInferenceService(t *testing.T) { assertion, m := setup(t) _, err := m.MapToInferenceService(&proto.Context{ TypeId: of(inferenceServiceTypeId), - Type: of(constants.InferenceServiceTypeName), + Type: of(defaults.InferenceServiceTypeName), }) assertion.Nil(err) } @@ -251,17 +251,17 @@ func TestMapToInferenceServiceInvalid(t *testing.T) { assertion, m := setup(t) _, err := m.MapToInferenceService(&proto.Context{ TypeId: of(invalidTypeId), - Type: of("kfmr.OtherEntity"), + Type: of("kf.OtherEntity"), }) assertion.NotNil(err) - assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kfmr.OtherEntity, please check the provided id", constants.InferenceServiceTypeName), err.Error()) + assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kf.OtherEntity, please check the provided id", defaults.InferenceServiceTypeName), err.Error()) } func TestMapToServeModel(t *testing.T) { assertion, m := setup(t) _, err := m.MapToServeModel(&proto.Execution{ TypeId: of(serveModelTypeId), - Type: of(constants.ServeModelTypeName), + Type: of(defaults.ServeModelTypeName), }) assertion.Nil(err) } @@ -270,10 +270,10 @@ func TestMapToServeModelInvalid(t *testing.T) { assertion, m := setup(t) _, err := m.MapToServeModel(&proto.Execution{ TypeId: of(invalidTypeId), - Type: of("kfmr.OtherEntity"), + Type: of("kf.OtherEntity"), }) assertion.NotNil(err) - assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kfmr.OtherEntity, please check the provided id", constants.ServeModelTypeName), err.Error()) + assertion.Equal(fmt.Sprintf("invalid entity: expected %s but received kf.OtherEntity, please check the provided id", defaults.ServeModelTypeName), err.Error()) } func TestMapTo(t *testing.T) { diff --git a/internal/mlmdtypes/mlmdtypes.go b/internal/mlmdtypes/mlmdtypes.go index 4a3af7e8..73b40399 100644 --- a/internal/mlmdtypes/mlmdtypes.go +++ b/internal/mlmdtypes/mlmdtypes.go @@ -4,32 +4,44 @@ import ( "context" "fmt" - "github.com/kubeflow/model-registry/internal/apiutils" - "github.com/kubeflow/model-registry/internal/constants" + "github.com/kubeflow/model-registry/internal/defaults" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" "google.golang.org/grpc" ) -var ( - registeredModelTypeName = apiutils.Of(constants.RegisteredModelTypeName) - modelVersionTypeName = apiutils.Of(constants.ModelVersionTypeName) - modelArtifactTypeName = apiutils.Of(constants.ModelArtifactTypeName) - docArtifactTypeName = apiutils.Of(constants.DocArtifactTypeName) - servingEnvironmentTypeName = apiutils.Of(constants.ServingEnvironmentTypeName) - inferenceServiceTypeName = apiutils.Of(constants.InferenceServiceTypeName) - serveModelTypeName = apiutils.Of(constants.ServeModelTypeName) - canAddFields = apiutils.Of(true) -) +type MLMDTypeNamesConfig struct { + RegisteredModelTypeName string + ModelVersionTypeName string + ModelArtifactTypeName string + DocArtifactTypeName string + ServingEnvironmentTypeName string + InferenceServiceTypeName string + ServeModelTypeName string + CanAddFields bool +} + +func NewMLMDTypeNamesConfigFromDefaults() MLMDTypeNamesConfig { + return MLMDTypeNamesConfig{ + RegisteredModelTypeName: defaults.RegisteredModelTypeName, + ModelVersionTypeName: defaults.ModelVersionTypeName, + ModelArtifactTypeName: defaults.ModelArtifactTypeName, + DocArtifactTypeName: defaults.DocArtifactTypeName, + ServingEnvironmentTypeName: defaults.ServingEnvironmentTypeName, + InferenceServiceTypeName: defaults.InferenceServiceTypeName, + ServeModelTypeName: defaults.ServeModelTypeName, + CanAddFields: true, + } +} // Utility method that created the necessary Model Registry's logical-model types // as the necessary MLMD's Context, Artifact, Execution types etc. in the underlying MLMD service -func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) { +func CreateMLMDTypes(cc grpc.ClientConnInterface, nameConfig MLMDTypeNamesConfig) (map[string]int64, error) { client := proto.NewMetadataStoreServiceClient(cc) registeredModelReq := proto.PutContextTypeRequest{ - CanAddFields: canAddFields, + CanAddFields: &nameConfig.CanAddFields, ContextType: &proto.ContextType{ - Name: registeredModelTypeName, + Name: &nameConfig.RegisteredModelTypeName, Properties: map[string]proto.PropertyType{ "description": proto.PropertyType_STRING, "state": proto.PropertyType_STRING, @@ -38,9 +50,9 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) { } modelVersionReq := proto.PutContextTypeRequest{ - CanAddFields: canAddFields, + CanAddFields: &nameConfig.CanAddFields, ContextType: &proto.ContextType{ - Name: modelVersionTypeName, + Name: &nameConfig.ModelVersionTypeName, Properties: map[string]proto.PropertyType{ "description": proto.PropertyType_STRING, "model_name": proto.PropertyType_STRING, @@ -52,9 +64,9 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) { } docArtifactReq := proto.PutArtifactTypeRequest{ - CanAddFields: canAddFields, + CanAddFields: &nameConfig.CanAddFields, ArtifactType: &proto.ArtifactType{ - Name: docArtifactTypeName, + Name: &nameConfig.DocArtifactTypeName, Properties: map[string]proto.PropertyType{ "description": proto.PropertyType_STRING, }, @@ -62,9 +74,9 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) { } modelArtifactReq := proto.PutArtifactTypeRequest{ - CanAddFields: canAddFields, + CanAddFields: &nameConfig.CanAddFields, ArtifactType: &proto.ArtifactType{ - Name: modelArtifactTypeName, + Name: &nameConfig.ModelArtifactTypeName, Properties: map[string]proto.PropertyType{ "description": proto.PropertyType_STRING, "model_format_name": proto.PropertyType_STRING, @@ -77,9 +89,9 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) { } servingEnvironmentReq := proto.PutContextTypeRequest{ - CanAddFields: canAddFields, + CanAddFields: &nameConfig.CanAddFields, ContextType: &proto.ContextType{ - Name: servingEnvironmentTypeName, + Name: &nameConfig.ServingEnvironmentTypeName, Properties: map[string]proto.PropertyType{ "description": proto.PropertyType_STRING, }, @@ -87,9 +99,9 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) { } inferenceServiceReq := proto.PutContextTypeRequest{ - CanAddFields: canAddFields, + CanAddFields: &nameConfig.CanAddFields, ContextType: &proto.ContextType{ - Name: inferenceServiceTypeName, + Name: &nameConfig.InferenceServiceTypeName, Properties: map[string]proto.PropertyType{ "description": proto.PropertyType_STRING, "model_version_id": proto.PropertyType_INT, @@ -103,9 +115,9 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) { } serveModelReq := proto.PutExecutionTypeRequest{ - CanAddFields: canAddFields, + CanAddFields: &nameConfig.CanAddFields, ExecutionType: &proto.ExecutionType{ - Name: serveModelTypeName, + Name: &nameConfig.ServeModelTypeName, Properties: map[string]proto.PropertyType{ "description": proto.PropertyType_STRING, "model_version_id": proto.PropertyType_INT, @@ -115,47 +127,47 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) { registeredModelResp, err := client.PutContextType(context.Background(), ®isteredModelReq) if err != nil { - return nil, fmt.Errorf("error setting up context type %s: %v", *registeredModelTypeName, err) + return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.RegisteredModelTypeName, err) } modelVersionResp, err := client.PutContextType(context.Background(), &modelVersionReq) if err != nil { - return nil, fmt.Errorf("error setting up context type %s: %v", *modelVersionTypeName, err) + return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.ModelVersionTypeName, err) } docArtifactResp, err := client.PutArtifactType(context.Background(), &docArtifactReq) if err != nil { - return nil, fmt.Errorf("error setting up artifact type %s: %v", *docArtifactTypeName, err) + return nil, fmt.Errorf("error setting up artifact type %s: %v", nameConfig.DocArtifactTypeName, err) } modelArtifactResp, err := client.PutArtifactType(context.Background(), &modelArtifactReq) if err != nil { - return nil, fmt.Errorf("error setting up artifact type %s: %v", *modelArtifactTypeName, err) + return nil, fmt.Errorf("error setting up artifact type %s: %v", nameConfig.ModelArtifactTypeName, err) } servingEnvironmentResp, err := client.PutContextType(context.Background(), &servingEnvironmentReq) if err != nil { - return nil, fmt.Errorf("error setting up context type %s: %v", *servingEnvironmentTypeName, err) + return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.ServingEnvironmentTypeName, err) } inferenceServiceResp, err := client.PutContextType(context.Background(), &inferenceServiceReq) if err != nil { - return nil, fmt.Errorf("error setting up context type %s: %v", *inferenceServiceTypeName, err) + return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.InferenceServiceTypeName, err) } serveModelResp, err := client.PutExecutionType(context.Background(), &serveModelReq) if err != nil { - return nil, fmt.Errorf("error setting up execution type %s: %v", *serveModelTypeName, err) + return nil, fmt.Errorf("error setting up execution type %s: %v", nameConfig.ServeModelTypeName, err) } typesMap := map[string]int64{ - constants.RegisteredModelTypeName: registeredModelResp.GetTypeId(), - constants.ModelVersionTypeName: modelVersionResp.GetTypeId(), - constants.DocArtifactTypeName: docArtifactResp.GetTypeId(), - constants.ModelArtifactTypeName: modelArtifactResp.GetTypeId(), - constants.ServingEnvironmentTypeName: servingEnvironmentResp.GetTypeId(), - constants.InferenceServiceTypeName: inferenceServiceResp.GetTypeId(), - constants.ServeModelTypeName: serveModelResp.GetTypeId(), + defaults.RegisteredModelTypeName: registeredModelResp.GetTypeId(), + defaults.ModelVersionTypeName: modelVersionResp.GetTypeId(), + defaults.DocArtifactTypeName: docArtifactResp.GetTypeId(), + defaults.ModelArtifactTypeName: modelArtifactResp.GetTypeId(), + defaults.ServingEnvironmentTypeName: servingEnvironmentResp.GetTypeId(), + defaults.InferenceServiceTypeName: inferenceServiceResp.GetTypeId(), + defaults.ServeModelTypeName: serveModelResp.GetTypeId(), } return typesMap, nil } diff --git a/pkg/core/core.go b/pkg/core/core.go index accc2a3d..dc0ec9fe 100644 --- a/pkg/core/core.go +++ b/pkg/core/core.go @@ -7,29 +7,20 @@ import ( "github.com/golang/glog" "github.com/kubeflow/model-registry/internal/apiutils" - "github.com/kubeflow/model-registry/internal/constants" "github.com/kubeflow/model-registry/internal/converter" "github.com/kubeflow/model-registry/internal/converter/generated" "github.com/kubeflow/model-registry/internal/mapper" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" + "github.com/kubeflow/model-registry/internal/mlmdtypes" "github.com/kubeflow/model-registry/pkg/api" "github.com/kubeflow/model-registry/pkg/openapi" "google.golang.org/grpc" ) -var ( - registeredModelTypeName = apiutils.Of(constants.RegisteredModelTypeName) - modelVersionTypeName = apiutils.Of(constants.ModelVersionTypeName) - modelArtifactTypeName = apiutils.Of(constants.ModelArtifactTypeName) - docArtifactTypeName = apiutils.Of(constants.DocArtifactTypeName) - servingEnvironmentTypeName = apiutils.Of(constants.ServingEnvironmentTypeName) - inferenceServiceTypeName = apiutils.Of(constants.InferenceServiceTypeName) - serveModelTypeName = apiutils.Of(constants.ServeModelTypeName) -) - // ModelRegistryService is the core library of the model registry type ModelRegistryService struct { mlmdClient proto.MetadataStoreServiceClient + nameConfig mlmdtypes.MLMDTypeNamesConfig typesMap map[string]int64 mapper *mapper.Mapper openapiConv *generated.OpenAPIConverterImpl @@ -40,8 +31,8 @@ type ModelRegistryService struct { // // Parameters: // - cc: A gRPC client connection to the underlying MLMD service -func NewModelRegistryService(cc grpc.ClientConnInterface) (api.ModelRegistryApi, error) { - typesMap, err := BuildTypesMap(cc) +func NewModelRegistryService(cc grpc.ClientConnInterface, nameConfig mlmdtypes.MLMDTypeNamesConfig) (api.ModelRegistryApi, error) { + typesMap, err := BuildTypesMap(cc, nameConfig) if err != nil { // early return in case type Ids cannot be retrieved return nil, err } @@ -50,72 +41,73 @@ func NewModelRegistryService(cc grpc.ClientConnInterface) (api.ModelRegistryApi, return &ModelRegistryService{ mlmdClient: client, + nameConfig: nameConfig, typesMap: typesMap, openapiConv: &generated.OpenAPIConverterImpl{}, mapper: mapper.NewMapper(typesMap), }, nil } -func BuildTypesMap(cc grpc.ClientConnInterface) (map[string]int64, error) { +func BuildTypesMap(cc grpc.ClientConnInterface, nameConfig mlmdtypes.MLMDTypeNamesConfig) (map[string]int64, error) { client := proto.NewMetadataStoreServiceClient(cc) registeredModelContextTypeReq := proto.GetContextTypeRequest{ - TypeName: registeredModelTypeName, + TypeName: &nameConfig.RegisteredModelTypeName, } registeredModelResp, err := client.GetContextType(context.Background(), ®isteredModelContextTypeReq) if err != nil { - return nil, fmt.Errorf("error getting context type %s: %v", *registeredModelTypeName, err) + return nil, fmt.Errorf("error getting context type %s: %v", nameConfig.RegisteredModelTypeName, err) } modelVersionContextTypeReq := proto.GetContextTypeRequest{ - TypeName: modelVersionTypeName, + TypeName: &nameConfig.ModelVersionTypeName, } modelVersionResp, err := client.GetContextType(context.Background(), &modelVersionContextTypeReq) if err != nil { - return nil, fmt.Errorf("error getting context type %s: %v", *modelVersionTypeName, err) + return nil, fmt.Errorf("error getting context type %s: %v", nameConfig.ModelVersionTypeName, err) } docArtifactResp, err := client.GetArtifactType(context.Background(), &proto.GetArtifactTypeRequest{ - TypeName: docArtifactTypeName, + TypeName: &nameConfig.DocArtifactTypeName, }) if err != nil { - return nil, fmt.Errorf("error getting artifact type %s: %v", *docArtifactTypeName, err) + return nil, fmt.Errorf("error getting artifact type %s: %v", nameConfig.DocArtifactTypeName, err) } modelArtifactArtifactTypeReq := proto.GetArtifactTypeRequest{ - TypeName: modelArtifactTypeName, + TypeName: &nameConfig.ModelArtifactTypeName, } modelArtifactResp, err := client.GetArtifactType(context.Background(), &modelArtifactArtifactTypeReq) if err != nil { - return nil, fmt.Errorf("error getting artifact type %s: %v", *modelArtifactTypeName, err) + return nil, fmt.Errorf("error getting artifact type %s: %v", nameConfig.ModelArtifactTypeName, err) } servingEnvironmentContextTypeReq := proto.GetContextTypeRequest{ - TypeName: servingEnvironmentTypeName, + TypeName: &nameConfig.ServingEnvironmentTypeName, } servingEnvironmentResp, err := client.GetContextType(context.Background(), &servingEnvironmentContextTypeReq) if err != nil { - return nil, fmt.Errorf("error getting context type %s: %v", *servingEnvironmentTypeName, err) + return nil, fmt.Errorf("error getting context type %s: %v", nameConfig.ServingEnvironmentTypeName, err) } inferenceServiceContextTypeReq := proto.GetContextTypeRequest{ - TypeName: inferenceServiceTypeName, + TypeName: &nameConfig.InferenceServiceTypeName, } inferenceServiceResp, err := client.GetContextType(context.Background(), &inferenceServiceContextTypeReq) if err != nil { - return nil, fmt.Errorf("error getting context type %s: %v", *inferenceServiceTypeName, err) + return nil, fmt.Errorf("error getting context type %s: %v", nameConfig.InferenceServiceTypeName, err) } serveModelExecutionReq := proto.GetExecutionTypeRequest{ - TypeName: serveModelTypeName, + TypeName: &nameConfig.ServeModelTypeName, } serveModelResp, err := client.GetExecutionType(context.Background(), &serveModelExecutionReq) if err != nil { - return nil, fmt.Errorf("error getting execution type %s: %v", *serveModelTypeName, err) + return nil, fmt.Errorf("error getting execution type %s: %v", nameConfig.ServeModelTypeName, err) } typesMap := map[string]int64{ - constants.RegisteredModelTypeName: registeredModelResp.ContextType.GetId(), - constants.ModelVersionTypeName: modelVersionResp.ContextType.GetId(), - constants.DocArtifactTypeName: docArtifactResp.ArtifactType.GetId(), - constants.ModelArtifactTypeName: modelArtifactResp.ArtifactType.GetId(), - constants.ServingEnvironmentTypeName: servingEnvironmentResp.ContextType.GetId(), - constants.InferenceServiceTypeName: inferenceServiceResp.ContextType.GetId(), - constants.ServeModelTypeName: serveModelResp.ExecutionType.GetId(), + nameConfig.RegisteredModelTypeName: registeredModelResp.ContextType.GetId(), + nameConfig.ModelVersionTypeName: modelVersionResp.ContextType.GetId(), + nameConfig.DocArtifactTypeName: docArtifactResp.ArtifactType.GetId(), + nameConfig.ModelArtifactTypeName: modelArtifactResp.ArtifactType.GetId(), + nameConfig.ServingEnvironmentTypeName: servingEnvironmentResp.ContextType.GetId(), + nameConfig.InferenceServiceTypeName: inferenceServiceResp.ContextType.GetId(), + nameConfig.ServeModelTypeName: serveModelResp.ExecutionType.GetId(), } return typesMap, nil } @@ -255,7 +247,7 @@ func (serv *ModelRegistryService) GetRegisteredModelByParams(name *string, exter } getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: registeredModelTypeName, + TypeName: &serv.nameConfig.RegisteredModelTypeName, Options: &proto.ListOperationOptions{ FilterQuery: &filterQuery, }, @@ -286,7 +278,7 @@ func (serv *ModelRegistryService) GetRegisteredModels(listOptions api.ListOption return nil, err } contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: registeredModelTypeName, + TypeName: &serv.nameConfig.RegisteredModelTypeName, Options: listOperationOptions, }) if err != nil { @@ -489,7 +481,7 @@ func (serv *ModelRegistryService) GetModelVersionByParams(versionName *string, r } getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: modelVersionTypeName, + TypeName: &serv.nameConfig.ModelVersionTypeName, Options: &proto.ListOperationOptions{ FilterQuery: &filterQuery, }, @@ -526,7 +518,7 @@ func (serv *ModelRegistryService) GetModelVersions(listOptions api.ListOptions, } contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: modelVersionTypeName, + TypeName: &serv.nameConfig.ModelVersionTypeName, Options: listOperationOptions, }) if err != nil { @@ -789,7 +781,7 @@ func (serv *ModelRegistryService) GetModelArtifactByParams(artifactName *string, } artifactsResponse, err := serv.mlmdClient.GetArtifactsByType(context.Background(), &proto.GetArtifactsByTypeRequest{ - TypeName: modelArtifactTypeName, + TypeName: &serv.nameConfig.ModelArtifactTypeName, Options: &proto.ListOperationOptions{ FilterQuery: &filterQuery, }, @@ -841,7 +833,7 @@ func (serv *ModelRegistryService) GetModelArtifacts(listOptions api.ListOptions, nextPageToken = artifactsResp.NextPageToken } else { artifactsResp, err := serv.mlmdClient.GetArtifactsByType(context.Background(), &proto.GetArtifactsByTypeRequest{ - TypeName: modelArtifactTypeName, + TypeName: &serv.nameConfig.ModelArtifactTypeName, Options: listOperationOptions, }) if err != nil { @@ -963,7 +955,7 @@ func (serv *ModelRegistryService) GetServingEnvironmentByParams(name *string, ex } getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: servingEnvironmentTypeName, + TypeName: &serv.nameConfig.ServingEnvironmentTypeName, Options: &proto.ListOperationOptions{ FilterQuery: &filterQuery, }, @@ -994,7 +986,7 @@ func (serv *ModelRegistryService) GetServingEnvironments(listOptions api.ListOpt return nil, err } contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: servingEnvironmentTypeName, + TypeName: &serv.nameConfig.ServingEnvironmentTypeName, Options: listOperationOptions, }) if err != nil { @@ -1187,7 +1179,7 @@ func (serv *ModelRegistryService) GetInferenceServiceByParams(name *string, serv } getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: inferenceServiceTypeName, + TypeName: &serv.nameConfig.InferenceServiceTypeName, Options: &proto.ListOperationOptions{ FilterQuery: &filterQuery, }, @@ -1233,7 +1225,7 @@ func (serv *ModelRegistryService) GetInferenceServices(listOptions api.ListOptio listOperationOptions.FilterQuery = &query contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ - TypeName: inferenceServiceTypeName, + TypeName: &serv.nameConfig.InferenceServiceTypeName, Options: listOperationOptions, }) if err != nil { @@ -1437,7 +1429,7 @@ func (serv *ModelRegistryService) GetServeModels(listOptions api.ListOptions, in nextPageToken = executionsResp.NextPageToken } else { executionsResp, err := serv.mlmdClient.GetExecutionsByType(context.Background(), &proto.GetExecutionsByTypeRequest{ - TypeName: serveModelTypeName, + TypeName: &serv.nameConfig.ServeModelTypeName, Options: listOperationOptions, }) if err != nil { diff --git a/pkg/core/core_test.go b/pkg/core/core_test.go index c364101f..5eba3392 100644 --- a/pkg/core/core_test.go +++ b/pkg/core/core_test.go @@ -7,6 +7,7 @@ import ( "github.com/kubeflow/model-registry/internal/apiutils" "github.com/kubeflow/model-registry/internal/converter" + "github.com/kubeflow/model-registry/internal/defaults" "github.com/kubeflow/model-registry/internal/ml_metadata/proto" "github.com/kubeflow/model-registry/internal/mlmdtypes" "github.com/kubeflow/model-registry/internal/testutils" @@ -53,7 +54,17 @@ type CoreTestSuite struct { mlmdClient proto.MetadataStoreServiceClient } -var canAddFields = apiutils.Of(true) +// test defaults +var ( + registeredModelTypeName = apiutils.Of(defaults.RegisteredModelTypeName) + modelVersionTypeName = apiutils.Of(defaults.ModelVersionTypeName) + modelArtifactTypeName = apiutils.Of(defaults.ModelArtifactTypeName) + docArtifactTypeName = apiutils.Of(defaults.DocArtifactTypeName) + servingEnvironmentTypeName = apiutils.Of(defaults.ServingEnvironmentTypeName) + inferenceServiceTypeName = apiutils.Of(defaults.InferenceServiceTypeName) + serveModelTypeName = apiutils.Of(defaults.ServeModelTypeName) + canAddFields = apiutils.Of(true) +) func TestRunCoreTestSuite(t *testing.T) { // before all @@ -102,10 +113,11 @@ func (suite *CoreTestSuite) AfterTest(suiteName, testName string) { } func (suite *CoreTestSuite) setupModelRegistryService() *ModelRegistryService { - _, err := mlmdtypes.CreateMLMDTypes(suite.grpcConn) + mlmdtypeNames := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults() + _, err := mlmdtypes.CreateMLMDTypes(suite.grpcConn, mlmdtypeNames) suite.Nilf(err, "error creating MLMD types: %v", err) // setup model registry service - service, err := NewModelRegistryService(suite.grpcConn) + service, err := NewModelRegistryService(suite.grpcConn, mlmdtypeNames) suite.Nilf(err, "error creating core service: %v", err) mrService, ok := service.(*ModelRegistryService) suite.True(ok) @@ -451,9 +463,9 @@ func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInRegisteredM suite.Nil(err) // steps to create model registry service - _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn) + _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn, mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()) suite.NotNil(err) - suite.Regexp("error setting up context type kfmr.RegisteredModel: rpc error: code = AlreadyExists.*", err.Error()) + suite.Regexp("error setting up context type "+*registeredModelTypeName+": rpc error: code = AlreadyExists.*", err.Error()) } func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInModelVersion() { @@ -471,9 +483,9 @@ func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInModelVersio suite.Nil(err) // steps to create model registry service - _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn) + _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn, mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()) suite.NotNil(err) - suite.Regexp("error setting up context type kfmr.ModelVersion: rpc error: code = AlreadyExists.*", err.Error()) + suite.Regexp("error setting up context type "+*modelVersionTypeName+": rpc error: code = AlreadyExists.*", err.Error()) } func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInModelArtifact() { @@ -491,9 +503,9 @@ func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInModelArtifa suite.Nil(err) // steps to create model registry service - _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn) + _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn, mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()) suite.NotNil(err) - suite.Regexp("error setting up artifact type kfmr.ModelArtifact: rpc error: code = AlreadyExists.*", err.Error()) + suite.Regexp("error setting up artifact type "+*modelArtifactTypeName+": rpc error: code = AlreadyExists.*", err.Error()) } func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInServingEnvironment() { @@ -510,9 +522,9 @@ func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInServingEnvi suite.Nil(err) // steps to create model registry service - _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn) + _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn, mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()) suite.NotNil(err) - suite.Regexp("error setting up context type kfmr.ServingEnvironment: rpc error: code = AlreadyExists.*", err.Error()) + suite.Regexp("error setting up context type "+*servingEnvironmentTypeName+": rpc error: code = AlreadyExists.*", err.Error()) } func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInInferenceService() { @@ -530,9 +542,9 @@ func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInInferenceSe suite.Nil(err) // steps to create model registry service - _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn) + _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn, mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()) suite.NotNil(err) - suite.Regexp("error setting up context type kfmr.InferenceService: rpc error: code = AlreadyExists.*", err.Error()) + suite.Regexp("error setting up context type "+*inferenceServiceTypeName+": rpc error: code = AlreadyExists.*", err.Error()) } func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInServeModel() { @@ -550,9 +562,9 @@ func (suite *CoreTestSuite) TestModelRegistryFailureForOmittedFieldInServeModel( suite.Nil(err) // steps to create model registry service - _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn) + _, err = mlmdtypes.CreateMLMDTypes(suite.grpcConn, mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()) suite.NotNil(err) - suite.Regexp("error setting up execution type kfmr.ServeModel: rpc error: code = AlreadyExists.*", err.Error()) + suite.Regexp("error setting up execution type "+*serveModelTypeName+": rpc error: code = AlreadyExists.*", err.Error()) } // REGISTERED MODELS diff --git a/test/robot/MRandLogicalModel.robot b/test/robot/MRandLogicalModel.robot index d5040707..fbfd82f3 100644 --- a/test/robot/MRandLogicalModel.robot +++ b/test/robot/MRandLogicalModel.robot @@ -29,13 +29,13 @@ Verify basic logical mapping between MR and MLMD # RegisteredModel shall result in a MLMD Context ${mlmdProto} Get Context By Single Id ${rId} Log To Console ${mlmdProto} - Should be equal ${mlmdProto.type} kfmr.RegisteredModel + Should be equal ${mlmdProto.type} kf.RegisteredModel Should be equal ${mlmdProto.name} ${name} # ModelVersion shall result in a MLMD Context and parent Context(of RegisteredModel) ${mlmdProto} Get Context By Single Id ${vId} Log To Console ${mlmdProto} - Should be equal ${mlmdProto.type} kfmr.ModelVersion + Should be equal ${mlmdProto.type} kf.ModelVersion Should be equal ${mlmdProto.name} ${rId}:v1 ${mlmdProto} Get Parent Contexts By Context ${vId} Should be equal ${mlmdProto[0].id} ${rId} @@ -44,7 +44,7 @@ Verify basic logical mapping between MR and MLMD ${aNamePrefix} Set Variable ${vId}: ${mlmdProto} Get Artifact By Single Id ${aId} Log To Console ${mlmdProto} - Should be equal ${mlmdProto.type} kfmr.ModelArtifact + Should be equal ${mlmdProto.type} kf.ModelArtifact Should Start With ${mlmdProto.name} ${aNamePrefix} Should be equal ${mlmdProto.uri} s3://12345 ${mlmdProto} Get Artifacts By Context ${vId}