Red velvet cake is my all-time favorite cake. It’s perfect for any holiday or celebration, and it always looks so elegant!
Research supported with Cloud TPUs from Google's TPU Research Cloud (TRC)
This contains all the code (training, data loading) for the prompting vision language model discussed in the master thesis "Multitask Multilingual Vision Language: Image and Text Generation".
To increase the language range with Vietnamese as our main focus, we have translated 6 common English datasets for image captioning and visual question answering into Vietnamese. The 6 datasets are CC3M, COCO, VQAv2, OK-VQA, TextCaps, TextVQA. By using Google Translate, these datasets collectively contain millions of image-text pairs in each language (English and Vietnamese). These datasets are available at huggingface.co/dinhanhx.
We have proposed a prompting vision language model (this repository) which is trained on these datasets. The model can caption images and answer questions related to images. The model is modular and has 3 components: a frozen image model (ConvNeXt V2), a frozen text model (BLOOMZ), and a Cutie model. Cutie model is meant to query useful information from visual features (by image model) for the text model. Our concept is heavily inspired by BLIP-2 architecture. Docstring in velvet/model/cutie.py
should provide the differences.
If you use this source code or model weights or theory, please cite it as below.
@software{dinhanhx_Velvet_2023,
title = {{Velvet}},
author = {dinhanhx},
year = 2023,
month = sep,
url = {https://github.com/dinhanhx/velvet},
license = {MIT}
}
It's unfortunately not good. Our largest model achieves CIDEr of Crossmodal-3600 of 0.435 in English and 0.318 in Vietnamese. It also gets 0.3404 F1 and 0.1135 BLEU on the private test of EVJVQA.
See this pip requirements to setup on TPU v3-8 or v4-8.
⚠ Please take a look at this file configs/data_dir.toml
first
Given that the folder to download data is /mnt/storage/data/
, in this folder, run the following commands,
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/datasets/dinhanhx/evjvqa
git clone https://huggingface.co/datasets/dinhanhx/TextVQA-vi
git clone https://huggingface.co/datasets/dinhanhx/TextCaps-vi
git clone https://huggingface.co/datasets/dinhanhx/OK-VQA-multilang
git clone https://huggingface.co/datasets/dinhanhx/VQAv2-vi
git clone https://huggingface.co/datasets/dinhanhx/cc_sbu_align_multilang
git clone https://huggingface.co/datasets/dinhanhx/coco-2017-vi
git clone https://huggingface.co/datasets/dinhanhx/gcc-vi
git clone https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K
Now setup LLaVA-CC3M-Pretrain-595K
,
cd LLaVA-CC3M-Pretrain-595K/
mkdir images
unzip -q images.zip -d images
Now setup coco-2017-images
,
mkdir coco-2017-images
cd coco-2017-images
curl -JO http://images.cocodataset.org/zips/train2017.zip
unzip -q train2017.zip
Now setup OpenImages
,
mkdir OpenImages
cd OpenImages
curl -JO https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
unzip -q train_val_images.zip
Now run some python scripts to make data usable by velvet, don't forget do back to root of this repository,
python scripts/<do_something>.py
where <do_something>.py
is a file in scripts/
⚠ Please read the files (including scripts/readme.md
) before running, especially things related to paths.
⚠ Our training code can be used on GPU too. It's not TPU exclusive. Please read pl.Trainer
in velvet_exp/run_all.py
carefully before running on any hardware.
for TPU v3-8
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
for TPU v4-8
export PJRT_DEVICE=TPU
# Replace <name> by config name
python -m velvet_exp.run_all configs/experiments/<name>.json
# Kill leftover processes
pgrep -f "python -m velvet_exp.run_all" | xargs kill -9
PyTorch Lightning (after and during training) produces a ckpt file (which can be opened by Torch). It contains everything including the model weight. To extract the model weight only, please refer the following code (which is similar to velvet_exp/run_all.py
)
Click me
import json
import torch
from transformers.models.bert import BertConfig
from transformers.models.bloom import BloomConfig
from transformers.models.convnextv2.modeling_convnextv2 import ConvNextV2Config
from velvet_exp.run_all import Wrapper
experiment_config = json.load(open("configs/experiments/big_model.json", "r"))
image_model_name = experiment_config.get("image_model_name", None)
if image_model_name is None:
image_model_name = "facebook/convnextv2-base-22k-224"
image_config = ConvNextV2Config.from_pretrained(image_model_name)
bloom_model_name = experiment_config.get("bloom_model_name", None)
if bloom_model_name is None:
bloom_model_name = "bigscience/bloomz-560m"
bloom_config = BloomConfig.from_pretrained(bloom_model_name)
bert_config = BertConfig(
hidden_size=1024,
num_hidden_layers=6,
num_attention_heads=16,
is_decoder=True,
add_cross_attention=True,
)
wrapper = Wrapper(
experiment_config=experiment_config,
image_config=image_config, # type: ignore
bert_config=bert_config,
bloom_config=bloom_config, # type: ignore
bloom_name=bloom_model_name,
learning_rate=experiment_config["learning_rate"],
warmup_ratio=experiment_config["warmup_ratio"],
use_lrs=experiment_config["use_learning_rate_scheduler"],
warmup_steps=experiment_config.get("warmup_steps", None),
)
wrapper.load_state_dict(torch.load("big_model_logs/lightning_logs/version_0/checkpoints/last.ckpt")["state_dict"])
torch.save(wrapper.visual_bloom.state_dict(), "big_model_logs/lightning_logs/version_0/checkpoints/visual_bloom.torch")
⚠ Please make sure that you have done the extraction or you have obtained the model weight here
The following code doesn't use GPU or TPU. To change the image, please look at variable url
. To change the prompt, please read the last for loop.
Click me
import requests
import torch
from PIL import Image
from transformers.models.bert import BertConfig
from transformers.models.bloom import BloomConfig, BloomTokenizerFast
from transformers.models.convnext import ConvNextImageProcessor
from transformers.models.convnextv2.modeling_convnextv2 import (
ConvNextV2Config,
ConvNextV2Model,
)
from velvet.collator import ImageFeatureCollator
from velvet.model import VisualBloom
image_model_name = "facebook/convnextv2-large-22k-224"
image_config = ConvNextV2Config.from_pretrained(image_model_name)
image_processor = ConvNextImageProcessor.from_pretrained(image_model_name)
image_model = ConvNextV2Model.from_pretrained(image_model_name)
image_feature_collator = ImageFeatureCollator(image_processor, image_model)
bloom_model_name = "bigscience/bloomz-1b7"
bloom_config = BloomConfig.from_pretrained(bloom_model_name)
tokenizer = BloomTokenizerFast.from_pretrained(bloom_model_name)
tokenizer.padding_side = "right"
bert_config = BertConfig(
hidden_size=1024,
num_hidden_layers=6,
num_attention_heads=16,
is_decoder=True,
add_cross_attention=True,
)
visual_bloom = VisualBloom(
image_config, bert_config, bloom_config, bloom_model_name, use_frozen_bloom=False
)
visual_bloom.load_state_dict(
torch.load("big_model_logs/lightning_logs/version_0/checkpoints/visual_bloom.torch")
)
visual_bloom = visual_bloom.eval()
url = "https://i.imgur.com/Y2vIAJp.jpg"
language_list = ["en", "vi"]
for language in language_list:
# instruction = (
# f"Generate answer in {language}: what is the color of the sky?"
# if language == "en"
# else f"Generate answer in {language}: màu sắc của bầu trời là gì?"
# )
instruction = f"Generate caption in {language}:"
print(instruction)
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
image_features, image_attentions = image_feature_collator([image])
instruction_inputs = tokenizer([instruction], return_tensors="pt")
language_output = visual_bloom.generate(
image_features,
image_attentions,
instruction_inputs["input_ids"],
instruction_inputs["attention_mask"],
)
human_output = tokenizer.decode(language_output[0], skip_special_tokens=True)
print(human_output.split(".")[0])