diff --git a/sigpt/model/sample.py b/sigpt/model/sample.py index 7918fc9..686505c 100644 --- a/sigpt/model/sample.py +++ b/sigpt/model/sample.py @@ -12,6 +12,7 @@ def generate_tokens( ) -> torch.Tensor: generations = 0 with torch.no_grad(): + model.eval() while generations < max_samples: logits = model(idx) # (B, T, vocab_size) # We use the last token for prediction @@ -40,4 +41,4 @@ def generate( encoded_prompt = torch.tile(encoded_prompt, (batches, 1)) generated_tokens = generate_tokens(model, encoded_prompt, max_samples, k) - return [encoder.decode(i.tolist()) for i in generated_tokens] \ No newline at end of file + return [encoder.decode(i.tolist()) for i in generated_tokens]