Skip to content

Commit

Permalink
run pre-commit again
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Dec 13, 2024
1 parent 028b4a9 commit 2876528
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
20 changes: 10 additions & 10 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
td_spec_module = tuner.get_td_spec(ir_module, compilation_info)
assert td_spec_module

named_sequence_ops: list[transform.NamedSequenceOp] = (
op_matchers.get_ops_from_module(
module=td_spec_module,
fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp),
)
named_sequence_ops: list[
transform.NamedSequenceOp
] = op_matchers.get_ops_from_module(
module=td_spec_module,
fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp),
)
apply_config_sequence = None
matcher_sequence = None
Expand Down Expand Up @@ -176,11 +176,11 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None:
td_spec_module = tuner.get_td_spec(ir_module, compilation_info)
assert td_spec_module

named_sequence_ops: list[transform.NamedSequenceOp] = (
op_matchers.get_ops_from_module(
module=td_spec_module,
fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp),
)
named_sequence_ops: list[
transform.NamedSequenceOp
] = op_matchers.get_ops_from_module(
module=td_spec_module,
fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp),
)
apply_config_sequence = None
matcher_sequence = None
Expand Down
19 changes: 10 additions & 9 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,9 +1017,10 @@ def parse_dispatch_benchmark_results(
benchmark_time = res.get_mean_time_us()
assert benchmark_time is not None
candidate_trackers[candidate_id].first_benchmark_time = benchmark_time
candidate_trackers[candidate_id].spec_path = (
path_config.specs_dir
/ path_config.get_candidate_spec_filename(candidate_id)
candidate_trackers[
candidate_id
].spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename(
candidate_id
)
mlir_path = candidate_trackers[candidate_id].dispatch_mlir_path
spec_path = candidate_trackers[candidate_id].spec_path
Expand Down Expand Up @@ -1283,9 +1284,9 @@ def parse_model_benchmark_results(
]

dump_list = []
incomplete_list: list[tuple[int, Optional[str]]] = (
[]
) # format: [(candidate_id, device_id)]
incomplete_list: list[
tuple[int, Optional[str]]
] = [] # format: [(candidate_id, device_id)]

baseline_time = None
for same_device_results in grouped_benchmark_results:
Expand Down Expand Up @@ -1337,9 +1338,9 @@ def parse_model_benchmark_results(
calibrated_benchmark_diff = (
benchmark_time - baseline_time
) / baseline_time
candidate_trackers[candidate_id].calibrated_benchmark_diff = (
calibrated_benchmark_diff
)
candidate_trackers[
candidate_id
].calibrated_benchmark_diff = calibrated_benchmark_diff
else:
calibrated_benchmark_diff = None

Expand Down

0 comments on commit 2876528

Please sign in to comment.