Skip to content

Commit

Permalink
Merge pull request #7 from annahedstroem/new-feature-batch-size
Browse files Browse the repository at this point in the history
include batch size as arg meta_evaluation.py
  • Loading branch information
annahedstroem authored Mar 27, 2024
2 parents 95cdb98 + 436efaf commit dcb86c2
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 1 deletion.
5 changes: 5 additions & 0 deletions metaquantus/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
keep_results: bool = False,
channel_first: Optional[bool] = True,
softmax: Optional[bool] = False,
batch_size: Optional[int] = 64,
device: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -56,6 +57,8 @@ def __init__(
Indicates if channels is first.
softmax: bool
Indicates if the softmax (or logits) are used.
batch_size: int
The batch size to run Quantus evaluation with.
device: torch.device
The device used, to enable GPUs.
Expand All @@ -73,6 +76,7 @@ def __init__(
self.write_to_file = write_to_file
self.channel_first = channel_first
self.softmax = softmax
self.batch_size = batch_size
self.device = device
self.name = self.master.fname

Expand Down Expand Up @@ -127,6 +131,7 @@ def __call__(self, *args, **kwargs) -> dict:
s_batch=settings_data["s_batch"],
channel_first=self.channel_first,
softmax=self.softmax,
batch_size=self.batch_size,
device=self.device,
score_direction=estimator["score_direction"],
)
Expand Down
7 changes: 7 additions & 0 deletions metaquantus/meta_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __call__(
channel_first: Optional[bool] = True,
softmax: Optional[bool] = False,
device: Optional[str] = None,
batch_size: Optional[int] = 64,
model_predict_kwargs: Optional[Dict[str, Any]] = {},
score_direction: Optional[str] = None,
):
Expand All @@ -154,6 +155,8 @@ def __call__(
Indicates if channels is first.
softmax: bool
Indicates if the softmax (or logits) are used.
batch_size: int
The batch size to run Quantus evaluation with.
device: torch.device
The device used, to enable GPUs.
model_predict_kwargs: dict
Expand All @@ -179,6 +182,7 @@ def __call__(
channel_first=channel_first,
softmax=softmax,
device=device,
batch_size=batch_size,
)

# Run inference.
Expand Down Expand Up @@ -255,6 +259,7 @@ def run_perturbation_analysis(
s_batch: Union[np.array, None] = None,
channel_first: Optional[bool] = True,
softmax: Optional[bool] = False,
batch_size: Optional[int] = 64,
device: Optional[str] = None,
model_predict_kwargs: Optional[Dict[str, Any]] = {},
):
Expand Down Expand Up @@ -352,6 +357,7 @@ def run_perturbation_analysis(
},
model_predict_kwargs=model_predict_kwargs,
softmax=softmax,
batch_size=batch_size,
device=device,
)

Expand Down Expand Up @@ -397,6 +403,7 @@ def run_perturbation_analysis(
model_predict_kwargs=model_predict_kwargs,
softmax=softmax,
device=device,
batch_size=batch_size,
)

self.results_eval_scores_perturbed[test_name][i] = scores_perturbed
Expand Down
1 change: 1 addition & 0 deletions metaquantus/perturbation_tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __call__(
explain_func_kwargs: Optional[Dict[str, Any]],
model_predict_kwargs: Optional[Dict],
softmax: Optional[bool],
batch_size: Optional[int],
device: Optional[str],
) -> Union[int, float, list, dict, None]:
raise NotImplementedError
4 changes: 4 additions & 0 deletions metaquantus/perturbation_tests/ipt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __call__(
explain_func: Optional[Callable],
model_predict_kwargs: Optional[Dict],
softmax: Optional[bool],
batch_size: Optional[int],
device: Optional[str],
) -> Tuple[dict, np.ndarray, dict]:
"""
Expand Down Expand Up @@ -83,6 +84,8 @@ def __call__(
Extra kwargs when running model.predict.
softmax: bool
Indicates if the softmax (or logits) are used.
batch_size: int
The batch size to run Quantus evaluation with.
device: torch.device
The device used, to enable GPUs.
Expand Down Expand Up @@ -175,6 +178,7 @@ def __call__(
},
model_predict_kwargs=model_predict_kwargs,
channel_first=channel_first,
batch_size=batch_size,
softmax=softmax,
device=device,
)
Expand Down
4 changes: 4 additions & 0 deletions metaquantus/perturbation_tests/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __call__(
explain_func: Optional[Callable],
model_predict_kwargs: Optional[Dict],
softmax: Optional[bool],
batch_size: Optional[int],
device: Optional[str],
) -> Tuple[dict, np.ndarray, dict]:
"""
Expand Down Expand Up @@ -90,6 +91,8 @@ def __call__(
Extra kwargs when running model.predict.
softmax: bool
Indicates if the softmax (or logits) are used.
batch_size: int
The batch size to run Quantus evaluation with.
device: torch.device
The device used, to enable GPUs.
Expand Down Expand Up @@ -184,6 +187,7 @@ def __call__(
model_predict_kwargs=model_predict_kwargs,
channel_first=channel_first,
softmax=softmax,
batch_size=batch_size,
device=device,
)

Expand Down
2 changes: 1 addition & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
coverage>=7.2.3
black==22.10.0
numpy>=1.20.3
pytest>=6.2.5
pytest<=7.4.4 # pytest-lazyfixture issue https://github.com/annahedstroem/MetaQuantus/actions/runs/8449892962/job/23144947207
pytest-cov>=3.0.0
pytest-lazy-fixture>=0.6.3
1 change: 1 addition & 0 deletions tests/test_benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_benchmarking_mnist(
keep_results=True,
channel_first=True,
softmax=False,
batch_size=64,
device=device,
path=os.getcwd()+"tests/assets/results/",
save=False,
Expand Down
1 change: 1 addition & 0 deletions tests/test_meta_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_meta_evaluation_mnist(
s_batch=s_batch,
channel_first=True,
softmax=False,
batch_size=50,
device=device,
score_direction=score_direction,
)
Expand Down

0 comments on commit dcb86c2

Please sign in to comment.