diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 7f6ff55a..79ca150e 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -25,4 +25,5 @@ from .deepseek_v2 import DeepseekV2AWQForCausalLM from .minicpm import MiniCPMAWQForCausalLM from .internlm2 import InternLM2AWQForCausalLM -from .qwen2vl import Qwen2VLAWQForCausalLM \ No newline at end of file +from .minicpm3 import MiniCPM3AWQForCausalLM +from .qwen2vl import Qwen2VLAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index df67844a..5f6378f7 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -35,6 +35,7 @@ "deepseek_v2": DeepseekV2AWQForCausalLM, "minicpm": MiniCPMAWQForCausalLM, "internlm2": InternLM2AWQForCausalLM, + "minicpm3": MiniCPM3AWQForCausalLM, "qwen2_vl": Qwen2VLAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index 71f45d1d..a5fbf4c3 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -85,6 +85,7 @@ "cohere": "AutoModelForCausalLM", "deepseek_v2": "AutoModelForCausalLM", "minicpm": "AutoModelForCausalLM", + "minicpm3":"AutoModelForCausalLM", "internlm2": "AutoModelForCausalLM", "qwen2_vl": "AutoModelForVision2Seq", } diff --git a/awq/models/minicpm3.py b/awq/models/minicpm3.py new file mode 100644 index 00000000..91bcb244 --- /dev/null +++ b/awq/models/minicpm3.py @@ -0,0 +1,69 @@ +from .base import BaseAWQForCausalLM + +class MiniCPM3AWQForCausalLM(BaseAWQForCausalLM): + layer_type = "MiniCPMDecoderLayer" + max_seq_len_key = "max_position_embeddings" + + @staticmethod + def get_model_layers(model): + print(model.model.layers) + return model.model.layers + + @staticmethod + def get_act_for_scaling(module): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module, input_feat, module_kwargs): + layers = [] + + # mlp + layers.append( + dict( + prev_op=module.self_attn.q_a_layernorm, + layers=[ + module.self_attn.q_b_proj, + + ], + inp=input_feat["self_attn.q_b_proj"], + module2inspect=module.self_attn.q_b_proj, + kwargs=module_kwargs, + ) + ) + + layers.append( + dict( + prev_op=module.self_attn.kv_a_layernorm, + layers=[ + module.self_attn.kv_b_proj, + ], + inp=input_feat["self_attn.kv_b_proj"], + module2inspect=module.self_attn.kv_b_proj, + kwargs=module_kwargs, + ) + ) + + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj,module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp + ) + ) + + return layers \ No newline at end of file diff --git a/docs/examples.md b/docs/examples.md index 6032b212..de4cd78a 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -424,6 +424,102 @@ model.model.config.use_cache = model.model.generation_config.use_cache = True model.save_quantized(quant_path, safetensors=True, shard_size="4GB") ``` +### Another Custom Quantizer (MiniCPM3 Example) + +Here we introduce another custom quantizer from the MiniCPM team at OpenBMB. We only +modify the weight clipping mechanism to make quantization work. + +```python +import torch +from transformers import AutoTokenizer + +from awq import AutoAWQForCausalLM +from awq.quantize.quantizer import AwqQuantizer, clear_memory + +class CPM3AwqQuantizer(AwqQuantizer): + @torch.no_grad() + def _compute_best_clip( + self, + w: torch.Tensor, + input_feat: torch.Tensor, + n_grid=20, + max_shrink=0.5, + n_sample_token=512, + ): + assert w.dim() == 2 + org_w_shape = w.shape + # w [co, ci] -> [co, 1, n_group, group size] + # input_feat [n_token, ci] -> [1, n_token, n_group, group size] + group_size = self.group_size if self.group_size > 0 else org_w_shape[1] + input_feat = input_feat.view(-1, input_feat.shape[-1]) + input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) + + # Compute input feature step size (minimum 1) + step_size = max(1, input_feat.shape[1] // n_sample_token) + input_feat = input_feat[:, ::step_size] + + w = w.reshape(org_w_shape[0], 1, -1, group_size) + + oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM + if org_w_shape[0] % oc_batch_size != 0: + oc_batch_size = org_w_shape[0] + assert org_w_shape[0] % oc_batch_size == 0 + w_all = w + best_max_val_all = [] + + for i_b in range(org_w_shape[0] // oc_batch_size): + w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size] + + org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 + + best_max_val = org_max_val.clone() + min_errs = torch.ones_like(org_max_val) * 1e9 + input_feat = input_feat.to(w.device) + org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group + + for i_s in range(int(max_shrink * n_grid)): + max_val = org_max_val * (1 - i_s / n_grid) + min_val = -max_val + cur_w = torch.clamp(w, min_val, max_val) + q_w = self.pseudo_quantize_tensor(cur_w)[0] + cur_out = (input_feat * q_w).sum(dim=-1) + + # co, 1, n_group, 1 + err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) + del cur_w + del cur_out + cur_best_idx = err < min_errs + min_errs[cur_best_idx] = err[cur_best_idx] + best_max_val[cur_best_idx] = max_val[cur_best_idx] + best_max_val_all.append(best_max_val) + + best_max_val = torch.cat(best_max_val_all, dim=0) + + clear_memory(input_feat) + clear_memory(org_out) + + return best_max_val.squeeze(1) + +model_path = 'openbmb/MiniCPM3-4B' +quant_path = 'minicpm3-4b-awq' +quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" } + +# Load model +model = AutoAWQForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, use_cache=False, safetensors=False +) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +# Quantize +model.quantize(tokenizer, quant_config=quant_config, quantizer_cls=CPM3AwqQuantizer) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) + +print(f'Model is quantized and saved at "{quant_path}"') +``` + ## Basic Inference ### Inference With GPU