Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: inference process & end of utterance plugin #1133

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0e792d7
add endpoint detector
jeradf Oct 16, 2024
bc0b286
sort imports
jeradf Oct 17, 2024
6ae0937
pass convo history to model prediction
jeradf Oct 17, 2024
0212d6a
formatz
jeradf Oct 17, 2024
ba9afbc
no need to deepcopy convo
jeradf Oct 17, 2024
830f4fc
wip
theomonnom Nov 25, 2024
b26526f
Merge branch 'turn_detector' into theo/inference-proc
theomonnom Nov 25, 2024
940eefe
wip
theomonnom Nov 25, 2024
047f79d
Delete endpoint_detector.py
theomonnom Nov 25, 2024
644f958
Discard changes to examples/voice-pipeline-agent/minimal_assistant.py
theomonnom Nov 25, 2024
f7bcfb7
wip
theomonnom Nov 25, 2024
bd43f51
Merge branch 'theo/inference-proc' of https://github.com/livekit/agen…
theomonnom Nov 25, 2024
bc0bca8
wip
theomonnom Nov 26, 2024
d30ef1e
Update CHANGELOG.md
theomonnom Nov 26, 2024
a4ca77c
wip
theomonnom Nov 26, 2024
53e06ed
Merge branch 'theo/inference-proc' of https://github.com/livekit/agen…
theomonnom Nov 26, 2024
6c79095
Merge branch 'main' into theo/inference-proc
theomonnom Nov 26, 2024
0d188a8
working version
theomonnom Nov 26, 2024
ac2ff51
Update pipeline_agent.py
theomonnom Nov 26, 2024
ae24725
Update inference_executor.py
theomonnom Nov 26, 2024
2bea42b
wip
theomonnom Nov 26, 2024
44cab40
share _ProcClient & remove duplicated code
theomonnom Nov 27, 2024
68304f9
Merge branch 'main' into theo/inference-proc
theomonnom Nov 27, 2024
0f4e630
Create wild-walls-occur.md
theomonnom Nov 27, 2024
e8da65b
Delete inference_runner.py
theomonnom Nov 27, 2024
f6345e5
Merge branch 'theo/inference-proc' of https://github.com/livekit/agen…
theomonnom Nov 27, 2024
3153096
wip
theomonnom Nov 27, 2024
be961e6
Update job_proc_executor.py
theomonnom Nov 27, 2024
3fa6ad6
remove duplicate
theomonnom Nov 27, 2024
01d02b4
wip
theomonnom Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/wild-walls-occur.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-agents": patch
"livekit-plugins-eou": minor
---

feat: inference process & end of utterance plugin
3 changes: 2 additions & 1 deletion examples/voice-pipeline-agent/minimal_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
metrics,
)
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
from livekit.plugins import deepgram, eou, openai, silero

load_dotenv()
logger = logging.getLogger("voice-assistant")
Expand Down Expand Up @@ -49,6 +49,7 @@ async def entrypoint(ctx: JobContext):
stt=deepgram.STT(model=dg_model),
llm=openai.LLM(),
tts=openai.TTS(),
eou=eou.EOU(),
chat_ctx=initial_ctx,
)

Expand Down
39 changes: 39 additions & 0 deletions livekit-agents/livekit/agents/inference_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import threading
from abc import ABC, abstractmethod
from typing import ClassVar, Protocol, Type


class _RunnerMeta(Protocol):
INFERENCE_METHOD: ClassVar[str]


_RunnersDict = dict[str, Type["_InferenceRunner"]]


# kept private until we stabilize the API (only used for EOU today)
class _InferenceRunner(ABC, _RunnerMeta):
registered_runners: _RunnersDict = {}

@classmethod
def register_runner(cls, runner_class: Type["_InferenceRunner"]) -> None:
if threading.current_thread() != threading.main_thread():
raise RuntimeError("InferenceRunner must be registered on the main thread")

if runner_class.INFERENCE_METHOD in cls.registered_runners:
raise ValueError(
f"InferenceRunner {runner_class.INFERENCE_METHOD} already registered"
)

cls.registered_runners[runner_class.INFERENCE_METHOD] = runner_class

@abstractmethod
def initialize(self) -> None:
"""Initialize the runner. This is used to load models, etc."""
...

@abstractmethod
def run(self, data: bytes) -> bytes | None:
"""Run inference on the given data."""
...
10 changes: 6 additions & 4 deletions livekit-agents/livekit/agents/ipc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from . import (
channel,
inference_proc_executor,
job_executor,
proc_job_executor,
job_proc_executor,
job_thread_executor,
proc_pool,
proto,
thread_job_executor,
)

__all__ = [
"proto",
"channel",
"proc_pool",
"proc_job_executor",
"thread_job_executor",
"job_proc_executor",
"job_thread_executor",
"inference_proc_executor",
"job_executor",
]
7 changes: 7 additions & 0 deletions livekit-agents/livekit/agents/ipc/inference_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations

from typing import Protocol


class InferenceExecutor(Protocol):
async def do_inference(self, method: str, data: bytes) -> bytes | None: ...
98 changes: 98 additions & 0 deletions livekit-agents/livekit/agents/ipc/inference_proc_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

