Skip to content

Commit

Permalink
fix for token bug that skips EOS (#815)
Browse files Browse the repository at this point in the history
This will actually add EOS if the tokenizer doesn't have it (HF doesn't do this)
  • Loading branch information
ahmeda14960 authored Nov 20, 2024
2 parents 80b2296 + 4eb4281 commit 8509037
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,15 +929,24 @@ def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]:
)


def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict:
def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_append_eos: bool) -> dict:
"""
Preprocess chat examples to match the format of preprocess_supervised_example.
Returns a dict with input_ids and sources_len like the supervised case.
Args:
batch: List of dicts with input/output pairs
tokenizer: HuggingFace tokenizer
should_append_eos: Whether we need to manually add EOS (True if tokenizer doesn't do it automatically)
"""
# Get sources (inputs) and targets (outputs) from the batch
sources = [example["input"] for example in batch]
targets = [example["output"] for example in batch]

# Add EOS only if needed (tokenizer doesn't do it automatically)
if should_append_eos:
targets = [t + tokenizer.eos_token for t in targets]

# Tokenize sources alone first to get the source lengths
sources_tokenized = tokenizer(sources, padding=False, truncation=True)

Expand Down Expand Up @@ -965,9 +974,13 @@ def mk_chat_sft_dataset(
# Set up example structure matching supervised case
output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)}

input_ids = tokenizer("hi there")["input_ids"]
should_append_eos = input_ids[-1] != tokenizer.eos_token_id
logger.info(f"Manual EOS Needed: {should_append_eos}")

# Process the dataset
dataset = source.map_batches(
lambda ex: preprocess_chat_example(ex, tokenizer),
lambda ex: preprocess_chat_example(ex, tokenizer, should_append_eos),
batch_size=128,
num_cpus=num_cpus_used_by_tokenizer(tokenizer),
output_exemplar=output_exemplar,
Expand Down

0 comments on commit 8509037

Please sign in to comment.