From f98d1eb49555eb77a0f1690ac90345ee81d8e6c2 Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Wed, 15 May 2024 21:18:37 -0400 Subject: [PATCH] API: refactor and fix `get_result(wait=True)` The previous `GET /dispatches/{dispatch_id}` endpoint was trying to do too much. Its responsibilities are now separated into two endpoints: * `GET /dispatches`: bulk query dispatch summaries (including status) with options to filter by `dispatch_id`, sort chronologically, and also limit the output to status only. * `GET /dispatches/{dispatch_id}`: download manifest To achieve the desired behavior of `get_result(id, wait=True)`, the client 1. Polls the dispatch status by querying the first endpoint. 2. Downloads the manifest after the dispatch has reached a final status. The server no longer returns 503 errors when the dispatch is not yet "ready". A 503 status code is not entirely accurate here because it is intended to convey temporary service unavailablity resulting from server overload or rate limiting. However, the fact that the workflow is still running does not indicate any fault of the server. These changes will allow `get_result(dispatch_id, wait=True)` to wait as long as required instead of erroring out after some time. Supporting improvements: DAL: Add sorting and pagination to Controller DAL: improve bulk get when retrieving only some columns Directly select the specified columns instead of retrieving the whole ORM entities and deferring column loading using load_only --- CHANGELOG.md | 1 + covalent/_results_manager/results_manager.py | 100 +++++++------- covalent/triggers/base.py | 17 +-- covalent_dispatcher/_dal/controller.py | 46 +++++-- covalent_dispatcher/_dal/result.py | 40 +++++- covalent_dispatcher/_service/app.py | 123 ++++++++++-------- covalent_dispatcher/_service/models.py | 50 +++++-- .../_dal/result_test.py | 81 ++++++++++++ .../_service/app_test.py | 28 +--- .../results_manager_test.py | 41 ++---- tests/covalent_tests/triggers/base_test.py | 24 +--- 11 files changed, 347 insertions(+), 204 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9411e401d..0042dcd2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Improved handling of Covalent version mismatches between client and executor environments +- `get_result(wait=True)` will wait as long as needed ### Removed diff --git a/covalent/_results_manager/results_manager.py b/covalent/_results_manager/results_manager.py index 4c751206a..75ba88633 100644 --- a/covalent/_results_manager/results_manager.py +++ b/covalent/_results_manager/results_manager.py @@ -19,12 +19,11 @@ import contextlib import os +import time from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional from furl import furl -from requests.adapters import HTTPAdapter -from urllib3.util import Retry from .._api.apiclient import CovalentAPIClient from .._serialize.common import load_asset @@ -40,9 +39,9 @@ from .._shared_files.exceptions import MissingLatticeRecordError from .._shared_files.schemas.asset import AssetSchema from .._shared_files.schemas.result import ResultSchema +from .._shared_files.util_classes import RESULT_STATUS, Status from .._shared_files.utils import copy_file_locally, format_server_url from .result import Result -from .wait import EXTREME app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -139,12 +138,20 @@ def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = # Multi-part +def _query_dispatch_status(dispatch_id: str, api_client: CovalentAPIClient): + endpoint = "/api/v2/dispatches" + resp = api_client.get(endpoint, params={"dispatch_id": dispatch_id, "status_only": True}) + resp.raise_for_status() + dispatches = resp.json()["dispatches"] + if len(dispatches) == 0: + raise MissingLatticeRecordError + + return dispatches[0]["status"] + + def _get_result_export_from_dispatcher( - dispatch_id: str, - wait: bool = False, - status_only: bool = False, - dispatcher_addr: str = None, -) -> Dict: + dispatch_id: str, api_client: CovalentAPIClient +) -> ResultSchema: """ Internal function to get the results of a dispatch from the server without checking if it is ready to read. @@ -161,24 +168,21 @@ def _get_result_export_from_dispatcher( MissingLatticeRecordError: If the result is not found. """ - if dispatcher_addr is None: - dispatcher_addr = format_server_url() + # if dispatcher_addr is None: + # dispatcher_addr = format_server_url() - retries = int(EXTREME) if wait else 5 + # retries = int(EXTREME) if wait else 5 - adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1)) - api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False) + # adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1)) + # api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False) endpoint = f"/api/v2/dispatches/{dispatch_id}" - response = api_client.get( - endpoint, - params={"wait": wait, "status_only": status_only}, - ) + response = api_client.get(endpoint) if response.status_code == 404: raise MissingLatticeRecordError response.raise_for_status() export = response.json() - return export + return ResultSchema.model_validate(export) # Function to download default assets @@ -346,11 +350,17 @@ def from_dispatch_id( wait: bool = False, dispatcher_addr: str = None, ) -> "ResultManager": - export = _get_result_export_from_dispatcher( - dispatch_id, wait, status_only=False, dispatcher_addr=dispatcher_addr - ) + if dispatcher_addr is None: + dispatcher_addr = format_server_url() - manifest = ResultSchema.model_validate(export["result_export"]) + api_client = CovalentAPIClient(dispatcher_addr) + if wait: + status = Status(_query_dispatch_status(dispatch_id, api_client)) + while not RESULT_STATUS.is_terminal(status): + time.sleep(1) + status = Status(_query_dispatch_status(dispatch_id, api_client)) + + manifest = _get_result_export_from_dispatcher(dispatch_id, api_client) # sort the nodes manifest.lattice.transport_graph.nodes.sort(key=lambda x: x.id) @@ -408,14 +418,15 @@ def _get_result_multistage( """ + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + api_client = CovalentAPIClient(dispatcher_addr) try: if status_only: - return _get_result_export_from_dispatcher( - dispatch_id=dispatch_id, - wait=wait, - status_only=status_only, - dispatcher_addr=dispatcher_addr, - ) + status = _query_dispatch_status(dispatch_id, api_client) + return {"id": dispatch_id, "status": status} + rm = get_result_manager(dispatch_id, results_dir, wait, dispatcher_addr) _get_default_assets(rm) @@ -496,23 +507,14 @@ def get_result( The Result object from the Covalent server """ - max_attempts = int(os.getenv("COVALENT_GET_RESULT_RETRIES", 10)) - num_attempts = 0 - while num_attempts < max_attempts: - try: - return _get_result_multistage( - dispatch_id=dispatch_id, - wait=wait, - dispatcher_addr=dispatcher_addr, - status_only=status_only, - results_dir=results_dir, - workflow_output=workflow_output, - intermediate_outputs=intermediate_outputs, - sublattice_results=sublattice_results, - qelectron_db=qelectron_db, - ) - - except RecursionError as re: - app_log.error(re) - num_attempts += 1 - raise RuntimeError("Timed out waiting for result. Please retry or check dispatch.") + return _get_result_multistage( + dispatch_id=dispatch_id, + wait=wait, + dispatcher_addr=dispatcher_addr, + status_only=status_only, + results_dir=results_dir, + workflow_output=workflow_output, + intermediate_outputs=intermediate_outputs, + sublattice_results=sublattice_results, + qelectron_db=qelectron_db, + ) diff --git a/covalent/triggers/base.py b/covalent/triggers/base.py index 2eb49a434..341cd7c74 100644 --- a/covalent/triggers/base.py +++ b/covalent/triggers/base.py @@ -15,8 +15,6 @@ # limitations under the License. -import asyncio -import json from abc import abstractmethod import requests @@ -108,17 +106,12 @@ def _get_status(self) -> Status: """ if self.use_internal_funcs: - from covalent_dispatcher._service.app import export_result + from covalent_dispatcher._service.app import get_dispatches_bulk - response = asyncio.run_coroutine_threadsafe( - export_result(self.lattice_dispatch_id, status_only=True), - self.event_loop, - ).result() - - if isinstance(response, dict): - return response["status"] - - return json.loads(response.body.decode()).get("status") + response = get_dispatches_bulk( + dispatch_id=[self.lattice_dispatch_id], status_only=True + ) + return response.dispatches[0].status from .. import get_result diff --git a/covalent_dispatcher/_dal/controller.py b/covalent_dispatcher/_dal/controller.py index 3e682b979..53928a2fe 100644 --- a/covalent_dispatcher/_dal/controller.py +++ b/covalent_dispatcher/_dal/controller.py @@ -17,10 +17,12 @@ from __future__ import annotations -from typing import Generic, Type, TypeVar +from typing import Generic, List, Optional, Sequence, Type, TypeVar, Union from sqlalchemy import select, update -from sqlalchemy.orm import Session, load_only +from sqlalchemy.engine import Row +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import Select, desc from .._db import models @@ -50,11 +52,16 @@ def get( cls, session: Session, *, + stmt: Optional[Select] = None, fields: list, equality_filters: dict, membership_filters: dict, for_update: bool = False, - ): + sort_fields: List[str] = [], + reverse: bool = True, + offset: int = 0, + max_items: Optional[int] = None, + ) -> Union[Sequence[Row], Sequence[T]]: """Bulk ORM-enabled SELECT. Args: @@ -64,19 +71,40 @@ def get( membership_filters: Dict{field_name: value_list} for_update: Whether to lock the selected rows + Returns: + A list of SQLAlchemy Rows or whole ORM entities depending + on whether only a subset of fields is specified. + """ - stmt = select(cls.model) + if stmt is None: + if len(fields) > 0: + entities = [getattr(cls.model, attr) for attr in fields] + stmt = select(*entities) + else: + stmt = select(cls.model) + for attr, val in equality_filters.items(): stmt = stmt.where(getattr(cls.model, attr) == val) for attr, vals in membership_filters.items(): stmt = stmt.where(getattr(cls.model, attr).in_(vals)) - if len(fields) > 0: - attrs = [getattr(cls.model, f) for f in fields] - stmt = stmt.options(load_only(*attrs)) if for_update: stmt = stmt.with_for_update() - - return session.scalars(stmt).all() + for attr in sort_fields: + if reverse: + stmt = stmt.order_by(desc(getattr(cls.model, attr))) + else: + stmt = stmt.order_by(getattr(cls.model, attr)) + + stmt = stmt.offset(offset) + if max_items: + stmt = stmt.limit(max_items) + + if len(fields) == 0: + # Return whole ORM entities + return session.scalars(stmt).all() + else: + # Return a named tuple containing the selected cols + return session.execute(stmt).all() @classmethod def get_by_primary_key( diff --git a/covalent_dispatcher/_dal/result.py b/covalent_dispatcher/_dal/result.py index a9378558c..489efe11d 100644 --- a/covalent_dispatcher/_dal/result.py +++ b/covalent_dispatcher/_dal/result.py @@ -21,6 +21,7 @@ from datetime import datetime from typing import Any, Dict, List +from sqlalchemy import select from sqlalchemy.orm import Session from covalent._shared_files import logger @@ -45,6 +46,41 @@ class ResultMeta(Record[models.Lattice]): model = models.Lattice + @classmethod + def get_toplevel_dispatches( + cls, + session: Session, + *, + fields: list, + equality_filters: dict, + membership_filters: dict, + for_update: bool = False, + sort_fields: List[str] = [], + reverse: bool = True, + offset: int = 0, + max_items: int = 10, + ): + if len(fields) > 0: + entities = [getattr(cls.model, attr) for attr in fields] + stmt = select(*entities) + else: + stmt = select(cls.model) + + stmt = stmt.where(models.Lattice.root_dispatch_id == models.Lattice.dispatch_id) + + return cls.get( + session=session, + stmt=stmt, + fields=fields, + equality_filters=equality_filters, + membership_filters=membership_filters, + for_update=for_update, + sort_fields=sort_fields, + reverse=reverse, + offset=offset, + max_items=max_items, + ) + class ResultAsset(Record[models.LatticeAsset]): model = models.LatticeAsset @@ -175,7 +211,7 @@ def _update_dispatch( with self.session() as session: electron_rec = Electron.get_db_records( session, - keys={"id", "parent_lattice_id"}, + keys=ELECTRON_KEYS, equality_filters={"id": self._electron_id}, membership_filters={}, )[0] @@ -343,7 +379,7 @@ def _get_incomplete_nodes(self): A dictionary {"failed": [node_ids], "cancelled": [node_ids]} """ with self.session() as session: - query_keys = {"parent_lattice_id", "node_id", "name", "status"} + query_keys = {"id", "parent_lattice_id", "node_id", "name", "status"} records = Electron.get_db_records( session, keys=query_keys, diff --git a/covalent_dispatcher/_service/app.py b/covalent_dispatcher/_service/app.py index 03e71186d..3ebc7da8a 100644 --- a/covalent_dispatcher/_service/app.py +++ b/covalent_dispatcher/_service/app.py @@ -20,11 +20,12 @@ import asyncio import json from contextlib import asynccontextmanager -from typing import List, Optional, Union +from typing import List, Union from uuid import UUID -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi.responses import JSONResponse +from typing_extensions import Annotated import covalent_dispatcher.entry_point as dispatcher from covalent._shared_files import logger @@ -38,7 +39,16 @@ from .._db.datastore import workflow_db from .._db.dispatchdb import DispatchDB from .heartbeat import Heartbeat -from .models import DispatchStatusSetSchema, ExportResponseSchema, TargetDispatchStatus +from .models import ( + BulkDispatchGetSchema, + BulkGetMetadata, + Cursor, + DispatchLinks, + DispatchStatusSetSchema, + DispatchSummary, + Link, + TargetDispatchStatus, +) app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -266,74 +276,77 @@ async def set_dispatch_status(dispatch_id: str, desired_status: DispatchStatusSe return await cancel(dispatch_id, desired_status.task_ids) -@router.get("/dispatches/{dispatch_id}") -async def export_result( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -) -> ExportResponseSchema: - """Export all metadata about a registered dispatch - - Args: - `dispatch_id`: The dispatch's unique id. +@router.get("/dispatches", response_model_exclude_unset=True) +def get_dispatches_bulk( + dispatch_id: Annotated[Union[List[str], None], Query()] = None, + status: Annotated[Union[List[str], None], Query()] = None, + latest: bool = True, + offset: int = 0, + count: int = 10, + status_only: bool = False, +) -> BulkDispatchGetSchema: + dispatch_controller = Result.meta_type - Returns: - { - id: `dispatch_id`, - status: status, - result_export: manifest for the result - } + if status_only: + fields = ["dispatch_id", "status"] + else: + fields = [ + "dispatch_id", + "root_dispatch_id", + "status", + "name", + "electron_num", + "completed_electron_num", + "created_at", + "updated_at", + "completed_at", + ] + + summaries = [] + with workflow_db.session() as session: + in_filters = {} + if dispatch_id is not None: + in_filters["dispatch_id"] = dispatch_id + if status is not None: + in_filters["status"] = status - The manifest `result_export` has the same schema as that which is - submitted to `/register`. + results = dispatch_controller.get( + session, + fields=fields, + equality_filters={"is_active": True}, + membership_filters=in_filters, + sort_fields=["created_at"], + reverse=latest, + offset=offset, + max_items=count, + ) + for res in results: + dispatch_id = res.dispatch_id + summary = DispatchSummary.model_validate(res) + if not status_only: + links = DispatchLinks(manifest=Link(href=f"./{dispatch_id}")) + summary.links = links + summaries.append(summary) - """ - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - None, - _export_result_sync, - dispatch_id, - wait, - status_only, - ) + bulk_meta = BulkGetMetadata(count=len(results), links=Cursor()) + return BulkDispatchGetSchema(dispatches=summaries, metadata=bulk_meta) -def _export_result_sync( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -) -> ExportResponseSchema: +@router.get("/dispatches/{dispatch_id}") +def export_manifest(dispatch_id: str) -> ResultSchema: result_object = _try_get_result_object(dispatch_id) if not result_object: return JSONResponse( status_code=404, content={"message": f"The requested dispatch ID {dispatch_id} was not found."}, ) - status = str(result_object.get_value("status", refresh=False)) - if not wait or status in [ - str(RESULT_STATUS.COMPLETED), - str(RESULT_STATUS.FAILED), - str(RESULT_STATUS.CANCELLED), - ]: - output = { - "id": dispatch_id, - "status": status, - } - if not status_only: - output["result_export"] = export_result_manifest(dispatch_id) - - return output - - response = JSONResponse( - status_code=503, - content={"message": "Result not ready to read yet. Please wait for a couple of seconds."}, - headers={"Retry-After": "2"}, - ) - return response + return export_result_manifest(dispatch_id) def _try_get_result_object(dispatch_id: str) -> Union[Result, None]: try: - res = get_result_object( - dispatch_id, bare=True, keys=["id", "dispatch_id", "status"], lattice_keys=["id"] - ) + res = get_result_object(dispatch_id, bare=True) except KeyError: res = None return res diff --git a/covalent_dispatcher/_service/models.py b/covalent_dispatcher/_service/models.py index 18a33a071..530880b08 100644 --- a/covalent_dispatcher/_service/models.py +++ b/covalent_dispatcher/_service/models.py @@ -17,12 +17,13 @@ """FastAPI models for /api/v1/resultv2 endpoints""" import re +from datetime import datetime from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from covalent._shared_files.schemas.result import ResultSchema +from covalent._shared_files.schemas.result import StatusEnum range_regex = "bytes=([0-9]+)-([0-9]*)" range_pattern = re.compile(range_regex) @@ -56,12 +57,6 @@ class ElectronAssetKey(str, Enum): qelectron_db = "qelectron_db" -class ExportResponseSchema(BaseModel): - id: str - status: str - result_export: Optional[ResultSchema] = None - - class AssetRepresentation(str, Enum): string = "string" b64pickle = "object" @@ -78,3 +73,42 @@ class DispatchStatusSetSchema(BaseModel): # For cancellation, an optional list of task ids to cancel task_ids: Optional[List] = [] + + +class Link(BaseModel): + href: str + + +class Cursor(BaseModel): + back: Optional[Link] = None + forward: Optional[Link] = None + + +class BulkGetMetadata(BaseModel): + count: int + links: Cursor + + +class DispatchLinks(BaseModel): + manifest: Link + + +class DispatchSummary(BaseModel): + model_config = ConfigDict(from_attributes=True) + + dispatch_id: str + root_dispatch_id: Optional[str] = None + status: StatusEnum + name: Optional[str] = None + electron_num: Optional[int] = None + completed_electron_num: Optional[int] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + links: Optional[DispatchLinks] = None + + +class BulkDispatchGetSchema(BaseModel): + dispatches: List[DispatchSummary] + metadata: BulkGetMetadata diff --git a/tests/covalent_dispatcher_tests/_dal/result_test.py b/tests/covalent_dispatcher_tests/_dal/result_test.py index 5b2ec19fa..cb858ebb5 100644 --- a/tests/covalent_dispatcher_tests/_dal/result_test.py +++ b/tests/covalent_dispatcher_tests/_dal/result_test.py @@ -551,3 +551,84 @@ def test_result_filters_parent_electron_updates(test_db, mocker): assert third_update assert subl_node.get_value("output").get_deserialized() == 42 + + +def test_result_controller_bulk_get(test_db, mocker): + record_1 = models.Lattice( + dispatch_id="dispatch_1", + root_dispatch_id="dispatch_1", + name="dispatch_1", + status="NEW_OBJECT", + electron_num=5, + completed_electron_num=0, + ) + + record_2 = models.Lattice( + dispatch_id="dispatch_2", + root_dispatch_id="dispatch_2", + name="dispatch_2", + status="NEW_OBJECT", + electron_num=25, + completed_electron_num=0, + ) + + record_3 = models.Lattice( + dispatch_id="dispatch_3", + root_dispatch_id="dispatch_2", + name="dispatch_3", + status="COMPLETED", + electron_num=25, + completed_electron_num=25, + ) + + with test_db.session() as session: + session.add(record_1) + session.add(record_2) + session.add(record_3) + session.commit() + + dispatch_controller = Result.meta_type + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + ) + assert len(results) == 3 + + with test_db.session() as session: + results = dispatch_controller.get_toplevel_dispatches( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + ) + assert len(results) == 2 + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + sort_fields=["name"], + reverse=False, + max_items=1, + ) + assert len(results) == 1 + assert results[0].dispatch_id == "dispatch_1" + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + sort_fields=["name"], + max_items=2, + offset=1, + ) + assert len(results) == 2 + assert results[0].dispatch_id == "dispatch_2" diff --git a/tests/covalent_dispatcher_tests/_service/app_test.py b/tests/covalent_dispatcher_tests/_service/app_test.py index 4615e35c5..1b353afe2 100644 --- a/tests/covalent_dispatcher_tests/_service/app_test.py +++ b/tests/covalent_dispatcher_tests/_service/app_test.py @@ -262,8 +262,8 @@ def test_start(mocker, app, client): assert resp.json() == dispatch_id -def test_export_result_nowait(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" +def test_export_manifest(mocker, app, client, mock_manifest): + dispatch_id = "test_export_manifest" mock_result_object = MagicMock() mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) mocker.patch( @@ -274,29 +274,11 @@ def test_export_result_nowait(mocker, app, client, mock_manifest): ) mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") resp = client.get(f"/api/v2/dispatches/{dispatch_id}") - assert resp.status_code == 200 - assert resp.json()["id"] == dispatch_id - assert resp.json()["status"] == str(RESULT_STATUS.NEW_OBJECT) - assert resp.json()["result_export"] == json.loads(mock_manifest.json()) - - -def test_export_result_wait_not_ready(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" - mock_result_object = MagicMock() - mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.RUNNING)) - mocker.patch( - "covalent_dispatcher._service.app._try_get_result_object", return_value=mock_result_object - ) - mock_export = mocker.patch( - "covalent_dispatcher._service.app.export_result_manifest", return_value=mock_manifest - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - resp = client.get(f"/api/v2/dispatches/{dispatch_id}", params={"wait": True}) - assert resp.status_code == 503 + assert resp.json() == json.loads(mock_manifest.json()) -def test_export_result_bad_dispatch_id(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" +def test_export_manifest_bad_dispatch_id(mocker, app, client, mock_manifest): + dispatch_id = "test_export_manifest" mock_result_object = MagicMock() mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) mocker.patch("covalent_dispatcher._service.app._try_get_result_object", return_value=None) diff --git a/tests/covalent_tests/results_manager_tests/results_manager_test.py b/tests/covalent_tests/results_manager_tests/results_manager_test.py index 06ba8ece0..72c0260d9 100644 --- a/tests/covalent_tests/results_manager_tests/results_manager_test.py +++ b/tests/covalent_tests/results_manager_tests/results_manager_test.py @@ -22,9 +22,9 @@ from unittest.mock import MagicMock import pytest -from requests import Response import covalent as ct +from covalent._api.apiclient import CovalentAPIClient from covalent._results_manager.results_manager import ( MissingLatticeRecordError, Result, @@ -105,24 +105,23 @@ def test_cancel_with_multiple_task_ids(mocker): def test_result_export(mocker): + import json + with tempfile.TemporaryDirectory() as staging_dir: test_manifest = get_test_manifest(staging_dir) dispatch_id = "test_result_export" - mock_body = {"id": "test_result_export", "status": "COMPLETED"} - + mock_response_body = json.loads(test_manifest.model_dump_json()) mock_client = MagicMock() - mock_response = Response() + mock_response = MagicMock() mock_response.status_code = 200 - mock_response.json = MagicMock(return_value=mock_body) + mock_response.json = MagicMock(return_value=mock_response_body) mocker.patch("covalent._api.apiclient.requests.Session.get", return_value=mock_response) - + apiclient = CovalentAPIClient("http://localhost:48008") endpoint = f"/api/v2/dispatches/{dispatch_id}" - assert mock_body == _get_result_export_from_dispatcher( - dispatch_id, wait=False, status_only=True - ) + assert test_manifest == _get_result_export_from_dispatcher(dispatch_id, apiclient) def test_result_manager_assets_local_copies(): @@ -176,11 +175,7 @@ def test_get_result(mocker): # local file copies from server_dir to results_dir. manifest = get_test_manifest(server_dir) - mock_result_export = { - "id": dispatch_id, - "status": "COMPLETED", - "result_export": manifest.dict(), - } + mock_result_export = manifest mocker.patch( "covalent._results_manager.results_manager._get_result_export_from_dispatcher", return_value=mock_result_export, @@ -208,17 +203,9 @@ def test_get_result_sublattice(mocker): # Sublattice manifest sub_manifest = get_test_manifest(server_dir_sub) - mock_result_export = { - "id": dispatch_id, - "status": "COMPLETED", - "result_export": manifest.dict(), - } + mock_result_export = manifest - mock_subresult_export = { - "id": sub_dispatch_id, - "status": "COMPLETED", - "result_export": sub_manifest.dict(), - } + mock_subresult_export = sub_manifest exports = {dispatch_id: mock_result_export, sub_dispatch_id: mock_subresult_export} @@ -277,10 +264,10 @@ def test_get_result_RecursionError(mocker): def test_get_status_only(mocker): """Check get_result when status_only=True""" - dispatch_id = "test_get_result_st" + dispatch_id = "test_get_status_only" mock_get_result_export = mocker.patch( - "covalent._results_manager.results_manager._get_result_export_from_dispatcher", - return_value={"id": dispatch_id, "status": "RUNNING"}, + "covalent._results_manager.results_manager._query_dispatch_status", + return_value="RUNNING", ) status_report = get_result(dispatch_id, status_only=True) diff --git a/tests/covalent_tests/triggers/base_test.py b/tests/covalent_tests/triggers/base_test.py index 0ca0c8d7c..b70aee295 100644 --- a/tests/covalent_tests/triggers/base_test.py +++ b/tests/covalent_tests/triggers/base_test.py @@ -17,7 +17,6 @@ from unittest import mock import pytest -from fastapi.responses import JSONResponse from covalent.triggers import BaseTrigger @@ -46,7 +45,6 @@ def test_register(mocker): @pytest.mark.parametrize( "use_internal_func, mock_status", [ - (True, JSONResponse("mock")), (True, {"status": "mocked-status"}), (False, {"status": "mocked-status"}), ], @@ -61,27 +59,15 @@ def test_get_status(mocker, use_internal_func, mock_status): base_trigger.use_internal_funcs = use_internal_func if use_internal_func: - mocker.patch("covalent_dispatcher._service.app.export_result") - - mock_fut_res = mock.Mock() - mock_fut_res.result.return_value = mock_status - mock_run_coro = mocker.patch( - "covalent.triggers.base.asyncio.run_coroutine_threadsafe", return_value=mock_fut_res + mock_bulk_get_res = mock.Mock() + mock_bulk_get_res.dispatches = [mock.Mock()] + mock_bulk_get_res.dispatches[0].status = mock_status["status"] + mocker.patch( + "covalent_dispatcher._service.app.get_dispatches_bulk", return_value=mock_bulk_get_res ) - if not isinstance(mock_status, dict): - mock_json_loads = mocker.patch( - "covalent.triggers.base.json.loads", return_value={"status": "mocked-status"} - ) - status = base_trigger._get_status() - mock_run_coro.assert_called_once() - mock_fut_res.result.assert_called_once() - - if not isinstance(mock_status, dict): - mock_json_loads.assert_called_once() - else: mock_get_status = mocker.patch("covalent.get_result", return_value=mock_status) status = base_trigger._get_status()