Skip to content

Commit

Permalink
[example] Adds document about how to trace gpt2 model (#3028)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Mar 13, 2024
1 parent 8c5ed49 commit 7567277
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public static String generateTextWithPyTorchGreedy()
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);

// You can use src/main/python/trace_gpt2.py to trace gpt2 model
String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip";

Criteria<NDList, CausalLMOutput> criteria =
Expand Down Expand Up @@ -160,6 +161,20 @@ public static String[] generateTextWithOnnxRuntimeBeam()
long padTokenId = 220;
config.setPadTokenId(padTokenId);

// The model is converted optimum:
// https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-using-past-keysvalues-in-the-decoder
/*
* optimum-cli export onnx --model gpt2 gpt2_onnx/
*
* from transformers import AutoTokenizer
* from optimum.onnxruntime import ORTModelForCausalLM
*
* tokenizer = AutoTokenizer.from_pretrained("./gpt2_onnx/")
* model = ORTModelForCausalLM.from_pretrained("./gpt2_onnx/")
* inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")
* gen_tokens = model.generate(**inputs)
* print(tokenizer.batch_decode(gen_tokens))
*/
String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_onnx.zip";

Criteria<NDList, CausalLMOutput> criteria =
Expand Down
73 changes: 73 additions & 0 deletions examples/src/main/python/trace_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = 'gpt2-large'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True)

# %% model_inputs
output_attentions = False
output_hidden_states = False
model_inputs = {}

model_inputs['past_key_values'] = torch.load(
"../data/nested_tuple_" + model_name + ".pt")
past_seq = model_inputs['past_key_values'][0][0].shape[-2]
model_inputs['input_ids'] = torch.tensor([[404]])
model_inputs['position_ids'] = torch.tensor([[past_seq]])
# |attention_mask| = `len(past_key_values) + len(input_ids)`
model_inputs['attention_mask'] = torch.ones(past_seq + 1, dtype=torch.int64)

model_inputs['use_cache'] = True
model_inputs['token_type_ids'] = None

model_inputs['return_dict'] = False
model_inputs['output_attentions'] = False
model_inputs['output_hidden_states'] = False

# This is a testing of text generation
outputs = model(**model_inputs)

# %% Wrapper class of GPT2LMHeadModel
from typing import Tuple

class Tracable(torch.nn.Module):
def __init__(self, config: dict):
super().__init__()
self.model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True)
self.config = {'use_cache': config.get('use_cache', True),
'token_type_ids': config.get('token_type_ids', None),
'return_dict': config.get('return_dict', False),
'output_attentions': config.get('output_attentions', False),
'output_hidden_states': config.get('output_hidden_states', True)}

def forward(self, my_input_ids, position_ids, attention_mask, past_key_values):
return self.model(input_ids=my_input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
**self.config) # return_tensor = True

# %% create class
config = {}
tracable = Tracable(config)
input = (model_inputs['input_ids'],
model_inputs['position_ids'],
model_inputs['attention_mask'],
model_inputs['past_key_values'])

output = tracable(*input)

# %% trace
tracable.eval()

traced_model = torch.jit.trace(tracable, input)
torch.jit.save(traced_model, "../traced_GPT2_hidden.pt")

out1 = traced_model(*input)

# %% load back
loaded_model = torch.jit.load("../traced_GPT2_hidden.pt")
out2 = loaded_model(*input)

0 comments on commit 7567277

Please sign in to comment.