The following are details on how to achieve 57%+ MFU results on TPU v5e on 16B and 32B model configurations.
Model Size | Hardware | TFLOP/sec/chip | MFU |
---|---|---|---|
16 B | 2x v5e-256 | 114 | 57.8% |
32 B | 2x v5e-256 | 113 | 57.3% |
Although the results are produced from running on only 2 pods, we expect scaling batch linearly with the number of pods will provide an almost constant MFU because less than 1% of time is spent on data center networking (DCN).
Depending on your organization's set up, these instructions may vary. Here we assume a vanilla GCE setup. Feel free to reach out with questions if your organization has a unique set up.
-
Create a custom MTU network to optimize performance and give it firewall permission. If you are unable to complete this step, you may skip it. This network step increases MFU by ~2%, but is not necessary.
Create a network with an MTU of 8896 bytes and setup firewall. (Creating a network requires
compute.networks.create
permission in your project)gcloud compute networks create mtu9k --mtu=8896 --project={PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create mtu9kfw --network mtu9k --allow tcp,icmp,udp --project={PROJECT}
When you create your TPUs, you need to indicate they should be part of this network using the
--network
flag (--network=mtu9k
). Below is an example of a queued-resources requestgcloud alpha compute tpus queued-resources create ${QR_ID} --node-prefix=${TPU_NAME} --node-count=${NUM_SLICES} --accelerator_type=${ACCELERATOR_TYPE} --runtime_version=${RUNTIME_VERSION} --network=mtu9k --project={$PROJECT} --zone=${ZONE}
Note: If you want to use only one slice, you need to replace node-prefix with node-id, and remove node-count
-
Install MaxText on your runner.
# Install maxtext git clone [email protected]:google/maxtext.git
-
Download dataset and set up GCS paths. If you have not downloaded the dataset before, we recommend following the maxtext repository README.
-
Install MaxText dependencies and run 16b.sh or 32b.sh on each worker.
bash setup.sh && bash MaxText/configs/16b.sh ${YOUR_RUN_NAME} ${MAXTEXT_OUTPUT_PATH} ${MAXTEXT_DATASET_PATH}
We recommend either the orchestration tool multihost_runner.py to quickly get code up and running for fast experimentation or multihost_job.py for longer training runs. If you use these tools, then use the above as the input to
--COMMAND
, e.g.:python3 multihost_runner.py --TPU_PREFIX=${TPU_PREFIX} --COMMAND="bash setup.sh && bash MaxText/configs/16b.sh ${YOUR_RUN_NAME} ${MAXTEXT_OUTPUT_PATH} ${MAXTEXT_DATASET_PATH}"
Note that these configurations do 3 things:
- Sets various XLA flags as
LIBTPU_INIT_ARGS
to optimize performance. - Runs rto_setup.sh which optimizes communication protocols for performance. (This script only needs to be run once on each worker)
- Runs train.py with specific hyper-parameters (batch size, embedding size, etc.)
These configurations are coded to only run for 150 steps without checkpointing, but can be changed by editing test_tflops_16b_params.sh and test_tflops_32b_params.sh.
- Sets various XLA flags as
We have found that for these small models, a fairly large batch size gives the best MFU (these configurations match a per pod batch size of 3.1 million tokens for 16B and 2.1 million tokens for 32B). This provides adequate scalability for the 16B and 32B models to converge in 15 days while not exceeding the ~8M token global batch size budget. You can also slightly lower the batch size with a fairly modest performance degradataion.