diff --git a/infra/cluster/job-cluster.yaml b/infra/cluster/job-cluster.yaml index 2f9e29b75..d692f8e16 100644 --- a/infra/cluster/job-cluster.yaml +++ b/infra/cluster/job-cluster.yaml @@ -129,7 +129,7 @@ available_node_types: tpu_slice_v5e_16: min_workers: 0 max_workers: 1024 - resources: { "CPU": 120, "TPU": 4 } + resources: { "CPU": 120, "TPU": 8 } node_config: acceleratorType: v5litepod-16 @@ -142,7 +142,7 @@ available_node_types: tpu_slice_v5e_256: min_workers: 0 max_workers: 1024 - resources: { "CPU": 120, "TPU": 4 } + resources: { "CPU": 120, "TPU": 8 } node_config: acceleratorType: v5litepod-256 diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index d14ddf5d3..5b8f56e73 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -161,7 +161,6 @@ def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResu return TpuFailed(info, e) actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore - info = _TpuInfo("get_slice_info", "ACTIVE", "TPU") futures = [actor.get_slice_info.remote() for actor in actors] try: logger.info("Getting slice infos...") @@ -169,12 +168,13 @@ def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResu slice_infos = ray.get(futures) logger.info(f"TPU slice infos {slice_infos}") except RayError as e: + logger.exception(e) for actor in actors: try: ray.cancel(actor) except Exception: logger.exception("Failed to kill actor after primary failure") - return [_handle_ray_error(info, e)] + return futures coordinator_ip = slice_infos[0][2] @@ -197,6 +197,7 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8 + remote_fn = remote_fn.options( runtime_env=runtime_env, resources={tpu_name: 1, "TPU": num_tpus_per_host}, @@ -299,9 +300,13 @@ def run_on_pod_multislice_resumable( logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}") attempt += 1 problem = None + futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices) try: - outs = ray.get(run_on_pod_multislice(remote_fn, tpu_type, num_slices)) + outs = ray.get(futures) except ray.exceptions.RayTaskError as e: + for f in futures: + ray.cancel(f) + logger.info(f"Cancelling {f}") problem = e if "preempted" in str(e).lower(): num_preemptions += 1 @@ -311,6 +316,9 @@ def run_on_pod_multislice_resumable( logger.warning(f"Failed {num_failures} times", exc_info=e) continue except Exception as e: + for f in futures: + ray.cancel(f) + logger.info(f"Cancelling {f}") problem = e num_failures += 1 if num_failures >= max_retries_failure: