Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add e2e tests for Unsloth qlora and test the builds #2093

Merged
merged 15 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
run: |
pip3 show torch
pip3 install -U -e .
python scripts/unsloth_install.py | sh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be added to pytest-sdist and to the nightly workflow too?

pip3 install -r requirements-dev.txt -r requirements-tests.txt

- name: Run tests
Expand Down
2 changes: 2 additions & 0 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi

RUN python scripts/unsloth_install.py | sh

# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt

Expand Down
2 changes: 2 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi

RUN python scripts/unsloth_install.py | sh

# So we can test the Docker image
RUN pip install pytest

Expand Down
7 changes: 5 additions & 2 deletions scripts/unsloth_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
Expand All @@ -29,5 +32,5 @@
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
)
34 changes: 26 additions & 8 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -94,14 +93,33 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)


def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss

LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
def patch_fa_llama_cross_entropy():
LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
)
from flash_attn.ops.triton.cross_entropy import (
cross_entropy_loss as flash_attn_cross_entropy_loss,
)

def fa2_fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
loss, _ = flash_attn_cross_entropy_loss(
source, target, ignore_index=ignore_index
)
if reduction == "sum":
loss = loss.sum() / num_items_in_batch
else:
loss = loss.sum() / (target != ignore_index).sum()
return loss

transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy


def patch_llama_rms_norm():
try:
Expand Down Expand Up @@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(

# skip only if explicitly disabled
if cross_entropy:
patch_llama_cross_entropy()
patch_fa_llama_cross_entropy()

# skip only if explicitly disabled
if rms_norm:
Expand Down
46 changes: 23 additions & 23 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

# pylint: disable=too-many-lines
import gc
import importlib
import logging
import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401

import addict
Expand Down Expand Up @@ -409,7 +411,7 @@ def apply_patches(self) -> None:
)

if self.cfg.is_llama_derived_model:
self.patch_loss()
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand Down Expand Up @@ -451,27 +453,34 @@ def patch_attention(self) -> None:

replace_stablelm_attn_with_flash_attn(self.cfg.base_model)

def patch_loss(self) -> None:
@cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None

def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
patch_llama_rms_norm,
)

if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
winglian marked this conversation as resolved.
Show resolved Hide resolved

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm

patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand All @@ -481,6 +490,7 @@ def patch_llama_derived_model(self) -> None:
"""
Modify all llama derived models in one block
"""
self.patch_loss_llama()

if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
Expand Down Expand Up @@ -528,16 +538,6 @@ def patch_llama_derived_model(self) -> None:
"Shifted-sparse attention not currently implemented without flash attention."
)

if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

patch_self_attn_lora()

def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
Expand Down
47 changes: 30 additions & 17 deletions tests/e2e/patched/test_fa_xentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import logging
import os
import unittest
from importlib import reload
from pathlib import Path

import pytest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.cli import load_datasets
Expand All @@ -17,7 +17,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import most_recent_subdir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand All @@ -31,18 +31,20 @@ def reload_transformers():
reload(transformers.models.llama.modeling_llama)


class TestFAXentropyLlama(unittest.TestCase):
class TestFAXentropyLlama:
"""
Test case for Llama models using LoRA w multipack
"""

@with_temp_dir
def test_lora_packing_fa_cross_entropy(self, temp_dir):
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
)
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
Expand All @@ -55,25 +57,29 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir):
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"path": "mlabonne/FineTome-100k",
"field_messages": "conversations",
"message_field_content": "value",
"message_field_role": "from",
"type": "chat_template",
"split": "train[:2%]",
},
],
"num_epochs": 1,
"max_steps": 10,
"save_steps": 10,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"max_steps": 5,
"save_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
Expand All @@ -87,3 +93,10 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir):

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 1.5, "Loss is too high"
Loading