diff --git a/src/prefect/cli/worker.py b/src/prefect/cli/worker.py index 981ad02bf53b..8603125d58b0 100644 --- a/src/prefect/cli/worker.py +++ b/src/prefect/cli/worker.py @@ -1,3 +1,4 @@ +import json import os import threading from enum import Enum @@ -101,6 +102,15 @@ async def start( help="Install policy to use workers from Prefect integration packages.", case_sensitive=False, ), + base_job_template: typer.FileText = typer.Option( + None, + "--base-job-template", + help=( + "The path to a JSON file containing the base job template to use. If" + " unspecified, Prefect will use the default base job template for the given" + " worker type. If the work pool already exists, this will be ignored." + ), + ), ): """ Start a worker process to poll a work pool for flow runs. @@ -129,6 +139,10 @@ async def start( worker_process_id, f"the {worker_type} worker", app.console.print ) + template_contents = None + if base_job_template is not None: + template_contents = json.load(fp=base_job_template) + async with worker_cls( name=worker_name, work_pool_name=work_pool_name, @@ -136,6 +150,7 @@ async def start( limit=limit, prefetch_seconds=prefetch_seconds, heartbeat_interval_seconds=PREFECT_WORKER_HEARTBEAT_SECONDS.value(), + base_job_template=template_contents, ) as worker: app.console.print(f"Worker {worker.name!r} started!", style="green") async with anyio.create_task_group() as tg: diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py index 3e1028ddd8d8..f1d3b49e3c32 100644 --- a/src/prefect/workers/base.py +++ b/src/prefect/workers/base.py @@ -325,6 +325,8 @@ def __init__( create_pool_if_not_found: bool = True, limit: Optional[int] = None, heartbeat_interval_seconds: Optional[int] = None, + *, + base_job_template: Optional[Dict[str, Any]] = None, ): """ Base class for all Prefect workers. @@ -344,6 +346,8 @@ def __init__( ensure that work pools are not created accidentally. limit: The maximum number of flow runs this worker should be running at a given time. + base_job_template: If creating the work pool, provide the base job + template to use. Logs a warning if the pool already exists. """ if name and ("/" in name or "%" in name): raise ValueError("Worker name cannot contain '/' or '%'") @@ -352,6 +356,7 @@ def __init__( self.is_setup = False self._create_pool_if_not_found = create_pool_if_not_found + self._base_job_template = base_job_template self._work_pool_name = work_pool_name self._work_queues: Set[str] = set(work_queues) if work_queues else set() @@ -656,12 +661,22 @@ async def _update_local_work_pool_info(self): ) except ObjectNotFound: if self._create_pool_if_not_found: - work_pool = await self._client.create_work_pool( - work_pool=WorkPoolCreate(name=self._work_pool_name, type=self.type) + wp = WorkPoolCreate( + name=self._work_pool_name, + type=self.type, ) + if self._base_job_template is not None: + wp.base_job_template = self._base_job_template + + work_pool = await self._client.create_work_pool(work_pool=wp) self._logger.info(f"Work pool {self._work_pool_name!r} created.") else: self._logger.warning(f"Work pool {self._work_pool_name!r} not found!") + if self._base_job_template is not None: + self._logger.warning( + "Ignoring supplied base job template because the work pool" + " already exists" + ) return # if the remote config type changes (or if it's being loaded for the diff --git a/tests/cli/test_worker.py b/tests/cli/test_worker.py index fc2ae66c2125..588d31e0a757 100644 --- a/tests/cli/test_worker.py +++ b/tests/cli/test_worker.py @@ -2,6 +2,7 @@ import signal import sys import tempfile +from pathlib import Path from unittest.mock import ANY import anyio @@ -136,6 +137,51 @@ async def test_start_worker_creates_work_pool(prefect_client: PrefectClient): assert work_pool.default_queue_id is not None +@pytest.mark.usefixtures("use_hosted_api_server") +async def test_start_worker_creates_work_pool_with_base_config( + prefect_client: PrefectClient, +): + await run_sync_in_worker_thread( + invoke_and_assert, + command=[ + "worker", + "start", + "--run-once", + "--pool", + "my-cool-pool", + "--type", + "process", + "--base-job-template", + Path(__file__).parent / "base-job-templates" / "process-worker.json", + ], + expected_code=0, + expected_output_contains=["Worker", "stopped!", "Worker", "started!"], + ) + + work_pool = await prefect_client.read_work_pool("my-cool-pool") + assert work_pool is not None + assert work_pool.name == "my-cool-pool" + assert work_pool.default_queue_id is not None + assert work_pool.base_job_template == { + "job_configuration": {"command": "{{ command }}", "name": "{{ name }}"}, + "variables": { + "properties": { + "command": { + "description": "Command to run.", + "title": "Command", + "type": "string", + }, + "name": { + "description": "Description.", + "title": "Name", + "type": "string", + }, + }, + "type": "object", + }, + } + + @pytest.mark.usefixtures("use_hosted_api_server") def test_start_worker_with_work_queue_names(monkeypatch, process_work_pool): mock_worker = MagicMock() @@ -161,6 +207,7 @@ def test_start_worker_with_work_queue_names(monkeypatch, process_work_pool): prefetch_seconds=ANY, limit=None, heartbeat_interval_seconds=30, + base_job_template=None, ) @@ -189,6 +236,7 @@ def test_start_worker_with_prefetch_seconds(monkeypatch): prefetch_seconds=30, limit=None, heartbeat_interval_seconds=30, + base_job_template=None, ) @@ -216,6 +264,7 @@ def test_start_worker_with_prefetch_seconds_from_setting_by_default(monkeypatch) prefetch_seconds=100, limit=None, heartbeat_interval_seconds=30, + base_job_template=None, ) @@ -244,6 +293,7 @@ def test_start_worker_with_limit(monkeypatch): prefetch_seconds=10, limit=5, heartbeat_interval_seconds=30, + base_job_template=None, )