Skip to content

Commit

Permalink
Merge sync remote-tracking branch 'upstream/main' into tarilabs-20241…
Browse files Browse the repository at this point in the history
…028-sync
  • Loading branch information
tarilabs committed Oct 28, 2024
2 parents 279811d + 2422e85 commit b862a97
Show file tree
Hide file tree
Showing 16 changed files with 595 additions and 242 deletions.
244 changes: 125 additions & 119 deletions clients/python/poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ nest-asyncio = "^1.6.0"
# necessary for modern type annotations using pydantic on 3.9
eval-type-backport = "^0.2.0"

huggingface-hub = { version = ">=0.20.1,<0.26.0", optional = true }
huggingface-hub = { version = ">=0.20.1,<0.27.0", optional = true }

[tool.poetry.extras]
hf = ["huggingface-hub"]
Expand All @@ -45,7 +45,7 @@ sphinx-autobuild = ">=2021.3.14,<2025.0.0"
pytest = ">=7.4.2,<9.0.0"
coverage = { extras = ["toml"], version = "^7.3.2" }
pytest-cov = ">=4.1,<6.0"
ruff = ">=0.5.2,<0.7.0"
ruff = ">=0.5.2,<0.8.0"
mypy = "^1.7.0"
pytest-asyncio = ">=0.23.7,<0.25.0"
requests = "^2.32.2"
Expand Down
4 changes: 2 additions & 2 deletions csi/GET_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ We assume all [prerequisites](#prerequisites) are satisfied at this point.

3. Setup local deployment of *Kserve* using the provided *Kserve quick installation* script
```bash
curl -s "https://raw.githubusercontent.com/kserve/kserve/release-0.12/hack/quick_install.sh" | bash
curl -s "https://raw.githubusercontent.com/kserve/kserve/release-0.14/hack/quick_install.sh" | bash
```

4. Install *model registry* in the local cluster
Expand Down Expand Up @@ -257,5 +257,5 @@ EOF
If you do not have DNS, you can still curl with the ingress gateway external IP using the HOST Header.
```bash
SERVICE_HOSTNAME=$(kubectl get inferenceservice iris-model -n kserve-test -o jsonpath='{.status.url}' | cut -d "/" -f 3)
curl -v -H "Host: ${SERVICE_HOSTNAME}" -H "Content-Type: application/json" "http://${INGRESS_HOST}:${INGRESS_PORT}/v1/models/iris-v1:predict" -d @/tmp/iris-input.json
curl -v -H "Host: ${SERVICE_HOSTNAME}" -H "Content-Type: application/json" "http://${INGRESS_HOST}:${INGRESS_PORT}/v1/models/iris-model:predict" -d @/tmp/iris-input.json
```
9 changes: 6 additions & 3 deletions csi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ sequenceDiagram
U->>+MR: Register ML Model
MR-->>-U: Indexed Model
U->>U: Create InferenceService CR
Note right of U: The InferenceService should<br/>point to the model registry<br/>indexed model, e.g.,:<br/> model-registry://<model>/<version>
Note right of U: The InferenceService should<br/>point to the model registry<br/>indexed model, e.g.,:<br/> model-registry://<model-registry-url>/<model>/<version>
KC->>KC: React to InferenceService creation
KC->>+MD: Create Model Deployment
MD->>+MRSI: Initialization (Download Model)
Expand Down Expand Up @@ -66,14 +66,17 @@ Which wil create the executable under `bin/mr-storage-initializer`.

You can run `main.go` (without building the executable) by running:
```bash
./bin/mr-storage-initializer "model-registry://model/version" "./"
./bin/mr-storage-initializer "model-registry://model-registry-url/model/version" "./"
```

or directly running the `main.go` skipping the previous step:
```bash
make SOURCE_URI=model-registry://model/version DEST_PATH=./ run
make SOURCE_URI=model-registry://model-registry-url/model/version DEST_PATH=./ run
```

> [!NOTE]
> `model-registry-url` is optional, if not provided the value of `MODEL_REGISTRY_BASE_URL` env variable will be used.
> [!NOTE]
> A Model Registry service should be up and running at `localhost:8080`.
Expand Down
2 changes: 1 addition & 1 deletion csi/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.21

require (
github.com/kserve/kserve v0.13.1
github.com/kubeflow/model-registry v0.2.8-alpha
github.com/kubeflow/model-registry v0.2.9
)

require (
Expand Down
4 changes: 2 additions & 2 deletions csi/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kserve/kserve v0.13.1 h1:MRszrN5pf1nNzBBoyTeBsoIYcbWvuve5G1pBwdKj9dI=
github.com/kserve/kserve v0.13.1/go.mod h1:l6fHejIVM3RYO9cD9Q0gQ4eriCz3lQaFIdcT05rMUbs=
github.com/kubeflow/model-registry v0.2.8-alpha h1:M1rdHHTlQ/QKZv3Wi9EZU+2heAe3YbaOLjb8uq9FggI=
github.com/kubeflow/model-registry v0.2.8-alpha/go.mod h1:luHw6hNjX7oim7PFiYxOEkSf3EGqZ/qZ5cUji5fKlA4=
github.com/kubeflow/model-registry v0.2.9 h1:+uOUeJwo9yzUObfqdgmQmGhrWWVi0mKGK8hK5UFBnFM=
github.com/kubeflow/model-registry v0.2.9/go.mod h1:QXBKUJhldJj6mB81+OBUIUYtajUzQw72zIoPbCSDriM=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
Expand Down
6 changes: 5 additions & 1 deletion csi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log"
"os"

"github.com/kubeflow/model-registry/csi/pkg/modelregistry"
"github.com/kubeflow/model-registry/csi/pkg/storage"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand Down Expand Up @@ -38,7 +39,10 @@ func main() {
cfg := openapi.NewConfiguration()
cfg.Host = baseUrl
cfg.Scheme = scheme
provider, err := storage.NewModelRegistryProvider(cfg)

apiClient := modelregistry.NewAPIClient(cfg, sourceUri)

provider, err := storage.NewModelRegistryProvider(apiClient)
if err != nil {
log.Fatalf("Error initiliazing model registry provider: %v", err)
}
Expand Down
5 changes: 5 additions & 0 deletions csi/pkg/constants/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package constants

import kserve "github.com/kserve/kserve/pkg/agent/storage"

const MR kserve.Protocol = "model-registry://"
41 changes: 41 additions & 0 deletions csi/pkg/modelregistry/api_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package modelregistry

import (
"context"
"log"
"strings"

"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

func NewAPIClient(cfg *openapi.Configuration, storageUri string) *openapi.APIClient {
client := openapi.NewAPIClient(cfg)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) < 2 {
return client
}

newCfg := openapi.NewConfiguration()
newCfg.Host = tokens[0]
newCfg.Scheme = cfg.Scheme

newClient := openapi.NewAPIClient(newCfg)

if len(tokens) == 2 {
// Check if the model registry service is available
_, _, err := newClient.ModelRegistryServiceAPI.GetRegisteredModels(context.Background()).Execute()
if err != nil {
log.Printf("Falling back to base url %s for model registry service", cfg.Host)

return client
}
}

return newClient
}
161 changes: 114 additions & 47 deletions csi/pkg/storage/modelregistry_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,79 +2,81 @@ package storage

import (
"context"
"errors"
"fmt"
"log"
"regexp"
"strings"

kserve "github.com/kserve/kserve/pkg/agent/storage"
"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

const MR kserve.Protocol = "model-registry://"
var (
_ kserve.Provider = (*ModelRegistryProvider)(nil)
ErrInvalidMRURI = errors.New("invalid model registry URI, use like model-registry://{dnsName}/{registeredModelName}/{versionName}")
ErrNoVersionAssociated = errors.New("no versions associated to registered model")
ErrNoArtifactAssociated = errors.New("no artifacts associated to model version")
ErrNoModelArtifact = errors.New("no model artifact found for model version")
ErrModelArtifactEmptyURI = errors.New("model artifact has empty URI")
ErrNoStorageURI = errors.New("there is no storageUri supplied")
ErrNoProtocolInSTorageURI = errors.New("there is no protocol specified for the storageUri")
ErrProtocolNotSupported = errors.New("protocol not supported for storageUri")
ErrFetchingModelVersion = errors.New("error fetching model version")
ErrFetchingModelVersions = errors.New("error fetching model versions")
)

type ModelRegistryProvider struct {
Client *openapi.APIClient
Providers map[kserve.Protocol]kserve.Provider
}

func NewModelRegistryProvider(cfg *openapi.Configuration) (*ModelRegistryProvider, error) {
client := openapi.NewAPIClient(cfg)

func NewModelRegistryProvider(client *openapi.APIClient) (*ModelRegistryProvider, error) {
return &ModelRegistryProvider{
Client: client,
Providers: map[kserve.Protocol]kserve.Provider{},
}, nil
}

var _ kserve.Provider = (*ModelRegistryProvider)(nil)

// storageUri formatted like model-registry://{registeredModelName}/{versionName}
// storageUri formatted like model-registry://{modelRegistryUrl}/{registeredModelName}/{versionName}
func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, storageUri string) error {
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", modelName, storageUri, modelDir)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(MR))
tokens := strings.SplitN(mrUri, "/", 2)
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s",
modelName,
storageUri,
modelDir,
)

if len(tokens) == 0 || len(tokens) > 2 {
return fmt.Errorf("invalid model registry URI, use like model-registry://{registeredModelName}/{versionName}")
registeredModelName, versionName, err := p.parseModelVersion(storageUri)
if err != nil {
return err
}

registeredModelName := tokens[0]
var versionName *string
if len(tokens) == 2 {
versionName = &tokens[1]
}
log.Printf("Parsed storageUri=%s as: modelRegistryUrl=%s, registeredModelName=%s, versionName=%v",
storageUri,
p.Client.GetConfig().Host,
registeredModelName,
versionName,
)

log.Printf("Fetching model: registeredModelName=%s, versionName=%v", registeredModelName, versionName)

// Fetch the registered model
model, _, err := p.Client.ModelRegistryServiceAPI.FindRegisteredModel(context.Background()).Name(registeredModelName).Execute()
if err != nil {
return err
}

// Fetch model version by name or latest if not specified
var version *openapi.ModelVersion
if versionName != nil {
version, _, err = p.Client.ModelRegistryServiceAPI.FindModelVersion(context.Background()).Name(*versionName).ParentResourceId(*model.Id).Execute()
if err != nil {
return err
}
} else {
versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return err
}
log.Printf("Fetching model version: model=%v", model)

if versions.Size == 0 {
return fmt.Errorf("no versions associated to registered model %s", registeredModelName)
}
version = &versions.Items[0]
// Fetch model version by name or latest if not specified
version, err := p.fetchModelVersion(versionName, registeredModelName, model)
if err != nil {
return err
}

log.Printf("Fetching model artifacts: version=%v", version)

artifacts, _, err := p.Client.ModelRegistryServiceAPI.GetModelVersionArtifacts(context.Background(), *version.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Expand All @@ -84,20 +86,20 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
}

if artifacts.Size == 0 {
return fmt.Errorf("no artifacts associated to model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoArtifactAssociated, *version.Id)
}

modelArtifact := artifacts.Items[0].ModelArtifact
if modelArtifact == nil {
return fmt.Errorf("no model artifact found for model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoModelArtifact, *version.Id)
}

// Call appropriate kserve provider based on the indexed model artifact URI
if modelArtifact.Uri == nil {
return fmt.Errorf("model artifact %s has empty URI", *modelArtifact.Id)
return fmt.Errorf("%w %s", ErrModelArtifactEmptyURI, *modelArtifact.Id)
}

protocol, err := extractProtocol(*modelArtifact.Uri)
protocol, err := p.extractProtocol(*modelArtifact.Uri)
if err != nil {
return err
}
Expand All @@ -110,19 +112,84 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
return provider.DownloadModel(modelDir, "", *modelArtifact.Uri)
}

func extractProtocol(storageURI string) (kserve.Protocol, error) {
// Possible URIs:
// (1) model-registry://{modelName}
// (2) model-registry://{modelName}/{modelVersion}
// (3) model-registry://{modelRegistryUrl}/{modelName}
// (4) model-registry://{modelRegistryUrl}/{modelName}/{modelVersion}
func (p *ModelRegistryProvider) parseModelVersion(storageUri string) (string, *string, error) {
var versionName *string

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) == 0 || len(tokens) > 3 {
return "", nil, ErrInvalidMRURI
}

// Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2)
if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] {
tokens = tokens[1:]
}

registeredModelName := tokens[0]

if len(tokens) == 2 {
versionName = &tokens[1]
}

return registeredModelName, versionName, nil
}

func (p *ModelRegistryProvider) fetchModelVersion(
versionName *string,
registeredModelName string,
model *openapi.RegisteredModel,
) (*openapi.ModelVersion, error) {
if versionName != nil {
version, _, err := p.Client.ModelRegistryServiceAPI.
FindModelVersion(context.Background()).
Name(*versionName).
ParentResourceId(*model.Id).
Execute()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersion, err)
}

return version, nil
}

versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
// OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersions, err)
}

if versions.Size == 0 {
return nil, fmt.Errorf("%w %s", ErrNoVersionAssociated, registeredModelName)
}

return &versions.Items[0], nil
}

func (*ModelRegistryProvider) extractProtocol(storageURI string) (kserve.Protocol, error) {
if storageURI == "" {
return "", fmt.Errorf("there is no storageUri supplied")
return "", ErrNoStorageURI
}

if !regexp.MustCompile("\\w+?://").MatchString(storageURI) {
return "", fmt.Errorf("there is no protocol specified for the storageUri")
if !regexp.MustCompile(`\w+?://`).MatchString(storageURI) {
return "", ErrNoProtocolInSTorageURI
}

for _, prefix := range kserve.SupportedProtocols {
if strings.HasPrefix(storageURI, string(prefix)) {
return prefix, nil
}
}
return "", fmt.Errorf("protocol not supported for storageUri")

return "", ErrProtocolNotSupported
}
Loading

0 comments on commit b862a97

Please sign in to comment.