Skip to content

Commit

Permalink
Support Trillium TPUs (#845)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Dec 20, 2024
1 parent 5d04609 commit d9678f5
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 75 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/scratch

# Configuration for TPU launches/secrets
.config
.levanter.yaml
.levanter.yaml

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
118 changes: 101 additions & 17 deletions config/llama2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,108 @@ data:
cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/"
tokenizer: "meta-llama/Llama-2-70b-hf"
model:
activation_function: silu
attn_backend: null
cross_entropy_block_size: null
flash_attention_block_size: null
gradient_checkpointing: true
gradient_checkpointing_block_size: 5
hidden_dim: 4096
initializer_range: 0.02
intermediate_dim: 14336
layer_norm_epsilon: 1.0e-05
num_heads: 32
num_kv_heads: 8
num_layers: 32
reference_checkpoint: meta-llama/Llama-2-7b-hf
rope:
factor: 1.0
theta: 10000
type: default
scan_layers: true
seq_len: 4096
tie_word_embeddings: false
type: llama
# TODO: uncomment this once we resolve the resource exhaustion issue
# initialize_from_hf: "meta-llama/Llama-2-7b-hf"
# use_hf_model_config: true
upcast_attn: false
use_bias: false
use_flash_attention: true
use_layer_norm_weight: true
optimizer:
beta1: 0.9
beta2: 0.95
cooldown: null
cycle_length: 10000
cycles: null
decay: 0.1
default_weight_decay_mask: null
epsilon: 1.0e-08
haps: null
learning_rate: 0.001
lr_schedule: inv
max_grad_norm: 1.0
min_lr_ratio: 0.1
rewarmup: 0.0
type: adam
warmup: 1000
weight_decay: 0.05
weight_decay_modules: null
trainer:
axis_resources: {}
batch_axis: batch
checkpointer:
append_run_id_to_base_path: false
base_path: gs://levanter-checkpoints/checkpoints/llama-8b-tootsie-0.001-19ad63/checkpoints
keep:
- every: 20000
save_interval: 10m
fp8: null
fsdp_axis: embed
id: llama-8b-tootsie-0.001-19ad63
initialize_from: null
jax_config:
jax_softmax_custom_jvp: true
jax_threefry_partitionable: true
load_checkpoint: null
load_checkpoint_path: null
log_dir: logs
max_eval_batches: null
model_axis_size: 1
mp: compute=bfloat16,params=float32,output=bfloat16
num_train_steps: 10000
parameter_axis_resources: {}
per_device_eval_parallelism: 2
per_device_parallelism: 2
profiler: false
profiler_num_steps: 100
profiler_perfetto_link: false
profiler_start_step: 5
ray:
address: null
auto_start_cluster: false
start_workers: false
# replica_dcn_axis_size: 2
# replica_ici_axis_size: 1
require_accelerator: true
seed: 0
shutdown_at_exit: false
steps_per_eval: 10000
tensor_parallel_axes: null
tracker:
entity: null
group: null
id: null
mode: null
name: null
project: levanter
resume: allow
save_code: true
save_xla_dumps: false
tags:
- llama-8b-test
- llama
- 8b
- wsd-s
type: wandb
project: "levanter"
tags: ["openwebtext", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 256 # set for v4-64 TPU
num_train_steps: 1000
steps_per_eval: 50
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 1.2E-5 # set low for fine-tuning
weight_decay: 0.1
min_lr_ratio: 0.1
train_batch_size: 1024
wandb: null
use_hf_model_config: false
13 changes: 9 additions & 4 deletions docs/Getting-Started-TPU-VM.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ on your development machine to build and run images on TPUs.
First create a configuration file for future launches in your Levanter directory:

```bash
cat > .config <<EOF
cat > .levanter.yaml <<EOF
env:
WANDB_API_KEY:
WANDB_ENTITY:
Expand All @@ -93,15 +93,18 @@ env:
TPU_MIN_LOG_LEVEL: 0
LIBTPU_INIT_ARGS: <extra args to libtpu> # Optional
# Optional: specific environment variables for TPUs based on the TPU type
accel_env:
v6e:
# If you're lucky enough to have a v6e, you can set the following, which is pretty important for performance
LIBTPU_INIT_ARGS: "--xla_tpu_scoped_vmem_limit_kib=98304"
docker_repository: levanter # default
zone: us-west4-a # if not set, will use your default zone
tpu_name: test-spin-up-32
tpu_type: "v5litepod-16"
vm_image: "tpu-ubuntu2204-base" # default
capacity_type: "preemptible"
autodelete: false
subnetwork: "default" # default
EOF
```

Expand Down Expand Up @@ -155,6 +158,8 @@ a new file:
If you're using `launch.py`, the config will be automatically uploaded as part of your Docker image, so you
can just reference the local config path in your command line:

```bash
python infra/launch.py -- python src/levanter/main/train_lm.py --config_path config/my_config.yaml --trainer.checkpointer.base_path gs://<somewhere>'
```

Afterward, you can use the config directly from the TPU VM instance, e.g.:
Expand Down
29 changes: 22 additions & 7 deletions infra/launch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/python

import argparse
import getpass
import subprocess
import sys
import time
from pathlib import Path

Expand All @@ -12,6 +12,14 @@
from levanter.infra.tpus import launch_job


# default: tpu-ubuntu2204-base
TPU_TYPE_TO_VM_IMAGE = {
"v5litepod": "v2-alpha-tpuv5-lite",
"v5p": "v2-alpha-tpuv5",
"v6e": "v2-alpha-tpuv6e",
}


def main():
parser = argparse.ArgumentParser()
config = cli.load_config()
Expand All @@ -28,7 +36,7 @@ def main():
cli.add_arg(parser, config, ["--tpu_name"], required=True)
cli.add_arg(parser, config, ["--tpu_type"], required=True)
cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--version"], default="tpu-ubuntu2204-base")
cli.add_arg(parser, config, ["--version"], default=None)
cli.add_arg(parser, config, ["--zone"], default=None, type=str, required=False)
cli.add_arg(parser, config, ["--retries"], default=10, type=int)
cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str)
Expand All @@ -37,9 +45,7 @@ def main():
cli.add_arg(parser, config, ["--github_token"], type=str)
cli.add_arg(parser, config, ["--extra_context"], type=Path, required=False, default=None)

parser.add_argument(
"-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"), default=list(config.get("env", {}).items())
)
parser.add_argument("-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"))
parser.add_argument("command", nargs=argparse.REMAINDER)

