Skip to content

Commit

Permalink
add sani
Browse files Browse the repository at this point in the history
  • Loading branch information
peizesun committed Apr 9, 2023
1 parent 0d87ed5 commit b1314d9
Showing 1 changed file with 35 additions and 94 deletions.
129 changes: 35 additions & 94 deletions demo_sani.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,78 +13,6 @@
from fastrcnn.modeling.clip_rcnn import build_clip_rcnn


def load_image(image_path):
# load image
image_pil = Image.open(image_path).convert("RGB") # load image

transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image_pil, image


def load_model(model_config_path, model_checkpoint_path, device):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
print(load_res)
_ = model.eval()
return model



def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)


def show_box(box, ax, label):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
ax.text(x0, y0, label)


def save_mask_data(output_dir, mask_list, box_list, label_list):
value = 0 # 0 for background

mask_img = torch.zeros(mask_list.shape[-2:])
for idx, mask in enumerate(mask_list):
mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
plt.figure(figsize=(10, 10))
plt.imshow(mask_img.numpy())
plt.axis('off')
plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)

json_data = [{
'value': value,
'label': 'background'
}]
for label, box in zip(label_list, box_list):
value += 1
name, logit = label.split('(')
logit = logit[:-1] # the last is ')'
json_data.append({
'value': value,
'label': name,
'logit': float(logit),
'box': box.numpy().tolist(),
})
with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
json.dump(json_data, f)

def show_anns(anns, ):
if len(anns) == 0:
return
Expand All @@ -95,51 +23,64 @@ def show_anns(anns, ):
color = []
for ann in sorted_anns:
m = ann['segmentation']
print(ann['stability_score'], ann['predicted_iou'])
# print(ann['stability_score'], ann['predicted_iou'])
if ann['predicted_iou'] < 1.0:
continue
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
box = ann['bbox']
x0, y0, w, h = box[0], box[1], box[2], box[3]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0, 0, 0, 0), lw=2))


def show_anns_with_scores(anns, scores, text_prompt):
def show_anns_with_scores(img_size, anns, scores, text_prompt):
if len(anns) == 0:
return
text_prompts = text_prompt.split(',')

for ann, score in zip(anns, scores):
ann['clip_score'] = score
ind = score.argmax()
ann['clip_score'] = score[ind]
ann['name'] = text_prompts[ind]
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img_h, img_w = img_size
for ann in sorted_anns:
m = ann['segmentation']

if ann['clip_score'] < 0.5:
box = ann['bbox']
x0, y0, w, h = box[0], box[1], box[2], box[3]
if w > 0.7 * img_w and h > 0.7 * img_h:
continue
clip_score = ann['clip_score']
predicted_iou = ann['predicted_iou']
# if clip_score < 0.5 or predicted_iou < 1.0:
if clip_score < 0.5:
continue
# print(ann['clip_score'], ann['predicted_iou'])
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
# x, y = ann['point_coords'][0]
# ax.text(x, y, text_prompt)
box = ann['bbox']
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0, 0, 0, 0), lw=2))
ax.text(x0, y0, text_prompt)

ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color_mask, facecolor=(0, 0, 0, 0), lw=2))
label = ann['name'] + ': {:.2}'.format(clip_score)
ax.text(x0, y0, label, color=color_mask, fontsize='large', fontfamily='sans-serif')


if __name__ == "__main__":
parser = argparse.ArgumentParser("Segment-Anything-and-Name-It Demo", add_help=True)
parser.add_argument(
"--sam_checkpoint", type=str, default="sam_vit_h_4b8939.pth", help="path to checkpoint file"
)
parser.add_argument(
"--clip_type", type=str, default="RN50", help="model type of clip"
"--clip_type", type=str, default="RN50x4", help="model type of clip"
)

parser.add_argument("--input_image", type=str, required=True, help="path to image file")
parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
parser.add_argument("--output_dir", "-o", type=str, default="outputs", required=True, help="output directory")
Expand Down Expand Up @@ -180,20 +121,20 @@ def show_anns_with_scores(anns, scores, text_prompt):
masks = mask_generator.generate(image)

# prepare boxes
boxes = [ann['bbox'] for ann in masks]

boxes_xywh = [ann['bbox'] for ann in masks]
boxes_xyxy = [[xywh[0], xywh[1], xywh[0] + xywh[2], xywh[1] + xywh[3]] \
for xywh in boxes_xywh]
# generate scores
scores = clip_rcnn.forward_clip(image, boxes, text_prompt)
scores = clip_rcnn.forward_clip(image, boxes_xyxy, text_prompt)

# import pdb
# pdb.set_trace()
# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_anns_with_scores(masks, scores, text_prompt)
show_anns_with_scores(image.shape[:2], masks, scores, text_prompt)

plt.axis('off')
image_name = image_path.split('/')[-1]
plt.savefig(
os.path.join(output_dir, "sani_output.jpg"),
os.path.join(output_dir, "sani_output_{}".format(image_name)),
bbox_inches="tight", dpi=300, pad_inches=0.0
)

0 comments on commit b1314d9

Please sign in to comment.