From dab54b26d4fec418ecf1bbb98c04fb686586a086 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 14 Nov 2024 17:53:14 -0500 Subject: [PATCH] [tuner] Fix typing issues in libtuner Make libtuner and its test type-check with mypy. Signed-off-by: Jakub Kuderski --- tuner/tuner/libtuner.py | 16 +++++--- tuner/tuner/libtuner_test.py | 78 +++++++++++++++++++++--------------- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 91c7b417a..3aa932dc4 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -38,7 +38,7 @@ import random import json from abc import ABC, abstractmethod -import iree.runtime as ireert +import iree.runtime as ireert # type: ignore from . import candidate_gen @@ -250,10 +250,11 @@ def get_mean_time_us(self) -> Optional[float]: mean_benchmark = self.find_mean_benchmark(self.result_json) if mean_benchmark: - real_time = mean_benchmark.get("real_time") - time_unit = mean_benchmark.get("time_unit") + real_time: float | None = mean_benchmark.get("real_time") + time_unit: str | None = mean_benchmark.get("time_unit") if real_time is not None: + assert time_unit is not None return self.unit_to_microseconds(real_time, time_unit) return None @@ -549,7 +550,7 @@ def create_worker_context_queue(device_ids: list[int]) -> queue.Queue[tuple[int, return worker_contexts_queue -def run_command(run_pack: RunPack) -> TaskResult: +def run_command(run_pack: RunPack) -> RunResult: command = run_pack.command check = run_pack.check timeout_seconds = run_pack.timeout_seconds @@ -946,6 +947,7 @@ def parse_dispatch_benchmark_results( continue res_json = extract_benchmark_from_run_result(benchmark_result.run_result) + assert res_json is not None res = IREEBenchmarkResult(candidate_id, res_json) benchmark_time = res.get_mean_time_us() assert benchmark_time is not None @@ -985,7 +987,10 @@ def generate_sample_task_result( stdout=stdout, returncode=0, ) - return TaskResult(result=res, candidate_id=candidate_id, device_id=device_id) + run_result = RunResult(res, False) + return TaskResult( + run_result=run_result, candidate_id=candidate_id, device_id=device_id + ) def generate_dryrun_dispatch_benchmark_results( @@ -1235,6 +1240,7 @@ def parse_model_benchmark_results( continue result_json = extract_benchmark_from_run_result(task_result.run_result) + assert result_json is not None res = IREEBenchmarkResult(candidate_id, result_json) benchmark_time = res.get_mean_time_us() assert benchmark_time is not None diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index 36bda3bd5..11af59af4 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -7,6 +7,7 @@ import argparse import pytest import json +from subprocess import CompletedProcess from unittest.mock import call, patch, MagicMock from . import libtuner @@ -15,15 +16,15 @@ """ -def test_group_benchmark_results_by_device_id(): +def test_group_benchmark_results_by_device_id() -> None: # Create mock TaskResult objects with device_id attributes - task_result_1 = MagicMock() + task_result_1: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) task_result_1.device_id = "device_1" - task_result_2 = MagicMock() + task_result_2: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) task_result_2.device_id = "device_2" - task_result_3 = MagicMock() + task_result_3: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) task_result_3.device_id = "device_1" benchmark_results = [task_result_1, task_result_2, task_result_3] @@ -40,7 +41,7 @@ def test_group_benchmark_results_by_device_id(): assert grouped_results[1][0].device_id == "device_2" -def test_find_collisions(): +def test_find_collisions() -> None: input = [(1, "abc"), (2, "def"), (3, "abc")] assert libtuner.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])]) input = [(1, "abc"), (2, "def"), (3, "hig")] @@ -50,14 +51,14 @@ def test_find_collisions(): ) -def test_collision_handler(): +def test_collision_handler() -> None: input = [(1, "abc"), (2, "def"), (3, "abc"), (4, "def"), (5, "hig")] assert libtuner.collision_handler(input) == (True, [1, 2, 5]) input = [(1, "abc"), (2, "def"), (3, "hig")] assert libtuner.collision_handler(input) == (False, []) -def test_IREEBenchmarkResult_get(): +def test_IREEBenchmarkResult_get() -> None: # Time is int in us int_json = [{"aggregate_name": "mean", "real_time": 1, "time_unit": "us"}] @@ -108,7 +109,7 @@ def test_IREEBenchmarkResult_get(): assert res.get_mean_time_us() == None # Invalid json: empty dictionary - res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json={}) + res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json=[]) assert res.get_mean_time_us() is None # Invalid json: invalid time unit @@ -131,7 +132,7 @@ def test_IREEBenchmarkResult_get(): assert res.get_mean_time_us() is None -def test_generate_display_BR(): +def test_generate_display_BR() -> None: output = libtuner.generate_display_DBR(1, 3.14) expected = f"1\tMean Time: 3.1" assert output == expected, "DispatchBenchmarkResult generates invalid sample string" @@ -147,29 +148,38 @@ def test_generate_display_BR(): assert output == expected, "ModelBenchmarkResult generates invalid sample string" -def test_parse_dispatch_benchmark_results(): +def make_mock_task_result() -> libtuner.TaskResult: + process: CompletedProcess = MagicMock(spec=CompletedProcess) + run_result = libtuner.RunResult(process, False) + task_result = libtuner.TaskResult(run_result, 0, "") + return task_result + + +def test_parse_dispatch_benchmark_results() -> None: base_path = libtuner.Path("/mock/base/dir") spec_dir = base_path / "specs" path_config = libtuner.PathConfig() object.__setattr__(path_config, "specs_dir", spec_dir) - mock_result_1 = MagicMock() + mock_result_1 = make_mock_task_result() mock_json_1 = { "benchmarks": [ {"aggregate_name": "mean", "real_time": 100.0, "time_unit": "us"} ] } + assert mock_result_1.run_result.process_res is not None mock_result_1.run_result.process_res.stdout = json.dumps(mock_json_1) mock_result_1.candidate_id = 1 - mock_result_2 = MagicMock() + mock_result_2 = make_mock_task_result() mock_json_2 = { "benchmarks": [ {"aggregate_name": "mean", "real_time": 200.0, "time_unit": "us"} ] } + assert mock_result_2.run_result.process_res is not None mock_result_2.run_result.process_res.stdout = json.dumps(mock_json_2) mock_result_2.candidate_id = 2 - mock_result_3 = MagicMock() + mock_result_3 = make_mock_task_result() mock_json_3 = { "benchmarks": [ { @@ -179,11 +189,11 @@ def test_parse_dispatch_benchmark_results(): } ] } + assert mock_result_3.run_result.process_res is not None mock_result_3.run_result.process_res.stdout = json.dumps(mock_json_3) mock_result_3.candidate_id = 3 - mock_result_4 = MagicMock() - mock_result_4.run_result.process_res = None # Incomplete result - mock_result_4.candidate_id = 4 + # Incomplete result. + mock_result_4 = libtuner.TaskResult(libtuner.RunResult(None, True), 4, "4") benchmark_results = [mock_result_1, mock_result_2, mock_result_3, mock_result_4] candidate_trackers = [] @@ -239,7 +249,7 @@ def test_parse_dispatch_benchmark_results(): ) -def test_parse_model_benchmark_results(): +def test_parse_model_benchmark_results() -> None: # Setup mock data for candidate_trackers tracker0 = libtuner.CandidateTracker(0) tracker0.compiled_model_path = libtuner.Path("/path/to/baseline.vmfb") @@ -256,38 +266,40 @@ def test_parse_model_benchmark_results(): candidate_trackers = [tracker0, tracker1, tracker2, tracker3] # Setup mock data for task results - result1 = MagicMock() + result1 = make_mock_task_result() result_json_1 = {"benchmarks": [{"real_time": 1.23}]} + assert result1.run_result.process_res is not None result1.run_result.process_res.stdout = json.dumps(result_json_1) result1.candidate_id = 1 result1.device_id = "device1" - result2 = MagicMock() + result2 = make_mock_task_result() result_json_2 = {"benchmarks": [{"real_time": 4.56}]} + assert result2.run_result.process_res is not None result2.run_result.process_res.stdout = json.dumps(result_json_2) result2.candidate_id = 2 result2.device_id = "device2" - result3 = MagicMock() + result3 = make_mock_task_result() result_json_3 = {"benchmarks": [{"real_time": 0.98}]} + assert result3.run_result.process_res is not None result3.run_result.process_res.stdout = json.dumps(result_json_3) result3.candidate_id = 0 result3.device_id = "device1" - result4 = MagicMock() + result4 = make_mock_task_result() result_json_4 = {"benchmarks": [{"real_time": 4.13}]} + assert result4.run_result.process_res is not None result4.run_result.process_res.stdout = json.dumps(result_json_4) result4.candidate_id = 0 result4.device_id = "device2" # Incomplete baseline on device3 - result5 = MagicMock() - result5.run_result.process_res = None - result5.candidate_id = 0 - result5.device_id = "device3" + result5 = libtuner.TaskResult(libtuner.RunResult(None, True), 0, "device3") - result6 = MagicMock() + result6 = make_mock_task_result() result_json_6 = {"benchmarks": [{"real_time": 3.38}]} + assert result6.run_result.process_res is not None result6.run_result.process_res.stdout = json.dumps(result_json_6) result6.candidate_id = 3 result6.device_id = "device3" @@ -347,14 +359,14 @@ def mock_get_mean_time_us(self): ) -def test_extract_driver_names(): +def test_extract_driver_names() -> None: user_devices = ["hip://0", "local-sync://default", "cuda://default"] expected_output = {"hip", "local-sync", "cuda"} assert libtuner.extract_driver_names(user_devices) == expected_output -def test_fetch_available_devices_success(): +def test_fetch_available_devices_success() -> None: drivers = ["hip", "local-sync", "cuda"] mock_devices = { "hip": [{"path": "ABCD", "device_id": 1}], @@ -384,7 +396,7 @@ def get_mock_driver(name): assert actual_output == expected_output -def test_fetch_available_devices_failure(): +def test_fetch_available_devices_failure() -> None: drivers = ["hip", "local-sync", "cuda"] mock_devices = { "hip": [{"path": "ABCD", "device_id": 1}], @@ -421,7 +433,7 @@ def get_mock_driver(name): ) -def test_parse_devices(): +def test_parse_devices() -> None: user_devices_str = "hip://0, local-sync://default, cuda://default" expected_output = ["hip://0", "local-sync://default", "cuda://default"] @@ -432,7 +444,7 @@ def test_parse_devices(): mock_handle_error.assert_not_called() -def test_parse_devices_with_invalid_input(): +def test_parse_devices_with_invalid_input() -> None: user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default" expected_output = [ "hip://0", @@ -452,7 +464,7 @@ def test_parse_devices_with_invalid_input(): ) -def test_validate_devices(): +def test_validate_devices() -> None: user_devices = ["hip://0", "local-sync://default"] user_drivers = {"hip", "local-sync"} @@ -469,7 +481,7 @@ def test_validate_devices(): ) -def test_validate_devices_with_invalid_device(): +def test_validate_devices_with_invalid_device() -> None: user_devices = ["hip://0", "local-sync://default", "cuda://default"] user_drivers = {"hip", "local-sync", "cuda"}