Skip to content

Commit

Permalink
feat: Compute the optimal threshold as average between the max good a…
Browse files Browse the repository at this point in the history
…nd min bad score when the F1 is 1 (#23)

* refactor: Compute the optimal threshold as average between the max good and min bad score when the F1 is 1

* build: Upgrade version, update changelog
  • Loading branch information
lorenzomammana authored Apr 22, 2024
1 parent 04e2db7 commit 5fa796d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [v0.7.0+obx.1.3.1]

### Updated

- Compute the optimal threshold as average between the max good and min bad score when the F1 is 1

## [v0.7.0+obx.1.3.0]

### Updated
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# SPDX-License-Identifier: Apache-2.0

anomalib_version = "0.7.0"
custom_orobix_version = "1.3.0"
custom_orobix_version = "1.3.1"

__version__ = f"{anomalib_version}+obx.{custom_orobix_version}"
20 changes: 17 additions & 3 deletions src/anomalib/utils/metrics/anomaly_score_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,31 @@ def compute(self) -> Tensor:
Value of the F1 score at the optimal threshold.
"""
current_targets = torch.concat(self.target)
current_preds = torch.concat(self.preds)

epsilon = 1e-3

if len(current_targets.unique()) == 1:
if current_targets.max() == 0:
self.value = torch.concat(self.preds).max() + epsilon
self.value = torch.concat(current_preds).max() + epsilon
else:
self.value = torch.concat(self.preds).min()
self.value = torch.concat(current_preds).min() - epsilon
else:
precision, recall, thresholds = super().compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
self.value = thresholds[torch.argmax(f1_score)]
optimal_f1_score = torch.max(f1_score)

if thresholds.nelement() == 1:
# Particular case when f1 score is 1 and the threshold is unique
self.value = thresholds
else:
if optimal_f1_score == 1:
# If there is a good boundary between good and bads we pick the average of the highest good
# and lowest bad
max_good_score = current_preds[torch.where(current_targets == 0)].max()
min_bad_score = current_preds[torch.where(current_targets == 1)].min()
self.value = (max_good_score + min_bad_score) / 2
else:
self.value = thresholds[torch.argmax(f1_score)]

return self.value
18 changes: 14 additions & 4 deletions src/anomalib/utils/metrics/optimal_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,36 @@ def compute(self) -> Tensor:
recall: torch.Tensor
thresholds: torch.Tensor
current_targets = torch.concat(self.precision_recall_curve.target)
current_preds = torch.concat(self.precision_recall_curve.preds)

epsilon = 1e-3
if len(current_targets.unique()) == 1:
optimal_f1_score = torch.tensor(1.0)

if current_targets.max() == 0:
self.threshold = torch.concat(self.precision_recall_curve.preds).max() + epsilon
self.threshold = current_preds.max() + epsilon
else:
self.threshold = torch.concat(self.precision_recall_curve.preds).min()
self.threshold = current_preds.min() - epsilon

return optimal_f1_score
else:
precision, recall, thresholds = self.precision_recall_curve.compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
optimal_f1_score = torch.max(f1_score)

if thresholds.nelement() == 1:
# Particular case when f1 score is 1 and the threshold is unique
self.threshold = thresholds
else:
self.threshold = thresholds[torch.argmax(f1_score)]
optimal_f1_score = torch.max(f1_score)
if optimal_f1_score == 1:
# If there is a good boundary between good and bads we pick the average of the highest good
# and lowest bad
max_good_score = current_preds[torch.where(current_targets == 0)].max()
min_bad_score = current_preds[torch.where(current_targets == 1)].min()
self.threshold = (max_good_score + min_bad_score) / 2
else:
self.threshold = thresholds[torch.argmax(f1_score)]

return optimal_f1_score

def reset(self) -> None:
Expand Down

0 comments on commit 5fa796d

Please sign in to comment.