diff --git a/CHANGES.rst b/CHANGES.rst index 9eee1eb1..11650e47 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,10 @@ +Version 0.12.0 +-------------- + +Unreleased + +- ``RedisCache`` now supports callables as keys + Version 0.11.0 -------------- diff --git a/src/cachelib/redis.py b/src/cachelib/redis.py index 8a0b0da8..8e38115b 100644 --- a/src/cachelib/redis.py +++ b/src/cachelib/redis.py @@ -37,8 +37,8 @@ def __init__( password: _t.Optional[str] = None, db: int = 0, default_timeout: int = 300, - key_prefix: _t.Optional[str] = None, - **kwargs: _t.Any + key_prefix: _t.Optional[_t.Union[str, _t.Callable[[], str]]] = None, + **kwargs: _t.Any, ): BaseCache.__init__(self, default_timeout) if host is None: @@ -57,6 +57,11 @@ def __init__( self._read_client = self._write_client = host self.key_prefix = key_prefix or "" + def _get_prefix(self) -> str: + return ( + self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix() + ) + def _normalize_timeout(self, timeout: _t.Optional[int]) -> int: """Normalize timeout by setting it to default of 300 if not defined (None) or -1 if explicitly set to zero. @@ -69,11 +74,13 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int: return timeout def get(self, key: str) -> _t.Any: - return self.serializer.loads(self._read_client.get(self.key_prefix + key)) + return self.serializer.loads( + self._read_client.get(f"{self._get_prefix()}{key}") + ) def get_many(self, *keys: str) -> _t.List[_t.Any]: if self.key_prefix: - prefixed_keys = [self.key_prefix + key for key in keys] + prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys] else: prefixed_keys = list(keys) return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)] @@ -82,20 +89,24 @@ def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.A timeout = self._normalize_timeout(timeout) dump = self.serializer.dumps(value) if timeout == -1: - result = self._write_client.set(name=self.key_prefix + key, value=dump) + result = self._write_client.set( + name=f"{self._get_prefix()}{key}", value=dump + ) else: result = self._write_client.setex( - name=self.key_prefix + key, value=dump, time=timeout + name=f"{self._get_prefix()}{key}", value=dump, time=timeout ) return result def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any: timeout = self._normalize_timeout(timeout) dump = self.serializer.dumps(value) - created = self._write_client.setnx(name=self.key_prefix + key, value=dump) + created = self._write_client.setnx( + name=f"{self._get_prefix()}{key}", value=dump + ) # handle case where timeout is explicitly set to zero if created and timeout != -1: - self._write_client.expire(name=self.key_prefix + key, time=timeout) + self._write_client.expire(name=f"{self._get_prefix()}{key}", time=timeout) return created def set_many( @@ -109,33 +120,32 @@ def set_many( for key, value in mapping.items(): dump = self.serializer.dumps(value) if timeout == -1: - pipe.set(name=self.key_prefix + key, value=dump) + pipe.set(name=f"{self._get_prefix()}{key}", value=dump) else: - pipe.setex(name=self.key_prefix + key, value=dump, time=timeout) + pipe.setex(name=f"{self._get_prefix()}{key}", value=dump, time=timeout) results = pipe.execute() - res = zip(mapping.keys(), results) # noqa: B905 - return [k for k, was_set in res if was_set] + return [k for k, was_set in zip(mapping.keys(), results) if was_set] def delete(self, key: str) -> bool: - return bool(self._write_client.delete(self.key_prefix + key)) + return bool(self._write_client.delete(f"{self._get_prefix()}{key}")) def delete_many(self, *keys: str) -> _t.List[_t.Any]: if not keys: return [] if self.key_prefix: - prefixed_keys = [self.key_prefix + key for key in keys] + prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys] else: prefixed_keys = [k for k in keys] self._write_client.delete(*prefixed_keys) return [k for k in prefixed_keys if not self.has(k)] def has(self, key: str) -> bool: - return bool(self._read_client.exists(self.key_prefix + key)) + return bool(self._read_client.exists(f"{self._get_prefix()}{key}")) def clear(self) -> bool: status = 0 if self.key_prefix: - keys = self._read_client.keys(self.key_prefix + "*") + keys = self._read_client.keys(self._get_prefix() + "*") if keys: status = self._write_client.delete(*keys) else: @@ -143,7 +153,7 @@ def clear(self) -> bool: return bool(status) def inc(self, key: str, delta: int = 1) -> _t.Any: - return self._write_client.incr(name=self.key_prefix + key, amount=delta) + return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=delta) def dec(self, key: str, delta: int = 1) -> _t.Any: - return self._write_client.incr(name=self.key_prefix + key, amount=-delta) + return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=-delta) diff --git a/tests/test_redis_cache.py b/tests/test_redis_cache.py index e6e10d67..26d178b6 100644 --- a/tests/test_redis_cache.py +++ b/tests/test_redis_cache.py @@ -38,6 +38,15 @@ def _factory(self, *args, **kwargs): request.cls.cache_factory = _factory +def my_callable_key() -> str: + return "bacon" + + @pytest.mark.usefixtures("redis_server") class TestRedisCache(CommonTests, ClearTests, HasTests): - pass + def test_callable_key(self): + cache = self.cache_factory() + assert cache.set(my_callable_key, "sausages") + assert cache.get(my_callable_key) == "sausages" + assert cache.set(lambda: "spam", "sausages") + assert cache.get(lambda: "spam") == "sausages"