diff --git a/langfun/core/eval/v2/checkpointing.py b/langfun/core/eval/v2/checkpointing.py index 8c6f163..d09b000 100644 --- a/langfun/core/eval/v2/checkpointing.py +++ b/langfun/core/eval/v2/checkpointing.py @@ -92,7 +92,7 @@ def save_state(example: Example): ) ) writer.add(example) - del writer + writer.close() runner.background_run(save_state, example) def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]: @@ -128,6 +128,8 @@ def on_run_abort( ) -> None: with self._lock: if self._sequence_writer is not None: + for writer in self._sequence_writer.values(): + writer.close() self._sequence_writer.clear() def on_run_complete( @@ -174,6 +176,9 @@ def on_experiment_complete( assert experiment.id in self._sequence_writer with self._lock: if self._sequence_writer is not None: + # Make sure the writer is closed without delay so the file will be + # available immediately. + self._sequence_writer[experiment.id].close() del self._sequence_writer[experiment.id] def on_example_complete( @@ -207,9 +212,13 @@ def add(self, example: Example): return self._sequence_writer.add(example_blob) - def __del__(self): + def close(self): # Make sure there is no write in progress. with self._lock: - assert self._sequence_writer is not None + if self._sequence_writer is None: + return self._sequence_writer.close() self._sequence_writer = None + + def __del__(self): + self.close()