From e68bbc03432ca49d9f57b796a4420039c615d93c Mon Sep 17 00:00:00 2001 From: Param Bole Date: Mon, 16 Sep 2024 14:00:51 -0700 Subject: [PATCH] Adding new build capabilities to maxdiffusion (#103) * Nightly Docker Image build * Updated flow to generate this nightly image --- .github/workflows/UploadDockerImages.yml | 7 ++- docker_maxdiffusion_image_upload.sh | 53 ++++++++++------- maxdiffusion_tpu.Dockerfile | 72 ++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 21 deletions(-) create mode 100644 maxdiffusion_tpu.Dockerfile diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml index ef891344..d0d39cba 100644 --- a/.github/workflows/UploadDockerImages.yml +++ b/.github/workflows/UploadDockerImages.yml @@ -27,6 +27,9 @@ jobs: runs-on: ["self-hosted", "e2", "cpu"] steps: - uses: actions/checkout@v3 - - name: build jax stable stack image + - name: build maxdiffusion jax stable stack image run: | - bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true \ No newline at end of file + bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true + - name: build maxdiffusion jax nightly image + run: | + bash docker_maxdiffusion_image_upload.sh MODE=nightly PROJECT_ID=tpu-prod-env-multipod CLOUD_IMAGE_NAME=maxdiffusion-jax-nightly IMAGE_TAG=auto MAXDIFFUSION_REQUIREMENTS_FILE=requirements.txt DELETE_LOCAL_IMAGE=true \ No newline at end of file diff --git a/docker_maxdiffusion_image_upload.sh b/docker_maxdiffusion_image_upload.sh index 43f1c105..da7700d0 100644 --- a/docker_maxdiffusion_image_upload.sh +++ b/docker_maxdiffusion_image_upload.sh @@ -21,7 +21,9 @@ # (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds. # Example command: -# bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=latest MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt +# bash docker_maxdiffusion_image_upload.sh MODE=stable PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=latest MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt + +# You need to specify a MODE {stable|nightly}, default value stable. set -e @@ -34,11 +36,6 @@ for ARGUMENT in "$@"; do echo "$KEY"="$VALUE" done -if [[ ! -v BASEIMAGE ]]; then - echo "Erroring out because BASEIMAGE is unset, please set it!" - exit 1 -fi - if [[ ! -v PROJECT_ID ]]; then echo "Erroring out because PROJECT_ID is unset, please set it!" exit 1 @@ -59,6 +56,11 @@ if [[ ! -v MAXDIFFUSION_REQUIREMENTS_FILE ]]; then exit 1 fi +if [[ -z MODE ]]; then + export MODE=stable + echo "Default MODE=${MODE}" +fi + # Default: Don't delete local image DELETE_LOCAL_IMAGE="${DELETE_LOCAL_IMAGE:-false}" @@ -66,25 +68,38 @@ gcloud auth configure-docker us-docker.pkg.dev --quiet COMMIT_HASH=$(git rev-parse --short HEAD) -echo "Building JAX Stable Stack MaxDiffusion at commit hash ${COMMIT_HASH} . . ." IMAGE_DATE=$(date +%Y-%m-%d) -FULL_IMAGE_TAG=${IMAGE_TAG}-${IMAGE_DATE} - -docker build --no-cache \ - --build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \ - --build-arg COMMIT_HASH=${COMMIT_HASH} \ - --build-arg MAXDIFFUSION_REQUIREMENTS_FILE=${MAXDIFFUSION_REQUIREMENTS_FILE} \ - --network=host \ - -t us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} \ - -f maxdiffusion_jax_stable_stack_tpu.Dockerfile . +IMAGE=us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${IMAGE_TAG}-${IMAGE_DATE} + +if [[ "${MODE}" == "nightly" ]]; then + echo "Building MaxDiffusion with JAX and JAXLIB nightly at commit hash ${COMMIT_HASH} . . ." + docker build --no-cache \ + --build-arg MAXDIFFUSION_REQUIREMENTS_FILE=${MAXDIFFUSION_REQUIREMENTS_FILE} \ + --network=host \ + -t ${IMAGE} \ + -f maxdiffusion_tpu.Dockerfile . +else + echo "Building JAX Stable Stack MaxDiffusion at commit hash ${COMMIT_HASH} . . ." + if [[ ! -v BASEIMAGE ]]; then + echo "Erroring out because BASEIMAGE is unset, please set it!" + exit 1 + fi + docker build --no-cache \ + --build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \ + --build-arg COMMIT_HASH=${COMMIT_HASH} \ + --build-arg MAXDIFFUSION_REQUIREMENTS_FILE=${MAXDIFFUSION_REQUIREMENTS_FILE} \ + --network=host \ + -t ${IMAGE} \ + -f maxdiffusion_jax_stable_stack_tpu.Dockerfile . +fi -docker push us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} +docker push ${IMAGE} -echo "All done, check out your artifacts at: us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG}" +echo "All done, check out your artifacts at: ${IMAGE}" if [ "$DELETE_LOCAL_IMAGE" == "true" ]; then - docker rmi us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} + docker rmi ${IMAGE} echo "Local image deleted." fi \ No newline at end of file diff --git a/maxdiffusion_tpu.Dockerfile b/maxdiffusion_tpu.Dockerfile new file mode 100644 index 00000000..af647547 --- /dev/null +++ b/maxdiffusion_tpu.Dockerfile @@ -0,0 +1,72 @@ +# Use Python 3.10-slim-bullseye as the base image +FROM python:3.10-slim-bullseye + +# Environment variable for no-cache-dir and pip root user warning +ENV PIP_NO_CACHE_DIR=1 +ENV PIP_ROOT_USER_ACTION=ignore + +# Set environment variables for Google Cloud SDK and Python 3.10 +ENV PYTHON_VERSION=3.10 +ENV CLOUD_SDK_VERSION=latest + +# Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors +ENV DEBIAN_FRONTEND=noninteractive + +# Upgrade pip to the latest version +RUN python -m pip install --upgrade pip --no-warn-script-location + +# Install system dependencies +RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool && rm -rf /var/lib/apt/lists/* + +# Add the Google Cloud SDK package repository +RUN curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \ + echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee /etc/apt/sources.list.d/google-cloud-sdk.list + +# Install the Google Cloud SDK +RUN apt-get update && apt-get install -y google-cloud-sdk && rm -rf /var/lib/apt/lists/* + +# Install cloud-accelerator-diagnostics +RUN pip install cloud-accelerator-diagnostics + +# Install cloud-tpu-diagnostics +RUN pip install cloud-tpu-diagnostics + +# Install gcsfs +RUN pip install gcsfs + +# Install google-cloud-storage +RUN pip install google-cloud-storage + +# Install jax-nightly +RUN pip install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + +# Install jaxlib-nightly +RUN pip install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + +# Install libtpu-nightly +RUN pip install libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -U --pre + +# Installing nightly tensorboard plugin profile +RUN pip install tbp-nightly --upgrade + + +# Set the working directory in the container +WORKDIR /deps + +# Copy all files from local workspace into docker container +COPY . . +RUN ls . + +ARG MAXDIFFUSION_REQUIREMENTS_FILE + +# Install Maxdiffusion requirements +RUN if [ ! -z "${MAXDIFFUSION_REQUIREMENTS_FILE}" ]; then \ + echo "Using MaxDiffusion requirements: ${MAXDIFFUSION_REQUIREMENTS_FILE}" && \ + pip install -r /deps/${MAXDIFFUSION_REQUIREMENTS_FILE}; \ + fi + +# Install MaxDiffusion +RUN pip install . + +# Cleanup +RUN rm -rf /root/.cache/pip \ No newline at end of file