From 86c6474d6b58392386352548b371bae3e6f5ee19 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Nov 2024 12:50:10 -0500 Subject: [PATCH] make sure flash attn is available before attempting to patch --- src/axolotl/utils/models.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a4f71733fe..6285d99a6a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -2,6 +2,7 @@ # pylint: disable=too-many-lines import gc +import importlib import logging import math import os @@ -451,27 +452,34 @@ def patch_attention(self) -> None: replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + @property + def flash_attn(self) -> bool: + """Check if flash attention is installed""" + return importlib.util.find_spec("flash_attn") is not None + def patch_loss_llama(self) -> None: """ Patch loss functions """ - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_llama_cross_entropy, - patch_llama_rms_norm, - ) + if self.flash_attn: + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_llama_cross_entropy, + patch_llama_rms_norm, + ) - if self.cfg.flash_attn_cross_entropy: + if self.cfg.flash_attn_cross_entropy and self.flash_attn: patch_llama_cross_entropy() - if self.cfg.flash_attn_rms_norm: + elif self.cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch + + integrate_cross_entropy_loss_patch(model_type="llama") + + if self.cfg.flash_attn_rms_norm and self.flash_attn: patch_llama_rms_norm() elif self.cfg.unsloth_rms_norm: from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm patch_unsloth_layernorm() - if self.cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - - integrate_cross_entropy_loss_patch(model_type="llama") if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora