Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multislice support in ray #771

Merged
merged 11 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion infra/launch_on_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"])
cli.add_arg(parser, config, ["--tpu_type"], required=True)
# TODO: bring node_count to Ray
# cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true")
cli.add_arg(parser, config, ["--retries"], default=10, type=int)
cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str)
Expand Down Expand Up @@ -122,6 +122,7 @@ def main():
env=env,
name="levanter",
retries=retries,
node_count=args.node_count,
)

address = args.address or os.getenv("RAY_ADDRESS")
Expand Down
31 changes: 31 additions & 0 deletions src/levanter/infra/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ def get_git_commit():
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()


class DockerRunCommand:
def __init__(self, image_id, command, *, foreground, env, name="levanter"):
blahBlahhhJ marked this conversation as resolved.
Show resolved Hide resolved
self.base_part = [
"docker",
"run",
"-t" if foreground else "-d",
f"--name={name}",
"--privileged",
"--shm-size=32gb",
"--net=host",
"--init",
"--mount",
"type=volume,source=levanter,target=/home/levanter",
"-v",
"/tmp:/tmp",
]

self.env_part = []
self.add_env(env)

self.cmd_part = [image_id, *command]

def add_env(self, env):
blahBlahhhJ marked this conversation as resolved.
Show resolved Hide resolved
for k, v in env.items():
self.env_part.extend(["-e", k + f"={str(v)}"])

@property
def full_cmd(self):
return self.base_part + self.env_part + self.cmd_part


def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"):
docker_command = [
"docker",
Expand Down
234 changes: 211 additions & 23 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import multiprocessing
import os
import socket
import subprocess
import tempfile
import time
Expand All @@ -16,7 +17,7 @@
from ray.exceptions import NodeDiedError, RayError, RaySystemError, RayTaskError, WorkerCrashedError
from ray.remote_function import RemoteFunction

from levanter.infra.cli_helpers import make_docker_run_command
from levanter.infra.cli_helpers import DockerRunCommand
from levanter.utils.ray_utils import ser_exc_info


Expand Down Expand Up @@ -62,21 +63,28 @@ class TpuRunError(_TpuRunResult):
error: Exception


def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.ObjectRef:
def run_on_pod(docker_cmd: DockerRunCommand, name: str, tpu_type: str) -> ray.ObjectRef:
blahBlahhhJ marked this conversation as resolved.
Show resolved Hide resolved
"""
Run a remote function on a TPU pod.

Args:
remote_fn: A remote function that takes no arguments
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
tpu_type: The type of TPU to run on, e.g. "v4-32"

Returns:
A Ray ObjectRef that represents the result of the function
"""

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
def do_run(remote_fn) -> _TpuRunResult:
def do_run(docker_cmd: DockerRunCommand, name: str) -> _TpuRunResult:
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4

def _run_docker():
run_docker(docker_cmd=docker_cmd.full_cmd, name=name)

remote_fn = ray.remote(_run_docker)

remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts)

info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
Expand All @@ -93,10 +101,87 @@ def do_run(remote_fn) -> _TpuRunResult:
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)

return do_run.remote(remote_fn)
return do_run.remote(docker_cmd, name)


def run_on_pod_multislice(docker_cmd: DockerRunCommand, name: str, tpu_type: str, num_slices: int) -> ray.ObjectRef:
blahBlahhhJ marked this conversation as resolved.
Show resolved Hide resolved
"""
Run a remote function on multiple TPU slices.

Args:
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run

def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):
Returns:
A Ray ObjectRef that represents the result of the function
"""

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
class MultisliceActor:
def __init__(self):
self.pod_name = ray.util.accelerators.tpu.get_current_pod_name()
self.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
self.ip = socket.gethostbyname(socket.gethostname())

def get_slice_info(self):
return self.pod_name, self.num_hosts, self.ip

def do_run(self, docker_cmd, name, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:
port = 8081
mxla_env = {
"MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}",
"MEGASCALE_NUM_SLICES": str(num_slices),
"MEGASCALE_PORT": f"{port}",
"MEGASCALE_SLICE_ID": str(slice_id),
}

docker_cmd.add_env(mxla_env)

def _run_docker():
run_docker(docker_cmd=docker_cmd.full_cmd, name=name)

remote_fn = ray.remote(_run_docker)

remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env)

info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
futures = [remote_fn.remote() for _ in range(self.num_hosts)]
try:
out = ray.get(futures)
logger.info("TPU job finished")
return TpuSuccess(info, out)
except RayError as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)

actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore
info = _TpuInfo("get_slice_info", "ACTIVE", "TPU")
futures = [actor.get_slice_info.remote() for actor in actors]
try:
logger.info("Getting slice infos...")
# also act as a sync step
slice_infos = ray.get(futures)
logger.info(f"TPU slice infos {slice_infos}")
except RayError as e:
for actor in actors:
try:
ray.cancel(actor)
except Exception:
logger.exception("Failed to kill actor after primary failure")
return [_handle_ray_error(info, e)]

