Skip to content

Commit

Permalink
Store example index when auditing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687164141
  • Loading branch information
daiyip authored and langfun authors committed Oct 18, 2024
1 parent 7d1ffee commit 6821206
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 16 deletions.
6 changes: 6 additions & 0 deletions langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,5 +921,11 @@ def _progress_control(
raise ValueError(f'Unsupported progress bar type: {progress_bar}')


def get_executor(
resource_id: str,
max_workers: int | None = None) -> concurrent.futures.ThreadPoolExecutor:
"""Gets a thread pool executor associated with a resource id."""
return _executor_pool.get(resource_id, max_workers)

# The global executor pool based on resource IDs.
_executor_pool = ExecutorPool()
18 changes: 11 additions & 7 deletions langfun/core/eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def _dryrun(
)
error = e

copy.audit(example, output_message, error, dryrun=True)
copy.audit(1, example, output_message, error, dryrun=True)
result = copy.finalize()

if verbose:
Expand Down Expand Up @@ -1124,19 +1124,20 @@ def _run(
with lf.use_settings(debug=debug, cache=self.cache):
self._reset()

def _process(example: Any):
def _process(idx_and_example: Any):
# NOTE(daiyip): set the `input` symbol of the globals to None, so LLM
# generated code with calls to `input` will raise an error, thus not
# blocking the evaluation.
_, example = idx_and_example
with lf_coding.context(input=None):
output_message = self.process(example, **(self.additional_args or {}))
self.process_output(example, output_message)
return output_message

try:
for example, message, error in lf.concurrent_map(
for (idx, example), message, error in lf.concurrent_map(
_process,
examples,
enumerate(examples),
max_workers=self.max_workers,
show_progress=progress_bar or False,
status_fn=self._status,
Expand All @@ -1148,7 +1149,7 @@ def _process(example: Any):
if isinstance(error, lf_structured.MappingError)
else None
)
self.audit(example, message, error)
self.audit(idx + 1, example, message, error)
finally:
# Save cache upon completion or interruption.
if self.dir and self.cache:
Expand Down Expand Up @@ -1437,6 +1438,7 @@ def _format_rate(self, rate: float) -> str:

def audit(
self,
example_idx: int,
example: Any,
message: lf.Message | None,
error: Exception | None = None,
Expand All @@ -1445,6 +1447,7 @@ def audit(
"""Audits the example against the output. Subclasses should override.
Args:
example_idx: 1-based index of the example in its dataset.
example: The input object.
message: The entire message returned by the LM, which could be used to
trace the LM input, response and parsed structure. If error is raised
Expand All @@ -1465,7 +1468,7 @@ def audit(
else:
assert message is not None
output = message.text if self.schema is None else message.result
self.audit_processed(example, output, message, dryrun=dryrun)
self.audit_processed(example_idx, example, output, message, dryrun=dryrun)

# Audit usage.
if message is not None:
Expand All @@ -1482,7 +1485,8 @@ def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
self._num_usages += 1

def audit_processed(
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
self, example_idx: int, example: Any, output: Any, message: lf.Message,
dryrun: bool = False
) -> None:
"""Audits a successfully processed example. Subclass should override."""

Expand Down
17 changes: 9 additions & 8 deletions langfun/core/eval/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def answer(self, output: Any, example: Any) -> Any:
"""Returns the answer from the structure output."""

@property
def matches(self) -> list[tuple[Any, Any, lf.Message]]:
"""Returns the matches examples, outputs and the output messages."""
def matches(self) -> list[tuple[int, Any, Any, lf.Message]]:
"""Returns the matches IDs, examples, outputs and the output messages."""
return self._matches

@property
Expand All @@ -57,7 +57,7 @@ def match_rate(self) -> float:
return self.num_matches / self.num_completed

@property
def mismatches(self) -> list[tuple[Any, Any, lf.Message]]:
def mismatches(self) -> list[tuple[int, Any, Any, lf.Message]]:
"""Returns the mismatches examples, outputs and output messages."""
return self._mismatches

Expand Down Expand Up @@ -87,7 +87,8 @@ def _reset(self) -> None:
self._mismatches = []

def audit_processed(
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
self, example_idx: int, example: Any, output: Any, message: lf.Message,
dryrun: bool = False
) -> None:
groundtruth = self.groundtruth(example)
answer = self.answer(output, example)
Expand All @@ -107,9 +108,9 @@ def audit_processed(
)

if self.match(answer, groundtruth):
self._matches.append((example, output, message))
self._matches.append((example_idx, example, output, message))
else:
self._mismatches.append((example, output, message))
self._mismatches.append((example_idx, example, output, message))

def match(self, answer: Any, groundtruth: Any) -> bool:
"""Matches answer against the groundtruth. Subclasses can override."""
Expand Down Expand Up @@ -247,7 +248,7 @@ def _maybe_html(v, root_indent: int):
# Fall back to the default format.
return None

for i, (example, output, message) in enumerate(self.matches):
for i, (_, example, output, message) in enumerate(self.matches):
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
input_str = lf.repr_utils.escape_quoted(
Expand Down Expand Up @@ -282,7 +283,7 @@ def _render_mismatches(self, s: io.StringIO) -> None:
'</tr>'
)

for i, (example, output, message) in enumerate(self.mismatches):
for i, (_, example, output, message) in enumerate(self.mismatches):
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
input_str = pg.format(example, verbose=False, max_bytes_len=32)
Expand Down
4 changes: 3 additions & 1 deletion langfun/core/eval/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def _reset(self) -> None:
self._scored = []

def audit_processed(
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
self, example_idx: int, example: Any, output: Any, message: lf.Message,
dryrun: bool = False
) -> None:
del example_idx
score = self.score(example, output)

if dryrun:
Expand Down

0 comments on commit 6821206

Please sign in to comment.