It is based on flax's ImageNet classification sample.
- GPU backend
- TPU backend
- Google Colab notebook
- Install dependency
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install flax pip install ml_collections clu pip install tensorflow tensorflow_datasets tensorboard pip install tf-models-official
git clone --depth 1 https://github.com/NobuoTsukamoto/jax_examples.git
cd jax_example
export PYTHONPATH=`pwd`/common:$PYTHONPATH
cd classification/
python main.py \
--task train \
--config configs/_CONFIG_FILE_.py \
--workdir `full path for workdir`
imagenet2012
Model | Backend | Config | Top-1 accuracy | Epochs | Note |
---|---|---|---|---|---|
MobileNet v2 | TPU v2-8 | config | 71.76 % | 500 | |
ResNet50 | TPU v2-8 | config | 76.3 % | 100 | |
ResNet50 Training techniques (ConvNeXt training techniques) |
TPU v2-8 | config | 77.96 % | 300 | override config --config.batch_size=1024 \ --config.gradient_accumulation_steps=4 |
python main.py --task summarize --config configs/`_MODEL_`.py