Skip to content

Commit

Permalink
use cached_property instead
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 21, 2024
1 parent 838b74d commit ccf0a3c
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

# pylint: disable=too-many-lines
import gc
import importlib
import logging
import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401

import addict
Expand Down Expand Up @@ -409,7 +411,7 @@ def apply_patches(self) -> None:
)

if self.cfg.is_llama_derived_model:
self.patch_loss()
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand Down Expand Up @@ -451,27 +453,34 @@ def patch_attention(self) -> None:

replace_stablelm_attn_with_flash_attn(self.cfg.base_model)

def patch_loss(self) -> None:
@cached_property
def flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None

def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)

if self.cfg.flash_attn_cross_entropy:
if self.cfg.flash_attn_cross_entropy and self.flash_attn:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.flash_attn_rms_norm and self.flash_attn:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm

patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand All @@ -481,6 +490,7 @@ def patch_llama_derived_model(self) -> None:
"""
Modify all llama derived models in one block
"""
self.patch_loss_llama()

if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
Expand Down Expand Up @@ -528,16 +538,6 @@ def patch_llama_derived_model(self) -> None:
"Shifted-sparse attention not currently implemented without flash attention."
)

if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

patch_self_attn_lora()

def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
Expand Down

0 comments on commit ccf0a3c

Please sign in to comment.