Skip to content

Commit

Permalink
py: provide API builders for secure and insecure connections
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed May 3, 2024
1 parent b9c638b commit 4c66394
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 92 deletions.
33 changes: 32 additions & 1 deletion clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import os
from pathlib import Path
from typing import get_args
from warnings import warn

Expand Down Expand Up @@ -36,7 +38,36 @@ def __init__(
"""
# TODO: get args from env
self._author = author
self._api = ModelRegistryAPIClient(server_address, port, user_token, custom_ca)

if not user_token:
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
user_token = Path(sa_token).read_bytes()
else:
warn("User access token is missing", stacklevel=2)

root_certs = None
if not custom_ca:
ca_cert = os.environ.get("CERT")
if ca_cert:
root_certs = Path(ca_cert).read_bytes()
elif port == 443:
warn(
"missing CA certificate, which is required for a secure connection",
stacklevel=2,
)
else:
root_certs = custom_ca

if root_certs:
self._api = ModelRegistryAPIClient.secure_connection(
server_address, port, user_token, custom_ca
)
else:
self._api = ModelRegistryAPIClient.insecure_connection(
server_address, port, user_token
)

def _register_model(self, name: str) -> RegisteredModel:
if rm := self._api.get_registered_model_by_params(name):
Expand Down
117 changes: 61 additions & 56 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path
from warnings import warn

Expand All @@ -16,16 +17,20 @@
from .utils import header_adder_interceptor


@dataclass
class ModelRegistryAPIClient:
"""Model registry API."""

def __init__(
self,
store: MLMDStore

@classmethod
def secure_connection(
cls,
server_address: str,
port: int = 443,
user_token: bytes | None = None,
custom_ca: bytes | None = None,
):
) -> ModelRegistryAPIClient:
"""Constructor.
Args:
Expand All @@ -34,37 +39,39 @@ def __init__(
user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
"""
if not user_token:
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
user_token = Path(sa_token).read_bytes()
else:
warn("User access token is missing", stacklevel=2)

if port == 443:
if not custom_ca:
ca_cert = os.environ.get("CERT")
if not ca_cert:
msg = "CA certificate must be provided"
raise StoreException(msg)
root_certs = Path(ca_cert).read_bytes()
else:
root_certs = custom_ca
chan_creds = grpc.ssl_channel_credentials(root_certs)

if user_token:
call_creds = grpc.access_token_call_credentials(user_token)
chan_creds = grpc.composite_channel_credentials(
chan_creds,
call_creds,
)

chan = grpc.secure_channel(
f"{server_address}:443",
chan_creds = grpc.ssl_channel_credentials(custom_ca)

if user_token:
chan_creds = grpc.composite_channel_credentials(
chan_creds,
grpc.access_token_call_credentials(user_token),
)
elif user_token:

if port != 443:
warn(f"Using non-standard port for TLS connection {port}", stacklevel=2)

chan = grpc.secure_channel(
f"{server_address}:{port}",
chan_creds,
)

return cls(MLMDStore.from_channel(chan))

@classmethod
def insecure_connection(
cls,
server_address: str,
port: int,
user_token: bytes | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Args:
server_address: Server address.
port: Server port.
user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH.
"""
if user_token:
chan = grpc.intercept_channel(
grpc.insecure_channel(f"{server_address}:{port}"),
# header key has to be lowercase
Expand All @@ -73,7 +80,7 @@ def __init__(
else:
chan = grpc.insecure_channel(f"{server_address}:{port}")

self._store = MLMDStore.from_channel(chan)
return cls(MLMDStore.from_channel(chan))

def _map(self, py_obj: ProtoBase) -> ProtoType:
"""Map a Python object to a proto object.
Expand All @@ -86,7 +93,7 @@ def _map(self, py_obj: ProtoBase) -> ProtoType:
Returns:
Proto object.
"""
type_id = self._store.get_type_id(
type_id = self.store.get_type_id(
py_obj.get_proto_type(), py_obj.get_proto_type_name()
)
return py_obj.map(type_id)
Expand All @@ -103,9 +110,9 @@ def upsert_registered_model(self, registered_model: RegisteredModel) -> str:
Returns:
ID of the registered model.
"""
id = self._store.put_context(self._map(registered_model))
id = self.store.put_context(self._map(registered_model))
new_py_rm = RegisteredModel.unmap(
self._store.get_context(RegisteredModel.get_proto_type_name(), id)
self.store.get_context(RegisteredModel.get_proto_type_name(), id)
)
id = str(id)
registered_model.id = id
Expand All @@ -124,7 +131,7 @@ def get_registered_model_by_id(self, id: str) -> RegisteredModel | None:
Returns:
Registered model.
"""
proto_rm = self._store.get_context(
proto_rm = self.store.get_context(
RegisteredModel.get_proto_type_name(), id=int(id)
)
if proto_rm is not None:
Expand All @@ -150,7 +157,7 @@ def get_registered_model_by_params(
if name is None and external_id is None:
msg = "Either name or external_id must be provided"
raise StoreException(msg)
proto_rm = self._store.get_context(
proto_rm = self.store.get_context(
RegisteredModel.get_proto_type_name(),
name=name,
external_id=external_id,
Expand All @@ -172,7 +179,7 @@ def get_registered_models(
Registered models.
"""
mlmd_options = options.as_mlmd_list_options() if options else MLMDListOptions()
proto_rms = self._store.get_contexts(
proto_rms = self.store.get_contexts(
RegisteredModel.get_proto_type_name(), mlmd_options
)
return [RegisteredModel.unmap(proto_rm) for proto_rm in proto_rms]
Expand All @@ -194,10 +201,10 @@ def upsert_model_version(
"""
# this is not ideal but we need this info for the prefix
model_version._registered_model_id = registered_model_id
id = self._store.put_context(self._map(model_version))
self._store.put_context_parent(int(registered_model_id), id)
id = self.store.put_context(self._map(model_version))
self.store.put_context_parent(int(registered_model_id), id)
new_py_mv = ModelVersion.unmap(
self._store.get_context(ModelVersion.get_proto_type_name(), id)
self.store.get_context(ModelVersion.get_proto_type_name(), id)
)
id = str(id)
model_version.id = id
Expand All @@ -216,7 +223,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:
Returns:
Model version.
"""
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(), id=int(model_version_id)
)
if proto_mv is not None:
Expand All @@ -240,7 +247,7 @@ def get_model_versions(
mlmd_options.filter_query = f"parent_contexts_a.id = {registered_model_id}"
return [
ModelVersion.unmap(proto_mv)
for proto_mv in self._store.get_contexts(
for proto_mv in self.store.get_contexts(
ModelVersion.get_proto_type_name(), mlmd_options
)
]
Expand All @@ -267,7 +274,7 @@ def get_model_version_by_params(
StoreException: If neither external ID nor registered model ID and version is provided.
"""
if external_id is not None:
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(), external_id=external_id
)
elif registered_model_id is None or version is None:
Expand All @@ -276,7 +283,7 @@ def get_model_version_by_params(
)
raise StoreException(msg)
else:
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(),
name=f"{registered_model_id}:{version}",
)
Expand Down Expand Up @@ -304,17 +311,17 @@ def upsert_model_artifact(
StoreException: If the model version already has a model artifact.
"""
mv_id = int(model_version_id)
if self._store.get_attributed_artifact(
if self.store.get_attributed_artifact(
ModelArtifact.get_proto_type_name(), mv_id
):
msg = f"Model version with ID {mv_id} already has a model artifact"
raise StoreException(msg)

model_artifact._model_version_id = model_version_id
id = self._store.put_artifact(self._map(model_artifact))
self._store.put_attribution(mv_id, id)
id = self.store.put_artifact(self._map(model_artifact))
self.store.put_attribution(mv_id, id)
new_py_ma = ModelArtifact.unmap(
self._store.get_artifact(ModelArtifact.get_proto_type_name(), id)
self.store.get_artifact(ModelArtifact.get_proto_type_name(), id)
)
id = str(id)
model_artifact.id = id
Expand All @@ -333,9 +340,7 @@ def get_model_artifact_by_id(self, id: str) -> ModelArtifact | None:
Returns:
Model artifact.
"""
proto_ma = self._store.get_artifact(
ModelArtifact.get_proto_type_name(), int(id)
)
proto_ma = self.store.get_artifact(ModelArtifact.get_proto_type_name(), int(id))
if proto_ma is not None:
return ModelArtifact.unmap(proto_ma)

Expand All @@ -357,14 +362,14 @@ def get_model_artifact_by_params(
StoreException: If neither external ID nor model version ID is provided.
"""
if external_id:
proto_ma = self._store.get_artifact(
proto_ma = self.store.get_artifact(
ModelArtifact.get_proto_type_name(), external_id=external_id
)
elif not model_version_id:
msg = "Either model_version_id or external_id must be provided"
raise StoreException(msg)
else:
proto_ma = self._store.get_attributed_artifact(
proto_ma = self.store.get_attributed_artifact(
ModelArtifact.get_proto_type_name(), int(model_version_id)
)
if proto_ma is not None:
Expand All @@ -390,7 +395,7 @@ def get_model_artifacts(
if model_version_id is not None:
mlmd_options.filter_query = f"contexts_a.id = {model_version_id}"

proto_mas = self._store.get_artifacts(
proto_mas = self.store.get_artifacts(
ModelArtifact.get_proto_type_name(), mlmd_options
)
return [ModelArtifact.unmap(proto_ma) for proto_ma in proto_mas]
2 changes: 1 addition & 1 deletion clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore:
@pytest.fixture()
def mr_api(store_wrapper: MLMDStore) -> ModelRegistryAPIClient:
mr = object.__new__(ModelRegistryAPIClient)
mr._store = store_wrapper
mr.store = store_wrapper
return mr


Expand Down
Loading

0 comments on commit 4c66394

Please sign in to comment.