Skip to content

Commit

Permalink
Adding new build capabilities to maxdiffusion (#103)
Browse files Browse the repository at this point in the history
* Nightly Docker Image build
* Updated flow to generate this nightly image
  • Loading branch information
parambole authored Sep 16, 2024
1 parent 6234383 commit e68bbc0
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 21 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ jobs:
runs-on: ["self-hosted", "e2", "cpu"]
steps:
- uses: actions/checkout@v3
- name: build jax stable stack image
- name: build maxdiffusion jax stable stack image
run: |
bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true
bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true
- name: build maxdiffusion jax nightly image
run: |
bash docker_maxdiffusion_image_upload.sh MODE=nightly PROJECT_ID=tpu-prod-env-multipod CLOUD_IMAGE_NAME=maxdiffusion-jax-nightly IMAGE_TAG=auto MAXDIFFUSION_REQUIREMENTS_FILE=requirements.txt DELETE_LOCAL_IMAGE=true
53 changes: 34 additions & 19 deletions docker_maxdiffusion_image_upload.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
# (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds.

# Example command:
# bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=latest MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
# bash docker_maxdiffusion_image_upload.sh MODE=stable PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=latest MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt

# You need to specify a MODE {stable|nightly}, default value stable.

set -e

Expand All @@ -34,11 +36,6 @@ for ARGUMENT in "$@"; do
echo "$KEY"="$VALUE"
done

if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi

if [[ ! -v PROJECT_ID ]]; then
echo "Erroring out because PROJECT_ID is unset, please set it!"
exit 1
Expand All @@ -59,32 +56,50 @@ if [[ ! -v MAXDIFFUSION_REQUIREMENTS_FILE ]]; then
exit 1
fi

if [[ -z MODE ]]; then
export MODE=stable
echo "Default MODE=${MODE}"
fi

# Default: Don't delete local image
DELETE_LOCAL_IMAGE="${DELETE_LOCAL_IMAGE:-false}"

gcloud auth configure-docker us-docker.pkg.dev --quiet

COMMIT_HASH=$(git rev-parse --short HEAD)

echo "Building JAX Stable Stack MaxDiffusion at commit hash ${COMMIT_HASH} . . ."

IMAGE_DATE=$(date +%Y-%m-%d)

FULL_IMAGE_TAG=${IMAGE_TAG}-${IMAGE_DATE}

docker build --no-cache \
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--build-arg MAXDIFFUSION_REQUIREMENTS_FILE=${MAXDIFFUSION_REQUIREMENTS_FILE} \
--network=host \
-t us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} \
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
IMAGE=us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${IMAGE_TAG}-${IMAGE_DATE}

if [[ "${MODE}" == "nightly" ]]; then
echo "Building MaxDiffusion with JAX and JAXLIB nightly at commit hash ${COMMIT_HASH} . . ."
docker build --no-cache \
--build-arg MAXDIFFUSION_REQUIREMENTS_FILE=${MAXDIFFUSION_REQUIREMENTS_FILE} \
--network=host \
-t ${IMAGE} \
-f maxdiffusion_tpu.Dockerfile .
else
echo "Building JAX Stable Stack MaxDiffusion at commit hash ${COMMIT_HASH} . . ."
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi
docker build --no-cache \
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--build-arg MAXDIFFUSION_REQUIREMENTS_FILE=${MAXDIFFUSION_REQUIREMENTS_FILE} \
--network=host \
-t ${IMAGE} \
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
fi

docker push us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG}
docker push ${IMAGE}

echo "All done, check out your artifacts at: us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG}"
echo "All done, check out your artifacts at: ${IMAGE}"

if [ "$DELETE_LOCAL_IMAGE" == "true" ]; then
docker rmi us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG}
docker rmi ${IMAGE}
echo "Local image deleted."
fi
72 changes: 72 additions & 0 deletions maxdiffusion_tpu.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Use Python 3.10-slim-bullseye as the base image
FROM python:3.10-slim-bullseye

# Environment variable for no-cache-dir and pip root user warning
ENV PIP_NO_CACHE_DIR=1
ENV PIP_ROOT_USER_ACTION=ignore

# Set environment variables for Google Cloud SDK and Python 3.10
ENV PYTHON_VERSION=3.10
ENV CLOUD_SDK_VERSION=latest

# Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors
ENV DEBIAN_FRONTEND=noninteractive

# Upgrade pip to the latest version
RUN python -m pip install --upgrade pip --no-warn-script-location

# Install system dependencies
RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool && rm -rf /var/lib/apt/lists/*

# Add the Google Cloud SDK package repository
RUN curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee /etc/apt/sources.list.d/google-cloud-sdk.list

# Install the Google Cloud SDK
RUN apt-get update && apt-get install -y google-cloud-sdk && rm -rf /var/lib/apt/lists/*

# Install cloud-accelerator-diagnostics
RUN pip install cloud-accelerator-diagnostics

# Install cloud-tpu-diagnostics
RUN pip install cloud-tpu-diagnostics

# Install gcsfs
RUN pip install gcsfs

# Install google-cloud-storage
RUN pip install google-cloud-storage

# Install jax-nightly
RUN pip install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html

# Install jaxlib-nightly
RUN pip install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Install libtpu-nightly
RUN pip install libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -U --pre

# Installing nightly tensorboard plugin profile
RUN pip install tbp-nightly --upgrade


# Set the working directory in the container
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .
RUN ls .

ARG MAXDIFFUSION_REQUIREMENTS_FILE

# Install Maxdiffusion requirements
RUN if [ ! -z "${MAXDIFFUSION_REQUIREMENTS_FILE}" ]; then \
echo "Using MaxDiffusion requirements: ${MAXDIFFUSION_REQUIREMENTS_FILE}" && \
pip install -r /deps/${MAXDIFFUSION_REQUIREMENTS_FILE}; \
fi

# Install MaxDiffusion
RUN pip install .

# Cleanup
RUN rm -rf /root/.cache/pip

0 comments on commit e68bbc0

Please sign in to comment.