diff --git a/train_gpt2.py b/train_gpt2.py index 403f213..83e0023 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -145,7 +145,7 @@ def from_pretrained(cls, model_type): config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints # create a from-scratch initialized minGPT model config = GPTConfig(**config_args) - model = GPT(config) + model = cls(config) sd = model.state_dict() sd_keys = sd.keys() sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param