From 10244a06a0f18bfb99362d7d51bbe858dbdf159b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20BASCOULE=CC=80S?= Date: Mon, 8 Jan 2024 15:22:08 +0100 Subject: [PATCH 1/3] Redis(Cache) now (re)supports function as a key_prefix --- src/cachelib/redis.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/cachelib/redis.py b/src/cachelib/redis.py index 8a0b0da8..3a6b093b 100644 --- a/src/cachelib/redis.py +++ b/src/cachelib/redis.py @@ -57,6 +57,11 @@ def __init__( self._read_client = self._write_client = host self.key_prefix = key_prefix or "" + def _get_prefix(self): + return ( + self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix() + ) + def _normalize_timeout(self, timeout: _t.Optional[int]) -> int: """Normalize timeout by setting it to default of 300 if not defined (None) or -1 if explicitly set to zero. @@ -69,11 +74,11 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int: return timeout def get(self, key: str) -> _t.Any: - return self.serializer.loads(self._read_client.get(self.key_prefix + key)) + return self.serializer.loads(self._read_client.get(self._get_prefix() + key)) def get_many(self, *keys: str) -> _t.List[_t.Any]: if self.key_prefix: - prefixed_keys = [self.key_prefix + key for key in keys] + prefixed_keys = [self._get_prefix() + key for key in keys] else: prefixed_keys = list(keys) return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)] @@ -82,20 +87,20 @@ def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.A timeout = self._normalize_timeout(timeout) dump = self.serializer.dumps(value) if timeout == -1: - result = self._write_client.set(name=self.key_prefix + key, value=dump) + result = self._write_client.set(name=self._get_prefix() + key, value=dump) else: result = self._write_client.setex( - name=self.key_prefix + key, value=dump, time=timeout + name=self._get_prefix() + key, value=dump, time=timeout ) return result def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any: timeout = self._normalize_timeout(timeout) dump = self.serializer.dumps(value) - created = self._write_client.setnx(name=self.key_prefix + key, value=dump) + created = self._write_client.setnx(name=self._get_prefix() + key, value=dump) # handle case where timeout is explicitly set to zero if created and timeout != -1: - self._write_client.expire(name=self.key_prefix + key, time=timeout) + self._write_client.expire(name=self._get_prefix() + key, time=timeout) return created def set_many( @@ -109,33 +114,32 @@ def set_many( for key, value in mapping.items(): dump = self.serializer.dumps(value) if timeout == -1: - pipe.set(name=self.key_prefix + key, value=dump) + pipe.set(name=self._get_prefix() + key, value=dump) else: - pipe.setex(name=self.key_prefix + key, value=dump, time=timeout) + pipe.setex(name=self._get_prefix() + key, value=dump, time=timeout) results = pipe.execute() - res = zip(mapping.keys(), results) # noqa: B905 - return [k for k, was_set in res if was_set] + return [k for k, was_set in zip(mapping.keys(), results) if was_set] def delete(self, key: str) -> bool: - return bool(self._write_client.delete(self.key_prefix + key)) + return bool(self._write_client.delete(self._get_prefix() + key)) def delete_many(self, *keys: str) -> _t.List[_t.Any]: if not keys: return [] if self.key_prefix: - prefixed_keys = [self.key_prefix + key for key in keys] + prefixed_keys = [self._get_prefix() + key for key in keys] else: prefixed_keys = [k for k in keys] self._write_client.delete(*prefixed_keys) return [k for k in prefixed_keys if not self.has(k)] def has(self, key: str) -> bool: - return bool(self._read_client.exists(self.key_prefix + key)) + return bool(self._read_client.exists(self._get_prefix() + key)) def clear(self) -> bool: status = 0 if self.key_prefix: - keys = self._read_client.keys(self.key_prefix + "*") + keys = self._read_client.keys(self._get_prefix() + "*") if keys: status = self._write_client.delete(*keys) else: @@ -143,7 +147,7 @@ def clear(self) -> bool: return bool(status) def inc(self, key: str, delta: int = 1) -> _t.Any: - return self._write_client.incr(name=self.key_prefix + key, amount=delta) + return self._write_client.incr(name=self._get_prefix() + key, amount=delta) def dec(self, key: str, delta: int = 1) -> _t.Any: - return self._write_client.incr(name=self.key_prefix + key, amount=-delta) + return self._write_client.incr(name=self._get_prefix() + key, amount=-delta) From a1bc730574699b14e9a224d6826738bc7c7e0b03 Mon Sep 17 00:00:00 2001 From: northernSage Date: Sat, 10 Feb 2024 13:10:14 -0300 Subject: [PATCH 2/3] add changelog --- CHANGES.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 9eee1eb1..91aa867a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,10 @@ +Version 0.12.0 +-------------- + +Unreleased + +- ``RedisCache`` now supports callables as key prefix + Version 0.11.0 -------------- From c2311f92ba583ec0b95219c686251905fd750ad1 Mon Sep 17 00:00:00 2001 From: northernSage Date: Sat, 10 Feb 2024 13:28:25 -0300 Subject: [PATCH 3/3] fix broken callable key typing + add test --- CHANGES.rst | 2 +- src/cachelib/redis.py | 38 ++++++++++++++++++++++---------------- tests/test_redis_cache.py | 11 ++++++++++- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 91aa867a..11650e47 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,7 +3,7 @@ Version 0.12.0 Unreleased -- ``RedisCache`` now supports callables as key prefix +- ``RedisCache`` now supports callables as keys Version 0.11.0 -------------- diff --git a/src/cachelib/redis.py b/src/cachelib/redis.py index 3a6b093b..8e38115b 100644 --- a/src/cachelib/redis.py +++ b/src/cachelib/redis.py @@ -37,8 +37,8 @@ def __init__( password: _t.Optional[str] = None, db: int = 0, default_timeout: int = 300, - key_prefix: _t.Optional[str] = None, - **kwargs: _t.Any + key_prefix: _t.Optional[_t.Union[str, _t.Callable[[], str]]] = None, + **kwargs: _t.Any, ): BaseCache.__init__(self, default_timeout) if host is None: @@ -57,7 +57,7 @@ def __init__( self._read_client = self._write_client = host self.key_prefix = key_prefix or "" - def _get_prefix(self): + def _get_prefix(self) -> str: return ( self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix() ) @@ -74,11 +74,13 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int: return timeout def get(self, key: str) -> _t.Any: - return self.serializer.loads(self._read_client.get(self._get_prefix() + key)) + return self.serializer.loads( + self._read_client.get(f"{self._get_prefix()}{key}") + ) def get_many(self, *keys: str) -> _t.List[_t.Any]: if self.key_prefix: - prefixed_keys = [self._get_prefix() + key for key in keys] + prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys] else: prefixed_keys = list(keys) return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)] @@ -87,20 +89,24 @@ def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.A timeout = self._normalize_timeout(timeout) dump = self.serializer.dumps(value) if timeout == -1: - result = self._write_client.set(name=self._get_prefix() + key, value=dump) + result = self._write_client.set( + name=f"{self._get_prefix()}{key}", value=dump + ) else: result = self._write_client.setex( - name=self._get_prefix() + key, value=dump, time=timeout + name=f"{self._get_prefix()}{key}", value=dump, time=timeout ) return result def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any: timeout = self._normalize_timeout(timeout) dump = self.serializer.dumps(value) - created = self._write_client.setnx(name=self._get_prefix() + key, value=dump) + created = self._write_client.setnx( + name=f"{self._get_prefix()}{key}", value=dump + ) # handle case where timeout is explicitly set to zero if created and timeout != -1: - self._write_client.expire(name=self._get_prefix() + key, time=timeout) + self._write_client.expire(name=f"{self._get_prefix()}{key}", time=timeout) return created def set_many( @@ -114,27 +120,27 @@ def set_many( for key, value in mapping.items(): dump = self.serializer.dumps(value) if timeout == -1: - pipe.set(name=self._get_prefix() + key, value=dump) + pipe.set(name=f"{self._get_prefix()}{key}", value=dump) else: - pipe.setex(name=self._get_prefix() + key, value=dump, time=timeout) + pipe.setex(name=f"{self._get_prefix()}{key}", value=dump, time=timeout) results = pipe.execute() return [k for k, was_set in zip(mapping.keys(), results) if was_set] def delete(self, key: str) -> bool: - return bool(self._write_client.delete(self._get_prefix() + key)) + return bool(self._write_client.delete(f"{self._get_prefix()}{key}")) def delete_many(self, *keys: str) -> _t.List[_t.Any]: if not keys: return [] if self.key_prefix: - prefixed_keys = [self._get_prefix() + key for key in keys] + prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys] else: prefixed_keys = [k for k in keys] self._write_client.delete(*prefixed_keys) return [k for k in prefixed_keys if not self.has(k)] def has(self, key: str) -> bool: - return bool(self._read_client.exists(self._get_prefix() + key)) + return bool(self._read_client.exists(f"{self._get_prefix()}{key}")) def clear(self) -> bool: status = 0 @@ -147,7 +153,7 @@ def clear(self) -> bool: return bool(status) def inc(self, key: str, delta: int = 1) -> _t.Any: - return self._write_client.incr(name=self._get_prefix() + key, amount=delta) + return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=delta) def dec(self, key: str, delta: int = 1) -> _t.Any: - return self._write_client.incr(name=self._get_prefix() + key, amount=-delta) + return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=-delta) diff --git a/tests/test_redis_cache.py b/tests/test_redis_cache.py index e6e10d67..26d178b6 100644 --- a/tests/test_redis_cache.py +++ b/tests/test_redis_cache.py @@ -38,6 +38,15 @@ def _factory(self, *args, **kwargs): request.cls.cache_factory = _factory +def my_callable_key() -> str: + return "bacon" + + @pytest.mark.usefixtures("redis_server") class TestRedisCache(CommonTests, ClearTests, HasTests): - pass + def test_callable_key(self): + cache = self.cache_factory() + assert cache.set(my_callable_key, "sausages") + assert cache.get(my_callable_key) == "sausages" + assert cache.set(lambda: "spam", "sausages") + assert cache.get(lambda: "spam") == "sausages"