coordinator_ip = slice_infos[0][2]

return [actor.do_run.remote(docker_cmd, name, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)]


def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):
"""
Redecorate a remote function to run on a TPU pod.

Expand All @@ -112,17 +197,21 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):

tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu
num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host})
remote_fn = remote_fn.options(
runtime_env=runtime_env,
resources={tpu_name: 1, "TPU": num_tpus_per_host},
)
logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host")
return remote_fn, tpu_name


def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_retries_failure=10):
def run_on_pod_resumable(docker_cmd, name, tpu_type, max_retries_preemption=1e6, max_retries_failure=10):
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.

Args:
remote_fn: A remote function that takes no arguments
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
tpu_type: The type of TPU to run on, e.g. "v4-32"
max_retries_preemption: The maximum number of times to retry if the job is preempted
max_retries_failure: The maximum number of times to retry if the job fails
Expand All @@ -141,7 +230,7 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
attempt += 1
problem = None
try:
out = ray.get(run_on_pod(remote_fn, tpu_type))
out = ray.get(run_on_pod(docker_cmd, name, tpu_type))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e):
Expand Down Expand Up @@ -185,26 +274,123 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
raise RuntimeError("Failed too many times") from problem


def run_on_pod_multislice_resumable(
docker_cmd, name, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10
):
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.

Args:
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run
max_retries_preemption: The maximum number of times to retry if the job is preempted
max_retries_failure: The maximum number of times to retry if the job fails

Returns:
The result of the function (not an ObjectRef)

"""
num_failures = 0
num_preemptions = 0
attempt = 0
problem: Exception | None = None

while num_failures < max_retries_failure and num_preemptions < max_retries_preemption:
logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}")
attempt += 1
problem = None
try:
outs = ray.get(run_on_pod_multislice(docker_cmd, name, tpu_type, num_slices))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e):
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times, {e}")
else:
num_failures += 1
logger.warning(f"Failed {num_failures} times")
continue
except Exception as e:
problem = e
num_failures += 1
if num_failures >= max_retries_failure:
logger.exception("Failed too many times", exc_info=e)
raise e
else:
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue

if all(isinstance(out, TpuSuccess) for out in outs):
results = [out.result for out in outs]
logger.info("Success")
return results
elif any(isinstance(out, TpuPreempted) for out in outs):
out = None
for o in outs:
if isinstance(o, TpuPreempted):
out = o
assert out is not None
problem = out.error
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem)
elif any(isinstance(out, TpuFailed) for out in outs):
num_preemptions += 1
logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times")
elif any(isinstance(out, TpuRunError) for out in outs):
out = None
for o in outs:
if isinstance(o, TpuRunError):
out = o
assert out is not None
problem = out.error
num_preemptions += 1
problem = out.error
num_failures += 1
logger.warning(f"Failed {num_failures} times", exc_info=problem)
else:
raise RuntimeError(f"Unexpected result: {out}")

if num_preemptions >= max_retries_preemption:
raise RuntimeError("Preempted too many times") from problem
elif num_failures >= max_retries_failure:
raise RuntimeError("Failed too many times") from problem


def _run_command(*args, **kwargs):
return subprocess.check_call(args, **kwargs)


def run_docker_on_pod(image_id: str, command: Sequence[str], *, tpu_type: str, env: dict, name="levanter", retries=10):
env = _massage_env(env)
def run_docker(docker_cmd, name="levanter"):
_kill_old_container(name)
try:
return _run_command(*docker_cmd)
except subprocess.CalledProcessError as e:
logger.exception("Failed to run docker command")
raise e

docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name)

def run_docker():
_kill_old_container(name)
try:
return _run_command(*docker_cmd)
except subprocess.CalledProcessError as e:
logger.exception("Failed to run docker command")
raise e
def run_docker_on_pod(
image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10
):
env = _massage_env(env)

run_on_pod_resumable(
ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
docker_cmd = DockerRunCommand(image_id, command, env=env, foreground=True, name=name)

if num_slices == 1:
run_on_pod_resumable(
docker_cmd, name=name, tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
else:
run_on_pod_multislice_resumable(
docker_cmd,
name=name,
tpu_type=tpu_type,
num_slices=num_slices,
max_retries_failure=retries,
max_retries_preemption=10000,
)


def _kill_old_container(name):
Expand Down Expand Up @@ -343,6 +529,7 @@ class RunDockerOnPodConfig:
env: dict = dataclasses.field(default_factory=dict)
name: str = "levanter"
retries: int = 10
node_count: int = 1


def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None):
Expand Down Expand Up @@ -411,6 +598,7 @@ def main(args: RunDockerOnPodConfig):
tpu_type=args.tpu_type,
env=args.env,
name=args.name,
num_slices=args.node_count,
)


Expand Down
Loading