diff --git a/docker-compose.yml b/docker-compose.yml index 3d1a16eed..7f2ab3c40 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,3 @@ -version: '3.8' - services: postgres: image: postgres:14-alpine @@ -23,7 +21,6 @@ services: skyvern: image: public.ecr.aws/skyvern/skyvern:latest restart: on-failure - # comment out if you want to externally call skyvern API ports: - 8000:8000 volumes: @@ -31,21 +28,28 @@ services: - ./videos:/data/videos - ./har:/data/har - ./.streamlit:/app/.streamlit + - ./skyvern/config.py:/app/skyvern/config.py + - ./skyvern/forge/sdk/api/llm/models.py:/app/skyvern/forge/sdk/api/llm/models.py + - ./skyvern/forge/sdk/api/llm/config_registry.py:/app/skyvern/forge/sdk/api/llm/config_registry.py + - ./skyvern/forge/sdk/api/llm/api_handler_factory.py:/app/skyvern/forge/sdk/api/llm/api_handler_factory.py environment: - DATABASE_STRING=postgresql+psycopg://skyvern:skyvern@postgres:5432/skyvern - BROWSER_TYPE=chromium-headful - - ENABLE_OPENAI=true - - OPENAI_API_KEY= - # If you want to use other LLM provider, like azure and anthropic: - # - ENABLE_ANTHROPIC=true - # - LLM_KEY=ANTHROPIC_CLAUDE3_OPUS - # - ANTHROPIC_API_KEY= - # - ENABLE_AZURE=true - # - LLM_KEY=AZURE_OPENAI - # - AZURE_DEPLOYMENT= - # - AZURE_API_KEY= - # - AZURE_API_BASE= - # - AZURE_API_VERSION= + # Enable Ollama + - ENABLE_OLLAMA=true + - LLM_KEY=OLLAMA_TEXT + - OLLAMA_API_BASE=http://192.168.1.3:11434 + - OLLAMA_TEXT_MODEL=llama3.1:70b + - OLLAMA_VISION_MODEL=moondream:latest + - OLLAMA_CONTEXT_WINDOW=8192 + # Disable other providers + - ENABLE_OPENAI=false + - ENABLE_ANTHROPIC=false + - ENABLE_AZURE=false + - ENABLE_AZURE_GPT4O_MINI=false + - ENABLE_BEDROCK=false + - LITELLM_VERBOSE=true + - LOG_LEVEL=DEBUG depends_on: postgres: condition: service_healthy diff --git a/skyvern/config.py b/skyvern/config.py index fb6356f85..ed577b169 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -65,8 +65,8 @@ class Settings(BaseSettings): # browser settings BROWSER_LOCALE: str = "en-US" BROWSER_TIMEZONE: str = "America/New_York" - BROWSER_WIDTH: int = 1920 - BROWSER_HEIGHT: int = 1080 + BROWSER_WIDTH: int = 1366 + BROWSER_HEIGHT: int = 768 # Workflow constant parameters WORKFLOW_DOWNLOAD_DIRECTORY_PARAMETER_KEY: str = "SKYVERN_DOWNLOAD_DIRECTORY" @@ -93,7 +93,7 @@ class Settings(BaseSettings): LLM_KEY: str = "OPENAI_GPT4O" SECONDARY_LLM_KEY: str | None = None # COMMON - LLM_CONFIG_TIMEOUT: int = 300 + LLM_CONFIG_TIMEOUT: int = 60 LLM_CONFIG_MAX_TOKENS: int = 4096 LLM_CONFIG_TEMPERATURE: float = 0 # LLM PROVIDER SPECIFIC @@ -125,6 +125,14 @@ class Settings(BaseSettings): SVG_MAX_LENGTH: int = 100000 + # Add these new settings for Ollama + ENABLE_OLLAMA: bool = False + OLLAMA_API_BASE: str = "http://192.168.1.3:11434" + OLLAMA_MODEL: str = "llama2" # Default model, can be changed + OLLAMA_TEXT_MODEL: str = "command-r:latest" # Model for text-only requests + OLLAMA_VISION_MODEL: str = "llava:34b" # Model for vision requests + OLLAMA_CONTEXT_WINDOW: int = 8192 + def is_cloud_environment(self) -> bool: """ :return: True if env is not local, else False diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index d73519b55..c5523d87f 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -171,92 +171,230 @@ async def llm_api_handler( parameters = LLMAPIHandlerFactory.get_api_parameters() active_parameters.update(parameters) - if llm_config.litellm_params: # type: ignore - active_parameters.update(llm_config.litellm_params) # type: ignore + if llm_config.litellm_params: + active_parameters.update(llm_config.litellm_params) - if step: - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.LLM_PROMPT, - data=prompt.encode("utf-8"), - ) - for screenshot in screenshots or []: + # Get timeout and max_retries from settings + timeout = active_parameters.pop('timeout', SettingsManager.get_settings().LLM_CONFIG_TIMEOUT) + max_retries = active_parameters.pop('max_retries', 3) + + # Handle Ollama-specific options + ollama_options = { + "num_ctx": SettingsManager.get_settings().OLLAMA_CONTEXT_WINDOW, + "temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE, + "num_predict": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS, + } + + # Format messages for Ollama + async def format_ollama_messages(messages: list) -> str: + formatted_prompt = "" + for message in messages: + if message["role"] == "user": + content = message["content"] + if isinstance(content, list): + # Handle multi-modal content + for item in content: + if item["type"] == "text": + formatted_prompt += f"{item['text']}\n" + else: + formatted_prompt += f"{content}\n" + return formatted_prompt.strip() + + try: + if llm_config.model_name.startswith("ollama/") and screenshots: + # Use vision model for image-related prompts + vision_config = LLMConfigRegistry.get_config("OLLAMA_VISION") + text_config = LLMConfigRegistry.get_config("OLLAMA_TEXT") + + # First, process the image with vision model + vision_messages = await llm_messages_builder(prompt, screenshots, vision_config.add_assistant_prefix) + vision_prompt = await format_ollama_messages(vision_messages) + + LOG.info("Calling Vision LLM API", model=vision_config.model_name) + vision_response = await litellm.acompletion( + model=vision_config.model_name, + messages=[{"role": "user", "content": vision_prompt}], + timeout=timeout, + max_retries=max_retries, + options=ollama_options + ) + + # Log the raw vision response + LOG.info("Raw Vision LLM Response", + model=vision_config.model_name, + response=vision_response.choices[0].message.content if vision_response.choices else "No response") + + # Extract vision analysis from response + vision_analysis = vision_response.choices[0].message.content + + # Then, process with text model + text_prompt = f"Vision Analysis:\n{vision_analysis}\n\nOriginal Prompt:\n{prompt}" + text_messages = await llm_messages_builder(text_prompt, None, text_config.add_assistant_prefix) + formatted_text_prompt = await format_ollama_messages(text_messages) + + LOG.info("Calling Text LLM API", model=text_config.model_name) + response = await litellm.acompletion( + model=text_config.model_name, + messages=[{"role": "user", "content": formatted_text_prompt}], + timeout=timeout, + max_retries=max_retries, + options=ollama_options + ) + + # Log the raw text response + LOG.info("Raw Text LLM Response", + model=text_config.model_name, + response=response.choices[0].message.content if response.choices else "No response") + else: + # Use regular handling for non-Ollama models or text-only requests + messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) + + if llm_config.model_name.startswith("ollama/"): + # Format messages for Ollama + formatted_prompt = await format_ollama_messages(messages) + LOG.info("Full prompt being sent to LLM", + llm_key=llm_key, + model=llm_config.model_name, + messages=formatted_prompt) + response = await litellm.acompletion( + model=llm_config.model_name, + messages=[{"role": "user", "content": formatted_prompt}], + timeout=timeout, + max_retries=max_retries, + options=ollama_options + ) + + # Log the raw response + LOG.info("Raw LLM Response", + model=llm_config.model_name, + response=response.choices[0].message.content if response.choices else "No response") + else: + # Regular handling for other models + LOG.info("Full prompt being sent to LLM", + llm_key=llm_key, + model=llm_config.model_name, + messages=json.dumps(messages, indent=2)) + response = await litellm.acompletion( + model=llm_config.model_name, + messages=messages, + timeout=timeout, + max_retries=max_retries, + **active_parameters + ) + + # Log the raw response + LOG.info("Raw LLM Response", + model=llm_config.model_name, + response=response.choices[0].message.content if response.choices else "No response") + + # Create artifact for raw LLM response + if step: await app.ARTIFACT_MANAGER.create_artifact( step=step, - artifact_type=ArtifactType.SCREENSHOT_LLM, - data=screenshot, + artifact_type=ArtifactType.LLM_RESPONSE, + data=response.model_dump_json(indent=2).encode("utf-8"), ) - # TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place? - if not llm_config.supports_vision: - screenshots = None + # Handle cost calculation if needed + if not llm_config.skip_cost_calculation: + llm_cost = litellm.completion_cost(completion_response=response) + prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) + completion_tokens = response.get("usage", {}).get("completion_tokens", 0) + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=llm_cost, + incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, + incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, + ) + + # Process the response content + if response and response.choices: + content = response.choices[0].message.content + + # Clean up the response content and extract JSON if present + def extract_json_from_content(content: str) -> str: + content = content.strip() + # Look for JSON code block + if "```json" in content: + # Extract content between ```json and ``` + parts = content.split("```json") + if len(parts) > 1: + json_part = parts[1].split("```")[0] + return json_part.strip() + # Look for just ``` code block + elif "```" in content: + # Extract content between ``` and ``` + parts = content.split("```") + if len(parts) > 1: + return parts[1].strip() + # If no code blocks found, return the original content + return content + + try: + # Parse JSON if expected + if "JSON" in prompt: + # Extract and clean JSON content + json_content = extract_json_from_content(content) + LOG.debug("Attempting to parse JSON response", content=json_content) + + try: + parsed_content = json.loads(json_content) + except json.JSONDecodeError: + # If JSON parsing fails, try to clean the content further + cleaned_content = json_content.strip() + if cleaned_content.startswith('json'): + cleaned_content = cleaned_content.split('\n', 1)[1] + LOG.debug("Retrying JSON parse with cleaned content", content=cleaned_content) + parsed_content = json.loads(cleaned_content) + + # Ensure response has required keys for action parsing + if "actions" not in parsed_content and "action" in parsed_content: + # Handle case where LLM returns single action + parsed_content = {"actions": [parsed_content]} + elif "actions" not in parsed_content and not any(key in parsed_content for key in ["action", "actions"]): + # If no action-related keys found, wrap the entire response + if "confidence_float" in parsed_content and "shape" in parsed_content: + # Special case for SVG shape analysis + return parsed_content + else: + LOG.warning("Response missing actions key, wrapping content", content=parsed_content) + parsed_content = {"actions": [{"action": "UNKNOWN", "data": parsed_content}]} + + # Create artifact for parsed response + if step: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + data=json.dumps(parsed_content, indent=2).encode("utf-8"), + ) + + return parsed_content + + # Handle non-JSON responses + result = {"content": content} + if step: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + data=json.dumps(result, indent=2).encode("utf-8"), + ) + return result + + except json.JSONDecodeError as e: + LOG.error("Failed to parse JSON response", + content=content, + error=str(e), + raw_content=content) + return {"error": "Invalid JSON response", "raw_content": content} + + LOG.error("Empty response from LLM") + return {"error": "Empty response from LLM"} - messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) - if step: - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.LLM_REQUEST, - data=json.dumps( - { - "model": llm_config.model_name, - "messages": messages, - # we're not using active_parameters here because it may contain sensitive information - **parameters, - } - ).encode("utf-8"), - ) - t_llm_request = time.perf_counter() - try: - # TODO (kerem): add a timeout to this call - # TODO (kerem): add a retry mechanism to this call (acompletion_with_retries) - # TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work - LOG.info("Calling LLM API", llm_key=llm_key, model=llm_config.model_name) - response = await litellm.acompletion( - model=llm_config.model_name, - messages=messages, - timeout=SettingsManager.get_settings().LLM_CONFIG_TIMEOUT, - **active_parameters, - ) - LOG.info("LLM API call successful", llm_key=llm_key, model=llm_config.model_name) - except litellm.exceptions.APIError as e: - raise LLMProviderErrorRetryableTask(llm_key) from e - except CancelledError: - t_llm_cancelled = time.perf_counter() - LOG.error( - "LLM request got cancelled", - llm_key=llm_key, - model=llm_config.model_name, - duration=t_llm_cancelled - t_llm_request, - ) - raise LLMProviderError(llm_key) except Exception as e: - LOG.exception("LLM request failed unexpectedly", llm_key=llm_key) + LOG.exception("LLM request failed unexpectedly", llm_key=llm_key, error=str(e)) raise LLMProviderError(llm_key) from e - if step: - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.LLM_RESPONSE, - data=response.model_dump_json(indent=2).encode("utf-8"), - ) - llm_cost = litellm.completion_cost(completion_response=response) - prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = response.get("usage", {}).get("completion_tokens", 0) - await app.DATABASE.update_step( - task_id=step.task_id, - step_id=step.step_id, - organization_id=step.organization_id, - incremental_cost=llm_cost, - incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, - incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, - ) - parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) - if step: - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.LLM_RESPONSE_PARSED, - data=json.dumps(parsed_response, indent=2).encode("utf-8"), - ) - return parsed_response return llm_api_handler diff --git a/skyvern/forge/sdk/api/llm/config_registry.py b/skyvern/forge/sdk/api/llm/config_registry.py index 1507ebabf..b5b856adf 100644 --- a/skyvern/forge/sdk/api/llm/config_registry.py +++ b/skyvern/forge/sdk/api/llm/config_registry.py @@ -43,7 +43,7 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig: return cls._configs[llm_key] -# if none of the LLM providers are enabled, raise an error +# Update the check for enabled providers if not any( [ SettingsManager.get_settings().ENABLE_OPENAI, @@ -51,6 +51,7 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig: SettingsManager.get_settings().ENABLE_AZURE, SettingsManager.get_settings().ENABLE_AZURE_GPT4O_MINI, SettingsManager.get_settings().ENABLE_BEDROCK, + SettingsManager.get_settings().ENABLE_OLLAMA, ] ): raise NoProviderEnabledError() @@ -216,3 +217,35 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig: add_assistant_prefix=False, ), ) + +# Add Ollama configurations +if SettingsManager.get_settings().ENABLE_OLLAMA: + # Register vision model + LLMConfigRegistry.register_config( + "OLLAMA_VISION", + LLMConfig( + f"ollama/{SettingsManager.get_settings().OLLAMA_VISION_MODEL}", + [], # No API key required + litellm_params=LiteLLMParams( + api_base=SettingsManager.get_settings().OLLAMA_API_BASE, + ), + supports_vision=True, + add_assistant_prefix=False, + skip_cost_calculation=True, + ), + ) + + # Register text model + LLMConfigRegistry.register_config( + "OLLAMA_TEXT", + LLMConfig( + f"ollama/{SettingsManager.get_settings().OLLAMA_TEXT_MODEL}", + [], # No API key required + litellm_params=LiteLLMParams( + api_base=SettingsManager.get_settings().OLLAMA_API_BASE, + ), + supports_vision=False, + add_assistant_prefix=False, + skip_cost_calculation=True, + ), + ) diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 015ad5745..5a82dfc8d 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -33,7 +33,8 @@ def get_missing_env_vars(self) -> list[str]: @dataclass(frozen=True) class LLMConfig(LLMConfigBase): - litellm_params: Optional[LiteLLMParams] = field(default=None) + litellm_params: Optional[LiteLLMParams] = None + skip_cost_calculation: bool = False @dataclass(frozen=True)