diff --git a/experiments/run.py b/experiments/run.py index 68a09a2..40b1983 100755 --- a/experiments/run.py +++ b/experiments/run.py @@ -727,19 +727,14 @@ def benchmark(tmp_dir, logs_dir, prefix, plots_dir, num_parallel=1): rel_err_example = rel_errs_example[example_binary] print(f"Average Rel Error for {prefix}example.exe: {rel_err_example}") - budgets = [] - runtimes = [] - errors = [] - - optimized_binaries = [] + data_tuples = [] args_list = [(cost, tmp_dir, prefix) for cost in costs] if num_parallel == 1: for args in args_list: cost, output_binary = process_cost(args) - budgets.append(cost) - optimized_binaries.append(output_binary) + data_tuples.append((cost, output_binary)) else: with ProcessPoolExecutor(max_workers=num_parallel) as executor: future_to_cost = {executor.submit(process_cost, args): args[0] for args in args_list} @@ -747,30 +742,37 @@ def benchmark(tmp_dir, logs_dir, prefix, plots_dir, num_parallel=1): cost = future_to_cost[future] try: cost_result, output_binary = future.result() - budgets.append(cost_result) - optimized_binaries.append(output_binary) + data_tuples.append((cost_result, output_binary)) except Exception as exc: print(f"Cost {cost} generated an exception: {exc}") - # Now measure runtimes serially - for cost, output_binary in zip(budgets, optimized_binaries): + data_tuples_sorted = sorted(data_tuples, key=lambda x: x[0]) + sorted_budgets, sorted_optimized_binaries = zip(*data_tuples_sorted) if data_tuples_sorted else ([], []) + + # Measure runtimes serially based on sorted budgets + sorted_runtimes = [] + for cost, output_binary in zip(sorted_budgets, sorted_optimized_binaries): avg_runtime = measure_runtime(tmp_dir, prefix, output_binary, NUM_RUNS) if avg_runtime is not None: - runtimes.append(avg_runtime) + sorted_runtimes.append(avg_runtime) else: print(f"Skipping cost {cost} due to runtime measurement failure.") - runtimes.append(None) + sorted_runtimes.append(None) + + errors_dict = get_avg_rel_error(tmp_dir, prefix, golden_values_file, sorted_optimized_binaries) + sorted_errors = [] + for binary in sorted_optimized_binaries: + sorted_errors.append(errors_dict.get(binary)) + print(f"Average rel error for {binary}: {errors_dict.get(binary)}") - errors_dict = get_avg_rel_error(tmp_dir, prefix, golden_values_file, optimized_binaries) - errors = [] - for binary in optimized_binaries: - errors.append(errors_dict[binary]) - print(f"Average rel error for {binary}: {errors_dict[binary]}") + sorted_budgets = list(sorted_budgets) + sorted_runtimes = list(sorted_runtimes) + sorted_errors = list(sorted_errors) data = { - "budgets": budgets, - "runtimes": runtimes, - "errors": errors, + "budgets": sorted_budgets, + "runtimes": sorted_runtimes, + "errors": sorted_errors, "original_runtime": original_runtime, "original_error": rel_err_example, } @@ -782,9 +784,9 @@ def benchmark(tmp_dir, logs_dir, prefix, plots_dir, num_parallel=1): plot_results( plots_dir, prefix, - budgets, - runtimes, - errors, + sorted_budgets, + sorted_runtimes, + sorted_errors, original_runtime=original_runtime, original_error=rel_err_example, )