diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f670f14b..0fa9cf554 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Improved handling of Covalent version mismatches between client and executor environments +- `get_result(wait=True)` will wait as long as needed ### Removed diff --git a/covalent/_dispatcher_plugins/local.py b/covalent/_dispatcher_plugins/local.py index 9857342cf..bc29af30c 100644 --- a/covalent/_dispatcher_plugins/local.py +++ b/covalent/_dispatcher_plugins/local.py @@ -129,8 +129,6 @@ def dispatch( Wrapper function which takes the inputs of the workflow as arguments """ - multistage = get_config("sdk.multistage_dispatch") == "true" - # Extract triggers here if "triggers" in orig_lattice.metadata: triggers_data = orig_lattice.metadata.pop("triggers") @@ -155,14 +153,7 @@ def wrapper(*args, **kwargs) -> str: The dispatch id of the workflow. """ - if multistage: - dispatch_id = LocalDispatcher.register(orig_lattice, dispatcher_addr)( - *args, **kwargs - ) - else: - dispatch_id = LocalDispatcher.submit(orig_lattice, dispatcher_addr)( - *args, **kwargs - ) + dispatch_id = LocalDispatcher.register(orig_lattice, dispatcher_addr)(*args, **kwargs) if triggers_data: LocalDispatcher.register_triggers(triggers_data, dispatch_id) @@ -237,61 +228,6 @@ def wrapper(*args, **kwargs) -> str: return wrapper - @staticmethod - def submit( - orig_lattice: Lattice, - dispatcher_addr: str = None, - ) -> Callable: - """ - Wrapping the dispatching functionality to allow input passing - and server address specification. - - Afterwards, send the lattice to the dispatcher server and return - the assigned dispatch id. - - Args: - orig_lattice: The lattice/workflow to send to the dispatcher server. - dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. - - Returns: - Wrapper function which takes the inputs of the workflow as arguments - """ - - if dispatcher_addr is None: - dispatcher_addr = format_server_url() - - @wraps(orig_lattice) - def wrapper(*args, **kwargs) -> str: - """ - Send the lattice to the dispatcher server and return - the assigned dispatch id. - - Args: - *args: The inputs of the workflow. - **kwargs: The keyword arguments of the workflow. - - Returns: - The dispatch id of the workflow. - """ - - if not isinstance(orig_lattice, Lattice): - message = f"Dispatcher expected a Lattice, received {type(orig_lattice)} instead." - app_log.error(message) - raise TypeError(message) - - lattice = deepcopy(orig_lattice) - - lattice.build_graph(*args, **kwargs) - - # Serialize the transport graph to JSON - json_lattice = lattice.serialize_to_json() - endpoint = "/api/v2/dispatches/submit" - r = APIClient(dispatcher_addr).post(endpoint, data=json_lattice) - r.raise_for_status() - return r.content.decode("utf-8").strip().replace('"', "") - - return wrapper - @staticmethod def start( dispatch_id: str, diff --git a/covalent/_results_manager/results_manager.py b/covalent/_results_manager/results_manager.py index 4c751206a..31d0f76ae 100644 --- a/covalent/_results_manager/results_manager.py +++ b/covalent/_results_manager/results_manager.py @@ -19,12 +19,11 @@ import contextlib import os +import time from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional from furl import furl -from requests.adapters import HTTPAdapter -from urllib3.util import Retry from .._api.apiclient import CovalentAPIClient from .._serialize.common import load_asset @@ -40,9 +39,9 @@ from .._shared_files.exceptions import MissingLatticeRecordError from .._shared_files.schemas.asset import AssetSchema from .._shared_files.schemas.result import ResultSchema +from .._shared_files.util_classes import RESULT_STATUS, Status from .._shared_files.utils import copy_file_locally, format_server_url from .result import Result -from .wait import EXTREME app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -139,12 +138,20 @@ def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = # Multi-part +def _query_dispatch_status(dispatch_id: str, api_client: CovalentAPIClient): + endpoint = "/api/v2/dispatches" + resp = api_client.get(endpoint, params={"dispatch_id": dispatch_id, "status_only": True}) + resp.raise_for_status() + dispatches = resp.json()["dispatches"] + if len(dispatches) == 0: + raise MissingLatticeRecordError + + return dispatches[0]["status"] + + def _get_result_export_from_dispatcher( - dispatch_id: str, - wait: bool = False, - status_only: bool = False, - dispatcher_addr: str = None, -) -> Dict: + dispatch_id: str, api_client: CovalentAPIClient +) -> ResultSchema: """ Internal function to get the results of a dispatch from the server without checking if it is ready to read. @@ -161,24 +168,13 @@ def _get_result_export_from_dispatcher( MissingLatticeRecordError: If the result is not found. """ - if dispatcher_addr is None: - dispatcher_addr = format_server_url() - - retries = int(EXTREME) if wait else 5 - - adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1)) - api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False) - endpoint = f"/api/v2/dispatches/{dispatch_id}" - response = api_client.get( - endpoint, - params={"wait": wait, "status_only": status_only}, - ) + response = api_client.get(endpoint) if response.status_code == 404: raise MissingLatticeRecordError response.raise_for_status() export = response.json() - return export + return ResultSchema.model_validate(export) # Function to download default assets @@ -346,11 +342,17 @@ def from_dispatch_id( wait: bool = False, dispatcher_addr: str = None, ) -> "ResultManager": - export = _get_result_export_from_dispatcher( - dispatch_id, wait, status_only=False, dispatcher_addr=dispatcher_addr - ) + if dispatcher_addr is None: + dispatcher_addr = format_server_url() - manifest = ResultSchema.model_validate(export["result_export"]) + api_client = CovalentAPIClient(dispatcher_addr) + if wait: + status = Status(_query_dispatch_status(dispatch_id, api_client)) + while not RESULT_STATUS.is_terminal(status): + time.sleep(1) + status = Status(_query_dispatch_status(dispatch_id, api_client)) + + manifest = _get_result_export_from_dispatcher(dispatch_id, api_client) # sort the nodes manifest.lattice.transport_graph.nodes.sort(key=lambda x: x.id) @@ -408,14 +410,15 @@ def _get_result_multistage( """ + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + api_client = CovalentAPIClient(dispatcher_addr) try: if status_only: - return _get_result_export_from_dispatcher( - dispatch_id=dispatch_id, - wait=wait, - status_only=status_only, - dispatcher_addr=dispatcher_addr, - ) + status = _query_dispatch_status(dispatch_id, api_client) + return {"id": dispatch_id, "status": status} + rm = get_result_manager(dispatch_id, results_dir, wait, dispatcher_addr) _get_default_assets(rm) @@ -496,23 +499,14 @@ def get_result( The Result object from the Covalent server """ - max_attempts = int(os.getenv("COVALENT_GET_RESULT_RETRIES", 10)) - num_attempts = 0 - while num_attempts < max_attempts: - try: - return _get_result_multistage( - dispatch_id=dispatch_id, - wait=wait, - dispatcher_addr=dispatcher_addr, - status_only=status_only, - results_dir=results_dir, - workflow_output=workflow_output, - intermediate_outputs=intermediate_outputs, - sublattice_results=sublattice_results, - qelectron_db=qelectron_db, - ) - - except RecursionError as re: - app_log.error(re) - num_attempts += 1 - raise RuntimeError("Timed out waiting for result. Please retry or check dispatch.") + return _get_result_multistage( + dispatch_id=dispatch_id, + wait=wait, + dispatcher_addr=dispatcher_addr, + status_only=status_only, + results_dir=results_dir, + workflow_output=workflow_output, + intermediate_outputs=intermediate_outputs, + sublattice_results=sublattice_results, + qelectron_db=qelectron_db, + ) diff --git a/covalent/_shared_files/defaults.py b/covalent/_shared_files/defaults.py index aef61086d..2780d7c53 100644 --- a/covalent/_shared_files/defaults.py +++ b/covalent/_shared_files/defaults.py @@ -67,9 +67,6 @@ def get_default_sdk_config(): + "/covalent/dispatches" ), "task_packing": "true" if os.environ.get("COVALENT_ENABLE_TASK_PACKING") else "false", - "multistage_dispatch": ( - "false" if os.environ.get("COVALENT_DISABLE_MULTISTAGE_DISPATCH") == "1" else "true" - ), "results_dir": os.environ.get( "COVALENT_RESULTS_DIR" ) # COVALENT_RESULTS_DIR is where the client downloads workflow artifacts during get_result() which is different from COVALENT_DATA_DIR diff --git a/covalent/triggers/base.py b/covalent/triggers/base.py index 2eb49a434..341cd7c74 100644 --- a/covalent/triggers/base.py +++ b/covalent/triggers/base.py @@ -15,8 +15,6 @@ # limitations under the License. -import asyncio -import json from abc import abstractmethod import requests @@ -108,17 +106,12 @@ def _get_status(self) -> Status: """ if self.use_internal_funcs: - from covalent_dispatcher._service.app import export_result + from covalent_dispatcher._service.app import get_dispatches_bulk - response = asyncio.run_coroutine_threadsafe( - export_result(self.lattice_dispatch_id, status_only=True), - self.event_loop, - ).result() - - if isinstance(response, dict): - return response["status"] - - return json.loads(response.body.decode()).get("status") + response = get_dispatches_bulk( + dispatch_id=[self.lattice_dispatch_id], status_only=True + ) + return response.dispatches[0].status from .. import get_result diff --git a/covalent_dispatcher/_dal/controller.py b/covalent_dispatcher/_dal/controller.py index 3e682b979..53928a2fe 100644 --- a/covalent_dispatcher/_dal/controller.py +++ b/covalent_dispatcher/_dal/controller.py @@ -17,10 +17,12 @@ from __future__ import annotations -from typing import Generic, Type, TypeVar +from typing import Generic, List, Optional, Sequence, Type, TypeVar, Union from sqlalchemy import select, update -from sqlalchemy.orm import Session, load_only +from sqlalchemy.engine import Row +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import Select, desc from .._db import models @@ -50,11 +52,16 @@ def get( cls, session: Session, *, + stmt: Optional[Select] = None, fields: list, equality_filters: dict, membership_filters: dict, for_update: bool = False, - ): + sort_fields: List[str] = [], + reverse: bool = True, + offset: int = 0, + max_items: Optional[int] = None, + ) -> Union[Sequence[Row], Sequence[T]]: """Bulk ORM-enabled SELECT. Args: @@ -64,19 +71,40 @@ def get( membership_filters: Dict{field_name: value_list} for_update: Whether to lock the selected rows + Returns: + A list of SQLAlchemy Rows or whole ORM entities depending + on whether only a subset of fields is specified. + """ - stmt = select(cls.model) + if stmt is None: + if len(fields) > 0: + entities = [getattr(cls.model, attr) for attr in fields] + stmt = select(*entities) + else: + stmt = select(cls.model) + for attr, val in equality_filters.items(): stmt = stmt.where(getattr(cls.model, attr) == val) for attr, vals in membership_filters.items(): stmt = stmt.where(getattr(cls.model, attr).in_(vals)) - if len(fields) > 0: - attrs = [getattr(cls.model, f) for f in fields] - stmt = stmt.options(load_only(*attrs)) if for_update: stmt = stmt.with_for_update() - - return session.scalars(stmt).all() + for attr in sort_fields: + if reverse: + stmt = stmt.order_by(desc(getattr(cls.model, attr))) + else: + stmt = stmt.order_by(getattr(cls.model, attr)) + + stmt = stmt.offset(offset) + if max_items: + stmt = stmt.limit(max_items) + + if len(fields) == 0: + # Return whole ORM entities + return session.scalars(stmt).all() + else: + # Return a named tuple containing the selected cols + return session.execute(stmt).all() @classmethod def get_by_primary_key( diff --git a/covalent_dispatcher/_dal/result.py b/covalent_dispatcher/_dal/result.py index a9378558c..489efe11d 100644 --- a/covalent_dispatcher/_dal/result.py +++ b/covalent_dispatcher/_dal/result.py @@ -21,6 +21,7 @@ from datetime import datetime from typing import Any, Dict, List +from sqlalchemy import select from sqlalchemy.orm import Session from covalent._shared_files import logger @@ -45,6 +46,41 @@ class ResultMeta(Record[models.Lattice]): model = models.Lattice + @classmethod + def get_toplevel_dispatches( + cls, + session: Session, + *, + fields: list, + equality_filters: dict, + membership_filters: dict, + for_update: bool = False, + sort_fields: List[str] = [], + reverse: bool = True, + offset: int = 0, + max_items: int = 10, + ): + if len(fields) > 0: + entities = [getattr(cls.model, attr) for attr in fields] + stmt = select(*entities) + else: + stmt = select(cls.model) + + stmt = stmt.where(models.Lattice.root_dispatch_id == models.Lattice.dispatch_id) + + return cls.get( + session=session, + stmt=stmt, + fields=fields, + equality_filters=equality_filters, + membership_filters=membership_filters, + for_update=for_update, + sort_fields=sort_fields, + reverse=reverse, + offset=offset, + max_items=max_items, + ) + class ResultAsset(Record[models.LatticeAsset]): model = models.LatticeAsset @@ -175,7 +211,7 @@ def _update_dispatch( with self.session() as session: electron_rec = Electron.get_db_records( session, - keys={"id", "parent_lattice_id"}, + keys=ELECTRON_KEYS, equality_filters={"id": self._electron_id}, membership_filters={}, )[0] @@ -343,7 +379,7 @@ def _get_incomplete_nodes(self): A dictionary {"failed": [node_ids], "cancelled": [node_ids]} """ with self.session() as session: - query_keys = {"parent_lattice_id", "node_id", "name", "status"} + query_keys = {"id", "parent_lattice_id", "node_id", "name", "status"} records = Electron.get_db_records( session, keys=query_keys, diff --git a/covalent_dispatcher/_service/app.py b/covalent_dispatcher/_service/app.py index 03e71186d..3d582a7bd 100644 --- a/covalent_dispatcher/_service/app.py +++ b/covalent_dispatcher/_service/app.py @@ -20,11 +20,11 @@ import asyncio import json from contextlib import asynccontextmanager -from typing import List, Optional, Union -from uuid import UUID +from typing import List, Union -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, HTTPException, Query from fastapi.responses import JSONResponse +from typing_extensions import Annotated import covalent_dispatcher.entry_point as dispatcher from covalent._shared_files import logger @@ -38,7 +38,13 @@ from .._db.datastore import workflow_db from .._db.dispatchdb import DispatchDB from .heartbeat import Heartbeat -from .models import DispatchStatusSetSchema, ExportResponseSchema, TargetDispatchStatus +from .models import ( + BulkDispatchGetSchema, + BulkGetMetadata, + DispatchStatusSetSchema, + DispatchSummary, + TargetDispatchStatus, +) app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -98,31 +104,6 @@ async def cancel_all_with_status(status: RESULT_STATUS): await dispatcher.cancel_running_dispatch(dispatch_id) -@router.post("/dispatches/submit") -async def submit(request: Request) -> UUID: - """ - Function to accept the submit request of - new dispatch and return the dispatch id - back to the client. - - Args: - None - - Returns: - dispatch_id: The dispatch id in a json format - returned as a Fast API Response object. - """ - try: - data = await request.json() - data = json.dumps(data).encode("utf-8") - return await dispatcher.make_dispatch(data) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to submit workflow: {e}", - ) from e - - async def start(dispatch_id: str): """Start a previously registered (re-)dispatch. @@ -266,74 +247,74 @@ async def set_dispatch_status(dispatch_id: str, desired_status: DispatchStatusSe return await cancel(dispatch_id, desired_status.task_ids) -@router.get("/dispatches/{dispatch_id}") -async def export_result( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -) -> ExportResponseSchema: - """Export all metadata about a registered dispatch +@router.get("/dispatches", response_model_exclude_unset=True) +def get_dispatches_bulk( + dispatch_id: Annotated[Union[List[str], None], Query()] = None, + status: Annotated[Union[List[str], None], Query()] = None, + latest: bool = True, + offset: int = 0, + count: int = 10, + status_only: bool = False, +) -> BulkDispatchGetSchema: + dispatch_controller = Result.meta_type - Args: - `dispatch_id`: The dispatch's unique id. - - Returns: - { - id: `dispatch_id`, - status: status, - result_export: manifest for the result - } + if status_only: + fields = ["dispatch_id", "status"] + else: + fields = [ + "dispatch_id", + "root_dispatch_id", + "status", + "name", + "electron_num", + "completed_electron_num", + "created_at", + "updated_at", + "completed_at", + ] + + summaries = [] + with workflow_db.session() as session: + in_filters = {} + if dispatch_id is not None: + in_filters["dispatch_id"] = dispatch_id + if status is not None: + in_filters["status"] = status - The manifest `result_export` has the same schema as that which is - submitted to `/register`. + results = dispatch_controller.get( + session, + fields=fields, + equality_filters={"is_active": True}, + membership_filters=in_filters, + sort_fields=["created_at"], + reverse=latest, + offset=offset, + max_items=count, + ) + for res in results: + dispatch_id = res.dispatch_id + summary = DispatchSummary.model_validate(res) + summaries.append(summary) - """ - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - None, - _export_result_sync, - dispatch_id, - wait, - status_only, - ) + bulk_meta = BulkGetMetadata(count=len(results)) + return BulkDispatchGetSchema(dispatches=summaries, metadata=bulk_meta) -def _export_result_sync( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -) -> ExportResponseSchema: +@router.get("/dispatches/{dispatch_id}") +def export_manifest(dispatch_id: str) -> ResultSchema: result_object = _try_get_result_object(dispatch_id) if not result_object: return JSONResponse( status_code=404, content={"message": f"The requested dispatch ID {dispatch_id} was not found."}, ) - status = str(result_object.get_value("status", refresh=False)) - if not wait or status in [ - str(RESULT_STATUS.COMPLETED), - str(RESULT_STATUS.FAILED), - str(RESULT_STATUS.CANCELLED), - ]: - output = { - "id": dispatch_id, - "status": status, - } - if not status_only: - output["result_export"] = export_result_manifest(dispatch_id) - - return output - - response = JSONResponse( - status_code=503, - content={"message": "Result not ready to read yet. Please wait for a couple of seconds."}, - headers={"Retry-After": "2"}, - ) - return response + return export_result_manifest(dispatch_id) def _try_get_result_object(dispatch_id: str) -> Union[Result, None]: try: - res = get_result_object( - dispatch_id, bare=True, keys=["id", "dispatch_id", "status"], lattice_keys=["id"] - ) + res = get_result_object(dispatch_id, bare=True) except KeyError: res = None return res diff --git a/covalent_dispatcher/_service/models.py b/covalent_dispatcher/_service/models.py index 18a33a071..f751c1bc2 100644 --- a/covalent_dispatcher/_service/models.py +++ b/covalent_dispatcher/_service/models.py @@ -17,12 +17,13 @@ """FastAPI models for /api/v1/resultv2 endpoints""" import re +from datetime import datetime from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from covalent._shared_files.schemas.result import ResultSchema +from covalent._shared_files.schemas.result import StatusEnum range_regex = "bytes=([0-9]+)-([0-9]*)" range_pattern = re.compile(range_regex) @@ -56,12 +57,6 @@ class ElectronAssetKey(str, Enum): qelectron_db = "qelectron_db" -class ExportResponseSchema(BaseModel): - id: str - status: str - result_export: Optional[ResultSchema] = None - - class AssetRepresentation(str, Enum): string = "string" b64pickle = "object" @@ -78,3 +73,26 @@ class DispatchStatusSetSchema(BaseModel): # For cancellation, an optional list of task ids to cancel task_ids: Optional[List] = [] + + +class BulkGetMetadata(BaseModel): + count: int + + +class DispatchSummary(BaseModel): + model_config = ConfigDict(from_attributes=True) + + dispatch_id: str + root_dispatch_id: Optional[str] = None + status: StatusEnum + name: Optional[str] = None + electron_num: Optional[int] = None + completed_electron_num: Optional[int] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + +class BulkDispatchGetSchema(BaseModel): + dispatches: List[DispatchSummary] + metadata: BulkGetMetadata diff --git a/tests/covalent_dispatcher_tests/_dal/result_test.py b/tests/covalent_dispatcher_tests/_dal/result_test.py index 5b2ec19fa..cb858ebb5 100644 --- a/tests/covalent_dispatcher_tests/_dal/result_test.py +++ b/tests/covalent_dispatcher_tests/_dal/result_test.py @@ -551,3 +551,84 @@ def test_result_filters_parent_electron_updates(test_db, mocker): assert third_update assert subl_node.get_value("output").get_deserialized() == 42 + + +def test_result_controller_bulk_get(test_db, mocker): + record_1 = models.Lattice( + dispatch_id="dispatch_1", + root_dispatch_id="dispatch_1", + name="dispatch_1", + status="NEW_OBJECT", + electron_num=5, + completed_electron_num=0, + ) + + record_2 = models.Lattice( + dispatch_id="dispatch_2", + root_dispatch_id="dispatch_2", + name="dispatch_2", + status="NEW_OBJECT", + electron_num=25, + completed_electron_num=0, + ) + + record_3 = models.Lattice( + dispatch_id="dispatch_3", + root_dispatch_id="dispatch_2", + name="dispatch_3", + status="COMPLETED", + electron_num=25, + completed_electron_num=25, + ) + + with test_db.session() as session: + session.add(record_1) + session.add(record_2) + session.add(record_3) + session.commit() + + dispatch_controller = Result.meta_type + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + ) + assert len(results) == 3 + + with test_db.session() as session: + results = dispatch_controller.get_toplevel_dispatches( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + ) + assert len(results) == 2 + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + sort_fields=["name"], + reverse=False, + max_items=1, + ) + assert len(results) == 1 + assert results[0].dispatch_id == "dispatch_1" + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + sort_fields=["name"], + max_items=2, + offset=1, + ) + assert len(results) == 2 + assert results[0].dispatch_id == "dispatch_2" diff --git a/tests/covalent_dispatcher_tests/_service/app_test.py b/tests/covalent_dispatcher_tests/_service/app_test.py index 4615e35c5..d02382b98 100644 --- a/tests/covalent_dispatcher_tests/_service/app_test.py +++ b/tests/covalent_dispatcher_tests/_service/app_test.py @@ -103,30 +103,6 @@ def test_db_file(): return MockDataStore(db_URL="sqlite+pysqlite:////tmp/testdb.sqlite") -@pytest.mark.asyncio -async def test_submit(mocker, client): - """Test the submit endpoint.""" - mock_data = json.dumps({}).encode("utf-8") - run_dispatcher_mock = mocker.patch( - "covalent_dispatcher.entry_point.make_dispatch", return_value=DISPATCH_ID - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - response = client.post("/api/v2/dispatches/submit", data=mock_data) - assert response.json() == DISPATCH_ID - run_dispatcher_mock.assert_called_once_with(mock_data) - - -@pytest.mark.asyncio -async def test_submit_exception(mocker, client): - """Test the submit endpoint.""" - mock_data = json.dumps({}).encode("utf-8") - mocker.patch("covalent_dispatcher.entry_point.make_dispatch", side_effect=Exception("mock")) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - response = client.post("/api/v2/dispatches/submit", data=mock_data) - assert response.status_code == 400 - assert response.json()["detail"] == "Failed to submit workflow: mock" - - def test_cancel_dispatch(mocker, app, client): """ Test cancelling dispatch @@ -262,8 +238,8 @@ def test_start(mocker, app, client): assert resp.json() == dispatch_id -def test_export_result_nowait(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" +def test_export_manifest(mocker, app, client, mock_manifest): + dispatch_id = "test_export_manifest" mock_result_object = MagicMock() mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) mocker.patch( @@ -274,29 +250,11 @@ def test_export_result_nowait(mocker, app, client, mock_manifest): ) mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") resp = client.get(f"/api/v2/dispatches/{dispatch_id}") - assert resp.status_code == 200 - assert resp.json()["id"] == dispatch_id - assert resp.json()["status"] == str(RESULT_STATUS.NEW_OBJECT) - assert resp.json()["result_export"] == json.loads(mock_manifest.json()) - - -def test_export_result_wait_not_ready(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" - mock_result_object = MagicMock() - mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.RUNNING)) - mocker.patch( - "covalent_dispatcher._service.app._try_get_result_object", return_value=mock_result_object - ) - mock_export = mocker.patch( - "covalent_dispatcher._service.app.export_result_manifest", return_value=mock_manifest - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - resp = client.get(f"/api/v2/dispatches/{dispatch_id}", params={"wait": True}) - assert resp.status_code == 503 + assert resp.json() == json.loads(mock_manifest.json()) -def test_export_result_bad_dispatch_id(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" +def test_export_manifest_bad_dispatch_id(mocker, app, client, mock_manifest): + dispatch_id = "test_export_manifest" mock_result_object = MagicMock() mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) mocker.patch("covalent_dispatcher._service.app._try_get_result_object", return_value=None) diff --git a/tests/covalent_tests/dispatcher_plugins/local_test.py b/tests/covalent_tests/dispatcher_plugins/local_test.py index d3c09c316..e10c83d31 100644 --- a/tests/covalent_tests/dispatcher_plugins/local_test.py +++ b/tests/covalent_tests/dispatcher_plugins/local_test.py @@ -74,42 +74,6 @@ def workflow(a, b): mock_print.assert_called_once_with(message) -def test_dispatcher_dispatch_single(mocker): - """test dispatching a lattice with submit api""" - - @ct.electron - def task(a, b, c): - return a + b + c - - @ct.lattice - def workflow(a, b): - return task(a, b, c=4) - - # test when api raises an implicit error - - dispatch_id = "test_dispatcher_dispatch_single" - # multistage = False - mocker.patch("covalent._dispatcher_plugins.local.get_config", return_value=False) - - mock_submit_callable = MagicMock(return_value=dispatch_id) - mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.submit", - return_value=mock_submit_callable, - ) - - mock_reg_tr = mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" - ) - mock_start = mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=dispatch_id - ) - - assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) - - mock_submit_callable.assert_called() - mock_start.assert_called() - - def test_dispatcher_dispatch_multi(mocker): """test dispatching a lattice with multistage api""" @@ -131,12 +95,6 @@ def workflow(a, b): return_value=mock_register_callable, ) - mock_submit_callable = MagicMock(return_value=dispatch_id) - mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.submit", - return_value=mock_submit_callable, - ) - mock_reg_tr = mocker.patch( "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" ) @@ -146,7 +104,6 @@ def workflow(a, b): assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) - mock_submit_callable.assert_not_called() mock_register_callable.assert_called() mock_start.assert_called() @@ -172,12 +129,6 @@ def workflow(a, b): return_value=mock_register_callable, ) - mock_submit_callable = MagicMock(return_value=dispatch_id) - mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.submit", - return_value=mock_submit_callable, - ) - mock_reg_tr = mocker.patch( "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" ) @@ -190,41 +141,6 @@ def workflow(a, b): mock_start.assert_not_called() -def test_dispatcher_submit_api(mocker): - """test dispatching a lattice with submit api""" - - @ct.electron - def task(a, b, c): - return a + b + c - - @ct.lattice - def workflow(a, b): - return task(a, b, c=4) - - # test when api raises an implicit error - r = Response() - r.status_code = 404 - r.url = "http://dummy" - r.reason = "dummy reason" - - mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) - - with pytest.raises(HTTPError, match="404 Client Error: dummy reason for url: http://dummy"): - dispatch_id = LocalDispatcher.submit(workflow)(1, 2) - assert dispatch_id is None - - # test when api doesn't raise an implicit error - r = Response() - r.status_code = 201 - r.url = "http://dummy" - r._content = b"abcde" - - mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) - - dispatch_id = LocalDispatcher.submit(workflow)(1, 2) - assert dispatch_id == "abcde" - - def test_dispatcher_start(mocker): """Test starting a dispatch""" diff --git a/tests/covalent_tests/results_manager_tests/results_manager_test.py b/tests/covalent_tests/results_manager_tests/results_manager_test.py index 06ba8ece0..72c0260d9 100644 --- a/tests/covalent_tests/results_manager_tests/results_manager_test.py +++ b/tests/covalent_tests/results_manager_tests/results_manager_test.py @@ -22,9 +22,9 @@ from unittest.mock import MagicMock import pytest -from requests import Response import covalent as ct +from covalent._api.apiclient import CovalentAPIClient from covalent._results_manager.results_manager import ( MissingLatticeRecordError, Result, @@ -105,24 +105,23 @@ def test_cancel_with_multiple_task_ids(mocker): def test_result_export(mocker): + import json + with tempfile.TemporaryDirectory() as staging_dir: test_manifest = get_test_manifest(staging_dir) dispatch_id = "test_result_export" - mock_body = {"id": "test_result_export", "status": "COMPLETED"} - + mock_response_body = json.loads(test_manifest.model_dump_json()) mock_client = MagicMock() - mock_response = Response() + mock_response = MagicMock() mock_response.status_code = 200 - mock_response.json = MagicMock(return_value=mock_body) + mock_response.json = MagicMock(return_value=mock_response_body) mocker.patch("covalent._api.apiclient.requests.Session.get", return_value=mock_response) - + apiclient = CovalentAPIClient("http://localhost:48008") endpoint = f"/api/v2/dispatches/{dispatch_id}" - assert mock_body == _get_result_export_from_dispatcher( - dispatch_id, wait=False, status_only=True - ) + assert test_manifest == _get_result_export_from_dispatcher(dispatch_id, apiclient) def test_result_manager_assets_local_copies(): @@ -176,11 +175,7 @@ def test_get_result(mocker): # local file copies from server_dir to results_dir. manifest = get_test_manifest(server_dir) - mock_result_export = { - "id": dispatch_id, - "status": "COMPLETED", - "result_export": manifest.dict(), - } + mock_result_export = manifest mocker.patch( "covalent._results_manager.results_manager._get_result_export_from_dispatcher", return_value=mock_result_export, @@ -208,17 +203,9 @@ def test_get_result_sublattice(mocker): # Sublattice manifest sub_manifest = get_test_manifest(server_dir_sub) - mock_result_export = { - "id": dispatch_id, - "status": "COMPLETED", - "result_export": manifest.dict(), - } + mock_result_export = manifest - mock_subresult_export = { - "id": sub_dispatch_id, - "status": "COMPLETED", - "result_export": sub_manifest.dict(), - } + mock_subresult_export = sub_manifest exports = {dispatch_id: mock_result_export, sub_dispatch_id: mock_subresult_export} @@ -277,10 +264,10 @@ def test_get_result_RecursionError(mocker): def test_get_status_only(mocker): """Check get_result when status_only=True""" - dispatch_id = "test_get_result_st" + dispatch_id = "test_get_status_only" mock_get_result_export = mocker.patch( - "covalent._results_manager.results_manager._get_result_export_from_dispatcher", - return_value={"id": dispatch_id, "status": "RUNNING"}, + "covalent._results_manager.results_manager._query_dispatch_status", + return_value="RUNNING", ) status_report = get_result(dispatch_id, status_only=True) diff --git a/tests/covalent_tests/triggers/base_test.py b/tests/covalent_tests/triggers/base_test.py index 0ca0c8d7c..b70aee295 100644 --- a/tests/covalent_tests/triggers/base_test.py +++ b/tests/covalent_tests/triggers/base_test.py @@ -17,7 +17,6 @@ from unittest import mock import pytest -from fastapi.responses import JSONResponse from covalent.triggers import BaseTrigger @@ -46,7 +45,6 @@ def test_register(mocker): @pytest.mark.parametrize( "use_internal_func, mock_status", [ - (True, JSONResponse("mock")), (True, {"status": "mocked-status"}), (False, {"status": "mocked-status"}), ], @@ -61,27 +59,15 @@ def test_get_status(mocker, use_internal_func, mock_status): base_trigger.use_internal_funcs = use_internal_func if use_internal_func: - mocker.patch("covalent_dispatcher._service.app.export_result") - - mock_fut_res = mock.Mock() - mock_fut_res.result.return_value = mock_status - mock_run_coro = mocker.patch( - "covalent.triggers.base.asyncio.run_coroutine_threadsafe", return_value=mock_fut_res + mock_bulk_get_res = mock.Mock() + mock_bulk_get_res.dispatches = [mock.Mock()] + mock_bulk_get_res.dispatches[0].status = mock_status["status"] + mocker.patch( + "covalent_dispatcher._service.app.get_dispatches_bulk", return_value=mock_bulk_get_res ) - if not isinstance(mock_status, dict): - mock_json_loads = mocker.patch( - "covalent.triggers.base.json.loads", return_value={"status": "mocked-status"} - ) - status = base_trigger._get_status() - mock_run_coro.assert_called_once() - mock_fut_res.result.assert_called_once() - - if not isinstance(mock_status, dict): - mock_json_loads.assert_called_once() - else: mock_get_status = mocker.patch("covalent.get_result", return_value=mock_status) status = base_trigger._get_status()