Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Jan 6, 2025
1 parent 2b8d3a6 commit 5628d0c
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions backend/onyx/redis/redis_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import json
import ssl
import threading
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -198,14 +199,12 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)


# # Usage example
# redis_pool = RedisPool()
# redis_client = redis_pool.get_client()
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
}

# # Example of setting and getting a value
# redis_client.set('key', 'value')
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'

_async_redis_connection: aioredis.Redis | None = None
_async_lock = asyncio.Lock()
Expand All @@ -224,15 +223,35 @@ async def get_async_redis_connection() -> aioredis.Redis:
async with _async_lock:
# Double-check inside the lock to avoid race conditions
if _async_redis_connection is None:
scheme = "rediss" if REDIS_SSL else "redis"
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"

# Create a new Redis connection (or connection pool) from the URL
_async_redis_connection = aioredis.from_url(
url,
password=REDIS_PASSWORD,
max_connections=REDIS_POOL_MAX_CONNECTIONS,
)
# Load env vars or your config variables

connection_kwargs = {
"host": REDIS_HOST,
"port": REDIS_PORT,
"db": REDIS_DB_NUMBER,
"password": REDIS_PASSWORD,
"max_connections": REDIS_POOL_MAX_CONNECTIONS,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}

if REDIS_SSL:
ssl_context = ssl.create_default_context()

if REDIS_SSL_CA_CERTS:
ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS)
ssl_context.check_hostname = False

# Map your string to the proper ssl.CERT_* constant
ssl_context.verify_mode = SSL_CERT_REQS_MAP.get(
REDIS_SSL_CERT_REQS, ssl.CERT_NONE
)

connection_kwargs["ssl"] = ssl_context

# Create a new Redis connection (or connection pool) with SSL configuration
_async_redis_connection = aioredis.Redis(**connection_kwargs)

# Return the established connection (or pool) for all future operations
return _async_redis_connection
Expand Down

0 comments on commit 5628d0c

Please sign in to comment.