diff --git a/dask_cuda/benchmarks/local_cudf_groupby.py b/dask_cuda/benchmarks/local_cudf_groupby.py index 4e9dea94e..2f07e3df7 100644 --- a/dask_cuda/benchmarks/local_cudf_groupby.py +++ b/dask_cuda/benchmarks/local_cudf_groupby.py @@ -139,7 +139,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}" ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/local_cudf_merge.py b/dask_cuda/benchmarks/local_cudf_merge.py index f26a26ae9..ba3a9d56d 100644 --- a/dask_cuda/benchmarks/local_cudf_merge.py +++ b/dask_cuda/benchmarks/local_cudf_merge.py @@ -217,7 +217,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") print_key_value(key="Frac-match", value=f"{args.frac_match}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/local_cudf_shuffle.py b/dask_cuda/benchmarks/local_cudf_shuffle.py index 51ba48f93..a3492b664 100644 --- a/dask_cuda/benchmarks/local_cudf_shuffle.py +++ b/dask_cuda/benchmarks/local_cudf_shuffle.py @@ -146,7 +146,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}" ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/local_cupy.py b/dask_cuda/benchmarks/local_cupy.py index 1c1d12d30..22c51556f 100644 --- a/dask_cuda/benchmarks/local_cupy.py +++ b/dask_cuda/benchmarks/local_cupy.py @@ -193,7 +193,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") print_key_value(key="Protocol", value=f"{args.protocol}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/local_cupy_map_overlap.py b/dask_cuda/benchmarks/local_cupy_map_overlap.py index f40318559..8250c9f9f 100644 --- a/dask_cuda/benchmarks/local_cupy_map_overlap.py +++ b/dask_cuda/benchmarks/local_cupy_map_overlap.py @@ -78,7 +78,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") print_key_value(key="Protocol", value=f"{args.protocol}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/utils.py b/dask_cuda/benchmarks/utils.py index d3ce666b2..51fae7201 100644 --- a/dask_cuda/benchmarks/utils.py +++ b/dask_cuda/benchmarks/utils.py @@ -73,7 +73,7 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[] cluster_args.add_argument( "-p", "--protocol", - choices=["tcp", "ucx"], + choices=["tcp", "ucx", "ucxx"], default="tcp", type=str, help="The communication protocol to use.", diff --git a/dask_cuda/initialize.py b/dask_cuda/initialize.py index 0b9c92a59..571a46a55 100644 --- a/dask_cuda/initialize.py +++ b/dask_cuda/initialize.py @@ -5,7 +5,6 @@ import numba.cuda import dask -import distributed.comm.ucx from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context from .utils import get_ucx_config @@ -23,12 +22,21 @@ def _create_cuda_context_handler(): numba.cuda.current_context() -def _create_cuda_context(): +def _create_cuda_context(protocol="ucx"): + if protocol not in ["ucx", "ucxx"]: + return try: # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA # context directly from the UCX module, thus avoiding a similar warning there. try: - distributed.comm.ucx.init_once() + if protocol == "ucx": + import distributed.comm.ucx + + distributed.comm.ucx.init_once() + elif protocol == "ucxx": + import distributed_ucxx.ucxx + + distributed_ucxx.ucxx.init_once() except ModuleNotFoundError: # UCX initialization has to be delegated to Distributed, it will take care # of setting correct environment variables and importing `ucp` after that. @@ -39,20 +47,35 @@ def _create_cuda_context(): os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] ) ctx = has_cuda_context() - if ( - ctx.has_context - and not distributed.comm.ucx.cuda_context_created.has_context - ): - distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid()) + if protocol == "ucx": + if ( + ctx.has_context + and not distributed.comm.ucx.cuda_context_created.has_context + ): + distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid()) + elif protocol == "ucxx": + if ( + ctx.has_context + and not distributed_ucxx.ucxx.cuda_context_created.has_context + ): + distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid()) _create_cuda_context_handler() - if not distributed.comm.ucx.cuda_context_created.has_context: - ctx = has_cuda_context() - if ctx.has_context and ctx.device_info != cuda_visible_device: - distributed.comm.ucx._warn_cuda_context_wrong_device( - cuda_visible_device, ctx.device_info, os.getpid() - ) + if protocol == "ucx": + if not distributed.comm.ucx.cuda_context_created.has_context: + ctx = has_cuda_context() + if ctx.has_context and ctx.device_info != cuda_visible_device: + distributed.comm.ucx._warn_cuda_context_wrong_device( + cuda_visible_device, ctx.device_info, os.getpid() + ) + elif protocol == "ucxx": + if not distributed_ucxx.ucxx.cuda_context_created.has_context: + ctx = has_cuda_context() + if ctx.has_context and ctx.device_info != cuda_visible_device: + distributed_ucxx.ucxx._warn_cuda_context_wrong_device( + cuda_visible_device, ctx.device_info, os.getpid() + ) except Exception: logger.error("Unable to start CUDA Context", exc_info=True) @@ -64,6 +87,7 @@ def initialize( enable_infiniband=None, enable_nvlink=None, enable_rdmacm=None, + protocol="ucx", ): """Create CUDA context and initialize UCX-Py, depending on user parameters. @@ -118,7 +142,7 @@ def initialize( dask.config.set({"distributed.comm.ucx": ucx_config}) if create_cuda_context: - _create_cuda_context() + _create_cuda_context(protocol=protocol) @click.command() @@ -127,6 +151,12 @@ def initialize( default=False, help="Create CUDA context", ) +@click.option( + "--protocol", + default=None, + type=str, + help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.", +) @click.option( "--enable-tcp-over-ucx/--disable-tcp-over-ucx", default=False, @@ -150,10 +180,11 @@ def initialize( def dask_setup( service, create_cuda_context, + protocol, enable_tcp_over_ucx, enable_infiniband, enable_nvlink, enable_rdmacm, ): if create_cuda_context: - _create_cuda_context() + _create_cuda_context(protocol=protocol) diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index d0ea92748..7a5c8c13d 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -319,8 +319,11 @@ def __init__( if enable_tcp_over_ucx or enable_infiniband or enable_nvlink: if protocol is None: protocol = "ucx" - elif protocol != "ucx": - raise TypeError("Enabling InfiniBand or NVLink requires protocol='ucx'") + elif protocol not in ["ucx", "ucxx"]: + raise TypeError( + "Enabling InfiniBand or NVLink requires protocol='ucx' or " + "protocol='ucxx'" + ) self.host = kwargs.get("host", None) @@ -371,7 +374,7 @@ def __init__( ) + ["dask_cuda.initialize"] self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get( "preload_argv", [] - ) + ["--create-cuda-context"] + ) + ["--create-cuda-context", "--protocol", protocol] self.cuda_visible_devices = CUDA_VISIBLE_DEVICES self.scale(n_workers) diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index f16ad18a2..ff4dbbae3 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -287,7 +287,7 @@ def get_preload_options( if create_cuda_context: preload_options["preload_argv"].append("--create-cuda-context") - if protocol == "ucx": + if protocol in ["ucx", "ucxx"]: initialize_ucx_argv = [] if enable_tcp_over_ucx: initialize_ucx_argv.append("--enable-tcp-over-ucx") @@ -625,6 +625,10 @@ def get_worker_config(dask_worker): import ucp ret["ucx-transports"] = ucp.get_active_transports() + elif scheme == "ucxx": + import ucxx + + ret["ucx-transports"] = ucxx.get_active_transports() # comm timeouts ret["distributed.comm.timeouts"] = dask.config.get("distributed.comm.timeouts")