import asyncio
import contextlib
import multiprocessing as mp
import socket
from multiprocessing.context import BaseContext

from ..inference_runner import _RunnersDict
from ..log import logger
from ..utils import aio, log_exceptions, shortuuid
from . import channel, proto
from .inference_proc_lazy_main import ProcStartArgs, proc_main
from .supervised_proc import SupervisedProc


class InferenceProcExecutor(SupervisedProc):
def __init__(
self,
*,
runners: _RunnersDict,
initialize_timeout: float,
close_timeout: float,
memory_warn_mb: float,
memory_limit_mb: float,
ping_interval: float,
ping_timeout: float,
high_ping_threshold: float,
mp_ctx: BaseContext,
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__(
initialize_timeout=initialize_timeout,
close_timeout=close_timeout,
memory_warn_mb=memory_warn_mb,
memory_limit_mb=memory_limit_mb,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
high_ping_threshold=high_ping_threshold,
mp_ctx=mp_ctx,
loop=loop,
)

self._runners = runners
self._active_requests: dict[str, asyncio.Future[proto.InferenceResponse]] = {}

def _create_process(self, cch: socket.socket, log_cch: socket.socket) -> mp.Process:
proc_args = ProcStartArgs(
log_cch=log_cch,
mp_cch=cch,
runners=self._runners,
)

return self._opts.mp_ctx.Process( # type: ignore
target=proc_main,
args=(proc_args,),
name="inference_proc",
)

@log_exceptions(logger=logger)
async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None:
async for msg in ipc_ch:
if isinstance(msg, proto.InferenceResponse):
fut = self._active_requests.pop(msg.request_id, None)
if fut is None:
logger.warning(
"received unexpected inference response",
extra={"request_id": msg.request_id},
)
return

with contextlib.suppress(asyncio.InvalidStateError):
fut.set_result(msg)

async def do_inference(self, method: str, data: bytes) -> bytes | None:
if not self.started:
raise RuntimeError("process not started")

request_id = shortuuid("inference_req_")
fut = asyncio.Future[proto.InferenceResponse]()

await channel.asend_message(
self._pch,
proto.InferenceRequest(request_id=request_id, method=method, data=data),
)

self._active_requests[request_id] = fut

inf_resp = await fut
if inf_resp.error:
raise RuntimeError(f"inference of {method} failed: {inf_resp.error}")

return inf_resp.data

def logging_extra(self):
extra = super().logging_extra()
extra["inference"] = True
return extra
108 changes: 108 additions & 0 deletions livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from multiprocessing import current_process

if current_process().name == "inference_proc":
import signal
import sys

# ignore signals in the inference process (the parent process will handle them)
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)

def _no_traceback_excepthook(exc_type, exc_val, traceback):
if isinstance(exc_val, KeyboardInterrupt):
return
sys.__excepthook__(exc_type, exc_val, traceback)

sys.excepthook = _no_traceback_excepthook


import asyncio
import socket
from dataclasses import dataclass

from ..inference_runner import _RunnersDict
from ..log import logger
from ..utils import aio, log_exceptions
from . import proto
from .channel import Message
from .proc_client import _ProcClient


@dataclass
class ProcStartArgs:
log_cch: socket.socket
mp_cch: socket.socket
runners: _RunnersDict


def proc_main(args: ProcStartArgs) -> None:
from .proc_client import _ProcClient

inf_proc = _InferenceProc(args.runners)

client = _ProcClient(
args.mp_cch,
args.log_cch,
inf_proc.initialize,
inf_proc.entrypoint,
)

client.initialize_logger()

pid = current_process().pid
logger.info("initializing inference process", extra={"pid": pid})
client.initialize()
logger.info("inference process initialized", extra={"pid": pid})

client.run()


class _InferenceProc:
def __init__(self, runners: _RunnersDict) -> None:
# create an instance of each runner (the ctor must not requires any argument)
self._runners = {name: runner() for name, runner in runners.items()}

def initialize(
self, init_req: proto.InitializeRequest, client: _ProcClient
) -> None:
self._client = client

for runner in self._runners.values():
logger.debug(
"initializing inference runner",
extra={"runner": runner.__class__.INFERENCE_METHOD},
)
runner.initialize()

@log_exceptions(logger=logger)
async def entrypoint(self, cch: aio.ChanReceiver[Message]) -> None:
async for msg in cch:
if isinstance(msg, proto.InferenceRequest):
await self._handle_inference_request(msg)

if isinstance(msg, proto.ShutdownRequest):
await self._client.send(proto.Exiting(reason=msg.reason))
break

async def _handle_inference_request(self, msg: proto.InferenceRequest) -> None:
loop = asyncio.get_running_loop()

if msg.method not in self._runners:
logger.warning("unknown inference method", extra={"method": msg.method})

try:
data = await loop.run_in_executor(
None, self._runners[msg.method].run, msg.data
)
await self._client.send(
proto.InferenceResponse(
request_id=msg.request_id,
data=data,
)
)

except Exception as e:
logger.exception("error running inference")
await self._client.send(
proto.InferenceResponse(request_id=msg.request_id, error=str(e))
)
Loading
Loading