Skip to content

Commit

Permalink
Add max_retry_interval to lf.concurrent_execute and `lf.LanguageM…
Browse files Browse the repository at this point in the history
…odel`.

This avoids very long wait time when `exponential_backoff` is set to True.

PiperOrigin-RevId: 689608563
  • Loading branch information
daiyip authored and langfun authors committed Oct 25, 2024
1 parent 9f5149d commit baba539
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 3 deletions.
11 changes: 10 additions & 1 deletion langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def with_retry(
max_attempts: int,
retry_interval: int | tuple[int, int] = (5, 60),
exponential_backoff: bool = True,
max_retry_interval: int = 300,
seed: int | None = None,
) -> Callable[..., Any]:
"""Derives a user function with retry on error.
Expand All @@ -133,6 +134,9 @@ def with_retry(
of the tuple.
exponential_backoff: If True, exponential wait time will be applied on top
of the base retry interval.
max_retry_interval: The max retry interval in seconds. This is useful when
the retry interval is exponential, to avoid the wait time to grow
exponentially.
seed: Random seed to generate retry interval. If None, the seed will be
determined based on current time.
Expand All @@ -153,7 +157,7 @@ def base_interval() -> int:
def next_wait_interval(attempt: int) -> float:
if not exponential_backoff:
attempt = 1
return base_interval() * (2 ** (attempt - 1))
return min(max_retry_interval, base_interval() * (2 ** (attempt - 1)))

wait_intervals = []
errors = []
Expand Down Expand Up @@ -193,6 +197,7 @@ def concurrent_execute(
max_attempts: int = 5,
retry_interval: int | tuple[int, int] = (5, 60),
exponential_backoff: bool = True,
max_retry_interval: int = 300,
) -> list[Any]:
"""Executes a function concurrently under current component context.
Expand All @@ -213,6 +218,9 @@ def concurrent_execute(
of the tuple.
exponential_backoff: If True, exponential wait time will be applied on top
of the base retry interval.
max_retry_interval: The max retry interval in seconds. This is useful when
the retry interval is exponential, to avoid the wait time to grow
exponentially.
Returns:
A list of ouputs. Each is the return value of `func` based on the input
Expand All @@ -225,6 +233,7 @@ def concurrent_execute(
max_attempts=max_attempts,
retry_interval=retry_interval,
exponential_backoff=exponential_backoff,
max_retry_interval=max_retry_interval,
)

# NOTE(daiyip): when executor is not specified and max_worker is 1,
Expand Down
11 changes: 10 additions & 1 deletion langfun/core/concurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,19 @@ def foo():
raise ValueError('Intentional error.')

foo_with_retry = concurrent.with_retry(
foo, ValueError, max_attempts=4, retry_interval=1
foo, ValueError, max_attempts=4, retry_interval=1,
)
self.assert_retry(foo_with_retry, 4, [1, 2, 4])

def test_retry_with_max_retry_interval(self):
def foo():
raise ValueError('Intentional error.')

foo_with_retry = concurrent.with_retry(
foo, ValueError, max_attempts=4, retry_interval=1, max_retry_interval=3,
)
self.assert_retry(foo_with_retry, 4, [1, 2, 3])

def test_retry_with_uncaught_exception(self):
def foo():
raise ValueError('Intentional error.')
Expand Down
3 changes: 2 additions & 1 deletion langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def test_call(self):
' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
' max_concurrency=None, timeout=120.0, max_attempts=5,'
' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
' retry_interval=(5, 60), exponential_backoff=True,'
' max_retry_interval=300, debug=False))',
)

l = LangFunc('Hello')
Expand Down
10 changes: 10 additions & 0 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,15 @@ class LanguageModel(component.Component):
)
] = True

max_retry_interval: Annotated[
int,
(
'The max retry interval in seconds. This is useful when the retry '
'interval is exponential, to avoid the wait time to grow '
'exponentially.'
)
] = 300

debug: Annotated[
bool | LMDebugMode,
(
Expand Down Expand Up @@ -587,6 +596,7 @@ def _parallel_execute_with_currency_control(
max_attempts=self.max_attempts,
retry_interval=self.retry_interval,
exponential_backoff=self.exponential_backoff,
max_retry_interval=self.max_retry_interval,
)

def __call__(
Expand Down

0 comments on commit baba539

Please sign in to comment.