forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
summarize.py
370 lines (312 loc) · 14.8 KB
/
summarize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import json
import os
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import AutoModelForCausalLM, AutoTokenizer
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm.logger import logger
from build import get_engine_name # isort:skip
def TRTOPT(args, config):
dtype = config['builder_config']['precision']
world_size = config['builder_config']['tensor_parallel']
assert world_size == tensorrt_llm.mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
use_gpt_attention_plugin = bool(
config['plugin_config']['gpt_attention_plugin'])
world_size = config['builder_config']['tensor_parallel']
num_heads = config['builder_config']['num_heads'] // world_size
hidden_size = config['builder_config']['hidden_size'] // world_size
vocab_size = config['builder_config']['vocab_size']
num_layers = config['builder_config']['num_layers']
remove_input_padding = config['plugin_config']['remove_input_padding']
model_config = tensorrt_llm.runtime.ModelConfig(
vocab_size=vocab_size,
num_layers=num_layers,
num_heads=num_heads,
num_kv_heads=num_heads,
hidden_size=hidden_size,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
dtype=dtype)
runtime_rank = tensorrt_llm.mpi_rank()
runtime_mapping = tensorrt_llm.Mapping(world_size,
runtime_rank,
tp_size=world_size)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
engine_name = get_engine_name('opt', dtype, world_size, runtime_rank)
serialize_path = os.path.join(args.engine_dir, engine_name)
tensorrt_llm.logger.set_level(args.log_level)
with open(serialize_path, 'rb') as f:
engine_buffer = f.read()
decoder = tensorrt_llm.runtime.GenerationSession(model_config,
engine_buffer,
runtime_mapping)
return decoder
def main(args):
runtime_rank = tensorrt_llm.mpi_rank()
logger.set_level(args.log_level)
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
test_trt_llm = args.test_trt_llm
hf_model_location = args.hf_model_location
tokenizer = AutoTokenizer.from_pretrained(hf_model_location,
padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
dataset_cnn = load_dataset("ccdv/cnn_dailymail",
'3.0.0',
cache_dir=args.dataset_path)
config_path = os.path.join(args.engine_dir, 'config.json')
with open(config_path, 'r') as f:
config = json.load(f)
max_batch_size = args.batch_size
# runtime parameters
# repetition_penalty = 1
top_k = args.top_k
output_len = 100
test_token_num = 923
# top_p = 0.0
# random_seed = 5
temperature = 1
num_beams = args.num_beams
# model hyper-parameters
pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0]
if test_trt_llm:
tensorrt_llm_gpt = TRTOPT(args, config)
if test_hf:
model = AutoModelForCausalLM.from_pretrained(hf_model_location)
model.cuda()
if args.data_type == 'fp16':
model.half()
def summarize_tensorrt_llm(datapoint):
batch_size = len(datapoint['article'])
line = copy.copy(datapoint['article'])
line_encoded = []
input_lengths = []
for i in range(batch_size):
line[i] = line[i] + ' TL;DR: '
line[i] = line[i].strip()
line[i] = line[i].replace(" n't", "n't")
input_id = tokenizer.encode(line[i],
return_tensors='pt',
add_special_tokens=False).type(
torch.int32)
input_id = input_id[:, -test_token_num:]
line_encoded.append(input_id)
input_lengths.append(input_id.shape[-1])
# do padding, should move outside the profiling to prevent the overhead
max_length = max(input_lengths)
if tensorrt_llm_gpt.remove_input_padding:
line_encoded = [
torch.tensor(t, dtype=torch.int32).cuda() for t in line_encoded
]
else:
# do padding, should move outside the profiling to prevent the overhead
for i in range(batch_size):
pad_size = max_length - input_lengths[i]
pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id
line_encoded[i] = torch.cat(
[torch.tensor(line_encoded[i], dtype=torch.int32), pad],
axis=-1)
line_encoded = torch.cat(line_encoded, axis=0).cuda()
input_lengths = torch.tensor(input_lengths,
dtype=torch.int32).cuda()
sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=end_id, pad_id=pad_id, top_k=top_k, num_beams=num_beams)
with torch.no_grad():
tensorrt_llm_gpt.setup(batch_size,
max_context_length=max_length,
max_new_tokens=output_len,
beam_width=num_beams)
if tensorrt_llm_gpt.remove_input_padding:
output_ids = tensorrt_llm_gpt.decode_batch(
line_encoded, sampling_config)
else:
output_ids = tensorrt_llm_gpt.decode(
line_encoded,
input_lengths,
sampling_config,
)
torch.cuda.synchronize()
# Extract a list of tensors of shape beam_width x output_ids.
if tensorrt_llm_gpt.mapping.is_first_pp_rank():
output_beams_list = [
tokenizer.batch_decode(output_ids[batch_idx, :,
input_lengths[batch_idx]:],
skip_special_tokens=True)
for batch_idx in range(batch_size)
]
return output_beams_list, output_ids[:, :, max_length:].tolist()
return [], []
def summarize_hf(datapoint):
batch_size = len(datapoint['article'])
if batch_size > 1:
logger.warning(
f"HF does not support batch_size > 1 to verify correctness due to padding. Current batch size is {batch_size}"
)
line = copy.copy(datapoint['article'])
for i in range(batch_size):
line[i] = line[i] + ' TL;DR: '
line[i] = line[i].strip()
line[i] = line[i].replace(" n't", "n't")
line_encoded = tokenizer(line,
return_tensors='pt',
padding=True,
truncation=True)["input_ids"].type(torch.int64)
line_encoded = line_encoded[:, -test_token_num:]
line_encoded = line_encoded.cuda()
with torch.no_grad():
output = model.generate(line_encoded,
max_length=len(line_encoded[0]) +
output_len,
top_k=top_k,
temperature=temperature,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
num_beams=num_beams,
num_return_sequences=num_beams,
early_stopping=True)
tokens_list = output[:, len(line_encoded[0]):].tolist()
output = output.reshape([batch_size, num_beams, -1])
output_lines_list = [
tokenizer.batch_decode(output[:, i, len(line_encoded[0]):],
skip_special_tokens=True)
for i in range(num_beams)
]
return output_lines_list, tokens_list
if test_trt_llm:
datapoint = dataset_cnn['test'][0:1]
summary, _ = summarize_tensorrt_llm(datapoint)
if runtime_rank == 0:
logger.info(
"---------------------------------------------------------")
logger.info("TensorRT-LLM Generated : ")
logger.info(f" Article : {datapoint['article']}")
logger.info(f"\n Highlights : {datapoint['highlights']}")
logger.info(f"\n Summary : {summary}")
logger.info(
"---------------------------------------------------------")
if test_hf:
datapoint = dataset_cnn['test'][0:1]
summary, _ = summarize_hf(datapoint)
logger.info("---------------------------------------------------------")
logger.info("HF Generated : ")
logger.info(f" Article : {datapoint['article']}")
logger.info(f"\n Highlights : {datapoint['highlights']}")
logger.info(f"\n Summary : {summary}")
logger.info("---------------------------------------------------------")
metric_tensorrt_llm = [load_metric("rouge") for _ in range(num_beams)]
metric_hf = [load_metric("rouge") for _ in range(num_beams)]
for i in range(num_beams):
metric_tensorrt_llm[i].seed = 0
metric_hf[i].seed = 0
ite_count = 0
data_point_idx = 0
while (data_point_idx < len(dataset_cnn['test'])) and (ite_count <
args.max_ite):
if runtime_rank == 0:
logger.debug(
f"run data_point {data_point_idx} ~ {data_point_idx + max_batch_size}"
)
datapoint = dataset_cnn['test'][data_point_idx:(data_point_idx +
max_batch_size)]
if test_trt_llm:
profiler.start('tensorrt_llm')
summary_tensorrt_llm, tokens_tensorrt_llm = summarize_tensorrt_llm(
datapoint)
profiler.stop('tensorrt_llm')
if test_hf:
profiler.start('hf')
summary_hf, tokens_hf = summarize_hf(datapoint)
profiler.stop('hf')
if runtime_rank == 0:
if test_trt_llm:
for batch_idx in range(len(summary_tensorrt_llm)):
for beam_idx in range(num_beams):
metric_tensorrt_llm[beam_idx].add_batch(
predictions=[
summary_tensorrt_llm[batch_idx][beam_idx]
],
references=[datapoint['highlights'][batch_idx]])
if test_hf:
for beam_idx in range(num_beams):
for batch_idx in range(len(summary_hf[beam_idx])):
metric_hf[beam_idx].add_batch(
predictions=[summary_hf[beam_idx][batch_idx]],
references=[datapoint['highlights'][batch_idx]])
logger.debug('-' * 100)
logger.debug(f"Article : {datapoint['article']}")
if test_trt_llm:
logger.debug(f'TensorRT-LLM Summary: {summary_tensorrt_llm}')
if test_hf:
logger.debug(f'HF Summary: {summary_hf}')
logger.debug(f"highlights : {datapoint['highlights']}")
data_point_idx += max_batch_size
ite_count += 1
if runtime_rank == 0:
if test_trt_llm:
np.random.seed(0) # rouge score use sampling to compute the score
logger.info(
f'TensorRT-LLM (total latency: {profiler.elapsed_time_in_sec("tensorrt_llm")} sec)'
)
for beam_idx in range(num_beams):
logger.info(f"TensorRT-LLM beam {beam_idx} result")
computed_metrics_tensorrt_llm = metric_tensorrt_llm[
beam_idx].compute()
for key in computed_metrics_tensorrt_llm.keys():
logger.info(
f' {key} : {computed_metrics_tensorrt_llm[key].mid[2]*100}'
)
if args.check_accuracy and beam_idx == 0:
assert computed_metrics_tensorrt_llm['rouge1'].mid[
2] * 100 > args.tensorrt_llm_rouge1_threshold
if test_hf:
np.random.seed(0) # rouge score use sampling to compute the score
logger.info(
f'Hugging Face (total latency: {profiler.elapsed_time_in_sec("hf")} sec)'
)
for beam_idx in range(num_beams):
logger.info(f"HF beam {beam_idx} result")
computed_metrics_hf = metric_hf[beam_idx].compute()
for key in computed_metrics_hf.keys():
logger.info(
f' {key} : {computed_metrics_hf[key].mid[2]*100}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--hf_model_location', type=str, default='opt-350m')
parser.add_argument('--test_hf', action='store_true')
parser.add_argument('--test_trt_llm', action='store_true')
parser.add_argument('--data_type',
type=str,
choices=['fp32', 'fp16'],
default='fp32')
parser.add_argument('--dataset_path', type=str, default="")
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--engine_dir', type=str, default='gpt_outputs')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--max_ite', type=int, default=20)
parser.add_argument('--check_accuracy', action='store_true')
parser.add_argument('--tensorrt_llm_rouge1_threshold',
type=float,
default=15.0)
parser.add_argument('--num_beams', type=int, default=1)
parser.add_argument('--top_k', type=int, default=1)
args = parser.parse_args()
main(args)