Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add invalidate_cache for @cached decorator #927

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions aiocache/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class cached:
Caches the functions return value into a key generated with module_name, function_name
and args. The cache is available in the function object as ``<function_name>.cache``.

To invalidate the cache, you can use the ``invalidate_cache`` method of the function object by
passing the args that were used to generate the cache key as
``await <function_name>.invalidate_cache(*args, **kwargs)``. It is an async method.

In some cases you will need to send more args to configure the cache object.
An example would be endpoint and port for the Redis cache. You can send those args as
kwargs and they will be propagated accordingly.
Expand Down Expand Up @@ -77,17 +81,20 @@ def __init__(
self.alias = alias
self.cache = None

self._func = None
self._cache = cache
self._serializer = serializer
self._namespace = namespace
self._plugins = plugins
self._kwargs = kwargs

def __call__(self, f):
self._func = f

if self.alias:
self.cache = caches.get(self.alias)
for arg in ("serializer", "namespace", "plugins"):
if getattr(self, f'_{arg}', None) is not None:
if getattr(self, f"_{arg}", None) is not None:
logger.warning(f"Using cache alias; ignoring {arg!r} argument.")
else:
self.cache = _get_cache(
Expand All @@ -103,6 +110,7 @@ async def wrapper(*args, **kwargs):
return await self.decorator(f, *args, **kwargs)

wrapper.cache = self.cache
wrapper.invalidate_cache = self.invalidate_cache
return wrapper

async def decorator(
Expand Down Expand Up @@ -157,6 +165,10 @@ async def set_in_cache(self, key, value):
except Exception:
logger.exception("Couldn't set %s in key %s, unexpected error", value, key)

async def invalidate_cache(self, *args, **kwargs):
key = self.get_cache_key(self._func, args, kwargs)
return await self.cache.delete(key)


class cached_stampede(cached):
"""
Expand Down Expand Up @@ -330,7 +342,7 @@ def __call__(self, f):
if self.alias:
self.cache = caches.get(self.alias)
for arg in ("serializer", "namespace", "plugins"):
if getattr(self, f'_{arg}', None) is not None:
if getattr(self, f"_{arg}", None) is not None:
logger.warning(f"Using cache alias; ignoring {arg!r} argument.")
else:
self.cache = _get_cache(
Expand Down
63 changes: 61 additions & 2 deletions tests/ut/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,64 @@ async def bar():

assert foo.cache != bar.cache

async def test_invalidate_cache_exists(self):
@cached()
async def foo():
"""Dummy function."""

assert callable(foo.invalidate_cache)

async def test_invalidate_cache(self):
cache_misses = 0

@cached(ttl=60 * 60)
async def foo(return_value: str):
nonlocal cache_misses
cache_misses += 1
return return_value

await foo("hello") # increments cache_misses since it's not cached
assert cache_misses == 1

await foo("hello") # doesn't increment cache_misses since it's cached
assert cache_misses == 1

await foo.invalidate_cache("hello")
await foo("hello") # increments cache_misses since the cache was invalidated
assert cache_misses == 2

await foo("hello") # doesn't increment cache_misses since it's cached
assert cache_misses == 2

async def test_invalidate_cache_diff_args(self):
"""
Tests that the invalidate_cache invalidates the cache for the correct arguments.
"""

cache_misses = 0

@cached(ttl=60 * 60)
async def foo(return_value: str):
nonlocal cache_misses
cache_misses += 1
return return_value

await foo("hello") # increments cache_misses since "hello" is not cached
assert cache_misses == 1

await foo("world") # increments cache_misses since "world" is not cached
assert cache_misses == 2

await foo.invalidate_cache("world")
await foo("hello") # doesn't increment cache_misses since "hello" is still cached
await foo("hello")
await foo("hello")
await foo("hello")
assert cache_misses == 2

await foo("world")
assert cache_misses == 3


class TestCachedStampede:
@pytest.fixture
Expand Down Expand Up @@ -476,8 +534,9 @@ async def test_cache_write_doesnt_wait_for_future(self, mocker, decorator, decor
mocker.spy(decorator, "set_in_cache")
with patch.object(decorator, "get_from_cache", autospec=True, return_value=[None, None]):
with patch("aiocache.decorators.asyncio.ensure_future", autospec=True):
await decorator_call(1, keys=["a", "b"], value="value",
aiocache_wait_for_write=False)
await decorator_call(
1, keys=["a", "b"], value="value", aiocache_wait_for_write=False
)

decorator.set_in_cache.assert_not_awaited()
decorator.set_in_cache.assert_called_once_with({"a": ANY, "b": ANY}, stub_dict, ANY, ANY)
Expand Down