From 3c86777802b6f37b1860dae06ed5e01ce02df373 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sun, 8 Dec 2024 05:19:00 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=84=20synced=20local=20'skyvern/'=20wi?= =?UTF-8?q?th=20remote=20'skyvern/'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit > [!IMPORTANT] > Add support for creating, executing, and completing observer cruises, including database schema updates and integration into existing execution and service logic. > > - **Database**: > - Add `prompt` and `url` columns to `observer_cruises` table in `2024_12_08_0513-51399a87a9bc_add_prompt_and_url_to_observer_cruises_.py`. > - Add `ObserverCruiseModel` and `ObserverThoughtModel` to `models.py`. > - Implement `create_observer_cruise`, `get_observer_cruise`, `update_observer_cruise`, and `create_observer_thought` in `client.py`. > - **Execution**: > - Modify `execute_cruise()` in `cloud_async_executor.py` to handle observer cruises using `observer_cruise_id`. > - Update job submission logic to use `observer_cruise_id` instead of `workflow_run_id`, `user_prompt`, and `url`. > - **Routes**: > - Update `run_observer()` in `observer.py` to return `ObserverCruise` and use `intialize_observer_cruise()`. > - **Services**: > - Add `intialize_observer_cruise()` in `observer_service.py` to create observer cruises. > - **Scripts**: > - Update `run_observer.py` to handle observer cruises, including validation and execution logic. > - Modify `run_observer_wrapper.sh` to pass `observer_cruise_id` to `run_observer.py`. > - **Exceptions**: > - Change `InvalidUrl` to inherit from `SkyvernHTTPException` in `exceptions.py`. > - **Validators**: > - Remove `validate_url()` from `validators.py`. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=Skyvern-AI%2Fskyvern-cloud&utm_source=github&utm_medium=referral) for 3f4fc732e85013716095e19841d2c03033ebd90e. It will automatically update as commits are pushed. --- skyvern/exceptions.py | 2 +- skyvern/forge/agent.py | 2 +- skyvern/forge/sdk/artifact/manager.py | 38 ++++++--- skyvern/forge/sdk/core/validators.py | 13 +--- skyvern/forge/sdk/db/client.py | 103 +++++++++++++++++++++++++ skyvern/forge/sdk/db/models.py | 2 + skyvern/forge/sdk/schemas/observers.py | 6 +- skyvern/forge/sdk/workflow/service.py | 2 +- 8 files changed, 143 insertions(+), 25 deletions(-) diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 96e7f8000..522209b75 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -504,7 +504,7 @@ def __init__(self, message: str) -> None: super().__init__(message) -class InvalidUrl(SkyvernException): +class InvalidUrl(SkyvernHTTPException): def __init__(self, url: str) -> None: super().__init__(f"Invalid URL: {url}. Skyvern supports HTTP and HTTPS urls with max 2083 character length.") diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index dfc83863a..aa9d5c5b0 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -1536,7 +1536,7 @@ async def clean_up_task( await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task) # Wait for all tasks to complete before generating the links for the artifacts - await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks([task.task_id]) + await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_task(task.task_id) if need_call_webhook: await self.execute_task_webhook(task=task, last_step=last_step, api_key=api_key) diff --git a/skyvern/forge/sdk/artifact/manager.py b/skyvern/forge/sdk/artifact/manager.py index ec182468c..ea2188258 100644 --- a/skyvern/forge/sdk/artifact/manager.py +++ b/skyvern/forge/sdk/artifact/manager.py @@ -135,28 +135,48 @@ async def get_share_link(self, artifact: Artifact) -> str | None: async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None: return await app.STORAGE.get_share_links(artifacts) - async def wait_for_upload_aiotasks(self, primary_keys: list[str]) -> None: + async def wait_for_upload_aiotasks_for_task(self, task_id: str) -> None: + try: + st = time.time() + async with asyncio.timeout(30): + await asyncio.gather( + *[aio_task for aio_task in self.upload_aiotasks_map[task_id] if not aio_task.done()] + ) + LOG.info( + f"S3 upload tasks for task_id={task_id} completed in {time.time() - st:.2f}s", + task_id=task_id, + duration=time.time() - st, + ) + except asyncio.TimeoutError: + LOG.error( + f"Timeout (30s) while waiting for upload tasks for task_id={task_id}", + task_id=task_id, + ) + + del self.upload_aiotasks_map[task_id] + + async def wait_for_upload_aiotasks_for_tasks(self, task_ids: list[str]) -> None: try: st = time.time() async with asyncio.timeout(30): await asyncio.gather( *[ aio_task - for primary_key in primary_keys - for aio_task in self.upload_aiotasks_map[primary_key] + for task_id in task_ids + for aio_task in self.upload_aiotasks_map[task_id] if not aio_task.done() ] ) LOG.info( - f"S3 upload aio tasks for primary_keys={primary_keys} completed in {time.time() - st:.2f}s", - primary_keys=primary_keys, + f"S3 upload tasks for task_ids={task_ids} completed in {time.time() - st:.2f}s", + task_ids=task_ids, duration=time.time() - st, ) except asyncio.TimeoutError: LOG.error( - f"Timeout (30s) while waiting for upload aio tasks for primary_keys={primary_keys}", - primary_keys=primary_keys, + f"Timeout (30s) while waiting for upload tasks for task_ids={task_ids}", + task_ids=task_ids, ) - for primary_key in primary_keys: - del self.upload_aiotasks_map[primary_key] + for task_id in task_ids: + del self.upload_aiotasks_map[task_id] diff --git a/skyvern/forge/sdk/core/validators.py b/skyvern/forge/sdk/core/validators.py index 7da7c9903..054455a31 100644 --- a/skyvern/forge/sdk/core/validators.py +++ b/skyvern/forge/sdk/core/validators.py @@ -1,7 +1,7 @@ import ipaddress from urllib.parse import urlparse -from pydantic import HttpUrl, ValidationError, parse_obj_as +from pydantic import HttpUrl, ValidationError from skyvern.config import settings from skyvern.exceptions import InvalidUrl @@ -27,17 +27,6 @@ def prepend_scheme_and_validate_url(url: str) -> str: return url -def validate_url(url: str) -> str: - try: - if url: - # Use parse_obj_as to validate the string as an HttpUrl - parse_obj_as(HttpUrl, url) - return url - except ValidationError: - # Handle the validation error - raise InvalidUrl(url=url) - - def is_blocked_host(host: str) -> bool: try: ip = ipaddress.ip_address(host) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 24c4be0c8..ccd9e3940 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -19,6 +19,8 @@ BitwardenCreditCardDataParameterModel, BitwardenLoginCredentialParameterModel, BitwardenSensitiveInformationParameterModel, + ObserverCruiseModel, + ObserverThoughtModel, OrganizationAuthTokenModel, OrganizationModel, OutputParameterModel, @@ -50,6 +52,7 @@ convert_to_workflow_run_parameter, ) from skyvern.forge.sdk.models import Step, StepStatus +from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverCruiseStatus, ObserverThought from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.task_generations import TaskGeneration from skyvern.forge.sdk.schemas.tasks import OrderBy, ProxyLocation, SortDirection, Task, TaskStatus @@ -1703,3 +1706,103 @@ async def delete_task_actions(self, organization_id: str, task_id: str) -> None: ) await session.execute(stmt) await session.commit() + + async def get_observer_cruise( + self, observer_cruise_id: str, organization_id: str | None = None + ) -> ObserverCruise | None: + async with self.Session() as session: + if observer_cruise := ( + await session.scalars( + select(ObserverCruiseModel) + .filter_by(observer_cruise_id=observer_cruise_id) + .filter_by(organization_id=organization_id) + ) + ).first(): + return ObserverCruise.model_validate(observer_cruise) + return None + + async def get_observer_thought( + self, observer_thought_id: str, organization_id: str | None = None + ) -> ObserverThought | None: + async with self.Session() as session: + if observer_thought := ( + await session.scalars( + select(ObserverThoughtModel) + .filter_by(observer_thought_id=observer_thought_id) + .filter_by(organization_id=organization_id) + ) + ).first(): + return ObserverThought.model_validate(observer_thought) + return None + + async def create_observer_cruise( + self, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + prompt: str | None = None, + url: str | None = None, + organization_id: str | None = None, + ) -> ObserverCruise: + async with self.Session() as session: + new_observer_cruise = ObserverCruiseModel( + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + prompt=prompt, + url=url, + organization_id=organization_id, + ) + session.add(new_observer_cruise) + await session.commit() + await session.refresh(new_observer_cruise) + return ObserverCruise.model_validate(new_observer_cruise) + + async def create_observer_thought( + self, + observer_cruise_id: str, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + workflow_run_block_id: str | None = None, + user_input: str | None = None, + observation: str | None = None, + thought: str | None = None, + answer: str | None = None, + organization_id: str | None = None, + ) -> ObserverThought: + async with self.Session() as session: + new_observer_thought = ObserverThoughtModel( + observer_cruise_id=observer_cruise_id, + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + workflow_run_block_id=workflow_run_block_id, + user_input=user_input, + observation=observation, + thought=thought, + answer=answer, + organization_id=organization_id, + ) + session.add(new_observer_thought) + await session.commit() + await session.refresh(new_observer_thought) + return ObserverThought.model_validate(new_observer_thought) + + async def update_observer_cruise( + self, + observer_cruise_id: str, + status: ObserverCruiseStatus | None = None, + organization_id: str | None = None, + ) -> ObserverCruise: + async with self.Session() as session: + observer_cruise = ( + await session.scalars( + select(ObserverCruiseModel) + .filter_by(observer_cruise_id=observer_cruise_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if observer_cruise: + if status: + observer_cruise.status = status + await session.commit() + await session.refresh(observer_cruise) + return ObserverCruise.model_validate(observer_cruise) + raise NotFoundError(f"ObserverCruise {observer_cruise_id} not found") diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 6f97c330f..19e1ba016 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -515,6 +515,8 @@ class ObserverCruiseModel(Base): organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=True) workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), nullable=True) workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=True) + prompt = Column(UnicodeText, nullable=True) + url = Column(String, nullable=True) class ObserverThoughtModel(Base): diff --git a/skyvern/forge/sdk/schemas/observers.py b/skyvern/forge/sdk/schemas/observers.py index 1fa9f0732..a8ada1359 100644 --- a/skyvern/forge/sdk/schemas/observers.py +++ b/skyvern/forge/sdk/schemas/observers.py @@ -1,7 +1,7 @@ from datetime import datetime from enum import StrEnum -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, HttpUrl class ObserverCruiseStatus(StrEnum): @@ -23,12 +23,16 @@ class ObserverCruise(BaseModel): organization_id: str | None = None workflow_run_id: str | None = None workflow_id: str | None = None + prompt: str | None = None + url: HttpUrl | None = None created_at: datetime modified_at: datetime class ObserverThought(BaseModel): + model_config = ConfigDict(from_attributes=True) + observer_thought_id: str observer_cruise_id: str organization_id: str | None = None diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 838f7b2f7..7575a28a9 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -859,7 +859,7 @@ async def clean_up_workflow( ) LOG.info("Persisted browser session for workflow run", workflow_run_id=workflow_run.workflow_run_id) - await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks(all_workflow_task_ids) + await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids) try: async with asyncio.timeout(SAVE_DOWNLOADED_FILES_TIMEOUT):