Skip to content

Commit

Permalink
Ensure model is in eval mode with sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
siranipour committed Aug 19, 2024
1 parent 09190ee commit e07a578
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sigpt/model/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def generate_tokens(
) -> torch.Tensor:
generations = 0
with torch.no_grad():
model.eval()
while generations < max_samples:
logits = model(idx) # (B, T, vocab_size)
# We use the last token for prediction
Expand Down Expand Up @@ -40,4 +41,4 @@ def generate(
encoded_prompt = torch.tile(encoded_prompt, (batches, 1))

generated_tokens = generate_tokens(model, encoded_prompt, max_samples, k)
return [encoder.decode(i.tolist()) for i in generated_tokens]
return [encoder.decode(i.tolist()) for i in generated_tokens]

0 comments on commit e07a578

Please sign in to comment.