Skip to content

Commit

Permalink
Encode producer component id and output key when CWP is created from …
Browse files Browse the repository at this point in the history
…an OutputChannel

PiperOrigin-RevId: 619667393
  • Loading branch information
kmonte authored and tfx-copybara committed May 20, 2024
1 parent e5d02d1 commit de7c771
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 40 deletions.
3 changes: 2 additions & 1 deletion tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def _compile_input_spec(
name=channel.pipeline_name,
)
result_input_channel.metadata_connection_config.Pack(config)

elif isinstance(channel, channel_types.ChannelWrappedPlaceholder):
print(f'Channel: {tfx_node.id}.{input_key} was a CWP!')
# Note that this path is *usually* not taken, as most output channels already
# exist in pipeline_ctx.channels, as they are added in after
# compiler._generate_input_spec_for_outputs is called.
Expand Down
9 changes: 6 additions & 3 deletions tfx/dsl/compiler/node_inputs_compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def testCompileConditionals(self):
self.assertEqual(result.inputs[cond_input_key].min_count, 1)
self.assertLen(result.conditionals, 1)
cond = list(result.conditionals.values())[0]
self.assertProtoEquals("""
self.assertProtoEquals(
"""
operator {
compare_op {
op: EQUAL
Expand All @@ -343,7 +344,7 @@ def testCompileConditionals(self):
index_op {
expression {
placeholder {
key: "%s"
key: "CondNode_x"
}
}
}
Expand All @@ -354,7 +355,9 @@ def testCompileConditionals(self):
}
}
}
""" % cond_input_key, cond.placeholder_expression)
""",
cond.placeholder_expression,
)

