Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

example of object detection on very simple image #21

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 288 additions & 0 deletions examples/perceiver_detect_box.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "perceiver detect box.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "1EmMgQSSra9M"
},
"source": [
"!pip install perceiver-pytorch\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "YNCh_-_1rhPX"
},
"source": [
"import torch\n",
"import cv2\n",
"import numpy as np\n",
"import torch.optim as optim\n",
"import torchvision\n",
"\n",
"from google.colab.patches import cv2_imshow\n",
"from perceiver_pytorch import Perceiver\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"##################\n",
"batch_size = 128\n",
"size = 64\n",
"objects = 1\n",
"##################"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Kn5CTHlhrmdM"
},
"source": [
"model = Perceiver(\n",
" input_channels = 3, # number of channels for each token of the input\n",
" input_axis = 2, # number of axis for input data (2 for images, 3 for video)\n",
" num_freq_bands = 4, # number of freq bands, with original value (2 * K + 1)\n",
" max_freq = 5., # maximum frequency, hyperparameter depending on how fine the data is\n",
" depth = 2, # depth of net\n",
" num_latents = 32, # number of latents, or induced set points, or centroids. different papers giving it different names\n",
" cross_dim = 32, # cross attention dimension\n",
" latent_dim = 32, # latent dimension\n",
" cross_heads = 1, # number of heads for cross attention. paper said 1\n",
" latent_heads = 4, # number of heads for latent self attention, 8\n",
" cross_dim_head = 32,\n",
" latent_dim_head = 32,\n",
" num_classes = 4*objects, # output number of classes\n",
" attn_dropout = 0.5,\n",
" ff_dropout = 0.5,\n",
" weight_tie_layers = True # whether to weight tie layers (optional, as indicated in the diagram)\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "c2WaA0Vhrqhv"
},
"source": [
"# training the model on very easy object detection problem\n",
"class CustomDataset(Dataset):\n",
" def __init__(self): \n",
" pass\n",
" \n",
" def __len__(self):\n",
" return 320000\n",
"\n",
" def __getitem__(self, idx):\n",
" image = np.zeros((size,size,3), np.uint8)\n",
" labels = []\n",
"\n",
" for i in range(objects):\n",
" if (np.random.rand() > 0.02):\n",
" point_x = int(np.random.rand() * size)\n",
" point_y = int(np.random.rand() * size)\n",
" size_x = int(np.random.rand() * size)\n",
" size_y = int(np.random.rand() * size) \n",
" r, g, b = int(np.random.rand()*255), int(np.random.rand()*255), int(np.random.rand() * 255) \n",
" try:\n",
" cv2.rectangle(image, (int(point_x - size_x/2), int(point_y - size_y/2)), (int(point_x + size_x/2), int(point_y + size_y/2)), (r, g, b), -1) \n",
" labels.append((point_x/size, point_y/size, size_x/size, size_y/size))\n",
" except Exception as e:\n",
" print(e)\n",
" labels.append((0,0,0,0))\n",
" else: \n",
" labels.append((0,0,0,0)) \n",
" \n",
" labels = torch.as_tensor(labels, dtype=torch.float32)\n",
" image = torch.as_tensor(image, dtype=torch.float32)\n",
"\n",
" return (image, labels)\n",
"\n",
"\n",
"\n",
"criterion = torch.nn.MSELoss().cuda()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3, amsgrad=True)\n",
"\n",
"dataset = CustomDataset()\n",
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)\n",
"\n",
"model.cuda()\n",
"model.train()\n",
"\n",
"\n",
"for epoch in range(20): # loop over the dataset multiple times\n",
"\n",
" running_loss = 0.0\n",
"\n",
" for i, batch in enumerate(dataloader):\n",
"\n",
" img, labels = batch \n",
" \n",
" img = torch.as_tensor(img, dtype=torch.float32).cuda()\n",
"\n",
" labels = torch.as_tensor(labels, dtype=torch.float32).cuda() \n",
" labels = labels.flatten()\n",
" labels = torch.reshape(labels, (batch_size, 4*objects))\n",
" \n",
" optimizer.zero_grad()\n",
" out = model(img)\n",
" \n",
" \n",
" loss = criterion(out, labels) \n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # print statistics\n",
" running_loss += loss.item()\n",
" if i % 20 == 19: # print every 2000 mini-batches\n",
" print('[%d, %6d] loss: %.6f' %\n",
" (epoch + 1, i + 1, running_loss / 20))\n",
" running_loss = 0.0\n",
" "
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 298
},
"id": "ovpU5I04r3mP",
"outputId": "a070ef95-75eb-4b77-831f-a786fe5a7ea8"
},
"source": [
"# testing the trained model\n",
"\n",
"model = model.eval()\n",
"import time\n",
"\n",
"\n",
"\n",
"np.random.seed(np.random.randint(0, 99999))\n",
"\n",
"image = np.zeros((size,size, 3), np.uint8)\n",
"labels = []\n",
"model = model.cuda()\n",
"\n",
"point_x = int(np.random.rand() * size)\n",
"point_y = int(np.random.rand() * size)\n",
"size_x = int(np.random.rand() * size)\n",
"size_y = int(np.random.rand() * size) \n",
"r, g, b = int(np.random.rand()*255), int(np.random.rand()*255), int(np.random.rand() * 255) \n",
"\n",
"cv2.rectangle(image, (int(point_x - size_x/2), int(point_y - size_y/2)), (int(point_x + size_x/2), int(point_y + size_y/2)), (r, g, b), -1) \n",
"\n",
"print(\"input image:\")\n",
"cv2_imshow(image)\n",
"print(\"input box location (x, y), (w, h):\")\n",
"print((int(point_x - size_x/2), int(point_y - size_y/2)), (int(point_x + size_x/2), int(point_y + size_y/2)))\n",
"\n",
"image_tensor = torch.as_tensor(image, dtype=torch.float32)\n",
"\n",
"image_tensor = torch.unsqueeze(image_tensor, 0).cuda() \n",
"\n",
"image2 = np.ones((size,size, 3), np.uint8)\n",
"\n",
"t1 = time.time()\n",
"for i in range(1):\n",
" out = model(image_tensor)\n",
" out = out[0].detach().cpu().numpy()\n",
" \n",
"out_point_x = int(out[0] * size)\n",
"out_point_y = int(out[1] * size)\n",
"out_size_x = int(out[2] * size)\n",
"out_size_y = int(out[3] * size)\n",
"\n",
"cv2.rectangle(image2, (int(out_point_x - out_size_x/2), int(out_point_y - out_size_y/2)), (int(out_point_x + out_size_x/2), int(out_point_y + out_size_y/2)), (r, g, b), -1) \n",
"\n",
"print(\"-\"*50)\n",
"print(\"predicted output:\")\n",
"cv2_imshow(image2)\n",
"print(\"predicted box location (x, y), (w, h):\")\n",
"print((int(out_point_x - out_size_x/2), int(out_point_y - out_size_y/2)), (int(out_point_x + out_size_x/2), int(out_point_y + out_size_y/2)))\n",
"print(\"-\"*50)\n",
"print(f'inference time was: {(time.time() - t1) * 1000} ms')\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"input image:\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAWUlEQVR4nO3PAQmAQAAEQTWNKSxgXguYwjiGkP9FmAlwxy4LAAAAAAC/s+7XPefpOY8Rs9uI0ZkE1ATUBNQE1ATUBNQE1ATUBNQE1ATUBNQE1ATUBNQE8NELYUwDRMSJFboAAAAASUVORK5CYII=\n",
"text/plain": [
"<PIL.Image.Image image mode=RGB size=64x64 at 0x7F9A1C5A8E50>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"input box location (x, y), (w, h):\n",
"(-18, 29) (34, 62)\n",
"--------------------------------------------------\n",
"predicted output:\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAdElEQVR4nO3RAQ2EAAwAMYYaXGAAuxh4F+8GERAakquAbZfNzCxftuoD7ipAK0ArQCtAK0ArQCtAK0ArQCtAK0ArQCtAK0ArQJvt/L227H/sj8/8/AcK0ArQCtAK0ArQCtAK0ArQCtAK0ArQCtAK0ArQCtAuyfADhXI8ZDMAAAAASUVORK5CYII=\n",
"text/plain": [
"<PIL.Image.Image image mode=RGB size=64x64 at 0x7F9A1C5A8E50>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"predicted box location (x, y), (w, h):\n",
"(-11, 32) (35, 70)\n",
"--------------------------------------------------\n",
"inference time was: 11.744499206542969 ms\n"
],
"name": "stdout"
}
]
}
]
}