Skip to content

Commit

Permalink
Fix issue #10
Browse files Browse the repository at this point in the history
  • Loading branch information
awolverp committed Sep 4, 2024
1 parent e1f04d3 commit e216621
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
20 changes: 19 additions & 1 deletion cachebox/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._cachebox import BaseCacheImpl, FIFOCache
from collections import namedtuple
import functools
import inspect
import typing

Expand Down Expand Up @@ -179,6 +180,8 @@ def make_typed_key(args: tuple, kwds: dict):

CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "length", "cachememory"])

_NOT_SETTED = object()


class _cached_wrapper(typing.Generic[VT]):
def __init__(
Expand All @@ -197,7 +200,10 @@ def __init__(
self.__reuse = clear_reuse
self._hits = 0
self._misses = 0
self.__doc__ = getattr(func, "__doc__", None)

self.instance = _NOT_SETTED

functools.update_wrapper(self, func)

def cache_info(self) -> CacheInfo:
return CacheInfo(
Expand All @@ -212,7 +218,16 @@ def cache_clear(self) -> None:
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.func}>"

if not typing.TYPE_CHECKING:

def __get__(self, instance, *args):
self.instance = instance
return self.__call__

def __call__(self, *args, **kwds) -> VT:
if self.instance is not _NOT_SETTED:
args = (self.instance, *args)

if kwds.pop("cachebox__ignore", False):
return self.func(*args, **kwds)

Expand All @@ -231,6 +246,9 @@ def __call__(self, *args, **kwds) -> VT:

class _async_cached_wrapper(_cached_wrapper[VT]):
async def __call__(self, *args, **kwds) -> VT:
if self.instance is not _NOT_SETTED:
args = (_NOT_SETTED, *args)

if kwds.pop("cachebox__ignore", False):
return await self.func(*args, **kwds)

Expand Down
12 changes: 8 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,13 @@ def test_async_cached():

def test_cachedmethod():
class TestCachedMethod:
def __init__(self, num) -> None:
self.num = num

@cachedmethod(None)
def method(self, a: int, b: str):
return str(a) + b
def method(self, char: str):
assert type(self) is TestCachedMethod
return char * self.num

cls = TestCachedMethod()
assert cls.method(1, "2") == "12"
cls = TestCachedMethod(10)
assert cls.method("a") == ("a" * 10)

0 comments on commit e216621

Please sign in to comment.