Skip to content

Commit

Permalink
encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ZitengXue committed Oct 10, 2023
1 parent 6406279 commit 94fe9f1
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 48 deletions.
27 changes: 13 additions & 14 deletions projects/DETR3D/detr3d/detr3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self,
self._special_tokens = '. '
self.positional_encoding = SinePositionalEncoding(
**positional_encoding_single)
# self.encoder = GroundingDinoTransformerEncoder(**encoder)
self.encoder = GroundingDinoTransformerEncoder(**encoder)
# self.level_embed = nn.Parameter(
# torch.Tensor(4, 256))
nn.init.constant_(self.text_feat_map.bias.data, 0)
Expand Down Expand Up @@ -180,20 +180,19 @@ def loss(self, batch_inputs_dict: Dict[List, Tensor],
positive_maps.append(positive_map)

text_dict = self.language_model(new_text_prompts)
# for key, value in text_dict.items():
# text_dict[key] = torch.cat([value] * 6, dim=0)
for key, value in text_dict.items():
text_dict[key] = torch.cat([value] * 6, dim=0)
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(text_dict['embedded'])
memory_text=text_dict['embedded']
#####################################################################
# encoder_inputs_dict = self.pre_transformer(
# img_feats, batch_data_samples)

# memory,memory_text = self.forward_encoder(
# **encoder_inputs_dict, text_dict=text_dict)#text和图像特征融合
# del img_feats
# img_feats = self.restore_img_feats(memory, encoder_inputs_dict['spatial_shapes'], encoder_inputs_dict['level_start_index'])
outs = self.pts_bbox_head(img_feats, batch_input_metas,memory_text, **kwargs)#text_dict
encoder_inputs_dict = self.pre_transformer(
img_feats, batch_data_samples)

memory = self.forward_encoder(
**encoder_inputs_dict, text_dict=text_dict)
del img_feats
img_feats = self.restore_img_feats(memory, encoder_inputs_dict['spatial_shapes'], encoder_inputs_dict['level_start_index'])
outs = self.pts_bbox_head(img_feats, batch_input_metas, **kwargs)#text_dict
loss_inputs = [batch_gt_instances_3d, outs]
losses_pts = self.pts_bbox_head.loss_by_feat(*loss_inputs)

Expand Down Expand Up @@ -491,7 +490,7 @@ def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
level_start_index: Tensor, valid_ratios: Tensor,
text_dict: Dict) -> Dict:
text_token_mask = text_dict['text_token_mask']
memory, memory_text = self.encoder(
memory, _ = self.encoder(
query=feat,
# query_pos=feat_pos,
key_padding_mask=feat_mask, # for self_attn
Expand All @@ -510,7 +509,7 @@ def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
# memory_text=memory_text,
# text_token_mask=text_token_mask)
# return encoder_outputs_dict
return memory,memory_text
return memory
@staticmethod
def get_valid_ratio(mask: Tensor) -> Tensor:
"""Get the valid radios of feature map in a level.
Expand Down
3 changes: 1 addition & 2 deletions projects/DETR3D/detr3d/detr3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def init_weights(self):
for m in self.cls_branches:
nn.init.constant_(m[-1].bias, bias_init)

def forward(self, mlvl_feats: List[Tensor], img_metas: List[Dict],memory_text:Tensor,
def forward(self, mlvl_feats: List[Tensor], img_metas: List[Dict],
**kwargs) -> Dict[str, Tensor]:
"""Forward function.
Expand All @@ -135,7 +135,6 @@ def forward(self, mlvl_feats: List[Tensor], img_metas: List[Dict],memory_text:Te
query_embeds,
reg_branches=self.reg_branches if self.with_box_refine else None,
img_metas=img_metas,
memory_text=memory_text,
**kwargs)
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
Expand Down
3 changes: 1 addition & 2 deletions projects/DETR3D/detr3d/detr3d_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_weights(self):
m.init_weight()
xavier_init(self.reference_points, distribution='uniform', bias=0.)

def forward(self, mlvl_feats, query_embed,memory_text, reg_branches=None, **kwargs):
def forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
Expand Down Expand Up @@ -120,7 +120,6 @@ def forward(self, mlvl_feats, query_embed,memory_text, reg_branches=None, **kwar
query_pos=query_pos,
reference_points=reference_points,
reg_branches=reg_branches,
memory_text=memory_text,
**kwargs)

inter_references_out = inter_references
Expand Down
60 changes: 30 additions & 30 deletions projects/DETR3D/layers/transformer/grounding_dino_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ def _init_layers(self) -> None:
# DeformableDetrTransformerEncoderLayer(**self.layer_cfg)
# for _ in range(self.num_layers)
# ])
self.text_layers = ModuleList([
DetrTransformerEncoderLayer(**self.text_layer_cfg)
for _ in range(self.num_layers)
])
# self.text_layers = ModuleList([
# DetrTransformerEncoderLayer(**self.text_layer_cfg)
# for _ in range(self.num_layers)
# ])
self.fusion_layers = ModuleList([
SingleScaleBiAttentionBlock(**self.fusion_layer_cfg)
for _ in range(self.num_layers)
Expand Down Expand Up @@ -208,21 +208,21 @@ def forward(self,
output = query
# reference_points = self.get_encoder_reference_points(
# spatial_shapes, valid_ratios, device=query.device)
if self.text_layers:
# generate pos_text
bs, n_text, _ = memory_text.shape
if pos_text is None and position_ids is None:
pos_text = (
torch.arange(n_text,
device=memory_text.device).float().unsqueeze(
0).unsqueeze(-1).repeat(bs, 1, 1))
pos_text = get_text_sine_pos_embed(
pos_text, num_pos_feats=256, exchange_xy=False)
if position_ids is not None:
pos_text = get_text_sine_pos_embed(
position_ids[..., None],
num_pos_feats=256,
exchange_xy=False)
# if self.text_layers:
# # generate pos_text
# bs, n_text, _ = memory_text.shape
# if pos_text is None and position_ids is None:
# pos_text = (
# torch.arange(n_text,
# device=memory_text.device).float().unsqueeze(
# 0).unsqueeze(-1).repeat(bs, 1, 1))
# pos_text = get_text_sine_pos_embed(
# pos_text, num_pos_feats=256, exchange_xy=False)
# if position_ids is not None:
# pos_text = get_text_sine_pos_embed(
# position_ids[..., None],
# num_pos_feats=256,
# exchange_xy=False)

# main process
# for layer_id, layer in enumerate(self.layers):
Expand All @@ -234,16 +234,16 @@ def forward(self,
attention_mask_v=key_padding_mask,
attention_mask_l=text_attention_mask,
)
if self.text_layers:
text_num_heads = self.text_layers[
layer_id].self_attn_cfg.num_heads
memory_text = self.text_layers[layer_id](
query=memory_text[0],
query_pos=(pos_text if pos_text is not None else None),
attn_mask=~text_self_attention_masks.repeat(
text_num_heads, 1, 1), # note we use ~ for mask here
key_padding_mask=None,
)
# if self.text_layers:
# text_num_heads = self.text_layers[
# layer_id].self_attn_cfg.num_heads
# memory_text = self.text_layers[layer_id](
# query=memory_text[0],
# query_pos=(pos_text if pos_text is not None else None),
# attn_mask=~text_self_attention_masks.repeat(
# text_num_heads, 1, 1), # note we use ~ for mask here
# key_padding_mask=None,
# )
# output = layer(
# query=output,
# query_pos=query_pos,
Expand All @@ -268,4 +268,4 @@ def _init_layers(self) -> None:
f'{self._get_name()}')
self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims,
self.embed_dims, 2)
self.norm = nn.LayerNorm(self.embed_dims)
self.norm = nn.LayerNorm(self.embed_dims)

0 comments on commit 94fe9f1

Please sign in to comment.