Skip to content

Commit

Permalink
add provider checking based on model name and provider
Browse files Browse the repository at this point in the history
  • Loading branch information
PCSwingle committed Apr 18, 2024
1 parent e51fdfb commit 1edafc3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 19 deletions.
3 changes: 2 additions & 1 deletion mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
validate_and_format_path,
)
from mentat.interval import parse_intervals, split_intervals_from_path
from mentat.llm_api_handler import get_max_tokens
from mentat.llm_api_handler import api_guard, get_max_tokens
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream
from mentat.utils import get_relative_path, mentat_dir_path
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
self.include_files: Dict[Path, List[CodeFeature]] = {}
self.ignore_files: Set[Path] = set()

@api_guard
async def refresh_daemon(self):
"""Call before interacting with context to ensure daemon is up to date."""

Expand Down
78 changes: 60 additions & 18 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from dotenv import load_dotenv
from openai.types.chat.completion_create_params import ResponseFormat
from spice import EmbeddingResponse, Spice, SpiceMessage, SpiceResponse, StreamingSpiceResponse, TranscriptionResponse
from spice.errors import APIConnectionError, NoAPIKeyError
from spice.errors import APIConnectionError, AuthenticationError, InvalidProviderError, NoAPIKeyError
from spice.models import WHISPER_1
from spice.providers import OPEN_AI
from spice.spice import UnknownModelError, get_model_from_name
from spice.spice import UnknownModelError, get_model_from_name, get_provider_from_name

from mentat.errors import MentatError, ReturnToUser
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -58,13 +57,20 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> RetType:
assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!"
try:
return await func(*args, **kwargs)
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and try again.")
raise MentatError("API connection error: Check your internet connection and try again.")
except UnknownModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()
except InvalidProviderError:
SESSION_CONTEXT.get().stream.send(
"Unknown provider. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return async_wrapper # pyright: ignore[reportReturnType]
else:
Expand All @@ -73,13 +79,20 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> RetType:
assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!"
try:
return func(*args, **kwargs)
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and try again.")
raise MentatError("API connection error: Check your internet connection and try again.")
except UnknownModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()
except InvalidProviderError:
SESSION_CONTEXT.get().stream.send(
"Unknown provider. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return sync_wrapper

Expand Down Expand Up @@ -142,19 +155,48 @@ async def initialize_client(self):
if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"):
load_dotenv()

try:
self.spice.load_provider(OPEN_AI)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

ctx.stream.send(
"No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ["OPENAI_API_KEY"] = key
provider = get_model_from_name(ctx.config.model).provider
if ctx.config.provider is not None:
try:
provider = get_provider_from_name(ctx.config.provider)
except InvalidProviderError:
ctx.stream.send(
f"Unknown provider {ctx.config.provider}. Use /config provider <provider> to set your provider.",
style="warning",
)
elif provider is None:
ctx.stream.send(f"Unknown model {ctx.config.model}. Use /config provider <provider> to set your provider.")

if provider is not None:
try:
self.spice.load_provider(provider)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

match provider.name:
case "open_ai":
env_variable = "OPENAI_API_KEY"
case "anthropic":
env_variable = "ANTHROPIC_API_KEY"
case "azure":
if os.getenv("AZURE_OPENAI_ENDPOINT") is None:
raise MentatError(
f"No AZURE_OPENAI_ENDPOINT detected. Create a .env file in ~/.mentat/.env or in your workspace root with your Azure endpoint."
)
env_variable = "AZURE_OPENAI_KEY"
case _:
raise MentatError(
f"No api key detected for provider {provider.name}. Create a .env file in ~/.mentat/.env or in your workspace root with your api key"
)

ctx.stream.send(
f"No {provider.name} api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ[env_variable] = key

@overload
async def call_llm_api(
Expand Down

0 comments on commit 1edafc3

Please sign in to comment.