Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Observer artifact creation #1345

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 79 additions & 22 deletions skyvern/forge/sdk/artifact/manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio
import time
from collections import defaultdict
from typing import Literal

import structlog

from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.id import generate_artifact_id
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverThought

LOG = structlog.get_logger(__name__)

Expand All @@ -16,58 +18,113 @@ class ArtifactManager:
# task_id -> list of aio_tasks for uploading artifacts
upload_aiotasks_map: dict[str, list[asyncio.Task[None]]] = defaultdict(list)

async def create_artifact(
async def _create_artifact(
self,
step: Step,
aio_task_primary_key: str,
artifact_id: str,
artifact_type: ArtifactType,
uri: str,
step_id: str | None = None,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_thought_id: str | None = None,
observer_cruise_id: str | None = None,
organization_id: str | None = None,
data: bytes | None = None,
path: str | None = None,
) -> str:
# TODO (kerem): Which is better?
# current: (disadvantage: we create the artifact_id UUID here)
# 1. generate artifact_id UUID here
# 2. build uri with artifact_id, step_id, task_id, artifact_type
# 3. create artifact in db using artifact_id, step_id, task_id, artifact_type, uri
# 4. store artifact in storage
# alternative: (disadvantage: two db calls)
# 1. create artifact in db without the URI
# 2. build uri with artifact_id, step_id, task_id, artifact_type
# 3. update artifact in db with the URI
# 4. store artifact in storage
if data is None and path is None:
raise ValueError("Either data or path must be provided to create an artifact.")
if data and path:
raise ValueError("Both data and path cannot be provided to create an artifact.")
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_uri(artifact_id, step, artifact_type)
artifact = await app.DATABASE.create_artifact(
artifact_id,
step.step_id,
step.task_id,
artifact_type,
uri,
organization_id=step.organization_id,
step_id=step_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_thought_id=observer_thought_id,
observer_cruise_id=observer_cruise_id,
organization_id=organization_id,
)
if data:
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
self.upload_aiotasks_map[step.task_id].append(aio_task)
self.upload_aiotasks_map[aio_task_primary_key].append(aio_task)
elif path:
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact_from_path(artifact, path))
self.upload_aiotasks_map[step.task_id].append(aio_task)
self.upload_aiotasks_map[aio_task_primary_key].append(aio_task)

return artifact_id

async def update_artifact_data(self, artifact_id: str | None, organization_id: str | None, data: bytes) -> None:
async def create_artifact(
self,
step: Step,
artifact_type: ArtifactType,
data: bytes | None = None,
path: str | None = None,
) -> str:
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_uri(artifact_id, step, artifact_type)
return await self._create_artifact(
aio_task_primary_key=step.task_id,
artifact_id=artifact_id,
artifact_type=artifact_type,
uri=uri,
step_id=step.step_id,
task_id=step.task_id,
organization_id=step.organization_id,
data=data,
path=path,
)

async def create_observer_thought_artifact(
self,
observer_thought: ObserverThought,
artifact_type: ArtifactType,
data: bytes | None = None,
path: str | None = None,
) -> str:
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_observer_thought_uri(artifact_id, observer_thought, artifact_type)
return await self._create_artifact(
aio_task_primary_key=observer_thought.observer_cruise_id,
artifact_id=artifact_id,
artifact_type=artifact_type,
uri=uri,
observer_thought_id=observer_thought.observer_thought_id,
observer_cruise_id=observer_thought.observer_cruise_id,
organization_id=observer_thought.organization_id,
data=data,
path=path,
)

