Skip to content

Commit

Permalink
Merge pull request #540 from Anhforth/master
Browse files Browse the repository at this point in the history
Added aquila2 finetuning
  • Loading branch information
ftgreat authored Oct 7, 2023
2 parents bf17821 + 41e724a commit e1e0116
Show file tree
Hide file tree
Showing 9 changed files with 521 additions and 27 deletions.
59 changes: 36 additions & 23 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import importlib
import os
import copy
from transformers import AutoTokenizer
import transformers
import math
from flagai.model.file_utils import _get_model_id, _get_checkpoint_path, _get_vocab_path, _get_model_files
from flagai.model.aquila2.modeling_aquila import AquilaForCausalLM
import torch

class LazyImport(object):
Expand Down Expand Up @@ -169,7 +171,8 @@ def __init__(self,
low_cpu_mem_usage=True,
lora_dir=None,
qlora_dir=None,
quantization_config=None,
inference_mode=True,
model_max_length=None,
**kwargs):
"""
Args:
Expand Down Expand Up @@ -205,6 +208,7 @@ def __init__(self,
print(f"All supported models are {list(MODEL_DICT.keys())}")
return
if task_name == "aquila2":
from flagai.model.aquila2.modeling_aquila import AquilaForCausalLM
download_path = os.path.join(model_dir, model_name)

if not os.path.exists(download_path):
Expand Down Expand Up @@ -261,28 +265,37 @@ def __init__(self,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)
if inference_mode:
model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype,
quantization_config=quantization_config)
model.eval()
if not qlora_dir:
model.to(device)
if lora_dir:
from flagai.model.tools.peft import PeftModel
model = PeftModel.from_pretrained(model, lora_dir)
print("lora modules loaded")
if qlora_dir:
from flagai.model.tools.peft import PeftModel
model = PeftModel.from_pretrained(model, qlora_dir)
print("Qlora modules loaded")
else:
# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
download_path,
cache_dir=kwargs['cache_dir'],
trust_remote_code=True,
)
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and model_max_length > orig_ctx_len:
scaling_factor = float(
math.ceil(model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
config.use_cache = False
model = AquilaForCausalLM.from_pretrained(download_path,
**kwargs)


model = AquilaForCausalLM.from_pretrained(download_path,
low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype,
quantization_config=quantization_config)

model.eval()
# from accelerate import load_checkpoint_and_dispatch
# model = load_checkpoint_and_dispatch(
# model, model_dir+model_name, device_map="balanced", no_split_module_classes=["LlamaDecoderLayer"])
if not qlora_dir:
model.to(device)
if lora_dir:
from flagai.model.tools.peft import PeftModel
model = PeftModel.from_pretrained(model, lora_dir)
print("lora modules loaded")
if qlora_dir:
from flagai.model.tools.peft import PeftModel
model = PeftModel.from_pretrained(model, qlora_dir)
print("Qlora modules loaded")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir+model_name)
tokenizer = AutoTokenizer.from_pretrained(download_path)
self.model = model
self.tokenizer = tokenizer
else:
Expand Down
Empty file.
165 changes: 165 additions & 0 deletions flagai/model/aquila2/aquila2_flash_attn_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Copied from https://github.com/lm-sys/FastChat.
Later we will contribute our changes into it.
"""

import warnings
from typing import Optional, Tuple

import torch
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_kvpacked_func,
)
from flagai.model.aquila2.modeling_aquila import (
AquilaAttention,
AquilaModel,
rotate_half,
)


def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
gather_indices = gather_indices.repeat(
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
)
bsz = gather_indices.shape[0]
cos, sin = (
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
for x in cos_sin
)
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
return q, k


def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `AquilaAttention`, returning `None` instead."
)

bsz, q_len, _ = hidden_states.size()
kv_heads = getattr(self, "num_key_value_heads", self.num_heads)

q, k, v = (
op(hidden_states).view(bsz, q_len, nh, self.head_dim)
for op, nh in (
(self.q_proj, self.num_heads),
(self.k_proj, kv_heads),
(self.v_proj, kv_heads),
)
)
# shape: (b, s, num_heads, head_dim)

kv_seq_len = k.shape[1]
past_kv_len = 0
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[1]
kv_seq_len += past_kv_len

cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)

if past_key_value is not None:
# reuse k, v
k = torch.cat([past_key_value[0], k], dim=1)
v = torch.cat([past_key_value[1], v], dim=1)

past_key_value = (k, v) if use_cache else None

key_padding_mask = attention_mask
# Ideally we could just do this:
# q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask[:, -q_len:])
# but this does not work as Flash attention treats the q seq and kv seq as starting at index 0
# which then breaks the causality logic. Probably if q_len >> past_kv_len we should
# just skip flash attention. Leaving this in for now to demonstrate correctness of
# flash attention information even when q needs padding.
# TODO(siddartha): delegate back to original implementation on this condition.
if past_kv_len > 0:
q = torch.cat(
(
torch.full(
(bsz, past_kv_len, self.num_heads, self.head_dim),
0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
)

if key_padding_mask is None:
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
bsz, q_len + past_kv_len, -1
)
else:
q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask)
# We can skip concat and call unpad twice but seems better to call unpad only once.
kv, _, cu_k_lens, max_k = unpad_input(
torch.stack((k, v), dim=2), key_padding_mask
)
output_unpad = flash_attn_varlen_kvpacked_func(
q,
kv,
cu_q_lens,
cu_k_lens,
max_s,
max_k,
0.0,
softmax_scale=None,
causal=True,
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len + past_kv_len)

# Need to strip off the zero query outputs.
if past_kv_len > 0:
output = output[:, past_kv_len:, ...]

return self.o_proj(output), None, past_key_value

# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
(
torch.full(
(input_shape[0], past_key_values_length),
True,
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
),
dim=-1,
)

if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples

return attention_mask


def replace_aquila_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)

AquilaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
AquilaAttention.forward = forward
74 changes: 74 additions & 0 deletions flagai/model/aquila2/aquila_condense_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Copied from https://github.com/lm-sys/FastChat.
Later we will contribute our changes into it.
"""

from functools import partial

import torch
import transformers
import flagai.model.aquila2.modeling_aquila


class CondenseRotaryEmbedding(torch.nn.Module):
def __init__(
self, dim, ratio, max_position_embeddings=2048, base=10000, device=None
):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.ratio = ratio
max_position_embeddings *= ratio
self.max_seq_len_cached = max_position_embeddings
# print(f"Monkey Patching condense ratio {ratio}")
t = (
torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype,
)
/ ratio
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer(
"cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
)
self.register_buffer(
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = (
torch.arange(
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
)
/ self.ratio
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer(
"cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False
)
self.register_buffer(
"sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False
)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


def replace_aquila_with_condense(ratio):
flagai.model.aquila2.modeling_aquila.AquilaRotaryEmbedding = partial(
CondenseRotaryEmbedding, ratio=ratio
)
Empty file modified flagai/model/aquila2/configuration_aquila.py
100644 → 100755
Empty file.
Loading

0 comments on commit e1e0116

Please sign in to comment.