Skip to content

Commit

Permalink
Support building GPU docker image for MaxDiffusion Model (#121)
Browse files Browse the repository at this point in the history
* adding gpu docker file

* add more gpu dependency files, unify working directory to match xpk setup, add hardware gpu option in yml, add jax multi-host support for gpu

* fix identation

* reformatting

* add gpu_multi_process_run.sh, unify working directory, update requirement to fix import error, add jax[cuda] install instruction more non-pinned mode when device is GPU

* reformatting

* resolve comments

* delete gpu pinned mode
  • Loading branch information
sizhit2 authored Oct 29, 2024
1 parent 7ca0857 commit 51e1db1
Show file tree
Hide file tree
Showing 14 changed files with 367 additions and 69 deletions.
49 changes: 32 additions & 17 deletions docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ if [[ -z ${MODE} ]]; then
echo "Default MODE=${MODE}"
fi

if [[ -z ${DEVICE} ]]; then
export DEVICE=tpu
echo "Default DEVICE=${DEVICE}"
fi
echo "DEVICE=${DEVICE}"

if [[ -z ${JAX_VERSION+x} ]] ; then
export JAX_VERSION=NONE
echo "Default JAX_VERSION=${JAX_VERSION}"
Expand All @@ -55,22 +61,31 @@ COMMIT_HASH=$(git rev-parse --short HEAD)

echo "Building MaxDiffusion with MODE=${MODE} at commit hash ${COMMIT_HASH} . . ."

if [[ "${MODE}" == "stable_stack" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
if [[ ${DEVICE} == "gpu" ]]; then
if [[ ${MODE} == "pinned" ]]; then
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-10-17
else
export BASEIMAGE=ghcr.io/nvidia/jax:base
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
else
if [[ "${MODE}" == "stable_stack" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi
docker build --no-cache \
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
else
docker build --no-cache \
--network=host \
--build-arg MODE=${MODE} \
--build-arg JAX_VERSION=${JAX_VERSION} \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_dependencies.Dockerfile .
fi
docker build --no-cache \
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
else
docker build --no-cache \
--network=host \
--build-arg MODE=${MODE} \
--build-arg JAX_VERSION=${JAX_VERSION} \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_dependencies.Dockerfile .
fi
156 changes: 156 additions & 0 deletions gpu_multi_process_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#! /bin/bash
set -e
set -u
set -o pipefail

: "${NNODES:?Must set NNODES}"
: "${NODE_RANK:?Must set NODE_RANK}"
: "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}"
: "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}"
: "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}"
: "${COMMAND:?Must set COMMAND}"


export GPUS_PER_NODE=$GPUS_PER_NODE
export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT
export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS

set_nccl_gpudirect_tcpx_specific_configuration() {
if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_CROSS_NIC=0
export NCCL_DEBUG=INFO
export NCCL_DYNAMIC_CHUNK_SIZE=524288
export NCCL_NET_GDR_LEVEL=PIX
export NCCL_NVLS_ENABLE=0
export NCCL_P2P_NET_CHUNKSIZE=524288
export NCCL_P2P_NVL_CHUNKSIZE=1048576
export NCCL_P2P_PCI_CHUNKSIZE=524288
export NCCL_PROTO=Simple
export NCCL_SOCKET_IFNAME=eth0
export NVTE_FUSED_ATTN=1
export TF_CPP_MAX_LOG_LEVEL=100
export TF_CPP_VMODULE=profile_guided_latency_estimator=10
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
shopt -s globstar nullglob
IFS=:$IFS
set -- /usr/local/cuda-*/compat
export LD_LIBRARY_PATH="${1+:"$*"}:${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64"
IFS=${IFS#?}
shopt -u globstar nullglob

if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then
echo "Using GPUDirect-TCPX"
export NCCL_ALGO=Ring
export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION
export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0
export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191"
export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4
export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177"
export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000
export NCCL_MAX_NCHANNELS=12
export NCCL_MIN_NCHANNELS=12
export NCCL_NSOCKS_PERTHREAD=4
export NCCL_P2P_PXN_LEVEL=0
export NCCL_SOCKET_NTHREADS=1
elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
echo "Using GPUDirect-TCPFasTrak"
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export NCCL_ALGO=Ring,Tree
export NCCL_BUFFSIZE=8388608
export NCCL_FASTRAK_CTRL_DEV=eth0
export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0
export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0
export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8
export NCCL_FASTRAK_NUM_FLOWS=2
export NCCL_FASTRAK_USE_LLCM=1
export NCCL_FASTRAK_USE_SNAP=1
export NCCL_MIN_NCHANNELS=4
export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto
export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto
export NCCL_TUNER_PLUGIN=libnccl-tuner.so
fi
else
echo "NOT using GPUDirect"
fi
}

echo "LD_LIBRARY_PATH ${LD_LIBRARY_PATH}"

set_nccl_gpudirect_tcpx_specific_configuration

wait_all_success_or_exit() {
# https://www.baeldung.com/linux/background-process-get-exit-code
local pids=("$@")
while [[ ${#pids[@]} -ne 0 ]]; do
all_success="true"
for pid in "${pids[@]}"; do
code=$(non_blocking_wait "$pid")
if [[ $code -ne 127 ]]; then
if [[ $code -ne 0 ]]; then
echo "PID $pid failed with exit code $code"
exit "$code"
fi
else
all_success="false"
fi
done
if [[ $all_success == "true" ]]; then
echo "All pids succeeded"
break
fi
sleep 5
done
}
non_blocking_wait() {
# https://www.baeldung.com/linux/background-process-get-exit-code
local pid=$1
local code=127 # special code to indicate not-finished
if [[ ! -d "/proc/$pid" ]]; then
wait "$pid"
code=$?
fi
echo $code
}

resolve_coordinator_ip() {
local lookup_attempt=1
local max_coordinator_lookups=500
local coordinator_found=false
local coordinator_ip_address=""

echo "Coordinator Address $JAX_COORDINATOR_ADDRESS"

while [[ "$coordinator_found" = false && $lookup_attempt -le $max_coordinator_lookups ]]; do
coordinator_ip_address=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1)
if [[ -n "$coordinator_ip_address" ]]; then
coordinator_found=true
echo "Coordinator IP address: $coordinator_ip_address"
export JAX_COORDINATOR_IP=$coordinator_ip_address
return 0
else
echo "Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt, retrying..."
((lookup_attempt++))
sleep 1
fi
done

if [[ "$coordinator_found" = false ]]; then
echo "Failed to resolve coordinator address after $max_coordinator_lookups attempts."
return 1
fi
}

# Resolving coordinator IP
set +e
resolve_coordinator_ip
set -e

PIDS=()
eval ${COMMAND} &
PID=$!
PIDS+=($PID)

wait_all_success_or_exit "${PIDS[@]}"
2 changes: 1 addition & 1 deletion maxdiffusion_dependencies.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ARG JAX_VERSION
ENV ENV_JAX_VERSION=$JAX_VERSION

# Set the working directory in the container
WORKDIR /app
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .
Expand Down
50 changes: 50 additions & 0 deletions maxdiffusion_gpu_dependencies.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# syntax=docker/dockerfile:experimental
# Note: This pulls in the lastest of jax:base
ARG BASEIMAGE=ghcr.io/nvidia/jax:base
FROM $BASEIMAGE

# Stopgaps measure to circumvent gpg key setup issue.
RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list

# Install dependencies for adjusting network rto
RUN apt-get update && apt-get install -y iproute2 ethtool lsof

# Install DNS util dependencies
RUN apt-get install -y dnsutils

# Add the Google Cloud SDK package repository
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -

# Install the Google Cloud SDK
RUN apt-get update && apt-get install -y google-cloud-sdk

# Set environment variables for Google Cloud SDK
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"

# Upgrade libcusprase to work with Jax
RUN apt-get update && apt-get install -y libcusparse-12-3

ARG MODE
ENV ENV_MODE=$MODE

ARG JAX_VERSION
ENV ENV_JAX_VERSION=$JAX_VERSION

ARG DEVICE
ENV ENV_DEVICE=$DEVICE

RUN mkdir -p /deps

# Set the working directory in the container
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .
RUN ls .

RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}


WORKDIR /deps
4 changes: 2 additions & 2 deletions maxdiffusion_jax_stable_stack_tpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ ARG COMMIT_HASH

ENV COMMIT_HASH=$COMMIT_HASH

RUN mkdir -p /app
RUN mkdir -p /deps

# Set the working directory in the container
WORKDIR /app
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .
Expand Down
4 changes: 2 additions & 2 deletions maxdiffusion_runner.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ ARG BASEIMAGE=maxdiffusion_base_image
FROM $BASEIMAGE

# Set the working directory in the container
WORKDIR /app
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .

WORKDIR /app
WORKDIR /deps
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ git+https://github.com/mlperf/logging.git
opencv-python-headless==4.10.0.84
orbax-checkpoint>=0.5.20
tokenizers==0.20.0
huggingface_hub==0.24.7
Loading

0 comments on commit 51e1db1

Please sign in to comment.