Skip to content

Commit

Permalink
Add support custom provider instance
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanSerafini committed Aug 18, 2023
1 parent e6279dc commit faf655c
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 19 deletions.
3 changes: 2 additions & 1 deletion fast_depends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from fast_depends.dependencies import dependency_provider
from fast_depends.dependencies import dependency_provider, Provider
from fast_depends.use import Depends, inject

__all__ = (
"Depends",
"Provider",
"dependency_provider",
"inject",
)
3 changes: 2 additions & 1 deletion fast_depends/dependencies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fast_depends.dependencies.model import Depends
from fast_depends.dependencies.provider import dependency_provider
from fast_depends.dependencies.provider import dependency_provider, Provider

__all__ = (
"Depends",
"Provider",
"dependency_provider",
)
6 changes: 5 additions & 1 deletion fast_depends/dependencies/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Protocol


class HasDependencyOverrides(Protocol):
dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]]


class Provider:
Expand Down
41 changes: 26 additions & 15 deletions fast_depends/use.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fast_depends.core import CallModel, build_call_model
from fast_depends.dependencies import dependency_provider, model
from fast_depends.dependencies.provider import HasDependencyOverrides

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -40,7 +41,9 @@ def __call__(
def inject( # pragma: no covers
func: None,
*,
dependency_overrides_provider: Optional[Any] = dependency_provider,
dependency_overrides_provider: Optional[
HasDependencyOverrides
] = dependency_provider,
extra_dependencies: Sequence[model.Depends] = (),
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> _InjectWrapper[P, T]:
Expand All @@ -51,7 +54,9 @@ def inject( # pragma: no covers
def inject( # pragma: no covers
func: Callable[P, T],
*,
dependency_overrides_provider: Optional[Any] = dependency_provider,
dependency_overrides_provider: Optional[
HasDependencyOverrides
] = dependency_provider,
extra_dependencies: Sequence[model.Depends] = (),
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> Callable[P, T]:
Expand All @@ -61,7 +66,9 @@ def inject( # pragma: no covers
def inject(
func: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None,
*,
dependency_overrides_provider: Optional[Any] = dependency_provider,
dependency_overrides_provider: Optional[
HasDependencyOverrides
] = dependency_provider,
extra_dependencies: Sequence[model.Depends] = (),
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> Union[Union[Callable[P, T], Callable[P, Awaitable[T]]], _InjectWrapper[P, T],]:
Expand All @@ -78,23 +85,23 @@ def inject(
return decorator(func)


def _resolve_overrides_provider(
dependency_overrides_provider: Optional[HasDependencyOverrides],
) -> Optional[dict[Callable[..., Any], Callable[..., Any]]]:
if not dependency_overrides_provider:
return None

return getattr(dependency_overrides_provider, "dependency_overrides", None)


def _wrap_inject(
dependency_overrides_provider: Optional[Any],
dependency_overrides_provider: Optional[HasDependencyOverrides],
wrap_model: Callable[
[CallModel[P, T]],
CallModel[P, T],
],
extra_dependencies: Sequence[model.Depends],
) -> _InjectWrapper[P, T]:
if (
dependency_overrides_provider
and getattr(dependency_overrides_provider, "dependency_overrides", None)
is not None
):
overrides = dependency_overrides_provider.dependency_overrides
else:
overrides = None

def func_wrapper(
func: Union[Callable[P, T], Callable[P, Awaitable[T]]],
model: Optional[CallModel[P, T]] = None,
Expand All @@ -117,7 +124,9 @@ async def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
r = await real_model.asolve(
*args,
stack=stack,
dependency_overrides=overrides,
dependency_overrides=_resolve_overrides_provider(
dependency_overrides_provider
),
cache_dependencies={},
**kwargs,
)
Expand All @@ -132,7 +141,9 @@ def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
r = real_model.solve(
*args,
stack=stack,
dependency_overrides=overrides,
dependency_overrides=_resolve_overrides_provider(
dependency_overrides_provider
),
cache_dependencies={},
**kwargs,
)
Expand Down
88 changes: 87 additions & 1 deletion tests/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from fast_depends import Depends, dependency_provider, inject
from fast_depends import Depends, dependency_provider, inject, Provider


@pytest.fixture
Expand Down Expand Up @@ -50,6 +50,48 @@ def func(d=Depends(base_dep)):
assert not mock.original.called


def test_sync_override_custom_provider(provider: Provider):
custom_provider = Provider()
mock = Mock()

def base_dep(): # pragma: no cover
mock.original()
return 1

def override_dep():
mock.override()
return 2

custom_provider.override(base_dep, override_dep)

def func(d=Depends(base_dep)):
assert d == 2

func = inject(func, dependency_overrides_provider=custom_provider)
func()

mock.override.assert_called_once()
assert not mock.original.called
assert not provider.dependency_overrides


def test_sync_not_override_custom_provider(provider: Provider):
custom_provider = Provider()
mock = Mock()

def base_dep(): # pragma: no cover
mock.original()
return 1

def func(d=Depends(base_dep)):
assert d == 1

func = inject(func, dependency_overrides_provider=custom_provider)
func()

mock.original.assert_called_once()


def test_sync_by_async_override(provider):
def base_dep(): # pragma: no cover
return 1
Expand Down Expand Up @@ -91,6 +133,50 @@ async def func(d=Depends(base_dep)):
assert not mock.original.called


@pytest.mark.asyncio
async def test_async_override_custom_provider(provider: Provider):
custom_provider = Provider()
mock = Mock()

async def base_dep(): # pragma: no cover
mock.original()
return 1

async def override_dep():
mock.override()
return 2

custom_provider.override(base_dep, override_dep)

async def func(d=Depends(base_dep)):
assert d == 2

func = inject(func, dependency_overrides_provider=custom_provider)
await func()

mock.override.assert_called_once()
assert not mock.original.called
assert not provider.dependency_overrides


@pytest.mark.asyncio
async def test_not_async_override_custom_provider(provider: Provider):
custom_provider = Provider()
mock = Mock()

async def base_dep(): # pragma: no cover
mock.original()
return 1

async def func(d=Depends(base_dep)):
assert d == 1

func = inject(func, dependency_overrides_provider=custom_provider)
await func()

mock.original.assert_called_once()


@pytest.mark.asyncio
async def test_async_by_sync_override(provider):
mock = Mock()
Expand Down

0 comments on commit faf655c

Please sign in to comment.