Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create cruise related artifact in cruise api #1355

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""add created_at and modified_at to observer tables;

Revision ID: c502ecf908c6
Revises: dc2a8facf0d7
Create Date: 2024-12-09 00:40:30.098534+00:00

"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "c502ecf908c6"
down_revision: Union[str, None] = "dc2a8facf0d7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("observer_cruises", sa.Column("created_at", sa.DateTime(), nullable=False))
wintonzheng marked this conversation as resolved.
Show resolved Hide resolved
op.add_column("observer_cruises", sa.Column("modified_at", sa.DateTime(), nullable=False))
op.add_column("observer_thoughts", sa.Column("created_at", sa.DateTime(), nullable=False))
op.add_column("observer_thoughts", sa.Column("modified_at", sa.DateTime(), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("observer_thoughts", "modified_at")
op.drop_column("observer_thoughts", "created_at")
op.drop_column("observer_cruises", "modified_at")
op.drop_column("observer_cruises", "created_at")
# ### end Alembic commands ###
152 changes: 79 additions & 73 deletions skyvern/forge/sdk/api/llm/api_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought

LOG = structlog.get_logger()

Expand Down Expand Up @@ -58,6 +59,8 @@ def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
async def llm_api_handler_with_router_and_fallback(
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverThought | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
Expand All @@ -76,32 +79,29 @@ async def llm_api_handler_with_router_and_fallback(
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)

if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)

await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)

await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
LOG.info("LLM API call successful", llm_key=llm_key, model=llm_config.model_name)
Expand All @@ -122,12 +122,14 @@ async def llm_api_handler_with_router_and_fallback(
)
raise LLMProviderError(llm_key) from e

await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
Expand All @@ -140,12 +142,13 @@ async def llm_api_handler_with_router_and_fallback(
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
return parsed_response

return llm_api_handler_with_router_and_fallback
Expand All @@ -162,6 +165,8 @@ def get_llm_api_handler(llm_key: str, base_parameters: dict[str, Any] | None = N
async def llm_api_handler(
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverThought | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
Expand All @@ -173,37 +178,33 @@ async def llm_api_handler(
if llm_config.litellm_params: # type: ignore
active_parameters.update(llm_config.litellm_params) # type: ignore

if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)

# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
if not llm_config.supports_vision:
screenshots = None

messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
}
).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
t_llm_request = time.perf_counter()
try:
# TODO (kerem): add a timeout to this call
Expand Down Expand Up @@ -231,12 +232,16 @@ async def llm_api_handler(
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e

await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)

if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
Expand All @@ -249,12 +254,13 @@ async def llm_api_handler(
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
return parsed_response

return llm_api_handler
Expand Down
3 changes: 3 additions & 0 deletions skyvern/forge/sdk/api/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from litellm import AllowedFailsPolicy

from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise
from skyvern.forge.sdk.settings_manager import SettingsManager


Expand Down Expand Up @@ -78,6 +79,8 @@ def __call__(
self,
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverCruise | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]: ...
74 changes: 72 additions & 2 deletions skyvern/forge/sdk/artifact/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.id import generate_artifact_id
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverThought
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought

LOG = structlog.get_logger(__name__)

Expand Down Expand Up @@ -103,12 +103,78 @@ async def create_observer_thought_artifact(
path=path,
)

async def create_observer_cruise_artifact(
self,
observer_cruise: ObserverCruise,
artifact_type: ArtifactType,
data: bytes | None = None,
path: str | None = None,
) -> str:
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_observer_cruise_uri(artifact_id, observer_cruise, artifact_type)
return await self._create_artifact(
aio_task_primary_key=observer_cruise.observer_cruise_id,
artifact_id=artifact_id,
artifact_type=artifact_type,
uri=uri,
observer_cruise_id=observer_cruise.observer_cruise_id,
organization_id=observer_cruise.organization_id,
data=data,
path=path,
)

async def create_llm_artifact(
self,
data: bytes,
artifact_type: ArtifactType,
screenshots: list[bytes] | None = None,
step: Step | None = None,
observer_thought: ObserverThought | None = None,
observer_cruise: ObserverCruise | None = None,
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a check to ensure that at least one of step, observer_cruise, or observer_thought is provided. If none are provided, raise a ValueError to prevent silent failures.

if step:
await self.create_artifact(
step=step,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
elif observer_cruise:
await self.create_observer_cruise_artifact(
observer_cruise=observer_cruise,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_observer_cruise_artifact(
observer_cruise=observer_cruise,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
elif observer_thought:
await self.create_observer_thought_artifact(
observer_thought=observer_thought,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_observer_thought_artifact(
observer_thought=observer_thought,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)

async def update_artifact_data(
self,
artifact_id: str | None,
organization_id: str | None,
data: bytes,
primary_key: Literal["task_id", "observer_thought_id"] = "task_id",
primary_key: Literal["task_id", "observer_thought_id", "observer_cruise_id"] = "task_id",
) -> None:
if not artifact_id or not organization_id:
return None
Expand All @@ -125,6 +191,10 @@ async def update_artifact_data(
if not artifact.observer_thought_id:
raise ValueError("Observer Thought ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.observer_thought_id].append(aio_task)
elif primary_key == "observer_cruise_id":
if not artifact.observer_cruise_id:
raise ValueError("Observer Cruise ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.observer_cruise_id].append(aio_task)

async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await app.STORAGE.retrieve_artifact(artifact)
Expand Down
Loading
Loading