-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support building GPU docker image for MaxDiffusion Model (#121)
* adding gpu docker file * add more gpu dependency files, unify working directory to match xpk setup, add hardware gpu option in yml, add jax multi-host support for gpu * fix identation * reformatting * add gpu_multi_process_run.sh, unify working directory, update requirement to fix import error, add jax[cuda] install instruction more non-pinned mode when device is GPU * reformatting * resolve comments * delete gpu pinned mode
- Loading branch information
Showing
14 changed files
with
367 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#! /bin/bash | ||
set -e | ||
set -u | ||
set -o pipefail | ||
|
||
: "${NNODES:?Must set NNODES}" | ||
: "${NODE_RANK:?Must set NODE_RANK}" | ||
: "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}" | ||
: "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}" | ||
: "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}" | ||
: "${COMMAND:?Must set COMMAND}" | ||
|
||
|
||
export GPUS_PER_NODE=$GPUS_PER_NODE | ||
export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT | ||
export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS | ||
|
||
set_nccl_gpudirect_tcpx_specific_configuration() { | ||
if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then | ||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
export NCCL_CROSS_NIC=0 | ||
export NCCL_DEBUG=INFO | ||
export NCCL_DYNAMIC_CHUNK_SIZE=524288 | ||
export NCCL_NET_GDR_LEVEL=PIX | ||
export NCCL_NVLS_ENABLE=0 | ||
export NCCL_P2P_NET_CHUNKSIZE=524288 | ||
export NCCL_P2P_NVL_CHUNKSIZE=1048576 | ||
export NCCL_P2P_PCI_CHUNKSIZE=524288 | ||
export NCCL_PROTO=Simple | ||
export NCCL_SOCKET_IFNAME=eth0 | ||
export NVTE_FUSED_ATTN=1 | ||
export TF_CPP_MAX_LOG_LEVEL=100 | ||
export TF_CPP_VMODULE=profile_guided_latency_estimator=10 | ||
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 | ||
shopt -s globstar nullglob | ||
IFS=:$IFS | ||
set -- /usr/local/cuda-*/compat | ||
export LD_LIBRARY_PATH="${1+:"$*"}:${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64" | ||
IFS=${IFS#?} | ||
shopt -u globstar nullglob | ||
|
||
if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then | ||
echo "Using GPUDirect-TCPX" | ||
export NCCL_ALGO=Ring | ||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION | ||
export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0 | ||
export NCCL_GPUDIRECTTCPX_FORCE_ACK=0 | ||
export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000 | ||
export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191" | ||
export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4 | ||
export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177" | ||
export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000 | ||
export NCCL_MAX_NCHANNELS=12 | ||
export NCCL_MIN_NCHANNELS=12 | ||
export NCCL_NSOCKS_PERTHREAD=4 | ||
export NCCL_P2P_PXN_LEVEL=0 | ||
export NCCL_SOCKET_NTHREADS=1 | ||
elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then | ||
echo "Using GPUDirect-TCPFasTrak" | ||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 | ||
export NCCL_ALGO=Ring,Tree | ||
export NCCL_BUFFSIZE=8388608 | ||
export NCCL_FASTRAK_CTRL_DEV=eth0 | ||
export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0 | ||
export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0 | ||
export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8 | ||
export NCCL_FASTRAK_NUM_FLOWS=2 | ||
export NCCL_FASTRAK_USE_LLCM=1 | ||
export NCCL_FASTRAK_USE_SNAP=1 | ||
export NCCL_MIN_NCHANNELS=4 | ||
export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto | ||
export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto | ||
export NCCL_TUNER_PLUGIN=libnccl-tuner.so | ||
fi | ||
else | ||
echo "NOT using GPUDirect" | ||
fi | ||
} | ||
|
||
echo "LD_LIBRARY_PATH ${LD_LIBRARY_PATH}" | ||
|
||
set_nccl_gpudirect_tcpx_specific_configuration | ||
|
||
wait_all_success_or_exit() { | ||
# https://www.baeldung.com/linux/background-process-get-exit-code | ||
local pids=("$@") | ||
while [[ ${#pids[@]} -ne 0 ]]; do | ||
all_success="true" | ||
for pid in "${pids[@]}"; do | ||
code=$(non_blocking_wait "$pid") | ||
if [[ $code -ne 127 ]]; then | ||
if [[ $code -ne 0 ]]; then | ||
echo "PID $pid failed with exit code $code" | ||
exit "$code" | ||
fi | ||
else | ||
all_success="false" | ||
fi | ||
done | ||
if [[ $all_success == "true" ]]; then | ||
echo "All pids succeeded" | ||
break | ||
fi | ||
sleep 5 | ||
done | ||
} | ||
non_blocking_wait() { | ||
# https://www.baeldung.com/linux/background-process-get-exit-code | ||
local pid=$1 | ||
local code=127 # special code to indicate not-finished | ||
if [[ ! -d "/proc/$pid" ]]; then | ||
wait "$pid" | ||
code=$? | ||
fi | ||
echo $code | ||
} | ||
|
||
resolve_coordinator_ip() { | ||
local lookup_attempt=1 | ||
local max_coordinator_lookups=500 | ||
local coordinator_found=false | ||
local coordinator_ip_address="" | ||
|
||
echo "Coordinator Address $JAX_COORDINATOR_ADDRESS" | ||
|
||
while [[ "$coordinator_found" = false && $lookup_attempt -le $max_coordinator_lookups ]]; do | ||
coordinator_ip_address=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1) | ||
if [[ -n "$coordinator_ip_address" ]]; then | ||
coordinator_found=true | ||
echo "Coordinator IP address: $coordinator_ip_address" | ||
export JAX_COORDINATOR_IP=$coordinator_ip_address | ||
return 0 | ||
else | ||
echo "Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt, retrying..." | ||
((lookup_attempt++)) | ||
sleep 1 | ||
fi | ||
done | ||
|
||
if [[ "$coordinator_found" = false ]]; then | ||
echo "Failed to resolve coordinator address after $max_coordinator_lookups attempts." | ||
return 1 | ||
fi | ||
} | ||
|
||
# Resolving coordinator IP | ||
set +e | ||
resolve_coordinator_ip | ||
set -e | ||
|
||
PIDS=() | ||
eval ${COMMAND} & | ||
PID=$! | ||
PIDS+=($PID) | ||
|
||
wait_all_success_or_exit "${PIDS[@]}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# syntax=docker/dockerfile:experimental | ||
# Note: This pulls in the lastest of jax:base | ||
ARG BASEIMAGE=ghcr.io/nvidia/jax:base | ||
FROM $BASEIMAGE | ||
|
||
# Stopgaps measure to circumvent gpg key setup issue. | ||
RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list | ||
|
||
# Install dependencies for adjusting network rto | ||
RUN apt-get update && apt-get install -y iproute2 ethtool lsof | ||
|
||
# Install DNS util dependencies | ||
RUN apt-get install -y dnsutils | ||
|
||
# Add the Google Cloud SDK package repository | ||
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list | ||
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - | ||
|
||
# Install the Google Cloud SDK | ||
RUN apt-get update && apt-get install -y google-cloud-sdk | ||
|
||
# Set environment variables for Google Cloud SDK | ||
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}" | ||
|
||
# Upgrade libcusprase to work with Jax | ||
RUN apt-get update && apt-get install -y libcusparse-12-3 | ||
|
||
ARG MODE | ||
ENV ENV_MODE=$MODE | ||
|
||
ARG JAX_VERSION | ||
ENV ENV_JAX_VERSION=$JAX_VERSION | ||
|
||
ARG DEVICE | ||
ENV ENV_DEVICE=$DEVICE | ||
|
||
RUN mkdir -p /deps | ||
|
||
# Set the working directory in the container | ||
WORKDIR /deps | ||
|
||
# Copy all files from local workspace into docker container | ||
COPY . . | ||
RUN ls . | ||
|
||
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" | ||
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} | ||
|
||
|
||
WORKDIR /deps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.