diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index 8feead6..65a5ada 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -17,6 +17,7 @@ import contextlib import dataclasses import enum +import threading import time from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union from langfun.core import component @@ -728,6 +729,7 @@ class _UsageTracker: def __init__(self, model_ids: set[str] | None): self.model_ids = model_ids + self._lock = threading.Lock() self.usages = { m: LMSamplingUsage(0, 0, 0, 0) for m in model_ids } if model_ids else {} @@ -735,10 +737,11 @@ def __init__(self, model_ids: set[str] | None): def track(self, model_id: str, usage: LMSamplingUsage): if self.model_ids is not None and model_id not in self.model_ids: return - if not isinstance(usage, UsageNotAvailable) and model_id in self.usages: - self.usages[model_id] += usage - else: - self.usages[model_id] = usage + with self._lock: + if not isinstance(usage, UsageNotAvailable) and model_id in self.usages: + self.usages[model_id] += usage + else: + self.usages[model_id] = usage @contextlib.contextmanager