diff --git a/langfun/core/concurrent.py b/langfun/core/concurrent.py index 62055ad..f63f305 100644 --- a/langfun/core/concurrent.py +++ b/langfun/core/concurrent.py @@ -101,7 +101,7 @@ def with_retry( Sequence[Union[Type[Exception], Tuple[Exception, str]]], ], max_attempts: int, - retry_interval: int | tuple[int, int] = (1, 60), + retry_interval: int | tuple[int, int] = (5, 60), exponential_backoff: bool = True, seed: int | None = None, ) -> Callable[..., Any]: @@ -135,10 +135,10 @@ def base_interval() -> int: assert isinstance(retry_interval, int) return retry_interval - def next_wait_interval(last_wait_interval: int | None) -> int: - if last_wait_interval is None or not exponential_backoff: - return base_interval() - return last_wait_interval * 2 + def next_wait_interval(attempt: int) -> float: + if not exponential_backoff: + attempt = 1 + return base_interval() * (2 ** (attempt - 1)) wait_interval = None wait_intervals = [] @@ -150,7 +150,7 @@ def next_wait_interval(last_wait_interval: int | None) -> int: # Branch when errors are met for retry. errors.append(error_context.error) if len(errors) < max_attempts: - wait_interval = next_wait_interval(wait_interval) + wait_interval = next_wait_interval(len(errors)) wait_intervals.append(wait_interval) pg.logging.warning( @@ -175,7 +175,7 @@ def concurrent_execute( None, ] = None, max_attempts: int = 5, - retry_interval: int | tuple[int, int] = (1, 60), + retry_interval: int | tuple[int, int] = (5, 60), exponential_backoff: bool = True, ) -> list[Any]: """Executes a function concurrently under current component context. @@ -221,14 +221,33 @@ class Job: arg: Any result: Any = pg.MISSING_VALUE error: Exception | None = None + start_time: float | None = None + end_time: float | None = None def __call__(self) -> Any: + self.start_time = time.time() try: self.result = self.func(self.arg) return self.result except Exception as e: # pylint: disable=broad-exception-caught self.error = e return e + finally: + self.end_time = time.time() + + def mark_canceled(self, error: Exception) -> None: + """Marks the job as canceled.""" + self.error = error + self.end_time = time.time() + + @property + def elapse(self) -> float: + """Returns the running time in seconds since the job get started.""" + if self.start_time is None: + return 0.0 + if self.end_time is None: + return time.time() - self.start_time + return self.end_time - self.start_time def concurrent_map( @@ -249,7 +268,7 @@ def concurrent_map( None, ] = None, max_attempts: int = 5, - retry_interval: int | tuple[int, int] = (1, 60), + retry_interval: int | tuple[int, int] = (5, 60), exponential_backoff: bool = True, ) -> Iterator[tuple[Any, Any, Exception | None]]: """Maps inputs to outptus via func concurrently under current context. @@ -297,13 +316,6 @@ def concurrent_map( exponential_backoff=exponential_backoff, ) - start_time = time.time() - - def remaining_time(): - if timeout is None: - return None - return time.time() - start_time - executor = executor or concurrent.futures.ThreadPoolExecutor( max_workers=max_workers) @@ -330,13 +342,13 @@ def update_progress(success: int, failure: int) -> None: % (success * 100.0 / completed, success, completed, failure * 100.0 / completed, failure, completed)) - remaining_futures = [] success, failure = 0, 0 if ordered: - for i, future in enumerate(pending_futures): + for future in pending_futures: + job = future_to_job[future] + wait_time = (timeout - job.elapse) if timeout else None try: - _ = future.result(timeout=remaining_time()) - job = future_to_job[future] + _ = future.result(timeout=wait_time) if job.error is not None: failure += 1 if not ( @@ -344,37 +356,56 @@ def update_progress(success: int, failure: int) -> None: raise job.error else: success += 1 - update_progress(success, failure) - del future_to_job[future] - yield job.arg, job.result, job.error except concurrent.futures.TimeoutError: - remaining_futures = pending_futures[i:] - break - else: - for future in concurrent.futures.as_completed( - pending_futures, timeout=remaining_time() - ): - job = future_to_job[future] - del future_to_job[future] - if job.error is not None: + future.cancel() + job.mark_canceled( + TimeoutError(f'Execution time ({job.elapse}) ' + f'exceeds {timeout} seconds.')) failure += 1 - if not ( - silence_on_errors and isinstance(job.error, silence_on_errors) - ): - raise job.error # pylint: disable=g-doc-exception - else: - success += 1 update_progress(success, failure) yield job.arg, job.result, job.error - remaining_futures = future_to_job - - # Flush pending requests. - for future in remaining_futures: - job = future_to_job[future] - if not future.done(): - future.cancel() - job.error = TimeoutError(f'Execution time exceeds {timeout} seconds.') - yield job.arg, job.result, job.error + else: + while pending_futures: + completed_batch = set() + try: + for future in concurrent.futures.as_completed( + pending_futures, timeout=timeout): + job = future_to_job[future] + del future_to_job[future] + if job.error is not None: + failure += 1 + if not ( + silence_on_errors and isinstance(job.error, silence_on_errors) + ): + raise job.error # pylint: disable=g-doc-exception + else: + success += 1 + update_progress(success, failure) + yield job.arg, job.result, job.error + completed_batch.add(future) + except concurrent.futures.TimeoutError: + pass + + remaining_futures = [] + + # When timeout is None, all future shall be completed through the loop + # above. + if timeout is not None: + for future in pending_futures: + if future in completed_batch: + continue + job = future_to_job[future] + if job.elapse > timeout: + if not future.done(): + future.cancel() + job.mark_canceled( + TimeoutError(f'Execution time ({job.elapse}) ' + f'exceeds {timeout} seconds.')) + yield job.arg, job.result, job.error + else: + remaining_futures.append(future) + + pending_futures = remaining_futures if progress is not None: progress.close() diff --git a/langfun/core/concurrent_test.py b/langfun/core/concurrent_test.py index f730175..0e5d119 100644 --- a/langfun/core/concurrent_test.py +++ b/langfun/core/concurrent_test.py @@ -283,25 +283,64 @@ def fun(x): with self.assertRaises(ValueError): _ = next(it) - def test_concurrent_map_with_timeout(self): + def test_concurrent_map_with_order_and_timeout(self): def fun(x): time.sleep(3 - x) return x - with component.context(y=0.9): - self.assertEqual( - [ - (i, o) - for i, o, _ in concurrent.concurrent_map( - fun, [1, 2, 3], ordered=True, timeout=2 - ) - ], - [ - (1, pg.MISSING_VALUE), - (2, pg.MISSING_VALUE), - (3, 3), - ], - ) + self.assertEqual( + [ + (i, o) + for i, o, _ in concurrent.concurrent_map( + fun, [1, 2, 3], ordered=True, timeout=1.5 + ) + ], + [ + (1, pg.MISSING_VALUE), + (2, 2), + (3, 3), + ], + ) + + def test_concurent_map_unordered_with_timeout(self): + def fun(x): + time.sleep(x) + return x + + self.assertEqual( + [ + (i, o) + for i, o, _ in concurrent.concurrent_map( + fun, [5, 2, 1, 4], timeout=3 + ) + ], + [ + (1, 1), + (2, 2), + (5, pg.MISSING_VALUE), + (4, pg.MISSING_VALUE), + ], + ) + + def test_concurent_map_unordered_with_timeout_less_worker(self): + def fun(x): + time.sleep(x) + return x + + self.assertEqual( + [ + (i, o) + for i, o, _ in concurrent.concurrent_map( + fun, [5, 2, 1, 4], timeout=3, max_workers=1 + ) + ], + [ + (5, pg.MISSING_VALUE), + (2, 2), + (1, 1), + (4, pg.MISSING_VALUE), + ], + ) def test_concurrent_map_with_showing_progress(self): def fun(x):