Measure the training throughput on TPU v3
Requirements
You need to have access to TPU resource on Google Cloud Platform (GCP).
Steps
Set up TPU instance and VM
- Follow instructions here to set up GCP VM and TPU instance.
- The type of VM machine used in our experiments is
n1-highmem-8 (8 vCPUs, 52 GB memory)
.
- The VM disk type is
Standard persistent disk
with 120 GB capacity.
- The type of TPU instance is
v3-8
with software version pytorch-1.7
.
Prepare codebase and download docker image
- Start the VM and TPU instance you just created and ensure the VM knows the TPU IP address (See instructions here).
- Clone this repo:
git clone https://github.com/UofT-EcoSystem/hfta.git
cd hfta; git checkout releases/mlsys21
- Download and enter the docker image:
bash docker/launch_xla.sh
. The docker image will generally be more than 20 GB.
- Install basic requirements for HFTA:
pip install -e .[xla]
- Install additional requirements for benchmarking:
pip install plyfile
Run experiments
- Prepare datasets
- Under the root directory of the repo, run
source datasets/prepare_datasets.sh
.
- Download the dataset by calling helper functions defined in
prepare_datasets.sh
. For example: run prepare_bert
for BERT experiment.
- Prepare experiment workflow helper functions
- Under the root directory of the repo, run
source benchmarks/workflow.sh xla v3 ./MLSys21/benchmarks
.
- The above command will set the target device to be
TPU v3
and the output directory to be ./MLSys21/benchmarks
. You can change the output directory as what you want.
- Run experiments by calling workflow helper functions
- For example, in order to run BERT experiment, call bash function:
workflow_bert
- The workflow functions are defined in the
_workflow_<modelname>.sh
files under <repo root>/benchmarks
.
- The functions are generally named as
workflow_<modelname>
.
Plot speedup curves
- After the workflow experiment is done, run bash function to process the output and plot the speedup curves.
- For example, run
plot_bert
for BERT experiment.
- The plot functions are also defined in
_workflow_<modelname>.sh
files.
- Finally, you should be able to see the
.csv
and .png
files under the output directory (./MLSys21/benchmarks
).