Skip to content

Commit

Permalink
Encode producer component id and output key when CWP is created from …
Browse files Browse the repository at this point in the history
…an OutputChannel

PiperOrigin-RevId: 619667393
  • Loading branch information
kmonte authored and tfx-copybara committed Jun 10, 2024
1 parent 810e466 commit 07ea4ff
Show file tree
Hide file tree
Showing 15 changed files with 202 additions and 95 deletions.
4 changes: 2 additions & 2 deletions tfx/dsl/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def _compile_node(

# Step 3: Node inputs
node_inputs_compiler.compile_node_inputs(
pipeline_ctx, tfx_node, node.inputs)

pipeline_ctx, tfx_node, node.inputs
)
# Step 4: Node outputs
if (isinstance(tfx_node, base_component.BaseComponent) or
compiler_utils.is_importer(tfx_node)):
Expand Down
5 changes: 5 additions & 0 deletions tfx/dsl/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def node_context_name(pipeline_context_name: str, node_id: str):

def implicit_channel_key(channel: types.BaseChannel):
"""Key of a channel to the node that consumes the channel as input."""
if (
isinstance(channel, channel_types.ChannelWrappedPlaceholder)
and channel.key
):
return channel.key
if isinstance(channel, channel_types.PipelineInputChannel):
channel = cast(channel_types.PipelineInputChannel, channel)
return f"_{channel.pipeline.id}.{channel.output_key}"
Expand Down
18 changes: 15 additions & 3 deletions tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,20 +421,32 @@ def _compile_conditionals(
contexts = context.dsl_context_registry.get_contexts(tfx_node)
except ValueError:
return

for dsl_context in contexts:
if not isinstance(dsl_context, conditional.CondContext):
continue
cond_context = cast(conditional.CondContext, dsl_context)
for channel in channel_utils.get_dependent_channels(cond_context.predicate):
# Since the channels here are *always* from a CWP, which we now set the
# key by default on for OutputChannel, we must re-create the input key if
# an output channel is used, otherwise the wrong key may be used by
# `get_input_key` (e.g. if the producer component is also used as data
# input to the component.)
# Note that this means we potentially have several inputs with identical
# artifact queries under the hood, which should be optimized away if we
# run into performance issues.
if isinstance(channel, channel_types.OutputChannel):
input_key = compiler_utils.implicit_channel_key(channel)
else:
input_key = context.get_node_context(tfx_node).get_input_key(channel)
_compile_input_spec(
pipeline_ctx=context,
tfx_node=tfx_node,
input_key=context.get_node_context(tfx_node).get_input_key(channel),
input_key=input_key,
channel=channel,
hidden=False,
min_count=1,
result=result)
result=result,
)
cond_id = context.get_conditional_id(cond_context)
expr = channel_utils.encode_placeholder_with_channels(
cond_context.predicate, context.get_node_context(tfx_node).get_input_key
Expand Down
9 changes: 6 additions & 3 deletions tfx/dsl/compiler/node_inputs_compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ def testCompileConditionals(self):
self.assertEqual(result.inputs[cond_input_key].min_count, 1)
self.assertLen(result.conditionals, 1)
cond = list(result.conditionals.values())[0]
self.assertProtoEquals("""
self.assertProtoEquals(
"""
operator {
compare_op {
op: EQUAL
Expand All @@ -594,7 +595,7 @@ def testCompileConditionals(self):
index_op {
expression {
placeholder {
key: "%s"
key: "_CondNode.x"
}
}
}
Expand All @@ -605,7 +606,9 @@ def testCompileConditionals(self):
}
}
}
""" % cond_input_key, cond.placeholder_expression)
""",
cond.placeholder_expression,
)

