- 03/07/2023: Project page built
- 11/07/2023: Pytorch version uploaded
Recent complicated problems require large-scale datasets and complex model architectures, however, it is difficult to train such large networks due to high computational issues. Significant efforts have been made to make the training more efficient such as momentum, learning rate scheduling, weight regularization, and meta-learning. Based on our observations on 1) high correlation between past weights and future weights, 2) conditions for beneficial weight prediction, and 3) feasibility of weight prediction, we propose a more general framework by intermittently skipping a handful of epochs by periodically forecasting near future weights, i.e., a Weight Nowcaster Network (WNN). As an add-on module, WNN predicts the future weights to make the learning process faster regardless of tasks and architectures. Experimental results show that WNN can significantly save actual time cost for training with an additional marginal time to train WNN. We validate the generalization capability of WNN under various tasks, and demonstrate that it works well even for unseen tasks.
WNN is composed of simple two-stream networks that use fully-connected layers and an activation network. Feature vectors from those two networks are unified to a feature vector and it is passed through a fully-connected layer. The predicted future weight parameters are obtained by adding outputs and input weight parameters.
| Library | Known Working | Known Not Working |
| tensorflow | 2.3.0, 2.9.0 | <= 2.0 |
We provide a simple plug-in source code that can be added to your source code by using a callback function extending tf.keras.callbacks.Callback:
import tensorflow as tf
import tensorflow.keras
from WNN import *
.
.
.
model.fit(..., callbacks=[WeightForecasting()])
.
.
.
Pre-trained weights of WNN are included. 'NWNN_XXX_13.h5 ' in this repo are the pre-trained weights for each mathematical operation type (Conv, FC, Bias).
Training without WNN on CIFAR10:
python CIFAR10_without_WNN.py
Training with WNN on CIFAR10:
python CIFAR10_with_WNN.py
We also provide the pytorch implmentation of WNN through this link: Pytorch Imlementation Link
We would like to thank Anthony Rigoli at University of Strasbourg for implementing the pytorch source code and sharing it.
@inproceedings{jang2023learning,
title={Learning to Boost Training by Periodic Nowcasting Near Future Weights},
author={Jang, Jinhyeok and Yun, Woo-han and Kim, Won Hwa and Yoon, Youngwoo and Kim, Jaehong and Lee, Jaeyeon and Han, ByungOk},
booktitle={International Conference on Machine Learning (ICML)},
year={2023}
}