Skip to content

Latest commit

 

History

History

classification

Classification models

It is based on flax's ImageNet classification sample.

Requirements

  • GPU backend
  • TPU backend
    • Google Colab notebook
    • Install dependency
      pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      pip install flax
      pip install ml_collections clu
      pip install tensorflow tensorflow_datasets tensorboard
      pip install tf-models-official
      

Preparation

git clone --depth 1 https://github.com/NobuoTsukamoto/jax_examples.git
cd jax_example
export PYTHONPATH=`pwd`/common:$PYTHONPATH
cd classification/

Train

Running locally

python main.py \
    --task train \
    --config configs/_CONFIG_FILE_.py \
    --workdir `full path for workdir`

Models

imagenet2012

Model Backend Config Top-1 accuracy Epochs Note
MobileNet v2 TPU v2-8 config 71.76 % 500
ResNet50 TPU v2-8 config 76.3 % 100
ResNet50 Training techniques
(ConvNeXt training techniques)
TPU v2-8 config 77.96 % 300 override config
--config.batch_size=1024 \
--config.gradient_accumulation_steps=4

Model summary

python main.py --task summarize --config configs/`_MODEL_`.py