Skip to content

Commit

Permalink
Configure base job template during worker startup (PrefectHQ#10798)
Browse files Browse the repository at this point in the history
On "prefect worker start", allow users to pass in the base job
template instead of using the default.

Related to PrefectHQ#9576
  • Loading branch information
jawnsy authored Sep 22, 2023
1 parent 4e7550d commit 56f4220
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
15 changes: 15 additions & 0 deletions src/prefect/cli/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import threading
from enum import Enum
Expand Down Expand Up @@ -101,6 +102,15 @@ async def start(
help="Install policy to use workers from Prefect integration packages.",
case_sensitive=False,
),
base_job_template: typer.FileText = typer.Option(
None,
"--base-job-template",
help=(
"The path to a JSON file containing the base job template to use. If"
" unspecified, Prefect will use the default base job template for the given"
" worker type. If the work pool already exists, this will be ignored."
),
),
):
"""
Start a worker process to poll a work pool for flow runs.
Expand Down Expand Up @@ -129,13 +139,18 @@ async def start(
worker_process_id, f"the {worker_type} worker", app.console.print
)

template_contents = None
if base_job_template is not None:
template_contents = json.load(fp=base_job_template)

async with worker_cls(
name=worker_name,
work_pool_name=work_pool_name,
work_queues=work_queues,
limit=limit,
prefetch_seconds=prefetch_seconds,
heartbeat_interval_seconds=PREFECT_WORKER_HEARTBEAT_SECONDS.value(),
base_job_template=template_contents,
) as worker:
app.console.print(f"Worker {worker.name!r} started!", style="green")
async with anyio.create_task_group() as tg:
Expand Down
19 changes: 17 additions & 2 deletions src/prefect/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ def __init__(
create_pool_if_not_found: bool = True,
limit: Optional[int] = None,
heartbeat_interval_seconds: Optional[int] = None,
*,
base_job_template: Optional[Dict[str, Any]] = None,
):
"""
Base class for all Prefect workers.
Expand All @@ -344,6 +346,8 @@ def __init__(
ensure that work pools are not created accidentally.
limit: The maximum number of flow runs this worker should be running at
a given time.
base_job_template: If creating the work pool, provide the base job
template to use. Logs a warning if the pool already exists.
"""
if name and ("/" in name or "%" in name):
raise ValueError("Worker name cannot contain '/' or '%'")
Expand All @@ -352,6 +356,7 @@ def __init__(

self.is_setup = False
self._create_pool_if_not_found = create_pool_if_not_found
self._base_job_template = base_job_template
self._work_pool_name = work_pool_name
self._work_queues: Set[str] = set(work_queues) if work_queues else set()

Expand Down Expand Up @@ -656,12 +661,22 @@ async def _update_local_work_pool_info(self):
)
except ObjectNotFound:
if self._create_pool_if_not_found:
work_pool = await self._client.create_work_pool(
work_pool=WorkPoolCreate(name=self._work_pool_name, type=self.type)
wp = WorkPoolCreate(
name=self._work_pool_name,
type=self.type,
)
if self._base_job_template is not None:
wp.base_job_template = self._base_job_template

work_pool = await self._client.create_work_pool(work_pool=wp)
self._logger.info(f"Work pool {self._work_pool_name!r} created.")
else:
self._logger.warning(f"Work pool {self._work_pool_name!r} not found!")
if self._base_job_template is not None:
self._logger.warning(
"Ignoring supplied base job template because the work pool"
" already exists"
)
return

# if the remote config type changes (or if it's being loaded for the
Expand Down
50 changes: 50 additions & 0 deletions tests/cli/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import signal
import sys
import tempfile
from pathlib import Path
from unittest.mock import ANY

import anyio
Expand Down Expand Up @@ -136,6 +137,51 @@ async def test_start_worker_creates_work_pool(prefect_client: PrefectClient):
assert work_pool.default_queue_id is not None


@pytest.mark.usefixtures("use_hosted_api_server")
async def test_start_worker_creates_work_pool_with_base_config(
prefect_client: PrefectClient,
):
await run_sync_in_worker_thread(
invoke_and_assert,
command=[
"worker",
"start",
"--run-once",
"--pool",
"my-cool-pool",
"--type",
"process",
"--base-job-template",
Path(__file__).parent / "base-job-templates" / "process-worker.json",
],
expected_code=0,
expected_output_contains=["Worker", "stopped!", "Worker", "started!"],
)

work_pool = await prefect_client.read_work_pool("my-cool-pool")
assert work_pool is not None
assert work_pool.name == "my-cool-pool"
assert work_pool.default_queue_id is not None
assert work_pool.base_job_template == {
"job_configuration": {"command": "{{ command }}", "name": "{{ name }}"},
"variables": {
"properties": {
"command": {
"description": "Command to run.",
"title": "Command",
"type": "string",
},
"name": {
"description": "Description.",
"title": "Name",
"type": "string",
},
},
"type": "object",
},
}


@pytest.mark.usefixtures("use_hosted_api_server")
def test_start_worker_with_work_queue_names(monkeypatch, process_work_pool):
mock_worker = MagicMock()
Expand All @@ -161,6 +207,7 @@ def test_start_worker_with_work_queue_names(monkeypatch, process_work_pool):
prefetch_seconds=ANY,
limit=None,
heartbeat_interval_seconds=30,
base_job_template=None,
)


Expand Down Expand Up @@ -189,6 +236,7 @@ def test_start_worker_with_prefetch_seconds(monkeypatch):
prefetch_seconds=30,
limit=None,
heartbeat_interval_seconds=30,
base_job_template=None,
)


Expand Down Expand Up @@ -216,6 +264,7 @@ def test_start_worker_with_prefetch_seconds_from_setting_by_default(monkeypatch)
prefetch_seconds=100,
limit=None,
heartbeat_interval_seconds=30,
base_job_template=None,
)


Expand Down Expand Up @@ -244,6 +293,7 @@ def test_start_worker_with_limit(monkeypatch):
prefetch_seconds=10,
limit=5,
heartbeat_interval_seconds=30,
base_job_template=None,
)


Expand Down

0 comments on commit 56f4220

Please sign in to comment.