Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691236408
  • Loading branch information
tensorflower-gardener committed Oct 30, 2024
1 parent eb2c1b3 commit 38f781c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
33 changes: 18 additions & 15 deletions official/modeling/hyperparams/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import functools
import inspect
import types
import typing
from typing import Any, List, Mapping, Optional, Type, Union

Expand Down Expand Up @@ -58,8 +59,13 @@ def _wrapper(self, *args, **kwargs): # pylint: disable=unused-argument


def _is_optional(field):
return typing.get_origin(field) is Union and type(None) in typing.get_args(
field)
# Two styles of annotating optional fields:
# Optional[T] -> typing.Union
# T | None -> types.UnionType
is_union = typing.get_origin(field) in (Union, types.UnionType)
# An optional field is a union of a type, and NoneType.
args = typing.get_args(field)
return is_union and len(args) == 2 and type(None) in args


@dataclasses.dataclass
Expand Down Expand Up @@ -189,6 +195,9 @@ def _get_subconfig_type(
if not subconfig_type:
subconfig_type = Config

def _is_subtype(x, target) -> bool:
return isinstance(x, type) and issubclass(x, target)

annotations = cls._get_annotations()
if k in annotations:
# Directly Config subtype.
Expand All @@ -198,22 +207,16 @@ def _get_subconfig_type(
traverse_in = True
while traverse_in:
i += 1
if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)):
if _is_subtype(type_annotation, Config):
subconfig_type = type_annotation
break
else:
# Check if the field is a sequence of subtypes.
field_type = typing.get_origin(type_annotation)
if (isinstance(field_type, type) and
issubclass(field_type, cls.SEQUENCE_TYPES)):
element_type = typing.get_args(type_annotation)[0]
subconfig_type = (
element_type if issubclass(element_type, params_dict.ParamsDict)
else subconfig_type)
break
elif _is_optional(type_annotation):
# Strip the `Optional` annotation and process the subtype.
# If the field is a sequence of sub-config types or an Optional
# sub-config, then strip the container and process the sub-config.
is_sequence = _is_subtype(
typing.get_origin(type_annotation), cls.SEQUENCE_TYPES
)
if is_sequence or _is_optional(type_annotation):
type_annotation = typing.get_args(type_annotation)[0]
continue
traverse_in = False
Expand Down
20 changes: 20 additions & 0 deletions official/modeling/hyperparams/base_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class DumpConfig6(base_config.Config):
test_config1: Optional[DumpConfig1] = None


@dataclasses.dataclass
class ModernOptionalConfig(base_config.Config):
leaf: DumpConfig1 | None = None
leaves: tuple[DumpConfig1 | None, ...] = tuple()


class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):

def assertHasSameTypes(self, c, d, msg=''):
Expand Down Expand Up @@ -422,6 +428,20 @@ def test_correctly_display_optional_field(self):
"DumpConfig6(test_config1=DumpConfig1(a=1, b='abc'))")
self.assertIsInstance(c.test_config1, DumpConfig1)

def test_modern_optional_syntax(self):
config = ModernOptionalConfig()
self.assertIsNone(config.leaf)
self.assertEqual(config.leaves, tuple())

replaced = config.replace(leaf={'a': 2}, leaves=({'a': 3}, {'b': 'foo'}))
self.assertEqual(replaced.leaf.a, 2)
self.assertEqual(replaced.leaf.b, 'text')
self.assertLen(replaced.leaves, 2)
self.assertEqual(replaced.leaves[0].a, 3)
self.assertEqual(replaced.leaves[0].b, 'text')
self.assertEqual(replaced.leaves[1].a, 1)
self.assertEqual(replaced.leaves[1].b, 'foo')


if __name__ == '__main__':
tf.test.main()

0 comments on commit 38f781c

Please sign in to comment.