Skip to content

Commit

Permalink
[tuner] Fix typing issues in libtuner
Browse files Browse the repository at this point in the history
Make libtuner and its test type-check with mypy.

Signed-off-by: Jakub Kuderski <[email protected]>
  • Loading branch information
kuhar committed Nov 14, 2024
1 parent e9ba3ef commit dab54b2
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 38 deletions.
16 changes: 11 additions & 5 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import random
import json
from abc import ABC, abstractmethod
import iree.runtime as ireert
import iree.runtime as ireert # type: ignore
from . import candidate_gen


Expand Down Expand Up @@ -250,10 +250,11 @@ def get_mean_time_us(self) -> Optional[float]:
mean_benchmark = self.find_mean_benchmark(self.result_json)

if mean_benchmark:
real_time = mean_benchmark.get("real_time")
time_unit = mean_benchmark.get("time_unit")
real_time: float | None = mean_benchmark.get("real_time")
time_unit: str | None = mean_benchmark.get("time_unit")

if real_time is not None:
assert time_unit is not None
return self.unit_to_microseconds(real_time, time_unit)

return None
Expand Down Expand Up @@ -549,7 +550,7 @@ def create_worker_context_queue(device_ids: list[int]) -> queue.Queue[tuple[int,
return worker_contexts_queue


def run_command(run_pack: RunPack) -> TaskResult:
def run_command(run_pack: RunPack) -> RunResult:
command = run_pack.command
check = run_pack.check
timeout_seconds = run_pack.timeout_seconds
Expand Down Expand Up @@ -946,6 +947,7 @@ def parse_dispatch_benchmark_results(
continue

res_json = extract_benchmark_from_run_result(benchmark_result.run_result)
assert res_json is not None
res = IREEBenchmarkResult(candidate_id, res_json)
benchmark_time = res.get_mean_time_us()
assert benchmark_time is not None
Expand Down Expand Up @@ -985,7 +987,10 @@ def generate_sample_task_result(
stdout=stdout,
returncode=0,
)
return TaskResult(result=res, candidate_id=candidate_id, device_id=device_id)
run_result = RunResult(res, False)
return TaskResult(
run_result=run_result, candidate_id=candidate_id, device_id=device_id
)


def generate_dryrun_dispatch_benchmark_results(
Expand Down Expand Up @@ -1235,6 +1240,7 @@ def parse_model_benchmark_results(
continue

result_json = extract_benchmark_from_run_result(task_result.run_result)
assert result_json is not None
res = IREEBenchmarkResult(candidate_id, result_json)
benchmark_time = res.get_mean_time_us()
assert benchmark_time is not None
Expand Down
78 changes: 45 additions & 33 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import pytest
import json
from subprocess import CompletedProcess
from unittest.mock import call, patch, MagicMock
from . import libtuner

Expand All @@ -15,15 +16,15 @@
"""


def test_group_benchmark_results_by_device_id():
def test_group_benchmark_results_by_device_id() -> None:
# Create mock TaskResult objects with device_id attributes
task_result_1 = MagicMock()
task_result_1: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult)
task_result_1.device_id = "device_1"

task_result_2 = MagicMock()
task_result_2: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult)
task_result_2.device_id = "device_2"

task_result_3 = MagicMock()
task_result_3: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult)
task_result_3.device_id = "device_1"

benchmark_results = [task_result_1, task_result_2, task_result_3]
Expand All @@ -40,7 +41,7 @@ def test_group_benchmark_results_by_device_id():
assert grouped_results[1][0].device_id == "device_2"


def test_find_collisions():
def test_find_collisions() -> None:
input = [(1, "abc"), (2, "def"), (3, "abc")]
assert libtuner.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])])
input = [(1, "abc"), (2, "def"), (3, "hig")]
Expand All @@ -50,14 +51,14 @@ def test_find_collisions():
)


def test_collision_handler():
def test_collision_handler() -> None:
input = [(1, "abc"), (2, "def"), (3, "abc"), (4, "def"), (5, "hig")]
assert libtuner.collision_handler(input) == (True, [1, 2, 5])
input = [(1, "abc"), (2, "def"), (3, "hig")]
assert libtuner.collision_handler(input) == (False, [])


def test_IREEBenchmarkResult_get():
def test_IREEBenchmarkResult_get() -> None:
# Time is int in us
int_json = [{"aggregate_name": "mean", "real_time": 1, "time_unit": "us"}]

Expand Down Expand Up @@ -108,7 +109,7 @@ def test_IREEBenchmarkResult_get():
assert res.get_mean_time_us() == None

# Invalid json: empty dictionary
res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json={})
res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json=[])
assert res.get_mean_time_us() is None

# Invalid json: invalid time unit
Expand All @@ -131,7 +132,7 @@ def test_IREEBenchmarkResult_get():
assert res.get_mean_time_us() is None


def test_generate_display_BR():
def test_generate_display_BR() -> None:
output = libtuner.generate_display_DBR(1, 3.14)
expected = f"1\tMean Time: 3.1"
assert output == expected, "DispatchBenchmarkResult generates invalid sample string"
Expand All @@ -147,29 +148,38 @@ def test_generate_display_BR():
assert output == expected, "ModelBenchmarkResult generates invalid sample string"


def test_parse_dispatch_benchmark_results():
def make_mock_task_result() -> libtuner.TaskResult:
process: CompletedProcess = MagicMock(spec=CompletedProcess)
run_result = libtuner.RunResult(process, False)
task_result = libtuner.TaskResult(run_result, 0, "")
return task_result


def test_parse_dispatch_benchmark_results() -> None:
base_path = libtuner.Path("/mock/base/dir")
spec_dir = base_path / "specs"
path_config = libtuner.PathConfig()
object.__setattr__(path_config, "specs_dir", spec_dir)

mock_result_1 = MagicMock()
mock_result_1 = make_mock_task_result()
mock_json_1 = {
"benchmarks": [
{"aggregate_name": "mean", "real_time": 100.0, "time_unit": "us"}
]
}
assert mock_result_1.run_result.process_res is not None
mock_result_1.run_result.process_res.stdout = json.dumps(mock_json_1)
mock_result_1.candidate_id = 1
mock_result_2 = MagicMock()
mock_result_2 = make_mock_task_result()
mock_json_2 = {
"benchmarks": [
{"aggregate_name": "mean", "real_time": 200.0, "time_unit": "us"}
]
}
assert mock_result_2.run_result.process_res is not None
mock_result_2.run_result.process_res.stdout = json.dumps(mock_json_2)
mock_result_2.candidate_id = 2
mock_result_3 = MagicMock()
mock_result_3 = make_mock_task_result()
mock_json_3 = {
"benchmarks": [
{
Expand All @@ -179,11 +189,11 @@ def test_parse_dispatch_benchmark_results():
}
]
}
assert mock_result_3.run_result.process_res is not None
mock_result_3.run_result.process_res.stdout = json.dumps(mock_json_3)
mock_result_3.candidate_id = 3
mock_result_4 = MagicMock()
mock_result_4.run_result.process_res = None # Incomplete result
mock_result_4.candidate_id = 4
# Incomplete result.
mock_result_4 = libtuner.TaskResult(libtuner.RunResult(None, True), 4, "4")
benchmark_results = [mock_result_1, mock_result_2, mock_result_3, mock_result_4]

candidate_trackers = []
Expand Down Expand Up @@ -239,7 +249,7 @@ def test_parse_dispatch_benchmark_results():
)


def test_parse_model_benchmark_results():
def test_parse_model_benchmark_results() -> None:
# Setup mock data for candidate_trackers
tracker0 = libtuner.CandidateTracker(0)
tracker0.compiled_model_path = libtuner.Path("/path/to/baseline.vmfb")
Expand All @@ -256,38 +266,40 @@ def test_parse_model_benchmark_results():
candidate_trackers = [tracker0, tracker1, tracker2, tracker3]

# Setup mock data for task results
result1 = MagicMock()
result1 = make_mock_task_result()
result_json_1 = {"benchmarks": [{"real_time": 1.23}]}
assert result1.run_result.process_res is not None
result1.run_result.process_res.stdout = json.dumps(result_json_1)
result1.candidate_id = 1
result1.device_id = "device1"

result2 = MagicMock()
result2 = make_mock_task_result()
result_json_2 = {"benchmarks": [{"real_time": 4.56}]}
assert result2.run_result.process_res is not None
result2.run_result.process_res.stdout = json.dumps(result_json_2)
result2.candidate_id = 2
result2.device_id = "device2"

result3 = MagicMock()
result3 = make_mock_task_result()
result_json_3 = {"benchmarks": [{"real_time": 0.98}]}
assert result3.run_result.process_res is not None
result3.run_result.process_res.stdout = json.dumps(result_json_3)
result3.candidate_id = 0
result3.device_id = "device1"

result4 = MagicMock()
result4 = make_mock_task_result()
result_json_4 = {"benchmarks": [{"real_time": 4.13}]}
assert result4.run_result.process_res is not None
result4.run_result.process_res.stdout = json.dumps(result_json_4)
result4.candidate_id = 0
result4.device_id = "device2"

# Incomplete baseline on device3
result5 = MagicMock()
result5.run_result.process_res = None
result5.candidate_id = 0
result5.device_id = "device3"
result5 = libtuner.TaskResult(libtuner.RunResult(None, True), 0, "device3")

result6 = MagicMock()
result6 = make_mock_task_result()
result_json_6 = {"benchmarks": [{"real_time": 3.38}]}
assert result6.run_result.process_res is not None
result6.run_result.process_res.stdout = json.dumps(result_json_6)
result6.candidate_id = 3
result6.device_id = "device3"
Expand Down Expand Up @@ -347,14 +359,14 @@ def mock_get_mean_time_us(self):
)


def test_extract_driver_names():
def test_extract_driver_names() -> None:
user_devices = ["hip://0", "local-sync://default", "cuda://default"]
expected_output = {"hip", "local-sync", "cuda"}

assert libtuner.extract_driver_names(user_devices) == expected_output


def test_fetch_available_devices_success():
def test_fetch_available_devices_success() -> None:
drivers = ["hip", "local-sync", "cuda"]
mock_devices = {
"hip": [{"path": "ABCD", "device_id": 1}],
Expand Down Expand Up @@ -384,7 +396,7 @@ def get_mock_driver(name):
assert actual_output == expected_output


def test_fetch_available_devices_failure():
def test_fetch_available_devices_failure() -> None:
drivers = ["hip", "local-sync", "cuda"]
mock_devices = {
"hip": [{"path": "ABCD", "device_id": 1}],
Expand Down Expand Up @@ -421,7 +433,7 @@ def get_mock_driver(name):
)


def test_parse_devices():
def test_parse_devices() -> None:
user_devices_str = "hip://0, local-sync://default, cuda://default"
expected_output = ["hip://0", "local-sync://default", "cuda://default"]

Expand All @@ -432,7 +444,7 @@ def test_parse_devices():
mock_handle_error.assert_not_called()


def test_parse_devices_with_invalid_input():
def test_parse_devices_with_invalid_input() -> None:
user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default"
expected_output = [
"hip://0",
Expand All @@ -452,7 +464,7 @@ def test_parse_devices_with_invalid_input():
)


def test_validate_devices():
def test_validate_devices() -> None:
user_devices = ["hip://0", "local-sync://default"]
user_drivers = {"hip", "local-sync"}

Expand All @@ -469,7 +481,7 @@ def test_validate_devices():
)


def test_validate_devices_with_invalid_device():
def test_validate_devices_with_invalid_device() -> None:
user_devices = ["hip://0", "local-sync://default", "cuda://default"]
user_drivers = {"hip", "local-sync", "cuda"}

Expand Down

0 comments on commit dab54b2

Please sign in to comment.