Convert a PyTorch model and train it in JavaScript using ONNX Runtime Web.
Try it yourself at https://juharris.github.io/train-pytorch-in-js.
Steps:
- Define your PyTorch model. You probably already did this.
- Use the new utility method to export an ONNX gradient graph for the model.
- Set up an optimizer graph.
- Load the graphs in JavaScript (this project uses TypeScript).
- Use the graphs to train the model.
Details:
You probably already did this. Here's our simple example:
import torch
class MyModel(torch.nn.Module):
def __init__(self,
input_size: int,
hidden_size: int,
num_classes: int):
super(MyModel, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
You can train it in Python to get some good initial weights but that's not required to export it and then train it in JavaScript.
We're going to create an ONNX graph that can compute gradients when given training data.
You can follow along here or see the full example in example.py or mnist/example.py.
I did this in Windows Subsystem for Linux (WSL).
- PyTorch
If you don't already have PyTorch installed, see pytorch.org for how to install it on your system. For example:
conda install pytorch torchvision torchaudio cpuonly -c pytorch
- ONNX Runtime
See onnxruntime.ai for all installation options. The utility method we'll use is new in version 1.11 so you'll need at least that version. Make sure that the version you use it the same as the version of ONNX Runtime Web that you'll use later. This repository includes a pre-built ONNX Runtime Web version for version 1.11 so we'll use that version for our Python onnxruntime dependencies.
Example:
pip install onnx 'onnxruntime==1.11.*' 'onnxruntime-training==1.11.*'
import torch
from onnxruntime.training.experimental import export_gradient_graph
# We need a custom loss function to load the graph in an InferenceSession in ONNX Runtime Web.
# You can still make the gradient graph with torch.nn.CrossEntropyLoss() and this part will work but you'll get problem later when trying to use the graph in JavaScript.
def binary_cross_entropy_loss(output, target):
return -torch.sum(target * torch.log2(output[:, 0]) +
(1-target) * torch.log2(output[:, 1]))
loss_fn = binary_cross_entropy_loss
input_size = 10
num_classes = 2
model = MyModel(input_size=input_size, hidden_size=5, num_classes=num_classes)
# File path for where to save the ONNX graph.
gradient_graph_path = 'gradient_graph.onnx'
# We need example input for the ONNX model.
# It doesn't matter what values are filled in the but the dimensions need to be correct.
batch_size = 32
example_input = torch.randn(
batch_size, input_size, requires_grad=True)
example_labels = torch.randint(0, num_classes, (batch_size,))
export_gradient_graph(
model, loss_fn, example_input, example_labels, gradient_graph_path)
You now have an ONNX graph at gradient_graph.onnx
.
If you want to validate it, see orttraining_test_experimental_gradient_graph.py for examples.
We'll run another ONNX graph to compute the weight updates. This repo has an example for an Adam optimizer here.
The optimizer graph is kept separate from the gradient graph for a few reasons:
- You can easily swap the optimizer for a different optimizer while using the same gradient graph.
- Historically, putting the model's gradient graph and the optimizer graph together was too complex to support many different types of optimizers.
Export the optimizer graph:
from optim.adam import AdamOnnxGraphBuilder
optimizer = AdamOnnxGraphBuilder(model.named_parameters())
onnx_optimizer = optimizer.export()
onnx.save(onnx_optimizer, 'optimizer_graph.onnx')
Those were just examples that you could follow in your own project. This browser example project will load a model that classifies digits from the MNIST dataset.
Next, we'll prepare the model's gradient graph and optimizer graph for the example JavaScript project. Go to the export folder:
cd export
To export the MNIST example:
python -m mnist.example
(Optional) Train the model in Python to verify that it should work:
python -m mnist.train
We'll use ONNX Runtime Web to load the gradient graph.
At this time (May 2022), this only works with custom ONNX Runtine Web builds which have training operators enabled but the required files are included in this repository.
The officially published ONNX Runtime Web doesn't support the certain operators in our exported gradient graph with gradient calculations such as GatherGrad
when using an InferenceSession.
For your convenience, we included a build of ONNX Runtime Web with training operators enabled for ONNX Runtime version 1.11. You can see other versions here.
If you would like to build it yourself, here's some commands that should help assuming you're using Linux and have CMake and conda
setup:
conda create --name ort-dev python=3.8 numpy h5py
conda activate ort-dev
conda install -c anaconda libstdcxx-ng
conda install pytorch torchvision torchaudio cpuonly -c pytorch
pip install flake8 pytest
# This is a specific tag that should work, you can try with other versions but this tutorial will work best if the version matches the onnxruntime and onnxruntime-training versions you installed for Python earlier.
commit="2dfd81b9bb097c90388010e5b7d298498274f8d9"
git clone --recursive [email protected]:microsoft/onnxruntime.git
cd onnxruntime
git checkout ${commit}
git submodule update --init --recursive
pip install -r requirements-dev.txt
For the build command, there are instructions at ONNX Runtime Web which currently links to specific instructions here.
When you get to the "Build ONNX Runtime WebAssembly" step, you'll need to add --enable_training --enable_training_ops
to the build command.
For example:
./build.sh --build_wasm --parallel $(expr `nproc` - 1) --enable_training --enable_training_ops --skip_submodule_sync --skip_tests
./build.sh --build_wasm --enable_wasm_simd --parallel $(expr `nproc` - 1) --enable_training --enable_training_ops --skip_submodule_sync --skip_tests
./build.sh --build_wasm --enable_wasm_threads --parallel $(expr `nproc` - 1) --enable_training --enable_training_ops --skip_submodule_sync --skip_tests
./build.sh --build_wasm --enable_wasm_simd --enable_wasm_threads --parallel $(expr `nproc` - 1) --enable_training --enable_training_ops --skip_submodule_sync --skip_tests
cp build/Linux/Debug/ort-wasm*.wasm js/web/dist/
cp build/Linux/Debug/ort-wasm*.js js/web/lib/wasm/binding/
cd js/web
NODE_OPTIONS=--max-old-space-size=4096 npm run build
You might get some errors but if you see ort.js and ort-web.js in the dist/ folder, then it should work.
- (If you built ONNX Runtime Web yourself)
Put the files from the ONNX Runtime Web build (ort.js and others such as the wasm files, if needed) in
training/public/onnxruntime_web_build_inference_with_training_ops/
:
# In the onnxruntime root directory, do:
rm <your workspace>/train-pytorch-in-js/training/public/onnxruntime_web_build_inference_with_training_ops/*.{js,wasm}
cp js/web/dist/* js/web/lib/wasm/binding/* <your workspace>/train-pytorch-in-js/training/public/onnxruntime_web_build_inference_with_training_ops
# Get the declaration files.
cp js/common/dist/lib/*.d.ts <your workspace>/FL/train-pytorch-in-js/training/src/ort
- Copy some files to
training/public/
:
cp *_graph.onnx training/public
Copy the MNIST data:
cd export
cp -R ../data training/public/
- Go to the
training
folder:
cd training
- Run
yarn install
- Run
yarn start
Your browser should open. Click "TRAIN" to train the model.