Skip to content

Commit

Permalink
no-op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660551866
  • Loading branch information
tfx-copybara committed Aug 9, 2024
1 parent 3db64bc commit 47008b6
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ def _generate_tasks_for_node(
execution_type=node.node_info.type,
contexts=resolved_info.contexts,
input_and_params=unprocessed_inputs,
pipeline=self._pipeline,
node_id=node.node_info.id,
)

for execution in executions:
Expand Down
8 changes: 8 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,3 +1673,11 @@ def get_pipeline_and_node(
'pipeline nodes are supported for external executions.'
)
return (pipeline_state.pipeline, node)


def get_pipeline(
mlmd_handle: metadata.Metadata, pipeline_id: str
) -> pipeline_pb2.Pipeline:
"""Loads the pipeline proto for a pipeline from latest execution."""
pipeline_view = PipelineView.load(mlmd_handle, pipeline_id)
return pipeline_view.pipeline
2 changes: 2 additions & 0 deletions tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ def _generate_tasks_from_resolved_inputs(
execution_type=node.node_info.type,
contexts=resolved_info.contexts,
input_and_params=resolved_info.input_and_params,
pipeline=self._pipeline,
node_id=node.node_info.id,
)

result.extend(
Expand Down
61 changes: 52 additions & 9 deletions tfx/orchestration/experimental/core/task_gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tfx.orchestration import metadata
from tfx.orchestration import node_proto_view
from tfx.orchestration.experimental.core import constants
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import mlmd_state
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
Expand Down Expand Up @@ -548,21 +549,41 @@ def register_executions_from_existing_executions(
contexts = metadata_handle.store.get_contexts_by_execution(
existing_executions[0].id
)
return execution_lib.put_executions(
executions = execution_lib.put_executions(
metadata_handle,
new_executions,
contexts,
input_artifacts_maps=input_artifacts,
)

pipeline_asset = metadata_handle.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pipeline,
node.node_info.id,
executions,
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
' node executions.',
pipeline_asset,
)
return executions


# TODO(b/349654866): make pipeline and node_id non-optional.
def register_executions(
metadata_handle: metadata.Metadata,
execution_type: metadata_store_pb2.ExecutionType,
contexts: Sequence[metadata_store_pb2.Context],
input_and_params: Sequence[InputAndParam],
pipeline: Optional[pipeline_pb2.Pipeline] = None,
node_id: Optional[str] = None,
) -> Sequence[metadata_store_pb2.Execution]:
"""Registers multiple executions in MLMD.
"""Registers multiple executions in storage backends.
Along with the execution:
- the input artifacts will be linked to the executions.
Expand All @@ -575,6 +596,8 @@ def register_executions(
input_and_params: A list of InputAndParams, which includes input_dicts
(dictionaries of artifacts. One execution will be registered for each of
the input_dict) and corresponding exec_properties.
pipeline: Optional. The pipeline proto.
node_id: Optional. The node id of the executions to be registered.
Returns:
A list of MLMD executions that are registered in MLMD, with id populated.
Expand Down Expand Up @@ -603,21 +626,41 @@ def register_executions(
executions.append(execution)

if len(executions) == 1:
return [
new_executions = [
execution_lib.put_execution(
metadata_handle,
executions[0],
contexts,
input_artifacts=input_and_params[0].input_artifacts,
)
]
else:
new_executions = execution_lib.put_executions(
metadata_handle,
executions,
contexts,
[
input_and_param.input_artifacts
for input_and_param in input_and_params
],
)

return execution_lib.put_executions(
metadata_handle,
executions,
contexts,
[input_and_param.input_artifacts for input_and_param in input_and_params],
)
pipeline_asset = metadata_handle.store.pipeline_asset
if pipeline_asset and pipeline and node_id:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pipeline,
node_id,
new_executions,
)
else:
logging.warning(
'Skipping creating pipeline run node executions for pipeline asset %s.',
pipeline_asset,
)

return new_executions


def update_external_artifact_type(
Expand Down
7 changes: 5 additions & 2 deletions tfx/orchestration/portable/execution_publish_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def publish_cached_executions(
output_artifacts_maps: Optional[
Sequence[typing_utils.ArtifactMultiMap]
] = None,
) -> None:
) -> Sequence[metadata_store_pb2.Execution]:
"""Marks an existing execution as using cached outputs from a previous execution.
Args:
Expand All @@ -46,11 +46,14 @@ def publish_cached_executions(
executions: Executions that will be published as CACHED executions.
output_artifacts_maps: A list of output artifacts of the executions. Each
artifact will be linked with the execution through an event of type OUTPUT
Returns:
A list of MLMD executions that are published to MLMD, with id pupulated.
"""
for execution in executions:
execution.last_known_state = metadata_store_pb2.Execution.CACHED

execution_lib.put_executions(
return execution_lib.put_executions(
metadata_handle,
executions,
contexts,
Expand Down
22 changes: 20 additions & 2 deletions tfx/orchestration/portable/importer_node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from tfx.dsl.components.common import importer
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.portable import data_types
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable import inputs_utils
Expand Down Expand Up @@ -57,7 +59,7 @@ def run(
Args:
mlmd_connection: ML metadata connection.
pipeline_node: The specification of the node that this launcher lauches.
pipeline_node: The specification of the node that this launcher launches.
pipeline_info: The information of the pipeline that this node runs in.
pipeline_runtime_spec: The runtime information of the pipeline that this
node runs in.
Expand All @@ -78,13 +80,29 @@ def run(
inputs_utils.resolve_parameters_with_schema(
node_parameters=pipeline_node.parameters))

# 3. Registers execution in metadata.
# 3. Registers execution in storage backend.
execution = execution_publish_utils.register_execution(
metadata_handle=m,
execution_type=pipeline_node.node_info.type,
contexts=contexts,
exec_properties=exec_properties,
)
pipeline_asset = m.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pstate.get_pipeline(m, pipeline_info.id),
pipeline_node.node_info.id,
[execution],
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
' node execution %s.',
pipeline_asset,
execution,
)

# 4. Generate output artifacts to represent the imported artifacts.
output_key = cast(str, exec_properties[importer.OUTPUT_KEY_KEY])
Expand Down
46 changes: 39 additions & 7 deletions tfx/orchestration/portable/partial_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tfx.dsl.compiler import constants
from tfx.orchestration import metadata
from tfx.orchestration import node_proto_view
from tfx.orchestration.experimental.core import env
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable.mlmd import context_lib
from tfx.orchestration.portable.mlmd import execution_lib
Expand Down Expand Up @@ -599,6 +600,8 @@ def __init__(
for node in node_proto_view.get_view_for_all_in(new_pipeline_run_ir)
}

self._pipeline = new_pipeline_run_ir

def _get_base_pipeline_run_context(
self, base_run_id: Optional[str] = None
) -> metadata_store_pb2.Context:
Expand Down Expand Up @@ -788,7 +791,12 @@ def _cache_and_publish(
contexts=[self._new_pipeline_run_context] + node_contexts,
)
)
if not prev_cache_executions:

# If there are no previous attempts to cache and publish, we will create new
# cache executions.
create_new_cache_executions: bool = not prev_cache_executions

if create_new_cache_executions:
new_cached_executions = []
for e in existing_executions:
new_cached_executions.append(
Expand Down Expand Up @@ -820,12 +828,36 @@ def _cache_and_publish(
execution_lib.get_output_artifacts(self._mlmd, e.id)
for e in existing_executions
]
execution_publish_utils.publish_cached_executions(
self._mlmd,
contexts=cached_execution_contexts,
executions=new_cached_executions,
output_artifacts_maps=output_artifacts_maps,
)

if create_new_cache_executions:
new_executions = execution_publish_utils.publish_cached_executions(
self._mlmd,
contexts=cached_execution_contexts,
executions=new_cached_executions,
output_artifacts_maps=output_artifacts_maps,
)
pipeline_asset = self._mlmd.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
self._pipeline,
node.node_info.id,
new_executions,
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
' node executions.',
pipeline_asset,
)
else:
execution_publish_utils.publish_cached_executions(
self._mlmd,
contexts=cached_execution_contexts,
executions=new_cached_executions,
output_artifacts_maps=output_artifacts_maps,
)

def put_parent_context(self):
"""Puts a ParentContext edge in MLMD."""
Expand Down
37 changes: 35 additions & 2 deletions tfx/orchestration/portable/resolver_node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import grpc
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.portable import data_types
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable import inputs_utils
Expand Down Expand Up @@ -86,6 +88,22 @@ def run(
contexts=contexts,
exec_properties=exec_properties,
)
pipeline_asset = m.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pstate.get_pipeline(m, pipeline_info.id),
pipeline_node.node_info.id,
[execution],
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline'
' run node execution %s.',
pipeline_asset,
execution,
)
execution_publish_utils.publish_failed_execution(
metadata_handle=m,
contexts=contexts,
Expand All @@ -103,14 +121,29 @@ def run(
if isinstance(resolved_inputs, inputs_utils.Skip):
return data_types.ExecutionInfo()

# 3. Registers execution in metadata.
# 3. Registers execution in storage backends.
execution = execution_publish_utils.register_execution(
metadata_handle=m,
execution_type=pipeline_node.node_info.type,
contexts=contexts,
exec_properties=exec_properties,
)

pipeline_asset = m.store.pipeline_asset
if pipeline_asset:
env.get_env().create_pipeline_run_node_executions(
pipeline_asset.owner,
pipeline_asset.name,
pstate.get_pipeline(m, pipeline_info.id),
pipeline_node.node_info.id,
[execution],
)
else:
logging.warning(
'Pipeline asset %s not found in MLMD. Unable to create pipeline'
' run node execution %s.',
pipeline_asset,
execution,
)
# TODO(b/197741942): Support len > 1.
if len(resolved_inputs) > 1:
execution_publish_utils.publish_failed_execution(
Expand Down

0 comments on commit 47008b6

Please sign in to comment.