From 81229e1f323e06202d8ced8c147bb8e7b7bb8bda Mon Sep 17 00:00:00 2001 From: kmonte Date: Wed, 27 Mar 2024 14:56:57 -0700 Subject: [PATCH] Encode producer component id and output key when CWP is created from an OutputChannel PiperOrigin-RevId: 619667393 --- tfx/dsl/compiler/compiler_test.py | 9 +- ...omposable_pipeline_async_input_v2_ir.pbtxt | 4 +- .../composable_pipeline_input_v2_ir.pbtxt | 4 +- .../conditional_pipeline_input_v2_ir.pbtxt | 8 +- ...exec_properties_pipeline_input_v2_ir.pbtxt | 2 +- ...ipeline_with_annotations_input_v2_ir.pbtxt | 2 +- tfx/dsl/components/base/testing/test_node.py | 31 ++ tfx/types/channel.py | 13 +- tfx/types/channel_utils.py | 4 +- tfx/types/channel_utils_test.py | 55 ++- tfx/types/channel_wrapped_placeholder_test.py | 396 ++++++++++++++---- tfx/types/component_spec_test.py | 16 +- ...to_placeholder_future_value_operator.pbtxt | 2 +- 13 files changed, 418 insertions(+), 128 deletions(-) create mode 100644 tfx/dsl/components/base/testing/test_node.py diff --git a/tfx/dsl/compiler/compiler_test.py b/tfx/dsl/compiler/compiler_test.py index 4881063ca35..989d0586ab6 100644 --- a/tfx/dsl/compiler/compiler_test.py +++ b/tfx/dsl/compiler/compiler_test.py @@ -205,10 +205,17 @@ def testCompileAdditionalCustomPropertyNameConflictError(self): def testCompileDynamicExecPropTypeError(self): dsl_compiler = compiler.Compiler() test_pipeline = dynamic_exec_properties_pipeline.create_test_pipeline() + upstream_component = next( + c + for c in test_pipeline.components + if isinstance(c, dynamic_exec_properties_pipeline.UpstreamComponent) + ) downstream_component = next( c for c in test_pipeline.components if isinstance(c, dynamic_exec_properties_pipeline.DownstreamComponent)) - test_wrong_type_channel = channel.Channel(_MyType).future().value + test_wrong_type_channel = ( + channel.OutputChannel(_MyType, upstream_component, "foo").future().value + ) downstream_component.exec_properties["input_num"] = test_wrong_type_channel with self.assertRaisesRegex( ValueError, ".*channel must be of a value artifact type.*" diff --git a/tfx/dsl/compiler/testdata/composable_pipeline_async_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/composable_pipeline_async_input_v2_ir.pbtxt index 618c41b36de..6db165707a9 100644 --- a/tfx/dsl/compiler/testdata/composable_pipeline_async_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/composable_pipeline_async_input_v2_ir.pbtxt @@ -1886,7 +1886,7 @@ nodes { index_op { expression { placeholder { - key: "blessing" + key: "Evaluator_blessing" } } } @@ -2983,7 +2983,7 @@ nodes { index_op { expression { placeholder { - key: "_infra-validator-pipeline.blessing" + key: "infra-validator-pipeline_blessing" } } } diff --git a/tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt index b257611d5c0..f1a8897d388 100644 --- a/tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt @@ -2145,7 +2145,7 @@ nodes { index_op { expression { placeholder { - key: "_Evaluator.blessing" + key: "Evaluator_blessing" } } } @@ -3351,7 +3351,7 @@ nodes { index_op { expression { placeholder { - key: "_infra-validator-pipeline.blessing" + key: "infra-validator-pipeline_blessing" } } } diff --git a/tfx/dsl/compiler/testdata/conditional_pipeline_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/conditional_pipeline_input_v2_ir.pbtxt index 34bd7e9a89d..999bd5f99e0 100644 --- a/tfx/dsl/compiler/testdata/conditional_pipeline_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/conditional_pipeline_input_v2_ir.pbtxt @@ -1001,7 +1001,7 @@ nodes { index_op { expression { placeholder { - key: "_Evaluator.blessing" + key: "Evaluator_blessing" } } } @@ -1264,7 +1264,7 @@ nodes { index_op { expression { placeholder { - key: "_Evaluator.blessing" + key: "Evaluator_blessing" } } } @@ -1301,7 +1301,7 @@ nodes { index_op { expression { placeholder { - key: "_InfraValidator.blessing" + key: "InfraValidator_blessing" } } } @@ -1333,7 +1333,7 @@ nodes { index_op { expression { placeholder { - key: "model" + key: "Trainer_model" } } } diff --git a/tfx/dsl/compiler/testdata/dynamic_exec_properties_pipeline_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/dynamic_exec_properties_pipeline_input_v2_ir.pbtxt index 549dbfecb26..ebfa13e432e 100644 --- a/tfx/dsl/compiler/testdata/dynamic_exec_properties_pipeline_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/dynamic_exec_properties_pipeline_input_v2_ir.pbtxt @@ -180,7 +180,7 @@ nodes { index_op { expression { placeholder { - key: "_UpstreamComponent.num" + key: "UpstreamComponent_num" } } } diff --git a/tfx/dsl/compiler/testdata/pipeline_with_annotations_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/pipeline_with_annotations_input_v2_ir.pbtxt index 02346d15148..c1d5f170b94 100644 --- a/tfx/dsl/compiler/testdata/pipeline_with_annotations_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/pipeline_with_annotations_input_v2_ir.pbtxt @@ -221,7 +221,7 @@ nodes { index_op { expression { placeholder { - key: "_UpstreamComponent.num" + key: "UpstreamComponent_num" } } } diff --git a/tfx/dsl/components/base/testing/test_node.py b/tfx/dsl/components/base/testing/test_node.py new file mode 100644 index 00000000000..8c8ef621ce4 --- /dev/null +++ b/tfx/dsl/components/base/testing/test_node.py @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module to provide a node for tests.""" + +from tfx.dsl.components.base import base_node + + +class TestNode(base_node.BaseNode): + """Node purely for testing, intentionally empty. + + DO NOT USE in real pipelines. + """ + + inputs = {} + outputs = {} + exec_properties = {} + + def __init__(self, name: str): + super().__init__() + self.with_id(name) diff --git a/tfx/types/channel.py b/tfx/types/channel.py index f6b3fe63466..6826b5577ca 100644 --- a/tfx/types/channel.py +++ b/tfx/types/channel.py @@ -204,7 +204,7 @@ def trigger_by_property(self, *property_keys: str): return self._with_input_trigger(TriggerByProperty(property_keys)) def future(self) -> ChannelWrappedPlaceholder: - return ChannelWrappedPlaceholder(self) + raise NotImplementedError() def __eq__(self, other): return self is other @@ -557,6 +557,11 @@ def set_external(self, predefined_artifact_uris: List[str]) -> None: def set_as_async_channel(self) -> None: self._is_async = True + def future(self) -> ChannelWrappedPlaceholder: + return ChannelWrappedPlaceholder( + self, f'{self.producer_component_id}_{self.output_key}' + ) + @doc_controls.do_not_generate_docs class UnionChannel(BaseChannel): @@ -703,6 +708,9 @@ def trigger_by_property(self, *property_keys: str): 'trigger_by_property is not implemented for PipelineInputChannel.' ) + def future(self) -> ChannelWrappedPlaceholder: + return ChannelWrappedPlaceholder(self, f'{self._output_key}') + class ExternalPipelineChannel(BaseChannel): """Channel subtype that is used to get artifacts from external MLMD db.""" @@ -787,7 +795,8 @@ def set_key(self, key: Optional[str]): Args: key: The new key for the channel. """ - self._key = key + del key # unused. + return def __getitem__(self, index: int) -> ChannelWrappedPlaceholder: if self._index is not None: diff --git a/tfx/types/channel_utils.py b/tfx/types/channel_utils.py index 7523661c466..e928fef245f 100644 --- a/tfx/types/channel_utils.py +++ b/tfx/types/channel_utils.py @@ -211,10 +211,8 @@ def unwrap_simple_channel_placeholder( # proto paths above and been getting default messages all along. If this # sub-message is present, then the whole chain was correct. not index_op.expression.HasField('placeholder') - # ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason, and has - # no key when encoded with encode(). + # ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason. or cwp.type != placeholder_pb2.Placeholder.Type.INPUT_ARTIFACT - or cwp.key # For the `[0]` part of the desired shape. or index_op.index != 0 ): diff --git a/tfx/types/channel_utils_test.py b/tfx/types/channel_utils_test.py index bb136f05a26..f97e49c726c 100644 --- a/tfx/types/channel_utils_test.py +++ b/tfx/types/channel_utils_test.py @@ -13,7 +13,8 @@ # limitations under the License. """Tests for tfx.utils.channel.""" -import tensorflow as tf +from absl.testing import absltest +from tfx.dsl.components.base.testing import test_node from tfx.dsl.placeholder import placeholder as ph from tfx.types import artifact from tfx.types import channel @@ -25,7 +26,7 @@ class _MyArtifact(artifact.Artifact): TYPE_NAME = 'MyTypeName' -class ChannelUtilsTest(tf.test.TestCase): +class ChannelUtilsTest(absltest.TestCase): def testArtifactCollectionAsChannel(self): instance_a = _MyArtifact() @@ -54,8 +55,16 @@ def testUnwrapChannelDict(self): self.assertDictEqual(result, {'id': [instance_a, instance_b]}) def testGetInidividualChannels(self): - one_channel = channel.Channel(_MyArtifact) - another_channel = channel.Channel(_MyArtifact) + one_channel = channel.OutputChannel( + artifact_type=_MyArtifact, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + another_channel = channel.OutputChannel( + artifact_type=_MyArtifact, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) result = channel_utils.get_individual_channels(one_channel) self.assertEqual(result, [one_channel]) @@ -65,8 +74,16 @@ def testGetInidividualChannels(self): self.assertEqual(result, [one_channel, another_channel]) def testPredicateDependentChannels(self): - int1 = channel.Channel(type=standard_artifacts.Integer) - int2 = channel.Channel(type=standard_artifacts.Integer) + int1 = channel.OutputChannel( + artifact_type=standard_artifacts.Integer, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + int2 = channel.OutputChannel( + artifact_type=standard_artifacts.Integer, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) pred1 = int1.future().value == 1 pred2 = int1.future().value == int2.future().value pred3 = ph.logical_not(pred1) @@ -82,7 +99,11 @@ def testPredicateDependentChannels(self): ) def testUnwrapSimpleChannelPlaceholder(self): - int1 = channel.Channel(type=standard_artifacts.Integer) + int1 = channel.OutputChannel( + artifact_type=standard_artifacts.Integer, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) self.assertEqual( channel_utils.unwrap_simple_channel_placeholder(int1.future()[0].value), int1, @@ -93,8 +114,16 @@ def testUnwrapSimpleChannelPlaceholder(self): ) def testUnwrapSimpleChannelPlaceholderRejectsMultiChannel(self): - str1 = channel.Channel(type=standard_artifacts.String) - str2 = channel.Channel(type=standard_artifacts.String) + str1 = channel.OutputChannel( + artifact_type=standard_artifacts.String, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + str2 = channel.OutputChannel( + artifact_type=standard_artifacts.String, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) with self.assertRaisesRegex(ValueError, '.*placeholder of shape.*'): channel_utils.unwrap_simple_channel_placeholder( str1.future()[0].value + str2.future()[0].value @@ -113,7 +142,11 @@ def testUnwrapSimpleChannelPlaceholderRejectsNoChannel(self): channel_utils.unwrap_simple_channel_placeholder(ph.output('disallowed')) def testUnwrapSimpleChannelPlaceholderRejectsComplexPlaceholders(self): - str1 = channel.Channel(type=standard_artifacts.String) + str1 = channel.OutputChannel( + artifact_type=standard_artifacts.String, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) with self.assertRaisesRegex(ValueError, '.*placeholder of shape.*'): channel_utils.unwrap_simple_channel_placeholder( str1.future()[0].value + 'foo' @@ -125,4 +158,4 @@ def testUnwrapSimpleChannelPlaceholderRejectsComplexPlaceholders(self): if __name__ == '__main__': - tf.test.main() + absltest.main() diff --git a/tfx/types/channel_wrapped_placeholder_test.py b/tfx/types/channel_wrapped_placeholder_test.py index 7ca33c69d54..1ab9933c4e3 100644 --- a/tfx/types/channel_wrapped_placeholder_test.py +++ b/tfx/types/channel_wrapped_placeholder_test.py @@ -18,14 +18,16 @@ from absl.testing import parameterized import tensorflow as tf +from tfx.dsl.components.base.testing import test_node from tfx.dsl.placeholder import placeholder as ph from tfx.proto.orchestration import placeholder_pb2 from tfx.types import channel_utils +from tfx.types import OutputChannel from tfx.types import standard_artifacts from tfx.types.artifact import Artifact from tfx.types.artifact import Property from tfx.types.artifact import PropertyType -from tfx.types.channel import Channel + from google.protobuf import message from google.protobuf import text_format @@ -53,9 +55,13 @@ class _MyType(Artifact): class ChannelWrappedPlaceholderTest(parameterized.TestCase, tf.test.TestCase): def testProtoFutureValueOperator(self): - output_channel = Channel(type=standard_artifacts.Integer) + output_channel = OutputChannel( + artifact_type=standard_artifacts.Integer, + producer_component=test_node.TestNode('producer'), + output_key='num', + ) placeholder = output_channel.future()[0].value - channel_to_key = {output_channel: '_component.num'} + channel_to_key = {output_channel: 'producer_num'} self.assertProtoEquals( channel_utils.encode_placeholder_with_channels( placeholder, lambda k: channel_to_key[k] @@ -66,30 +72,82 @@ def testProtoFutureValueOperator(self): @parameterized.named_parameters( { 'testcase_name': 'two_sides_placeholder', - 'left': Channel(type=_MyType).future().value, - 'right': Channel(type=_MyType).future().value, + 'left': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('left'), + output_key='l', + ) + .future() + .value + ), + 'right': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('right'), + output_key='r', + ) + .future() + .value + ), }, { 'testcase_name': 'left_side_placeholder_right_side_string', - 'left': Channel(type=_MyType).future().value, + 'left': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('left'), + output_key='l', + ) + .future() + .value + ), 'right': '#', }, { 'testcase_name': 'left_side_string_right_side_placeholder', 'left': 'http://', - 'right': Channel(type=_MyType).future().value, + 'right': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('right'), + output_key='r', + ) + .future() + .value + ), }, ) def testConcat(self, left, right): self.assertIsInstance(left + right, ph.Placeholder) def testJoinWithSelf(self): - left = Channel(type=_MyType).future().value - right = Channel(type=_MyType).future().value + left = ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ) + right = ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ) self.assertIsInstance(ph.join([left, right]), ph.Placeholder) def testEncodeWithKeys(self): - my_channel = Channel(type=_MyType) + my_channel = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) channel_future = my_channel.future()[0].value actual_pb = channel_utils.encode_placeholder_with_channels( channel_future, lambda c: c.type_name @@ -103,7 +161,7 @@ def testEncodeWithKeys(self): index_op { expression { placeholder { - key: "MyTypeName" + key: "producer_foo" } } } @@ -111,7 +169,9 @@ def testEncodeWithKeys(self): } } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) @@ -120,15 +180,39 @@ class PredicateTest(parameterized.TestCase, tf.test.TestCase): @parameterized.named_parameters( { 'testcase_name': 'two_sides_placeholder', - 'left': Channel(type=_MyType).future().value, - 'right': Channel(type=_MyType).future().value, + 'left': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ), + 'right': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ), 'expected_op': placeholder_pb2.ComparisonOperator.Operation.LESS_THAN, 'expected_lhs_field': 'operator', 'expected_rhs_field': 'operator', }, { 'testcase_name': 'left_side_placeholder_right_side_int', - 'left': Channel(type=_MyType).future().value, + 'left': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ), 'right': 1, 'expected_op': placeholder_pb2.ComparisonOperator.Operation.LESS_THAN, 'expected_lhs_field': 'operator', @@ -137,7 +221,15 @@ class PredicateTest(parameterized.TestCase, tf.test.TestCase): }, { 'testcase_name': 'left_side_placeholder_right_side_float', - 'left': Channel(type=_MyType).future().value, + 'left': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ), 'right': 1.1, 'expected_op': placeholder_pb2.ComparisonOperator.Operation.LESS_THAN, 'expected_lhs_field': 'operator', @@ -146,7 +238,15 @@ class PredicateTest(parameterized.TestCase, tf.test.TestCase): }, { 'testcase_name': 'left_side_placeholder_right_side_string', - 'left': Channel(type=_MyType).future().value, + 'left': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ), 'right': 'one', 'expected_op': placeholder_pb2.ComparisonOperator.Operation.LESS_THAN, 'expected_lhs_field': 'operator', @@ -154,36 +254,42 @@ class PredicateTest(parameterized.TestCase, tf.test.TestCase): 'expected_rhs_value_type': 'string_value', }, { - 'testcase_name': - 'right_side_placeholder_left_side_int', - 'left': - 1, - 'right': - Channel(type=_MyType).future().value, - 'expected_op': - placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN, - 'expected_lhs_field': - 'operator', - 'expected_rhs_field': - 'value', - 'expected_rhs_value_type': - 'int_value', + 'testcase_name': 'right_side_placeholder_left_side_int', + 'left': 1, + 'right': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ), + 'expected_op': ( + placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN + ), + 'expected_lhs_field': 'operator', + 'expected_rhs_field': 'value', + 'expected_rhs_value_type': 'int_value', }, { - 'testcase_name': - 'right_side_placeholder_left_side_float', - 'left': - 1.1, - 'right': - Channel(type=_MyType).future().value, - 'expected_op': - placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN, - 'expected_lhs_field': - 'operator', - 'expected_rhs_field': - 'value', - 'expected_rhs_value_type': - 'double_value', + 'testcase_name': 'right_side_placeholder_left_side_float', + 'left': 1.1, + 'right': ( + OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + .future() + .value + ), + 'expected_op': ( + placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN + ), + 'expected_lhs_field': 'operator', + 'expected_rhs_field': 'value', + 'expected_rhs_value_type': 'double_value', }, ) def testComparison(self, @@ -206,16 +312,32 @@ def testComparison(self, expected_rhs_value_type)) def testEquals(self): - left = Channel(type=_MyType) - right = Channel(type=_MyType) + left = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) + right = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) pred = left.future().value == right.future().value actual_pb = pred.encode() self.assertEqual(actual_pb.operator.compare_op.op, placeholder_pb2.ComparisonOperator.Operation.EQUAL) def testEncode(self): - channel_1 = Channel(type=_MyType) - channel_2 = Channel(type=_MyType) + channel_1 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + channel_2 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) pred = channel_1.future().value > channel_2.future().value actual_pb = pred.encode() expected_pb = text_format.Parse( @@ -229,7 +351,9 @@ def testEncode(self): operator { index_op { expression { - placeholder {} + placeholder { + key: "a_foo" + } } } } @@ -244,7 +368,9 @@ def testEncode(self): operator { index_op { expression { - placeholder {} + placeholder { + key: "b_bar" + } } } } @@ -255,12 +381,22 @@ def testEncode(self): op: GREATER_THAN } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) def testEncodeWithKeys(self): - channel_1 = Channel(type=_MyType) - channel_2 = Channel(type=_MyType) + channel_1 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + channel_2 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) pred = channel_1.future().value > channel_2.future().value channel_to_key_map = { channel_1: 'channel_1_key', @@ -281,7 +417,7 @@ def testEncodeWithKeys(self): index_op { expression { placeholder { - key: "channel_1_key" + key: "a_foo" } } } @@ -298,7 +434,7 @@ def testEncodeWithKeys(self): index_op { expression { placeholder { - key: "channel_2_key" + key: "b_bar" } } } @@ -310,12 +446,22 @@ def testEncodeWithKeys(self): op: GREATER_THAN } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) def testNegation(self): - channel_1 = Channel(type=_MyType) - channel_2 = Channel(type=_MyType) + channel_1 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + channel_2 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) pred = channel_1.future().value < channel_2.future().value not_pred = ph.logical_not(pred) channel_to_key_map = { @@ -340,7 +486,7 @@ def testNegation(self): index_op { expression { placeholder { - key: "channel_1_key" + key: "a_foo" } } } @@ -357,7 +503,7 @@ def testNegation(self): index_op { expression { placeholder { - key: "channel_2_key" + key: "b_bar" } } } @@ -373,13 +519,23 @@ def testNegation(self): op: NOT } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) def testDoubleNegation(self): """Treat `not(not(a))` as `a`.""" - channel_1 = Channel(type=_MyType) - channel_2 = Channel(type=_MyType) + channel_1 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + channel_2 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) pred = channel_1.future().value < channel_2.future().value not_not_pred = ph.logical_not(ph.logical_not(pred)) channel_to_key_map = { @@ -401,7 +557,7 @@ def testDoubleNegation(self): index_op { expression { placeholder { - key: "channel_1_key" + key: "a_foo" } } } @@ -418,7 +574,7 @@ def testDoubleNegation(self): index_op { expression { placeholder { - key: "channel_2_key" + key: "b_bar" } } } @@ -430,13 +586,23 @@ def testDoubleNegation(self): op: LESS_THAN } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) def testComparison_notEqual(self): """Treat `a != b` as `not(a == b)`.""" - channel_1 = Channel(type=_MyType) - channel_2 = Channel(type=_MyType) + channel_1 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + channel_2 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) pred = channel_1.future().value != channel_2.future().value channel_to_key_map = { channel_1: 'channel_1_key', @@ -460,7 +626,7 @@ def testComparison_notEqual(self): index_op { expression { placeholder { - key: "channel_1_key" + key: "a_foo" } } } @@ -477,7 +643,7 @@ def testComparison_notEqual(self): index_op { expression { placeholder { - key: "channel_2_key" + key: "b_bar" } } } @@ -493,13 +659,23 @@ def testComparison_notEqual(self): op: NOT } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) def testComparison_lessThanOrEqual(self): """Treat `a <= b` as `not(a > b)`.""" - channel_1 = Channel(type=_MyType) - channel_2 = Channel(type=_MyType) + channel_1 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + channel_2 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) pred = channel_1.future().value <= channel_2.future().value channel_to_key_map = { channel_1: 'channel_1_key', @@ -523,7 +699,7 @@ def testComparison_lessThanOrEqual(self): index_op { expression { placeholder { - key: "channel_1_key" + key: "a_foo" } } } @@ -540,7 +716,7 @@ def testComparison_lessThanOrEqual(self): index_op { expression { placeholder { - key: "channel_2_key" + key: "b_bar" } } } @@ -556,13 +732,23 @@ def testComparison_lessThanOrEqual(self): op: NOT } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) def testComparison_greaterThanOrEqual(self): """Treat `a >= b` as `not(a < b)`.""" - channel_1 = Channel(type=_MyType) - channel_2 = Channel(type=_MyType) + channel_1 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('a'), + output_key='foo', + ) + channel_2 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('b'), + output_key='bar', + ) pred = channel_1.future().value >= channel_2.future().value channel_to_key_map = { channel_1: 'channel_1_key', @@ -586,7 +772,7 @@ def testComparison_greaterThanOrEqual(self): index_op { expression { placeholder { - key: "channel_1_key" + key: "a_foo" } } } @@ -603,7 +789,7 @@ def testComparison_greaterThanOrEqual(self): index_op { expression { placeholder { - key: "channel_2_key" + key: "b_bar" } } } @@ -619,15 +805,37 @@ def testComparison_greaterThanOrEqual(self): op: NOT } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) def testNestedLogicalOps(self): - channel_11 = Channel(type=_MyType) - channel_12 = Channel(type=_MyType) - channel_21 = Channel(type=_MyType) - channel_22 = Channel(type=_MyType) - channel_3 = Channel(type=_MyType) + channel_11 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('a'), + output_key='1', + ) + channel_12 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('b'), + output_key='2', + ) + channel_21 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('c'), + output_key='3', + ) + channel_22 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('d'), + output_key='4', + ) + channel_3 = OutputChannel( + artifact_type=_MyType, + producer_component=test_node.TestNode('e'), + output_key='5', + ) pred = ph.logical_or( ph.logical_and(channel_11.future().value >= channel_12.future().value, channel_21.future().value < channel_22.future().value), @@ -664,7 +872,7 @@ def testNestedLogicalOps(self): index_op { expression { placeholder { - key: "channel_11_key" + key: "a_1" } } } @@ -681,7 +889,7 @@ def testNestedLogicalOps(self): index_op { expression { placeholder { - key: "channel_12_key" + key: "b_2" } } } @@ -709,7 +917,7 @@ def testNestedLogicalOps(self): index_op { expression { placeholder { - key: "channel_21_key" + key: "c_3" } } } @@ -726,7 +934,7 @@ def testNestedLogicalOps(self): index_op { expression { placeholder { - key: "channel_22_key" + key: "d_4" } } } @@ -757,7 +965,7 @@ def testNestedLogicalOps(self): index_op { expression { placeholder { - key: "channel_3_key" + key: "e_5" } } } @@ -782,7 +990,9 @@ def testNestedLogicalOps(self): op: OR } } - """, placeholder_pb2.PlaceholderExpression()) + """, + placeholder_pb2.PlaceholderExpression(), + ) self.assertProtoEquals(actual_pb, expected_pb) diff --git a/tfx/types/component_spec_test.py b/tfx/types/component_spec_test.py index e58630a5d45..1b9e589f7ae 100644 --- a/tfx/types/component_spec_test.py +++ b/tfx/types/component_spec_test.py @@ -18,7 +18,8 @@ from typing import Dict, List import unittest -import tensorflow as tf +from absl.testing import absltest +from tfx.dsl.components.base.testing import test_node from tfx.dsl.placeholder import placeholder from tfx.proto import example_gen_pb2 from tfx.types import artifact @@ -64,7 +65,7 @@ class _BasicComponentSpec(ComponentSpec): } -class ComponentSpecTest(tf.test.TestCase): +class ComponentSpecTest(absltest.TestCase): # pylint: disable=unused-variable def testComponentSpec_Empty(self): @@ -308,9 +309,6 @@ class _BarArtifact(artifact.Artifact): # Following should pass. channel_parameter.type_check(arg_name, channel.Channel(type=_FooArtifact)) - with self.assertRaisesRegex(TypeError, arg_name): - channel_parameter.type_check(arg_name, 42) # Wrong value. - with self.assertRaisesRegex(TypeError, arg_name): channel_parameter.type_check(arg_name, channel.Channel(type=_BarArtifact)) @@ -361,7 +359,11 @@ def testExecutionParameterTypeCheck(self): with self.assertRaises(json_format.ParseError): proto_parameter.type_check('proto_parameter', {'splits': 42}) - output_channel = channel.Channel(type=_OutputArtifact) + output_channel = channel.OutputChannel( + artifact_type=_OutputArtifact, + producer_component=test_node.TestNode('producer'), + output_key='foo', + ) placeholder_parameter = ExecutionParameter(type=str) placeholder_parameter.type_check( @@ -445,4 +447,4 @@ class SpecWithNonPrimitiveTypes(ComponentSpec): if __name__ == '__main__': - tf.test.main() + absltest.main() diff --git a/tfx/types/testdata/proto_placeholder_future_value_operator.pbtxt b/tfx/types/testdata/proto_placeholder_future_value_operator.pbtxt index 6b260aec6ac..f3dbfaa56af 100644 --- a/tfx/types/testdata/proto_placeholder_future_value_operator.pbtxt +++ b/tfx/types/testdata/proto_placeholder_future_value_operator.pbtxt @@ -8,7 +8,7 @@ operator { index_op { expression { placeholder { - key: "_component.num" + key: "producer_num" } } }