diff --git a/.changeset/ninety-fans-jump.md b/.changeset/ninety-fans-jump.md new file mode 100644 index 000000000..2dd38c8a4 --- /dev/null +++ b/.changeset/ninety-fans-jump.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +add max_job_memory_usage and will kill the job if it exceeds the limit diff --git a/livekit-agents/livekit/agents/ipc/job_executor.py b/livekit-agents/livekit/agents/ipc/job_executor.py index 19704791a..7d546bd2e 100644 --- a/livekit-agents/livekit/agents/ipc/job_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_executor.py @@ -58,3 +58,7 @@ class JobExecutorError_Unresponsive(JobExecutorError): class JobExecutorError_Runtime(JobExecutorError): pass + + +class JobExecutorError_MemoryLimitExceeded(JobExecutorError): + pass diff --git a/livekit-agents/livekit/agents/ipc/proc_job_executor.py b/livekit-agents/livekit/agents/ipc/proc_job_executor.py index 2a956d947..a42f11ce4 100644 --- a/livekit-agents/livekit/agents/ipc/proc_job_executor.py +++ b/livekit-agents/livekit/agents/ipc/proc_job_executor.py @@ -11,12 +11,15 @@ from multiprocessing.context import BaseContext from typing import Any, Awaitable, Callable +import psutil + from .. import utils from ..job import JobContext, JobProcess, RunningJobInfo from ..log import logger from ..utils.aio import duplex_unix from . import channel, job_main, proc_lazy_main, proto from .job_executor import ( + JobExecutorError_MemoryLimitExceeded, JobExecutorError_Runtime, JobExecutorError_ShutdownTimeout, JobExecutorError_Unresponsive, @@ -73,6 +76,8 @@ class _ProcOpts: mp_ctx: BaseContext initialize_timeout: float close_timeout: float + job_memory_warn_mb: float + job_memory_limit_mb: float class ProcJobExecutor: @@ -85,6 +90,8 @@ def __init__( close_timeout: float, mp_ctx: BaseContext, loop: asyncio.AbstractEventLoop, + job_memory_warn_mb: float = 0, + job_memory_limit_mb: float = 0, ) -> None: self._loop = loop self._opts = _ProcOpts( @@ -93,6 +100,8 @@ def __init__( initialize_timeout=initialize_timeout, close_timeout=close_timeout, mp_ctx=mp_ctx, + job_memory_warn_mb=job_memory_warn_mb, + job_memory_limit_mb=job_memory_limit_mb, ) self._user_args: Any | None = None @@ -335,11 +344,19 @@ async def _main_task(self) -> None: ping_task = asyncio.create_task(self._ping_pong_task(pong_timeout)) monitor_task = asyncio.create_task(self._monitor_task(pong_timeout)) + if self._opts.job_memory_limit_mb > 0 or self._opts.job_memory_warn_mb > 0: + memory_monitor_task = asyncio.create_task(self._memory_monitor_task()) + else: + memory_monitor_task = None + await self._join_fut self._exitcode = self._proc.exitcode self._proc.close() await utils.aio.gracefully_cancel(ping_task, monitor_task) + if memory_monitor_task: + await utils.aio.gracefully_cancel(memory_monitor_task) + with contextlib.suppress(duplex_unix.DuplexClosed): await self._pch.aclose() @@ -403,6 +420,65 @@ async def _pong_timeout_co(): finally: await utils.aio.gracefully_cancel(*tasks) + @utils.log_exceptions(logger=logger) + async def _memory_monitor_task(self) -> None: + """Monitor memory usage and kill the process if it exceeds the limit.""" + while not self._closing and not self._kill_sent: + try: + if not self._pid or not self._running_job: + await asyncio.sleep(5) + continue + + # Get process memory info + process = psutil.Process(self._pid) + memory_info = process.memory_info() + memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB + + if ( + self._opts.job_memory_limit_mb > 0 + and memory_mb > self._opts.job_memory_limit_mb + ): + logger.error( + "Job exceeded memory limit, killing job", + extra={ + "memory_usage_mb": memory_mb, + "memory_limit_mb": self._opts.job_memory_limit_mb, + **self.logging_extra(), + }, + ) + self._exception = JobExecutorError_MemoryLimitExceeded() + self._send_kill_signal() + elif ( + self._opts.job_memory_warn_mb > 0 + and memory_mb > self._opts.job_memory_warn_mb + ): + logger.warning( + "Job memory usage is high", + extra={ + "memory_usage_mb": memory_mb, + "memory_warn_mb": self._opts.job_memory_warn_mb, + "memory_limit_mb": self._opts.job_memory_limit_mb, + **self.logging_extra(), + }, + ) + + except (psutil.NoSuchProcess, psutil.AccessDenied) as e: + logger.warning( + "Failed to get memory info for process", + extra=self.logging_extra(), + exc_info=e, + ) + except Exception: + if self._closing or self._kill_sent: + return + + logger.exception( + "Error in memory monitoring task", + extra=self.logging_extra(), + ) + + await asyncio.sleep(5) # Check every 5 seconds + def logging_extra(self): extra: dict[str, Any] = { "pid": self.pid, diff --git a/livekit-agents/livekit/agents/ipc/proc_pool.py b/livekit-agents/livekit/agents/ipc/proc_pool.py index d707987ab..a67a42d1d 100644 --- a/livekit-agents/livekit/agents/ipc/proc_pool.py +++ b/livekit-agents/livekit/agents/ipc/proc_pool.py @@ -34,6 +34,8 @@ def __init__( job_executor_type: JobExecutorType, mp_ctx: BaseContext, loop: asyncio.AbstractEventLoop, + job_memory_warn_mb: float = 0, + job_memory_limit_mb: float = 0, ) -> None: super().__init__() self._job_executor_type = job_executor_type @@ -43,7 +45,8 @@ def __init__( self._close_timeout = close_timeout self._initialize_timeout = initialize_timeout self._loop = loop - + self._job_memory_limit_mb = job_memory_limit_mb + self._job_memory_warn_mb = job_memory_warn_mb self._num_idle_processes = num_idle_processes self._init_sem = asyncio.Semaphore(MAX_CONCURRENT_INITIALIZATIONS) self._proc_needed_sem = asyncio.Semaphore(num_idle_processes) @@ -110,6 +113,8 @@ async def _proc_watch_task(self) -> None: close_timeout=self._close_timeout, mp_ctx=self._mp_ctx, loop=self._loop, + job_memory_warn_mb=self._job_memory_warn_mb, + job_memory_limit_mb=self._job_memory_limit_mb, ) else: raise ValueError(f"unsupported job executor: {self._job_executor_type}") diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index a9a6c39b3..e9e0899f7 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -158,6 +158,15 @@ class WorkerOptions: Defaults to 0.75 on "production" mode, and is disabled in "development" mode. """ + + job_memory_warn_mb: float = 300 + """Memory warning threshold in MB. If the job process exceeds this limit, a warning will be logged.""" + job_memory_limit_mb: float = 0 + """Maximum memory usage for a job in MB, the job process will be killed if it exceeds this limit. + Defaults to 0 (disabled). + """ + + """Number of idle processes to keep warm.""" num_idle_processes: int | _WorkerEnvOption[int] = _WorkerEnvOption( dev_default=0, prod_default=3 ) @@ -234,6 +243,15 @@ def __init__( "api_secret is required, or add LIVEKIT_API_SECRET in your environment" ) + if ( + opts.job_memory_limit_mb > 0 + and opts.job_executor_type != JobExecutorType.PROCESS + ): + logger.warning( + "max_job_memory_usage is only supported for process-based job executors, " + "ignoring max_job_memory_usage" + ) + self._opts = opts self._loop = loop or asyncio.get_event_loop() @@ -259,6 +277,8 @@ def __init__( mp_ctx=mp_ctx, initialize_timeout=opts.initialize_process_timeout, close_timeout=opts.shutdown_process_timeout, + job_memory_warn_mb=opts.job_memory_warn_mb, + job_memory_limit_mb=opts.job_memory_limit_mb, ) self._proc_pool.on("process_started", self._on_process_started) self._proc_pool.on("process_closed", self._on_process_closed)