Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
shaunyuan22 authored Sep 18, 2023
1 parent 1a2f023 commit f1c7a8c
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions mmdet/models/roi_heads/feature_imitation_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self,
num_save_feats=300,
enc_output_dim=512,
proj_output_dim=128,
sim_mode='inner',
temperature=0.07,
ins_quality_assess_cfg=dict(
cls_score=0.00,
Expand Down Expand Up @@ -59,7 +58,6 @@ def __init__(self,
assert self.num_con_queue >= con_sampler_cfg['num']
self.con_sampler_cfg = con_sampler_cfg
self.con_sample_num = self.con_sampler_cfg['num']
self.sim_mode = sim_mode # unused now
self.temperature = temperature
self.iq_cls_score = ins_quality_assess_cfg['cls_score']
self.hq_score = ins_quality_assess_cfg['hq_score']
Expand All @@ -70,8 +68,6 @@ def __init__(self,
self._mkdir(con_queue_dir, num_gpus)
self.con_queue_dir = con_queue_dir
self.num_classes = num_classes
self.iq_score_json_pth = os.path.join(
self.con_queue_dir, 'state.json')
if aug_roi_extractor is None:
aug_roi_extractor = dict(
type='SingleRoIExtractor',
Expand Down Expand Up @@ -396,23 +392,38 @@ def _bbox_forward_train(self, x, assign_results, sampling_results,
roi_labels = torch.cat(
[proposal_labels[i], hq_labels, aug_gt_labels], dim=-1)
assert roi_labels.size(0) == con_roi_feats.size(0)
# sample_inds = (num_gts, self.con_sample_num)
# pos_signs = (num_gts, self.con_sample_num)
# for dense ground-truth situation, only a part of gt will be processed,
# which resembles the way of gt being handled in bbox_sampler
num_actual_gts = sampling_results[i].pos_is_gt.sum()
pos_assigned_gt_inds = sampling_results[i].pos_assigned_gt_inds
pos_is_gt = sampling_results[i].pos_is_gt.bool()
pos_assigned_actual_gt_inds = pos_assigned_gt_inds[pos_is_gt]
iq_scores = iq_scores[pos_assigned_actual_gt_inds]
iq_signs = iq_signs[pos_assigned_actual_gt_inds]
ex_pos_nums = ex_pos_nums[pos_assigned_actual_gt_inds]
labels = gt_labels[i][pos_assigned_actual_gt_inds]
sample_inds, pos_signs = self._sample(
iq_signs, ex_pos_nums, gt_labels[i], roi_labels, is_hq, aug_gt_ind, cur_sample_num)
iq_signs, ex_pos_nums, labels, roi_labels, is_hq, aug_gt_ind, cur_sample_num)
# anchor_feature: (num_gts, 256, 7, 7)
# contrast_feature: (num_gts, self.con_sample_num, 256, 7, 7)
anchor_feature = con_roi_feats[sampling_results[i].pos_inds[:num_gts]]
anchor_feature = con_roi_feats[:num_actual_gts]
contrast_feature = con_roi_feats[sample_inds]
assert anchor_feature.size(0) == contrast_feature.size(0)
iq_loss_weights = torch.ones_like(iq_scores)
for j, weight in enumerate(self.iq_loss_weights):
iq_loss_weights[iq_signs == j] *= weight
cur_signs = torch.nonzero(iq_signs == j).view(-1)
iq_loss_weights[cur_signs] = weight * iq_loss_weights[cur_signs]
loss = self.contrast_forward(anchor_feature, contrast_feature,
pos_signs, iq_loss_weights)
contrast_loss = self.contrast_loss_weights * loss
con_losses = con_losses + contrast_loss

# save high-quality features at last
# for dense ground-truth situation
pro_counts = pro_counts[pos_assigned_actual_gt_inds]
hq_inds = torch.nonzero((iq_scores >= self.hq_score) & \
(pro_counts >= self.hq_pro_counts_thr),
as_tuple=False).view(-1) # (N, )
# high-quality proposals: high instance quality scores and
# sufficient numbers of proposals
if len(hq_inds) > 0:
Expand Down Expand Up @@ -441,9 +452,6 @@ def contrast_forward(self, anchor_feature, contrast_feature,
(num_gts, self.con_sample_num)
loss_weights: loss weights of each gt (num_gts, )
"""
# anchor_feature = anchor_feature.flatten(start_dim=1)
# contrast_feature = contrast_feature.flatten(start_dim=2)

anchor_feature = anchor_feature.view(anchor_feature.size()[:-2]) # [num_gts, 256]
contrast_feature = contrast_feature.view(contrast_feature.size()[:-2]) # [num_gts, self.con_sample_num, 256]
for fc in self.fc_enc:
Expand Down

0 comments on commit f1c7a8c

Please sign in to comment.