-
Notifications
You must be signed in to change notification settings - Fork 297
/
merge_lora.py
33 lines (27 loc) · 1.04 KB
/
merge_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# -*- coding:utf-8 -*-
# @project: ChatGPT
# @filename: merge_lora
# @author: 刘聪NLP
# @zhihu: https://www.zhihu.com/people/LiuCongNLP
# @contact: [email protected]
# @time: 2023/8/6 16:13
"""
文件说明:
"""
import torch
from model import MODE
import argparse
from peft import PeftModel
def set_args():
parser = argparse.ArgumentParser()
parser.add_argument('--ori_model_dir', default="ChatGLM2-6B", type=str, help='')
parser.add_argument('--model_dir', default="output-glm2/epoch-2-step-3900/", type=str, help='')
parser.add_argument('--mode', default="glm2", type=str, help='')
return parser.parse_args()
if __name__ == '__main__':
args = set_args()
base_model = MODE[args.mode]["model"].from_pretrained(args.ori_model_dir, torch_dtype=torch.float16)
lora_model = PeftModel.from_pretrained(base_model, args.model_dir, torch_dtype=torch.float16)
lora_model.to("cpu")
model = lora_model.merge_and_unload()
MODE[args.mode]["model"].save_pretrained(model, args.model_dir, max_shard_size="2GB")