diff --git a/fastrcnn/modeling/clip_rcnn.py b/fastrcnn/modeling/clip_rcnn.py index b22d19c..5f6502f 100644 --- a/fastrcnn/modeling/clip_rcnn.py +++ b/fastrcnn/modeling/clip_rcnn.py @@ -24,7 +24,14 @@ def build_clip_rcnn(clip_type='RN50'): - clip_rcnn = CLIP_RCNN(clip_type=clip_type) + pooler_res_dict = { + "RN50": 14, + "RN50x4": 18, + "RN50x16": 24, + "RN50x64": 28, + } + pooler_resolution = pooler_res_dict[clip_type] + clip_rcnn = CLIP_RCNN(clip_type=clip_type, pooler_resolution=pooler_resolution) clip_rcnn.eval() return clip_rcnn @@ -37,12 +44,12 @@ class CLIP_RCNN(nn.Module): def __init__( self, clip_type, - softmax_t: float = 0.01, - pooler_resolution: int = 14, + pooler_resolution, pooler_scales: int = 16, sampling_ratio: int = 0, pooler_type: str = "ROIAlignV2", canonical_box_size: int = 224, + softmax_t: float = 0.01, ): super().__init__() self.register_buffer("pixel_mean_clip", @@ -75,11 +82,10 @@ def forward_clip(self, image, boxes, text_prompt): features = self.clip_res_c4_backbone(imageList_clip.tensor) text_embed = self.get_text_embeddings(text_prompt) clip_scores = self.clip_res5_roi_heads(features, boxes, text_embed) - return clip_scores.cpu().tolist() + return clip_scores.cpu() def get_text_embeddings(self, vocabulary, prefix_prompt='a '): - if not isinstance(vocabulary, list): - vocabulary = [vocabulary] + vocabulary = vocabulary.split(',') texts = [prefix_prompt + x.lower().replace(':', ' ') for x in vocabulary] texts_aug = texts + ['background'] emb = self.text_encoder(texts_aug).permute(1, 0) @@ -110,7 +116,7 @@ def clip_res5_roi_heads(self, features, boxes, text_embed): region_features = F.normalize(region_features, p=2, dim=-1) similarity = ((1 / self.softmax_t) * region_features @ text_embed).softmax(dim=-1) - clip_scores = similarity[:,0] + clip_scores = similarity[:,:-1] return clip_scores diff --git a/fastrcnn/modeling/text_encoder.py b/fastrcnn/modeling/text_encoder.py index 9c52f81..ede4644 100644 --- a/fastrcnn/modeling/text_encoder.py +++ b/fastrcnn/modeling/text_encoder.py @@ -172,15 +172,15 @@ def build_text_encoder(pretrain=True, visual_type="RN50"): "visual_type": ["embed_dim", "context_length", "vocab_size", "transformer_width", "transformer_heads", "transformer_layers"], "RN50": [1024, 77, 49408, 512, 8, 12], - "ViT-B/32": [512, 77, 49408, 512, 8, 12], + "RN50x4": [640, 77, 49408, 640, 10, 12], + "RN50x16": [768, 77, 49408, 768, 12, 12], + "RN50x64": [1024, 77, 49408, 1024, 16, 12], } text_encoder = CLIPTEXT(**{k: v for k, v in zip(clip_dict['visual_type'], clip_dict[visual_type])}) if pretrain: import clip - if visual_type == 'RN50': - pretrained_model, _ = clip.load("RN50", device='cpu') - elif visual_type == 'ViT-B/32': - pretrained_model, _ = clip.load("ViT-B/32", device='cpu') + if visual_type in clip_dict: + pretrained_model, _ = clip.load(visual_type, device='cpu') else: raise NotImplementedError