Skip to content

Commit

Permalink
Enable TLS auth on py client (#64)
Browse files Browse the repository at this point in the history
* py: enable TLS auth by default

Signed-off-by: Isabella Basso do Amaral <[email protected]>

* py: add user auth using SA

Signed-off-by: Isabella Basso do Amaral <[email protected]>

* py: provide API builders for secure and insecure connections

Signed-off-by: Isabella do Amaral <[email protected]>

---------

Signed-off-by: Isabella Basso do Amaral <[email protected]>
Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored May 9, 2024
1 parent e37b07f commit c4402c9
Show file tree
Hide file tree
Showing 10 changed files with 320 additions and 133 deletions.
4 changes: 3 additions & 1 deletion clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ This library provides a high level interface for interacting with a model regist
```py
from model_registry import ModelRegistry

registry = ModelRegistry(server_address="server-address", port=9090, author="author")
registry = ModelRegistry("server-address", author="Ada Lovelace") # Defaults to a secure connection via port 443

# registry = ModelRegistry("server-address", 1234, author="Ada Lovelace", is_secure=False) # To use MR without TLS

model = registry.register_model(
"my-model", # model name
Expand Down
54 changes: 42 additions & 12 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import os
from pathlib import Path
from typing import get_args
from warnings import warn

Expand All @@ -17,27 +19,55 @@ class ModelRegistry:
def __init__(
self,
server_address: str,
port: int,
port: int = 443,
*,
author: str,
client_key: str | None = None,
server_cert: str | None = None,
custom_ca: str | None = None,
is_secure: bool = True,
user_token: bytes | None = None,
custom_ca: bytes | None = None,
):
"""Constructor.
Args:
server_address: Server address.
port: Server port.
port: Server port. Defaults to 443.
Keyword Args:
author: Name of the author.
client_key: The PEM-encoded private key as a byte string.
server_cert: The PEM-encoded certificate as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT.
"""
# TODO: get args from env
# TODO: get remaining args from env
self._author = author
self._api = ModelRegistryAPIClient(
server_address, port, client_key, server_cert, custom_ca
)

if not user_token:
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
user_token = Path(sa_token).read_bytes()
else:
warn("User access token is missing", stacklevel=2)

if is_secure:
root_ca = None
if not custom_ca:
if ca_path := os.getenv("CERT"):
root_ca = Path(ca_path).read_bytes()
# client might have a default CA setup
else:
root_ca = custom_ca

self._api = ModelRegistryAPIClient.secure_connection(
server_address, port, user_token, root_ca
)
elif custom_ca:
msg = "Custom CA provided without secure connection"
raise StoreException(msg)
else:
self._api = ModelRegistryAPIClient.insecure_connection(
server_address, port, user_token
)

def _register_model(self, name: str) -> RegisteredModel:
if rm := self._api.get_registered_model_by_params(name):
Expand Down
118 changes: 75 additions & 43 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,79 @@

from __future__ import annotations

from ml_metadata.proto import MetadataStoreClientConfig
from dataclasses import dataclass

import grpc

from .exceptions import StoreException
from .store import MLMDStore, ProtoType
from .types import ListOptions, ModelArtifact, ModelVersion, RegisteredModel
from .types.base import ProtoBase
from .types.options import MLMDListOptions
from .utils import header_adder_interceptor


@dataclass
class ModelRegistryAPIClient:
"""Model registry API."""

def __init__(
self,
store: MLMDStore

@classmethod
def secure_connection(
cls,
server_address: str,
port: int = 443,
user_token: bytes | None = None,
custom_ca: bytes | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Args:
server_address: Server address.
port: Server port. Defaults to 443.
user_token: The PEM-encoded user token as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to GRPC_DEFAULT_SSL_ROOTS_FILE_PATH, then system default.
"""
if not user_token:
msg = "user token must be provided for secure connection"
raise StoreException(msg)

chan = grpc.secure_channel(
f"{server_address}:{port}",
grpc.composite_channel_credentials(
# custom_ca = None will get the default root certificates
grpc.ssl_channel_credentials(custom_ca),
grpc.access_token_call_credentials(user_token),
),
)

return cls(MLMDStore.from_channel(chan))

@classmethod
def insecure_connection(
cls,
server_address: str,
port: int,
client_key: str | None = None,
server_cert: str | None = None,
custom_ca: str | None = None,
):
user_token: bytes | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Args:
server_address: Server address.
port: Server port.
client_key: The PEM-encoded private key as a byte string.
server_cert: The PEM-encoded certificate as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string.
user_token: The PEM-encoded user token as a byte string.
"""
config = MetadataStoreClientConfig()
config.host = server_address
config.port = port
if client_key is not None:
config.ssl_config.client_key = client_key
if server_cert is not None:
config.ssl_config.server_cert = server_cert
if custom_ca is not None:
config.ssl_config.custom_ca = custom_ca
self._store = MLMDStore(config)
if user_token:
chan = grpc.intercept_channel(
grpc.insecure_channel(f"{server_address}:{port}"),
# header key has to be lowercase
header_adder_interceptor("authorization", f"Bearer {user_token}"),
)
else:
chan = grpc.insecure_channel(f"{server_address}:{port}")

return cls(MLMDStore.from_channel(chan))

def _map(self, py_obj: ProtoBase) -> ProtoType:
"""Map a Python object to a proto object.
Expand All @@ -53,7 +87,7 @@ def _map(self, py_obj: ProtoBase) -> ProtoType:
Returns:
Proto object.
"""
type_id = self._store.get_type_id(
type_id = self.store.get_type_id(
py_obj.get_proto_type(), py_obj.get_proto_type_name()
)
return py_obj.map(type_id)
Expand All @@ -70,9 +104,9 @@ def upsert_registered_model(self, registered_model: RegisteredModel) -> str:
Returns:
ID of the registered model.
"""
id = self._store.put_context(self._map(registered_model))
id = self.store.put_context(self._map(registered_model))
new_py_rm = RegisteredModel.unmap(
self._store.get_context(RegisteredModel.get_proto_type_name(), id)
self.store.get_context(RegisteredModel.get_proto_type_name(), id)
)
id = str(id)
registered_model.id = id
Expand All @@ -91,7 +125,7 @@ def get_registered_model_by_id(self, id: str) -> RegisteredModel | None:
Returns:
Registered model.
"""
proto_rm = self._store.get_context(
proto_rm = self.store.get_context(
RegisteredModel.get_proto_type_name(), id=int(id)
)
if proto_rm is not None:
Expand All @@ -117,7 +151,7 @@ def get_registered_model_by_params(
if name is None and external_id is None:
msg = "Either name or external_id must be provided"
raise StoreException(msg)
proto_rm = self._store.get_context(
proto_rm = self.store.get_context(
RegisteredModel.get_proto_type_name(),
name=name,
external_id=external_id,
Expand All @@ -139,7 +173,7 @@ def get_registered_models(
Registered models.
"""
mlmd_options = options.as_mlmd_list_options() if options else MLMDListOptions()
proto_rms = self._store.get_contexts(
proto_rms = self.store.get_contexts(
RegisteredModel.get_proto_type_name(), mlmd_options
)
return [RegisteredModel.unmap(proto_rm) for proto_rm in proto_rms]
Expand All @@ -161,10 +195,10 @@ def upsert_model_version(
"""
# this is not ideal but we need this info for the prefix
model_version._registered_model_id = registered_model_id
id = self._store.put_context(self._map(model_version))
self._store.put_context_parent(int(registered_model_id), id)
id = self.store.put_context(self._map(model_version))
self.store.put_context_parent(int(registered_model_id), id)
new_py_mv = ModelVersion.unmap(
self._store.get_context(ModelVersion.get_proto_type_name(), id)
self.store.get_context(ModelVersion.get_proto_type_name(), id)
)
id = str(id)
model_version.id = id
Expand All @@ -183,7 +217,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:
Returns:
Model version.
"""
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(), id=int(model_version_id)
)
if proto_mv is not None:
Expand All @@ -207,7 +241,7 @@ def get_model_versions(
mlmd_options.filter_query = f"parent_contexts_a.id = {registered_model_id}"
return [
ModelVersion.unmap(proto_mv)
for proto_mv in self._store.get_contexts(
for proto_mv in self.store.get_contexts(
ModelVersion.get_proto_type_name(), mlmd_options
)
]
Expand All @@ -234,7 +268,7 @@ def get_model_version_by_params(
StoreException: If neither external ID nor registered model ID and version is provided.
"""
if external_id is not None:
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(), external_id=external_id
)
elif registered_model_id is None or version is None:
Expand All @@ -243,7 +277,7 @@ def get_model_version_by_params(
)
raise StoreException(msg)
else:
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(),
name=f"{registered_model_id}:{version}",
)
Expand Down Expand Up @@ -271,17 +305,17 @@ def upsert_model_artifact(
StoreException: If the model version already has a model artifact.
"""
mv_id = int(model_version_id)
if self._store.get_attributed_artifact(
if self.store.get_attributed_artifact(
ModelArtifact.get_proto_type_name(), mv_id
):
msg = f"Model version with ID {mv_id} already has a model artifact"
raise StoreException(msg)

model_artifact._model_version_id = model_version_id
id = self._store.put_artifact(self._map(model_artifact))
self._store.put_attribution(mv_id, id)
id = self.store.put_artifact(self._map(model_artifact))
self.store.put_attribution(mv_id, id)
new_py_ma = ModelArtifact.unmap(
self._store.get_artifact(ModelArtifact.get_proto_type_name(), id)
self.store.get_artifact(ModelArtifact.get_proto_type_name(), id)
)
id = str(id)
model_artifact.id = id
Expand All @@ -300,9 +334,7 @@ def get_model_artifact_by_id(self, id: str) -> ModelArtifact | None:
Returns:
Model artifact.
"""
proto_ma = self._store.get_artifact(
ModelArtifact.get_proto_type_name(), int(id)
)
proto_ma = self.store.get_artifact(ModelArtifact.get_proto_type_name(), int(id))
if proto_ma is not None:
return ModelArtifact.unmap(proto_ma)

Expand All @@ -324,14 +356,14 @@ def get_model_artifact_by_params(
StoreException: If neither external ID nor model version ID is provided.
"""
if external_id:
proto_ma = self._store.get_artifact(
proto_ma = self.store.get_artifact(
ModelArtifact.get_proto_type_name(), external_id=external_id
)
elif not model_version_id:
msg = "Either model_version_id or external_id must be provided"
raise StoreException(msg)
else:
proto_ma = self._store.get_attributed_artifact(
proto_ma = self.store.get_attributed_artifact(
ModelArtifact.get_proto_type_name(), int(model_version_id)
)
if proto_ma is not None:
Expand All @@ -357,7 +389,7 @@ def get_model_artifacts(
if model_version_id is not None:
mlmd_options.filter_query = f"contexts_a.id = {model_version_id}"

proto_mas = self._store.get_artifacts(
proto_mas = self.store.get_artifacts(
ModelArtifact.get_proto_type_name(), mlmd_options
)
return [ModelArtifact.unmap(proto_ma) for proto_ma in proto_mas]
Loading

0 comments on commit c4402c9

Please sign in to comment.