Skip to content

Commit

Permalink
make sure flash attn is available before attempting to patch
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 21, 2024
1 parent 71b7a09 commit 86c6474
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# pylint: disable=too-many-lines
import gc
import importlib
import logging
import math
import os
Expand Down Expand Up @@ -451,27 +452,34 @@ def patch_attention(self) -> None:

replace_stablelm_attn_with_flash_attn(self.cfg.base_model)

@property
def flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None

def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)

if self.cfg.flash_attn_cross_entropy:
if self.cfg.flash_attn_cross_entropy and self.flash_attn:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.flash_attn_rms_norm and self.flash_attn:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm

patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand Down

0 comments on commit 86c6474

Please sign in to comment.