Skip to content

Commit

Permalink
Encode producer component id and output key when CWP is created from …
Browse files Browse the repository at this point in the history
…an OutputChannel

PiperOrigin-RevId: 619667393
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 17, 2024
1 parent ef4dd95 commit 81229e1
Show file tree
Hide file tree
Showing 13 changed files with 418 additions and 128 deletions.
9 changes: 8 additions & 1 deletion tfx/dsl/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,17 @@ def testCompileAdditionalCustomPropertyNameConflictError(self):
def testCompileDynamicExecPropTypeError(self):
dsl_compiler = compiler.Compiler()
test_pipeline = dynamic_exec_properties_pipeline.create_test_pipeline()
upstream_component = next(
c
for c in test_pipeline.components
if isinstance(c, dynamic_exec_properties_pipeline.UpstreamComponent)
)
downstream_component = next(
c for c in test_pipeline.components
if isinstance(c, dynamic_exec_properties_pipeline.DownstreamComponent))
test_wrong_type_channel = channel.Channel(_MyType).future().value
test_wrong_type_channel = (
channel.OutputChannel(_MyType, upstream_component, "foo").future().value
)
downstream_component.exec_properties["input_num"] = test_wrong_type_channel
with self.assertRaisesRegex(
ValueError, ".*channel must be of a value artifact type.*"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,7 @@ nodes {
index_op {
expression {
placeholder {
key: "blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -2983,7 +2983,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_infra-validator-pipeline.blessing"
key: "infra-validator-pipeline_blessing"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2145,7 +2145,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -3351,7 +3351,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_infra-validator-pipeline.blessing"
key: "infra-validator-pipeline_blessing"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -1264,7 +1264,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -1301,7 +1301,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_InfraValidator.blessing"
key: "InfraValidator_blessing"
}
}
}
Expand Down Expand Up @@ -1333,7 +1333,7 @@ nodes {
index_op {
expression {
placeholder {
key: "model"
key: "Trainer_model"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_UpstreamComponent.num"
key: "UpstreamComponent_num"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_UpstreamComponent.num"
key: "UpstreamComponent_num"
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions tfx/dsl/components/base/testing/test_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module to provide a node for tests."""

from tfx.dsl.components.base import base_node


class TestNode(base_node.BaseNode):
"""Node purely for testing, intentionally empty.
DO NOT USE in real pipelines.
"""

inputs = {}
outputs = {}
exec_properties = {}

def __init__(self, name: str):
super().__init__()
self.with_id(name)
13 changes: 11 additions & 2 deletions tfx/types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def trigger_by_property(self, *property_keys: str):
return self._with_input_trigger(TriggerByProperty(property_keys))

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self)
raise NotImplementedError()

def __eq__(self, other):
return self is other
Expand Down Expand Up @@ -557,6 +557,11 @@ def set_external(self, predefined_artifact_uris: List[str]) -> None:
def set_as_async_channel(self) -> None:
self._is_async = True

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(
self, f'{self.producer_component_id}_{self.output_key}'
)


@doc_controls.do_not_generate_docs
class UnionChannel(BaseChannel):
Expand Down Expand Up @@ -703,6 +708,9 @@ def trigger_by_property(self, *property_keys: str):
'trigger_by_property is not implemented for PipelineInputChannel.'
)

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self, f'{self._output_key}')


class ExternalPipelineChannel(BaseChannel):
"""Channel subtype that is used to get artifacts from external MLMD db."""
Expand Down Expand Up @@ -787,7 +795,8 @@ def set_key(self, key: Optional[str]):
Args:
key: The new key for the channel.
"""
self._key = key
del key # unused.
return

def __getitem__(self, index: int) -> ChannelWrappedPlaceholder:
if self._index is not None:
Expand Down
4 changes: 1 addition & 3 deletions tfx/types/channel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,8 @@ def unwrap_simple_channel_placeholder(
# proto paths above and been getting default messages all along. If this
# sub-message is present, then the whole chain was correct.
not index_op.expression.HasField('placeholder')
# ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason, and has
# no key when encoded with encode().
# ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason.
or cwp.type != placeholder_pb2.Placeholder.Type.INPUT_ARTIFACT
or cwp.key
# For the `[0]` part of the desired shape.
or index_op.index != 0
):
Expand Down
55 changes: 44 additions & 11 deletions tfx/types/channel_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
"""Tests for tfx.utils.channel."""

import tensorflow as tf
from absl.testing import absltest
from tfx.dsl.components.base.testing import test_node
from tfx.dsl.placeholder import placeholder as ph
from tfx.types import artifact
from tfx.types import channel
Expand All @@ -25,7 +26,7 @@ class _MyArtifact(artifact.Artifact):
TYPE_NAME = 'MyTypeName'


class ChannelUtilsTest(tf.test.TestCase):
class ChannelUtilsTest(absltest.TestCase):

def testArtifactCollectionAsChannel(self):
instance_a = _MyArtifact()
Expand Down Expand Up @@ -54,8 +55,16 @@ def testUnwrapChannelDict(self):
self.assertDictEqual(result, {'id': [instance_a, instance_b]})

def testGetInidividualChannels(self):
one_channel = channel.Channel(_MyArtifact)
another_channel = channel.Channel(_MyArtifact)
one_channel = channel.OutputChannel(
artifact_type=_MyArtifact,
producer_component=test_node.TestNode('a'),
output_key='foo',
)
another_channel = channel.OutputChannel(
artifact_type=_MyArtifact,
producer_component=test_node.TestNode('b'),
output_key='bar',
)

result = channel_utils.get_individual_channels(one_channel)
self.assertEqual(result, [one_channel])
Expand All @@ -65,8 +74,16 @@ def testGetInidividualChannels(self):
self.assertEqual(result, [one_channel, another_channel])

def testPredicateDependentChannels(self):
int1 = channel.Channel(type=standard_artifacts.Integer)
int2 = channel.Channel(type=standard_artifacts.Integer)
int1 = channel.OutputChannel(
artifact_type=standard_artifacts.Integer,
producer_component=test_node.TestNode('a'),
output_key='foo',
)
int2 = channel.OutputChannel(
artifact_type=standard_artifacts.Integer,
producer_component=test_node.TestNode('b'),
output_key='bar',
)
pred1 = int1.future().value == 1
pred2 = int1.future().value == int2.future().value
pred3 = ph.logical_not(pred1)
Expand All @@ -82,7 +99,11 @@ def testPredicateDependentChannels(self):
)

def testUnwrapSimpleChannelPlaceholder(self):
int1 = channel.Channel(type=standard_artifacts.Integer)
int1 = channel.OutputChannel(
artifact_type=standard_artifacts.Integer,
producer_component=test_node.TestNode('a'),
output_key='foo',
)
self.assertEqual(
channel_utils.unwrap_simple_channel_placeholder(int1.future()[0].value),
int1,
Expand All @@ -93,8 +114,16 @@ def testUnwrapSimpleChannelPlaceholder(self):
)

def testUnwrapSimpleChannelPlaceholderRejectsMultiChannel(self):
str1 = channel.Channel(type=standard_artifacts.String)
str2 = channel.Channel(type=standard_artifacts.String)
str1 = channel.OutputChannel(
artifact_type=standard_artifacts.String,
producer_component=test_node.TestNode('a'),
output_key='foo',
)
str2 = channel.OutputChannel(
artifact_type=standard_artifacts.String,
producer_component=test_node.TestNode('b'),
output_key='bar',
)
with self.assertRaisesRegex(ValueError, '.*placeholder of shape.*'):
channel_utils.unwrap_simple_channel_placeholder(
str1.future()[0].value + str2.future()[0].value
Expand All @@ -113,7 +142,11 @@ def testUnwrapSimpleChannelPlaceholderRejectsNoChannel(self):
channel_utils.unwrap_simple_channel_placeholder(ph.output('disallowed'))

def testUnwrapSimpleChannelPlaceholderRejectsComplexPlaceholders(self):
str1 = channel.Channel(type=standard_artifacts.String)
str1 = channel.OutputChannel(
artifact_type=standard_artifacts.String,
producer_component=test_node.TestNode('a'),
output_key='foo',
)
with self.assertRaisesRegex(ValueError, '.*placeholder of shape.*'):
channel_utils.unwrap_simple_channel_placeholder(
str1.future()[0].value + 'foo'
Expand All @@ -125,4 +158,4 @@ def testUnwrapSimpleChannelPlaceholderRejectsComplexPlaceholders(self):


if __name__ == '__main__':
tf.test.main()
absltest.main()
Loading

0 comments on commit 81229e1

Please sign in to comment.