def testCompileInputsForDynamicProperties(self):
producer = DummyNode('Producer')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,43 @@ nodes {
}
}
inputs {
inputs {
key: "_Evaluator.blessing"
value {
channels {
producer_node_query {
id: "Evaluator"
}
context_queries {
type {
name: "pipeline"
}
name {
field_value {
string_value: "composable-pipeline"
}
}
}
context_queries {
type {
name: "node"
}
name {
field_value {
string_value: "composable-pipeline.Evaluator"
}
}
}
artifact_query {
type {
name: "ModelBlessing"
}
}
output_key: "blessing"
}
min_count: 1
}
}
inputs {
key: "blessing"
value {
Expand Down Expand Up @@ -2109,7 +2146,7 @@ nodes {
index_op {
expression {
placeholder {
key: "blessing"
key: "_Evaluator.blessing"
}
}
}
Expand Down
51 changes: 50 additions & 1 deletion tfx/dsl/compiler/testdata/conditional_pipeline_input_v2_ir.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,55 @@ nodes {
min_count: 1
}
}
inputs {
key: "_Trainer.model"
value {
channels {
producer_node_query {
id: "Trainer"
}
context_queries {
type {
name: "pipeline"
}
name {
field_value {
string_value: "cond"
}
}
}
context_queries {
type {
name: "pipeline_run"
}
name {
runtime_parameter {
name: "pipeline-run-id"
type: STRING
}
}
}
context_queries {
type {
name: "node"
}
name {
field_value {
string_value: "cond.Trainer"
}
}
}
artifact_query {
type {
name: "Model"
base_type: MODEL
}
}
output_key: "model"
}
min_count: 1
}
}
inputs {
key: "model"
value {
Expand Down Expand Up @@ -1333,7 +1382,7 @@ nodes {
index_op {
expression {
placeholder {
key: "model"
key: "_Trainer.model"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# proto-file: tfx/proto/orchestration/pipeline.proto
# proto-message: Pipeline
#
# This file contains the IR of an example pipeline
# tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py

pipeline_info {
id: "consumer-pipeline"
}
Expand Down
66 changes: 32 additions & 34 deletions tfx/orchestration/kubeflow/v2/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,36 +266,38 @@ def setUp(self):

@parameterized.named_parameters(
{
'testcase_name':
'two_sides_placeholder',
'predicate':
_TEST_CHANNEL.future()[0].property('int1') <
_TEST_CHANNEL.future()[0].property('int2'),
'expected_cel':
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'int1\'] < '
'inputs.artifacts[\'key\'].artifacts[0].metadata[\'int2\'])',
'testcase_name': 'two_sides_placeholder',
'predicate': _TEST_CHANNEL.future()[0].property(
'int1'
) < _TEST_CHANNEL.future()[0].property('int2'),
'expected_cel': (
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['int1'] < "
"inputs.artifacts['_producer.foo'].artifacts[0].metadata['int2'])"
),
},
{
'testcase_name':
'left_side_placeholder_right_side_int',
'predicate':
_TEST_CHANNEL.future()[0].property('int') < 1,
'expected_cel':
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'int\'] < 1.0)',
'testcase_name': 'left_side_placeholder_right_side_int',
'predicate': _TEST_CHANNEL.future()[0].property('int') < 1,
'expected_cel': (
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['int']"
' < 1.0)'
),
},
{
'testcase_name': 'left_side_placeholder_right_side_float',
'predicate': _TEST_CHANNEL.future()[0].property('float') < 1.1,
'expected_cel':
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'float\'] < '
'1.1)',
'expected_cel': (
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['float']"
' < 1.1)'
),
},
{
'testcase_name': 'left_side_placeholder_right_side_string',
'predicate': _TEST_CHANNEL.future()[0].property('str') == 'test_str',
'expected_cel':
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'str\'] == '
'\'test_str\')',
'expected_cel': (
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['str']"
" == 'test_str')"
),
},
)
def testComparison(self, predicate, expected_cel):
Expand All @@ -310,8 +312,9 @@ def testComparison(self, predicate, expected_cel):

def testArtifactUri(self):
predicate = _TEST_CHANNEL.future()[0].uri == 'test_str'
expected_cel = ('(inputs.artifacts[\'key\'].artifacts[0].uri == '
'\'test_str\')')
expected_cel = (
"(inputs.artifacts['_producer.foo'].artifacts[0].uri == 'test_str')"
)
channel_to_key_map = {
_TEST_CHANNEL: 'key',
}
Expand All @@ -323,8 +326,10 @@ def testArtifactUri(self):

def testNegation(self):
predicate = _TEST_CHANNEL.future()[0].property('int') != 1
expected_cel = ('!((inputs.artifacts[\'key\'].artifacts[0]'
'.metadata[\'int\'] == 1.0))')
expected_cel = (
"!((inputs.artifacts['_producer.foo'].artifacts[0]"
".metadata['int'] == 1.0))"
)
channel_to_key_map = {
_TEST_CHANNEL: 'key',
}
Expand All @@ -337,8 +342,9 @@ def testNegation(self):
def testConcat(self):
predicate = _TEST_CHANNEL.future()[0].uri + 'something' == 'test_str'
expected_cel = (
'((inputs.artifacts[\'key\'].artifacts[0].uri + \'something\') == '
'\'test_str\')')
"((inputs.artifacts['_producer.foo'].artifacts[0].uri + 'something') =="
" 'test_str')"
)
channel_to_key_map = {
_TEST_CHANNEL: 'key',
}
Expand All @@ -360,14 +366,6 @@ def testUnsupportedOperator(self):
ValueError, 'Got unsupported placeholder operator base64_encode_op.'):
compiler_utils.placeholder_to_cel(placeholder_pb)

def testPlaceholderWithoutKey(self):
predicate = _TEST_CHANNEL.future()[0].uri == 'test_str'
placeholder_pb = predicate.encode()
with self.assertRaisesRegex(
ValueError,
'Only supports accessing placeholders with a key on KFPv2.'):
compiler_utils.placeholder_to_cel(placeholder_pb)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ inputs {
}
}
trigger_policy {
condition: "!((inputs.artifacts['input1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
condition: "!((inputs.artifacts['_producer_task_1.output1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
}
component_ref {
name: "DummyConsumerComponent"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ inputs {
}
}
trigger_policy {
condition: "!((inputs.artifacts['input1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
condition: "!((inputs.artifacts['_producer_task_1.output1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
}
component_ref {
name: "DummyConsumerComponent"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,21 +715,23 @@ def testConditionals(self):

with self.subTest('blessed == 1'):
node_inputs = pipeline_pb2.NodeInputs(
inputs={'x': x},
inputs={'_foo.x': x},
input_graphs={'graph_1': graph_1},
conditionals={'cond_1': cond_1})
conditionals={'cond_1': cond_1},
)

result = node_inputs_resolver.resolve(self._mlmd_handle, node_inputs)
self.assertEqual(result, [{'x': [a1]}, {'x': [a4]}])
self.assertEqual(result, [{'_foo.x': [a1]}, {'_foo.x': [a4]}])

with self.subTest('blessed == 1 and tag == foo'):
node_inputs = pipeline_pb2.NodeInputs(
inputs={'x': x},
inputs={'_foo.x': x},
input_graphs={'graph_1': graph_1},
conditionals={'cond_1': cond_1, 'cond_2': cond_2})
conditionals={'cond_1': cond_1, 'cond_2': cond_2},
)

result = node_inputs_resolver.resolve(self._mlmd_handle, node_inputs)
self.assertEqual(result, [{'x': [a1]}])
self.assertEqual(result, [{'_foo.x': [a1]}])

def testConditionals_FalseCondAlwaysReturnsEmpty(self):
a = self.create_artifacts(1)
Expand Down Expand Up @@ -778,7 +780,7 @@ def testConditionals_FalseCondAlwaysReturnsEmpty(self):
node_inputs = NodeInputs(
inputs={
'a': x1,
'b': x2,
'_foo.x': x2,
},
conditionals={'cond': cond},
)
Expand Down
Loading

0 comments on commit 07ea4ff

Please sign in to comment.