From 91104073f0c58d3b3e675f3ad5afd1a8472e0403 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 28 Aug 2023 14:36:37 +0200 Subject: [PATCH] Restore the original batch_id before calling the callback --- include/ctranslate2/generation.h | 16 ++++++++++++++++ python/tests/test_translator.py | 30 ++++++++++++++++++++++++++++++ src/generator.cc | 4 +++- src/translator.cc | 4 +++- 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/include/ctranslate2/generation.h b/include/ctranslate2/generation.h index 9bd4e70c1..f09aeef45 100644 --- a/include/ctranslate2/generation.h +++ b/include/ctranslate2/generation.h @@ -111,6 +111,22 @@ namespace ctranslate2 { } }; + template + Options restore_batch_ids_in_callback(Options options, const std::vector& example_index) { + if (options.callback) { + std::function wrapped_callback = + [&example_index, callback = std::move(options.callback)] + (GenerationStepResult step_result) { + step_result.batch_id = example_index[step_result.batch_id]; + return callback(std::move(step_result)); + }; + + options.callback = std::move(wrapped_callback); + } + + return options; + } + class ResolveEndToken { private: const Vocabulary& _vocabulary; diff --git a/python/tests/test_translator.py b/python/tests/test_translator.py index f89537b00..c64189226 100644 --- a/python/tests/test_translator.py +++ b/python/tests/test_translator.py @@ -228,6 +228,36 @@ def _callback(step_result): assert len(hypotheses) == 3 +def test_callback_batch_id(): + # The method will internally sort the input from longest to shortest, + # but we check that the returned batch ids match the user input. + + source = [ + ["ن"] * 1, + ["ن"] * 2, + ["ن"] * 3, + ] + + target_prefix = [ + ["a"], + ["b"], + ["c"], + ] + + def _callback(step_result): + assert step_result.token == target_prefix[step_result.batch_id][0] + return True + + translator = _get_transliterator() + translator.translate_batch( + source, + target_prefix, + max_batch_size=2, + beam_size=1, + callback=_callback, + ) + + def test_file_translation(tmp_dir): input_path = str(tmp_dir.join("input.txt")) output_path = str(tmp_dir.join("output.txt")) diff --git a/src/generator.cc b/src/generator.cc index 8f07ce094..8d8c9a709 100644 --- a/src/generator.cc +++ b/src/generator.cc @@ -15,7 +15,9 @@ namespace ctranslate2 { batch_type, [options](models::SequenceGeneratorReplica& generator, const Batch& batch) { spdlog::debug("Running batch generation on {} examples", batch.num_examples()); - auto results = generator.generate(batch.get_stream(0), options); + auto results = generator.generate( + batch.get_stream(0), + restore_batch_ids_in_callback(options, batch.example_index)); spdlog::debug("Finished batch generation"); return results; }); diff --git a/src/translator.cc b/src/translator.cc index 011e145f5..e1d76ed07 100644 --- a/src/translator.cc +++ b/src/translator.cc @@ -184,7 +184,9 @@ namespace ctranslate2 { const Batch& batch, const TranslationOptions& options) { spdlog::debug("Running batch translation on {} examples", batch.num_examples()); - auto results = model.translate(batch.get_stream(0), batch.get_stream(1), options); + auto results = model.translate(batch.get_stream(0), + batch.get_stream(1), + restore_batch_ids_in_callback(options, batch.example_index)); spdlog::debug("Finished batch translation"); return results; }