From 94fe9f1f5683c525b6a3c0f69931f090b8d9c764 Mon Sep 17 00:00:00 2001 From: ZitengXue <2507456075@qq.com> Date: Tue, 10 Oct 2023 23:34:51 +0800 Subject: [PATCH] encoder --- projects/DETR3D/detr3d/detr3d.py | 27 ++++----- projects/DETR3D/detr3d/detr3d_head.py | 3 +- projects/DETR3D/detr3d/detr3d_transformer.py | 3 +- .../transformer/grounding_dino_layers.py | 60 +++++++++---------- 4 files changed, 45 insertions(+), 48 deletions(-) diff --git a/projects/DETR3D/detr3d/detr3d.py b/projects/DETR3D/detr3d/detr3d.py index c57af66..a7e1e5a 100755 --- a/projects/DETR3D/detr3d/detr3d.py +++ b/projects/DETR3D/detr3d/detr3d.py @@ -72,7 +72,7 @@ def __init__(self, self._special_tokens = '. ' self.positional_encoding = SinePositionalEncoding( **positional_encoding_single) - # self.encoder = GroundingDinoTransformerEncoder(**encoder) + self.encoder = GroundingDinoTransformerEncoder(**encoder) # self.level_embed = nn.Parameter( # torch.Tensor(4, 256)) nn.init.constant_(self.text_feat_map.bias.data, 0) @@ -180,20 +180,19 @@ def loss(self, batch_inputs_dict: Dict[List, Tensor], positive_maps.append(positive_map) text_dict = self.language_model(new_text_prompts) - # for key, value in text_dict.items(): - # text_dict[key] = torch.cat([value] * 6, dim=0) + for key, value in text_dict.items(): + text_dict[key] = torch.cat([value] * 6, dim=0) if self.text_feat_map is not None: text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) - memory_text=text_dict['embedded'] ##################################################################### - # encoder_inputs_dict = self.pre_transformer( - # img_feats, batch_data_samples) - - # memory,memory_text = self.forward_encoder( - # **encoder_inputs_dict, text_dict=text_dict)#text和图像特征融合 - # del img_feats - # img_feats = self.restore_img_feats(memory, encoder_inputs_dict['spatial_shapes'], encoder_inputs_dict['level_start_index']) - outs = self.pts_bbox_head(img_feats, batch_input_metas,memory_text, **kwargs)#text_dict + encoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + + memory = self.forward_encoder( + **encoder_inputs_dict, text_dict=text_dict) + del img_feats + img_feats = self.restore_img_feats(memory, encoder_inputs_dict['spatial_shapes'], encoder_inputs_dict['level_start_index']) + outs = self.pts_bbox_head(img_feats, batch_input_metas, **kwargs)#text_dict loss_inputs = [batch_gt_instances_3d, outs] losses_pts = self.pts_bbox_head.loss_by_feat(*loss_inputs) @@ -491,7 +490,7 @@ def forward_encoder(self, feat: Tensor, feat_mask: Tensor, level_start_index: Tensor, valid_ratios: Tensor, text_dict: Dict) -> Dict: text_token_mask = text_dict['text_token_mask'] - memory, memory_text = self.encoder( + memory, _ = self.encoder( query=feat, # query_pos=feat_pos, key_padding_mask=feat_mask, # for self_attn @@ -510,7 +509,7 @@ def forward_encoder(self, feat: Tensor, feat_mask: Tensor, # memory_text=memory_text, # text_token_mask=text_token_mask) # return encoder_outputs_dict - return memory,memory_text + return memory @staticmethod def get_valid_ratio(mask: Tensor) -> Tensor: """Get the valid radios of feature map in a level. diff --git a/projects/DETR3D/detr3d/detr3d_head.py b/projects/DETR3D/detr3d/detr3d_head.py index ff18334..d4143ad 100755 --- a/projects/DETR3D/detr3d/detr3d_head.py +++ b/projects/DETR3D/detr3d/detr3d_head.py @@ -112,7 +112,7 @@ def init_weights(self): for m in self.cls_branches: nn.init.constant_(m[-1].bias, bias_init) - def forward(self, mlvl_feats: List[Tensor], img_metas: List[Dict],memory_text:Tensor, + def forward(self, mlvl_feats: List[Tensor], img_metas: List[Dict], **kwargs) -> Dict[str, Tensor]: """Forward function. @@ -135,7 +135,6 @@ def forward(self, mlvl_feats: List[Tensor], img_metas: List[Dict],memory_text:Te query_embeds, reg_branches=self.reg_branches if self.with_box_refine else None, img_metas=img_metas, - memory_text=memory_text, **kwargs) hs = hs.permute(0, 2, 1, 3) outputs_classes = [] diff --git a/projects/DETR3D/detr3d/detr3d_transformer.py b/projects/DETR3D/detr3d/detr3d_transformer.py index f5ed263..dfe0765 100755 --- a/projects/DETR3D/detr3d/detr3d_transformer.py +++ b/projects/DETR3D/detr3d/detr3d_transformer.py @@ -75,7 +75,7 @@ def init_weights(self): m.init_weight() xavier_init(self.reference_points, distribution='uniform', bias=0.) - def forward(self, mlvl_feats, query_embed,memory_text, reg_branches=None, **kwargs): + def forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs): """Forward function for `Detr3DTransformer`. Args: mlvl_feats (list(Tensor)): Input queries from @@ -120,7 +120,6 @@ def forward(self, mlvl_feats, query_embed,memory_text, reg_branches=None, **kwar query_pos=query_pos, reference_points=reference_points, reg_branches=reg_branches, - memory_text=memory_text, **kwargs) inter_references_out = inter_references diff --git a/projects/DETR3D/layers/transformer/grounding_dino_layers.py b/projects/DETR3D/layers/transformer/grounding_dino_layers.py index 2182d33..5116088 100755 --- a/projects/DETR3D/layers/transformer/grounding_dino_layers.py +++ b/projects/DETR3D/layers/transformer/grounding_dino_layers.py @@ -146,10 +146,10 @@ def _init_layers(self) -> None: # DeformableDetrTransformerEncoderLayer(**self.layer_cfg) # for _ in range(self.num_layers) # ]) - self.text_layers = ModuleList([ - DetrTransformerEncoderLayer(**self.text_layer_cfg) - for _ in range(self.num_layers) - ]) + # self.text_layers = ModuleList([ + # DetrTransformerEncoderLayer(**self.text_layer_cfg) + # for _ in range(self.num_layers) + # ]) self.fusion_layers = ModuleList([ SingleScaleBiAttentionBlock(**self.fusion_layer_cfg) for _ in range(self.num_layers) @@ -208,21 +208,21 @@ def forward(self, output = query # reference_points = self.get_encoder_reference_points( # spatial_shapes, valid_ratios, device=query.device) - if self.text_layers: - # generate pos_text - bs, n_text, _ = memory_text.shape - if pos_text is None and position_ids is None: - pos_text = ( - torch.arange(n_text, - device=memory_text.device).float().unsqueeze( - 0).unsqueeze(-1).repeat(bs, 1, 1)) - pos_text = get_text_sine_pos_embed( - pos_text, num_pos_feats=256, exchange_xy=False) - if position_ids is not None: - pos_text = get_text_sine_pos_embed( - position_ids[..., None], - num_pos_feats=256, - exchange_xy=False) + # if self.text_layers: + # # generate pos_text + # bs, n_text, _ = memory_text.shape + # if pos_text is None and position_ids is None: + # pos_text = ( + # torch.arange(n_text, + # device=memory_text.device).float().unsqueeze( + # 0).unsqueeze(-1).repeat(bs, 1, 1)) + # pos_text = get_text_sine_pos_embed( + # pos_text, num_pos_feats=256, exchange_xy=False) + # if position_ids is not None: + # pos_text = get_text_sine_pos_embed( + # position_ids[..., None], + # num_pos_feats=256, + # exchange_xy=False) # main process # for layer_id, layer in enumerate(self.layers): @@ -234,16 +234,16 @@ def forward(self, attention_mask_v=key_padding_mask, attention_mask_l=text_attention_mask, ) - if self.text_layers: - text_num_heads = self.text_layers[ - layer_id].self_attn_cfg.num_heads - memory_text = self.text_layers[layer_id]( - query=memory_text[0], - query_pos=(pos_text if pos_text is not None else None), - attn_mask=~text_self_attention_masks.repeat( - text_num_heads, 1, 1), # note we use ~ for mask here - key_padding_mask=None, - ) + # if self.text_layers: + # text_num_heads = self.text_layers[ + # layer_id].self_attn_cfg.num_heads + # memory_text = self.text_layers[layer_id]( + # query=memory_text[0], + # query_pos=(pos_text if pos_text is not None else None), + # attn_mask=~text_self_attention_masks.repeat( + # text_num_heads, 1, 1), # note we use ~ for mask here + # key_padding_mask=None, + # ) # output = layer( # query=output, # query_pos=query_pos, @@ -268,4 +268,4 @@ def _init_layers(self) -> None: f'{self._get_name()}') self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims, self.embed_dims, 2) - self.norm = nn.LayerNorm(self.embed_dims) + self.norm = nn.LayerNorm(self.embed_dims) \ No newline at end of file