From 05dc4254e2ddbc5296af922b0d78e62078202b08 Mon Sep 17 00:00:00 2001 From: George Batchkala Date: Thu, 22 Jun 2023 16:31:15 +0100 Subject: [PATCH 1/2] find indices of maximum class scores instead of sorting (descending) and taking the 0th element --- dsmil.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsmil.py b/dsmil.py index 47c4bdc..c563081 100644 --- a/dsmil.py +++ b/dsmil.py @@ -50,8 +50,8 @@ def forward(self, feats, c): # N x K, N x C Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted # handle multiple classes without for loop - _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C - m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K + _, m_indices = torch.max(c, dim=0) # sort class scores along the instance dimension, m_indices in shape N x C + m_feats = torch.index_select(feats, dim=0, index=m_indices) # select critical instances, m_feats in shape C x K q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C, From 0da1493973da811e84bea0aea7e2c03cff35b8a5 Mon Sep 17 00:00:00 2001 From: George Batchkala Date: Thu, 22 Jun 2023 16:32:30 +0100 Subject: [PATCH 2/2] simplify the selection: slice instead of torch.index_select() --- dsmil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsmil.py b/dsmil.py index c563081..eacca65 100644 --- a/dsmil.py +++ b/dsmil.py @@ -51,7 +51,7 @@ def forward(self, feats, c): # N x K, N x C # handle multiple classes without for loop _, m_indices = torch.max(c, dim=0) # sort class scores along the instance dimension, m_indices in shape N x C - m_feats = torch.index_select(feats, dim=0, index=m_indices) # select critical instances, m_feats in shape C x K + m_feats = feats[m_indices, :] # select critical instances, m_feats in shape C x K q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C,