Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow to use */** arguments with non-standard names #127

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions fast_depends/core/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def build_call_model(
custom_fields: Dict[str, CustomField] = {}
positional_args: List[str] = []
keyword_args: List[str] = []
var_positional_arg: Optional[str] = None
var_keyword_arg: Optional[str] = None

for param_name, param in typed_params.parameters.items():
dep: Optional[Depends] = None
Expand Down Expand Up @@ -117,10 +119,12 @@ def build_call_model(
annotation = param.annotation

default: Any
if param_name == "args":
if param.kind == inspect.Parameter.VAR_POSITIONAL:
default = ()
elif param_name == "kwargs":
var_positional_arg = param_name
elif param.kind == inspect.Parameter.VAR_KEYWORD:
default = {}
var_keyword_arg = param_name
elif param.default is inspect.Parameter.empty:
default = Ellipsis
else:
Expand Down Expand Up @@ -180,7 +184,7 @@ def build_call_model(
else:
if param.kind is param.KEYWORD_ONLY:
keyword_args.append(param_name)
elif param_name not in ("args", "kwargs"):
elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
positional_args.append(param_name)

func_model = create_model( # type: ignore[call-overload]
Expand Down Expand Up @@ -210,6 +214,8 @@ def build_call_model(
custom_fields=custom_fields,
positional_args=positional_args,
keyword_args=keyword_args,
var_positional_arg=var_positional_arg,
var_keyword_arg=var_keyword_arg,
extra_dependencies=[
build_call_model(
d.dependency,
Expand Down
25 changes: 17 additions & 8 deletions fast_depends/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class CallModel(Generic[P, T]):
custom_fields: Dict[str, CustomField]
keyword_args: Tuple[str, ...]
positional_args: Tuple[str, ...]
var_positional_arg: Optional[str]
var_keyword_arg: Optional[str]

# Dependencies and custom fields
use_cache: bool
Expand All @@ -82,6 +84,8 @@ class CallModel(Generic[P, T]):
"alias_arguments",
"keyword_args",
"positional_args",
"var_positional_arg",
"var_keyword_arg",
"dependencies",
"extra_dependencies",
"sorted_dependencies",
Expand Down Expand Up @@ -152,6 +156,8 @@ def __init__(
extra_dependencies: Optional[Iterable["CallModel[..., Any]"]] = None,
keyword_args: Optional[List[str]] = None,
positional_args: Optional[List[str]] = None,
var_positional_arg: Optional[str] = None,
var_keyword_arg: Optional[str] = None,
custom_fields: Optional[Dict[str, CustomField]] = None,
):
self.call = call
Expand All @@ -164,6 +170,8 @@ def __init__(

self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())
self.var_positional_arg = var_positional_arg
self.var_keyword_arg = var_keyword_arg
self.response_model = response_model
self.use_cache = use_cache
self.cast = cast
Expand Down Expand Up @@ -241,8 +249,8 @@ def _solve(
if (v := kwargs.pop(arg, Parameter.empty)) is not Parameter.empty:
kw[arg] = v

if "kwargs" in self.alias_arguments:
kw["kwargs"] = kwargs
if self.var_keyword_arg is not None:
kw[self.var_keyword_arg] = kwargs
else:
kw.update(kwargs)

Expand All @@ -253,8 +261,8 @@ def _solve(
break

keyword_args: Iterable[str]
if has_args := "args" in self.alias_arguments:
kw["args"] = args
if self.var_positional_arg is not None:
kw[self.var_positional_arg] = args
keyword_args = self.keyword_args

else:
Expand All @@ -281,21 +289,22 @@ def _solve(
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in keyword_args
}
kwargs_.update(getattr(casted_model, "kwargs", {}))
if self.var_keyword_arg:
kwargs_.update(getattr(casted_model, self.var_keyword_arg, {}))

if has_args:
if self.var_positional_arg is not None:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self.positional_args
]
args_.extend(getattr(casted_model, "args", ()))
args_.extend(getattr(casted_model, self.var_positional_arg, ()))
else:
args_ = ()

else:
kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args}

if has_args:
if self.var_positional_arg is not None:
args_ = tuple(map(solved_kw.get, self.positional_args))
else:
args_ = ()
Expand Down
30 changes: 30 additions & 0 deletions tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,33 @@ def extra_func(n): ...

assert set(model.params.keys()) == {"a", "b"}
assert set(model.flat_params.keys()) == {"a", "b", "c", "m", "n"}


def test_args_kwargs_params():
def func1(m): ...

def func2(c, b=Depends(func1), d=CustomField()): # noqa: B008
...

def func3(b): ...

def default_var_names(a, *args, b, m=Depends(func2), k=Depends(func3), **kwargs):
return a, args, b, kwargs

def custom_var_names(a, *args_, b, m=Depends(func2), k=Depends(func3), **kwargs_):
return a, args_, b, kwargs_

def extra_func(n): ...

model1 = build_call_model(default_var_names, extra_dependencies=(Depends(extra_func),))

assert set(model1.params.keys()) == {"a", "args", "b", "kwargs"}
assert set(model1.flat_params.keys()) == {"a", "args", "b", "kwargs", "c", "m", "n"}

model2 = build_call_model(custom_var_names, extra_dependencies=(Depends(extra_func),))

assert set(model2.params.keys()) == {"a", "args_", "b", "kwargs_"}
assert set(model2.flat_params.keys()) == {"a", "args_", "b", "kwargs_", "c", "m", "n"}

assert default_var_names(1, *('a'), b=2, **{'kw': 'kw'}) == (1, ('a',), 2, {'kw': 'kw'})
assert custom_var_names(1, *('a'), b=2, **{'kw': 'kw'}) == (1, ('a',), 2, {'kw': 'kw'})