Skip to content

Commit

Permalink
feat(csi): support multiple model registries (#508)
Browse files Browse the repository at this point in the history
* feat(csi): support multiple model registries

Signed-off-by: Alessio Pragliola <[email protected]>

* chore(csi): add info about the uri in readme

Signed-off-by: Alessio Pragliola <[email protected]>

* fix(csi): remove unnecessary comment

Signed-off-by: Alessio Pragliola <[email protected]>

* chore(csi): cleanup after each test + improve log message after uri parsing

Signed-off-by: Alessio Pragliola <[email protected]>

* chore(csi): remove unnecessary kind external config in CI

Signed-off-by: Alessio Pragliola <[email protected]>

---------

Signed-off-by: Alessio Pragliola <[email protected]>
  • Loading branch information
Al-Pragliola authored Oct 28, 2024
1 parent 505b00e commit 2422e85
Show file tree
Hide file tree
Showing 8 changed files with 433 additions and 112 deletions.
9 changes: 6 additions & 3 deletions csi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ sequenceDiagram
U->>+MR: Register ML Model
MR-->>-U: Indexed Model
U->>U: Create InferenceService CR
Note right of U: The InferenceService should<br/>point to the model registry<br/>indexed model, e.g.,:<br/> model-registry://<model>/<version>
Note right of U: The InferenceService should<br/>point to the model registry<br/>indexed model, e.g.,:<br/> model-registry://<model-registry-url>/<model>/<version>
KC->>KC: React to InferenceService creation
KC->>+MD: Create Model Deployment
MD->>+MRSI: Initialization (Download Model)
Expand Down Expand Up @@ -66,14 +66,17 @@ Which wil create the executable under `bin/mr-storage-initializer`.

You can run `main.go` (without building the executable) by running:
```bash
./bin/mr-storage-initializer "model-registry://model/version" "./"
./bin/mr-storage-initializer "model-registry://model-registry-url/model/version" "./"
```

or directly running the `main.go` skipping the previous step:
```bash
make SOURCE_URI=model-registry://model/version DEST_PATH=./ run
make SOURCE_URI=model-registry://model-registry-url/model/version DEST_PATH=./ run
```

> [!NOTE]
> `model-registry-url` is optional, if not provided the value of `MODEL_REGISTRY_BASE_URL` env variable will be used.
> [!NOTE]
> A Model Registry service should be up and running at `localhost:8080`.
Expand Down
6 changes: 5 additions & 1 deletion csi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log"
"os"

"github.com/kubeflow/model-registry/csi/pkg/modelregistry"
"github.com/kubeflow/model-registry/csi/pkg/storage"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand Down Expand Up @@ -38,7 +39,10 @@ func main() {
cfg := openapi.NewConfiguration()
cfg.Host = baseUrl
cfg.Scheme = scheme
provider, err := storage.NewModelRegistryProvider(cfg)

apiClient := modelregistry.NewAPIClient(cfg, sourceUri)

provider, err := storage.NewModelRegistryProvider(apiClient)
if err != nil {
log.Fatalf("Error initiliazing model registry provider: %v", err)
}
Expand Down
5 changes: 5 additions & 0 deletions csi/pkg/constants/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package constants

import kserve "github.com/kserve/kserve/pkg/agent/storage"

const MR kserve.Protocol = "model-registry://"
41 changes: 41 additions & 0 deletions csi/pkg/modelregistry/api_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package modelregistry

import (
"context"
"log"
"strings"

"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

func NewAPIClient(cfg *openapi.Configuration, storageUri string) *openapi.APIClient {
client := openapi.NewAPIClient(cfg)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) < 2 {
return client
}

newCfg := openapi.NewConfiguration()
newCfg.Host = tokens[0]
newCfg.Scheme = cfg.Scheme

newClient := openapi.NewAPIClient(newCfg)

if len(tokens) == 2 {
// Check if the model registry service is available
_, _, err := newClient.ModelRegistryServiceAPI.GetRegisteredModels(context.Background()).Execute()
if err != nil {
log.Printf("Falling back to base url %s for model registry service", cfg.Host)

return client
}
}

return newClient
}
161 changes: 114 additions & 47 deletions csi/pkg/storage/modelregistry_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,79 +2,81 @@ package storage

import (
"context"
"errors"
"fmt"
"log"
"regexp"
"strings"

kserve "github.com/kserve/kserve/pkg/agent/storage"
"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

const MR kserve.Protocol = "model-registry://"
var (
_ kserve.Provider = (*ModelRegistryProvider)(nil)
ErrInvalidMRURI = errors.New("invalid model registry URI, use like model-registry://{dnsName}/{registeredModelName}/{versionName}")
ErrNoVersionAssociated = errors.New("no versions associated to registered model")
ErrNoArtifactAssociated = errors.New("no artifacts associated to model version")
ErrNoModelArtifact = errors.New("no model artifact found for model version")
ErrModelArtifactEmptyURI = errors.New("model artifact has empty URI")
ErrNoStorageURI = errors.New("there is no storageUri supplied")
ErrNoProtocolInSTorageURI = errors.New("there is no protocol specified for the storageUri")
ErrProtocolNotSupported = errors.New("protocol not supported for storageUri")
ErrFetchingModelVersion = errors.New("error fetching model version")
ErrFetchingModelVersions = errors.New("error fetching model versions")
)

type ModelRegistryProvider struct {
Client *openapi.APIClient
Providers map[kserve.Protocol]kserve.Provider
}

func NewModelRegistryProvider(cfg *openapi.Configuration) (*ModelRegistryProvider, error) {
client := openapi.NewAPIClient(cfg)

func NewModelRegistryProvider(client *openapi.APIClient) (*ModelRegistryProvider, error) {
return &ModelRegistryProvider{
Client: client,
Providers: map[kserve.Protocol]kserve.Provider{},
}, nil
}

var _ kserve.Provider = (*ModelRegistryProvider)(nil)

// storageUri formatted like model-registry://{registeredModelName}/{versionName}
// storageUri formatted like model-registry://{modelRegistryUrl}/{registeredModelName}/{versionName}
func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, storageUri string) error {
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", modelName, storageUri, modelDir)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(MR))
tokens := strings.SplitN(mrUri, "/", 2)
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s",
modelName,
storageUri,
modelDir,
)

if len(tokens) == 0 || len(tokens) > 2 {
return fmt.Errorf("invalid model registry URI, use like model-registry://{registeredModelName}/{versionName}")
registeredModelName, versionName, err := p.parseModelVersion(storageUri)
if err != nil {
return err
}

registeredModelName := tokens[0]
var versionName *string
if len(tokens) == 2 {
versionName = &tokens[1]
}
log.Printf("Parsed storageUri=%s as: modelRegistryUrl=%s, registeredModelName=%s, versionName=%v",
storageUri,
p.Client.GetConfig().Host,
registeredModelName,
versionName,
)

log.Printf("Fetching model: registeredModelName=%s, versionName=%v", registeredModelName, versionName)

// Fetch the registered model
model, _, err := p.Client.ModelRegistryServiceAPI.FindRegisteredModel(context.Background()).Name(registeredModelName).Execute()
if err != nil {
return err
}

// Fetch model version by name or latest if not specified
var version *openapi.ModelVersion
if versionName != nil {
version, _, err = p.Client.ModelRegistryServiceAPI.FindModelVersion(context.Background()).Name(*versionName).ParentResourceId(*model.Id).Execute()
if err != nil {
return err
}
} else {
versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return err
}
log.Printf("Fetching model version: model=%v", model)

if versions.Size == 0 {
return fmt.Errorf("no versions associated to registered model %s", registeredModelName)
}
version = &versions.Items[0]
// Fetch model version by name or latest if not specified
version, err := p.fetchModelVersion(versionName, registeredModelName, model)
if err != nil {
return err
}

log.Printf("Fetching model artifacts: version=%v", version)

artifacts, _, err := p.Client.ModelRegistryServiceAPI.GetModelVersionArtifacts(context.Background(), *version.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Expand All @@ -84,20 +86,20 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
}

if artifacts.Size == 0 {
return fmt.Errorf("no artifacts associated to model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoArtifactAssociated, *version.Id)
}

modelArtifact := artifacts.Items[0].ModelArtifact
if modelArtifact == nil {
return fmt.Errorf("no model artifact found for model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoModelArtifact, *version.Id)
}

// Call appropriate kserve provider based on the indexed model artifact URI
if modelArtifact.Uri == nil {
return fmt.Errorf("model artifact %s has empty URI", *modelArtifact.Id)
return fmt.Errorf("%w %s", ErrModelArtifactEmptyURI, *modelArtifact.Id)
}

protocol, err := extractProtocol(*modelArtifact.Uri)
protocol, err := p.extractProtocol(*modelArtifact.Uri)
if err != nil {
return err
}
Expand All @@ -110,19 +112,84 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
return provider.DownloadModel(modelDir, "", *modelArtifact.Uri)
}

func extractProtocol(storageURI string) (kserve.Protocol, error) {
// Possible URIs:
// (1) model-registry://{modelName}
// (2) model-registry://{modelName}/{modelVersion}
// (3) model-registry://{modelRegistryUrl}/{modelName}
// (4) model-registry://{modelRegistryUrl}/{modelName}/{modelVersion}
func (p *ModelRegistryProvider) parseModelVersion(storageUri string) (string, *string, error) {
var versionName *string

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) == 0 || len(tokens) > 3 {
return "", nil, ErrInvalidMRURI
}

// Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2)
if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] {
tokens = tokens[1:]
}

registeredModelName := tokens[0]

if len(tokens) == 2 {
versionName = &tokens[1]
}

return registeredModelName, versionName, nil
}

func (p *ModelRegistryProvider) fetchModelVersion(
versionName *string,
registeredModelName string,
model *openapi.RegisteredModel,
) (*openapi.ModelVersion, error) {
if versionName != nil {
version, _, err := p.Client.ModelRegistryServiceAPI.
FindModelVersion(context.Background()).
Name(*versionName).
ParentResourceId(*model.Id).
Execute()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersion, err)
}

return version, nil
}

versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
// OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersions, err)
}

if versions.Size == 0 {
return nil, fmt.Errorf("%w %s", ErrNoVersionAssociated, registeredModelName)
}

return &versions.Items[0], nil
}

func (*ModelRegistryProvider) extractProtocol(storageURI string) (kserve.Protocol, error) {
if storageURI == "" {
return "", fmt.Errorf("there is no storageUri supplied")
return "", ErrNoStorageURI
}

if !regexp.MustCompile("\\w+?://").MatchString(storageURI) {
return "", fmt.Errorf("there is no protocol specified for the storageUri")
if !regexp.MustCompile(`\w+?://`).MatchString(storageURI) {
return "", ErrNoProtocolInSTorageURI
}

for _, prefix := range kserve.SupportedProtocols {
if strings.HasPrefix(storageURI, string(prefix)) {
return prefix, nil
}
}
return "", fmt.Errorf("protocol not supported for storageUri")

return "", ErrProtocolNotSupported
}
8 changes: 5 additions & 3 deletions csi/scripts/install_modelregistry.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ if ! kubectl get namespace "$namespace" &> /dev/null; then
fi
# Apply model-registry kustomize manifests
echo Using model registry image: $image
cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && cd -
cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && \
kustomize edit set namespace $namespace && cd -
cd $MR_ROOT/manifests/kustomize/overlays/db && kustomize edit set namespace $namespace && cd -
kubectl -n $namespace apply -k "$MR_ROOT/manifests/kustomize/overlays/db"

# Wait for model registry deployment
modelregistry=$(kubectl get pod -n kubeflow --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}')
kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m
modelregistry=$(kubectl get pod -n $namespace --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}')
kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m
Loading

0 comments on commit 2422e85

Please sign in to comment.