diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index ee331a2a6..368b97d44 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -517,52 +517,48 @@ def tune( with ir.Context() as ctx: tuner_context = TunerContext(ctx, tune_logger) - with parse_mlir(mlir_text, tuner_context) as mlir_module: - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) - - walk_result: OpWalkResult = walk_mlir_op( - mlir_module, dispatch_tuner_registry - ) + mlir_module = parse_mlir(mlir_text, tuner_context) + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) - dispatch_tuner = walk_result.dispatch_tuner - assert dispatch_tuner, "No suitable dispatch tuner found" - problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate( - generate_solutions(tune_logger, problem_size, num_subgroups) - ): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params( - problem_size, mlir_template, config - ) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - with open(path.join(output, "configs.pkl"), "wb") as file: - pickle.dump(configs, file) - - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate( + generate_solutions(tune_logger, problem_size, num_subgroups) + ): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + with open(path.join(output, "configs.pkl"), "wb") as file: + pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") def main():