async def update_artifact_data(
self,
artifact_id: str | None,
organization_id: str | None,
data: bytes,
primary_key: Literal["task_id", "observer_thought_id"] = "task_id",
) -> None:
if not artifact_id or not organization_id:
return None
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id)
if not artifact:
return
# Fire and forget
aio_task = asyncio.create_task(app.STORAGE.store_artifact(artifact, data))
self.upload_aiotasks_map[artifact.task_id].append(aio_task)
if primary_key == "task_id":
if not artifact.task_id:
raise ValueError("Task ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.task_id].append(aio_task)
elif primary_key == "observer_thought_id":
if not artifact.observer_thought_id:
raise ValueError("Observer Thought ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.observer_thought_id].append(aio_task)

async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await app.STORAGE.retrieve_artifact(artifact)
Expand Down
34 changes: 9 additions & 25 deletions skyvern/forge/sdk/artifact/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,14 @@ class Artifact(BaseModel):
def serialize_datetime_to_isoformat(self, value: datetime) -> str:
return value.isoformat()

artifact_id: str = Field(
...,
description="The ID of the task artifact.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
task_id: str = Field(
...,
description="The ID of the task this artifact belongs to.",
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
)
step_id: str = Field(
...,
description="The ID of the task step this artifact belongs to.",
examples=["6bb1801a-fd80-45e8-899a-4dd723cc602e"],
)
artifact_type: ArtifactType = Field(
...,
description="The type of the artifact.",
examples=["screenshot"],
)
uri: str = Field(
...,
description="The URI of the artifact.",
examples=["/Users/skyvern/hello/world.png"],
)
artifact_id: str
artifact_type: ArtifactType
uri: str
task_id: str | None = None
step_id: str | None = None
workflow_run_id: str | None = None
workflow_run_block_id: str | None = None
observer_cruise_id: str | None = None
observer_thought_id: str | None = None
signed_url: str | None = None
organization_id: str | None = None
13 changes: 13 additions & 0 deletions skyvern/forge/sdk/artifact/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought

# TODO: This should be a part of the ArtifactType model
FILE_EXTENTSION_MAP: dict[ArtifactType, str] = {
Expand Down Expand Up @@ -33,6 +34,18 @@ class BaseStorage(ABC):
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
pass

@abstractmethod
def build_observer_thought_uri(
self, artifact_id: str, observer_thought: ObserverThought, artifact_type: ArtifactType
) -> str:
pass

@abstractmethod
def build_observer_cruise_uri(
self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType
) -> str:
pass

@abstractmethod
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
pass
Expand Down
13 changes: 13 additions & 0 deletions skyvern/forge/sdk/artifact/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought

LOG = structlog.get_logger()

Expand All @@ -23,6 +24,18 @@ def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

def build_observer_thought_uri(
self, artifact_id: str, observer_thought: ObserverThought, artifact_type: ArtifactType
) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{settings.ENV}/observers/{observer_thought.observer_cruise_id}/{observer_thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

def build_observer_cruise_uri(
self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType
) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{settings.ENV}/observers/{observer_cruise.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
file_path = None
try:
Expand Down
13 changes: 13 additions & 0 deletions skyvern/forge/sdk/artifact/storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought


class S3Storage(BaseStorage):
Expand All @@ -26,6 +27,18 @@ def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"s3://{self.bucket}/{settings.ENV}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

def build_observer_thought_uri(
self, artifact_id: str, observer_thought: ObserverThought, artifact_type: ArtifactType
) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"s3://{self.bucket}/{settings.ENV}/observers/{observer_thought.observer_cruise_id}/{observer_thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

def build_observer_cruise_uri(
self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType
) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"s3://{self.bucket}/{settings.ENV}/observers/{observer_cruise.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
await self.async_client.upload_file(artifact.uri, data)

Expand Down
16 changes: 12 additions & 4 deletions skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,28 @@ async def create_step(
async def create_artifact(
self,
artifact_id: str,
step_id: str,
task_id: str,
artifact_type: str,
uri: str,
step_id: str | None = None,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_cruise_id: str | None = None,
observer_thought_id: str | None = None,
organization_id: str | None = None,
) -> Artifact:
try:
async with self.Session() as session:
new_artifact = ArtifactModel(
artifact_id=artifact_id,
task_id=task_id,
step_id=step_id,
artifact_type=artifact_type,
uri=uri,
task_id=task_id,
step_id=step_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_cruise_id=observer_cruise_id,
observer_thought_id=observer_thought_id,
organization_id=organization_id,
)
session.add(new_artifact)
Expand Down
4 changes: 2 additions & 2 deletions skyvern/forge/sdk/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class WorkflowRunBlockModel(Base):
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)


class ObserverCruise(Base):
class ObserverCruiseModel(Base):
__tablename__ = "observer_cruises"

observer_cruise_id = Column(String, primary_key=True, default=generate_observer_cruise_id)
Expand All @@ -516,7 +516,7 @@ class ObserverCruise(Base):
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=True)


class ObserverThought(Base):
class ObserverThoughtModel(Base):
__tablename__ = "observer_thoughts"

observer_thought_id = Column(String, primary_key=True, default=generate_observer_thought_id)
Expand Down
4 changes: 4 additions & 0 deletions skyvern/forge/sdk/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = Fal
uri=artifact_model.uri,
task_id=artifact_model.task_id,
step_id=artifact_model.step_id,
workflow_run_id=artifact_model.workflow_run_id,
workflow_run_block_id=artifact_model.workflow_run_block_id,
observer_cruise_id=artifact_model.observer_cruise_id,
observer_thought_id=artifact_model.observer_thought_id,
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
organization_id=artifact_model.organization_id,
Expand Down
44 changes: 44 additions & 0 deletions skyvern/forge/sdk/schemas/observers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from datetime import datetime
from enum import StrEnum

from pydantic import BaseModel, ConfigDict


class ObserverCruiseStatus(StrEnum):
created = "created"
queued = "queued"
running = "running"
failed = "failed"
terminated = "terminated"
canceled = "canceled"
timed_out = "timed_out"
completed = "completed"


class ObserverCruise(BaseModel):
model_config = ConfigDict(from_attributes=True)

observer_cruise_id: str
status: ObserverCruiseStatus
organization_id: str | None = None
workflow_run_id: str | None = None
workflow_id: str | None = None

created_at: datetime
modified_at: datetime


class ObserverThought(BaseModel):
observer_thought_id: str
observer_cruise_id: str
organization_id: str | None = None
workflow_run_id: str | None = None
workflow_run_block_id: str | None = None
workflow_id: str | None = None
user_input: str | None = None
observation: str | None = None
thought: str | None = None
answer: str | None = None

created_at: datetime
modified_at: datetime
Loading