From 0e792d78f8336e48149fcaaab473b5f9ef70a042 Mon Sep 17 00:00:00 2001 From: jeradf Date: Wed, 16 Oct 2024 18:36:39 -0500 Subject: [PATCH 01/24] add endpoint detector --- .../voice-pipeline-agent/minimal_assistant.py | 1 + .../agents/pipeline/endpoint_detector.py | 72 +++++++++++++++++++ .../livekit/agents/pipeline/pipeline_agent.py | 65 +++++++++++++---- 3 files changed, 124 insertions(+), 14 deletions(-) create mode 100644 livekit-agents/livekit/agents/pipeline/endpoint_detector.py diff --git a/examples/voice-pipeline-agent/minimal_assistant.py b/examples/voice-pipeline-agent/minimal_assistant.py index e5ea9b64a..0c9c5464b 100644 --- a/examples/voice-pipeline-agent/minimal_assistant.py +++ b/examples/voice-pipeline-agent/minimal_assistant.py @@ -49,6 +49,7 @@ async def entrypoint(ctx: JobContext): llm=openai.LLM(), tts=openai.TTS(), chat_ctx=initial_ctx, + use_endpoint_detector=True, ) agent.start(ctx.room, participant) diff --git a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py new file mode 100644 index 000000000..5d2e7b56d --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py @@ -0,0 +1,72 @@ +from pprint import pprint +from optimum.onnxruntime import ORTModelForCausalLM +from transformers import AutoTokenizer +import copy +import torch +import string +from .log import logger +import time +PUNCS = string.punctuation.replace("'","") + + +class EndpointDetector: + def __init__(self, model_path='jeradf/opt-125m-eou'): + self.model = ORTModelForCausalLM.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self._eou_index = self.tokenizer.encode("<|im_end|>")[-1] + + def normalize(self, text): + def strip_puncs(text): + return text.translate(str.maketrans('', '', PUNCS)) + return " ".join(strip_puncs(text).lower().split()) + + def apply_chat_template(self, convo): + for msg in convo: + msg['content'] = self.normalize(msg['content']) + + convo_text = self.tokenizer.apply_chat_template( + convo, + add_generation_prompt=False, + add_special_tokens=False, + tokenize=False, + ) + + # remove the EOU token from current utterance + ix = convo_text.rfind('<|im_end|>') + text = convo_text[:ix] + return text + + def tokenize(self, text): + return self.tokenizer( + text, + add_special_tokens=False, + return_tensors='pt', + ) + + def predict(self, utterance, convo=[]): + start_time = time.time() + + convo_copy = copy.deepcopy(convo) + convo_copy.append(dict(role='user', content=utterance)) + + text = self.apply_chat_template(convo_copy) + inputs = self.tokenize(text) + + outputs = self.model(**inputs) + logits = outputs.logits[0, -1, :] + probs = torch.nn.functional.softmax(logits, dim=-1) + result = probs[self._eou_index].item() + + end_time = time.time() + latency = end_time - start_time + + logger.debug( + "EndpointDetector prediction", + extra={ + "probability": round(result, 2), + "utterance": utterance, + "latency": round(latency, 2), + } + ) + return result + diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index ed05b4613..81933feb8 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -152,6 +152,7 @@ def __init__( loop: asyncio.AbstractEventLoop | None = None, # backward compatibility will_synthesize_assistant_reply: WillSynthesizeAssistantReply | None = None, + use_endpoint_detector: bool = False, ) -> None: """ Create a new VoicePipelineAgent. @@ -244,6 +245,7 @@ def __init__( self._validate_reply_if_possible, self._opts.min_endpointing_delay, loop=self._loop, + use_endpoint_detector=use_endpoint_detector, ) self._speech_q: list[SpeechHandle] = [] @@ -893,14 +895,22 @@ class _DeferredReplyValidation: # if the STT gives us punctuation, we can try validate the reply faster. PUNCTUATION = ".!?" PUNCTUATION_REDUCE_FACTOR = 0.75 - + LATE_TRANSCRIPT_TOLERANCE = 1.5 # late compared to end of speech + # When endpoint probability is below this threshold we think the user is not finished speaking + # so we will use a long delay + UNLIKELY_ENDPOINT_THRESHOLD = 0.25 + + # Long delay to use when the model thinks the user is still speaking + UNLIKELY_ENDPOINT_DELAY = 10.0 + def __init__( self, validate_fnc: Callable[[], None], min_endpointing_delay: float, loop: asyncio.AbstractEventLoop | None = None, + use_endpoint_detector: bool = False, ) -> None: self._validate_fnc = validate_fnc self._validating_task: asyncio.Task | None = None @@ -910,17 +920,19 @@ def __init__( self._end_of_speech_delay = min_endpointing_delay self._final_transcript_delay = min_endpointing_delay + 1.0 - + + self._endpoint_detector = None + self._endpoint_probability = None + + if use_endpoint_detector: + from .endpoint_detector import EndpointDetector + self._endpoint_detector = EndpointDetector() + @property def validating(self) -> bool: return self._validating_task is not None and not self._validating_task.done() - - def on_human_final_transcript(self, transcript: str) -> None: - self._last_final_transcript = transcript.strip() # type: ignore - - if self._speaking: - return - + + def get_endpoint_delay(self) -> float: has_recent_end_of_speech = ( time.time() - self._last_recv_end_of_speech_time < self.LATE_TRANSCRIPT_TOLERANCE @@ -935,7 +947,36 @@ def on_human_final_transcript(self, transcript: str) -> None: if self._end_with_punctuation() else 1.0 ) + + if self._endpoint_probability is not None: + if self._endpoint_probability < self.UNLIKELY_ENDPOINT_THRESHOLD: + new_delay = self.UNLIKELY_ENDPOINT_DELAY + + logger.debug( + "Unlikely endpoint detected", + extra={ + "new_delay": new_delay, + "old_delay": delay, + "endpoint_prob": round(self._endpoint_probability, 2), + "transcript": self._last_final_transcript, + } + ) + delay = new_delay + return delay + + def on_human_final_transcript(self, transcript: str) -> None: + self._last_final_transcript = transcript.strip() # type: ignore + + if self._endpoint_detector: + self._endpoint_probability = self._endpoint_detector.predict( + utterance=self._last_final_transcript, + convo=[] + ) + if self._speaking: + return + + delay = self.get_endpoint_delay() self._run(delay) def on_human_start_of_speech(self, ev: vad.VADEvent) -> None: @@ -949,11 +990,7 @@ def on_human_end_of_speech(self, ev: vad.VADEvent) -> None: self._last_recv_end_of_speech_time = time.time() if self._last_final_transcript: - delay = ( - self._end_of_speech_delay * self.PUNCTUATION_REDUCE_FACTOR - if self._end_with_punctuation() - else 1.0 - ) + delay = self.get_endpoint_delay() self._run(delay) async def aclose(self) -> None: From bc0b2863d82c51ab988b7bde2c2f3a628b043bac Mon Sep 17 00:00:00 2001 From: jeradf Date: Wed, 16 Oct 2024 19:10:10 -0500 Subject: [PATCH 02/24] sort imports --- .../livekit/agents/pipeline/endpoint_detector.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py index 5d2e7b56d..1a003cb9f 100644 --- a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py +++ b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py @@ -1,11 +1,13 @@ -from pprint import pprint -from optimum.onnxruntime import ORTModelForCausalLM -from transformers import AutoTokenizer import copy -import torch import string -from .log import logger import time + +import torch +from optimum.onnxruntime import ORTModelForCausalLM +from transformers import AutoTokenizer + +from .log import logger + PUNCS = string.punctuation.replace("'","") From 6ae0937e611a62054411643a408b0dfae51ac9b6 Mon Sep 17 00:00:00 2001 From: jeradf Date: Wed, 16 Oct 2024 20:12:06 -0500 Subject: [PATCH 03/24] pass convo history to model prediction --- .../livekit/agents/pipeline/endpoint_detector.py | 5 ++++- .../livekit/agents/pipeline/pipeline_agent.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py index 1a003cb9f..f5ace7aa6 100644 --- a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py +++ b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py @@ -9,7 +9,7 @@ from .log import logger PUNCS = string.punctuation.replace("'","") - +MAX_HISTORY = 3 class EndpointDetector: def __init__(self, model_path='jeradf/opt-125m-eou'): @@ -24,6 +24,8 @@ def strip_puncs(text): def apply_chat_template(self, convo): for msg in convo: + if msg['role'] not in ['user', 'assistant']: + continue msg['content'] = self.normalize(msg['content']) convo_text = self.tokenizer.apply_chat_template( @@ -50,6 +52,7 @@ def predict(self, utterance, convo=[]): convo_copy = copy.deepcopy(convo) convo_copy.append(dict(role='user', content=utterance)) + convo_copy = convo_copy[-MAX_HISTORY:] text = self.apply_chat_template(convo_copy) inputs = self.tokenize(text) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 81933feb8..75ea75879 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -244,6 +244,7 @@ def __init__( self._deferred_validation = _DeferredReplyValidation( self._validate_reply_if_possible, self._opts.min_endpointing_delay, + self._chat_ctx, # Pass the ChatContext here loop=self._loop, use_endpoint_detector=use_endpoint_detector, ) @@ -900,15 +901,16 @@ class _DeferredReplyValidation: # When endpoint probability is below this threshold we think the user is not finished speaking # so we will use a long delay - UNLIKELY_ENDPOINT_THRESHOLD = 0.25 + UNLIKELY_ENDPOINT_THRESHOLD = 0.15 # Long delay to use when the model thinks the user is still speaking - UNLIKELY_ENDPOINT_DELAY = 10.0 + UNLIKELY_ENDPOINT_DELAY = 5.0 def __init__( self, validate_fnc: Callable[[], None], min_endpointing_delay: float, + chat_ctx: ChatContext, # Add this parameter loop: asyncio.AbstractEventLoop | None = None, use_endpoint_detector: bool = False, ) -> None: @@ -928,6 +930,8 @@ def __init__( from .endpoint_detector import EndpointDetector self._endpoint_detector = EndpointDetector() + self._chat_ctx = chat_ctx # Store the ChatContext + @property def validating(self) -> bool: return self._validating_task is not None and not self._validating_task.done() @@ -968,9 +972,11 @@ def on_human_final_transcript(self, transcript: str) -> None: self._last_final_transcript = transcript.strip() # type: ignore if self._endpoint_detector: + convo = [dict(role=msg.role, content=msg.content) + for msg in self._chat_ctx.messages] self._endpoint_probability = self._endpoint_detector.predict( utterance=self._last_final_transcript, - convo=[] + convo=convo ) if self._speaking: From 0212d6a80034935aa6ef6aac359402496bf20019 Mon Sep 17 00:00:00 2001 From: jeradf Date: Wed, 16 Oct 2024 20:14:50 -0500 Subject: [PATCH 04/24] formatz --- livekit-agents/livekit/agents/pipeline/pipeline_agent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 75ea75879..6e78f7be4 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -243,9 +243,9 @@ def __init__( self._deferred_validation = _DeferredReplyValidation( self._validate_reply_if_possible, - self._opts.min_endpointing_delay, - self._chat_ctx, # Pass the ChatContext here + self._opts.min_endpointing_delay, loop=self._loop, + chat_ctx=self._chat_ctx, use_endpoint_detector=use_endpoint_detector, ) @@ -910,7 +910,7 @@ def __init__( self, validate_fnc: Callable[[], None], min_endpointing_delay: float, - chat_ctx: ChatContext, # Add this parameter + chat_ctx: ChatContext, loop: asyncio.AbstractEventLoop | None = None, use_endpoint_detector: bool = False, ) -> None: @@ -930,7 +930,7 @@ def __init__( from .endpoint_detector import EndpointDetector self._endpoint_detector = EndpointDetector() - self._chat_ctx = chat_ctx # Store the ChatContext + self._chat_ctx = chat_ctx @property def validating(self) -> bool: From ba9afbc77dfcfd83fac5fde5102bbd2225a4373b Mon Sep 17 00:00:00 2001 From: jeradf Date: Wed, 16 Oct 2024 20:31:31 -0500 Subject: [PATCH 05/24] no need to deepcopy convo --- .../livekit/agents/pipeline/endpoint_detector.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py index f5ace7aa6..831293cd9 100644 --- a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py +++ b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py @@ -50,11 +50,10 @@ def tokenize(self, text): def predict(self, utterance, convo=[]): start_time = time.time() - convo_copy = copy.deepcopy(convo) - convo_copy.append(dict(role='user', content=utterance)) - convo_copy = convo_copy[-MAX_HISTORY:] + convo.append(dict(role='user', content=utterance)) + convo = convo[-MAX_HISTORY:] - text = self.apply_chat_template(convo_copy) + text = self.apply_chat_template(convo) inputs = self.tokenize(text) outputs = self.model(**inputs) @@ -66,7 +65,7 @@ def predict(self, utterance, convo=[]): latency = end_time - start_time logger.debug( - "EndpointDetector prediction", + "EndpointDetector inference", extra={ "probability": round(result, 2), "utterance": utterance, From 830f4fc012f9d4ca20ba0923ffbed64fa0abea7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 25 Nov 2024 15:29:33 +0100 Subject: [PATCH 06/24] wip --- .../livekit/agents/ipc/inference_main.py | 76 +++++ .../agents/ipc/inference_proc_executor.py | 306 ++++++++++++++++++ .../agents/ipc/inference_proc_lazy_main.py | 73 +++++ livekit-agents/livekit/agents/ipc/job_main.py | 67 +--- ...c_job_executor.py => job_proc_executor.py} | 44 +-- ...roc_lazy_main.py => job_proc_lazy_main.py} | 32 +- ...job_executor.py => job_thread_executor.py} | 0 .../livekit/agents/ipc/log_queue.py | 110 +++++++ 8 files changed, 587 insertions(+), 121 deletions(-) create mode 100644 livekit-agents/livekit/agents/ipc/inference_main.py create mode 100644 livekit-agents/livekit/agents/ipc/inference_proc_executor.py create mode 100644 livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py rename livekit-agents/livekit/agents/ipc/{proc_job_executor.py => job_proc_executor.py} (91%) rename livekit-agents/livekit/agents/ipc/{proc_lazy_main.py => job_proc_lazy_main.py} (68%) rename livekit-agents/livekit/agents/ipc/{thread_job_executor.py => job_thread_executor.py} (100%) create mode 100644 livekit-agents/livekit/agents/ipc/log_queue.py diff --git a/livekit-agents/livekit/agents/ipc/inference_main.py b/livekit-agents/livekit/agents/ipc/inference_main.py new file mode 100644 index 000000000..1edffd5f9 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/inference_main.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import asyncio +import contextlib +import socket +import threading +from dataclasses import dataclass +from typing import Any, Callable + +from livekit import rtc + +from .. import utils +from ..job import JobContext, JobProcess +from ..log import logger +from ..utils.aio import duplex_unix +from . import channel, proto + + +async def _async_main( + entrypoint_fnc: Callable, + mp_cch: socket.socket, +) -> None: + cch = await duplex_unix._AsyncDuplex.open(mp_cch) + + exit_flag = asyncio.Event() + no_msg_timeout = utils.aio.sleep(proto.PING_INTERVAL * 5) # missing 5 pings + + @utils.log_exceptions(logger=logger) + async def _read_ipc_task(): + while True: + try: + msg = await channel.arecv_message(cch, proto.IPC_MESSAGES) + except duplex_unix.DuplexClosed: + break + + with contextlib.suppress(utils.aio.SleepFinished): + no_msg_timeout.reset() + + if isinstance(msg, proto.PingRequest): + pong = proto.PongResponse( + last_timestamp=msg.timestamp, timestamp=utils.time_ms() + ) + await channel.asend_message(cch, pong) + + if isinstance(msg, proto.ShutdownRequest): + pass + + async def _self_health_check(): + await no_msg_timeout + print("worker process is not responding.. worker crashed?") + with contextlib.suppress(asyncio.CancelledError): + exit_flag.set() + + read_task = asyncio.create_task(_read_ipc_task(), name="ipc_read") + health_check_task = asyncio.create_task(_self_health_check(), name="health_check") + + def _done_cb(_: asyncio.Task) -> None: + with contextlib.suppress(asyncio.InvalidStateError): + exit_flag.set() + + read_task.add_done_callback(_done_cb) + + await exit_flag.wait() + await utils.aio.gracefully_cancel(read_task, health_check_task) + + with contextlib.suppress(duplex_unix.DuplexClosed): + await cch.aclose() + + +@dataclass +class ProcStartArgs: + initialize_process_fnc: Callable + entrypoint_fnc: Callable + log_cch: socket.socket + mp_cch: socket.socket + asyncio_debug: bool diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py new file mode 100644 index 000000000..429353286 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import asyncio +import contextlib +import logging +import socket +import sys +import threading +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from typing import Any, Callable + +from .. import utils +from ..log import logger +from ..utils.aio import duplex_unix +from . import channel, inference_proc_lazy_main, proto, inference_main +from .log_queue import LogQueueListener + + +@dataclass +class _ProcOpts: + initialize_process_fnc: Callable + entrypoint_fnc: Callable + mp_ctx: BaseContext + initialize_timeout: float + close_timeout: float + + +class ProcJobExecutor: + def __init__( + self, + *, + initialize_process_fnc: Callable, + entrypoint_fnc: Callable, + initialize_timeout: float, + close_timeout: float, + mp_ctx: BaseContext, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._opts = _ProcOpts( + initialize_process_fnc=initialize_process_fnc, + entrypoint_fnc=entrypoint_fnc, + initialize_timeout=initialize_timeout, + close_timeout=close_timeout, + mp_ctx=mp_ctx, + ) + + self._exitcode: int | None = None + self._pid: int | None = None + + self._main_atask: asyncio.Task[None] | None = None + self._closing = False + self._kill_sent = False + self._initialize_fut = asyncio.Future[None]() + + self._lock = asyncio.Lock() + + @property + def exitcode(self) -> int | None: + return self._exitcode + + @property + def killed(self) -> bool: + return self._kill_sent + + @property + def pid(self) -> int | None: + return self._pid + + @property + def started(self) -> bool: + return self._main_atask is not None + + async def start(self) -> None: + """start the job process""" + if self.started: + raise RuntimeError("process already started") + + if self._closing: + raise RuntimeError("process is closed") + + await asyncio.shield(self._start()) + + async def _start(self) -> None: + def _add_proc_ctx_log(record: logging.LogRecord) -> None: + extra = self.logging_extra() + for key, value in extra.items(): + setattr(record, key, value) + + async with self._lock: + self._pong_timeout = utils.aio.sleep(proto.PING_TIMEOUT) + + mp_pch, mp_cch = socket.socketpair() + mp_log_pch, mp_log_cch = socket.socketpair() + + self._pch = await duplex_unix._AsyncDuplex.open(mp_pch) + + log_pch = duplex_unix._Duplex.open(mp_log_pch) + log_listener = LogQueueListener(log_pch, _add_proc_ctx_log) + log_listener.start() + + self._proc_args = inference_main.ProcStartArgs( + initialize_process_fnc=self._opts.initialize_process_fnc, + entrypoint_fnc=self._opts.entrypoint_fnc, + log_cch=mp_log_cch, + mp_cch=mp_cch, + asyncio_debug=self._loop.get_debug(), + ) + + self._proc = self._opts.mp_ctx.Process( # type: ignore + target=inference_proc_lazy_main.proc_main, + args=(self._proc_args,), + name="inference_proc", + ) + + self._proc.start() + mp_log_cch.close() + mp_cch.close() + + self._pid = self._proc.pid + self._join_fut = asyncio.Future[None]() + + def _sync_run(): + self._proc.join() + log_listener.stop() + try: + self._loop.call_soon_threadsafe(self._join_fut.set_result, None) + except RuntimeError: + pass + + thread = threading.Thread(target=_sync_run, name="proc_join_thread") + thread.start() + self._main_atask = asyncio.create_task(self._main_task()) + + async def join(self) -> None: + if not self.started: + raise RuntimeError("process not started") + + async with self._lock: + if self._main_atask: + await asyncio.shield(self._main_atask) + + async def initialize(self) -> None: + await channel.asend_message(self._pch, proto.InitializeRequest()) + + # wait for the process to become ready + try: + init_res = await asyncio.wait_for( + channel.arecv_message(self._pch, proto.IPC_MESSAGES), + timeout=self._opts.initialize_timeout, + ) + assert isinstance( + init_res, proto.InitializeResponse + ), "first message must be InitializeResponse" + except asyncio.TimeoutError: + self._initialize_fut.set_exception( + asyncio.TimeoutError("process initialization timed out") + ) + logger.error( + "initialization timed out, killing job", extra=self.logging_extra() + ) + self._send_kill_signal() + raise + except Exception as e: # should be channel.ChannelClosed most of the time + self._initialize_fut.set_exception(e) + raise + else: + self._initialize_fut.set_result(None) + + async def aclose(self) -> None: + """attempt to gracefully close the job process""" + if not self.started: + return + + self._closing = True + with contextlib.suppress(utils.aio.duplex_unix.DuplexClosed): + await channel.asend_message(self._pch, proto.ShutdownRequest()) + + try: + if self._main_atask: + await asyncio.wait_for( + asyncio.shield(self._main_atask), timeout=self._opts.close_timeout + ) + except asyncio.TimeoutError: + logger.error( + "process did not exit in time, killing job", extra=self.logging_extra() + ) + self._send_kill_signal() + + async with self._lock: + if self._main_atask: + await asyncio.shield(self._main_atask) + + async def kill(self) -> None: + """forcefully kill the job process""" + if not self.started: + raise RuntimeError("process not started") + + self._closing = True + self._send_kill_signal() + + async with self._lock: + if self._main_atask: + await asyncio.shield(self._main_atask) + + def _send_kill_signal(self) -> None: + """forcefully kill the job process""" + try: + if not self._proc.is_alive(): + return + except ValueError: + return + + logger.info("killing job process", extra=self.logging_extra()) + if sys.platform == "win32": + self._proc.terminate() + else: + self._proc.kill() + + self._kill_sent = True + + @utils.log_exceptions(logger=logger) + async def _main_task(self) -> None: + try: + await self._initialize_fut + except asyncio.TimeoutError: + pass # this happens when the initialization takes longer than self._initialize_timeout + except Exception: + pass # initialization failed + + # the process is killed if it doesn't respond to ping requests + ping_task = asyncio.create_task(self._ping_pong_task()) + monitor_task = asyncio.create_task(self._monitor_task()) + + await self._join_fut + self._exitcode = self._proc.exitcode + self._proc.close() + await utils.aio.gracefully_cancel(ping_task, monitor_task) + + with contextlib.suppress(duplex_unix.DuplexClosed): + await self._pch.aclose() + + if self._exitcode != 0 and not self._kill_sent: + logger.error( + f"job process exited with non-zero exit code {self.exitcode}", + extra=self.logging_extra(), + ) + + @utils.log_exceptions(logger=logger) + async def _monitor_task(self) -> None: + while True: + try: + msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES) + except utils.aio.duplex_unix.DuplexClosed: + break + + if isinstance(msg, proto.PongResponse): + delay = utils.time_ms() - msg.timestamp + if delay > proto.HIGH_PING_THRESHOLD * 1000: + logger.warning( + "inference process is unresponsive", + extra={"delay": delay, **self.logging_extra()}, + ) + + with contextlib.suppress(utils.aio.SleepFinished): + self._pong_timeout.reset() + + @utils.log_exceptions(logger=logger) + async def _ping_pong_task(self) -> None: + ping_interval = utils.aio.interval(proto.PING_INTERVAL) + + async def _send_ping_co(): + while True: + await ping_interval.tick() + try: + await channel.asend_message( + self._pch, proto.PingRequest(timestamp=utils.time_ms()) + ) + except utils.aio.duplex_unix.DuplexClosed: + break + + async def _pong_timeout_co(): + while True: + await self._pong_timeout + logger.error( + "inference process is unresponsive, killing proc", + extra=self.logging_extra(), + ) + self._pong_timeout = utils.aio.sleep(proto.PING_TIMEOUT) + + tasks = [ + asyncio.create_task(_send_ping_co()), + asyncio.create_task(_pong_timeout_co()), + ] + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) + + def logging_extra(self): + extra: dict[str, Any] = { + "pid": self.pid, + "inference_process": True, + } + return extra diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py new file mode 100644 index 000000000..e4fc7626f --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py @@ -0,0 +1,73 @@ +from multiprocessing import current_process + +if current_process().name == "inference_proc": + import signal + import sys + + # ignore signals in the inference process (the parent process will handle them) + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + + def _no_traceback_excepthook(exc_type, exc_val, traceback): + if isinstance(exc_val, KeyboardInterrupt): + return + sys.__excepthook__(exc_type, exc_val, traceback) + + sys.excepthook = _no_traceback_excepthook + + +def proc_main(args) -> None: + # import every package lazily + import asyncio + import logging + + from ..log import logger + from ..utils import aio + from .channel import recv_message, send_message + from .log_queue import LogQueueHandler + from .proto import IPC_MESSAGES, InitializeRequest, InitializeResponse + + root_logger = logging.getLogger() + root_logger.setLevel(logging.NOTSET) + + log_cch = aio.duplex_unix._Duplex.open(args.log_cch) + log_handler = LogQueueHandler(log_cch) + root_logger.addHandler(log_handler) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.set_debug(args.asyncio_debug) + loop.slow_callback_duration = 0.1 # 100ms + aio.debug.hook_slow_callbacks(2.0) + + cch = aio.duplex_unix._Duplex.open(args.mp_cch) + try: + init_req = recv_message(cch, IPC_MESSAGES) + + assert isinstance( + init_req, InitializeRequest + ), "first message must be InitializeRequest" + + pid = current_process().pid + logger.info("initializing process", extra={"pid": pid}) + args.initialize_process_fnc() + logger.info("process initialized", extra={"pid": pid}) + send_message(cch, InitializeResponse()) + + from .inference_main import _async_main + + main_task = loop.create_task( + _async_main(args.entrypoint_fnc, cch.detach()), + name="inference_proc_main", + ) + while not main_task.done(): + try: + loop.run_until_complete(main_task) + except KeyboardInterrupt: + # ignore the keyboard interrupt, we handle the process shutdown ourselves on the worker process + pass + except (aio.duplex_unix.DuplexClosed, KeyboardInterrupt): + pass + finally: + log_handler.close() + loop.run_until_complete(loop.shutdown_default_executor()) diff --git a/livekit-agents/livekit/agents/ipc/job_main.py b/livekit-agents/livekit/agents/ipc/job_main.py index 4e7519400..dc75155b1 100644 --- a/livekit-agents/livekit/agents/ipc/job_main.py +++ b/livekit-agents/livekit/agents/ipc/job_main.py @@ -2,15 +2,10 @@ import asyncio import contextlib -import copy -import logging -import pickle -import queue import socket -import sys import threading from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable from livekit import rtc @@ -21,62 +16,6 @@ from . import channel, proto -class LogQueueHandler(logging.Handler): - _sentinal = None - - def __init__(self, duplex: utils.aio.duplex_unix._Duplex) -> None: - super().__init__() - self._duplex = duplex - self._send_q = queue.SimpleQueue[Optional[bytes]]() - self._send_thread = threading.Thread( - target=self._forward_logs, name="ipc_log_forwarder" - ) - self._send_thread.start() - - def _forward_logs(self): - while True: - serialized_record = self._send_q.get() - if serialized_record is None: - break - - try: - self._duplex.send_bytes(serialized_record) - except duplex_unix.DuplexClosed: - break - - self._duplex.close() - - def emit(self, record: logging.LogRecord) -> None: - try: - # Check if Python is shutting down - if sys.is_finalizing(): - return - - # from https://github.com/python/cpython/blob/91b7f2e7f6593acefda4fa860250dd87d6f849bf/Lib/logging/handlers.py#L1453 - msg = self.format(record) - record = copy.copy(record) - record.message = msg - record.msg = msg - record.args = None - record.exc_info = None - record.exc_text = None - record.stack_info = None - - # https://websockets.readthedocs.io/en/stable/topics/logging.html#logging-to-json - # webosckets library add "websocket" attribute to log records, which is not pickleable - if hasattr(record, "websocket"): - record.websocket = None - - self._send_q.put_nowait(pickle.dumps(record)) - - except Exception: - self.handleError(record) - - def close(self) -> None: - super().close() - self._send_q.put_nowait(self._sentinal) - - @dataclass class _ShutdownInfo: user_initiated: bool @@ -142,7 +81,7 @@ async def _run_job_task() -> None: async def _warn_not_connected_task(): await asyncio.sleep(10) if not ctx_connect and not ctx_shutdown: - logger.warn( + logger.warning( ( "room not connected after job_entry was called after 10 seconds, " "did you forget to call job_ctx.connect()?" @@ -159,7 +98,7 @@ def log_exception(t: asyncio.Task) -> None: exc_info=t.exception(), ) elif not ctx_connect and not ctx_shutdown: - logger.warn("job task completed without connecting or shutting down") + logger.warning("job task completed without connecting or shutting down") job_entry_task.add_done_callback(log_exception) diff --git a/livekit-agents/livekit/agents/ipc/proc_job_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py similarity index 91% rename from livekit-agents/livekit/agents/ipc/proc_job_executor.py rename to livekit-agents/livekit/agents/ipc/job_proc_executor.py index 2a956d947..26023924f 100644 --- a/livekit-agents/livekit/agents/ipc/proc_job_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -3,7 +3,6 @@ import asyncio import contextlib import logging -import pickle import socket import sys import threading @@ -22,48 +21,7 @@ JobExecutorError_Unresponsive, RunStatus, ) - - -class LogQueueListener: - def __init__( - self, - duplex: utils.aio.duplex_unix._Duplex, - prepare_fnc: Callable[[logging.LogRecord], None], - ): - self._thread: threading.Thread | None = None - self._duplex = duplex - self._prepare_fnc = prepare_fnc - - def start(self) -> None: - self._thread = threading.Thread(target=self._monitor, name="ipc_log_listener") - self._thread.start() - - def stop(self) -> None: - if self._thread is None: - return - - self._duplex.close() - self._thread.join() - self._thread = None - - def handle(self, record: logging.LogRecord) -> None: - self._prepare_fnc(record) - - lger = logging.getLogger(record.name) - if not lger.isEnabledFor(record.levelno): - return - - lger.callHandlers(record) - - def _monitor(self): - while True: - try: - data = self._duplex.recv_bytes() - except utils.aio.duplex_unix.DuplexClosed: - break - - record = pickle.loads(data) - self.handle(record) +from .log_queue_listener import LogQueueListener @dataclass diff --git a/livekit-agents/livekit/agents/ipc/proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py similarity index 68% rename from livekit-agents/livekit/agents/ipc/proc_lazy_main.py rename to livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py index be09e7f5a..6408f6c98 100644 --- a/livekit-agents/livekit/agents/ipc/proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -1,6 +1,6 @@ -import multiprocessing +from multiprocessing import current_process -if multiprocessing.current_process().name == "job_proc": +if current_process().name == "job_proc": import signal import sys @@ -23,41 +23,45 @@ def proc_main(args) -> None: import asyncio import logging - from .. import utils from ..job import JobProcess from ..log import logger - from . import channel, job_main, proto + from ..utils import aio + from .channel import recv_message, send_message + from .log_queue import LogQueueHandler + from .proto import IPC_MESSAGES, InitializeRequest, InitializeResponse root_logger = logging.getLogger() root_logger.setLevel(logging.NOTSET) - log_cch = utils.aio.duplex_unix._Duplex.open(args.log_cch) - log_handler = job_main.LogQueueHandler(log_cch) + log_cch = aio.duplex_unix._Duplex.open(args.log_cch) + log_handler = LogQueueHandler(log_cch) root_logger.addHandler(log_handler) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.set_debug(args.asyncio_debug) loop.slow_callback_duration = 0.1 # 100ms - utils.aio.debug.hook_slow_callbacks(2.0) + aio.debug.hook_slow_callbacks(2.0) - cch = utils.aio.duplex_unix._Duplex.open(args.mp_cch) + cch = aio.duplex_unix._Duplex.open(args.mp_cch) try: - init_req = channel.recv_message(cch, proto.IPC_MESSAGES) + init_req = recv_message(cch, IPC_MESSAGES) assert isinstance( - init_req, proto.InitializeRequest + init_req, InitializeRequest ), "first message must be InitializeRequest" job_proc = JobProcess(start_arguments=args.user_arguments) logger.info("initializing process", extra={"pid": job_proc.pid}) args.initialize_process_fnc(job_proc) logger.info("process initialized", extra={"pid": job_proc.pid}) - channel.send_message(cch, proto.InitializeResponse()) + send_message(cch, InitializeResponse()) + + from .job_main import _async_main main_task = loop.create_task( - job_main._async_main(job_proc, args.job_entrypoint_fnc, cch.detach()), - name="job_proc_main", + _async_main(job_proc, args.job_entrypoint_fnc, cch.detach()), + name="inference_proc_main", ) while not main_task.done(): try: @@ -65,7 +69,7 @@ def proc_main(args) -> None: except KeyboardInterrupt: # ignore the keyboard interrupt, we handle the process shutdown ourselves on the worker process pass - except (utils.aio.duplex_unix.DuplexClosed, KeyboardInterrupt): + except (aio.duplex_unix.DuplexClosed, KeyboardInterrupt): pass finally: log_handler.close() diff --git a/livekit-agents/livekit/agents/ipc/thread_job_executor.py b/livekit-agents/livekit/agents/ipc/job_thread_executor.py similarity index 100% rename from livekit-agents/livekit/agents/ipc/thread_job_executor.py rename to livekit-agents/livekit/agents/ipc/job_thread_executor.py diff --git a/livekit-agents/livekit/agents/ipc/log_queue.py b/livekit-agents/livekit/agents/ipc/log_queue.py new file mode 100644 index 000000000..38115cff1 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/log_queue.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import copy +import logging +import pickle +import queue +import sys +import threading +from typing import Callable, Optional + +from .. import utils +from ..utils.aio import duplex_unix + + +class LogQueueListener: + def __init__( + self, + duplex: utils.aio.duplex_unix._Duplex, + prepare_fnc: Callable[[logging.LogRecord], None], + ): + self._thread: threading.Thread | None = None + self._duplex = duplex + self._prepare_fnc = prepare_fnc + + def start(self) -> None: + self._thread = threading.Thread(target=self._monitor, name="ipc_log_listener") + self._thread.start() + + def stop(self) -> None: + if self._thread is None: + return + + self._duplex.close() + self._thread.join() + self._thread = None + + def handle(self, record: logging.LogRecord) -> None: + self._prepare_fnc(record) + + lger = logging.getLogger(record.name) + if not lger.isEnabledFor(record.levelno): + return + + lger.callHandlers(record) + + def _monitor(self): + while True: + try: + data = self._duplex.recv_bytes() + except utils.aio.duplex_unix.DuplexClosed: + break + + record = pickle.loads(data) + self.handle(record) + + +class LogQueueHandler(logging.Handler): + _sentinal = None + + def __init__(self, duplex: utils.aio.duplex_unix._Duplex) -> None: + super().__init__() + self._duplex = duplex + self._send_q = queue.SimpleQueue[Optional[bytes]]() + self._send_thread = threading.Thread( + target=self._forward_logs, name="ipc_log_forwarder" + ) + self._send_thread.start() + + def _forward_logs(self): + while True: + serialized_record = self._send_q.get() + if serialized_record is None: + break + + try: + self._duplex.send_bytes(serialized_record) + except duplex_unix.DuplexClosed: + break + + self._duplex.close() + + def emit(self, record: logging.LogRecord) -> None: + try: + # Check if Python is shutting down + if sys.is_finalizing(): + return + + # from https://github.com/python/cpython/blob/91b7f2e7f6593acefda4fa860250dd87d6f849bf/Lib/logging/handlers.py#L1453 + msg = self.format(record) + record = copy.copy(record) + record.message = msg + record.msg = msg + record.args = None + record.exc_info = None + record.exc_text = None + record.stack_info = None + + # https://websockets.readthedocs.io/en/stable/topics/logging.html#logging-to-json + # webosckets library add "websocket" attribute to log records, which is not pickleable + if hasattr(record, "websocket"): + record.websocket = None + + self._send_q.put_nowait(pickle.dumps(record)) + + except Exception: + self.handleError(record) + + def close(self) -> None: + super().close() + self._send_q.put_nowait(self._sentinal) From 940eefe2ee954d4339ac95b04f519b447ba572d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 00:21:48 +0100 Subject: [PATCH 07/24] wip --- .../livekit/agents/inference_runner.py | 44 ++++++++++ .../livekit/agents/ipc/inference_main.py | 3 - .../agents/ipc/inference_proc_executor.py | 8 -- livekit-agents/livekit/agents/ipc/proto.py | 42 ++++++++++ .../agents/pipeline/endpoint_detector.py | 27 +++--- livekit-agents/livekit/agents/plugin.py | 1 - .../livekit/agents/utils/inference_runner.py | 0 .../livekit-plugins-eou/CHANGELOG.md | 73 ++++++++++++++++ livekit-plugins/livekit-plugins-eou/README.md | 14 ++++ .../livekit/plugins/eou/__init__.py | 34 ++++++++ .../livekit/plugins/eou/eou.py | 84 +++++++++++++++++++ .../livekit/plugins/eou/log.py | 3 + .../livekit/plugins/eou/version.py | 15 ++++ .../livekit-plugins-eou/package.json | 5 ++ .../livekit-plugins-eou/pyproject.toml | 3 + livekit-plugins/livekit-plugins-eou/setup.py | 57 +++++++++++++ 16 files changed, 388 insertions(+), 25 deletions(-) create mode 100644 livekit-agents/livekit/agents/inference_runner.py create mode 100644 livekit-agents/livekit/agents/utils/inference_runner.py create mode 100644 livekit-plugins/livekit-plugins-eou/CHANGELOG.md create mode 100644 livekit-plugins/livekit-plugins-eou/README.md create mode 100644 livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/__init__.py create mode 100644 livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py create mode 100644 livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/log.py create mode 100644 livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/version.py create mode 100644 livekit-plugins/livekit-plugins-eou/package.json create mode 100644 livekit-plugins/livekit-plugins-eou/pyproject.toml create mode 100644 livekit-plugins/livekit-plugins-eou/setup.py diff --git a/livekit-agents/livekit/agents/inference_runner.py b/livekit-agents/livekit/agents/inference_runner.py new file mode 100644 index 000000000..c8119ca48 --- /dev/null +++ b/livekit-agents/livekit/agents/inference_runner.py @@ -0,0 +1,44 @@ +from __future__ import annotations + + +from abc import ABCMeta, abstractmethod + + +from typing import Type + +import threading + +from abc import ABC, abstractmethod + + +class _RunnerMeta(ABCMeta): + @property + @abstractmethod + def METHOD(cls) -> str: ... + + +# kept private until we stabilize the API (only used for EOU today) +class _InferenceRunner(ABC, metaclass=_RunnerMeta): + registered_runners: dict[str, Type["_InferenceRunner"]] = {} + + @classmethod + def register_runner(cls, runner_class: Type["_InferenceRunner"]) -> None: + if threading.current_thread() != threading.main_thread(): + raise RuntimeError("InferenceRunner must be registered on the main thread") + + if runner_class.METHOD in cls.registered_runners: + raise ValueError( + f"InferenceRunner {runner_class.METHOD} already registered" + ) + + cls.registered_runners[runner_class.METHOD] = runner_class + + @abstractmethod + def initialize(self) -> None: + """Initialize the runner. This is used to load models, etc.""" + ... + + @abstractmethod + def run(self, data: bytes) -> bytes | None: + """Run inference on the given data.""" + ... diff --git a/livekit-agents/livekit/agents/ipc/inference_main.py b/livekit-agents/livekit/agents/ipc/inference_main.py index 1edffd5f9..ca7f6f9c4 100644 --- a/livekit-agents/livekit/agents/ipc/inference_main.py +++ b/livekit-agents/livekit/agents/ipc/inference_main.py @@ -17,7 +17,6 @@ async def _async_main( - entrypoint_fnc: Callable, mp_cch: socket.socket, ) -> None: cch = await duplex_unix._AsyncDuplex.open(mp_cch) @@ -69,8 +68,6 @@ def _done_cb(_: asyncio.Task) -> None: @dataclass class ProcStartArgs: - initialize_process_fnc: Callable - entrypoint_fnc: Callable log_cch: socket.socket mp_cch: socket.socket asyncio_debug: bool diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py index 429353286..ae442bc57 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py @@ -19,8 +19,6 @@ @dataclass class _ProcOpts: - initialize_process_fnc: Callable - entrypoint_fnc: Callable mp_ctx: BaseContext initialize_timeout: float close_timeout: float @@ -30,8 +28,6 @@ class ProcJobExecutor: def __init__( self, *, - initialize_process_fnc: Callable, - entrypoint_fnc: Callable, initialize_timeout: float, close_timeout: float, mp_ctx: BaseContext, @@ -39,8 +35,6 @@ def __init__( ) -> None: self._loop = loop self._opts = _ProcOpts( - initialize_process_fnc=initialize_process_fnc, - entrypoint_fnc=entrypoint_fnc, initialize_timeout=initialize_timeout, close_timeout=close_timeout, mp_ctx=mp_ctx, @@ -101,8 +95,6 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None: log_listener.start() self._proc_args = inference_main.ProcStartArgs( - initialize_process_fnc=self._opts.initialize_process_fnc, - entrypoint_fnc=self._opts.entrypoint_fnc, log_cch=mp_log_cch, mp_cch=mp_cch, asyncio_debug=self._loop.get_debug(), diff --git a/livekit-agents/livekit/agents/ipc/proto.py b/livekit-agents/livekit/agents/ipc/proto.py index 7dd7c29e3..18a880221 100644 --- a/livekit-agents/livekit/agents/ipc/proto.py +++ b/livekit-agents/livekit/agents/ipc/proto.py @@ -121,6 +121,46 @@ def read(self, b: io.BytesIO) -> None: self.reason = channel.read_string(b) +@dataclass +class InferenceRequest: + """sent by a subprocess to the main process to request inference""" + + MSG_ID: ClassVar[int] = 7 + method: str = "" + request_id: str = "" + data: bytes = b"" + + def write(self, b: io.BytesIO) -> None: + channel.write_string(b, self.method) + channel.write_string(b, self.request_id) + channel.write_bytes(b, self.data) + + def read(self, b: io.BytesIO) -> None: + self.method = channel.read_string(b) + self.request_id = channel.read_string(b) + self.data = channel.read_bytes(b) + + +@dataclass +class InferenceResponse: + """response to an InferenceRequest""" + + MSG_ID: ClassVar[int] = 8 + request_id: str = "" + data: bytes = b"" + error: str = "" + + def write(self, b: io.BytesIO) -> None: + channel.write_string(b, self.request_id) + channel.write_bytes(b, self.data) + channel.write_string(b, self.error) + + def read(self, b: io.BytesIO) -> None: + self.request_id = channel.read_string(b) + self.data = channel.read_bytes(b) + self.error = channel.read_string(b) + + IPC_MESSAGES = { InitializeRequest.MSG_ID: InitializeRequest, InitializeResponse.MSG_ID: InitializeResponse, @@ -129,4 +169,6 @@ def read(self, b: io.BytesIO) -> None: StartJobRequest.MSG_ID: StartJobRequest, ShutdownRequest.MSG_ID: ShutdownRequest, Exiting.MSG_ID: Exiting, + InferenceRequest.MSG_ID: InferenceRequest, + InferenceResponse.MSG_ID: InferenceResponse, } diff --git a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py index 831293cd9..228ff563b 100644 --- a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py +++ b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py @@ -8,25 +8,27 @@ from .log import logger -PUNCS = string.punctuation.replace("'","") +PUNCS = string.punctuation.replace("'", "") MAX_HISTORY = 3 + class EndpointDetector: - def __init__(self, model_path='jeradf/opt-125m-eou'): + def __init__(self, model_path="jeradf/opt-125m-eou"): self.model = ORTModelForCausalLM.from_pretrained(model_path) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self._eou_index = self.tokenizer.encode("<|im_end|>")[-1] def normalize(self, text): def strip_puncs(text): - return text.translate(str.maketrans('', '', PUNCS)) + return text.translate(str.maketrans("", "", PUNCS)) + return " ".join(strip_puncs(text).lower().split()) def apply_chat_template(self, convo): for msg in convo: - if msg['role'] not in ['user', 'assistant']: + if msg["role"] not in ["user", "assistant"]: continue - msg['content'] = self.normalize(msg['content']) + msg["content"] = self.normalize(msg["content"]) convo_text = self.tokenizer.apply_chat_template( convo, @@ -36,7 +38,7 @@ def apply_chat_template(self, convo): ) # remove the EOU token from current utterance - ix = convo_text.rfind('<|im_end|>') + ix = convo_text.rfind("<|im_end|>") text = convo_text[:ix] return text @@ -44,13 +46,13 @@ def tokenize(self, text): return self.tokenizer( text, add_special_tokens=False, - return_tensors='pt', + return_tensors="pt", ) def predict(self, utterance, convo=[]): start_time = time.time() - convo.append(dict(role='user', content=utterance)) + convo.append(dict(role="user", content=utterance)) convo = convo[-MAX_HISTORY:] text = self.apply_chat_template(convo) @@ -63,14 +65,13 @@ def predict(self, utterance, convo=[]): end_time = time.time() latency = end_time - start_time - + logger.debug( - "EndpointDetector inference", + "EndpointDetector inference", extra={ - "probability": round(result, 2), + "probability": round(result, 2), "utterance": utterance, "latency": round(latency, 2), - } + }, ) return result - diff --git a/livekit-agents/livekit/agents/plugin.py b/livekit-agents/livekit/agents/plugin.py index 3554fc337..5aca08a93 100644 --- a/livekit-agents/livekit/agents/plugin.py +++ b/livekit-agents/livekit/agents/plugin.py @@ -13,7 +13,6 @@ class Plugin(ABC): registered_plugins: List["Plugin"] = [] emitter: utils.EventEmitter[EventTypes] = utils.EventEmitter() - lock = threading.Lock() # TODO(theomonnom): make logger mandatory once all plugins have been updated def __init__( diff --git a/livekit-agents/livekit/agents/utils/inference_runner.py b/livekit-agents/livekit/agents/utils/inference_runner.py new file mode 100644 index 000000000..e69de29bb diff --git a/livekit-plugins/livekit-plugins-eou/CHANGELOG.md b/livekit-plugins/livekit-plugins-eou/CHANGELOG.md new file mode 100644 index 000000000..535fb2bec --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/CHANGELOG.md @@ -0,0 +1,73 @@ +# livekit-plugins-minimal + +## 0.2.0 + +### Minor Changes + +- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +### Patch Changes + +- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO)) + +- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom)) + +- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom)) + +- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.0-dev.7 + +### Patch Changes + +- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.0-dev.6 + +### Patch Changes + +- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.0-dev.5 + +### Patch Changes + +- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.0-dev.4 + +### Patch Changes + +- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.0-dev.3 + +### Patch Changes + +- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.0-dev.2 + +### Patch Changes + +- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.0-dev.1 + +### Minor Changes + +- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) + +## 0.1.1-dev.0 + +### Patch Changes + +- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO)) diff --git a/livekit-plugins/livekit-plugins-eou/README.md b/livekit-plugins/livekit-plugins-eou/README.md new file mode 100644 index 000000000..006906528 --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/README.md @@ -0,0 +1,14 @@ +# LiveKit Plugins Minimal + +This is a minimal example of a LiveKit plugin for Agents. + +### Developer note + +When copying this directory over to create a new `livekit-plugins` package, make sure it's nested within the `livekit-plugins` folder and that the `"name"` field in `package.json` follows the proper naming convention for CI: + +```json +{ + "name": "livekit-plugins-", + "private": true +} +``` diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/__init__.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/__init__.py new file mode 100644 index 000000000..66362a52f --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from livekit.agents import Plugin + +from .log import logger +from .version import __version__ + + + +class EOUPlugin(Plugin): + def __init__(self): + super().__init__(__name__, __version__, __package__, logger) + + def download_files(self) -> None: + from transformers import AutoTokenizer, AutoModelForCausalLM + from .eou import HG_MODEL + + AutoModelForCausalLM.from_pretrained(HG_MODEL, from_tf=True) + AutoTokenizer.from_pretrained(HG_MODEL) + + +Plugin.register_plugin(EOUPlugin()) diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py new file mode 100644 index 000000000..6e72c7f58 --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import string +import json +import torch + +from livekit.agents.inference_runner import _InferenceRunner + +from transformers import AutoTokenizer, AutoModelForCausalLM + +HG_MODEL = "livekit/opt-125m-endpoint-detector" +PUNCS = string.punctuation.replace("'", "") +MAX_HISTORY = 3 + + +class _EUORunner(_InferenceRunner): + METHOD = "lk_end_of_utterance" + + def __init__(self) -> None: + pass + + def _normalize(self, text): + def strip_puncs(text): + return text.translate(str.maketrans("", "", PUNCS)) + + return " ".join(strip_puncs(text).lower().split()) + + def _format_chat_ctx(self, chat_ctx: dict): + new_chat_ctx = [] + for msg in chat_ctx: + if msg["role"] not in ["user", "assistant"]: + continue + + content = self._normalize(msg["content"]) + + if not content: + continue + + msg["content"] = content + new_chat_ctx.append(msg) + + convo_text = self._tokenizer.apply_chat_template( + new_chat_ctx, + add_generation_prompt=False, + add_special_tokens=False, + tokenize=False, + ) + + # remove the EOU token from current utterance + ix = convo_text.rfind("<|im_end|>") + text = convo_text[:ix] + return text + + def initialize(self) -> None: + self._model = AutoModelForCausalLM.from_pretrained(HG_MODEL, from_tf=True) + self._tokenizer = AutoTokenizer.from_pretrained(HG_MODEL) + self._eou_index = self._tokenizer.encode("<|im_end|>")[-1] + + def run(self, data: bytes) -> bytes | None: + data_json = json.loads(data) + chat_ctx = data_json.get("chat_ctx", None) + + if not chat_ctx: + raise ValueError("chat_ctx is required on the inference input data") + + chat_ctx = chat_ctx[-MAX_HISTORY:] + + text = self._format_chat_ctx(chat_ctx) + inputs = self._tokenizer( + text, + add_special_tokens=False, + return_tensors="pt", + ) + + outputs = self._model(**inputs) + logits = outputs.logits[0, -1, :] + output_probs = torch.nn.functional.softmax(logits, dim=-1) + eou_probability = output_probs[self._eou_index].item() + return json.dumps({"eou_probability": eou_probability}).encode() + + +class EOU: + def __init__(self) -> None: + pass diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/log.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/log.py new file mode 100644 index 000000000..11cb57b75 --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/log.py @@ -0,0 +1,3 @@ +import logging + +logger = logging.getLogger("livekit.plugins.eou") diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/version.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/version.py new file mode 100644 index 000000000..eaa4231b0 --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/version.py @@ -0,0 +1,15 @@ +# Copyright 2023 LiveKit, Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.2.0" diff --git a/livekit-plugins/livekit-plugins-eou/package.json b/livekit-plugins/livekit-plugins-eou/package.json new file mode 100644 index 000000000..098a32e95 --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/package.json @@ -0,0 +1,5 @@ +{ + "name": "livekit-plugins-eou", + "private": true, + "version": "0.2.0" +} diff --git a/livekit-plugins/livekit-plugins-eou/pyproject.toml b/livekit-plugins/livekit-plugins-eou/pyproject.toml new file mode 100644 index 000000000..8cf32563a --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/livekit-plugins/livekit-plugins-eou/setup.py b/livekit-plugins/livekit-plugins-eou/setup.py new file mode 100644 index 000000000..2826e7050 --- /dev/null +++ b/livekit-plugins/livekit-plugins-eou/setup.py @@ -0,0 +1,57 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib + +import setuptools +import setuptools.command.build_py + +here = pathlib.Path(__file__).parent.resolve() +about = {} +with open(os.path.join(here, "livekit", "plugins", "eou", "version.py"), "r") as f: + exec(f.read(), about) + + +setuptools.setup( + name="livekit-plugins-eou", + version=about["__version__"], + description="End of utterance detection for LiveKit Agents", + long_description=(here / "README.md").read_text(encoding="utf-8"), + long_description_content_type="text/markdown", + url="https://github.com/livekit/agents", + cmdclass={}, + classifiers=[ + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Multimedia :: Video", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3 :: Only", + ], + keywords=["webrtc", "realtime", "audio", "video", "livekit"], + license="Apache-2.0", + packages=setuptools.find_namespace_packages(include=["livekit.*"]), + python_requires=">=3.9.0", + install_requires=["livekit-agents>=0.11"], + package_data={"livekit.plugins.eou": ["py.typed"]}, + project_urls={ + "Documentation": "https://docs.livekit.io", + "Website": "https://livekit.io/", + "Source": "https://github.com/livekit/agents", + }, +) From 047f79d1369c457e7dbe1cdbcc95452b80391a0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 00:22:00 +0100 Subject: [PATCH 08/24] Delete endpoint_detector.py --- .../agents/pipeline/endpoint_detector.py | 77 ------------------- 1 file changed, 77 deletions(-) delete mode 100644 livekit-agents/livekit/agents/pipeline/endpoint_detector.py diff --git a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py b/livekit-agents/livekit/agents/pipeline/endpoint_detector.py deleted file mode 100644 index 228ff563b..000000000 --- a/livekit-agents/livekit/agents/pipeline/endpoint_detector.py +++ /dev/null @@ -1,77 +0,0 @@ -import copy -import string -import time - -import torch -from optimum.onnxruntime import ORTModelForCausalLM -from transformers import AutoTokenizer - -from .log import logger - -PUNCS = string.punctuation.replace("'", "") -MAX_HISTORY = 3 - - -class EndpointDetector: - def __init__(self, model_path="jeradf/opt-125m-eou"): - self.model = ORTModelForCausalLM.from_pretrained(model_path) - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - self._eou_index = self.tokenizer.encode("<|im_end|>")[-1] - - def normalize(self, text): - def strip_puncs(text): - return text.translate(str.maketrans("", "", PUNCS)) - - return " ".join(strip_puncs(text).lower().split()) - - def apply_chat_template(self, convo): - for msg in convo: - if msg["role"] not in ["user", "assistant"]: - continue - msg["content"] = self.normalize(msg["content"]) - - convo_text = self.tokenizer.apply_chat_template( - convo, - add_generation_prompt=False, - add_special_tokens=False, - tokenize=False, - ) - - # remove the EOU token from current utterance - ix = convo_text.rfind("<|im_end|>") - text = convo_text[:ix] - return text - - def tokenize(self, text): - return self.tokenizer( - text, - add_special_tokens=False, - return_tensors="pt", - ) - - def predict(self, utterance, convo=[]): - start_time = time.time() - - convo.append(dict(role="user", content=utterance)) - convo = convo[-MAX_HISTORY:] - - text = self.apply_chat_template(convo) - inputs = self.tokenize(text) - - outputs = self.model(**inputs) - logits = outputs.logits[0, -1, :] - probs = torch.nn.functional.softmax(logits, dim=-1) - result = probs[self._eou_index].item() - - end_time = time.time() - latency = end_time - start_time - - logger.debug( - "EndpointDetector inference", - extra={ - "probability": round(result, 2), - "utterance": utterance, - "latency": round(latency, 2), - }, - ) - return result From 644f958ef0e11b137705823b5d58d93a3d6e3780 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Tue, 26 Nov 2024 00:22:30 +0100 Subject: [PATCH 09/24] Discard changes to examples/voice-pipeline-agent/minimal_assistant.py --- examples/voice-pipeline-agent/minimal_assistant.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/voice-pipeline-agent/minimal_assistant.py b/examples/voice-pipeline-agent/minimal_assistant.py index 34dcf3933..4b94bd5b7 100644 --- a/examples/voice-pipeline-agent/minimal_assistant.py +++ b/examples/voice-pipeline-agent/minimal_assistant.py @@ -50,7 +50,6 @@ async def entrypoint(ctx: JobContext): llm=openai.LLM(), tts=openai.TTS(), chat_ctx=initial_ctx, - use_endpoint_detector=True, ) agent.start(ctx.room, participant) From f7bcfb76e82a339c570029d82594466c4ec053d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 00:52:09 +0100 Subject: [PATCH 10/24] wip --- .../voice-pipeline-agent/minimal_assistant.py | 1 - .../livekit/agents/inference_runner.py | 10 ++----- livekit-agents/livekit/agents/ipc/__init__.py | 10 ++++--- .../livekit/agents/ipc/inference_main.py | 5 ---- .../agents/ipc/inference_proc_executor.py | 21 ++++++------- .../livekit/agents/ipc/job_proc_executor.py | 6 ++-- .../livekit/agents/ipc/proc_pool.py | 6 ++-- livekit-agents/livekit/agents/worker.py | 19 ++++++++++++ livekit-plugins/install_plugins_editable.sh | 4 ++- .../livekit/plugins/eou/__init__.py | 9 ++++-- .../livekit/plugins/eou/eou.py | 30 ++++++++++++------- 11 files changed, 72 insertions(+), 49 deletions(-) diff --git a/examples/voice-pipeline-agent/minimal_assistant.py b/examples/voice-pipeline-agent/minimal_assistant.py index 34dcf3933..4b94bd5b7 100644 --- a/examples/voice-pipeline-agent/minimal_assistant.py +++ b/examples/voice-pipeline-agent/minimal_assistant.py @@ -50,7 +50,6 @@ async def entrypoint(ctx: JobContext): llm=openai.LLM(), tts=openai.TTS(), chat_ctx=initial_ctx, - use_endpoint_detector=True, ) agent.start(ctx.room, participant) diff --git a/livekit-agents/livekit/agents/inference_runner.py b/livekit-agents/livekit/agents/inference_runner.py index c8119ca48..f8be9b8d7 100644 --- a/livekit-agents/livekit/agents/inference_runner.py +++ b/livekit-agents/livekit/agents/inference_runner.py @@ -1,14 +1,8 @@ from __future__ import annotations - -from abc import ABCMeta, abstractmethod - - -from typing import Type - import threading - -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod +from typing import Type class _RunnerMeta(ABCMeta): diff --git a/livekit-agents/livekit/agents/ipc/__init__.py b/livekit-agents/livekit/agents/ipc/__init__.py index ab04d6b5e..589936600 100644 --- a/livekit-agents/livekit/agents/ipc/__init__.py +++ b/livekit-agents/livekit/agents/ipc/__init__.py @@ -1,17 +1,19 @@ from . import ( channel, + inference_proc_executor, job_executor, - proc_job_executor, + job_proc_executor, + job_thread_executor, proc_pool, proto, - thread_job_executor, ) __all__ = [ "proto", "channel", "proc_pool", - "proc_job_executor", - "thread_job_executor", + "job_proc_executor", + "job_thread_executor", + "inference_proc_executor", "job_executor", ] diff --git a/livekit-agents/livekit/agents/ipc/inference_main.py b/livekit-agents/livekit/agents/ipc/inference_main.py index ca7f6f9c4..f920e9250 100644 --- a/livekit-agents/livekit/agents/ipc/inference_main.py +++ b/livekit-agents/livekit/agents/ipc/inference_main.py @@ -3,14 +3,9 @@ import asyncio import contextlib import socket -import threading from dataclasses import dataclass -from typing import Any, Callable - -from livekit import rtc from .. import utils -from ..job import JobContext, JobProcess from ..log import logger from ..utils.aio import duplex_unix from . import channel, proto diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py index ae442bc57..629d948f8 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py @@ -8,12 +8,12 @@ import threading from dataclasses import dataclass from multiprocessing.context import BaseContext -from typing import Any, Callable +from typing import Any from .. import utils from ..log import logger from ..utils.aio import duplex_unix -from . import channel, inference_proc_lazy_main, proto, inference_main +from . import channel, inference_main, inference_proc_lazy_main, proto from .log_queue import LogQueueListener @@ -24,7 +24,7 @@ class _ProcOpts: close_timeout: float -class ProcJobExecutor: +class InferenceProcExecutor: def __init__( self, *, @@ -67,7 +67,6 @@ def started(self) -> bool: return self._main_atask is not None async def start(self) -> None: - """start the job process""" if self.started: raise RuntimeError("process already started") @@ -150,7 +149,7 @@ async def initialize(self) -> None: asyncio.TimeoutError("process initialization timed out") ) logger.error( - "initialization timed out, killing job", extra=self.logging_extra() + "initialization timed out, killing process", extra=self.logging_extra() ) self._send_kill_signal() raise @@ -161,7 +160,6 @@ async def initialize(self) -> None: self._initialize_fut.set_result(None) async def aclose(self) -> None: - """attempt to gracefully close the job process""" if not self.started: return @@ -176,7 +174,8 @@ async def aclose(self) -> None: ) except asyncio.TimeoutError: logger.error( - "process did not exit in time, killing job", extra=self.logging_extra() + "process did not exit in time, killing process", + extra=self.logging_extra(), ) self._send_kill_signal() @@ -185,7 +184,6 @@ async def aclose(self) -> None: await asyncio.shield(self._main_atask) async def kill(self) -> None: - """forcefully kill the job process""" if not self.started: raise RuntimeError("process not started") @@ -197,14 +195,13 @@ async def kill(self) -> None: await asyncio.shield(self._main_atask) def _send_kill_signal(self) -> None: - """forcefully kill the job process""" try: if not self._proc.is_alive(): return except ValueError: return - logger.info("killing job process", extra=self.logging_extra()) + logger.info("killing process", extra=self.logging_extra()) if sys.platform == "win32": self._proc.terminate() else: @@ -235,7 +232,7 @@ async def _main_task(self) -> None: if self._exitcode != 0 and not self._kill_sent: logger.error( - f"job process exited with non-zero exit code {self.exitcode}", + f"inference process exited with non-zero exit code {self.exitcode}", extra=self.logging_extra(), ) @@ -293,6 +290,6 @@ async def _pong_timeout_co(): def logging_extra(self): extra: dict[str, Any] = { "pid": self.pid, - "inference_process": True, + "inference_proc": True, } return extra diff --git a/livekit-agents/livekit/agents/ipc/job_proc_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py index 26023924f..938101f08 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -14,14 +14,14 @@ from ..job import JobContext, JobProcess, RunningJobInfo from ..log import logger from ..utils.aio import duplex_unix -from . import channel, job_main, proc_lazy_main, proto +from . import channel, job_main, job_proc_lazy_main, proto from .job_executor import ( JobExecutorError_Runtime, JobExecutorError_ShutdownTimeout, JobExecutorError_Unresponsive, RunStatus, ) -from .log_queue_listener import LogQueueListener +from .log_queue import LogQueueListener @dataclass @@ -153,7 +153,7 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None: ) self._proc = self._opts.mp_ctx.Process( # type: ignore - target=proc_lazy_main.proc_main, + target=job_proc_lazy_main.proc_main, args=(self._proc_args,), name="job_proc", ) diff --git a/livekit-agents/livekit/agents/ipc/proc_pool.py b/livekit-agents/livekit/agents/ipc/proc_pool.py index d707987ab..96b63a1e3 100644 --- a/livekit-agents/livekit/agents/ipc/proc_pool.py +++ b/livekit-agents/livekit/agents/ipc/proc_pool.py @@ -8,7 +8,7 @@ from ..job import JobContext, JobExecutorType, JobProcess, RunningJobInfo from ..log import logger from ..utils import aio -from . import proc_job_executor, thread_job_executor +from . import job_proc_executor, job_thread_executor from .job_executor import JobExecutor EventTypes = Literal[ @@ -95,7 +95,7 @@ async def launch_job(self, info: RunningJobInfo) -> None: async def _proc_watch_task(self) -> None: proc: JobExecutor if self._job_executor_type == JobExecutorType.THREAD: - proc = thread_job_executor.ThreadJobExecutor( + proc = job_thread_executor.ThreadJobExecutor( initialize_process_fnc=self._initialize_process_fnc, job_entrypoint_fnc=self._job_entrypoint_fnc, initialize_timeout=self._initialize_timeout, @@ -103,7 +103,7 @@ async def _proc_watch_task(self) -> None: loop=self._loop, ) elif self._job_executor_type == JobExecutorType.PROCESS: - proc = proc_job_executor.ProcJobExecutor( + proc = job_proc_executor.ProcJobExecutor( initialize_process_fnc=self._initialize_process_fnc, job_entrypoint_fnc=self._job_entrypoint_fnc, initialize_timeout=self._initialize_timeout, diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index a9a6c39b3..a4b504b18 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -43,6 +43,7 @@ from . import http_server, ipc, utils from ._exceptions import AssignmentTimeoutError +from .inference_runner import _InferenceRunner from .job import ( JobAcceptArguments, JobContext, @@ -276,6 +277,19 @@ def __init__( self._main_task: asyncio.Task[None] | None = None + self._inference_executor: ( + ipc.inference_proc_executor.InferenceProcExecutor | None + ) = None + if len(_InferenceRunner.registered_runners) > 0: + self._inference_executor = ( + ipc.inference_proc_executor.InferenceProcExecutor( + initialize_timeout=60.0, + close_timeout=5.0, + mp_ctx=mp_ctx, + loop=self._loop, + ) + ) + async def run(self): if not self._closed: raise Exception("worker is already running") @@ -285,6 +299,11 @@ async def run(self): extra={"version": __version__, "rtc-version": rtc.__version__}, ) + if self._inference_executor is not None: + logger.info("starting inference executor") + await self._inference_executor.start() + await self._inference_executor.initialize() + self._closed = False self._proc_pool.start() self._api = api.LiveKitAPI( diff --git a/livekit-plugins/install_plugins_editable.sh b/livekit-plugins/install_plugins_editable.sh index 0072e5a17..2f97f878a 100755 --- a/livekit-plugins/install_plugins_editable.sh +++ b/livekit-plugins/install_plugins_editable.sh @@ -16,6 +16,8 @@ pip install -e ./livekit-plugins-minimal --config-settings editable_mode=strict pip install -e ./livekit-plugins-nltk --config-settings editable_mode=strict pip install -e ./livekit-plugins-openai --config-settings editable_mode=strict pip install -e ./livekit-plugins-rag --config-settings editable_mode=strict +pip install -e ./livekit-plugins-llama-index --config-settings editable_mode=strict +pip install -e ./livekit-plugins-eou --config-settings editable_mode=strict pip install -e ./livekit-plugins-silero --config-settings editable_mode=strict pip install -e ./livekit-plugins-browser --config-settings editable_mode=strict -pip install -e ./livekit-plugins-llama-index --config-settings editable_mode=strict + diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/__init__.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/__init__.py index 66362a52f..479838eb8 100644 --- a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/__init__.py +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/__init__.py @@ -13,10 +13,13 @@ # limitations under the License. from livekit.agents import Plugin +from livekit.agents.inference_runner import _InferenceRunner +from .eou import EOU, _EUORunner from .log import logger from .version import __version__ +__all__ = ["EOU", "__version__"] class EOUPlugin(Plugin): @@ -24,11 +27,13 @@ def __init__(self): super().__init__(__name__, __version__, __package__, logger) def download_files(self) -> None: - from transformers import AutoTokenizer, AutoModelForCausalLM + from transformers import AutoModelForCausalLM, AutoTokenizer + from .eou import HG_MODEL - AutoModelForCausalLM.from_pretrained(HG_MODEL, from_tf=True) + AutoModelForCausalLM.from_pretrained(HG_MODEL) AutoTokenizer.from_pretrained(HG_MODEL) Plugin.register_plugin(EOUPlugin()) +_InferenceRunner.register_runner(_EUORunner) diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py index 6e72c7f58..4303f781a 100644 --- a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py @@ -1,16 +1,20 @@ from __future__ import annotations -import string import json -import torch +import string +import numpy as np +from livekit.agents import llm from livekit.agents.inference_runner import _InferenceRunner -from transformers import AutoTokenizer, AutoModelForCausalLM - HG_MODEL = "livekit/opt-125m-endpoint-detector" PUNCS = string.punctuation.replace("'", "") -MAX_HISTORY = 3 +MAX_HISTORY = 4 + + +def _softmax(logits: np.ndarray) -> np.ndarray: + exp_logits = np.exp(logits - np.max(logits)) + return exp_logits / np.sum(exp_logits) class _EUORunner(_InferenceRunner): @@ -52,7 +56,9 @@ def _format_chat_ctx(self, chat_ctx: dict): return text def initialize(self) -> None: - self._model = AutoModelForCausalLM.from_pretrained(HG_MODEL, from_tf=True) + from transformers import AutoModelForCausalLM, AutoTokenizer + + self._model = AutoModelForCausalLM.from_pretrained(HG_MODEL) self._tokenizer = AutoTokenizer.from_pretrained(HG_MODEL) self._eou_index = self._tokenizer.encode("<|im_end|>")[-1] @@ -73,12 +79,16 @@ def run(self, data: bytes) -> bytes | None: ) outputs = self._model(**inputs) - logits = outputs.logits[0, -1, :] - output_probs = torch.nn.functional.softmax(logits, dim=-1) - eou_probability = output_probs[self._eou_index].item() + logits = outputs.logits[0, -1, :].detach().numpy() + output_probs = _softmax(logits) + eou_probability = output_probs[self._eou_index] + return json.dumps({"eou_probability": eou_probability}).encode() class EOU: - def __init__(self) -> None: + def __init__(self): + pass + + def predict_eou(self, chat_ctx: llm.ChatContext) -> float: pass From bc0bca84a17913f73013b916e5b0bc0ab607231d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 14:15:48 +0100 Subject: [PATCH 11/24] wip --- .../voice-pipeline-agent/minimal_assistant.py | 2 +- .../livekit/agents/inference_runner.py | 21 +-- .../livekit/agents/ipc/inference_executor.py | 7 + .../livekit/agents/ipc/inference_main.py | 68 ---------- .../agents/ipc/inference_proc_executor.py | 11 +- .../agents/ipc/inference_proc_lazy_main.py | 116 +++++++++++----- livekit-agents/livekit/agents/ipc/job_main.py | 5 +- .../livekit/agents/ipc/proc_client.py | 127 ++++++++++++++++++ livekit-agents/livekit/agents/ipc/proto.py | 10 +- livekit-agents/livekit/agents/job.py | 35 ++++- livekit-agents/livekit/agents/worker.py | 2 - .../livekit/plugins/eou/eou.py | 63 ++++++--- 12 files changed, 328 insertions(+), 139 deletions(-) create mode 100644 livekit-agents/livekit/agents/ipc/inference_executor.py delete mode 100644 livekit-agents/livekit/agents/ipc/inference_main.py create mode 100644 livekit-agents/livekit/agents/ipc/proc_client.py diff --git a/examples/voice-pipeline-agent/minimal_assistant.py b/examples/voice-pipeline-agent/minimal_assistant.py index 4b94bd5b7..9ecad5056 100644 --- a/examples/voice-pipeline-agent/minimal_assistant.py +++ b/examples/voice-pipeline-agent/minimal_assistant.py @@ -13,7 +13,7 @@ metrics, ) from livekit.agents.pipeline import VoicePipelineAgent -from livekit.plugins import deepgram, openai, silero +from livekit.plugins import deepgram, openai, silero, eou load_dotenv() logger = logging.getLogger("voice-assistant") diff --git a/livekit-agents/livekit/agents/inference_runner.py b/livekit-agents/livekit/agents/inference_runner.py index f8be9b8d7..b34cd993f 100644 --- a/livekit-agents/livekit/agents/inference_runner.py +++ b/livekit-agents/livekit/agents/inference_runner.py @@ -2,30 +2,31 @@ import threading from abc import ABC, ABCMeta, abstractmethod -from typing import Type +from typing import ClassVar, Protocol, Type -class _RunnerMeta(ABCMeta): - @property - @abstractmethod - def METHOD(cls) -> str: ... +class _RunnerMeta(Protocol): + INFERENCE_METHOD: ClassVar[str] + + +_RunnersDict = dict[str, Type["_InferenceRunner"]] # kept private until we stabilize the API (only used for EOU today) -class _InferenceRunner(ABC, metaclass=_RunnerMeta): - registered_runners: dict[str, Type["_InferenceRunner"]] = {} +class _InferenceRunner(ABC, _RunnerMeta): + registered_runners: _RunnersDict = {} @classmethod def register_runner(cls, runner_class: Type["_InferenceRunner"]) -> None: if threading.current_thread() != threading.main_thread(): raise RuntimeError("InferenceRunner must be registered on the main thread") - if runner_class.METHOD in cls.registered_runners: + if runner_class.INFERENCE_METHOD in cls.registered_runners: raise ValueError( - f"InferenceRunner {runner_class.METHOD} already registered" + f"InferenceRunner {runner_class.INFERENCE_METHOD} already registered" ) - cls.registered_runners[runner_class.METHOD] = runner_class + cls.registered_runners[runner_class.INFERENCE_METHOD] = runner_class @abstractmethod def initialize(self) -> None: diff --git a/livekit-agents/livekit/agents/ipc/inference_executor.py b/livekit-agents/livekit/agents/ipc/inference_executor.py new file mode 100644 index 000000000..c83aee64d --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/inference_executor.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from typing import Protocol + + +class InferenceExecutor(Protocol): + async def do_inference(self, method: str, data: bytes) -> bytes | None: ... diff --git a/livekit-agents/livekit/agents/ipc/inference_main.py b/livekit-agents/livekit/agents/ipc/inference_main.py deleted file mode 100644 index f920e9250..000000000 --- a/livekit-agents/livekit/agents/ipc/inference_main.py +++ /dev/null @@ -1,68 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import socket -from dataclasses import dataclass - -from .. import utils -from ..log import logger -from ..utils.aio import duplex_unix -from . import channel, proto - - -async def _async_main( - mp_cch: socket.socket, -) -> None: - cch = await duplex_unix._AsyncDuplex.open(mp_cch) - - exit_flag = asyncio.Event() - no_msg_timeout = utils.aio.sleep(proto.PING_INTERVAL * 5) # missing 5 pings - - @utils.log_exceptions(logger=logger) - async def _read_ipc_task(): - while True: - try: - msg = await channel.arecv_message(cch, proto.IPC_MESSAGES) - except duplex_unix.DuplexClosed: - break - - with contextlib.suppress(utils.aio.SleepFinished): - no_msg_timeout.reset() - - if isinstance(msg, proto.PingRequest): - pong = proto.PongResponse( - last_timestamp=msg.timestamp, timestamp=utils.time_ms() - ) - await channel.asend_message(cch, pong) - - if isinstance(msg, proto.ShutdownRequest): - pass - - async def _self_health_check(): - await no_msg_timeout - print("worker process is not responding.. worker crashed?") - with contextlib.suppress(asyncio.CancelledError): - exit_flag.set() - - read_task = asyncio.create_task(_read_ipc_task(), name="ipc_read") - health_check_task = asyncio.create_task(_self_health_check(), name="health_check") - - def _done_cb(_: asyncio.Task) -> None: - with contextlib.suppress(asyncio.InvalidStateError): - exit_flag.set() - - read_task.add_done_callback(_done_cb) - - await exit_flag.wait() - await utils.aio.gracefully_cancel(read_task, health_check_task) - - with contextlib.suppress(duplex_unix.DuplexClosed): - await cch.aclose() - - -@dataclass -class ProcStartArgs: - log_cch: socket.socket - mp_cch: socket.socket - asyncio_debug: bool diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py index 629d948f8..1bec9ec4c 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py @@ -13,9 +13,11 @@ from .. import utils from ..log import logger from ..utils.aio import duplex_unix -from . import channel, inference_main, inference_proc_lazy_main, proto +from . import channel, inference_proc_lazy_main, proto from .log_queue import LogQueueListener +from ..inference_runner import _InferenceRunner + @dataclass class _ProcOpts: @@ -28,10 +30,10 @@ class InferenceProcExecutor: def __init__( self, *, - initialize_timeout: float, - close_timeout: float, mp_ctx: BaseContext, loop: asyncio.AbstractEventLoop, + initialize_timeout: float = 60.0, + close_timeout: float = 2.5, ) -> None: self._loop = loop self._opts = _ProcOpts( @@ -93,10 +95,11 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None: log_listener = LogQueueListener(log_pch, _add_proc_ctx_log) log_listener.start() - self._proc_args = inference_main.ProcStartArgs( + self._proc_args = inference_proc_lazy_main.ProcStartArgs( log_cch=mp_log_cch, mp_cch=mp_cch, asyncio_debug=self._loop.get_debug(), + runners=_InferenceRunner.registered_runners, ) self._proc = self._opts.mp_ctx.Process( # type: ignore diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py index e4fc7626f..344196823 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py @@ -16,16 +16,32 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): sys.excepthook = _no_traceback_excepthook -def proc_main(args) -> None: - # import every package lazily - import asyncio +import asyncio +import socket +from dataclasses import dataclass + +from ..inference_runner import _RunnersDict +from ..log import logger +from ..utils import aio, log_exceptions +from . import proto +from .channel import Message +from .proc_client import _ProcClient + + +@dataclass +class ProcStartArgs: + log_cch: socket.socket + mp_cch: socket.socket + asyncio_debug: bool + runners: _RunnersDict + + +def proc_main(args: ProcStartArgs) -> None: import logging from ..log import logger from ..utils import aio - from .channel import recv_message, send_message from .log_queue import LogQueueHandler - from .proto import IPC_MESSAGES, InitializeRequest, InitializeResponse root_logger = logging.getLogger() root_logger.setLevel(logging.NOTSET) @@ -34,40 +50,74 @@ def proc_main(args) -> None: log_handler = LogQueueHandler(log_cch) root_logger.addHandler(log_handler) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.set_debug(args.asyncio_debug) - loop.slow_callback_duration = 0.1 # 100ms - aio.debug.hook_slow_callbacks(2.0) - - cch = aio.duplex_unix._Duplex.open(args.mp_cch) try: - init_req = recv_message(cch, IPC_MESSAGES) + from .proc_client import _ProcClient + + inf_proc = _InferenceProc(args.runners) - assert isinstance( - init_req, InitializeRequest - ), "first message must be InitializeRequest" + client = _ProcClient( + args.mp_cch, + inf_proc.initialize, + inf_proc.entrypoint, + args.asyncio_debug, + ) pid = current_process().pid logger.info("initializing process", extra={"pid": pid}) - args.initialize_process_fnc() + client.initialize() logger.info("process initialized", extra={"pid": pid}) - send_message(cch, InitializeResponse()) - from .inference_main import _async_main - - main_task = loop.create_task( - _async_main(args.entrypoint_fnc, cch.detach()), - name="inference_proc_main", - ) - while not main_task.done(): - try: - loop.run_until_complete(main_task) - except KeyboardInterrupt: - # ignore the keyboard interrupt, we handle the process shutdown ourselves on the worker process - pass - except (aio.duplex_unix.DuplexClosed, KeyboardInterrupt): - pass + client.run() finally: log_handler.close() - loop.run_until_complete(loop.shutdown_default_executor()) + + +class _InferenceProc: + def __init__(self, runners: _RunnersDict) -> None: + # create an instance of each runner (the ctor must not requires any argument) + self._runners = {name: runner() for name, runner in runners.items()} + + def initialize( + self, init_req: proto.InitializeRequest, client: _ProcClient + ) -> None: + self._client = client + + for runner in self._runners.values(): + logger.debug( + "initializing inference runner", + extra={"runner": runner.__class__.INFERENCE_METHOD}, + ) + runner.initialize() + + @log_exceptions(logger=logger) + async def entrypoint(self, cch: aio.ChanReceiver[Message]) -> None: + async for msg in cch: + if isinstance(msg, proto.InferenceRequest): + await self._handle_inference_request(msg) + + if isinstance(msg, proto.ShutdownRequest): + await self._client.send(proto.Exiting(reason=msg.reason)) + break + + async def _handle_inference_request(self, msg: proto.InferenceRequest) -> None: + loop = asyncio.get_running_loop() + + if msg.method not in self._runners: + logger.warning("unknown inference method", extra={"method": msg.method}) + + try: + data = await loop.run_in_executor( + None, self._runners[msg.method].run, msg.data + ) + await self._client.send( + proto.InferenceResponse( + request_id=msg.request_id, + data=data, + ) + ) + + except Exception as e: + logger.exception("error running inference") + await self._client.send( + proto.InferenceResponse(request_id=msg.request_id, error=str(e)) + ) diff --git a/livekit-agents/livekit/agents/ipc/job_main.py b/livekit-agents/livekit/agents/ipc/job_main.py index dc75155b1..0fe98d468 100644 --- a/livekit-agents/livekit/agents/ipc/job_main.py +++ b/livekit-agents/livekit/agents/ipc/job_main.py @@ -10,7 +10,7 @@ from livekit import rtc from .. import utils -from ..job import JobContext, JobProcess +from ..job import JobContext, JobProcess, _JobContextVar from ..log import logger from ..utils.aio import duplex_unix from . import channel, proto @@ -74,6 +74,8 @@ def _on_ctx_shutdown(reason: str) -> None: @utils.log_exceptions(logger=logger) async def _run_job_task() -> None: utils.http_context._new_session_ctx() + job_ctx_token = _JobContextVar.set(job_ctx) + job_entry_task = asyncio.create_task( job_entrypoint_fnc(job_ctx), name="job_entrypoint" ) @@ -125,6 +127,7 @@ def log_exception(t: asyncio.Task) -> None: logger.exception("error while shutting down the job") await utils.http_context._close_http_ctx() + _JobContextVar.reset(job_ctx_token) exit_proc_fut.set() task = asyncio.create_task(_run_job_task()) diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py new file mode 100644 index 000000000..542d8b8bb --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -0,0 +1,127 @@ +import asyncio +import contextlib +import socket +import sys +from typing import Callable, Coroutine + +from ..log import logger +from ..utils import aio, log_exceptions, time_ms +from . import proto +from .channel import Message, arecv_message, asend_message, recv_message, send_message + + +class _ProcClient: + def __init__( + self, + mp_cch: socket.socket, + initialize_fnc: Callable[[proto.InitializeRequest, "_ProcClient"], None], + entrypoint_fnc: Callable[ + [aio.ChanReceiver[Message]], Coroutine[None, None, None] + ], + asyncio_debug: bool, + ) -> None: + self._mp_cch = mp_cch + self._asyncio_debug = asyncio_debug + self._initialize_fnc = initialize_fnc + self._entrypoint_fnc = entrypoint_fnc + self._initialized = False + + def initialize(self) -> None: + try: + cch = aio.duplex_unix._Duplex.open(self._mp_cch) + self._init_req = recv_message(cch, proto.IPC_MESSAGES) + + assert isinstance( + self._init_req, proto.InitializeRequest + ), "first message must be proto.InitializeRequest" + + self._initialize_fnc(self._init_req, self) + send_message(cch, proto.InitializeResponse()) + self._initialized = True + cch.detach() + except aio.duplex_unix.DuplexClosed as e: + raise RuntimeError("failed to initialize proc_client") from e + + def run(self) -> None: + if not self._initialized: + raise RuntimeError("proc_client not initialized") + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.set_debug(self._asyncio_debug) + loop.slow_callback_duration = 0.1 # 100ms + aio.debug.hook_slow_callbacks(2.0) + + try: + self._task = loop.create_task(self._main_task(), name="proc_client_main") + while not self._task.done(): + try: + loop.run_until_complete(self._task) + except KeyboardInterrupt: + # ignore the keyboard interrupt, we handle the process shutdown ourselves on the worker process + # (See proto.ShutdownRequest) + pass + except KeyboardInterrupt: + pass + finally: + loop.run_until_complete(loop.shutdown_default_executor()) + + async def send(self, msg: Message) -> None: + await asend_message(self._acch, msg) + + async def _main_task(self) -> None: + self._acch = await aio.duplex_unix._AsyncDuplex.open(self._mp_cch) + try: + exit_flag = asyncio.Event() + ping_timeout = aio.sleep(proto.PING_INTERVAL * 5) + + ipc_ch = aio.Chan[Message]() + + @log_exceptions(logger=logger) + async def _read_ipc_task(): + while True: + try: + msg = await arecv_message(self._acch, proto.IPC_MESSAGES) + except aio.duplex_unix.DuplexClosed: + break + + with contextlib.suppress(aio.SleepFinished): + ping_timeout.reset() + + if isinstance(msg, proto.PingRequest): + pong = proto.PongResponse( + last_timestamp=msg.timestamp, timestamp=time_ms() + ) + await asend_message(self._acch, pong) + + ipc_ch.send_nowait(msg) + + async def _self_health_check(): + await ping_timeout + print( + "worker process is not responding.. worker crashed?", + file=sys.stderr, + ) + + read_task = asyncio.create_task(_read_ipc_task(), name="ipc_read") + health_check_task = asyncio.create_task( + _self_health_check(), name="health_check" + ) + entrypoint_task = asyncio.create_task( + self._entrypoint_fnc(ipc_ch), name="entrypoint" + ) + + def _done_cb(_: asyncio.Task) -> None: + with contextlib.suppress(asyncio.InvalidStateError): + exit_flag.set() + + ipc_ch.close() + + read_task.add_done_callback(_done_cb) + health_check_task.add_done_callback(_done_cb) + entrypoint_task.add_done_callback(_done_cb) + + await exit_flag.wait() + await aio.gracefully_cancel(read_task, health_check_task, entrypoint_task) + finally: + await self._acch.aclose() diff --git a/livekit-agents/livekit/agents/ipc/proto.py b/livekit-agents/livekit/agents/ipc/proto.py index 18a880221..14de2881c 100644 --- a/livekit-agents/livekit/agents/ipc/proto.py +++ b/livekit-agents/livekit/agents/ipc/proto.py @@ -147,17 +147,21 @@ class InferenceResponse: MSG_ID: ClassVar[int] = 8 request_id: str = "" - data: bytes = b"" + data: bytes | None = None error: str = "" def write(self, b: io.BytesIO) -> None: channel.write_string(b, self.request_id) - channel.write_bytes(b, self.data) + channel.write_bool(b, self.data is not None) + if self.data is not None: + channel.write_bytes(b, self.data) channel.write_string(b, self.error) def read(self, b: io.BytesIO) -> None: self.request_id = channel.read_string(b) - self.data = channel.read_bytes(b) + has_data = channel.read_bool(b) + if has_data: + self.data = channel.read_bytes(b) self.error = channel.read_string(b) diff --git a/livekit-agents/livekit/agents/job.py b/livekit-agents/livekit/agents/job.py index 471ed86c6..328cf5ffa 100644 --- a/livekit-agents/livekit/agents/job.py +++ b/livekit-agents/livekit/agents/job.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio +import contextvars import multiprocessing as mp from dataclasses import dataclass from enum import Enum, unique @@ -25,6 +26,20 @@ from .log import logger +from .ipc.inference_executor import InferenceExecutor + +_JobContextVar = contextvars.ContextVar("agents_job_context") + + +def get_current_job_context() -> JobContext: + ctx = _JobContextVar.get(None) + if ctx is None: + raise RuntimeError( + "no job context found, are you running this code inside a job entrypoint?" + ) + + return ctx + @unique class JobExecutorType(Enum): @@ -249,10 +264,28 @@ def on_track_published(pub: rtc.RemoteTrackPublication, _: rtc.RemoteParticipant class JobProcess: - def __init__(self, *, start_arguments: Any | None = None) -> None: + def __init__( + self, + *, + start_arguments: Any | None = None, + inference_executor: InferenceExecutor | None = None, + ) -> None: self._mp_proc = mp.current_process() self._userdata: dict[str, Any] = {} self._start_arguments = start_arguments + self._inf_executor = inference_executor + + @property + def inference_executor(self) -> InferenceExecutor: + if self._inf_executor is None: + raise ValueError( + ( + "no inference executor is provided for the current JobProcess, did you " + "forgot to register/import plugins necessary for inference?" + ) + ) + + return self._inf_executor @property def pid(self) -> int | None: diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index a4b504b18..8df4a86bf 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -283,8 +283,6 @@ def __init__( if len(_InferenceRunner.registered_runners) > 0: self._inference_executor = ( ipc.inference_proc_executor.InferenceProcExecutor( - initialize_timeout=60.0, - close_timeout=5.0, mp_ctx=mp_ctx, loop=self._loop, ) diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py index 4303f781a..03887e72c 100644 --- a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py @@ -5,7 +5,9 @@ import numpy as np from livekit.agents import llm +from livekit.agents.job import get_current_job_context from livekit.agents.inference_runner import _InferenceRunner +from livekit.agents.ipc.inference_executor import InferenceExecutor HG_MODEL = "livekit/opt-125m-endpoint-detector" PUNCS = string.punctuation.replace("'", "") @@ -18,10 +20,7 @@ def _softmax(logits: np.ndarray) -> np.ndarray: class _EUORunner(_InferenceRunner): - METHOD = "lk_end_of_utterance" - - def __init__(self) -> None: - pass + INFERENCE_METHOD = "lk_end_of_utterance" def _normalize(self, text): def strip_puncs(text): @@ -32,9 +31,6 @@ def strip_puncs(text): def _format_chat_ctx(self, chat_ctx: dict): new_chat_ctx = [] for msg in chat_ctx: - if msg["role"] not in ["user", "assistant"]: - continue - content = self._normalize(msg["content"]) if not content: @@ -57,9 +53,8 @@ def _format_chat_ctx(self, chat_ctx: dict): def initialize(self) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer - - self._model = AutoModelForCausalLM.from_pretrained(HG_MODEL) - self._tokenizer = AutoTokenizer.from_pretrained(HG_MODEL) + self._model = AutoModelForCausalLM.from_pretrained(HG_MODEL, local_files_only=True) + self._tokenizer = AutoTokenizer.from_pretrained(HG_MODEL, local_files_only=True) self._eou_index = self._tokenizer.encode("<|im_end|>")[-1] def run(self, data: bytes) -> bytes | None: @@ -69,8 +64,6 @@ def run(self, data: bytes) -> bytes | None: if not chat_ctx: raise ValueError("chat_ctx is required on the inference input data") - chat_ctx = chat_ctx[-MAX_HISTORY:] - text = self._format_chat_ctx(chat_ctx) inputs = self._tokenizer( text, @@ -87,8 +80,46 @@ def run(self, data: bytes) -> bytes | None: class EOU: - def __init__(self): - pass + def __init__(self, inference_executor: InferenceExecutor | None = None) -> None: + self._executor = ( + inference_executor or get_current_job_context().proc.inference_executor + ) + + async def predict_eou(self, chat_ctx: llm.ChatContext) -> float: + messages = [] + + for msg in chat_ctx.messages: + if msg.role not in ("user", "assistant"): + continue + + if isinstance(msg.content, str): + messages.append( + { + "role": msg.role, + "content": msg.content, + } + ) + elif isinstance(msg.content, list): + for cnt in msg.content: + if isinstance(cnt, str): + messages.append( + { + "role": msg.role, + "content": cnt, + } + ) + break + + messages = messages[-MAX_HISTORY:] + + json_data = json.dumps({"chat_ctx": messages}).encode() + result = await self._executor.do_inference( + _EUORunner.INFERENCE_METHOD, json_data + ) + + assert ( + result is not None + ), "end_of_utterance prediction should always returns a result" - def predict_eou(self, chat_ctx: llm.ChatContext) -> float: - pass + result_json = json.loads(result.decode()) + return result_json["eou_probability"] From d30ef1efe9aecc314039c0618f3a4dc403f001d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Tue, 26 Nov 2024 14:18:28 +0100 Subject: [PATCH 12/24] Update CHANGELOG.md --- .../livekit-plugins-eou/CHANGELOG.md | 74 +------------------ 1 file changed, 1 insertion(+), 73 deletions(-) diff --git a/livekit-plugins/livekit-plugins-eou/CHANGELOG.md b/livekit-plugins/livekit-plugins-eou/CHANGELOG.md index 535fb2bec..6071b2806 100644 --- a/livekit-plugins/livekit-plugins-eou/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-eou/CHANGELOG.md @@ -1,73 +1 @@ -# livekit-plugins-minimal - -## 0.2.0 - -### Minor Changes - -- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -### Patch Changes - -- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO)) - -- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom)) - -- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom)) - -- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom)) - -## 0.2.0-dev.7 - -### Patch Changes - -- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -## 0.2.0-dev.6 - -### Patch Changes - -- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -## 0.2.0-dev.5 - -### Patch Changes - -- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -## 0.2.0-dev.4 - -### Patch Changes - -- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -## 0.2.0-dev.3 - -### Patch Changes - -- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom)) - -## 0.2.0-dev.2 - -### Patch Changes - -- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom)) - -## 0.2.0-dev.1 - -### Minor Changes - -- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom)) - -## 0.1.1-dev.0 - -### Patch Changes - -- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO)) +# livekit-plugins-eou From a4ca77cf5b4f291dbf9257e45503ce259e3a64b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 14:51:09 +0100 Subject: [PATCH 13/24] wip --- .../livekit/agents/ipc/inference_executor.py | 32 +++++++++++++++++++ .../agents/ipc/inference_proc_executor.py | 14 +++++++- livekit-agents/livekit/agents/ipc/job_main.py | 17 ++++++++-- .../livekit/agents/ipc/job_proc_executor.py | 28 ++++++++++++++++ .../livekit/agents/ipc/proc_pool.py | 5 ++- livekit-agents/livekit/agents/job.py | 20 ++++-------- livekit-agents/livekit/agents/worker.py | 24 +++++++------- .../livekit/plugins/eou/eou.py | 5 ++- 8 files changed, 114 insertions(+), 31 deletions(-) diff --git a/livekit-agents/livekit/agents/ipc/inference_executor.py b/livekit-agents/livekit/agents/ipc/inference_executor.py index c83aee64d..6d18ee882 100644 --- a/livekit-agents/livekit/agents/ipc/inference_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_executor.py @@ -1,7 +1,39 @@ from __future__ import annotations from typing import Protocol +from ..utils import aio, shortuuid +from . import proto +from ..log import logger +from . import channel +import asyncio class InferenceExecutor(Protocol): async def do_inference(self, method: str, data: bytes) -> bytes | None: ... + + +class _InferenceRunnerClient(InferenceExecutor): + def __init__(self, *, cch: aio.duplex_unix._AsyncDuplex) -> None: + self._cch = cch + self._active_requests: dict[str, asyncio.Future[proto.InferenceResponse]] = {} + + async def do_inference(self, method: str, data: bytes) -> bytes | None: + request_id = shortuuid("INF_") + await channel.asend_message( + self._cch, + proto.InferenceRequest(request_id=request_id, method=method, data=data), + ) + + fut = asyncio.Future[proto.InferenceResponse]() + self._active_requests[request_id] = fut + return (await fut).data + + def _on_inference_response(self, resp: proto.InferenceResponse) -> None: + fut = self._active_requests.pop(resp.request_id, None) + if fut is None: + logger.warning( + "received unexpected inference response", extra={"resp": resp} + ) + return + + fut.set_result(resp) diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py index 1bec9ec4c..513de5fa7 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py @@ -17,6 +17,7 @@ from .log_queue import LogQueueListener from ..inference_runner import _InferenceRunner +from .inference_executor import InferenceExecutor, _InferenceRunnerClient @dataclass @@ -26,7 +27,7 @@ class _ProcOpts: close_timeout: float -class InferenceProcExecutor: +class InferenceProcExecutor(InferenceExecutor): def __init__( self, *, @@ -102,6 +103,8 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None: runners=_InferenceRunner.registered_runners, ) + self._inf_client = _InferenceRunnerClient(cch=self._pch) + self._proc = self._opts.mp_ctx.Process( # type: ignore target=inference_proc_lazy_main.proc_main, args=(self._proc_args,), @@ -135,6 +138,12 @@ async def join(self) -> None: if self._main_atask: await asyncio.shield(self._main_atask) + async def do_inference(self, method: str, data: bytes) -> bytes | None: + if not self.started: + raise RuntimeError("process not started") + + return await self._inf_client.do_inference(method, data) + async def initialize(self) -> None: await channel.asend_message(self._pch, proto.InitializeRequest()) @@ -258,6 +267,9 @@ async def _monitor_task(self) -> None: with contextlib.suppress(utils.aio.SleepFinished): self._pong_timeout.reset() + if isinstance(msg, proto.InferenceResponse): + self._inf_client._on_inference_response(msg) + @utils.log_exceptions(logger=logger) async def _ping_pong_task(self) -> None: ping_interval = utils.aio.interval(proto.PING_INTERVAL) diff --git a/livekit-agents/livekit/agents/ipc/job_main.py b/livekit-agents/livekit/agents/ipc/job_main.py index 0fe98d468..42b699457 100644 --- a/livekit-agents/livekit/agents/ipc/job_main.py +++ b/livekit-agents/livekit/agents/ipc/job_main.py @@ -15,6 +15,8 @@ from ..utils.aio import duplex_unix from . import channel, proto +from .inference_executor import InferenceExecutor, _InferenceRunnerClient + @dataclass class _ShutdownInfo: @@ -35,6 +37,7 @@ def _start_job( start_req: proto.StartJobRequest, exit_proc_fut: asyncio.Event, cch: utils.aio.duplex_unix._AsyncDuplex, + inference_client: _InferenceRunnerClient, ) -> JobTask: # used to warn users if none of connect/shutdown is called inside the job_entry ctx_connect, ctx_shutdown = False, False @@ -69,6 +72,7 @@ def _on_ctx_shutdown(reason: str) -> None: room=room, on_connect=_on_ctx_connect, on_shutdown=_on_ctx_shutdown, + inference_executor=inference_client, ) @utils.log_exceptions(logger=logger) @@ -146,6 +150,8 @@ async def _async_main( exit_proc_fut = asyncio.Event() no_msg_timeout = utils.aio.sleep(proto.PING_INTERVAL * 5) # missing 5 pings + inference_client = _InferenceRunnerClient(cch=cch) + @utils.log_exceptions(logger=logger) async def _read_ipc_task(): nonlocal job_task @@ -166,18 +172,22 @@ async def _read_ipc_task(): if isinstance(msg, proto.StartJobRequest): assert job_task is None, "job task already running" - job_task = _start_job(proc, job_entrypoint_fnc, msg, exit_proc_fut, cch) + job_task = _start_job( + proc, job_entrypoint_fnc, msg, exit_proc_fut, cch, inference_client + ) if isinstance(msg, proto.ShutdownRequest): if job_task is None: - # there is no running job, we can exit immediately - break + break # there is no running job, we can exit immediately with contextlib.suppress(asyncio.InvalidStateError): job_task.shutdown_fut.set_result( _ShutdownInfo(reason=msg.reason, user_initiated=False) ) + if isinstance(msg, proto.InferenceResponse): + inference_client._on_inference_response(msg) + async def _self_health_check(): await no_msg_timeout print("worker process is not responding.. worker crashed?") @@ -208,6 +218,7 @@ class ProcStartArgs: mp_cch: socket.socket asyncio_debug: bool user_arguments: Any | None = None + inference_runners: dict[str, Callable[[], InferenceExecutor]] | None = None @dataclass diff --git a/livekit-agents/livekit/agents/ipc/job_proc_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py index 938101f08..25afd8e1e 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -12,6 +12,7 @@ from .. import utils from ..job import JobContext, JobProcess, RunningJobInfo +from .inference_executor import InferenceExecutor from ..log import logger from ..utils.aio import duplex_unix from . import channel, job_main, job_proc_lazy_main, proto @@ -40,6 +41,7 @@ def __init__( initialize_process_fnc: Callable[[JobProcess], Any], job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]], initialize_timeout: float, + inference_executor: InferenceExecutor | None, close_timeout: float, mp_ctx: BaseContext, loop: asyncio.AbstractEventLoop, @@ -54,6 +56,7 @@ def __init__( ) self._user_args: Any | None = None + self._inf_excecutor = inference_executor self._running_job: RunningJobInfo | None = None self._exitcode: int | None = None self._pid: int | None = None @@ -64,6 +67,8 @@ def __init__( self._kill_sent = False self._initialize_fut = asyncio.Future[None]() + self._inference_tasks: list[asyncio.Task[None]] = [] + self._lock = asyncio.Lock() @property @@ -332,6 +337,29 @@ async def _monitor_task(self, pong_timeout: utils.aio.Sleep) -> None: "job exiting", extra={"reason": msg.reason, **self.logging_extra()} ) + if isinstance(msg, proto.InferenceRequest): + self._inference_tasks.append( + asyncio.create_task(self._do_inference_task(msg)) + ) + + async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None: + if self._inf_excecutor is None: + logger.warning("inference request received but no inference executor") + return + + try: + inf_res = await self._inf_excecutor.do_inference( + inf_req.method, inf_req.data + ) + await channel.asend_message( + self._pch, + proto.InferenceResponse(request_id=inf_req.request_id, data=inf_res), + ) + except Exception as e: + logger.exception( + "error handling inference request", extra=self.logging_extra() + ) + @utils.log_exceptions(logger=logger) async def _ping_pong_task(self, pong_timeout: utils.aio.Sleep) -> None: ping_interval = utils.aio.interval(proto.PING_INTERVAL) diff --git a/livekit-agents/livekit/agents/ipc/proc_pool.py b/livekit-agents/livekit/agents/ipc/proc_pool.py index 96b63a1e3..cb978fd99 100644 --- a/livekit-agents/livekit/agents/ipc/proc_pool.py +++ b/livekit-agents/livekit/agents/ipc/proc_pool.py @@ -8,7 +8,7 @@ from ..job import JobContext, JobExecutorType, JobProcess, RunningJobInfo from ..log import logger from ..utils import aio -from . import job_proc_executor, job_thread_executor +from . import job_proc_executor, job_thread_executor, inference_executor from .job_executor import JobExecutor EventTypes = Literal[ @@ -31,6 +31,7 @@ def __init__( num_idle_processes: int, initialize_timeout: float, close_timeout: float, + inference_executor: inference_executor.InferenceExecutor | None, job_executor_type: JobExecutorType, mp_ctx: BaseContext, loop: asyncio.AbstractEventLoop, @@ -41,6 +42,7 @@ def __init__( self._initialize_process_fnc = initialize_process_fnc self._job_entrypoint_fnc = job_entrypoint_fnc self._close_timeout = close_timeout + self._inf_executor = inference_executor self._initialize_timeout = initialize_timeout self._loop = loop @@ -108,6 +110,7 @@ async def _proc_watch_task(self) -> None: job_entrypoint_fnc=self._job_entrypoint_fnc, initialize_timeout=self._initialize_timeout, close_timeout=self._close_timeout, + inference_executor=self._inf_executor, mp_ctx=self._mp_ctx, loop=self._loop, ) diff --git a/livekit-agents/livekit/agents/job.py b/livekit-agents/livekit/agents/job.py index 328cf5ffa..65e90eba7 100644 --- a/livekit-agents/livekit/agents/job.py +++ b/livekit-agents/livekit/agents/job.py @@ -85,6 +85,7 @@ def __init__( room: rtc.Room, on_connect: Callable[[], None], on_shutdown: Callable[[str], None], + inference_executor: InferenceExecutor, ) -> None: self._proc = proc self._info = info @@ -102,6 +103,11 @@ def __init__( ] = [] self._participant_tasks = dict[Tuple[str, Callable], asyncio.Task[None]]() self._room.on("participant_connected", self._participant_available) + self._inf_executor = inference_executor + + @property + def inference_executor(self) -> InferenceExecutor: + return self._inf_executor @property def proc(self) -> JobProcess: @@ -268,24 +274,10 @@ def __init__( self, *, start_arguments: Any | None = None, - inference_executor: InferenceExecutor | None = None, ) -> None: self._mp_proc = mp.current_process() self._userdata: dict[str, Any] = {} self._start_arguments = start_arguments - self._inf_executor = inference_executor - - @property - def inference_executor(self) -> InferenceExecutor: - if self._inf_executor is None: - raise ValueError( - ( - "no inference executor is provided for the current JobProcess, did you " - "forgot to register/import plugins necessary for inference?" - ) - ) - - return self._inf_executor @property def pid(self) -> int | None: diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index 8df4a86bf..3a1e27f71 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -249,6 +249,18 @@ def __init__( # using spawn context for all platforms. We may have further optimizations for # Linux with forkserver, but for now, this is the safest option mp_ctx = mp.get_context("spawn") + + self._inference_executor: ( + ipc.inference_proc_executor.InferenceProcExecutor | None + ) = None + if len(_InferenceRunner.registered_runners) > 0: + self._inference_executor = ( + ipc.inference_proc_executor.InferenceProcExecutor( + mp_ctx=mp_ctx, + loop=self._loop, + ) + ) + self._proc_pool = ipc.proc_pool.ProcPool( initialize_process_fnc=opts.prewarm_fnc, job_entrypoint_fnc=opts.entrypoint_fnc, @@ -257,6 +269,7 @@ def __init__( ), loop=self._loop, job_executor_type=opts.job_executor_type, + inference_executor=self._inference_executor, mp_ctx=mp_ctx, initialize_timeout=opts.initialize_process_timeout, close_timeout=opts.shutdown_process_timeout, @@ -277,17 +290,6 @@ def __init__( self._main_task: asyncio.Task[None] | None = None - self._inference_executor: ( - ipc.inference_proc_executor.InferenceProcExecutor | None - ) = None - if len(_InferenceRunner.registered_runners) > 0: - self._inference_executor = ( - ipc.inference_proc_executor.InferenceProcExecutor( - mp_ctx=mp_ctx, - loop=self._loop, - ) - ) - async def run(self): if not self._closed: raise Exception("worker is already running") diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py index 03887e72c..d48c2ebcc 100644 --- a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py @@ -53,7 +53,10 @@ def _format_chat_ctx(self, chat_ctx: dict): def initialize(self) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer - self._model = AutoModelForCausalLM.from_pretrained(HG_MODEL, local_files_only=True) + + self._model = AutoModelForCausalLM.from_pretrained( + HG_MODEL, local_files_only=True + ) self._tokenizer = AutoTokenizer.from_pretrained(HG_MODEL, local_files_only=True) self._eou_index = self._tokenizer.encode("<|im_end|>")[-1] From 0d188a8985e22fb4dfc2f1abf5b1f644e474db09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 15:28:33 +0100 Subject: [PATCH 14/24] working version --- .../voice-pipeline-agent/minimal_assistant.py | 4 +- .../livekit/agents/ipc/proc_client.py | 1 + .../livekit/agents/pipeline/pipeline_agent.py | 43 ++++++++++++++++--- .../livekit/plugins/eou/eou.py | 5 +-- 4 files changed, 44 insertions(+), 9 deletions(-) diff --git a/examples/voice-pipeline-agent/minimal_assistant.py b/examples/voice-pipeline-agent/minimal_assistant.py index 9ecad5056..ca179093c 100644 --- a/examples/voice-pipeline-agent/minimal_assistant.py +++ b/examples/voice-pipeline-agent/minimal_assistant.py @@ -13,7 +13,8 @@ metrics, ) from livekit.agents.pipeline import VoicePipelineAgent -from livekit.plugins import deepgram, openai, silero, eou +from livekit.plugins import deepgram, openai, silero +from livekit.plugins import eou load_dotenv() logger = logging.getLogger("voice-assistant") @@ -49,6 +50,7 @@ async def entrypoint(ctx: JobContext): stt=deepgram.STT(model=dg_model), llm=openai.LLM(), tts=openai.TTS(), + eou=eou.EOU(), chat_ctx=initial_ctx, ) diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py index 542d8b8bb..86e2856a5 100644 --- a/livekit-agents/livekit/agents/ipc/proc_client.py +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -94,6 +94,7 @@ async def _read_ipc_task(): ) await asend_message(self._acch, pong) + print(msg) ipc_ch.send_nowait(msg) async def _self_health_check(): diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index ccd394652..644e4c251 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -12,6 +12,7 @@ Callable, Literal, Optional, + Protocol, Union, ) @@ -150,6 +151,10 @@ class AgentTranscriptionOptions: representing the hyphenated parts of the word.""" +class _EOUModel(Protocol): + async def predict_eou(self, chat_ctx: ChatContext) -> float: ... + + class VoicePipelineAgent(utils.EventEmitter[EventTypes]): """ A pipeline agent (VAD + STT + LLM + TTS) implementation. @@ -165,6 +170,7 @@ def __init__( stt: stt.STT, llm: LLM, tts: tts.TTS, + eou: _EOUModel | None = None, chat_ctx: ChatContext | None = None, fnc_ctx: FunctionContext | None = None, allow_interruptions: bool = True, @@ -255,6 +261,7 @@ def __init__( ) self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts + self._eou = eou self._chat_ctx = chat_ctx or ChatContext() self._fnc_ctx = fnc_ctx self._started, self._closed = False, False @@ -274,7 +281,8 @@ def __init__( self._deferred_validation = _DeferredReplyValidation( self._validate_reply_if_possible, self._opts.min_endpointing_delay, - loop=self._loop, + eou=self._eou, + chat_ctx=self._chat_ctx, ) self._speech_q: list[SpeechHandle] = [] @@ -1057,18 +1065,28 @@ class _DeferredReplyValidation: LATE_TRANSCRIPT_TOLERANCE = 1.5 # late compared to end of speech + # When endpoint probability is below this threshold we think the user is not finished speaking + # so we will use a long delay + UNLIKELY_ENDPOINT_THRESHOLD = 0.15 + + # Long delay to use when the model thinks the user is still speaking + UNLIKELY_ENDPOINT_DELAY = 5.0 + def __init__( self, validate_fnc: Callable[[], None], min_endpointing_delay: float, - loop: asyncio.AbstractEventLoop | None = None, + eou: _EOUModel | None, + chat_ctx: ChatContext, ) -> None: + self._eou = eou self._validate_fnc = validate_fnc self._validating_task: asyncio.Task | None = None self._last_final_transcript: str = "" self._last_recv_end_of_speech_time: float = 0.0 self._speaking = False + self._chat_ctx = chat_ctx self._end_of_speech_delay = min_endpointing_delay self._final_transcript_delay = min_endpointing_delay + 1.0 @@ -1094,7 +1112,6 @@ def on_human_final_transcript(self, transcript: str) -> None: delay = delay * ( self.PUNCTUATION_REDUCE_FACTOR if self._end_with_punctuation() else 1.0 ) - self._run(delay) def on_human_start_of_speech(self, ev: vad.VADEvent) -> None: @@ -1128,13 +1145,29 @@ def _reset_states(self) -> None: self._last_recv_end_of_speech_time = 0.0 def _run(self, delay: float) -> None: + detect_ctx = self._chat_ctx.copy() + detect_ctx.messages.append( + ChatMessage.create(text=self._last_final_transcript, role="user") + ) + @utils.log_exceptions(logger=logger) - async def _run_task(delay: float) -> None: + async def _run_task(chat_ctx: ChatContext, delay: float) -> None: await asyncio.sleep(delay) + if self._eou is not None: + eou_prob = await self._eou.predict_eou(chat_ctx) + logger.debug( + "eou prediction", + extra={"eou_probability": eou_prob}, + ) + + if eou_prob < self.UNLIKELY_ENDPOINT_THRESHOLD: + # TODO(theomonnom): This is additive with the last delay, need to refactor + await asyncio.sleep(self.UNLIKELY_ENDPOINT_DELAY) + self._reset_states() self._validate_fnc() if self._validating_task is not None: self._validating_task.cancel() - self._validating_task = asyncio.create_task(_run_task(delay)) + self._validating_task = asyncio.create_task(_run_task(detect_ctx, delay)) diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py index d48c2ebcc..3e7c03db2 100644 --- a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py @@ -78,14 +78,13 @@ def run(self, data: bytes) -> bytes | None: logits = outputs.logits[0, -1, :].detach().numpy() output_probs = _softmax(logits) eou_probability = output_probs[self._eou_index] - - return json.dumps({"eou_probability": eou_probability}).encode() + return json.dumps({"eou_probability": float(eou_probability)}).encode() class EOU: def __init__(self, inference_executor: InferenceExecutor | None = None) -> None: self._executor = ( - inference_executor or get_current_job_context().proc.inference_executor + inference_executor or get_current_job_context().inference_executor ) async def predict_eou(self, chat_ctx: llm.ChatContext) -> float: From ac2ff515b16bcc0fd6c6638b9195e03ab834274d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 15:31:41 +0100 Subject: [PATCH 15/24] Update pipeline_agent.py --- livekit-agents/livekit/agents/pipeline/pipeline_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 644e4c251..432f009c7 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -1095,7 +1095,7 @@ def validating(self) -> bool: return self._validating_task is not None and not self._validating_task.done() def on_human_final_transcript(self, transcript: str) -> None: - self._last_final_transcript = transcript.strip() # type: ignore + self._last_final_transcript += (" " + transcript.strip()) # type: ignore if self._speaking: return From ae247258926cedb1742ff78428b973f995c19010 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 15:34:15 +0100 Subject: [PATCH 16/24] Update inference_executor.py --- livekit-agents/livekit/agents/ipc/inference_executor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/livekit-agents/livekit/agents/ipc/inference_executor.py b/livekit-agents/livekit/agents/ipc/inference_executor.py index 6d18ee882..5493a4e90 100644 --- a/livekit-agents/livekit/agents/ipc/inference_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_executor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from typing import Protocol from ..utils import aio, shortuuid from . import proto @@ -36,4 +37,6 @@ def _on_inference_response(self, resp: proto.InferenceResponse) -> None: ) return - fut.set_result(resp) + print("got response", resp) + with contextlib.suppress(asyncio.CancelledError): + fut.set_result(resp) From 2bea42b2257fe45583ebb6bdd458e5f0e71d3c26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 26 Nov 2024 15:39:08 +0100 Subject: [PATCH 17/24] wip --- livekit-agents/livekit/agents/ipc/inference_executor.py | 2 +- livekit-agents/livekit/agents/ipc/proc_client.py | 3 +-- livekit-agents/livekit/agents/pipeline/pipeline_agent.py | 2 +- livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py | 1 + 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/livekit-agents/livekit/agents/ipc/inference_executor.py b/livekit-agents/livekit/agents/ipc/inference_executor.py index 5493a4e90..707378b1d 100644 --- a/livekit-agents/livekit/agents/ipc/inference_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_executor.py @@ -38,5 +38,5 @@ def _on_inference_response(self, resp: proto.InferenceResponse) -> None: return print("got response", resp) - with contextlib.suppress(asyncio.CancelledError): + with contextlib.suppress(asyncio.InvalidStateError): fut.set_result(resp) diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py index 86e2856a5..89dc4e3c4 100644 --- a/livekit-agents/livekit/agents/ipc/proc_client.py +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -94,13 +94,12 @@ async def _read_ipc_task(): ) await asend_message(self._acch, pong) - print(msg) ipc_ch.send_nowait(msg) async def _self_health_check(): await ping_timeout print( - "worker process is not responding.. worker crashed?", + "worker process is not responding.. worker crashed??", file=sys.stderr, ) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 432f009c7..39f42353d 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -1070,7 +1070,7 @@ class _DeferredReplyValidation: UNLIKELY_ENDPOINT_THRESHOLD = 0.15 # Long delay to use when the model thinks the user is still speaking - UNLIKELY_ENDPOINT_DELAY = 5.0 + UNLIKELY_ENDPOINT_DELAY = 3.0 def __init__( self, diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py index 3e7c03db2..25a7dea63 100644 --- a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py @@ -68,6 +68,7 @@ def run(self, data: bytes) -> bytes | None: raise ValueError("chat_ctx is required on the inference input data") text = self._format_chat_ctx(chat_ctx) + print(text) inputs = self._tokenizer( text, add_special_tokens=False, From 44cab40ae7436e0d3217b52499bd9ab00828a4e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 27 Nov 2024 14:19:57 +0100 Subject: [PATCH 18/24] share _ProcClient & remove duplicated code --- .../voice-pipeline-agent/minimal_assistant.py | 3 +- .../livekit/agents/inference_runner.py | 2 +- .../livekit/agents/ipc/inference_executor.py | 35 -- .../agents/ipc/inference_proc_executor.py | 5 +- .../agents/ipc/inference_proc_lazy_main.py | 12 +- livekit-agents/livekit/agents/ipc/job_main.py | 268 -------------- .../livekit/agents/ipc/job_proc_executor.py | 4 +- .../livekit/agents/ipc/job_proc_lazy_main.py | 332 +++++++++++++++--- .../livekit/agents/ipc/proc_client.py | 10 +- .../livekit/agents/ipc/proc_pool.py | 2 +- livekit-agents/livekit/agents/job.py | 3 +- .../livekit/agents/pipeline/pipeline_agent.py | 2 +- .../livekit/plugins/eou/eou.py | 2 +- 13 files changed, 311 insertions(+), 369 deletions(-) delete mode 100644 livekit-agents/livekit/agents/ipc/job_main.py diff --git a/examples/voice-pipeline-agent/minimal_assistant.py b/examples/voice-pipeline-agent/minimal_assistant.py index ca179093c..557a08329 100644 --- a/examples/voice-pipeline-agent/minimal_assistant.py +++ b/examples/voice-pipeline-agent/minimal_assistant.py @@ -13,8 +13,7 @@ metrics, ) from livekit.agents.pipeline import VoicePipelineAgent -from livekit.plugins import deepgram, openai, silero -from livekit.plugins import eou +from livekit.plugins import deepgram, eou, openai, silero load_dotenv() logger = logging.getLogger("voice-assistant") diff --git a/livekit-agents/livekit/agents/inference_runner.py b/livekit-agents/livekit/agents/inference_runner.py index b34cd993f..646a03bdd 100644 --- a/livekit-agents/livekit/agents/inference_runner.py +++ b/livekit-agents/livekit/agents/inference_runner.py @@ -1,7 +1,7 @@ from __future__ import annotations import threading -from abc import ABC, ABCMeta, abstractmethod +from abc import ABC, abstractmethod from typing import ClassVar, Protocol, Type diff --git a/livekit-agents/livekit/agents/ipc/inference_executor.py b/livekit-agents/livekit/agents/ipc/inference_executor.py index 707378b1d..c83aee64d 100644 --- a/livekit-agents/livekit/agents/ipc/inference_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_executor.py @@ -1,42 +1,7 @@ from __future__ import annotations -import contextlib from typing import Protocol -from ..utils import aio, shortuuid -from . import proto -from ..log import logger -from . import channel -import asyncio class InferenceExecutor(Protocol): async def do_inference(self, method: str, data: bytes) -> bytes | None: ... - - -class _InferenceRunnerClient(InferenceExecutor): - def __init__(self, *, cch: aio.duplex_unix._AsyncDuplex) -> None: - self._cch = cch - self._active_requests: dict[str, asyncio.Future[proto.InferenceResponse]] = {} - - async def do_inference(self, method: str, data: bytes) -> bytes | None: - request_id = shortuuid("INF_") - await channel.asend_message( - self._cch, - proto.InferenceRequest(request_id=request_id, method=method, data=data), - ) - - fut = asyncio.Future[proto.InferenceResponse]() - self._active_requests[request_id] = fut - return (await fut).data - - def _on_inference_response(self, resp: proto.InferenceResponse) -> None: - fut = self._active_requests.pop(resp.request_id, None) - if fut is None: - logger.warning( - "received unexpected inference response", extra={"resp": resp} - ) - return - - print("got response", resp) - with contextlib.suppress(asyncio.InvalidStateError): - fut.set_result(resp) diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py index 513de5fa7..b9cfe1a40 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py @@ -11,13 +11,12 @@ from typing import Any from .. import utils +from ..inference_runner import _InferenceRunner from ..log import logger from ..utils.aio import duplex_unix from . import channel, inference_proc_lazy_main, proto -from .log_queue import LogQueueListener - -from ..inference_runner import _InferenceRunner from .inference_executor import InferenceExecutor, _InferenceRunnerClient +from .log_queue import LogQueueListener @dataclass diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py index 344196823..61051708a 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py @@ -17,6 +17,7 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): import asyncio +import logging import socket from dataclasses import dataclass @@ -25,6 +26,7 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): from ..utils import aio, log_exceptions from . import proto from .channel import Message +from .log_queue import LogQueueHandler from .proc_client import _ProcClient @@ -37,12 +39,6 @@ class ProcStartArgs: def proc_main(args: ProcStartArgs) -> None: - import logging - - from ..log import logger - from ..utils import aio - from .log_queue import LogQueueHandler - root_logger = logging.getLogger() root_logger.setLevel(logging.NOTSET) @@ -63,9 +59,9 @@ def proc_main(args: ProcStartArgs) -> None: ) pid = current_process().pid - logger.info("initializing process", extra={"pid": pid}) + logger.info("initializing inference process", extra={"pid": pid}) client.initialize() - logger.info("process initialized", extra={"pid": pid}) + logger.info("inference process initialized", extra={"pid": pid}) client.run() finally: diff --git a/livekit-agents/livekit/agents/ipc/job_main.py b/livekit-agents/livekit/agents/ipc/job_main.py deleted file mode 100644 index 42b699457..000000000 --- a/livekit-agents/livekit/agents/ipc/job_main.py +++ /dev/null @@ -1,268 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import socket -import threading -from dataclasses import dataclass -from typing import Any, Callable - -from livekit import rtc - -from .. import utils -from ..job import JobContext, JobProcess, _JobContextVar -from ..log import logger -from ..utils.aio import duplex_unix -from . import channel, proto - -from .inference_executor import InferenceExecutor, _InferenceRunnerClient - - -@dataclass -class _ShutdownInfo: - user_initiated: bool - reason: str - - -@dataclass -class JobTask: - job_ctx: JobContext - task: asyncio.Task - shutdown_fut: asyncio.Future[_ShutdownInfo] - - -def _start_job( - proc: JobProcess, - job_entrypoint_fnc: Callable[[JobContext], Any], - start_req: proto.StartJobRequest, - exit_proc_fut: asyncio.Event, - cch: utils.aio.duplex_unix._AsyncDuplex, - inference_client: _InferenceRunnerClient, -) -> JobTask: - # used to warn users if none of connect/shutdown is called inside the job_entry - ctx_connect, ctx_shutdown = False, False - room = rtc.Room() - request_shutdown_fut = asyncio.Future[_ShutdownInfo]() - - @room.on("disconnected") - def _on_room_disconnected(*args): - with contextlib.suppress(asyncio.InvalidStateError): - request_shutdown_fut.set_result( - _ShutdownInfo(user_initiated=False, reason="room disconnected") - ) - - def _on_ctx_connect() -> None: - nonlocal ctx_connect - ctx_connect = True - - def _on_ctx_shutdown(reason: str) -> None: - nonlocal ctx_shutdown - ctx_shutdown = True - - with contextlib.suppress(asyncio.InvalidStateError): - request_shutdown_fut.set_result( - _ShutdownInfo(user_initiated=True, reason=reason) - ) - - info = start_req.running_job - room._info.name = info.job.room.name - job_ctx = JobContext( - proc=proc, - info=info, - room=room, - on_connect=_on_ctx_connect, - on_shutdown=_on_ctx_shutdown, - inference_executor=inference_client, - ) - - @utils.log_exceptions(logger=logger) - async def _run_job_task() -> None: - utils.http_context._new_session_ctx() - job_ctx_token = _JobContextVar.set(job_ctx) - - job_entry_task = asyncio.create_task( - job_entrypoint_fnc(job_ctx), name="job_entrypoint" - ) - - async def _warn_not_connected_task(): - await asyncio.sleep(10) - if not ctx_connect and not ctx_shutdown: - logger.warning( - ( - "room not connected after job_entry was called after 10 seconds, " - "did you forget to call job_ctx.connect()?" - ) - ) - - warn_unconnected_task = asyncio.create_task(_warn_not_connected_task()) - job_entry_task.add_done_callback(lambda _: warn_unconnected_task.cancel()) - - def log_exception(t: asyncio.Task) -> None: - if not t.cancelled() and t.exception(): - logger.error( - "unhandled exception while running the job task", - exc_info=t.exception(), - ) - elif not ctx_connect and not ctx_shutdown: - logger.warning("job task completed without connecting or shutting down") - - job_entry_task.add_done_callback(log_exception) - - shutdown_info = await request_shutdown_fut - logger.debug( - "shutting down job task", - extra={ - "reason": shutdown_info.reason, - "user_initiated": shutdown_info.user_initiated, - }, - ) - await channel.asend_message(cch, proto.Exiting(reason=shutdown_info.reason)) - await room.disconnect() - - try: - shutdown_tasks = [] - for callback in job_ctx._shutdown_callbacks: - shutdown_tasks.append( - asyncio.create_task(callback(), name="job_shutdown_callback") - ) - - await asyncio.gather(*shutdown_tasks) - except Exception: - logger.exception("error while shutting down the job") - - await utils.http_context._close_http_ctx() - _JobContextVar.reset(job_ctx_token) - exit_proc_fut.set() - - task = asyncio.create_task(_run_job_task()) - job_task = JobTask(job_ctx=job_ctx, task=task, shutdown_fut=request_shutdown_fut) - return job_task - - -async def _async_main( - proc: JobProcess, - job_entrypoint_fnc: Callable[[JobContext], Any], - mp_cch: socket.socket, -) -> None: - cch = await duplex_unix._AsyncDuplex.open(mp_cch) - - job_task: JobTask | None = None - exit_proc_fut = asyncio.Event() - no_msg_timeout = utils.aio.sleep(proto.PING_INTERVAL * 5) # missing 5 pings - - inference_client = _InferenceRunnerClient(cch=cch) - - @utils.log_exceptions(logger=logger) - async def _read_ipc_task(): - nonlocal job_task - while True: - try: - msg = await channel.arecv_message(cch, proto.IPC_MESSAGES) - except duplex_unix.DuplexClosed: - break - - with contextlib.suppress(utils.aio.SleepFinished): - no_msg_timeout.reset() - - if isinstance(msg, proto.PingRequest): - pong = proto.PongResponse( - last_timestamp=msg.timestamp, timestamp=utils.time_ms() - ) - await channel.asend_message(cch, pong) - - if isinstance(msg, proto.StartJobRequest): - assert job_task is None, "job task already running" - job_task = _start_job( - proc, job_entrypoint_fnc, msg, exit_proc_fut, cch, inference_client - ) - - if isinstance(msg, proto.ShutdownRequest): - if job_task is None: - break # there is no running job, we can exit immediately - - with contextlib.suppress(asyncio.InvalidStateError): - job_task.shutdown_fut.set_result( - _ShutdownInfo(reason=msg.reason, user_initiated=False) - ) - - if isinstance(msg, proto.InferenceResponse): - inference_client._on_inference_response(msg) - - async def _self_health_check(): - await no_msg_timeout - print("worker process is not responding.. worker crashed?") - with contextlib.suppress(asyncio.CancelledError): - exit_proc_fut.set() - - read_task = asyncio.create_task(_read_ipc_task(), name="ipc_read") - health_check_task = asyncio.create_task(_self_health_check(), name="health_check") - - def _done_cb(task: asyncio.Task) -> None: - with contextlib.suppress(asyncio.InvalidStateError): - exit_proc_fut.set() - - read_task.add_done_callback(_done_cb) - - await exit_proc_fut.wait() - await utils.aio.gracefully_cancel(read_task, health_check_task) - - with contextlib.suppress(duplex_unix.DuplexClosed): - await cch.aclose() - - -@dataclass -class ProcStartArgs: - initialize_process_fnc: Callable[[JobProcess], Any] - job_entrypoint_fnc: Callable[[JobContext], Any] - log_cch: socket.socket - mp_cch: socket.socket - asyncio_debug: bool - user_arguments: Any | None = None - inference_runners: dict[str, Callable[[], InferenceExecutor]] | None = None - - -@dataclass -class ThreadStartArgs: - mp_cch: socket.socket - initialize_process_fnc: Callable[[JobProcess], Any] - job_entrypoint_fnc: Callable[[JobContext], Any] - user_arguments: Any | None - asyncio_debug: bool - join_fnc: Callable[[], None] - - -def thread_main( - args: ThreadStartArgs, -) -> None: - """main function for the job process when using the ThreadedJobRunner""" - tid = threading.get_native_id() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.set_debug(args.asyncio_debug) - loop.slow_callback_duration = 0.1 # 100ms - - cch = duplex_unix._Duplex.open(args.mp_cch) - try: - init_req = channel.recv_message(cch, proto.IPC_MESSAGES) - assert isinstance( - init_req, proto.InitializeRequest - ), "first message must be InitializeRequest" - job_proc = JobProcess(start_arguments=args.user_arguments) - - logger.debug("initializing job runner", extra={"tid": tid}) - args.initialize_process_fnc(job_proc) - logger.debug("job runner initialized", extra={"tid": tid}) - channel.send_message(cch, proto.InitializeResponse()) - - main_task = loop.create_task( - _async_main(job_proc, args.job_entrypoint_fnc, cch.detach()), - name="job_proc_main", - ) - loop.run_until_complete(main_task) - except duplex_unix.DuplexClosed: - pass - except Exception: - logger.exception("error while running job process", extra={"tid": tid}) - finally: - args.join_fnc() - loop.run_until_complete(loop.shutdown_default_executor()) diff --git a/livekit-agents/livekit/agents/ipc/job_proc_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py index 25afd8e1e..9e02c4dd7 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -12,10 +12,10 @@ from .. import utils from ..job import JobContext, JobProcess, RunningJobInfo -from .inference_executor import InferenceExecutor from ..log import logger from ..utils.aio import duplex_unix from . import channel, job_main, job_proc_lazy_main, proto +from .inference_executor import InferenceExecutor from .job_executor import ( JobExecutorError_Runtime, JobExecutorError_ShutdownTimeout, @@ -355,7 +355,7 @@ async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None: self._pch, proto.InferenceResponse(request_id=inf_req.request_id, data=inf_res), ) - except Exception as e: + except Exception: logger.exception( "error handling inference request", extra=self.logging_extra() ) diff --git a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py index 6408f6c98..70e723e14 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from multiprocessing import current_process if current_process().name == "job_proc": @@ -16,20 +18,44 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): sys.excepthook = _no_traceback_excepthook -def proc_main(args) -> None: - """main function for the job process when using the ProcessJobRunner""" +import asyncio +import contextlib +import logging +import socket +import threading +from dataclasses import dataclass +from typing import Any, Callable + +from livekit import rtc + +from ..job import JobContext, JobProcess, _JobContextVar +from ..log import logger +from ..utils import aio, http_context, log_exceptions, shortuuid +from .channel import Message +from .inference_executor import InferenceExecutor +from .log_queue import LogQueueHandler +from .proc_client import _ProcClient +from .proto import ( + Exiting, + InferenceRequest, + InferenceResponse, + InitializeRequest, + ShutdownRequest, + StartJobRequest, +) - # import every package lazily - import asyncio - import logging - from ..job import JobProcess - from ..log import logger - from ..utils import aio - from .channel import recv_message, send_message - from .log_queue import LogQueueHandler - from .proto import IPC_MESSAGES, InitializeRequest, InitializeResponse +@dataclass +class ProcStartArgs: + initialize_process_fnc: Callable[[JobProcess], Any] + job_entrypoint_fnc: Callable[[JobContext], Any] + mp_cch: socket.socket + log_cch: socket.socket + asyncio_debug: bool + user_arguments: Any | None = None + +def proc_main(args: ProcStartArgs) -> None: root_logger = logging.getLogger() root_logger.setLevel(logging.NOTSET) @@ -37,40 +63,264 @@ def proc_main(args) -> None: log_handler = LogQueueHandler(log_cch) root_logger.addHandler(log_handler) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.set_debug(args.asyncio_debug) - loop.slow_callback_duration = 0.1 # 100ms - aio.debug.hook_slow_callbacks(2.0) - - cch = aio.duplex_unix._Duplex.open(args.mp_cch) try: - init_req = recv_message(cch, IPC_MESSAGES) + from .proc_client import _ProcClient - assert isinstance( - init_req, InitializeRequest - ), "first message must be InitializeRequest" + job_proc = _JobProc( + args.initialize_process_fnc, args.job_entrypoint_fnc, args.user_arguments + ) - job_proc = JobProcess(start_arguments=args.user_arguments) - logger.info("initializing process", extra={"pid": job_proc.pid}) - args.initialize_process_fnc(job_proc) - logger.info("process initialized", extra={"pid": job_proc.pid}) - send_message(cch, InitializeResponse()) + client = _ProcClient( + args.mp_cch, + job_proc.initialize, + job_proc.entrypoint, + args.asyncio_debug, + ) - from .job_main import _async_main + pid = current_process().pid + logger.info("initializing job process", extra={"pid": pid}) + client.initialize() + logger.info("job process initialized", extra={"pid": pid}) - main_task = loop.create_task( - _async_main(job_proc, args.job_entrypoint_fnc, cch.detach()), - name="inference_proc_main", - ) - while not main_task.done(): - try: - loop.run_until_complete(main_task) - except KeyboardInterrupt: - # ignore the keyboard interrupt, we handle the process shutdown ourselves on the worker process - pass - except (aio.duplex_unix.DuplexClosed, KeyboardInterrupt): - pass + client.run() finally: log_handler.close() - loop.run_until_complete(loop.shutdown_default_executor()) + + +class _InfClient(InferenceExecutor): + def __init__(self, proc_client: _ProcClient) -> None: + self._client = proc_client + self._active_requests: dict[str, asyncio.Future[InferenceResponse]] = {} + + async def do_inference(self, method: str, data: bytes) -> bytes | None: + request_id = shortuuid("INF_") + fut = asyncio.Future[InferenceResponse]() + + await self._client.send( + InferenceRequest(request_id=request_id, method=method, data=data), + ) + + self._active_requests[request_id] = fut + + inf_resp = await fut + if inf_resp.error: + raise RuntimeError(f"inference of {method} failed: {inf_resp.error}") + + return inf_resp.data + + def _on_inference_response(self, resp: InferenceResponse) -> None: + fut = self._active_requests.pop(resp.request_id, None) + if fut is None: + logger.warning( + "received unexpected inference response", extra={"resp": resp} + ) + return + + with contextlib.suppress(asyncio.InvalidStateError): + fut.set_result(resp) + + +@dataclass +class _ShutdownInfo: + user_initiated: bool + reason: str + + +class _JobProc: + def __init__( + self, + initialize_process_fnc: Callable[[JobProcess], Any], + job_entrypoint_fnc: Callable[[JobContext], Any], + user_arguments: Any | None = None, + ) -> None: + self._initialize_process_fnc = initialize_process_fnc + self._job_entrypoint_fnc = job_entrypoint_fnc + self._job_proc = JobProcess(start_arguments=user_arguments) + self._exit_proc_flag = asyncio.Event() + + self._job_task: asyncio.Task | None = None + self._shutdown_fut: asyncio.Future[_ShutdownInfo] = asyncio.Future() + + # used to warn users if both connect and shutdown are not called inside the job_entry + self._ctx_connect_called = False + self._ctx_shutdown_called = False + + @property + def has_running_job(self) -> bool: + return self._job_task is not None + + def initialize(self, init_req: InitializeRequest, client: _ProcClient) -> None: + self._client = client + self._inf_client = _InfClient(client) + self._initialize_process_fnc(self._job_proc) + + @log_exceptions(logger=logger) + async def entrypoint(self, cch: aio.ChanReceiver[Message]) -> None: + @log_exceptions(logger=logger) + async def _read_ipc_task(): + async for msg in cch: + if isinstance(msg, StartJobRequest): + if self.has_running_job: + logger.warning( + "trying to start a new job while one is already running" + ) + continue + + self._start_job(msg) + if isinstance(msg, ShutdownRequest): + if not self.has_running_job: + break # exit immediately + + with contextlib.suppress(asyncio.InvalidStateError): + self._shutdown_fut.set_result( + _ShutdownInfo(reason=msg.reason, user_initiated=False) + ) + + if isinstance(msg, InferenceResponse): + self._inf_client._on_inference_response(msg) + + read_task = asyncio.create_task(_read_ipc_task(), name="job_ipc_read") + + await self._exit_proc_flag.wait() + await aio.gracefully_cancel(read_task) + + def _start_job(self, msg: StartJobRequest) -> None: + self._room = rtc.Room() + + @self._room.on("disconnected") + def _on_room_disconnected(*args): + with contextlib.suppress(asyncio.InvalidStateError): + self._shutdown_fut.set_result( + _ShutdownInfo(user_initiated=False, reason="room disconnected") + ) + + def _on_ctx_connect() -> None: + self._ctx_connect_called = True + + def _on_ctx_shutdown(reason: str) -> None: + self._ctx_shutdown_called = True + + with contextlib.suppress(asyncio.InvalidStateError): + self._shutdown_fut.set_result( + _ShutdownInfo(user_initiated=True, reason=reason) + ) + + self._room._info.name = msg.running_job.job.room.name + + self._job_ctx = JobContext( + proc=self._job_proc, + info=msg.running_job, + room=self._room, + on_connect=_on_ctx_connect, + on_shutdown=_on_ctx_shutdown, + inference_executor=self._inf_client, + ) + + self._job_task = asyncio.create_task(self._run_job_task(), name="job_task") + + def _exit_proc_cb(_: asyncio.Task) -> None: + self._exit_proc_flag.set() + + self._job_task.add_done_callback(_exit_proc_cb) + + async def _run_job_task(self) -> None: + http_context._new_session_ctx() + job_ctx_token = _JobContextVar.set(self._job_ctx) + + job_entry_task = asyncio.create_task( + self._job_entrypoint_fnc(self._job_ctx), name="job_user_entrypoint" + ) + + async def _warn_not_connected_task(): + await asyncio.sleep(10) + if not self._ctx_connect_called and not self._ctx_shutdown_called: + logger.warning( + ( + "The room connection was not established within 10 seconds after calling job_entry. " + "This may indicate that job_ctx.connect() was not called. " + ) + ) + + warn_unconnected_task = asyncio.create_task(_warn_not_connected_task()) + job_entry_task.add_done_callback(lambda _: warn_unconnected_task.cancel()) + + def log_exception(t: asyncio.Task) -> None: + if not t.cancelled() and t.exception(): + logger.error( + "unhandled exception while running the job task", + exc_info=t.exception(), + ) + elif not self._ctx_connect_called and not self._ctx_shutdown_called: + logger.warning( + ( + "The job task completed without establishing a connection or performing a proper shutdown. " + "Ensure that job_ctx.connect()/job_ctx.shutdown() is called and the job is correctly finalized." + ) + ) + + job_entry_task.add_done_callback(log_exception) + + shutdown_info = await self._shutdown_fut + logger.debug( + "shutting down job task", + extra={ + "reason": shutdown_info.reason, + "user_initiated": shutdown_info.user_initiated, + }, + ) + + await self._client.send(Exiting(reason=shutdown_info.reason)) + await self._room.disconnect() + + try: + shutdown_tasks = [] + for callback in self._job_ctx._shutdown_callbacks: + shutdown_tasks.append( + asyncio.create_task(callback(), name="job_shutdown_callback") + ) + + await asyncio.gather(*shutdown_tasks) + except Exception: + logger.exception("error while shutting down the job") + + await http_context._close_http_ctx() + _JobContextVar.reset(job_ctx_token) + + +@dataclass +class ThreadStartArgs: + initialize_process_fnc: Callable[[JobProcess], Any] + job_entrypoint_fnc: Callable[[JobContext], Any] + join_fnc: Callable[[], None] + mp_cch: socket.socket + user_arguments: Any | None + asyncio_debug: bool + + +def thread_main( + args: ThreadStartArgs, +) -> None: + """main function for the job process when using the ThreadedJobRunner""" + tid = threading.get_native_id() + + try: + from .proc_client import _ProcClient + + job_proc = _JobProc( + args.initialize_process_fnc, args.job_entrypoint_fnc, args.user_arguments + ) + + client = _ProcClient( + args.mp_cch, + job_proc.initialize, + job_proc.entrypoint, + args.asyncio_debug, + ) + + logger.info("initializing job runner", extra={"tid": tid}) + client.initialize() + logger.info("job runner initialized", extra={"tid": tid}) + + client.run() + finally: + args.join_fnc() diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py index 89dc4e3c4..4b70fd5e2 100644 --- a/livekit-agents/livekit/agents/ipc/proc_client.py +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -89,17 +89,19 @@ async def _read_ipc_task(): ping_timeout.reset() if isinstance(msg, proto.PingRequest): - pong = proto.PongResponse( - last_timestamp=msg.timestamp, timestamp=time_ms() + await asend_message( + self._acch, + proto.PongResponse( + last_timestamp=msg.timestamp, timestamp=time_ms() + ), ) - await asend_message(self._acch, pong) ipc_ch.send_nowait(msg) async def _self_health_check(): await ping_timeout print( - "worker process is not responding.. worker crashed??", + "worker process is not responding.. worker crashed?", file=sys.stderr, ) diff --git a/livekit-agents/livekit/agents/ipc/proc_pool.py b/livekit-agents/livekit/agents/ipc/proc_pool.py index cb978fd99..a58285979 100644 --- a/livekit-agents/livekit/agents/ipc/proc_pool.py +++ b/livekit-agents/livekit/agents/ipc/proc_pool.py @@ -8,7 +8,7 @@ from ..job import JobContext, JobExecutorType, JobProcess, RunningJobInfo from ..log import logger from ..utils import aio -from . import job_proc_executor, job_thread_executor, inference_executor +from . import inference_executor, job_proc_executor, job_thread_executor from .job_executor import JobExecutor EventTypes = Literal[ diff --git a/livekit-agents/livekit/agents/job.py b/livekit-agents/livekit/agents/job.py index 65e90eba7..2e2480d8a 100644 --- a/livekit-agents/livekit/agents/job.py +++ b/livekit-agents/livekit/agents/job.py @@ -24,9 +24,8 @@ from livekit import rtc from livekit.protocol import agent, models -from .log import logger - from .ipc.inference_executor import InferenceExecutor +from .log import logger _JobContextVar = contextvars.ContextVar("agents_job_context") diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 39f42353d..26598147e 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -1095,7 +1095,7 @@ def validating(self) -> bool: return self._validating_task is not None and not self._validating_task.done() def on_human_final_transcript(self, transcript: str) -> None: - self._last_final_transcript += (" " + transcript.strip()) # type: ignore + self._last_final_transcript += " " + transcript.strip() # type: ignore if self._speaking: return diff --git a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py index 25a7dea63..a4cf3f181 100644 --- a/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py +++ b/livekit-plugins/livekit-plugins-eou/livekit/plugins/eou/eou.py @@ -5,9 +5,9 @@ import numpy as np from livekit.agents import llm -from livekit.agents.job import get_current_job_context from livekit.agents.inference_runner import _InferenceRunner from livekit.agents.ipc.inference_executor import InferenceExecutor +from livekit.agents.job import get_current_job_context HG_MODEL = "livekit/opt-125m-endpoint-detector" PUNCS = string.punctuation.replace("'", "") From 0f4e630c13115fee5e4a335f33c0412d30133f1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Wed, 27 Nov 2024 14:21:24 +0100 Subject: [PATCH 19/24] Create wild-walls-occur.md --- .changeset/wild-walls-occur.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changeset/wild-walls-occur.md diff --git a/.changeset/wild-walls-occur.md b/.changeset/wild-walls-occur.md new file mode 100644 index 000000000..35ea69cf9 --- /dev/null +++ b/.changeset/wild-walls-occur.md @@ -0,0 +1,6 @@ +--- +"livekit-agents": patch +"livekit-plugins-eou": minor +--- + +feat: inference process & end of utterance plugin From e8da65bf05a0091147b440f0142c59234d6ee9db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 27 Nov 2024 14:22:25 +0100 Subject: [PATCH 20/24] Delete inference_runner.py --- livekit-agents/livekit/agents/utils/inference_runner.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 livekit-agents/livekit/agents/utils/inference_runner.py diff --git a/livekit-agents/livekit/agents/utils/inference_runner.py b/livekit-agents/livekit/agents/utils/inference_runner.py deleted file mode 100644 index e69de29bb..000000000 From 31530964da0ebcf4053b37ddc20763b2092a289c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 27 Nov 2024 17:20:43 +0100 Subject: [PATCH 21/24] wip --- .../agents/ipc/inference_proc_lazy_main.py | 39 +- .../livekit/agents/ipc/job_proc_executor.py | 473 +++--------------- .../livekit/agents/ipc/job_proc_lazy_main.py | 45 +- .../livekit/agents/ipc/proc_client.py | 66 ++- livekit-agents/livekit/agents/ipc/proto.py | 24 +- .../livekit/agents/ipc/supervised_proc.py | 396 +++++++++++++++ 6 files changed, 551 insertions(+), 492 deletions(-) create mode 100644 livekit-agents/livekit/agents/ipc/supervised_proc.py diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py index 61051708a..d65afc273 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py @@ -17,7 +17,6 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): import asyncio -import logging import socket from dataclasses import dataclass @@ -26,7 +25,6 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): from ..utils import aio, log_exceptions from . import proto from .channel import Message -from .log_queue import LogQueueHandler from .proc_client import _ProcClient @@ -34,38 +32,27 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): class ProcStartArgs: log_cch: socket.socket mp_cch: socket.socket - asyncio_debug: bool runners: _RunnersDict def proc_main(args: ProcStartArgs) -> None: - root_logger = logging.getLogger() - root_logger.setLevel(logging.NOTSET) + from .proc_client import _ProcClient - log_cch = aio.duplex_unix._Duplex.open(args.log_cch) - log_handler = LogQueueHandler(log_cch) - root_logger.addHandler(log_handler) + inf_proc = _InferenceProc(args.runners) - try: - from .proc_client import _ProcClient + client = _ProcClient( + args.mp_cch, + args.log_cch, + inf_proc.initialize, + inf_proc.entrypoint, + ) - inf_proc = _InferenceProc(args.runners) + pid = current_process().pid + logger.info("initializing inference process", extra={"pid": pid}) + client.initialize() + logger.info("inference process initialized", extra={"pid": pid}) - client = _ProcClient( - args.mp_cch, - inf_proc.initialize, - inf_proc.entrypoint, - args.asyncio_debug, - ) - - pid = current_process().pid - logger.info("initializing inference process", extra={"pid": pid}) - client.initialize() - logger.info("inference process initialized", extra={"pid": pid}) - - client.run() - finally: - log_handler.close() + client.run() class _InferenceProc: diff --git a/livekit-agents/livekit/agents/ipc/job_proc_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py index 1df9c9014..c44b364b7 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -1,371 +1,99 @@ from __future__ import annotations import asyncio -import contextlib -import logging +import multiprocessing as mp import socket -import sys -import threading -from dataclasses import dataclass from multiprocessing.context import BaseContext from typing import Any, Awaitable, Callable -import psutil - -from .. import utils from ..job import JobContext, JobProcess, RunningJobInfo from ..log import logger -from ..utils.aio import duplex_unix -from . import channel, job_main, job_proc_lazy_main, proto +from ..utils import aio, log_exceptions +from . import channel, proto from .inference_executor import InferenceExecutor -from .job_executor import ( - JobExecutorError_MemoryLimitExceeded, - JobExecutorError_Runtime, - JobExecutorError_ShutdownTimeout, - JobExecutorError_Unresponsive, - RunStatus, -) -from .log_queue import LogQueueListener - +from .job_proc_lazy_main import ProcStartArgs, proc_main +from .supervised_proc import SupervisedProc -@dataclass -class _ProcOpts: - initialize_process_fnc: Callable[[JobProcess], Any] - job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]] - mp_ctx: BaseContext - initialize_timeout: float - close_timeout: float - job_memory_warn_mb: float - job_memory_limit_mb: float - -class ProcJobExecutor: +class ProcJobExecutor(SupervisedProc): def __init__( self, *, initialize_process_fnc: Callable[[JobProcess], Any], job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]], - initialize_timeout: float, inference_executor: InferenceExecutor | None, + initialize_timeout: float, close_timeout: float, + memory_warn_mb: float, + memory_limit_mb: float, + ping_interval: float, + ping_timeout: float, + high_ping_threshold: float, mp_ctx: BaseContext, loop: asyncio.AbstractEventLoop, - job_memory_warn_mb: float = 0, - job_memory_limit_mb: float = 0, ) -> None: - self._loop = loop - self._opts = _ProcOpts( - initialize_process_fnc=initialize_process_fnc, - job_entrypoint_fnc=job_entrypoint_fnc, + super().__init__( initialize_timeout=initialize_timeout, close_timeout=close_timeout, + memory_warn_mb=memory_warn_mb, + memory_limit_mb=memory_limit_mb, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + high_ping_threshold=high_ping_threshold, mp_ctx=mp_ctx, - job_memory_warn_mb=job_memory_warn_mb, - job_memory_limit_mb=job_memory_limit_mb, + loop=loop, ) - self._user_args: Any | None = None - self._inf_excecutor = inference_executor self._running_job: RunningJobInfo | None = None - self._exitcode: int | None = None - self._pid: int | None = None - self._exception: Exception | None = None - - self._main_atask: asyncio.Task[None] | None = None - self._closing = False - self._kill_sent = False - self._initialize_fut = asyncio.Future[None]() + self._initialize_process_fnc = initialize_process_fnc + self._job_entrypoint_fnc = job_entrypoint_fnc + self._inference_executor = inference_executor self._inference_tasks: list[asyncio.Task[None]] = [] - self._lock = asyncio.Lock() - - @property - def exitcode(self) -> int | None: - return self._exitcode - - @property - def killed(self) -> bool: - return self._kill_sent - - @property - def pid(self) -> int | None: - return self._pid - - @property - def started(self) -> bool: - return self._main_atask is not None - - @property - def start_arguments(self) -> Any | None: - return self._user_args - - @start_arguments.setter - def start_arguments(self, value: Any | None) -> None: - self._user_args = value - @property def running_job(self) -> RunningJobInfo | None: return self._running_job - @property - def exception(self) -> Exception | None: - return self._exception - - @property - def run_status(self) -> RunStatus: - if not self._running_job: - if self.started: - return RunStatus.WAITING_FOR_JOB - else: - return RunStatus.STARTING - - if not self._main_atask: - return RunStatus.STARTING - - if self._main_atask.done(): - if self.exception: - return RunStatus.FINISHED_FAILED - else: - return RunStatus.FINISHED_CLEAN - else: - return RunStatus.RUNNING_JOB - - async def start(self) -> None: - """start the job process""" - if self.started: - raise RuntimeError("process already started") - - if self._closing: - raise RuntimeError("process is closed") - - await asyncio.shield(self._start()) - - async def _start(self) -> None: - def _add_proc_ctx_log(record: logging.LogRecord) -> None: - extra = self.logging_extra() - for key, value in extra.items(): - setattr(record, key, value) - - async with self._lock: - mp_pch, mp_cch = socket.socketpair() - mp_log_pch, mp_log_cch = socket.socketpair() - - self._pch = await duplex_unix._AsyncDuplex.open(mp_pch) - - log_pch = duplex_unix._Duplex.open(mp_log_pch) - log_listener = LogQueueListener(log_pch, _add_proc_ctx_log) - log_listener.start() - - self._proc_args = job_main.ProcStartArgs( - initialize_process_fnc=self._opts.initialize_process_fnc, - job_entrypoint_fnc=self._opts.job_entrypoint_fnc, - log_cch=mp_log_cch, - mp_cch=mp_cch, - asyncio_debug=self._loop.get_debug(), - user_arguments=self._user_args, - ) - - self._proc = self._opts.mp_ctx.Process( # type: ignore - target=job_proc_lazy_main.proc_main, - args=(self._proc_args,), - name="job_proc", - ) - - self._proc.start() - mp_log_cch.close() - mp_cch.close() - - self._pid = self._proc.pid - self._join_fut = asyncio.Future[None]() - - def _sync_run(): - self._proc.join() - log_listener.stop() - try: - self._loop.call_soon_threadsafe(self._join_fut.set_result, None) - except RuntimeError: - pass - - thread = threading.Thread(target=_sync_run, name="proc_join_thread") - thread.start() - self._main_atask = asyncio.create_task(self._main_task()) - - async def join(self) -> None: - """wait for the job process to finish""" - if not self.started: - raise RuntimeError("process not started") - - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) - - async def initialize(self) -> None: - """initialize the job process, this is calling the user provided initialize_process_fnc - raise asyncio.TimeoutError if initialization times out""" - await channel.asend_message(self._pch, proto.InitializeRequest()) - - # wait for the process to become ready - try: - init_res = await asyncio.wait_for( - channel.arecv_message(self._pch, proto.IPC_MESSAGES), - timeout=self._opts.initialize_timeout, - ) - assert isinstance( - init_res, proto.InitializeResponse - ), "first message must be InitializeResponse" - except asyncio.TimeoutError: - self._initialize_fut.set_exception( - asyncio.TimeoutError("process initialization timed out") - ) - logger.error( - "initialization timed out, killing job", extra=self.logging_extra() - ) - self._send_kill_signal() - raise - except Exception as e: # should be channel.ChannelClosed most of the time - self._exception = JobExecutorError_Runtime() - self._initialize_fut.set_exception(e) - raise - else: - self._initialize_fut.set_result(None) - - async def aclose(self) -> None: - """attempt to gracefully close the job process""" - if not self.started: - return - - self._closing = True - with contextlib.suppress(utils.aio.duplex_unix.DuplexClosed): - await channel.asend_message(self._pch, proto.ShutdownRequest()) - - try: - if self._main_atask: - await asyncio.wait_for( - asyncio.shield(self._main_atask), timeout=self._opts.close_timeout - ) - except asyncio.TimeoutError: - logger.error( - "process did not exit in time, killing job", extra=self.logging_extra() - ) - self._exception = JobExecutorError_ShutdownTimeout() - self._send_kill_signal() - - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) - - async def kill(self) -> None: - """forcefully kill the job process""" - if not self.started: - raise RuntimeError("process not started") - - self._closing = True - self._send_kill_signal() - - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) - - async def launch_job(self, info: RunningJobInfo) -> None: - """start/assign a job to the process""" - if self._running_job is not None: - raise RuntimeError("process already has a running job") - - self._running_job = info - start_req = proto.StartJobRequest() - start_req.running_job = info - await channel.asend_message(self._pch, start_req) - - def _send_kill_signal(self) -> None: - """forcefully kill the job process""" - try: - if not self._proc.is_alive(): - return - except ValueError: - return - - logger.info("killing job process", extra=self.logging_extra()) - if sys.platform == "win32": - self._proc.terminate() - else: - self._proc.kill() + def _create_process(self, cch: socket.socket, log_cch: socket.socket) -> mp.Process: + proc_args = ProcStartArgs( + initialize_process_fnc=self._initialize_process_fnc, + job_entrypoint_fnc=self._job_entrypoint_fnc, + log_cch=log_cch, + mp_cch=cch, + user_arguments=self._user_args, + ) - self._kill_sent = True + return self._opts.mp_ctx.Process( # type: ignore + target=proc_main, + args=(proc_args,), + name="job_proc", + ) - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: + @log_exceptions(logger=logger) + async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None: try: - await self._initialize_fut - except asyncio.TimeoutError: - pass # this happens when the initialization takes longer than self._initialize_timeout - except Exception: - pass # initialization failed - - # the process is killed if it doesn't respond to ping requests - pong_timeout = utils.aio.sleep(proto.PING_TIMEOUT) - ping_task = asyncio.create_task(self._ping_pong_task(pong_timeout)) - monitor_task = asyncio.create_task(self._monitor_task(pong_timeout)) - - if self._opts.job_memory_limit_mb > 0 or self._opts.job_memory_warn_mb > 0: - memory_monitor_task = asyncio.create_task(self._memory_monitor_task()) - else: - memory_monitor_task = None - - await self._join_fut - self._exitcode = self._proc.exitcode - self._proc.close() - await utils.aio.gracefully_cancel(ping_task, monitor_task) - - if memory_monitor_task: - await utils.aio.gracefully_cancel(memory_monitor_task) - - with contextlib.suppress(duplex_unix.DuplexClosed): - await self._pch.aclose() - - if self._exitcode != 0 and not self._kill_sent: - self._exception = JobExecutorError_Runtime() - logger.error( - f"job process exited with non-zero exit code {self.exitcode}", - extra=self.logging_extra(), - ) - - @utils.log_exceptions(logger=logger) - async def _monitor_task(self, pong_timeout: utils.aio.Sleep) -> None: - while True: - try: - msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES) - except utils.aio.duplex_unix.DuplexClosed: - break - - if isinstance(msg, proto.PongResponse): - delay = utils.time_ms() - msg.timestamp - if delay > proto.HIGH_PING_THRESHOLD * 1000: - logger.warning( - "job process is unresponsive", - extra={"delay": delay, **self.logging_extra()}, + async for msg in ipc_ch: + if isinstance(msg, proto.InferenceRequest): + self._inference_tasks.append( + asyncio.create_task(self._do_inference_task(msg)) ) - - with contextlib.suppress(utils.aio.SleepFinished): - pong_timeout.reset() - - if isinstance(msg, proto.Exiting): - logger.info( - "job exiting", extra={"reason": msg.reason, **self.logging_extra()} - ) - - if isinstance(msg, proto.InferenceRequest): - self._inference_tasks.append( - asyncio.create_task(self._do_inference_task(msg)) - ) + finally: + await aio.gracefully_cancel(*self._inference_tasks) async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None: - if self._inf_excecutor is None: + if self._inference_executor is None: logger.warning("inference request received but no inference executor") + await channel.asend_message( + self._pch, + proto.InferenceResponse( + request_id=inf_req.request_id, error="no inference executor" + ), + ) return try: - inf_res = await self._inf_excecutor.do_inference( + inf_res = await self._inference_executor.do_inference( inf_req.method, inf_req.data ) await channel.asend_message( @@ -377,98 +105,19 @@ async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None: "error handling inference request", extra=self.logging_extra() ) - @utils.log_exceptions(logger=logger) - async def _ping_pong_task(self, pong_timeout: utils.aio.Sleep) -> None: - ping_interval = utils.aio.interval(proto.PING_INTERVAL) - - async def _send_ping_co(): - while True: - await ping_interval.tick() - try: - await channel.asend_message( - self._pch, proto.PingRequest(timestamp=utils.time_ms()) - ) - except utils.aio.duplex_unix.DuplexClosed: - break - - async def _pong_timeout_co(): - await pong_timeout - logger.error("job is unresponsive, killing job", extra=self.logging_extra()) - self._exception = JobExecutorError_Unresponsive() - self._send_kill_signal() - - tasks = [ - asyncio.create_task(_send_ping_co()), - asyncio.create_task(_pong_timeout_co()), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) - - @utils.log_exceptions(logger=logger) - async def _memory_monitor_task(self) -> None: - """Monitor memory usage and kill the process if it exceeds the limit.""" - while not self._closing and not self._kill_sent: - try: - if not self._pid or not self._running_job: - await asyncio.sleep(5) - continue - - # Get process memory info - process = psutil.Process(self._pid) - memory_info = process.memory_info() - memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB - - if ( - self._opts.job_memory_limit_mb > 0 - and memory_mb > self._opts.job_memory_limit_mb - ): - logger.error( - "Job exceeded memory limit, killing job", - extra={ - "memory_usage_mb": memory_mb, - "memory_limit_mb": self._opts.job_memory_limit_mb, - **self.logging_extra(), - }, - ) - self._exception = JobExecutorError_MemoryLimitExceeded() - self._send_kill_signal() - elif ( - self._opts.job_memory_warn_mb > 0 - and memory_mb > self._opts.job_memory_warn_mb - ): - logger.warning( - "Job memory usage is high", - extra={ - "memory_usage_mb": memory_mb, - "memory_warn_mb": self._opts.job_memory_warn_mb, - "memory_limit_mb": self._opts.job_memory_limit_mb, - **self.logging_extra(), - }, - ) - - except (psutil.NoSuchProcess, psutil.AccessDenied) as e: - logger.warning( - "Failed to get memory info for process", - extra=self.logging_extra(), - exc_info=e, - ) - except Exception: - if self._closing or self._kill_sent: - return - - logger.exception( - "Error in memory monitoring task", - extra=self.logging_extra(), - ) + async def launch_job(self, info: RunningJobInfo) -> None: + """start/assign a job to the process""" + if self._running_job is not None: + raise RuntimeError("process already has a running job") - await asyncio.sleep(5) # Check every 5 seconds + self._running_job = info + start_req = proto.StartJobRequest() + start_req.running_job = info + await channel.asend_message(self._pch, start_req) def logging_extra(self): - extra: dict[str, Any] = { - "pid": self.pid, - } + extra = super().logging_extra() + if self._running_job: extra["job_id"] = self._running_job.job.id diff --git a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py index 70e723e14..9be517d75 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -20,7 +20,6 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): import asyncio import contextlib -import logging import socket import threading from dataclasses import dataclass @@ -33,7 +32,6 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): from ..utils import aio, http_context, log_exceptions, shortuuid from .channel import Message from .inference_executor import InferenceExecutor -from .log_queue import LogQueueHandler from .proc_client import _ProcClient from .proto import ( Exiting, @@ -51,40 +49,29 @@ class ProcStartArgs: job_entrypoint_fnc: Callable[[JobContext], Any] mp_cch: socket.socket log_cch: socket.socket - asyncio_debug: bool user_arguments: Any | None = None def proc_main(args: ProcStartArgs) -> None: - root_logger = logging.getLogger() - root_logger.setLevel(logging.NOTSET) + from .proc_client import _ProcClient - log_cch = aio.duplex_unix._Duplex.open(args.log_cch) - log_handler = LogQueueHandler(log_cch) - root_logger.addHandler(log_handler) + job_proc = _JobProc( + args.initialize_process_fnc, args.job_entrypoint_fnc, args.user_arguments + ) - try: - from .proc_client import _ProcClient + client = _ProcClient( + args.mp_cch, + args.log_cch, + job_proc.initialize, + job_proc.entrypoint, + ) - job_proc = _JobProc( - args.initialize_process_fnc, args.job_entrypoint_fnc, args.user_arguments - ) + pid = current_process().pid + logger.info("initializing job process", extra={"pid": pid}) + client.initialize() + logger.info("job process initialized", extra={"pid": pid}) - client = _ProcClient( - args.mp_cch, - job_proc.initialize, - job_proc.entrypoint, - args.asyncio_debug, - ) - - pid = current_process().pid - logger.info("initializing job process", extra={"pid": pid}) - client.initialize() - logger.info("job process initialized", extra={"pid": pid}) - - client.run() - finally: - log_handler.close() + client.run() class _InfClient(InferenceExecutor): @@ -93,7 +80,7 @@ def __init__(self, proc_client: _ProcClient) -> None: self._active_requests: dict[str, asyncio.Future[InferenceResponse]] = {} async def do_inference(self, method: str, data: bytes) -> bytes | None: - request_id = shortuuid("INF_") + request_id = shortuuid("inference_") fut = asyncio.Future[InferenceResponse]() await self._client.send( diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py index 4b70fd5e2..f9153f29c 100644 --- a/livekit-agents/livekit/agents/ipc/proc_client.py +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -1,42 +1,65 @@ +from __future__ import annotations + import asyncio import contextlib +import logging import socket import sys from typing import Callable, Coroutine from ..log import logger from ..utils import aio, log_exceptions, time_ms -from . import proto from .channel import Message, arecv_message, asend_message, recv_message, send_message +from .log_queue import LogQueueHandler +from .proto import ( + IPC_MESSAGES, + InitializeRequest, + InitializeResponse, + PingRequest, + PongResponse, +) class _ProcClient: def __init__( self, mp_cch: socket.socket, - initialize_fnc: Callable[[proto.InitializeRequest, "_ProcClient"], None], - entrypoint_fnc: Callable[ + log_cch: socket.socket | None, + initialize_fnc: Callable[[InitializeRequest, "_ProcClient"], None], + main_task_fnc: Callable[ [aio.ChanReceiver[Message]], Coroutine[None, None, None] ], - asyncio_debug: bool, ) -> None: self._mp_cch = mp_cch - self._asyncio_debug = asyncio_debug + self._log_cch = log_cch self._initialize_fnc = initialize_fnc - self._entrypoint_fnc = entrypoint_fnc + self._main_task_fnc = main_task_fnc self._initialized = False + self._log_handler: LogQueueHandler | None = None + + def initialize_logger(self) -> None: + if self._log_cch is None: + raise RuntimeError("cannot initialize logger without log channel") + + root_logger = logging.getLogger() + root_logger.setLevel(logging.NOTSET) + + log_cch = aio.duplex_unix._Duplex.open(self._log_cch) + self._log_handler = LogQueueHandler(log_cch) + root_logger.addHandler(self._log_handler) def initialize(self) -> None: try: cch = aio.duplex_unix._Duplex.open(self._mp_cch) - self._init_req = recv_message(cch, proto.IPC_MESSAGES) + first_req = recv_message(cch, IPC_MESSAGES) assert isinstance( - self._init_req, proto.InitializeRequest + first_req, InitializeRequest ), "first message must be proto.InitializeRequest" + self._init_req = first_req self._initialize_fnc(self._init_req, self) - send_message(cch, proto.InitializeResponse()) + send_message(cch, InitializeResponse()) self._initialized = True cch.detach() except aio.duplex_unix.DuplexClosed as e: @@ -48,12 +71,12 @@ def run(self) -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.set_debug(self._asyncio_debug) + loop.set_debug(self._init_req.asyncio_debug) loop.slow_callback_duration = 0.1 # 100ms aio.debug.hook_slow_callbacks(2.0) try: - self._task = loop.create_task(self._main_task(), name="proc_client_main") + self._task = loop.create_task(self._monitor_task(), name="proc_client_main") while not self._task.done(): try: loop.run_until_complete(self._task) @@ -64,16 +87,19 @@ def run(self) -> None: except KeyboardInterrupt: pass finally: + if self._log_handler is not None: + self._log_handler.close() + loop.run_until_complete(loop.shutdown_default_executor()) async def send(self, msg: Message) -> None: await asend_message(self._acch, msg) - async def _main_task(self) -> None: + async def _monitor_task(self) -> None: self._acch = await aio.duplex_unix._AsyncDuplex.open(self._mp_cch) try: exit_flag = asyncio.Event() - ping_timeout = aio.sleep(proto.PING_INTERVAL * 5) + ping_timeout = aio.sleep(self._init_req.ping_timeout) ipc_ch = aio.Chan[Message]() @@ -81,17 +107,17 @@ async def _main_task(self) -> None: async def _read_ipc_task(): while True: try: - msg = await arecv_message(self._acch, proto.IPC_MESSAGES) + msg = await arecv_message(self._acch, IPC_MESSAGES) except aio.duplex_unix.DuplexClosed: break with contextlib.suppress(aio.SleepFinished): ping_timeout.reset() - if isinstance(msg, proto.PingRequest): + if isinstance(msg, PingRequest): await asend_message( self._acch, - proto.PongResponse( + PongResponse( last_timestamp=msg.timestamp, timestamp=time_ms() ), ) @@ -109,8 +135,8 @@ async def _self_health_check(): health_check_task = asyncio.create_task( _self_health_check(), name="health_check" ) - entrypoint_task = asyncio.create_task( - self._entrypoint_fnc(ipc_ch), name="entrypoint" + main_task = asyncio.create_task( + self._main_task_fnc(ipc_ch), name="main_task_entrypoint" ) def _done_cb(_: asyncio.Task) -> None: @@ -121,9 +147,9 @@ def _done_cb(_: asyncio.Task) -> None: read_task.add_done_callback(_done_cb) health_check_task.add_done_callback(_done_cb) - entrypoint_task.add_done_callback(_done_cb) + main_task.add_done_callback(_done_cb) await exit_flag.wait() - await aio.gracefully_cancel(read_task, health_check_task, entrypoint_task) + await aio.gracefully_cancel(read_task, health_check_task, main_task) finally: await self._acch.aclose() diff --git a/livekit-agents/livekit/agents/ipc/proto.py b/livekit-agents/livekit/agents/ipc/proto.py index 14de2881c..c878b4f23 100644 --- a/livekit-agents/livekit/agents/ipc/proto.py +++ b/livekit-agents/livekit/agents/ipc/proto.py @@ -9,11 +9,6 @@ from ..job import JobAcceptArguments, RunningJobInfo from . import channel -PING_INTERVAL = 2.5 -PING_TIMEOUT = 90 -HIGH_PING_THRESHOLD = 0.5 -NO_MESSAGE_TIMEOUT = 15.0 - @dataclass class InitializeRequest: @@ -21,6 +16,25 @@ class InitializeRequest: MSG_ID: ClassVar[int] = 0 + asyncio_debug: bool = False + ping_interval: float = 0 + ping_timeout: float = 0 # if no response, process is considered dead + high_ping_threshold: float = ( + 0 # if ping is higher than this, process is considered unresponsive + ) + + def write(self, b: io.BytesIO) -> None: + channel.write_bool(b, self.asyncio_debug) + channel.write_float(b, self.ping_interval) + channel.write_float(b, self.ping_timeout) + channel.write_float(b, self.high_ping_threshold) + + def read(self, b: io.BytesIO) -> None: + self.asyncio_debug = channel.read_bool(b) + self.ping_interval = channel.read_float(b) + self.ping_timeout = channel.read_float(b) + self.high_ping_threshold = channel.read_float(b) + @dataclass class InitializeResponse: diff --git a/livekit-agents/livekit/agents/ipc/supervised_proc.py b/livekit-agents/livekit/agents/ipc/supervised_proc.py new file mode 100644 index 000000000..b4be1c631 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/supervised_proc.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +import asyncio +import contextlib +import logging +import multiprocessing as mp +import socket +import sys +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from typing import Any + +import psutil + +from ..log import logger +from ..utils import aio, log_exceptions, time_ms +from ..utils.aio import duplex_unix +from . import channel, proto +from .log_queue import LogQueueListener + + +@dataclass +class _ProcOpts: + initialize_timeout: float + close_timeout: float + memory_warn_mb: float + memory_limit_mb: float + ping_interval: float + ping_timeout: float + high_ping_threshold: float + + +class SupervisedProc(ABC): + def __init__( + self, + *, + initialize_timeout: float, + close_timeout: float, + memory_warn_mb: float, + memory_limit_mb: float, + ping_interval: float, + ping_timeout: float, + high_ping_threshold: float, + mp_ctx: BaseContext, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._mp_ctx = mp_ctx + self._opts = _ProcOpts( + initialize_timeout=initialize_timeout, + close_timeout=close_timeout, + memory_warn_mb=memory_warn_mb, + memory_limit_mb=memory_limit_mb, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + high_ping_threshold=high_ping_threshold, + ) + + self._user_args: Any | None = None + self._exitcode: int | None = None + self._pid: int | None = None + + self._supervise_atask: asyncio.Task[None] | None = None + self._closing = False + self._kill_sent = False + self._initialize_fut = asyncio.Future[None]() + + self._lock = asyncio.Lock() + + @abstractmethod + def _create_process( + self, cch: socket.socket, log_cch: socket.socket + ) -> mp.Process: ... + + @abstractmethod + async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None: ... + + @property + def exitcode(self) -> int | None: + return self._exitcode + + @property + def killed(self) -> bool: + return self._kill_sent + + @property + def pid(self) -> int | None: + return self._pid + + @property + def started(self) -> bool: + return self._supervise_atask is not None + + async def start(self) -> None: + """start the supervised process""" + if self.started: + raise RuntimeError("process already started") + + if self._closing: + raise RuntimeError("process is closed") + + await asyncio.shield(self._start()) + + async def _start(self) -> None: + def _add_proc_ctx_log(record: logging.LogRecord) -> None: + extra = self.logging_extra() + for key, value in extra.items(): + setattr(record, key, value) + + async with self._lock: + mp_pch, mp_cch = socket.socketpair() + mp_log_pch, mp_log_cch = socket.socketpair() + + self._pch = await duplex_unix._AsyncDuplex.open(mp_pch) + + log_pch = duplex_unix._Duplex.open(mp_log_pch) + log_listener = LogQueueListener(log_pch, _add_proc_ctx_log) + log_listener.start() + + self._proc = self._create_process(mp_cch, mp_log_cch) + self._proc.start() + mp_log_cch.close() + mp_cch.close() + + self._pid = self._proc.pid + self._join_fut = asyncio.Future[None]() + + def _sync_run(): + self._proc.join() + log_listener.stop() + try: + self._loop.call_soon_threadsafe(self._join_fut.set_result, None) + except RuntimeError: + pass + + thread = threading.Thread(target=_sync_run, name="proc_join_thread") + thread.start() + self._supervise_atask = asyncio.create_task(self._supervise_task()) + + async def join(self) -> None: + """wait for the process to finish""" + if not self.started: + raise RuntimeError("process not started") + + async with self._lock: + if self._supervise_atask: + await asyncio.shield(self._supervise_atask) + + async def initialize(self) -> None: + """initialize the process, this is sending a InitializeRequest message and waiting for a + InitializeResponse with a timeout""" + await channel.asend_message(self._pch, proto.InitializeRequest()) + + # wait for the process to become ready + try: + init_res = await asyncio.wait_for( + channel.arecv_message(self._pch, proto.IPC_MESSAGES), + timeout=self._opts.initialize_timeout, + ) + assert isinstance( + init_res, proto.InitializeResponse + ), "first message must be InitializeResponse" + except asyncio.TimeoutError: + self._initialize_fut.set_exception( + asyncio.TimeoutError("process initialization timed out") + ) + logger.error( + "initialization timed out, killing process", extra=self.logging_extra() + ) + self._send_kill_signal() + raise + except Exception as e: # should be channel.ChannelClosed most of the time + self._initialize_fut.set_exception(e) + raise + else: + self._initialize_fut.set_result(None) + + async def aclose(self) -> None: + """attempt to gracefully close the supervised process""" + if not self.started: + return + + self._closing = True + with contextlib.suppress(duplex_unix.DuplexClosed): + await channel.asend_message(self._pch, proto.ShutdownRequest()) + + try: + if self._supervise_atask: + await asyncio.wait_for( + asyncio.shield(self._supervise_atask), + timeout=self._opts.close_timeout, + ) + except asyncio.TimeoutError: + logger.error( + "process did not exit in time, killing process", + extra=self.logging_extra(), + ) + self._send_kill_signal() + + async with self._lock: + if self._supervise_atask: + await asyncio.shield(self._supervise_atask) + + async def kill(self) -> None: + """forcefully kill the supervised process""" + if not self.started: + raise RuntimeError("process not started") + + self._closing = True + self._send_kill_signal() + + async with self._lock: + if self._supervise_atask: + await asyncio.shield(self._supervise_atask) + + def _send_kill_signal(self) -> None: + """forcefully kill the process""" + try: + if not self._proc.is_alive(): + return + except ValueError: + return + + logger.info("killing process", extra=self.logging_extra()) + if sys.platform == "win32": + self._proc.terminate() + else: + self._proc.kill() + + self._kill_sent = True + + @log_exceptions(logger=logger) + async def _supervise_task(self) -> None: + try: + await self._initialize_fut + except asyncio.TimeoutError: + pass # this happens when the initialization takes longer than self._initialize_timeout + except Exception: + pass # initialization failed + + # the process is killed if it doesn't respond to ping requests + pong_timeout = aio.sleep(self._opts.ping_timeout) + + ipc_ch = aio.Chan[channel.Message]() + + main_task = asyncio.create_task(self._main_task(ipc_ch)) + read_ipc_task = asyncio.create_task(self._read_ipc_task(ipc_ch, pong_timeout)) + ping_task = asyncio.create_task(self._ping_pong_task(pong_timeout)) + read_ipc_task.add_done_callback(lambda _: ipc_ch.close()) + + memory_monitor_task: asyncio.Task[None] | None = None + if self._opts.memory_limit_mb > 0 or self._opts.memory_warn_mb > 0: + memory_monitor_task = asyncio.create_task(self._memory_monitor_task()) + + await self._join_fut + self._exitcode = self._proc.exitcode + self._proc.close() + await aio.gracefully_cancel(ping_task, read_ipc_task, main_task) + + if memory_monitor_task is not None: + await aio.gracefully_cancel(memory_monitor_task) + + with contextlib.suppress(duplex_unix.DuplexClosed): + await self._pch.aclose() + + if self._exitcode != 0 and not self._kill_sent: + logger.error( + f"process exited with non-zero exit code {self.exitcode}", + extra=self.logging_extra(), + ) + + @log_exceptions(logger=logger) + async def _read_ipc_task( + self, ipc_ch: aio.Chan[channel.Message], pong_timeout: aio.Sleep + ) -> None: + while True: + try: + msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES) + except duplex_unix.DuplexClosed: + break + + if isinstance(msg, proto.PongResponse): + delay = time_ms() - msg.timestamp + if delay > self._opts.high_ping_threshold * 1000: + logger.warning( + "process is unresponsive", + extra={"delay": delay, **self.logging_extra()}, + ) + + with contextlib.suppress(aio.SleepFinished): + pong_timeout.reset() + + if isinstance(msg, proto.Exiting): + logger.info( + "process exiting", + extra={"reason": msg.reason, **self.logging_extra()}, + ) + + ipc_ch.send_nowait(msg) + + @log_exceptions(logger=logger) + async def _ping_pong_task(self, pong_timeout: aio.Sleep) -> None: + ping_interval = aio.interval(self._opts.ping_interval) + + async def _send_ping_co(): + while True: + await ping_interval.tick() + try: + await channel.asend_message( + self._pch, proto.PingRequest(timestamp=time_ms()) + ) + except duplex_unix.DuplexClosed: + break + + async def _pong_timeout_co(): + await pong_timeout + logger.error( + "process is unresponsive, killing process", extra=self.logging_extra() + ) + self._send_kill_signal() + + tasks = [ + asyncio.create_task(_send_ping_co()), + asyncio.create_task(_pong_timeout_co()), + ] + try: + await asyncio.gather(*tasks) + finally: + await aio.gracefully_cancel(*tasks) + + @log_exceptions(logger=logger) + async def _memory_monitor_task(self) -> None: + """Monitor memory usage and kill the process if it exceeds the limit.""" + while not self._closing and not self._kill_sent: + try: + if not self._pid: + await asyncio.sleep(5) + continue + + # get process memory info + process = psutil.Process(self._pid) + memory_info = process.memory_info() + memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB + + if ( + self._opts.memory_limit_mb > 0 + and memory_mb > self._opts.memory_limit_mb + ): + logger.error( + "process exceeded memory limit, killing process", + extra={ + "memory_usage_mb": memory_mb, + "memory_limit_mb": self._opts.memory_limit_mb, + **self.logging_extra(), + }, + ) + self._send_kill_signal() + elif ( + self._opts.memory_warn_mb > 0 + and memory_mb > self._opts.memory_warn_mb + ): + logger.warning( + "process memory usage is high", + extra={ + "memory_usage_mb": memory_mb, + "memory_warn_mb": self._opts.memory_warn_mb, + "memory_limit_mb": self._opts.memory_limit_mb, + **self.logging_extra(), + }, + ) + + except (psutil.NoSuchProcess, psutil.AccessDenied) as e: + logger.warning( + "Failed to get memory info for process", + extra=self.logging_extra(), + exc_info=e, + ) + except Exception: + if self._closing or self._kill_sent: + return + + logger.exception( + "Error in memory monitoring task", + extra=self.logging_extra(), + ) + + await asyncio.sleep(5) # check every 5 seconds + + def logging_extra(self): + extra: dict[str, Any] = { + "pid": self.pid, + } + + return extra From be961e69dfb23b1ec8d6d5791f13de7d9cd8bb8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 27 Nov 2024 17:23:20 +0100 Subject: [PATCH 22/24] Update job_proc_executor.py --- livekit-agents/livekit/agents/ipc/job_proc_executor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/livekit-agents/livekit/agents/ipc/job_proc_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py index c44b364b7..fc7c0e984 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -100,9 +100,10 @@ async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None: self._pch, proto.InferenceResponse(request_id=inf_req.request_id, data=inf_res), ) - except Exception: - logger.exception( - "error handling inference request", extra=self.logging_extra() + except Exception as e: + await channel.asend_message( + self._pch, + proto.InferenceResponse(request_id=inf_req.request_id, error=str(e)), ) async def launch_job(self, info: RunningJobInfo) -> None: From 3fa6ad67f0dd79adb2c862daea09b68789ed2de3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 27 Nov 2024 17:35:10 +0100 Subject: [PATCH 23/24] remove duplicate --- .../agents/ipc/inference_proc_executor.py | 333 ++++-------------- .../livekit/agents/ipc/job_proc_lazy_main.py | 2 +- .../livekit/agents/ipc/supervised_proc.py | 10 +- 3 files changed, 71 insertions(+), 274 deletions(-) diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py index b9cfe1a40..1f07ed07d 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py @@ -2,308 +2,97 @@ import asyncio import contextlib -import logging +import multiprocessing as mp import socket -import sys -import threading -from dataclasses import dataclass from multiprocessing.context import BaseContext -from typing import Any -from .. import utils -from ..inference_runner import _InferenceRunner +from ..inference_runner import _RunnersDict from ..log import logger -from ..utils.aio import duplex_unix -from . import channel, inference_proc_lazy_main, proto -from .inference_executor import InferenceExecutor, _InferenceRunnerClient -from .log_queue import LogQueueListener +from ..utils import aio, log_exceptions, shortuuid +from . import channel, proto +from .inference_proc_lazy_main import ProcStartArgs, proc_main +from .supervised_proc import SupervisedProc -@dataclass -class _ProcOpts: - mp_ctx: BaseContext - initialize_timeout: float - close_timeout: float - - -class InferenceProcExecutor(InferenceExecutor): +class InferenceProcExecutor(SupervisedProc): def __init__( self, *, + runners: _RunnersDict, + initialize_timeout: float, + close_timeout: float, + memory_warn_mb: float, + memory_limit_mb: float, + ping_interval: float, + ping_timeout: float, + high_ping_threshold: float, mp_ctx: BaseContext, loop: asyncio.AbstractEventLoop, - initialize_timeout: float = 60.0, - close_timeout: float = 2.5, ) -> None: - self._loop = loop - self._opts = _ProcOpts( + super().__init__( initialize_timeout=initialize_timeout, close_timeout=close_timeout, + memory_warn_mb=memory_warn_mb, + memory_limit_mb=memory_limit_mb, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + high_ping_threshold=high_ping_threshold, mp_ctx=mp_ctx, + loop=loop, ) - self._exitcode: int | None = None - self._pid: int | None = None - - self._main_atask: asyncio.Task[None] | None = None - self._closing = False - self._kill_sent = False - self._initialize_fut = asyncio.Future[None]() - - self._lock = asyncio.Lock() - - @property - def exitcode(self) -> int | None: - return self._exitcode - - @property - def killed(self) -> bool: - return self._kill_sent - - @property - def pid(self) -> int | None: - return self._pid - - @property - def started(self) -> bool: - return self._main_atask is not None - - async def start(self) -> None: - if self.started: - raise RuntimeError("process already started") - - if self._closing: - raise RuntimeError("process is closed") - - await asyncio.shield(self._start()) - - async def _start(self) -> None: - def _add_proc_ctx_log(record: logging.LogRecord) -> None: - extra = self.logging_extra() - for key, value in extra.items(): - setattr(record, key, value) - - async with self._lock: - self._pong_timeout = utils.aio.sleep(proto.PING_TIMEOUT) - - mp_pch, mp_cch = socket.socketpair() - mp_log_pch, mp_log_cch = socket.socketpair() - - self._pch = await duplex_unix._AsyncDuplex.open(mp_pch) + self._runners = runners + self._active_requests: dict[str, asyncio.Future[proto.InferenceResponse]] = {} - log_pch = duplex_unix._Duplex.open(mp_log_pch) - log_listener = LogQueueListener(log_pch, _add_proc_ctx_log) - log_listener.start() - - self._proc_args = inference_proc_lazy_main.ProcStartArgs( - log_cch=mp_log_cch, - mp_cch=mp_cch, - asyncio_debug=self._loop.get_debug(), - runners=_InferenceRunner.registered_runners, - ) - - self._inf_client = _InferenceRunnerClient(cch=self._pch) - - self._proc = self._opts.mp_ctx.Process( # type: ignore - target=inference_proc_lazy_main.proc_main, - args=(self._proc_args,), - name="inference_proc", - ) - - self._proc.start() - mp_log_cch.close() - mp_cch.close() - - self._pid = self._proc.pid - self._join_fut = asyncio.Future[None]() - - def _sync_run(): - self._proc.join() - log_listener.stop() - try: - self._loop.call_soon_threadsafe(self._join_fut.set_result, None) - except RuntimeError: - pass + def _create_process(self, cch: socket.socket, log_cch: socket.socket) -> mp.Process: + proc_args = ProcStartArgs( + log_cch=log_cch, + mp_cch=cch, + runners=self._runners, + ) - thread = threading.Thread(target=_sync_run, name="proc_join_thread") - thread.start() - self._main_atask = asyncio.create_task(self._main_task()) + return self._opts.mp_ctx.Process( # type: ignore + target=proc_main, + args=(proc_args,), + name="inference_proc", + ) - async def join(self) -> None: - if not self.started: - raise RuntimeError("process not started") + @log_exceptions(logger=logger) + async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None: + async for msg in ipc_ch: + if isinstance(msg, proto.InferenceResponse): + fut = self._active_requests.pop(msg.request_id, None) + if fut is None: + logger.warning( + "received unexpected inference response", + extra={"request_id": msg.request_id}, + ) + return - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) + with contextlib.suppress(asyncio.InvalidStateError): + fut.set_result(msg) async def do_inference(self, method: str, data: bytes) -> bytes | None: if not self.started: raise RuntimeError("process not started") - return await self._inf_client.do_inference(method, data) - - async def initialize(self) -> None: - await channel.asend_message(self._pch, proto.InitializeRequest()) + request_id = shortuuid("inference_req_") + fut = asyncio.Future[proto.InferenceResponse]() - # wait for the process to become ready - try: - init_res = await asyncio.wait_for( - channel.arecv_message(self._pch, proto.IPC_MESSAGES), - timeout=self._opts.initialize_timeout, - ) - assert isinstance( - init_res, proto.InitializeResponse - ), "first message must be InitializeResponse" - except asyncio.TimeoutError: - self._initialize_fut.set_exception( - asyncio.TimeoutError("process initialization timed out") - ) - logger.error( - "initialization timed out, killing process", extra=self.logging_extra() - ) - self._send_kill_signal() - raise - except Exception as e: # should be channel.ChannelClosed most of the time - self._initialize_fut.set_exception(e) - raise - else: - self._initialize_fut.set_result(None) - - async def aclose(self) -> None: - if not self.started: - return - - self._closing = True - with contextlib.suppress(utils.aio.duplex_unix.DuplexClosed): - await channel.asend_message(self._pch, proto.ShutdownRequest()) - - try: - if self._main_atask: - await asyncio.wait_for( - asyncio.shield(self._main_atask), timeout=self._opts.close_timeout - ) - except asyncio.TimeoutError: - logger.error( - "process did not exit in time, killing process", - extra=self.logging_extra(), - ) - self._send_kill_signal() - - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) - - async def kill(self) -> None: - if not self.started: - raise RuntimeError("process not started") - - self._closing = True - self._send_kill_signal() - - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) - - def _send_kill_signal(self) -> None: - try: - if not self._proc.is_alive(): - return - except ValueError: - return - - logger.info("killing process", extra=self.logging_extra()) - if sys.platform == "win32": - self._proc.terminate() - else: - self._proc.kill() - - self._kill_sent = True - - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - try: - await self._initialize_fut - except asyncio.TimeoutError: - pass # this happens when the initialization takes longer than self._initialize_timeout - except Exception: - pass # initialization failed - - # the process is killed if it doesn't respond to ping requests - ping_task = asyncio.create_task(self._ping_pong_task()) - monitor_task = asyncio.create_task(self._monitor_task()) - - await self._join_fut - self._exitcode = self._proc.exitcode - self._proc.close() - await utils.aio.gracefully_cancel(ping_task, monitor_task) - - with contextlib.suppress(duplex_unix.DuplexClosed): - await self._pch.aclose() - - if self._exitcode != 0 and not self._kill_sent: - logger.error( - f"inference process exited with non-zero exit code {self.exitcode}", - extra=self.logging_extra(), - ) - - @utils.log_exceptions(logger=logger) - async def _monitor_task(self) -> None: - while True: - try: - msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES) - except utils.aio.duplex_unix.DuplexClosed: - break - - if isinstance(msg, proto.PongResponse): - delay = utils.time_ms() - msg.timestamp - if delay > proto.HIGH_PING_THRESHOLD * 1000: - logger.warning( - "inference process is unresponsive", - extra={"delay": delay, **self.logging_extra()}, - ) - - with contextlib.suppress(utils.aio.SleepFinished): - self._pong_timeout.reset() - - if isinstance(msg, proto.InferenceResponse): - self._inf_client._on_inference_response(msg) - - @utils.log_exceptions(logger=logger) - async def _ping_pong_task(self) -> None: - ping_interval = utils.aio.interval(proto.PING_INTERVAL) + await channel.asend_message( + self._pch, + proto.InferenceRequest(request_id=request_id, method=method, data=data), + ) - async def _send_ping_co(): - while True: - await ping_interval.tick() - try: - await channel.asend_message( - self._pch, proto.PingRequest(timestamp=utils.time_ms()) - ) - except utils.aio.duplex_unix.DuplexClosed: - break + self._active_requests[request_id] = fut - async def _pong_timeout_co(): - while True: - await self._pong_timeout - logger.error( - "inference process is unresponsive, killing proc", - extra=self.logging_extra(), - ) - self._pong_timeout = utils.aio.sleep(proto.PING_TIMEOUT) + inf_resp = await fut + if inf_resp.error: + raise RuntimeError(f"inference of {method} failed: {inf_resp.error}") - tasks = [ - asyncio.create_task(_send_ping_co()), - asyncio.create_task(_pong_timeout_co()), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) + return inf_resp.data def logging_extra(self): - extra: dict[str, Any] = { - "pid": self.pid, - "inference_proc": True, - } + extra = super().logging_extra() + extra["inference"] = True return extra diff --git a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py index 9be517d75..75d2ccaee 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -80,7 +80,7 @@ def __init__(self, proc_client: _ProcClient) -> None: self._active_requests: dict[str, asyncio.Future[InferenceResponse]] = {} async def do_inference(self, method: str, data: bytes) -> bytes | None: - request_id = shortuuid("inference_") + request_id = shortuuid("inference_job_") fut = asyncio.Future[InferenceResponse]() await self._client.send( diff --git a/livekit-agents/livekit/agents/ipc/supervised_proc.py b/livekit-agents/livekit/agents/ipc/supervised_proc.py index b4be1c631..6b00f9388 100644 --- a/livekit-agents/livekit/agents/ipc/supervised_proc.py +++ b/livekit-agents/livekit/agents/ipc/supervised_proc.py @@ -151,7 +151,15 @@ async def join(self) -> None: async def initialize(self) -> None: """initialize the process, this is sending a InitializeRequest message and waiting for a InitializeResponse with a timeout""" - await channel.asend_message(self._pch, proto.InitializeRequest()) + await channel.asend_message( + self._pch, + proto.InitializeRequest( + asyncio_debug=self._loop.get_debug(), + ping_interval=self._opts.ping_interval, + ping_timeout=self._opts.ping_timeout, + high_ping_threshold=self._opts.high_ping_threshold, + ), + ) # wait for the process to become ready try: From 01d02b4b6c5a5a9ea1021422546a95b2df4587be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 27 Nov 2024 17:36:29 +0100 Subject: [PATCH 24/24] wip --- livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py | 2 ++ livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py index d65afc273..c4e949d58 100644 --- a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py @@ -47,6 +47,8 @@ def proc_main(args: ProcStartArgs) -> None: inf_proc.entrypoint, ) + client.initialize_logger() + pid = current_process().pid logger.info("initializing inference process", extra={"pid": pid}) client.initialize() diff --git a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py index 75d2ccaee..26a207301 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -66,6 +66,8 @@ def proc_main(args: ProcStartArgs) -> None: job_proc.entrypoint, ) + client.initialize_logger() + pid = current_process().pid logger.info("initializing job process", extra={"pid": pid}) client.initialize()