diff --git a/langfun/core/langfunc_test.py b/langfun/core/langfunc_test.py index 4a3a3312..a8ac5bdf 100644 --- a/langfun/core/langfunc_test.py +++ b/langfun/core/langfunc_test.py @@ -97,7 +97,8 @@ def test_call(self): "LangFunc(template_str='Hello', clean=True, " 'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, ' 'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), ' - 'cache=None, timeout=120.0, max_attempts=5, debug=False), ' + 'cache=None, timeout=120.0, max_attempts=5, retry_interval=(5, 60), ' + 'exponential_backoff=True, debug=False), ' 'input_transform=None, output_transform=None)', ) diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index 815d982a..648ee8a0 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -181,6 +181,27 @@ class LanguageModel(component.Component): ), ] = 5 + retry_interval: Annotated[ + int | tuple[int, int], + ( + 'An integer as a constant wait time in seconds before next retry, ' + 'or a tuple of two integers representing the range of wait time, ' + 'based on which the next wait time will be randmly chosen.' + ) + ] = (5, 60) + + exponential_backoff: Annotated[ + bool, + ( + 'If True, the wait time among multiple attempts will exponentially ' + 'grow. If `retry_interval` is an integer, the wait time for the ' + 'k\'th attempt will be `retry_interval * 2 ^ (k - 1)` seconds. If ' + '`retry_interval` is a tuple, the wait time range for the k\'th ' + 'attempt will be `(retry_interval[0] * 2 ^ (k - 1), ' + 'retry_interval[1] * 2 ^ (k - 1)`) seconds.' + ) + ] = True + debug: Annotated[ bool | LMDebugMode, ( diff --git a/langfun/core/llms/llama_cpp.py b/langfun/core/llms/llama_cpp.py index 276a20ec..5faed47a 100644 --- a/langfun/core/llms/llama_cpp.py +++ b/langfun/core/llms/llama_cpp.py @@ -71,6 +71,6 @@ def _complete_fn(cur_prompts): _complete_fn, retry_on_errors=(), max_attempts=self.max_attempts, - retry_interval=(1, 60), - exponential_backoff=True, + retry_interval=self.retry_interval, + exponential_backoff=self.exponential_backoff, )(prompts) diff --git a/langfun/core/llms/openai.py b/langfun/core/llms/openai.py index cfe47b02..13166fcf 100644 --- a/langfun/core/llms/openai.py +++ b/langfun/core/llms/openai.py @@ -169,8 +169,8 @@ def _open_ai_completion(prompts): openai_error.RateLimitError, ), max_attempts=self.max_attempts, - retry_interval=(1, 60), - exponential_backoff=True, + retry_interval=self.retry_interval, + exponential_backoff=self.exponential_backoff, )(prompts) def _chat_complete_batch( @@ -204,8 +204,8 @@ def _open_ai_chat_completion(prompt): openai_error.RateLimitError, ), max_attempts=self.max_attempts, - retry_interval=(1, 60), - exponential_backoff=True, + retry_interval=self.retry_interval, + exponential_backoff=self.exponential_backoff, )