From 16f7fbda9b14c39e411ce175ef2ee2fb9e02b0ad Mon Sep 17 00:00:00 2001 From: Louka Ewington-Pitsos Date: Thu, 18 Jul 2024 21:08:31 +0000 Subject: [PATCH 1/2] stop dropping samples every batch --- sae/data.py | 35 +++++++++++++++++++++++++------ tests/test_chunk_and_tokenize.py | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) create mode 100644 tests/test_chunk_and_tokenize.py diff --git a/sae/data.py b/sae/data.py index 21859ad..23ad7ec 100644 --- a/sae/data.py +++ b/sae/data.py @@ -1,5 +1,6 @@ """Tools for tokenizing and manipulating text datasets.""" +import os import math from multiprocessing import cpu_count from typing import TypeVar, Union @@ -20,6 +21,7 @@ def chunk_and_tokenize( max_seq_len: int = 2048, return_final_batch: bool = False, load_from_cache_file: bool = True, + batch_size: int = 2048, ) -> T: """Perform GPT-style chunking and tokenization on a dataset. @@ -43,7 +45,7 @@ def chunk_and_tokenize( The chunked and tokenized dataset. """ - def _tokenize_fn(x: dict[str, list]): + def _tokenize_fn(x: dict[str, list], leftovers: list=[]): chunk_size = min(tokenizer.model_max_length, max_seq_len) sep = tokenizer.eos_token or "<|endoftext|>" joined_text = sep.join([""] + x[text_key]) @@ -67,10 +69,30 @@ def _tokenize_fn(x: dict[str, list]): ] output = {"input_ids": chunks} - if not return_final_batch: - # We know that the last sample will almost always be less than the max - # number of tokens, and we don't want to pad, so we just drop it. - output = {k: v[:-1] for k, v in output.items()} + if (not return_final_batch) and len(output["input_ids"][-1]) != chunk_size: + # we do not pad so if the last batch is smaller than the required + # batch size we either lengthen it using leftover batches or put + # it in the basket of leftovers + final_chunk = output["input_ids"].pop() + + while len(final_chunk) < chunk_size: + if len(leftovers) == 0: + leftovers.append(final_chunk) + break + + leftover = leftovers.pop() + final_chunk.extend([tokenizer.eos_token_id] + leftover) + else: + new_leftover = final_chunk[chunk_size:] + final_chunk = final_chunk[:chunk_size] + output["input_ids"].append(final_chunk) + + if len(new_leftover) > 0: + leftovers.append(new_leftover) + + output = {k: v[:len(output['input_ids'])] for k, v in output.items()} + + output_batch_size = len(output["input_ids"]) @@ -89,10 +111,11 @@ def _tokenize_fn(x: dict[str, list]): # since we always throw away the last element of the batch we # want to keep the batch size as large as possible batched=True, - batch_size=2048, + batch_size=batch_size, num_proc=num_proc, remove_columns=get_columns_all_equal(data), load_from_cache_file=load_from_cache_file, + fn_kwargs={} if return_final_batch else {"leftovers": []} ) return data.with_format(format, columns=["input_ids"]) diff --git a/tests/test_chunk_and_tokenize.py b/tests/test_chunk_and_tokenize.py new file mode 100644 index 0000000..5fc35ee --- /dev/null +++ b/tests/test_chunk_and_tokenize.py @@ -0,0 +1,36 @@ +import pytest +from transformers import GPT2TokenizerFast +from datasets import Dataset +from sae.data import chunk_and_tokenize # Replace 'mymodule' with the actual module name + +@pytest.fixture +def setup_data(): + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + data = Dataset.from_dict({ + "text": ["This is a very short sentence.",] * 2000 + \ + ["This is a sentence which is just a little bit longer, you better have a way of dealing with it homeslice."] * 3 + + }) + return tokenizer, data + +def test_chunk_and_tokenize(setup_data): + tokenizer, data = setup_data + + # Perform chunking and tokenization + max_seq_len = 10 # Setting a small max_seq_len for testing overflow + tokenized_data = chunk_and_tokenize( + data, + tokenizer, + max_seq_len=max_seq_len, + num_proc=2, + batch_size=32, + ) + + # Verify the output + input_ids = tokenized_data["input_ids"] + input_id_lengths = [len(ids) for ids in input_ids] + + assert all([l == max_seq_len for l in input_id_lengths]), f"All input_ids should have max_seq_len, got {input_id_lengths}" + assert len(input_ids[-1]) <= max_seq_len, "Last input_ids should be <= max_seq_len" + assert len(input_ids) >= 1610, f"Expected at least 1610 input_ids, got {len(input_ids)}" + From 9e43ce3e39dcba003df96af8c9449bc5b5937b83 Mon Sep 17 00:00:00 2001 From: Louka Ewington-Pitsos Date: Thu, 18 Jul 2024 21:11:06 +0000 Subject: [PATCH 2/2] cleanup --- sae/data.py | 2 -- tests/test_chunk_and_tokenize.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sae/data.py b/sae/data.py index 23ad7ec..649ba9d 100644 --- a/sae/data.py +++ b/sae/data.py @@ -1,6 +1,5 @@ """Tools for tokenizing and manipulating text datasets.""" -import os import math from multiprocessing import cpu_count from typing import TypeVar, Union @@ -93,7 +92,6 @@ def _tokenize_fn(x: dict[str, list], leftovers: list=[]): output = {k: v[:len(output['input_ids'])] for k, v in output.items()} - output_batch_size = len(output["input_ids"]) if output_batch_size == 0: diff --git a/tests/test_chunk_and_tokenize.py b/tests/test_chunk_and_tokenize.py index 5fc35ee..b872171 100644 --- a/tests/test_chunk_and_tokenize.py +++ b/tests/test_chunk_and_tokenize.py @@ -1,7 +1,7 @@ import pytest from transformers import GPT2TokenizerFast from datasets import Dataset -from sae.data import chunk_and_tokenize # Replace 'mymodule' with the actual module name +from sae.data import chunk_and_tokenize @pytest.fixture def setup_data():