diff --git a/tfx/dsl/input_resolution/ops/latest_policy_model_op.py b/tfx/dsl/input_resolution/ops/latest_policy_model_op.py index 70e7dfcb9c7..ac061466fb6 100644 --- a/tfx/dsl/input_resolution/ops/latest_policy_model_op.py +++ b/tfx/dsl/input_resolution/ops/latest_policy_model_op.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for LatestPolicyModel operator.""" + import collections import enum -from typing import Dict, List +from typing import Dict, List, Optional, Tuple from tfx import types from tfx.dsl.input_resolution import resolver_op @@ -24,6 +25,7 @@ from tfx.orchestration.portable.mlmd import event_lib from tfx.orchestration.portable.mlmd import filter_query_builder as q from tfx.types import artifact_utils +from tfx.types import external_artifact_utils from tfx.utils import typing_utils from ml_metadata.proto import metadata_store_pb2 @@ -204,6 +206,33 @@ def _build_result_dictionary( return result +def _dedpupe_model_artifacts( + models: Optional[List[artifact_utils.Artifact]], +) -> Tuple[List[artifact_utils.Artifact], List[int]]: + """Dedupes a list of Model artifacts.""" + if not models: + return [], [] + + model_by_external_id = {} + model_by_id = {} + + for m in models: + if m.external_id: + model_by_external_id[m.external_id] = m + else: + model_by_id[m.id] = m + + deduped_models = list(model_by_external_id.values()) + list( + model_by_id.values() + ) + model_artifact_ids = [ + external_artifact_utils.get_id_from_external_id(i) + for i in model_by_external_id.keys() + ] + list(model_by_id.keys()) + + return (deduped_models, model_artifact_ids) + + class LatestPolicyModel( resolver_op.ResolverOp, canonical_name='tfx.LatestPolicyModel', @@ -325,6 +354,25 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap): if self.policy == Policy.LATEST_EXPORTED: return {ops_utils.MODEL_KEY: [models[0]]} + are_models_external = [m.is_external for m in models] + if any(are_models_external) and not all(are_models_external): + raise exceptions.InvalidArgument( + 'Inputs to the LastestPolicyModel are from both current pipeline and' + ' external pipeline. LastestPolicyModel does not support such usage.' + ) + if all(are_models_external): + pipeline_assets = set([ + external_artifact_utils.get_pipeline_asset_from_external_id( + m.mlmd_artifact.external_id + ) + for m in models + ]) + if len(pipeline_assets) != 1: + raise exceptions.InvalidArgument( + 'Input models to the LastestPolicyModel are from multiple' + ' pipelines. LastestPolicyModel does not support such usage.' + ) + # If ModelBlessing and/or ModelInfraBlessing artifacts were included in # input_dict, then we will only consider those child artifacts. specifies_child_artifacts = ( @@ -334,7 +382,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap): input_child_artifacts = input_dict.get( ops_utils.MODEL_BLESSSING_KEY, [] ) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, []) - input_child_artifact_ids = set([a.id for a in input_child_artifacts]) + + input_child_artifact_ids = set() + for a in input_child_artifacts: + if a.is_external: + input_child_artifact_ids.add( + external_artifact_utils.get_id_from_external_id( + a.mlmd_artifact.external_id + ) + ) + else: + input_child_artifact_ids.add(a.id) # If the ModelBlessing and ModelInfraBlessing lists are empty, then no # child artifacts can be considered and we raise a SkipSignal. This can @@ -362,8 +420,8 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap): # There could be multiple events with the same execution ID but different # artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we - # keep the values of model_artifact_ids_by_execution_id as sets. - model_artifact_ids = sorted(set(m.id for m in models)) + # need to deduplicate the Model artifacts. + deduped_models, model_artifact_ids = _dedpupe_model_artifacts(models) downstream_artifact_type_names_filter_query = q.to_sql_string([ ops_utils.MODEL_BLESSING_TYPE_NAME, @@ -407,10 +465,13 @@ def event_filter(event): else: return event_lib.is_valid_output_event(event) - mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store) + mlmd_resolver = metadata_resolver.MetadataResolver( + self.context.store, + mlmd_connection_manager=self.context.mlmd_connection_manager, + ) # Populate the ModelRelations associated with each Model artifact and its # children. - model_relations_by_model_artifact_id = collections.defaultdict( + model_relations_by_model_identifier = collections.defaultdict( ModelRelations ) artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {} @@ -419,34 +480,35 @@ def event_filter(event): # fetching downstream artifacts, because # `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids # as starting artifact ids. - for id_index in range(0, len(model_artifact_ids), ops_utils.BATCH_SIZE): - batch_model_artifact_ids = model_artifact_ids[ + for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE): + batch_model_artifacts = deduped_models[ id_index : id_index + ops_utils.BATCH_SIZE ] # Set `max_num_hops` to 50, which should be enough for this use case. - batch_downstream_artifacts_and_types_by_model_ids = ( - mlmd_resolver.get_downstream_artifacts_by_artifact_ids( - batch_model_artifact_ids, + batch_downstream_artifacts_and_types_by_model_identifier = ( + mlmd_resolver.get_downstream_artifacts_by_artifacts( + batch_model_artifacts, max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS, filter_query=filter_query, event_filter=event_filter, ) ) + for ( - model_artifact_id, + model_identifier, artifacts_and_types, - ) in batch_downstream_artifacts_and_types_by_model_ids.items(): + ) in batch_downstream_artifacts_and_types_by_model_identifier.items(): for downstream_artifact, artifact_type in artifacts_and_types: artifact_type_by_name[artifact_type.name] = artifact_type - model_relations = model_relations_by_model_artifact_id[ - model_artifact_id - ] - model_relations.add_downstream_artifact(downstream_artifact) + model_relations_by_model_identifier[ + model_identifier + ].add_downstream_artifact(downstream_artifact) # Find the latest model and ModelRelations that meets the Policy. result = {} for model in models: - model_relations = model_relations_by_model_artifact_id[model.id] + identifier = external_artifact_utils.identifier(model) + model_relations = model_relations_by_model_identifier[identifier] if model_relations.meets_policy(self.policy): result[ops_utils.MODEL_KEY] = [model] break diff --git a/tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py b/tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py index 20083c3a624..45cc8d37b5a 100644 --- a/tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py +++ b/tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for tfx.dsl.input_resolution.ops.latest_policy_model_op.""" +import os from typing import Dict, List, Optional +from unittest import mock from absl.testing import parameterized import tensorflow as tf @@ -22,6 +24,7 @@ from tfx.dsl.input_resolution.ops import ops from tfx.dsl.input_resolution.ops import ops_utils from tfx.dsl.input_resolution.ops import test_utils +from tfx.orchestration import metadata from tfx.orchestration.portable.input_resolution import exceptions from ml_metadata.proto import metadata_store_pb2 @@ -146,6 +149,7 @@ def _run_latest_policy_model(self, *args, **kwargs): args=args, kwargs=kwargs, store=self.store, + mlmd_handle_like=self.mlmd_cm, ) def setUp(self): @@ -158,6 +162,7 @@ def setUp(self): self.artifacts = [self.model_1, self.model_2, self.model_3] + def assertDictKeysEmpty( self, output_dict: Dict[str, List[types.Artifact]], diff --git a/tfx/dsl/input_resolution/ops/test_utils.py b/tfx/dsl/input_resolution/ops/test_utils.py index 55d5811b931..1ab3ce09088 100644 --- a/tfx/dsl/input_resolution/ops/test_utils.py +++ b/tfx/dsl/input_resolution/ops/test_utils.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Testing utility for builtin resolver ops.""" -from typing import Type, Any, Dict, List, Optional, Sequence, Tuple, Union, Mapping +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union from unittest import mock from absl.testing import parameterized - from tfx import types from tfx.dsl.compiler import compiler_context from tfx.dsl.compiler import node_inputs_compiler @@ -27,6 +26,7 @@ from tfx.dsl.input_resolution import resolver_op from tfx.dsl.input_resolution.ops import ops_utils from tfx.orchestration import pipeline +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.proto.orchestration import pipeline_pb2 from tfx.types import artifact as tfx_artifact from tfx.types import artifact_utils @@ -201,6 +201,7 @@ def prepare_tfx_artifact( properties: Optional[Dict[str, Union[int, str]]] = None, custom_properties: Optional[Dict[str, Union[int, str]]] = None, state: metadata_store_pb2.Artifact.State = metadata_store_pb2.Artifact.State.LIVE, + connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None, ) -> types.Artifact: """Adds a single artifact to MLMD and returns the TFleX Artifact object.""" mlmd_artifact = self.put_artifact( @@ -208,8 +209,11 @@ def prepare_tfx_artifact( properties=properties, custom_properties=custom_properties, state=state, + connection_config=connection_config, ) - artifact_type = self.store.get_artifact_type(artifact.TYPE_NAME) + + store = self.get_store(connection_config) + artifact_type = store.get_artifact_type(artifact.TYPE_NAME) return artifact_utils.deserialize_artifact(artifact_type, mlmd_artifact) def unwrap_tfx_artifacts( @@ -222,10 +226,13 @@ def build_node_context( self, pipeline_name: str, node_id: str, + connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None, ): """Returns a "node" Context with name "pipeline_name.node_id.""" context = self.put_context( - context_type='node', context_name=f'{pipeline_name}.{node_id}' + context_type='node', + context_name=f'{pipeline_name}.{node_id}', + connection_config=connection_config, ) return context @@ -233,20 +240,24 @@ def create_examples( self, spans_and_versions: Sequence[Tuple[int, int]], contexts: Sequence[metadata_store_pb2.Context] = (), + connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None, ) -> List[types.Artifact]: """Build Examples artifacts and add an ExampleGen execution to MLMD.""" examples = [] for span, version in spans_and_versions: examples.append( self.prepare_tfx_artifact( - Examples, properties={'span': span, 'version': version} - ) + Examples, + properties={'span': span, 'version': version}, + connection_config=connection_config, + ), ) self.put_execution( 'ExampleGen', inputs={}, outputs={'examples': self.unwrap_tfx_artifacts(examples)}, contexts=contexts, + connection_config=connection_config, ) return examples @@ -254,9 +265,12 @@ def transform_examples( self, examples: List[types.Artifact], contexts: Sequence[metadata_store_pb2.Context] = (), + connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None, ) -> types.Artifact: inputs = {'examples': self.unwrap_tfx_artifacts(examples)} - transform_graph = self.prepare_tfx_artifact(TransformGraph) + transform_graph = self.prepare_tfx_artifact( + TransformGraph, connection_config=connection_config + ) self.put_execution( 'Transform', inputs=inputs, @@ -264,6 +278,7 @@ def transform_examples( 'transform_graph': self.unwrap_tfx_artifacts([transform_graph]) }, contexts=contexts, + connection_config=connection_config, ) return transform_graph @@ -273,6 +288,7 @@ def train_on_examples( examples: List[types.Artifact], transform_graph: Optional[types.Artifact] = None, contexts: Sequence[metadata_store_pb2.Context] = (), + connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None, ): """Add an Execution to MLMD where a Trainer trains on the examples.""" inputs = {'examples': self.unwrap_tfx_artifacts(examples)} @@ -283,6 +299,7 @@ def train_on_examples( inputs=inputs, outputs={'model': self.unwrap_tfx_artifacts([model])}, contexts=contexts, + connection_config=connection_config, ) def evaluator_bless_model( @@ -291,10 +308,13 @@ def evaluator_bless_model( blessed: bool = True, baseline_model: Optional[types.Artifact] = None, contexts: Sequence[metadata_store_pb2.Context] = (), + connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None, ) -> types.Artifact: """Add an Execution to MLMD where the Evaluator blesses the model.""" model_blessing = self.prepare_tfx_artifact( - ModelBlessing, custom_properties={'blessed': int(blessed)} + ModelBlessing, + custom_properties={'blessed': int(blessed)}, + connection_config=connection_config, ) inputs = {'model': self.unwrap_tfx_artifacts([model])} @@ -306,6 +326,7 @@ def evaluator_bless_model( inputs=inputs, outputs={'blessing': self.unwrap_tfx_artifacts([model_blessing])}, contexts=contexts, + connection_config=connection_config, ) return model_blessing @@ -315,6 +336,7 @@ def infra_validator_bless_model( model: types.Artifact, blessed: bool = True, contexts: Sequence[metadata_store_pb2.Context] = (), + connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None, ) -> types.Artifact: """Add an Execution to MLMD where the InfraValidator blesses the model.""" if blessed: @@ -322,7 +344,9 @@ def infra_validator_bless_model( else: custom_properties = {'blessing_status': 'INFRA_NOT_BLESSED'} model_infra_blessing = self.prepare_tfx_artifact( - ModelInfraBlessing, custom_properties=custom_properties + ModelInfraBlessing, + custom_properties=custom_properties, + connection_config=connection_config, ) self.put_execution( @@ -330,6 +354,7 @@ def infra_validator_bless_model( inputs={'model': self.unwrap_tfx_artifacts([model])}, outputs={'result': self.unwrap_tfx_artifacts([model_infra_blessing])}, contexts=contexts, + connection_config=connection_config, ) return model_infra_blessing @@ -339,15 +364,19 @@ def push_model( model: types.Artifact, model_push: Optional[types.Artifact] = None, contexts: Sequence[metadata_store_pb2.Context] = (), + connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None, ): """Add an Execution to MLMD where the Pusher pushes the model.""" if model_push is None: - model_push = self.prepare_tfx_artifact(ModelPush) + model_push = self.prepare_tfx_artifact( + ModelPush, connection_config=connection_config + ) self.put_execution( 'ServomaticPusher', inputs={'model_export': self.unwrap_tfx_artifacts([model])}, outputs={'model_push': self.unwrap_tfx_artifacts([model_push])}, contexts=contexts, + connection_config=connection_config, ) return model_push @@ -370,6 +399,7 @@ def strict_run_resolver_op( args: Tuple[Any, ...], kwargs: Mapping[str, Any], store: Optional[mlmd.MetadataStore] = None, + mlmd_handle_like: Optional[mlmd_cm.HandleLike] = None, ): """Runs ResolverOp with strict type checking.""" if len(args) != len(op_type.arg_data_types): @@ -396,7 +426,8 @@ def strict_run_resolver_op( context = resolver_op.Context( store=store if store is not None - else mock.MagicMock(spec=mlmd.MetadataStore) + else mock.MagicMock(spec=mlmd.MetadataStore), + mlmd_handle_like=mlmd_handle_like, ) op.set_context(context) result = op.apply(*args) diff --git a/tfx/dsl/input_resolution/resolver_op.py b/tfx/dsl/input_resolution/resolver_op.py index 8594d93b6db..964016a5a54 100644 --- a/tfx/dsl/input_resolution/resolver_op.py +++ b/tfx/dsl/input_resolution/resolver_op.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for ResolverOp and its related definitions.""" + from __future__ import annotations import abc -from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union +from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union, cast import attr from tfx import types +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import json_utils from tfx.utils import typing_utils @@ -28,11 +30,31 @@ # Mark frozen as context instance may be used across multiple operator # invocations. -@attr.s(auto_attribs=True, frozen=True, kw_only=True) class Context: """Context for running ResolverOp.""" - # MetadataStore for MLMD read access. - store: mlmd.MetadataStore + + def __init__( + self, + store=mlmd.MetadataStore, + mlmd_handle_like: Optional[mlmd_cm.HandleLike] = None, + ): + # TODO(b/302730333) We could remove self._store, and only use + # self._mlmd_handle_like. Keeping it for now to preserve backward + # compatibility with other resolve ops. + self._store = store + self._mlmd_handle_like = mlmd_handle_like + + @property + def store(self): + return self._store + + @property + def mlmd_connection_manager(self): + if isinstance(self._mlmd_handle_like, mlmd_cm.MLMDConnectionManager): + return cast(mlmd_cm.MLMDConnectionManager, self._mlmd_handle_like) + else: + return None + # TODO(jjong): Add more context such as current pipeline, current pipeline # run, and current running node information. diff --git a/tfx/orchestration/portable/input_resolution/input_graph_resolver.py b/tfx/orchestration/portable/input_resolution/input_graph_resolver.py index 5c6e04a9a94..667b224a7f7 100644 --- a/tfx/orchestration/portable/input_resolution/input_graph_resolver.py +++ b/tfx/orchestration/portable/input_resolution/input_graph_resolver.py @@ -29,14 +29,14 @@ import collections import dataclasses import functools -from typing import Union, Sequence, Mapping, Tuple, List, Iterable, Callable +from typing import Callable, Iterable, List, Mapping, Sequence, Tuple, Union from tfx import types from tfx.dsl.components.common import resolver from tfx.dsl.input_resolution import resolver_op from tfx.dsl.input_resolution.ops import ops from tfx.orchestration import data_types_utils -from tfx.orchestration import metadata +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.orchestration.portable.input_resolution import exceptions from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import topsort @@ -52,8 +52,12 @@ @dataclasses.dataclass class _Context: - mlmd_handle: metadata.Metadata input_graph: pipeline_pb2.InputGraph + mlmd_handle_like: mlmd_cm.HandleLike + + @property + def mlmd_handle(self): + return mlmd_cm.get_handle(self.mlmd_handle_like) def _topologically_sorted_node_ids( @@ -131,7 +135,12 @@ def _evaluate_op_node( f'nodes[{node_id}] has unknown op_type {op_node.op_type}.') from e if issubclass(op_type, resolver_op.ResolverOp): op: resolver_op.ResolverOp = op_type.create(**kwargs) - op.set_context(resolver_op.Context(store=ctx.mlmd_handle.store)) + op.set_context( + resolver_op.Context( + store=mlmd_cm.get_handle(ctx.mlmd_handle_like).store, + mlmd_handle_like=ctx.mlmd_handle_like, + ) + ) return op.apply(*args) elif issubclass(op_type, resolver.ResolverStrategy): if len(args) != 1: @@ -207,7 +216,7 @@ def new_graph_fn(data: Mapping[str, _Data]): def build_graph_fn( - mlmd_handle: metadata.Metadata, + handle_like: mlmd_cm.HandleLike, input_graph: pipeline_pb2.InputGraph, ) -> Tuple[_GraphFn, List[str]]: """Build a functional interface for the `input_graph`. @@ -222,7 +231,7 @@ def build_graph_fn( z = graph_fn({'x': inputs['x'], 'y': inputs['y']}) Args: - mlmd_handle: A `Metadata` instance. + handle_like: A `mlmd_cm.HandleLike` instance. input_graph: An `pipeline_pb2.InputGraph` proto. Returns: @@ -235,7 +244,7 @@ def build_graph_fn( f'result_node {input_graph.result_node} does not exist in input_graph. ' f'Valid node ids: {list(input_graph.nodes.keys())}') - context = _Context(mlmd_handle=mlmd_handle, input_graph=input_graph) + context = _Context(mlmd_handle_like=handle_like, input_graph=input_graph) input_key_to_node_id = {} for node_id in input_graph.nodes: diff --git a/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver.py b/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver.py index c0e069f31f5..2aa52031d9e 100644 --- a/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver.py +++ b/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver.py @@ -13,9 +13,12 @@ # limitations under the License. """Metadata resolver for reasoning about metadata information.""" -from typing import Callable, Dict, List, Optional, Tuple +import collections +from typing import Callable, Dict, List, Optional, Tuple, Union +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.orchestration.portable.input_resolution.mlmd_resolver import metadata_resolver_utils +from tfx.types import external_artifact_utils import ml_metadata as mlmd from ml_metadata.proto import metadata_store_pb2 @@ -53,8 +56,148 @@ class MetadataResolver: ) """ - def __init__(self, store: mlmd.MetadataStore): + def __init__( + self, + store: mlmd.MetadataStore, + mlmd_connection_manager: Optional[mlmd_cm.MLMDConnectionManager] = None, + ): self._store = store + self._mlmd_connection_manager = mlmd_connection_manager + + # TODO(b/302730333) Write a function get_upstream_artifacts_by_artifacts(), + # which is similar to get_downstream_artifacts_by_artifacts(). + + # TODO(b/302730333) Write unit tests for the new functions. + + def get_downstream_artifacts_by_artifacts( + self, + artifacts: List[metadata_store_pb2.Artifact], + max_num_hops: int = _MAX_NUM_HOPS, + filter_query: str = '', + event_filter: Optional[Callable[[metadata_store_pb2.Event], bool]] = None, + ) -> Dict[ + Union[str, int], + List[Tuple[metadata_store_pb2.Artifact, metadata_store_pb2.ArtifactType]], + ]: + """Given a list of artifacts, get their provenance successor artifacts. + + For each artifact matched by a given `artifact_id`, treat it as a starting + artifact and get artifacts that are connected to them within `max_num_hops` + via a path in the downstream direction like: + artifact_i -> INPUT_event -> execution_j -> OUTPUT_event -> artifact_k. + + A hop is defined as a jump to the next node following the path of node + -> event -> next_node. + For example, in the lineage graph artifact_1 -> event -> execution_1 + -> event -> artifact_2: + artifact_2 is 2 hops away from artifact_1, and execution_1 is 1 hop away + from artifact_1. + + Args: + artifacts: a list of starting artifacts. At most 100 ids are supported. + Returns empty result if `artifact_ids` is empty. + max_num_hops: maximum number of hops performed for downstream tracing. + `max_num_hops` cannot exceed 100 nor be negative. + filter_query: a query string filtering downstream artifacts by their own + attributes or the attributes of immediate neighbors. Please refer to + go/mlmd-filter-query-guide for more detailed guidance. Note: if + `filter_query` is specified and `max_num_hops` is 0, it's equivalent + to getting filtered artifacts by artifact ids with `get_artifacts()`. + event_filter: an optional callable object for filtering events in the + paths towards the downstream artifacts. Only an event with + `event_filter(event)` evaluated to True will be considered as valid + and kept in the path. + + Returns: + Mapping of artifact ids to a list of downstream artifacts. + """ + if not artifacts: + return {} + + # Precondition check. + if len(artifacts) > _MAX_NUM_STARTING_NODES: + raise ValueError( + 'Number of artifacts is larger than supported value of %d.' + % _MAX_NUM_STARTING_NODES + ) + if max_num_hops > _MAX_NUM_HOPS or max_num_hops < 0: + raise ValueError( + 'Number of hops %d is larger than supported value of %d or is' + ' negative.' % (max_num_hops, _MAX_NUM_HOPS) + ) + + internal_artifact_ids = [a.id for a in artifacts if not a.external_id] + external_artifact_ids = [a.external_id for a in artifacts if a.external_id] + + if not external_artifact_ids: + return self.get_downstream_artifacts_by_artifact_ids( + internal_artifact_ids, max_num_hops, filter_query, event_filter + ) + + if not self._mlmd_connection_manager: + raise ValueError( + 'mlmd_connection_manager is not initialized. There are external' + 'artifacts, so we need it to query the external MLMD instance.' + ) + + store_by_pipeline_asset: Dict[str, mlmd.MetadataStore] = {} + external_ids_by_pipeline_asset: Dict[str, List[str]] = ( + collections.defaultdict(list) + ) + for external_id in external_artifact_ids: + connection_config = ( + external_artifact_utils.get_external_connection_config(external_id) + ) + store = self._mlmd_connection_manager.get_mlmd_handle( + connection_config + ).store + pipeline_asset = ( + external_artifact_utils.get_pipeline_asset_from_external_id( + external_id + ) + ) + external_ids_by_pipeline_asset[pipeline_asset].append(external_id) + store_by_pipeline_asset[pipeline_asset] = store + + result = {} + # Gets artifacts from each external store. + for pipeline_asset, external_ids in external_ids_by_pipeline_asset.items(): + store = store_by_pipeline_asset[pipeline_asset] + external_id_by_id = { + external_artifact_utils.get_id_from_external_id(e): e + for e in external_ids + } + artifacts_and_types_by_artifact_id = ( + self.get_downstream_artifacts_by_artifact_ids( + list(external_id_by_id.keys()), + max_num_hops, + filter_query, + event_filter, + store, + ) + ) + + pipeline_owner = pipeline_asset.split('/')[0] + pipeline_name = pipeline_asset.split('/')[1] + artifacts_by_external_id = {} + for ( + artifact_id, + artifacts_and_types, + ) in artifacts_and_types_by_artifact_id.items(): + external_id = external_id_by_id[artifact_id] + imported_artifacts_and_types = [] + for a, t in artifacts_and_types: + imported_artifact = external_artifact_utils.cold_import_artifacts( + t, [a], pipeline_owner, pipeline_name + )[0] + imported_artifacts_and_types.append( + (imported_artifact.mlmd_artifact, imported_artifact.artifact_type) + ) + artifacts_by_external_id[external_id] = imported_artifacts_and_types + + result.update(artifacts_by_external_id) + + return result def get_downstream_artifacts_by_artifact_ids( self, @@ -62,6 +205,7 @@ def get_downstream_artifacts_by_artifact_ids( max_num_hops: int = _MAX_NUM_HOPS, filter_query: str = '', event_filter: Optional[Callable[[metadata_store_pb2.Event], bool]] = None, + store: Optional[mlmd.MetadataStore] = None, ) -> Dict[ int, List[Tuple[metadata_store_pb2.Artifact, metadata_store_pb2.ArtifactType]], @@ -94,34 +238,45 @@ def get_downstream_artifacts_by_artifact_ids( paths towards the downstream artifacts. Only an event with `event_filter(event)` evaluated to True will be considered as valid and kept in the path. + store: A metadata_store.MetadataStore instance. Returns: Mapping of artifact ids to a list of downstream artifacts. """ # Precondition check. - if len(artifact_ids) > _MAX_NUM_STARTING_NODES: - raise ValueError('Number of artifact ids is larger than supported.') if not artifact_ids: return {} + + if len(artifact_ids) > _MAX_NUM_STARTING_NODES: + raise ValueError( + 'Number of artifact ids is larger than supported value of %d.' + % _MAX_NUM_STARTING_NODES + ) if max_num_hops > _MAX_NUM_HOPS or max_num_hops < 0: raise ValueError( - 'Number of hops is larger than supported or is negative.' + 'Number of hops %d is larger than supported value of %d or is' + ' negative.' % (max_num_hops, _MAX_NUM_HOPS) ) + if store is None: + store = self._store + if store is None: + raise ValueError('MetadataStore provided to MetadataResolver is None.') + artifact_ids_str = ','.join(str(id) for id in artifact_ids) # If `max_num_hops` is set to 0, we don't need the graph traversal. if max_num_hops == 0: if not filter_query: - artifacts = self._store.get_artifacts_by_id(artifact_ids) + artifacts = store.get_artifacts_by_id(artifact_ids) else: - artifacts = self._store.get_artifacts( + artifacts = store.get_artifacts( list_options=mlmd.ListOptions( filter_query=f'id IN ({artifact_ids_str}) AND ({filter_query})', limit=_MAX_NUM_STARTING_NODES, ) ) artifact_type_ids = [a.type_id for a in artifacts] - artifact_types = self._store.get_artifact_types_by_id(artifact_type_ids) + artifact_types = store.get_artifact_types_by_id(artifact_type_ids) artifact_type_by_id = {t.id: t for t in artifact_types} return { artifact.id: [(artifact, artifact_type_by_id[artifact.type_id])] @@ -140,7 +295,7 @@ def get_downstream_artifacts_by_artifact_ids( _EVENTS_FIELD_MASK_PATH, _ARTIFACT_TYPES_MASK_PATH, ] - lineage_graph = self._store.get_lineage_subgraph( + lineage_graph = store.get_lineage_subgraph( query_options=options, field_mask_paths=field_mask_paths, ) @@ -175,7 +330,7 @@ def get_downstream_artifacts_by_artifact_ids( ) artifact_ids_str = ','.join(str(id) for id in candidate_artifact_ids) # Send a call to metadata_store to get filtered downstream artifacts. - artifacts = self._store.get_artifacts( + artifacts = store.get_artifacts( list_options=mlmd.ListOptions( filter_query=f'id IN ({artifact_ids_str}) AND ({filter_query})' ) diff --git a/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py b/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py index cad7d29c250..fee73bda28d 100644 --- a/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py +++ b/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py @@ -341,7 +341,7 @@ def _join_artifacts( def _resolve_input_graph_ref( - mlmd_handle: metadata.Metadata, + handle_like: mlmd_cm.HandleLike, node_inputs: pipeline_pb2.NodeInputs, input_key: str, resolved: Dict[str, List[_Entry]], @@ -352,12 +352,12 @@ def _resolve_input_graph_ref( (i.e. `InputGraphRef` with the same `graph_id`). Args: - mlmd_handle: A `Metadata` instance. + handle_like: A `mlmd_cm.HandleLike` instance. node_inputs: A `NodeInputs` proto. input_key: A target input key whose corresponding `InputSpec` has an - `InputGraphRef`. + `InputGraphRef`. resolved: A dict that contains the already resolved inputs, and to which the - resolved result would be written from this function. + resolved result would be written from this function. """ graph_id = node_inputs.inputs[input_key].input_graph_ref.graph_id input_graph = node_inputs.input_graphs[graph_id] @@ -372,7 +372,8 @@ def _resolve_input_graph_ref( } graph_fn, graph_input_keys = input_graph_resolver.build_graph_fn( - mlmd_handle, node_inputs.input_graphs[graph_id]) + handle_like, node_inputs.input_graphs[graph_id] + ) for partition, input_dict in _join_artifacts(resolved, graph_input_keys): result = graph_fn(input_dict) if graph_output_type == _DataType.ARTIFACT_LIST: @@ -514,9 +515,7 @@ def resolve( (partition_utils.NO_PARTITION, _filter_live(artifacts)) ] elif input_spec.input_graph_ref.graph_id: - _resolve_input_graph_ref( - mlmd_cm.get_handle(handle_like), node_inputs, input_key, - resolved) + _resolve_input_graph_ref(handle_like, node_inputs, input_key, resolved) elif input_spec.mixed_inputs.input_keys: _resolve_mixed_inputs(node_inputs, input_key, resolved) elif input_spec.HasField('static_inputs'): diff --git a/tfx/types/external_artifact_utils.py b/tfx/types/external_artifact_utils.py new file mode 100644 index 00000000000..be106311e16 --- /dev/null +++ b/tfx/types/external_artifact_utils.py @@ -0,0 +1,35 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Third party version of external_artifact_utils.py.""" + + +def get_artifact_id_from_external_id(external_id: str): + del external_id + + +def get_pipeline_asset_from_external_id( + external_id: str, +): + del external_id + + +def get_external_connection_config( + external_id: str, +): + del external_id + + +def identifier(artifact): + return artifact.id