Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed Oct 28, 2024
1 parent 4cc74de commit dc87e85
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
4 changes: 2 additions & 2 deletions infra/cluster/job-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ available_node_types:
tpu_slice_v5e_16:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }
resources: { "CPU": 120, "TPU": 8 }

node_config:
acceleratorType: v5litepod-16
Expand All @@ -142,7 +142,7 @@ available_node_types:
tpu_slice_v5e_256:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }
resources: { "CPU": 120, "TPU": 8 }

node_config:
acceleratorType: v5litepod-256
Expand Down
14 changes: 11 additions & 3 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,20 +161,20 @@ def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResu
return TpuFailed(info, e)

actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore
info = _TpuInfo("get_slice_info", "ACTIVE", "TPU")
futures = [actor.get_slice_info.remote() for actor in actors]
try:
logger.info("Getting slice infos...")
# also act as a sync step
slice_infos = ray.get(futures)
logger.info(f"TPU slice infos {slice_infos}")
except RayError as e:
logger.exception(e)
for actor in actors:
try:
ray.cancel(actor)
except Exception:
logger.exception("Failed to kill actor after primary failure")
return [_handle_ray_error(info, e)]
return futures

coordinator_ip = slice_infos[0][2]

Expand All @@ -197,6 +197,7 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):

tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu
num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8

remote_fn = remote_fn.options(
runtime_env=runtime_env,
resources={tpu_name: 1, "TPU": num_tpus_per_host},
Expand Down Expand Up @@ -299,9 +300,13 @@ def run_on_pod_multislice_resumable(
logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}")
attempt += 1
problem = None
futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices)
try:
outs = ray.get(run_on_pod_multislice(remote_fn, tpu_type, num_slices))
outs = ray.get(futures)
except ray.exceptions.RayTaskError as e:
for f in futures:
ray.cancel(f)
logger.info(f"Cancelling {f}")
problem = e
if "preempted" in str(e).lower():
num_preemptions += 1
Expand All @@ -311,6 +316,9 @@ def run_on_pod_multislice_resumable(
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue
except Exception as e:
for f in futures:
ray.cancel(f)
logger.info(f"Cancelling {f}")
problem = e
num_failures += 1
if num_failures >= max_retries_failure:
Expand Down

0 comments on commit dc87e85

Please sign in to comment.