Skip to content

Commit

Permalink
feat: support for custom toolchoice in LLMs (#1102)
Browse files Browse the repository at this point in the history
Co-authored-by: Théo Monnom <[email protected]>
  • Loading branch information
jayeshp19 and theomonnom authored Nov 26, 2024
1 parent 28f1ab6 commit 0194e36
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 14 deletions.
8 changes: 8 additions & 0 deletions .changeset/real-phones-cheat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"livekit-plugins-anthropic": patch
"livekit-plugins-openai": patch
"livekit-agents": patch
"livekit-plugins-llama-index": patch
---

support for custom tool use in LLMs
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CompletionUsage,
LLMCapabilities,
LLMStream,
ToolChoice,
)

__all__ = [
Expand Down Expand Up @@ -52,4 +53,5 @@
"LLMCapabilities",
"FallbackAdapter",
"AvailabilityChangedEvent",
"ToolChoice",
]
11 changes: 9 additions & 2 deletions livekit-agents/livekit/agents/llm/fallback_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import dataclasses
import time
from dataclasses import dataclass
from typing import AsyncIterable, Literal
from typing import AsyncIterable, Literal, Union

from livekit.agents._exceptions import APIConnectionError, APIError

from ..log import logger
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from .chat_context import ChatContext
from .function_context import FunctionContext
from .llm import LLM, ChatChunk, LLMStream
from .llm import LLM, ChatChunk, LLMStream, ToolChoice

DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
Expand Down Expand Up @@ -66,6 +66,8 @@ def chat(
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
return FallbackLLMStream(
llm=self,
Expand All @@ -75,6 +77,7 @@ def chat(
temperature=temperature,
n=n,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)


Expand All @@ -89,6 +92,8 @@ def __init__(
temperature: float | None,
n: int | None,
parallel_tool_calls: bool | None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
Expand All @@ -97,6 +102,7 @@ def __init__(
self._temperature = temperature
self._n = n
self._parallel_tool_calls = parallel_tool_calls
self._tool_choice = tool_choice

async def _try_generate(
self, *, llm: LLM, recovering: bool = False
Expand All @@ -108,6 +114,7 @@ async def _try_generate(
temperature=self._temperature,
n=self._n,
parallel_tool_calls=self._parallel_tool_calls,
tool_choice=self._tool_choice,
conn_options=dataclasses.replace(
self._conn_options,
max_retry=self._fallback_adapter._max_retry_per_llm,
Expand Down
18 changes: 17 additions & 1 deletion livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, AsyncIterable, AsyncIterator, Generic, Literal, TypeVar, Union
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Generic,
Literal,
TypeVar,
Union,
)

from livekit import rtc
from livekit.agents._exceptions import APIConnectionError, APIError
Expand Down Expand Up @@ -51,6 +59,12 @@ class ChatChunk:
usage: CompletionUsage | None = None


@dataclass
class ToolChoice:
type: Literal["function"]
name: str


TEvent = TypeVar("TEvent")


Expand Down Expand Up @@ -78,6 +92,8 @@ def chat(
temperature: float | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream": ...

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Awaitable, List, Tuple, get_args, get_origin
from typing import (
Any,
Awaitable,
List,
Literal,
Tuple,
Union,
get_args,
get_origin,
)

import httpx
from livekit import rtc
Expand All @@ -30,6 +39,7 @@
llm,
utils,
)
from livekit.agents.llm import ToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

import anthropic
Expand All @@ -45,6 +55,8 @@ class LLMOptions:
model: str | ChatModels
user: str | None
temperature: float | None
parallel_tool_calls: bool | None
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] | None


class LLM(llm.LLM):
Expand All @@ -57,6 +69,8 @@ def __init__(
user: str | None = None,
client: anthropic.AsyncClient | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> None:
"""
Create a new instance of Anthropic LLM.
Expand All @@ -71,7 +85,13 @@ def __init__(
if api_key is None:
raise ValueError("Anthropic API key is required")

self._opts = LLMOptions(model=model, user=user, temperature=temperature)
self._opts = LLMOptions(
model=model,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
self._client = client or anthropic.AsyncClient(
api_key=api_key,
base_url=base_url,
Expand All @@ -95,9 +115,15 @@ def chat(
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
if temperature is None:
temperature = self._opts.temperature
if parallel_tool_calls is None:
parallel_tool_calls = self._opts.parallel_tool_calls
if tool_choice is None:
tool_choice = self._opts.tool_choice

opts: dict[str, Any] = dict()
if fnc_ctx and len(fnc_ctx.ai_functions) > 0:
Expand All @@ -106,9 +132,20 @@ def chat(
fncs_desc.append(_build_function_description(fnc))

opts["tools"] = fncs_desc

if fnc_ctx and parallel_tool_calls is not None:
opts["parallel_tool_calls"] = parallel_tool_calls
if tool_choice is not None:
anthropic_tool_choice: dict[str, Any] = {"type": "auto"}
if isinstance(tool_choice, ToolChoice):
if tool_choice.type == "function":
anthropic_tool_choice = {
"type": "tool",
"name": tool_choice.name,
}
elif isinstance(tool_choice, str):
if tool_choice == "required":
anthropic_tool_choice = {"type": "any"}
if parallel_tool_calls is not None and parallel_tool_calls is False:
anthropic_tool_choice["disable_parallel_tool_use"] = True
opts["tool_choice"] = anthropic_tool_choice

latest_system_message = _latest_system_message(chat_ctx)
anthropic_ctx = _build_anthropic_context(chat_ctx.messages, id(self))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import Literal, Union

from livekit.agents import (
APIConnectionError,
llm,
)
from livekit.agents.llm import ToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

from llama_index.core.chat_engine.types import (
Expand Down Expand Up @@ -33,6 +36,8 @@ def chat(
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
if fnc_ctx is not None:
logger.warning("fnc_ctx is currently not supported with llama_index.LLM")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import json
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Dict, Literal, MutableSet
from typing import Any, Callable, Dict, Literal, MutableSet, Union

import httpx
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm import ToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

from openai import AsyncAssistantEventHandler, AsyncClient
Expand Down Expand Up @@ -172,6 +173,8 @@ def chat(
temperature: float | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
):
if n is not None:
logger.warning("OpenAI Assistants does not support the 'n' parameter")
Expand Down
Loading

0 comments on commit 0194e36

Please sign in to comment.