Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callable key support for RedisCache #348

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
Version 0.12.0
--------------

Unreleased

- ``RedisCache`` now supports callables as keys

Version 0.11.0
--------------

Expand Down
46 changes: 28 additions & 18 deletions src/cachelib/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
password: _t.Optional[str] = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: _t.Optional[str] = None,
**kwargs: _t.Any
key_prefix: _t.Optional[_t.Union[str, _t.Callable[[], str]]] = None,
**kwargs: _t.Any,
):
BaseCache.__init__(self, default_timeout)
if host is None:
Expand All @@ -57,6 +57,11 @@ def __init__(
self._read_client = self._write_client = host
self.key_prefix = key_prefix or ""

def _get_prefix(self) -> str:
return (
self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix()
)

def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
"""Normalize timeout by setting it to default of 300 if
not defined (None) or -1 if explicitly set to zero.
Expand All @@ -69,11 +74,13 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
return timeout

def get(self, key: str) -> _t.Any:
return self.serializer.loads(self._read_client.get(self.key_prefix + key))
return self.serializer.loads(
self._read_client.get(f"{self._get_prefix()}{key}")
)

def get_many(self, *keys: str) -> _t.List[_t.Any]:
if self.key_prefix:
prefixed_keys = [self.key_prefix + key for key in keys]
prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
else:
prefixed_keys = list(keys)
return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)]
Expand All @@ -82,20 +89,24 @@ def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.A
timeout = self._normalize_timeout(timeout)
dump = self.serializer.dumps(value)
if timeout == -1:
result = self._write_client.set(name=self.key_prefix + key, value=dump)
result = self._write_client.set(
name=f"{self._get_prefix()}{key}", value=dump
)
else:
result = self._write_client.setex(
name=self.key_prefix + key, value=dump, time=timeout
name=f"{self._get_prefix()}{key}", value=dump, time=timeout
)
return result

def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
timeout = self._normalize_timeout(timeout)
dump = self.serializer.dumps(value)
created = self._write_client.setnx(name=self.key_prefix + key, value=dump)
created = self._write_client.setnx(
name=f"{self._get_prefix()}{key}", value=dump
)
# handle case where timeout is explicitly set to zero
if created and timeout != -1:
self._write_client.expire(name=self.key_prefix + key, time=timeout)
self._write_client.expire(name=f"{self._get_prefix()}{key}", time=timeout)
return created

def set_many(
Expand All @@ -109,41 +120,40 @@ def set_many(
for key, value in mapping.items():
dump = self.serializer.dumps(value)
if timeout == -1:
pipe.set(name=self.key_prefix + key, value=dump)
pipe.set(name=f"{self._get_prefix()}{key}", value=dump)
else:
pipe.setex(name=self.key_prefix + key, value=dump, time=timeout)
pipe.setex(name=f"{self._get_prefix()}{key}", value=dump, time=timeout)
results = pipe.execute()
res = zip(mapping.keys(), results) # noqa: B905
return [k for k, was_set in res if was_set]
return [k for k, was_set in zip(mapping.keys(), results) if was_set]

def delete(self, key: str) -> bool:
return bool(self._write_client.delete(self.key_prefix + key))
return bool(self._write_client.delete(f"{self._get_prefix()}{key}"))

def delete_many(self, *keys: str) -> _t.List[_t.Any]:
if not keys:
return []
if self.key_prefix:
prefixed_keys = [self.key_prefix + key for key in keys]
prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
else:
prefixed_keys = [k for k in keys]
self._write_client.delete(*prefixed_keys)
return [k for k in prefixed_keys if not self.has(k)]

def has(self, key: str) -> bool:
return bool(self._read_client.exists(self.key_prefix + key))
return bool(self._read_client.exists(f"{self._get_prefix()}{key}"))

def clear(self) -> bool:
status = 0
if self.key_prefix:
keys = self._read_client.keys(self.key_prefix + "*")
keys = self._read_client.keys(self._get_prefix() + "*")
if keys:
status = self._write_client.delete(*keys)
else:
status = self._write_client.flushdb()
return bool(status)

def inc(self, key: str, delta: int = 1) -> _t.Any:
return self._write_client.incr(name=self.key_prefix + key, amount=delta)
return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=delta)

def dec(self, key: str, delta: int = 1) -> _t.Any:
return self._write_client.incr(name=self.key_prefix + key, amount=-delta)
return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=-delta)
11 changes: 10 additions & 1 deletion tests/test_redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def _factory(self, *args, **kwargs):
request.cls.cache_factory = _factory


def my_callable_key() -> str:
return "bacon"


@pytest.mark.usefixtures("redis_server")
class TestRedisCache(CommonTests, ClearTests, HasTests):
pass
def test_callable_key(self):
cache = self.cache_factory()
assert cache.set(my_callable_key, "sausages")
assert cache.get(my_callable_key) == "sausages"
assert cache.set(lambda: "spam", "sausages")
assert cache.get(lambda: "spam") == "sausages"
Loading