From 756727752c4493395119a0bb33f8b2cc205b3f26 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 13 Mar 2024 11:03:12 -0700 Subject: [PATCH] [example] Adds document about how to trace gpt2 model (#3028) --- .../inference/nlp/TextGeneration.java | 15 ++++ examples/src/main/python/trace_gpt2.py | 73 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 examples/src/main/python/trace_gpt2.py diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java index acbaa152f8c..59cba679ba2 100644 --- a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java @@ -59,6 +59,7 @@ public static String generateTextWithPyTorchGreedy() SearchConfig config = new SearchConfig(); config.setMaxSeqLength(60); + // You can use src/main/python/trace_gpt2.py to trace gpt2 model String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip"; Criteria criteria = @@ -160,6 +161,20 @@ public static String[] generateTextWithOnnxRuntimeBeam() long padTokenId = 220; config.setPadTokenId(padTokenId); + // The model is converted optimum: + // https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-using-past-keysvalues-in-the-decoder + /* + * optimum-cli export onnx --model gpt2 gpt2_onnx/ + * + * from transformers import AutoTokenizer + * from optimum.onnxruntime import ORTModelForCausalLM + * + * tokenizer = AutoTokenizer.from_pretrained("./gpt2_onnx/") + * model = ORTModelForCausalLM.from_pretrained("./gpt2_onnx/") + * inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt") + * gen_tokens = model.generate(**inputs) + * print(tokenizer.batch_decode(gen_tokens)) + */ String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_onnx.zip"; Criteria criteria = diff --git a/examples/src/main/python/trace_gpt2.py b/examples/src/main/python/trace_gpt2.py new file mode 100644 index 00000000000..33c3badb08d --- /dev/null +++ b/examples/src/main/python/trace_gpt2.py @@ -0,0 +1,73 @@ +import torch +from transformers import GPT2LMHeadModel, GPT2Tokenizer + +model_name = 'gpt2-large' +tokenizer = GPT2Tokenizer.from_pretrained(model_name) + +# add the EOS token as PAD token to avoid warnings +model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True) + +# %% model_inputs +output_attentions = False +output_hidden_states = False +model_inputs = {} + +model_inputs['past_key_values'] = torch.load( + "../data/nested_tuple_" + model_name + ".pt") +past_seq = model_inputs['past_key_values'][0][0].shape[-2] +model_inputs['input_ids'] = torch.tensor([[404]]) +model_inputs['position_ids'] = torch.tensor([[past_seq]]) +# |attention_mask| = `len(past_key_values) + len(input_ids)` +model_inputs['attention_mask'] = torch.ones(past_seq + 1, dtype=torch.int64) + +model_inputs['use_cache'] = True +model_inputs['token_type_ids'] = None + +model_inputs['return_dict'] = False +model_inputs['output_attentions'] = False +model_inputs['output_hidden_states'] = False + +# This is a testing of text generation +outputs = model(**model_inputs) + +# %% Wrapper class of GPT2LMHeadModel +from typing import Tuple + +class Tracable(torch.nn.Module): + def __init__(self, config: dict): + super().__init__() + self.model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True) + self.config = {'use_cache': config.get('use_cache', True), + 'token_type_ids': config.get('token_type_ids', None), + 'return_dict': config.get('return_dict', False), + 'output_attentions': config.get('output_attentions', False), + 'output_hidden_states': config.get('output_hidden_states', True)} + + def forward(self, my_input_ids, position_ids, attention_mask, past_key_values): + return self.model(input_ids=my_input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + **self.config) # return_tensor = True + +# %% create class +config = {} +tracable = Tracable(config) +input = (model_inputs['input_ids'], + model_inputs['position_ids'], + model_inputs['attention_mask'], + model_inputs['past_key_values']) + +output = tracable(*input) + +# %% trace +tracable.eval() + +traced_model = torch.jit.trace(tracable, input) +torch.jit.save(traced_model, "../traced_GPT2_hidden.pt") + +out1 = traced_model(*input) + +# %% load back +loaded_model = torch.jit.load("../traced_GPT2_hidden.pt") +out2 = loaded_model(*input)