diff --git a/langfun/core/concurrent.py b/langfun/core/concurrent.py index 6a96fa4..56e10c1 100644 --- a/langfun/core/concurrent.py +++ b/langfun/core/concurrent.py @@ -921,5 +921,11 @@ def _progress_control( raise ValueError(f'Unsupported progress bar type: {progress_bar}') +def get_executor( + resource_id: str, + max_workers: int | None = None) -> concurrent.futures.ThreadPoolExecutor: + """Gets a thread pool executor associated with a resource id.""" + return _executor_pool.get(resource_id, max_workers) + # The global executor pool based on resource IDs. _executor_pool = ExecutorPool() diff --git a/langfun/core/eval/base.py b/langfun/core/eval/base.py index ac54db6..c159098 100644 --- a/langfun/core/eval/base.py +++ b/langfun/core/eval/base.py @@ -1087,7 +1087,7 @@ def _dryrun( ) error = e - copy.audit(example, output_message, error, dryrun=True) + copy.audit(1, example, output_message, error, dryrun=True) result = copy.finalize() if verbose: @@ -1124,19 +1124,20 @@ def _run( with lf.use_settings(debug=debug, cache=self.cache): self._reset() - def _process(example: Any): + def _process(idx_and_example: Any): # NOTE(daiyip): set the `input` symbol of the globals to None, so LLM # generated code with calls to `input` will raise an error, thus not # blocking the evaluation. + _, example = idx_and_example with lf_coding.context(input=None): output_message = self.process(example, **(self.additional_args or {})) self.process_output(example, output_message) return output_message try: - for example, message, error in lf.concurrent_map( + for (idx, example), message, error in lf.concurrent_map( _process, - examples, + enumerate(examples), max_workers=self.max_workers, show_progress=progress_bar or False, status_fn=self._status, @@ -1148,7 +1149,7 @@ def _process(example: Any): if isinstance(error, lf_structured.MappingError) else None ) - self.audit(example, message, error) + self.audit(idx + 1, example, message, error) finally: # Save cache upon completion or interruption. if self.dir and self.cache: @@ -1437,6 +1438,7 @@ def _format_rate(self, rate: float) -> str: def audit( self, + example_idx: int, example: Any, message: lf.Message | None, error: Exception | None = None, @@ -1445,6 +1447,7 @@ def audit( """Audits the example against the output. Subclasses should override. Args: + example_idx: 1-based index of the example in its dataset. example: The input object. message: The entire message returned by the LM, which could be used to trace the LM input, response and parsed structure. If error is raised @@ -1465,7 +1468,7 @@ def audit( else: assert message is not None output = message.text if self.schema is None else message.result - self.audit_processed(example, output, message, dryrun=dryrun) + self.audit_processed(example_idx, example, output, message, dryrun=dryrun) # Audit usage. if message is not None: @@ -1482,7 +1485,8 @@ def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None: self._num_usages += 1 def audit_processed( - self, example: Any, output: Any, message: lf.Message, dryrun: bool = False + self, example_idx: int, example: Any, output: Any, message: lf.Message, + dryrun: bool = False ) -> None: """Audits a successfully processed example. Subclass should override.""" diff --git a/langfun/core/eval/matching.py b/langfun/core/eval/matching.py index ee1dd31..5fad74b 100644 --- a/langfun/core/eval/matching.py +++ b/langfun/core/eval/matching.py @@ -41,8 +41,8 @@ def answer(self, output: Any, example: Any) -> Any: """Returns the answer from the structure output.""" @property - def matches(self) -> list[tuple[Any, Any, lf.Message]]: - """Returns the matches examples, outputs and the output messages.""" + def matches(self) -> list[tuple[int, Any, Any, lf.Message]]: + """Returns the matches IDs, examples, outputs and the output messages.""" return self._matches @property @@ -57,7 +57,7 @@ def match_rate(self) -> float: return self.num_matches / self.num_completed @property - def mismatches(self) -> list[tuple[Any, Any, lf.Message]]: + def mismatches(self) -> list[tuple[int, Any, Any, lf.Message]]: """Returns the mismatches examples, outputs and output messages.""" return self._mismatches @@ -87,7 +87,8 @@ def _reset(self) -> None: self._mismatches = [] def audit_processed( - self, example: Any, output: Any, message: lf.Message, dryrun: bool = False + self, example_idx: int, example: Any, output: Any, message: lf.Message, + dryrun: bool = False ) -> None: groundtruth = self.groundtruth(example) answer = self.answer(output, example) @@ -107,9 +108,9 @@ def audit_processed( ) if self.match(answer, groundtruth): - self._matches.append((example, output, message)) + self._matches.append((example_idx, example, output, message)) else: - self._mismatches.append((example, output, message)) + self._mismatches.append((example_idx, example, output, message)) def match(self, answer: Any, groundtruth: Any) -> bool: """Matches answer against the groundtruth. Subclasses can override.""" @@ -247,7 +248,7 @@ def _maybe_html(v, root_indent: int): # Fall back to the default format. return None - for i, (example, output, message) in enumerate(self.matches): + for i, (_, example, output, message) in enumerate(self.matches): bgcolor = 'white' if i % 2 == 0 else '#DDDDDD' s.write(f'