-
Notifications
You must be signed in to change notification settings - Fork 8
/
sklearn_metrics_mask.py
77 lines (73 loc) · 2.41 KB
/
sklearn_metrics_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
mean_absolute_error,
)
import numpy as np
test_cases = [
(
"完全重疊 (Complete Overlap)",
np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]),
np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]),
),
(
"部分重疊 (Partial Overlap)",
np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]),
np.array([[0, 1, 1], [1, 0, 0], [0, 0, 1]]),
),
(
"不重疊 (No Overlap)",
np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]),
np.array([[0, 0, 1], [0, 0, 1], [1, 1, 0]]),
),
(
"邊界接觸 (Touching at Edges)",
np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]),
np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]]),
),
(
"小遮罩在大遮罩內 (Small Mask Inside Large Mask)",
np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]]),
np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]),
),
(
"交錯重疊 (Interleaved Overlap)",
np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]]),
np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]]),
),
(
"不同形狀 (Different Shapes)",
np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]),
np.array([[1, 0, 0], [1, 0, 0], [1, 1, 1]]),
),
(
"相似形狀但位置偏移 (Similar Shapes but Offset)",
np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]]),
np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]),
),
(
"大面積交疊 (Large Area Overlap)",
np.array([[1, 1, 1], [1, 1, 1], [0, 0, 0]]),
np.array([[1, 1, 0], [1, 1, 1], [1, 0, 0]]),
),
(
"一個遮罩全為零 (One Mask All Zero)",
np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
np.array([[1, 1, 1], [1, 0, 0], [0, 0, 1]]),
),
]
for description, true_mask, pred_mask in test_cases:
true_flat = true_mask.flatten()
pred_flat = pred_mask.flatten()
accuracy = accuracy_score(true_flat, pred_flat)
precision = precision_score(true_flat, pred_flat, zero_division=0, average="binary")
recall = recall_score(true_flat, pred_flat, zero_division=0, average="binary")
f1 = f1_score(true_flat, pred_flat, zero_division=0, average="binary")
print(f"{description}:")
print(f" Accuracy: {accuracy:.2f}")
print(f" Precision: {precision:.2f}")
print(f" Recall: {recall:.2f}")
print(f" F1-Score: {f1:.2f}")
print()