diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index ffc55948..c1c70eee 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -188,13 +188,14 @@ def _add_dependency(cls_or_classes): if isinstance(cls_or_classes, type): cls_or_classes = [cls_or_classes] for cls in cls_or_classes: - if cls not in seen: + if cls not in dependencies: dependencies.append(cls) - seen.add(cls) def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool): if isinstance(vs, pg.typing.Object): if issubclass(vs.cls, pg.Object) and vs.cls not in seen: + seen.add(vs.cls) + # Add base classes as dependencies. for base_cls in vs.cls.__bases__: # We only keep track of user-defined symbolic classes. diff --git a/langfun/core/structured/schema_test.py b/langfun/core/structured/schema_test.py index a9a03d92..98a40e79 100644 --- a/langfun/core/structured/schema_test.py +++ b/langfun/core/structured/schema_test.py @@ -32,6 +32,10 @@ class Itinerary(pg.Object): hotel: pg.typing.Str['.*Hotel'] | None +class Node(pg.Object): + children: list['Node'] + + class SchemaTest(unittest.TestCase): def assert_schema(self, annotation, spec): @@ -201,6 +205,12 @@ class B(A): with self.assertRaisesRegex(TypeError, 'Unsupported spec type'): schema_lib.class_dependencies((Foo, 1)) + def test_class_dependencies_recursive(self): + self.assertEqual( + schema_lib.class_dependencies(Node), + [Node] + ) + def test_class_dependencies_from_value(self): class Foo(pg.Object): x: int diff --git a/requirements.txt b/requirements.txt index a37358bd..65b8debb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ jinja2>=3.1.2 openai>=0.18.1 -pyglove>=0.4.3 +pyglove>=0.4.4.dev20230922 termcolor==1.1.0 tqdm>=4.64.1 \ No newline at end of file