Skip to content

Commit

Permalink
Observer artifact creation (#1345)
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng authored Dec 7, 2024
1 parent 7591873 commit 127f25c
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 53 deletions.
101 changes: 79 additions & 22 deletions skyvern/forge/sdk/artifact/manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio
import time
from collections import defaultdict
from typing import Literal

import structlog

from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.id import generate_artifact_id
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverThought

LOG = structlog.get_logger(__name__)

Expand All @@ -16,58 +18,113 @@ class ArtifactManager:
# task_id -> list of aio_tasks for uploading artifacts
upload_aiotasks_map: dict[str, list[asyncio.Task[None]]] = defaultdict(list)

async def create_artifact(
async def _create_artifact(
self,
step: Step,
aio_task_primary_key: str,
artifact_id: str,
artifact_type: ArtifactType,
uri: str,
step_id: str | None = None,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_thought_id: str | None = None,
observer_cruise_id: str | None = None,
organization_id: str | None = None,
data: bytes | None = None,
path: str | None = None,
) -> str:
# TODO (kerem): Which is better?
# current: (disadvantage: we create the artifact_id UUID here)
# 1. generate artifact_id UUID here
# 2. build uri with artifact_id, step_id, task_id, artifact_type
# 3. create artifact in db using artifact_id, step_id, task_id, artifact_type, uri
# 4. store artifact in storage
# alternative: (disadvantage: two db calls)
# 1. create artifact in db without the URI
# 2. build uri with artifact_id, step_id, task_id, artifact_type
# 3. update artifact in db with the URI
# 4. store artifact in storage
if data is None and path is None:
raise ValueError("Either data or path must be provided to create an artifact.")
if data and path:
raise ValueError("Both data and path cannot be provided to create an artifact.")
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_uri(artifact_id, step, artifact_type)
artifact = await app.DATABASE.create_artifact(
artifact_id,
step.step_id,
step.task_id,
artifact_type,
uri,
organization_id=step.organization_id,
step_id=step_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_thought_id=observer_thought_id,
observer_cruise_id=observer_cruise_id,
organization_id=organization_id,
)
if data:
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
self.upload_aiotasks_map[step.task_id].append(aio_task)
self.upload_aiotasks_map[aio_task_primary_key].append(aio_task)
elif path:
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact_from_path(artifact, path))
self.upload_aiotasks_map[step.task_id].append(aio_task)
self.upload_aiotasks_map[aio_task_primary_key].append(aio_task)

return artifact_id

async def update_artifact_data(self, artifact_id: str | None, organization_id: str | None, data: bytes) -> None:
async def create_artifact(
self,
step: Step,
artifact_type: ArtifactType,
data: bytes | None = None,
path: str | None = None,
) -> str:
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_uri(artifact_id, step, artifact_type)
return await self._create_artifact(
aio_task_primary_key=step.task_id,
artifact_id=artifact_id,
artifact_type=artifact_type,
uri=uri,
step_id=step.step_id,
task_id=step.task_id,
organization_id=step.organization_id,
data=data,
path=path,
)

async def create_observer_thought_artifact(
self,
observer_thought: ObserverThought,
artifact_type: ArtifactType,
data: bytes | None = None,
path: str | None = None,
) -> str:
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_observer_thought_uri(artifact_id, observer_thought, artifact_type)
return await self._create_artifact(
aio_task_primary_key=observer_thought.observer_cruise_id,
artifact_id=artifact_id,
artifact_type=artifact_type,
uri=uri,
observer_thought_id=observer_thought.observer_thought_id,
observer_cruise_id=observer_thought.observer_cruise_id,
organization_id=observer_thought.organization_id,
data=data,
path=path,
)

async def update_artifact_data(
self,
artifact_id: str | None,
organization_id: str | None,
data: bytes,
primary_key: Literal["task_id", "observer_thought_id"] = "task_id",
) -> None:
if not artifact_id or not organization_id:
return None
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id)
if not artifact:
return
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
self.upload_aiotasks_map[artifact.task_id].append(aio_task)
if primary_key == "task_id":
if not artifact.task_id:
raise ValueError("Task ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.task_id].append(aio_task)
elif primary_key == "observer_thought_id":
if not artifact.observer_thought_id:
raise ValueError("Observer Thought ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.observer_thought_id].append(aio_task)

async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await app.STORAGE.retrieve_artifact(artifact)
Expand Down
34 changes: 9 additions & 25 deletions skyvern/forge/sdk/artifact/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,14 @@ class Artifact(BaseModel):
def serialize_datetime_to_isoformat(self, value: datetime) -> str:
return value.isoformat()

artifact_id: str = Field(
...,
description="The ID of the task artifact.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
task_id: str = Field(
...,
description="The ID of the task this artifact belongs to.",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
)
step_id: str = Field(
...,
description="The ID of the task step this artifact belongs to.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
artifact_type: ArtifactType = Field(
...,
description="The type of the artifact.",
examples=["screenshot"],
)
uri: str = Field(
...,
description="The URI of the artifact.",
examples=["/Users/skyvern/hello/world.png"],
)
artifact_id: str
artifact_type: ArtifactType
uri: str
task_id: str | None = None
step_id: str | None = None
workflow_run_id: str | None = None
workflow_run_block_id: str | None = None
observer_cruise_id: str | None = None
observer_thought_id: str | None = None
signed_url: str | None = None
organization_id: str | None = None
13 changes: 13 additions & 0 deletions skyvern/forge/sdk/artifact/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought

# TODO: This should be a part of the ArtifactType model
FILE_EXTENTSION_MAP: dict[ArtifactType, str] = {
Expand Down Expand Up @@ -33,6 +34,18 @@ class BaseStorage(ABC):
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
pass

@abstractmethod
def build_observer_thought_uri(
self, artifact_id: str, observer_thought: ObserverThought, artifact_type: ArtifactType
) -> str:
pass

@abstractmethod
def build_observer_cruise_uri(
self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType
) -> str:
pass

@abstractmethod
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
pass
Expand Down
13 changes: 13 additions & 0 deletions skyvern/forge/sdk/artifact/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought

LOG = structlog.get_logger()

Expand All @@ -23,6 +24,18 @@ def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

def build_observer_thought_uri(
self, artifact_id: str, observer_thought: ObserverThought, artifact_type: ArtifactType
) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{settings.ENV}/observers/{observer_thought.observer_cruise_id}/{observer_thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

def build_observer_cruise_uri(
self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType
) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{settings.ENV}/observers/{observer_cruise.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
file_path = None
try:
Expand Down
13 changes: 13 additions & 0 deletions skyvern/forge/sdk/artifact/storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought


class S3Storage(BaseStorage):
Expand All @@ -26,6 +27,18 @@ def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"s3://{self.bucket}/{settings.ENV}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

def build_observer_thought_uri(
self, artifact_id: str, observer_thought: ObserverThought, artifact_type: ArtifactType
) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"s3://{self.bucket}/{settings.ENV}/observers/{observer_thought.observer_cruise_id}/{observer_thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

def build_observer_cruise_uri(
self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType
) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"s3://{self.bucket}/{settings.ENV}/observers/{observer_cruise.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
await self.async_client.upload_file(artifact.uri, data)

Expand Down
16 changes: 12 additions & 4 deletions skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,28 @@ async def create_step(
async def create_artifact(
self,
artifact_id: str,
step_id: str,
task_id: str,
artifact_type: str,
uri: str,
step_id: str | None = None,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_cruise_id: str | None = None,
observer_thought_id: str | None = None,
organization_id: str | None = None,
) -> Artifact:
try:
async with self.Session() as session:
new_artifact = ArtifactModel(
artifact_id=artifact_id,
task_id=task_id,
step_id=step_id,
artifact_type=artifact_type,
uri=uri,
task_id=task_id,
step_id=step_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_cruise_id=observer_cruise_id,
observer_thought_id=observer_thought_id,
organization_id=organization_id,
)
session.add(new_artifact)
Expand Down
4 changes: 2 additions & 2 deletions skyvern/forge/sdk/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class WorkflowRunBlockModel(Base):
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)


class ObserverCruise(Base):
class ObserverCruiseModel(Base):
__tablename__ = "observer_cruises"

observer_cruise_id = Column(String, primary_key=True, default=generate_observer_cruise_id)
Expand All @@ -516,7 +516,7 @@ class ObserverCruise(Base):
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=True)


class ObserverThought(Base):
class ObserverThoughtModel(Base):
__tablename__ = "observer_thoughts"

observer_thought_id = Column(String, primary_key=True, default=generate_observer_thought_id)
Expand Down
4 changes: 4 additions & 0 deletions skyvern/forge/sdk/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = Fal
uri=artifact_model.uri,
task_id=artifact_model.task_id,
step_id=artifact_model.step_id,
workflow_run_id=artifact_model.workflow_run_id,
workflow_run_block_id=artifact_model.workflow_run_block_id,
observer_cruise_id=artifact_model.observer_cruise_id,
observer_thought_id=artifact_model.observer_thought_id,
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
organization_id=artifact_model.organization_id,
Expand Down
44 changes: 44 additions & 0 deletions skyvern/forge/sdk/schemas/observers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from datetime import datetime
from enum import StrEnum

from pydantic import BaseModel, ConfigDict


class ObserverCruiseStatus(StrEnum):
created = "created"
queued = "queued"
running = "running"
failed = "failed"
terminated = "terminated"
canceled = "canceled"
timed_out = "timed_out"
completed = "completed"


class ObserverCruise(BaseModel):
model_config = ConfigDict(from_attributes=True)

observer_cruise_id: str
status: ObserverCruiseStatus
organization_id: str | None = None
workflow_run_id: str | None = None
workflow_id: str | None = None

created_at: datetime
modified_at: datetime


class ObserverThought(BaseModel):
observer_thought_id: str
observer_cruise_id: str
organization_id: str | None = None
workflow_run_id: str | None = None
workflow_run_block_id: str | None = None
workflow_id: str | None = None
user_input: str | None = None
observation: str | None = None
thought: str | None = None
answer: str | None = None

created_at: datetime
modified_at: datetime

0 comments on commit 127f25c

Please sign in to comment.