Skip to content

Commit

Permalink
fix broken callable key typing + add test
Browse files Browse the repository at this point in the history
  • Loading branch information
northernSage committed Feb 11, 2024
1 parent a1bc730 commit fdfa9cb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
38 changes: 22 additions & 16 deletions src/cachelib/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
password: _t.Optional[str] = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: _t.Optional[str] = None,
**kwargs: _t.Any
key_prefix: _t.Optional[_t.Union[str, _t.Callable[[], str]]] = None,
**kwargs: _t.Any,
):
BaseCache.__init__(self, default_timeout)
if host is None:
Expand All @@ -57,7 +57,7 @@ def __init__(
self._read_client = self._write_client = host
self.key_prefix = key_prefix or ""

def _get_prefix(self):
def _get_prefix(self) -> str:
return (
self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix()
)
Expand All @@ -74,11 +74,13 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
return timeout

def get(self, key: str) -> _t.Any:
return self.serializer.loads(self._read_client.get(self._get_prefix() + key))
return self.serializer.loads(
self._read_client.get(f"{self._get_prefix()}{key}")
)

def get_many(self, *keys: str) -> _t.List[_t.Any]:
if self.key_prefix:
prefixed_keys = [self._get_prefix() + key for key in keys]
prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
else:
prefixed_keys = list(keys)
return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)]
Expand All @@ -87,20 +89,24 @@ def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.A
timeout = self._normalize_timeout(timeout)
dump = self.serializer.dumps(value)
if timeout == -1:
result = self._write_client.set(name=self._get_prefix() + key, value=dump)
result = self._write_client.set(
name=f"{self._get_prefix()}{key}", value=dump
)
else:
result = self._write_client.setex(
name=self._get_prefix() + key, value=dump, time=timeout
name=f"{self._get_prefix()}{key}", value=dump, time=timeout
)
return result

def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
timeout = self._normalize_timeout(timeout)
dump = self.serializer.dumps(value)
created = self._write_client.setnx(name=self._get_prefix() + key, value=dump)
created = self._write_client.setnx(
name=f"{self._get_prefix()}{key}", value=dump
)
# handle case where timeout is explicitly set to zero
if created and timeout != -1:
self._write_client.expire(name=self._get_prefix() + key, time=timeout)
self._write_client.expire(name=f"{self._get_prefix()}{key}", time=timeout)
return created

def set_many(
Expand All @@ -114,27 +120,27 @@ def set_many(
for key, value in mapping.items():
dump = self.serializer.dumps(value)
if timeout == -1:
pipe.set(name=self._get_prefix() + key, value=dump)
pipe.set(name=f"{self._get_prefix()}{key}", value=dump)
else:
pipe.setex(name=self._get_prefix() + key, value=dump, time=timeout)
pipe.setex(name=f"{self._get_prefix()}{key}", value=dump, time=timeout)
results = pipe.execute()
return [k for k, was_set in zip(mapping.keys(), results) if was_set]

def delete(self, key: str) -> bool:
return bool(self._write_client.delete(self._get_prefix() + key))
return bool(self._write_client.delete(f"{self._get_prefix()}{key}"))

def delete_many(self, *keys: str) -> _t.List[_t.Any]:
if not keys:
return []
if self.key_prefix:
prefixed_keys = [self._get_prefix() + key for key in keys]
prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
else:
prefixed_keys = [k for k in keys]
self._write_client.delete(*prefixed_keys)
return [k for k in prefixed_keys if not self.has(k)]

def has(self, key: str) -> bool:
return bool(self._read_client.exists(self._get_prefix() + key))
return bool(self._read_client.exists(f"{self._get_prefix()}{key}"))

def clear(self) -> bool:
status = 0
Expand All @@ -147,7 +153,7 @@ def clear(self) -> bool:
return bool(status)

def inc(self, key: str, delta: int = 1) -> _t.Any:
return self._write_client.incr(name=self._get_prefix() + key, amount=delta)
return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=delta)

def dec(self, key: str, delta: int = 1) -> _t.Any:
return self._write_client.incr(name=self._get_prefix() + key, amount=-delta)
return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=-delta)
11 changes: 10 additions & 1 deletion tests/test_redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def _factory(self, *args, **kwargs):
request.cls.cache_factory = _factory


def my_callable_key() -> str:
return "bacon"


@pytest.mark.usefixtures("redis_server")
class TestRedisCache(CommonTests, ClearTests, HasTests):
pass
def test_callable_key(self):
cache = self.cache_factory()
assert cache.set(my_callable_key, "sausages")
assert cache.get(my_callable_key) == "sausages"
assert cache.set(lambda: "spam", "sausages")
assert cache.get(lambda: "spam") == "sausages"

0 comments on commit fdfa9cb

Please sign in to comment.