Skip to content

Commit

Permalink
Add callable key support for RedisCache (#348)
Browse files Browse the repository at this point in the history
* Redis(Cache) now (re)supports function as a key_prefix

* add changelog

* fix broken callable key typing + add test

---------

Co-authored-by: David BASCOULÈS <[email protected]>
  • Loading branch information
northernSage and dbascoules authored Feb 11, 2024
1 parent 187dc56 commit bf1f8fd
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 19 deletions.
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
Version 0.12.0
--------------

Unreleased

- ``RedisCache`` now supports callables as keys

Version 0.11.0
--------------

Expand Down
46 changes: 28 additions & 18 deletions src/cachelib/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
password: _t.Optional[str] = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: _t.Optional[str] = None,
**kwargs: _t.Any
key_prefix: _t.Optional[_t.Union[str, _t.Callable[[], str]]] = None,
**kwargs: _t.Any,
):
BaseCache.__init__(self, default_timeout)
if host is None:
Expand All @@ -57,6 +57,11 @@ def __init__(
self._read_client = self._write_client = host
self.key_prefix = key_prefix or ""

def _get_prefix(self) -> str:
return (
self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix()
)

def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
"""Normalize timeout by setting it to default of 300 if
not defined (None) or -1 if explicitly set to zero.
Expand All @@ -69,11 +74,13 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
return timeout

def get(self, key: str) -> _t.Any:
return self.serializer.loads(self._read_client.get(self.key_prefix + key))
return self.serializer.loads(
self._read_client.get(f"{self._get_prefix()}{key}")
)

def get_many(self, *keys: str) -> _t.List[_t.Any]:
if self.key_prefix:
prefixed_keys = [self.key_prefix + key for key in keys]
prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
else:
prefixed_keys = list(keys)
return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)]
Expand All @@ -82,20 +89,24 @@ def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.A
timeout = self._normalize_timeout(timeout)
dump = self.serializer.dumps(value)
if timeout == -1:
result = self._write_client.set(name=self.key_prefix + key, value=dump)
result = self._write_client.set(
name=f"{self._get_prefix()}{key}", value=dump
)
else:
result = self._write_client.setex(
name=self.key_prefix + key, value=dump, time=timeout
name=f"{self._get_prefix()}{key}", value=dump, time=timeout
)
return result

def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
timeout = self._normalize_timeout(timeout)
dump = self.serializer.dumps(value)
created = self._write_client.setnx(name=self.key_prefix + key, value=dump)
created = self._write_client.setnx(
name=f"{self._get_prefix()}{key}", value=dump
)
# handle case where timeout is explicitly set to zero
if created and timeout != -1:
self._write_client.expire(name=self.key_prefix + key, time=timeout)
self._write_client.expire(name=f"{self._get_prefix()}{key}", time=timeout)
return created

def set_many(
Expand All @@ -109,41 +120,40 @@ def set_many(
for key, value in mapping.items():
dump = self.serializer.dumps(value)
if timeout == -1:
pipe.set(name=self.key_prefix + key, value=dump)
pipe.set(name=f"{self._get_prefix()}{key}", value=dump)
else:
pipe.setex(name=self.key_prefix + key, value=dump, time=timeout)
pipe.setex(name=f"{self._get_prefix()}{key}", value=dump, time=timeout)
results = pipe.execute()
res = zip(mapping.keys(), results) # noqa: B905
return [k for k, was_set in res if was_set]
return [k for k, was_set in zip(mapping.keys(), results) if was_set]

def delete(self, key: str) -> bool:
return bool(self._write_client.delete(self.key_prefix + key))
return bool(self._write_client.delete(f"{self._get_prefix()}{key}"))

def delete_many(self, *keys: str) -> _t.List[_t.Any]:
if not keys:
return []
if self.key_prefix:
prefixed_keys = [self.key_prefix + key for key in keys]
prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
else:
prefixed_keys = [k for k in keys]
self._write_client.delete(*prefixed_keys)
return [k for k in prefixed_keys if not self.has(k)]

def has(self, key: str) -> bool:
return bool(self._read_client.exists(self.key_prefix + key))
return bool(self._read_client.exists(f"{self._get_prefix()}{key}"))

def clear(self) -> bool:
status = 0
if self.key_prefix:
keys = self._read_client.keys(self.key_prefix + "*")
keys = self._read_client.keys(self._get_prefix() + "*")
if keys:
status = self._write_client.delete(*keys)
else:
status = self._write_client.flushdb()
return bool(status)

def inc(self, key: str, delta: int = 1) -> _t.Any:
return self._write_client.incr(name=self.key_prefix + key, amount=delta)
return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=delta)

def dec(self, key: str, delta: int = 1) -> _t.Any:
return self._write_client.incr(name=self.key_prefix + key, amount=-delta)
return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=-delta)
11 changes: 10 additions & 1 deletion tests/test_redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def _factory(self, *args, **kwargs):
request.cls.cache_factory = _factory


def my_callable_key() -> str:
return "bacon"


@pytest.mark.usefixtures("redis_server")
class TestRedisCache(CommonTests, ClearTests, HasTests):
pass
def test_callable_key(self):
cache = self.cache_factory()
assert cache.set(my_callable_key, "sausages")
assert cache.get(my_callable_key) == "sausages"
assert cache.set(lambda: "spam", "sausages")
assert cache.get(lambda: "spam") == "sausages"

0 comments on commit bf1f8fd

Please sign in to comment.