diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_vmfb.py index 92313b32d..f1900e6e3 100644 --- a/sharktank/sharktank/evaluate/perplexity_vmfb.py +++ b/sharktank/sharktank/evaluate/perplexity_vmfb.py @@ -176,7 +176,7 @@ def get_prompts(self): s.replace("\n", "").rstrip() for s in test_prompts if s != "" and len(s.split()) >= 20 and s.count("=") < 2 - ][0:4] + ] logger.info(f" num_test_prompts: {len(test_prompts)}") @@ -208,7 +208,7 @@ def prefill_vmfb(self, token_batch, i): ) seq_block_ids = self.batch.pad_block_ids() - prefill_logits = self.runner.ctx.modules.module.prefill_bs4( + prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"]( token_batch, self.seq_lens_batch, seq_block_ids, @@ -239,7 +239,7 @@ def decode_vmfb(self, token_batch, i): self.batch.allocate_seq_block_ids() seq_block_ids = self.batch.pad_block_ids() - decode_logits = self.runner.ctx.modules.module.decode_bs4( + decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"]( token_batch, self.seq_lens_batch, start_positions,