Skip to content

Commit

Permalink
[tuner] Fix module management (#581)
Browse files Browse the repository at this point in the history
Module does not support context management
  • Loading branch information
kuhar authored Nov 21, 2024
1 parent dbff2e5 commit 5348a11
Showing 1 changed file with 41 additions and 45 deletions.
86 changes: 41 additions & 45 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,52 +517,48 @@ def tune(

with ir.Context() as ctx:
tuner_context = TunerContext(ctx, tune_logger)
with parse_mlir(mlir_text, tuner_context) as mlir_module:
# Save the input file as the first candidate.
with open(path.join(output, f"0.mlir"), "w") as f:
f.write(mlir_text)

dispatch_tuner_registry = DispatchTunerRegistry()
dispatch_tuner_registry.register(
[
MmtTuner(),
ConvTuner(),
ContractionTuner(lhs_dims, rhs_dims, tile_dims),
BatchMmtTuner(),
BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims),
]
)

walk_result: OpWalkResult = walk_mlir_op(
mlir_module, dispatch_tuner_registry
)
mlir_module = parse_mlir(mlir_text, tuner_context)
# Save the input file as the first candidate.
with open(path.join(output, f"0.mlir"), "w") as f:
f.write(mlir_text)

dispatch_tuner_registry = DispatchTunerRegistry()
dispatch_tuner_registry.register(
[
MmtTuner(),
ConvTuner(),
ContractionTuner(lhs_dims, rhs_dims, tile_dims),
BatchMmtTuner(),
BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims),
]
)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups)
):
if i >= limit:
break
tune_logger.info(f"Solution #{i+1}: {config}")
configs.append(config)
tf_mlir = dispatch_tuner.apply_params(
problem_size, mlir_template, config
)

with open(path.join(output, f"{i+1}.mlir"), "w") as f:
f.write(tf_mlir.modified)
with open(path.join(output, f"{i+1}_config.mlir"), "w") as f:
f.write(tf_mlir.embeddable)

with open(path.join(output, "configs.pkl"), "wb") as file:
pickle.dump(configs, file)

tune_logger.info(f"Generated {len(configs)} candidates")
tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")
walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups)
):
if i >= limit:
break
tune_logger.info(f"Solution #{i+1}: {config}")
configs.append(config)
tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config)

with open(path.join(output, f"{i+1}.mlir"), "w") as f:
f.write(tf_mlir.modified)
with open(path.join(output, f"{i+1}_config.mlir"), "w") as f:
f.write(tf_mlir.embeddable)

with open(path.join(output, "configs.pkl"), "wb") as file:
pickle.dump(configs, file)

tune_logger.info(f"Generated {len(configs)} candidates")
tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")


def main():
Expand Down

0 comments on commit 5348a11

Please sign in to comment.