diff --git a/python/ctranslate2/extensions.py b/python/ctranslate2/extensions.py index 5abafb327..120cdba02 100644 --- a/python/ctranslate2/extensions.py +++ b/python/ctranslate2/extensions.py @@ -345,7 +345,7 @@ def generator_generate_tokens( async def generator_async_generate_tokens( generator: Generator, - start_tokens: Union[List[str], List[List[str]]], + prompt: Union[List[str], List[List[str]]], max_batch_size: int = 0, batch_type: str = "examples", *, @@ -366,7 +366,7 @@ async def generator_async_generate_tokens( """Yields tokens asynchronously as they are generated by the model. Arguments: - start_tokens: Batch of start tokens. If the decoder starts from a + prompt: Batch of start tokens. If the decoder starts from a special start token like , this token should be added to this input. max_batch_size: The maximum batch size. batch_type: Whether `max_batch_size` is the number of “examples” or “tokens”. @@ -396,11 +396,11 @@ async def generator_async_generate_tokens( Note: This generation method is not compatible with beam search which requires a complete decoding. """ - if len(start_tokens) > 0 and isinstance(start_tokens[0], str): - start_tokens = [start_tokens] + if len(prompt) > 0 and isinstance(prompt[0], str): + prompt = [prompt] async for step_result in AsyncGenerator( generator.generate_batch, - start_tokens, + prompt, max_batch_size=max_batch_size, batch_type=batch_type, repetition_penalty=repetition_penalty,