Skip to content

Commit

Permalink
Initialize db_cache engine in wrapper (#546)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Nov 6, 2024
1 parent c6ae204 commit dfa90ca
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 38 deletions.
59 changes: 22 additions & 37 deletions prediction_market_agent_tooling/tools/caches/db_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,33 +91,22 @@ def decorator(func: FunctionT) -> FunctionT:

api_keys = api_keys if api_keys is not None else APIKeys()

sqlalchemy_db_url = api_keys.SQLALCHEMY_DB_URL
if sqlalchemy_db_url is None:
logger.warning(
f"SQLALCHEMY_DB_URL not provided in the environment, skipping function caching."
)
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# If caching is disabled, just call the function and return it
if not api_keys.ENABLE_CACHE:
return func(*args, **kwargs)

engine = (
create_engine(
sqlalchemy_db_url.get_secret_value(),
engine = create_engine(
api_keys.sqlalchemy_db_url.get_secret_value(),
# Use custom json serializer and deserializer, because otherwise, for example `datetime` serialization would fail.
json_serializer=json_serializer,
json_deserializer=json_deserializer,
)
if sqlalchemy_db_url is not None
else None
)

# Create table if it doesn't exist
if engine is not None:
# Create table if it doesn't exist
SQLModel.metadata.create_all(engine)

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# If caching is disabled, just call the function and return it
if not api_keys.ENABLE_CACHE:
return func(*args, **kwargs)

# Convert *args and **kwargs to a single dictionary, where we have names for arguments passed as args as well.
signature = inspect.signature(func)
bound_arguments = signature.bind(*args, **kwargs)
Expand Down Expand Up @@ -162,25 +151,21 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
if return_type is not None and contains_pydantic_model(return_type):
is_pydantic_model = True

# If postgres access was specified, try to find a hit
if engine is not None:
with Session(engine) as session:
# Try to get cached result
statement = (
select(FunctionCache)
.where(
FunctionCache.function_name == function_name,
FunctionCache.full_function_name == full_function_name,
FunctionCache.args_hash == args_hash,
)
.order_by(desc(FunctionCache.created_at))
with Session(engine) as session:
# Try to get cached result
statement = (
select(FunctionCache)
.where(
FunctionCache.function_name == function_name,
FunctionCache.full_function_name == full_function_name,
FunctionCache.args_hash == args_hash,
)
if max_age is not None:
cutoff_time = utcnow() - max_age
statement = statement.where(FunctionCache.created_at >= cutoff_time)
cached_result = session.exec(statement).first()
else:
cached_result = None
.order_by(desc(FunctionCache.created_at))
)
if max_age is not None:
cutoff_time = utcnow() - max_age
statement = statement.where(FunctionCache.created_at >= cutoff_time)
cached_result = session.exec(statement).first()

if cached_result:
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "prediction-market-agent-tooling"
version = "0.56.0"
version = "0.56.1"
description = "Tools to benchmark, deploy and monitor prediction market agents."
authors = ["Gnosis"]
readme = "README.md"
Expand Down

0 comments on commit dfa90ca

Please sign in to comment.