Skip to content

Commit

Permalink
Let resolver op be able to get external artifacts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623548131
  • Loading branch information
tfx-copybara committed Jun 5, 2024
1 parent baab834 commit 759e6ed
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 58 deletions.
98 changes: 80 additions & 18 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for LatestPolicyModel operator."""

import collections
import enum
from typing import Dict, List
from typing import Dict, List, Optional, Tuple

from tfx import types
from tfx.dsl.input_resolution import resolver_op
Expand All @@ -24,6 +25,7 @@
from tfx.orchestration.portable.mlmd import event_lib
from tfx.orchestration.portable.mlmd import filter_query_builder as q
from tfx.types import artifact_utils
from tfx.types import external_artifact_utils
from tfx.utils import typing_utils

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -204,6 +206,33 @@ def _build_result_dictionary(
return result


def _dedpupe_model_artifacts(
models: Optional[List[artifact_utils.Artifact]],
) -> Tuple[List[artifact_utils.Artifact], List[int]]:
"""Dedupes a list of Model artifacts."""
if not models:
return [], []

model_by_external_id = {}
model_by_id = {}

for m in models:
if m.external_id:
model_by_external_id[m.external_id] = m
else:
model_by_id[m.id] = m

deduped_models = list(model_by_external_id.values()) + list(
model_by_id.values()
)
model_artifact_ids = [
external_artifact_utils.get_id_from_external_id(i)
for i in model_by_external_id.keys()
] + list(model_by_id.keys())

return (deduped_models, model_artifact_ids)


class LatestPolicyModel(
resolver_op.ResolverOp,
canonical_name='tfx.LatestPolicyModel',
Expand Down Expand Up @@ -325,6 +354,25 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
if self.policy == Policy.LATEST_EXPORTED:
return {ops_utils.MODEL_KEY: [models[0]]}

are_models_external = [m.is_external for m in models]
if any(are_models_external) and not all(are_models_external):
raise exceptions.InvalidArgument(
'Inputs to the LastestPolicyModel are from both current pipeline and'
' external pipeline. LastestPolicyModel does not support such usage.'
)
if all(are_models_external):
pipeline_assets = set([
external_artifact_utils.get_pipeline_asset_from_external_id(
m.mlmd_artifact.external_id
)
for m in models
])
if len(pipeline_assets) != 1:
raise exceptions.InvalidArgument(
'Input models to the LastestPolicyModel are from multiple'
' pipelines. LastestPolicyModel does not support such usage.'
)

# If ModelBlessing and/or ModelInfraBlessing artifacts were included in
# input_dict, then we will only consider those child artifacts.
specifies_child_artifacts = (
Expand All @@ -334,7 +382,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
input_child_artifacts = input_dict.get(
ops_utils.MODEL_BLESSSING_KEY, []
) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])
input_child_artifact_ids = set([a.id for a in input_child_artifacts])

input_child_artifact_ids = set()
for a in input_child_artifacts:
if a.is_external:
input_child_artifact_ids.add(
external_artifact_utils.get_id_from_external_id(
a.mlmd_artifact.external_id
)
)
else:
input_child_artifact_ids.add(a.id)

# If the ModelBlessing and ModelInfraBlessing lists are empty, then no
# child artifacts can be considered and we raise a SkipSignal. This can
Expand Down Expand Up @@ -362,8 +420,8 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):

# There could be multiple events with the same execution ID but different
# artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
# keep the values of model_artifact_ids_by_execution_id as sets.
model_artifact_ids = sorted(set(m.id for m in models))
# need to deduplicate the Model artifacts.
deduped_models, model_artifact_ids = _dedpupe_model_artifacts(models)

downstream_artifact_type_names_filter_query = q.to_sql_string([
ops_utils.MODEL_BLESSING_TYPE_NAME,
Expand Down Expand Up @@ -407,10 +465,13 @@ def event_filter(event):
else:
return event_lib.is_valid_output_event(event)

mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store)
mlmd_resolver = metadata_resolver.MetadataResolver(
self.context.store,
mlmd_connection_manager=self.context.mlmd_connection_manager,
)
# Populate the ModelRelations associated with each Model artifact and its
# children.
model_relations_by_model_artifact_id = collections.defaultdict(
model_relations_by_model_identifier = collections.defaultdict(
ModelRelations
)
artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {}
Expand All @@ -419,34 +480,35 @@ def event_filter(event):
# fetching downstream artifacts, because
# `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids
# as starting artifact ids.
for id_index in range(0, len(model_artifact_ids), ops_utils.BATCH_SIZE):
batch_model_artifact_ids = model_artifact_ids[
for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE):
batch_model_artifacts = deduped_models[
id_index : id_index + ops_utils.BATCH_SIZE
]
# Set `max_num_hops` to 50, which should be enough for this use case.
batch_downstream_artifacts_and_types_by_model_ids = (
mlmd_resolver.get_downstream_artifacts_by_artifact_ids(
batch_model_artifact_ids,
batch_downstream_artifacts_and_types_by_model_identifier = (
mlmd_resolver.get_downstream_artifacts_by_artifacts(
batch_model_artifacts,
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
filter_query=filter_query,
event_filter=event_filter,
)
)

for (
model_artifact_id,
model_identifier,
artifacts_and_types,
) in batch_downstream_artifacts_and_types_by_model_ids.items():
) in batch_downstream_artifacts_and_types_by_model_identifier.items():
for downstream_artifact, artifact_type in artifacts_and_types:
artifact_type_by_name[artifact_type.name] = artifact_type
model_relations = model_relations_by_model_artifact_id[
model_artifact_id
]
model_relations.add_downstream_artifact(downstream_artifact)
model_relations_by_model_identifier[
model_identifier
].add_downstream_artifact(downstream_artifact)

# Find the latest model and ModelRelations that meets the Policy.
result = {}
for model in models:
model_relations = model_relations_by_model_artifact_id[model.id]
identifier = external_artifact_utils.identifier(model)
model_relations = model_relations_by_model_identifier[identifier]
if model_relations.meets_policy(self.policy):
result[ops_utils.MODEL_KEY] = [model]
break
Expand Down
5 changes: 5 additions & 0 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.dsl.input_resolution.ops.latest_policy_model_op."""
import os
from typing import Dict, List, Optional
from unittest import mock

from absl.testing import parameterized
import tensorflow as tf
Expand All @@ -22,6 +24,7 @@
from tfx.dsl.input_resolution.ops import ops
from tfx.dsl.input_resolution.ops import ops_utils
from tfx.dsl.input_resolution.ops import test_utils
from tfx.orchestration import metadata
from tfx.orchestration.portable.input_resolution import exceptions

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -146,6 +149,7 @@ def _run_latest_policy_model(self, *args, **kwargs):
args=args,
kwargs=kwargs,
store=self.store,
mlmd_handle_like=self.mlmd_cm,
)

def setUp(self):
Expand All @@ -158,6 +162,7 @@ def setUp(self):

self.artifacts = [self.model_1, self.model_2, self.model_3]


def assertDictKeysEmpty(
self,
output_dict: Dict[str, List[types.Artifact]],
Expand Down
Loading

0 comments on commit 759e6ed

Please sign in to comment.