diff --git a/examples/perceiver_detect_box.ipynb b/examples/perceiver_detect_box.ipynb new file mode 100644 index 0000000..c9a0cff --- /dev/null +++ b/examples/perceiver_detect_box.ipynb @@ -0,0 +1,288 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "perceiver detect box.ipynb", + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "id": "1EmMgQSSra9M" + }, + "source": [ + "!pip install perceiver-pytorch\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "YNCh_-_1rhPX" + }, + "source": [ + "import torch\n", + "import cv2\n", + "import numpy as np\n", + "import torch.optim as optim\n", + "import torchvision\n", + "\n", + "from google.colab.patches import cv2_imshow\n", + "from perceiver_pytorch import Perceiver\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "\n", + "\n", + "##################\n", + "batch_size = 128\n", + "size = 64\n", + "objects = 1\n", + "##################" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Kn5CTHlhrmdM" + }, + "source": [ + "model = Perceiver(\n", + " input_channels = 3, # number of channels for each token of the input\n", + " input_axis = 2, # number of axis for input data (2 for images, 3 for video)\n", + " num_freq_bands = 4, # number of freq bands, with original value (2 * K + 1)\n", + " max_freq = 5., # maximum frequency, hyperparameter depending on how fine the data is\n", + " depth = 2, # depth of net\n", + " num_latents = 32, # number of latents, or induced set points, or centroids. different papers giving it different names\n", + " cross_dim = 32, # cross attention dimension\n", + " latent_dim = 32, # latent dimension\n", + " cross_heads = 1, # number of heads for cross attention. paper said 1\n", + " latent_heads = 4, # number of heads for latent self attention, 8\n", + " cross_dim_head = 32,\n", + " latent_dim_head = 32,\n", + " num_classes = 4*objects, # output number of classes\n", + " attn_dropout = 0.5,\n", + " ff_dropout = 0.5,\n", + " weight_tie_layers = True # whether to weight tie layers (optional, as indicated in the diagram)\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "c2WaA0Vhrqhv" + }, + "source": [ + "# training the model on very easy object detection problem\n", + "class CustomDataset(Dataset):\n", + " def __init__(self): \n", + " pass\n", + " \n", + " def __len__(self):\n", + " return 320000\n", + "\n", + " def __getitem__(self, idx):\n", + " image = np.zeros((size,size,3), np.uint8)\n", + " labels = []\n", + "\n", + " for i in range(objects):\n", + " if (np.random.rand() > 0.02):\n", + " point_x = int(np.random.rand() * size)\n", + " point_y = int(np.random.rand() * size)\n", + " size_x = int(np.random.rand() * size)\n", + " size_y = int(np.random.rand() * size) \n", + " r, g, b = int(np.random.rand()*255), int(np.random.rand()*255), int(np.random.rand() * 255) \n", + " try:\n", + " cv2.rectangle(image, (int(point_x - size_x/2), int(point_y - size_y/2)), (int(point_x + size_x/2), int(point_y + size_y/2)), (r, g, b), -1) \n", + " labels.append((point_x/size, point_y/size, size_x/size, size_y/size))\n", + " except Exception as e:\n", + " print(e)\n", + " labels.append((0,0,0,0))\n", + " else: \n", + " labels.append((0,0,0,0)) \n", + " \n", + " labels = torch.as_tensor(labels, dtype=torch.float32)\n", + " image = torch.as_tensor(image, dtype=torch.float32)\n", + "\n", + " return (image, labels)\n", + "\n", + "\n", + "\n", + "criterion = torch.nn.MSELoss().cuda()\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-3, amsgrad=True)\n", + "\n", + "dataset = CustomDataset()\n", + "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)\n", + "\n", + "model.cuda()\n", + "model.train()\n", + "\n", + "\n", + "for epoch in range(20): # loop over the dataset multiple times\n", + "\n", + " running_loss = 0.0\n", + "\n", + " for i, batch in enumerate(dataloader):\n", + "\n", + " img, labels = batch \n", + " \n", + " img = torch.as_tensor(img, dtype=torch.float32).cuda()\n", + "\n", + " labels = torch.as_tensor(labels, dtype=torch.float32).cuda() \n", + " labels = labels.flatten()\n", + " labels = torch.reshape(labels, (batch_size, 4*objects))\n", + " \n", + " optimizer.zero_grad()\n", + " out = model(img)\n", + " \n", + " \n", + " loss = criterion(out, labels) \n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # print statistics\n", + " running_loss += loss.item()\n", + " if i % 20 == 19: # print every 2000 mini-batches\n", + " print('[%d, %6d] loss: %.6f' %\n", + " (epoch + 1, i + 1, running_loss / 20))\n", + " running_loss = 0.0\n", + " " + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 298 + }, + "id": "ovpU5I04r3mP", + "outputId": "a070ef95-75eb-4b77-831f-a786fe5a7ea8" + }, + "source": [ + "# testing the trained model\n", + "\n", + "model = model.eval()\n", + "import time\n", + "\n", + "\n", + "\n", + "np.random.seed(np.random.randint(0, 99999))\n", + "\n", + "image = np.zeros((size,size, 3), np.uint8)\n", + "labels = []\n", + "model = model.cuda()\n", + "\n", + "point_x = int(np.random.rand() * size)\n", + "point_y = int(np.random.rand() * size)\n", + "size_x = int(np.random.rand() * size)\n", + "size_y = int(np.random.rand() * size) \n", + "r, g, b = int(np.random.rand()*255), int(np.random.rand()*255), int(np.random.rand() * 255) \n", + "\n", + "cv2.rectangle(image, (int(point_x - size_x/2), int(point_y - size_y/2)), (int(point_x + size_x/2), int(point_y + size_y/2)), (r, g, b), -1) \n", + "\n", + "print(\"input image:\")\n", + "cv2_imshow(image)\n", + "print(\"input box location (x, y), (w, h):\")\n", + "print((int(point_x - size_x/2), int(point_y - size_y/2)), (int(point_x + size_x/2), int(point_y + size_y/2)))\n", + "\n", + "image_tensor = torch.as_tensor(image, dtype=torch.float32)\n", + "\n", + "image_tensor = torch.unsqueeze(image_tensor, 0).cuda() \n", + "\n", + "image2 = np.ones((size,size, 3), np.uint8)\n", + "\n", + "t1 = time.time()\n", + "for i in range(1):\n", + " out = model(image_tensor)\n", + " out = out[0].detach().cpu().numpy()\n", + " \n", + "out_point_x = int(out[0] * size)\n", + "out_point_y = int(out[1] * size)\n", + "out_size_x = int(out[2] * size)\n", + "out_size_y = int(out[3] * size)\n", + "\n", + "cv2.rectangle(image2, (int(out_point_x - out_size_x/2), int(out_point_y - out_size_y/2)), (int(out_point_x + out_size_x/2), int(out_point_y + out_size_y/2)), (r, g, b), -1) \n", + "\n", + "print(\"-\"*50)\n", + "print(\"predicted output:\")\n", + "cv2_imshow(image2)\n", + "print(\"predicted box location (x, y), (w, h):\")\n", + "print((int(out_point_x - out_size_x/2), int(out_point_y - out_size_y/2)), (int(out_point_x + out_size_x/2), int(out_point_y + out_size_y/2)))\n", + "print(\"-\"*50)\n", + "print(f'inference time was: {(time.time() - t1) * 1000} ms')\n" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "input image:\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAWUlEQVR4nO3PAQmAQAAEQTWNKSxgXguYwjiGkP9FmAlwxy4LAAAAAAC/s+7XPefpOY8Rs9uI0ZkE1ATUBNQE1ATUBNQE1ATUBNQE1ATUBNQE1ATUBNQE8NELYUwDRMSJFboAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "input box location (x, y), (w, h):\n", + "(-18, 29) (34, 62)\n", + "--------------------------------------------------\n", + "predicted output:\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAdElEQVR4nO3RAQ2EAAwAMYYaXGAAuxh4F+8GERAakquAbZfNzCxftuoD7ipAK0ArQCtAK0ArQCtAK0ArQCtAK0ArQCtAK0ArQJvt/L227H/sj8/8/AcK0ArQCtAK0ArQCtAK0ArQCtAK0ArQCtAK0ArQCtAuyfADhXI8ZDMAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "predicted box location (x, y), (w, h):\n", + "(-11, 32) (35, 70)\n", + "--------------------------------------------------\n", + "inference time was: 11.744499206542969 ms\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file