Skip to content

Commit

Permalink
py: add user auth using SA
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Apr 24, 2024
1 parent 192e0df commit 6383837
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 82 deletions.
6 changes: 4 additions & 2 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
author: str,
server_address: str,
port: int = 443,
user_token: str | None = None,
custom_ca: str | None = None,
):
"""Constructor.
Expand All @@ -27,11 +28,12 @@ def __init__(
author: Name of the author.
server_address: Server address.
port: Server port. Defaults to 443.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT.
"""
# TODO: get args from env
self._author = author
self._api = ModelRegistryAPIClient(server_address, port, custom_ca)
self._api = ModelRegistryAPIClient(server_address, port, user_token, custom_ca)

def _register_model(self, name: str) -> RegisteredModel:
if rm := self._api.get_registered_model_by_params(name):
Expand Down
40 changes: 33 additions & 7 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import os
from pathlib import Path

from ml_metadata.proto import MetadataStoreClientConfig
import grpc

from .exceptions import StoreException
from .store import MLMDStore, ProtoType
from .types import ListOptions, ModelArtifact, ModelVersion, RegisteredModel
from .types.base import ProtoBase
from .types.options import MLMDListOptions
from .utils import header_adder_interceptor


class ModelRegistryAPIClient:
Expand All @@ -21,16 +22,27 @@ def __init__(
self,
server_address: str,
port: int = 443,
user_token: str | None = None,
custom_ca: str | None = None,
):
"""Constructor.
Args:
server_address: Server address.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
port: Server port. Defaults to 443.
user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
"""
config = MetadataStoreClientConfig()
if not user_token:
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if not sa_token:
msg = "Access token must be provided"
raise StoreException(msg)
token = Path(sa_token).read_bytes()
else:
token = user_token

