From ccf0a3c2be19b540a5d934a624bf8ec9d0c71892 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Nov 2024 12:03:34 -0500 Subject: [PATCH] use cached_property instead --- src/axolotl/utils/models.py | 44 ++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 082df7c27..5bba615e1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -2,10 +2,12 @@ # pylint: disable=too-many-lines import gc +import importlib import logging import math import os import types +from functools import cached_property from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 import addict @@ -409,7 +411,7 @@ def apply_patches(self) -> None: ) if self.cfg.is_llama_derived_model: - self.patch_loss() + self.patch_loss_llama() if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora @@ -451,27 +453,34 @@ def patch_attention(self) -> None: replace_stablelm_attn_with_flash_attn(self.cfg.base_model) - def patch_loss(self) -> None: + @cached_property + def flash_attn(self) -> bool: + """Check if flash attention is installed""" + return importlib.util.find_spec("flash_attn") is not None + + def patch_loss_llama(self) -> None: """ Patch loss functions """ - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_llama_cross_entropy, - patch_llama_rms_norm, - ) + if self.flash_attn: + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_llama_cross_entropy, + patch_llama_rms_norm, + ) - if self.cfg.flash_attn_cross_entropy: + if self.cfg.flash_attn_cross_entropy and self.flash_attn: patch_llama_cross_entropy() - if self.cfg.flash_attn_rms_norm: + elif self.cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch + + integrate_cross_entropy_loss_patch(model_type="llama") + + if self.cfg.flash_attn_rms_norm and self.flash_attn: patch_llama_rms_norm() elif self.cfg.unsloth_rms_norm: from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm patch_unsloth_layernorm() - if self.cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - - integrate_cross_entropy_loss_patch(model_type="llama") if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora @@ -481,6 +490,7 @@ def patch_llama_derived_model(self) -> None: """ Modify all llama derived models in one block """ + self.patch_loss_llama() if self.cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( @@ -528,16 +538,6 @@ def patch_llama_derived_model(self) -> None: "Shifted-sparse attention not currently implemented without flash attention." ) - if self.cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - - integrate_cross_entropy_loss_patch(model_type="llama") - - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - - patch_self_attn_lora() - def set_auto_model_loader(self) -> None: """set self.AutoModelLoader - default value: AutoModelForCausalLM (set at __init__)