Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle required arguments in nested pydantic models #163

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Xallt
Copy link

@Xallt Xallt commented Sep 15, 2024

from pydantic import BaseModel, Field
import tyro

class A(BaseModel):
    x: str

class B(BaseModel):
    a: A = Field(default_factory=A)

if __name__ == "__main__":
    tyro.cli(B)

This particular example failed on me with

Traceback (most recent call last):
  File "/home/xallt/progs/test.py", line 11, in <module>
    tyro.cli(B)
  File "/home/xallt/clones/tyro/src/tyro/_cli.py", line 207, in cli
    output = _cli_impl(
  File "/home/xallt/clones/tyro/src/tyro/_cli.py", line 332, in _cli_impl
    if not _fields.is_nested_type(cast(type, f), default_instance_internal):
  File "/home/xallt/clones/tyro/src/tyro/_unsafe_cache.py", line 33, in wrapped_f
    out = f(*args, **kwargs)
  File "/home/xallt/clones/tyro/src/tyro/_fields.py", line 248, in is_nested_type
    _try_field_list_from_callable(typ, default_instance),
  File "/home/xallt/clones/tyro/src/tyro/_fields.py", line 417, in _try_field_list_from_callable
    return _field_list_from_pydantic(cls, default_instance)
  File "/home/xallt/clones/tyro/src/tyro/_fields.py", line 689, in _field_list_from_pydantic
    default, is_default_from_default_instance = _get_pydantic_v2_field_default(
  File "/home/xallt/clones/tyro/src/tyro/_fields.py", line 1186, in _get_pydantic_v2_field_default
    return field.get_default(call_default_factory=True), False
  File "/media/xallt/HardDrive/.virtualenv/python310/lib/python3.10/site-packages/pydantic/fields.py", line 556, in get_default
    return self.default_factory()
  File "/media/xallt/HardDrive/.virtualenv/python310/lib/python3.10/site-packages/pydantic/main.py", line 193, in __init__
    self.__pydantic_validator__.validate_python(data, self_instance=self)
pydantic_core._pydantic_core.ValidationError: 1 validation error for A
x
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.8/v/missing

I dug around in the code and realized that this is because of the way that the default values for pydantic fields are retrieved in https://github.com/brentyi/tyro/blob/main/src/tyro/_fields.py#L1186
By wrapping in a ValidationError try-catch, we can check if the specified field did try to initialize but crashed due to lack of constructors for some fields, therefore it is in fact a "required" field

@Xallt Xallt changed the title Handle default arguments in nested models Handle default arguments in nested pydantic models Sep 15, 2024
@Xallt Xallt changed the title Handle default arguments in nested pydantic models Handle required arguments in nested pydantic models Sep 15, 2024
@brentyi
Copy link
Owner

brentyi commented Sep 15, 2024

Hi @Xallt, thanks for the PR!

To clarify, isn't this use of default_factory= invalid?

class B(BaseModel):
    a: A = Field(default_factory=A)

default_factory= should take an instance of Callable[[], Any], but here A would expect an argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants