-
Notifications
You must be signed in to change notification settings - Fork 416
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #540 from Anhforth/master
Added aquila2 finetuning
- Loading branch information
Showing
9 changed files
with
521 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
165 changes: 165 additions & 0 deletions
165
flagai/model/aquila2/aquila2_flash_attn_monkey_patch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
""" | ||
Copied from https://github.com/lm-sys/FastChat. | ||
Later we will contribute our changes into it. | ||
""" | ||
|
||
import warnings | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from flash_attn.bert_padding import pad_input, unpad_input | ||
from flash_attn.flash_attn_interface import ( | ||
flash_attn_func, | ||
flash_attn_varlen_kvpacked_func, | ||
) | ||
from flagai.model.aquila2.modeling_aquila import ( | ||
AquilaAttention, | ||
AquilaModel, | ||
rotate_half, | ||
) | ||
|
||
|
||
def apply_rotary_pos_emb(q, k, cos_sin, position_ids): | ||
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] | ||
gather_indices = gather_indices.repeat( | ||
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] | ||
) | ||
bsz = gather_indices.shape[0] | ||
cos, sin = ( | ||
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) | ||
for x in cos_sin | ||
) | ||
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) | ||
return q, k | ||
|
||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.Tensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
if output_attentions: | ||
warnings.warn( | ||
"Output attentions is not supported for patched `AquilaAttention`, returning `None` instead." | ||
) | ||
|
||
bsz, q_len, _ = hidden_states.size() | ||
kv_heads = getattr(self, "num_key_value_heads", self.num_heads) | ||
|
||
q, k, v = ( | ||
op(hidden_states).view(bsz, q_len, nh, self.head_dim) | ||
for op, nh in ( | ||
(self.q_proj, self.num_heads), | ||
(self.k_proj, kv_heads), | ||
(self.v_proj, kv_heads), | ||
) | ||
) | ||
# shape: (b, s, num_heads, head_dim) | ||
|
||
kv_seq_len = k.shape[1] | ||
past_kv_len = 0 | ||
if past_key_value is not None: | ||
past_kv_len = past_key_value[0].shape[1] | ||
kv_seq_len += past_kv_len | ||
|
||
cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) | ||
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) | ||
|
||
if past_key_value is not None: | ||
# reuse k, v | ||
k = torch.cat([past_key_value[0], k], dim=1) | ||
v = torch.cat([past_key_value[1], v], dim=1) | ||
|
||
past_key_value = (k, v) if use_cache else None | ||
|
||
key_padding_mask = attention_mask | ||
# Ideally we could just do this: | ||
# q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask[:, -q_len:]) | ||
# but this does not work as Flash attention treats the q seq and kv seq as starting at index 0 | ||
# which then breaks the causality logic. Probably if q_len >> past_kv_len we should | ||
# just skip flash attention. Leaving this in for now to demonstrate correctness of | ||
# flash attention information even when q needs padding. | ||
# TODO(siddartha): delegate back to original implementation on this condition. | ||
if past_kv_len > 0: | ||
q = torch.cat( | ||
( | ||
torch.full( | ||
(bsz, past_kv_len, self.num_heads, self.head_dim), | ||
0.0, | ||
dtype=q.dtype, | ||
device=q.device, | ||
), | ||
q, | ||
), | ||
dim=1, | ||
) | ||
|
||
if key_padding_mask is None: | ||
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( | ||
bsz, q_len + past_kv_len, -1 | ||
) | ||
else: | ||
q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask) | ||
# We can skip concat and call unpad twice but seems better to call unpad only once. | ||
kv, _, cu_k_lens, max_k = unpad_input( | ||
torch.stack((k, v), dim=2), key_padding_mask | ||
) | ||
output_unpad = flash_attn_varlen_kvpacked_func( | ||
q, | ||
kv, | ||
cu_q_lens, | ||
cu_k_lens, | ||
max_s, | ||
max_k, | ||
0.0, | ||
softmax_scale=None, | ||
causal=True, | ||
) | ||
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) | ||
output = pad_input(output_unpad, indices, bsz, q_len + past_kv_len) | ||
|
||
# Need to strip off the zero query outputs. | ||
if past_kv_len > 0: | ||
output = output[:, past_kv_len:, ...] | ||
|
||
return self.o_proj(output), None, past_key_value | ||
|
||
# takes a boolean key_padding_mask. Fills in the past kv length for use in forward. | ||
def _prepare_decoder_attention_mask( | ||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length | ||
): | ||
# [bsz, seq_len] | ||
if past_key_values_length > 0 and attention_mask is not None: | ||
attention_mask = torch.cat( | ||
( | ||
torch.full( | ||
(input_shape[0], past_key_values_length), | ||
True, | ||
dtype=attention_mask.dtype, | ||
device=attention_mask.device, | ||
), | ||
attention_mask, | ||
), | ||
dim=-1, | ||
) | ||
|
||
if attention_mask is not None and torch.all(attention_mask): | ||
return None # This uses the faster call when training with full samples | ||
|
||
return attention_mask | ||
|
||
|
||
def replace_aquila_attn_with_flash_attn(): | ||
cuda_major, cuda_minor = torch.cuda.get_device_capability() | ||
if cuda_major < 8: | ||
warnings.warn( | ||
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." | ||
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" | ||
) | ||
|
||
AquilaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask | ||
AquilaAttention.forward = forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Copied from https://github.com/lm-sys/FastChat. | ||
Later we will contribute our changes into it. | ||
""" | ||
|
||
from functools import partial | ||
|
||
import torch | ||
import transformers | ||
import flagai.model.aquila2.modeling_aquila | ||
|
||
|
||
class CondenseRotaryEmbedding(torch.nn.Module): | ||
def __init__( | ||
self, dim, ratio, max_position_embeddings=2048, base=10000, device=None | ||
): | ||
super().__init__() | ||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) | ||
self.register_buffer("inv_freq", inv_freq) | ||
|
||
# Build here to make `torch.jit.trace` work. | ||
self.ratio = ratio | ||
max_position_embeddings *= ratio | ||
self.max_seq_len_cached = max_position_embeddings | ||
# print(f"Monkey Patching condense ratio {ratio}") | ||
t = ( | ||
torch.arange( | ||
self.max_seq_len_cached, | ||
device=self.inv_freq.device, | ||
dtype=self.inv_freq.dtype, | ||
) | ||
/ ratio | ||
) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||
emb = torch.cat((freqs, freqs), dim=-1) | ||
dtype = torch.get_default_dtype() | ||
self.register_buffer( | ||
"cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False | ||
) | ||
self.register_buffer( | ||
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False | ||
) | ||
|
||
def forward(self, x, seq_len=None): | ||
# x: [bs, num_attention_heads, seq_len, head_size] | ||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. | ||
if seq_len > self.max_seq_len_cached: | ||
self.max_seq_len_cached = seq_len | ||
t = ( | ||
torch.arange( | ||
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype | ||
) | ||
/ self.ratio | ||
) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | ||
self.register_buffer( | ||
"cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False | ||
) | ||
self.register_buffer( | ||
"sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False | ||
) | ||
return ( | ||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | ||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | ||
) | ||
|
||
|
||
def replace_aquila_with_condense(ratio): | ||
flagai.model.aquila2.modeling_aquila.AquilaRotaryEmbedding = partial( | ||
CondenseRotaryEmbedding, ratio=ratio | ||
) |
Empty file.
Oops, something went wrong.