Skip to content

Commit

Permalink
Create fused LlamaLikeModel (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Nov 4, 2023
1 parent 84a2686 commit 8110e02
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 289 deletions.
147 changes: 60 additions & 87 deletions awq/models/aquila.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,41 @@
## Reference from llama.py
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as AquilaDecoderLayer,
LlamaForCausalLM as AquilaForCausalLM,
LlamaAttention as AquilaAttention,
LlamaRMSNorm as AquilaRMSNorm,
LlamaMLP as AquilaMLP
LlamaDecoderLayer as OldAquilaDecoderLayer,
LlamaForCausalLM as OldAquilaForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm

class AquilaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "AquilaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: AquilaForCausalLM):
def fuse_layers(model: OldAquilaForCausalLM):
fuser = AquilaFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: AquilaForCausalLM):
def get_model_layers(model: OldAquilaForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: AquilaDecoderLayer):
def get_act_for_scaling(module: OldAquilaDecoderLayer):
return dict(
is_scalable=False
)

@staticmethod
def move_embed(model: AquilaForCausalLM, device: str):
def move_embed(model: OldAquilaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module: OldAquilaDecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
Expand Down Expand Up @@ -72,85 +73,57 @@ def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs

return layers

import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV

class AquilaFuser:
def __init__(self, model):
def __init__(self, model: OldAquilaForCausalLM):
self.model = model

self.attention_modules: List[Tuple[str, AquilaAttention]] = [
(name, module) for name, module in self.model.named_modules()
if "AquilaAttention".lower() in module.__class__.__name__.lower()
]

self.rmsnorm_modules: List[Tuple[str, AquilaRMSNorm]] = [
(name, module) for name, module in self.model.named_modules()
if "AquilaRMSNorm".lower() in module.__class__.__name__.lower()
]

self.mlp_modules: List[Tuple[str, AquilaMLP]] = [
self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if "AquilaMLP".lower() in module.__class__.__name__.lower()
if 'AquilaDecoderLayer'.lower() in module.__class__.__name__.lower()
]

def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantAttentionFused(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
module.o_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
def fuse_transformer(self):
blocks = []

module: OldAquilaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
set_module_name(self.model, name, attn)

def _fuse_qkv(self, module: AquilaAttention):
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None

if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM

qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
)

if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
mlp = QuantLlamaMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))

qkv_layer.bias = bias

return qkv_layer

def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)

def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
145 changes: 62 additions & 83 deletions awq/models/llama.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm

class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: LlamaForCausalLM):
def fuse_layers(model: OldLlamaForCausalLM):
fuser = LlamaFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: LlamaForCausalLM):
def get_model_layers(model: OldLlamaForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: LlamaDecoderLayer):
def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict(
is_scalable=False
)

@staticmethod
def move_embed(model: LlamaForCausalLM, device: str):
def move_embed(model: OldLlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
Expand Down Expand Up @@ -65,86 +73,57 @@ def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs)

return layers

import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP

class LlamaFuser:
def __init__(self, model):
def __init__(self, model: OldLlamaForCausalLM):
self.model = model

self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaAttention)
]

self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaRMSNorm)
]

self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaMLP)
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower()
]

def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantAttentionFused(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
module.o_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
def fuse_transformer(self):
blocks = []

module: OldLlamaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
set_module_name(self.model, name, attn)

def _fuse_qkv(self, module: LlamaAttention):
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None

if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM

qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
)

if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
mlp = QuantLlamaMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))

qkv_layer.bias = bias

return qkv_layer

def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)

def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
Loading

0 comments on commit 8110e02

Please sign in to comment.