Skip to content

Commit

Permalink
Update async_generate_token
Browse files Browse the repository at this point in the history
  • Loading branch information
jgcb00 committed Sep 6, 2023
1 parent 362ff5d commit e529676
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/ctranslate2/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def generator_generate_tokens(

async def generator_async_generate_tokens(
generator: Generator,
start_tokens: Union[List[str], List[List[str]]],
prompt: Union[List[str], List[List[str]]],
max_batch_size: int = 0,
batch_type: str = "examples",
*,
Expand All @@ -366,7 +366,7 @@ async def generator_async_generate_tokens(
"""Yields tokens asynchronously as they are generated by the model.
Arguments:
start_tokens: Batch of start tokens. If the decoder starts from a
prompt: Batch of start tokens. If the decoder starts from a
special start token like <s>, this token should be added to this input.
max_batch_size: The maximum batch size.
batch_type: Whether `max_batch_size` is the number of “examples” or “tokens”.
Expand Down Expand Up @@ -396,11 +396,11 @@ async def generator_async_generate_tokens(
Note:
This generation method is not compatible with beam search which requires a complete decoding.
"""
if len(start_tokens) > 0 and isinstance(start_tokens[0], str):
start_tokens = [start_tokens]
if len(prompt) > 0 and isinstance(prompt[0], str):
prompt = [prompt]
async for step_result in AsyncGenerator(
generator.generate_batch,
start_tokens,
prompt,
max_batch_size=max_batch_size,
batch_type=batch_type,
repetition_penalty=repetition_penalty,
Expand Down

0 comments on commit e529676

Please sign in to comment.