diff --git a/infra/cluster/job-cluster.yaml b/infra/cluster/job-cluster.yaml index cf8703d54..cff7d4884 100644 --- a/infra/cluster/job-cluster.yaml +++ b/infra/cluster/job-cluster.yaml @@ -14,8 +14,8 @@ cluster_name: levanter-cluster # Configure GCP provider: type: gcp - region: us-central2 - availability_zone: us-central2-b + region: us-west4 + availability_zone: us-west4-a project_id: hai-gcp-models # Maximum Workers (excluding Head Node) @@ -126,6 +126,45 @@ available_node_types: schedulingConfig: preemptible: true + tpu_slice_v5e_16: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-16 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + + tpu_slice_v5e_64: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-64 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + + tpu_slice_v5e_256: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-256 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + docker: image: "ghcr.io/stanford-crfm/levanter-cluster:latest" container_name: "ray_docker" @@ -140,7 +179,7 @@ docker: - -v "/var/run/docker.sock:/var/run/docker.sock" initialization_commands: - - yes | gcloud auth configure-docker us-central2-docker.pkg.dev + - yes | gcloud auth configure-docker us-west4-docker.pkg.dev - "export TPU_WORKER_ID=$(curl -H 'Metadata-Flavor: Google' http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number) || true" - which docker || (curl -fsSL https://get.docker.com -o get-docker.sh; sudo sh get-docker.sh; sudo usermod -aG docker $USER; sudo systemctl restart docker -f) # always run this because ray doesn't run with sudo diff --git a/infra/launch_on_ray.py b/infra/launch_on_ray.py index fa5e81f27..90f2c586a 100755 --- a/infra/launch_on_ray.py +++ b/infra/launch_on_ray.py @@ -27,7 +27,7 @@ def main(): cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"]) cli.add_arg(parser, config, ["--tpu_type"], required=True) # TODO: bring node_count to Ray - # cli.add_arg(parser, config, ["--node_count"], default=1, type=int) + cli.add_arg(parser, config, ["--node_count"], default=1, type=int) cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true") cli.add_arg(parser, config, ["--retries"], default=10, type=int) cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str) @@ -122,6 +122,7 @@ def main(): env=env, name="levanter", retries=retries, + node_count=args.node_count, ) address = args.address or os.getenv("RAY_ADDRESS") diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py index b92b6efb5..58413ef2b 100644 --- a/src/levanter/infra/cli_helpers.py +++ b/src/levanter/infra/cli_helpers.py @@ -76,6 +76,11 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante "/tmp:/tmp", ] + # optionally add multislice env vars (if set by ray runtime env vars) + for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]: + v = shlex.quote(str(v)) + docker_command.extend(["-e", v]) + for k, v in env.items(): v = shlex.quote(str(v)) k = shlex.quote(str(k)) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 2dc554808..57f484770 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -3,6 +3,7 @@ import logging import multiprocessing import os +import socket import subprocess import tempfile import time @@ -104,7 +105,83 @@ def do_run(remote_fn) -> _TpuRunResult: return do_run.remote(remote_fn) -def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts): +def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef: + """ + Run a remote function on multiple TPU slices. + + Args: + remote_fn: A remote function that takes no arguments + tpu_type: The type of TPU to run on, e.g. "v4-32" + num_slices: The number of slices to run + + Returns: + A Ray ObjectRef that represents the result of the function + """ + + @ray.remote(resources={f"TPU-{tpu_type}-head": 1}) + class MultisliceActor: + def __init__(self): + self.pod_name = ray.util.accelerators.tpu.get_current_pod_name() + self.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() + self.ip = socket.gethostbyname(socket.gethostname()) + + def get_slice_info(self): + return self.pod_name, self.num_hosts, self.ip + + def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult: + port = 8081 + mxla_env = { + "MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}", + "MEGASCALE_NUM_SLICES": str(num_slices), + "MEGASCALE_PORT": f"{port}", + "MEGASCALE_SLICE_ID": str(slice_id), + } + + remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env) + + info = _TpuInfo(tpu_name, "ACTIVE", "TPU") + futures = [remote_fn.remote() for _ in range(self.num_hosts)] + try: + out = ray.get(futures) + logger.info("TPU job finished") + return TpuSuccess(info, out) + except RayError as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return _handle_ray_error(info, e) + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return TpuFailed(info, e) + + actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore + futures = [actor.get_slice_info.remote() for actor in actors] + try: + logger.info("Getting slice infos...") + # also act as a sync step + slice_infos = ray.get(futures) + logger.info(f"TPU slice infos {slice_infos}") + except RayError as e: + logger.exception(e) + for actor in actors: + try: + ray.cancel(actor) + except Exception: + logger.exception("Failed to kill actor after primary failure") + return futures + + coordinator_ip = slice_infos[0][2] + + return [actor.do_run.remote(remote_fn, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)] + + +def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): """ Redecorate a remote function to run on a TPU pod. @@ -120,7 +197,11 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts): tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8 - remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host}) + + remote_fn = remote_fn.options( + runtime_env=runtime_env, + resources={tpu_name: 1, "TPU": num_tpus_per_host}, + ) logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host") return remote_fn, tpu_name @@ -193,11 +274,107 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re raise RuntimeError("Failed too many times") from problem +def run_on_pod_multislice_resumable( + remote_fn, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10 +): + """ + Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached. + + Args: + remote_fn: A remote function that takes no arguments + tpu_type: The type of TPU to run on, e.g. "v4-32" + num_slices: The number of slices to run + max_retries_preemption: The maximum number of times to retry if the job is preempted + max_retries_failure: The maximum number of times to retry if the job fails + + Returns: + The result of the function (not an ObjectRef) + + """ + num_failures = 0 + num_preemptions = 0 + attempt = 0 + problem: Exception | None = None + + while num_failures < max_retries_failure and num_preemptions < max_retries_preemption: + logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}") + attempt += 1 + problem = None + futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices) + try: + outs = ray.get(futures) + except ray.exceptions.RayTaskError as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + problem = e + if "preempted" in str(e).lower(): + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times, {e}") + else: + num_failures += 1 + logger.warning(f"Failed {num_failures} times", exc_info=e) + continue + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + problem = e + num_failures += 1 + if num_failures >= max_retries_failure: + logger.exception("Failed too many times", exc_info=e) + raise e + else: + logger.warning(f"Failed {num_failures} times", exc_info=e) + continue + + if all(isinstance(out, TpuSuccess) for out in outs): + results = [out.result for out in outs] + logger.info("Success") + return results + elif any(isinstance(out, TpuPreempted) for out in outs): + out = None + for o in outs: + if isinstance(o, TpuPreempted): + out = o + assert out is not None + problem = out.error + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem) + elif any(isinstance(out, TpuFailed) for out in outs): + num_preemptions += 1 + logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times") + elif any(isinstance(out, TpuRunError) for out in outs): + out = None + for o in outs: + if isinstance(o, TpuRunError): + out = o + assert out is not None + problem = out.error + num_preemptions += 1 + problem = out.error + num_failures += 1 + logger.warning(f"Failed {num_failures} times", exc_info=problem) + else: + raise RuntimeError(f"Unexpected result: {out}") + + if num_preemptions >= max_retries_preemption: + raise RuntimeError("Preempted too many times") from problem + elif num_failures >= max_retries_failure: + raise RuntimeError("Failed too many times") from problem + + def _run_command(*args, **kwargs): return subprocess.check_call(args, **kwargs) -def run_docker_on_pod(image_id: str, command: Sequence[str], *, tpu_type: str, env: dict, name="levanter", retries=10): +def run_docker_on_pod( + image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10 +): env = _massage_env(env) docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name) @@ -210,9 +387,18 @@ def run_docker(): logger.exception("Failed to run docker command") raise e - run_on_pod_resumable( - ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000 - ) + if num_slices == 1: + run_on_pod_resumable( + ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000 + ) + else: + run_on_pod_multislice_resumable( + ray.remote(run_docker), + tpu_type=tpu_type, + num_slices=num_slices, + max_retries_failure=retries, + max_retries_preemption=10000, + ) def _kill_old_container(name): @@ -351,6 +537,7 @@ class RunDockerOnPodConfig: env: dict = dataclasses.field(default_factory=dict) name: str = "levanter" retries: int = 10 + node_count: int = 1 def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None): @@ -419,6 +606,8 @@ def main(args: RunDockerOnPodConfig): tpu_type=args.tpu_type, env=args.env, name=args.name, + retries=args.retries, + num_slices=args.node_count, )