Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kubeflow: make MLMD type names (and prefix) pluggable #19

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clients/python/src/model_registry/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def get_proto_type_name(cls) -> str:
"""Name of the proto type.

Returns:
Name of the class prefixed with `kfmr.`
Name of the class prefixed with `kf.`
"""
return f"kfmr.{cls.__name__}"
return f"kf.{cls.__name__}"

@property
@abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions cmd/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
defer conn.Close()
glog.Infof("connected to MLMD server")

_, err = mlmdtypes.CreateMLMDTypes(conn)
mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()
_, err = mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig)
if err != nil {
return fmt.Errorf("error creating MLMD types: %v", err)
}
service, err := core.NewModelRegistryService(conn)
service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig)
if err != nil {
return fmt.Errorf("error creating core service: %v", err)
}
Expand Down
12 changes: 0 additions & 12 deletions internal/constants/constants.go

This file was deleted.

20 changes: 10 additions & 10 deletions internal/converter/mlmd_converter_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"strings"
"testing"

"github.com/kubeflow/model-registry/internal/constants"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/ml_metadata/proto"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -208,7 +208,7 @@ func TestMapRegisteredModelType(t *testing.T) {

typeName := MapRegisteredModelType(&openapi.RegisteredModel{})
assertion.NotNil(typeName)
assertion.Equal(constants.RegisteredModelTypeName, *typeName)
assertion.Equal(defaults.RegisteredModelTypeName, *typeName)
}

