Skip to content

Commit

Permalink
Add input len 2048 test for bs4 tp1 8b f16 non-decomposed
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 committed Dec 3, 2024
1 parent 7eccb76 commit 55c82a0
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def setUp(self):
self.prefill_args_bs4_128_in_tokens_f16 = (
self.artifacts_dir / "prefill_args_bs4_128"
)
self.prefill_args_bs4_2048_in_tokens_f16 = (
self.artifacts_dir / "prefill_args_bs4_2048"
)
self.decode_args_f16 = self.artifacts_dir / "decode_args"
self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8"
self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8"
Expand All @@ -129,6 +132,14 @@ def setUp(self):
f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy",
"--benchmark_repetitions=3",
]
self.iree_run_prefill_nondecomposed_args_fp16_2048 = [
"--function=prefill_bs4",
f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/tokens_2048.npy",
f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/seq_lens.npy",
f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/seq_block_ids.npy",
f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/cs_f16.npy",
"--benchmark_repetitions=3",
]
self.iree_run_decode_args = [
"--function=decode_bs4",
f"--input=@{self.decode_args_f16}/tokens.npy",
Expand Down Expand Up @@ -196,8 +207,42 @@ def testBenchmark8B_f16_Decomposed(self):
)

@skipif_run_quick_llama_test
def testBenchmark8B_f16_Non_Decomposed_Prefill(self):
output_file_name = self.dir_path_8b / "f16_torch_prefill"
def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_128(self):
output_file_name = self.dir_path_8b / "f16_torch_prefill_128"
output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
)
output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".json", prefix=output_file_name
)
output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".vmfb", prefix=output_file_name
)
self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch"
export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
skip_decode=True,
)
self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path,
args=self.iree_run_prefill_nondecomposed_args_fp16,
cwd=self.repo_root,
)

@skipif_run_quick_llama_test
def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_2048(self):
output_file_name = self.dir_path_8b / "f16_torch_prefill_2048"
output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
)
Expand Down

0 comments on commit 55c82a0

Please sign in to comment.