Skip to content

Commit

Permalink
support minicpm3.0 (#605)
Browse files Browse the repository at this point in the history
Co-authored-by: Casper <[email protected]>
  • Loading branch information
LDLINGLINGLING and casper-hansen authored Nov 14, 2024
1 parent 0187ac1 commit b42e3c3
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 1 deletion.
3 changes: 2 additions & 1 deletion awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@
from .deepseek_v2 import DeepseekV2AWQForCausalLM
from .minicpm import MiniCPMAWQForCausalLM
from .internlm2 import InternLM2AWQForCausalLM
from .qwen2vl import Qwen2VLAWQForCausalLM
from .minicpm3 import MiniCPM3AWQForCausalLM
from .qwen2vl import Qwen2VLAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"deepseek_v2": DeepseekV2AWQForCausalLM,
"minicpm": MiniCPMAWQForCausalLM,
"internlm2": InternLM2AWQForCausalLM,
"minicpm3": MiniCPM3AWQForCausalLM,
"qwen2_vl": Qwen2VLAWQForCausalLM,
}

Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"cohere": "AutoModelForCausalLM",
"deepseek_v2": "AutoModelForCausalLM",
"minicpm": "AutoModelForCausalLM",
"minicpm3":"AutoModelForCausalLM",
"internlm2": "AutoModelForCausalLM",
"qwen2_vl": "AutoModelForVision2Seq",
}
Expand Down
69 changes: 69 additions & 0 deletions awq/models/minicpm3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .base import BaseAWQForCausalLM

class MiniCPM3AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MiniCPMDecoderLayer"
max_seq_len_key = "max_position_embeddings"

@staticmethod
def get_model_layers(model):
print(model.model.layers)
return model.model.layers

@staticmethod
def get_act_for_scaling(module):
return dict(is_scalable=False)

@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []

# mlp
layers.append(
dict(
prev_op=module.self_attn.q_a_layernorm,
layers=[
module.self_attn.q_b_proj,

],
inp=input_feat["self_attn.q_b_proj"],
module2inspect=module.self_attn.q_b_proj,
kwargs=module_kwargs,
)
)

layers.append(
dict(
prev_op=module.self_attn.kv_a_layernorm,
layers=[
module.self_attn.kv_b_proj,
],
inp=input_feat["self_attn.kv_b_proj"],
module2inspect=module.self_attn.kv_b_proj,
kwargs=module_kwargs,
)
)


# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj,module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp
)
)

return layers
96 changes: 96 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,102 @@ model.model.config.use_cache = model.model.generation_config.use_cache = True
model.save_quantized(quant_path, safetensors=True, shard_size="4GB")
```

### Another Custom Quantizer (MiniCPM3 Example)

Here we introduce another custom quantizer from the MiniCPM team at OpenBMB. We only
modify the weight clipping mechanism to make quantization work.

```python
import torch
from transformers import AutoTokenizer

from awq import AutoAWQForCausalLM
from awq.quantize.quantizer import AwqQuantizer, clear_memory

class CPM3AwqQuantizer(AwqQuantizer):
@torch.no_grad()
def _compute_best_clip(
self,
w: torch.Tensor,
input_feat: torch.Tensor,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)

# Compute input feature step size (minimum 1)
step_size = max(1, input_feat.shape[1] // n_sample_token)
input_feat = input_feat[:, ::step_size]

w = w.reshape(org_w_shape[0], 1, -1, group_size)

oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM
if org_w_shape[0] % oc_batch_size != 0:
oc_batch_size = org_w_shape[0]
assert org_w_shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []

for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]

org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1

best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group

for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)[0]
cur_out = (input_feat * q_w).sum(dim=-1)

# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)

best_max_val = torch.cat(best_max_val_all, dim=0)

clear_memory(input_feat)
clear_memory(org_out)

return best_max_val.squeeze(1)

model_path = 'openbmb/MiniCPM3-4B'
quant_path = 'minicpm3-4b-awq'
quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, safetensors=False
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config, quantizer_cls=CPM3AwqQuantizer)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')
```

## Basic Inference

### Inference With GPU
Expand Down

0 comments on commit b42e3c3

Please sign in to comment.