func TestMapModelVersionProperties(t *testing.T) {
Expand Down Expand Up @@ -236,7 +236,7 @@ func TestMapModelVersionType(t *testing.T) {

typeName := MapModelVersionType(&openapi.ModelVersion{})
assertion.NotNil(typeName)
assertion.Equal(constants.ModelVersionTypeName, *typeName)
assertion.Equal(defaults.ModelVersionTypeName, *typeName)
}

func TestMapModelVersionName(t *testing.T) {
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestMapModelArtifactType(t *testing.T) {

typeName := MapModelArtifactType(&openapi.ModelArtifact{})
assertion.NotNil(typeName)
assertion.Equal(constants.ModelArtifactTypeName, *typeName)
assertion.Equal(defaults.ModelArtifactTypeName, *typeName)
}

func TestMapModelArtifactName(t *testing.T) {
Expand Down Expand Up @@ -346,7 +346,7 @@ func TestMapDocArtifactType(t *testing.T) {

typeName := MapModelArtifactType(&openapi.ModelArtifact{})
assertion.NotNil(typeName)
assertion.Equal(constants.ModelArtifactTypeName, *typeName)
assertion.Equal(defaults.ModelArtifactTypeName, *typeName)
}

func TestMapDocArtifactName(t *testing.T) {
Expand Down Expand Up @@ -577,13 +577,13 @@ func TestMapArtifactType(t *testing.T) {
assertion := setup(t)

artifactType, err := MapArtifactType(&proto.Artifact{
Type: of(constants.ModelArtifactTypeName),
Type: of(defaults.ModelArtifactTypeName),
})
assertion.Nil(err)
assertion.Equal("model-artifact", artifactType)

artifactType, err = MapArtifactType(&proto.Artifact{
Type: of(constants.DocArtifactTypeName),
Type: of(defaults.DocArtifactTypeName),
})
assertion.Nil(err)
assertion.Equal("doc-artifact", artifactType)
Expand Down Expand Up @@ -659,15 +659,15 @@ func TestMapServingEnvironmentType(t *testing.T) {

typeName := MapServingEnvironmentType(&openapi.ServingEnvironment{})
assertion.NotNil(typeName)
assertion.Equal(constants.ServingEnvironmentTypeName, *typeName)
assertion.Equal(defaults.ServingEnvironmentTypeName, *typeName)
}

func TestMapInferenceServiceType(t *testing.T) {
assertion := setup(t)

typeName := MapInferenceServiceType(&openapi.InferenceService{})
assertion.NotNil(typeName)
assertion.Equal(constants.InferenceServiceTypeName, *typeName)
assertion.Equal(defaults.InferenceServiceTypeName, *typeName)
}

func TestMapInferenceServiceProperties(t *testing.T) {
Expand Down Expand Up @@ -710,7 +710,7 @@ func TestMapServeModelType(t *testing.T) {

typeName := MapServeModelType(&openapi.ServeModel{})
assertion.NotNil(typeName)
assertion.Equal(constants.ServeModelTypeName, *typeName)
assertion.Equal(defaults.ServeModelTypeName, *typeName)
}

func TestMapServeModelProperties(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions internal/converter/mlmd_openapi_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"
"strings"

"github.com/kubeflow/model-registry/internal/constants"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/ml_metadata/proto"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand Down Expand Up @@ -87,9 +87,9 @@ func MapArtifactType(source *proto.Artifact) (string, error) {
return "", fmt.Errorf("artifact type is nil")
}
switch *source.Type {
case constants.ModelArtifactTypeName:
case defaults.ModelArtifactTypeName:
return "model-artifact", nil
case constants.DocArtifactTypeName:
case defaults.DocArtifactTypeName:
return "doc-artifact", nil
default:
return "", fmt.Errorf("invalid artifact type found: %v", source.Type)
Expand Down
16 changes: 8 additions & 8 deletions internal/converter/openapi_mlmd_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strconv"

"github.com/google/uuid"
"github.com/kubeflow/model-registry/internal/constants"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/ml_metadata/proto"
"github.com/kubeflow/model-registry/pkg/openapi"
"google.golang.org/protobuf/types/known/structpb"
Expand Down Expand Up @@ -144,7 +144,7 @@ func MapRegisteredModelProperties(source *openapi.RegisteredModel) (map[string]*

// MapRegisteredModelType return RegisteredModel corresponding MLMD context type
func MapRegisteredModelType(_ *openapi.RegisteredModel) *string {
return of(constants.RegisteredModelTypeName)
return of(defaults.RegisteredModelTypeName)
}

// MODEL VERSION
Expand Down Expand Up @@ -194,7 +194,7 @@ func MapModelVersionProperties(source *OpenAPIModelWrapper[openapi.ModelVersion]

// MapModelVersionType return ModelVersion corresponding MLMD context type
func MapModelVersionType(_ *openapi.ModelVersion) *string {
return of(constants.ModelVersionTypeName)
return of(defaults.ModelVersionTypeName)
}

// MapModelVersionName maps the user-provided name into MLMD one, i.e., prefixing it with
Expand Down Expand Up @@ -222,7 +222,7 @@ func MapOpenAPIArtifactState(source *openapi.ArtifactState) (*proto.Artifact_Sta

// get DocArtifact MLMD type name
func MapDocArtifactType(_ *openapi.DocArtifact) *string {
return of(constants.DocArtifactTypeName)
return of(defaults.DocArtifactTypeName)
}

func MapDocArtifactProperties(source *openapi.DocArtifact) (map[string]*proto.Value, error) {
Expand Down Expand Up @@ -307,7 +307,7 @@ func MapModelArtifactProperties(source *openapi.ModelArtifact) (map[string]*prot

// MapModelArtifactType return ModelArtifact corresponding MLMD context type
func MapModelArtifactType(_ *openapi.ModelArtifact) *string {
return of(constants.ModelArtifactTypeName)
return of(defaults.ModelArtifactTypeName)
}

// MapModelArtifactName maps the user-provided name into MLMD one, i.e., prefixing it with
Expand All @@ -328,7 +328,7 @@ func MapModelArtifactName(source *OpenAPIModelWrapper[openapi.ModelArtifact]) *s

// MapServingEnvironmentType return ServingEnvironment corresponding MLMD context type
func MapServingEnvironmentType(_ *openapi.ServingEnvironment) *string {
return of(constants.ServingEnvironmentTypeName)
return of(defaults.ServingEnvironmentTypeName)
}

// MapServingEnvironmentProperties maps ServingEnvironment fields to specific MLMD properties
Expand All @@ -350,7 +350,7 @@ func MapServingEnvironmentProperties(source *openapi.ServingEnvironment) (map[st

// MapInferenceServiceType return InferenceService corresponding MLMD context type
func MapInferenceServiceType(_ *openapi.InferenceService) *string {
return of(constants.InferenceServiceTypeName)
return of(defaults.InferenceServiceTypeName)
}

// MapInferenceServiceProperties maps InferenceService fields to specific MLMD properties
Expand Down Expand Up @@ -436,7 +436,7 @@ func MapInferenceServiceName(source *OpenAPIModelWrapper[openapi.InferenceServic

// MapServeModelType return ServeModel corresponding MLMD context type
func MapServeModelType(_ *openapi.ServeModel) *string {
return of(constants.ServeModelTypeName)
return of(defaults.ServeModelTypeName)
}

// MapServeModelProperties maps ServeModel fields to specific MLMD properties
Expand Down
12 changes: 12 additions & 0 deletions internal/defaults/defaults.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package defaults

// MLMD type names
const (
RegisteredModelTypeName = "kf.RegisteredModel"
ModelVersionTypeName = "kf.ModelVersion"
ModelArtifactTypeName = "kf.ModelArtifact"
DocArtifactTypeName = "kf.DocArtifact"
ServingEnvironmentTypeName = "kf.ServingEnvironment"
InferenceServiceTypeName = "kf.InferenceService"
ServeModelTypeName = "kf.ServeModel"
)
34 changes: 17 additions & 17 deletions internal/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package mapper
import (
"fmt"

"github.com/kubeflow/model-registry/internal/constants"
"github.com/kubeflow/model-registry/internal/converter"
"github.com/kubeflow/model-registry/internal/converter/generated"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/ml_metadata/proto"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand All @@ -28,14 +28,14 @@ func NewMapper(mlmdTypes map[string]int64) *Mapper {

func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertRegisteredModel(&converter.OpenAPIModelWrapper[openapi.RegisteredModel]{
TypeId: m.MLMDTypes[constants.RegisteredModelTypeName],
TypeId: m.MLMDTypes[defaults.RegisteredModelTypeName],
Model: registeredModel,
})
}

func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registeredModelId string, registeredModelName *string) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertModelVersion(&converter.OpenAPIModelWrapper[openapi.ModelVersion]{
TypeId: m.MLMDTypes[constants.ModelVersionTypeName],
TypeId: m.MLMDTypes[defaults.ModelVersionTypeName],
Model: modelVersion,
ParentResourceId: &registeredModelId,
ModelName: registeredModelName,
Expand All @@ -44,15 +44,15 @@ func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registe

func (m *Mapper) MapFromModelArtifact(modelArtifact *openapi.ModelArtifact, modelVersionId *string) (*proto.Artifact, error) {
return m.OpenAPIConverter.ConvertModelArtifact(&converter.OpenAPIModelWrapper[openapi.ModelArtifact]{
TypeId: m.MLMDTypes[constants.ModelArtifactTypeName],
TypeId: m.MLMDTypes[defaults.ModelArtifactTypeName],
Model: modelArtifact,
ParentResourceId: modelVersionId,
})
}

func (m *Mapper) MapFromDocArtifact(docArtifact *openapi.DocArtifact, modelVersionId *string) (*proto.Artifact, error) {
return m.OpenAPIConverter.ConvertDocArtifact(&converter.OpenAPIModelWrapper[openapi.DocArtifact]{
TypeId: m.MLMDTypes[constants.DocArtifactTypeName],
TypeId: m.MLMDTypes[defaults.DocArtifactTypeName],
Model: docArtifact,
ParentResourceId: modelVersionId,
})
Expand Down Expand Up @@ -89,22 +89,22 @@ func (m *Mapper) MapFromModelArtifacts(modelArtifacts []openapi.ModelArtifact, m

func (m *Mapper) MapFromServingEnvironment(servingEnvironment *openapi.ServingEnvironment) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertServingEnvironment(&converter.OpenAPIModelWrapper[openapi.ServingEnvironment]{
TypeId: m.MLMDTypes[constants.ServingEnvironmentTypeName],
TypeId: m.MLMDTypes[defaults.ServingEnvironmentTypeName],
Model: servingEnvironment,
})
}

func (m *Mapper) MapFromInferenceService(inferenceService *openapi.InferenceService, servingEnvironmentId string) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertInferenceService(&converter.OpenAPIModelWrapper[openapi.InferenceService]{
TypeId: m.MLMDTypes[constants.InferenceServiceTypeName],
TypeId: m.MLMDTypes[defaults.InferenceServiceTypeName],
Model: inferenceService,
ParentResourceId: &servingEnvironmentId,
})
}

func (m *Mapper) MapFromServeModel(serveModel *openapi.ServeModel, inferenceServiceId string) (*proto.Execution, error) {
return m.OpenAPIConverter.ConvertServeModel(&converter.OpenAPIModelWrapper[openapi.ServeModel]{
TypeId: m.MLMDTypes[constants.ServeModelTypeName],
TypeId: m.MLMDTypes[defaults.ServeModelTypeName],
Model: serveModel,
ParentResourceId: &inferenceServiceId,
})
Expand All @@ -113,19 +113,19 @@ func (m *Mapper) MapFromServeModel(serveModel *openapi.ServeModel, inferenceServ
// Utilities for MLMD --> OpenAPI mapping, make use of generated Converters

func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) {
return mapTo(ctx, m.MLMDTypes, constants.RegisteredModelTypeName, m.MLMDConverter.ConvertRegisteredModel)
return mapTo(ctx, m.MLMDTypes, defaults.RegisteredModelTypeName, m.MLMDConverter.ConvertRegisteredModel)
}

func (m *Mapper) MapToModelVersion(ctx *proto.Context) (*openapi.ModelVersion, error) {
return mapTo(ctx, m.MLMDTypes, constants.ModelVersionTypeName, m.MLMDConverter.ConvertModelVersion)
return mapTo(ctx, m.MLMDTypes, defaults.ModelVersionTypeName, m.MLMDConverter.ConvertModelVersion)
}

func (m *Mapper) MapToModelArtifact(art *proto.Artifact) (*openapi.ModelArtifact, error) {
return mapTo(art, m.MLMDTypes, constants.ModelArtifactTypeName, m.MLMDConverter.ConvertModelArtifact)
return mapTo(art, m.MLMDTypes, defaults.ModelArtifactTypeName, m.MLMDConverter.ConvertModelArtifact)
}

func (m *Mapper) MapToDocArtifact(art *proto.Artifact) (*openapi.DocArtifact, error) {
return mapTo(art, m.MLMDTypes, constants.DocArtifactTypeName, m.MLMDConverter.ConvertDocArtifact)
return mapTo(art, m.MLMDTypes, defaults.DocArtifactTypeName, m.MLMDConverter.ConvertDocArtifact)
}

func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) {
Expand All @@ -136,12 +136,12 @@ func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) {
return nil, fmt.Errorf("invalid artifact type, can't map from nil")
}
switch art.GetType() {
case constants.ModelArtifactTypeName:
case defaults.ModelArtifactTypeName:
ma, err := m.MapToModelArtifact(art)
return &openapi.Artifact{
ModelArtifact: ma,
}, err
case constants.DocArtifactTypeName:
case defaults.DocArtifactTypeName:
da, err := m.MapToDocArtifact(art)
return &openapi.Artifact{
DocArtifact: da,
Expand All @@ -152,15 +152,15 @@ func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) {
}

func (m *Mapper) MapToServingEnvironment(ctx *proto.Context) (*openapi.ServingEnvironment, error) {
return mapTo(ctx, m.MLMDTypes, constants.ServingEnvironmentTypeName, m.MLMDConverter.ConvertServingEnvironment)
return mapTo(ctx, m.MLMDTypes, defaults.ServingEnvironmentTypeName, m.MLMDConverter.ConvertServingEnvironment)
}

func (m *Mapper) MapToInferenceService(ctx *proto.Context) (*openapi.InferenceService, error) {
return mapTo(ctx, m.MLMDTypes, constants.InferenceServiceTypeName, m.MLMDConverter.ConvertInferenceService)
return mapTo(ctx, m.MLMDTypes, defaults.InferenceServiceTypeName, m.MLMDConverter.ConvertInferenceService)
}

func (m *Mapper) MapToServeModel(ex *proto.Execution) (*openapi.ServeModel, error) {
return mapTo(ex, m.MLMDTypes, constants.ServeModelTypeName, m.MLMDConverter.ConvertServeModel)
return mapTo(ex, m.MLMDTypes, defaults.ServeModelTypeName, m.MLMDConverter.ConvertServeModel)
}

type getTypeIder interface {
Expand Down
Loading
Loading