-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_pytorch_to_ggml.py
129 lines (98 loc) · 4.23 KB
/
convert_pytorch_to_ggml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M-FP16.bin FP16
# Get model checkpoints from https://huggingface.co/BlinkDL
# See FILE_FORMAT.md for the documentation on the file format.
import argparse
import struct
import torch
from typing import Dict
def parse_args():
parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file')
parser.add_argument('src_path', help='Path to PyTorch checkpoint file')
parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, FP16 or FP32', type=str, choices=['FP16', 'FP32', 'float16', 'float32'], default='FP16')
return parser.parse_args()
def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
n_layer: int = 0
while f'blocks.{n_layer}.ln1.weight' in state_dict:
n_layer += 1
assert n_layer > 0
return n_layer
def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None:
emb_weight: torch.Tensor = state_dict['emb.weight']
n_layer: int = get_layer_count(state_dict)
n_vocab: int = emb_weight.shape[0]
n_embed: int = emb_weight.shape[1]
is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict
is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict
print(state_dict.keys())
if is_v6_0:
print('Detected RWKV v6.0')
elif is_v5_2:
print('Detected RWKV v5.2')
elif is_v5_1_or_2:
print('Detected RWKV v5.1')
else:
print('Detected RWKV v4')
with open(dest_path, 'wb') as out_file:
is_FP16: bool = data_type == 'FP16' or data_type == 'float16'
out_file.write(struct.pack(
# Disable padding with '='
'=iiiiii',
# Magic: 'ggmf' in hex
0x67676d66,
101,
n_vocab,
n_embed,
n_layer,
1 if is_FP16 else 0
))
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()
if '.time_' in k:
tensor = tensor.squeeze()
if is_v6_0:
if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
elif is_v5_1_or_2:
if '.time_decay' in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
else:
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)
if '.time_first' in k:
tensor = torch.exp(tensor).reshape(-1, 1, 1)
if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
else:
if '.time_decay' in k:
tensor = -torch.exp(tensor)
# Keep 1-dim vectors and small matrices in FP32
if is_FP16 and len(tensor.shape) > 1 and '.time_' not in k:
tensor = tensor.half()
shape = tensor.shape
print(f'Writing {k}, shape {shape}, type {tensor.dtype}')
k_encoded: bytes = k.encode('utf-8')
out_file.write(struct.pack(
'=iii',
len(shape),
len(k_encoded),
1 if tensor.dtype == torch.float16 else 0
))
# Dimension order is reversed here:
# * PyTorch shape is (x rows, y columns)
# * ggml shape is (y elements in a row, x elements in a column)
# Both shapes represent the same tensor.
for dim in reversed(tensor.shape):
out_file.write(struct.pack('=i', dim))
out_file.write(k_encoded)
tensor.numpy().tofile(out_file)
def main() -> None:
args = parse_args()
print(f'Reading {args.src_path}')
state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location='cpu')
write_state_dict(state_dict, args.dest_path, args.data_type)
print('Done')
if __name__ == "__main__":
main()