args = parser.parse_args()
Expand All @@ -57,8 +63,14 @@ def main():
retries = args.retries
tpu_name = args.tpu_name
tpu_type = args.tpu_type

tpu_gen = tpu_type.split("-")[0]
version = args.version or TPU_TYPE_TO_VM_IMAGE.get(tpu_gen, "tpu-ubuntu2204-base")

if not args.version:
print(f"Using default version: {version}", file=sys.stderr)

node_count = args.node_count
version = args.version
zone = args.zone
run_id = args.run_id
registry = args.docker_registry
Expand All @@ -73,7 +85,10 @@ def main():
raise ValueError("Zone must be specified or set in gcloud config.")

region = "-".join(zone.split("-")[:-1])
env = {k: v for k, v in args.env}

env = config.env_for_accel(tpu_type)
for key, value in args.env or []:
env[key] = value

if "WANDB_PROJECT" not in env:
env["WANDB_PROJECT"] = "levanter"
Expand Down
9 changes: 6 additions & 3 deletions infra/launch_on_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def main():
cli.add_arg(parser, config, ["--extra_context"], type=Path, required=False, default=None)
cli.add_arg(parser, config, ["--zone"], default=None, type=str, required=False)

parser.add_argument(
"-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"), default=list(config.get("env", {}).items())
)
parser.add_argument("-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"))

parser.add_argument("command", nargs=argparse.REMAINDER)

args = parser.parse_args()
Expand All @@ -62,6 +61,10 @@ def main():
github_token = args.github_token
extra_context = args.extra_context

env = config.env_for_accel(tpu_type)
for key, value in args.env or []:
env[key] = value

if zone is None:
zone = cli.gcloud_config()["zone"]

Expand Down
Loading

0 comments on commit d9678f5

Please sign in to comment.