Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Exact Deduplication Feature to Preprocessing Pipeline #2072

Merged
merged 12 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true

Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true

# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
Expand Down
95 changes: 95 additions & 0 deletions examples/llama-3/lora-1b-deduplicate-dpo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: true
load_in_4bit: false
strict: false

chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant

dataset_exact_deduplication: true
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/lora-out

sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
76 changes: 76 additions & 0 deletions examples/llama-3/lora-1b-deduplicate-sft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/lora-out

dataset_exact_deduplication: true
test_value: true

sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
2 changes: 1 addition & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def check_remote_config(config: Union[str, Path]):
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path

Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ class Config:
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
dataloader_num_workers: Optional[int] = None
Expand Down
7 changes: 6 additions & 1 deletion src/axolotl/utils/data/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.utils import md5
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
Expand Down Expand Up @@ -145,4 +145,9 @@ def load_split(dataset_cfgs, _cfg):
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)

if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
)

return train_dataset, eval_dataset
21 changes: 14 additions & 7 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
Expand Down Expand Up @@ -136,8 +136,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")
return train_dataset, eval_dataset, cfg.max_steps, prompters

if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
Expand Down Expand Up @@ -178,7 +179,7 @@ def load_tokenized_prepared_datasets(
+ "|".join(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
f"{d.path}: {d.type}: {d.shards}: {d.conversation}{d.split}"
olivermolenschot marked this conversation as resolved.
Show resolved Hide resolved
for d in cfg_datasets
]
)
Expand Down Expand Up @@ -571,7 +572,8 @@ def load_prepare_datasets(
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)

if cfg.dataset_exact_deduplication:
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
Expand All @@ -583,12 +585,17 @@ def load_prepare_datasets(
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif split == "test":
if cfg.dataset_exact_deduplication:
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
else:
eval_dataset = dataset
train_dataset = None
eval_dataset = dataset
else:
train_dataset = dataset
if cfg.dataset_exact_deduplication:
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
else:
train_dataset = dataset
eval_dataset = None

return train_dataset, eval_dataset, prompters


Expand Down
98 changes: 98 additions & 0 deletions src/axolotl/utils/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,108 @@
"""data handling helpers"""

import hashlib
import logging

from datasets import Dataset

LOG = logging.getLogger("axolotl")


def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec


def sha256(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()


def deduplicate_dataset(
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
) -> Dataset:
unique_indices = []

for idx, row in enumerate(dataset):
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
if row_hash not in seen_hashes:
seen_hashes[row_hash] = [idx]
unique_indices.append(idx)
else:
# Check for collision by looking up the original dataset indices
original_indices = seen_hashes[row_hash]
is_duplicate = False
for original_idx in original_indices:
if (
not idx == original_idx
and original_idx < len(dataset)
and str(dataset[original_idx]) == str(row)
):
is_duplicate = True
break
# Check in the other dataset if provided
if other_dataset is not None:
if original_idx < len(other_dataset) and str(
other_dataset[original_idx]
) == str(row):
is_duplicate = True
break
if not is_duplicate:
seen_hashes[row_hash].append(idx)
unique_indices.append(idx)
continue
return dataset.select(unique_indices)


def deduplicate_and_log_datasets(
*,
train_dataset: Dataset = None,
eval_dataset: Dataset = None,
dataset: Dataset = None,
) -> tuple[Dataset, Dataset, Dataset]:
"""
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.

Returns:
tuple: Deduplicated train, eval, and additional datasets.
"""
seen_hashes: dict[str, list[int]] = {}

# Handle cases where datasets are None
if train_dataset is not None:
LOG.info(
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
)
train_dataset = deduplicate_dataset(
dataset=train_dataset, seen_hashes=seen_hashes
)
LOG.info(
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
)
else:
LOG.info("Train dataset is None. Skipping deduplication.")

if eval_dataset is not None:
LOG.info(
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
)
eval_dataset = deduplicate_dataset(
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
)
LOG.info(
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
)
else:
LOG.info("Eval dataset is None. Skipping deduplication.")

if dataset is not None and (eval_dataset is None and train_dataset is None):
LOG.info(
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
)
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
LOG.info(
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
)

return train_dataset, eval_dataset, dataset
Loading