Skip to content

Commit

Permalink
Add benchmark support for e2e tests (#183)
Browse files Browse the repository at this point in the history
Signed-off-by: erman-gurses <[email protected]>
  • Loading branch information
erman-gurses authored Oct 3, 2024
1 parent 9ed388a commit 0f00c6d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
if: "contains(matrix.os, 'mi300') && !cancelled()"
run: |
export WAVE_RUN_E2E_TESTS=1
pytest -n 4 ./tests/kernel/wave/
pytest -n 4 --capture=tee-sys ./tests/kernel/wave/
- name: Run LIT tests
if: ${{ !cancelled() }}
Expand Down
27 changes: 26 additions & 1 deletion shark_turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import torch.fx as fx
import shark_turbine.kernel.lang as tkl


import tempfile
from ...support.conversions import TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM
from iree.compiler.dialects.transform import (
interpreter as transform_interpreter,
any_op_t,
Expand Down Expand Up @@ -372,7 +375,29 @@ def compile_and_invoke(
_invoke(ctx.vm_context, device, func, kernel_inputs, kernel_outputs)

if run_bench:
inputs = [inp.numpy() for inp in kernel_inputs]
bench_with_constant_weights = config.get("bench_with_constant_weights", False)
tempfiles = []
inputs = []
if bench_with_constant_weights:
for inp in kernel_inputs:
inputs.append(
"x".join(
[str(x) for x in inp.shape]
+ [TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM[inp.dtype]]
)
)
else:
for inp in kernel_inputs:
tf = tempfile.NamedTemporaryFile()
torch.save(inp, tf)
tempfiles.append(tf)
inputs.append("@" + tf.name)

benchmark_results = bench.benchmark_module(
mod,
entry_function=func_name,
)

benchmark_results = bench.benchmark_module(
mod,
entry_function=func_name,
Expand Down
8 changes: 8 additions & 0 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def test(
},
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
):
test(a, b)
Expand Down Expand Up @@ -214,6 +215,7 @@ def test(
},
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
):
test(a, b)
Expand Down Expand Up @@ -270,6 +272,7 @@ def test(
},
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
):
test(a, b)
Expand Down Expand Up @@ -326,6 +329,7 @@ def test(
},
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
):
test(a, b, c)
Expand Down Expand Up @@ -401,6 +405,7 @@ def repeat(
},
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
):
test(a, b, c)
Expand Down Expand Up @@ -505,6 +510,7 @@ def test(
},
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
):
test(a, b)
Expand Down Expand Up @@ -635,6 +641,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
},
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
):
gpu_func(x, we, out)
Expand Down Expand Up @@ -949,6 +956,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
},
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
):
conv(x, we, out)
Expand Down
1 change: 1 addition & 0 deletions tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
hyperparams,
canonicalize=True,
run=True,
run_bench=True,
run_config=config,
schedule=enable_scheduling,
):
Expand Down

0 comments on commit 0f00c6d

Please sign in to comment.