diff --git a/src/magentic/chat.py b/src/magentic/chat.py index 2e0dbab7..13e4e6aa 100644 --- a/src/magentic/chat.py +++ b/src/magentic/chat.py @@ -90,23 +90,37 @@ def add_assistant_message(self: Self, content: Any) -> Self: """Add an assistant message to the chat.""" return self.add_message(AssistantMessage(content=content)) - def submit(self: Self) -> Self: + def submit(self: Self, num_retries: int | None = None) -> Self: """Request an LLM message to be added to the chat.""" - output_message: AssistantMessage[Any] = self.model.complete( - messages=self._messages, - functions=self._functions, - output_types=self._output_types, - ) - return self.add_message(output_message) + retries = num_retries if num_retries is not None else 0 + for attempt in range(retries + 1): + try: + output_message: AssistantMessage[Any] = self.model.complete( + messages=self._messages, + functions=self._functions, + output_types=self._output_types, + ) + return self.add_message(output_message) + except Exception as e: + if attempt == retries: + raise + self._messages.append(UserMessage(content=f"Error: {str(e)}")) - async def asubmit(self: Self) -> Self: + async def asubmit(self: Self, num_retries: int | None = None) -> Self: """Async version of `submit`.""" - output_message: AssistantMessage[Any] = await self.model.acomplete( - messages=self._messages, - functions=self._functions, - output_types=self._output_types, - ) - return self.add_message(output_message) + retries = num_retries if num_retries is not None else 0 + for attempt in range(retries + 1): + try: + output_message: AssistantMessage[Any] = await self.model.acomplete( + messages=self._messages, + functions=self._functions, + output_types=self._output_types, + ) + return self.add_message(output_message) + except Exception as e: + if attempt == retries: + raise + self._messages.append(UserMessage(content=f"Error: {str(e)}")) def exec_function_call(self: Self) -> Self: """If the last message is a function call, execute it and add the result.""" diff --git a/src/magentic/prompt_function.py b/src/magentic/prompt_function.py index 179c4ab8..efc12ea8 100644 --- a/src/magentic/prompt_function.py +++ b/src/magentic/prompt_function.py @@ -38,6 +38,7 @@ def __init__( functions: list[Callable[..., Any]] | None = None, stop: list[str] | None = None, model: ChatModel | None = None, + num_retries: int | None = None, ): self._signature = inspect.Signature( parameters=parameters, @@ -49,6 +50,7 @@ def __init__( self._model = model self._return_types = list(split_union_type(return_type)) + self._num_retries = num_retries @property def functions(self) -> list[Callable[..., Any]]: @@ -78,13 +80,20 @@ class PromptFunction(BasePromptFunction[P, R], Generic[P, R]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: """Query the LLM with the formatted prompt template.""" - message = self.model.complete( - messages=[UserMessage(content=self.format(*args, **kwargs))], - functions=self._functions, - output_types=self._return_types, - stop=self._stop, - ) - return message.content + retries = self._num_retries if self._num_retries is not None else 0 + for attempt in range(retries + 1): + try: + message = self.model.complete( + messages=[UserMessage(content=self.format(*args, **kwargs))], + functions=self._functions, + output_types=self._return_types, + stop=self._stop, + ) + return message.content + except Exception as e: + if attempt == retries: + raise + self._messages.append(UserMessage(content=f"Error: {str(e)}")) class AsyncPromptFunction(BasePromptFunction[P, R], Generic[P, R]): @@ -92,13 +101,20 @@ class AsyncPromptFunction(BasePromptFunction[P, R], Generic[P, R]): async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: """Asynchronously query the LLM with the formatted prompt template.""" - message = await self.model.acomplete( - messages=[UserMessage(content=self.format(*args, **kwargs))], - functions=self._functions, - output_types=self._return_types, - stop=self._stop, - ) - return message.content + retries = self._num_retries if self._num_retries is not None else 0 + for attempt in range(retries + 1): + try: + message = await self.model.acomplete( + messages=[UserMessage(content=self.format(*args, **kwargs))], + functions=self._functions, + output_types=self._return_types, + stop=self._stop, + ) + return message.content + except Exception as e: + if attempt == retries: + raise + self._messages.append(UserMessage(content=f"Error: {str(e)}")) class PromptDecorator(Protocol):