-
Notifications
You must be signed in to change notification settings - Fork 8
/
loss.py
156 lines (123 loc) · 6.64 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 6 19:48:54 2019
@author: viswanatha
"""
import torch.nn as nn
import torch.nn.functional as F
import torch
from utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MultiBoxLoss(nn.Module):
"""
The MultiBox loss, a loss function for object detection.
This is a combination of:
(1) a localization loss for the predicted locations of the boxes, and
(2) a confidence loss for the predicted class scores.
"""
def __init__(self, priors_cxcy, threshold=0.5, neg_pos_ratio=3, alpha=1.0):
super(MultiBoxLoss, self).__init__()
self.priors_cxcy = priors_cxcy
self.priors_xy = cxcy_to_xy(priors_cxcy)
self.threshold = threshold
self.neg_pos_ratio = neg_pos_ratio
self.alpha = alpha
self.smooth_l1 = nn.L1Loss()
self.cross_entropy = nn.CrossEntropyLoss(reduce=False)
def forward(self, predicted_locs, predicted_scores, boxes, labels):
"""
Forward propagation.
:param predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4)
:param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes)
:param boxes: true object bounding boxes in boundary coordinates, a list of N tensors
:param labels: true object labels, a list of N tensors
:return: multibox loss, a scalar
"""
batch_size = predicted_locs.size(0)
n_priors = self.priors_cxcy.size(0)
n_classes = predicted_scores.size(2)
assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)
true_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(
device
) # (N, 8732, 4)
true_classes = torch.zeros((batch_size, n_priors), dtype=torch.long).to(
device
) # (N, 8732)
# For each image
for i in range(batch_size):
n_objects = boxes[i].size(0)
overlap = find_jaccard_overlap(
boxes[i], self.priors_xy
) # (n_objects, 8732)
# For each prior, find the object that has the maximum overlap
overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0) # (8732)
# We don't want a situation where an object is not represented in our positive (non-background) priors -
# 1. An object might not be the best object for all priors, and is therefore not in object_for_each_prior.
# 2. All priors with the object may be assigned as background based on the threshold (0.5).
# To remedy this -
# First, find the prior that has the maximum overlap for each object.
_, prior_for_each_object = overlap.max(dim=1) # (N_o)
# Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
object_for_each_prior[prior_for_each_object] = torch.LongTensor(
range(n_objects)
).to(device)
# To ensure these priors qualify, artificially give them an overlap of greater than 0.5. (This fixes 2.)
overlap_for_each_prior[prior_for_each_object] = 1.0
# Labels for each prior
label_for_each_prior = labels[i][object_for_each_prior] # (8732)
# Set priors whose overlaps with objects are less than the threshold to be background (no object)
label_for_each_prior[overlap_for_each_prior < self.threshold] = 0 # (8732)
# Store
true_classes[i] = label_for_each_prior
# Encode center-size object coordinates into the form we regressed predicted boxes to
true_locs[i] = cxcy_to_gcxgcy(
xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy
) # (8732, 4)
# Identify priors that are positive (object/non-background)
positive_priors = true_classes != 0 # (N, 8732)
# LOCALIZATION LOSS
# Localization loss is computed only over positive (non-background) priors
loc_loss = self.smooth_l1(
predicted_locs[positive_priors], true_locs[positive_priors]
) # (), scalar
# Note: indexing with a torch.uint8 (byte) tensor flattens the tensor when indexing is across multiple dimensions (N & 8732)
# So, if predicted_locs has the shape (N, 8732, 4), predicted_locs[positive_priors] will have (total positives, 4)
# CONFIDENCE LOSS
# Confidence loss is computed over positive priors and the most difficult (hardest) negative priors in each image
# That is, FOR EACH IMAGE,
# we will take the hardest (neg_pos_ratio * n_positives) negative priors, i.e where there is maximum loss
# This is called Hard Negative Mining - it concentrates on hardest negatives in each image, and also minimizes pos/neg imbalance
# Number of positive and hard-negative priors per image
n_positives = positive_priors.sum(dim=1) # (N)
n_hard_negatives = self.neg_pos_ratio * n_positives # (N)
# First, find the loss for all priors
conf_loss_all = self.cross_entropy(
predicted_scores.view(-1, n_classes), true_classes.view(-1)
) # (N * 8732)
conf_loss_all = conf_loss_all.view(batch_size, n_priors) # (N, 8732)
# We already know which priors are positive
conf_loss_pos = conf_loss_all[positive_priors] # (sum(n_positives))
# Next, find which priors are hard-negative
# To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
conf_loss_neg = conf_loss_all.clone() # (N, 8732)
conf_loss_neg[
positive_priors
] = 0.0 # (N, 8732), positive priors are ignored (never in top n_hard_negatives)
conf_loss_neg, _ = conf_loss_neg.sort(
dim=1, descending=True
) # (N, 8732), sorted by decreasing hardness
hardness_ranks = (
torch.LongTensor(range(n_priors))
.unsqueeze(0)
.expand_as(conf_loss_neg)
.to(device)
) # (N, 8732)
hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1) # (N, 8732)
conf_loss_hard_neg = conf_loss_neg[hard_negatives] # (sum(n_hard_negatives))
# As in the paper, averaged over positive priors only, although computed over both positive and hard-negative priors
conf_loss = (
conf_loss_hard_neg.sum() + conf_loss_pos.sum()
) / n_positives.sum().float() # (), scalar
# TOTAL LOSS
return conf_loss + self.alpha * loc_loss