Skip to content

Commit

Permalink
[Done] Improve celery integration (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw authored Sep 23, 2024
1 parent 55d3f05 commit 4270d4b
Show file tree
Hide file tree
Showing 14 changed files with 308 additions and 110 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ jobs:
ports:
- 9000:9000

rabbitmq:
image: "rabbitmq:3-alpine"
env:
RABBITMQ_DEFAULT_USER: "cleanpython"
RABBITMQ_DEFAULT_PASS: "cleanpython"
RABBITMQ_DEFAULT_VHOST: "cleanpython"
ports:
- "5672:5672"

steps:
- uses: actions/checkout@v3

Expand Down
6 changes: 4 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Changelog of clean-python

## 0.16.6 (unreleased)
## 0.17.0 (unreleased)
----------------------

- Nothing changed yet.
- Added a `celery.CeleryConfig` with an `apply` method that properly sets up celery
without making the tasks depending on the config. Also added integration tests that
confirm the forwarding of context (tenant and correlation id).


## 0.16.5 (2024-09-12)
Expand Down
1 change: 1 addition & 0 deletions clean_python/celery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_task import * # NOQA
from .celery_task_logger import * # NOQA
from .config import * # NOQA
from .kubernetes import * # NOQA
36 changes: 15 additions & 21 deletions clean_python/celery/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,42 @@
from uuid import uuid4

from celery import Task
from celery.worker.request import Request as CeleryRequest

from clean_python import ctx
from clean_python import Json
from clean_python import Id
from clean_python import Tenant
from clean_python import ValueObject

__all__ = ["BaseTask"]


HEADER_FIELD = "clean_python_context"


class TaskHeaders(ValueObject):
tenant: Tenant | None
correlation_id: UUID | None
tenant_id: Id | None = None
# avoid conflict with celery's own correlation_id:
x_correlation_id: UUID | None = None

@classmethod
def from_kwargs(cls, kwargs: Json) -> tuple["TaskHeaders", Json]:
if HEADER_FIELD in kwargs:
kwargs = kwargs.copy()
headers = kwargs.pop(HEADER_FIELD)
return TaskHeaders(**headers), kwargs
else:
return TaskHeaders(tenant=None, correlation_id=None), kwargs
def from_celery_request(cls, request: CeleryRequest) -> "TaskHeaders":
return cls(**request.headers)


class BaseTask(Task):
def apply_async(self, args=None, kwargs=None, **options):
# include correlation_id and tenant in the kwargs
# and NOT the headers as that is buggy in celery
# see https://github.com/celery/celery/issues/4875
kwargs = {} if kwargs is None else kwargs.copy()
kwargs[HEADER_FIELD] = TaskHeaders(
tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
options["headers"] = TaskHeaders(
tenant_id=ctx.tenant.id if ctx.tenant else None,
x_correlation_id=ctx.correlation_id or uuid4(),
).model_dump(mode="json")
return super().apply_async(args, kwargs, **options)

def __call__(self, *args, **kwargs):
return copy_context().run(self._call_with_context, *args, **kwargs)

def _call_with_context(self, *args, **kwargs):
headers, kwargs = TaskHeaders.from_kwargs(kwargs)
ctx.tenant = headers.tenant
ctx.correlation_id = headers.correlation_id
headers = TaskHeaders.from_celery_request(self.request)
ctx.tenant = (
Tenant(id=headers.tenant_id, name="") if headers.tenant_id else None
)
ctx.correlation_id = headers.x_correlation_id or uuid4()
return super().__call__(*args, **kwargs)
12 changes: 6 additions & 6 deletions clean_python/celery/celery_task_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@ def stop(self, task: Task, state: str, result: Any = None):
request = None

try:
headers, kwargs = TaskHeaders.from_kwargs(request.kwargs)
headers = TaskHeaders.from_celery_request(request)
except (AttributeError, TypeError):
headers = kwargs = None # type: ignore
headers = None

try:
tenant_id = headers.tenant.id # type: ignore
tenant_id = headers.tenant_id # type: ignore
except AttributeError:
tenant_id = None

try:
correlation_id = headers.correlation_id
correlation_id = headers.x_correlation_id # type: ignore
except AttributeError:
correlation_id = None

Expand All @@ -86,8 +86,8 @@ def stop(self, task: Task, state: str, result: Any = None):
argsrepr = None

try:
kwargsrepr = json.dumps(kwargs)
except TypeError:
kwargsrepr = json.dumps(request.kwargs)
except (AttributeError, TypeError):
kwargsrepr = None

log_dict = {
Expand Down
31 changes: 31 additions & 0 deletions clean_python/celery/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from celery import Celery
from celery import current_app

from clean_python import Json
from clean_python import ValueObject
from clean_python.celery import BaseTask

__all__ = ["CeleryConfig"]


class CeleryConfig(ValueObject):
timezone: str = "Europe/Amsterdam"
broker_url: str
broker_transport_options: Json = {"socket_timeout": 2}
broker_connection_retry_on_startup: bool = True
result_backend: str | None = None
worker_prefetch_multiplier: int = 1
task_always_eager: bool = False
task_eager_propagates: bool = False
task_acks_late: bool = True
task_default_queue: str = "default"
task_default_priority: int = 0
task_queue_max_priority: int = 10
task_track_started: bool = True

def apply(self, strict_typing: bool = True) -> Celery:
app = current_app if current_app else Celery()
app.task_cls = BaseTask
app.strict_typing = strict_typing
app.config_from_object(self)
return app
21 changes: 19 additions & 2 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: "3.8"

services:

postgres:
Expand All @@ -18,8 +16,27 @@ services:
MINIO_ROOT_PASSWORD: cleanpython
ports:
- "9000:9000"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres -d postgres"]
interval: 2s
retries: 10
timeout: 1s

fluentbit:
image: fluent/fluent-bit:1.9
ports:
- "24224:24224"

rabbitmq:
image: "rabbitmq:3-alpine"
environment:
RABBITMQ_DEFAULT_USER: "cleanpython"
RABBITMQ_DEFAULT_PASS: "cleanpython"
RABBITMQ_DEFAULT_VHOST: "cleanpython"
ports:
- "5672:5672"
healthcheck:
test: rabbitmq-diagnostics check_port_connectivity
interval: 10s
timeout: 1s
retries: 5
22 changes: 22 additions & 0 deletions integration_tests/celery_example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from pathlib import Path

from clean_python.celery import CeleryConfig
from clean_python.celery import CeleryTaskLogger
from clean_python.celery import set_task_logger
from clean_python.testing.debugger import setup_debugger

from .logger import MultilineJsonFileGateway
from .tasks import sleep_task # NOQA

app = CeleryConfig(
broker_url="amqp://cleanpython:cleanpython@localhost/cleanpython",
result_backend="rpc://",
).apply()
# the file path is set from the test fixture
logging_path = os.environ.get("CLEAN_PYTHON_TEST_LOGGING")
if logging_path:
set_task_logger(CeleryTaskLogger(MultilineJsonFileGateway(Path(logging_path))))
debug_port = os.environ.get("CLEAN_PYTHON_TEST_DEBUG")
if debug_port:
setup_debugger(port=int(debug_port))
34 changes: 34 additions & 0 deletions integration_tests/celery_example/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import json
from pathlib import Path

from clean_python import Filter
from clean_python import Json
from clean_python import PageOptions
from clean_python import SyncGateway

__all__ = ["MultilineJsonFileGateway"]


class MultilineJsonFileGateway(SyncGateway):
def __init__(self, path: Path) -> None:
self.path = path

def clear(self):
if self.path.exists():
self.path.unlink()

def filter(
self, filters: list[Filter], params: PageOptions | None = None
) -> list[Json]:
assert not filters
assert not params
if not self.path.exists():
return []
with self.path.open("r") as f:
return [json.loads(line) for line in f]

def add(self, item: Json) -> Json:
with self.path.open("a") as f:
f.write(json.dumps(item))
f.write("\n")
return item
34 changes: 34 additions & 0 deletions integration_tests/celery_example/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import time

from celery import shared_task
from celery import Task
from celery.exceptions import Ignore
from celery.exceptions import Reject

from clean_python import ctx


@shared_task(bind=True, name="testing")
def sleep_task(self: Task, seconds: float, return_value=None, event="success"):
event = event.lower()
if event == "success":
time.sleep(int(seconds))
elif event == "crash":
import ctypes

ctypes.string_at(0) # segfault
elif event == "ignore":
raise Ignore()
elif event == "reject":
raise Reject()
elif event == "retry":
raise self.retry(countdown=seconds, max_retries=1)
elif event == "context":
return {
"tenant_id": ctx.tenant.id,
"correlation_id": str(ctx.correlation_id),
}
else:
raise ValueError(f"Unknown event '{event}'")

return {"value": return_value}
36 changes: 36 additions & 0 deletions integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import io
import multiprocessing
import os
import signal
import subprocess
import time
from pathlib import Path
from urllib.error import URLError
from urllib.request import urlopen

Expand All @@ -13,6 +16,8 @@
import uvicorn
from botocore.exceptions import ClientError

from .celery_example import MultilineJsonFileGateway


def pytest_sessionstart(session):
"""
Expand Down Expand Up @@ -102,6 +107,37 @@ async def fastapi_example_app():
p.terminate()


@pytest.fixture(scope="session")
def celery_worker(tmp_path_factory):
log_file = str(tmp_path_factory.mktemp("pytest-celery") / "celery.log")
p = subprocess.Popen(
[
"celery",
"-A",
"integration_tests.celery_example",
"worker",
"-c",
"1",
# "-P", enable when using the debugger
# "solo"
],
start_new_session=True,
stdout=subprocess.PIPE,
# optionally add "CLEAN_PYTHON_TEST_DEBUG": "5679" to enable debugging
env={"CLEAN_PYTHON_TEST_LOGGING": log_file, **os.environ},
)
try:
yield MultilineJsonFileGateway(Path(log_file))
finally:
p.send_signal(signal.SIGQUIT)


@pytest.fixture
def celery_task_logs(celery_worker):
celery_worker.clear()
return celery_worker


@pytest.fixture(scope="session")
def s3_settings(s3_url):
minio_settings = {
Expand Down
Loading

0 comments on commit 4270d4b

Please sign in to comment.