if port == 443:
if not custom_ca:
ca_cert = os.environ.get("CERT")
Expand All @@ -40,11 +52,25 @@ def __init__(
root_certs = Path(ca_cert).read_bytes()
else:
root_certs = custom_ca
channel_credentials = grpc.ssl_channel_credentials(root_certs)

config.ssl_config.custom_ca = root_certs
config.host = server_address
config.port = port
self._store = MLMDStore(config)
call_credentials = grpc.access_token_call_credentials(token)
composite_credentials = grpc.composite_channel_credentials(
channel_credentials,
call_credentials,
)
channel = grpc.secure_channel(
f"{server_address}:443",
composite_credentials,
)
self._store = MLMDStore.from_channel(channel)
else:
intercepted = grpc.intercept_channel(
grpc.insecure_channel(f"{server_address}:{port}"),
# header key has to be lowercase
header_adder_interceptor("authorization", f"Bearer {token}"),
)
self._store = MLMDStore.from_channel(intercepted)

def _map(self, py_obj: ProtoBase) -> ProtoType:
"""Map a Python object to a proto object.
Expand Down
65 changes: 46 additions & 19 deletions clients/python/src/model_registry/store/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import ClassVar

from grpc import Channel
from ml_metadata import errors
from ml_metadata.metadata_store import ListOptions, MetadataStore
from ml_metadata.proto import (
Expand All @@ -14,6 +16,7 @@
MetadataStoreClientConfig,
ParentContext,
)
from ml_metadata.proto.metadata_store_service_pb2_grpc import MetadataStoreServiceStub

from model_registry.exceptions import (
DuplicateException,
Expand All @@ -25,19 +28,43 @@
from .base import ProtoType


@dataclass
class MLMDStore:
"""MLMD storage backend."""

store: MetadataStore
# cache for MLMD type IDs
_type_ids: ClassVar[dict[str, int]] = {}

def __init__(self, config: MetadataStoreClientConfig):
@classmethod
def from_config(cls, host: str, port: int):
"""Constructor.
Args:
config: MLMD config.
host: MLMD store server host.
port: MLMD store server port.
"""
self._mlmd_store = MetadataStore(config)
return cls(
MetadataStore(
MetadataStoreClientConfig(
host=host,
port=port,
)
)
)

@classmethod
def from_channel(cls, chan: Channel):
"""Constructor.
Args:
chan: gRPC channel to the MLMD store.
"""
store = MetadataStore(
MetadataStoreClientConfig(host="localhost", port=8080),
)
store._metadata_store_stub = MetadataStoreServiceStub(chan)
return cls(store)

def get_type_id(self, pt: type[ProtoType], type_name: str) -> int:
"""Get backend ID for a type.
Expand All @@ -59,7 +86,7 @@ def get_type_id(self, pt: type[ProtoType], type_name: str) -> int:
pt_name = pt.__name__.lower()

try:
_type = getattr(self._mlmd_store, f"get_{pt_name}_type")(type_name)
_type = getattr(self.store, f"get_{pt_name}_type")(type_name)
except errors.NotFoundError as e:
msg = f"{pt_name} type {type_name} does not exist"
raise TypeNotFoundException(msg) from e
Expand All @@ -85,7 +112,7 @@ def put_artifact(self, artifact: Artifact) -> int:
StoreException: If the artifact isn't properly formed.
"""
try:
return self._mlmd_store.put_artifacts([artifact])[0]
return self.store.put_artifacts([artifact])[0]
except errors.AlreadyExistsError as e:
msg = f"Artifact {artifact.name} already exists"
raise DuplicateException(msg) from e
Expand All @@ -111,7 +138,7 @@ def put_context(self, context: Context) -> int:
StoreException: If the context isn't propertly formed.
"""
try:
return self._mlmd_store.put_contexts([context])[0]
return self.store.put_contexts([context])[0]
except errors.AlreadyExistsError as e:
msg = f"Context {context.name} already exists"
raise DuplicateException(msg) from e
Expand Down Expand Up @@ -152,12 +179,12 @@ def get_context(
StoreException: Invalid arguments.
"""
if name is not None:
return self._mlmd_store.get_context_by_type_and_name(ctx_type_name, name)
return self.store.get_context_by_type_and_name(ctx_type_name, name)

if id is not None:
contexts = self._mlmd_store.get_contexts_by_id([id])
contexts = self.store.get_contexts_by_id([id])
elif external_id is not None:
contexts = self._mlmd_store.get_contexts_by_external_ids([external_id])
contexts = self.store.get_contexts_by_external_ids([external_id])
else:
msg = "Either id, name or external_id must be provided"
raise StoreException(msg)
Expand Down Expand Up @@ -186,7 +213,7 @@ def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context
# TODO: should we make options optional?
# if options is not None:
try:
contexts = self._mlmd_store.get_contexts(options)
contexts = self.store.get_contexts(options)
except errors.InvalidArgumentError as e:
msg = f"Invalid arguments for get_contexts: {e}"
raise StoreException(msg) from e
Expand All @@ -199,7 +226,7 @@ def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context
# contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name)

if not contexts and ctx_type_name not in [
t.name for t in self._mlmd_store.get_context_types()
t.name for t in self.store.get_context_types()
]:
msg = f"Context type {ctx_type_name} does not exist"
raise TypeNotFoundException(msg)
Expand All @@ -218,7 +245,7 @@ def put_context_parent(self, parent_id: int, child_id: int):
ServerException: If there was an error putting the parent context.
"""
try:
self._mlmd_store.put_parent_contexts(
self.store.put_parent_contexts(
[ParentContext(parent_id=parent_id, child_id=child_id)]
)
except errors.AlreadyExistsError as e:
Expand All @@ -240,7 +267,7 @@ def put_attribution(self, context_id: int, artifact_id: int):
"""
attribution = Attribution(context_id=context_id, artifact_id=artifact_id)
try:
self._mlmd_store.put_attributions_and_associations([attribution], [])
self.store.put_attributions_and_associations([attribution], [])
except errors.InvalidArgumentError as e:
if "artifact" in str(e).lower():
msg = f"Artifact with ID {artifact_id} does not exist"
Expand Down Expand Up @@ -277,12 +304,12 @@ def get_artifact(
StoreException: Invalid arguments.
"""
if name is not None:
return self._mlmd_store.get_artifact_by_type_and_name(art_type_name, name)
return self.store.get_artifact_by_type_and_name(art_type_name, name)

if id is not None:
artifacts = self._mlmd_store.get_artifacts_by_id([id])
artifacts = self.store.get_artifacts_by_id([id])
elif external_id is not None:
artifacts = self._mlmd_store.get_artifacts_by_external_ids([external_id])
artifacts = self.store.get_artifacts_by_external_ids([external_id])
else:
msg = "Either id, name or external_id must be provided"
raise StoreException(msg)
Expand All @@ -304,7 +331,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact:
Artifact.
"""
try:
artifacts = self._mlmd_store.get_artifacts_by_context(ctx_id)
artifacts = self.store.get_artifacts_by_context(ctx_id)
except errors.InternalError as e:
msg = f"Couldn't get artifacts by context {ctx_id}"
raise ServerException(msg) from e
Expand All @@ -330,7 +357,7 @@ def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifa
StoreException: Invalid arguments.
"""
try:
artifacts = self._mlmd_store.get_artifacts(options)
artifacts = self.store.get_artifacts(options)
except errors.InvalidArgumentError as e:
msg = f"Invalid arguments for get_artifacts: {e}"
raise StoreException(msg) from e
Expand All @@ -340,7 +367,7 @@ def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifa

artifacts = self._filter_type(art_type_name, artifacts)
if not artifacts and art_type_name not in [
t.name for t in self._mlmd_store.get_artifact_types()
t.name for t in self.store.get_artifact_types()
]:
msg = f"Artifact type {art_type_name} does not exist"
raise TypeNotFoundException(msg)
Expand Down
84 changes: 84 additions & 0 deletions clients/python/src/model_registry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from __future__ import annotations

import os
from collections import namedtuple
from typing import Callable

import grpc
from attr import dataclass
from typing_extensions import overload

from ._utils import required_args
Expand Down Expand Up @@ -90,3 +94,83 @@ def s3_uri_from(
# https://alexwlchan.net/2020/s3-keys-are-not-file-paths/ nor do they resolve to valid URls
# FIXME: is this safe?
return f"s3://{bucket}/{path}?endpoint={endpoint}&defaultRegion={region}"


# https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/generic_client_interceptor.py
@dataclass
class GenericClientInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
fn: Callable

def intercept_unary_unary(self, continuation, client_call_details, request):
new_details, new_request_iterator, postprocess = self.fn(
client_call_details, iter((request,)), False, False
)
response = continuation(new_details, next(new_request_iterator))
return postprocess(response) if postprocess else response

def intercept_unary_stream(self, continuation, client_call_details, request):
new_details, new_request_iterator, postprocess = self.fn(
client_call_details, iter((request,)), False, True
)
response_it = continuation(new_details, next(new_request_iterator))
return postprocess(response_it) if postprocess else response_it

def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
new_details, new_request_iterator, postprocess = self.fn(
client_call_details, request_iterator, True, False
)
response = continuation(new_details, new_request_iterator)
return postprocess(response) if postprocess else response

def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
new_details, new_request_iterator, postprocess = self.fn(
client_call_details, request_iterator, True, True
)
response_it = continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it


# https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/header_manipulator_client_interceptor.py
# we need to subclass ClientCallDetails to add a constructor (it's ABC)
class ClientCallDetails(
namedtuple("ClientCallDetails", ("method", "timeout", "metadata", "credentials")),
grpc.ClientCallDetails,
):
pass


def header_adder_interceptor(header, value):
def intercept_call(
client_call_details,
request_iterator,
request_streaming,
response_streaming,
):
metadata = list(client_call_details.metadata or [])
metadata.append(
(
header,
value,
)
)
return (
ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
),
request_iterator,
None,
)

return GenericClientInterceptor(intercept_call)
Loading

0 comments on commit 6383837

Please sign in to comment.