From e07a5788bfe3fcda97b8d8daa50f5722fc92d206 Mon Sep 17 00:00:00 2001 From: siranipour Date: Mon, 19 Aug 2024 19:18:38 +0100 Subject: [PATCH] Ensure model is in eval mode with sampling --- sigpt/model/sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sigpt/model/sample.py b/sigpt/model/sample.py index 7918fc9..686505c 100644 --- a/sigpt/model/sample.py +++ b/sigpt/model/sample.py @@ -12,6 +12,7 @@ def generate_tokens( ) -> torch.Tensor: generations = 0 with torch.no_grad(): + model.eval() while generations < max_samples: logits = model(idx) # (B, T, vocab_size) # We use the last token for prediction @@ -40,4 +41,4 @@ def generate( encoded_prompt = torch.tile(encoded_prompt, (batches, 1)) generated_tokens = generate_tokens(model, encoded_prompt, max_samples, k) - return [encoder.decode(i.tolist()) for i in generated_tokens] \ No newline at end of file + return [encoder.decode(i.tolist()) for i in generated_tokens]