Skip to content

Commit

Permalink
Fixes for memory (#2)
Browse files Browse the repository at this point in the history
* Fix: call() and __call__() should return _T

* Fix: allow missing argument for memory.cache

* Fix static analyzer
  • Loading branch information
yorickvP authored Sep 18, 2024
1 parent a21340d commit bbeb6ce
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions src/joblib-stubs/memory.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,12 @@ class MemorizedFunc(Logger, Generic[_P, _T]):
def call_and_shelve(
self, *args: _P.args, **kwargs: _P.kwargs
) -> MemorizedResult[_T] | NotMemorizedResult[_T]: ...
def __call__(
self, *args: _P.args, **kwargs: _P.kwargs
) -> MemorizedResult[_T] | NotMemorizedResult[_T]: ...
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
def check_call_in_cache(self, *args: _P.args, **kwargs: _P.kwargs) -> bool: ...
def clear(self, warn: bool = ...) -> None: ...
def call(
self, *args: _P.args, **kwargs: _P.kwargs
) -> tuple[MemorizedResult[_T] | NotMemorizedResult[_T], dict[str, Any]]: ...
) -> tuple[_T, dict[str, Any]]: ...

class AsyncMemorizedFunc(MemorizedFunc[_P, AnyAwaitable[_T]], Generic[_P, _T]):
func: AnyAwaitableCallable[_P, _T] # pyright: ignore[reportIncompatibleMethodOverride]
Expand All @@ -166,15 +164,13 @@ class AsyncMemorizedFunc(MemorizedFunc[_P, AnyAwaitable[_T]], Generic[_P, _T]):
timestamp: float | None = ...,
cache_validation_callback: Callable[..., Any] | None = ...,
) -> None: ...
async def __call__( # type: ignore[override]
self, *args: _P.args, **kwargs: _P.kwargs
) -> MemorizedResult[_T] | NotMemorizedResult[_T]: ...
async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
async def call_and_shelve( # type: ignore[override]
self, *args: _P.args, **kwargs: _P.kwargs
) -> MemorizedResult[_T] | NotMemorizedResult[_T]: ...
async def call( # type: ignore[override]
self, *args: _P.args, **kwargs: _P.kwargs
) -> tuple[MemorizedResult[_T] | NotMemorizedResult[_T], dict[str, Any]]: ...
) -> tuple[_T, dict[str, Any]]: ...

class Memory(Logger):
mmap_mode: MmapMode
Expand All @@ -198,7 +194,7 @@ class Memory(Logger):
@overload
def cache(
self,
func: None,
func: None = ...,
ignore: list[str] | None = ...,
verbose: int | None = ...,
mmap_mode: MmapMode | bool = ...,
Expand Down Expand Up @@ -234,11 +230,11 @@ class Memory(Logger):
@overload
def eval(
self, func: AnyAwaitableCallable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> Coroutine[Any, Any, _T | NotMemorizedResult[_T] | MemorizedResult[_T]]: ...
) -> Coroutine[Any, Any, _T]: ...
@overload
def eval(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T | NotMemorizedResult[_T] | MemorizedResult[_T]: ...
) -> _T: ...

def expires_after(
days: int = ...,
Expand Down

0 comments on commit bbeb6ce

Please sign in to comment.