forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
weight.py
463 lines (396 loc) · 21.5 KB
/
weight.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from operator import attrgetter
import numpy as np
import torch
import tensorrt_llm
from tensorrt_llm._utils import pad_vocab_size
from tensorrt_llm.models import GPTNeoXForCausalLM
UINT4_TO_INT4_FLAG = 1
GPTQ_FLAG = 1
GROUP_SIZE = 128
def numpy_split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
else:
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx])
def torch_split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
else:
return (torch.split(v, v.shape[dim] // tp_size,
dim=dim)[idx]).contiguous()
def unpack_int32_into_int8(w_packed):
# Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
w_unpacked = torch.zeros(w_packed_int4x2.shape[0],
w_packed_int4x2.shape[1] * 2,
dtype=torch.int8)
w_unpacked[:, ::2] = w_packed_int4x2 % 16
w_unpacked[:, 1::2] = w_packed_int4x2 // 16
return w_unpacked.contiguous()
def preprocess_groupwise_weight_params(qweight_unpacked_int8, scales_fp16,
qzeros_unpacked_int8):
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
qweight_interleaved = preprocessor(packer(qweight_unpacked_int8),
torch.quint4x2).view(torch.float32)
# zeros = zeros * scales
zeros_x_scales_fp16 = (-qzeros_unpacked_int8 + 8 * UINT4_TO_INT4_FLAG -
GPTQ_FLAG) * scales_fp16
zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
# return processed interleaved weight, original scales and zeros * scales
return qweight_interleaved.contiguous().numpy(), scales_fp16.contiguous(
).numpy(), zeros_x_scales_fp16.contiguous().numpy()
def load_from_hf_gpt_neox(tensorrt_llm_gpt_neox: GPTNeoXForCausalLM,
hf_gpt_neox,
fp16=False,
rank=0,
tp_size=1,
use_weight_only_groupwise_quant_matmul_plugin=False):
hf_model_gptneox_block_names = [
"input_layernorm.weight",
"input_layernorm.bias",
"post_attention_layernorm.weight",
"post_attention_layernorm.bias",
]
tensorrt_llm_model_gptneox_block_names = [
"input_layernorm.weight",
"input_layernorm.bias",
"post_attention_layernorm.weight",
"post_attention_layernorm.bias",
]
if not use_weight_only_groupwise_quant_matmul_plugin:
hf_model_gptneox_block_names += [
"attention.dense.weight",
"attention.dense.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"mlp.dense_4h_to_h.weight",
"mlp.dense_4h_to_h.bias",
]
tensorrt_llm_model_gptneox_block_names += [
"attention.dense.weight",
"attention.dense.bias",
"mlp.fc.weight",
"mlp.fc.bias",
"mlp.proj.weight",
"mlp.proj.bias",
]
if not use_weight_only_groupwise_quant_matmul_plugin:
tensorrt_llm.logger.info('Loading weights from HF GPT-NeoX...')
else:
tensorrt_llm.logger.info(
'Loading weights from GPTQ quantized HF GPT-NeoX...')
tik = time.time()
torch_dtype = torch.float16 if fp16 else torch.float32
hf_gpt_neox_state_dict = hf_gpt_neox.state_dict()
# [vocab_size, hidden_size]
v = hf_gpt_neox_state_dict.get('gpt_neox.embed_in.weight').to(
torch_dtype).cpu().numpy()
if tensorrt_llm_gpt_neox._use_parallel_embedding:
v = numpy_split(v, tp_size, rank,
tensorrt_llm_gpt_neox._embedding_sharding_dim)
tensorrt_llm_gpt_neox.embedding.weight.value = v
n_layer = hf_gpt_neox.config.num_hidden_layers
for layer_idx in range(n_layer):
prefix = "gpt_neox.layers." + str(layer_idx) + "."
for idx, hf_attr in enumerate(hf_model_gptneox_block_names):
v = hf_gpt_neox_state_dict.get(prefix + hf_attr).to(
torch_dtype).cpu().numpy()
layer = attrgetter(tensorrt_llm_model_gptneox_block_names[idx])(
tensorrt_llm_gpt_neox.layers[layer_idx])
if tp_size > 1:
if 'dense.weight' in hf_attr:
# [n=hidden_size, k=hidden_size] ->
# [n=hidden_size, k=hidden_size // tp_size]
split_v = numpy_split(v, tp_size, rank, dim=1)
elif 'dense_h_to_4h.weight' in hf_attr:
# [hidden_size * 4, hidden_size] ->
# [hidden_size * 4 // tp_size, hidden_size]
split_v = numpy_split(v, tp_size, rank, dim=0)
elif 'dense_h_to_4h.bias' in hf_attr:
# [hidden_size * 4] -> [hidden_size * 4 // tp_size]
split_v = numpy_split(v, tp_size, rank, dim=0)
elif 'dense_4h_to_h.weight' in hf_attr:
# [hidden_size, hidden_size * 4] ->
# [hidden_size, hidden_size * 4 // tp_size]
split_v = numpy_split(v, tp_size, rank, dim=1)
else:
split_v = v
setattr(layer, 'value', split_v)
else:
setattr(layer, 'value', v)
num_heads = hf_gpt_neox.config.num_attention_heads
hidden_size = hf_gpt_neox.config.hidden_size
head_size = hidden_size // num_heads
if not use_weight_only_groupwise_quant_matmul_plugin:
# Attention QKV Linear
# qkv_weights [num_heads x (q|k|v), hidden_size] ->
# [(num_heads x q)|(num_heads x k)|(num_heads x v), hidden_size]
qkv_weights = hf_gpt_neox_state_dict.get(
prefix + "attention.query_key_value.weight")
qkv_bias = hf_gpt_neox_state_dict.get(
prefix + "attention.query_key_value.bias")
new_qkv_weight_shape = torch.Size(
[num_heads, 3, head_size * qkv_weights.size()[-1]])
new_qkv_bias_shape = torch.Size([num_heads, 3, head_size])
qkv_weights = qkv_weights.view(new_qkv_weight_shape).permute(
1, 0, 2).reshape([hidden_size * 3, hidden_size])
qkv_bias = qkv_bias.view(new_qkv_bias_shape).permute(
1, 0, 2).reshape([hidden_size * 3])
if tp_size > 1:
qkv_weights = qkv_weights.reshape(
3, hidden_size, hidden_size).to(torch_dtype).cpu().numpy()
split_qkv_weights = numpy_split(
qkv_weights, tp_size, rank,
dim=1).reshape(3 * (hidden_size // tp_size), hidden_size)
tensorrt_llm_gpt_neox.layers[layer_idx].attention.qkv.weight.value = \
np.ascontiguousarray(split_qkv_weights)
qkv_bias = qkv_bias.reshape(
3, hidden_size).to(torch_dtype).cpu().numpy()
split_qkv_bias = numpy_split(qkv_bias, tp_size, rank,
dim=1).reshape(
3 * (hidden_size // tp_size))
tensorrt_llm_gpt_neox.layers[layer_idx].attention.qkv.bias.value = \
np.ascontiguousarray(split_qkv_bias)
else:
tensorrt_llm_gpt_neox.layers[layer_idx].attention.qkv.weight.value = \
qkv_weights.to(torch_dtype).cpu().numpy()
tensorrt_llm_gpt_neox.layers[layer_idx].attention.qkv.bias.value = \
qkv_bias.to(torch_dtype).cpu().numpy()
else:
# use_weight_only_groupwise_quant_matmul_plugin
qweight_int32 = hf_gpt_neox_state_dict.get(
prefix + "attention.query_key_value.qweight")
scales_fp16 = hf_gpt_neox_state_dict.get(
prefix + "attention.query_key_value.scales")
qzeros_int32 = hf_gpt_neox_state_dict.get(
prefix + "attention.query_key_value.qzeros")
biases_fp16 = hf_gpt_neox_state_dict.get(
prefix + "attention.query_key_value.bias")
# [hidden_size // 8, hidden_size * 3] -> [hidden_size * 3, hidden_size]
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).contiguous() - 8
# [hidden_size // GROUP_SIZE, hidden_size * 3 // 8] ->
# [hidden_size // GROUP_SIZE, hidden_size * 3]
qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32)
# qkv_weights [num_heads x (q|k|v), hidden_size] ->
# [(num_heads x q)|(num_heads x k)|(num_heads x v), hidden_size]
new_qkv_weight_shape = torch.Size(
[num_heads, 3, head_size * qweight_unpacked_int8.size()[-1]])
# [hidden_size * 3, hidden_size]
qweight_unpacked_int8 = qweight_unpacked_int8.view(
new_qkv_weight_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size]).contiguous()
new_qkv_scale_shape = torch.Size(
[num_heads, 3, head_size * (hidden_size // GROUP_SIZE)])
# [hidden_size * 3, hidden_size // GROUP_SIZE]
scales_fp16 = scales_fp16.T.contiguous().view(
new_qkv_scale_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size // GROUP_SIZE]).contiguous()
new_qkv_zero_shape = torch.Size(
[num_heads, 3, head_size * (hidden_size // GROUP_SIZE)])
# [hidden_size * 3, hidden_size // GROUP_SIZE]
qzeros_unpacked_int8 = qzeros_unpacked_int8.T.contiguous().view(
new_qkv_zero_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size // GROUP_SIZE]).contiguous()
new_qkv_bias_shape = torch.Size([num_heads, 3, head_size])
biases_fp16 = biases_fp16.view(new_qkv_bias_shape).permute(
1, 0, 2).reshape([hidden_size * 3]).numpy()
if tp_size > 1:
qweight_unpacked_int8 = qweight_unpacked_int8.reshape(
[3, hidden_size, hidden_size])
qweight_unpacked_int8 = torch_split(qweight_unpacked_int8,
tp_size,
rank,
dim=1)
qweight_unpacked_int8 = qweight_unpacked_int8.reshape(
[3 * hidden_size // tp_size, hidden_size])
scales_fp16 = scales_fp16.reshape(
[3, hidden_size, hidden_size // GROUP_SIZE])
scales_fp16 = torch_split(scales_fp16, tp_size, rank, dim=1)
scales_fp16 = scales_fp16.reshape(
[3 * hidden_size // tp_size, hidden_size // GROUP_SIZE])
qzeros_unpacked_int8 = qzeros_unpacked_int8.reshape(
[3, hidden_size, hidden_size // GROUP_SIZE])
qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8,
tp_size,
rank,
dim=1)
qzeros_unpacked_int8 = qzeros_unpacked_int8.reshape(
[3 * hidden_size // tp_size, hidden_size // GROUP_SIZE])
biases_fp16 = biases_fp16.reshape([3, hidden_size])
biases_fp16 = numpy_split(biases_fp16, tp_size, rank, dim=1)
biases_fp16 = biases_fp16.reshape([3 * hidden_size // tp_size])
qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params(
qweight_unpacked_int8.T.contiguous(),
scales_fp16.T.contiguous(), qzeros_unpacked_int8.T.contiguous())
tensorrt_llm_gpt_neox.layers[layer_idx].attention.qkv.qweight.value = \
qweight_fp32
tensorrt_llm_gpt_neox.layers[layer_idx].attention.qkv.scale.value = \
scales_fp16
tensorrt_llm_gpt_neox.layers[layer_idx].attention.qkv.zero.value = \
zeros_fp16
tensorrt_llm_gpt_neox.layers[layer_idx].attention.qkv.bias.value = \
biases_fp16
qweight_int32 = hf_gpt_neox_state_dict.get(
prefix + "attention.dense.qweight")
scales_fp16 = hf_gpt_neox_state_dict.get(prefix +
"attention.dense.scales")
qzeros_int32 = hf_gpt_neox_state_dict.get(prefix +
"attention.dense.qzeros")
biases_fp16 = hf_gpt_neox_state_dict.get(
prefix + "attention.dense.bias").numpy()
# [k=hidden_size // 8, n=hidden_size] -> [n=hidden_size, k=hidden_size]
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).contiguous() - 8
# [n=hidden_size, k=hidden_size] -> [k=hidden_size, n=hidden_size]
qweight_unpacked_int8 = qweight_unpacked_int8.T.contiguous()
# [k=hidden_size // GROUP_SIZE, n=hidden_size // 8] ->
# [k=hidden_size // GROUP_SIZE, n=hidden_size]
qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32)
if tp_size > 1:
qweight_unpacked_int8 = torch_split(qweight_unpacked_int8,
tp_size,
rank,
dim=0)
scales_fp16 = torch_split(scales_fp16, tp_size, rank, dim=0)
qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8,
tp_size,
rank,
dim=0)
if rank > 0:
biases_fp16 = np.zeros_like(biases_fp16)
qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params(
qweight_unpacked_int8, scales_fp16, qzeros_unpacked_int8)
tensorrt_llm_gpt_neox.layers[layer_idx].attention.dense.qweight.value = \
qweight_fp32
tensorrt_llm_gpt_neox.layers[layer_idx].attention.dense.scale.value = \
scales_fp16
tensorrt_llm_gpt_neox.layers[layer_idx].attention.dense.zero.value = \
zeros_fp16
tensorrt_llm_gpt_neox.layers[layer_idx].attention.dense.bias.value = \
biases_fp16
qweight_int32 = hf_gpt_neox_state_dict.get(
prefix + "mlp.dense_h_to_4h.qweight")
scales_fp16 = hf_gpt_neox_state_dict.get(prefix +
"mlp.dense_h_to_4h.scales")
qzeros_int32 = hf_gpt_neox_state_dict.get(
prefix + "mlp.dense_h_to_4h.qzeros")
biases_fp16 = hf_gpt_neox_state_dict.get(
prefix + "mlp.dense_h_to_4h.bias").numpy()
# [hidden_size // 8, hidden_size * 4] -> [hidden_size, hidden_size * 4]
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).contiguous() - 8
qweight_unpacked_int8 = qweight_unpacked_int8.T.contiguous()
# [hidden_size // GROUP_SIZE, hidden_size * 4 // 8] ->
# [hidden_size // GROUP_SIZE, hidden_size * 4]
qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32)
if tp_size > 1:
# [hidden_size, hidden_size * 4] ->
# [hidden_size, hidden_size * 4 // tp_size]
qweight_unpacked_int8 = torch_split(qweight_unpacked_int8,
tp_size,
rank,
dim=1)
# [hidden_size // GROUP_SIZE, hidden_size * 4] ->
# [hidden_size // GROUP_SIZE, hidden_size * 4 // tp_size]
scales_fp16 = torch_split(scales_fp16, tp_size, rank, dim=1)
# [hidden_size // GROUP_SIZE, hidden_size * 4] ->
# [hidden_size // GROUP_SIZE, hidden_size * 4 // tp_size]
qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8,
tp_size,
rank,
dim=1)
# [hidden_size * 4] -> [hidden_size * 4 // tp_size]
biases_fp16 = numpy_split(biases_fp16, tp_size, rank, dim=0)
qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params(
qweight_unpacked_int8, scales_fp16, qzeros_unpacked_int8)
tensorrt_llm_gpt_neox.layers[layer_idx].mlp.fc.qweight.value = \
qweight_fp32
tensorrt_llm_gpt_neox.layers[layer_idx].mlp.fc.scale.value = \
scales_fp16
tensorrt_llm_gpt_neox.layers[layer_idx].mlp.fc.zero.value = \
zeros_fp16
tensorrt_llm_gpt_neox.layers[layer_idx].mlp.fc.bias.value = \
biases_fp16
qweight_int32 = hf_gpt_neox_state_dict.get(
prefix + "mlp.dense_4h_to_h.qweight")
scales_fp16 = hf_gpt_neox_state_dict.get(prefix +
"mlp.dense_4h_to_h.scales")
qzeros_int32 = hf_gpt_neox_state_dict.get(
prefix + "mlp.dense_4h_to_h.qzeros")
biases_fp16 = hf_gpt_neox_state_dict.get(
prefix + "mlp.dense_4h_to_h.bias").numpy()
# [hidden_size * 4 // 8, hidden_size] -> [hidden_size * 4, hidden_size]
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).contiguous() - 8
qweight_unpacked_int8 = qweight_unpacked_int8.T.contiguous()
# [hidden_size * 4 // GROUP_SIZE, hidden_size // 8] ->
# [hidden_size * 4 // GROUP_SIZE, hidden_size]
qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32)
if tp_size > 1:
# [hidden_size * 4, hidden_size] ->
# [hidden_size * 4 // tp_size, hidden_size]
qweight_unpacked_int8 = torch_split(qweight_unpacked_int8,
tp_size,
rank,
dim=0)
# [hidden_size * 4 // GROUP_SIZE, hidden_size] ->
# [hidden_size * 4 // GROUP_SIZE // tp_size, hidden_size] ->
scales_fp16 = torch_split(scales_fp16, tp_size, rank, dim=0)
# [hidden_size * 4 // GROUP_SIZE, hidden_size] ->
# [hidden_size * 4 // GROUP_SIZE // tp_size, hidden_size]
qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8,
tp_size,
rank,
dim=0)
if rank > 0:
biases_fp16 = np.zeros_like(biases_fp16)
qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params(
qweight_unpacked_int8, scales_fp16, qzeros_unpacked_int8)
tensorrt_llm_gpt_neox.layers[layer_idx].mlp.proj.qweight.value = \
qweight_fp32
tensorrt_llm_gpt_neox.layers[layer_idx].mlp.proj.scale.value = \
scales_fp16
tensorrt_llm_gpt_neox.layers[layer_idx].mlp.proj.zero.value = \
zeros_fp16
tensorrt_llm_gpt_neox.layers[layer_idx].mlp.proj.bias.value = \
biases_fp16
v = hf_gpt_neox_state_dict.get('gpt_neox.final_layer_norm.weight')
tensorrt_llm_gpt_neox.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
v = hf_gpt_neox_state_dict.get('gpt_neox.final_layer_norm.bias')
tensorrt_llm_gpt_neox.ln_f.bias.value = v.to(torch_dtype).cpu().numpy()
v = hf_gpt_neox_state_dict.get('embed_out.weight').to(
torch_dtype).cpu().numpy()
if tp_size > 1:
# [vocab_size, hidden_size] ->
# [vocab_size // tp_size, hidden_size]
if v.shape[0] % tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(v.shape[0], tp_size)
pad_width = vocab_size_padded - v.shape[0]
v = np.pad(v, ((0, pad_width), (0, 0)),
'constant',
constant_values=0)
split_v = numpy_split(v, tp_size, rank, dim=0)
tensorrt_llm_gpt_neox.lm_head.weight.value = split_v
else:
tensorrt_llm_gpt_neox.lm_head.weight.value = v
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')