diff --git a/demo_sani.py b/demo_sani.py index 8748b13..b5dd6f4 100644 --- a/demo_sani.py +++ b/demo_sani.py @@ -13,78 +13,6 @@ from fastrcnn.modeling.clip_rcnn import build_clip_rcnn -def load_image(image_path): - # load image - image_pil = Image.open(image_path).convert("RGB") # load image - - transform = T.Compose( - [ - T.RandomResize([800], max_size=1333), - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) - image, _ = transform(image_pil, None) # 3, h, w - return image_pil, image - - -def load_model(model_config_path, model_checkpoint_path, device): - args = SLConfig.fromfile(model_config_path) - args.device = device - model = build_model(args) - checkpoint = torch.load(model_checkpoint_path, map_location="cpu") - load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) - print(load_res) - _ = model.eval() - return model - - - -def show_mask(mask, ax, random_color=False): - if random_color: - color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) - else: - color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) - h, w = mask.shape[-2:] - mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) - ax.imshow(mask_image) - - -def show_box(box, ax, label): - x0, y0 = box[0], box[1] - w, h = box[2] - box[0], box[3] - box[1] - ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) - ax.text(x0, y0, label) - - -def save_mask_data(output_dir, mask_list, box_list, label_list): - value = 0 # 0 for background - - mask_img = torch.zeros(mask_list.shape[-2:]) - for idx, mask in enumerate(mask_list): - mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1 - plt.figure(figsize=(10, 10)) - plt.imshow(mask_img.numpy()) - plt.axis('off') - plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0) - - json_data = [{ - 'value': value, - 'label': 'background' - }] - for label, box in zip(label_list, box_list): - value += 1 - name, logit = label.split('(') - logit = logit[:-1] # the last is ')' - json_data.append({ - 'value': value, - 'label': name, - 'logit': float(logit), - 'box': box.numpy().tolist(), - }) - with open(os.path.join(output_dir, 'mask.json'), 'w') as f: - json.dump(json_data, f) - def show_anns(anns, ): if len(anns) == 0: return @@ -95,7 +23,7 @@ def show_anns(anns, ): color = [] for ann in sorted_anns: m = ann['segmentation'] - print(ann['stability_score'], ann['predicted_iou']) + # print(ann['stability_score'], ann['predicted_iou']) if ann['predicted_iou'] < 1.0: continue img = np.ones((m.shape[0], m.shape[1], 3)) @@ -103,34 +31,46 @@ def show_anns(anns, ): for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35))) + box = ann['bbox'] + x0, y0, w, h = box[0], box[1], box[2], box[3] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0, 0, 0, 0), lw=2)) -def show_anns_with_scores(anns, scores, text_prompt): +def show_anns_with_scores(img_size, anns, scores, text_prompt): if len(anns) == 0: return + text_prompts = text_prompt.split(',') + for ann, score in zip(anns, scores): - ann['clip_score'] = score + ind = score.argmax() + ann['clip_score'] = score[ind] + ann['name'] = text_prompts[ind] sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) - + img_h, img_w = img_size for ann in sorted_anns: - m = ann['segmentation'] - - if ann['clip_score'] < 0.5: + box = ann['bbox'] + x0, y0, w, h = box[0], box[1], box[2], box[3] + if w > 0.7 * img_w and h > 0.7 * img_h: + continue + clip_score = ann['clip_score'] + predicted_iou = ann['predicted_iou'] + # if clip_score < 0.5 or predicted_iou < 1.0: + if clip_score < 0.5: continue + # print(ann['clip_score'], ann['predicted_iou']) + m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35))) - # x, y = ann['point_coords'][0] - # ax.text(x, y, text_prompt) - box = ann['bbox'] - x0, y0 = box[0], box[1] - w, h = box[2] - box[0], box[3] - box[1] - ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0, 0, 0, 0), lw=2)) - ax.text(x0, y0, text_prompt) + + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color_mask, facecolor=(0, 0, 0, 0), lw=2)) + label = ann['name'] + ': {:.2}'.format(clip_score) + ax.text(x0, y0, label, color=color_mask, fontsize='large', fontfamily='sans-serif') + if __name__ == "__main__": parser = argparse.ArgumentParser("Segment-Anything-and-Name-It Demo", add_help=True) @@ -138,8 +78,9 @@ def show_anns_with_scores(anns, scores, text_prompt): "--sam_checkpoint", type=str, default="sam_vit_h_4b8939.pth", help="path to checkpoint file" ) parser.add_argument( - "--clip_type", type=str, default="RN50", help="model type of clip" + "--clip_type", type=str, default="RN50x4", help="model type of clip" ) + parser.add_argument("--input_image", type=str, required=True, help="path to image file") parser.add_argument("--text_prompt", type=str, required=True, help="text prompt") parser.add_argument("--output_dir", "-o", type=str, default="outputs", required=True, help="output directory") @@ -180,20 +121,20 @@ def show_anns_with_scores(anns, scores, text_prompt): masks = mask_generator.generate(image) # prepare boxes - boxes = [ann['bbox'] for ann in masks] - + boxes_xywh = [ann['bbox'] for ann in masks] + boxes_xyxy = [[xywh[0], xywh[1], xywh[0] + xywh[2], xywh[1] + xywh[3]] \ + for xywh in boxes_xywh] # generate scores - scores = clip_rcnn.forward_clip(image, boxes, text_prompt) + scores = clip_rcnn.forward_clip(image, boxes_xyxy, text_prompt) - # import pdb - # pdb.set_trace() # draw output image plt.figure(figsize=(10, 10)) plt.imshow(image) - show_anns_with_scores(masks, scores, text_prompt) + show_anns_with_scores(image.shape[:2], masks, scores, text_prompt) plt.axis('off') + image_name = image_path.split('/')[-1] plt.savefig( - os.path.join(output_dir, "sani_output.jpg"), + os.path.join(output_dir, "sani_output_{}".format(image_name)), bbox_inches="tight", dpi=300, pad_inches=0.0 )