diff --git a/dsmil.py b/dsmil.py index 80ff71c..d30276a 100644 --- a/dsmil.py +++ b/dsmil.py @@ -49,8 +49,8 @@ def forward(self, feats, c): # N x K, N x C Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted # handle multiple classes without for loop - _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C - m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K + _, m_indices = torch.max(c, dim=0) # sort class scores along the instance dimension, m_indices in shape N x C + m_feats = feats[m_indices, :] # select critical instances, m_feats in shape C x K q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C,