From 9475b128424db41b2c3247717df0274a62995a05 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Tue, 19 Sep 2023 18:15:42 -0700 Subject: [PATCH] `lf.structured.class_dependencies` to support cyclic dependencies. PiperOrigin-RevId: 566806521 --- langfun/core/structured/schema.py | 5 +++-- langfun/core/structured/schema_test.py | 10 ++++++++++ requirements.txt | 2 +- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index ffc55948..c1c70eee 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -188,13 +188,14 @@ def _add_dependency(cls_or_classes): if isinstance(cls_or_classes, type): cls_or_classes = [cls_or_classes] for cls in cls_or_classes: - if cls not in seen: + if cls not in dependencies: dependencies.append(cls) - seen.add(cls) def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool): if isinstance(vs, pg.typing.Object): if issubclass(vs.cls, pg.Object) and vs.cls not in seen: + seen.add(vs.cls) + # Add base classes as dependencies. for base_cls in vs.cls.__bases__: # We only keep track of user-defined symbolic classes. diff --git a/langfun/core/structured/schema_test.py b/langfun/core/structured/schema_test.py index a9a03d92..98a40e79 100644 --- a/langfun/core/structured/schema_test.py +++ b/langfun/core/structured/schema_test.py @@ -32,6 +32,10 @@ class Itinerary(pg.Object): hotel: pg.typing.Str['.*Hotel'] | None +class Node(pg.Object): + children: list['Node'] + + class SchemaTest(unittest.TestCase): def assert_schema(self, annotation, spec): @@ -201,6 +205,12 @@ class B(A): with self.assertRaisesRegex(TypeError, 'Unsupported spec type'): schema_lib.class_dependencies((Foo, 1)) + def test_class_dependencies_recursive(self): + self.assertEqual( + schema_lib.class_dependencies(Node), + [Node] + ) + def test_class_dependencies_from_value(self): class Foo(pg.Object): x: int diff --git a/requirements.txt b/requirements.txt index e61adf17..c728f8e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ jinja2>=3.1.2 openai>=0.18.1 -pyglove>=0.4.3 +pyglove>=0.4.4.dev20230922 termcolor==1.1.0