From fdfa9cbb466593623766f76b5fb36808976ab643 Mon Sep 17 00:00:00 2001 From: northernSage Date: Sat, 10 Feb 2024 13:28:25 -0300 Subject: [PATCH] fix broken callable key typing + add test --- src/cachelib/redis.py | 38 ++++++++++++++++++++++---------------- tests/test_redis_cache.py | 11 ++++++++++- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/cachelib/redis.py b/src/cachelib/redis.py index 3a6b093b..8e38115b 100644 --- a/src/cachelib/redis.py +++ b/src/cachelib/redis.py @@ -37,8 +37,8 @@ def __init__( password: _t.Optional[str] = None, db: int = 0, default_timeout: int = 300, - key_prefix: _t.Optional[str] = None, - **kwargs: _t.Any + key_prefix: _t.Optional[_t.Union[str, _t.Callable[[], str]]] = None, + **kwargs: _t.Any, ): BaseCache.__init__(self, default_timeout) if host is None: @@ -57,7 +57,7 @@ def __init__( self._read_client = self._write_client = host self.key_prefix = key_prefix or "" - def _get_prefix(self): + def _get_prefix(self) -> str: return ( self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix() ) @@ -74,11 +74,13 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int: return timeout def get(self, key: str) -> _t.Any: - return self.serializer.loads(self._read_client.get(self._get_prefix() + key)) + return self.serializer.loads( + self._read_client.get(f"{self._get_prefix()}{key}") + ) def get_many(self, *keys: str) -> _t.List[_t.Any]: if self.key_prefix: - prefixed_keys = [self._get_prefix() + key for key in keys] + prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys] else: prefixed_keys = list(keys) return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)] @@ -87,20 +89,24 @@ def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.A timeout = self._normalize_timeout(timeout) dump = self.serializer.dumps(value) if timeout == -1: - result = self._write_client.set(name=self._get_prefix() + key, value=dump) + result = self._write_client.set( + name=f"{self._get_prefix()}{key}", value=dump + ) else: result = self._write_client.setex( - name=self._get_prefix() + key, value=dump, time=timeout + name=f"{self._get_prefix()}{key}", value=dump, time=timeout ) return result def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any: timeout = self._normalize_timeout(timeout) dump = self.serializer.dumps(value) - created = self._write_client.setnx(name=self._get_prefix() + key, value=dump) + created = self._write_client.setnx( + name=f"{self._get_prefix()}{key}", value=dump + ) # handle case where timeout is explicitly set to zero if created and timeout != -1: - self._write_client.expire(name=self._get_prefix() + key, time=timeout) + self._write_client.expire(name=f"{self._get_prefix()}{key}", time=timeout) return created def set_many( @@ -114,27 +120,27 @@ def set_many( for key, value in mapping.items(): dump = self.serializer.dumps(value) if timeout == -1: - pipe.set(name=self._get_prefix() + key, value=dump) + pipe.set(name=f"{self._get_prefix()}{key}", value=dump) else: - pipe.setex(name=self._get_prefix() + key, value=dump, time=timeout) + pipe.setex(name=f"{self._get_prefix()}{key}", value=dump, time=timeout) results = pipe.execute() return [k for k, was_set in zip(mapping.keys(), results) if was_set] def delete(self, key: str) -> bool: - return bool(self._write_client.delete(self._get_prefix() + key)) + return bool(self._write_client.delete(f"{self._get_prefix()}{key}")) def delete_many(self, *keys: str) -> _t.List[_t.Any]: if not keys: return [] if self.key_prefix: - prefixed_keys = [self._get_prefix() + key for key in keys] + prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys] else: prefixed_keys = [k for k in keys] self._write_client.delete(*prefixed_keys) return [k for k in prefixed_keys if not self.has(k)] def has(self, key: str) -> bool: - return bool(self._read_client.exists(self._get_prefix() + key)) + return bool(self._read_client.exists(f"{self._get_prefix()}{key}")) def clear(self) -> bool: status = 0 @@ -147,7 +153,7 @@ def clear(self) -> bool: return bool(status) def inc(self, key: str, delta: int = 1) -> _t.Any: - return self._write_client.incr(name=self._get_prefix() + key, amount=delta) + return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=delta) def dec(self, key: str, delta: int = 1) -> _t.Any: - return self._write_client.incr(name=self._get_prefix() + key, amount=-delta) + return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=-delta) diff --git a/tests/test_redis_cache.py b/tests/test_redis_cache.py index e6e10d67..26d178b6 100644 --- a/tests/test_redis_cache.py +++ b/tests/test_redis_cache.py @@ -38,6 +38,15 @@ def _factory(self, *args, **kwargs): request.cls.cache_factory = _factory +def my_callable_key() -> str: + return "bacon" + + @pytest.mark.usefixtures("redis_server") class TestRedisCache(CommonTests, ClearTests, HasTests): - pass + def test_callable_key(self): + cache = self.cache_factory() + assert cache.set(my_callable_key, "sausages") + assert cache.get(my_callable_key) == "sausages" + assert cache.set(lambda: "spam", "sausages") + assert cache.get(lambda: "spam") == "sausages"