Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make sure to call unsloth patches even if not packed #2097

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

# pylint: disable=too-many-lines
import gc
import importlib
import logging
import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401

import addict
Expand Down Expand Up @@ -409,7 +411,7 @@ def apply_patches(self) -> None:
)

if self.cfg.is_llama_derived_model:
self.patch_loss()
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand Down Expand Up @@ -451,27 +453,34 @@ def patch_attention(self) -> None:

replace_stablelm_attn_with_flash_attn(self.cfg.base_model)

def patch_loss(self) -> None:
@cached_property
def flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None

def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)

if self.cfg.flash_attn_cross_entropy:
if self.cfg.flash_attn_cross_entropy and self.flash_attn:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.flash_attn_rms_norm and self.flash_attn:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm

patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand All @@ -481,6 +490,7 @@ def patch_llama_derived_model(self) -> None:
"""
Modify all llama derived models in one block
"""
self.patch_loss_llama()

if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
Expand Down Expand Up @@ -528,16 +538,6 @@ def patch_llama_derived_model(self) -> None:
"Shifted-sparse attention not currently implemented without flash attention."
)

if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

patch_self_attn_lora()

def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
Expand Down
Loading