Skip to content

Commit

Permalink
Added tests. Fixed issues
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Nov 20, 2024
1 parent a804972 commit 8256243
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 8 deletions.
18 changes: 17 additions & 1 deletion tools/who_what_benchmark/tests/test_cli_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def run_wwb(args):
logger.info(" ".join(["TRANSFOREMRS_VERBOSITY=debug wwb"] + args))
result = subprocess.run(["wwb"] + args, capture_output=True, text=True)
logger.info(result)
print(" ".join(["TRANSFOREMRS_VERBOSITY=debug wwb"] + args))
return result


Expand Down Expand Up @@ -132,9 +131,26 @@ def test_image_model_genai(model_id, model_type):
output_dir,
]
result = run_wwb(wwb_args)
assert result.returncode == 0
assert os.path.exists(os.path.join(output_dir, "target"))
assert os.path.exists(os.path.join(output_dir, "target.csv"))

# test w/o models
wwb_args = [
"--target-data",
os.path.join(output_dir, "target.csv"),
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
]
result = run_wwb(wwb_args)
assert result.returncode == 0

try:
os.remove(GT_FILE)
except OSError:
Expand Down
21 changes: 19 additions & 2 deletions tools/who_what_benchmark/tests/test_cli_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def test_text_target_model():

@pytest.fixture
def test_text_gt_data():
with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile:
temp_file_name = tmpfile.name
temp_file_name = tempfile.NamedTemporaryFile(suffix=".csv").name

result = run_wwb(
[
Expand Down Expand Up @@ -107,6 +106,8 @@ def test_text_output_directory():
[
"--base-model",
base_model_path,
"--gt-data",
os.path.join(temp_dir, "gt.csv"),
"--target-model",
target_model_path,
"--num-samples",
Expand All @@ -123,6 +124,22 @@ def test_text_output_directory():
assert os.path.exists(os.path.join(temp_dir, "metrics.csv"))
assert os.path.exists(os.path.join(temp_dir, "target.csv"))

# test measurtement w/o models
result = run_wwb(
[
"--gt-data",
os.path.join(temp_dir, "gt.csv"),
"--target-data",
os.path.join(temp_dir, "target.csv"),
"--num-samples",
"2",
"--device",
"CPU",
]
)
assert result.returncode == 0
assert "Metrics for model" in result.stderr


def test_text_verbose():
result = run_wwb(
Expand Down
22 changes: 22 additions & 0 deletions tools/who_what_benchmark/tests/test_cli_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_vlm_basic(model_id, model_type):
)
assert result.returncode == 0

# Collect reference with HF model
wwb_args = [
"--base-model",
model_id,
Expand All @@ -53,6 +54,7 @@ def test_vlm_basic(model_id, model_type):
result = run_wwb(wwb_args)
assert result.returncode == 0

# test Optimum
wwb_args = [
"--target-model",
MODEL_PATH,
Expand All @@ -68,6 +70,7 @@ def test_vlm_basic(model_id, model_type):
result = run_wwb(wwb_args)
assert result.returncode == 0

# test GenAI
wwb_args = [
"--target-model",
MODEL_PATH,
Expand All @@ -80,6 +83,25 @@ def test_vlm_basic(model_id, model_type):
"--model-type",
model_type,
"--genai",
"--output",
"target",
]
result = run_wwb(wwb_args)
assert result.returncode == 0

# test w/o models
wwb_args = [
"--target-data",
"target/target.csv",
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
"--genai",
]
result = run_wwb(wwb_args)
assert result.returncode == 0
Expand Down
7 changes: 2 additions & 5 deletions tools/who_what_benchmark/whowhatbench/wwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,9 @@ def parse_args():
def check_args(args):
if args.base_model is None and args.gt_data is None:
raise ValueError("Wether --base-model or --gt-data should be provided")
if args.target_model is None and args.target_data is None:
if args.target_model is None and args.gt_data is None and args.target_data:
raise ValueError(
"Wether --target-model or --target-data should be provided")
if args.target_model is None and args.gt_data is None:
raise ValueError(
"Wether --target-model or --gt-data should be provided")
"Wether --target-model, --target-data or --gt-data should be provided")


def load_tokenizer(args):
Expand Down

0 comments on commit 8256243

Please sign in to comment.