From e5dfc275872a4fdc3dc4905106f255ba5273854a Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Thu, 18 Jul 2024 17:40:13 -0700 Subject: [PATCH] Add lock for usage tracking. This avoids racing requests to be counted just once. PiperOrigin-RevId: 653809899 --- langfun/core/language_model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index 8feead6..65a5ada 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -17,6 +17,7 @@ import contextlib import dataclasses import enum +import threading import time from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union from langfun.core import component @@ -728,6 +729,7 @@ class _UsageTracker: def __init__(self, model_ids: set[str] | None): self.model_ids = model_ids + self._lock = threading.Lock() self.usages = { m: LMSamplingUsage(0, 0, 0, 0) for m in model_ids } if model_ids else {} @@ -735,10 +737,11 @@ def __init__(self, model_ids: set[str] | None): def track(self, model_id: str, usage: LMSamplingUsage): if self.model_ids is not None and model_id not in self.model_ids: return - if not isinstance(usage, UsageNotAvailable) and model_id in self.usages: - self.usages[model_id] += usage - else: - self.usages[model_id] = usage + with self._lock: + if not isinstance(usage, UsageNotAvailable) and model_id in self.usages: + self.usages[model_id] += usage + else: + self.usages[model_id] = usage @contextlib.contextmanager