Skip to content

Commit

Permalink
Maintain pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing committed Aug 20, 2024
1 parent 5abe57f commit 3c7a57d
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 297 deletions.
101 changes: 13 additions & 88 deletions tuning/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_model_benchmark_command(


@dataclass
class TaskTuple:
class TaskPack:
args: argparse.Namespace
candidate_id: int
command: list[str]
Expand Down Expand Up @@ -228,15 +228,15 @@ def get_median_time(self) -> Optional[float]:
return None


def generate_sample_DBR(
def generate_display_DBR(
candidate_id: int = 0, mean_time: float = random.uniform(100.0, 500.0)
) -> str:
"""Generate dispatch_benchmark_result string for displaying"""
# time unit is implicit and dependent on the output of iree-benchmark-module
return f"{candidate_id}\tMean Time: {mean_time:.1f}\n"


def generate_sample_MBR(
def generate_display_MBR(
candidate_vmfb_path_str: str = "baseline.vmfb",
device_id: str = "0",
t1: float = random.uniform(100.0, 500.0),
Expand All @@ -253,80 +253,6 @@ def generate_sample_MBR(
return head_str + res_str


# @dataclass
# class ModelBenchmarkResult:
# result_str: Optional[str] = None

# def get_tokens(self) -> list[str]:
# # e.g. ['Benchmarking:', '/sdxl-scripts/tuning/tuning_2024_07_19_08_55/unet_candidate_12.vmfb', 'on', 'device', '4', 'BM_main/process_time/real_time_median', '65.3', 'ms', '66.7', 'ms', '5', 'items_per_second=15.3201/s']
# if self.result_str is None:
# return []
# try:
# return self.result_str.split()
# except:
# return []

# def get_model_candidate_path(self) -> Optional[str]:
# if len(self.get_tokens()) < 2:
# return None
# return self.get_tokens()[1]

# def get_candidate_id(self) -> Optional[int]:
# if self.get_model_candidate_path():
# try:
# path_str = self.get_model_candidate_path()
# return int(path_str.split("_")[-1].split(".")[0]) if path_str else None
# except ValueError:
# return None
# return None

# def get_device_id(self) -> Optional[int]:
# if len(self.get_tokens()) < 5:
# return None
# try:
# return int(self.get_tokens()[4])
# except ValueError:
# return None

# def get_benchmark_time(self) -> Optional[int | float]:
# if len(self.get_tokens()) < 7:
# return None
# try:
# return float(self.get_tokens()[6])
# except ValueError:
# return None

# def get_calibrated_result_str(self, change: float) -> Optional[str]:
# if self.result_str is None:
# return self.result_str

# benchmark_time = self.get_benchmark_time()
# if benchmark_time is None:
# return self.result_str

# # Format the change to be added to the string
# percentage_change = change * 100
# change_str = f"({percentage_change:+.3f}%)"

# # Use regex to find and replace the old benchmark time with the new one
# new_result_str = re.sub(
# r"(\d+(\.\d+)?)\s*ms",
# lambda m: f"{self.get_benchmark_time()} ms {change_str}",
# self.result_str,
# count=1,
# )

# return new_result_str

# def generate_sample_result(
# self,
# candidate_vmfb_path_str: str = "unet_baseline.vmfb",
# device_id: int = 0,
# t1: float = random.uniform(100.0, 500.0), # time in ms
# ) -> str:
# return f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}\nBM_run_forward/process_time/real_time_median\t {t1:.3g} ms\t {(t1+1):.3g} ms\t 5 items_per_second={t1/200:5f}/s\n\n"


def extract_driver_names(user_devices: list[str]) -> set[str]:
"""Extract driver names from the user devices"""
return {device.split("://")[0] for device in user_devices}
Expand Down Expand Up @@ -604,7 +530,7 @@ def run_command(
return result


def run_command_wrapper(task_tuple: TaskTuple) -> TaskResult:
def run_command_wrapper(task_tuple: TaskPack) -> TaskResult:
"""pool.imap_unordered can't iterate an iterable of iterables input, this function helps dividing arguments"""
if task_tuple.command_need_device_id:
# worker searches for special symbol and substitute to correct device_id
Expand Down Expand Up @@ -849,7 +775,7 @@ def compile_dispatches(
return []

task_list = [
TaskTuple(
TaskPack(
args,
candidate_id=i,
command=tuning_client.get_dispatch_compile_command(candidate_trackers[i]),
Expand Down Expand Up @@ -933,7 +859,7 @@ def parse_dispatch_benchmark_results(
mlir_path = candidate_trackers[candidate_id].dispatch_mlir_path
spec_path = candidate_trackers[candidate_id].spec_path
assert mlir_path is not None and spec_path is not None
dump_list.append(generate_sample_DBR(candidate_id, benchmark_time))
dump_list.append(generate_display_DBR(candidate_id, benchmark_time))

benchmark_result_configs.append(
(
Expand Down Expand Up @@ -1015,7 +941,7 @@ def benchmark_dispatches(
else:
# Benchmarking dispatch candidates
task_list = [
TaskTuple(
TaskPack(
args,
candidate_id=i,
command=tuning_client.get_dispatch_benchmark_command(
Expand Down Expand Up @@ -1096,7 +1022,7 @@ def compile_models(
return []

task_list = [
TaskTuple(
TaskPack(
args,
candidate_id=i,
command=tuning_client.get_model_compile_command(candidate_trackers[i]),
Expand Down Expand Up @@ -1183,7 +1109,6 @@ def group_benchmark_results_by_device_id(


def parse_model_benchmark_results(
path_config: PathConfig,
candidate_trackers: list[CandidateTracker],
candidate_results: list[TaskResult],
baseline_results: list[TaskResult],
Expand Down Expand Up @@ -1232,7 +1157,7 @@ def parse_model_benchmark_results(
candidate_id
].compiled_model_path
assert baseline_vmfb_path is not None
dump_str = generate_sample_MBR(
dump_str = generate_display_MBR(
candidate_vmfb_path_str=baseline_vmfb_path.as_posix(),
device_id=device_id,
t1=benchmark_time,
Expand All @@ -1259,7 +1184,7 @@ def parse_model_benchmark_results(
# Collect candidate dump str
candidate_vmfb_path = candidate_trackers[candidate_id].compiled_model_path
assert candidate_vmfb_path is not None
dump_str = generate_sample_MBR(
dump_str = generate_display_MBR(
candidate_vmfb_path_str=candidate_vmfb_path.as_posix(),
device_id=device_id,
t1=benchmark_time,
Expand Down Expand Up @@ -1302,7 +1227,7 @@ def benchmark_models(
# Benchmarking model candidates
worker_context_queue = create_worker_context_queue(args.devices)
benchmark_task_list = [
TaskTuple(
TaskPack(
args,
candidate_id=i,
command=tuning_client.get_model_benchmark_command(
Expand All @@ -1326,7 +1251,7 @@ def benchmark_models(
candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb
worker_context_queue = create_worker_context_queue(args.devices)
baseline_task_list = [
TaskTuple(
TaskPack(
args,
candidate_id=0,
command=tuning_client.get_model_benchmark_command(
Expand All @@ -1345,7 +1270,7 @@ def benchmark_models(
)

dump_list = parse_model_benchmark_results(
path_config, candidate_trackers, candidate_results, baseline_results
candidate_trackers, candidate_results, baseline_results
)

append_to_file(
Expand Down
Loading

0 comments on commit 3c7a57d

Please sign in to comment.