diff --git a/README.md b/README.md index 01713f4..8c35b04 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ We introduce GUI action dataset **Act2Cap** as well as an effective framework: * - 19 Jun 2024: We release our paper on Arxiv. - 15 Aug 2024: The automatic collected datasets and human demonstration datasets are available. - +- 22 Sep 2024: Release pipeline for Cursor detection and Key Frame Extraction module. Here we support 10 frames sampled from video. --- - Download **ACT2CAP** dataset, which consists of 10-frame GUI screenshot sequences depicting atomic actions. **[Download link here](https://drive.google.com/file/d/18cL3ByBkEMI-eTKrelaEXWeiF3QwZAAl/view?usp=drive_link)**. @@ -38,3 +38,19 @@ We introduce GUI action dataset **Act2Cap** as well as an effective framework: * Where `a`, `b` denotes the start and the end frame index respectively. `x` denotes the folder index. The terms `Prompt` and `Crop` refers to screen shot with visual prompt and cropped detailed images generated depend on cursor detection module. However, if you are interested in the original images, you can substitute them with `frame_idx`. +--- +- Download **Cursor detection and Key frame extraction checkpoint** from **[Download link here](https://drive.google.com/file/d/1ChrpBuPL7W84mKNsSsbueff5EGlyB3h2/view?usp=sharing)** + +- Import supporting packages + ``` + pip install -r requirements.txt + ``` + +- Run inference code as below, the visual prompts and cropped images will be generated in folder `frames_sample ` + ``` + cd model + python run_model.py \ + --frame_extract_model_path /path/to/checkpoint_key_frames \ + --yolo_model_path /path/to/Yolo_best \ + --images_path /path/to/frames_sample + ``` \ No newline at end of file diff --git a/frames_sample/frame_0.png b/frames_sample/frame_0.png new file mode 100644 index 0000000..df58835 Binary files /dev/null and b/frames_sample/frame_0.png differ diff --git a/frames_sample/frame_1.png b/frames_sample/frame_1.png new file mode 100644 index 0000000..aae0fe6 Binary files /dev/null and b/frames_sample/frame_1.png differ diff --git a/frames_sample/frame_2.png b/frames_sample/frame_2.png new file mode 100644 index 0000000..e5550a7 Binary files /dev/null and b/frames_sample/frame_2.png differ diff --git a/frames_sample/frame_3.png b/frames_sample/frame_3.png new file mode 100644 index 0000000..30f64e1 Binary files /dev/null and b/frames_sample/frame_3.png differ diff --git a/frames_sample/frame_4.png b/frames_sample/frame_4.png new file mode 100644 index 0000000..e987286 Binary files /dev/null and b/frames_sample/frame_4.png differ diff --git a/frames_sample/frame_5.png b/frames_sample/frame_5.png new file mode 100644 index 0000000..5223b46 Binary files /dev/null and b/frames_sample/frame_5.png differ diff --git a/frames_sample/frame_6.png b/frames_sample/frame_6.png new file mode 100644 index 0000000..4dd181f Binary files /dev/null and b/frames_sample/frame_6.png differ diff --git a/frames_sample/frame_7.png b/frames_sample/frame_7.png new file mode 100644 index 0000000..537aee5 Binary files /dev/null and b/frames_sample/frame_7.png differ diff --git a/frames_sample/frame_8.png b/frames_sample/frame_8.png new file mode 100644 index 0000000..8cd2c16 Binary files /dev/null and b/frames_sample/frame_8.png differ diff --git a/frames_sample/frame_9.png b/frames_sample/frame_9.png new file mode 100644 index 0000000..6ae2b12 Binary files /dev/null and b/frames_sample/frame_9.png differ diff --git a/model/base_model.py b/model/base_model.py new file mode 100644 index 0000000..b3b3654 --- /dev/null +++ b/model/base_model.py @@ -0,0 +1,204 @@ +import torch +from PIL import Image +import open_clip +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm import tqdm +import os +from torch.utils.data import Dataset +from torchvision import transforms +from PIL import Image +import torch.nn.init as init +import torch.nn.functional as F +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import json +import numpy as np +from torchvision.transforms import InterpolationMode +from ultralytics import YOLO +from PIL import Image, ImageDraw + + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +vit_model,_, preprocess = open_clip.create_model_and_transforms('ViT-B/32', pretrained='openai') +print(vit_model) + +def smooth_labels(binary_labels, sigma): + # Define the Gaussian kernel + size = 2 * sigma + 1 + kernel = torch.exp(-(torch.arange(size) - sigma)**2 / (2 * sigma**2)).to(device) + kernel = kernel / kernel.sum() + + padded_labels = F.pad(binary_labels.unsqueeze(0), (0, 0)) + kernel = kernel.view(1, 1, -1) + + smoothed_labels = F.conv1d(padded_labels.unsqueeze(0), kernel, padding=sigma)[0][0] + return smoothed_labels + +def loss_function(predicted_probs, labels): + loss = nn.BCELoss() + return loss(predicted_probs, labels) + +mean = (0.48145466, 0.4578275, 0.40821073) +std = (0.26862954, 0.26130258, 0.27577711) +image_transform = transforms.Compose([ + transforms.Resize( + (224, 224), + interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + +class MLPProjector(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MLPProjector, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, output_size) + nn.init.kaiming_uniform_(self.fc1.weight, a=0.01, nonlinearity='relu') + nn.init.constant_(self.fc1.bias, 0) + nn.init.kaiming_uniform_(self.fc2.weight, a=0.01, nonlinearity='relu') + nn.init.constant_(self.fc2.bias, 0) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + + + +class SelfAttentionBlock(nn.Module): + def __init__(self, input_size, num_heads=8): + super(SelfAttentionBlock, self).__init__() + self.normlayer = nn.LayerNorm(normalized_shape=(input_size,)) + self.self_attention = nn.MultiheadAttention(embed_dim=input_size, num_heads=num_heads) + self.feedforward = nn.Sequential( + nn.Linear(input_size, 4 * input_size), + nn.GELU(), + nn.Linear(4 * input_size, input_size) + ) + + def forward(self, x): + residual = x + out = self.normlayer(x) + out, _ = self.self_attention(out, out, out) + out += residual + residual = out + out = self.normlayer(out) + out = self.feedforward(out) + out += residual + return out + + +class KeyFrameExtractor_v4(nn.Module): + """ + a relatively easy model with + """ + def __init__(self, num_classes=10, num_layers=2): + super(KeyFrameExtractor_v4, self).__init__() + self.clip_encoder = vit_model + + for param in self.clip_encoder.parameters(): + param.requires_grad = False + + self.attention_layers = nn.ModuleList([ + SelfAttentionBlock(input_size=256) for _ in range(num_layers) + ]) + self.position_embedding = nn.Parameter(torch.randn(10, 1, 256)) + self.mlp_projector = MLPProjector(input_size=512, hidden_size=512*4, output_size=256) + self.normlayer= nn.LayerNorm(normalized_shape=(10,)) + + def forward(self, images): + flattened_images = [self.clip_encoder.encode_image(im.unsqueeze(0)) for image in images for im in image] + features = torch.stack(flattened_images) + projected_features = self.mlp_projector(features) + projected_features += self.position_embedding + out= projected_features + for layer in self.attention_layers: + out= layer(out) + + out = out.permute(1,0,2) + # print(out.size()) + out = out.mean(dim=2) # (B, 10) + + return out + + +class Cursor_detector: + def __init__(self, check_point_path, video_dir): + super(Cursor_detector, self).__init__() + self.detection_model = YOLO(check_point_path) + self.video_dir = video_dir + + def detect(self): + for j in range(10): + image_path= f'{self.video_dir}/frame_{j}.png' + results = self.detection_model(image_path) + img = Image.open(image_path) + width, height = img.size + img.close() + print(width, height) + for result in results: + if result.boxes.xywh.size(0)>0: + boxes = result.boxes + xywh_tensor = boxes.xywh + x, y = xywh_tensor[0][0].item(),xywh_tensor[0][1].item() + # print("Value of the first tensor:", x,y) + image1 = Image.open(image_path).convert('RGB') + x1, y1= max(0, x-128), max(0, y-128) + start_crop = image1.crop((x1, y1, min(x1 + 256,width), min(y1 + 256,height))) + start_crop.save(self.video_dir+f'/{j}_crop.png') + x1 = max(0, x - 128) + y1 = max(0, y - 128) + x2 = min(x1 + 256, width) + y2 = min(y1 + 256, height) + + # Draw the bounding box on the image + draw = ImageDraw.Draw(image1) + draw.rectangle([x1, y1, x2, y2], outline='green', width=3) + image1.save(self.video_dir+f'/{j}_prompt.png') + image1.close() + else: + image1 = Image.open(image_path).convert('RGB') + x1, y1= max(0, x-128), max(0, y-128) + start_crop = image1.crop((x1, y1, min(x1 + 256,width), min(y1 + 256,height))) + start_crop.save(self.video_dir+f'/{j}_crop.png') + x1 = max(0, x - 128) + y1 = max(0, y - 128) + x2 = min(x1 + 256, width) + y2 = min(y1 + 256, height) + draw = ImageDraw.Draw(image1) + draw.rectangle([x1, y1, x2, y2], outline='green', width=3) + image1.save(self.video_dir+f'/{j}_prompt.png') + image1.close() + +class ImageReader: + def __init__(self, root_dir, transform=image_transform): + self.root_dir = root_dir + self.transform = transform + self.image_paths = self._get_image_paths() + + def _get_image_paths(self): + image_paths = [] + for i in range(10): + image_path = os.path.join(self.root_dir, f'{i}_crop.png') + if os.path.exists(image_path): + image_paths.append(image_path) + return image_paths + + def read_images(self): + images = [] + for image_path in self.image_paths: + image = Image.open(image_path).convert('RGB') + if self.transform: + image = self.transform(image) + images.append(image) + return torch.stack(images) + +class VideoReader: + pass + diff --git a/model/run_model.py b/model/run_model.py new file mode 100644 index 0000000..702b129 --- /dev/null +++ b/model/run_model.py @@ -0,0 +1,65 @@ +import torch +from PIL import Image +import open_clip +from base_model import KeyFrameExtractor_v4, Cursor_detector, ImageReader +from torchvision.transforms import InterpolationMode +from torchvision import transforms +import argparse + +mean = (0.48145466, 0.4578275, 0.40821073) +std = (0.26862954, 0.26130258, 0.27577711) +image_transform = transforms.Compose([ + transforms.Resize( + (224, 224), + interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + +def main(): + parser = argparse.ArgumentParser(description='Process paths') + parser.add_argument('--frame_extract_model_path', type=str, help='Path to the frame extraction model') + parser.add_argument('--yolo_model_path', type=str, help='Path to the YOLO model') + parser.add_argument('--images_path', type=str, help='Path to the images') + args = parser.parse_args() + + + frame_extract_model_path = args.frame_extract_model_path + yolo_model_path = args.yolo_model_path + images_path = args.images_path + + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print('>>>>>>',device) + + model = KeyFrameExtractor_v4() + loaded_dict = torch.load(frame_extract_model_path) + model.load_state_dict({k.replace('module.', ''): v for k, v in loaded_dict.items()}) + model= model.eval() + + try: + detector = Cursor_detector(yolo_model_path,images_path) + print('detector_load_successful') + + except: + print('error in loading check_point') + + + detector.detect() + + + image_reader = ImageReader(images_path,transform=image_transform) + images_tensor = image_reader.read_images() + output = model(images_tensor.unsqueeze(0)) + values, indices = torch.topk(output, 2) + + start, end = indices[0] + s= min(int(start), int(end)) + e= max(int(start), int(end)) + return s,e + +if __name__ == "__main__": + s, e = main() + print(f'start_frame_index: {s}', f'end_frame_index {e}') + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..99f4c5e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +open_clip_torch +ultralytics