diff --git a/baseplate/clients/redis_cluster.py b/baseplate/clients/redis_cluster.py new file mode 100644 index 000000000..f68aa7138 --- /dev/null +++ b/baseplate/clients/redis_cluster.py @@ -0,0 +1,450 @@ +import logging +import random + +from datetime import timedelta +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import rediscluster + +from rediscluster.pipeline import ClusterPipeline + +from baseplate import Span +from baseplate.clients import ContextFactory +from baseplate.lib import config +from baseplate.lib import metrics + + +logger = logging.getLogger(__name__) +randomizer = random.SystemRandom() + + +# Read commands that take a single key as their first parameter +SINGLE_KEY_READ_COMMANDS = frozenset( + [ + "BITCOUNT", + "BITPOS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUS", + "GEORADIUSBYMEMBER", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "LINDEX", + "LLEN", + "LRANGE", + "PTTL", + "SCARD", + "SISMEMBER", + "SMEMBERS", + "SRANDMEMBER", + "STRLEN", + "TTL", + "ZCARD", + "ZCOUNT", + "ZRANGE", + "ZSCORE", + ] +) + +# Read commands that take a list of keys as parameters +MULTI_KEY_READ_COMMANDS = frozenset(["EXISTS", "MGET", "SDIFF", "SINTER", "SUNION"]) + +# Write commands that take a single key as their first parameter. +SINGLE_KEY_WRITE_COMMANDS = frozenset( + [ + "EXPIRE", + "EXPIREAT", + "HINCRBY", + "HINCRBYFLOAT", + "HDEL", + "HMSET", + "HSET", + "HSETNX", + "LPUSH", + "LREM", + "LPOP", + "LSET", + "LTRIM", + "RPOP", + "RPUSH", + "SADD", + "SET", + "SETNX", + "SPOP", + "SREM", + "ZADD", + "ZINCRBY", + "ZPOPMAX", + "ZPOPMIN", + "ZREM", + "ZREMRANGEBYLEX", + "ZREMRANGEBYRANK", + "ZREMRANGEBYSCORE", + ] +) + +# Write commands that take a list of keys as argument +MULTI_KEY_WRITE_COMMANDS = frozenset(["DEL"]) + +# These are a special case of multi-key write commands that take arguments in the form +# of key value [key value ...] +MULTI_KEY_BATCH_WRITE_COMMANDS = frozenset(["MSET", "MSETNX"]) + + +class HotKeyTracker: + """ + HotKeyTracker can be used to help identify hot keys within Redis. + + Helper class that can be used to track our key usage and identify hot keys within + Redis. Whenever we send a read command to Redis we have a (very low but configurable) + chance to increase a counter associated with that key in Redis. Over time this should + allow us to find keys that are disproportionaly represented by querying the sorted + set "baseplate-hot-key-tracker-reads" in Redis. A same sorted set by the name of + "baseplate-hot-key-tracker-writes" will be used to track write frequency. + + Both read and writes tracking have different configurable percentages, which means + we can enable tracking for reads without enabling it for writes or have different + percentages for them, which is useful when the number of reads is much higher than + the number of writes to a cluster. + + This feature can be turned off by setting the tracking percentage to zero, and should + probably only be enabled if we're actively debugging an issue or looking for a regression. + + The "baseplate-hot-key-tracker-reads" will have a TTL of 24 hours to ensure that + older key counts don't interfere with new debugging sessions. This means that the + sorted set and its contents will disappear in 24 hours after this feature is disabled + and we stopped writing to it. + """ + + def __init__( + self, + redis_client: rediscluster.RedisCluster, + track_reads_sample_rate: float, + track_writes_sample_rate: float, + ): + self.redis_client = redis_client + self.track_reads_sample_rate = track_reads_sample_rate + self.track_writes_sample_rate = track_writes_sample_rate + + self.reads_sorted_set_name = "baseplate-hot-key-tracker-reads" + self.writes_sorted_set_name = "baseplate-hot-key-tracker-writes" + + def should_track_key_reads(self) -> bool: + return randomizer.random() < self.track_reads_sample_rate + + def should_track_key_writes(self) -> bool: + return randomizer.random() < self.track_writes_sample_rate + + def increment_keys_read_counter(self, key_list: List[str], ignore_errors: bool = True) -> None: + self._increment_hot_key_counter(key_list, self.reads_sorted_set_name, ignore_errors) + + def increment_keys_written_counter( + self, key_list: List[str], ignore_errors: bool = True + ) -> None: + self._increment_hot_key_counter(key_list, self.writes_sorted_set_name, ignore_errors) + + def _increment_hot_key_counter( + self, key_list: List[str], set_name: str, ignore_errors: bool = True + ) -> None: + if len(key_list) == 0: + return + + try: + with self.redis_client.pipeline(set_name) as pipe: + for key in key_list: + pipe.zincrby(set_name, 1, key) + # Reset the TTL for the sorted set + pipe.expire(set_name, timedelta(hours=24)) + pipe.execute() + except Exception as e: + # We don't want to disrupt this request even if key tracking fails, so just + # log it. + logger.exception(e) + if not ignore_errors: + raise + + def maybe_track_key_usage(self, args: List[str]) -> None: + """Probabilistically track usage of the keys in this command. + + If we have enabled key usage tracing *and* this command is withing the + percentage of commands we want to track, then write it to a sorted set + so we can keep track of the most accessed keys. + """ + if len(args) == 0: + return + + command = args[0] + + if self.should_track_key_reads(): + if command in SINGLE_KEY_READ_COMMANDS: + self.increment_keys_read_counter([args[1]]) + elif command in MULTI_KEY_READ_COMMANDS: + self.increment_keys_read_counter(args[1:]) + + if self.should_track_key_writes(): + if command in SINGLE_KEY_WRITE_COMMANDS: + self.increment_keys_written_counter([args[1]]) + elif command in MULTI_KEY_WRITE_COMMANDS: + self.increment_keys_written_counter(args[1:]) + elif command in MULTI_KEY_BATCH_WRITE_COMMANDS: + # These commands follow key value [key value...] format + self.increment_keys_written_counter(args[1::2]) + + +# We want to be able to combine blocking behaviour with the ability to read from replicas +# Unfortunately this is not provide as-is so we combine two connection pool classes to provide +# the desired behaviour. +class ClusterWithReadReplicasBlockingConnectionPool(rediscluster.ClusterBlockingConnectionPool): + # pylint: disable=arguments-differ + def get_node_by_slot(self, slot: int, read_command: bool = False) -> Dict[str, Any]: + """Get a node from the slot. + + If the command is a read command we'll try to return a random node. + If there are no replicas or this isn't a read command we'll return the primary. + """ + if read_command: + return random.choice(self.nodes.slots[slot]) + + # This isn't a read command, so return the primary (first node) + return self.nodes.slots[slot][0] + + +def cluster_pool_from_config( + app_config: config.RawConfig, prefix: str = "rediscluster.", **kwargs: Any +) -> rediscluster.ClusterConnectionPool: + """Make a ClusterConnectionPool from a configuration dictionary. + + The keys useful to :py:func:`cluster_pool_from_config` should be prefixed, e.g. + ``rediscluster.url``, ``rediscluster.max_connections``, etc. The ``prefix`` argument + specifies the prefix used to filter keys. Each key is mapped to a + corresponding keyword argument on the :py:class:`rediscluster.ClusterConnectionPool` + constructor. + Supported keys: + * ``url`` (required): a URL like ``redis://localhost/0``. + * ``max_connections``: an integer maximum number of connections in the pool + * ``max_connections_per_node``: Boolean, whether max_connections should be applied + globally (False) or per node (True). + * ``skip_full_coverage_check``: Skips the check of cluster-require-full-coverage + config, useful for clusters without the CONFIG command (like aws) + * ``nodemanager_follow_cluster``: Tell the node manager to reuse the last set of + nodes it was operating on when intializing. + * ``read_from_replicas``: (Boolean) Whether the client should send all read queries to + replicas instead of just the primary + * ``timeout``: how long to wait for sockets to connect. e.g. + ``200 milliseconds`` (:py:func:`~baseplate.lib.config.Timespan`) + * ``track_key_reads_sample_rate``: If greater than zero, which percentage of requests will + be inspected to keep track of hot key usage within Redis when reading. + Every command inspected will result in a write to a sorted set + (baseplate-hot-key-tracker-reads) for tracking. + * ``track_key_writes_sample_rate``: If greater than zero, which percentage of requests will + be inspected to keep track of hot key usage within Redis when writing. + Every command inspected will result in a write to a sorted set + (baseplate-hot-key-tracker-reads) for tracking. + + """ + assert prefix.endswith(".") + + parser = config.SpecParser( + { + "url": config.String, + "max_connections": config.Optional(config.Integer, default=50), + "max_connections_per_node": config.Optional(config.Boolean, default=False), + "timeout": config.Optional(config.Timespan, default=None), + "read_from_replicas": config.Optional(config.Boolean, default=True), + "skip_full_coverage_check": config.Optional(config.Boolean, default=True), + "nodemanager_follow_cluster": config.Optional(config.Boolean, default=None), + "decode_responses": config.Optional(config.Boolean, default=True), + "track_key_reads_sample_rate": config.Optional(config.Float, default=0), + "track_key_writes_sample_rate": config.Optional(config.Float, default=0), + } + ) + + options = parser.parse(prefix[:-1], app_config) + + # We're explicitly setting a default here because of https://github.com/Grokzen/redis-py-cluster/issues/435 + kwargs.setdefault("max_connections", options.max_connections) + + kwargs.setdefault("decode_responses", options.decode_responses) + + if options.nodemanager_follow_cluster is not None: + kwargs.setdefault("nodemanager_follow_cluster", options.nodemanager_follow_cluster) + if options.skip_full_coverage_check is not None: + kwargs.setdefault("skip_full_coverage_check", options.skip_full_coverage_check) + if options.timeout is not None: + kwargs.setdefault("timeout", options.timeout.total_seconds()) + + if options.read_from_replicas: + connection_pool = ClusterWithReadReplicasBlockingConnectionPool.from_url( + options.url, **kwargs + ) + else: + connection_pool = rediscluster.ClusterBlockingConnectionPool.from_url(options.url, **kwargs) + + connection_pool.track_key_reads_sample_rate = options.track_key_reads_sample_rate + connection_pool.track_key_writes_sample_rate = options.track_key_writes_sample_rate + + connection_pool.read_from_replicas = options.read_from_replicas + connection_pool.skip_full_coverage_check = options.skip_full_coverage_check + + return connection_pool + + +class ClusterRedisClient(config.Parser): + """Configure a clustered Redis client. + + This is meant to be used with + :py:meth:`baseplate.Baseplate.configure_context`. + See :py:func:`cluster_pool_from_config` for available configuration settings. + """ + + def __init__(self, **kwargs: Any): + self.kwargs = kwargs + + def parse(self, key_path: str, raw_config: config.RawConfig) -> "ClusterRedisContextFactory": + connection_pool = cluster_pool_from_config(raw_config, f"{key_path}.", **self.kwargs) + return ClusterRedisContextFactory(connection_pool) + + +class ClusterRedisContextFactory(ContextFactory): + """Cluster Redis client context factory. + + This factory will attach a + :py:class:`~baseplate.clients.redis.MonitoredClusterRedisConnection` to an + attribute on the :py:class:`~baseplate.RequestContext`. When Redis commands + are executed via this connection object, they will use connections from the + provided :py:class:`rediscluster.ClusterConnectionPool` and automatically record + diagnostic information. + :param connection_pool: A connection pool. + """ + + def __init__(self, connection_pool: rediscluster.ClusterConnectionPool): + self.connection_pool = connection_pool + + def report_runtime_metrics(self, batch: metrics.Client) -> None: + if not isinstance(self.connection_pool, rediscluster.ClusterBlockingConnectionPool): + return + + size = self.connection_pool.max_connections + open_connections = len(self.connection_pool._connections) + + batch.gauge("pool.size").replace(size) + batch.gauge("pool.open_connections").replace(open_connections) + + def make_object_for_context(self, name: str, span: Span) -> "MonitoredClusterRedisConnection": + return MonitoredClusterRedisConnection( + name, + span, + self.connection_pool, + getattr(self.connection_pool, "track_key_reads_sample_rate", 0), + getattr(self.connection_pool, "track_key_writes_sample_rate", 0), + ) + + +class MonitoredClusterRedisConnection(rediscluster.RedisCluster): + """Cluster Redis connection that collects diagnostic information. + + This connection acts like :py:class:`rediscluster.Redis` except that all + operations are automatically wrapped with diagnostic collection. + The interface is the same as that class except for the + :py:meth:`~baseplate.clients.redis.MonitoredClusterRedisConnection.pipeline` + method. + """ + + def __init__( + self, + context_name: str, + server_span: Span, + connection_pool: rediscluster.ClusterConnectionPool, + track_key_reads_sample_rate: float = 0, + track_key_writes_sample_rate: float = 0, + ): + self.context_name = context_name + self.server_span = server_span + self.track_key_reads_sample_rate = track_key_reads_sample_rate + self.track_key_writes_sample_rate = track_key_writes_sample_rate + self.hot_key_tracker = HotKeyTracker( + self, self.track_key_reads_sample_rate, self.track_key_writes_sample_rate + ) + + super().__init__( + connection_pool=connection_pool, + read_from_replicas=connection_pool.read_from_replicas, + skip_full_coverage_check=connection_pool.skip_full_coverage_check, + ) + + def execute_command(self, *args: Any, **kwargs: Any) -> Any: + command = args[0] + trace_name = f"{self.context_name}.{command}" + + with self.server_span.make_child(trace_name): + res = super().execute_command(command, *args[1:], **kwargs) + + self.hot_key_tracker.maybe_track_key_usage(list(args)) + + return res + + # pylint: disable=arguments-differ + def pipeline(self, name: str) -> "MonitoredClusterRedisPipeline": + """Create a pipeline. + + This returns an object on which you can call the standard Redis + commands. Execution will be deferred until ``execute`` is called. This + is useful for saving round trips even in a clustered environment . + :param name: The name to attach to diagnostics for this pipeline. + """ + return MonitoredClusterRedisPipeline( + f"{self.context_name}.pipeline_{name}", + self.server_span, + self.connection_pool, + self.response_callbacks, + read_from_replicas=self.read_from_replicas, + hot_key_tracker=self.hot_key_tracker, + ) + + # No transaction support in redis-py-cluster + def transaction(self, *args: Any, **kwargs: Any) -> Any: + """Not currently implemented.""" + raise NotImplementedError + + +# pylint: disable=abstract-method +class MonitoredClusterRedisPipeline(ClusterPipeline): + def __init__( + self, + trace_name: str, + server_span: Span, + connection_pool: rediscluster.ClusterConnectionPool, + response_callbacks: Dict, + hot_key_tracker: Optional[HotKeyTracker], + **kwargs: Any, + ): + self.trace_name = trace_name + self.server_span = server_span + self.hot_key_tracker = hot_key_tracker + super().__init__(connection_pool, response_callbacks, **kwargs) + + def execute_command(self, *args: Any, **kwargs: Any) -> Any: + res = super().execute_command(*args, **kwargs) + + if self.hot_key_tracker is not None: + self.hot_key_tracker.maybe_track_key_usage(list(args)) + + return res + + # pylint: disable=arguments-differ + def execute(self, **kwargs: Any) -> Any: + with self.server_span.make_child(self.trace_name): + return super().execute(**kwargs) diff --git a/docker-compose.yml b/docker-compose.yml index 96110f4a2..bf32bb41d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,7 @@ services: - "memcached" - "redis" - "zookeeper" - + - "redis-cluster-node" cassandra: image: "cassandra:3.11" environment: @@ -25,3 +25,5 @@ services: image: "redis:4.0.9" zookeeper: image: "zookeeper:3.4.10" + redis-cluster-node: + image: docker.io/grokzen/redis-cluster:6.2.0 diff --git a/docs/api/baseplate/clients/index.rst b/docs/api/baseplate/clients/index.rst index 13f010aba..49449989c 100644 --- a/docs/api/baseplate/clients/index.rst +++ b/docs/api/baseplate/clients/index.rst @@ -13,6 +13,7 @@ Instrumented Client Libraries baseplate.clients.kombu: Client for publishing to queues baseplate.clients.memcache: Memcached Client baseplate.clients.redis: Redis Client + baseplate.clients.redis_cluster: Redis Cluster Client baseplate.clients.requests: Requests (HTTP) Client baseplate.clients.sqlalchemy: SQL Client for relational databases (e.g. PostgreSQL) baseplate.clients.thrift: Thrift client for RPC to other backend services diff --git a/docs/api/baseplate/clients/redis_cluster.rst b/docs/api/baseplate/clients/redis_cluster.rst new file mode 100644 index 000000000..9518f86a9 --- /dev/null +++ b/docs/api/baseplate/clients/redis_cluster.rst @@ -0,0 +1,145 @@ +``baseplate.clients.redis_cluster`` +=================================== + +`Redis`_ is an in-memory data structure store used where speed is necessary but +complexity is beyond simple key-value operations. (If you're just doing +caching, prefer :doc:`memcached `). `Redis-py-cluster`_ is a Python +client library that supports interacting with Redis when operating in cluster mode. + +.. _`Redis`: https://redis.io/ +.. _`redis-py-cluster`: https://github.com/Grokzen/redis-py + +.. automodule:: baseplate.clients.redis_cluster + +.. versionadded:: 2.1 + +Example +------- + +To integrate redis-py-cluster with your application, add the appropriate client +declaration to your context configuration:: + + baseplate.configure_context( + app_config, + { + ... + "foo": ClusterRedisClient(), + ... + } + ) + +configure it in your application's configuration file: + +.. code-block:: ini + + [app:main] + + ... + + + # required: what redis instance to connect to + foo.url = redis://localhost:6379/0 + + # optional: the maximum size of the connection pool + foo.max_connections = 99 + + # optional: how long to wait for a connection to establish + foo.timeout = 3 seconds + + # optional: Whether read requests should be directed to replicas as well + # instead of just the primary + foo.read_from_replicas = true + ... + + +and then use the attached :py:class:`~redis.Redis`-like object in +request:: + + def my_method(request): + request.foo.ping() + +Configuration +------------- + +.. autoclass:: ClusterRedisClient + +.. autofunction:: cluster_pool_from_config + +Classes +------- + +.. autoclass:: ClusterRedisContextFactory + :members: + +.. autoclass:: MonitoredClusterRedisConnection + :members: + +Runtime Metrics +--------------- + +In addition to request-level metrics reported through spans, this wrapper +reports connection pool statistics periodically via the :ref:`runtime-metrics` +system. All metrics are tagged with ``client``, the name given to +:py:meth:`~baseplate.Baseplate.configure_context` when registering this context +factory. + +The following metrics are reported: + +``runtime.pool.size`` + The size limit for the connection pool. +``runtime.pool.in_use`` + How many connections have been established and are currently checked out and + being used. + + +Hot Key Tracking +---------------- + +Optionally, the client can help track key usage across the Redis cluster to +help you identify if you have "hot" keys (keys that are read from or +written to much more frequently than other keys). This is particularly useful +in clusters with ``noeviction`` set as the eviction policy, since Redis +lacks a built-in mechanism to help you track hot keys in this case. + +Since tracking every single key used is expensive, the tracker works by +tracking a small percentage or reads and/or writes, which can be configured +on your client: + +.. code-block:: ini + + [app:main] + + ... + # Note that by default the sample rate will be zero for both reads and writes + + # optional: Sample keys for 1% of read operations + foo.track_key_reads_sample_rate = 0.01 + + # optional: Sample keys for 10% of write operations + foo.track_key_writes_sample_rate = 0.01 + + ... + +The keys tracked will be written to a sorted set in the Redis cluster itself, +which we can query at any time to see what keys are read from or written to +more often than others. Keys used for writes will be stored in +`baseplate-hot-key-tracker-writes` and keys used for reads will be stored in +`baseplate-hot-key-tracker-reads`. Here's an example of how you can query the +top 10 keys on each categories with their associated scores: + +.. code-block:: console + + > ZREVRANGEBYSCORE baseplate-hot-key-tracker-reads +inf 0 WITHSCORES LIMIT 0 10 + + > ZREVRANGEBYSCORE baseplate-hot-key-tracker-writes +inf 0 WITHSCORES LIMIT 0 10 + + +Note that due to how the sampling works the scores are only meaningful in a +relative sense (by comparing one key's access frequency to others in the list) +but can't be used to make any inferences about key access rate or anything like +that. + +Both tracker sets have a default TTL of 24 hours, so once they stop being +written to (for instance, if key tracking is disabled) they will clean up +after themselves in 24 hours, allowing us to start fresh the next time we +want to enable key tracking. diff --git a/requirements-transitive.txt b/requirements-transitive.txt index a354f149c..0bb5c3239 100644 --- a/requirements-transitive.txt +++ b/requirements-transitive.txt @@ -56,6 +56,7 @@ pyramid==1.10.5 python-json-logger==2.0.1 reddit-cqlmapper==0.3.0 redis==3.5.3 +redis-py-cluster==2.1.2 regex==2020.11.13 requests==2.25.1 sentry-sdk==0.20.1 diff --git a/requirements.txt b/requirements.txt index 972763bba..175e7e0e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ sphinx-autodoc-typehints==1.11.1 sphinxcontrib-spelling==7.1.0 webtest==2.0.35 wheel==0.36.2 +fakeredis==1.5.0 diff --git a/setup.cfg b/setup.cfg index b4193df9a..2d9a3aab3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -102,6 +102,9 @@ ignore_missing_imports = True [mypy-pythonjsonlogger.*] ignore_missing_imports = True +[mypy-rediscluster.*] +ignore_missing_imports = True + [mypy-sqlalchemy.*] ignore_missing_imports = True diff --git a/setup.py b/setup.py index 922c5d5c0..61767ff5c 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ "memcache": ["pymemcache>=1.3.0,<1.4.4"], "pyramid": ["pyramid>=1.9.0,<2.0"], "redis": ["redis>=2.10.0,<4.0.0"], + "redis-py-cluster": ["redis-py-cluster>=2.1.2,<3.0.0"], "refcycle": ["objgraph>=3.0,<4.0"], "requests": ["advocate>=1.0.0,<2.0"], "sentry": ["sentry-sdk>=0.19,<1.0"], diff --git a/tests/integration/redis_cluster_tests.py b/tests/integration/redis_cluster_tests.py new file mode 100644 index 000000000..8d450f876 --- /dev/null +++ b/tests/integration/redis_cluster_tests.py @@ -0,0 +1,158 @@ +import unittest + +try: + import rediscluster +except ImportError: + raise unittest.SkipTest("redis-py-cluster is not installed") + +from baseplate.lib.config import ConfigurationError +from baseplate.clients.redis_cluster import cluster_pool_from_config + +from baseplate.clients.redis_cluster import ClusterRedisClient +from baseplate import Baseplate +from . import TestBaseplateObserver, get_endpoint_or_skip_container + +redis_endpoint = get_endpoint_or_skip_container("redis-cluster-node", 7000) + + +# This belongs on the unit tests section but the client class attempts to initialise +# the list of nodes when being instantiated so it's simpler to test here with a redis +# cluster available +class ClusterPoolFromConfigTests(unittest.TestCase): + def test_empty_config(self): + with self.assertRaises(ConfigurationError): + cluster_pool_from_config({}) + + def test_basic_url(self): + pool = cluster_pool_from_config({"rediscluster.url": f"redis://{redis_endpoint}/0"}) + + self.assertEqual(pool.nodes.startup_nodes[0]["host"], "redis-cluster-node") + self.assertEqual(pool.nodes.startup_nodes[0]["port"], "7000") + + def test_timeouts(self): + pool = cluster_pool_from_config( + { + "rediscluster.url": f"redis://{redis_endpoint}/0", + "rediscluster.timeout": "30 seconds", + } + ) + + self.assertEqual(pool.timeout, 30) + + def test_max_connections(self): + pool = cluster_pool_from_config( + { + "rediscluster.url": f"redis://{redis_endpoint}/0", + "rediscluster.max_connections": "300", + } + ) + + self.assertEqual(pool.max_connections, 300) + + def test_max_connections_default(self): + # https://github.com/Grokzen/redis-py-cluster/issues/435 + pool = cluster_pool_from_config({"rediscluster.url": f"redis://{redis_endpoint}/0"}) + + self.assertEqual(pool.max_connections, 50) + + def test_kwargs_passthrough(self): + pool = cluster_pool_from_config( + {"rediscluster.url": f"redis://{redis_endpoint}/0"}, example="present" + ) + + self.assertEqual(pool.connection_kwargs["example"], "present") + + def test_alternate_prefix(self): + pool = cluster_pool_from_config( + {"noodle.url": f"redis://{redis_endpoint}/0"}, prefix="noodle." + ) + self.assertEqual(pool.nodes.startup_nodes[0]["host"], "redis-cluster-node") + self.assertEqual(pool.nodes.startup_nodes[0]["port"], "7000") + + def test_only_primary_available(self): + pool = cluster_pool_from_config({"rediscluster.url": f"redis://{redis_endpoint}/0"}) + node_list = [pool.get_node_by_slot(slot=1, read_command=False) for _ in range(0, 100)] + + # The primary is on port 7000 so that's the only port we expect to see + self.assertTrue(all(node["port"] == 7000 for node in node_list)) + + def test_read_from_replicas(self): + pool = cluster_pool_from_config({"rediscluster.url": f"redis://{redis_endpoint}/0"}) + + node_list = [pool.get_node_by_slot(slot=1, read_command=True) for _ in range(0, 100)] + + # Both replicas and primary are available, so we expect to see some non-primaries here + self.assertTrue(any(node["port"] != 7000 for node in node_list)) + + +class RedisClusterIntegrationTests(unittest.TestCase): + def setUp(self): + self.baseplate_observer = TestBaseplateObserver() + + baseplate = Baseplate( + { + "rediscluster.url": f"redis://{redis_endpoint}/0", + "rediscluster.timeout": "1 second", + "rediscluster.max_connections": "4", + } + ) + baseplate.register(self.baseplate_observer) + baseplate.configure_context({"rediscluster": ClusterRedisClient()}) + + self.context = baseplate.make_context_object() + self.server_span = baseplate.make_server_span(self.context, "test") + + def test_simple_command(self): + with self.server_span: + result = self.context.rediscluster.ping() + + self.assertTrue(result) + + server_span_observer = self.baseplate_observer.get_only_child() + span_observer = server_span_observer.get_only_child() + self.assertEqual(span_observer.span.name, "rediscluster.PING") + self.assertTrue(span_observer.on_start_called) + self.assertTrue(span_observer.on_finish_called) + self.assertIsNone(span_observer.on_finish_exc_info) + + def test_error(self): + with self.server_span: + with self.assertRaises(rediscluster.RedisClusterException): + self.context.rediscluster.execute_command("crazycommand") + + server_span_observer = self.baseplate_observer.get_only_child() + span_observer = server_span_observer.get_only_child() + self.assertTrue(span_observer.on_start_called) + self.assertTrue(span_observer.on_finish_called) + self.assertIsNotNone(span_observer.on_finish_exc_info) + + def test_lock(self): + with self.server_span: + with self.context.rediscluster.lock("foo-lock"): + pass + + server_span_observer = self.baseplate_observer.get_only_child() + + self.assertGreater(len(server_span_observer.children), 0) + for span_observer in server_span_observer.children: + self.assertTrue(span_observer.on_start_called) + self.assertTrue(span_observer.on_finish_called) + + def test_pipeline(self): + with self.server_span: + with self.context.rediscluster.pipeline("foo") as pipeline: + pipeline.set("foo", "bar") + pipeline.get("foo") + pipeline.get("foo") + pipeline.get("foo") + pipeline.get("foo") + pipeline.get("foo") + pipeline.delete("foo") + pipeline.execute() + + server_span_observer = self.baseplate_observer.get_only_child() + span_observer = server_span_observer.get_only_child() + self.assertEqual(span_observer.span.name, "rediscluster.pipeline_foo") + self.assertTrue(span_observer.on_start_called) + self.assertTrue(span_observer.on_finish_called) + self.assertIsNone(span_observer.on_finish_exc_info) diff --git a/tests/unit/clients/redis_cluster_tests.py b/tests/unit/clients/redis_cluster_tests.py new file mode 100644 index 000000000..a9e3b2d3e --- /dev/null +++ b/tests/unit/clients/redis_cluster_tests.py @@ -0,0 +1,106 @@ +import unittest + +import fakeredis + +from baseplate.clients.redis_cluster import HotKeyTracker + + +class HotKeyTrackerTests(unittest.TestCase): + def setUp(self): + self.rc = fakeredis.FakeStrictRedis() + + def test_increment_reads_once(self): + tracker = HotKeyTracker(self.rc, 1, 1) + tracker.increment_keys_read_counter(["foo"], ignore_errors=False) + self.assertEqual( + tracker.redis_client.zrangebyscore( + "baseplate-hot-key-tracker-reads", "-inf", "+inf", withscores=True + ), + [(b"foo", float(1))], + ) + + def test_increment_several_reads(self): + tracker = HotKeyTracker(self.rc, 1, 1) + for _ in range(5): + tracker.increment_keys_read_counter(["foo"], ignore_errors=False) + + tracker.increment_keys_read_counter(["bar"], ignore_errors=False) + + self.assertEqual( + tracker.redis_client.zrangebyscore( + "baseplate-hot-key-tracker-reads", "-inf", "+inf", withscores=True + ), + [(b"bar", float(1)), (b"foo", float(5))], + ) + + def test_reads_disabled_tracking(self): + tracker = HotKeyTracker(self.rc, 0, 0) + for _ in range(5): + tracker.maybe_track_key_usage(["GET", "foo"]) + + self.assertEqual( + tracker.redis_client.zrangebyscore( + "baseplate-hot-key-tracker-reads", "-inf", "+inf", withscores=True + ), + [], + ) + + def test_reads_enabled_tracking(self): + tracker = HotKeyTracker(self.rc, 1, 1) + for _ in range(5): + tracker.maybe_track_key_usage(["GET", "foo"]) + + self.assertEqual( + tracker.redis_client.zrangebyscore( + "baseplate-hot-key-tracker-reads", "-inf", "+inf", withscores=True + ), + [(b"foo", float(5))], + ) + + def test_writes_enabled_tracking(self): + tracker = HotKeyTracker(self.rc, 1, 1) + for _ in range(5): + tracker.maybe_track_key_usage(["SET", "foo", "bar"]) + + self.assertEqual( + tracker.redis_client.zrangebyscore( + "baseplate-hot-key-tracker-writes", "-inf", "+inf", withscores=True + ), + [(b"foo", float(5))], + ) + + def test_writes_disabled_tracking(self): + tracker = HotKeyTracker(self.rc, 0, 0) + for _ in range(5): + tracker.maybe_track_key_usage(["SET", "foo", "bar"]) + + self.assertEqual( + tracker.redis_client.zrangebyscore( + "baseplate-hot-key-tracker-writes", "-inf", "+inf", withscores=True + ), + [], + ) + + def test_write_multikey_commands(self): + tracker = HotKeyTracker(self.rc, 1, 1) + + tracker.maybe_track_key_usage(["DEL", "foo", "bar"]) + + self.assertEqual( + tracker.redis_client.zrangebyscore( + "baseplate-hot-key-tracker-writes", "-inf", "+inf", withscores=True + ), + [(b"bar", float(1)), (b"foo", float(1))], + ) + + def test_write_batchkey_commands(self): + tracker = HotKeyTracker(self.rc, 1, 1) + + tracker.maybe_track_key_usage(["MSET", "foo", "bar", "baz", "wednesday"]) + + self.assertEqual( + tracker.redis_client.zrangebyscore( + "baseplate-hot-key-tracker-writes", "-inf", "+inf", withscores=True + ), + [(b"baz", float(1)), (b"foo", float(1))], + )