diff --git a/langfun/core/concurrent.py b/langfun/core/concurrent.py index ae159417..86d9724d 100644 --- a/langfun/core/concurrent.py +++ b/langfun/core/concurrent.py @@ -250,6 +250,74 @@ def elapse(self) -> float: return self.end_time - self.start_time +@dataclasses.dataclass +class Progress: + """Concurrent processing progress.""" + total: int + + _succeeded: int = 0 + _failed: int = 0 + _last_error: Exception | None = None + _total_duration: float = 0.0 + _job: Job | None = None + + @property + def succeeded(self) -> int: + """Returns number of succeeded jobs.""" + return self._succeeded + + @property + def failed(self) -> int: + """Returns number of failed jobs.""" + return self._failed + + @property + def completed(self) -> int: + """Returns number of completed jobs.""" + return self.succeeded + self.failed + + @property + def last_error(self) -> Exception | None: + """Returns last error.""" + return self._last_error + + @property + def job(self) -> Job | None: + """Returns current job.""" + return self._job + + @property + def success_rate(self) -> float: + """Returns success rate.""" + if self.completed == 0: + return 0.0 + return self.succeeded / self.completed + + @property + def failure_rate(self) -> float: + """Returns failure rate.""" + if self.completed == 0: + return 0.0 + return self.failed / self.completed + + @property + def avg_duration(self) -> float: + """Returns average duration each job worked.""" + if self.completed == 0: + return 0.0 + return self._total_duration / self.completed + + def update(self, job: Job) -> None: + """Mark a job as completed.""" + self._job = job + if job.error is None: + self._succeeded += 1 + else: + self._failed += 1 + self._last_error = job.error + self._total_duration += job.elapse + + def concurrent_map( func: Callable[[Any], Any], parallel_inputs: Iterable[Any], @@ -258,6 +326,7 @@ def concurrent_map( max_workers: int = 32, ordered: bool = False, show_progress: bool = False, + status_fn: Callable[[Progress], dict[str, Any]] | None = None, timeout: int | None = None, silence_on_errors: Union[ Type[Exception], Tuple[Type[Exception], ...], None @@ -283,6 +352,11 @@ def concurrent_map( the order of the elements in `parallel_inputs`. Otherwise, elements that are finished earlier will be delivered first. show_progress: If True, show progress on console. + status_fn: An optional callable object that receives a + `lf.concurrent.Progress` object and returns a dict of kv pairs as + the status to include in the progress bar. Applicable only when + `show_progress` is set to True. If None, the default status_fn will be + used, which outputs the success and failure rate. timeout: The timeout in seconds for processing each input. It is the total processing time for each input, even multiple retries take place. If None, there is no timeout. @@ -316,6 +390,13 @@ def concurrent_map( exponential_backoff=exponential_backoff, ) + status_fn = status_fn or (lambda p: { # pylint: disable=g-long-lambda + 'Succeeded': '%.2f%% (%d/%d)' % ( + p.success_rate * 100, p.succeeded, p.completed), + 'Failed': '%.2f%% (%d/%d)' % ( + p.failure_rate * 100, p.failed, p.completed), + }) + executor = executor or concurrent.futures.ThreadPoolExecutor( max_workers=max_workers) @@ -332,28 +413,31 @@ def concurrent_map( future_to_job[future] = job total += 1 - progress = tqdm.tqdm(total=total) if show_progress else None - def update_progress( - success: int, - failure: int, - last_error: Exception | None = None) -> None: - if progress is not None: - completed = success + failure - description = 'Success: %.2f%% (%d/%d), Failure: %.2f%% (%d/%d)' % ( - success * 100.0 / completed, success, completed, - failure * 100.0 / completed, failure, completed - ) - progress.set_description(description) - - if last_error is not None: - error_text = repr(last_error) + progress = Progress(total=total) + progress_bar = tqdm.tqdm(total=total) if show_progress else None + + def update_progress_bar(progress: Progress) -> None: + if progress_bar is not None: + status = status_fn(progress) + description = ', '.join([ + f'{k}: {v}' for k, v in status.items() + ]) + if description: + progress_bar.set_description(f'[{description}]') + + postfix = { + 'AvgDuration': '%.2f seconds' % progress.avg_duration + } + if progress.last_error is not None: + error_text = repr(progress.last_error) if len(error_text) >= 64: error_text = error_text[:64] + '...' - postfix = {'LastError': error_text} - progress.set_postfix(postfix) - progress.update(1) + postfix['LastError'] = error_text + + if postfix: + progress_bar.set_postfix(postfix) + progress_bar.update(1) - success, failure, last_error = 0, 0, None if ordered: for future in pending_futures: job = future_to_job[future] @@ -361,21 +445,17 @@ def update_progress( try: _ = future.result(timeout=wait_time) if job.error is not None: - last_error = job.error - failure += 1 if not ( silence_on_errors and isinstance(job.error, silence_on_errors)): raise job.error - else: - success += 1 except concurrent.futures.TimeoutError: future.cancel() last_error = TimeoutError( f'Execution time ({job.elapse}) exceeds {timeout} seconds.') job.mark_canceled(last_error) - failure += 1 - update_progress(success, failure, last_error) yield job.arg, job.result, job.error + progress.update(job) + update_progress_bar(progress) else: while pending_futures: completed_batch = set() @@ -385,16 +465,13 @@ def update_progress( job = future_to_job[future] del future_to_job[future] if job.error is not None: - last_error = job.error - failure += 1 if not ( silence_on_errors and isinstance(job.error, silence_on_errors) ): raise job.error # pylint: disable=g-doc-exception - else: - success += 1 - update_progress(success, failure, last_error) yield job.arg, job.result, job.error + progress.update(job) + update_progress_bar(progress) completed_batch.add(future) except concurrent.futures.TimeoutError: pass @@ -417,16 +494,14 @@ def update_progress( if job.error is not None: last_error = job.error - failure += 1 - else: - success += 1 - update_progress(success, failure, last_error) yield job.arg, job.result, job.error + progress.update(job) + update_progress_bar(progress) else: remaining_futures.append(future) pending_futures = remaining_futures - if progress is not None: - progress.close() + if progress_bar is not None: + progress_bar.close() diff --git a/langfun/core/concurrent_test.py b/langfun/core/concurrent_test.py index d5a9cd61..4dfd5841 100644 --- a/langfun/core/concurrent_test.py +++ b/langfun/core/concurrent_test.py @@ -162,6 +162,46 @@ def fun(a): self.assertEqual(concurrent.concurrent_execute(fun, [A(1), A(2)]), [2, 4]) +class ProgressTest(unittest.TestCase): + + def test_progress(self): + p = concurrent.Progress(total=10) + self.assertEqual(p.total, 10) + self.assertEqual(p.succeeded, 0) + self.assertEqual(p.failed, 0) + self.assertEqual(p.completed, 0) + self.assertEqual(p.success_rate, 0) + self.assertEqual(p.failure_rate, 0) + self.assertEqual(p.avg_duration, 0) + + def fun(x): + time.sleep(x) + return x + + def fun2(unused_x): + raise ValueError('Intentional error.') + + job1 = concurrent.Job(fun, 1) + job2 = concurrent.Job(fun2, 2) + job1() + job2() + + p.update(job1) + self.assertEqual(p.succeeded, 1) + self.assertEqual(p.failed, 0) + self.assertEqual(p.completed, 1) + self.assertEqual(p.success_rate, 1) + self.assertEqual(p.failure_rate, 0) + self.assertGreater(p.avg_duration, 0.5) + + p.update(job2) + self.assertEqual(p.succeeded, 1) + self.assertEqual(p.failed, 1) + self.assertEqual(p.completed, 2) + self.assertEqual(p.success_rate, 0.5) + self.assertEqual(p.failure_rate, 0.5) + + class ConcurrentMapTest(unittest.TestCase): def test_concurrent_map_raise_on_error(self): error = ValueError() @@ -351,8 +391,7 @@ def fun(x): self.assertEqual( [ - (i, o) - for i, o, _ in concurrent.concurrent_map( + (i, o) for i, o, _ in concurrent.concurrent_map( fun, [1, 2, 3], timeout=1.5, max_workers=1, show_progress=True ) ], @@ -363,6 +402,27 @@ def fun(x): ], ) + def test_concurrent_map_with_showing_progress_and_status_fn(self): + def fun(x): + if x == 2: + raise ValueError('Intentional error.') + time.sleep(x) + return x + + self.assertEqual( + [ + (i, o) for i, o, _ in concurrent.concurrent_map( + fun, [1, 2, 3], timeout=1.5, max_workers=1, + show_progress=True, status_fn=lambda p: dict(x=1, y=1) + ) + ], + [ + (1, 1), + (2, pg.MISSING_VALUE), + (3, pg.MISSING_VALUE), + ], + ) + if __name__ == '__main__': unittest.main()