def testCompileInputsForDynamicProperties(self):
producer = DummyNode('Producer')
Expand Down
1 change: 1 addition & 0 deletions tfx/dsl/compiler/placeholder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def resolve_placeholder_expression(
debug_str(expression),
err.placeholder,
)
logging.warning("Context: %s", context)
return None
except Exception as e:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2109,7 +2109,7 @@ nodes {
index_op {
expression {
placeholder {
key: "blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -3318,7 +3318,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_infra-validator-pipeline.blessing"
key: "infra-validator-pipeline_blessing"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2368,7 +2368,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -3686,7 +3686,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_infra-validator-pipeline.blessing"
key: "infra-validator-pipeline_blessing"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -1264,7 +1264,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -1301,7 +1301,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_InfraValidator.blessing"
key: "InfraValidator_blessing"
}
}
}
Expand Down Expand Up @@ -1333,7 +1333,7 @@ nodes {
index_op {
expression {
placeholder {
key: "model"
key: "Trainer_model"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_UpstreamComponent.num"
key: "UpstreamComponent_num"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_UpstreamComponent.num"
key: "UpstreamComponent_num"
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions tfx/types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@ def set_as_async_channel(self) -> None:
self._is_async = True

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self)
return ChannelWrappedPlaceholder(
self, key=f'{self.producer_component_id}_{self.output_key}'
)


@doc_controls.do_not_generate_docs
Expand Down Expand Up @@ -793,7 +795,8 @@ def set_key(self, key: Optional[str]):
Args:
key: The new key for the channel.
"""
self._key = key
del key # unused.
return

def __getitem__(self, index: int) -> ChannelWrappedPlaceholder:
if self._index is not None:
Expand Down
4 changes: 1 addition & 3 deletions tfx/types/channel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,8 @@ def unwrap_simple_channel_placeholder(
# proto paths above and been getting default messages all along. If this
# sub-message is present, then the whole chain was correct.
not index_op.expression.HasField('placeholder')
# ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason, and has
# no key when encoded with encode().
# ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason.
or cwp.type != placeholder_pb2.Placeholder.Type.INPUT_ARTIFACT
or cwp.key
# For the `[0]` part of the desired shape.
or index_op.index != 0
):
Expand Down
44 changes: 24 additions & 20 deletions tfx/types/channel_wrapped_placeholder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def testProtoFutureValueOperator(self):
output_key='num',
)
placeholder = output_channel.future()[0].value
channel_to_key = {output_channel: '_component.num'}
channel_to_key = {output_channel: 'producer_num'}
self.assertProtoEquals(
channel_utils.encode_placeholder_with_channels(
placeholder, lambda k: channel_to_key[k]
Expand Down Expand Up @@ -161,7 +161,7 @@ def testEncodeWithKeys(self):
index_op {
expression {
placeholder {
key: "MyTypeName"
key: "producer_foo"
}
}
}
Expand Down Expand Up @@ -351,7 +351,9 @@ def testEncode(self):
operator {
index_op {
expression {
placeholder {}
placeholder {
key: "a_foo"
}
}
}
}
Expand All @@ -366,7 +368,9 @@ def testEncode(self):
operator {
index_op {
expression {
placeholder {}
placeholder {
key: "b_bar"
}
}
}
}
Expand Down Expand Up @@ -413,7 +417,7 @@ def testEncodeWithKeys(self):
index_op {
expression {
placeholder {
key: "channel_1_key"
key: "a_foo"
}
}
}
Expand All @@ -430,7 +434,7 @@ def testEncodeWithKeys(self):
index_op {
expression {
placeholder {
key: "channel_2_key"
key: "b_bar"
}
}
}
Expand Down Expand Up @@ -482,7 +486,7 @@ def testNegation(self):
index_op {
expression {
placeholder {
key: "channel_1_key"
key: "a_foo"
}
}
}
Expand All @@ -499,7 +503,7 @@ def testNegation(self):
index_op {
expression {
placeholder {
key: "channel_2_key"
key: "b_bar"
}
}
}
Expand Down Expand Up @@ -553,7 +557,7 @@ def testDoubleNegation(self):
index_op {
expression {
placeholder {
key: "channel_1_key"
key: "a_foo"
}
}
}
Expand All @@ -570,7 +574,7 @@ def testDoubleNegation(self):
index_op {
expression {
placeholder {
key: "channel_2_key"
key: "b_bar"
}
}
}
Expand Down Expand Up @@ -622,7 +626,7 @@ def testComparison_notEqual(self):
index_op {
expression {
placeholder {
key: "channel_1_key"
key: "a_foo"
}
}
}
Expand All @@ -639,7 +643,7 @@ def testComparison_notEqual(self):
index_op {
expression {
placeholder {
key: "channel_2_key"
key: "b_bar"
}
}
}
Expand Down Expand Up @@ -695,7 +699,7 @@ def testComparison_lessThanOrEqual(self):
index_op {
expression {
placeholder {
key: "channel_1_key"
key: "a_foo"
}
}
}
Expand All @@ -712,7 +716,7 @@ def testComparison_lessThanOrEqual(self):
index_op {
expression {
placeholder {
key: "channel_2_key"
key: "b_bar"
}
}
}
Expand Down Expand Up @@ -768,7 +772,7 @@ def testComparison_greaterThanOrEqual(self):
index_op {
expression {
placeholder {
key: "channel_1_key"
key: "a_foo"
}
}
}
Expand All @@ -785,7 +789,7 @@ def testComparison_greaterThanOrEqual(self):
index_op {
expression {
placeholder {
key: "channel_2_key"
key: "b_bar"
}
}
}
Expand Down Expand Up @@ -868,7 +872,7 @@ def testNestedLogicalOps(self):
index_op {
expression {
placeholder {
key: "channel_11_key"
key: "a_1"
}
}
}
Expand All @@ -885,7 +889,7 @@ def testNestedLogicalOps(self):
index_op {
expression {
placeholder {
key: "channel_12_key"
key: "b_2"
}
}
}
Expand Down Expand Up @@ -913,7 +917,7 @@ def testNestedLogicalOps(self):
index_op {
expression {
placeholder {
key: "channel_21_key"
key: "c_3"
}
}
}
Expand All @@ -930,7 +934,7 @@ def testNestedLogicalOps(self):
index_op {
expression {
placeholder {
key: "channel_22_key"
key: "d_4"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ operator {
index_op {
expression {
placeholder {
key: "_component.num"
key: "producer_num"
}
}
}
Expand Down

0 comments on commit de7c771

Please sign in to comment.