Skip to content

Commit

Permalink
add key frame model
Browse files Browse the repository at this point in the history
  • Loading branch information
frank6200db committed Sep 22, 2024
1 parent ec0b496 commit d58e18a
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 1 deletion.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ We introduce GUI action dataset **Act2Cap** as well as an effective framework: *

- 19 Jun 2024: We release our paper on Arxiv.
- 15 Aug 2024: The automatic collected datasets and human demonstration datasets are available.

- 22 Sep 2024: Release pipeline for Cursor detection and Key Frame Extraction module. Here we support 10 frames sampled from video.
---

- Download **ACT2CAP** dataset, which consists of 10-frame GUI screenshot sequences depicting atomic actions. **[Download link here](https://drive.google.com/file/d/18cL3ByBkEMI-eTKrelaEXWeiF3QwZAAl/view?usp=drive_link)**.
Expand All @@ -38,3 +38,19 @@ We introduce GUI action dataset **Act2Cap** as well as an effective framework: *
Where `a`, `b` denotes the start and the end frame index respectively. `x` denotes the folder index.
The terms `Prompt` and `Crop` refers to screen shot with visual prompt and cropped detailed images generated depend on cursor detection module.
However, if you are interested in the original images, you can substitute them with `frame_idx`.
---
- Download **Cursor detection and Key frame extraction checkpoint** from **[Download link here](https://drive.google.com/file/d/1ChrpBuPL7W84mKNsSsbueff5EGlyB3h2/view?usp=sharing)**
- Import supporting packages
```
pip install -r requirements.txt
```
- Run inference code as below, the visual prompts and cropped images will be generated in folder `frames_sample `
```
cd model
python run_model.py \
--frame_extract_model_path /path/to/checkpoint_key_frames \
--yolo_model_path /path/to/Yolo_best \
--images_path /path/to/frames_sample
```
Binary file added frames_sample/frame_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added frames_sample/frame_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
204 changes: 204 additions & 0 deletions model/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import torch
from PIL import Image
import open_clip
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch.nn.init as init
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import json
import numpy as np
from torchvision.transforms import InterpolationMode
from ultralytics import YOLO
from PIL import Image, ImageDraw



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_model,_, preprocess = open_clip.create_model_and_transforms('ViT-B/32', pretrained='openai')
print(vit_model)

def smooth_labels(binary_labels, sigma):
# Define the Gaussian kernel
size = 2 * sigma + 1
kernel = torch.exp(-(torch.arange(size) - sigma)**2 / (2 * sigma**2)).to(device)
kernel = kernel / kernel.sum()

padded_labels = F.pad(binary_labels.unsqueeze(0), (0, 0))
kernel = kernel.view(1, 1, -1)

smoothed_labels = F.conv1d(padded_labels.unsqueeze(0), kernel, padding=sigma)[0][0]
return smoothed_labels

def loss_function(predicted_probs, labels):
loss = nn.BCELoss()
return loss(predicted_probs, labels)

mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
image_transform = transforms.Compose([
transforms.Resize(
(224, 224),
interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])

class MLPProjector(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLPProjector, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
nn.init.kaiming_uniform_(self.fc1.weight, a=0.01, nonlinearity='relu')
nn.init.constant_(self.fc1.bias, 0)
nn.init.kaiming_uniform_(self.fc2.weight, a=0.01, nonlinearity='relu')
nn.init.constant_(self.fc2.bias, 0)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x




class SelfAttentionBlock(nn.Module):
def __init__(self, input_size, num_heads=8):
super(SelfAttentionBlock, self).__init__()
self.normlayer = nn.LayerNorm(normalized_shape=(input_size,))
self.self_attention = nn.MultiheadAttention(embed_dim=input_size, num_heads=num_heads)
self.feedforward = nn.Sequential(
nn.Linear(input_size, 4 * input_size),
nn.GELU(),
nn.Linear(4 * input_size, input_size)
)

def forward(self, x):
residual = x
out = self.normlayer(x)
out, _ = self.self_attention(out, out, out)
out += residual
residual = out
out = self.normlayer(out)
out = self.feedforward(out)
out += residual
return out


class KeyFrameExtractor_v4(nn.Module):
"""
a relatively easy model with
"""
def __init__(self, num_classes=10, num_layers=2):
super(KeyFrameExtractor_v4, self).__init__()
self.clip_encoder = vit_model

for param in self.clip_encoder.parameters():
param.requires_grad = False

self.attention_layers = nn.ModuleList([
SelfAttentionBlock(input_size=256) for _ in range(num_layers)
])
self.position_embedding = nn.Parameter(torch.randn(10, 1, 256))
self.mlp_projector = MLPProjector(input_size=512, hidden_size=512*4, output_size=256)
self.normlayer= nn.LayerNorm(normalized_shape=(10,))

def forward(self, images):
flattened_images = [self.clip_encoder.encode_image(im.unsqueeze(0)) for image in images for im in image]
features = torch.stack(flattened_images)
projected_features = self.mlp_projector(features)
projected_features += self.position_embedding
out= projected_features
for layer in self.attention_layers:
out= layer(out)

out = out.permute(1,0,2)
# print(out.size())
out = out.mean(dim=2) # (B, 10)

return out


class Cursor_detector:
def __init__(self, check_point_path, video_dir):
super(Cursor_detector, self).__init__()
self.detection_model = YOLO(check_point_path)
self.video_dir = video_dir

def detect(self):
for j in range(10):
image_path= f'{self.video_dir}/frame_{j}.png'
results = self.detection_model(image_path)
img = Image.open(image_path)
width, height = img.size
img.close()
print(width, height)
for result in results:
if result.boxes.xywh.size(0)>0:
boxes = result.boxes
xywh_tensor = boxes.xywh
x, y = xywh_tensor[0][0].item(),xywh_tensor[0][1].item()
# print("Value of the first tensor:", x,y)
image1 = Image.open(image_path).convert('RGB')
x1, y1= max(0, x-128), max(0, y-128)
start_crop = image1.crop((x1, y1, min(x1 + 256,width), min(y1 + 256,height)))
start_crop.save(self.video_dir+f'/{j}_crop.png')
x1 = max(0, x - 128)
y1 = max(0, y - 128)
x2 = min(x1 + 256, width)
y2 = min(y1 + 256, height)

# Draw the bounding box on the image
draw = ImageDraw.Draw(image1)
draw.rectangle([x1, y1, x2, y2], outline='green', width=3)
image1.save(self.video_dir+f'/{j}_prompt.png')
image1.close()
else:
image1 = Image.open(image_path).convert('RGB')
x1, y1= max(0, x-128), max(0, y-128)
start_crop = image1.crop((x1, y1, min(x1 + 256,width), min(y1 + 256,height)))
start_crop.save(self.video_dir+f'/{j}_crop.png')
x1 = max(0, x - 128)
y1 = max(0, y - 128)
x2 = min(x1 + 256, width)
y2 = min(y1 + 256, height)
draw = ImageDraw.Draw(image1)
draw.rectangle([x1, y1, x2, y2], outline='green', width=3)
image1.save(self.video_dir+f'/{j}_prompt.png')
image1.close()

class ImageReader:
def __init__(self, root_dir, transform=image_transform):
self.root_dir = root_dir
self.transform = transform
self.image_paths = self._get_image_paths()

def _get_image_paths(self):
image_paths = []
for i in range(10):
image_path = os.path.join(self.root_dir, f'{i}_crop.png')
if os.path.exists(image_path):
image_paths.append(image_path)
return image_paths

def read_images(self):
images = []
for image_path in self.image_paths:
image = Image.open(image_path).convert('RGB')
if self.transform:
image = self.transform(image)
images.append(image)
return torch.stack(images)

class VideoReader:
pass

65 changes: 65 additions & 0 deletions model/run_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from PIL import Image
import open_clip
from base_model import KeyFrameExtractor_v4, Cursor_detector, ImageReader
from torchvision.transforms import InterpolationMode
from torchvision import transforms
import argparse

mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
image_transform = transforms.Compose([
transforms.Resize(
(224, 224),
interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])

def main():
parser = argparse.ArgumentParser(description='Process paths')
parser.add_argument('--frame_extract_model_path', type=str, help='Path to the frame extraction model')
parser.add_argument('--yolo_model_path', type=str, help='Path to the YOLO model')
parser.add_argument('--images_path', type=str, help='Path to the images')
args = parser.parse_args()


frame_extract_model_path = args.frame_extract_model_path
yolo_model_path = args.yolo_model_path
images_path = args.images_path


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('>>>>>>',device)

model = KeyFrameExtractor_v4()
loaded_dict = torch.load(frame_extract_model_path)
model.load_state_dict({k.replace('module.', ''): v for k, v in loaded_dict.items()})
model= model.eval()

try:
detector = Cursor_detector(yolo_model_path,images_path)
print('detector_load_successful')

except:
print('error in loading check_point')


detector.detect()


image_reader = ImageReader(images_path,transform=image_transform)
images_tensor = image_reader.read_images()
output = model(images_tensor.unsqueeze(0))
values, indices = torch.topk(output, 2)

start, end = indices[0]
s= min(int(start), int(end))
e= max(int(start), int(end))
return s,e

if __name__ == "__main__":
s, e = main()
print(f'start_frame_index: {s}', f'end_frame_index {e}')

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
open_clip_torch
ultralytics

0 comments on commit d58e18a

Please sign in to comment.