Skip to content

Releases: vballoli/vit-flax

0.0.3 release v2

12 Oct 00:55
b33a178
Compare
Choose a tag to compare

Vision Transformer in JAX/Flax

This repository implements Vision Trasnformer(ViT) in Flax, introduced in an ICLR paper 2021 submission, with further explanation by Yannic Kilcher. This repository is heavily inspired from lucidrain's implementation.

Install

pip install vit-flax

Usage

import jax
from jax import numpy as jnp
from flax import nn
from vit_flax import ViT

rng = jax.random.PRNGKey(0)
module = ViT.partial(patch_size=32, dim=1024, depth=6, num_heads=8, dense_dims=(2048, 2048), img_size=256, num_classes=10)
_, initial_params = module.init_by_shape(
  rng, [((1, 256, 256, 3), jnp.float32)]
)
model = nn.Model(module, initial_params)

img = jax.random.uniform(rng, (1,256,256,3))
output = model(img)

examples directory contains code to train ViT on CIFAR datasets.

Docs and references

Documentation for all the modules can be viewed here.

Note

This repository is still in initial stages. Feel free to Contact me or raise issues/PR for suggestions, improvements or bugs.

Help needed

A recent commit introduces code for training CIFAR models in the examples directory. If you're using this code and have the resources to run, I'd be happy to include those reports here and give appropriate credits for the same.

0.0.3 release

12 Oct 00:51
2159e89
Compare
Choose a tag to compare

Vision Transformer in JAX/Flax

This repository implements Vision Trasnformer(ViT) in Flax, introduced in an ICLR paper 2021 submission, with further explanation by Yannic Kilcher. This repository is heavily inspired from lucidrain's implementation.

Install

pip install vit-flax

Usage

import jax
from jax import numpy as jnp
from flax import nn
from vit_flax import ViT

rng = jax.random.PRNGKey(0)
module = ViT.partial(patch_size=32, dim=1024, depth=6, num_heads=8, dense_dims=(2048, 2048), img_size=256, num_classes=10)
_, initial_params = module.init_by_shape(
  rng, [((1, 256, 256, 3), jnp.float32)]
)
model = nn.Model(module, initial_params)

img = jax.random.uniform(rng, (1,256,256,3))
output = model(img)

examples directory contains code to train ViT on CIFAR datasets.

Docs and references

Documentation for all the modules can be viewed here.

Note

This repository is still in initial stages. Feel free to Contact me or raise issues/PR for suggestions, improvements or bugs.

Help needed

A recent commit introduces code for training CIFAR models in the examples directory. If you're using this code and have the resources to run, I'd be happy to include those reports here and give appropriate credits for the same.

Initial release

05 Oct 16:43
5459e26
Compare
Choose a tag to compare

Vision Transformer in Flax

This repository implements Vision Trasnformer(ViT) in Flax, introduced in an ICLR paper 2021 submission, with further explanation by Yannic Kilcher. This repository is heavily inspired from lucidrain's implementation.

Install

pip install vit-flax

Usage

import jax
from jax import numpy as jnp
from flax import nn
from vit_flax import ViT

rng = jax.random.PRNGKey(0)
module = ViT.partial(patch_size=32, dim=1024, depth=6, num_heads=8, dense_dims=(2048, 2048), img_size=256, num_classes=10)
_, initial_params = module.init_by_shape(
  rng, [((1, 256, 256, 3), jnp.float32)]
)
model = nn.Model(module, initial_params)

img = jax.random.uniform(rng, (1,256,256,3))
output = model(img)

Note: This repository is still in initial stages. Feel free to Contact me or raise issues/PR for suggestions, improvements or bugs.