diff --git a/.github/workflows/build_triton_and_ft.yml b/.github/workflows/build_triton_and_ft.yml index 2893e1567ff..beeff41a592 100644 --- a/.github/workflows/build_triton_and_ft.yml +++ b/.github/workflows/build_triton_and_ft.yml @@ -1,4 +1,4 @@ -name: Build Triton Server and FasterTransformers +name: Build Triton Server on: workflow_dispatch: @@ -6,36 +6,28 @@ on: triton: description: 'triton branch version' required: true - default: 'r23.04' - fastertransformer: - description: 'fastertransformer branch/tag version' - required: true - default: 'main' - is_llama_build: - description: 'whether to build custom llama source' - required: false - type: boolean - default: false + default: 'r23.10' jobs: build-triton: if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: triton-inference-server/server ref: ${{ github.event.inputs.triton }} - name: Set up Python3 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Build Triton Binary shell: 'script -q -e -c "bash --noprofile --norc -eo pipefail {0}"' run: | - python3 build.py --enable-logging --enable-metrics --enable-stats --enable-cpu-metrics --endpoint http + pip3 install requests + python3 build.py --enable-logging --enable-metrics --enable-stats --enable-cpu-metrics --enable-gpu --endpoint http - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v2 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -45,60 +37,3 @@ jobs: aws s3 cp build/install/lib/libtritonserver.so s3://djl-ai/publish/tritonserver/${{ github.event.inputs.triton }}/ aws s3 cp build/install/bin/tritonserver s3://djl-ai/publish/tritonserver/${{ github.event.inputs.triton }}/ aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tritonserver/${{ github.event.inputs.triton }}/*" - - create-runner: - if: github.repository == 'deepjavalibrary/djl' - runs-on: [ self-hosted, scheduler ] - steps: - - name: Create new CPU instance - id: create_cpu - run: | - cd /home/ubuntu/djl_benchmark_script/scripts - token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \ - https://api.github.com/repos/deepjavalibrary/djl/actions/runners/registration-token \ - --fail \ - | jq '.token' | tr -d '"' ) - ./start_instance.sh action_cpu $token djl - outputs: - cpu_instance_id: ${{ steps.create_cpu.outputs.action_cpu_instance_id }} - - - build-fastertransformer: - if: github.repository == 'deepjavalibrary/djl' - runs-on: [ self-hosted, cpu ] - container: deepjavalibrary/djl-serving:fastertransformer-nightly - timeout-minutes: 60 - needs: create-runner - steps: - - uses: actions/checkout@v3 - - name: Build FasterTransformers - run: | - tools/scripts/build_ft_deps.sh ${{ github.event.inputs.fastertransformer }} ${{ github.event.inputs.triton }} ${{ github.event.inputs.is_llama_build }} - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v2 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: us-east-2 - - name: Copy files to S3 with the AWS CLI - if: github.event.inputs.is_llama_build == 'false' - run: | - aws s3 sync /tmp/binaries/ s3://djl-ai/publish/fastertransformer/${{ github.event.inputs.fastertransformer }}/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/fastertransformer/${{ github.event.inputs.fastertransformer }}/*" - - name: Copy files for llama build to S3 with AWS CLI - if: github.event.inputs.is_llama_build == 'true' - run: | - echo "pushing binaries to ft/llama" - aws s3 sync /tmp/binaries/ s3://djl-ai/publish/fastertransformer/llama/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/fastertransformer-llama/${{ github.event.inputs.fastertransformer }}/*" - - stop-runner: - if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, scheduler ] - needs: [ create-runner, build-fastertransformer] - steps: - - name: Stop all instances - run: | - cd /home/ubuntu/djl_benchmark_script/scripts - instance_id=${{ needs.create-runner.outputs.cpu_instance_id }} - ./stop_instance.sh $instance_id diff --git a/.github/workflows/codeql-analysis-java.yml b/.github/workflows/codeql-analysis-java.yml index 107efb92286..19fe93e1737 100644 --- a/.github/workflows/codeql-analysis-java.yml +++ b/.github/workflows/codeql-analysis-java.yml @@ -34,20 +34,20 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Init gradle run: ./gradlew --no-daemon clean # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -58,7 +58,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 # ℹī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -72,4 +72,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/continuous.yml b/.github/workflows/continuous.yml index 5cfa1503ed5..31b759cfb2e 100644 --- a/.github/workflows/continuous.yml +++ b/.github/workflows/continuous.yml @@ -10,6 +10,15 @@ on: - "**.js" - "**.css" - "android/**" + push: + paths-ignore: + - "**.md" + - "**.ipynb" + - "**.json" + - "**.html" + - "**.js" + - "**.css" + - "android/**" jobs: build: @@ -17,17 +26,17 @@ jobs: runs-on: ${{ matrix.operating-system }} strategy: matrix: - operating-system: [ ubuntu-latest, macos-12 ] + operating-system: [ ubuntu-latest, macos-13 ] steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -41,6 +50,9 @@ jobs: filters: | src: - 'extensions/sentencepiece/**' + - name: install libomp on macos + if: ${{ runner.os == 'macOS' }} + run: brew install libomp - name: Compile Sentencepiece JNI if: steps.sentencepiece_changes.outputs.src == 'true' run: ./gradlew :extensions:sentencepiece:compileJNI @@ -106,14 +118,14 @@ jobs: if: github.repository == 'deepjavalibrary/djl' runs-on: windows-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} diff --git a/.github/workflows/docker_publish.yml b/.github/workflows/docker_publish.yml index d183a0e2787..a85ea409f4e 100644 --- a/.github/workflows/docker_publish.yml +++ b/.github/workflows/docker_publish.yml @@ -15,11 +15,11 @@ jobs: if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Docker buildx uses: docker/setup-buildx-action@v2 - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -56,4 +56,4 @@ jobs: context: . file: docker/spark/Dockerfile build-args: DJL_VERSION=${DJL_VERSION} - tags: deepjavalibrary/djl-spark:${{ env.DJL_VERSION }}-cpu \ No newline at end of file + tags: deepjavalibrary/djl-spark:${{ env.DJL_VERSION }}-cpu diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 45e2177466d..de303ed5486 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -3,7 +3,6 @@ name: Docs on: pull_request: paths: - - "**.ipynb" - "docs/mkdocs.yml" # Publish docs weekly schedule: @@ -15,13 +14,13 @@ jobs: if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest steps: - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Set up Python3 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install CN fonts @@ -34,14 +33,10 @@ jobs: cd IJava/ ./gradlew installKernel - name: checkout repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: add mybinder link run: | python3 tools/scripts/add_online_runner.py - - name: run Notebooks - run: | - cd jupyter - bash test_notebook.sh - name: clone demos run: | cd docs @@ -50,13 +45,18 @@ jobs: run: | cd docs git clone https://github.com/deepjavalibrary/djl-serving.git serving + - name: run Notebooks + run: | + cd docs/demos/jupyter + bash test_notebook.sh - name: build docs run: | cd docs + export DJL_DISABLE_PROGRESS_BAR=true mkdocs build --site-dir ../../site - name: Configure AWS Credentials if: github.event_name != 'pull_request' - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_build_mxnet_osx.yml b/.github/workflows/native_build_mxnet_osx.yml index 4728ea1755b..1d8d9245f9a 100644 --- a/.github/workflows/native_build_mxnet_osx.yml +++ b/.github/workflows/native_build_mxnet_osx.yml @@ -9,7 +9,7 @@ jobs: steps: - name: Checkout Apache MXNet repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: apache/incubator-mxnet ref: 1.9.1 diff --git a/.github/workflows/native_jni_s3_paddle.yml b/.github/workflows/native_jni_s3_paddle.yml index 3cb9a62f7c9..752ba39c685 100644 --- a/.github/workflows/native_jni_s3_paddle.yml +++ b/.github/workflows/native_jni_s3_paddle.yml @@ -14,13 +14,13 @@ jobs: operating-system: [ macos-latest, windows-latest ] steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -55,11 +55,11 @@ jobs: ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -81,7 +81,7 @@ jobs: runs-on: ubuntu-latest needs: [ build-paddle-jni-cpu, build-paddle-jni-linux ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Download compiledJNI Mac uses: actions/download-artifact@v3 with: @@ -98,7 +98,7 @@ jobs: name: jnilib-Linux path: jnilib - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_jni_s3_pytorch.yml b/.github/workflows/native_jni_s3_pytorch.yml index 7bd2edc1677..1f1a220e8c2 100644 --- a/.github/workflows/native_jni_s3_pytorch.yml +++ b/.github/workflows/native_jni_s3_pytorch.yml @@ -15,13 +15,13 @@ jobs: runs-on: macos-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -36,7 +36,7 @@ jobs: ./gradlew :engines:pytorch:pytorch-native:compileJNI -Ppt_version=$PYTORCH_VERSION ./gradlew -Pjni -Ppt_version=$PYTORCH_VERSION :integration:test "-Dai.djl.default_engine=PyTorch" - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -55,11 +55,11 @@ jobs: container: nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04 steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -91,7 +91,7 @@ jobs: if [[ "$PYTORCH_VERSION" == "1.12.1" ]]; then ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcu10 -Ppt_version=$PYTORCH_VERSION; fi if [[ "$PYTORCH_VERSION" == "1.11.0" ]]; then ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcu10 -Ppt_version=$PYTORCH_VERSION; fi - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v3 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -116,11 +116,11 @@ jobs: ln -s /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -137,9 +137,10 @@ jobs: ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pprecxx11 -Ppt_version=$PYTORCH_VERSION ./gradlew -Pjni -Ppt_version=$PYTORCH_VERSION :integration:test "-Dai.djl.default_engine=PyTorch" ./gradlew :engines:pytorch:pytorch-native:cleanJNI + rm -rf ~/.djl.ai ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcu11 -Pprecxx11 -Ppt_version=$PYTORCH_VERSION - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -155,13 +156,13 @@ jobs: if: github.repository == 'deepjavalibrary/djl' runs-on: windows-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -190,7 +191,7 @@ jobs: set "PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH%" gradlew :engines:pytorch:pytorch-native:cleanJNI :engines:pytorch:pytorch-native:compileJNI -Pcu11 -Ppt_version=${{ github.event.inputs.pt_version }} - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -205,16 +206,16 @@ jobs: build-pytorch-jni-arm64-macos: if: github.repository == 'deepjavalibrary/djl' - runs-on: [ self-hosted, ARM64, macOS ] + runs-on: macos-latest-xlarge steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -229,7 +230,7 @@ jobs: ./gradlew :engines:pytorch:pytorch-native:compileJNI -Ppt_version=$PYTORCH_VERSION ./gradlew -Pjni -Ppt_version=$PYTORCH_VERSION :integration:test "-Dai.djl.default_engine=PyTorch" - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -259,6 +260,7 @@ jobs: aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }} build-pytorch-jni-aarch64: + if: github.repository == 'deepjavalibrary/djl' runs-on: [ self-hosted, aarch64 ] container: amazonlinux:2 timeout-minutes: 30 @@ -268,21 +270,21 @@ jobs: run: | yum -y update yum -y groupinstall "Development Tools" - yum -y install patch git cmake3 python3-devel java-11-amazon-corretto + yum -y install patch git cmake3 python3-devel java-17-amazon-corretto-devel ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - name: Release JNI prep run: | - export JAVA_HOME=/usr/lib/jvm/java-11-amazon-corretto.aarch64 - export PATH=$PATH:$JAVA_HOME + export JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.aarch64 + export PATH=$PATH:$JAVA_HOME/bin PYTORCH_VERSION=${{ github.event.inputs.pt_version }} export PYTORCH_VERSION=${PYTORCH_VERSION:-$(cat gradle.properties | awk -F '=' '/pytorch_version/ {print $2}')} echo $PYTORCH_VERSION ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pprecxx11 -Ppt_version=$PYTORCH_VERSION ./gradlew -Pjni -Ppt_version=$PYTORCH_VERSION :integration:test "-Dai.djl.default_engine=PyTorch" - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_jni_s3_pytorch_android.yml b/.github/workflows/native_jni_s3_pytorch_android.yml index c3834057c9e..0376856bb33 100644 --- a/.github/workflows/native_jni_s3_pytorch_android.yml +++ b/.github/workflows/native_jni_s3_pytorch_android.yml @@ -7,32 +7,33 @@ on: jobs: build-pytorch-jni-android: + if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest env: NDK_VERSION: "21.1.6352462" steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} restore-keys: | ${{ runner.os }}-gradle- - name: Install NDK - run: echo "y" | sudo ${ANDROID_HOME}/tools/bin/sdkmanager --install "ndk;${NDK_VERSION}" + run: echo "y" | sudo ${ANDROID_HOME}/cmdline-tools/latest/bin/sdkmanager --install "ndk;${NDK_VERSION}" - name: build android run: | export ANDROID_NDK=${ANDROID_SDK_ROOT}/ndk/${NDK_VERSION} PYTORCH_VERSION=${PYTORCH_VERSION:-$(cat gradle.properties | awk -F '=' '/pytorch_version/ {print $2}')} ./gradlew :engines:pytorch:pytorch-native:compileAndroidJNI -Ppt_version=${PYTORCH_VERSION} - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_jni_s3_tensorrt.yml b/.github/workflows/native_jni_s3_tensorrt.yml index cf39c6e070a..1711a2d4e82 100644 --- a/.github/workflows/native_jni_s3_tensorrt.yml +++ b/.github/workflows/native_jni_s3_tensorrt.yml @@ -6,17 +6,12 @@ on: jobs: build-tensorrt-jni-linux: runs-on: ubuntu-latest - container: deepjavalibrary/ubuntu18.04:tensorrt-cuda116 + container: deepjavalibrary/ubuntu20.04:tensorrt-cuda122 steps: - name: Install Environment run: pip3 install awscli --upgrade - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 - with: - distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + - uses: actions/checkout@v4 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -25,7 +20,7 @@ jobs: - name: Release JNI prep run: ./gradlew :engines:tensorrt:compileJNI - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_publish_mxnet.yml b/.github/workflows/native_publish_mxnet.yml index b6ff595e9cf..99940cbf075 100644 --- a/.github/workflows/native_publish_mxnet.yml +++ b/.github/workflows/native_publish_mxnet.yml @@ -13,14 +13,14 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} diff --git a/.github/workflows/native_publish_paddle.yml b/.github/workflows/native_publish_paddle.yml index bfcdbd42708..a724418f25d 100644 --- a/.github/workflows/native_publish_paddle.yml +++ b/.github/workflows/native_publish_paddle.yml @@ -13,14 +13,14 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} diff --git a/.github/workflows/native_publish_pytorch.yml b/.github/workflows/native_publish_pytorch.yml index f3ef6496a44..8dd65271ffc 100644 --- a/.github/workflows/native_publish_pytorch.yml +++ b/.github/workflows/native_publish_pytorch.yml @@ -13,14 +13,14 @@ jobs: runs-on: macos-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} diff --git a/.github/workflows/native_publish_tensorflow.yml b/.github/workflows/native_publish_tensorflow.yml index b4d95d768ba..82ef95f1875 100644 --- a/.github/workflows/native_publish_tensorflow.yml +++ b/.github/workflows/native_publish_tensorflow.yml @@ -13,14 +13,14 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} diff --git a/.github/workflows/native_publish_tflite.yml b/.github/workflows/native_publish_tflite.yml index c88d2f00491..cca397a0487 100644 --- a/.github/workflows/native_publish_tflite.yml +++ b/.github/workflows/native_publish_tflite.yml @@ -13,14 +13,14 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} diff --git a/.github/workflows/native_s3_fasttext.yml b/.github/workflows/native_s3_fasttext.yml index 948ebabf478..8dd34e6d3c6 100644 --- a/.github/workflows/native_s3_fasttext.yml +++ b/.github/workflows/native_s3_fasttext.yml @@ -7,13 +7,13 @@ jobs: build-fasttext-jni-osx: runs-on: macos-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -24,7 +24,7 @@ jobs: ./gradlew :extensions:fasttext:compileJNI ./gradlew -Pjni :extensions:fasttext:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -46,11 +46,11 @@ jobs: ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -63,7 +63,7 @@ jobs: ./gradlew :extensions:fasttext:compileJNI ./gradlew -Pjni :extensions:fasttext:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -75,16 +75,16 @@ jobs: build-fasttext-jni-arm64-osx: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, ARM64, macOS ] + runs-on: macos-latest-xlarge steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -95,7 +95,7 @@ jobs: ./gradlew :extensions:fasttext:compileJNI ./gradlew -Pjni :extensions:fasttext:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_s3_huggingface.yml b/.github/workflows/native_s3_huggingface.yml index d9ce2d29197..e1ec9752438 100644 --- a/.github/workflows/native_s3_huggingface.yml +++ b/.github/workflows/native_s3_huggingface.yml @@ -7,16 +7,16 @@ jobs: build-tokenizers-jni-osx: runs-on: macos-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions-rs/toolchain@v1 with: toolchain: stable - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -27,7 +27,7 @@ jobs: ./gradlew :extensions:tokenizers:compileJNI ./gradlew -Pjni :extensions:tokenizers:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -45,7 +45,7 @@ jobs: - name: Install Environment run: | yum -y update - yum -y install centos-release-scl-rh epel-release + yum -y install centos-release-scl-rh epel-release perl-core yum -y install devtoolset-7 git patch cmake3 libstdc++-static ln -s /usr/bin/cmake3 /usr/bin/cmake curl https://sh.rustup.rs -sSf | sh -s -- -y @@ -54,11 +54,11 @@ jobs: with: toolchain: stable - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -72,7 +72,7 @@ jobs: ./gradlew :extensions:tokenizers:compileJNI PYTORCH_PRECXX11=true ./gradlew -Pjni :extensions:tokenizers:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -92,13 +92,13 @@ jobs: - uses: actions-rs/toolchain@v1 with: toolchain: stable - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -110,7 +110,7 @@ jobs: call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" amd64 gradlew :extensions:tokenizer:compileJNI - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -124,19 +124,19 @@ jobs: build-tokenizers-jni-arm64-osx: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, ARM64, macOS ] + runs-on: macos-latest-xlarge steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions-rs/toolchain@v1 with: toolchain: stable - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -147,12 +147,13 @@ jobs: ./gradlew :extensions:tokenizers:compileJNI ./gradlew -Pjni :extensions:tokenizers:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: us-east-2 - name: Copy files to S3 with the AWS CLI + shell: bash run: | TOKENIZERS_VERSION="$(cat gradle.properties | awk -F '=' '/tokenizers_version/ {print $2}')" aws s3 sync extensions/tokenizers/jnilib s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/ @@ -179,24 +180,23 @@ jobs: runs-on: [ self-hosted, aarch64 ] timeout-minutes: 30 needs: create-aarch64-runner - container: centos:centos7 + container: amazonlinux:2 steps: - name: Install Environment run: | yum -y update - yum -y install centos-release-scl-rh epel-release - yum -y install devtoolset-7 git patch cmake3 libstdc++-static + yum -y groupinstall "Development Tools" + yum -y install patch perl-IPC-Cmd cmake3 ln -s /usr/bin/cmake3 /usr/bin/cmake - curl https://sh.rustup.rs -sSf | sh -s -- -y pip3 install awscli --upgrade - uses: actions-rs/toolchain@v1 with: toolchain: stable - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions/cache@v3 @@ -207,12 +207,10 @@ jobs: ${{ runner.os }}-gradle- - name: Release JNI prep run: | - source "$HOME/.cargo/env" - export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin ./gradlew :extensions:tokenizers:compileJNI PYTORCH_PRECXX11=true ./gradlew -Pjni :extensions:tokenizers:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_s3_llama.yml b/.github/workflows/native_s3_llama.yml new file mode 100644 index 00000000000..8172fc4bbef --- /dev/null +++ b/.github/workflows/native_s3_llama.yml @@ -0,0 +1,204 @@ +name: Native S3 llama.cpp + +on: + workflow_dispatch: + +jobs: + build-llamacpp-jni-osx: + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + distribution: 'corretto' + java-version: 17 + - uses: actions/cache@v4 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} + restore-keys: | + ${{ runner.os }}-gradle- + - name: Release JNI prep + run: | + ./gradlew :engines:llama:compileJNI + ./gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + build-llamacpp-jni-linux: + runs-on: ubuntu-latest + container: centos:centos7 + steps: + - name: Install Environment + run: | + yum -y update + yum -y install centos-release-scl-rh epel-release perl-core + yum -y install devtoolset-7 git patch cmake3 libstdc++-static + ln -s /usr/bin/cmake3 /usr/bin/cmake + pip3 install awscli --upgrade + - uses: actions/checkout@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v3 + with: + distribution: 'corretto' + java-version: 17 + - name: Release JNI prep + run: | + export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin + ./gradlew :engines:llama:compileJNI + ./gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v2 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + build-llamacpp-jni-windows: + runs-on: windows-latest + steps: + - name: Install Environment + run: | + choco install -y mingw + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + distribution: 'corretto' + java-version: 17 + - uses: actions/cache@v4 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} + restore-keys: | + ${{ runner.os }}-gradle- + - name: Release CPU JNI + shell: cmd + run: | + call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" amd64 + gradlew :engines:llama:compileJNI + gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + shell: bash + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + build-llamacpp-jni-arm64-osx: + if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} + runs-on: macos-latest-xlarge + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: 17 + distribution: corretto + architecture: aarch64 + - uses: actions/cache@v4 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} + restore-keys: | + ${{ runner.os }}-gradle- + - name: Release JNI prep + run: | + ./gradlew :engines:llama:compileJNI + ./gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + create-aarch64-runner: + if: github.repository == 'deepjavalibrary/djl' + runs-on: [ self-hosted, scheduler ] + steps: + - name: Create new Graviton instance + id: create_aarch64 + run: | + cd /home/ubuntu/djl_benchmark_script/scripts + token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \ + https://api.github.com/repos/deepjavalibrary/djl/actions/runners/registration-token \ + --fail \ + | jq '.token' | tr -d '"' ) + ./start_instance.sh action_graviton $token djl + outputs: + aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }} + + build-llamacpp-jni-aarch64: + if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} + runs-on: [ self-hosted, aarch64 ] + timeout-minutes: 30 + needs: create-aarch64-runner + container: amazonlinux:2 + steps: + - name: Install Environment + run: | + yum -y update + yum -y groupinstall "Development Tools" + yum -y install patch perl-IPC-Cmd cmake3 + ln -s /usr/bin/cmake3 /usr/bin/cmake + pip3 install awscli --upgrade + - uses: actions/checkout@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v3 + with: + java-version: 17 + distribution: corretto + architecture: aarch64 + - name: Release JNI prep + run: | + ./gradlew :engines:llama:compileJNI + ./gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v2 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + stop-runners: + if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} + runs-on: [ self-hosted, scheduler ] + needs: [ create-aarch64-runner, build-llamacpp-jni-aarch64 ] + steps: + - name: Stop all instances + run: | + cd /home/ubuntu/djl_benchmark_script/scripts + instance_id=${{ needs.create-aarch64-runner.outputs.aarch64_instance_id }} + ./stop_instance.sh $instance_id diff --git a/.github/workflows/native_s3_pytorch.yml b/.github/workflows/native_s3_pytorch.yml index fa3800e45da..54f4dba3119 100644 --- a/.github/workflows/native_s3_pytorch.yml +++ b/.github/workflows/native_s3_pytorch.yml @@ -7,21 +7,21 @@ jobs: build: runs-on: macos-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} restore-keys: | ${{ runner.os }}-gradle- - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_s3_pytorch_android.yml b/.github/workflows/native_s3_pytorch_android.yml index 3e03319be02..8887704a91e 100644 --- a/.github/workflows/native_s3_pytorch_android.yml +++ b/.github/workflows/native_s3_pytorch_android.yml @@ -10,17 +10,17 @@ jobs: matrix: format: ["armeabi-v7a", "arm64-v8a", "x86" ,"x86_64"] env: - PYTORCH_VERSION: "2.0.1" + PYTORCH_VERSION: "2.1.2" NDK_VERSION: "21.1.6352462" steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Set up Python3 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: clone pytorch @@ -29,7 +29,7 @@ jobs: - name: install Python Dependencies run: pip install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions - name: Install NDK - run: echo "y" | sudo ${ANDROID_HOME}/tools/bin/sdkmanager --install "ndk;${NDK_VERSION}" + run: echo "y" | sudo ${ANDROID_HOME}/cmdline-tools/latest/bin/sdkmanager --install "ndk;${NDK_VERSION}" - name: build android run: | export ANDROID_NDK=${ANDROID_SDK_ROOT}/ndk/${NDK_VERSION} @@ -39,7 +39,7 @@ jobs: cd build_android zip -r ${{ matrix.format }}_native.zip install/include lib - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -48,8 +48,3 @@ jobs: run: | aws s3 cp android_pytorch_tmp/build_android/${{ matrix.format }}_native.zip s3://djl-ai/publish/pytorch/${PYTORCH_VERSION}/android_native/ aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/pytorch/${PYTORCH_VERSION}/android_native*" -# - name: Upload pytorch src -# uses: actions/upload-artifact@v3 -# with: -# name: pytorch-src-${{ matrix.format }} -# path: android_pytorch_tmp diff --git a/.github/workflows/native_s3_sentencepiece.yml b/.github/workflows/native_s3_sentencepiece.yml index 1e217ed218d..134cf012a30 100644 --- a/.github/workflows/native_s3_sentencepiece.yml +++ b/.github/workflows/native_s3_sentencepiece.yml @@ -8,13 +8,13 @@ jobs: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} runs-on: macos-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -25,7 +25,7 @@ jobs: ./gradlew :extensions:sentencepiece:compileJNI ./gradlew -Pjni :extensions:sentencepiece:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -49,12 +49,12 @@ jobs: ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -66,7 +66,7 @@ jobs: ./gradlew :extensions:sentencepiece:compileJNI ./gradlew -Pjni :extensions:sentencepiece:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -81,13 +81,13 @@ jobs: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} runs-on: windows-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 - - uses: actions/cache@v3 + java-version: 17 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -98,7 +98,7 @@ jobs: ./gradlew :extensions:sentencepiece:compileJNI ./gradlew -Pjni :extensions:sentencepiece:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -112,19 +112,19 @@ jobs: build-sentencepiece-jni-arm64-osx: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, ARM64, macOS ] + runs-on: macos-latest-xlarge steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions-rs/toolchain@v1 with: toolchain: stable - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -135,7 +135,7 @@ jobs: ./gradlew :extensions:sentencepiece:compileJNI ./gradlew -Pjni :extensions:sentencepiece:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -164,13 +164,13 @@ jobs: with: toolchain: stable - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -182,7 +182,7 @@ jobs: ./gradlew :extensions:sentencepiece:compileJNI ./gradlew -Pjni :extensions:sentencepiece:test - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_s3_tensorflow.yml b/.github/workflows/native_s3_tensorflow.yml index 3a119bbc7a1..4d88337b34f 100644 --- a/.github/workflows/native_s3_tensorflow.yml +++ b/.github/workflows/native_s3_tensorflow.yml @@ -7,21 +7,21 @@ jobs: upload: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} restore-keys: | ${{ runner.os }}-gradle- - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_s3_tflite.yml b/.github/workflows/native_s3_tflite.yml index a8544baf669..8298880e14c 100644 --- a/.github/workflows/native_s3_tflite.yml +++ b/.github/workflows/native_s3_tflite.yml @@ -7,23 +7,23 @@ jobs: build-osx: runs-on: macos-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Get TFLITE_VERSION run: | TFLITE_VERSION="$(cat gradle.properties | awk -F '=' '/tflite_version/ {print $2}')" echo "TFLITE_VERSION=${TFLITE_VERSION}" >> $GITHUB_ENV - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: tensorflow/tensorflow ref: v${{ env.TFLITE_VERSION }} submodules: 'recursive' - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Set up Python3 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: install Python Dependencies @@ -33,7 +33,7 @@ jobs: cd tensorflow bazel build -c opt //tensorflow/lite/java:tensorflowlitelib //tensorflow/lite/delegates/flex:delegate - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -53,7 +53,7 @@ jobs: run: | yum -y update yum -y groupinstall "Development Tools" - yum -y install patch cmake3 unzip which java-11-amazon-corretto + yum -y install patch cmake3 unzip which java-17-amazon-corretto-devel ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade pip3 install numpy --upgrade @@ -70,12 +70,12 @@ jobs: - name: build package run: | cd tensorflow - export JAVA_HOME=/usr/lib/jvm/java-11-amazon-corretto.x86_64/ + export JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/ curl -L https://github.com/bazelbuild/bazel/releases/download/3.7.2/bazel-3.7.2-installer-linux-x86_64.sh -o bazel.sh --retry 10 bash bazel.sh bazel build -c opt //tensorflow/lite/java:tensorflowlitelib //tensorflow/lite/delegates/flex:delegate - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/native_s3_xgboost.yml b/.github/workflows/native_s3_xgboost.yml index 3d92e1bd3a8..2e83554a253 100644 --- a/.github/workflows/native_s3_xgboost.yml +++ b/.github/workflows/native_s3_xgboost.yml @@ -34,28 +34,28 @@ jobs: run: | yum -y update yum -y install centos-release-scl-rh epel-release - yum -y install devtoolset-7 git patch libstdc++-static curl python3-devel + yum -y install devtoolset-8 git patch libstdc++-static curl python3-devel curl -L -o cmake.tar.gz https://github.com/Kitware/CMake/releases/download/v3.27.0-rc2/cmake-3.27.0-rc2-linux-aarch64.tar.gz tar xvfz cmake.tar.gz ln -sf $PWD/cmake-3.*/bin/cmake /usr/bin/cmake cmake --version pip3 install awscli --upgrade - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Release JNI prep run: | XGBOOST_VERSION=${{ github.event.inputs.xgb_version }} XGBOOST_VERSION=${XGBOOST_VERSION:-$(cat gradle.properties | awk -F '=' '/xgboost_version/ {print $2}')} git clone https://github.com/dmlc/xgboost --recursive -b v"$XGBOOST_VERSION" - export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin + export PATH=$PATH:/opt/rh/devtoolset-8/root/usr/bin cd xgboost/jvm-packages python3 create_jni.py cd ../.. - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v2 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/nightly_android.yml b/.github/workflows/nightly_android.yml index 541d2ba9275..b46528bc29f 100644 --- a/.github/workflows/nightly_android.yml +++ b/.github/workflows/nightly_android.yml @@ -9,17 +9,18 @@ on: jobs: build: + if: github.repository == 'deepjavalibrary/djl' runs-on: macos-latest strategy: matrix: api-level: [ 26 ] steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Gradle cache uses: gradle/gradle-build-action@v2 - name: run tests diff --git a/.github/workflows/nightly_publish.yml b/.github/workflows/nightly_publish.yml index 64ac23e852c..8d2fcc146b6 100644 --- a/.github/workflows/nightly_publish.yml +++ b/.github/workflows/nightly_publish.yml @@ -16,17 +16,17 @@ jobs: runs-on: ${{ matrix.operating-system }} strategy: matrix: - operating-system: [ macos-12, ubuntu-latest ] + operating-system: [ macos-13, ubuntu-latest ] steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -34,6 +34,9 @@ jobs: ${{ runner.os }}-gradle- - name: check disk space run: df -h + - name: install libomp on macos + if: ${{ runner.os == 'macOS' }} + run: brew install libomp - name: Build with Gradle run: ./gradlew -Dnightly=true build :jacoco:testCodeCoverageReport - name: Upload test results @@ -51,14 +54,14 @@ jobs: operating-system: [ macos-latest, ubuntu-latest, windows-latest ] steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -75,14 +78,14 @@ jobs: operating-system: [ macos-latest, ubuntu-latest, windows-latest ] steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -102,11 +105,12 @@ jobs: run: | yum -y update yum install -y tar gzip + # checkout@v4 requires GLIBC 2.27 - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions/cache@v3 @@ -119,14 +123,15 @@ jobs: run: | ./gradlew :integration:test "-Dai.djl.default_engine=PyTorch" ./gradlew :integration:clean - ./gradlew :integration:test "-Dai.djl.default_engine=OnnxRuntime" - ./gradlew :integration:clean + # OnnxRuntime 1.17.1 requires GLIBC 2.27 + # ./gradlew :integration:test "-Dai.djl.default_engine=OnnxRuntime" + # ./gradlew :integration:clean - test-cuda-118: + test-cuda-121: if: github.repository == 'deepjavalibrary/djl' runs-on: [ self-hosted, gpu ] container: - image: nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu18.04 + image: nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu20.04 options: --gpus all --runtime=nvidia timeout-minutes: 30 needs: create-runners @@ -136,13 +141,13 @@ jobs: apt-get update apt-get install -y software-properties-common wget locales libfontconfig1 locale-gen en_US.UTF-8 - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -163,16 +168,16 @@ jobs: publish: if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest - needs: [ build, test-pytorch, test-tensorflow, test-aarch64, test-cuda-118 ] + needs: [ build, test-pytorch, test-tensorflow, test-aarch64, test-cuda-121 ] steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} @@ -181,10 +186,8 @@ jobs: - name: Publish to snapshot repository if: ${{ github.event.inputs.mode == '' || github.event.inputs.mode == 'snapshot' }} run: | - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.11.0 -Psnapshot - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.12.1 -Psnapshot ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.13.1 -Psnapshot - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.0.1 -Psnapshot + ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.1.2 -Psnapshot ./gradlew clean engines:ml:xgboost:publish -Pgpu -Psnapshot ./gradlew clean publish -Psnapshot cd bom @@ -197,10 +200,8 @@ jobs: - name: Publish to staging repository if: ${{ github.event.inputs.mode == 'staging' }} run: | - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.11.0 -P${{ github.event.inputs.mode }} - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.12.1 -P${{ github.event.inputs.mode }} ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.13.1 -P${{ github.event.inputs.mode }} - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.0.1 -P${{ github.event.inputs.mode }} + ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.1.2 -P${{ github.event.inputs.mode }} ./gradlew clean engines:ml:xgboost:publish -Pgpu -P${{ github.event.inputs.mode }} ./gradlew clean publish -P${{ github.event.inputs.mode }} cd bom @@ -211,7 +212,7 @@ jobs: ORG_GRADLE_PROJECT_ossrhUsername: ${{ secrets.ORG_GRADLE_PROJECT_ossrhUsername }} ORG_GRADLE_PROJECT_ossrhPassword: ${{ secrets.ORG_GRADLE_PROJECT_ossrhPassword }} - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -246,7 +247,7 @@ jobs: stop-runners: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} runs-on: [ self-hosted, scheduler ] - needs: [ create-runners, test-aarch64, test-cuda-118 ] + needs: [ create-runners, test-aarch64, test-cuda-121 ] steps: - name: Stop all instances run: | diff --git a/.github/workflows/no_response.yml b/.github/workflows/no_response.yml index 893ac1eac93..75c1c07ad54 100644 --- a/.github/workflows/no_response.yml +++ b/.github/workflows/no_response.yml @@ -11,6 +11,7 @@ on: jobs: noResponse: + if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest steps: - uses: lee-dohm/no-response@v0.5.0 diff --git a/.github/workflows/publish-job-success.yml b/.github/workflows/publish-job-success.yml new file mode 100644 index 00000000000..ec75dd63595 --- /dev/null +++ b/.github/workflows/publish-job-success.yml @@ -0,0 +1,37 @@ +name: Publish Job Success Metric to CloudWatch + +on: + workflow_run: + workflows: "*" + types: + - completed + branches: + - master + +permissions: + id-token: write + contents: read + +jobs: + publish-job-success-to-cloudwatch: + if: ${{ github.event.workflow_run.event == 'schedule' }} + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::185921645874:role/djl-github-cloudwatch-ci-metrics + aws-region: us-west-2 + - name: Publish Job Success Metric + env: + WORKFLOW_NAME: ${{ github.event.workflow_run.display_title }} + REPO_NAME: ${{ github.event.workflow_run.repository.name }} + CONCLUSION: ${{ github.event.workflow_run.conclusion }} + run: | + workflow_name=$(echo "$WORKFLOW_NAME" | tr -d ' ') + metric_name="${REPO_NAME}-${workflow_name}-Failure" + failedBuild=$([ "$CONCLUSION" == "success" ]; echo $?) + aws cloudwatch put-metric-data --namespace GithubCI \ + --metric-name "$metric_name" \ + --value $failedBuild \ + --unit Count diff --git a/.github/workflows/publish_android_packages.yml b/.github/workflows/publish_android_packages.yml index fea44330f94..4ae8bcefcb2 100644 --- a/.github/workflows/publish_android_packages.yml +++ b/.github/workflows/publish_android_packages.yml @@ -12,16 +12,17 @@ on: jobs: release-android: + if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} diff --git a/.github/workflows/serving_publish.yml b/.github/workflows/serving_publish.yml index 256a1c2eabb..380666487a7 100644 --- a/.github/workflows/serving_publish.yml +++ b/.github/workflows/serving_publish.yml @@ -24,24 +24,24 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: deepjavalibrary/djl-serving ref: ${{ github.event.inputs.serving-branch }} - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.gradle/caches key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} restore-keys: | ${{ runner.os }}-gradle- - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v1-node16 + uses: aws-actions/configure-aws-credentials@v4 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} @@ -74,6 +74,19 @@ jobs: aws s3 cp benchmark/build/distributions/*.deb s3://djl-ai/publish/djl-bench/${DJL_VERSION}/ aws s3 cp benchmark/build/distributions/*.zip s3://djl-ai/publish/djl-bench/${DJL_VERSION}/ aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/djl-bench/${DJL_VERSION}/*" + - name: Copy awscurl snapshot artifacts to S3 + if: ${{ github.event.inputs.mode == '' || github.event.inputs.mode == 'snapshot' }} + run: | + ./gradlew :awscurl:jar + aws s3 cp awscurl/build/awscurl s3://djl-ai/publish/awscurl/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/awscurl/*" + - name: Copy awscurl staging artifacts to S3 + if: ${{ github.event.inputs.mode == 'staging' }} + run: | + ./gradlew :awscurl:jar + DJL_VERSION=$(cat gradle.properties | awk -F '=' '/djl_version/ {print $2}') + aws s3 cp awscurl/build/awscurl s3://djl-ai/publish/${DJL_VERSION}/awscurl/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/awscurl/${DJL_VERSION}/*" - name: Publish to snapshot repository if: ${{ github.event.inputs.mode == '' || github.event.inputs.mode == 'snapshot' }} run: ./gradlew publish -Psnapshot --refresh-dependencies diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 5b627cfa60b..21634a08872 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,4 +1,5 @@ -## Code of Conduct +# Code of Conduct + This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact opensource-codeofconduct@amazon.com with any additional questions or comments. diff --git a/README.md b/README.md index 30975aec73a..29235fb6e21 100644 --- a/README.md +++ b/README.md @@ -81,34 +81,18 @@ The following pseudocode demonstrates running training: - [Documentation](docs/README.md#documentation) - [DJL's D2L Book](https://d2l.djl.ai/) -- [JavaDoc API Reference](https://javadoc.djl.ai/) +- [JavaDoc API Reference](https://djl.ai/website/javadoc.html) ## Release Notes +* [0.27.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.27.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.27.0)) +* [0.26.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.26.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.26.0)) +* [0.25.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.25.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.25.0)) +* [0.24.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.24.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.24.0)) * [0.23.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.23.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.23.0)) -* [0.22.1](https://github.com/deepjavalibrary/djl/releases/tag/v0.22.1) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.22.1)) -* [0.21.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.21.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.21.0)) -* [0.20.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.20.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.20.0)) -* [0.19.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.19.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.19.0)) -* [0.18.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.18.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.18.0)) -* [0.17.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.17.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.17.0)) -* [0.16.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.16.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.16.0)) -* [0.15.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.15.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.15.0)) -* [0.14.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.14.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.14.0)) -* [0.13.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.13.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.13.0)) -* [0.12.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.12.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.12.0)) -* [0.11.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.11.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.11.0)) -* [0.10.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.10.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.10.0)) -* [0.9.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.9.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.9.0)) -* [0.8.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.8.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.8.0)) -* [0.6.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.6.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.6.0)) -* [0.5.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.5.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.5.0)) -* [0.4.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.4.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.4.0)) -* [0.3.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.3.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.3.0)) -* [0.2.1](https://github.com/deepjavalibrary/djl/releases/tag/v0.2.1) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.2.1)) -* [0.2.0 Initial release](https://github.com/deepjavalibrary/djl/releases/tag/v0.2.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.2.0)) - -The release of DJL 0.24.0 is planned for August or September 2023. +* [+23 releases](https://github.com/deepjavalibrary/djl/releases) + +The release of DJL 0.28.0 is planned for May 2024. ## Building From Source diff --git a/android/README.md b/android/README.md index 739cd86093b..4eecc64bc43 100644 --- a/android/README.md +++ b/android/README.md @@ -16,7 +16,7 @@ In gradle, you can add the 5 modules in your dependencies: ```groovy dependencies { - implementation platform("ai.djl:bom:0.23.0") + implementation platform("ai.djl:bom:0.27.0") implementation "ai.djl:api" implementation "ai.djl.android:core" diff --git a/android/core/build.gradle b/android/core/build.gradle index 11b8c473c00..fc549d5afb2 100644 --- a/android/core/build.gradle +++ b/android/core/build.gradle @@ -109,3 +109,7 @@ dependencies { androidTestImplementation 'androidx.test.ext:junit:1.1.5' androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' } + +configurations.configureEach { + exclude group: "org.apache.commons", module: "commons-compress" +} diff --git a/android/gradle.properties b/android/gradle.properties index 68ad6c12151..4c93af45329 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -17,5 +17,5 @@ org.gradle.jvmargs=-Xmx1536m android.useAndroidX=true # Automatically convert third-party libraries to use AndroidX android.enableJetifier=true -djl_version=0.23.0 -pytorch_version=1.13.1 +djl_version=0.27.0 +pytorch_version=2.1.1 diff --git a/android/pytorch-native/README.md b/android/pytorch-native/README.md index 6a955a9e4ce..fb435e9aa7b 100644 --- a/android/pytorch-native/README.md +++ b/android/pytorch-native/README.md @@ -124,7 +124,7 @@ cd .. ./gradlew compileAndroidJNI -Ppt_version=${PYTORCH_VERSION} ``` -`jnilib/0.23.0/android` folder will be created after build, and shared library will be uploaded to S3 in CI build +`jnilib/0.27.0/android` folder will be created after build, and shared library will be uploaded to S3 in CI build ## Build PyTorch android library (.aar) and publish to Sonatype snapshot repo @@ -138,7 +138,7 @@ cd ../../../android # To avoid download jni from S3, manually copy them mkdir -p pytorch-native/jnilib -cp -r ../engines/pytorch/pytorch-native/jnilib/0.23.0/android/* pytorch-native/jnilib +cp -r ../engines/pytorch/pytorch-native/jnilib/0.27.0/android/* pytorch-native/jnilib ./gradlew :pytorch-native:assemble # publish to local maven repo (~/.m2 folder) diff --git a/api/README.md b/api/README.md index 85ad22f0188..fa29d149706 100644 --- a/api/README.md +++ b/api/README.md @@ -35,7 +35,7 @@ You can pull the DJL API from the central Maven repository by including the foll ai.djl api - 0.23.0 + 0.27.0 ``` @@ -45,7 +45,7 @@ For testing the current nightly build, use the following: ai.djl api - 0.24.0-SNAPSHOT + 0.28.0-SNAPSHOT ``` diff --git a/api/build.gradle b/api/build.gradle index 707b20b81d8..729e6e58964 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -1,7 +1,9 @@ dependencies { api "com.google.code.gson:gson:${gson_version}" api "net.java.dev.jna:jna:${jna_version}" - api "org.apache.commons:commons-compress:${commons_compress_version}" + api ("org.apache.commons:commons-compress:${commons_compress_version}") { + exclude group: "org.apache.commons", module: "commons-lang3" + } api "org.slf4j:slf4j-api:${slf4j_version}" testImplementation("org.testng:testng:${testng_version}") { diff --git a/api/src/main/java/ai/djl/BaseModel.java b/api/src/main/java/ai/djl/BaseModel.java index 572ab65508c..db2f0d3dd70 100644 --- a/api/src/main/java/ai/djl/BaseModel.java +++ b/api/src/main/java/ai/djl/BaseModel.java @@ -339,8 +339,12 @@ protected Path paramPathResolver(String prefix, Map options) throws I protected boolean readParameters(Path paramFile, Map options) throws IOException, MalformedModelException { logger.debug("Try to load model from {}", paramFile); - try (DataInputStream dis = - new DataInputStream(new BufferedInputStream(Files.newInputStream(paramFile)))) { + return readParameters(Files.newInputStream(paramFile), options); + } + + protected boolean readParameters(InputStream paramStream, Map options) + throws IOException, MalformedModelException { + try (DataInputStream dis = new DataInputStream(new BufferedInputStream(paramStream))) { byte[] buf = new byte[4]; dis.readFully(buf); if (!"DJL@".equals(new String(buf, StandardCharsets.US_ASCII))) { diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java index ce9b29ae5ba..597d7d9be02 100644 --- a/api/src/main/java/ai/djl/Device.java +++ b/api/src/main/java/ai/djl/Device.java @@ -14,11 +14,17 @@ import ai.djl.engine.Engine; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code @@ -30,7 +36,7 @@ * @see The D2L chapter * on GPU devices */ -public final class Device { +public class Device { private static final Map CACHE = new ConcurrentHashMap<>(); @@ -39,8 +45,8 @@ public final class Device { private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)"); - private String deviceType; - private int deviceId; + protected String deviceType; + protected int deviceId; /** * Creates a {@code Device} with basic information. @@ -101,6 +107,13 @@ public static Device fromName(String deviceName, Engine engine) { return engine.defaultDevice(); } + if (deviceName.contains("+")) { + String[] split = deviceName.split("\\+"); + List subDevices = + Arrays.stream(split).map(n -> fromName(n, engine)).collect(Collectors.toList()); + return new MultiDevice(subDevices); + } + Matcher matcher = DEVICE_NAME.matcher(deviceName); if (matcher.matches()) { String deviceType = matcher.group(1); @@ -150,6 +163,15 @@ public boolean isGpu() { return Type.GPU.equals(deviceType); } + /** + * Returns the sub devices if present (such as a {@link MultiDevice}), otherwise this. + * + * @return the sub devices if present (such as a {@link MultiDevice}), otherwise this. + */ + public List getDevices() { + return Collections.singletonList(this); + } + /** {@inheritDoc} */ @Override public String toString() { @@ -214,4 +236,88 @@ public interface Type { String CPU = "cpu"; String GPU = "gpu"; } + + /** A combined {@link Device} representing the composition of multiple other devices. */ + public static class MultiDevice extends Device { + + List devices; + + /** + * Constructs a {@link MultiDevice} with a range of new devices. + * + * @param deviceType the type of the sub-devices + * @param startInclusive the start (inclusive) of the devices range + * @param endExclusive the end (exclusive) of the devices range + */ + public MultiDevice(String deviceType, int startInclusive, int endExclusive) { + this( + IntStream.range(startInclusive, endExclusive) + .mapToObj(i -> Device.of(deviceType, i)) + .collect(Collectors.toList())); + } + + /** + * Constructs a {@link MultiDevice} from sub devices. + * + * @param devices the sub devices + */ + public MultiDevice(Device... devices) { + this(Arrays.asList(devices)); + } + + /** + * Constructs a {@link MultiDevice} from sub devices. + * + * @param devices the sub devices + */ + public MultiDevice(List devices) { + super(null, -1); + devices.sort( + Comparator.comparing(Device::getDeviceType, String.CASE_INSENSITIVE_ORDER) + .thenComparingInt(Device::getDeviceId)); + this.deviceType = + String.join( + "+", + (Iterable) + () -> + devices.stream() + .map(d -> d.getDeviceType() + d.getDeviceId()) + .iterator()); + this.devices = devices; + } + + /** {@inheritDoc} */ + @Override + public List getDevices() { + return devices; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + MultiDevice that = (MultiDevice) o; + return Objects.equals(devices, that.devices); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), devices); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return deviceType + "()"; + } + } } diff --git a/api/src/main/java/ai/djl/Model.java b/api/src/main/java/ai/djl/Model.java index 1903ca392dc..dab2568949f 100644 --- a/api/src/main/java/ai/djl/Model.java +++ b/api/src/main/java/ai/djl/Model.java @@ -233,6 +233,36 @@ default String getProperty(String key, String defValue) { */ void setProperty(String key, String value); + /** + * Returns the property of the model based on property name. + * + * @param key the name of the property + * @param defValue the default value if key not found + * @return the value of the property + */ + default int intProperty(String key, int defValue) { + String value = getProperty(key); + if (value == null || value.isEmpty()) { + return defValue; + } + return Integer.parseInt(value); + } + + /** + * Returns the property of the model based on property name. + * + * @param key the name of the property + * @param defValue the default value if key not found + * @return the value of the property + */ + default long longProperty(String key, long defValue) { + String value = getProperty(key); + if (value == null || value.isEmpty()) { + return defValue; + } + return Integer.parseInt(value); + } + /** * Gets the {@link NDManager} from the model. * diff --git a/api/src/main/java/ai/djl/inference/Predictor.java b/api/src/main/java/ai/djl/inference/Predictor.java index 853b30d7a5e..045a26b0d04 100644 --- a/api/src/main/java/ai/djl/inference/Predictor.java +++ b/api/src/main/java/ai/djl/inference/Predictor.java @@ -17,6 +17,7 @@ import ai.djl.inference.streaming.StreamingBlock; import ai.djl.inference.streaming.StreamingTranslator; import ai.djl.inference.streaming.StreamingTranslator.StreamOutput; +import ai.djl.metric.Dimension; import ai.djl.metric.Metrics; import ai.djl.metric.Unit; import ai.djl.ndarray.LazyNDArray; @@ -33,7 +34,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.lang.reflect.Array; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -60,14 +63,13 @@ * * * * @param the input type @@ -95,6 +97,7 @@ public class Predictor implements AutoCloseable { protected Metrics metrics; protected Block block; protected ParameterStore parameterStore; + protected Dimension dimension; /** * Creates a new instance of {@code BasePredictor} with the given {@link Model} and {@link @@ -117,6 +120,7 @@ public Predictor(Model model, Translator translator, Device device, boolea this.translator = translator; block = model.getBlock(); parameterStore = new ParameterStore(manager, copy); + dimension = new Dimension("Model", model.getProperty("metric_dimension", "model")); } /** @@ -151,42 +155,46 @@ protected NDList predictInternal(TranslatorContext ctx, NDList ndList) * @return a list of output objects defined by the user * @throws TranslateException if an error occurs during prediction */ - @SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches"}) + @SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches", "unchecked"}) public List batchPredict(List inputs) throws TranslateException { - long begin = System.nanoTime(); try (PredictorContext context = new PredictorContext()) { if (!prepared) { translator.prepare(context); prepared = true; } - Batchifier batchifier = translator.getBatchifier(); - if (batchifier == null) { + Translator batchTranslator = translator.toBatchTranslator(); + if (batchTranslator == null) { List ret = new ArrayList<>(inputs.size()); for (I input : inputs) { timestamp = System.nanoTime(); - begin = timestamp; + long begin = timestamp; NDList ndList = translator.processInput(context, input); - preprocessEnd(ndList); + preprocessEnd(ndList, 1); NDList result = predictInternal(context, ndList); - predictEnd(result); + predictEnd(result, 1); ret.add(translator.processOutput(context, result)); - postProcessEnd(begin); + postProcessEnd(begin, 1); } return ret; } + int batchSize = inputs.size(); + I[] empty = (I[]) Array.newInstance(inputs.get(0).getClass(), 0); + I[] in = inputs.toArray(empty); + timestamp = System.nanoTime(); - NDList inputBatch = processInputs(context, inputs); - preprocessEnd(inputBatch); + long begin = timestamp; + NDList ndList = batchTranslator.processInput(context, in); + preprocessEnd(ndList, batchSize); - NDList result = predictInternal(context, inputBatch); - predictEnd(result); + NDList result = predictInternal(context, ndList); + predictEnd(result, batchSize); - List ret = processOutputs(context, result); - postProcessEnd(begin); - return ret; + O[] ret = batchTranslator.processOutput(context, result); + postProcessEnd(begin, batchSize); + return Arrays.asList(ret); } catch (TranslateException e) { throw e; } catch (Exception e) { @@ -300,43 +308,34 @@ private NDList processInputs(TranslatorContext ctx, List inputs) throws Excep return translator.getBatchifier().batchify(preprocessed); } - @SuppressWarnings("PMD.SignatureDeclareThrowsException") - private List processOutputs(TranslatorContext ctx, NDList list) throws Exception { - NDList[] unbatched = translator.getBatchifier().unbatchify(list); - List outputs = new ArrayList<>(unbatched.length); - for (NDList output : unbatched) { - outputs.add(translator.processOutput(ctx, output)); - } - return outputs; - } - - private void preprocessEnd(NDList list) { + private void preprocessEnd(NDList list, int batchSize) { if (metrics != null) { waitToRead(list); long tmp = System.nanoTime(); - long duration = (tmp - timestamp) / 1000; + long duration = (tmp - timestamp) / 1000 / batchSize; timestamp = tmp; - metrics.addMetric("Preprocess", duration, Unit.MICROSECONDS); + metrics.addMetric("Preprocess", duration, Unit.MICROSECONDS, dimension); } } - private void predictEnd(NDList list) { + private void predictEnd(NDList list, int batchSize) { if (metrics != null) { waitToRead(list); long tmp = System.nanoTime(); - long duration = (tmp - timestamp) / 1000; + long duration = (tmp - timestamp) / 1000 / batchSize; timestamp = tmp; - metrics.addMetric("Inference", duration, Unit.MICROSECONDS); + metrics.addMetric("Inference", duration, Unit.MICROSECONDS, dimension); } } - private void postProcessEnd(long begin) { + private void postProcessEnd(long begin, int batchSize) { if (metrics != null) { long tmp = System.nanoTime(); - long duration = (tmp - timestamp) / 1000; + long duration = (tmp - timestamp) / 1000 / batchSize; timestamp = tmp; - metrics.addMetric("Postprocess", duration, Unit.MICROSECONDS); - metrics.addMetric("Total", (tmp - begin) / 1000, Unit.MICROSECONDS); + metrics.addMetric("Postprocess", duration, Unit.MICROSECONDS, dimension); + long prediction = (tmp - begin) / 1000; + metrics.addMetric("Prediction", prediction, Unit.MICROSECONDS, dimension); } } diff --git a/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java b/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java index d83c4678f33..d5fdfda878b 100644 --- a/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java +++ b/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java @@ -14,13 +14,10 @@ import ai.djl.ndarray.BytesSupplier; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; /** @@ -29,16 +26,14 @@ */ public class PublisherBytesSupplier implements BytesSupplier { - private final List allData; - private final AtomicBoolean completed; private Consumer subscriber; - private final AtomicInteger dataPushed; + private CountDownLatch latch; + private CompletableFuture future; /** Constructs a {@link PublisherBytesSupplier}. */ public PublisherBytesSupplier() { - allData = new ArrayList<>(); - completed = new AtomicBoolean(); - dataPushed = new AtomicInteger(); + latch = new CountDownLatch(1); + future = new CompletableFuture<>(); } /** @@ -48,13 +43,24 @@ public PublisherBytesSupplier() { * @param lastChunk true if this is the last chunk */ public void appendContent(byte[] data, boolean lastChunk) { - synchronized (allData) { - allData.add(data); + if (subscriber == null) { + try { + if (!latch.await(2, TimeUnit.MINUTES)) { + throw new IllegalStateException("Wait for subscriber timeout."); + } + if (subscriber == null) { + // workaround Spotbugs + throw new IllegalStateException("subscriber is not set."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Append content interrupted.", e); + } } + subscriber.accept(data); if (lastChunk) { - completed.set(true); + subscriber.accept(null); + future.complete(null); } - pushData(); } /** @@ -62,69 +68,21 @@ public void appendContent(byte[] data, boolean lastChunk) { * * @param subscriber a consumer function that will receive bytes when new daata is added and * null when completed + * @return a {@code CompletableFuture} object */ - public void subscribe(Consumer subscriber) { + public CompletableFuture subscribe(Consumer subscriber) { if (this.subscriber != null) { throw new IllegalStateException( "The PublisherBytesSupplier only allows a single Subscriber"); } this.subscriber = subscriber; - pushData(); - } - - private void pushData() { - if (subscriber == null) { - return; - } - - int dataAvailable; - synchronized (allData) { - dataAvailable = allData.size(); - } - - int sent = dataPushed.getAndSet(dataAvailable); - if (sent < dataAvailable) { - synchronized (this) { - for (; sent < dataAvailable; sent++) { - subscriber.accept(allData.get(sent)); - } - if (completed.get()) { - subscriber.accept(null); - } - } - } - } - - /** Waits until completed before passing thread (BLOCKS THREAD!). */ - @SuppressWarnings("PMD.EmptyControlStatement") - public void waitToRead() { - // Block until complete!!! - while (!completed.get()) { - // Do nothing - } - } - - /** {@inheritDoc} */ - @Override - public byte[] getAsBytes() { - if (!completed.get()) { - throw new IllegalStateException( - "PublisherByteSupplier must be completely filled before reading."); - } - - try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { - for (byte[] data : allData) { - bos.write(data); - } - return bos.toByteArray(); - } catch (IOException e) { - throw new AssertionError("Failed to read BytesSupplier", e); - } + latch.countDown(); + return future; } /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(getAsBytes()); + throw new UnsupportedOperationException("Not supported."); } } diff --git a/api/src/main/java/ai/djl/metric/Metric.java b/api/src/main/java/ai/djl/metric/Metric.java index b3c172ce8b3..4743421a1f6 100644 --- a/api/src/main/java/ai/djl/metric/Metric.java +++ b/api/src/main/java/ai/djl/metric/Metric.java @@ -105,6 +105,16 @@ private Metric( this.dimensions = dimensions; } + /** + * Returns a copy of the metric with a new name. + * + * @param name the new metric name + * @return a copy of the metric + */ + public Metric copyOf(String name) { + return new Metric(name, value, unit, timestamp, dimensions); + } + /** * Returns the name of the {@code Metric}. * diff --git a/api/src/main/java/ai/djl/metric/Metrics.java b/api/src/main/java/ai/djl/metric/Metrics.java index bdbd0d3e732..19fcc0f681d 100644 --- a/api/src/main/java/ai/djl/metric/Metrics.java +++ b/api/src/main/java/ai/djl/metric/Metrics.java @@ -103,9 +103,10 @@ public void addMetric(String name, Number value) { * @param name the metric name * @param value the metric value * @param unit the metric unit + * @param dimensions the metric dimensions */ - public void addMetric(String name, Number value, Unit unit) { - addMetric(new Metric(name, value, unit)); + public void addMetric(String name, Number value, Unit unit, Dimension... dimensions) { + addMetric(new Metric(name, value, unit, dimensions)); } /** @@ -172,7 +173,8 @@ public Metric percentile(String metricName, int percentile) { List list = new ArrayList<>(metric); list.sort(Comparator.comparingDouble(Metric::getValue)); int index = metric.size() * percentile / 100; - return list.get(index); + Metric m = list.get(index); + return m.copyOf(m.getMetricName() + "_p" + percentile); } /** diff --git a/api/src/main/java/ai/djl/metric/Unit.java b/api/src/main/java/ai/djl/metric/Unit.java index 81d45e63185..3703d2a312e 100644 --- a/api/src/main/java/ai/djl/metric/Unit.java +++ b/api/src/main/java/ai/djl/metric/Unit.java @@ -41,6 +41,7 @@ public enum Unit { GIGABITS_PER_SECOND("Gigabits/Second"), TERABITS_PER_SECOND("Terabits/Second"), COUNT_PER_SECOND("Count/Second"), + COUNT_PER_ITEM("Count/Item"), NONE("None"); private static final ConcurrentHashMap MAP = new ConcurrentHashMap<>(); diff --git a/api/src/main/java/ai/djl/modality/Classifications.java b/api/src/main/java/ai/djl/modality/Classifications.java index 84025ce07e1..070c0372a7a 100644 --- a/api/src/main/java/ai/djl/modality/Classifications.java +++ b/api/src/main/java/ai/djl/modality/Classifications.java @@ -53,7 +53,7 @@ public class Classifications implements JsonSerializable, Ensembleable probabilities; - private int topK; + protected int topK; /** * Constructs a {@code Classifications} using a parallel list of classNames and probabilities. @@ -88,10 +88,18 @@ public Classifications(List classNames, NDArray probabilities) { */ public Classifications(List classNames, NDArray probabilities, int topK) { this.classNames = classNames; - NDArray array = probabilities.toType(DataType.FLOAT64, false); - this.probabilities = - Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList()); - array.close(); + if (probabilities.getDataType() == DataType.FLOAT32) { + // Avoid converting float32 to float64 as this is not supported on MPS device + this.probabilities = new ArrayList<>(); + for (float prob : probabilities.toFloatArray()) { + this.probabilities.add((double) prob); + } + } else { + NDArray array = probabilities.toType(DataType.FLOAT64, false); + this.probabilities = + Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList()); + array.close(); + } this.topK = topK; } diff --git a/api/src/main/java/ai/djl/modality/Input.java b/api/src/main/java/ai/djl/modality/Input.java index ecd0679661b..45c6f8161f7 100644 --- a/api/src/main/java/ai/djl/modality/Input.java +++ b/api/src/main/java/ai/djl/modality/Input.java @@ -37,6 +37,7 @@ public class Input { protected Map properties; protected PairList content; + private boolean cancelled; /** Constructs a new {@code Input} instance. */ public Input() { @@ -44,6 +45,24 @@ public Input() { content = new PairList<>(); } + /** + * Returns {@code true} if the input is cancelled. + * + * @return {@code true} if the input is cancelled. + */ + public boolean isCancelled() { + return cancelled; + } + + /** + * Sets the cancelled status. + * + * @param cancelled the cancelled status + */ + public void setCancelled(boolean cancelled) { + this.cancelled = cancelled; + } + /** * Returns the properties of the input. * diff --git a/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java b/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java index c7c676fca01..c7d5414da28 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java +++ b/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java @@ -43,7 +43,7 @@ public class CategoryMask implements JsonSerializable { .registerTypeAdapter(CategoryMask.class, new SegmentationSerializer()) .create(); - private List classes; + private transient List classes; private int[][] mask; /** diff --git a/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java b/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java index 2fd90fe39ec..9d58575af59 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java +++ b/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java @@ -48,7 +48,7 @@ public DetectedObjects( List classNames, List probabilities, List boundingBoxes) { super(classNames, probabilities); this.boundingBoxes = boundingBoxes; - setTopK(Integer.MAX_VALUE); + this.topK = Integer.MAX_VALUE; } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java index a4ebfcb9df1..c31353766d3 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java @@ -160,7 +160,7 @@ protected double overlap(double x1, double w1, double x2, double w2) { return right - left; } - private DetectedObjects processFromBoxOutput(NDList list) { + protected DetectedObjects processFromBoxOutput(NDList list) { float[] flattened = list.get(0).toFloatArray(); ArrayList intermediateResults = new ArrayList<>(); int sizeClasses = classes.size(); @@ -280,7 +280,7 @@ public YoloV5Translator build() { } } - private static final class IntermediateResult { + protected static final class IntermediateResult { /** * A sortable score for how good the recognition is relative to others. Higher should be diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java new file mode 100644 index 00000000000..faf31ab3188 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -0,0 +1,130 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.ArgumentsUtil; + +import java.util.ArrayList; +import java.util.Map; + +/** + * A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check + * here: https://github.com/ultralytics/ultralytics + */ +public class YoloV8Translator extends YoloV5Translator { + + private int maxBoxes; + + /** + * Constructs an ImageTranslator with the provided builder. + * + * @param builder the data to build with + */ + protected YoloV8Translator(Builder builder) { + super(builder); + maxBoxes = builder.maxBox; + } + + /** + * Creates a builder to build a {@code YoloV8Translator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static YoloV8Translator.Builder builder(Map arguments) { + YoloV8Translator.Builder builder = new YoloV8Translator.Builder(); + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + /** {@inheritDoc} */ + @Override + protected DetectedObjects processFromBoxOutput(NDList list) { + NDArray rawResult = list.get(0); + NDArray reshapedResult = rawResult.transpose(); + Shape shape = reshapedResult.getShape(); + float[] buf = reshapedResult.toFloatArray(); + int numberRows = Math.toIntExact(shape.get(0)); + int nClasses = Math.toIntExact(shape.get(1)); + int padding = nClasses - classes.size(); + if (padding != 0 && padding != 4) { + throw new IllegalStateException( + "Expected classes: " + (nClasses - 4) + ", got " + classes.size()); + } + + ArrayList intermediateResults = new ArrayList<>(); + // reverse order search in heap; searches through #maxBoxes for optimization when set + for (int i = numberRows - 1; i > numberRows - maxBoxes; --i) { + int index = i * nClasses; + float maxClassProb = -1f; + int maxIndex = -1; + for (int c = 4; c < nClasses; c++) { + float classProb = buf[index + c]; + if (classProb > maxClassProb) { + maxClassProb = classProb; + maxIndex = c; + } + } + maxIndex -= padding; + + if (maxClassProb > threshold) { + float xPos = buf[index]; // center x + float yPos = buf[index + 1]; // center y + float w = buf[index + 2]; + float h = buf[index + 3]; + Rectangle rect = + new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); + intermediateResults.add( + new IntermediateResult( + classes.get(maxIndex), maxClassProb, maxIndex, rect)); + } + } + return nms(intermediateResults); + } + + /** The builder for {@link YoloV8Translator}. */ + public static class Builder extends YoloV5Translator.Builder { + + private int maxBox = 8400; + + /** + * Builds the translator. + * + * @return the new translator + */ + @Override + public YoloV8Translator build() { + if (pipeline == null) { + addTransform( + array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255)); + } + validate(); + return new YoloV8Translator(this); + } + + /** {@inheritDoc} */ + @Override + protected void configPostProcess(Map arguments) { + super.configPostProcess(arguments); + maxBox = ArgumentsUtil.intValue(arguments, "maxBox", 8400); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java new file mode 100644 index 00000000000..b5a4db00d28 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.Translator; + +import java.io.Serializable; +import java.util.Map; + +/** A translatorFactory that creates a {@link YoloV8Translator} instance. */ +public class YoloV8TranslatorFactory extends ObjectDetectionTranslatorFactory + implements Serializable { + + private static final long serialVersionUID = 1L; + + /** {@inheritDoc} */ + @Override + protected Translator buildBaseTranslator( + Model model, Map arguments) { + return YoloV8Translator.builder(arguments).build(); + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/Decoder.java b/api/src/main/java/ai/djl/modality/nlp/Decoder.java index e8081666950..c422665b147 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Decoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Decoder.java @@ -42,6 +42,7 @@ public abstract class Decoder extends AbstractBlock { * @param block the block to be used to decode * @param version the version to use for parameter and metadata serialization */ + @SuppressWarnings("this-escape") public Decoder(byte version, Block block) { super(version); this.block = addChildBlock("Block", block); diff --git a/api/src/main/java/ai/djl/modality/nlp/Encoder.java b/api/src/main/java/ai/djl/modality/nlp/Encoder.java index 4c5a4469388..221626d7559 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Encoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Encoder.java @@ -40,6 +40,7 @@ public abstract class Encoder extends AbstractBlock { * @param version the version to use for parameter and metadata serialization * @param block the encoder block */ + @SuppressWarnings("this-escape") public Encoder(byte version, Block block) { super(version); this.block = addChildBlock("Block", block); diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java index 58cc67867c7..24abcb77bb8 100644 --- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java @@ -46,6 +46,7 @@ public class EncoderDecoder extends AbstractBlock { * @param encoder the {@link Encoder} * @param decoder the {@link Decoder} */ + @SuppressWarnings("this-escape") public EncoderDecoder(Encoder encoder, Decoder decoder) { super(VERSION); this.encoder = addChildBlock("Encoder", encoder); diff --git a/api/src/main/java/ai/djl/modality/nlp/TextPrompt.java b/api/src/main/java/ai/djl/modality/nlp/TextPrompt.java new file mode 100644 index 00000000000..dd1cef113bd --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/TextPrompt.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp; + +import ai.djl.modality.Input; +import ai.djl.translate.TranslateException; +import ai.djl.util.JsonUtils; + +import com.google.gson.JsonElement; +import com.google.gson.JsonParseException; + +/** The input container for NLP text prompt. */ +public final class TextPrompt { + + private String text; + private String[] batch; + + private TextPrompt(String text) { + this.text = text; + } + + private TextPrompt(String[] batch) { + this.batch = batch; + } + + /** + * Returns if the prompt is a batch. + * + * @return {@code true} if the prompt is a batch + */ + public boolean isBatch() { + return batch != null; + } + + /** + * Returns the single prompt. + * + * @return the single prompt + */ + public String getText() { + return text; + } + + /** + * Returns the batch prompt. + * + * @return the batch prompt + */ + public String[] getBatch() { + return batch; + } + + /** + * Returns the {@code TextPrompt} from the {@link Input}. + * + * @param input the input object + * @return the {@code TextPrompt} from the {@link Input} + * @throws TranslateException if the input is invalid + */ + public static TextPrompt parseInput(Input input) throws TranslateException { + String contentType = input.getProperty("Content-Type", null); + String text = input.getData().getAsString(); + if (!"application/json".equals(contentType)) { + return new TextPrompt(text); + } + + try { + JsonElement element = JsonUtils.GSON.fromJson(text, JsonElement.class); + if (element != null && element.isJsonObject()) { + element = element.getAsJsonObject().get("inputs"); + } + if (element == null) { + throw new TranslateException("Missing \"inputs\" in json."); + } else if (element.isJsonArray()) { + String[] batch = JsonUtils.GSON.fromJson(element, String[].class); + return new TextPrompt(batch); + } else { + return new TextPrompt(element.getAsString()); + } + } catch (JsonParseException e) { + throw new TranslateException("Input is not a valid json.", e); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java index af153cb0b23..a65e9cebb4f 100644 --- a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java +++ b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java @@ -38,6 +38,7 @@ public class TrainableTextEmbedding extends AbstractBlock implements TextEmbeddi * * @param wordEmbedding the word embedding to embed each word */ + @SuppressWarnings("this-escape") public TrainableTextEmbedding(TrainableWordEmbedding wordEmbedding) { this.trainableWordEmbedding = addChildBlock("trainableWordEmbedding", wordEmbedding); } diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java index 78f40c5b2f2..07c63428ff6 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java @@ -35,15 +35,13 @@ * policy is setting several thresholds. */ public abstract class SeqBatchScheduler { + private static final Logger logger = LoggerFactory.getLogger(SeqBatchScheduler.class); Predictor predictor; SeqBatcher seqBatcher; - NDManager manager; - SearchConfig config; - Map results; /** @@ -101,7 +99,7 @@ public boolean incrementForward(int count) throws TranslateException { * @return the output token ids * @throws TranslateException if forward fails */ - abstract NDArray inferenceCall() throws TranslateException; + protected abstract NDArray inferenceCall() throws TranslateException; /** * Adds new batch. diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java index cf267318588..7590244e4a4 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java @@ -134,9 +134,9 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException { * Generates text using beam search. * * @param inputIds input tokens ids - * @see Beam Search * @return the output token ids stored as NDArray and the endPosition of each sentence * @throws TranslateException if failed run forward + * @see Beam Search */ @SuppressWarnings("try") public NDArray beamSearch(NDArray inputIds) throws TranslateException { @@ -261,9 +261,9 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException { * Generates text using contrastive search. * * @param inputIds input token ids - * @see Contrastive Search * @return the output token ids stored as NDArray * @throws TranslateException if forward failed + * @see Contrastive Search */ @SuppressWarnings("try") public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException { diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java new file mode 100644 index 00000000000..e62167a34b2 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.translator; + +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.BytesSupplier; +import ai.djl.ndarray.NDList; +import ai.djl.translate.Batchifier; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; +import ai.djl.util.PairList; +import ai.djl.util.StringPair; + +import com.google.gson.JsonElement; +import com.google.gson.JsonParseException; + +/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */ +public class CrossEncoderServingTranslator implements NoBatchifyTranslator { + + private Translator translator; + private Translator batchTranslator; + + /** + * Constructs a {@code CrossEncoderServingTranslator} instance. + * + * @param translator a {@code Translator} processes question answering input + */ + public CrossEncoderServingTranslator(Translator translator) { + this.translator = translator; + this.batchTranslator = translator.toBatchTranslator(); + } + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) throws Exception { + translator.prepare(ctx); + batchTranslator.prepare(ctx); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Input input) throws Exception { + PairList content = input.getContent(); + if (content.isEmpty()) { + throw new TranslateException("Input data is empty."); + } + + String contentType = input.getProperty("Content-Type", null); + StringPair pair; + if ("application/json".equals(contentType)) { + String json = input.getData().getAsString(); + try { + JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class); + if (element.isJsonArray()) { + ctx.setAttachment("batch", Boolean.TRUE); + StringPair[] inputs = JsonUtils.GSON.fromJson(json, StringPair[].class); + return batchTranslator.processInput(ctx, inputs); + } + + pair = JsonUtils.GSON.fromJson(json, StringPair.class); + if (pair.getKey() == null || pair.getValue() == null) { + throw new TranslateException("Missing key or value in json."); + } + } catch (JsonParseException e) { + throw new TranslateException("Input is not a valid json.", e); + } + } else { + String key = input.getAsString("key"); + String value = input.getAsString("value"); + if (key == null || value == null) { + throw new TranslateException("Missing key or value in input."); + } + pair = new StringPair(key, value); + } + + NDList ret = translator.processInput(ctx, pair); + Batchifier batchifier = translator.getBatchifier(); + if (batchifier != null) { + NDList[] batch = {ret}; + return batchifier.batchify(batch); + } + return ret; + } + + /** {@inheritDoc} */ + @Override + public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { + Output output = new Output(); + output.addProperty("Content-Type", "application/json"); + if (ctx.getAttachment("batch") != null) { + output.add(BytesSupplier.wrapAsJson(batchTranslator.processOutput(ctx, list))); + } else { + Batchifier batchifier = translator.getBatchifier(); + if (batchifier != null) { + list = batchifier.unbatchify(list)[0]; + } + output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list))); + } + return output; + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java index 27e343120c4..cb265087557 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java @@ -15,6 +15,7 @@ import ai.djl.modality.Classifications; import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.nlp.TextPrompt; import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; import ai.djl.translate.Batchifier; @@ -22,7 +23,6 @@ import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; -import ai.djl.util.JsonUtils; /** * A {@link Translator} that can handle generic text classification {@link Input} and {@link @@ -57,14 +57,13 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception throw new TranslateException("Input data is empty."); } - String contentType = input.getProperty("Content-Type", null); - String text = input.getData().getAsString(); - if ("application/json".equals(contentType)) { + TextPrompt prompt = TextPrompt.parseInput(input); + if (prompt.isBatch()) { ctx.setAttachment("batch", Boolean.TRUE); - String[] inputs = JsonUtils.GSON.fromJson(text, String[].class); - return batchTranslator.processInput(ctx, inputs); + return batchTranslator.processInput(ctx, prompt.getBatch()); } - NDList ret = translator.processInput(ctx, text); + + NDList ret = translator.processInput(ctx, prompt.getText()); Batchifier batchifier = translator.getBatchifier(); if (batchifier != null) { NDList[] batch = {ret}; diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java index 110f9e09fe5..c1e98ac0256 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java @@ -14,6 +14,7 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.nlp.TextPrompt; import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; import ai.djl.translate.Batchifier; @@ -21,10 +22,14 @@ import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; -import ai.djl.util.JsonUtils; +import ai.djl.util.Utils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; /** A {@link Translator} that can handle generic text embedding {@link Input} and {@link Output}. */ -public class TextEmbeddingServingTranslator implements NoBatchifyTranslator { +public class TextEmbeddingServingTranslator implements Translator { private Translator translator; private Translator batchTranslator; @@ -53,14 +58,13 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception throw new TranslateException("Input data is empty."); } - String contentType = input.getProperty("Content-Type", null); - String text = input.getData().getAsString(); - if ("application/json".equals(contentType)) { + TextPrompt prompt = TextPrompt.parseInput(input); + if (prompt.isBatch()) { ctx.setAttachment("batch", Boolean.TRUE); - String[] inputs = JsonUtils.GSON.fromJson(text, String[].class); - return batchTranslator.processInput(ctx, inputs); + return batchTranslator.processInput(ctx, prompt.getBatch()); } - NDList ret = translator.processInput(ctx, text); + + NDList ret = translator.processInput(ctx, prompt.getText()); Batchifier batchifier = translator.getBatchifier(); if (batchifier != null) { NDList[] batch = {ret}; @@ -85,4 +89,64 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception } return output; } + + /** {@inheritDoc} */ + @Override + public Translator toBatchTranslator(Batchifier batchifier) { + return new NoBatchifyTranslator() { + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("PMD.SignatureDeclareThrowsException") + public NDList processInput(TranslatorContext ctx, Input[] inputs) throws Exception { + List prompts = new ArrayList<>(inputs.length); + int[] mapping = new int[inputs.length]; + for (int i = 0; i < inputs.length; ++i) { + TextPrompt prompt = TextPrompt.parseInput(inputs[i]); + if (prompt.isBatch()) { + String[] batch = prompt.getBatch(); + mapping[i] = batch.length; + prompts.addAll(Arrays.asList(batch)); + } else { + mapping[i] = -1; + prompts.add(prompt.getText()); + } + } + ctx.setAttachment("mapping", mapping); + return batchTranslator.processInput(ctx, prompts.toArray(Utils.EMPTY_ARRAY)); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings({"PMD.SignatureDeclareThrowsException", "unchecked"}) + public Output[] processOutput(TranslatorContext ctx, NDList list) throws Exception { + NDList[] unbatched = batchifier.unbatchify(list); + int[] mapping = (int[]) ctx.getAttachment("mapping"); + Object[] encodings = (Object[]) ctx.getAttachment("encodings"); + Output[] ret = new Output[mapping.length]; + int index = 0; + for (int i = 0; i < ret.length; ++i) { + Output output = new Output(); + output.addProperty("Content-Type", "application/json"); + if (mapping[i] == -1) { + // non-batching + ctx.setAttachment("encoding", encodings[index]); + float[] embedding = translator.processOutput(ctx, unbatched[index]); + ++index; + output.add(BytesSupplier.wrapAsJson(embedding)); + } else { + float[][] embeddings = new float[mapping[i]][]; + for (int j = 0; j < mapping[i]; ++j) { + ctx.setAttachment("encoding", encodings[index]); + embeddings[j] = translator.processOutput(ctx, unbatched[index]); + ++index; + } + output.add(BytesSupplier.wrapAsJson(embeddings)); + } + ret[i] = output; + } + return ret; + } + }; + } } diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java index 6f97964351f..e9c5751a324 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java @@ -14,6 +14,7 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.nlp.TextPrompt; import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; import ai.djl.translate.Batchifier; @@ -21,7 +22,6 @@ import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; -import ai.djl.util.JsonUtils; /** * A {@link Translator} that can handle generic token classification {@link Input} and {@link @@ -56,14 +56,13 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception throw new TranslateException("Input data is empty."); } - String contentType = input.getProperty("Content-Type", null); - String text = input.getData().getAsString(); - if ("application/json".equals(contentType)) { + TextPrompt prompt = TextPrompt.parseInput(input); + if (prompt.isBatch()) { ctx.setAttachment("batch", Boolean.TRUE); - String[] inputs = JsonUtils.GSON.fromJson(text, String[].class); - return batchTranslator.processInput(ctx, inputs); + return batchTranslator.processInput(ctx, prompt.getBatch()); } - NDList ret = translator.processInput(ctx, text); + + NDList ret = translator.processInput(ctx, prompt.getText()); Batchifier batchifier = translator.getBatchifier(); if (batchifier != null) { NDList[] batch = {ret}; diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 29a57739aa3..c3df1ef3301 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -53,6 +53,7 @@ public abstract class BaseNDManager implements NDManager { protected AtomicBoolean closed = new AtomicBoolean(false); protected AtomicBoolean capped = new AtomicBoolean(false); + @SuppressWarnings("this-escape") protected BaseNDManager(NDManager parent, Device device) { this.parent = parent; this.device = device == null ? defaultDevice() : device; diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 385c32e88e3..2b4a9df095a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -2344,6 +2344,24 @@ default boolean allClose(NDArray other, double rtol, double atol, boolean equalN */ NDArray atan(); + /** + * Returns the element-wise arc-tangent of this/other choosing the quadrant correctly. + * + *

Examples + * + *

+     * jshell> NDArray x = manager.create(new float[] {0f, 1f});
+     * jshell> NDArray y = manager.create(new float[] {0f, -6f});
+     * jshell> x.atan2(y);
+     * ND: (2) cpu() float64
+     * [0.    , 2.9764]
+     * 
+ * + * @param other The other {@code NDArray} + * @return the result {@code NDArray} + */ + NDArray atan2(NDArray other); + /** * Returns the hyperbolic sine of this {@code NDArray} element-wise. * @@ -3375,6 +3393,48 @@ NDArray stft( boolean normalize, boolean returnComplex); + /** + * Computes the two-dimensional Discrete Fourier Transform. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @param axes Axes over which to compute the 2D-FFT. + * @return The truncated or zero-padded input, transformed along the axes. + */ + NDArray fft2(long[] sizes, long[] axes); + + /** + * Computes the two-dimensional Discrete Fourier Transform along the last 2 axes. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @return The truncated or zero-padded input, transformed along the last two axes + */ + default NDArray fft2(long[] sizes) { + return fft2(sizes, new long[] {-2, -1}); + } + + /** + * Computes the two-dimensional inverse Discrete Fourier Transform. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @param axes Axes over which to compute the 2D-Inverse-FFT. + * @return The truncated or zero-padded input, transformed along the axes. + */ + NDArray ifft2(long[] sizes, long[] axes); + + /** + * Computes the two-dimensional inverse Discrete Fourier Transform along the last 2 axes. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @return The truncated or zero-padded input, transformed along the axes. + */ + default NDArray ifft2(long[] sizes) { + return ifft2(sizes, new long[] {-2, -1}); + } + /** * Reshapes this {@code NDArray} to the given {@link Shape}. * @@ -4922,6 +4982,22 @@ default NDArray countNonzero(int axis) { */ NDArray erfinv(); + /** + * Returns element-wise gauss error function of the {@code NDArray}. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
+     * jshell> array.erf();
+     * ND: (3) cpu() float32
+     * [0., 0.5, -1]
+     * 
+ * + * @return The gauss error of the {@code NDArray}, element-wise + */ + NDArray erf(); + /** {@inheritDoc} */ @Override default List getResourceNDArrays() { diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 59047e688c8..9a4ad8db93a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -726,6 +726,12 @@ public NDArray atan() { return getAlternativeArray().atan(); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + return getAlternativeArray().atan2(other); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { @@ -906,6 +912,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { @@ -1188,6 +1206,12 @@ public NDArray erfinv() { return getAlternativeArray().erfinv(); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return getAlternativeArray().erf(); + } + /** {@inheritDoc} */ @Override public NDArray inverse() { diff --git a/api/src/main/java/ai/djl/ndarray/NDArrays.java b/api/src/main/java/ai/djl/ndarray/NDArrays.java index 304b803939c..0e1c0922a7b 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrays.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrays.java @@ -1996,4 +1996,23 @@ public static NDArray logicalXor(NDArray a, NDArray b) { public static NDArray erfinv(NDArray input) { return input.erfinv(); } + + /** + * Returns element-wise gauss error function of the {@code NDArray}. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
+     * jshell> array.erf();
+     * ND: (3) cpu() float32
+     * [0., 0.5, -1]
+     * 
+ * + * @param input The input {@code NDArray} + * @return The gauss error of the {@code NDArray}, element-wise + */ + public static NDArray erf(NDArray input) { + return input.erf(); + } } diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index e48c243a3ec..f0069d3f3f3 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -100,12 +100,12 @@ public static NDList decode(NDManager manager, byte[] byteArray) { try { if (byteArray[0] == 'P' && byteArray[1] == 'K') { return decodeNumpy(manager, new ByteArrayInputStream(byteArray)); - } else if (byteArray[0] == (byte) 0x39 + } else if (byteArray[0] == (byte) 0x93 && byteArray[1] == 'N' && byteArray[2] == 'U' && byteArray[3] == 'M') { return new NDList( - NDSerializer.decode(manager, new ByteArrayInputStream(byteArray))); + NDSerializer.decodeNumpy(manager, new ByteArrayInputStream(byteArray))); } else if (byteArray[8] == '{') { return decodeSafetensors(manager, new ByteArrayInputStream(byteArray)); } @@ -144,11 +144,11 @@ public static NDList decode(NDManager manager, InputStream is) { if (magic[0] == 'P' && magic[1] == 'K') { // assume this is npz file return decodeNumpy(manager, pis); - } else if (magic[0] == (byte) 0x39 + } else if (magic[0] == (byte) 0x93 && magic[1] == 'N' && magic[2] == 'U' && magic[3] == 'M') { - return new NDList(NDSerializer.decode(manager, pis)); + return new NDList(NDSerializer.decodeNumpy(manager, pis)); } else if (magic[8] == '{') { return decodeSafetensors(manager, pis); } diff --git a/api/src/main/java/ai/djl/ndarray/NDScope.java b/api/src/main/java/ai/djl/ndarray/NDScope.java index 8b0deb23132..764c829c2e6 100644 --- a/api/src/main/java/ai/djl/ndarray/NDScope.java +++ b/api/src/main/java/ai/djl/ndarray/NDScope.java @@ -30,6 +30,7 @@ public class NDScope implements AutoCloseable { private IdentityHashMap resources; /** Constructs a new {@code NDScope} instance. */ + @SuppressWarnings("this-escape") public NDScope() { resources = new IdentityHashMap<>(); SCOPE_STACK.get().addLast(this); diff --git a/api/src/main/java/ai/djl/ndarray/NDSerializer.java b/api/src/main/java/ai/djl/ndarray/NDSerializer.java index f42be7f087f..26dde5d92ff 100644 --- a/api/src/main/java/ai/djl/ndarray/NDSerializer.java +++ b/api/src/main/java/ai/djl/ndarray/NDSerializer.java @@ -83,6 +83,16 @@ static void encode(NDArray array, OutputStream os) throws IOException { Shape shape = array.getShape(); dos.write(shape.getEncoded()); + if (array.getDataType() == DataType.STRING) { + String[] data = array.toStringArray(); + dos.writeInt(data.length); + for (String str : data) { + dos.writeUTF(str); + } + dos.flush(); + return; + } + ByteBuffer bb = array.toByteBuffer(); dos.write(bb.order() == ByteOrder.BIG_ENDIAN ? '>' : '<'); int length = bb.remaining(); @@ -167,6 +177,17 @@ static NDArray decode(NDManager manager, ByteBuffer bb) { // Shape Shape shape = Shape.decode(bb); + if (dataType == DataType.STRING) { + int size = bb.getInt(); + String[] data = new String[size]; + for (int i = 0; i < size; ++i) { + data[i] = readUTF(bb); + } + NDArray array = manager.create(data, StandardCharsets.UTF_8, shape); + array.setName(name); + return array; + } + // Data ByteOrder order; if (version > 2) { diff --git a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java index b12ac5dd07d..07e56a5ca04 100644 --- a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java +++ b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java @@ -434,7 +434,12 @@ default NDArray toTensor() { if (dim == 3) { result = result.expandDims(0); } - result = result.div(255.0).transpose(0, 3, 1, 2); + // For Apple Silicon MPS it is important not to switch to 64-bit float here + if (result.getDataType() == DataType.FLOAT32) { + result = result.div(255.0f).transpose(0, 3, 1, 2); + } else { + result = result.div(255.0).transpose(0, 3, 1, 2); + } if (dim == 3) { result = result.squeeze(0); } diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 3d58d501293..7ace6880c56 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -105,7 +105,7 @@ * further refine these elements, use {@link Block#freezeParameters(boolean)} to unfreeze them. * * @see this + * href="http://docs.djl.ai/docs/demos/jupyter/tutorial/01_create_your_first_network.html">this * tutorial on creating your first network * @see The * D2L chapter on blocks and pred) { + for (Parameter parameter : getParameters().values()) { + if (pred.test(parameter)) { + parameter.freeze(freeze); + } + } + } + /** * Validates that actual layout matches the expected layout. * diff --git a/api/src/main/java/ai/djl/nn/ParallelBlock.java b/api/src/main/java/ai/djl/nn/ParallelBlock.java index 4ebe1e8119b..269e52b6b22 100644 --- a/api/src/main/java/ai/djl/nn/ParallelBlock.java +++ b/api/src/main/java/ai/djl/nn/ParallelBlock.java @@ -62,6 +62,7 @@ public ParallelBlock(Function, NDList> function) { * @param function the function to define how the parallel branches are combined * @param blocks the blocks that form each of the parallel branches */ + @SuppressWarnings("this-escape") public ParallelBlock(Function, NDList> function, List blocks) { super(VERSION); this.function = function; @@ -74,6 +75,7 @@ public ParallelBlock(Function, NDList> function, List blocks * @param blocks the array of blocks to add * @return this block */ + @SuppressWarnings("this-escape") public final ParallelBlock addAll(Block... blocks) { return addAll(Arrays.asList(blocks)); } diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java index f862ee13274..a049c20e2b7 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java @@ -89,6 +89,7 @@ public abstract class Convolution extends AbstractBlock { * * @param builder the {@code Builder} that has the necessary configurations */ + @SuppressWarnings("this-escape") public Convolution(ConvolutionBuilder builder) { super(VERSION); kernelShape = builder.kernelShape; diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java index 667de724e2a..419780a98d1 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java @@ -62,6 +62,7 @@ public abstract class Deconvolution extends AbstractBlock { * * @param builder the {@code Builder} that has the necessary configurations */ + @SuppressWarnings("this-escape") public Deconvolution(DeconvolutionBuilder builder) { kernelShape = builder.kernelShape; stride = builder.stride; diff --git a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java index d2e0acf8e46..c1c27f57935 100644 --- a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java +++ b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java @@ -38,6 +38,7 @@ public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedE * * @param embedding the value to return for all embeddings */ + @SuppressWarnings("this-escape") public ConstantEmbedding(NDArray embedding) { this.embedding = embedding; freezeParameters(true); diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java index d6a937fe9a0..ab6167ced2f 100644 --- a/api/src/main/java/ai/djl/nn/core/Embedding.java +++ b/api/src/main/java/ai/djl/nn/core/Embedding.java @@ -49,6 +49,7 @@ public abstract class Embedding extends AbstractBlock implements AbstractInde protected Parameter embedding; + @SuppressWarnings("this-escape") protected Embedding(BaseBuilder baseBuilder) { super(VERSION); embeddingSize = baseBuilder.embeddingSize; @@ -91,6 +92,7 @@ protected Embedding(NDArray embedding) { * @param embedding the embedding array * @param format whether to compute row sparse gradient in the backward calculation */ + @SuppressWarnings("this-escape") protected Embedding(NDArray embedding, SparseFormat format) { super(VERSION); numEmbeddings = Math.toIntExact(embedding.getShape().get(0)); diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 530344a8858..d10c0a91eb8 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -62,6 +62,7 @@ public class Linear extends AbstractBlock { private Parameter weight; private Parameter bias; + @SuppressWarnings("this-escape") protected Linear(Builder builder) { super(VERSION); units = builder.units; diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java index 8fcb9971330..e70d06a448b 100644 --- a/api/src/main/java/ai/djl/nn/core/Prelu.java +++ b/api/src/main/java/ai/djl/nn/core/Prelu.java @@ -41,6 +41,7 @@ public class Prelu extends AbstractBlock { private Parameter alpha; /** Creates a Parametric ReLU Block. */ + @SuppressWarnings("this-escape") public Prelu() { super(VERSION); alpha = diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java index 5d69284132e..42ab1036aa8 100644 --- a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java @@ -66,6 +66,7 @@ public class LayerNorm extends AbstractBlock { protected Parameter gamma; protected Parameter beta; + @SuppressWarnings("this-escape") protected LayerNorm(Builder builder) { epsilon = builder.epsilon; scale = builder.scale; diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index 3c9bb3f89d7..981e4954e7c 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -58,6 +58,7 @@ public abstract class RecurrentBlock extends AbstractBlock { * * @param builder the {@code Builder} that has the necessary configurations */ + @SuppressWarnings("this-escape") public RecurrentBlock(BaseBuilder builder) { super(VERSION); stateSize = builder.stateSize; diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java index a0b49b9430d..cb02a2f4074 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java @@ -46,6 +46,7 @@ public class BertMaskedLanguageModelBlock extends AbstractBlock { * @param bertBlock the bert block to create the task for * @param hiddenActivation the activation to use for the hidden layer */ + @SuppressWarnings("this-escape") public BertMaskedLanguageModelBlock( BertBlock bertBlock, Function hiddenActivation) { super(VERSION); diff --git a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java index 549d05b629e..4c3bbdb55b8 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java @@ -29,6 +29,7 @@ public class BertNextSentenceBlock extends AbstractBlock { private Linear binaryClassifier; /** Creates a next sentence block. */ + @SuppressWarnings("this-escape") public BertNextSentenceBlock() { binaryClassifier = addChildBlock( diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java index d196ace2782..8d9cec6c01e 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java @@ -36,6 +36,7 @@ public class BertPretrainingBlock extends AbstractBlock { * * @param builder a builder with a bert configuration */ + @SuppressWarnings("this-escape") public BertPretrainingBlock(final BertBlock.Builder builder) { this.bertBlock = addChildBlock("Bert", builder.build()); this.mlBlock = diff --git a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java index 3b530808bdf..451709c3f74 100644 --- a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java @@ -31,6 +31,7 @@ public class PointwiseFeedForwardBlock extends SequentialBlock { * @param activationFunction the activation function to use for the hidden layers (not applied * to output) */ + @SuppressWarnings("this-escape") public PointwiseFeedForwardBlock( List hiddenSizes, int outputSize, diff --git a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java index bc251d42e86..f01cb1adc33 100644 --- a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java @@ -51,6 +51,7 @@ public class TransformerEncoderBlock extends AbstractBlock { * @param dropoutProbability dropout probability * @param activationFunction activation function */ + @SuppressWarnings("this-escape") public TransformerEncoderBlock( int embeddingSize, int headCount, diff --git a/api/src/main/java/ai/djl/repository/AbstractRepository.java b/api/src/main/java/ai/djl/repository/AbstractRepository.java index 3b83c359aad..c28a3b16887 100644 --- a/api/src/main/java/ai/djl/repository/AbstractRepository.java +++ b/api/src/main/java/ai/djl/repository/AbstractRepository.java @@ -14,13 +14,10 @@ import ai.djl.util.Hex; import ai.djl.util.Progress; +import ai.djl.util.TarUtils; import ai.djl.util.Utils; import ai.djl.util.ZipUtils; -import org.apache.commons.compress.archivers.tar.TarArchiveEntry; -import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; -import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; -import org.apache.commons.compress.utils.CloseShieldFilterInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -212,9 +209,9 @@ protected void save(InputStream is, Path tmp, Artifact.Item item, Progress progr if ("zip".equals(extension)) { ZipUtils.unzip(pis, dir); } else if ("tgz".equals(extension)) { - untar(pis, dir, true); + TarUtils.untar(pis, dir, true); } else if ("tar".equals(extension)) { - untar(pis, dir, false); + TarUtils.untar(pis, dir, false); } else { throw new IOException("File type is not supported: " + extension); } @@ -233,36 +230,6 @@ protected void save(InputStream is, Path tmp, Artifact.Item item, Progress progr pis.validateChecksum(item); } - private void untar(InputStream is, Path dir, boolean gzip) throws IOException { - InputStream bis; - if (gzip) { - bis = new GzipCompressorInputStream(new BufferedInputStream(is)); - } else { - bis = new BufferedInputStream(is); - } - bis = new CloseShieldFilterInputStream(bis); - try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) { - TarArchiveEntry entry; - while ((entry = tis.getNextTarEntry()) != null) { - String entryName = entry.getName(); - if (entryName.contains("..")) { - throw new IOException("Malicious zip entry: " + entryName); - } - Path file = dir.resolve(entryName).toAbsolutePath(); - if (entry.isDirectory()) { - Files.createDirectories(file); - } else { - Path parentFile = file.getParent(); - if (parentFile == null) { - throw new AssertionError("Parent path should never be null: " + file); - } - Files.createDirectories(parentFile); - Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING); - } - } - } - } - private static Map parseQueryString(URI uri) { try { Map map = new ConcurrentHashMap<>(); diff --git a/api/src/main/java/ai/djl/repository/Artifact.java b/api/src/main/java/ai/djl/repository/Artifact.java index 4908871497b..a7edceac7e7 100644 --- a/api/src/main/java/ai/djl/repository/Artifact.java +++ b/api/src/main/java/ai/djl/repository/Artifact.java @@ -240,15 +240,13 @@ public void setMetadata(Metadata metadata) { */ public URI getResourceUri() { URI uri = metadata.getRepositoryUri(); - if (properties != null) { - for (String values : properties.values()) { - uri = uri.resolve(values + '/'); - } + if (version != null) { + uri = uri.resolve(version + '/'); } - if (version == null) { - return uri; + if (name != null && !name.isEmpty()) { + uri = uri.resolve(name + '/'); } - return uri.resolve(version + '/'); + return uri; } /** diff --git a/api/src/main/java/ai/djl/repository/JarRepository.java b/api/src/main/java/ai/djl/repository/JarRepository.java index b1b443d01f0..1f610f99bb2 100644 --- a/api/src/main/java/ai/djl/repository/JarRepository.java +++ b/api/src/main/java/ai/djl/repository/JarRepository.java @@ -42,13 +42,16 @@ public class JarRepository extends AbstractRepository { private String artifactId; private String modelName; private String queryString; + private String originalUri; private Metadata metadata; private boolean resolved; - JarRepository(String name, URI uri, String fileName, String queryString) { + JarRepository(String name, URI uri, String fileName, URI realUri) { super(name, uri); - this.queryString = queryString; + this.uri = realUri; + queryString = uri.getRawQuery(); + originalUri = uri.toString(); modelName = arguments.get("model_name"); artifactId = arguments.get("artifact_id"); if (artifactId == null) { @@ -123,8 +126,14 @@ private synchronized Metadata getMetadata() { metadata = new Metadata.MatchAllMetadata(); metadata.setArtifactId(artifactId); metadata.setArtifacts(Collections.singletonList(artifact)); - String hash = - Utils.hash(queryString == null ? uri.toString() : uri.toString() + queryString); + String hashKey; + if (Boolean.parseBoolean(arguments.get("ignore_real_uri"))) { + hashKey = originalUri; + } else { + hashKey = queryString == null ? uri.toString() : uri.toString() + queryString; + } + + String hash = Utils.hash(hashKey); MRL mrl = model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); metadata.setRepositoryUri(mrl.toURI()); diff --git a/api/src/main/java/ai/djl/repository/RemoteRepository.java b/api/src/main/java/ai/djl/repository/RemoteRepository.java index 6b01ce14ef8..52f87afbf82 100644 --- a/api/src/main/java/ai/djl/repository/RemoteRepository.java +++ b/api/src/main/java/ai/djl/repository/RemoteRepository.java @@ -75,7 +75,7 @@ public Metadata locate(MRL mrl) throws IOException { Metadata metadata = JsonUtils.GSON_PRETTY.fromJson(reader, Metadata.class); metadata.init(arguments); Date lastUpdated = metadata.getLastUpdated(); - if (Boolean.getBoolean("offline") + if (Utils.isOfflineMode() || System.currentTimeMillis() - lastUpdated.getTime() < ONE_DAY) { metadata.setRepositoryUri(mrlUri); return metadata; diff --git a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java index 71f394d6d14..eb5ab68a703 100644 --- a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java +++ b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java @@ -144,7 +144,6 @@ private static final class JarRepositoryFactory implements RepositoryFactory { @Override public Repository newInstance(String name, URI uri) { String p = uri.getPath(); - String queryString = uri.getRawQuery(); if (p.startsWith("/")) { p = p.substring(1); } @@ -152,20 +151,22 @@ public Repository newInstance(String name, URI uri) { if (u == null) { throw new IllegalArgumentException("Resource not found: " + uri); } + + URI realUri; try { - uri = u.toURI(); + // resolve real uri: jar:file:/path/my_lib.jar!/model.zip + realUri = u.toURI(); } catch (URISyntaxException e) { throw new IllegalArgumentException("Resource not found: " + uri, e); } - Path path = Paths.get(parseFilePath(uri)); + Path path = Paths.get(parseFilePath(realUri)); String fileName = path.toFile().getName(); - if (!FilenameUtils.isArchiveFile(fileName)) { - throw new IllegalArgumentException("Only archive file is supported for res URL."); + if (FilenameUtils.isArchiveFile(fileName)) { + fileName = FilenameUtils.getNamePart(fileName); } - fileName = FilenameUtils.getNamePart(fileName); - return new JarRepository(name, uri, fileName, queryString); + return new JarRepository(name, uri, fileName, realUri); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java b/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java index 527871067fa..676bab73d75 100644 --- a/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java +++ b/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java @@ -29,6 +29,7 @@ public class DefaultModelZoo extends ModelZoo { private static final Logger logger = LoggerFactory.getLogger(DefaultModelZoo.class); /** Constructs a new {@code LocalModelZoo} instance. */ + @SuppressWarnings("this-escape") public DefaultModelZoo() { String locations = System.getProperty("ai.djl.repository.zoo.location"); if (locations != null) { @@ -41,6 +42,7 @@ public DefaultModelZoo() { * * @param locations a comma separated urls where the models to be loaded from */ + @SuppressWarnings("this-escape") public DefaultModelZoo(String locations) { parseLocation(locations); } diff --git a/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java b/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java index 50b219be509..e903a1677b3 100644 --- a/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java +++ b/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java @@ -32,6 +32,7 @@ public abstract class ModelZoo { private static final Map MODEL_ZOO_MAP = new ConcurrentHashMap<>(); + private static ModelZooResolver resolver; private Map modelLoaders = new ConcurrentHashMap<>(); @@ -86,6 +87,15 @@ protected final void addModel(ModelLoader loader) { modelLoaders.put(loader.getArtifactId(), loader); } + /** + * Sets the {@code ModelZooResolver}. + * + * @param resolver the {@code ModelZooResolver} + */ + public static void setModelZooResolver(ModelZooResolver resolver) { + ModelZoo.resolver = resolver; + } + /** * Refreshes model zoo. * @@ -112,7 +122,14 @@ public static Collection listModelZoo() { * @return the {@code ModelZoo} with the {@code groupId} */ public static ModelZoo getModelZoo(String groupId) { - return MODEL_ZOO_MAP.get(groupId); + ModelZoo zoo = MODEL_ZOO_MAP.get(groupId); + if (zoo == null && resolver != null) { + zoo = resolver.resolve(groupId); + if (zoo != null) { + MODEL_ZOO_MAP.putIfAbsent(groupId, zoo); + } + } + return zoo; } /** diff --git a/api/src/main/java/ai/djl/repository/zoo/ModelZooResolver.java b/api/src/main/java/ai/djl/repository/zoo/ModelZooResolver.java new file mode 100644 index 00000000000..897e122f191 --- /dev/null +++ b/api/src/main/java/ai/djl/repository/zoo/ModelZooResolver.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.repository.zoo; + +/** An interface that resolves external ModelZoo. */ +public interface ModelZooResolver { + + /** + * Returns {@link ModelZoo} based on model zoo group ID. + * + * @param groupId the model zoo group ID. + * @return the resolved {@code ModelZoo} + */ + ModelZoo resolve(String groupId); +} diff --git a/api/src/main/java/ai/djl/training/ParameterStore.java b/api/src/main/java/ai/djl/training/ParameterStore.java index 7029282c46e..15c83bde8ca 100644 --- a/api/src/main/java/ai/djl/training/ParameterStore.java +++ b/api/src/main/java/ai/djl/training/ParameterStore.java @@ -14,6 +14,7 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.Device.MultiDevice; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.nn.Parameter; @@ -64,6 +65,10 @@ public void setParameterServer(ParameterServer parameterServer, Device[] devices this.parameterServer = parameterServer; deviceMap.clear(); for (int i = 0; i < devices.length; ++i) { + if (devices[i] instanceof MultiDevice) { + throw new IllegalArgumentException( + "The parameter store does not support MultiDevices"); + } if (deviceMap.put(devices[i], i) != null) { throw new IllegalArgumentException("Duplicated devices are not allowed."); } diff --git a/api/src/main/java/ai/djl/training/Trainer.java b/api/src/main/java/ai/djl/training/Trainer.java index eab6ba07f2a..6d79dde3eec 100644 --- a/api/src/main/java/ai/djl/training/Trainer.java +++ b/api/src/main/java/ai/djl/training/Trainer.java @@ -52,14 +52,12 @@ * * * * @see The guide on memory @@ -88,6 +86,7 @@ public class Trainer implements AutoCloseable { * @param model the model the trainer will train on * @param trainingConfig the configuration used by the trainer */ + @SuppressWarnings("this-escape") public Trainer(Model model, TrainingConfig trainingConfig) { this.model = model; manager = model.getNDManager().newSubManager(); diff --git a/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java b/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java index c9a5fdf7036..8610f9e92bb 100644 --- a/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java +++ b/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java @@ -77,9 +77,22 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { Pair update = accuracyHelper(labels, predictions); - totalInstances.compute(key, (k, v) -> v + update.getKey()); - correctInstances.compute(key, (k, v) -> v + update.getValue().sum().getLong()); + NDArray value = update.getValue(); + NDArray sum = value.sum(); + long correct = sum.getLong(); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + update.getKey()); + correctInstances.compute(key, (k, v) -> v + correct); + } + value.close(); + sum.close(); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java b/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java index 4af9e5de3d1..ab2d554142d 100644 --- a/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java +++ b/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java @@ -63,10 +63,18 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { NDArray boundingBoxError = evaluate(labels, predictions); float update = boundingBoxError.sum().getFloat(); - totalInstances.compute(key, (k, v) -> v + boundingBoxError.size()); - ssdBoxPredictionError.compute(key, (k, v) -> v + update); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + boundingBoxError.size()); + ssdBoxPredictionError.compute(key, (k, v) -> v + update); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/evaluator/Evaluator.java b/api/src/main/java/ai/djl/training/evaluator/Evaluator.java index 6d2c5995601..c373471f6cf 100644 --- a/api/src/main/java/ai/djl/training/evaluator/Evaluator.java +++ b/api/src/main/java/ai/djl/training/evaluator/Evaluator.java @@ -74,6 +74,25 @@ public String getName() { */ public abstract void addAccumulator(String key); + /** + * Updates the evaluator with the given keys based on a {@link NDList} of labels and + * predictions. + * + *

This is a synchronized operation. You should only call it at the end of a batch or epoch. + * + *

This is an alternative to @{link {@link #updateAccumulator(String, NDList, NDList)}} that + * may be more efficient when updating multiple accumulators at once. + * + * @param keys the keys of all the accumulators to update + * @param labels a {@code NDList} of labels + * @param predictions a {@code NDList} of predictions + */ + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { + for (String key : keys) { + updateAccumulator(key, labels, predictions); + } + } + /** * Updates the evaluator with the given key based on a {@link NDList} of labels and predictions. * diff --git a/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java b/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java index a7fe08b610e..aa12cae628c 100644 --- a/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java +++ b/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java @@ -67,6 +67,12 @@ public void updateAccumulator(String key, NDList labels, NDList predictions) { evaluator.updateAccumulator(key, getLabels(labels), getPredictions(predictions)); } + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { + evaluator.updateAccumulators(keys, getLabels(labels), getPredictions(predictions)); + } + /** {@inheritDoc} */ @Override public void resetAccumulator(String key) { diff --git a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java new file mode 100644 index 00000000000..6c013c37715 --- /dev/null +++ b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java @@ -0,0 +1,281 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.training.listener; + +import ai.djl.training.Trainer; +import ai.djl.training.TrainingResult; + +import java.time.Duration; + +/** + * Listener that allows the training to be stopped early if the validation loss is not improving, or + * if time has expired.
+ * + *

Usage: Add this listener to the training config, and add it as the last one. + * + *

+ *  new DefaultTrainingConfig(...)
+ *        .addTrainingListeners(EarlyStoppingListener.builder()
+ *                .setEpochPatience(1)
+ *                .setEarlyStopPctImprovement(1)
+ *                .setMaxDuration(Duration.ofMinutes(42))
+ *                .setMinEpochs(1)
+ *                .build()
+ *        );
+ * 
+ * + *

Then surround the fit with a try catch that catches the {@link + * EarlyStoppingListener.EarlyStoppedException}.
+ * Example: + * + *

+ * try {
+ *   EasyTrain.fit(trainer, 5, trainDataset, testDataset);
+ * } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ *   // handle early stopping
+ *   log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
+ * }
+ * 
+ * + *
+ * Note: Ensure that Metrics are set on the trainer. + */ +public final class EarlyStoppingListener implements TrainingListener { + private final double objectiveSuccess; + + private final int minEpochs; + private final long maxMillis; + private final double earlyStopPctImprovement; + private final int epochPatience; + + private long startTimeMills; + private double prevLoss; + private int numberOfEpochsWithoutImprovements; + + private EarlyStoppingListener( + double objectiveSuccess, + int minEpochs, + long maxMillis, + double earlyStopPctImprovement, + int earlyStopPatience) { + this.objectiveSuccess = objectiveSuccess; + this.minEpochs = minEpochs; + this.maxMillis = maxMillis; + this.earlyStopPctImprovement = earlyStopPctImprovement; + this.epochPatience = earlyStopPatience; + } + + /** {@inheritDoc} */ + @Override + public void onEpoch(Trainer trainer) { + int currentEpoch = trainer.getTrainingResult().getEpoch(); + // stopping criteria + final double loss = getLoss(trainer.getTrainingResult()); + if (currentEpoch >= minEpochs) { + if (loss < objectiveSuccess) { + throw new EarlyStoppedException( + currentEpoch, + String.format( + "validation loss %s < objectiveSuccess %s", + loss, objectiveSuccess)); + } + long elapsedMillis = System.currentTimeMillis() - startTimeMills; + if (elapsedMillis >= maxMillis) { + throw new EarlyStoppedException( + currentEpoch, + String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, maxMillis)); + } + // consider early stopping? + if (Double.isFinite(prevLoss)) { + double goalImprovement = prevLoss * (100 - earlyStopPctImprovement) / 100.0; + boolean improved = loss <= goalImprovement; // false if any NANs + if (improved) { + numberOfEpochsWithoutImprovements = 0; + } else { + numberOfEpochsWithoutImprovements++; + if (numberOfEpochsWithoutImprovements >= epochPatience) { + throw new EarlyStoppedException( + currentEpoch, + String.format( + "failed to achieve %s%% improvement %s times in a row", + earlyStopPctImprovement, epochPatience)); + } + } + } + } + if (Double.isFinite(loss)) { + prevLoss = loss; + } + } + + private static double getLoss(TrainingResult trainingResult) { + Float vLoss = trainingResult.getValidateLoss(); + if (vLoss != null) { + return vLoss; + } + Float tLoss = trainingResult.getTrainLoss(); + if (tLoss == null) { + return Double.NaN; + } + return tLoss; + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBatch(Trainer trainer, BatchData batchData) { + // do nothing + } + + /** {@inheritDoc} */ + @Override + public void onValidationBatch(Trainer trainer, BatchData batchData) { + // do nothing + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBegin(Trainer trainer) { + this.startTimeMills = System.currentTimeMillis(); + this.prevLoss = Double.NaN; + this.numberOfEpochsWithoutImprovements = 0; + } + + /** {@inheritDoc} */ + @Override + public void onTrainingEnd(Trainer trainer) { + // do nothing + } + + /** + * Creates a builder to build a {@link EarlyStoppingListener}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** A builder for a {@link EarlyStoppingListener}. */ + public static final class Builder { + private final double objectiveSuccess; + private int minEpochs; + private long maxMillis; + private double earlyStopPctImprovement; + private int epochPatience; + + /** Constructs a {@link Builder} with default values. */ + public Builder() { + this.objectiveSuccess = 0; + this.minEpochs = 0; + this.maxMillis = Long.MAX_VALUE; + this.earlyStopPctImprovement = 0; + this.epochPatience = 0; + } + + /** + * Set the minimum # epochs, defaults to 0. + * + * @param minEpochs the minimum # epochs + * @return this builder + */ + public Builder optMinEpochs(int minEpochs) { + this.minEpochs = minEpochs; + return this; + } + + /** + * Set the maximum duration a training run should take, defaults to Long.MAX_VALUE in ms. + * + * @param duration the maximum duration a training run should take + * @return this builder + */ + public Builder optMaxDuration(Duration duration) { + this.maxMillis = duration.toMillis(); + return this; + } + + /** + * Set the maximum # milliseconds a training run should take, defaults to Long.MAX_VALUE. + * + * @param maxMillis the maximum # milliseconds a training run should take + * @return this builder + */ + public Builder optMaxMillis(int maxMillis) { + this.maxMillis = maxMillis; + return this; + } + + /** + * Consider early stopping if not x% improvement, defaults to 0. + * + * @param earlyStopPctImprovement the percentage improvement to consider early stopping, + * must be between 0 and 100. + * @return this builder + */ + public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) { + this.earlyStopPctImprovement = earlyStopPctImprovement; + return this; + } + + /** + * Stop if insufficient improvement for x epochs in a row, defaults to 0. + * + * @param epochPatience the number of epochs without improvement to consider stopping, must + * be greater than 0. + * @return this builder + */ + public Builder optEpochPatience(int epochPatience) { + this.epochPatience = epochPatience; + return this; + } + + /** + * Builds a {@link EarlyStoppingListener} with the specified values. + * + * @return a new {@link EarlyStoppingListener} + */ + public EarlyStoppingListener build() { + return new EarlyStoppingListener( + objectiveSuccess, minEpochs, maxMillis, earlyStopPctImprovement, epochPatience); + } + } + + /** + * Thrown when training is stopped early, the message will contain the reason why it is stopped + * early. + */ + public static class EarlyStoppedException extends RuntimeException { + private static final long serialVersionUID = 1L; + private final int stopEpoch; + + /** + * Constructs an {@link EarlyStoppedException} with the specified message and epoch. + * + * @param stopEpoch the epoch at which training was stopped early + * @param message the message/reason why training was stopped early + */ + public EarlyStoppedException(int stopEpoch, String message) { + super(message); + this.stopEpoch = stopEpoch; + } + + /** + * Gets the epoch at which training was stopped early. + * + * @return the epoch at which training was stopped early. + */ + public int getStopEpoch() { + return stopEpoch; + } + } +} diff --git a/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java b/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java index 1dbfe4117cd..2556a026259 100644 --- a/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java +++ b/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java @@ -144,9 +144,7 @@ private void updateEvaluators(Trainer trainer, BatchData batchData, String[] acc for (Device device : batchData.getLabels().keySet()) { NDList labels = batchData.getLabels().get(device); NDList predictions = batchData.getPredictions().get(device); - for (String accumulator : accumulators) { - evaluator.updateAccumulator(accumulator, labels, predictions); - } + evaluator.updateAccumulators(accumulators, labels, predictions); } } } diff --git a/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java b/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java index 2a46416190a..2e2cdcb8c86 100644 --- a/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java +++ b/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java @@ -80,10 +80,10 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override - public void updateAccumulator(String key, NDList labels, NDList predictions) { + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { for (int i = 0; i < components.size(); i++) { Pair inputs = inputForComponent(i, labels, predictions); - components.get(i).updateAccumulator(key, inputs.getKey(), inputs.getValue()); + components.get(i).updateAccumulators(keys, inputs.getKey(), inputs.getValue()); } } diff --git a/api/src/main/java/ai/djl/training/loss/Loss.java b/api/src/main/java/ai/djl/training/loss/Loss.java index a661a3e9a0e..bcf39d23b39 100644 --- a/api/src/main/java/ai/djl/training/loss/Loss.java +++ b/api/src/main/java/ai/djl/training/loss/Loss.java @@ -385,10 +385,18 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { // this is a synchronized operation, only call it at end of batch or epoch float update = evaluate(labels, predictions).sum().getFloat(); - totalInstances.compute(key, (k, v) -> v + 1); - totalLoss.compute(key, (k, v) -> v + update); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + 1); + totalLoss.compute(key, (k, v) -> v + update); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/tracker/LinearTracker.java b/api/src/main/java/ai/djl/training/tracker/LinearTracker.java index 986117d2b65..08bee48da87 100644 --- a/api/src/main/java/ai/djl/training/tracker/LinearTracker.java +++ b/api/src/main/java/ai/djl/training/tracker/LinearTracker.java @@ -12,7 +12,6 @@ */ package ai.djl.training.tracker; -import ai.djl.training.tracker.WarmUpTracker.Builder; import ai.djl.util.Preconditions; /** diff --git a/api/src/main/java/ai/djl/training/util/ProgressBar.java b/api/src/main/java/ai/djl/training/util/ProgressBar.java index 6300116dc5b..ae36640f01d 100644 --- a/api/src/main/java/ai/djl/training/util/ProgressBar.java +++ b/api/src/main/java/ai/djl/training/util/ProgressBar.java @@ -29,10 +29,14 @@ public final class ProgressBar implements Progress { private long progress; private int currentPercent; private char progressChar = getProgressChar(); + private boolean disableProgressBar; /** Creates an instance of {@code ProgressBar} with a maximum value of 1. */ public ProgressBar() { max = 1; + disableProgressBar = + Boolean.parseBoolean(Utils.getEnvOrSystemProperty("DJL_DISABLE_PROGRESS_BAR")) + || Boolean.getBoolean("disableProgressBar"); } /** @@ -43,6 +47,7 @@ public ProgressBar() { * @param max the maximum value */ public ProgressBar(String message, long max) { + this(); reset(message, max); } @@ -55,6 +60,7 @@ public ProgressBar(String message, long max) { * @param trailingMessage the trailing message to be shown */ public ProgressBar(String message, long max, String trailingMessage) { + this(); reset(message, max); this.trailingMessage = trailingMessage; } @@ -91,7 +97,7 @@ public void increment(long increment) { @Override @SuppressWarnings("PMD.SystemPrintln") public void update(long progress, String additionalMessage) { - if (Boolean.getBoolean("disableProgressBar") || max <= 1) { + if (disableProgressBar || max <= 1) { return; } diff --git a/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java b/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java index 3f3bb1b2d6e..f026bd431c9 100644 --- a/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java +++ b/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java @@ -29,10 +29,17 @@ public final class PaddingStackBatchifier implements Batchifier { private static final long serialVersionUID = 1L; + @SuppressWarnings("serial") private List arraysToPad; + + @SuppressWarnings("serial") private List dimsToPad; + private transient List paddingSuppliers; + + @SuppressWarnings("serial") private List paddingSizes; + private boolean includeValidLengths; private PaddingStackBatchifier(Builder builder) { diff --git a/api/src/main/java/ai/djl/util/Ec2Utils.java b/api/src/main/java/ai/djl/util/Ec2Utils.java index 178c3d7efe7..5408182964f 100644 --- a/api/src/main/java/ai/djl/util/Ec2Utils.java +++ b/api/src/main/java/ai/djl/util/Ec2Utils.java @@ -97,7 +97,7 @@ public static String readMetadata(String key) { * @param engine the default engine name */ public static void callHome(String engine) { - if (Boolean.getBoolean("offline") + if (Utils.isOfflineMode() || Boolean.parseBoolean(Utils.getEnvOrSystemProperty("OPT_OUT_TRACKING")) || System.currentTimeMillis() - lastCheckIn < ONE_DAY) { return; diff --git a/api/src/main/java/ai/djl/util/StringPair.java b/api/src/main/java/ai/djl/util/StringPair.java new file mode 100644 index 00000000000..a42e739614b --- /dev/null +++ b/api/src/main/java/ai/djl/util/StringPair.java @@ -0,0 +1,27 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util; + +/** A class containing the string key-value pair. */ +public class StringPair extends Pair { + + /** + * Constructs a {@code Pair} instance with key and value. + * + * @param key the key + * @param value the value + */ + public StringPair(String key, String value) { + super(key, value); + } +} diff --git a/api/src/main/java/ai/djl/util/TarUtils.java b/api/src/main/java/ai/djl/util/TarUtils.java new file mode 100644 index 00000000000..d4a6e42b230 --- /dev/null +++ b/api/src/main/java/ai/djl/util/TarUtils.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util; + +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.apache.commons.io.input.CloseShieldInputStream; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; + +/** Utilities for working with zip files. */ +public final class TarUtils { + + private TarUtils() {} + + /** + * Un-compress a tar ball from InputStream. + * + * @param is the InputStream + * @param dir the target directory + * @param gzip if the bar ball is gzip + * @throws IOException for failures to untar the input directory + */ + public static void untar(InputStream is, Path dir, boolean gzip) throws IOException { + InputStream bis; + if (gzip) { + bis = new GzipCompressorInputStream(new BufferedInputStream(is)); + } else { + bis = new BufferedInputStream(is); + } + bis = CloseShieldInputStream.wrap(bis); + try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) { + TarArchiveEntry entry; + while ((entry = tis.getNextEntry()) != null) { + String entryName = ZipUtils.removeLeadingFileSeparator(entry.getName()); + if (entryName.contains("..")) { + throw new IOException("Malicious zip entry: " + entryName); + } + Path file = dir.resolve(entryName).toAbsolutePath(); + if (entry.isDirectory()) { + Files.createDirectories(file); + } else { + Path parentFile = file.getParent(); + if (parentFile == null) { + throw new AssertionError("Parent path should never be null: " + file); + } + Files.createDirectories(parentFile); + Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING); + } + } + } + } +} diff --git a/api/src/main/java/ai/djl/util/Utils.java b/api/src/main/java/ai/djl/util/Utils.java index c8e1bd514ac..270958d5b40 100644 --- a/api/src/main/java/ai/djl/util/Utils.java +++ b/api/src/main/java/ai/djl/util/Utils.java @@ -357,6 +357,20 @@ public static Path getCacheDir() { return Paths.get(cacheDir); } + /** + * Returns if offline mode is enabled. + * + * @return true if offline mode is enabled + */ + public static boolean isOfflineMode() { + String mode = getenv("DJL_OFFLINE", System.getProperty("ai.djl.offline")); + if (mode != null) { + return Boolean.parseBoolean(mode); + } + // backward compatible + return Boolean.getBoolean("offline"); + } + /** * Returns nested model directory if the directory contains only one subdirectory. * @@ -481,7 +495,7 @@ public static InputStream openUrl(String url) throws IOException { */ public static InputStream openUrl(URL url) throws IOException { String protocol = url.getProtocol(); - if (Boolean.getBoolean("offline") + if (isOfflineMode() && ("http".equalsIgnoreCase(protocol) || "https".equalsIgnoreCase(protocol))) { throw new IOException("Offline model is enabled."); } diff --git a/api/src/main/java/ai/djl/util/ZipUtils.java b/api/src/main/java/ai/djl/util/ZipUtils.java index f1a4889af0b..7c8c298a6cb 100644 --- a/api/src/main/java/ai/djl/util/ZipUtils.java +++ b/api/src/main/java/ai/djl/util/ZipUtils.java @@ -52,7 +52,7 @@ public static void unzip(InputStream is, Path dest) throws IOException { ZipEntry entry; Set set = new HashSet<>(); while ((entry = zis.getNextEntry()) != null) { - String name = entry.getName(); + String name = removeLeadingFileSeparator(entry.getName()); if (name.contains("..")) { throw new IOException("Malicious zip entry: " + name); } @@ -121,6 +121,16 @@ private static void addToZip(Path root, Path file, ZipOutputStream zos) throws I } } + static String removeLeadingFileSeparator(String name) { + int index = 0; + for (; index < name.length(); index++) { + if (name.charAt(index) != File.separatorChar) { + break; + } + } + return name.substring(index); + } + private static final class ValidationInputStream extends FilterInputStream { private static final int ZIP64_LOCSIG = 0x07064b50; // "PK\006\007" @@ -223,7 +233,7 @@ private End findEND(ByteBuffer bb) throws IOException { // Let's do some extra verification, we don't care about the // performance in this situation. int cenpos = end.endpos - end.cenlen; - int locpos = cenpos - end.cenoff; + int locpos = Math.toIntExact(cenpos - end.cenoff); if (cenpos < 0 || locpos < 0 || bb.getInt(cenpos) != CENSIG @@ -243,7 +253,7 @@ private End findEND(ByteBuffer bb) throws IOException { // end64 candidate found, int cenlen64 = Math.toIntExact(bb.getLong(relativePos + 40)); - int cenoff64 = Math.toIntExact(bb.getLong(relativePos + 48)); + long cenoff64 = bb.getLong(relativePos + 48); // double-check if (cenlen64 != end.cenlen && end.cenlen > 0 || cenoff64 != end.cenoff && end.cenoff > 0) { @@ -303,7 +313,7 @@ private List initCEN(byte[] header) throws IOException { private static final class End { int cenlen; // 4 bytes - int cenoff; // 4 bytes + long cenoff; // 4 bytes int endpos; // 4 bytes } } diff --git a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java index b0b8e3e4247..1de074ea6c8 100644 --- a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java +++ b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java @@ -22,7 +22,11 @@ import org.slf4j.LoggerFactory; import java.io.File; +import java.io.IOException; +import java.io.InputStream; import java.lang.management.MemoryUsage; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; import java.util.regex.Pattern; @@ -33,6 +37,8 @@ public final class CudaUtils { private static final CudaLibrary LIB = loadLibrary(); + private static String[] gpuInfo; + private CudaUtils() {} /** @@ -49,7 +55,15 @@ public static boolean hasCuda() { * * @return the number of GPUs available in the system */ + @SuppressWarnings("PMD.NonThreadSafeSingleton") public static int getGpuCount() { + if (Boolean.getBoolean("ai.djl.util.cuda.fork")) { + if (gpuInfo == null) { + gpuInfo = execute(-1); // NOPMD + } + return Integer.parseInt(gpuInfo[0]); + } + if (LIB == null) { return 0; } @@ -79,7 +93,19 @@ public static int getGpuCount() { * * @return the version of CUDA runtime */ + @SuppressWarnings("PMD.NonThreadSafeSingleton") public static int getCudaVersion() { + if (Boolean.getBoolean("ai.djl.util.cuda.fork")) { + if (gpuInfo == null) { + gpuInfo = execute(-1); + } + int version = Integer.parseInt(gpuInfo[1]); + if (version == -1) { + throw new IllegalArgumentException("No cuda device found."); + } + return version; + } + if (LIB == null) { throw new IllegalStateException("No cuda library is loaded."); } @@ -95,9 +121,6 @@ public static int getCudaVersion() { * @return the version string of CUDA runtime */ public static String getCudaVersionString() { - if (LIB == null) { - throw new IllegalStateException("No cuda library is loaded."); - } int version = getCudaVersion(); int major = version / 1000; int minor = (version / 10) % 10; @@ -111,6 +134,14 @@ public static String getCudaVersionString() { * @return the CUDA compute capability */ public static String getComputeCapability(int device) { + if (Boolean.getBoolean("ai.djl.util.cuda.fork")) { + String[] ret = execute(device); + if (ret.length != 3) { + throw new IllegalArgumentException(ret[0]); + } + return ret[0]; + } + if (LIB == null) { throw new IllegalStateException("No cuda library is loaded."); } @@ -137,6 +168,16 @@ public static MemoryUsage getGpuMemory(Device device) { throw new IllegalArgumentException("Only GPU device is allowed."); } + if (Boolean.getBoolean("ai.djl.util.cuda.fork")) { + String[] ret = execute(device.getDeviceId()); + if (ret.length != 3) { + throw new IllegalArgumentException(ret[0]); + } + long total = Long.parseLong(ret[1]); + long used = Long.parseLong(ret[2]); + return new MemoryUsage(-1, used, used, total); + } + if (LIB == null) { throw new IllegalStateException("No GPU device detected."); } @@ -155,8 +196,42 @@ public static MemoryUsage getGpuMemory(Device device) { return new MemoryUsage(-1, committed, committed, total[0]); } + /** + * The main entrypoint to get CUDA information with command line. + * + * @param args the command line arguments. + */ + @SuppressWarnings("PMD.SystemPrintln") + public static void main(String[] args) { + int gpuCount = getGpuCount(); + if (args.length == 0) { + if (gpuCount <= 0) { + System.out.println("0,-1"); + return; + } + int cudaVersion = getCudaVersion(); + System.out.println(gpuCount + "," + cudaVersion); + return; + } + try { + int deviceId = Integer.parseInt(args[0]); + if (deviceId < 0 || deviceId >= gpuCount) { + System.out.println("Invalid device: " + deviceId); + return; + } + MemoryUsage mem = getGpuMemory(Device.gpu(deviceId)); + String cc = getComputeCapability(deviceId); + System.out.println(cc + ',' + mem.getMax() + ',' + mem.getUsed()); + } catch (NumberFormatException e) { + System.out.println("Invalid device: " + args[0]); + } + } + private static CudaLibrary loadLibrary() { try { + if (Boolean.getBoolean("ai.djl.util.cuda.fork")) { + return null; + } if (System.getProperty("os.name").startsWith("Win")) { String path = Utils.getenv("PATH"); if (path == null) { @@ -187,15 +262,40 @@ private static CudaLibrary loadLibrary() { } catch (UnsatisfiedLinkError e) { logger.debug("cudart library not found."); logger.trace("", e); - return null; - } catch (IncompatibleClassChangeError e) { + } catch (LinkageError e) { logger.warn("You have a conflict version of JNA in the classpath."); logger.debug("", e); - return null; } catch (SecurityException e) { logger.warn("Access denied during loading cudart library."); logger.trace("", e); - return null; + } + return null; + } + + private static String[] execute(int deviceId) { + try { + String javaHome = System.getProperty("java.home"); + String classPath = System.getProperty("java.class.path"); + String os = System.getProperty("os.name"); + List cmd = new ArrayList<>(4); + if (os.startsWith("Win")) { + cmd.add(javaHome + "\\bin\\java.exe"); + } else { + cmd.add(javaHome + "/bin/java"); + } + cmd.add("-cp"); + cmd.add(classPath); + cmd.add("ai.djl.util.cuda.CudaUtils"); + if (deviceId >= 0) { + cmd.add(String.valueOf(deviceId)); + } + Process ps = new ProcessBuilder(cmd).redirectErrorStream(true).start(); + try (InputStream is = ps.getInputStream()) { + String line = Utils.toString(is).trim(); + return line.split(","); + } + } catch (IOException e) { + throw new IllegalArgumentException("Failed get GPU information", e); } } diff --git a/api/src/test/java/ai/djl/DeviceTest.java b/api/src/test/java/ai/djl/DeviceTest.java index 92a0474c6e7..a69a502739b 100644 --- a/api/src/test/java/ai/djl/DeviceTest.java +++ b/api/src/test/java/ai/djl/DeviceTest.java @@ -13,6 +13,7 @@ package ai.djl; +import ai.djl.Device.MultiDevice; import ai.djl.engine.Engine; import org.testng.Assert; @@ -37,6 +38,9 @@ public void testDevice() { System.setProperty("test_key", "test"); Engine.debugEnvironment(); + + Assert.assertEquals(1, Device.cpu().getDevices().size()); + Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size()); } @Test @@ -54,5 +58,9 @@ public void testDeviceName() { Device defaultDevice = Engine.getInstance().defaultDevice(); Assert.assertEquals(Device.fromName(""), defaultDevice); Assert.assertEquals(Device.fromName(null), defaultDevice); + + Assert.assertEquals( + Device.fromName("gpu1+gpu2"), new MultiDevice(Device.gpu(2), Device.gpu(1))); + Assert.assertEquals(Device.fromName("gpu1+gpu2"), new MultiDevice("gpu", 1, 3)); } } diff --git a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java index 8c140688124..a8b2bdfab62 100644 --- a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java +++ b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java @@ -15,32 +15,38 @@ import org.testng.Assert; import org.testng.annotations.Test; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; public class PublisherBytesSupplierTest { @Test - public void test() { + public void test() throws ExecutionException, InterruptedException { AtomicInteger contentCount = new AtomicInteger(); PublisherBytesSupplier supplier = new PublisherBytesSupplier(); - // Add to supplier without subscriber - supplier.appendContent(new byte[] {1}, false); - Assert.assertEquals(contentCount.get(), 0); + new Thread( + () -> { + // Add to supplier without subscriber + supplier.appendContent(new byte[] {1}, false); + // Add to supplier with subscriber + supplier.appendContent(new byte[] {1}, true); + }) + .start(); // Subscribing with data should trigger subscriptions - supplier.subscribe( - d -> { - if (d == null) { - // Do nothing on completion - return; - } - contentCount.getAndIncrement(); - }); - Assert.assertEquals(contentCount.get(), 1); + CompletableFuture future = + supplier.subscribe( + d -> { + if (d == null) { + // Do nothing on completion + return; + } + contentCount.getAndIncrement(); + }); - // Add to supplier with subscriber - supplier.appendContent(new byte[] {1}, true); + future.get(); Assert.assertEquals(contentCount.get(), 2); } } diff --git a/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java new file mode 100644 index 00000000000..8fbbae7301b --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.BasicTranslator; +import ai.djl.translate.Translator; + +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +public class YoloV8TranslatorFactoryTest { + + private YoloV8TranslatorFactory factory; + + @BeforeClass + public void setUp() { + factory = new YoloV8TranslatorFactory(); + } + + @Test + public void testGetSupportedTypes() { + Assert.assertEquals(factory.getSupportedTypes().size(), 5); + } + + @Test + public void testNewInstance() { + Map arguments = new HashMap<>(); + try (Model model = Model.newInstance("test")) { + Translator translator1 = + factory.newInstance(Image.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator1 instanceof YoloV8Translator); + + Translator translator2 = + factory.newInstance(Path.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator2 instanceof BasicTranslator); + + Translator translator3 = + factory.newInstance(URL.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator3 instanceof BasicTranslator); + + Translator translator4 = + factory.newInstance(InputStream.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator4 instanceof BasicTranslator); + + Translator translator5 = + factory.newInstance(Input.class, Output.class, model, arguments); + Assert.assertTrue(translator5 instanceof ImageServingTranslator); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(Image.class, Output.class, model, arguments)); + } + } +} diff --git a/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java b/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java index 0e38c2d8be6..98ba896883b 100644 --- a/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java +++ b/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java @@ -97,6 +97,23 @@ public void testNDSerializer() throws IOException { } } + @Test + public void testStringTensor() { + try (NDManager manager = NDManager.newBaseManager("PyTorch")) { + NDArray array = manager.create("hello"); + byte[] buf = array.encode(); + NDArray decoded = NDArray.decode(manager, buf); + Assert.assertTrue(decoded.getShape().isScalar()); + + array = manager.create(new String[] {"hello", "world"}); + buf = array.encode(); + decoded = NDArray.decode(manager, buf); + Assert.assertEquals(decoded.getShape(), array.getShape()); + Assert.assertEquals(decoded.toStringArray()[1], "world"); + Assert.assertEquals(decoded, array); + } + } + private static byte[] encode(NDArray array) throws IOException { try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { NDSerializer.encodeAsNumpy(array, bos); @@ -107,7 +124,7 @@ private static byte[] encode(NDArray array) throws IOException { private static NDArray decode(NDManager manager, byte[] data) throws IOException { try (ByteArrayInputStream bis = new ByteArrayInputStream(data)) { - return NDSerializer.decodeNumpy(manager, bis); + return NDList.decode(manager, bis).get(0); } } diff --git a/api/src/test/java/ai/djl/repository/JarRepositoryTest.java b/api/src/test/java/ai/djl/repository/JarRepositoryTest.java index 4599214fab5..c1370d1da69 100644 --- a/api/src/test/java/ai/djl/repository/JarRepositoryTest.java +++ b/api/src/test/java/ai/djl/repository/JarRepositoryTest.java @@ -45,7 +45,7 @@ public void testResource() throws IOException { URL[] url = {jarFile.toUri().toURL()}; try { Thread.currentThread().setContextClassLoader(new URLClassLoader(url)); - Repository repo = Repository.newInstance("test", "jar:///test.zip?hash=1"); + Repository repo = Repository.newInstance("test", "jar:///test.zip"); Assert.assertEquals("test", repo.getName()); Assert.assertTrue(repo.isRemote()); @@ -55,6 +55,12 @@ public void testResource() throws IOException { Artifact artifact = repo.resolve(list.get(0), null); repo.prepare(artifact); Assert.assertEquals(1, artifact.getFiles().size()); + + repo = Repository.newInstance("test", "jar:///test.zip?ignore_real_uri=true"); + list = repo.getResources(); + artifact = repo.resolve(list.get(0), null); + Path p = repo.getResourceDirectory(artifact); + Assert.assertFalse(Files.exists(p)); } finally { Thread.currentThread().setContextClassLoader(null); } diff --git a/api/src/test/java/ai/djl/repository/ZooTest.java b/api/src/test/java/ai/djl/repository/ZooTest.java index 2b44f967144..29fc10391aa 100644 --- a/api/src/test/java/ai/djl/repository/ZooTest.java +++ b/api/src/test/java/ai/djl/repository/ZooTest.java @@ -17,6 +17,7 @@ import ai.djl.modality.Output; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; import org.testng.Assert; import org.testng.annotations.Test; @@ -48,4 +49,11 @@ public void testInvalidCriteria() Criteria criteria = Criteria.builder().build(); criteria.loadModel(); } + + @Test + public void testModelZooResolver() { + ModelZoo.setModelZooResolver(groupId -> null); + ModelZoo zoo = ModelZoo.getModelZoo("unknown"); + Assert.assertNull(zoo); + } } diff --git a/api/src/test/java/ai/djl/translate/TranslatorTest.java b/api/src/test/java/ai/djl/translate/TranslatorTest.java index 5ce63faa29d..1a636af7787 100644 --- a/api/src/test/java/ai/djl/translate/TranslatorTest.java +++ b/api/src/test/java/ai/djl/translate/TranslatorTest.java @@ -83,6 +83,10 @@ public void testBatchTranslator() throws IOException, ModelException, TranslateE Predictor predictor = model.newPredictor()) { Classifications[] res = predictor.predict(inputs); Assert.assertEquals(res.length, 2); + int intValue = model.intProperty("something", -1); + Assert.assertEquals(intValue, -1); + long longValue = model.longProperty("something", -1L); + Assert.assertEquals(longValue, -1L); } } } diff --git a/api/src/test/java/ai/djl/util/SecurityManagerTest.java b/api/src/test/java/ai/djl/util/SecurityManagerTest.java index fd9b5db72bc..1e9eb17f63c 100644 --- a/api/src/test/java/ai/djl/util/SecurityManagerTest.java +++ b/api/src/test/java/ai/djl/util/SecurityManagerTest.java @@ -74,8 +74,11 @@ public void checkPermission(Permission perm) { } }; System.setSecurityManager(sm); - - Assert.assertFalse(CudaUtils.hasCuda()); - Assert.assertEquals(CudaUtils.getGpuCount(), 0); + try { + Assert.assertFalse(CudaUtils.hasCuda()); + Assert.assertEquals(CudaUtils.getGpuCount(), 0); + } finally { + System.setSecurityManager(null); + } } } diff --git a/api/src/test/java/ai/djl/util/ZipUtilsTest.java b/api/src/test/java/ai/djl/util/ZipUtilsTest.java index 4340019de55..387715bbd44 100644 --- a/api/src/test/java/ai/djl/util/ZipUtilsTest.java +++ b/api/src/test/java/ai/djl/util/ZipUtilsTest.java @@ -45,6 +45,19 @@ public void testEmptyZipFile() throws IOException { } } + @Test + public void testOffendingTar() throws IOException { + Path path = Paths.get("src/test/resources/offending.tar"); + Path output = Paths.get("build/output"); + Path file = output.resolve("tmp/empty.txt"); + Utils.deleteQuietly(file); + Files.createDirectories(output); + try (InputStream is = Files.newInputStream(path)) { + TarUtils.untar(is, output, false); + } + Assert.assertTrue(Files.exists(file)); + } + @Test public void testInvalidZipFile() throws IOException { ByteArrayOutputStream bos = new ByteArrayOutputStream(); diff --git a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java index de1c5cb4a20..a6ad7e52122 100644 --- a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java +++ b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java @@ -20,8 +20,6 @@ import org.testng.annotations.Test; import java.lang.management.MemoryUsage; -import java.util.Arrays; -import java.util.List; public class CudaUtilsTest { @@ -30,6 +28,9 @@ public class CudaUtilsTest { @Test public void testCudaUtils() { if (!CudaUtils.hasCuda()) { + Assert.assertThrows(CudaUtils::getCudaVersionString); + Assert.assertThrows(() -> CudaUtils.getComputeCapability(0)); + Assert.assertThrows(() -> CudaUtils.getGpuMemory(Device.gpu())); return; } // Possible to have CUDA and not have a GPU. @@ -37,16 +38,24 @@ public void testCudaUtils() { return; } - int cudaVersion = CudaUtils.getCudaVersion(); + String cudaVersion = CudaUtils.getCudaVersionString(); String smVersion = CudaUtils.getComputeCapability(0); MemoryUsage memoryUsage = CudaUtils.getGpuMemory(Device.gpu()); logger.info("CUDA runtime version: {}, sm: {}", cudaVersion, smVersion); logger.info("Memory usage: {}", memoryUsage); - Assert.assertTrue(cudaVersion >= 9020, "cuda 9.2+ required."); + Assert.assertNotNull(cudaVersion); + Assert.assertNotNull(smVersion); + } - List supportedSm = Arrays.asList("37", "52", "60", "61", "70", "75"); - Assert.assertTrue(supportedSm.contains(smVersion), "Unsupported cuda sm: " + smVersion); + @Test + public void testCudaUtilsWithFork() { + System.setProperty("ai.djl.util.cuda.fork", "true"); + try { + testCudaUtils(); + } finally { + System.clearProperty("ai.djl.util.cuda.fork"); + } } } diff --git a/api/src/test/resources/offending.tar b/api/src/test/resources/offending.tar new file mode 100644 index 00000000000..3a767ae55ac Binary files /dev/null and b/api/src/test/resources/offending.tar differ diff --git a/apt.txt b/apt.txt index 7083f85c374..c89953ff1f9 100644 --- a/apt.txt +++ b/apt.txt @@ -1 +1 @@ -openjdk-11-jdk +openjdk-17-jdk diff --git a/basicdataset/README.md b/basicdataset/README.md index 37bab679551..1c9ac977198 100644 --- a/basicdataset/README.md +++ b/basicdataset/README.md @@ -29,7 +29,7 @@ You can pull the module from the central Maven repository by including the follo ai.djl basicdataset - 0.23.0 + 0.27.0 ``` diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java index a92a9b6a3d4..deef04907be 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.Map; /** @@ -118,8 +119,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { byte[] buf = Utils.toByteArray(is); try (NDArray array = manager.create( - new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), DataType.UINT8)) { - array.set(buf); + ByteBuffer.wrap(buf), + new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), + DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } @@ -132,8 +134,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java index 164ba9876cb..5503e721caa 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.Map; /** @@ -111,8 +112,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(length, 28, 28, 1), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create( + ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } @@ -123,10 +125,9 @@ private NDArray readLabel(Artifact.Item item) throws IOException { if (is.skip(8) != 8) { throw new AssertionError("Failed skip data."); } - byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java index 42fc1744451..b04ae800a10 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java @@ -44,6 +44,7 @@ public ListFeatures(int initialCapacity) { * * @param source the source list */ + @SuppressWarnings("this-escape") public ListFeatures(List source) { super(source.size()); addAll(source); diff --git a/basicdataset/src/main/resources/imagenet/extract_imagenet.py b/basicdataset/src/main/resources/imagenet/extract_imagenet.py index c618fe05e44..2f161b5757a 100644 --- a/basicdataset/src/main/resources/imagenet/extract_imagenet.py +++ b/basicdataset/src/main/resources/imagenet/extract_imagenet.py @@ -14,6 +14,7 @@ _VAL_TAR = 'ILSVRC2012_img_val.tar' _VAL_TAR_SHA1 = '5f3f73da3395154b60528b2b2a2caf2374f5f178' + def download(url, path=None, overwrite=False, sha1_hash=None): """Download an given URL Parameters @@ -42,26 +43,29 @@ def download(url, path=None, overwrite=False, sha1_hash=None): else: fname = path - if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): + if overwrite or not os.path.exists(fname) or ( + sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): os.makedirs(dirname) - print('Downloading %s from %s...'%(fname, url)) + print('Downloading %s from %s...' % (fname, url)) r = requests.get(url, stream=True) if r.status_code != 200: - raise RuntimeError("Failed downloading url %s"%url) + raise RuntimeError("Failed downloading url %s" % url) total_length = r.headers.get('content-length') with open(fname, 'wb') as f: - if total_length is None: # no content length header + if total_length is None: # no content length header for chunk in r.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks + if chunk: # filter out keep-alive new chunks f.write(chunk) else: total_length = int(total_length) for chunk in tqdm(r.iter_content(chunk_size=1024), total=int(total_length / 1024. + 0.5), - unit='KB', unit_scale=False, dynamic_ncols=True): + unit='KB', + unit_scale=False, + dynamic_ncols=True): f.write(chunk) if sha1_hash and not check_sha1(fname, sha1_hash): @@ -72,25 +76,34 @@ def download(url, path=None, overwrite=False, sha1_hash=None): return fname + def parse_args(): parser = argparse.ArgumentParser( description='Setup the ImageNet dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--download-dir', required=True, - help="The directory that contains downloaded tar files") + parser.add_argument( + '--download-dir', + required=True, + help="The directory that contains downloaded tar files") parser.add_argument('--target-dir', help="The directory to store extracted images") - parser.add_argument('--checksum', action='store_true', + parser.add_argument('--checksum', + action='store_true', help="If check integrity before extracting.") - parser.add_argument('--with-rec', action='store_true', + parser.add_argument('--with-rec', + action='store_true', help="If build image record files.") - parser.add_argument('--num-thread', type=int, default=1, - help="Number of threads to use when building image record file.") + parser.add_argument( + '--num-thread', + type=int, + default=1, + help="Number of threads to use when building image record file.") args = parser.parse_args() if args.target_dir is None: args.target_dir = args.download_dir return args + def check_sha1(filename, sha1_hash): """Check whether the sha1 hash of the file content matches the expected hash. @@ -116,11 +129,13 @@ def check_sha1(filename, sha1_hash): return sha1.hexdigest() == sha1_hash + def check_file(filename, checksum, sha1): if not os.path.exists(filename): - raise ValueError('File not found: '+filename) + raise ValueError('File not found: ' + filename) if checksum and not check_sha1(filename, sha1): - raise ValueError('Corrupted file: '+filename) + raise ValueError('Corrupted file: ' + filename) + def build_rec_process(img_dir, train=False, num_thread=1): rec_dir = os.path.abspath(os.path.join(img_dir, '../rec')) @@ -141,14 +156,8 @@ def build_rec_process(img_dir, train=False, num_thread=1): # execution import sys cmd = [ - sys.executable, - script_path, - rec_dir, - img_dir, - '--recursive', - '--pass-through', - '--pack-label', - '--num-thread', + sys.executable, script_path, rec_dir, img_dir, '--recursive', + '--pass-through', '--pack-label', '--num-thread', str(num_thread) ] subprocess.call(cmd) @@ -156,87 +165,75 @@ def build_rec_process(img_dir, train=False, num_thread=1): os.remove(lst_path) print('ImageRecord file for ' + prefix + ' has been built!') + +def is_within_directory(directory, target): + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + prefix = os.path.commonprefix([abs_directory, abs_target]) + return prefix == abs_directory + + +def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + tar.extractall(path, members, numeric_owner=numeric_owner) + + def extract_train(tar_fname, target_dir, with_rec=False, num_thread=1): os.makedirs(target_dir) with tarfile.open(tar_fname) as tar: - print("Extracting "+tar_fname+"...") + print("Extracting " + tar_fname + "...") # extract each class one-by-one pbar = tqdm(total=len(tar.getnames())) for class_tar in tar: - pbar.set_description('Extract '+class_tar.name) - tar.extract(class_tar, target_dir) + pbar.set_description('Extract ' + class_tar.name) class_fname = os.path.join(target_dir, class_tar.name) + if not is_within_directory(target_dir, class_fname): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extract(class_tar, target_dir) class_dir = os.path.splitext(class_fname)[0] os.mkdir(class_dir) with tarfile.open(class_fname) as f: - def is_within_directory(directory, target): - - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - safe_extract(f, class_dir) + os.remove(class_fname) pbar.update(1) pbar.close() if with_rec: build_rec_process(target_dir, True, num_thread) + def extract_val(tar_fname, target_dir, with_rec=False, num_thread=1): os.makedirs(target_dir) print('Extracting ' + tar_fname) with tarfile.open(tar_fname) as tar: - def is_within_directory(directory, target): - - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - safe_extract(tar, target_dir) + # build rec file before images are moved into subfolders if with_rec: build_rec_process(target_dir, False, num_thread) # move images to proper subfolders - val_maps_file = os.path.join(os.path.dirname(__file__), 'imagenet_val_maps.pklz') + val_maps_file = os.path.join(os.path.dirname(__file__), + 'imagenet_val_maps.pklz') with gzip.open(val_maps_file, 'rb') as f: dirs, mappings = pickle.load(f) for d in dirs: os.makedirs(os.path.join(target_dir, d)) for m in mappings: - os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0])) + os.rename(os.path.join(target_dir, m[0]), + os.path.join(target_dir, m[1], m[0])) + def main(): args = parse_args() target_dir = os.path.expanduser(args.target_dir) if os.path.exists(target_dir): - raise ValueError('Target dir ['+target_dir+'] exists. Remove it first') + raise ValueError('Target dir [' + target_dir + + '] exists. Remove it first') download_dir = os.path.expanduser(args.download_dir) train_tar_fname = os.path.join(download_dir, _TRAIN_TAR) @@ -247,8 +244,11 @@ def main(): build_rec = args.with_rec if build_rec: os.makedirs(os.path.join(target_dir, 'rec')) - extract_train(train_tar_fname, os.path.join(target_dir, 'train'), build_rec, args.num_thread) - extract_val(val_tar_fname, os.path.join(target_dir, 'val'), build_rec, args.num_thread) + extract_train(train_tar_fname, os.path.join(target_dir, 'train'), + build_rec, args.num_thread) + extract_val(val_tar_fname, os.path.join(target_dir, 'val'), build_rec, + args.num_thread) + if __name__ == '__main__': main() diff --git a/basicdataset/src/test/resources/mlrepo/dataset/cv/ai/djl/basicdataset/mnist/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/cv/ai/djl/basicdataset/mnist/metadata.json index 5e5c1b81a95..0b5f61d1d32 100644 --- a/basicdataset/src/test/resources/mlrepo/dataset/cv/ai/djl/basicdataset/mnist/metadata.json +++ b/basicdataset/src/test/resources/mlrepo/dataset/cv/ai/djl/basicdataset/mnist/metadata.json @@ -19,23 +19,23 @@ "snapshot": false, "files": { "train_data": { - "uri": "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/train-images-idx3-ubyte.gz", - "sha1Hash": "6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d", - "size": 9912422 + "uri": "https://mlrepo.djl.ai/dataset/cv/ai/djl/basicdataset/mnist/1.0/train-images-idx3-ubyte.gz", + "sha1Hash": "0e0d45c28981154deda73aabc437dc09aa5a4fd1", + "size": 9822052 }, "train_labels": { - "uri": "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/train-labels-idx1-ubyte.gz", - "sha1Hash": "2a80914081dc54586dbdf242f9805a6b8d2a15fc", - "size": 28881 + "uri": "https://mlrepo.djl.ai/dataset/cv/ai/djl/basicdataset/mnist/1.0/train-labels-idx1-ubyte.gz", + "sha1Hash": "af3fbf34a4396c1ee1a6128dfde57812d8abe06e", + "size": 28902 }, "test_data": { - "uri": "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/t10k-images-idx3-ubyte.gz", - "sha1Hash": "c3a25af1f52dad7f726cce8cacb138654b760d48", - "size": 1648877 + "uri": "https://mlrepo.djl.ai/dataset/cv/ai/djl/basicdataset/mnist/1.0/t10k-images-idx3-ubyte.gz", + "sha1Hash": "5a939b565aa3e5063d816efc7f3dfb721135648d", + "size": 1634335 }, "test_labels": { - "uri": "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/t10k-labels-idx1-ubyte.gz", - "sha1Hash": "763e7fa3757d93b0cdec073cef058b2004252c17", + "uri": "https://mlrepo.djl.ai/dataset/cv/ai/djl/basicdataset/mnist/1.0/t10k-labels-idx1-ubyte.gz", + "sha1Hash": "0e4e66587e3a14f5775793e2ae10d1c48be8ae46", "size": 4542 } } diff --git a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/wikitext-2/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/wikitext-2/metadata.json index f9c64dc8028..1f31ac3afcd 100644 --- a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/wikitext-2/metadata.json +++ b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/wikitext-2/metadata.json @@ -20,10 +20,10 @@ "name": "wikitext-2", "files": { "wikitext-2": { - "uri": "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip", - "sha1Hash": "3c914d17d80b1459be871a5039ac23e752a53cbe", + "uri": "https://mlrepo.djl.ai/dataset/nlp/ai/djl/basicdataset/wikitext-2/1.0/wikitext-2-v1.zip", + "sha1Hash": "46965bdeca1d8165e688598752ca467bb5bee018", "name": "", - "size": 4475746 + "size": 4475596 } } } diff --git a/bom/README.md b/bom/README.md index 44519846712..ecb4f092234 100644 --- a/bom/README.md +++ b/bom/README.md @@ -22,7 +22,7 @@ will need to mention the type as pom and the scope as import) as the following: ai.djl bom - 0.23.0 + 0.27.0 pom import @@ -38,7 +38,7 @@ will need to mention the type as pom and the scope as import) as the following: ai.djl bom - 0.23.0 + 0.27.0 pom import @@ -65,7 +65,7 @@ will need to mention the type as pom and the scope as import) as the following: - First you need add BOM into your build.gradle file as the following: ``` - implementation platform("ai.djl:bom:0.23.0") + implementation platform("ai.djl:bom:0.27.0") ``` - Then you import the desired DJL modules into to you pom.xml file (no version is needed): diff --git a/bom/build.gradle b/bom/build.gradle index 4708978b5b5..0c509740f92 100644 --- a/bom/build.gradle +++ b/bom/build.gradle @@ -19,7 +19,6 @@ dependencies { api "ai.djl:basicdataset:${version}" api "ai.djl:model-zoo:${version}" api "ai.djl:djl-zero:${version}" - api "ai.djl:serving:${version}" api "ai.djl.android:core:${version}" api "ai.djl.android:onnxruntime:${version}" api "ai.djl.android:pytorch-native:${version}" @@ -28,6 +27,7 @@ dependencies { api "ai.djl.fasttext:fasttext-engine:${version}" api "ai.djl.hadoop:hadoop:${version}" api "ai.djl.huggingface:tokenizers:${version}" + api "ai.djl.llama:llama:${version}" api "ai.djl.ml.lightgbm:lightgbm:${version}" api "ai.djl.ml.xgboost:xgboost-gpu:${version}" api "ai.djl.ml.xgboost:xgboost:${version}" @@ -43,6 +43,9 @@ dependencies { api "ai.djl.pytorch:pytorch-model-zoo:${version}" api "ai.djl.sentencepiece:sentencepiece:${version}" api "ai.djl.spark:spark_2.12:${version}" + api "ai.djl.serving:prometheus:${version}" + api "ai.djl.serving:serving:${version}" + api "ai.djl.serving:wlm:${version}" api "ai.djl.tablesaw:tablesaw:${version}" api "ai.djl.tensorflow:tensorflow-api:${version}" api "ai.djl.tensorflow:tensorflow-engine:${version}" @@ -115,15 +118,12 @@ publishing { addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "win-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-aarch64", "${pytorch_version}") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116", "linux-x86_64", "1.12.1") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116", "win-x86_64", "1.12.1") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116-precxx11", "linux-x86_64", "1.12.1") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "linux-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "win-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121-precxx11", "linux-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117", "linux-x86_64", "1.13.1") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117", "win-x86_64", "1.13.1") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117-precxx11", "linux-x86_64", "1.13.1") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118", "linux-x86_64", "${pytorch_version}") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118", "win-x86_64", "${pytorch_version}") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118-precxx11", "linux-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "osx-x86_64", "${tensorflow_version}") addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "linux-x86_64", "${tensorflow_version}") addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "win-x86_64", "${tensorflow_version}") diff --git a/build.gradle b/build.gradle index f98b86c4e51..ca6f7e68133 100644 --- a/build.gradle +++ b/build.gradle @@ -44,6 +44,7 @@ configure(javaProjects()) { targetCompatibility = JavaVersion.VERSION_11 options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static,-removal" << "-Werror" } + javadoc.options.addStringOption("Xdoclint:none", "-quiet") apply plugin: 'eclipse' @@ -88,7 +89,7 @@ configure(javaProjects()) { systemProperty "disableProgressBar", "true" systemProperty "nightly", System.getProperty("nightly", "false") if (gradle.startParameter.offline) { - systemProperty "offline", "true" + systemProperty "ai.djl.offline", "true" } // This is used to avoid overriding on default engine for modules: // mxnet-engine, mxnet-model-zoo, api (MockEngine), basicdataset, fasttext, etc diff --git a/djl-zero/README.md b/djl-zero/README.md index 2d2c473cc88..91c84554c58 100644 --- a/djl-zero/README.md +++ b/djl-zero/README.md @@ -49,6 +49,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl djl-zero - 0.23.0 + 0.27.0 ``` diff --git a/docker/README.md b/docker/README.md index 5b5bd01be2b..0df33be9f83 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,10 +1,12 @@ # Docker Resources + DJL provides docker files that you can use to setup containers with the appropriate environment for certain platforms. We recommend setting up a docker container with the provided Dockerfile when developing for the following platforms and/or engines. ## Windows + You can use the [docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/windows/Dockerfile) provided by us. Please note that this docker will only work with Windows server 2019 by default. If you want it to work with other versions of Windows, you need to pass the version as an argument as follows: @@ -14,19 +16,20 @@ docker build --build-arg version= ``` ## TensorRT + You can use the [docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) provided by us. This docker file is a modification of the one provided by NVIDIA in -[TensorRT](https://github.com/NVIDIA/TensorRT/blob/8.4.1/docker/ubuntu-18.04.Dockerfile) to include JDK11. -By default this sets up a container using Ubuntu 18.04 and CUDA 11.6.2. You can build the container with other versions as follows, +[TensorRT](https://github.com/NVIDIA/TensorRT/blob/8.4.1/docker/ubuntu-18.04.Dockerfile) to include JDK17. +By default this sets up a container using Ubuntu 18.04 and CUDA 11.6.2. You can build the container with other versions as follows, but keep in mind the TensorRT software requirements outlined [here](https://github.com/NVIDIA/TensorRT#prerequisites): ```bash docker build --build-arg OS_VERSION= --build-arg CUDA_VERSION= ``` -To run the container, we recommend using `nvidia-docker run ...` to ensure cuda driver and runtime are compatible. +To run the container, we recommend using `nvidia-docker run ...` to ensure cuda driver and runtime are compatible. -We recommend that you follow the setup steps in the [TensorRT guide](https://github.com/NVIDIA/TensorRT) if you -need access to the full suite of tools TensorRT provides, such as `trtexec` which can convert onnx models to -uff tensorrt models. When following that guide, make sure to use the DJL provided -[docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) to enable JDK11 in the docker container. +We recommend that you follow the setup steps in the [TensorRT guide](https://github.com/NVIDIA/TensorRT) if you +need access to the full suite of tools TensorRT provides, such as `trtexec` which can convert onnx models to +uff tensorrt models. When following that guide, make sure to use the DJL provided +[docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) to enable JDK17 in the docker container. diff --git a/docker/spark/Dockerfile b/docker/spark/Dockerfile index b715899e2f1..b777d5a69ed 100644 --- a/docker/spark/Dockerfile +++ b/docker/spark/Dockerfile @@ -13,7 +13,7 @@ FROM 314815235551.dkr.ecr.us-east-2.amazonaws.com/sagemaker-spark-processing:3.3 LABEL maintainer="djl-dev@amazon.com" # Install dependencies -ARG DJL_VERSION=0.23.0 +ARG DJL_VERSION=0.24.0 ARG JNA_VERSION=5.13.0 ARG JAVACV_VERSION=1.5.9 ARG JAVACPP_VERSION=1.5.9 diff --git a/docker/tensorrt/Dockerfile b/docker/tensorrt/Dockerfile index 3a99bb9cb5d..a92dad12f4d 100644 --- a/docker/tensorrt/Dockerfile +++ b/docker/tensorrt/Dockerfile @@ -14,15 +14,43 @@ # See the License for the specific language governing permissions and # limitations under the License. # -ARG CUDA_VERSION=11.6.2 -ARG OS_VERSION=18.04 -FROM nvidia/cuda:${CUDA_VERSION}-cudnn8-devel-ubuntu${OS_VERSION} +ARG CUDA_VERSION=12.2.2 -ENV TRT_VERSION 8.4.1.5 +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 + +ENV NV_CUDNN_VERSION 8.9.6.50 +ENV NV_CUDNN_PACKAGE_NAME "libcudnn8" + +ARG CUDA_VERSION_MAJOR_MINOR=12.2 + +ENV NV_CUDNN_PACKAGE "libcudnn8=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" +ENV NV_CUDNN_PACKAGE_DEV "libcudnn8-dev=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" + +ENV TRT_VERSION 9.2.0.5 SHELL ["/bin/bash", "-c"] RUN mkdir -p /workspace -# Install Required Libraries +RUN apt-get update && apt-get install -y --no-install-recommends \ + ${NV_CUDNN_PACKAGE} \ + ${NV_CUDNN_PACKAGE_DEV} \ + && apt-mark hold ${NV_CUDNN_PACKAGE_NAME} \ + && rm -rf /var/lib/apt/lists/* + +# Setup user account +ARG uid=1000 +ARG gid=1000 +RUN groupadd -r -f -g ${gid} djl && useradd -o -r -l -u ${uid} -g ${gid} -ms /bin/bash djl +RUN usermod -aG sudo djl +RUN echo 'djl:djl' | chpasswd +RUN mkdir -p /workspace && chown djl /workspace + +# Required to build Ubuntu 20.04 without user prompts with DLFW container +ENV DEBIAN_FRONTEND=noninteractive + +# Update CUDA signing key +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub + +# Install requried libraries RUN apt-get update && apt-get install -y software-properties-common RUN add-apt-repository ppa:ubuntu-toolchain-r/test RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -42,7 +70,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ fakeroot \ dh-make \ build-essential \ - openjdk-11-jdk && \ + openjdk-17-jdk &&\ apt-get clean -y && rm -rf /var/lib/apt/lists/* # Install python3 @@ -53,17 +81,24 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ python3-wheel &&\ cd /usr/local/bin &&\ ln -s /usr/bin/python3 python &&\ - ln -s /usr/bin/pip3 pip && \ + ln -s /usr/bin/pip3 pip &&\ apt-get clean -y && rm -rf /var/lib/apt/lists/* # Install TensorRT -RUN v="${TRT_VERSION%.*}-1+cuda${CUDA_VERSION%.*}" &&\ - apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub &&\ - apt-get update &&\ - sudo apt-get install libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} \ - libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} \ - python3-libnvinfer=${v}; \ - apt-get clean -y && rm -rf /var/lib/apt/lists/* +RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.2.0/tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-9.2.0.5/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-9.2.0.5/python/tensorrt-9.2.0.post11.dev5-cp38-none-linux_x86_64.whl ;\ +elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.2.0/tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz \ + && tar -xf tensorrt-9.2.0.5.linux.x86_64-gnu.cuda-12.2.tar.gz \ + && cp -a TensorRT-9.2.0.5/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-9.2.0.5/python/tensorrt-9.2.0.post12.dev5-cp38-none-linux_x86_64.whl ;\ +else \ + echo "Invalid CUDA_VERSION"; \ + exit 1; \ +fi # Install Cmake RUN cd /tmp && \ @@ -72,19 +107,16 @@ RUN cd /tmp && \ ./cmake-3.14.4-Linux-x86_64.sh --prefix=/usr/local --exclude-subdir --skip-license && \ rm ./cmake-3.14.4-Linux-x86_64.sh -RUN cd /usr/local/bin && \ - wget https://ngc.nvidia.com/downloads/ngccli_cat_linux.zip && \ - unzip ngccli_cat_linux.zip && \ - chmod u+x ngc-cli/ngc && \ - rm ngccli_cat_linux.zip ngc-cli.md5 && \ - echo "no-apikey\nascii\n" | ngc-cli/ngc config set - +# Download NGC client +RUN cd /usr/local/bin && wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/ngc-apps/ngc_cli/versions/3.38.0/files/ngccli_linux.zip -O ngccli_linux.zip && unzip ngccli_linux.zip && chmod u+x ngc-cli/ngc && rm ngccli_linux.zip ngc-cli.md5 && echo "no-apikey\nascii\n" | ngc-cli/ngc config set # Set environment and working directory ENV TRT_LIBPATH /usr/lib/x86_64-linux-gnu ENV TRT_OSSPATH /workspace/TensorRT ENV PATH="${PATH}:/usr/local/bin/ngc-cli" ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${TRT_OSSPATH}/build/out:${TRT_LIBPATH}" +ENV JAVA_HOME=/usr/lib/jvm/java-17-openjdk-amd64 WORKDIR /workspace +USER djl RUN ["/bin/bash"] diff --git a/docker/windows/Dockerfile b/docker/windows/Dockerfile index 31567b3168b..10989e8a4c8 100644 --- a/docker/windows/Dockerfile +++ b/docker/windows/Dockerfile @@ -11,4 +11,4 @@ RUN powershell -Command \ Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://chocolatey.org/install.ps1')); \ choco feature disable --name showDownloadProgress -RUN choco install -y openjdk11 +RUN choco install -y openjdk17 diff --git a/docs/README.md b/docs/README.md index cdd02661c78..7749d39eb5f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,7 +2,7 @@ This folder contains examples and documentation for the Deep Java Library (DJL) project. -## [JavaDoc API Reference](https://javadoc.djl.ai/) +## [JavaDoc API Reference](https://djl.ai/website/javadoc.html) Note: when searching in JavaDoc, if your access is denied, please try removing the string `undefined` in the url. @@ -20,14 +20,14 @@ Note: when searching in JavaDoc, if your access is denied, please try removing t - [Troubleshooting](development/troubleshooting.md) - [Inference Optimization](development/inference_performance_optimization.md) -## [Jupyter notebook tutorials](../jupyter/README.md) +## [Jupyter notebook tutorials](http://docs.djl.ai/docs/demos/jupyter/index.html) -- **[Beginner Jupyter Tutorial](../jupyter/tutorial/README.md)** -- [Run object detection with model zoo](../jupyter/object_detection_with_model_zoo.ipynb) -- [Load pre-trained PyTorch model](../jupyter/load_pytorch_model.ipynb) -- [Load pre-trained Apache MXNet model](../jupyter/load_mxnet_model.ipynb) -- [Transfer learning example](../jupyter/transfer_learning_on_cifar10.ipynb) -- [Question answering example](../jupyter/BERTQA.ipynb) +- **[Beginner Jupyter Tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/index.html)** +- [Run object detection with model zoo](http://docs.djl.ai/docs/demos/jupyter/object_detection_with_model_zoo.html) +- [Load pre-trained PyTorch model](http://docs.djl.ai/docs/demos/jupyter/load_pytorch_model.html) +- [Load pre-trained Apache MXNet model](http://docs.djl.ai/docs/demos/jupyter/load_mxnet_model.html) +- [Transfer learning example](http://docs.djl.ai/docs/demos/jupyter/transfer_learning_on_cifar10.html) +- [Question answering example](http://docs.djl.ai/docs/demos/jupyter/BERTQA.html) ## [API Examples](../examples/README.md) diff --git a/docs/development/cache_management.md b/docs/development/cache_management.md index b0b56460e54..2bdacb9a699 100644 --- a/docs/development/cache_management.md +++ b/docs/development/cache_management.md @@ -30,10 +30,10 @@ ONNXRuntime will extract native libraries into system default temporary-file dir ### Huggingface tokenizer -If the `TOKENIZERS_CACHE` environment variable is set, Huggingface tokenizer will store cache files in it. +If the `HF_HOME` or `HF_HUB_CACHE` environment variable is set, Huggingface tokenizer will store cache files in it. It is the responsibility of the user to make sure this path is correct. Otherwise, we try to use the default cache directory as defined for each OS: -- macOS: `/Users/{user}/Library/Caches/huggingface/tokenizers` -- linux: `/home/{user}/.cache/huggingface/tokenizers` -- windows: `C:\Users\{user}\AppData\Local\huggingface\tokenizers` +- macOS: `/Users/{user}/.cache/huggingface/hub` +- linux: `/home/{user}/.cache/huggingface/hub` +- windows: `C:\Users\{user}\.cache\huggingface\hub` diff --git a/docs/development/example_dataset.md b/docs/development/example_dataset.md index 35e071f728b..2f9fb456e02 100644 --- a/docs/development/example_dataset.md +++ b/docs/development/example_dataset.md @@ -1,4 +1,4 @@ -## Example CSV Dataset +# Custom CSV Dataset Example If the provided Datasets don't meet your requirements, you can also easily extend our dataset to create your own customized dataset. @@ -24,8 +24,8 @@ api group: 'org.apache.commons', name: 'commons-csv', version: '1.7' In order to extend the dataset, the following dependencies are required: ``` -api "ai.djl:api:0.23.0" -api "ai.djl:basicdataset:0.23.0" +api "ai.djl:api:0.27.0" +api "ai.djl:basicdataset:0.27.0" ``` There are four parts we need to implement for CSVDataset. diff --git a/docs/development/external_libraries.md b/docs/development/external_libraries.md index 7f57fec3165..701fb9d0a03 100644 --- a/docs/development/external_libraries.md +++ b/docs/development/external_libraries.md @@ -1,5 +1,4 @@ - -## DJL external dependencies +# DJL external dependencies This document contains external libraries that DJL depends on and their versions. diff --git a/docs/development/inference_performance_optimization.md b/docs/development/inference_performance_optimization.md index 27bccfd3f3e..0fdc67d999c 100644 --- a/docs/development/inference_performance_optimization.md +++ b/docs/development/inference_performance_optimization.md @@ -85,6 +85,23 @@ You can enable it by setting the environment variable: You might see an exception if a data type or operator is not supported with the oneDNN device. +#### oneDNN(MKLDNN) tuning on AWS Graviton3 +AWS Graviton3(E) (e.g. c7g/m7g/r7g, c7gn and Hpc7g instances) supports BF16 format for ML acceleration. This can be enabled in oneDNN by setting the below environment variable +``` +grep -q bf16 /proc/cpuinfo && export DNNL_DEFAULT_FPMATH_MODE=BF16 +``` +To avoid redundant primitive creation latency overhead, enable primitive caching by setting the LRU cache capacity. Please note this caching feature increases the memory footprint. It is recommended to tune the capacity to an optimal value for a given use case. + +``` +export LRU_CACHE_CAPACITY=1024 +``` + +In addition to avoiding the redundant allocations, tensor memory allocation latencies can be optimized with Linux transparent huge pages (THP). To enable THP allocations, set the following torch environment variable. +``` +export THP_MEM_ALLOC_ENABLE=1 +``` +Please refer to [PyTorch Graviton tutorial](https://pytorch.org/tutorials/recipes/inference_tuning_on_aws_graviton.html) for more details on how to achieve the best PyTorch inference performance on AWS Graviton3 instances. + #### CuDNN acceleration PyTorch has a special flag that is used for a CNN or related network speed up. If your input size won't change frequently, you may benefit from enabling this configuration in your model: diff --git a/docs/development/profiler.md b/docs/development/profiler.md index 6db5739483c..4a2a9f626e4 100644 --- a/docs/development/profiler.md +++ b/docs/development/profiler.md @@ -1,4 +1,4 @@ -## Profiler (Experimental) +# Engine Profiler Support Currently, DJL supports experimental profilers for developers that investigate the performance of operator execution as well as memory consumption. diff --git a/docs/development/setup.md b/docs/development/setup.md index e4eb73b2501..fb290eb0e3a 100644 --- a/docs/development/setup.md +++ b/docs/development/setup.md @@ -10,13 +10,13 @@ you can use the $JAVA_HOME environment variable to control which version of Java For ubuntu: ```bash -sudo apt-get install openjdk-11-jdk +sudo apt-get install openjdk-17-jdk ``` For centos ```bash -sudo yum install java-11-openjdk +sudo yum install java-17-openjdk ``` For Mac: @@ -24,7 +24,7 @@ For Mac: ```bash brew tap homebrew/cask-versions brew update -brew install --cask temurin11 +brew install --cask zulu17 ``` You can also download and install [Oracle JDK](https://www.oracle.com/technetwork/java/javase/overview/index.html) diff --git a/docs/get.md b/docs/get.md index 8c4c34502ad..2c6e8b99968 100644 --- a/docs/get.md +++ b/docs/get.md @@ -99,7 +99,7 @@ dependencies { implementation platform("ai.djl:bom:-SNAPSHOT") } ``` -Currently, the ` = 0.21.0`. +Currently, the ` = 0.28.0`. This snapshot version is the same as the custom DJL repository. You also need to change directory to `djl/bom`. Then build and publish it to maven local same as what was done in `djl`. diff --git a/docs/hybrid_engine.md b/docs/hybrid_engine.md index 58bdbe69cb4..cc6ec9400d2 100644 --- a/docs/hybrid_engine.md +++ b/docs/hybrid_engine.md @@ -21,17 +21,17 @@ to run in a hybrid mode: To use it along with Apache MXNet for additional API support, add the following two dependencies: ``` -runtimeOnly "ai.djl.mxnet:mxnet-engine:0.23.0" +runtimeOnly "ai.djl.mxnet:mxnet-engine:0.27.0" ``` You can also use PyTorch or TensorFlow Engine as the supplemental engine by adding their corresponding dependencies. ``` -runtimeOnly "ai.djl.pytorch:pytorch-engine:0.23.0" +runtimeOnly "ai.djl.pytorch:pytorch-engine:0.27.0" ``` ``` -runtimeOnly "ai.djl.tensorflow:tensorflow-engine:0.23.0" +runtimeOnly "ai.djl.tensorflow:tensorflow-engine:0.27.0" ``` ## How Hybrid works diff --git a/docs/interactive_tool.md b/docs/interactive_tool.md index ed102fedc8d..d7d267db710 100644 --- a/docs/interactive_tool.md +++ b/docs/interactive_tool.md @@ -63,7 +63,7 @@ After that, click `run` and you should see the following result: Finally, you can get the running project setup by clicking `Get Template`. This will bring you a gradle project that can be used in your local machine. -## [Java Jupyter Notebook](../jupyter/README.md) +## [Java Jupyter Notebook](http://docs.djl.ai/docs/demos/jupyter/index.html) Wait a second, are we talking about hosting Jupyter Notebook in python? No, it’s Java 11, only. @@ -71,9 +71,9 @@ No, it’s Java 11, only. ![jupyter](https://djl-ai.s3.amazonaws.com/web-data/images/jupyter.gif) Inspired by Spencer Park’s [IJava project](https://github.com/SpencerPark/IJava), we integrated DJL with Jupyter Notebooks. -For more information on the simple setup, follow the instructions in [DJL Jupyter notebooks](../jupyter/README.md#setup). +For more information on the simple setup, follow the instructions in [DJL Jupyter notebooks](http://docs.djl.ai/docs/demos/jupyter/index.html#setup). After that, use the Jupyter Notebook freely in your hosted server. You can do all kinds of work, like block building and plotting a graph. -There are [tutorials and instructions](../jupyter/README.md#djl---jupyter-notebooks) to guide you how you can run training and/or inference with Java. +There are [tutorials and instructions](http://docs.djl.ai/docs/demos/jupyter/index.html#djl---jupyter-notebooks) to guide you how you can run training and/or inference with Java. ## About Future Lab diff --git a/docs/load_model.md b/docs/load_model.md index 621d7514605..3c0afec26e9 100644 --- a/docs/load_model.md +++ b/docs/load_model.md @@ -181,7 +181,7 @@ Here is a few tips you can use to help you debug model loading issue: See [here](development/configure_logging.md#configure-logging-level) for how to enable debug log #### List models programmatically in your code -You can use [ModelZoo.listModels()](https://javadoc.io/static/ai.djl/api/0.23.0/ai/djl/repository/zoo/ModelZoo.html#listModels--) API to query available models. +You can use [ModelZoo.listModels()](https://javadoc.io/static/ai.djl/api/0.27.0/ai/djl/repository/zoo/ModelZoo.html#listModels--) API to query available models. #### List available models using DJL command line diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c911bf43b2d..ef7c46d331a 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -61,15 +61,15 @@ nav: - 'docs/faq.md' - Tutorials: - Beginner Tutorial: - - 'jupyter/tutorial/01_create_your_first_network.ipynb' - - 'jupyter/tutorial/02_train_your_first_model.ipynb' - - 'jupyter/tutorial/03_image_classification_with_your_model.ipynb' + - 'docs/demos/jupyter/tutorial/01_create_your_first_network.ipynb' + - 'docs/demos/jupyter/tutorial/02_train_your_first_model.ipynb' + - 'docs/demos/jupyter/tutorial/03_image_classification_with_your_model.ipynb' - 'docs/d2l.md' - - 'jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb' - - 'jupyter/transfer_learning_on_cifar10.ipynb' + - 'docs/demos/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb' + - 'docs/demos/jupyter/transfer_learning_on_cifar10.ipynb' - Load your own BERT: - - BERT with MXNet: 'jupyter/mxnet/load_your_own_mxnet_bert.ipynb' - - BERT with PyTorch: 'jupyter/pytorch/load_your_own_pytorch_bert.ipynb' + - BERT with MXNet: 'docs/demos/jupyter/mxnet/load_your_own_mxnet_bert.ipynb' + - BERT with PyTorch: 'docs/demos/jupyter/pytorch/load_your_own_pytorch_bert.ipynb' - Guides: - Models: - 'docs/load_model.md' @@ -97,25 +97,25 @@ nav: - PyTorch NDArray Operators: 'docs/pytorch/pytorch-djl-ndarray-cheatsheet.md' - PyTorch Model Zoo: 'engines/pytorch/pytorch-model-zoo/README.md' - Import PyTorch Model: 'docs/pytorch/how_to_convert_your_model_to_torchscript.md' - - Load a PyTorch Model: 'jupyter/load_pytorch_model.ipynb' + - Load a PyTorch Model: 'docs/demos/jupyter/load_pytorch_model.ipynb' - TensorFlow: - Overview: 'engines/tensorflow/README.md' - TensorFlow Engine: 'engines/tensorflow/tensorflow-engine/README.md' - TensorFlow Model Zoo: 'engines/tensorflow/tensorflow-model-zoo/README.md' - Import TensorFlow Model: 'docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md' - - Load a TensorFlow Model: 'jupyter/tensorflow/pneumonia_detection.ipynb' + - Load a TensorFlow Model: 'docs/demos/jupyter/tensorflow/pneumonia_detection.ipynb' - Apache MXNet: - Overview: 'engines/mxnet/README.md' - MXNet Engine: 'engines/mxnet/mxnet-engine/README.md' - MXNet Model Zoo: 'engines/mxnet/mxnet-model-zoo/README.md' - Import Gluon Model: 'docs/mxnet/how_to_convert_your_model_to_symbol.md' - - Load a MXNet Model: 'jupyter/load_mxnet_model.ipynb' + - Load a MXNet Model: 'docs/demos/jupyter/load_mxnet_model.ipynb' - Backend Optimizer for MXNet: 'docs/mxnet/mxnet_backend_optimizer.md' - Hybrid engines: - Hybrid engine overview: 'docs/hybrid_engine.md' - ONNX Runtime: - Overview: 'engines/onnxruntime/onnxruntime-engine/README.md' - - Load a ONNX Model: 'jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb' + - Load a ONNX Model: 'docs/demos/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb' - PaddlePaddle: - Overview: 'engines/paddlepaddle/README.md' - PaddlePaddle Engine: 'engines/paddlepaddle/paddlepaddle-engine/README.md' @@ -124,11 +124,11 @@ nav: - English: 'docs/paddlepaddle/how_to_create_paddlepaddle_model.md' - 中文: 'docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md' - Facemask detection using PaddlePaddle: - - English: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb' - - 中文: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb' + - English: 'docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb' + - 中文: 'docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb' - PaddleOCR example: - - English: 'jupyter/paddlepaddle/paddle_ocr_java.ipynb' - - 中文: 'jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb' + - English: 'docs/demos/jupyter/paddlepaddle/paddle_ocr_java.ipynb' + - 中文: 'docs/demos/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb' - XGBoost: 'engines/ml/xgboost/README.md' - LightGBM: 'engines/ml/lightgbm/README.md' - TensorRT: 'engines/tensorrt/README.md' @@ -153,15 +153,49 @@ nav: - 'docs/serving/serving/docs/inference.md' - 'docs/serving/serving/docs/modes.md' - 'docs/serving/serving/docs/console.md' - - 'docs/serving/serving/docs/configuration.md' - - 'docs/serving/serving/docs/configurations.md' - - 'docs/serving/serving/docs/workflows.md' + - Configuration: + - 'docs/serving/serving/docs/configuration.md' + - 'docs/serving/serving/docs/configurations_global.md' + - 'docs/serving/serving/docs/configurations.md' + - 'docs/serving/serving/docs/workflows.md' + - 'docs/serving/serving/docs/configurations_model.md' - 'docs/serving/serving/docs/architecture.md' - HTTP API: - 'docs/serving/serving/docs/inference_api.md' - 'docs/serving/serving/docs/management_api.md' - 'docs/serving/serving/docs/plugin_management.md' - 'docs/serving/wlm/README.md' + - Large Model Inference: + - 'docs/serving/serving/docs/lmi/README.md' + - User Guides: + - 'docs/serving/serving/docs/lmi/user_guides/README.md' + - 'docs/serving/serving/docs/lmi/user_guides/starting-guide.md' + - 'docs/serving/serving/docs/lmi/user_guides/deepspeed_user_guide.md' + - 'docs/serving/serving/docs/lmi/user_guides/lmi-dist_user_guide.md' + - 'docs/serving/serving/docs/lmi/user_guides/vllm_user_guide.md' + - 'docs/serving/serving/docs/lmi/user_guides/tnx_user_guide.md' + - 'docs/serving/serving/docs/lmi/user_guides/trt_llm_user_guide.md' + - 'docs/serving/serving/docs/lmi/user_guides/hf_accelerate.md' + - 'docs/serving/serving/docs/lmi/user_guides/lmi_input_output_schema.md' + - 'docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.md' + - Deployment Guides: + - 'docs/serving/serving/docs/lmi/deployment_guide/README.md' + - 'docs/serving/serving/docs/lmi/deployment_guide/model-artifacts.md' + - 'docs/serving/serving/docs/lmi/deployment_guide/instance-type-selection.md' + - 'docs/serving/serving/docs/lmi/deployment_guide/backend-selection.md' + - 'docs/serving/serving/docs/lmi/deployment_guide/configurations.md' + - 'docs/serving/serving/docs/lmi/deployment_guide/deploying-your-endpoint.md' + - 'docs/serving/serving/docs/lmi/deployment_guide/benchmarking-your-endpoint.md' + - 'docs/serving/serving/docs/lmi/deployment_guide/testing-custom-script.md' + - Tutorials: + - 'docs/serving/serving/docs/lmi/tutorials/seq_scheduler_tutorial.md' + - 'docs/serving/serving/docs/lmi/tutorials/trtllm_aot_tutorial.md' + - 'docs/serving/serving/docs/lmi/tutorials/trtllm_manual_convert_tutorial.md' + - 'docs/serving/serving/docs/lmi/tutorials/tnx_aot_tutorial.md' + - Conceptual Guides: + - 'docs/serving/serving/docs/lmi/conceptual_guide/lmi_engine.md' + - SageMaker LMI containers resources: + - 'docs/demos/aws/sagemaker/large-model-inference/README.md' - Demos: - Demos: 'docs/demos/README.md' - AWS: diff --git a/docs/mxnet/how_to_convert_your_model_to_symbol.md b/docs/mxnet/how_to_convert_your_model_to_symbol.md index be178afe437..57a5b8a9b05 100644 --- a/docs/mxnet/how_to_convert_your_model_to_symbol.md +++ b/docs/mxnet/how_to_convert_your_model_to_symbol.md @@ -1,4 +1,4 @@ -## How to convert your Gluon model to an MXNet Symbol +# How to convert your Gluon model to an MXNet Symbol DJL currently supports symbolic model loading from MXNet. A gluon [HybridBlock](https://mxnet.apache.org/api/python/docs/api/gluon/hybrid_block.html) can be converted into a symbol for loading by doing as follows: diff --git a/docs/paddlepaddle/how_to_create_paddlepaddle_model.md b/docs/paddlepaddle/how_to_create_paddlepaddle_model.md index 042acbd2d61..b78d4406946 100644 --- a/docs/paddlepaddle/how_to_create_paddlepaddle_model.md +++ b/docs/paddlepaddle/how_to_create_paddlepaddle_model.md @@ -157,5 +157,5 @@ predictor.predict(list); As mentioned, you need to find out what is the input for the model, like images usually interpret as NCHW (batch_size, channel, height, width). -However, usage like this is really basic, you can write a `Translator` in DJL for it. You can find some code examples [here](../../jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb). +However, usage like this is really basic, you can write a `Translator` in DJL for it. You can find some code examples [here](http://docs.djl.ai/docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.html). diff --git a/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md b/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md index 74e5dec634f..5f79d713783 100644 --- a/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md +++ b/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md @@ -156,4 +156,4 @@ predictor.predict(list); 在čŋ™é‡ŒīŧŒäŊ éœ€čĻįŸĨé“æ¨Ąåž‹įš„čž“å…Ĩ输å‡ēæ ŧåŧ, 比åĻ‚回į‰‡įģå¸¸čĄ¨čžžæˆ NCHW (扚大小, RGB通道, éĢ˜åēĻ, åŽŊåēĻ)įš„多įģ´įŸŠé˜ĩ。 -č™Ŋį„ļčŋ™æ ˇå¯äģĨčŽŠæ¨Ąåž‹čˇ‘čĩˇæĨ, äŊ†æ˜¯æœ€åĨŊčŋ˜æ˜¯įģ“合 DJL įš„ `Translator` class äŊŋį”¨ã€‚äŊ å¯äģĨ在 [čŋ™é‡Œ](../../jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb) 扞到一äē›į¤ē例äģŖį ã€‚ +č™Ŋį„ļčŋ™æ ˇå¯äģĨčŽŠæ¨Ąåž‹čˇ‘čĩˇæĨ, äŊ†æ˜¯æœ€åĨŊčŋ˜æ˜¯įģ“合 DJL įš„ `Translator` class äŊŋį”¨ã€‚äŊ å¯äģĨ在 [čŋ™é‡Œ](http://docs.djl.ai/docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.html) 扞到一äē›į¤ē例äģŖį ã€‚ diff --git a/docs/pytorch/how_to_convert_your_model_to_torchscript.md b/docs/pytorch/how_to_convert_your_model_to_torchscript.md index 4dd4b3102d7..f90ee468764 100644 --- a/docs/pytorch/how_to_convert_your_model_to_torchscript.md +++ b/docs/pytorch/how_to_convert_your_model_to_torchscript.md @@ -1,4 +1,4 @@ -## How to convert your PyTorch model to TorchScript +# How to convert your PyTorch model to TorchScript There are two ways to convert your model to TorchScript: tracing and scripting. We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation. diff --git a/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md b/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md index 7416ec50bab..37d24276d82 100644 --- a/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md +++ b/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md @@ -1,4 +1,4 @@ -## PyTorch NDArray operators +# PyTorch NDArray operators In the following examples, we assume diff --git a/docs/quick_start.md b/docs/quick_start.md index f352a39156a..b7072a50a59 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -1,7 +1,7 @@ # Quick start Deep Java Library (DJL) is designed to be easy to get started with and simple to use. -The easiest way to learn DJL is to read the [beginner tutorial](../jupyter/tutorial/README.md) or +The easiest way to learn DJL is to read the [beginner tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/README.md) or our [examples](../examples/README.md). You can also view our 1.5 hour long (in 8 x ~10 minute segments) DJL 101 tutorial video series: @@ -22,7 +22,7 @@ See [DJL Future Labs](interactive_tool.md) ## Beginner tutorial -To get started, we recommend that you follow our short [beginner tutorial](../jupyter/tutorial/README.md). It takes you through some of the basics of deep learning to create a model, train your model, and run inference using your trained model. +To get started, we recommend that you follow our short [beginner tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/index.html). It takes you through some of the basics of deep learning to create a model, train your model, and run inference using your trained model. ## Run examples @@ -33,10 +33,10 @@ All of our examples are executed by a simple command. For detailed command line - [Train your first model](../examples/docs/train_mnist_mlp.md) - [Single-shot Object Detection inference example](../examples/docs/object_detection.md) - [More examples](https://github.com/deepjavalibrary/djl/tree/master/examples) -- [Jupyter examples](../jupyter/README.md) +- [Jupyter examples](http://docs.djl.ai/docs/demos/jupyter/index.html) ## Other resources -- [JavaDoc API Reference](https://javadoc.djl.ai/) +- [JavaDoc API Reference](https://djl.ai/website/javadoc.html) - [Contributor Documentation](development/README.md) - [FAQ](faq.md) diff --git a/docs/telemetry.md b/docs/telemetry.md index d6ff9b20bc1..256adf00a49 100644 --- a/docs/telemetry.md +++ b/docs/telemetry.md @@ -20,5 +20,5 @@ System.setProperty("OPT_OUT_TRACKING", "true") Usage tracking is also disable in `offline` mode: ```java -System.setProperty("offline", "true") +System.setProperty("ai.djl.offline", "true") ``` diff --git a/engines/llama/.gitignore b/engines/llama/.gitignore new file mode 100644 index 00000000000..3428b3b2f53 --- /dev/null +++ b/engines/llama/.gitignore @@ -0,0 +1,3 @@ +jnilib/ +llama.cpp/ +models/ diff --git a/engines/llama/CMakeLists.txt b/engines/llama/CMakeLists.txt new file mode 100644 index 00000000000..d1fc8131db8 --- /dev/null +++ b/engines/llama/CMakeLists.txt @@ -0,0 +1,23 @@ +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) + +project(djl_llama CXX) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(BUILD_SHARED_LIBS ON) + +set(JAVA_AWT_LIBRARY NotNeeded) +set(JAVA_AWT_INCLUDE_PATH NotNeeded) +find_package(JNI REQUIRED) + +add_subdirectory(llama.cpp) +include(build-args.cmake) +add_library(djl_llama SHARED src/main/native/ai_djl_llama.cpp) + +target_include_directories(djl_llama PRIVATE + ${JNI_INCLUDE_DIRS} + src/main/native + llama.cpp + llama.cpp/common + build/include) +target_link_libraries(djl_llama PRIVATE common llama ${LLAMA_EXTRA_LIBS}) +target_compile_features(djl_llama PRIVATE cxx_std_11) diff --git a/engines/llama/build-args.cmake b/engines/llama/build-args.cmake new file mode 100644 index 00000000000..dee0db659cd --- /dev/null +++ b/engines/llama/build-args.cmake @@ -0,0 +1,639 @@ +if (APPLE) + set(LLAMA_METAL_DEFAULT ON) +else() + set(LLAMA_METAL_DEFAULT OFF) +endif() + +# general +option(LLAMA_NATIVE "llama: enable -march=native flag" ON) + +# instruction set specific +if (LLAMA_NATIVE) + set(INS_ENB OFF) +else() + set(INS_ENB ON) +endif() + +option(LLAMA_AVX "llama: enable AVX" ${INS_ENB}) +option(LLAMA_AVX2 "llama: enable AVX2" ${INS_ENB}) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ${INS_ENB}) +# in MSVC F16C is implied with AVX2/AVX512 +if (NOT MSVC) + option(LLAMA_F16C "llama: enable F16C" ${INS_ENB}) +endif() + +# 3rd party libs +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) +set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") +option(LLAMA_CUBLAS "llama: use CUDA" OFF) +#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) +option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) +option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") +option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) +set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") +set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING + "llama: max. batch size for using peer access") +option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) +option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) +option(LLAMA_MPI "llama: use MPI" OFF) +option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) + + +# +# Compile flags +# + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED true) +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) +include(CheckCXXCompilerFlag) + +# enable libstdc++ assertions for debug builds +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_compile_definitions($<$:_GLIBCXX_ASSERTIONS>) +endif() + +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + add_compile_options(-fsanitize=thread) + link_libraries(-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries(-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + add_compile_options(-fsanitize=undefined) + link_libraries(-fsanitize=undefined) + endif() +endif() + +if (APPLE AND LLAMA_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + + add_compile_definitions(GGML_USE_ACCELERATE) + add_compile_definitions(ACCELERATE_NEW_LAPACK) + add_compile_definitions(ACCELERATE_LAPACK_ILP64) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found") + endif() +endif() + +if (LLAMA_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + + message(STATUS "Metal framework found") + set(GGML_HEADERS_METAL ggml-metal.h) + set(GGML_SOURCES_METAL ggml-metal.m) + + add_compile_definitions(GGML_USE_METAL) + if (LLAMA_METAL_NDEBUG) + add_compile_definitions(GGML_METAL_NDEBUG) + endif() + + # get full path to the file + #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ) +endif() +if (LLAMA_BLAS) + if (LLAMA_STATIC) + set(BLA_STATIC ON) + endif() + if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) + set(BLA_SIZEOF_INTEGER 8) + endif() + + set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) + find_package(BLAS) + + if (BLAS_FOUND) + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + + if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") + # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. + # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 + find_package(PkgConfig REQUIRED) + if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") + pkg_check_modules(DepBLAS REQUIRED blas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") + pkg_check_modules(DepBLAS REQUIRED openblas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") + pkg_check_modules(DepBLAS REQUIRED blis) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") + pkg_check_modules(DepBLAS REQUIRED blas-atlas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") + pkg_check_modules(DepBLAS REQUIRED flexiblas_api) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") + # all Intel* libraries share the same include path + pkg_check_modules(DepBLAS REQUIRED mkl-sdl) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") + # this doesn't provide pkg-config + # suggest to assign BLAS_INCLUDE_DIRS on your own + if ("${NVHPC_VERSION}" STREQUAL "") + message(WARNING "Better to set NVHPC_VERSION") + else() + set(DepBLAS_FOUND ON) + set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") + endif() + endif() + if (DepBLAS_FOUND) + set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" + " detected by pkgconfig, trying to find cblas.h from possible paths...") + find_path(BLAS_INCLUDE_DIRS + NAMES cblas.h + HINTS + /usr/include + /usr/local/include + /usr/include/openblas + /opt/homebrew/opt/openblas/include + /usr/local/opt/openblas/include + /usr/include/x86_64-linux-gnu/openblas/include + ) + endif() + endif() + + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + add_compile_options(${BLAS_LINKER_FLAGS}) + add_compile_definitions(GGML_USE_OPENBLAS) + if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) + add_compile_definitions(GGML_BLAS_USE_MKL) + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) + + else() + message(WARNING "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct LLAMA_BLAS_VENDOR") + endif() +endif() + +if (LLAMA_QKK_64) + add_compile_definitions(GGML_QKK_64) +endif() + +if (LLAMA_CUBLAS) + cmake_minimum_required(VERSION 3.17) + + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + enable_language(CUDA) + + set(GGML_HEADERS_CUDA ggml-cuda.h) + set(GGML_SOURCES_CUDA ggml-cuda.cu) + + add_compile_definitions(GGML_USE_CUBLAS) +# if (LLAMA_CUDA_CUBLAS) +# add_compile_definitions(GGML_CUDA_CUBLAS) +# endif() + if (LLAMA_CUDA_FORCE_DMMV) + add_compile_definitions(GGML_CUDA_FORCE_DMMV) + endif() + if (LLAMA_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() + add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + if (DEFINED LLAMA_CUDA_DMMV_Y) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility + endif() + if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) + add_compile_definitions(GGML_CUDA_F16) + endif() + add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE}) + + if (LLAMA_STATIC) + if (WIN32) + # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) + else () + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + endif() + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + # 52 == lowest CUDA 12 standard + # 60 == f16 CUDA intrinsics + # 61 == integer CUDA intrinsics + # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster + if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) + set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics + else() + set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics + #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work + endif() + endif() + message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + + else() + message(WARNING "cuBLAS not found") + endif() +endif() + +if (LLAMA_MPI) + cmake_minimum_required(VERSION 3.10) + find_package(MPI) + if (MPI_C_FOUND) + message(STATUS "MPI found") + set(GGML_HEADERS_MPI ggml-mpi.h) + set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) + add_compile_definitions(GGML_USE_MPI) + add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) + if (NOT MSVC) + add_compile_options(-Wno-cast-qual) + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) + # Even if you're only using the C header, C++ programs may bring in MPI + # C++ functions, so more linkage is needed + if (MPI_CXX_FOUND) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_CXX_LIBRARIES}) + endif() + else() + message(WARNING "MPI not found") + endif() +endif() + +if (LLAMA_CLBLAST) + find_package(CLBlast) + if (CLBlast_FOUND) + message(STATUS "CLBlast found") + + set(GGML_HEADERS_OPENCL ggml-opencl.h) + set(GGML_SOURCES_OPENCL ggml-opencl.cpp) + + add_compile_definitions(GGML_USE_CLBLAST) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) + else() + message(WARNING "CLBlast not found") + endif() +endif() + +if (LLAMA_HIPBLAS) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) + + if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") + endif() + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + find_package(hip) + find_package(hipblas) + find_package(rocblas) + + if (${hipblas_FOUND} AND ${hip_FOUND}) + message(STATUS "HIP and hipBLAS found") + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) + add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) + if (BUILD_SHARED_LIBS) + set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + if (LLAMA_CUDA_FORCE_DMMV) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) + endif() + if (LLAMA_CUDA_FORCE_MMQ) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_MMQ) + endif() + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) + target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + + if (LLAMA_STATIC) + message(FATAL_ERROR "Static linking not supported for HIP/ROCm") + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm) + else() + message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") + endif() +endif() + +function(get_flags CCID CCVER) + set(C_FLAGS "") + set(CXX_FLAGS "") + + if (CCID MATCHES "Clang") + set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return) + set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi) + + if ( + (CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR + (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0) + ) + set(C_FLAGS ${C_FLAGS} -Wdouble-promotion) + endif() + elseif (CCID STREQUAL "GNU") + set(C_FLAGS -Wdouble-promotion) + set(CXX_FLAGS -Wno-array-bounds) + + if (CCVER VERSION_GREATER_EQUAL 7.1.0) + set(CXX_FLAGS ${CXX_FLAGS} -Wno-format-truncation) + endif() + if (CCVER VERSION_GREATER_EQUAL 8.1.0) + set(CXX_FLAGS ${CXX_FLAGS} -Wextra-semi) + endif() + endif() + + set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE) + set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE) +endfunction() + +if (LLAMA_ALL_WARNINGS) + if (NOT MSVC) + set(WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) + set(C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes + -Werror=implicit-int -Werror=implicit-function-declaration) + set(CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn) + + set(C_FLAGS ${WARNING_FLAGS} ${C_FLAGS}) + set(CXX_FLAGS ${WARNING_FLAGS} ${CXX_FLAGS}) + + get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}) + + add_compile_options("$<$:${C_FLAGS};${GF_C_FLAGS}>" + "$<$:${CXX_FLAGS};${GF_CXX_FLAGS}>") + else() + # todo : msvc + set(C_FLAGS "") + set(CXX_FLAGS "") + endif() +endif() + +if (LLAMA_CUBLAS) + set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math) + if (NOT MSVC) + set(CUDA_FLAGS ${CUDA_FLAGS} -Wno-pedantic) + endif() + + if (LLAMA_ALL_WARNINGS AND NOT MSVC) + set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c) + if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "") + set(NVCC_CMD ${NVCC_CMD} -ccbin ${CMAKE_CUDA_HOST_COMPILER}) + endif() + + execute_process( + COMMAND ${NVCC_CMD} -Xcompiler --version + OUTPUT_VARIABLE CUDA_CCFULLVER + ERROR_QUIET + ) + + if (NOT CUDA_CCFULLVER MATCHES clang) + set(CUDA_CCID "GNU") + execute_process( + COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion" + OUTPUT_VARIABLE CUDA_CCVER + ERROR_QUIET + ) + else() + if (CUDA_CCFULLVER MATCHES Apple) + set(CUDA_CCID "AppleClang") + else() + set(CUDA_CCID "Clang") + endif() + string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER}) + endif() + + message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") + + get_flags(${CUDA_CCID} ${CUDA_CCVER}) + list(JOIN GF_CXX_FLAGS " " CUDA_CXX_FLAGS) # pass host compiler flags as a single argument + if (NOT CUDA_CXX_FLAGS STREQUAL "") + set(CUDA_FLAGS ${CUDA_FLAGS} -Xcompiler ${CUDA_CXX_FLAGS}) + endif() + endif() + + add_compile_options("$<$:${CUDA_FLAGS}>") +endif() + +if (WIN32) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) + + if (BUILD_SHARED_LIBS) + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) + endif() +endif() + +if (LLAMA_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT result OUTPUT output) + if (result) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "IPO is not supported: ${output}") + endif() +endif() + +# this version of Apple ld64 is buggy +execute_process( + COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v + ERROR_VARIABLE output + OUTPUT_QUIET +) +if (output MATCHES "dyld-1015\.7") + add_compile_definitions(HAVE_BUGGY_APPLE_LINKER) +endif() + +# Architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +if (MSVC) + string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR) + message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}") +else () + set(CMAKE_GENERATOR_PLATFORM_LWR "") +endif () + +if (NOT MSVC) + if (LLAMA_STATIC) + add_link_options(-static) + if (MINGW) + add_link_options(-static-libgcc -static-libstdc++) + endif() + endif() + if (LLAMA_GPROF) + add_compile_options(-pg) + endif() +endif() + +if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") OR ("${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "arm64")) + message(STATUS "ARM detected") + if (MSVC) + add_compile_definitions(__ARM_NEON) + add_compile_definitions(__ARM_FEATURE_FMA) + add_compile_definitions(__ARM_FEATURE_DOTPROD) + # add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) # MSVC doesn't support vdupq_n_f16, vld1q_f16, vst1q_f16 + add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead + else() + check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E) + if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") + add_compile_options(-mfp16-format=ieee) + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") + # Raspberry Pi 1, Zero + add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access) + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") + # Raspberry Pi 2 + add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations) + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") + # Raspberry Pi 3, 4, Zero 2 (32-bit) + add_compile_options(-mno-unaligned-access) + endif() + endif() +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" ) + message(STATUS "x86 detected") + if (MSVC) + # instruction set detection for MSVC only + if (LLAMA_NATIVE) + include(${llama.cpp_SOURCE_DIR}/cmake/FindSIMD.cmake) + endif () + if (LLAMA_AVX512) + add_compile_options($<$:/arch:AVX512>) + add_compile_options($<$:/arch:AVX512>) + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + if (LLAMA_AVX512_VBMI) + add_compile_definitions($<$:__AVX512VBMI__>) + add_compile_definitions($<$:__AVX512VBMI__>) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_definitions($<$:__AVX512VNNI__>) + add_compile_definitions($<$:__AVX512VNNI__>) + endif() + elseif (LLAMA_AVX2) + add_compile_options($<$:/arch:AVX2>) + add_compile_options($<$:/arch:AVX2>) + elseif (LLAMA_AVX) + add_compile_options($<$:/arch:AVX>) + add_compile_options($<$:/arch:AVX>) + endif() + else() + if (LLAMA_NATIVE) + add_compile_options(-march=native) + endif() + if (LLAMA_F16C) + add_compile_options(-mf16c) + endif() + if (LLAMA_FMA) + add_compile_options(-mfma) + endif() + if (LLAMA_AVX) + add_compile_options(-mavx) + endif() + if (LLAMA_AVX2) + add_compile_options(-mavx2) + endif() + if (LLAMA_AVX512) + add_compile_options(-mavx512f) + add_compile_options(-mavx512bw) + endif() + if (LLAMA_AVX512_VBMI) + add_compile_options(-mavx512vbmi) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_options(-mavx512vnni) + endif() + endif() +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") + message(STATUS "PowerPC detected") + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") + add_compile_options(-mcpu=powerpc64le) + else() + add_compile_options(-mcpu=native -mtune=native) + #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) + endif() +else() + message(STATUS "Unknown architecture") +endif() + +if (MINGW) + # Target Windows 8 for PrefetchVirtualMemory + add_compile_definitions(_WIN32_WINNT=0x602) +endif() + +# +# POSIX conformance +# + +# clock_gettime came in POSIX.1b (1993) +# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional +# posix_memalign came in POSIX.1-2001 / SUSv3 +# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) +add_compile_definitions(_XOPEN_SOURCE=600) + +# Somehow in OpenBSD whenever POSIX conformance is specified +# some string functions rely on locale_t availability, +# which was introduced in POSIX.1-2008, forcing us to go higher +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + remove_definitions(-D_XOPEN_SOURCE=600) + add_compile_definitions(_XOPEN_SOURCE=700) +endif() + +# Data types, macros and functions related to controlling CPU affinity and +# some memory allocation are available on Linux through GNU extensions in libc +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_compile_definitions(_GNU_SOURCE) +endif() + +# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, +# and on macOS its availability depends on enabling Darwin extensions +# similarly on DragonFly, enabling BSD extensions is necessary +if ( + CMAKE_SYSTEM_NAME MATCHES "Darwin" OR + CMAKE_SYSTEM_NAME MATCHES "iOS" OR + CMAKE_SYSTEM_NAME MATCHES "tvOS" OR + CMAKE_SYSTEM_NAME MATCHES "DragonFly" +) + add_compile_definitions(_DARWIN_C_SOURCE) +endif() + +# alloca is a non-standard interface that is not visible on BSDs when +# POSIX conformance is specified, but not all of them provide a clean way +# to enable it in such cases +if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD") + add_compile_definitions(__BSD_VISIBLE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "NetBSD") + add_compile_definitions(_NETBSD_SOURCE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + add_compile_definitions(_BSD_SOURCE) +endif() diff --git a/engines/llama/build.cmd b/engines/llama/build.cmd new file mode 100644 index 00000000000..83ccf65198c --- /dev/null +++ b/engines/llama/build.cmd @@ -0,0 +1,23 @@ +@rem https://chocolatey.org/docs/installation#install-with-cmdexe +@rem to install rust java etc.. +@rem choco install jdk17 -y + +set VERSION="%1" + +if exist "llama.cpp" ( + echo Found "llama.cpp" +) else ( + git clone https://github.com/ggerganov/llama.cpp.git -b %VERSION% +) + +if exist build rd /q /s build +md build\classes +cd build +javac -classpath "%2" -sourcepath ..\src\main\java\ ..\src\main\java\ai\djl\llama\jni\LlamaLibrary.java -h include -d classes +cmake .. +cmake --build . --config Release + +@rem for nightly ci +md jnilib\win-x86_64 +copy Release\djl_llama.dll jnilib\win-x86_64\ +copy bin\Release\llama.dll jnilib\win-x86_64\ diff --git a/engines/llama/build.gradle b/engines/llama/build.gradle new file mode 100644 index 00000000000..73feb62fc5e --- /dev/null +++ b/engines/llama/build.gradle @@ -0,0 +1,108 @@ +import java.util.zip.GZIPInputStream + +group "ai.djl.llama" + +dependencies { + api project(":api") + + testImplementation project(":testing") + testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" +} + +compileJava.dependsOn(processResources) + +processResources { + outputs.dir file("${project.projectDir}/build/classes/java/main/native/lib") + doLast { + def url = "https://publish.djl.ai/llama/${llamacpp_version}/jnilib/${djl_version}" + def files = new String[]{ + "linux-x86_64/libdjl_llama.so", + "linux-x86_64/libllama.so", + "linux-aarch64/libdjl_llama.so", + "linux-aarch64/libllama.so", + "osx-x86_64/libdjl_llama.dylib", + "osx-x86_64/libllama.dylib", + "osx-x86_64/ggml-metal.metal", + "osx-aarch64/libdjl_llama.dylib", + "osx-aarch64/libllama.dylib", + "osx-aarch64/ggml-metal.metal", + "win-x86_64/djl_llama.dll", + "win-x86_64/llama.dll", + } + def jnilibDir = "${project.projectDir}/jnilib/${djl_version}" + files.each { entry -> + def file = new File("${jnilibDir}/${entry}") + if (file.exists()) { + project.logger.lifecycle("prebuilt or cached file found for ${entry}") + } else if (!project.hasProperty("jni")) { + project.logger.lifecycle("Downloading ${url}/${entry}") + file.getParentFile().mkdirs() + def downloadPath = new URL("${url}/${entry}") + downloadPath.withInputStream { i -> file.withOutputStream { it << i } } + } + } + copy { + from jnilibDir + into "${project.projectDir}/build/classes/java/main/native/lib" + } + + // write properties + def propFile = file("${project.projectDir}/build/classes/java/main/native/lib/llama.properties") + propFile.text = "version=${llamacpp_version}-${version}\n" + + url = "https://mlrepo.djl.ai/model/nlp/text_generation/ai/djl/huggingface/gguf/models.json.gz" + def prefix = "${project.projectDir}/build/classes/java/main/nlp/text_generation" + def file = new File("${prefix}/ai.djl.huggingface.gguf.json") + if (file.exists()) { + project.logger.lifecycle("gguf index file already exists") + } else { + project.logger.lifecycle("Downloading gguf index file") + file.getParentFile().mkdirs() + def downloadPath = new URL(url) + downloadPath.withInputStream { i -> file.withOutputStream { it << new GZIPInputStream(i) } } + } + } +} + +publishing { + publications { + maven(MavenPublication) { + pom { + name = "DJL NLP utilities for Llama.cpp" + description = "Deep Java Library (DJL) NLP utilities for llama.cpp" + url = "http://www.djl.ai/engines/${project.name}" + } + } + } +} + +apply from: file("${rootProject.projectDir}/tools/gradle/cpp-formatter.gradle") + +tasks.register('compileJNI') { + doFirst { + def cp = configurations.runtimeClasspath.resolve().stream().map {f->f.toString()}.toList() + if (System.properties['os.name'].toLowerCase(Locale.ROOT).contains("mac") + || System.properties['os.name'].toLowerCase(Locale.ROOT).contains("linux")) { + def arch = System.properties["os.arch"] == "amd64" ? "x86_64" : System.properties["os.arch"] + exec { + commandLine "bash", "build.sh", llamacpp_version, arch, String.join(":", cp) + } + } else { + exec { + commandLine "${project.projectDir}/build.cmd", llamacpp_version, String.join(";", cp) + } + } + + // for ci to upload to S3 + def ciDir = "${project.projectDir}/jnilib/${djl_version}/" + copy { + from "${project.projectDir}/build/jnilib" + into ciDir + } + delete System.getProperty("user.home") + "/.djl.ai/llama" + } +} + +clean.doFirst { + delete System.getProperty("user.home") + "/.djl.ai/llama" +} diff --git a/engines/llama/build.sh b/engines/llama/build.sh new file mode 100755 index 00000000000..1b6e7d4e1fa --- /dev/null +++ b/engines/llama/build.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +set -e +WORK_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NUM_PROC=1 +if [[ -n $(command -v nproc) ]]; then + NUM_PROC=$(nproc) +elif [[ -n $(command -v sysctl) ]]; then + NUM_PROC=$(sysctl -n hw.ncpu) +fi +PLATFORM=$(uname | tr '[:upper:]' '[:lower:]') + +VERSION=$1 +ARCH=$2 +CLASSPATH=$3 + +pushd $WORK_DIR +if [ ! -d "llama.cpp" ]; then + git clone https://github.com/ggerganov/llama.cpp.git -b $VERSION +fi + +if [ ! -d "build" ]; then + mkdir build +fi +cd build + +rm -rf classes +mkdir classes +javac -classpath $CLASSPATH -sourcepath ../src/main/java/:../../../api/src/main/java ../src/main/java/ai/djl/llama/jni/LlamaLibrary.java -h include -d classes +cmake .. +cmake --build . --config Release -- -j "${NUM_PROC}" + +popd + +# for nightly ci +if [[ $PLATFORM == 'darwin' ]]; then + mkdir -p build/jnilib/osx-$ARCH + cp -f build/libdjl_llama.dylib build/jnilib/osx-$ARCH/ + cp -f build/llama.cpp/libllama.dylib build/jnilib/osx-$ARCH/ + cp -f llama.cpp/ggml-metal.metal build/jnilib/osx-$ARCH/ +elif [[ $PLATFORM == 'linux' ]]; then + mkdir -p build/jnilib/linux-$ARCH + cp -f build/libdjl_llama.so build/jnilib/linux-$ARCH/ + cp -f build/llama.cpp/libllama.so build/jnilib/linux-$ARCH/ +fi diff --git a/engines/llama/gradlew b/engines/llama/gradlew new file mode 120000 index 00000000000..343e0d2caa4 --- /dev/null +++ b/engines/llama/gradlew @@ -0,0 +1 @@ +../../gradlew \ No newline at end of file diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java new file mode 100644 index 00000000000..75fdf5a5d8c --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java @@ -0,0 +1,110 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.llama.engine; + +import ai.djl.Device; +import ai.djl.Model; +import ai.djl.engine.Engine; +import ai.djl.engine.EngineException; +import ai.djl.llama.jni.LibUtils; +import ai.djl.ndarray.NDManager; +import ai.djl.util.Platform; +import ai.djl.util.passthrough.PassthroughNDManager; + +/** The {@code LlamaEngine} is an implementation of the {@link Engine} based on the llama.cpp. */ +public final class LlamaEngine extends Engine { + + public static final String ENGINE_NAME = "Llama"; + static final int RANK = 10; + + private Engine alternativeEngine; + private boolean initialized; + + private LlamaEngine() { + try { + LibUtils.loadLibrary(); + } catch (EngineException e) { // NOPMD + throw e; + } catch (Throwable t) { + throw new EngineException("Failed to load llama.cpp native library", t); + } + } + + static Engine newInstance() { + return new LlamaEngine(); + } + + /** {@inheritDoc} */ + @Override + public Engine getAlternativeEngine() { + if (!initialized && !Boolean.getBoolean("ai.djl.llama.disable_alternative")) { + Engine engine = Engine.getInstance(); + if (engine.getRank() < getRank()) { + // alternativeEngine should not have the same rank as Llama + alternativeEngine = engine; + } + initialized = true; + } + return alternativeEngine; + } + + /** {@inheritDoc} */ + @Override + public String getEngineName() { + return ENGINE_NAME; + } + + /** {@inheritDoc} */ + @Override + public int getRank() { + return RANK; + } + + /** {@inheritDoc} */ + @Override + public String getVersion() { + Platform platform = Platform.detectPlatform("llama"); + return platform.getVersion(); + } + + /** {@inheritDoc} */ + @Override + public boolean hasCapability(String capability) { + return false; + } + + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device) { + return new LlamaModel(name, newBaseManager(device)); + } + + /** {@inheritDoc} */ + @Override + public NDManager newBaseManager() { + return newBaseManager(null); + } + + /** {@inheritDoc} */ + @Override + public NDManager newBaseManager(Device device) { + return PassthroughNDManager.INSTANCE; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return getEngineName() + ':' + getVersion() + ", " + getEngineName() + ':' + getVersion(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java new file mode 100644 index 00000000000..ca5cc646498 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.engine.Engine; +import ai.djl.engine.EngineProvider; + +/** {@code LlamaEngineProvider} is the Llama implementation of {@link EngineProvider}. */ +public class LlamaEngineProvider implements EngineProvider { + + /** {@inheritDoc} */ + @Override + public String getEngineName() { + return LlamaEngine.ENGINE_NAME; + } + + /** {@inheritDoc} */ + @Override + public int getEngineRank() { + return LlamaEngine.RANK; + } + + /** {@inheritDoc} */ + @Override + public Engine getEngine() { + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = LlamaEngine.newInstance(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java new file mode 100644 index 00000000000..4b4d332fc9f --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java @@ -0,0 +1,430 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.llama.jni.InputParameters; + +import com.google.gson.annotations.SerializedName; + +import java.util.Map; + +/** A class hold input data for Llama model. */ +public class LlamaInput { + + private String inputs; + private String prefix; + private String suffix; + private Parameters parameters; + + /** + * Returns the input prompt. + * + * @return the input prompt + */ + public String getInputs() { + return inputs; + } + + /** + * Sets the input prompt. + * + * @param inputs the input prompt + */ + public void setInputs(String inputs) { + this.inputs = inputs; + } + + /** + * Returns the prompt prefix. + * + * @return the prompt prefix + */ + public String getPrefix() { + return prefix; + } + + /** + * Sets the prompt prefix. + * + * @param prefix the prompt prefix + */ + public void setPrefix(String prefix) { + this.prefix = prefix; + } + + /** + * Returns the prompt suffix. + * + * @return the prompt suffix + */ + public String getSuffix() { + return suffix; + } + + /** + * Sets the prompt suffix. + * + * @param suffix the prompt suffix + */ + public void setSuffix(String suffix) { + this.suffix = suffix; + } + + /** + * Returns the input parameters. + * + * @return the input parameters + */ + public Parameters getParameters() { + if (parameters == null) { + parameters = new Parameters(); + } + return parameters; + } + + /** + * Sets the input parameters. + * + * @param parameters the input parameters + */ + public void setParameters(Parameters parameters) { + this.parameters = parameters; + } + + /** The input parameters class. */ + public static final class Parameters { + + @SerializedName("max_new_tokens") + private int nPredict; + + @SerializedName("number_keep") + private int nKeep; + + @SerializedName("number_probabilities") + private int nProbs; + + @SerializedName("top_k") + private int topK; + + @SerializedName("top_p") + private float topP; + + @SerializedName("tfs_z") + private float tfsZ; + + @SerializedName("typical_p") + private float typicalP; + + @SerializedName("temperature") + private float temperature; + + @SerializedName("repeat_penalty") + private float repeatPenalty; + + @SerializedName("repeat_last_n") + private int repeatLastN; + + @SerializedName("frequency_penalty") + private float frequencyPenalty; + + @SerializedName("presence_penalty") + private float presencePenalty; + + @SerializedName("penalize_nl") + private boolean penalizeNl; + + @SerializedName("ignore_eos") + private boolean ignoreEos; + + @SerializedName("mirostat") + private int mirostat; + + @SerializedName("mirostat_tau") + private float mirostatTau; + + @SerializedName("mirostat_eta") + private float mirostatEta; + + @SerializedName("number_beams") + private int nBeams; + + @SerializedName("seed") + private int seed; + + @SerializedName("logit_bias") + private Map logitBias; + + @SerializedName("grammar") + private String grammar; + + @SerializedName("anti_prompt") + private String[] antiPrompt; + + /** + * Sets the max new tokens. + * + * @param maxNewTokens the max new tokens + */ + public void setMaxNewTokens(int maxNewTokens) { + this.nPredict = maxNewTokens; + } + + /** + * Sets the number of keep. + * + * @param nKeep the number of keep + */ + public void setNumberKeep(int nKeep) { + this.nKeep = nKeep; + } + + /** + * Sets the number of probabilities. + * + * @param nProbs the number of probabilities + */ + public void setNumberProbabilities(int nProbs) { + this.nProbs = nProbs; + } + + /** + * Sets the top K. + * + * @param topK the top K + */ + public void setTopK(int topK) { + this.topK = topK; + } + + /** + * Sets the top P. + * + * @param topP the top P + */ + public void setTopP(float topP) { + this.topP = topP; + } + + /** + * Sets the tfs Z. + * + * @param tfsZ the tfs Z + */ + public void setTfsZ(float tfsZ) { + this.tfsZ = tfsZ; + } + + /** + * Sets the typical P. + * + * @param typicalP the typical P + */ + public void setTypicalP(float typicalP) { + this.typicalP = typicalP; + } + + /** + * Sets the temperature. + * + * @param temperature the temperature + */ + public void setTemperature(float temperature) { + this.temperature = temperature; + } + + /** + * Sets the repeat penalty. + * + * @param repeatPenalty the repeat penalty + */ + public void setRepeatPenalty(float repeatPenalty) { + this.repeatPenalty = repeatPenalty; + } + + /** + * Sets the repeat last N. + * + * @param repeatLastN the repeat last N + */ + public void setRepeatLastN(int repeatLastN) { + this.repeatLastN = repeatLastN; + } + + /** + * Sets the frequency penalty. + * + * @param frequencyPenalty the frequency penalty + */ + public void setFrequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + /** + * Sets the presence penalty. + * + * @param presencePenalty the presence penalty + */ + public void setPresencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + /** + * Sets the penalize nl. + * + * @param penalizeNl the penalize nl + */ + public void setPenalizeNl(boolean penalizeNl) { + this.penalizeNl = penalizeNl; + } + + /** + * Sets if ignore EOS. + * + * @param ignoreEos if ignore EOS + */ + public void setIgnoreEos(boolean ignoreEos) { + this.ignoreEos = ignoreEos; + } + + /** + * Sets the mirostat. + * + * @param mirostat the mirostat + */ + public void setMirostat(int mirostat) { + this.mirostat = mirostat; + } + + /** + * Sets the mirostat TAU. + * + * @param mirostatTau the mirostat TAU + */ + public void setMirostatTau(float mirostatTau) { + this.mirostatTau = mirostatTau; + } + + /** + * Sets the mirostat ETA. + * + * @param mirostatEta the mirostat ETA + */ + public void setMirostatEta(float mirostatEta) { + this.mirostatEta = mirostatEta; + } + + /** + * Sets the number of beams. + * + * @param nBeams the number of beams + */ + public void setNumberBeams(int nBeams) { + this.nBeams = nBeams; + } + + /** + * Sets the seed. + * + * @param seed the seed + */ + public void setSeed(int seed) { + this.seed = seed; + } + + /** + * Sets the logit bias. + * + * @param logitBias the logit bias + */ + public void setLogitBias(Map logitBias) { + this.logitBias = logitBias; + } + + /** + * Sets the grammar template. + * + * @param grammar the grammar template + */ + public void setGrammar(String grammar) { + this.grammar = grammar; + } + + /** + * Sets the anti prompt. + * + * @param antiPrompt the anti prompt + */ + public void setAntiPrompt(String[] antiPrompt) { + this.antiPrompt = antiPrompt; + } + + /** + * Returns the {@link InputParameters} object. + * + * @return the {@link InputParameters} object + */ + public InputParameters toInputParameters() { + setDefaultValue(); + return new InputParameters( + nPredict, + nKeep, + nProbs, + topK, + topP, + tfsZ, + typicalP, + temperature, + repeatPenalty, + repeatLastN, + frequencyPenalty, + presencePenalty, + penalizeNl, + ignoreEos, + mirostat, + mirostatTau, + mirostatEta, + nBeams, + seed, + logitBias, + grammar, + antiPrompt); + } + + private void setDefaultValue() { + if (nPredict == 0) { + nPredict = -1; + } + if (topK == 0) { + topK = 40; + } + if (topP == 0) { + topP = 0.95f; + } + if (tfsZ == 0) { + tfsZ = 1f; + } + if (typicalP == 0) { + typicalP = 1f; + } + if (temperature == 0) { + temperature = 0.8f; + } + if (repeatPenalty == 0) { + repeatPenalty = 1.10f; + } + if (repeatLastN == 0) { + repeatLastN = 64; + } + } + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java new file mode 100644 index 00000000000..0ff3c6d70c0 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java @@ -0,0 +1,112 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.BaseModel; +import ai.djl.Model; +import ai.djl.llama.jni.LlamaLibrary; +import ai.djl.llama.jni.ModelParameters; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.nn.Blocks; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +/** {@code LlamaModel} is the llama.cpp implementation of {@link Model}. */ +public class LlamaModel extends BaseModel { + + private long handle = -1; + + /** + * Constructs a new Model on a given device. + * + * @param name the model name + * @param manager the {@link NDManager} to holds the NDArray + */ + LlamaModel(String name, NDManager manager) { + super(name); + this.manager = manager; + this.manager.setName("llamaModel"); + dataType = DataType.FLOAT32; + } + + /** {@inheritDoc} */ + @Override + public void load(Path modelPath, String prefix, Map options) throws IOException { + setModelDir(modelPath); + wasLoaded = true; + if (block != null) { + throw new UnsupportedOperationException("Llama does not support dynamic blocks"); + } + + if (prefix == null) { + prefix = modelName; + } + + // search for .onnx file with prefix, folder name or "model.onnx" + Path modelFile = findModelFile(prefix, modelDir.toFile().getName(), "model.gguf"); + if (modelFile == null) { + throw new FileNotFoundException(".gguf file not found in: " + modelPath); + } + + ModelParameters param = new ModelParameters(options); + handle = LlamaLibrary.loadModel(modelFile.toString(), param); + block = Blocks.identityBlock(); + } + + long getHandle() { + return handle; + } + + private Path findModelFile(String... prefixes) { + if (Files.isRegularFile(modelDir)) { + Path file = modelDir; + modelDir = modelDir.getParent(); + String fileName = file.toFile().getName(); + if (fileName.endsWith(".gguf")) { + modelName = fileName.substring(0, fileName.length() - 5); + } else { + modelName = fileName; + } + return file; + } + for (String prefix : prefixes) { + Path modelFile = modelDir.resolve(prefix); + if (Files.isRegularFile(modelFile)) { + return modelFile; + } + if (!prefix.endsWith(".gguf")) { + modelFile = modelDir.resolve(prefix + ".gguf"); + if (Files.isRegularFile(modelFile)) { + return modelFile; + } + } + } + return null; + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (handle == -1) { + return; + } + LlamaLibrary.delete(handle); + handle = -1; + super.close(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java new file mode 100644 index 00000000000..c8d3692b160 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java @@ -0,0 +1,107 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.inference.streaming.IteratorBytesSupplier; +import ai.djl.llama.jni.InputParameters; +import ai.djl.llama.jni.LlamaLibrary; +import ai.djl.llama.jni.Token; +import ai.djl.llama.jni.TokenIterator; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.BytesSupplier; +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; + +import java.util.Iterator; + +/** Built-in {@code Translator} that provides preprocessing and postprocessing for llama.cpp. */ +public class LlamaTranslator implements NoBatchifyTranslator { + + private long handle; + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) { + LlamaModel model = (LlamaModel) ctx.getModel(); + handle = model.getHandle(); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, I input) { + if (input instanceof String) { + ctx.setAttachment("out", generate((String) input)); + } else if (input instanceof LlamaInput) { + ctx.setAttachment("out", generate((LlamaInput) input)); + } else if (input instanceof Input) { + String prompt = ((Input) input).getData().getAsString(); + TokenIterator it = generate(prompt); + Output output = new Output(); + output.add(new IteratorBytesSupplier(new OutputIterator(it))); + ctx.setAttachment("out", output); + } + return new NDList(); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public O processOutput(TranslatorContext ctx, NDList list) { + return (O) ctx.getAttachment("out"); + } + + private TokenIterator generate(String input) { + LlamaInput in = JsonUtils.GSON.fromJson(input, LlamaInput.class); + return generate(in); + } + + private TokenIterator generate(LlamaInput in) { + InputParameters param = in.getParameters().toInputParameters(); + String prefix = in.getPrefix(); + String suffix = in.getSuffix(); + String inputs = in.getInputs(); + if (prefix != null && suffix != null) { + LlamaLibrary.infill(handle, prefix, prefix, param); + } else if (inputs != null && !inputs.isEmpty()) { + LlamaLibrary.generate(handle, inputs, param); + } else { + throw new IllegalArgumentException("Unsupported input format"); + } + return new TokenIterator(handle); + } + + private static final class OutputIterator implements Iterator { + + private TokenIterator it; + + public OutputIterator(TokenIterator it) { + this.it = it; + } + + /** {@inheritDoc} */ + @Override + public boolean hasNext() { + return it.hasNext(); + } + + /** {@inheritDoc} */ + @Override + public BytesSupplier next() { + Token token = it.next(); + return BytesSupplier.wrap(JsonUtils.GSON.toJson(token) + "\n"); + } + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java new file mode 100644 index 00000000000..089b5055b51 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.Model; +import ai.djl.llama.jni.TokenIterator; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.io.Serializable; +import java.lang.reflect.Type; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link LlamaTranslator} instance. */ +public class LlamaTranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(String.class, TokenIterator.class)); + SUPPORTED_TYPES.add(new Pair<>(LlamaInput.class, TokenIterator.class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + public boolean isSupported(Class input, Class output) { + return true; + } + + /** {@inheritDoc} */ + @Override + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + return new LlamaTranslator<>(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java b/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java new file mode 100644 index 00000000000..226e7a6ddb8 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes to interface with the underlying Llama Engine. */ +package ai.djl.llama.engine; diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java b/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java new file mode 100644 index 00000000000..d13abc5ef90 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java @@ -0,0 +1,314 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import java.util.Map; + +/** A class holds input parameters. */ +@SuppressWarnings({"PMD.UnusedPrivateField", "PMD.UnusedAssignment"}) +public class InputParameters { + + private int nPredict; + private int nKeep; + private int nProbs; + private int topK; + private float topP; + private float tfsZ; + private float typicalP; + private float temperature; + private float repeatPenalty; + private int repeatLastN; + private float frequencyPenalty; + private float presencePenalty; + private boolean penalizeNl; + private boolean ignoreEos; + private int mirostat; + private float mirostatTau; + private float mirostatEta; + private int nBeams; + private int seed; + private Map logitBias; + private String grammar; + private String[] antiPrompt; + + /** + * Constructs new {@code InputParameters} instance. + * + * @param nPredict the max new tokens + * @param nKeep the number of keep + * @param nProbs the number of probabilities + * @param topK the top K + * @param topP the top P + * @param tfsZ the tfs Z + * @param typicalP the typical P + * @param temperature the temperature + * @param repeatPenalty the repeat penalty + * @param repeatLastN the repeat last N + * @param frequencyPenalty the frequency penalty + * @param presencePenalty the presence penalty + * @param penalizeNl the penalize nl + * @param ignoreEos the ignore EOS + * @param mirostat the mirostat + * @param mirostatTau the mirostat TAU + * @param mirostatEta the mirostat ETA + * @param nBeams the number of beams + * @param seed the seed + * @param logitBias the logit bias + * @param grammar the grammar + * @param antiPrompt the anti prompt + */ + public InputParameters( + int nPredict, + int nKeep, + int nProbs, + int topK, + float topP, + float tfsZ, + float typicalP, + float temperature, + float repeatPenalty, + int repeatLastN, + float frequencyPenalty, + float presencePenalty, + boolean penalizeNl, + boolean ignoreEos, + int mirostat, + float mirostatTau, + float mirostatEta, + int nBeams, + int seed, + Map logitBias, + String grammar, + String[] antiPrompt) { + this.nPredict = nPredict; + this.nKeep = nKeep; + this.nProbs = nProbs; + this.topK = topK; + this.topP = topP; + this.tfsZ = tfsZ; + this.typicalP = typicalP; + this.temperature = temperature; + this.repeatPenalty = repeatPenalty; + this.repeatLastN = repeatLastN; + this.frequencyPenalty = frequencyPenalty; + this.presencePenalty = presencePenalty; + this.penalizeNl = penalizeNl; + this.ignoreEos = ignoreEos; + this.mirostat = mirostat; + this.mirostatTau = mirostatTau; + this.mirostatEta = mirostatEta; + this.nBeams = nBeams; + this.seed = seed; + this.logitBias = logitBias; + this.grammar = grammar; + this.antiPrompt = antiPrompt; + } + + /** + * Returns the max new tokens. + * + * @return the max new tokens + */ + public int getMaxNewTokens() { + return nPredict; + } + + /** + * Returns the number of keep. + * + * @return the number of keep + */ + public int getNumberKeep() { + return nKeep; + } + + /** + * Returns the number of probabilities. + * + * @return the number of probabilities + */ + public int getNumberProbabilities() { + return nProbs; + } + + /** + * Returns the top K. + * + * @return the top K + */ + public int getTopK() { + return topK; + } + + /** + * Return the top P. + * + * @return the top P + */ + public float getTopP() { + return topP; + } + + /** + * Return the TfsZ. + * + * @return the TfsZ + */ + public float getTfsZ() { + return tfsZ; + } + + /** + * Return the typical P. + * + * @return the typical P + */ + public float getTypicalP() { + return typicalP; + } + + /** + * Return the temperature. + * + * @return the temperature + */ + public float getTemperature() { + return temperature; + } + + /** + * Return the repeat penalty. + * + * @return the repeat penalty + */ + public float getRepeatPenalty() { + return repeatPenalty; + } + + /** + * Return the repeat last N. + * + * @return the repeat last N + */ + public int getRepeatLastN() { + return repeatLastN; + } + + /** + * Return the frequency penalty. + * + * @return the frequency penalty + */ + public float getFrequencyPenalty() { + return frequencyPenalty; + } + + /** + * Return the presence penalty. + * + * @return the presence penalty + */ + public float getPresencePenalty() { + return presencePenalty; + } + + /** + * Return the penalize NL. + * + * @return the penalize NL + */ + public boolean isPenalizeNl() { + return penalizeNl; + } + + /** + * Returns {@code true} if ignore EOS. + * + * @return {@code true} if ignore EOS + */ + public boolean isIgnoreEos() { + return ignoreEos; + } + + /** + * Returns the mirostat. + * + * @return the mirostat + */ + public int getMirostat() { + return mirostat; + } + + /** + * Returns the mirostat TAU. + * + * @return the mirostat TAU + */ + public float getMirostatTau() { + return mirostatTau; + } + + /** + * Returns the mirostat ETA. + * + * @return the mirostat ETA + */ + public float getMirostatEta() { + return mirostatEta; + } + + /** + * Returns the number of beams. + * + * @return the number of beams + */ + public int getNumberBeams() { + return nBeams; + } + + /** + * Returns the seed. + * + * @return the seed + */ + public int getSeed() { + return seed; + } + + /** + * Returns the logit bias. + * + * @return the logit bias + */ + public Map getLogitBias() { + return logitBias; + } + + /** + * Returns the grammar template. + * + * @return the grammar template + */ + public String getGrammar() { + return grammar; + } + + /** + * Returns the anti-prompt. + * + * @return the anti-prompt + */ + public String[] getAntiPrompt() { + return antiPrompt; + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java b/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java new file mode 100644 index 00000000000..d51a4fe2e5e --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java @@ -0,0 +1,99 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import ai.djl.util.ClassLoaderUtils; +import ai.djl.util.Platform; +import ai.djl.util.Utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.ArrayList; +import java.util.List; + +/** Utilities for finding the llama.cpp native binary on the System. */ +public final class LibUtils { + + private static final Logger logger = LoggerFactory.getLogger(LibUtils.class); + + private static final String LIB_NAME = System.mapLibraryName("djl_llama"); + private static final String LLAMA_NAME = System.mapLibraryName("llama"); + + private LibUtils() {} + + /** Loads llama.cpp native library. */ + public static void loadLibrary() { + List libs = new ArrayList<>(3); + libs.add(LLAMA_NAME); + libs.add(LIB_NAME); + if (System.getProperty("os.name").startsWith("Mac")) { + libs.add("ggml-metal.metal"); + } + Path dir = copyJniLibraryFromClasspath(libs.toArray(new String[0])); + logger.debug("Loading llama.cpp library from: {}", dir); + + for (int i = 0; i < 2; ++i) { + String lib = libs.get(i); + String path = dir.resolve(lib).toString(); + logger.debug("Loading native library: {}", path); + String nativeHelper = System.getProperty("ai.djl.llama.native_helper"); + if (nativeHelper != null && !nativeHelper.isEmpty()) { + ClassLoaderUtils.nativeLoad(nativeHelper, path); + } else { + System.load(path); // NOPMD + } + } + } + + private static Path copyJniLibraryFromClasspath(String... libs) { + Path cacheDir = Utils.getEngineCacheDir("llama"); + Platform platform = Platform.detectPlatform("llama"); + String classifier = platform.getClassifier(); + String version = platform.getVersion(); + Path dir = cacheDir.resolve(version + '-' + classifier); + Path path = dir.resolve(LIB_NAME); + logger.debug("Using cache dir: {}", dir); + if (Files.exists(path)) { + return dir.toAbsolutePath(); + } + + Path tmp = null; + try { + Files.createDirectories(cacheDir); + tmp = Files.createTempDirectory(cacheDir, "tmp"); + + for (String libName : libs) { + String libPath = "native/lib/" + classifier + "/" + libName; + logger.info("Extracting {} to cache ...", libPath); + try (InputStream is = ClassLoaderUtils.getResourceAsStream(libPath)) { + Path target = tmp.resolve(libName); + Files.copy(is, target, StandardCopyOption.REPLACE_EXISTING); + } + } + Utils.moveQuietly(tmp, dir); + return dir.toAbsolutePath(); + } catch (IOException e) { + throw new IllegalStateException("Cannot copy jni files", e); + } finally { + if (tmp != null) { + Utils.deleteQuietly(tmp); + } + } + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java b/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java new file mode 100644 index 00000000000..5d40fa29830 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +/** Native library for llama.cpp. */ +@SuppressWarnings("MissingJavadocMethod") +public final class LlamaLibrary { + + private LlamaLibrary() {} + + public static native long loadModel(String filePath, ModelParameters param); + + public static native void generate(long handle, String prompt, InputParameters param); + + public static native void infill( + long handle, String prefix, String suffix, InputParameters param); + + public static native Token getNext(long handle, long count, long pos); + + public static native float[] embed(long handle, String prompt); + + public static native int[] encode(long handle, String prompt); + + public static native byte[] decodeBytes(long handle, int[] tokens); + + public static native void delete(long handle); +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java b/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java new file mode 100644 index 00000000000..e3e440474a8 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java @@ -0,0 +1,114 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import java.util.Map; + +/** A class holds llama.cpp model loading parameters. */ +@SuppressWarnings("PMD.SingularField") +public final class ModelParameters { + + private int nThreads; + private int nCtx; + private int nBatch; + private int nGpuLayers; + private int mainGpu; + private float ropeFreqBase; + private float ropeFreqScale; + private boolean mulMatQ; + private boolean f16Kv; + private boolean logitsAll; + private boolean vocabOnly; + private boolean useMmap; + private boolean useMlock; + private boolean embedding; + private boolean memoryF16; + private boolean memTest; + private boolean numa; + private boolean verbosePrompt; + private float[] tensorSplit; + private String loraAdapter; + private String loraBase; + + /** + * Constructs a new {@code ModelParameters} instance. + * + * @param options the model loading options + */ + public ModelParameters(Map options) { + nThreads = intValue(options, "number_threads", Runtime.getRuntime().availableProcessors()); + nCtx = intValue(options, "max_context_length", 512); + nBatch = intValue(options, "max_rolling_batch", 512); + nGpuLayers = intValue(options, "number_gpu_layers", -1); + mainGpu = intValue(options, "tensor_parallel_degree", 0); + ropeFreqBase = floatValue(options, "rope_freq_base"); + ropeFreqScale = floatValue(options, "ropeFreqScale"); + f16Kv = booleanValue(options, "f16_kv"); + mulMatQ = booleanValue(options, "mulmat_q", true); + logitsAll = booleanValue(options, "logits_all"); + vocabOnly = booleanValue(options, "vocab_only"); + useMmap = booleanValue(options, "use_mmap", true); + useMlock = booleanValue(options, "use_mlock"); + embedding = booleanValue(options, "embedding"); + memoryF16 = booleanValue(options, "memory_f16", true); + memTest = booleanValue(options, "mem_test"); + numa = booleanValue(options, "numa"); + verbosePrompt = booleanValue(options, "verbose_prompt"); + String val = stringValue(options, "tensor_split"); + if (val != null && !val.isEmpty()) { + String[] tokens = val.split(","); + tensorSplit = new float[tokens.length]; + for (int i = 0; i < tokens.length; ++i) { + tensorSplit[i] = Float.parseFloat(tokens[i].trim()); + } + } + loraAdapter = stringValue(options, "lora_adapter"); + loraBase = stringValue(options, "loraBase"); + } + + private static int intValue(Map arguments, String key, int def) { + Object value = arguments.get(key); + if (value == null) { + return def; + } + return (int) Double.parseDouble(value.toString()); + } + + private static float floatValue(Map arguments, String key) { + Object value = arguments.get(key); + if (value == null) { + return 0f; + } + return (float) Double.parseDouble(value.toString()); + } + + private static boolean booleanValue(Map arguments, String key) { + return booleanValue(arguments, key, false); + } + + private static boolean booleanValue(Map arguments, String key, boolean def) { + Object value = arguments.get(key); + if (value == null) { + return def; + } + return Boolean.parseBoolean(value.toString()); + } + + private static String stringValue(Map arguments, String key) { + Object value = arguments.get(key); + if (value == null) { + return null; + } + return value.toString(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/Token.java b/engines/llama/src/main/java/ai/djl/llama/jni/Token.java new file mode 100644 index 00000000000..b8d74306b56 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/Token.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import ai.djl.util.JsonUtils; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** The output token class. */ +public final class Token { + + private int token; + private String text; + private Map probabilities; + transient long count; + transient long pos; + transient boolean hasNext; + + /** + * Constructs a new {@code Token} instance. + * + * @param token the token id + * @param generated the token text + * @param probabilities the token probabilities + * @param count the generated token count + * @param pos the token index + * @param hasNext has more tokens + */ + public Token( + int token, + byte[] generated, + Map probabilities, + long count, + long pos, + boolean hasNext) { + this.token = token; + this.text = new String(generated, StandardCharsets.UTF_8); + this.probabilities = probabilities; + this.count = count; + this.pos = pos; + this.hasNext = hasNext; + } + + /** + * Returns the token id. + * + * @return the token id + */ + public int getToken() { + return token; + } + + /** + * Returns the token text. + * + * @return the token text + */ + public String getText() { + return text; + } + + /** + * Returns the token probabilities. + * + * @return the token probabilities + */ + public Map getProbabilities() { + return probabilities; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return JsonUtils.GSON.toJson(this) + '\n'; + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java b/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java new file mode 100644 index 00000000000..cab6575d8f7 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicBoolean; + +/** A iterator class holds generated tokens. */ +public class TokenIterator implements Iterator { + + private static final Logger logger = LoggerFactory.getLogger(TokenIterator.class); + + private static AtomicBoolean active = new AtomicBoolean(); + + private long handle; + private long count; + private long pos; + private boolean hasNext; + + /** + * Constructs a new {@code TokenIterator} instance. + * + * @param handle the llama.cpp handle + */ + public TokenIterator(long handle) { + this.handle = handle; + hasNext = true; + if (!active.compareAndSet(false, true)) { + active.set(true); + logger.warn("Previous inference has been reset"); + } + } + + /** {@inheritDoc} */ + @Override + public boolean hasNext() { + return hasNext; + } + + /** {@inheritDoc} */ + @Override + public Token next() { + if (!hasNext) { + throw new NoSuchElementException(); + } + Token token = LlamaLibrary.getNext(handle, count, pos); + count = token.count; + pos = token.pos; + hasNext = token.hasNext; + if (!hasNext) { + active.set(false); + } + return token; + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java b/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java new file mode 100644 index 00000000000..6f429aceda2 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains classes to interface with the native llama.cpp code. */ +package ai.djl.llama.jni; diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java new file mode 100644 index 00000000000..91b6e55050a --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java @@ -0,0 +1,176 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.zoo; + +import ai.djl.Application; +import ai.djl.repository.Repository; +import ai.djl.repository.zoo.ModelLoader; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.util.ClassLoaderUtils; +import ai.djl.util.JsonUtils; +import ai.djl.util.Utils; + +import com.google.gson.reflect.TypeToken; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.io.Writer; +import java.lang.reflect.Type; +import java.net.URI; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.zip.GZIPInputStream; + +/** LlamaModelZoo is a repository that contains llama.cpp models. */ +public class LlamaModelZoo extends ModelZoo { + + private static final Logger logger = LoggerFactory.getLogger(LlamaModelZoo.class); + + private static final String REPO = "https://mlrepo.djl.ai/"; + private static final Repository REPOSITORY = Repository.newInstance("gguf", REPO); + private static final String GROUP_ID = "ai.djl.huggingface.gguf"; + + private static final long ONE_DAY = Duration.ofDays(1).toMillis(); + + private volatile boolean initialized; // NOPMD + + LlamaModelZoo() {} + + /** {@inheritDoc} */ + @Override + public String getGroupId() { + return GROUP_ID; + } + + /** {@inheritDoc} */ + @Override + public Set getSupportedEngines() { + return Collections.singleton("Llama"); + } + + /** {@inheritDoc} */ + @Override + public Collection getModelLoaders() { + init(); + return super.getModelLoaders(); + } + + /** {@inheritDoc} */ + @Override + public ModelLoader getModelLoader(String name) { + init(); + return super.getModelLoader(name); + } + + private void init() { + if (!initialized) { + synchronized (LlamaModelZoo.class) { + if (!initialized) { + Application app = Application.NLP.TEXT_GENERATION; + Map map = listModels(app); + for (Map.Entry entry : map.entrySet()) { + String artifactId = entry.getKey(); + Map gguf = entry.getValue().getGguf(); + if (gguf != null) { + for (String key : gguf.keySet()) { + addModel(REPOSITORY.model(app, GROUP_ID, artifactId, "0.0.1", key)); + } + } + } + initialized = true; + } + } + } + } + + private Map listModels(Application app) { + try { + String path = "model/" + app.getPath() + "/ai/djl/huggingface/gguf/"; + Path dir = Utils.getCacheDir().resolve("cache/repo/" + path); + if (Files.notExists(dir)) { + Files.createDirectories(dir); + } else if (!Files.isDirectory(dir)) { + logger.warn("Failed initialize cache directory: {}", dir); + return Collections.emptyMap(); + } + Type type = new TypeToken>() {}.getType(); + + Path file = dir.resolve("models.json"); + if (Files.exists(file)) { + long lastModified = Files.getLastModifiedTime(file).toMillis(); + if (Utils.isOfflineMode() || System.currentTimeMillis() - lastModified < ONE_DAY) { + try (Reader reader = Files.newBufferedReader(file)) { + return JsonUtils.GSON.fromJson(reader, type); + } + } + } + + URL url = URI.create(REPO).resolve(path + "models.json.gz").toURL(); + Path tmp = Files.createTempFile(dir, "models", ".tmp"); + try (GZIPInputStream gis = new GZIPInputStream(Utils.openUrl(url))) { + String json = Utils.toString(gis); + try (Writer writer = Files.newBufferedWriter(tmp)) { + writer.write(json); + } + Utils.moveQuietly(tmp, file); + return JsonUtils.GSON.fromJson(json, type); + } catch (IOException e) { + logger.warn("Failed to download Huggingface gguf index: {}", app); + if (Files.exists(file)) { + try (Reader reader = Files.newBufferedReader(file)) { + return JsonUtils.GSON.fromJson(reader, type); + } + } + + String resource = app.getPath() + "/" + GROUP_ID + ".json"; + try (InputStream is = ClassLoaderUtils.getResourceAsStream(resource)) { + String json = Utils.toString(is); + try (Writer writer = Files.newBufferedWriter(tmp)) { + writer.write(json); + } + Utils.moveQuietly(tmp, file); + return JsonUtils.GSON.fromJson(json, type); + } + } finally { + Utils.deleteQuietly(tmp); + } + } catch (IOException e) { + logger.warn("Failed load gguf index file", e); + } + + return Collections.emptyMap(); + } + + private static final class ModelDetail { + + private Map gguf; + + public Map getGguf() { + return gguf; + } + + public void setGguf(Map gguf) { + this.gguf = gguf; + } + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java new file mode 100644 index 00000000000..ba2b04722c1 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.zoo; + +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooProvider; + +/** + * An Huggingface llama.cpp model zoo provider implements the {@link + * ai.djl.repository.zoo.ZooProvider} interface. + */ +public class LlamaZooProvider implements ZooProvider { + + /** {@inheritDoc} */ + @Override + public ModelZoo getModelZoo() { + return new LlamaModelZoo(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java b/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java new file mode 100644 index 00000000000..a9c1df64cd0 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains the built-in {@link ai.djl.llama.zoo.LlamaModelZoo}. */ +package ai.djl.llama.zoo; diff --git a/engines/llama/src/main/javadoc/overview.html b/engines/llama/src/main/javadoc/overview.html new file mode 100644 index 00000000000..05dec7d0bd4 --- /dev/null +++ b/engines/llama/src/main/javadoc/overview.html @@ -0,0 +1,14 @@ + + + + + +

This document is the API specification for the Deep Java Library (DJL) Llama Engine.

+ +
+ + + diff --git a/engines/llama/src/main/native/ai_djl_llama.cpp b/engines/llama/src/main/native/ai_djl_llama.cpp new file mode 100644 index 00000000000..1d6072751f2 --- /dev/null +++ b/engines/llama/src/main/native/ai_djl_llama.cpp @@ -0,0 +1,1025 @@ +#include +#include +#include +#include + +#include "ai_djl_llama_jni_LlamaLibrary.h" +#include "common.h" +#include "grammar-parser.h" +#include "llama.h" +#include "sampling.h" + +// classes +static jclass c_lib_utils = 0; +static jclass c_model_params = 0; +static jclass c_input_params = 0; +static jclass c_token = 0; +static jclass c_standard_charsets = 0; +static jclass c_string = 0; +static jclass c_hash_map = 0; +static jclass c_map = 0; +static jclass c_set = 0; +static jclass c_entry = 0; +static jclass c_integer = 0; +static jclass c_float = 0; +static jclass c_logger = 0; +static jclass c_engine_exception = 0; + +// constructors +static jmethodID cc_token = 0; +static jmethodID cc_hash_map = 0; +static jmethodID cc_integer = 0; +static jmethodID cc_float = 0; + +// methods +static jmethodID m_get_bytes = 0; +static jmethodID m_entry_set = 0; +static jmethodID m_set_iterator = 0; +static jmethodID m_iterator_has_next = 0; +static jmethodID m_iterator_next = 0; +static jmethodID m_entry_key = 0; +static jmethodID m_entry_value = 0; +static jmethodID m_map_put = 0; +static jmethodID m_int_value = 0; +static jmethodID m_float_value = 0; +static jmethodID m_log_debug = 0; +static jmethodID m_log_info = 0; +static jmethodID m_log_warn = 0; +static jmethodID m_log_error = 0; + +// fields +static jfieldID f_logger = 0; +// inference parameters +static jfieldID f_n_predict = 0; +static jfieldID f_n_keep = 0; +static jfieldID f_n_probs = 0; +static jfieldID f_logit_bias = 0; +static jfieldID f_top_k = 0; +static jfieldID f_top_p = 0; +static jfieldID f_tfs_z = 0; +static jfieldID f_typical_p = 0; +static jfieldID f_temperature = 0; +static jfieldID f_repeat_penalty = 0; +static jfieldID f_repeat_last_n = 0; +static jfieldID f_frequency_penalty = 0; +static jfieldID f_presence_penalty = 0; +static jfieldID f_penalize_nl = 0; +static jfieldID f_ignore_eos = 0; +static jfieldID f_mirostat = 0; +static jfieldID f_mirostat_tau = 0; +static jfieldID f_mirostat_eta = 0; +static jfieldID f_n_beams = 0; +static jfieldID f_grammar = 0; +static jfieldID f_antiprompt = 0; +static jfieldID f_infer_seed = 0; +// model parameters +static jfieldID f_n_threads = 0; +static jfieldID f_n_ctx = 0; +static jfieldID f_n_batch = 0; +static jfieldID f_n_gpu_layers = 0; +static jfieldID f_main_gpu = 0; +static jfieldID f_tensor_split = 0; +static jfieldID f_rope_freq_base = 0; +static jfieldID f_rope_freq_scale = 0; +static jfieldID f_mul_mat_q = 0; +static jfieldID f_f16_kv = 0; +static jfieldID f_logits_all = 0; +static jfieldID f_vocab_only = 0; +static jfieldID f_use_mmap = 0; +static jfieldID f_use_mlock = 0; +static jfieldID f_embedding = 0; +static jfieldID f_lora_adapter = 0; +static jfieldID f_lora_base = 0; +static jfieldID f_memory_f16 = 0; +static jfieldID f_mem_test = 0; +static jfieldID f_numa = 0; +static jfieldID f_verbose_prompt = 0; +// log level +static jfieldID f_utf_8 = 0; +// objects +static jobject o_utf_8 = 0; +static jobject o_logger = 0; + +static JavaVM *g_vm = nullptr; + +static void null_log_callback(enum ggml_log_level level, const char *text, void *user_data) {} + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { + JNIEnv *env = 0; + + if (JNI_OK != vm->GetEnv((void **) &env, JNI_VERSION_1_1)) { + return JNI_ERR; + } + + log_disable(); + llama_log_set(null_log_callback, nullptr); + + // find classes + c_input_params = env->FindClass("ai/djl/llama/jni/InputParameters"); + c_model_params = env->FindClass("ai/djl/llama/jni/ModelParameters"); + c_lib_utils = env->FindClass("ai/djl/llama/jni/LibUtils"); + c_token = env->FindClass("ai/djl/llama/jni/Token"); + c_engine_exception = env->FindClass("ai/djl/engine/EngineException"); + c_logger = env->FindClass("org/slf4j/Logger"); + c_standard_charsets = env->FindClass("java/nio/charset/StandardCharsets"); + c_string = env->FindClass("java/lang/String"); + c_hash_map = env->FindClass("java/util/HashMap"); + c_map = env->FindClass("java/util/Map"); + c_set = env->FindClass("java/util/Set"); + c_entry = env->FindClass("java/util/Map$Entry"); + c_integer = env->FindClass("java/lang/Integer"); + c_float = env->FindClass("java/lang/Float"); + + // create references + c_input_params = (jclass) env->NewGlobalRef(c_input_params); + c_model_params = (jclass) env->NewGlobalRef(c_model_params); + c_lib_utils = (jclass) env->NewGlobalRef(c_lib_utils); + c_token = (jclass) env->NewGlobalRef(c_token); + c_engine_exception = (jclass) env->NewGlobalRef(c_engine_exception); + c_logger = (jclass) env->NewGlobalRef(c_logger); + c_string = (jclass) env->NewGlobalRef(c_string); + c_hash_map = (jclass) env->NewGlobalRef(c_hash_map); + c_map = (jclass) env->NewGlobalRef(c_map); + c_set = (jclass) env->NewGlobalRef(c_set); + c_entry = (jclass) env->NewGlobalRef(c_entry); + c_integer = (jclass) env->NewGlobalRef(c_integer); + c_float = (jclass) env->NewGlobalRef(c_float); + + // find constructors + cc_token = env->GetMethodID(c_token, "", "(I[BLjava/util/Map;JJZ)V"); + cc_hash_map = env->GetMethodID(c_hash_map, "", "()V"); + cc_integer = env->GetMethodID(c_integer, "", "(I)V"); + cc_float = env->GetMethodID(c_float, "", "(F)V"); + + // find methods + m_get_bytes = env->GetMethodID(c_string, "getBytes", "(Ljava/lang/String;)[B"); + m_entry_set = env->GetMethodID(c_map, "entrySet", "()Ljava/util/Set;"); + m_entry_key = env->GetMethodID(c_entry, "getKey", "()Ljava/lang/Object;"); + m_entry_value = env->GetMethodID(c_entry, "getValue", "()Ljava/lang/Object;"); + m_map_put = env->GetMethodID(c_map, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); + m_int_value = env->GetMethodID(c_integer, "intValue", "()I"); + m_float_value = env->GetMethodID(c_float, "floatValue", "()F"); + m_log_debug = env->GetMethodID(c_logger, "debug", "(Ljava/lang/String;)V"); + m_log_info = env->GetMethodID(c_logger, "info", "(Ljava/lang/String;)V"); + m_log_warn = env->GetMethodID(c_logger, "warn", "(Ljava/lang/String;)V"); + m_log_error = env->GetMethodID(c_logger, "error", "(Ljava/lang/String;)V"); + + // find fields + f_logger = env->GetStaticFieldID(c_lib_utils, "logger", "Lorg/slf4j/Logger;"); + + f_n_predict = env->GetFieldID(c_input_params, "nPredict", "I"); + f_n_keep = env->GetFieldID(c_input_params, "nKeep", "I"); + f_n_probs = env->GetFieldID(c_input_params, "nProbs", "I"); + f_logit_bias = env->GetFieldID(c_input_params, "logitBias", "Ljava/util/Map;"); + f_top_k = env->GetFieldID(c_input_params, "topK", "I"); + f_top_p = env->GetFieldID(c_input_params, "topP", "F"); + f_tfs_z = env->GetFieldID(c_input_params, "tfsZ", "F"); + f_typical_p = env->GetFieldID(c_input_params, "typicalP", "F"); + f_temperature = env->GetFieldID(c_input_params, "temperature", "F"); + f_repeat_penalty = env->GetFieldID(c_input_params, "repeatPenalty", "F"); + f_repeat_last_n = env->GetFieldID(c_input_params, "repeatLastN", "I"); + f_frequency_penalty = env->GetFieldID(c_input_params, "frequencyPenalty", "F"); + f_presence_penalty = env->GetFieldID(c_input_params, "presencePenalty", "F"); + f_penalize_nl = env->GetFieldID(c_input_params, "penalizeNl", "Z"); + f_ignore_eos = env->GetFieldID(c_input_params, "ignoreEos", "Z"); + f_mirostat = env->GetFieldID(c_input_params, "mirostat", "I"); + f_mirostat_tau = env->GetFieldID(c_input_params, "mirostatTau", "F"); + f_mirostat_eta = env->GetFieldID(c_input_params, "mirostatEta", "F"); + f_n_beams = env->GetFieldID(c_input_params, "nBeams", "I"); + f_grammar = env->GetFieldID(c_input_params, "grammar", "Ljava/lang/String;"); + f_antiprompt = env->GetFieldID(c_input_params, "antiPrompt", "[Ljava/lang/String;"); + f_infer_seed = env->GetFieldID(c_input_params, "seed", "I"); + + f_n_threads = env->GetFieldID(c_model_params, "nThreads", "I"); + f_n_ctx = env->GetFieldID(c_model_params, "nCtx", "I"); + f_n_batch = env->GetFieldID(c_model_params, "nBatch", "I"); + f_n_gpu_layers = env->GetFieldID(c_model_params, "nGpuLayers", "I"); + f_main_gpu = env->GetFieldID(c_model_params, "mainGpu", "I"); + f_tensor_split = env->GetFieldID(c_model_params, "tensorSplit", "[F"); + f_rope_freq_base = env->GetFieldID(c_model_params, "ropeFreqBase", "F"); + f_rope_freq_scale = env->GetFieldID(c_model_params, "ropeFreqScale", "F"); + f_mul_mat_q = env->GetFieldID(c_model_params, "mulMatQ", "Z"); + f_f16_kv = env->GetFieldID(c_model_params, "f16Kv", "Z"); + f_logits_all = env->GetFieldID(c_model_params, "logitsAll", "Z"); + f_vocab_only = env->GetFieldID(c_model_params, "vocabOnly", "Z"); + f_use_mmap = env->GetFieldID(c_model_params, "useMmap", "Z"); + f_use_mlock = env->GetFieldID(c_model_params, "useMlock", "Z"); + f_embedding = env->GetFieldID(c_model_params, "embedding", "Z"); + f_lora_adapter = env->GetFieldID(c_model_params, "loraAdapter", "Ljava/lang/String;"); + f_lora_base = env->GetFieldID(c_model_params, "loraBase", "Ljava/lang/String;"); + f_memory_f16 = env->GetFieldID(c_model_params, "memoryF16", "Z"); + f_mem_test = env->GetFieldID(c_model_params, "memTest", "Z"); + f_numa = env->GetFieldID(c_model_params, "numa", "Z"); + f_verbose_prompt = env->GetFieldID(c_model_params, "verbosePrompt", "Z"); + + f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;"); + o_utf_8 = env->NewStringUTF("UTF-8"); + o_utf_8 = (jobject) env->NewGlobalRef(o_utf_8); + o_logger = env->GetStaticObjectField(c_lib_utils, f_logger); + o_logger = (jobject) env->NewGlobalRef(o_logger); + + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + return JNI_ERR; + } + + return JNI_VERSION_1_1; +} + +JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { + JNIEnv *env = 0; + + if (JNI_OK != vm->GetEnv((void **) &env, JNI_VERSION_1_1)) { + return; + } + + env->DeleteGlobalRef(c_input_params); + env->DeleteGlobalRef(c_model_params); + env->DeleteGlobalRef(c_token); + env->DeleteGlobalRef(c_string); + env->DeleteGlobalRef(c_hash_map); + env->DeleteGlobalRef(c_map); + env->DeleteGlobalRef(c_set); + env->DeleteGlobalRef(c_entry); + env->DeleteGlobalRef(c_integer); + env->DeleteGlobalRef(c_float); + env->DeleteGlobalRef(c_logger); + env->DeleteGlobalRef(c_engine_exception); + + env->DeleteGlobalRef(o_utf_8); +} + +static void log(JNIEnv *env, enum ggml_log_level level, const char *text) { + jstring java_text = env->NewStringUTF(text); + + switch (level) { + case GGML_LOG_LEVEL_ERROR: + env->CallVoidMethod(o_logger, m_log_error, java_text); + break; + case GGML_LOG_LEVEL_WARN: + env->CallVoidMethod(o_logger, m_log_warn, java_text); + break; + case GGML_LOG_LEVEL_INFO: + env->CallVoidMethod(o_logger, m_log_info, java_text); + break; + default: + env->CallVoidMethod(o_logger, m_log_debug, java_text); + break; + } + env->DeleteLocalRef(java_text); +} + +static void log(JNIEnv *env, enum ggml_log_level level, std::string text) { log(env, level, text.c_str()); } + +static std::string parse_jstring(JNIEnv *env, jstring java_string) { + const jbyteArray string_bytes = (jbyteArray) env->CallObjectMethod(java_string, m_get_bytes, o_utf_8); + + size_t length = (size_t) env->GetArrayLength(string_bytes); + jbyte *byte_elements = env->GetByteArrayElements(string_bytes, nullptr); + + std::string string = std::string((char *) byte_elements, length); + + env->ReleaseByteArrayElements(string_bytes, byte_elements, JNI_ABORT); + env->DeleteLocalRef(string_bytes); + + return string; +} + +static int parse_jinteger(JNIEnv *env, jobject java_integer) { + if (!java_integer) return 0; + return env->CallIntMethod(java_integer, m_int_value); +} + +static float parse_jfloat(JNIEnv *env, jobject java_float) { + if (!java_float) return 0; + return env->CallFloatMethod(java_float, m_float_value); +} + +static jbyteArray parse_jbytes(JNIEnv *env, std::string string) { + jsize len = string.size(); + jbyteArray bytes = env->NewByteArray(len); + env->SetByteArrayRegion(bytes, 0, len, reinterpret_cast(string.c_str())); + return bytes; +} + +// completion token output with probabilities +struct completion_token_output { + struct token_prob { + llama_token tok; + float prob; + }; + + std::vector probs; + llama_token tok; +}; + +static size_t common_part(const std::vector &a, const std::vector &b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) { + } + return i; +} + +enum stop_type { + STOP_FULL, + STOP_PARTIAL, +}; + +static bool ends_with(const std::string &str, const std::string &suffix) { + return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { + if (!text.empty() && !stop.empty()) { + const char text_last_char = text.back(); + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { + const std::string current_partial = stop.substr(0, char_index + 1); + if (ends_with(text, current_partial)) { + return text.size() - char_index - 1; + } + } + } + } + return std::string::npos; +} + +template +static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += llama_token_to_piece(ctx, *begin); + } + return ret; +} + +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) { + std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + return out; +} + +struct jllama_context { + bool has_next_token = false; + std::string generated_text; + std::vector generated_token_probs; + + size_t num_prompt_tokens = 0; + size_t num_tokens_predicted = 0; + size_t n_past = 0; + size_t n_remain = 0; + + std::string prompt; + std::vector embd; + std::vector last_n_tokens; + + llama_model *model = nullptr; + llama_context *ctx = nullptr; + gpt_params params; + llama_sampling_context ctx_sampling; + int n_ctx; + + grammar_parser::parse_state parsed_grammar; + llama_grammar *grammar = nullptr; + + bool truncated = false; + bool stopped_eos = false; + bool stopped_word = false; + bool stopped_limit = false; + std::string stopping_word; + int32_t multibyte_pending = 0; + + std::mutex mutex; + + std::unique_lock lock() { return std::unique_lock(mutex); } + + ~jllama_context() { + if (ctx) { + llama_free(ctx); + ctx = nullptr; + } + if (model) { + llama_free_model(model); + model = nullptr; + } + if (grammar) { + llama_grammar_free(grammar); + grammar = nullptr; + } + } + + void rewind() { + params.antiprompt.clear(); + params.sparams.grammar.clear(); + num_prompt_tokens = 0; + num_tokens_predicted = 0; + generated_text = ""; + generated_text.reserve(n_ctx); + generated_token_probs.clear(); + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; + stopping_word = ""; + multibyte_pending = 0; + n_remain = 0; + n_past = 0; + + if (grammar != nullptr) { + llama_grammar_free(grammar); + grammar = nullptr; + ctx_sampling = *llama_sampling_init(params.sparams); + } + } + + bool loadModel(const gpt_params ¶ms_) { + params = params_; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == nullptr) { + return false; + } + n_ctx = llama_n_ctx(ctx); + last_n_tokens.resize(n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + return true; + } + + std::vector tokenize(std::string prompt, bool add_bos) const { + return ::llama_tokenize(ctx, prompt, add_bos); + } + + bool loadGrammar(JNIEnv *env) { + if (!params.sparams.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.sparams.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + log(env, GGML_LOG_LEVEL_ERROR, "grammar parse error"); + return false; + } + grammar_parser::print_grammar(stderr, parsed_grammar); + + { + auto it = params.sparams.logit_bias.find(llama_token_eos(model)); + if (it != params.sparams.logit_bias.end() && it->second == -INFINITY) { + log(env, GGML_LOG_LEVEL_WARN, "EOS token is disabled, which will cause most grammars to fail"); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + ctx_sampling = *llama_sampling_init(params.sparams); + return true; + } + + void loadInfill(JNIEnv *env) { + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + + auto prefix_tokens = tokenize(params.input_prefix, false); + auto suffix_tokens = tokenize(params.input_suffix, false); + const int space_token = 29871; + if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(model)); + auto prompt_tokens = prefix_tokens; + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) { + params.n_keep = (int) num_prompt_tokens; + } + params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t) params.n_ctx) { + // todo we probably want to cut from both sides + const int n_left = (params.n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert( + new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); + + truncated = true; + prompt_tokens = new_tokens; + } else { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + embd = prompt_tokens; + + if (n_past == num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + n_past--; + } + + // since #3228 we now have to manually manage the KV cache + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + + has_next_token = true; + } + + void loadPrompt(JNIEnv *env) { + auto prompt_tokens = tokenize(prompt, true); // always add BOS + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) { + params.n_keep = (int) num_prompt_tokens; + } + params.n_keep = std::min(n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t) n_ctx) { + const int n_left = (n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert( + new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); + + truncated = true; + prompt_tokens = new_tokens; + } else { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + + embd = prompt_tokens; + if (n_past == num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + n_past--; + } + + // since #3228 we now have to manually manage the KV cache + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + + has_next_token = true; + } + + void beginCompletion() { + // number of tokens to keep when resetting context + n_remain = params.n_predict; + llama_set_rng_seed(ctx, params.seed); + } + + completion_token_output nextToken(JNIEnv *env) { + completion_token_output result; + result.tok = -1; + + if (embd.size() >= (size_t) n_ctx) { + // Shift context + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left / 2; + + llama_kv_cache_seq_rm(ctx, 0, params.n_keep + 1, params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) { + embd[i - n_discard] = embd[i]; + } + embd.resize(embd.size() - n_discard); + + n_past -= n_discard; + + truncated = true; + log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); + } + + bool tg = true; + while (n_past < embd.size()) { + int n_eval = (int) embd.size() - n_past; + tg = n_eval == 1; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + + if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) { + log(env, GGML_LOG_LEVEL_ERROR, "failed to eval n_eval=" + std::to_string(n_eval)); + has_next_token = false; + return result; + } + n_past += n_eval; + } + + if (params.n_predict == 0) { + has_next_token = false; + result.tok = llama_token_eos(model); + return result; + } + + { + // out of user input, sample next token + result.tok = llama_sampling_sample(&ctx_sampling, ctx, NULL); + + llama_token_data_array candidates_p = {ctx_sampling.cur.data(), ctx_sampling.cur.size(), false}; + + const int32_t n_probs = params.sparams.n_probs; + if (params.sparams.temp <= 0 && n_probs > 0) { + // For llama_sample_token_greedy we need to sort candidates + llama_sample_softmax(ctx, &candidates_p); + } + + for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) { + result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + } + + llama_sampling_accept(&ctx_sampling, ctx, result.tok, true); + if (tg) { + num_tokens_predicted++; + } + } + + // add it to the context + embd.push_back(result.tok); + // decrement remaining sampling budget + --n_remain; + + if (!embd.empty() && embd.back() == llama_token_eos(model)) { + // stopping_word = llama_token_to_piece(ctx, embd.back()); + has_next_token = false; + stopped_eos = true; + return result; + } + + has_next_token = params.n_predict == -1 || n_remain != 0; + return result; + } + + size_t findStoppingStrings(const std::string &text, const size_t last_token_size, const stop_type type) { + size_t stop_pos = std::string::npos; + for (const std::string &word : params.antiprompt) { + size_t pos; + if (type == STOP_FULL) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + pos = text.find(word, from_pos); + } else { + pos = find_partial_stop_string(word, text); + } + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (type == STOP_FULL) { + stopping_word = word; + stopped_word = true; + has_next_token = false; + } + stop_pos = pos; + } + } + return stop_pos; + } + + completion_token_output doCompletion(JNIEnv *env) { + auto token_with_probs = nextToken(env); + + const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); + generated_text += token_text; + + if (params.sparams.n_probs > 0) { + generated_token_probs.push_back(token_with_probs); + } + + if (multibyte_pending > 0) { + multibyte_pending -= token_text.size(); + } else if (token_text.size() == 1) { + const char c = token_text[0]; + // 2-byte characters: 110xxxxx 10xxxxxx + if ((c & 0xE0) == 0xC0) { + multibyte_pending = 1; + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + } else if ((c & 0xF0) == 0xE0) { + multibyte_pending = 2; + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + } else if ((c & 0xF8) == 0xF0) { + multibyte_pending = 3; + } else { + multibyte_pending = 0; + } + } + + if (multibyte_pending > 0 && !has_next_token) { + has_next_token = true; + n_remain++; + } + + if (!has_next_token && n_remain == 0) { + stopped_limit = true; + } + + return token_with_probs; + } + + std::vector getEmbedding(JNIEnv *env) { + static const int n_embd = llama_n_embd(model); + if (!params.embedding) { + log(env, GGML_LOG_LEVEL_ERROR, "embedding disabled"); + return std::vector(n_embd, 0.0f); + } + const float *data = llama_get_embeddings(ctx); + std::vector embedding(data, data + n_embd); + return embedding; + } +}; + +static gpt_params parse_model_params(JNIEnv *env, jobject jparams, jstring java_file_path) { + gpt_params params; + + params.model = parse_jstring(env, java_file_path); + params.n_threads = env->GetIntField(jparams, f_n_threads); + params.n_ctx = env->GetIntField(jparams, f_n_ctx); + params.n_batch = env->GetIntField(jparams, f_n_batch); + params.n_gpu_layers = env->GetIntField(jparams, f_n_gpu_layers); + params.main_gpu = env->GetIntField(jparams, f_main_gpu); + params.rope_freq_base = env->GetFloatField(jparams, f_rope_freq_base); + params.rope_freq_scale = env->GetFloatField(jparams, f_rope_freq_scale); + params.mul_mat_q = env->GetBooleanField(jparams, f_mul_mat_q); + params.embedding = env->GetBooleanField(jparams, f_embedding); + params.escape = env->GetIntField(jparams, f_n_predict); + params.use_mmap = env->GetBooleanField(jparams, f_use_mmap); + params.use_mlock = env->GetBooleanField(jparams, f_use_mlock); + params.numa = env->GetBooleanField(jparams, f_numa); + params.verbose_prompt = env->GetBooleanField(jparams, f_verbose_prompt); + + if (params.model_alias == "unknown") { + params.model_alias = params.model; + } + + return params; +} + +static void setup_infer_params(JNIEnv *env, jllama_context *llama, jobject jparams) { + auto ¶ms = llama->params; + + params.seed = env->GetIntField(jparams, f_infer_seed); + params.n_predict = env->GetIntField(jparams, f_n_predict); + params.n_keep = env->GetIntField(jparams, f_n_keep); + + auto &sparams = params.sparams; + + sparams.top_k = env->GetIntField(jparams, f_top_k); + sparams.top_p = env->GetFloatField(jparams, f_top_p); + sparams.tfs_z = env->GetFloatField(jparams, f_tfs_z); + sparams.typical_p = env->GetFloatField(jparams, f_typical_p); + sparams.temp = env->GetFloatField(jparams, f_temperature); + sparams.penalty_repeat = env->GetFloatField(jparams, f_repeat_penalty); + sparams.n_prev = env->GetIntField(jparams, f_repeat_last_n); + sparams.penalty_freq = env->GetFloatField(jparams, f_frequency_penalty); + sparams.penalty_present = env->GetFloatField(jparams, f_presence_penalty); + sparams.penalize_nl = env->GetBooleanField(jparams, f_penalize_nl); + sparams.mirostat = env->GetIntField(jparams, f_mirostat); + sparams.mirostat_tau = env->GetFloatField(jparams, f_mirostat_tau); + sparams.mirostat_eta = env->GetFloatField(jparams, f_mirostat_eta); + sparams.n_probs = env->GetIntField(jparams, f_n_probs); + + jstring j_grammar = (jstring) env->GetObjectField(jparams, f_grammar); + if (j_grammar != nullptr) { + sparams.grammar = parse_jstring(env, j_grammar); + env->DeleteLocalRef(j_grammar); + if (!llama->loadGrammar(env)) { + env->ThrowNew(c_engine_exception, "could not load grammar"); + } + } + + sparams.logit_bias.clear(); + jboolean ignore_eos = env->GetBooleanField(jparams, f_ignore_eos); + if (ignore_eos) { + sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY; + } + + jobject logit_bias = env->GetObjectField(jparams, f_logit_bias); + if (logit_bias != nullptr) { + jobject entry_set = env->CallObjectMethod(logit_bias, m_entry_set); + jobject iterator = env->CallObjectMethod(entry_set, m_set_iterator); + while (env->CallBooleanMethod(iterator, m_iterator_has_next)) { + jobject entry = env->CallObjectMethod(iterator, m_iterator_next); + jobject key = env->CallObjectMethod(entry, m_entry_key); + jobject value = env->CallObjectMethod(entry, m_entry_value); + + int tok = parse_jinteger(env, key); + float bias = parse_jfloat(env, value); + sparams.logit_bias[tok] = bias; + + env->DeleteLocalRef(entry); + env->DeleteLocalRef(key); + env->DeleteLocalRef(value); + } + } + + params.antiprompt.clear(); + jobjectArray antiprompt = (jobjectArray) env->GetObjectField(jparams, f_antiprompt); + if (antiprompt != nullptr) { + jsize array_length = env->GetArrayLength(antiprompt); + for (jsize i = 0; i < array_length; i++) { + jstring java_string = (jstring) env->GetObjectArrayElement(antiprompt, i); + if (java_string != nullptr) { + std::string string = parse_jstring(env, java_string); + params.antiprompt.push_back(string); + env->DeleteLocalRef(java_string); + } + } + } + + llama->ctx_sampling = *llama_sampling_init(params.sparams); +} + +static void setup_answering(JNIEnv *env, jllama_context *llama, jstring prompt, jobject params) { + llama->prompt = parse_jstring(env, prompt); + llama->params.input_prefix = ""; + llama->params.input_suffix = ""; + setup_infer_params(env, llama, params); +} + +static void setup_infilling(JNIEnv *env, jllama_context *llama, jstring prefix, jstring suffix, jobject params) { + llama->prompt = ""; + llama->params.input_prefix = parse_jstring(env, prefix); + llama->params.input_suffix = parse_jstring(env, suffix); + setup_infer_params(env, llama, params); +} + +JNIEXPORT jlong JNICALL Java_ai_djl_llama_jni_LlamaLibrary_loadModel( + JNIEnv *env, jclass clazz, jstring file_path, jobject jparams) { + gpt_params params = parse_model_params(env, jparams, file_path); + + jllama_context *llama = new jllama_context; + llama_backend_init(false); + + if (!llama->loadModel(params)) { + env->ThrowNew(c_engine_exception, "could not load model from given file path"); + return 0; + } + + return reinterpret_cast(llama); +} + +JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_generate( + JNIEnv *env, jclass clazz, jlong handle, jstring prompt, jobject params) { + auto *llama = reinterpret_cast(handle); + + llama->rewind(); + llama_reset_timings(llama->ctx); + setup_answering(env, llama, prompt, params); + + llama->loadPrompt(env); + llama->beginCompletion(); +} + +JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_infill( + JNIEnv *env, jclass clazz, jlong handle, jstring prefix, jstring suffix, jobject params) { + auto *llama = reinterpret_cast(handle); + + llama->rewind(); + + llama_reset_timings(llama->ctx); + + setup_infilling(env, llama, prefix, suffix, params); + + llama->loadInfill(env); + llama->beginCompletion(); +} + +JNIEXPORT jobject JNICALL Java_ai_djl_llama_jni_LlamaLibrary_getNext( + JNIEnv *env, jclass clazz, jlong handle, jlong sent_count, jlong sent_token_probs_index) { + auto *llama = reinterpret_cast(handle); + + completion_token_output token_with_probs; + while (llama->has_next_token) { + token_with_probs = llama->doCompletion(env); + if (token_with_probs.tok >= 0 && llama->multibyte_pending <= 0) { + break; + } + } + const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok); + + size_t pos = std::min((size_t) sent_count, llama->generated_text.size()); + + const std::string str_test = llama->generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_FULL); + if (stop_pos != std::string::npos) { + is_stop_full = true; + llama->generated_text.erase(llama->generated_text.begin() + pos + stop_pos, llama->generated_text.end()); + pos = std::min((size_t) sent_count, llama->generated_text.size()); + } else { + is_stop_full = false; + stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_PARTIAL); + } + + std::string to_send; + if (stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama->has_next_token && !is_stop_full && stop_pos > 0)) { + to_send = llama->generated_text.substr(pos, std::string::npos); + + sent_count += to_send.size(); + std::vector probs_output = {}; + + if (llama->params.sparams.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); + size_t probs_pos = std::min((size_t) sent_token_probs_index, llama->generated_token_probs.size()); + size_t probs_stop_pos = + std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector( + llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + } else { + to_send = ""; + } + + jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); + for (const auto &tp : token_with_probs.probs) { + jobject jtoken = env->NewObject(c_integer, cc_integer, tp.tok); + jobject jprob = env->NewObject(c_float, cc_float, tp.prob); + env->CallObjectMethod(o_probabilities, m_map_put, jtoken, jprob); + } + + jbyteArray jbytes = parse_jbytes(env, to_send); + return env->NewObject(c_token, cc_token, token_with_probs.tok, jbytes, o_probabilities, sent_count, + sent_token_probs_index, llama->has_next_token); +} + +JNIEXPORT jfloatArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_embed( + JNIEnv *env, jclass clazz, jlong handle, jstring java_prompt) { + auto *llama = reinterpret_cast(handle); + + llama->rewind(); + llama_reset_timings(llama->ctx); + llama->prompt = parse_jstring(env, java_prompt); + llama->params.n_predict = 0; + llama->loadPrompt(env); + llama->beginCompletion(); + llama->doCompletion(env); + + static const int n_embd = llama_n_embd(llama->model); + const float *data = llama_get_embeddings(llama->ctx); + std::vector embedding(data, data + n_embd); + + jfloatArray java_embedding = env->NewFloatArray(embedding.size()); + env->SetFloatArrayRegion(java_embedding, 0, embedding.size(), reinterpret_cast(embedding.data())); + + return java_embedding; +} + +JNIEXPORT jintArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_encode( + JNIEnv *env, jclass clazz, jlong handle, jstring jprompt) { + auto *llama = reinterpret_cast(handle); + + std::string prompt = parse_jstring(env, jprompt); + std::vector tokens = llama->tokenize(prompt, false); + + jintArray java_tokens = env->NewIntArray(tokens.size()); + env->SetIntArrayRegion(java_tokens, 0, tokens.size(), reinterpret_cast(tokens.data())); + + return java_tokens; +} + +JNIEXPORT jbyteArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_decodeBytes( + JNIEnv *env, jclass clazz, jlong handle, jintArray java_tokens) { + auto *llama = reinterpret_cast(handle); + + jsize length = env->GetArrayLength(java_tokens); + jint *elements = env->GetIntArrayElements(java_tokens, nullptr); + std::vector tokens(elements, elements + length); + std::string text = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend()); + + env->ReleaseIntArrayElements(java_tokens, elements, 0); + + return parse_jbytes(env, text); +} + +JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_delete(JNIEnv *env, jclass clazz, jlong handle) { + auto *llama = reinterpret_cast(handle); + delete llama; +} diff --git a/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider b/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider new file mode 100644 index 00000000000..d2f8ca8e42c --- /dev/null +++ b/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider @@ -0,0 +1 @@ +ai.djl.llama.engine.LlamaEngineProvider diff --git a/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider b/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider new file mode 100644 index 00000000000..92f6245340f --- /dev/null +++ b/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider @@ -0,0 +1 @@ +ai.djl.llama.zoo.LlamaZooProvider diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java new file mode 100644 index 00000000000..429cd569392 --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.llama.engine.LlamaInput.Parameters; +import ai.djl.llama.jni.InputParameters; +import ai.djl.util.JsonUtils; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.Reader; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Map; + +public class LlamaInputTest { + + @Test + public void testInputParameters() throws IOException { + Path file = Paths.get("src/test/resources/inputs.json"); + try (Reader reader = Files.newBufferedReader(file)) { + LlamaInput in = JsonUtils.GSON.fromJson(reader, LlamaInput.class); + checkParameters(in); + } + + Parameters param = new Parameters(); + LlamaInput in = new LlamaInput(); + in.setInputs("prompt"); + in.setPrefix("prefix"); + in.setSuffix("suffix"); + in.setParameters(param); + param.setMaxNewTokens(2); + param.setNumberKeep(2); + param.setNumberProbabilities(2); + param.setTopK(2); + param.setTopP(2f); + param.setTfsZ(2f); + param.setTypicalP(2f); + param.setTemperature(2f); + param.setRepeatPenalty(2f); + param.setRepeatLastN(2); + param.setFrequencyPenalty(2f); + param.setFrequencyPenalty(2f); + param.setPresencePenalty(2f); + param.setPenalizeNl(true); + param.setIgnoreEos(true); + param.setMirostat(2); + param.setMirostatTau(2f); + param.setMirostatEta(2f); + param.setNumberBeams(5); + param.setSeed(2); + Map logitBias = Map.of(2, 0.4f, 3, 0.5f); + param.setLogitBias(logitBias); + param.setGrammar("grammar"); + param.setAntiPrompt(new String[] {"User: "}); + checkParameters(in); + } + + private void checkParameters(LlamaInput in) { + InputParameters param = in.getParameters().toInputParameters(); + Assert.assertEquals(param.getMaxNewTokens(), 2); + Assert.assertEquals(param.getNumberKeep(), 2); + Assert.assertEquals(param.getNumberProbabilities(), 2); + Assert.assertEquals(param.getTopK(), 2); + Assert.assertEquals(param.getTopP(), 2f); + Assert.assertEquals(param.getTfsZ(), 2f); + Assert.assertEquals(param.getTypicalP(), 2f); + Assert.assertEquals(param.getTemperature(), 2f); + Assert.assertEquals(param.getRepeatPenalty(), 2f); + Assert.assertEquals(param.getRepeatLastN(), 2); + Assert.assertEquals(param.getFrequencyPenalty(), 2f); + Assert.assertEquals(param.getFrequencyPenalty(), 2f); + Assert.assertEquals(param.getPresencePenalty(), 2f); + Assert.assertTrue(param.isPenalizeNl()); + Assert.assertTrue(param.isIgnoreEos()); + Assert.assertEquals(param.getMirostat(), 2); + Assert.assertEquals(param.getMirostatTau(), 2f); + Assert.assertEquals(param.getMirostatEta(), 2f); + Assert.assertEquals(param.getNumberBeams(), 5); + Assert.assertEquals(param.getSeed(), 2); + Map logitBias = param.getLogitBias(); + Assert.assertNotNull(logitBias); + Assert.assertEquals(logitBias.size(), 2); + Assert.assertEquals(logitBias.get(2), 0.4f); + Assert.assertNotNull(param.getGrammar()); + Assert.assertNotNull(param.getAntiPrompt()[0], "User: "); + } +} diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java new file mode 100644 index 00000000000..7b372ee4258 --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.ModelException; +import ai.djl.engine.Engine; +import ai.djl.engine.StandardCapabilities; +import ai.djl.inference.Predictor; +import ai.djl.llama.jni.Token; +import ai.djl.llama.jni.TokenIterator; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.testing.TestRequirements; +import ai.djl.training.util.DownloadUtils; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; + +public class LlamaTest { + + private static final Logger logger = LoggerFactory.getLogger(LlamaTest.class); + + @BeforeClass + public void setUp() { + System.setProperty("DJL_CACHE_DIR", "build/cache"); + } + + @AfterClass + public void tierDown() { + System.clearProperty("DJL_CACHE_DIR"); + } + + @Test + public void testLlamaVersion() { + Engine engine = Engine.getEngine("Llama"); + Assert.assertEquals(engine.getVersion(), "b1696-" + Engine.getDjlVersion()); + Assert.assertNotNull(engine.toString()); + Assert.assertEquals(engine.getRank(), 10); + Assert.assertFalse(engine.hasCapability(StandardCapabilities.CUDA)); + Assert.assertNull(engine.getAlternativeEngine()); + try (NDManager manager = engine.newBaseManager()) { + Assert.assertNotNull(manager); + } + } + + @Test + public void testLlama() throws TranslateException, ModelException, IOException { + TestRequirements.nightly(); + downloadModel(); + Path path = Paths.get("models"); + Criteria criteria = + Criteria.builder() + .setTypes(String.class, TokenIterator.class) + .optModelPath(path) + .optModelName("tinyllama-1.1b-1t-openorca.Q4_K_M") + .optEngine("Llama") + .optOption("number_gpu_layers", "43") + .optTranslatorFactory(new LlamaTranslatorFactory()) + .build(); + + String prompt = + "{\"inputs\": \"<|im_start|>system\n" + + "{system_message}<|im_end|>\n" + + "<|im_start|>user\n" + + "{prompt}<|im_end|>\n" + + "<|im_start|>assistant\", \"parameters\": {\"max_new_tokens\": 10}}"; + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + TokenIterator it = predictor.predict(prompt); + StringBuilder sb = new StringBuilder(); + while (it.hasNext()) { + Token token = it.next(); + Assert.assertNotNull(token.getText()); + Assert.assertTrue(token.getToken() >= 0); + Assert.assertNotNull(token.getProbabilities()); + sb.append(token.getText()); + logger.info("{}", token); + } + Assert.assertTrue(sb.length() > 1); + } + } + + @Test + public void testLlamaInfill() throws TranslateException, ModelException, IOException { + TestRequirements.nightly(); + downloadModel(); + Path path = Paths.get("models/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf"); + Criteria criteria = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(path) + .optOption("number_gpu_layers", "43") + .optEngine("Llama") + .optTranslatorFactory(new LlamaTranslatorFactory()) + .build(); + + String prompt = + "{\n" + + " \"prefix\":\"def remove_non_ascii(s: str) -> str:\n\",\n" + + " \"suffix\":\"\n return result\n\",\n" + + " \"parameters\":{\n" + + " \"max_new_tokens\": 10" + + " }\n" + + "}"; + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Input in = new Input(); + in.add("data", prompt); + Output out = predictor.predict(in); + Assert.assertNotNull(out.getData().getAsString()); + } + } + + private void downloadModel() throws IOException { + String url = + "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf?download=true"; + Path dir = Paths.get("models/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf"); + DownloadUtils.download(URI.create(url).toURL(), dir, null); + } +} diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java b/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java new file mode 100644 index 00000000000..b2ee786419f --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains test classes for llama engine. */ +package ai.djl.llama.engine; diff --git a/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java b/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java new file mode 100644 index 00000000000..fab7bacb9e3 --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.zoo; + +import ai.djl.repository.zoo.ModelLoader; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.util.Utils; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.nio.file.Paths; +import java.util.Collection; + +public class LlamaModelZooTest { + + @Test + public void testLlamaModelZoo() { + System.setProperty("DJL_CACHE_DIR", "build/cache"); + Utils.deleteQuietly(Paths.get("build/cache/cache")); + try { + ModelZoo zoo = ModelZoo.getModelZoo("ai.djl.huggingface.gguf"); + Collection models = zoo.getModelLoaders(); + Assert.assertFalse(models.isEmpty()); + Assert.assertEquals(zoo.getSupportedEngines().size(), 1); + ModelLoader loader = zoo.getModelLoader("TinyLlama/TinyLlama-1.1B-Chat-v0.6"); + Assert.assertNotNull(loader); + + ModelZoo llamaModelZoo = new LlamaModelZoo(); + Assert.assertFalse(llamaModelZoo.getModelLoaders().isEmpty()); + } finally { + System.clearProperty("DJL_CACHE_DIR"); + } + } + + @Test + public void testOffLine() { + System.setProperty("DJL_CACHE_DIR", "build/cache"); + System.setProperty("ai.djl.offline", "true"); + Utils.deleteQuietly(Paths.get("build/cache/cache")); + try { + // static variables cannot not be initialized properly if directly use LlamaModelZoo() + ModelZoo.getModelZoo("ai.djl.huggingface.gguf"); + + ModelZoo zoo = new LlamaModelZoo(); + Assert.assertFalse(zoo.getModelLoaders().isEmpty()); + } finally { + System.clearProperty("DJL_CACHE_DIR"); + System.clearProperty("ai.djl.offline"); + } + } +} diff --git a/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java b/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java new file mode 100644 index 00000000000..145b2ddcca9 --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains test classes for llama model zoo. */ +package ai.djl.llama.zoo; diff --git a/engines/llama/src/test/resources/inputs.json b/engines/llama/src/test/resources/inputs.json new file mode 100644 index 00000000000..ab77386e1b6 --- /dev/null +++ b/engines/llama/src/test/resources/inputs.json @@ -0,0 +1,33 @@ +{ + "prefix": "def remove_non_ascii(s: str) -> str:", + "suffix": " return result", + "parameters": { + "max_new_tokens": 2, + "number_keep": 2, + "number_probabilities": 2, + "top_k": 2, + "top_p": 2, + "tfs_z": 2, + "typical_p": 2, + "temperature": 2, + "repeat_penalty": 2, + "repeat_last_n": 2, + "frequency_penalty": 2, + "presence_penalty": 2, + "penalize_nl": true, + "ignore_eos": true, + "mirostat": 2, + "mirostat_tau": 2, + "mirostat_eta": 2, + "number_beams": 5, + "seed": 2, + "logit_bias": { + "2": 0.4, + "5": 0.6 + }, + "grammar": "root ::= (expr \"=\" term \"\\n\")+\nexpr ::= term ([-+*/] term)*\nterm ::= [0-9]", + "anti_prompt": [ + "User: " + ] + } +} diff --git a/engines/ml/lightgbm/README.md b/engines/ml/lightgbm/README.md index 3ea950c8935..b74fae73082 100644 --- a/engines/ml/lightgbm/README.md +++ b/engines/ml/lightgbm/README.md @@ -36,13 +36,13 @@ LightGBM can only run on top of the Linux/Mac/Windows machine using x86_64. ## Installation You can pull the LightGBM engine from the central Maven repository by including the following dependency: -- ai.djl.ml.lightgbm:lightgbm:0.23.0 +- ai.djl.ml.lightgbm:lightgbm:0.27.0 ```xml ai.djl.ml.lightgbm lightgbm - 0.23.0 + 0.27.0 runtime ``` diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index a253ce3d246..f8c84c753ef 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -18,8 +18,6 @@ /** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ public class LgbmEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,11 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (LgbmEngineProvider.class) { - engine = LgbmEngine.newInstance(); - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = LgbmEngine.newInstance(); } } diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java index 0bb92645a89..826b1a0f900 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java @@ -46,6 +46,7 @@ public class LgbmSymbolBlock extends AbstractSymbolBlock implements AutoCloseabl * @param iterations the number of iterations the model was trained for * @param handle the Booster handle */ + @SuppressWarnings("this-escape") public LgbmSymbolBlock(LgbmNDManager manager, int iterations, SWIGTYPE_p_p_void handle) { this.handle = new AtomicReference<>(handle); this.iterations = iterations; diff --git a/engines/ml/xgboost/README.md b/engines/ml/xgboost/README.md index d69f1830193..d10f770c956 100644 --- a/engines/ml/xgboost/README.md +++ b/engines/ml/xgboost/README.md @@ -37,13 +37,13 @@ XGBoost can only run on top of the Linux/Mac machine. User can build from source ## Installation You can pull the XGBoost engine from the central Maven repository by including the following dependency: -- ai.djl.ml.xgboost:xgboost:0.23.0 +- ai.djl.ml.xgboost:xgboost:0.27.0 ```xml ai.djl.ml.xgboost xgboost - 0.23.0 + 0.27.0 runtime ``` diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 19cba32cc71..5859f3f344d 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -18,8 +18,6 @@ /** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */ public class XgbEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,11 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (XgbEngineProvider.class) { - engine = XgbEngine.newInstance(); - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = XgbEngine.newInstance(); } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java index bf41acb9b6c..1b3c5ae277f 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java @@ -80,6 +80,8 @@ private Path findModelFile(String prefix) { String fileName = file.toFile().getName(); if (fileName.endsWith(".json")) { modelName = fileName.substring(0, fileName.length() - 5); + } else if (fileName.endsWith(".xgb")) { + modelName = fileName.substring(0, fileName.length() - 4); } else { modelName = fileName; } @@ -90,13 +92,22 @@ private Path findModelFile(String prefix) { } Path modelFile = modelDir.resolve(prefix); if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) { - if (prefix.endsWith(".json")) { + if (prefix.endsWith(".json") || prefix.endsWith(".xgb")) { return null; } modelFile = modelDir.resolve(prefix + ".json"); - if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) { - return null; + if (Files.isRegularFile(modelFile)) { + return modelFile; + } + modelFile = modelDir.resolve(prefix + ".xgb"); + if (Files.isRegularFile(modelFile)) { + return modelFile; + } + modelFile = modelDir.resolve("model.xgb"); + if (Files.isRegularFile(modelFile)) { + return modelFile; } + return null; } return modelFile; } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java index 3b56cbca241..81f9708e72b 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java @@ -39,6 +39,7 @@ public class XgbNDManager extends BaseNDManager { private static final XgbNDManager SYSTEM_MANAGER = new SystemManager(); private float missingValue = Float.NaN; + private int nthread = 1; private XgbNDManager(NDManager parent, Device device) { super(parent, device); @@ -57,6 +58,15 @@ public void setMissingValue(float missingValue) { this.missingValue = missingValue; } + /** + * Sets the default number of threads. + * + * @param nthread the default number of threads + */ + public void setNthread(int nthread) { + this.nthread = nthread; + } + /** {@inheritDoc} */ @Override public ByteBuffer allocateDirect(int capacity) { @@ -166,7 +176,7 @@ public NDArray createCSR(Buffer buffer, long[] indptr, long[] indices, Shape sha int[] intIndices = Arrays.stream(indices).mapToInt(Math::toIntExact).toArray(); float[] data = new float[buffer.remaining()]; ((FloatBuffer) buffer).get(data); - long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data); + long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data, missingValue, nthread); return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.CSR); } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java index 1e2bcddd999..43a9e129dea 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java @@ -45,6 +45,7 @@ public class XgbSymbolBlock extends AbstractSymbolBlock implements AutoCloseable * @param manager the manager to use for the block * @param handle the Booster handle */ + @SuppressWarnings("this-escape") public XgbSymbolBlock(XgbNDManager manager, long handle) { this.handle = new AtomicReference<>(handle); this.manager = manager; diff --git a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java index fefbe7f0716..eb071552fd0 100644 --- a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java +++ b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java @@ -67,9 +67,12 @@ public static long createDMatrix(ColumnBatch columnBatch, float missing, int nth return handles[0]; } - public static long createDMatrixCSR(long[] indptr, int[] indices, float[] array) { + public static long createDMatrixCSR( + long[] indptr, int[] indices, float[] array, float missing, int nthread) { long[] handles = new long[1]; - checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(indptr, indices, array, 0, handles)); + checkCall( + XGBoostJNI.XGDMatrixCreateFromCSR( + indptr, indices, array, 0, missing, nthread, handles)); return handles[0]; } diff --git a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java index 0b09ed6807c..acbfa998867 100644 --- a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java +++ b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java @@ -53,7 +53,7 @@ public void downloadXGBoostModel() throws IOException { @Test public void testVersion() { Engine engine = Engine.getEngine("XGBoost"); - Assert.assertEquals("1.7.5", engine.getVersion()); + Assert.assertEquals("2.0.3", engine.getVersion()); } /* @@ -93,6 +93,7 @@ public void testNDArray() { try (XgbNDManager manager = (XgbNDManager) XgbNDManager.getSystemManager().newSubManager()) { manager.setMissingValue(Float.NaN); + manager.setNthread(1); NDArray zeros = manager.zeros(new Shape(1, 2)); Assert.expectThrows(UnsupportedOperationException.class, zeros::toFloatArray); diff --git a/engines/mxnet/jnarator/build.gradle b/engines/mxnet/jnarator/build.gradle index b9cc0d4cd5f..b9fd8ceab14 100644 --- a/engines/mxnet/jnarator/build.gradle +++ b/engines/mxnet/jnarator/build.gradle @@ -17,6 +17,11 @@ dependencies { checkstyleMain.source = 'src/main/java' pmdMain.source = 'src/main/java' +compileJava { + options.compilerArgs.clear() + options.compilerArgs << "--release" << "11" << "-proc:none" << "-Xlint:all,-options,-static" +} + jar { manifest { attributes ( diff --git a/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java b/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java index 3105ec9cd48..ba3e18fea3b 100644 --- a/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java +++ b/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java @@ -276,6 +276,7 @@ public void writeNativeSize() throws IOException { writer.append(" public NativeSizeByReference() {\n"); writer.append(" this(new NativeSize(0));\n"); writer.append(" }\n\n"); + writer.append(" @SuppressWarnings(\"this-escape\")\n"); writer.append(" public NativeSizeByReference(NativeSize value) {\n"); writer.append(" super(NativeSize.SIZE);\n"); writer.append(" setValue(value);\n"); diff --git a/engines/mxnet/mxnet-engine/README.md b/engines/mxnet/mxnet-engine/README.md index cef559f1e31..92f94848550 100644 --- a/engines/mxnet/mxnet-engine/README.md +++ b/engines/mxnet/mxnet-engine/README.md @@ -7,7 +7,7 @@ This module contains the Deep Java Library (DJL) EngineProvider for Apache MXNet We don't recommend that developers use classes in this module directly. Use of these classes will couple your code with Apache MXNet and make switching between engines difficult. Even so, developers are not restricted from using engine-specific features. For more information, -see [NDManager#invoke()](https://javadoc.io/static/ai.djl/api/0.23.0/ai/djl/ndarray/NDManager.html#invoke-java.lang.String-ai.djl.ndarray.NDArray:A-ai.djl.ndarray.NDArray:A-ai.djl.util.PairList-). +see [NDManager#invoke()](https://javadoc.io/static/ai.djl/api/0.27.0/ai/djl/ndarray/NDManager.html#invoke-java.lang.String-ai.djl.ndarray.NDArray:A-ai.djl.ndarray.NDArray:A-ai.djl.util.PairList-). ## Documentation @@ -33,7 +33,7 @@ You can pull the MXNet engine from the central Maven repository by including the ai.djl.mxnet mxnet-engine - 0.23.0 + 0.27.0 runtime ``` diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java index 62398b1868e..b1ca8e49aa4 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java @@ -63,6 +63,7 @@ public class CachedOp extends NativeResource { * @param dataIndices the input data names required by the model and their corresponding * location */ + @SuppressWarnings("this-escape") public CachedOp( Pointer handle, MxNDManager manager, diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index f30a6a89252..5f45116f615 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -18,8 +18,6 @@ /** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */ public class MxEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,11 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (MxEngineProvider.class) { - engine = MxEngine.newInstance(); - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = MxEngine.newInstance(); } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 87ccba78e96..8b884b3993a 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -888,6 +888,13 @@ public NDArray atan() { return manager.invoke("_npi_arctan", this, null); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + other = manager.from(other); + return manager.invoke("_npi_arctan2", new NDArray[] {this, other}, null); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { @@ -1153,6 +1160,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { @@ -1601,6 +1620,12 @@ public NDArray erfinv() { return manager.invoke("erfinv", this, null); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return manager.invoke("erf", this, null); + } + /** {@inheritDoc} */ @Override public NDArray norm(boolean keepDims) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java index 36bead164e4..952ca2f0995 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java @@ -40,6 +40,7 @@ public class MxParameterServer extends NativeResource implements Parame * * @param optimizer the optimizer to use for the parameter server updates */ + @SuppressWarnings("this-escape") public MxParameterServer(Optimizer optimizer) { super(createdKVStore()); callback = new OptimizerCallback(optimizer); diff --git a/engines/mxnet/mxnet-model-zoo/README.md b/engines/mxnet/mxnet-model-zoo/README.md index c4f44fe358c..8c03913c776 100644 --- a/engines/mxnet/mxnet-model-zoo/README.md +++ b/engines/mxnet/mxnet-model-zoo/README.md @@ -27,7 +27,7 @@ You can pull the MXNet engine from the central Maven repository by including the ai.djl.mxnet mxnet-model-zoo - 0.23.0 + 0.27.0 ``` diff --git a/engines/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/mxnet/yolo/metadata.json b/engines/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/mxnet/yolo/metadata.json index a5c3a140933..81a0fdc944a 100644 --- a/engines/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/mxnet/yolo/metadata.json +++ b/engines/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/mxnet/yolo/metadata.json @@ -55,7 +55,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_darknet_voc_416", "properties": { "dataset": "voc", "version": "3", @@ -80,11 +80,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_darknet53_voc-416x416/yolo-symbol.json", + "name": "yolo3_darknet_voc_416-symbol.json", "sha1Hash": "488dfc61afdb9022901673c048e3773041a20669", "size": 216997 }, "parameters": { "uri": "0.0.1/yolo3_darknet53_voc-416x416/yolo-0000.params.gz", + "name": "yolo3_darknet_voc_416-0000.params", "sha1Hash": "e71611a6eda9d475b941a3c57d6e447e54e22b6d", "size": 228664813 } @@ -93,7 +95,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_mobilenet_voc_320", "properties": { "dataset": "voc", "version": "3", @@ -118,11 +120,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_mobilenet1.0_voc-320x320/yolo-symbol.json", + "name": "yolo3_mobilenet_voc_320-symbol.json", "sha1Hash": "367e425d3ffa1fc06355dc88b96f5c0c408e224c", "size": 147800 }, "parameters": { "uri": "0.0.1/yolo3_mobilenet1.0_voc-320x320/yolo-0000.params.gz", + "name": "yolo3_mobilenet_voc_320-0000.params", "sha1Hash": "69f6935e53f69560ced1718bfa73935f9db7412d", "size": 89818905 } @@ -131,7 +135,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_mobilenet_voc_41", "properties": { "dataset": "voc", "version": "3", @@ -156,11 +160,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_mobilenet1.0_voc-416x416/yolo-symbol.json", + "name": "yolo3_mobilenet_voc_416-symbol.json", "sha1Hash": "1f537495fd8ad952d4c7a3bc3160583a55269469", "size": 147800 }, "parameters": { "uri": "0.0.1/yolo3_mobilenet1.0_voc-416x416/yolo-0000.params.gz", + "name": "yolo3_mobilenet_voc_416-0000.params", "sha1Hash": "3a5bedb5122c970375d4ee10a78e990832fda1cb", "size": 89818919 } @@ -169,7 +175,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_darknet_coco_320", "properties": { "dataset": "coco", "version": "3", @@ -194,11 +200,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_darknet53_coco-320x320/yolo-symbol.json", + "name": "yolo3_darknet_coco_320-symbol.json", "sha1Hash": "17e60b0b141d81fb5534dec02252fdf9364a1087", "size": 217009 }, "parameters": { "uri": "0.0.1/yolo3_darknet53_coco-320x320/yolo-0000.params.gz", + "name": "yolo3_darknet_coco_320-0000.params", "sha1Hash": "06c5ddb4c6daf1839fed15d5566e49968edf60b5", "size": 229889985 } @@ -207,7 +215,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_darknet_coco_416", "properties": { "dataset": "coco", "version": "3", @@ -232,11 +240,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_darknet53_coco-416x416/yolo-symbol.json", + "name": "yolo3_darknet_coco_416-symbol.json", "sha1Hash": "ccb6cc9e479e12992059f3196ce55cda9bfb6d3e", "size": 217009 }, "parameters": { "uri": "0.0.1/yolo3_darknet53_coco-416x416/yolo-0000.params.gz", + "name": "yolo3_darknet_coco_416-0000.params", "sha1Hash": "b290675ce6b79eb35fc315c475d82423fa7621c1", "size": 229889985 } @@ -245,7 +255,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_darknet_coco_608", "properties": { "dataset": "coco", "version": "3", @@ -270,11 +280,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_darknet53_coco-608x608/yolo-symbol.json", + "name": "yolo3_darknet_coco_608-symbol.json", "sha1Hash": "a7cb07555e06571007516298dc1f238bc90baf72", "size": 217009 }, "parameters": { "uri": "0.0.1/yolo3_darknet53_coco-608x608/yolo-0000.params.gz", + "name": "yolo3_darknet_coco_608-0000.params", "sha1Hash": "2efd6cd89723913d96b66642a225ea56e03e7fa2", "size": 229889985 } @@ -283,7 +295,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_mobilenet_coco_320", "properties": { "dataset": "coco", "version": "3", @@ -308,11 +320,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_mobilenet1.0_coco-320x320/yolo-symbol.json", + "name": "yolo3_mobilenet_coco_320-symbol.json", "sha1Hash": "8ac07f8169228b5e720804f36a4dadb37817f4c3", "size": 147812 }, "parameters": { "uri": "0.0.1/yolo3_mobilenet1.0_coco-320x320/yolo-0000.params.gz", + "name": "yolo3_mobilenet_coco_320-0000.params", "sha1Hash": "d9fa1ad5413abb8f8df81ba729fa7a115836f833", "size": 91257892 } @@ -321,7 +335,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_mobilenet_coco_416", "properties": { "dataset": "coco", "version": "3", @@ -346,11 +360,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_mobilenet1.0_coco-416x416/yolo-symbol.json", + "name": "yolo3_mobilenet_coco_416-symbol.json", "sha1Hash": "c6a85feca8d849fed6a82a6e70cdc351ec36027f", "size": 147812 }, "parameters": { "uri": "0.0.1/yolo3_mobilenet1.0_coco-416x416/yolo-0000.params.gz", + "name": "yolo3_mobilenet_coco_416-0000.params", "sha1Hash": "061e21037dcd5ac011190585437f4fbab4952a3b", "size": 91257867 } @@ -359,7 +375,7 @@ { "version": "0.0.1", "snapshot": false, - "name": "yolo", + "name": "yolo3_mobilenet_coco_608", "properties": { "dataset": "coco", "version": "3", @@ -384,11 +400,13 @@ }, "symbol": { "uri": "0.0.1/yolo3_mobilenet1.0_coco-608x608/yolo-symbol.json", + "name": "yolo3_mobilenet_coco_608-symbol.json", "sha1Hash": "10e47405a1744788ccb533bca20b2608770eeec3", "size": 147812 }, "parameters": { "uri": "0.0.1/yolo3_mobilenet1.0_coco-608x608/yolo-0000.params.gz", + "name": "yolo3_mobilenet_coco_608-0000.params", "sha1Hash": "f8fd4e8955ee90d4060d2544ed285b232c8085da", "size": 91257867 } diff --git a/engines/mxnet/native/build.gradle b/engines/mxnet/native/build.gradle index 3f8ee285054..dc9d6e5e12d 100644 --- a/engines/mxnet/native/build.gradle +++ b/engines/mxnet/native/build.gradle @@ -89,6 +89,7 @@ flavorNames.each { flavor -> } from file("${BINARY_ROOT}/${flavor}/${osName}") archiveClassifier = "${osName}-x86_64" + archiveBaseName = "mxnet-native-${flavor}" manifest { attributes("Automatic-Module-Name": "ai.djl.mxnet_native_${flavor}_${osName}") diff --git a/engines/onnxruntime/onnxruntime-android/README.md b/engines/onnxruntime/onnxruntime-android/README.md index e304e78d5c3..6e00ea2af60 100644 --- a/engines/onnxruntime/onnxruntime-android/README.md +++ b/engines/onnxruntime/onnxruntime-android/README.md @@ -6,13 +6,13 @@ This module contains the DJL ONNX Runtime engine for Android. ## Installation You can pull the ONNX Runtime for Android from the central Maven repository by including the following dependency: -- ai.djl.android:onnxruntime:0.23.0 +- ai.djl.android:onnxruntime:0.27.0 ```xml ai.djl.android onnxruntime - 0.23.0 + 0.27.0 runtime ``` diff --git a/engines/onnxruntime/onnxruntime-engine/README.md b/engines/onnxruntime/onnxruntime-engine/README.md index c287819d23f..36a2f1a3cd1 100644 --- a/engines/onnxruntime/onnxruntime-engine/README.md +++ b/engines/onnxruntime/onnxruntime-engine/README.md @@ -37,13 +37,13 @@ for the official ONNX Runtime project. ## Installation You can pull the ONNX Runtime engine from the central Maven repository by including the following dependency: -- ai.djl.onnxruntime:onnxruntime-engine:0.23.0 +- ai.djl.onnxruntime:onnxruntime-engine:0.27.0 ```xml ai.djl.onnxruntime onnxruntime-engine - 0.23.0 + 0.27.0 runtime ``` @@ -61,7 +61,7 @@ Maven: ai.djl.onnxruntime onnxruntime-engine - 0.23.0 + 0.27.0 runtime @@ -73,7 +73,7 @@ Maven: com.microsoft.onnxruntime onnxruntime_gpu - 1.14.0 + 1.17.1 runtime ``` @@ -81,10 +81,10 @@ Maven: Gradle: ```groovy -implementation("ai.djl.onnxruntime:onnxruntime-engine:0.23.0") { +implementation("ai.djl.onnxruntime:onnxruntime-engine:0.27.0") { exclude group: "com.microsoft.onnxruntime", module: "onnxruntime" } -implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.14.0" +implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.17.1" ``` #### Enable TensorRT execution diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 89599722435..43312fb18e8 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -97,7 +97,7 @@ public int getRank() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.15.1"; + return "1.17.1"; } /** {@inheritDoc} */ diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index c673b3dcbf1..005c0fa25f1 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -18,8 +18,6 @@ /** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */ public class OrtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,11 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (OrtEngineProvider.class) { - engine = OrtEngine.newInstance(); - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = OrtEngine.newInstance(); } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index 86877e47a21..e8b6008cf7e 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -70,12 +70,14 @@ public void load(Path modelPath, String prefix, Map options) throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks"); } - if (prefix == null) { - prefix = modelName; + Path modelFile; + if (prefix != null) { + modelFile = findModelFile(prefix); + } else { + // search for .onnx file with folder name or "model.onnx" + modelFile = findModelFile(modelName, modelDir.toFile().getName(), "model.onnx"); } - // search for .onnx file with prefix, folder name or "model.onnx" - Path modelFile = findModelFile(prefix, modelDir.toFile().getName(), "model.onnx"); if (modelFile == null) { throw new FileNotFoundException(".onnx file not found in: " + modelPath); } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index aa54b43f376..4e8df210d40 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -59,6 +59,7 @@ public class OrtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable * @param session the {@link OrtSession} contains the model information * @param manager the {@link NDManager} to holds the NDArray */ + @SuppressWarnings("this-escape") public OrtSymbolBlock(OrtSession session, OrtNDManager manager) { this.session = session; this.manager = manager; diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java index 9d8037cfa8b..d61cb81f1ee 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java @@ -31,6 +31,7 @@ public class OrtModelZoo extends ModelZoo { OrtModelZoo() { addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo5s", "0.0.1")); + addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); addModel(REPOSITORY.model(Tabular.SOFTMAX_REGRESSION, GROUP_ID, "iris_flowers", "0.0.1")); } diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index b3d8225a898..c16070161e7 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -84,14 +84,16 @@ public void testOrt() throws TranslateException, ModelException, IOException { Model m = Model.newInstance("model", "OnnxRuntime"); Path path = model.getModelPath(); - Assert.assertThrows(() -> m.load(path, null)); Assert.assertThrows(() -> m.load(path, "invalid.onnx")); - Path modelFile = path.resolve(model.getName() + ".onnx"); - m.load(modelFile); - + m.load(path, null); m.close(); + Model m2 = Model.newInstance("model", "OnnxRuntime"); + Path modelFile = path.resolve(model.getName() + ".onnx"); + m2.load(modelFile); + m2.close(); + // Test load model from stream Model stream = Model.newInstance("model", "OnnxRuntime"); try (InputStream is = Files.newInputStream(modelFile)) { diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json b/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json new file mode 100644 index 00000000000..1e0169a2561 --- /dev/null +++ b/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json @@ -0,0 +1,40 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/object_detection", + "groupId": "ai.djl.onnxruntime", + "artifactId": "yolov8n", + "name": "yolov8n", + "description": "YoloV8 Model", + "website": "http://www.djl.ai/engines/onnxruntime/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "yolov8n", + "arguments": { + "width": 640, + "height": 640, + "resize": true, + "rescale": true, + "optApplyRatio": true, + "threshold": 0.6, + "translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/yolov8n.zip", + "name": "", + "sha1Hash": "9fbad7f706713843cbb8c8d6a56c81a640ec6fa2", + "size": 11053839 + } + } + } + ] +} diff --git a/engines/paddlepaddle/paddlepaddle-engine/README.md b/engines/paddlepaddle/paddlepaddle-engine/README.md index 9e65fb76601..0e9643bda1a 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/README.md +++ b/engines/paddlepaddle/paddlepaddle-engine/README.md @@ -30,7 +30,7 @@ You can pull the PaddlePaddle engine from the central Maven repository by includ ai.djl.paddlepaddle paddlepaddle-engine - 0.23.0 + 0.27.0 runtime ``` diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index e2b5bdd35a0..59e5cd90724 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -18,8 +18,6 @@ /** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */ public class PpEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,11 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (PpEngineProvider.class) { - engine = PpEngine.newInstance(); - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = PpEngine.newInstance(); } } diff --git a/engines/paddlepaddle/paddlepaddle-model-zoo/README.md b/engines/paddlepaddle/paddlepaddle-model-zoo/README.md index e2c9cf6036c..09aef220bf9 100644 --- a/engines/paddlepaddle/paddlepaddle-model-zoo/README.md +++ b/engines/paddlepaddle/paddlepaddle-model-zoo/README.md @@ -26,7 +26,7 @@ from the central Maven repository by including the following dependency: ai.djl.paddlepaddle paddlepaddle-model-zoo - 0.23.0 + 0.27.0 ``` diff --git a/engines/paddlepaddle/paddlepaddle-native/build.gradle b/engines/paddlepaddle/paddlepaddle-native/build.gradle index 74a573debad..de1ea58da2b 100644 --- a/engines/paddlepaddle/paddlepaddle-native/build.gradle +++ b/engines/paddlepaddle/paddlepaddle-native/build.gradle @@ -213,6 +213,7 @@ flavorNames.each { flavor -> } from file("${BINARY_ROOT}/${flavor}/${osName}") archiveClassifier = "${osName}-x86_64" + archiveBaseName = "paddlepaddle-native-${flavor}" manifest { attributes("Automatic-Module-Name": "ai.djl.paddlepaddle_native_${flavor}_${osName}") diff --git a/engines/pytorch/pytorch-engine/README.md b/engines/pytorch/pytorch-engine/README.md index ef74cf98808..a0d246626c6 100644 --- a/engines/pytorch/pytorch-engine/README.md +++ b/engines/pytorch/pytorch-engine/README.md @@ -24,13 +24,13 @@ The javadocs output is built in the `build/doc/javadoc` folder. ## Installation You can pull the PyTorch engine from the central Maven repository by including the following dependency: -- ai.djl.pytorch:pytorch-engine:0.23.0 +- ai.djl.pytorch:pytorch-engine:0.27.0 ```xml ai.djl.pytorch pytorch-engine - 0.23.0 + 0.27.0 runtime ``` @@ -46,6 +46,11 @@ The following table illustrates which pytorch version that DJL supports: | PyTorch engine version | PyTorch native library version | |------------------------|-------------------------------------------| +| pytorch-engine:0.28.0 | 1.13.1, **2.1.2** | +| pytorch-engine:0.27.0 | 1.13.1, **2.1.1** | +| pytorch-engine:0.26.0 | 1.13.1, 2.0.1, **2.1.1** | +| pytorch-engine:0.25.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 | +| pytorch-engine:0.24.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 | | pytorch-engine:0.23.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 | | pytorch-engine:0.22.1 | 1.11.0, 1.12.1, **1.13.1**, 2.0.0 | | pytorch-engine:0.21.0 | 1.11.0, 1.12.1, **1.13.1** | @@ -110,21 +115,21 @@ export PYTORCH_FLAVOR=cpu ### macOS For macOS, you can use the following library: -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu:2.0.1:osx-x86_64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.27.0 +- ai.djl.pytorch:pytorch-native-cpu:2.1.1:osx-x86_64 ```xml ai.djl.pytorch pytorch-native-cpu osx-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` @@ -134,21 +139,21 @@ For macOS, you can use the following library: ### macOS M1 For macOS M1, you can use the following library: -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu:2.0.1:osx-aarch64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.27.0 +- ai.djl.pytorch:pytorch-native-cpu:2.1.1:osx-aarch64 ```xml ai.djl.pytorch pytorch-native-cpu osx-aarch64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` @@ -159,29 +164,29 @@ installed on your GPU machine, you can use one of the following library: #### Linux GPU -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cu118:2.0.1:linux-x86_64 - CUDA 11.8 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.27.0 +- ai.djl.pytorch:pytorch-native-cu121:2.1.1:linux-x86_64 - CUDA 12.1 ```xml ai.djl.pytorch - pytorch-native-cu118 + pytorch-native-cu121 linux-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` ### Linux CPU -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu:2.0.1:linux-x86_64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.27.0 +- ai.djl.pytorch:pytorch-native-cpu:2.1.1:linux-x86_64 ```xml @@ -189,20 +194,20 @@ installed on your GPU machine, you can use one of the following library: pytorch-native-cpu linux-x86_64 runtime - 2.0.1 + 2.1.1 ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` ### For aarch64 build -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.0.1:linux-aarch64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.27.0 +- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.1.1:linux-aarch64 ```xml @@ -210,12 +215,12 @@ installed on your GPU machine, you can use one of the following library: pytorch-native-cpu-precxx11 linux-aarch64 runtime - 2.0.1 + 2.1.1 ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` @@ -225,22 +230,22 @@ installed on your GPU machine, you can use one of the following library: We also provide packages for the system like CentOS 7/Ubuntu 14.04 with GLIBC >= 2.17. All the package were built with GCC 7, we provided a newer `libstdc++.so.6.24` in the package that contains `CXXABI_1.3.9` to use the package successfully. -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cu118-precxx11:2.0.1:linux-x86_64 - CUDA 11.8 -- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.0.1:linux-x86_64 - CPU +- ai.djl.pytorch:pytorch-jni:2.1.1-0.27.0 +- ai.djl.pytorch:pytorch-native-cu121-precxx11:2.1.1:linux-x86_64 - CUDA 12.1 +- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.1.1:linux-x86_64 - CPU ```xml ai.djl.pytorch - pytorch-native-cu118-precxx11 + pytorch-native-cu121-precxx11 linux-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` @@ -250,13 +255,13 @@ All the package were built with GCC 7, we provided a newer `libstdc++.so.6.24` i ai.djl.pytorch pytorch-native-cpu-precxx11 linux-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` @@ -271,29 +276,29 @@ For the Windows platform, you can choose between CPU and GPU. #### Windows GPU -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cu118:2.0.1:win-x86_64 - CUDA 11.8 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.27.0 +- ai.djl.pytorch:pytorch-native-cu121:2.1.1:win-x86_64 - CUDA 12.1 ```xml ai.djl.pytorch - pytorch-native-cu118 + pytorch-native-cu121 win-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` ### Windows CPU -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu:2.0.1:win-x86_64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.27.0 +- ai.djl.pytorch:pytorch-native-cpu:2.1.1:win-x86_64 ```xml @@ -301,12 +306,12 @@ For the Windows platform, you can choose between CPU and GPU. pytorch-native-cpu win-x86_64 runtime - 2.0.1 + 2.1.1 ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.27.0 runtime ``` diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 57ae6c09d34..42ca3c5b8a5 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -37,7 +37,9 @@ public int getEngineRank() { public Engine getEngine() { if (engine == null) { synchronized (PtEngineProvider.class) { - engine = PtEngine.newInstance(); + if (engine == null) { + engine = PtEngine.newInstance(); + } } } return engine; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index e72e98c9495..e409918a091 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -18,6 +18,7 @@ import ai.djl.Model; import ai.djl.ndarray.types.DataType; import ai.djl.nn.Parameter; +import ai.djl.nn.Parameter.Type; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -32,6 +33,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.function.Predicate; @@ -64,13 +66,17 @@ public void load(Path modelPath, String prefix, Map options) throws IOException, MalformedModelException { setModelDir(modelPath); wasLoaded = true; - if (prefix == null) { + + Path modelFile; + if (prefix != null) { + modelFile = findModelFile(prefix); + } else { + // search for .pt file with modelName, folder name or "model.pt" + modelFile = findModelFile(modelName, modelDir.toFile().getName(), "model.pt"); prefix = modelName; } if (block == null) { - // search for .pt file with prefix, folder name or "model.pt" - Path modelFile = findModelFile(prefix, modelDir.toFile().getName(), "model.pt"); if (modelFile == null) { String fileName = prefix.endsWith(".pt") ? prefix : prefix + ".pt"; throw new FileNotFoundException(fileName + " file not found in: " + modelDir); @@ -131,7 +137,8 @@ public void load(Path modelPath, String prefix, Map options) /** {@inheritDoc} */ @Override - public void load(InputStream modelStream, Map options) throws IOException { + public void load(InputStream modelStream, Map options) + throws IOException, MalformedModelException { boolean mapLocation = false; if (options != null) { mapLocation = Boolean.parseBoolean((String) options.get("mapLocation")); @@ -145,11 +152,26 @@ public void load(InputStream modelStream, Map options) throws IOExcep * @param modelStream the stream of the model file * @param mapLocation force load to specified device if true * @throws IOException model loading error + * @throws MalformedModelException if model file is corrupted */ - public void load(InputStream modelStream, boolean mapLocation) throws IOException { - modelDir = Files.createTempDirectory("pt-model"); - modelDir.toFile().deleteOnExit(); - block = JniUtils.loadModule((PtNDManager) manager, modelStream, mapLocation, false); + public void load(InputStream modelStream, boolean mapLocation) + throws IOException, MalformedModelException { + wasLoaded = true; + if (block == null) { + modelDir = Files.createTempDirectory("pt-model"); + modelDir.toFile().deleteOnExit(); + block = JniUtils.loadModule((PtNDManager) manager, modelStream, mapLocation, false); + + /* + * By default, the parameters are frozen, since the previous version before adding this + * trainParam, they were frozen due to the setting JITCallGuard guard, which disables + * autograd. Also, the pretrained parameters usually should not be updated too much. It + * is safe to freeze it. Users may unfreeze it and set their learning rate small. + */ + block.freezeParameters(true); + } else { + readParameters(modelStream, Collections.emptyMap()); + } } private Path findModelFile(String... prefixes) { @@ -189,7 +211,9 @@ public Trainer newTrainer(TrainingConfig trainingConfig) { } if (wasLoaded) { // Unfreeze parameters if training directly - block.freezeParameters(false); + block.freezeParameters( + false, + p -> p.getType() != Type.RUNNING_MEAN && p.getType() != Type.RUNNING_VAR); } for (Pair> pair : initializer) { if (pair.getKey() != null && pair.getValue() != null) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 9e36ec35884..551a16d0359 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -60,6 +60,7 @@ public class PtNDArray extends NativeResource implements NDArray { * @param manager the manager to attach the new array to * @param handle the pointer to the native PyTorch memory */ + @SuppressWarnings("this-escape") public PtNDArray(PtNDManager manager, long handle) { super(handle); this.manager = manager; @@ -76,6 +77,7 @@ public PtNDArray(PtNDManager manager, long handle) { * @param handle the pointer to the native PyTorch memory * @param data the direct buffer of the data */ + @SuppressWarnings("this-escape") public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { super(handle); this.manager = manager; @@ -93,10 +95,12 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { * @param strs the string array * @param shape the {@link Shape} of the {@link NDArray} */ + @SuppressWarnings("this-escape") public PtNDArray(PtNDManager manager, String[] strs, Shape shape) { super(-1L); this.manager = manager; this.strs = strs; + this.sparseFormat = SparseFormat.DENSE; this.shape = shape; this.dataType = DataType.STRING; NDScope.register(this); @@ -222,6 +226,10 @@ public NDArray stopGradient() { /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { + if (getDataType() == DataType.STRING) { + throw new UnsupportedOperationException( + "toByteBuffer is not supported for String tensor."); + } return JniUtils.getByteBuffer(this); } @@ -426,6 +434,9 @@ public boolean contentEquals(NDArray other) { if (getDataType() != other.getDataType()) { return false; } + if (getDataType() == DataType.STRING) { + return Arrays.equals(toStringArray(), other.toStringArray()); + } return JniUtils.contentEqual(this, manager.from(other)); } @@ -888,6 +899,12 @@ public PtNDArray atan() { return JniUtils.atan(this); } + /** {@inheritDoc} */ + @Override + public PtNDArray atan2(NDArray other) { + return JniUtils.atan2(this, manager.from(other)); + } + /** {@inheritDoc} */ @Override public PtNDArray sinh() { @@ -1097,6 +1114,18 @@ public NDArray stft( this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + return JniUtils.fft2(this, sizes, axes); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + return JniUtils.ifft2(this, sizes, axes); + } + /** {@inheritDoc} */ @Override public PtNDArray reshape(Shape shape) { @@ -1539,6 +1568,12 @@ public PtNDArray erfinv() { return JniUtils.erfinv(this); } + /** {@inheritDoc} */ + @Override + public PtNDArray erf() { + return JniUtils.erf(this); + } + /** {@inheritDoc} */ @Override public PtNDArray inverse() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index fa4ee81f26c..b7f92cbd1c3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.engine; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.NDUtils; @@ -24,6 +25,8 @@ import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.jni.JniUtils; +import java.util.Arrays; +import java.util.Comparator; import java.util.List; /** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */ @@ -760,7 +763,152 @@ public NDList multiBoxDetection( float nmsThreshold, boolean forceSuppress, int nmsTopK) { - throw new UnsupportedOperationException("Not implemented"); + assert (inputs.size() == 3); + + NDArray clsProb = inputs.get(0); + NDArray locPred = inputs.get(1); + NDArray anchors = inputs.get(2).reshape(new Shape(-1, 4)); + + NDManager ndManager = array.getManager(); + + NDArray variances = ndManager.create(new float[] {0.1f, 0.1f, 0.2f, 0.2f}); + + assert (variances.size() == 4); // << "Variance size must be 4"; + final int numClasses = (int) clsProb.size(1); + final int numAnchors = (int) clsProb.size(2); + final int numBatches = (int) clsProb.size(0); + + final float[] pAnchor = anchors.toFloatArray(); + + // [id, prob, xmin, ymin, xmax, ymax] + // TODO Move to NDArray-based implementation + NDList batchOutputs = new NDList(); + for (int nbatch = 0; nbatch < numBatches; ++nbatch) { + float[][] outputs = new float[numAnchors][6]; + final float[] pClsProb = clsProb.get(nbatch).toFloatArray(); + final float[] pLocPred = locPred.get(nbatch).toFloatArray(); + + for (int i = 0; i < numAnchors; ++i) { + // find the predicted class id and probability + float score = -1; + int id = 0; + for (int j = 1; j < numClasses; ++j) { + float temp = pClsProb[j * numAnchors + i]; + if (temp > score) { + score = temp; + id = j; + } + } + + if (id > 0 && score < threshold) { + id = 0; + } + + // [id, prob, xmin, ymin, xmax, ymax] + outputs[i][0] = id - 1; + outputs[i][1] = score; + int offset = i * 4; + float[] pAnchorRow4 = new float[4]; + pAnchorRow4[0] = pAnchor[offset]; + pAnchorRow4[1] = pAnchor[offset + 1]; + pAnchorRow4[2] = pAnchor[offset + 2]; + pAnchorRow4[3] = pAnchor[offset + 3]; + float[] pLocPredRow4 = new float[4]; + pLocPredRow4[0] = pLocPred[offset]; + pLocPredRow4[1] = pLocPred[offset + 1]; + pLocPredRow4[2] = pLocPred[offset + 2]; + pLocPredRow4[3] = pLocPred[offset + 3]; + float[] outRowLast4 = + transformLocations( + pAnchorRow4, + pLocPredRow4, + clip, + variances.toFloatArray()[0], + variances.toFloatArray()[1], + variances.toFloatArray()[2], + variances.toFloatArray()[3]); + outputs[i][2] = outRowLast4[0]; + outputs[i][3] = outRowLast4[1]; + outputs[i][4] = outRowLast4[2]; + outputs[i][5] = outRowLast4[3]; + } + + outputs = + Arrays.stream(outputs) + .filter(o -> o[0] >= 0) + .sorted(Comparator.comparing(o -> -o[1])) + .toArray(float[][]::new); + + // apply nms + for (int i = 0; i < outputs.length; ++i) { + for (int j = i + 1; j < outputs.length; ++j) { + if (outputs[i][0] == outputs[j][0]) { + float[] outputsIRow4 = new float[4]; + float[] outputsJRow4 = new float[4]; + outputsIRow4[0] = outputs[i][2]; + outputsIRow4[1] = outputs[i][3]; + outputsIRow4[2] = outputs[i][4]; + outputsIRow4[3] = outputs[i][5]; + outputsJRow4[0] = outputs[j][2]; + outputsJRow4[1] = outputs[j][3]; + outputsJRow4[2] = outputs[j][4]; + outputsJRow4[3] = outputs[j][5]; + float iou = calculateOverlap(outputsIRow4, outputsJRow4); + if (iou >= nmsThreshold) { + outputs[j][0] = -1; + } + } + } + } + batchOutputs.add(ndManager.create(outputs)); + } // end iter batch + + NDArray pOutNDArray = NDArrays.stack(batchOutputs); + NDList resultNDList = new NDList(); + resultNDList.add(pOutNDArray); + assert (resultNDList.size() == 1); + return resultNDList; + } + + private float[] transformLocations( + final float[] anchors, + final float[] locPred, + final boolean clip, + final float vx, + final float vy, + final float vw, + final float vh) { + float[] outRowLast4 = new float[4]; + // transform predictions to detection results + float al = anchors[0]; + float at = anchors[1]; + float ar = anchors[2]; + float ab = anchors[3]; + float aw = ar - al; + float ah = ab - at; + float ax = (al + ar) / 2.f; + float ay = (at + ab) / 2.f; + float px = locPred[0]; + float py = locPred[1]; + float pw = locPred[2]; + float ph = locPred[3]; + float ox = px * vx * aw + ax; + float oy = py * vy * ah + ay; + float ow = (float) (Math.exp(pw * vw) * aw / 2); + float oh = (float) (Math.exp(ph * vh) * ah / 2); + outRowLast4[0] = clip ? Math.max(0f, Math.min(1f, ox - ow)) : (ox - ow); + outRowLast4[1] = clip ? Math.max(0f, Math.min(1f, oy - oh)) : (oy - oh); + outRowLast4[2] = clip ? Math.max(0f, Math.min(1f, ox + ow)) : (ox + ow); + outRowLast4[3] = clip ? Math.max(0f, Math.min(1f, oy + oh)) : (oy + oh); + return outRowLast4; + } + + private float calculateOverlap(final float[] a, final float[] b) { + float w = Math.max(0f, Math.min(a[2], b[2]) - Math.max(a[0], b[0])); + float h = Math.max(0f, Math.min(a[3], b[3]) - Math.max(a[1], b[1])); + float i = w * h; + float u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i; + return u <= 0.f ? 0f : (i / u); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 8bc28a2c21b..7075cb05efa 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -67,6 +67,7 @@ public class PtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable * @param manager the manager to use for the block * @param handle the module handle */ + @SuppressWarnings("this-escape") public PtSymbolBlock(PtNDManager manager, long handle) { this(manager); this.handle = new AtomicReference<>(handle); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index aad38ae8f0c..40a6a0065bc 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1040,6 +1040,18 @@ public static PtNDArray stft( return new PtNDArray(ndArray.getManager(), handle); } + public static PtNDArray fft2(PtNDArray ndArray, long[] sizes, long[] axes) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchFft2(ndArray.getHandle(), sizes, axes)); + } + + public static PtNDArray ifft2(PtNDArray ndArray, long[] sizes, long[] axes) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchIfft2(ndArray.getHandle(), sizes, axes)); + } + public static PtNDArray real(PtNDArray ndArray) { long handle = PyTorchLibrary.LIB.torchViewAsReal(ndArray.getHandle()); if (handle == -1) { @@ -1145,6 +1157,12 @@ public static PtNDArray atan(PtNDArray ndArray) { ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle())); } + public static PtNDArray atan2(PtNDArray self, PtNDArray other) { + return new PtNDArray( + self.getManager(), + PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle())); + } + public static PtNDArray sqrt(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle())); @@ -1334,6 +1352,11 @@ public static PtNDArray erfinv(PtNDArray ndArray) { ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle())); } + public static PtNDArray erf(PtNDArray ndArray) { + return new PtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchErf(ndArray.getHandle())); + } + public static PtNDArray inverse(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle())); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java index 9d422463910..83e0f5b5b95 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java @@ -65,6 +65,7 @@ public final class LibUtils { private static final Pattern VERSION_PATTERN = Pattern.compile("(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)(-SNAPSHOT)?(-\\d+)?"); + private static final Pattern LIB_PATTERN = Pattern.compile("(.*\\.(so(\\.\\d+)*|dll|dylib))"); private static LibTorch libTorch; @@ -106,10 +107,19 @@ public static String getLibtorchPath() { private static void loadLibTorch(LibTorch libTorch) { Path libDir = libTorch.dir.toAbsolutePath(); - if ("1.8.1".equals(getVersion()) && System.getProperty("os.name").startsWith("Mac")) { - // PyTorch 1.8.1 libtorch_cpu.dylib cannot be loaded individually - return; + if (Files.exists(libDir.resolve("libstdc++.so.6"))) { + String libstd = Utils.getEnvOrSystemProperty("LIBSTDCXX_LIBRARY_PATH"); + if (libstd != null) { + try { + logger.info("Loading libstdc++.so.6 from: {}", libstd); + System.load(libstd); + } catch (UnsatisfiedLinkError e) { + logger.warn("Failed Loading libstdc++.so.6 from: {}", libstd); + } + } } + String libExclusion = Utils.getEnvOrSystemProperty("PYTORCH_LIBRARY_EXCLUSION", ""); + Set exclusion = new HashSet<>(Arrays.asList(libExclusion.split(","))); boolean isCuda = libTorch.flavor.contains("cu"); List deferred = Arrays.asList( @@ -120,6 +130,7 @@ private static void loadLibTorch(LibTorch libTorch) { System.mapLibraryName("torch_cuda_cpp"), System.mapLibraryName("torch_cuda_cu"), System.mapLibraryName("torch_cuda"), + System.mapLibraryName("nvfuser_codegen"), System.mapLibraryName("torch")); Set loadLater = new HashSet<>(deferred); @@ -128,12 +139,16 @@ private static void loadLibTorch(LibTorch libTorch) { paths.filter( path -> { String name = path.getFileName().toString(); - if (!isCuda + if (!LIB_PATTERN.matcher(name).matches() + || exclusion.contains(name)) { + return false; + } else if (!isCuda && name.contains("nvrtc") && name.contains("cudart") && name.contains("nvTools")) { return false; - } else if (name.startsWith("libarm_compute-")) { + } else if (name.startsWith("libarm_compute-") + || name.startsWith("libopenblasp")) { rank.put(path, 2); return true; } else if (name.startsWith("libarm_compute_")) { @@ -219,10 +234,21 @@ private static Path findJniLibrary(LibTorch libTorch) { String djlVersion = libTorch.apiVersion; String flavor = libTorch.flavor; + // Looking for JNI in libTorch.dir first + Path libDir = libTorch.dir.toAbsolutePath(); + Path path = libDir.resolve(djlVersion + '-' + JNI_LIB_NAME); + if (Files.exists(path)) { + return path; + } + path = libDir.resolve(JNI_LIB_NAME); + if (Files.exists(path)) { + return path; + } + // always use cache dir, cache dir might be different from libTorch.dir Path cacheDir = Utils.getEngineCacheDir("pytorch"); Path dir = cacheDir.resolve(version + '-' + flavor + '-' + classifier); - Path path = dir.resolve(djlVersion + '-' + JNI_LIB_NAME); + path = dir.resolve(djlVersion + '-' + JNI_LIB_NAME); if (Files.exists(path)) { return path; } @@ -349,8 +375,9 @@ private static void loadNativeLibrary(String path) { String nativeHelper = System.getProperty("ai.djl.pytorch.native_helper"); if (nativeHelper != null && !nativeHelper.isEmpty()) { ClassLoaderUtils.nativeLoad(nativeHelper, path); + } else { + System.load(path); // NOPMD } - System.load(path); // NOPMD } private static LibTorch downloadPyTorch(Platform platform) { @@ -541,8 +568,10 @@ private static final class LibTorch { if (flavor == null || flavor.isEmpty()) { if (CudaUtils.getGpuCount() > 0) { flavor = "cu" + CudaUtils.getCudaVersionString() + "-precxx11"; - } else { + } else if ("linux".equals(platform.getOsPrefix())) { flavor = "cpu-precxx11"; + } else { + flavor = "cpu"; } } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index c0f7b553ab2..54fc5419145 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -273,6 +273,10 @@ native long torchStft( boolean normalize, boolean returnComplex); + native long torchFft2(long handle, long[] sizes, long[] axes); + + native long torchIfft2(long handle, long[] sizes, long[] axes); + native long torchViewAsReal(long handle); native long torchViewAsComplex(long handle); @@ -332,6 +336,8 @@ native long[] torchUnique( native long torchAtan(long handle); + native long torchAtan2(long self, long other); + native long torchSqrt(long handle); native long torchSinh(long handle); @@ -405,6 +411,8 @@ native long tensorUniform( native long torchErfinv(long handle); + native long torchErf(long handle); + native long torchInverse(long self); native long torchNNInterpolate(long handle, long[] size, int mode, boolean alignCorners); diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java similarity index 73% rename from engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java rename to engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java index 617d2cfb809..f6cfda91106 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java @@ -18,17 +18,21 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class LibUtilsTest { +// Ensure this test run first +public class ALibUtilsTest { @BeforeClass public void setup() { - System.setProperty( - "ai.djl.pytorch.native_helper", "ai.djl.pytorch.integration.LibUtilsTest"); + System.setProperty("ai.djl.pytorch.native_helper", ALibUtilsTest.class.getName()); + System.setProperty("STDCXX_LIBRARY_PATH", "/usr/lib/non-exists"); + System.setProperty("PYTORCH_PRECXX11", "true"); } @AfterClass public void teardown() { System.clearProperty("ai.djl.pytorch.native_helper"); + System.clearProperty("LIBSTDCXX_LIBRARY_PATH"); + System.clearProperty("PYTORCH_PRECXX11"); } @Test diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java index 8b4e2326f26..e8f6e5d405f 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.integration; import ai.djl.Device; +import ai.djl.modality.Classifications; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; @@ -21,6 +22,9 @@ import org.testng.SkipException; import org.testng.annotations.Test; +import java.util.Arrays; +import java.util.List; + public class MpsTest { @Test @@ -36,4 +40,39 @@ public void testMps() { Assert.assertEquals(array.getDevice().getDeviceType(), "mps"); } } + + private static boolean checkMpsCompatible() { + return "aarch64".equals(System.getProperty("os.arch")) + && System.getProperty("os.name").startsWith("Mac"); + } + + @Test + public void testToTensorMPS() { + if (!checkMpsCompatible()) { + throw new SkipException("MPS toTensor test requires Apple Silicon macOS."); + } + + // Test that toTensor does not fail on MPS (e.g. due to use of float64 for division) + try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { + NDArray array = manager.create(127f).reshape(1, 1, 1, 1); + NDArray tensor = array.getNDArrayInternal().toTensor(); + Assert.assertEquals(tensor.toFloatArray(), new float[] {127f / 255f}); + } + } + + @Test + public void testClassificationsMPS() { + if (!checkMpsCompatible()) { + throw new SkipException("MPS classification test requires Apple Silicon macOS."); + } + + // Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to + // float64) + try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { + List names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth"); + NDArray tensor = manager.create(new float[] {0f, 0.125f, 1f, 0.5f, 0.25f}); + Classifications classifications = new Classifications(names, tensor); + Assert.assertEquals(classifications.best().getClassName(), "Third"); + } + } } diff --git a/engines/pytorch/pytorch-jni/build.gradle b/engines/pytorch/pytorch-jni/build.gradle index 450c832e803..c2b0ee9dc7b 100644 --- a/engines/pytorch/pytorch-jni/build.gradle +++ b/engines/pytorch/pytorch-jni/build.gradle @@ -24,7 +24,13 @@ processResources { "osx-x86_64/cpu/libdjl_torch.dylib", "win-x86_64/cpu/djl_torch.dll" ] - if (ptVersion.startsWith("2.0.")) { + if (ptVersion.startsWith("2.1.")) { + files.add("linux-aarch64/cpu-precxx11/libdjl_torch.so") + files.add("linux-x86_64/cu121/libdjl_torch.so") + files.add("linux-x86_64/cu121-precxx11/libdjl_torch.so") + files.add("win-x86_64/cu121/djl_torch.dll") + files.add("osx-aarch64/cpu/libdjl_torch.dylib") + } else if (ptVersion.startsWith("2.0.")) { files.add("linux-aarch64/cpu-precxx11/libdjl_torch.so") files.add("linux-x86_64/cu118/libdjl_torch.so") files.add("linux-x86_64/cu118-precxx11/libdjl_torch.so") diff --git a/engines/pytorch/pytorch-model-zoo/README.md b/engines/pytorch/pytorch-model-zoo/README.md index 8d3113842e1..41f677fdd6c 100644 --- a/engines/pytorch/pytorch-model-zoo/README.md +++ b/engines/pytorch/pytorch-model-zoo/README.md @@ -25,7 +25,7 @@ You can pull the PyTorch engine from the central Maven repository by including t ai.djl.pytorch pytorch-model-zoo - 0.23.0 + 0.27.0 ``` diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index ea70871eff0..abb820cced9 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -38,6 +38,7 @@ public class PtModelZoo extends ModelZoo { REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1")); + addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); addModel(REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1")); addModel(REPOSITORY.model(NLP.SENTIMENT_ANALYSIS, GROUP_ID, "distilbert", "0.0.1")); addModel(REPOSITORY.model(CV.IMAGE_GENERATION, GROUP_ID, "biggan-deep", "0.0.1")); diff --git a/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json new file mode 100644 index 00000000000..399b79b4889 --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json @@ -0,0 +1,40 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/object_detection", + "groupId": "ai.djl.pytorch", + "artifactId": "yolov8n", + "name": "yolov8n", + "description": "YoloV8 Model", + "website": "http://www.djl.ai/engines/onnxruntime/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "yolov8n", + "arguments": { + "width": 640, + "height": 640, + "resize": true, + "rescale": true, + "optApplyRatio": true, + "threshold": 0.6, + "translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/yolov8n.zip", + "name": "", + "sha1Hash": "a868778452ef8d6d2f9cb7109a9e14a64e851d48", + "size": 11183356 + } + } + } + ] +} diff --git a/engines/pytorch/pytorch-native/CMakeLists.txt b/engines/pytorch/pytorch-native/CMakeLists.txt index 4453186be6f..c53d71dc93e 100644 --- a/engines/pytorch/pytorch-native/CMakeLists.txt +++ b/engines/pytorch/pytorch-native/CMakeLists.txt @@ -60,11 +60,12 @@ if(USE_CUDA) endif() add_library(djl_torch SHARED ${SOURCE_FILES}) +set_property(TARGET djl_torch PROPERTY CXX_STANDARD 17) + # build host if(NOT BUILD_ANDROID) target_link_libraries(djl_torch "${TORCH_LIBRARIES}") target_include_directories(djl_torch PUBLIC build/include ${JNI_INCLUDE_DIRS} ${UTILS_INCLUDE_DIR}) - set_property(TARGET djl_torch PROPERTY CXX_STANDARD 14) # We have to kill the default rpath and use current dir set(CMAKE_SKIP_RPATH TRUE) if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") diff --git a/engines/pytorch/pytorch-native/build.gradle b/engines/pytorch/pytorch-native/build.gradle index b4a195e109f..99a658bf3ed 100644 --- a/engines/pytorch/pytorch-native/build.gradle +++ b/engines/pytorch/pytorch-native/build.gradle @@ -24,6 +24,8 @@ if (project.hasProperty("cu11")) { FLAVOR = "cu117" } else if (VERSION.startsWith("2.0.")) { FLAVOR = "cu118" + } else if (VERSION.startsWith("2.1.")) { + FLAVOR = "cu121" } else { throw new GradleException("Unsupported PyTorch version: ${VERSION}") } @@ -88,15 +90,17 @@ def prepareNativeLib(String binaryRoot, String ver) { def officialPytorchUrl = "https://download.pytorch.org/libtorch" def aarch64PytorchUrl = "https://djl-ai.s3.amazonaws.com/publish/pytorch" - String cu11 + String cuda if (ver.startsWith("1.11.")) { - cu11 = "cu113" + cuda = "cu113" } else if (ver.startsWith("1.12.")) { - cu11 = "cu116" + cuda = "cu116" } else if (ver.startsWith("1.13.")) { - cu11 = "cu117" + cuda = "cu117" } else if (ver.startsWith("2.0.")) { - cu11 = "cu118" + cuda = "cu118" + } else if (ver.startsWith("2.1.")) { + cuda = "cu121" } else { throw new GradleException("Unsupported PyTorch version: ${ver}") } @@ -105,10 +109,10 @@ def prepareNativeLib(String binaryRoot, String ver) { "cpu/libtorch-cxx11-abi-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/linux-x86_64", "cpu/libtorch-macos-${ver}.zip" : "cpu/osx-x86_64", "cpu/libtorch-win-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/win-x86_64", - "${cu11}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cu11}.zip": "${cu11}/linux-x86_64", - "${cu11}/libtorch-win-shared-with-deps-${ver}%2B${cu11}.zip" : "${cu11}/win-x86_64", + "${cuda}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cuda}.zip": "${cuda}/linux-x86_64", + "${cuda}/libtorch-win-shared-with-deps-${ver}%2B${cuda}.zip" : "${cuda}/win-x86_64", "cpu/libtorch-shared-with-deps-${ver}%2Bcpu.zip" : "cpu-precxx11/linux-x86_64", - "${cu11}/libtorch-shared-with-deps-${ver}%2B${cu11}.zip" : "${cu11}-precxx11/linux-x86_64" + "${cuda}/libtorch-shared-with-deps-${ver}%2B${cuda}.zip" : "${cuda}-precxx11/linux-x86_64" ] def aarch64Files = [ @@ -138,17 +142,12 @@ def copyNativeLibToOutputDir(Map fileStoreMap, String binaryRoot from zipTree(file) into outputDir } - // CPU dependencies - copy { - from("${outputDir}/libtorch/lib/") { - include "libc10.*", "c10.dll", "libiomp5*.*", "libarm_compute*.*", "libgomp*.*", "libnvfuser_codegen.so", "libtorch.*", "libtorch_cpu.*", "torch.dll", "torch_cpu.dll", "fbgemm.dll", "asmjit.dll", "uv.dll", "nvfuser_codegen.dll" - } - into("${outputDir}/native/lib") - } - // GPU dependencies + delete "${outputDir}/libtorch/lib/*.lib" + delete "${outputDir}/libtorch/lib/*.a" + copy { from("${outputDir}/libtorch/lib/") { - include "libtorch_cuda*.so", "torch_cuda*.dll", "libc10_cuda.so", "c10_cuda.dll", "libcaffe2_nvrtc.so", "libnvrtc*.so.*", "libcudart*.*", "*nvToolsExt*.*", "cudnn*.dll", "caffe2_nvrtc.dll", "nvrtc64*.dll", "uv.dll", "libcublas*", "zlibwapi.dll" + include "libarm_compute*", "libc10_cuda.so", "libc10.*", "libcaffe2_nvrtc.so", "libcu*", "libgfortran-*", "libgomp*", "libiomp*", "libnv*", "libopenblasp-*", "libtorch_cpu.*", "libtorch_cuda*.so", "libtorch.*", "asmjit.dll", "c10_cuda.dll", "c10.dll", "caffe2_nvrtc.dll", "cu*.dll", "fbgemm.dll", "nv*.dll", "torch_cpu.dll", "torch_cuda*.dll", "torch.dll", "uv.dll", "zlibwapi.dll" } into("${outputDir}/native/lib") } @@ -287,9 +286,9 @@ tasks.register('uploadS3') { "${BINARY_ROOT}/cpu/win-x86_64/native/lib/", "${BINARY_ROOT}/cpu-precxx11/linux-aarch64/native/lib/", "${BINARY_ROOT}/cpu-precxx11/linux-x86_64/native/lib/", - "${BINARY_ROOT}/cu118/linux-x86_64/native/lib/", - "${BINARY_ROOT}/cu118/win-x86_64/native/lib/", - "${BINARY_ROOT}/cu118-precxx11/linux-x86_64/native/lib/" + "${BINARY_ROOT}/cu121/linux-x86_64/native/lib/", + "${BINARY_ROOT}/cu121/win-x86_64/native/lib/", + "${BINARY_ROOT}/cu121-precxx11/linux-x86_64/native/lib/" ] uploadDirs.each { item -> fileTree(item).files.name.each { diff --git a/engines/pytorch/pytorch-native/build.sh b/engines/pytorch/pytorch-native/build.sh index 78c59d6bf2a..ae0456bec62 100755 --- a/engines/pytorch/pytorch-native/build.sh +++ b/engines/pytorch/pytorch-native/build.sh @@ -23,22 +23,22 @@ ARCH=$4 if [[ ! -d "libtorch" ]]; then if [[ $PLATFORM == 'linux' ]]; then - if [[ ! "$FLAVOR" =~ ^(cpu|cu102|cu113|cu116|cu117|cu118)$ ]]; then + if [[ ! "$FLAVOR" =~ ^(cpu|cu102|cu113|cu116|cu117|cu118|cu121)$ ]]; then echo "$FLAVOR is not supported." exit 1 fi if [[ $ARCH == 'aarch64' ]]; then - curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch${AARCH64_CXX11ABI}-shared-with-deps-${VERSION}-aarch64.zip | jar xv + curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch${AARCH64_CXX11ABI}-shared-with-deps-${VERSION}-aarch64.zip | jar xv > /dev/null else - curl -s https://download.pytorch.org/libtorch/${FLAVOR}/libtorch${CXX11ABI}-shared-with-deps-${VERSION}%2B${FLAVOR}.zip | jar xv + curl -s https://download.pytorch.org/libtorch/${FLAVOR}/libtorch${CXX11ABI}-shared-with-deps-${VERSION}%2B${FLAVOR}.zip | jar xv > /dev/null fi elif [[ $PLATFORM == 'darwin' ]]; then if [[ $ARCH == 'aarch64' ]]; then - curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch-macos-${VERSION}-aarch64.zip | jar xv + curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch-macos-${VERSION}-aarch64.zip | jar xv > /dev/null else - curl -s https://download.pytorch.org/libtorch/cpu/libtorch-macos-${VERSION}.zip | jar xv + curl -s https://download.pytorch.org/libtorch/cpu/libtorch-macos-${VERSION}.zip | jar xv > /dev/null fi else echo "$PLATFORM is not supported." @@ -62,6 +62,12 @@ mkdir classes javac -sourcepath ../../pytorch-engine/src/main/java/ ../../pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java -h include -d classes cmake -DCMAKE_PREFIX_PATH=libtorch -DPT_VERSION=${PT_VERSION} -DUSE_CUDA=$USE_CUDA .. cmake --build . --config Release -- -j "${NUM_PROC}" +if [[ "$FLAVOR" = cu* ]]; then + # avoid link with libcudart.so.11.0 + sed -i -r "s/\/usr\/local\/cuda(.{5})?\/lib64\/lib(cudart|nvrtc).so//g" CMakeFiles/djl_torch.dir/link.txt + rm libdjl_torch.so + . CMakeFiles/djl_torch.dir/link.txt +fi if [[ $PLATFORM == 'darwin' ]]; then install_name_tool -add_rpath @loader_path libdjl_torch.dylib diff --git a/engines/pytorch/pytorch-native/build_android.sh b/engines/pytorch/pytorch-native/build_android.sh index b37dd96a86d..72050b20a85 100755 --- a/engines/pytorch/pytorch-native/build_android.sh +++ b/engines/pytorch/pytorch-native/build_android.sh @@ -20,7 +20,7 @@ if [[ ! -d libtorch_android/"$FLAVOR" ]]; then mkdir -p libtorch_android/"$FLAVOR" cd libtorch_android/"$FLAVOR" echo "Downloading https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip" - curl -s "https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip" | jar xv + curl -s "https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip" | jar xv > /dev/null mv install/include include cd - fi diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc index 5a65e1eca69..08932098da9 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc @@ -34,6 +34,28 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft2( + JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const std::vector sizes = djl::utils::jni::GetVecFromJLongArray(env, js); + const std::vector axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes); + const auto* result_ptr = new torch::Tensor(torch::fft_fft2(*tensor_ptr, sizes, axes)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIfft2( + JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const std::vector sizes = djl::utils::jni::GetVecFromJLongArray(env, js); + const std::vector axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes); + const auto* result_ptr = new torch::Tensor(torch::fft_ifft2(*tensor_ptr, sizes, axes)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchStft(JNIEnv* env, jobject jthis, jlong jhandle, jlong jn_fft, jlong jhop_length, jlong jwindow, jboolean jcenter, jboolean jnormalize, jboolean jreturn_complex) { #ifdef V1_11_X diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc index 28e40e916be..ccf2616dc65 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc @@ -355,6 +355,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan(JNIEnv* API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan2( +JNIEnv* env, jobject jthis, jlong jself, jlong jother) { + API_BEGIN() + const auto* self_ptr = reinterpret_cast(jself); + const auto* other_ptr = reinterpret_cast(jother); + const auto* result_ptr = new torch::Tensor(self_ptr->atan2(*other_ptr)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSqrt(JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); @@ -496,6 +506,14 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErfinv(JNIEn API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErf(JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const auto* result_ptr = new torch::Tensor(tensor_ptr->erf()); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchInverse(JNIEnv* env, jobject jthis, jlong jself) { API_BEGIN() const auto* self_ptr = reinterpret_cast(jself); diff --git a/engines/tensorflow/tensorflow-api/README.md b/engines/tensorflow/tensorflow-api/README.md index fd2741dc9e4..12766d87669 100644 --- a/engines/tensorflow/tensorflow-api/README.md +++ b/engines/tensorflow/tensorflow-api/README.md @@ -16,6 +16,6 @@ You can pull the TensorFlow core java API from the central Maven repository by i ai.djl.tensorflow tensorflow-api - 0.23.0 + 0.27.0 ``` diff --git a/engines/tensorflow/tensorflow-engine/README.md b/engines/tensorflow/tensorflow-engine/README.md index 57bcdda98d7..17573ed7127 100644 --- a/engines/tensorflow/tensorflow-engine/README.md +++ b/engines/tensorflow/tensorflow-engine/README.md @@ -28,13 +28,13 @@ The javadocs output is built in the `build/doc/javadoc` folder. You can pull the TensorFlow engine from the central Maven repository by including the following dependency: -- ai.djl.tensorflow:tensorflow-engine:0.23.0 +- ai.djl.tensorflow:tensorflow-engine:0.27.0 ```xml ai.djl.tensorflow tensorflow-engine - 0.23.0 + 0.27.0 runtime ``` diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index d964ea5c295..ad440a47951 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -37,7 +37,9 @@ public int getEngineRank() { public Engine getEngine() { if (engine == null) { synchronized (TfEngineProvider.class) { - engine = TfEngine.newInstance(); + if (engine == null) { + engine = TfEngine.newInstance(); + } } } return engine; diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 07c31bacd99..419be4c09f6 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -457,6 +457,12 @@ public NDArray erfinv() { return manager.opExecutor("Erfinv").addInput(this).buildSingletonOrThrow(); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return manager.opExecutor("Erf").addInput(this).buildSingletonOrThrow(); + } + /** {@inheritDoc} */ @Override public NDArray norm(boolean keepDims) { @@ -911,6 +917,12 @@ public NDArray atan() { return manager.opExecutor("Atan").addInput(this).buildSingletonOrThrow(); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + return manager.opExecutor("Atan2").addInput(this).addInput(other).buildSingletonOrThrow(); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { @@ -1172,6 +1184,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { diff --git a/engines/tensorflow/tensorflow-model-zoo/README.md b/engines/tensorflow/tensorflow-model-zoo/README.md index b34154fa126..663f3ff840a 100644 --- a/engines/tensorflow/tensorflow-model-zoo/README.md +++ b/engines/tensorflow/tensorflow-model-zoo/README.md @@ -26,7 +26,7 @@ from the central Maven repository by including the following dependency: ai.djl.tensorflow tensorflow-model-zoo - 0.23.0 + 0.27.0 ``` diff --git a/engines/tensorflow/tensorflow-native/build.gradle b/engines/tensorflow/tensorflow-native/build.gradle index 8138d93334d..56cd6eed9e2 100644 --- a/engines/tensorflow/tensorflow-native/build.gradle +++ b/engines/tensorflow/tensorflow-native/build.gradle @@ -153,6 +153,7 @@ flavorNames.each { flavor -> } from file("${BINARY_ROOT}/${flavor}/${osName}") archiveClassifier = "${osName}-x86_64" + archiveBaseName = "tensorflow-native-${flavor}" manifest { attributes("Automatic-Module-Name": "ai.djl.tensorflow_native_${flavor}_${osName}") diff --git a/engines/tensorrt/CMakeLists.txt b/engines/tensorrt/CMakeLists.txt index 21c1e64d96e..6c56505d6ef 100644 --- a/engines/tensorrt/CMakeLists.txt +++ b/engines/tensorrt/CMakeLists.txt @@ -7,10 +7,10 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(JAVA_AWT_LIBRARY NotNeeded) set(JAVA_AWT_INCLUDE_PATH NotNeeded) find_package(JNI REQUIRED) -#find_library(TRT_ONNX_PARSER -# NAMES nvonnxparser -# PATH_SUFFIXES lib -# REQUIRED) +find_library(TRT_ONNX_PARSER + NAMES nvonnxparser + PATH_SUFFIXES lib + REQUIRED) find_path(UTILS_INCLUDE_DIR NAMES djl/utils.h @@ -37,4 +37,4 @@ target_include_directories(djl_trt PUBLIC main/native trt/include build/include) -target_link_libraries(djl_trt nvonnxparser nvparsers) +target_link_libraries(djl_trt nvonnxparser) diff --git a/engines/tensorrt/README.md b/engines/tensorrt/README.md index 6373386479e..f3844b18aa0 100644 --- a/engines/tensorrt/README.md +++ b/engines/tensorrt/README.md @@ -28,13 +28,13 @@ The javadocs output is generated in the `build/doc/javadoc` folder. ## Installation You can pull the TensorRT engine from the central Maven repository by including the following dependency: -- ai.djl.tensorrt:tensorrt:0.23.0 +- ai.djl.tensorrt:tensorrt:0.27.0 ```xml ai.djl.tensorrt tensorrt - 0.23.0 + 0.27.0 runtime ``` diff --git a/engines/tensorrt/build.sh b/engines/tensorrt/build.sh index c2ad26c00a2..2f31d1146bf 100755 --- a/engines/tensorrt/build.sh +++ b/engines/tensorrt/build.sh @@ -8,7 +8,8 @@ VERSION="$(cat ../../gradle.properties | awk -F '=' '/trt_version/ {print $2}')" if [ ! -d "trt" ]; then - git clone https://github.com/NVIDIA/TensorRT.git -b $VERSION trt + git clone --recurse-submodules https://github.com/NVIDIA/TensorRT.git -b v$VERSION trt + cp -f trt/parsers/onnx/NvOnnxParser.h trt/include fi if [ ! -d "build" ]; diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index 05a7eceeb41..d92ed9e449d 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -18,8 +18,6 @@ /** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */ public class TrtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,11 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (TrtEngineProvider.class) { - engine = TrtEngine.newInstance(); - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = TrtEngine.newInstance(); } } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtModel.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtModel.java index 44047e0e614..6a8ddb3a54c 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtModel.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtModel.java @@ -62,7 +62,10 @@ public void load(Path modelPath, String prefix, Map options) throws I if (modelFile == null) { modelFile = findModelFile(modelDir.toFile().getName()); if (modelFile == null) { - throw new FileNotFoundException(prefix + ".* file not found in: " + modelDir); + modelFile = findModelFile("model.onnx"); + if (modelFile == null) { + throw new FileNotFoundException(prefix + ".* file not found in: " + modelDir); + } } } String filePath = modelFile.toString(); diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java index 96066b380e1..d800ca13369 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java @@ -26,9 +26,9 @@ public void getVersion() { try { Engine engine = Engine.getEngine("TensorRT"); version = engine.getVersion(); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } - Assert.assertEquals(version, "8.4.1"); + Assert.assertEquals(version, "9.2.0"); } } diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java index 24d734af54c..09001f0e2da 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java @@ -28,7 +28,7 @@ public void testNDArray() { Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java index 105e057ba0a..2e3215cf464 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java @@ -49,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -70,12 +70,12 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException } } - @Test + @Test(enabled = false) public void testTrtUff() throws ModelException, IOException, TranslateException { Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -107,12 +107,12 @@ public void testTrtUff() throws ModelException, IOException, TranslateException } } - @Test + @Test(enabled = false) public void testSerializedEngine() throws ModelException, IOException, TranslateException { Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Device device = engine.defaultDevice(); diff --git a/engines/tflite/tflite-engine/README.md b/engines/tflite/tflite-engine/README.md index b1dd8fc9778..6a285b50f4e 100644 --- a/engines/tflite/tflite-engine/README.md +++ b/engines/tflite/tflite-engine/README.md @@ -24,13 +24,13 @@ The javadocs output is built in the `build/doc/javadoc` folder. ## Installation You can pull the TensorFlow Lite engine from the central Maven repository by including the following dependency: -- ai.djl.tflite:tflite-engine:0.23.0 +- ai.djl.tflite:tflite-engine:0.27.0 ```xml ai.djl.tflite tflite-engine - 0.23.0 + 0.27.0 runtime ``` diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index aa0fdb73d21..fb61551a3bf 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -18,8 +18,6 @@ /** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */ public class TfLiteEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,11 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (TfLiteEngineProvider.class) { - engine = TfLiteEngine.newInstance(); - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = TfLiteEngine.newInstance(); } } diff --git a/engines/tflite/tflite-native/build.gradle b/engines/tflite/tflite-native/build.gradle index eb045331c12..3e2a6008f38 100644 --- a/engines/tflite/tflite-native/build.gradle +++ b/engines/tflite/tflite-native/build.gradle @@ -155,6 +155,7 @@ flavorNames.each { flavor -> from file("src/main/resources") from file("${project.buildDir}/classes/java/main") archiveClassifier = "${osName}" + archiveBaseName = "tflite-native-${flavor}" manifest { attributes("Automatic-Module-Name": "ai.djl.tflite_native_${flavor}_${osName}") diff --git a/examples/docs/image_classification.md b/examples/docs/image_classification.md index 1f515f9680f..c8f331320a8 100644 --- a/examples/docs/image_classification.md +++ b/examples/docs/image_classification.md @@ -6,7 +6,7 @@ In this example, you learn how to implement inference code with Deep Java Librar The image classification example code can be found at [ImageClassification.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java). -You can also use the [Jupyter notebook tutorial](../../jupyter/tutorial/03_image_classification_with_your_model.ipynb). +You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/03_image_classification_with_your_model.html). The Jupyter notebook explains the key concepts in detail. ## Setup Guide diff --git a/examples/docs/object_detection.md b/examples/docs/object_detection.md index 7d0898128b9..84286fb6e00 100644 --- a/examples/docs/object_detection.md +++ b/examples/docs/object_detection.md @@ -7,7 +7,7 @@ In this example, you learn how to implement inference code with a [ModelZoo mode The source code can be found at [ObjectDetection.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java). -You can also use the [Jupyter notebook tutorial](../../jupyter/object_detection_with_model_zoo.ipynb). +You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/object_detection_with_model_zoo.html). The Jupyter notebook explains the key concepts in detail. ## Setup guide diff --git a/examples/docs/stable_diffusion.md b/examples/docs/stable_diffusion.md index 7eb544646ee..be3cbb48d6e 100644 --- a/examples/docs/stable_diffusion.md +++ b/examples/docs/stable_diffusion.md @@ -1,4 +1,4 @@ -## Stable Diffusion in DJL +# Stable Diffusion in DJL [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) is an open-source model developed by Stability.ai. It aimed to produce images (artwork, pictures, etc.) based on diff --git a/examples/docs/train_cifar10_resnet.md b/examples/docs/train_cifar10_resnet.md index cfaf03f8a61..1cdfcb495c2 100644 --- a/examples/docs/train_cifar10_resnet.md +++ b/examples/docs/train_cifar10_resnet.md @@ -5,7 +5,7 @@ In this example, you learn how to train the [CIFAR-10](https://www.cs.toronto.ed You can find the example source code in: [TrainResnetWithCifar10.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java). -You can also find the Jupyter notebook tutorial [here](../../jupyter/transfer_learning_on_cifar10.ipynb). +You can also find the Jupyter notebook tutorial [here](http://docs.djl.ai/docs/demos/jupyter/transfer_learning_on_cifar10.html). The Jupyter notebook explains the key concepts in detail. ## Setup guide diff --git a/examples/docs/train_mnist_mlp.md b/examples/docs/train_mnist_mlp.md index 72b591d062a..40a32ca365f 100644 --- a/examples/docs/train_mnist_mlp.md +++ b/examples/docs/train_mnist_mlp.md @@ -6,7 +6,7 @@ In this example, you learn how to train the MNIST dataset with Deep Java Library The source code for this example can be found at [TrainMnist.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/TrainMnist.java). -You can also use the [Jupyter notebook tutorial](../../jupyter/tutorial/02_train_your_first_model.ipynb). +You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/02_train_your_first_model.html). The Jupyter notebook explains the key concepts in detail. ## Setup guide diff --git a/examples/pom.xml b/examples/pom.xml index 9eb2ee32fa0..e6a09987174 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -5,12 +5,12 @@ ai.djl examples - 0.24.0-SNAPSHOT + 0.28.0-SNAPSHOT 11 11 - 0.24.0-SNAPSHOT + 0.28.0-SNAPSHOT ai.djl.examples.inference.ObjectDetection @@ -41,7 +41,7 @@ org.apache.logging.log4j log4j-slf4j-impl - 2.18.0 + 2.21.0 ai.djl diff --git a/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java b/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java index b667cd29f90..093e159bebb 100644 --- a/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java +++ b/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java @@ -34,9 +34,8 @@ *

See: * *

    - *
  • the jupyter - * demo with more information about BERT. + *
  • the jupyter demo with more + * information about BERT. *
  • the docs * for information about running this example. diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java new file mode 100644 index 00000000000..3d2cfb26409 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.translator.YoloV8TranslatorFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** An example of inference using an yolov8 model. */ +public final class Yolov8Detection { + + private static final Logger logger = LoggerFactory.getLogger(Yolov8Detection.class); + + private Yolov8Detection() {} + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + DetectedObjects detection = Yolov8Detection.predict(); + logger.info("{}", detection); + } + + public static DetectedObjects predict() throws IOException, ModelException, TranslateException { + Path imgPath = Paths.get("src/test/resources/yolov8_test.jpg"); + Image img = ImageFactory.getInstance().fromFile(imgPath); + + Criteria criteria = + Criteria.builder() + .setTypes(Image.class, DetectedObjects.class) + .optModelUrls("djl://ai.djl.onnxruntime/yolov8n") + .optEngine("OnnxRuntime") + .optArgument("width", 640) + .optArgument("height", 640) + .optArgument("resize", true) + .optArgument("toTensor", true) + .optArgument("applyRatio", true) + .optArgument("threshold", 0.6f) + // for performance optimization maxBox parameter can reduce number of + // considered boxes from 8400 + .optArgument("maxBox", 1000) + .optTranslatorFactory(new YoloV8TranslatorFactory()) + .optProgress(new ProgressBar()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Path outputPath = Paths.get("build/output"); + Files.createDirectories(outputPath); + + DetectedObjects detection = predictor.predict(img); + if (detection.getNumberOfObjects() > 0) { + img.drawBoundingBoxes(detection); + Path output = outputPath.resolve("yolov8_detected.png"); + try (OutputStream os = Files.newOutputStream(output)) { + img.save(os, "png"); + } + logger.info("Detected object saved in: {}", output); + } + return detection; + } + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/face/FaceDetectionTranslator.java b/examples/src/main/java/ai/djl/examples/inference/face/FaceDetectionTranslator.java index c5a04065d5f..088558d6b0b 100644 --- a/examples/src/main/java/ai/djl/examples/inference/face/FaceDetectionTranslator.java +++ b/examples/src/main/java/ai/djl/examples/inference/face/FaceDetectionTranslator.java @@ -40,8 +40,6 @@ public class FaceDetectionTranslator implements Translator CHW RGB -> BGR // The network by default takes float32 @@ -78,6 +78,10 @@ public NDList processInput(TranslatorContext ctx, Image input) { /** {@inheritDoc} */ @Override public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { + + int width = (int) ctx.getAttachment("width"); + int height = (int) ctx.getAttachment("height"); + NDManager manager = ctx.getNDManager(); double scaleXY = variance[0]; double scaleWH = variance[1]; diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java index acbaa152f8c..59cba679ba2 100644 --- a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java @@ -59,6 +59,7 @@ public static String generateTextWithPyTorchGreedy() SearchConfig config = new SearchConfig(); config.setMaxSeqLength(60); + // You can use src/main/python/trace_gpt2.py to trace gpt2 model String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip"; Criteria criteria = @@ -160,6 +161,20 @@ public static String[] generateTextWithOnnxRuntimeBeam() long padTokenId = 220; config.setPadTokenId(padTokenId); + // The model is converted optimum: + // https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-using-past-keysvalues-in-the-decoder + /* + * optimum-cli export onnx --model gpt2 gpt2_onnx/ + * + * from transformers import AutoTokenizer + * from optimum.onnxruntime import ORTModelForCausalLM + * + * tokenizer = AutoTokenizer.from_pretrained("./gpt2_onnx/") + * model = ORTModelForCausalLM.from_pretrained("./gpt2_onnx/") + * inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt") + * gen_tokens = model.generate(**inputs) + * print(tokenizer.batch_decode(gen_tokens)) + */ String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_onnx.zip"; Criteria criteria = diff --git a/examples/src/main/python/trace_gpt2.py b/examples/src/main/python/trace_gpt2.py new file mode 100644 index 00000000000..33c3badb08d --- /dev/null +++ b/examples/src/main/python/trace_gpt2.py @@ -0,0 +1,73 @@ +import torch +from transformers import GPT2LMHeadModel, GPT2Tokenizer + +model_name = 'gpt2-large' +tokenizer = GPT2Tokenizer.from_pretrained(model_name) + +# add the EOS token as PAD token to avoid warnings +model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True) + +# %% model_inputs +output_attentions = False +output_hidden_states = False +model_inputs = {} + +model_inputs['past_key_values'] = torch.load( + "../data/nested_tuple_" + model_name + ".pt") +past_seq = model_inputs['past_key_values'][0][0].shape[-2] +model_inputs['input_ids'] = torch.tensor([[404]]) +model_inputs['position_ids'] = torch.tensor([[past_seq]]) +# |attention_mask| = `len(past_key_values) + len(input_ids)` +model_inputs['attention_mask'] = torch.ones(past_seq + 1, dtype=torch.int64) + +model_inputs['use_cache'] = True +model_inputs['token_type_ids'] = None + +model_inputs['return_dict'] = False +model_inputs['output_attentions'] = False +model_inputs['output_hidden_states'] = False + +# This is a testing of text generation +outputs = model(**model_inputs) + +# %% Wrapper class of GPT2LMHeadModel +from typing import Tuple + +class Tracable(torch.nn.Module): + def __init__(self, config: dict): + super().__init__() + self.model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True) + self.config = {'use_cache': config.get('use_cache', True), + 'token_type_ids': config.get('token_type_ids', None), + 'return_dict': config.get('return_dict', False), + 'output_attentions': config.get('output_attentions', False), + 'output_hidden_states': config.get('output_hidden_states', True)} + + def forward(self, my_input_ids, position_ids, attention_mask, past_key_values): + return self.model(input_ids=my_input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + **self.config) # return_tensor = True + +# %% create class +config = {} +tracable = Tracable(config) +input = (model_inputs['input_ids'], + model_inputs['position_ids'], + model_inputs['attention_mask'], + model_inputs['past_key_values']) + +output = tracable(*input) + +# %% trace +tracable.eval() + +traced_model = torch.jit.trace(tracable, input) +torch.jit.save(traced_model, "../traced_GPT2_hidden.pt") + +out1 = traced_model(*input) + +# %% load back +loaded_model = torch.jit.load("../traced_GPT2_hidden.pt") +out2 = loaded_model(*input) diff --git a/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java new file mode 100644 index 00000000000..35e3fc434aa --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.testing.TestRequirements; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; + +public class Yolov8DetectionTest { + + @Test + public void testYolov8Detection() throws ModelException, TranslateException, IOException { + TestRequirements.engine("MXNet", "PyTorch"); + + DetectedObjects result = Yolov8Detection.predict(); + + Assert.assertTrue(result.getNumberOfObjects() >= 1); + Classifications.Classification obj = result.best(); + String className = obj.getClassName(); + Assert.assertEquals(className, "dog"); + Assert.assertTrue(obj.getProbability() > 0.6); + } +} diff --git a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java index 2a61e25862e..1a5699836c8 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java @@ -27,7 +27,6 @@ public class TrainPikachuTest { @Test public void testDetection() throws IOException, MalformedModelException, TranslateException { - TestRequirements.engine("MXNet"); TestRequirements.nightly(); String[] args; diff --git a/examples/src/test/java/ai/djl/testing/TestRequirements.java b/examples/src/test/java/ai/djl/testing/TestRequirements.java index e8c9bd4bdda..01eef756201 100644 --- a/examples/src/test/java/ai/djl/testing/TestRequirements.java +++ b/examples/src/test/java/ai/djl/testing/TestRequirements.java @@ -14,6 +14,7 @@ import ai.djl.engine.Engine; import ai.djl.engine.EngineException; +import ai.djl.util.Utils; import org.testng.SkipException; @@ -45,7 +46,7 @@ public static void weekly() { /** Requires a test not be run in offline mode. */ public static void notOffline() { - if (Boolean.getBoolean("offline")) { + if (Utils.isOfflineMode()) { throw new SkipException("This test can not run while offline"); } } diff --git a/examples/src/test/resources/yolov8_synset.txt b/examples/src/test/resources/yolov8_synset.txt new file mode 100644 index 00000000000..ffba2064933 --- /dev/null +++ b/examples/src/test/resources/yolov8_synset.txt @@ -0,0 +1,84 @@ +# Classes for coco dataset on which yelov8 is trained +# source config https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml. +# COCO dataset website: https://cocodataset.org/#home +# Ultralytics Coco doc page: https://docs.ultralytics.com/datasets/detect/coco/ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush \ No newline at end of file diff --git a/examples/src/test/resources/yolov8_test.jpg b/examples/src/test/resources/yolov8_test.jpg new file mode 100644 index 00000000000..01e43374348 Binary files /dev/null and b/examples/src/test/resources/yolov8_test.jpg differ diff --git a/examples/src/test/resources/yolov8n.onnx b/examples/src/test/resources/yolov8n.onnx new file mode 100644 index 00000000000..430f7f2beb0 Binary files /dev/null and b/examples/src/test/resources/yolov8n.onnx differ diff --git a/extensions/audio/README.md b/extensions/audio/README.md index 7e2c89692bc..6ec5ade8feb 100644 --- a/extensions/audio/README.md +++ b/extensions/audio/README.md @@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.audio audio - 0.23.0 + 0.27.0 ``` diff --git a/extensions/aws-ai/README.md b/extensions/aws-ai/README.md index 829df0bb0ca..95f7bf2568a 100644 --- a/extensions/aws-ai/README.md +++ b/extensions/aws-ai/README.md @@ -58,6 +58,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.aws aws-ai - 0.23.0 + 0.27.0 ``` diff --git a/extensions/fasttext/README.md b/extensions/fasttext/README.md index 6f5a25064ea..16003dd3927 100644 --- a/extensions/fasttext/README.md +++ b/extensions/fasttext/README.md @@ -34,7 +34,7 @@ You can pull the fastText engine from the central Maven repository by including ai.djl.fasttext fasttext-engine - 0.23.0 + 0.27.0 ``` diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java index 5b421ff431f..4395ddf1a6c 100644 --- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java +++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java @@ -41,6 +41,7 @@ import java.io.IOException; import java.io.InputStream; +import java.net.URI; import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; @@ -129,7 +130,9 @@ public void testWord2Vec() throws IOException, MalformedModelException, ModelNot public void testBlazingText() throws IOException, ModelException { TestRequirements.nightly(); - URL url = new URL("https://resources.djl.ai/test-models/blazingtext_classification.bin"); + URL url = + URI.create("https://resources.djl.ai/test-models/blazingtext_classification.bin") + .toURL(); Path path = Paths.get("build/tmp/model"); Path modelFile = path.resolve("text_classification.bin"); if (!Files.exists(modelFile)) { diff --git a/extensions/hadoop/README.md b/extensions/hadoop/README.md index b3c4ebcc762..38ed91747c8 100644 --- a/extensions/hadoop/README.md +++ b/extensions/hadoop/README.md @@ -52,6 +52,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.hadoop hadoop - 0.23.0 + 0.27.0 ``` diff --git a/extensions/opencv/README.md b/extensions/opencv/README.md index d6c58f518dc..c8f88a80475 100644 --- a/extensions/opencv/README.md +++ b/extensions/opencv/README.md @@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.opencv opencv - 0.23.0 + 0.27.0 ``` diff --git a/extensions/sentencepiece/README.md b/extensions/sentencepiece/README.md index 4308308111f..2dba43c86a9 100644 --- a/extensions/sentencepiece/README.md +++ b/extensions/sentencepiece/README.md @@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.sentencepiece sentencepiece - 0.23.0 + 0.27.0 ``` diff --git a/extensions/spark/README.md b/extensions/spark/README.md index 02ebcc07a1d..957a3f8a3ff 100644 --- a/extensions/spark/README.md +++ b/extensions/spark/README.md @@ -34,7 +34,7 @@ You can pull the module from the central Maven repository by including the follo ai.djl.spark spark_2.12 - 0.23.0 + 0.27.0 ``` diff --git a/extensions/spark/setup/djl_spark/util/files_util.py b/extensions/spark/setup/djl_spark/util/files_util.py index 5e31fc9e177..dd9224000cf 100644 --- a/extensions/spark/setup/djl_spark/util/files_util.py +++ b/extensions/spark/setup/djl_spark/util/files_util.py @@ -70,6 +70,21 @@ def download_and_extract(url, path): :param url: The url of the tar file. :param path: The path to the file to download to. """ + + def is_within_directory(directory, target): + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + prefix = os.path.commonprefix([abs_directory, abs_target]) + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + if not os.path.exists(path): os.makedirs(path) if not os.listdir(path): @@ -78,9 +93,9 @@ def download_and_extract(url, path): if url.startswith("s3://"): s3_download(url, tmp_file) with tarfile.open(name=tmp_file, mode="r:gz") as t: - t.extractall(path=path) + safe_extract(t, path=path) elif url.startswith("http://") or url.startswith("https://"): with urlopen(url) as response, open(tmp_file, 'wb') as f: shutil.copyfileobj(response, f) with tarfile.open(name=tmp_file, mode="r:gz") as t: - t.extractall(path=path) + safe_extract(t, path=path) diff --git a/extensions/tablesaw/README.md b/extensions/tablesaw/README.md index 010c6395eb9..8e092a3df61 100644 --- a/extensions/tablesaw/README.md +++ b/extensions/tablesaw/README.md @@ -25,6 +25,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.tablesaw tablesaw - 0.23.0 + 0.27.0 ``` diff --git a/extensions/timeseries/README.md b/extensions/timeseries/README.md index 9706c9334a4..f5629124a76 100644 --- a/extensions/timeseries/README.md +++ b/extensions/timeseries/README.md @@ -245,6 +245,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.timeseries timeseries - 0.23.0 + 0.27.0 ``` diff --git a/extensions/timeseries/docs/forecast_with_M5_data.md b/extensions/timeseries/docs/forecast_with_M5_data.md index a4f1a24a1d9..4eb1587a66c 100644 --- a/extensions/timeseries/docs/forecast_with_M5_data.md +++ b/extensions/timeseries/docs/forecast_with_M5_data.md @@ -1,5 +1,7 @@ # Forecast the future in a timeseries data with Deep Java Library (DJL) + ## -- Demonstration on M5forecasting and airpassenger datasests + Junyuan Zhang, Kexin Feng Time series data are commonly seen in the world. They can contain valued information that helps forecast for the future, monitor the status of a procedure and feedforward a control. Generic applications includes the following: sales forecasting, stock market analysis, yield projections, process and quality control, and many many more. See [link1](https://www.itl.nist.gov/div898/handbook/pmc/section4/pmc41.htm) and [link2](https://www.influxdata.com/time-series-forecasting-methods/#:~:text=Time%20series%20forecasting%20means%20to,on%20what%20has%20already%20happened) for further examples of timeseries data. @@ -54,7 +56,7 @@ repositories { } dependencies { implementation "org.apache.logging.log4j:log4j-slf4j-impl:2.17.1" - implementation platform("ai.djl:bom:0.23.0") + implementation platform("ai.djl:bom:0.27.0") implementation "ai.djl:api" implementation "ai.djl.timeseries" runtimeOnly "ai.djl.mxnet:mxnet-engine" diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java index 5b642285c3e..9edb45ff5f0 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java @@ -94,15 +94,23 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { Pair update = evaluateHelper(labels, predictions); - totalInstances.compute(key, (k, v) -> v + update.getKey()); - totalLoss.compute( - key, - (k, v) -> { - try (NDArray array = update.getValue().sum()) { - return v + array.getFloat(); - } - }); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + update.getKey()); + totalLoss.compute( + key, + (k, v) -> { + try (NDArray array = update.getValue().sum()) { + return v + array.getFloat(); + } + }); + } } /** {@inheritDoc} */ diff --git a/extensions/tokenizers/README.md b/extensions/tokenizers/README.md index 1b85625572c..fc700007baf 100644 --- a/extensions/tokenizers/README.md +++ b/extensions/tokenizers/README.md @@ -23,7 +23,7 @@ You can pull the module from the central Maven repository by including the follo ai.djl.huggingface tokenizers - 0.23.0 + 0.27.0 ``` diff --git a/extensions/tokenizers/build.cmd b/extensions/tokenizers/build.cmd index 3a481d33bab..d83f2c1ed74 100644 --- a/extensions/tokenizers/build.cmd +++ b/extensions/tokenizers/build.cmd @@ -3,7 +3,7 @@ @rem choco install rust -y @rem choco install jdk8 -y -set VERSION=python-v"%1" +set VERSION=v"%1" if exist "tokenizers" ( echo Found "tokenizers" diff --git a/extensions/tokenizers/build.sh b/extensions/tokenizers/build.sh index 4ba45a09965..229e8124914 100755 --- a/extensions/tokenizers/build.sh +++ b/extensions/tokenizers/build.sh @@ -10,7 +10,7 @@ elif [[ -n $(command -v sysctl) ]]; then fi PLATFORM=$(uname | tr '[:upper:]' '[:lower:]') -VERSION=python-v$1 +VERSION=v$1 ARCH=$2 pushd $WORK_DIR diff --git a/extensions/tokenizers/rust/Cargo.toml b/extensions/tokenizers/rust/Cargo.toml index f6b846f636c..17bdd47b132 100644 --- a/extensions/tokenizers/rust/Cargo.toml +++ b/extensions/tokenizers/rust/Cargo.toml @@ -5,8 +5,8 @@ authors = ["Frank Liu "] edition = "2018" [dependencies] -jni = "0.19.0" -tokenizers = { path = "../tokenizers/tokenizers", version = "*" } +jni = "0.21.1" +tokenizers = { path = "../tokenizers/tokenizers", version = "*", features = ["http"] } [target.'cfg(target_os = "linux")'.dependencies] openssl = { version = "0.10", features = ["vendored"] } diff --git a/extensions/tokenizers/rust/src/lib.rs b/extensions/tokenizers/rust/src/lib.rs index d1c0c455c19..3352f98aa8a 100644 --- a/extensions/tokenizers/rust/src/lib.rs +++ b/extensions/tokenizers/rust/src/lib.rs @@ -15,25 +15,29 @@ extern crate tokenizers as tk; use std::str::FromStr; + +use jni::objects::{ + JClass, JLongArray, JMethodID, JObject, JObjectArray, JString, JValue, ReleaseMode, +}; +use jni::sys::{jboolean, jint, jlong, jsize, jvalue, JNI_TRUE}; +use jni::JNIEnv; +use tk::models::bpe::BPE; use tk::tokenizer::{EncodeInput, Encoding}; use tk::utils::padding::{PaddingParams, PaddingStrategy}; use tk::utils::truncation::{TruncationParams, TruncationStrategy}; use tk::Tokenizer; use tk::{FromPretrainedParameters, Offsets}; -use tk::models::bpe::BPE; - -use jni::objects::{JClass, JMethodID, JObject, JString, JValue, ReleaseMode}; -use jni::sys::{jboolean, jint, jlong, jlongArray, jobjectArray, jsize, jstring, JNI_TRUE}; -use jni::JNIEnv; #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_createTokenizer( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_createTokenizer< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, input: JString, ) -> jlong { let identifier: String = env - .get_string(input) + .get_string(&input) .expect("Couldn't get java string!") .into(); @@ -50,13 +54,15 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_createTokenizerFromString( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_createTokenizerFromString< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, json: JString, ) -> jlong { let data: String = env - .get_string(json) + .get_string(&json) .expect("Couldn't get java string!") .into(); @@ -72,19 +78,21 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ // Tokenizer using BPE model #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_createBpeTokenizer( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_createBpeTokenizer< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, vocabulary: JString, merges: JString, ) -> jlong { let vocabulary: String = env - .get_string(vocabulary) + .get_string(&vocabulary) .expect("Couldn't get java string!") .into(); let merges: String = env - .get_string(merges) + .get_string(&merges) .expect("Couldn't get java string!") .into(); @@ -99,7 +107,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_deleteTokenizer( - _env: JNIEnv, + _: JNIEnv, _: JObject, handle: jlong, ) { @@ -107,8 +115,8 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_encode( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_encode<'local>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, input: JString, @@ -116,7 +124,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ ) -> jlong { let tokenizer = cast_handle::(handle); let sequence: String = env - .get_string(input) + .get_string(&input) .expect("Couldn't get java string!") .into(); @@ -134,8 +142,10 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_encodeDual( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_encodeDual< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, text: JString, @@ -144,11 +154,11 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ ) -> jlong { let tokenizer = cast_handle::(handle); let sequence1: String = env - .get_string(text) + .get_string(&text) .expect("Couldn't get text string!") .into(); let sequence2: String = env - .get_string(text_pair) + .get_string(&text_pair) .expect("Couldn't get text_pair string!") .into(); @@ -167,20 +177,22 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_encodeList( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_encodeList< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, - inputs: jobjectArray, + inputs: JObjectArray<'local>, add_special_tokens: jboolean, ) -> jlong { let tokenizer = cast_handle::(handle); - let len = env.get_array_length(inputs).unwrap(); + let len = env.get_array_length(&inputs).unwrap(); let mut array: Vec = Vec::new(); for i in 0..len { - let item = env.get_object_array_element(inputs, i).unwrap().into(); + let item = env.get_object_array_element(&inputs, i).unwrap().into(); let value: String = env - .get_string(item) + .get_string(&item) .expect("Couldn't get java string!") .into(); array.push(value); @@ -200,20 +212,22 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_batchEncode( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_batchEncode< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, - inputs: jobjectArray, + inputs: JObjectArray<'local>, add_special_tokens: jboolean, -) -> jlongArray { +) -> JLongArray<'local> { let tokenizer = cast_handle::(handle); - let len = env.get_array_length(inputs).unwrap(); + let len = env.get_array_length(&inputs).unwrap(); let mut array: Vec = Vec::new(); for i in 0..len { - let item = env.get_object_array_element(inputs, i).unwrap().into(); + let item = env.get_object_array_element(&inputs, i).unwrap().into(); let value: String = env - .get_string(item) + .get_string(&item) .expect("Couldn't get java string!") .into(); array.push(value); @@ -229,31 +243,33 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ let size = handles.len() as jsize; let ret = env.new_long_array(size).unwrap(); - env.set_long_array_region(ret, 0, &handles).unwrap(); + env.set_long_array_region(&ret, 0, &handles).unwrap(); ret } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_batchEncodePair( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_batchEncodePair< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, - text: jobjectArray, - text_pair: jobjectArray, + text: JObjectArray<'local>, + text_pair: JObjectArray<'local>, add_special_tokens: jboolean, -) -> jlongArray { +) -> JLongArray<'local> { let tokenizer = cast_handle::(handle); - let len = env.get_array_length(text).unwrap(); + let len = env.get_array_length(&text).unwrap(); let mut array: Vec = Vec::new(); for i in 0..len { - let item1 = env.get_object_array_element(text, i).unwrap().into(); - let item2 = env.get_object_array_element(text_pair, i).unwrap().into(); + let item1 = env.get_object_array_element(&text, i).unwrap().into(); + let item2 = env.get_object_array_element(&text_pair, i).unwrap().into(); let sequence1: String = env - .get_string(item1) + .get_string(&item1) .expect("Couldn't get text string!") .into(); let sequence2: String = env - .get_string(item2) + .get_string(&item2) .expect("Couldn't get text_pair string!") .into(); @@ -273,13 +289,13 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ let size = handles.len() as jsize; let ret = env.new_long_array(size).unwrap(); - env.set_long_array_region(ret, 0, &handles).unwrap(); + env.set_long_array_region(&ret, 0, &handles).unwrap(); ret } #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_deleteEncoding( - _env: JNIEnv, + _: JNIEnv, _: JObject, handle: jlong, ) { @@ -287,11 +303,13 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokenIds( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokenIds< + 'local, +>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jlongArray { +) -> JLongArray<'local> { let encoding = cast_handle::(handle); let ids = encoding.get_ids(); let len = ids.len() as jsize; @@ -301,17 +319,19 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ long_ids.push(*i as jlong) } - let array: jlongArray = env.new_long_array(len).unwrap(); - env.set_long_array_region(array, 0, &long_ids).unwrap(); + let array = env.new_long_array(len).unwrap(); + env.set_long_array_region(&array, 0, &long_ids).unwrap(); array } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTypeIds( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTypeIds< + 'local, +>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jlongArray { +) -> JLongArray<'local> { let encoding = cast_handle::(handle); let type_ids = encoding.get_type_ids(); let len = type_ids.len() as jsize; @@ -320,17 +340,19 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ long_ids.push(*i as jlong) } - let array: jlongArray = env.new_long_array(len).unwrap(); - env.set_long_array_region(array, 0, &long_ids).unwrap(); + let array = env.new_long_array(len).unwrap(); + env.set_long_array_region(&array, 0, &long_ids).unwrap(); array } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getWordIds( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getWordIds< + 'local, +>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jlongArray { +) -> JLongArray<'local> { let encoding = cast_handle::(handle); let word_ids = encoding.get_word_ids(); let len = word_ids.len() as jsize; @@ -343,38 +365,42 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } } - let array: jlongArray = env.new_long_array(len).unwrap(); - env.set_long_array_region(array, 0, &long_ids).unwrap(); + let array = env.new_long_array(len).unwrap(); + env.set_long_array_region(&array, 0, &long_ids).unwrap(); array } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokens( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokens< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jobjectArray { +) -> JObjectArray<'local> { let encoding = cast_handle::(handle); let tokens = encoding.get_tokens(); let len = tokens.len() as jsize; - let array: jobjectArray = env + let array = env .new_object_array(len, "java/lang/String", JObject::null()) .unwrap(); for (i, token) in tokens.iter().enumerate() { let item: JString = env.new_string(&token).unwrap(); - env.set_object_array_element(array, i as jsize, item) + env.set_object_array_element(&array, i as jsize, item) .unwrap(); } array } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getAttentionMask( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getAttentionMask< + 'local, +>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jlongArray { +) -> JLongArray<'local> { let encoding = cast_handle::(handle); let attention_masks = encoding.get_attention_mask(); let len = attention_masks.len() as jsize; @@ -383,17 +409,19 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ long_ids.push(*i as jlong) } - let array: jlongArray = env.new_long_array(len).unwrap(); - env.set_long_array_region(array, 0, &long_ids).unwrap(); + let array = env.new_long_array(len).unwrap(); + env.set_long_array_region(&array, 0, &long_ids).unwrap(); array } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getSpecialTokenMask( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getSpecialTokenMask< + 'local, +>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jlongArray { +) -> JLongArray<'local> { let encoding = cast_handle::(handle); let special_token_masks = encoding.get_special_tokens_mask(); let len = special_token_masks.len() as jsize; @@ -402,22 +430,24 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ long_ids.push(*i as jlong) } - let array: jlongArray = env.new_long_array(len).unwrap(); - env.set_long_array_region(array, 0, &long_ids).unwrap(); + let array = env.new_long_array(len).unwrap(); + env.set_long_array_region(&array, 0, &long_ids).unwrap(); array } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokenCharSpans( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokenCharSpans< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jobjectArray { +) -> JObjectArray<'local> { let encoding = cast_handle::(handle); let tokens = encoding.get_tokens(); let len = tokens.len() as jsize; - let array: jobjectArray = env + let array = env .new_object_array( len, "ai/djl/huggingface/tokenizers/jni/CharSpan", @@ -427,22 +457,22 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ for (i, _) in tokens.iter().enumerate() { let opt_offsets: Option<(usize, Offsets)> = encoding.token_to_chars(i); match &opt_offsets { - Some((_, offsets)) => { + Some((_, offsets)) => unsafe { let class_id = "ai/djl/huggingface/tokenizers/jni/CharSpan"; let method_id = ""; let params = "(II)V"; let cls: JClass = env.find_class(class_id).unwrap(); - let constructor: JMethodID = env.get_method_id(cls, method_id, params).unwrap(); - let offsets_vec: Vec = vec![ - JValue::Int((*offsets).0 as jint), - JValue::Int((*offsets).1 as jint), + let constructor: JMethodID = env.get_method_id(&cls, method_id, params).unwrap(); + let offsets_vec: Vec = vec![ + JValue::Int((*offsets).0 as jint).as_jni(), + JValue::Int((*offsets).1 as jint).as_jni(), ]; let obj = env - .new_object_unchecked(cls, constructor, &offsets_vec[..]) + .new_object_unchecked(&cls, constructor, &offsets_vec[..]) .unwrap(); - env.set_object_array_element(array, i as jsize, obj) + env.set_object_array_element(&array, i as jsize, obj) .unwrap(); - } + }, None => {} } } @@ -450,38 +480,38 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getOverflowing( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getOverflowing< + 'local, +>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jlongArray { +) -> JLongArray<'local> { let encoding = cast_handle::(handle); let handles = encoding - .get_overflowing() - .clone() - .into_iter() - .map(|c| to_handle(c)) - .collect::>(); + .get_overflowing() + .clone() + .into_iter() + .map(|c| to_handle(c)) + .collect::>(); let size = handles.len() as jsize; let ret = env.new_long_array(size).unwrap(); - env.set_long_array_region(ret, 0, &handles).unwrap(); + env.set_long_array_region(&ret, 0, &handles).unwrap(); ret } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_decode( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_decode<'local>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, - ids: jlongArray, + ids: JLongArray<'local>, skip_special_tokens: jboolean, -) -> jstring { +) -> JString<'local> { let tokenizer = cast_handle::(handle); - let long_ids = env - .get_long_array_elements(ids, ReleaseMode::NoCopyBack) - .unwrap(); + let long_ids = unsafe { env.get_array_elements(&ids, ReleaseMode::NoCopyBack) }.unwrap(); let long_ids_ptr = long_ids.as_ptr(); - let len = long_ids.size().unwrap() as usize; + let len = long_ids.len(); let mut decode_ids: Vec = Vec::new(); for i in 0..len { unsafe { @@ -490,63 +520,71 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } } let decoding: String = tokenizer - .decode(decode_ids, skip_special_tokens == JNI_TRUE) + .decode(&*decode_ids, skip_special_tokens == JNI_TRUE) .unwrap(); let ret = env - .new_string(decoding) - .expect("Couldn't create java string!") - .into_inner(); + .new_string(&decoding) + .expect("Couldn't create java string!"); ret } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_batchDecode( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_batchDecode< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, - batch_ids: jobjectArray, + batch_ids: JObjectArray<'local>, skip_special_tokens: jboolean, -) -> jobjectArray { +) -> JObjectArray<'local> { let tokenizer = cast_handle::(handle); - let batch_len = env.get_array_length(batch_ids).unwrap(); + let batch_len = env.get_array_length(&batch_ids).unwrap(); let mut batch_decode_input: Vec> = Vec::new(); - for i in 0..batch_len { - let item = env.get_object_array_element(batch_ids, i).unwrap(); - let sequence_ids = env - .get_long_array_elements(*item, ReleaseMode::NoCopyBack) - .unwrap(); - let sequence_ids_ptr = sequence_ids.as_ptr(); - let sequence_len = sequence_ids.size().unwrap() as usize; - let mut decode_ids: Vec = Vec::new(); - for i in 0..sequence_len { - unsafe { + unsafe { + for i in 0..batch_len { + let item: JLongArray<'local> = + JLongArray::from(env.get_object_array_element(&batch_ids, i).unwrap()); + let sequence_ids = env + .get_array_elements(&item, ReleaseMode::NoCopyBack) + .unwrap(); + let sequence_ids_ptr = sequence_ids.as_ptr(); + let sequence_len = sequence_ids.len(); + let mut decode_ids: Vec = Vec::new(); + for i in 0..sequence_len { let val = sequence_ids_ptr.add(i); decode_ids.push(*val as u32); } + batch_decode_input.push(decode_ids); } - batch_decode_input.push(decode_ids); + } + let mut references: Vec<&[u32]> = Vec::new(); + for reference in batch_decode_input.iter() { + references.push(reference); } let decoding: Vec = tokenizer - .decode_batch(batch_decode_input, skip_special_tokens == JNI_TRUE) + .decode_batch(&references, skip_special_tokens == JNI_TRUE) .unwrap(); - let ret: jobjectArray = env + let ret = env .new_object_array(batch_len, "java/lang/String", JObject::null()) .unwrap(); for (i, decode) in decoding.iter().enumerate() { let item: JString = env.new_string(&decode).unwrap(); - env.set_object_array_element(ret, i as jsize, item) + env.set_object_array_element(&ret, i as jsize, item) .unwrap(); } ret } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTruncationStrategy( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTruncationStrategy< + 'local, +>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jstring { +) -> JString<'local> { let tokenizer = cast_handle::(handle); let truncation = tokenizer.get_truncation(); let strategy = match truncation { @@ -556,18 +594,19 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ let ret = env .new_string(strategy.to_string()) - .expect("Couldn't create java string!") - .into_inner(); + .expect("Couldn't create java string!"); ret } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getPaddingStrategy( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getPaddingStrategy< + 'local, +>( + env: JNIEnv<'local>, _: JObject, handle: jlong, -) -> jstring { +) -> JString<'local> { let tokenizer = cast_handle::(handle); let padding = tokenizer.get_padding(); let strategy = match padding { @@ -580,15 +619,14 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ let ret = env .new_string(strategy) - .expect("Couldn't create java string!") - .into_inner(); + .expect("Couldn't create java string!"); ret } #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getMaxLength( - _env: JNIEnv, + _: JNIEnv, _: JObject, handle: jlong, ) -> jint { @@ -613,7 +651,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getStride( - _env: JNIEnv, + _: JNIEnv, _: JObject, handle: jlong, ) -> jint { @@ -628,7 +666,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getPadToMultipleOf( - _env: JNIEnv, + _: JNIEnv, _: JObject, handle: jlong, ) -> jint { @@ -642,8 +680,10 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setPadding( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setPadding< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, max_length: jint, @@ -651,7 +691,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ pad_to_multiple_of: jint, ) { let strategy: String = env - .get_string(padding_strategy) + .get_string(&padding_strategy) .expect("Couldn't get java string!") .into(); let len = max_length as usize; @@ -663,7 +703,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ let res_pad_to_multiple_of = match pad_to_multiple_of as usize { 0 => None, - val => Some(val) + val => Some(val), }; let tokenizer = cast_handle::(handle); @@ -683,7 +723,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_disablePadding( - _env: JNIEnv, + _: JNIEnv, _: JObject, handle: jlong, ) { @@ -692,8 +732,10 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } #[no_mangle] -pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setTruncation( - env: JNIEnv, +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setTruncation< + 'local, +>( + mut env: JNIEnv<'local>, _: JObject, handle: jlong, truncation_max_length: jint, @@ -701,7 +743,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ truncation_stride: jint, ) { let strategy: String = env - .get_string(truncation_strategy) + .get_string(&truncation_strategy) .expect("Couldn't get java string!") .into(); let res_strategy = match strategy.as_ref() { @@ -724,18 +766,18 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ max_length: truncation_max_length as usize, ..Default::default() }; - tokenizer.with_truncation(Some(truncation_params)); + let _ = tokenizer.with_truncation(Some(truncation_params)); } } #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_disableTruncation( - _env: JNIEnv, + _: JNIEnv, _: JObject, handle: jlong, ) { let tokenizer = cast_handle::(handle); - tokenizer.with_truncation(None); + let _ = tokenizer.with_truncation(None); } fn to_handle(val: T) -> jlong { diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java index e58d6ada5ee..887f01646dc 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java @@ -27,6 +27,7 @@ public class Encoding { private long[] specialTokenMask; private CharSpan[] charTokenSpans; private Encoding[] overflowing; + private boolean exceedMaxLength; protected Encoding( long[] ids, @@ -36,6 +37,7 @@ protected Encoding( long[] attentionMask, long[] specialTokenMask, CharSpan[] charTokenSpans, + boolean exceedMaxLength, Encoding[] overflowing) { this.ids = ids; this.typeIds = typeIds; @@ -44,6 +46,7 @@ protected Encoding( this.attentionMask = attentionMask; this.specialTokenMask = specialTokenMask; this.charTokenSpans = charTokenSpans; + this.exceedMaxLength = exceedMaxLength; this.overflowing = overflowing; } @@ -127,6 +130,15 @@ public CharSpan[] getCharTokenSpans() { return charTokenSpans; } + /** + * Returns if tokens exceed max length. + * + * @return {@code true} if tokens exceed max length + */ + public boolean exceedMaxLength() { + return exceedMaxLength; + } + /** * Returns an array of overflowing encodings. * diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index f75342b7cb8..7dd3a9d9552 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -32,6 +32,7 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -44,6 +45,8 @@ public final class HuggingFaceTokenizer extends NativeResource implements private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class); private boolean addSpecialTokens; + private boolean withOverflowingTokens; + private Locale doLowerCase; private TruncationStrategy truncation; private PaddingStrategy padding; private int maxLength; @@ -53,9 +56,8 @@ public final class HuggingFaceTokenizer extends NativeResource implements private HuggingFaceTokenizer(long handle, Map options) { super(handle); - String val = TokenizersLibrary.LIB.getTruncationStrategy(handle); - truncation = TruncationStrategy.fromValue(val); - val = TokenizersLibrary.LIB.getPaddingStrategy(handle); + truncation = TruncationStrategy.LONGEST_FIRST; + String val = TokenizersLibrary.LIB.getPaddingStrategy(handle); padding = PaddingStrategy.fromValue(val); maxLength = TokenizersLibrary.LIB.getMaxLength(handle); stride = TokenizersLibrary.LIB.getStride(handle); @@ -64,6 +66,8 @@ private HuggingFaceTokenizer(long handle, Map options) { if (options != null) { val = options.getOrDefault("addSpecialTokens", "true"); addSpecialTokens = Boolean.parseBoolean(val); + val = options.getOrDefault("withOverflowingTokens", "false"); + withOverflowingTokens = Boolean.parseBoolean(val); modelMaxLength = ArgumentsUtil.intValue(options, "modelMaxLength", 512); if (options.containsKey("truncation")) { truncation = TruncationStrategy.fromValue(options.get("truncation")); @@ -74,6 +78,12 @@ private HuggingFaceTokenizer(long handle, Map options) { maxLength = ArgumentsUtil.intValue(options, "maxLength", maxLength); stride = ArgumentsUtil.intValue(options, "stride", stride); padToMultipleOf = ArgumentsUtil.intValue(options, "padToMultipleOf", padToMultipleOf); + String lowerCase = options.getOrDefault("doLowerCase", "false"); + if ("true".equals(lowerCase)) { + this.doLowerCase = Locale.getDefault(); + } else if (!"false".equals(lowerCase)) { + this.doLowerCase = Locale.forLanguageTag(lowerCase); + } } else { addSpecialTokens = true; modelMaxLength = 512; @@ -203,11 +213,15 @@ public void close() { * @param text the input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence */ - public Encoding encode(String text, boolean addSpecialTokens) { + public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) { + if (doLowerCase != null) { + text = text.toLowerCase(doLowerCase); + } long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -217,7 +231,7 @@ public Encoding encode(String text, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence */ public Encoding encode(String text) { - return encode(text, addSpecialTokens); + return encode(text, addSpecialTokens, withOverflowingTokens); } /** @@ -227,12 +241,19 @@ public Encoding encode(String text) { * @param textPair the second input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence */ - public Encoding encode(String text, String textPair, boolean addSpecialTokens) { + public Encoding encode( + String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) { + if (doLowerCase != null) { + text = text.toLowerCase(doLowerCase); + textPair = textPair.toLowerCase(doLowerCase); + } + long encoding = TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -243,7 +264,7 @@ public Encoding encode(String text, String textPair, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence */ public Encoding encode(String text, String textPair) { - return encode(text, textPair, addSpecialTokens); + return encode(text, textPair, addSpecialTokens, withOverflowingTokens); } /** @@ -252,11 +273,13 @@ public Encoding encode(String text, String textPair) { * @param inputs the input sentences * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentences */ - public Encoding encode(List inputs, boolean addSpecialTokens) { + public Encoding encode( + List inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { String[] array = inputs.toArray(Utils.EMPTY_ARRAY); - return encode(array, addSpecialTokens); + return encode(array, addSpecialTokens, withOverflowingTokens); } /** @@ -266,7 +289,7 @@ public Encoding encode(List inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentences */ public Encoding encode(List inputs) { - return encode(inputs, addSpecialTokens); + return encode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -275,11 +298,18 @@ public Encoding encode(List inputs) { * @param inputs the input sentences * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentences */ - public Encoding encode(String[] inputs, boolean addSpecialTokens) { + public Encoding encode( + String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { + if (doLowerCase != null) { + for (int i = 0; i < inputs.length; ++i) { + inputs[i] = inputs[i].toLowerCase(doLowerCase); + } + } long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -289,7 +319,7 @@ public Encoding encode(String[] inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentences */ public Encoding encode(String[] inputs) { - return encode(inputs, addSpecialTokens); + return encode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -298,11 +328,13 @@ public Encoding encode(String[] inputs) { * @param inputs the batch of input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence in batch */ - public Encoding[] batchEncode(List inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + List inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { String[] array = inputs.toArray(Utils.EMPTY_ARRAY); - return batchEncode(array, addSpecialTokens); + return batchEncode(array, addSpecialTokens, withOverflowingTokens); } /** @@ -312,7 +344,7 @@ public Encoding[] batchEncode(List inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence in batch */ public Encoding[] batchEncode(List inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -321,13 +353,20 @@ public Encoding[] batchEncode(List inputs) { * @param inputs the batch of input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence in batch */ - public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { + if (doLowerCase != null) { + for (int i = 0; i < inputs.length; ++i) { + inputs[i] = inputs[i].toLowerCase(doLowerCase); + } + } long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens); Encoding[] ret = new Encoding[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - ret[i] = toEncoding(encodings[i]); + ret[i] = toEncoding(encodings[i], withOverflowingTokens); } return ret; } @@ -339,7 +378,7 @@ public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence in batch */ public Encoding[] batchEncode(String[] inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -348,17 +387,29 @@ public Encoding[] batchEncode(String[] inputs) { * @param inputs the batch of input text pair * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input text pair in batch */ - public Encoding[] batchEncode(PairList inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + PairList inputs, + boolean addSpecialTokens, + boolean withOverflowingTokens) { String[] text = inputs.keyArray(Utils.EMPTY_ARRAY); String[] textPair = inputs.valueArray(Utils.EMPTY_ARRAY); + if (doLowerCase != null) { + for (int i = 0; i < text.length; ++i) { + text[i] = text[i].toLowerCase(doLowerCase); + } + for (int i = 0; i < textPair.length; ++i) { + textPair[i] = textPair[i].toLowerCase(doLowerCase); + } + } long[] encodings = TokenizersLibrary.LIB.batchEncodePair( getHandle(), text, textPair, addSpecialTokens); Encoding[] ret = new Encoding[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - ret[i] = toEncoding(encodings[i]); + ret[i] = toEncoding(encodings[i], withOverflowingTokens); } return ret; } @@ -370,7 +421,7 @@ public Encoding[] batchEncode(PairList inputs, boolean addSpecia * @return the {@code Encoding} of the input text pair in batch */ public Encoding[] batchEncode(PairList inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -431,6 +482,53 @@ public void enableBatch() { } } + /** + * Returns the truncation policy. + * + * @return the truncation policy + */ + public String getTruncation() { + return truncation.name(); + } + + /** + * Returns the padding policy. + * + * @return the padding policy + */ + public String getPadding() { + return padding.name(); + } + + /** + * Returns the max token length. + * + * @return the max token length + */ + public int getMaxLength() { + return maxLength; + } + + /** + * Returns the stride to use in overflow overlap when truncating sequences longer than the model + * supports. + * + * @return the stride to use in overflow overlap when truncating sequences longer than the model + * supports + */ + public int getStride() { + return stride; + } + + /** + * Returns the padToMultipleOf for padding. + * + * @return the padToMultipleOf for padding + */ + public int getPadToMultipleOf() { + return padToMultipleOf; + } + /** * Creates a builder to build a {@code HuggingFaceTokenizer}. * @@ -503,7 +601,7 @@ private void updateTruncationAndPadding() { } } - private Encoding toEncoding(long encoding) { + private Encoding toEncoding(long encoding, boolean withOverflowingTokens) { long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding); long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding); String[] tokens = TokenizersLibrary.LIB.getTokens(encoding); @@ -511,11 +609,17 @@ private Encoding toEncoding(long encoding) { long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding); long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding); CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding); - long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding); - Encoding[] overflowing = new Encoding[overflowingHandles.length]; - for (int i = 0; i < overflowingHandles.length; ++i) { - overflowing[i] = toEncoding(overflowingHandles[i]); + long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding); + boolean exceedMaxLength = overflowingHandles.length > 0; + Encoding[] overflowing; + if (withOverflowingTokens) { + overflowing = new Encoding[overflowingHandles.length]; + for (int i = 0; i < overflowingHandles.length; ++i) { + overflowing[i] = toEncoding(overflowingHandles[i], true); + } + } else { + overflowing = new Encoding[0]; } TokenizersLibrary.LIB.deleteEncoding(encoding); @@ -527,6 +631,7 @@ private Encoding toEncoding(long encoding) { attentionMask, specialTokenMask, charSpans, + exceedMaxLength, overflowing); } @@ -651,6 +756,17 @@ public Builder optAddSpecialTokens(boolean addSpecialTokens) { return this; } + /** + * Sets if add special tokens. + * + * @param withOverflowingTokens true to return overflowing tokens + * @return this builder + */ + public Builder optWithOverflowingTokens(boolean withOverflowingTokens) { + options.put("withOverflowingTokens", String.valueOf(withOverflowingTokens)); + return this; + } + /** * Enables or Disables default truncation behavior for the tokenizer. * @@ -738,6 +854,28 @@ public Builder optStride(int stride) { return this; } + /** + * Sets the doLowerCase for the tokenizer. + * + * @param doLowerCase {@code true} to enable convert to lowercase + * @return this builder + */ + public Builder optDoLowerCase(boolean doLowerCase) { + options.put("doLowerCase", String.valueOf(doLowerCase)); + return this; + } + + /** + * Sets the doLowerCase for the tokenizer with specific locale. + * + * @param locale the locale to use when converting to lowercase + * @return this builder + */ + public Builder optDoLowerCase(String locale) { + options.put("doLowerCase", locale); + return this; + } + /** * Configures the builder with the arguments. * @@ -787,7 +925,7 @@ public HuggingFaceTokenizer build() throws IOException { return managed(HuggingFaceTokenizer.newInstance(vocab, merges, options)); } throw new IOException("tokenizer.json file not found."); - } else if (Files.exists(tokenizerPath)) { + } else if (!Files.exists(tokenizerPath)) { throw new IOException("Tokenizer file not exits: " + tokenizerPath); } return managed(HuggingFaceTokenizer.newInstance(tokenizerPath, options)); diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/LibUtils.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/LibUtils.java index 7276b30ae9d..5bc71de4c0e 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/LibUtils.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/LibUtils.java @@ -72,8 +72,9 @@ private static void loadLibrary() { String nativeHelper = System.getProperty("ai.djl.huggingface.native_helper"); if (nativeHelper != null && !nativeHelper.isEmpty()) { ClassLoaderUtils.nativeLoad(nativeHelper, path); + } else { + System.load(path); // NOPMD } - System.load(path); // NOPMD } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java new file mode 100644 index 00000000000..6f43c7cb480 --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.Batchifier; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.PairList; +import ai.djl.util.StringPair; + +import java.util.Arrays; + +/** The translator for Huggingface cross encoder model. */ +public class CrossEncoderBatchTranslator implements NoBatchifyTranslator { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private Batchifier batchifier; + + CrossEncoderBatchTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { + this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; + this.batchifier = batchifier; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, StringPair[] inputs) + throws TranslateException { + NDManager manager = ctx.getNDManager(); + PairList list = new PairList<>(Arrays.asList(inputs)); + Encoding[] encodings = tokenizer.batchEncode(list); + NDList[] batch = new NDList[encodings.length]; + for (int i = 0; i < encodings.length; ++i) { + batch[i] = encodings[i].toNDList(manager, includeTokenTypes); + } + return batchifier.batchify(batch); + } + + /** {@inheritDoc} */ + @Override + public float[][] processOutput(TranslatorContext ctx, NDList list) { + NDList[] batch = batchifier.unbatchify(list); + float[][] ret = new float[batch.length][]; + for (int i = 0; i < batch.length; ++i) { + NDArray logits = list.get(0); + NDArray result = logits.getNDArrayInternal().sigmoid(); + ret[i] = result.toFloatArray(); + } + return ret; + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java new file mode 100644 index 00000000000..c3f4db0cc17 --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java @@ -0,0 +1,169 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.StringPair; + +import java.io.IOException; +import java.util.Map; + +/** The translator for Huggingface cross encoder model. */ +public class CrossEncoderTranslator implements Translator { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private boolean sigmoid; + private Batchifier batchifier; + + CrossEncoderTranslator( + HuggingFaceTokenizer tokenizer, + boolean includeTokenTypes, + boolean sigmoid, + Batchifier batchifier) { + this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; + this.sigmoid = sigmoid; + this.batchifier = batchifier; + } + + /** {@inheritDoc} */ + @Override + public Batchifier getBatchifier() { + return batchifier; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, StringPair input) { + Encoding encoding = tokenizer.encode(input.getKey(), input.getValue()); + ctx.setAttachment("encoding", encoding); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); + } + + /** {@inheritDoc} */ + @Override + public float[] processOutput(TranslatorContext ctx, NDList list) { + NDArray logits = list.get(0); + if (sigmoid) { + logits = logits.getNDArrayInternal().sigmoid(); + } + return logits.toFloatArray(); + } + + /** {@inheritDoc} */ + @Override + public CrossEncoderBatchTranslator toBatchTranslator(Batchifier batchifier) { + tokenizer.enableBatch(); + return new CrossEncoderBatchTranslator(tokenizer, includeTokenTypes, batchifier); + } + + /** + * Creates a builder to build a {@code CrossEncoderTranslator}. + * + * @param tokenizer the tokenizer + * @return a new builder + */ + public static Builder builder(HuggingFaceTokenizer tokenizer) { + return new Builder(tokenizer); + } + + /** + * Creates a builder to build a {@code CrossEncoderTranslator}. + * + * @param tokenizer the tokenizer + * @param arguments the models' arguments + * @return a new builder + */ + public static Builder builder(HuggingFaceTokenizer tokenizer, Map arguments) { + Builder builder = builder(tokenizer); + builder.configure(arguments); + + return builder; + } + + /** The builder for question answering translator. */ + public static final class Builder { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private boolean sigmoid; + private Batchifier batchifier = Batchifier.STACK; + + Builder(HuggingFaceTokenizer tokenizer) { + this.tokenizer = tokenizer; + } + + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + + /** + * Sets if apply sigmoid for the {@link Translator}. + * + * @param sigmoid true to apply sigmoid + * @return this builder + */ + public Builder optSigmoid(boolean sigmoid) { + this.sigmoid = sigmoid; + return this; + } + + /** + * Sets the {@link Batchifier} for the {@link Translator}. + * + * @param batchifier true to include token types + * @return this builder + */ + public Builder optBatchifier(Batchifier batchifier) { + this.batchifier = batchifier; + return this; + } + + /** + * Configures the builder with the model arguments. + * + * @param arguments the model arguments + */ + public void configure(Map arguments) { + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); + optSigmoid(ArgumentsUtil.booleanValue(arguments, "sigmoid", true)); + String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); + optBatchifier(Batchifier.fromString(batchifierStr)); + } + + /** + * Builds the translator. + * + * @return the new translator + * @throws IOException if I/O error occurs + */ + public CrossEncoderTranslator build() throws IOException { + return new CrossEncoderTranslator(tokenizer, includeTokenTypes, sigmoid, batchifier); + } + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java new file mode 100644 index 00000000000..f4f9af02c4b --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.Model; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.nlp.translator.CrossEncoderServingTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; +import ai.djl.util.StringPair; + +import java.io.IOException; +import java.io.Serializable; +import java.lang.reflect.Type; +import java.nio.file.Path; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link CrossEncoderTranslatorFactory} instance. */ +public class CrossEncoderTranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(StringPair.class, float[].class)); + SUPPORTED_TYPES.add(new Pair<>(StringPair[].class, float[][].class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) + throws TranslateException { + Path modelPath = model.getModelPath(); + try { + HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder(arguments) + .optTokenizerPath(modelPath) + .optManager(model.getNDManager()) + .build(); + CrossEncoderTranslator translator = + CrossEncoderTranslator.builder(tokenizer, arguments).build(); + if (input == StringPair.class && output == float[].class) { + return (Translator) translator; + } else if (input == StringPair[].class && output == float[][].class) { + return (Translator) translator.toBatchTranslator(); + } else if (input == Input.class && output == Output.class) { + return (Translator) new CrossEncoderServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } catch (IOException e) { + throw new TranslateException("Failed to load tokenizer.", e); + } + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java index 43b120cac43..ee4e9cf9601 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java @@ -29,15 +29,21 @@ public class FillMaskBatchTranslator implements NoBatchifyTranslator { private String maskToken; private long maskTokenId; private int topK; + private boolean includeTokenTypes; private Batchifier batchifier; FillMaskTranslator( - HuggingFaceTokenizer tokenizer, String maskToken, int topK, Batchifier batchifier) { + HuggingFaceTokenizer tokenizer, + String maskToken, + int topK, + boolean includeTokenTypes, + Batchifier batchifier) { this.tokenizer = tokenizer; this.maskToken = maskToken; this.topK = topK; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; - Encoding encoding = tokenizer.encode(maskToken, false); + Encoding encoding = tokenizer.encode(maskToken, false, false); maskTokenId = encoding.getIds()[0]; } @@ -61,7 +67,7 @@ public NDList processInput(TranslatorContext ctx, String input) throws Translate long[] indices = encoding.getIds(); int maskIndex = getMaskIndex(indices, maskToken, maskTokenId); ctx.setAttachment("maskIndex", maskIndex); - return encoding.toNDList(ctx.getNDManager(), false); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); } /** {@inheritDoc} */ @@ -75,7 +81,8 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) { @Override public FillMaskBatchTranslator toBatchTranslator(Batchifier batchifier) { tokenizer.enableBatch(); - return new FillMaskBatchTranslator(tokenizer, maskToken, topK, batchifier); + return new FillMaskBatchTranslator( + tokenizer, maskToken, topK, includeTokenTypes, batchifier); } static int getMaskIndex(long[] indices, String maskToken, long maskTokenId) @@ -139,6 +146,7 @@ public static final class Builder { private HuggingFaceTokenizer tokenizer; private String maskedToken = "[MASK]"; private int topK = 5; + private boolean includeTokenTypes; private Batchifier batchifier = Batchifier.STACK; Builder(HuggingFaceTokenizer tokenizer) { @@ -167,6 +175,17 @@ public Builder optTopK(int topK) { return this; } + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + /** * Sets the {@link Batchifier} for the {@link Translator}. * @@ -186,6 +205,7 @@ public Builder optBatchifier(Batchifier batchifier) { public void configure(Map arguments) { optMaskToken(ArgumentsUtil.stringValue(arguments, "maskToken", "[MASK]")); optTopK(ArgumentsUtil.intValue(arguments, "topK", 5)); + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); optBatchifier(Batchifier.fromString(batchifierStr)); } @@ -197,7 +217,8 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public FillMaskTranslator build() throws IOException { - return new FillMaskTranslator(tokenizer, maskedToken, topK, batchifier); + return new FillMaskTranslator( + tokenizer, maskedToken, topK, includeTokenTypes, batchifier); } } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationBatchTranslator.java index 6c9beda2531..c4252bbf48e 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationBatchTranslator.java @@ -32,12 +32,19 @@ public class TextClassificationBatchTranslator implements NoBatchifyTranslator { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier; private PretrainedConfig config; - TextClassificationBatchTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + TextClassificationBatchTranslator( + HuggingFaceTokenizer tokenizer, + boolean includeTokenTypes, + Batchifier batchifier, + PretrainedConfig config) { this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; + this.config = config; } /** {@inheritDoc} */ @@ -56,7 +63,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) { Encoding[] encodings = tokenizer.batchEncode(inputs); NDList[] batch = new NDList[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - batch[i] = encodings[i].toNDList(manager, false); + batch[i] = encodings[i].toNDList(manager, includeTokenTypes); } return batchifier.batchify(batch); } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java index d624b69d700..48d190020e2 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java @@ -35,11 +35,14 @@ public class TextClassificationTranslator implements Translator { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier; private PretrainedConfig config; - TextClassificationTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + TextClassificationTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; } @@ -63,7 +66,7 @@ public void prepare(TranslatorContext ctx) throws IOException { @Override public NDList processInput(TranslatorContext ctx, String input) { Encoding encoding = tokenizer.encode(input); - return encoding.toNDList(ctx.getNDManager(), false); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); } /** {@inheritDoc} */ @@ -76,7 +79,8 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) { @Override public TextClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) { tokenizer.enableBatch(); - return new TextClassificationBatchTranslator(tokenizer, batchifier); + return new TextClassificationBatchTranslator( + tokenizer, includeTokenTypes, batchifier, config); } static Classifications toClassifications(PretrainedConfig config, NDList list) { @@ -127,12 +131,24 @@ public static Builder builder(HuggingFaceTokenizer tokenizer, Map arg public static final class Builder { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier = Batchifier.STACK; Builder(HuggingFaceTokenizer tokenizer) { this.tokenizer = tokenizer; } + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + /** * Sets the {@link Batchifier} for the {@link Translator}. * @@ -150,6 +166,7 @@ public Builder optBatchifier(Batchifier batchifier) { * @param arguments the model arguments */ public void configure(Map arguments) { + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); optBatchifier(Batchifier.fromString(batchifierStr)); } @@ -161,7 +178,7 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public TextClassificationTranslator build() throws IOException { - return new TextClassificationTranslator(tokenizer, batchifier); + return new TextClassificationTranslator(tokenizer, includeTokenTypes, batchifier); } } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingBatchTranslator.java index 8f74a0c8bd9..13cf218821c 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingBatchTranslator.java @@ -21,6 +21,8 @@ import ai.djl.translate.NoBatchifyTranslator; import ai.djl.translate.TranslatorContext; +import java.util.Arrays; + /** The translator for Huggingface text embedding model. */ public class TextEmbeddingBatchTranslator implements NoBatchifyTranslator { @@ -28,27 +30,30 @@ public class TextEmbeddingBatchTranslator implements NoBatchifyTranslator { private boolean normalize; private String pooling; private boolean includeTokenTypes; + private String dense; + private String denseActivation; + private String layerNorm; + private NDList denseModel; + private NDList layerNormModel; TextEmbeddingTranslator( HuggingFaceTokenizer tokenizer, Batchifier batchifier, String pooling, boolean normalize, - boolean includeTokenTypes) { + boolean includeTokenTypes, + String dense, + String denseActivation, + String layerNorm) { this.tokenizer = tokenizer; this.batchifier = batchifier; this.pooling = pooling; this.normalize = normalize; this.includeTokenTypes = includeTokenTypes; + this.dense = dense; + this.denseActivation = denseActivation; + this.layerNorm = layerNorm; } /** {@inheritDoc} */ @@ -56,6 +73,34 @@ public Batchifier getBatchifier() { return batchifier; } + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) throws Exception { + NDManager manager = ctx.getPredictorManager().newSubManager(Device.cpu()); + if (dense != null) { + Path file = Paths.get(dense); + if (!file.isAbsolute()) { + file = ctx.getModel().getModelPath().resolve(file); + } + if (Files.exists(file)) { + try (InputStream is = Files.newInputStream(file)) { + denseModel = NDList.decode(manager, is); + } + } + } + if (layerNorm != null) { + Path file = Paths.get(layerNorm); + if (!file.isAbsolute()) { + file = ctx.getModel().getModelPath().resolve(file); + } + if (Files.exists(file)) { + try (InputStream is = Files.newInputStream(file)) { + layerNormModel = NDList.decode(manager, is); + } + } + } + } + /** {@inheritDoc} */ @Override public NDList processInput(TranslatorContext ctx, String input) { @@ -70,6 +115,25 @@ public float[] processOutput(TranslatorContext ctx, NDList list) { Encoding encoding = (Encoding) ctx.getAttachment("encoding"); NDManager manager = ctx.getNDManager(); NDArray embeddings = processEmbedding(manager, list, encoding, pooling); + embeddings = embeddings.toDevice(Device.cpu(), false); + if (denseModel != null) { + NDArray weight = denseModel.get("linear.weight"); + NDArray bias = denseModel.get("linear.bias"); + embeddings = embeddings.getNDArrayInternal().linear(embeddings, weight, bias).get(0); + if ("Tanh".equals(denseActivation)) { + embeddings = embeddings.tanh(); + } + } + if (layerNormModel != null) { + NDArray weight = layerNormModel.get("norm.weight"); + NDArray bias = layerNormModel.get("norm.bias"); + Shape shape = weight.getShape(); + embeddings = + embeddings + .getNDArrayInternal() + .layerNorm(embeddings, shape, weight, bias, 1e-5f) + .get(0); + } if (normalize) { embeddings = embeddings.normalize(2, 0); } @@ -81,7 +145,8 @@ public float[] processOutput(TranslatorContext ctx, NDList list) { @Override public TextEmbeddingBatchTranslator toBatchTranslator(Batchifier batchifier) { tokenizer.enableBatch(); - return new TextEmbeddingBatchTranslator(tokenizer, batchifier, pooling, normalize); + return new TextEmbeddingBatchTranslator( + tokenizer, batchifier, pooling, normalize, includeTokenTypes); } static NDArray processEmbedding( @@ -113,7 +178,7 @@ private static NDArray meanPool(NDArray embeddings, NDArray attentionMask, boole long[] shape = embeddings.getShape().getShape(); attentionMask = attentionMask.expandDims(-1).broadcast(shape); NDArray inputAttentionMaskSum = attentionMask.sum(AXIS); - NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12); + NDArray clamp = inputAttentionMaskSum.clip(1e-9f, 1e12f); NDArray prod = embeddings.mul(attentionMask); NDArray sum = prod.sum(AXIS); if (sqrt) { @@ -175,6 +240,9 @@ public static final class Builder { private boolean normalize = true; private String pooling = "mean"; private boolean includeTokenTypes; + private String dense; + private String denseActivation; + private String layerNorm; Builder(HuggingFaceTokenizer tokenizer) { this.tokenizer = tokenizer; @@ -233,6 +301,39 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) { return this; } + /** + * Sets the dense layer model file for the {@link Translator}. + * + * @param dense path to dense layer model file + * @return this builder + */ + public Builder optDense(String dense) { + this.dense = dense; + return this; + } + + /** + * Sets the dense activation function for the {@link Translator}. + * + * @param denseActivation path to dense layer + * @return this builder + */ + public Builder optDenseActivation(String denseActivation) { + this.denseActivation = denseActivation; + return this; + } + + /** + * Sets the LayerNorm model for the {@link Translator}. + * + * @param layerNorm path to LayerNorm model + * @return this builder + */ + public Builder optLayerNorm(String layerNorm) { + this.layerNorm = layerNorm; + return this; + } + /** * Configures the builder with the model arguments. * @@ -244,6 +345,9 @@ public void configure(Map arguments) { optNormalize(ArgumentsUtil.booleanValue(arguments, "normalize", true)); optPoolingMode(ArgumentsUtil.stringValue(arguments, "pooling", "mean")); optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); + optDense(ArgumentsUtil.stringValue(arguments, "dense")); + optDenseActivation(ArgumentsUtil.stringValue(arguments, "denseActivation")); + optLayerNorm(ArgumentsUtil.stringValue(arguments, "layerNorm")); } /** @@ -254,7 +358,14 @@ public void configure(Map arguments) { */ public TextEmbeddingTranslator build() throws IOException { return new TextEmbeddingTranslator( - tokenizer, batchifier, pooling, normalize, includeTokenTypes); + tokenizer, + batchifier, + pooling, + normalize, + includeTokenTypes, + dense, + denseActivation, + layerNorm); } } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationBatchTranslator.java index 2ae45438ccd..f5fab3d7e75 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationBatchTranslator.java @@ -32,12 +32,19 @@ public class TokenClassificationBatchTranslator implements NoBatchifyTranslator { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier; private PretrainedConfig config; - TokenClassificationBatchTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + TokenClassificationBatchTranslator( + HuggingFaceTokenizer tokenizer, + boolean includeTokenTypes, + Batchifier batchifier, + PretrainedConfig config) { this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; + this.config = config; } /** {@inheritDoc} */ @@ -58,7 +65,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) { ctx.setAttachment("encodings", encodings); NDList[] batch = new NDList[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - batch[i] = encodings[i].toNDList(manager, false); + batch[i] = encodings[i].toNDList(manager, includeTokenTypes); } return batchifier.batchify(batch); } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java index b1106c244a6..c9390e12cc8 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java @@ -36,11 +36,14 @@ public class TokenClassificationTranslator implements Translator { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier; private PretrainedConfig config; - TokenClassificationTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + TokenClassificationTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; } @@ -65,7 +68,7 @@ public void prepare(TranslatorContext ctx) throws IOException { public NDList processInput(TranslatorContext ctx, String input) { Encoding encoding = tokenizer.encode(input); ctx.setAttachment("encoding", encoding); - return encoding.toNDList(ctx.getNDManager(), false); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); } /** {@inheritDoc} */ @@ -79,7 +82,8 @@ public NamedEntity[] processOutput(TranslatorContext ctx, NDList list) { @Override public TokenClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) { tokenizer.enableBatch(); - return new TokenClassificationBatchTranslator(tokenizer, batchifier); + return new TokenClassificationBatchTranslator( + tokenizer, includeTokenTypes, batchifier, config); } /** @@ -139,12 +143,24 @@ static NamedEntity[] toNamedEntities(Encoding encoding, NDList list, PretrainedC public static final class Builder { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier = Batchifier.STACK; Builder(HuggingFaceTokenizer tokenizer) { this.tokenizer = tokenizer; } + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + /** * Sets the {@link Batchifier} for the {@link Translator}. * @@ -162,6 +178,7 @@ public Builder optBatchifier(Batchifier batchifier) { * @param arguments the model arguments */ public void configure(Map arguments) { + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); optBatchifier(Batchifier.fromString(batchifierStr)); } @@ -173,7 +190,7 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public TokenClassificationTranslator build() throws IOException { - return new TokenClassificationTranslator(tokenizer, batchifier); + return new TokenClassificationTranslator(tokenizer, includeTokenTypes, batchifier); } } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/zoo/HfModelZoo.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/zoo/HfModelZoo.java index 9ee8fc19cf8..f79a9a6090a 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/zoo/HfModelZoo.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/zoo/HfModelZoo.java @@ -54,7 +54,7 @@ public class HfModelZoo extends ModelZoo { private static final long ONE_DAY = Duration.ofDays(1).toMillis(); - private boolean initialized; + private volatile boolean initialized; // NOPMD HfModelZoo() {} @@ -86,13 +86,17 @@ public ModelLoader getModelLoader(String name) { private void init() { if (!initialized) { - Version version = new Version(Engine.getDjlVersion()); - addModels(NLP.FILL_MASK, version); - addModels(NLP.QUESTION_ANSWER, version); - addModels(NLP.TEXT_CLASSIFICATION, version); - addModels(NLP.TEXT_EMBEDDING, version); - addModels(NLP.TOKEN_CLASSIFICATION, version); - initialized = true; + synchronized (HfModelZoo.class) { + if (!initialized) { + Version version = new Version(Engine.getDjlVersion()); + addModels(NLP.FILL_MASK, version); + addModels(NLP.QUESTION_ANSWER, version); + addModels(NLP.TEXT_CLASSIFICATION, version); + addModels(NLP.TEXT_EMBEDDING, version); + addModels(NLP.TOKEN_CLASSIFICATION, version); + initialized = true; + } + } } } @@ -123,7 +127,7 @@ private Map> listModels(Application app) { if (Files.notExists(dir)) { Files.createDirectories(dir); } else if (!Files.isDirectory(dir)) { - logger.warn("Failed initialize cache directory: " + dir); + logger.warn("Failed initialize cache directory: {}", dir); return Collections.emptyMap(); } Type type = new TypeToken>>() {}.getType(); @@ -131,8 +135,7 @@ private Map> listModels(Application app) { Path file = dir.resolve("models.json"); if (Files.exists(file)) { long lastModified = Files.getLastModifiedTime(file).toMillis(); - if (Boolean.getBoolean("offline") - || System.currentTimeMillis() - lastModified < ONE_DAY) { + if (Utils.isOfflineMode() || System.currentTimeMillis() - lastModified < ONE_DAY) { try (Reader reader = Files.newBufferedReader(file)) { return JsonUtils.GSON.fromJson(reader, type); } diff --git a/extensions/tokenizers/src/main/python/fill_mask_converter.py b/extensions/tokenizers/src/main/python/fill_mask_converter.py index ab0b5f4447c..ff9de4bea2e 100644 --- a/extensions/tokenizers/src/main/python/fill_mask_converter.py +++ b/extensions/tokenizers/src/main/python/fill_mask_converter.py @@ -59,5 +59,6 @@ def encode_inputs(self, tokenizer): text = self.inputs.replace("[MASK]", tokenizer.mask_token) return tokenizer.encode_plus(text, return_tensors='pt') - def get_extra_arguments(self, hf_pipeline, model_id: str) -> dict: + def get_extra_arguments(self, hf_pipeline, model_id: str, + temp_dir: str) -> dict: return {"maskToken": hf_pipeline.tokenizer.mask_token} diff --git a/extensions/tokenizers/src/main/python/huggingface_converter.py b/extensions/tokenizers/src/main/python/huggingface_converter.py index f3b85c241ec..6f9bea2c884 100644 --- a/extensions/tokenizers/src/main/python/huggingface_converter.py +++ b/extensions/tokenizers/src/main/python/huggingface_converter.py @@ -10,7 +10,6 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. - import logging import os.path import shutil @@ -41,9 +40,20 @@ def save_model(self, model_info, args: Namespace, temp_dir: str): if not os.path.exists(temp_dir): os.makedirs(temp_dir) - hf_pipeline = self.load_model(model_id) - # Save tokenizer.json to temp dir - self.save_tokenizer(hf_pipeline, temp_dir) + try: + hf_pipeline = self.load_model(model_id) + except Exception as e: + logging.warning(f"Failed to load model: {model_id}.") + logging.warning(e, exc_info=True) + return False, "Failed to load model", -1 + + try: + # Save tokenizer.json to temp dir + self.save_tokenizer(hf_pipeline, temp_dir) + except Exception as e: + logging.warning(f"Failed to save tokenizer: {model_id}.") + logging.warning(e, exc_info=True) + return False, "Failed to save tokenizer", -1 # Save config.json just for reference config = hf_hub_download(repo_id=model_id, filename="config.json") @@ -112,7 +122,7 @@ def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str, logging.info(f"Saving torchscript model: {model_name}.pt ...") model_file = os.path.join(temp_dir, f"{model_name}.pt") script_module.save(model_file) - except (RuntimeError, ValueError) as e: + except Exception as e: logging.warning(f"Failed to trace model: {model_id}.") logging.warning(e, exc_info=True) return None @@ -131,14 +141,15 @@ def save_to_model_zoo(self, model_info, output_dir: str, temp_dir: str, # Save serving.properties serving_file = os.path.join(temp_dir, "serving.properties") - arguments = self.get_extra_arguments(hf_pipeline, model_id) + arguments = self.get_extra_arguments(hf_pipeline, model_id, temp_dir) + if include_types: + arguments["includeTokenTypes"] = "true" + arguments["translatorFactory"] = self.translator + with open(serving_file, 'w') as f: f.write(f"engine=PyTorch\n" f"option.modelName={model_name}\n" - f"option.mapLocation=true\n" - f"translatorFactory={self.translator}\n") - if include_types: - f.write(f"includeTokenTypes={include_types}\n") + f"option.mapLocation=true\n") for k, v in arguments.items(): f.write(f"{k}={v}\n") @@ -149,10 +160,11 @@ def save_to_model_zoo(self, model_info, output_dir: str, temp_dir: str, zip_dir(temp_dir, zip_file) # Save metadata.json + arguments["engine"] = "PyTorch" sha1 = sha1_sum(zip_file) file_size = os.path.getsize(zip_file) - metadata = HuggingfaceMetadata(model_info, self.application, - self.translator, sha1, file_size) + metadata = HuggingfaceMetadata(model_info, self.application, sha1, + file_size, arguments) metadata_file = os.path.join(repo_dir, "metadata.json") metadata.save_metadata(metadata_file) @@ -194,7 +206,8 @@ def verify_jit_model(self, hf_pipeline, model_file: str, return self.verify_jit_output(hf_pipeline, encoding, out) - def get_extra_arguments(self, hf_pipeline, model_id: str) -> dict: + def get_extra_arguments(self, hf_pipeline, model_id: str, + temp_dir: str) -> dict: return {} def verify_jit_output(self, hf_pipeline, encoding, out): diff --git a/extensions/tokenizers/src/main/python/huggingface_models.py b/extensions/tokenizers/src/main/python/huggingface_models.py index 3418815d5c4..5b1c6debe5d 100644 --- a/extensions/tokenizers/src/main/python/huggingface_models.py +++ b/extensions/tokenizers/src/main/python/huggingface_models.py @@ -16,7 +16,7 @@ from argparse import Namespace from typing import List -from huggingface_hub import HfApi, ModelSearchArguments +from huggingface_hub import HfApi from huggingface_hub import hf_hub_download from huggingface_hub.hf_api import ModelInfo @@ -27,7 +27,7 @@ "ForMultipleChoice": "text-classification", "ForMaskedLM": "fill-mask", } -LANGUAGES = ModelSearchArguments().language +LANGUAGES = HfApi().get_model_tags()["language"] def get_lang_tags(model_info): @@ -56,23 +56,32 @@ def __init__(self, output_dir: str): self.temp_dir = f"{self.output_dir}/tmp" def list_models(self, args: Namespace) -> List[dict]: + import_all = os.environ.get("HF_IMPORT_ALL") + api = HfApi() if args.model_name: - models = api.list_models(filter="pytorch", - search=args.model_name, - sort="downloads", - direction=-1, - limit=args.limit) - if not models: - logging.warning(f"no model found: {args.model_name}.") + all_models = api.list_models(search=args.model_name, + sort="downloads", + direction=-1, + limit=args.limit) + import_all = True else: - models = api.list_models(filter=f"{args.category},pytorch", - sort="downloads", - direction=-1, - limit=args.limit) - if not models: + all_models = api.list_models(filter=args.category, + sort="downloads", + direction=-1, + limit=args.limit) + models = [ + model for model in all_models + if 'pytorch' in model.tags or 'safetensors' in model.tags + ] + if not models: + if args.model_name: + logging.warning(f"no model found: {args.model_name}.") + else: logging.warning(f"no model matches category: {args.category}.") + return [] + ret = [] for model_info in models: model_id = model_info.modelId @@ -83,7 +92,7 @@ def list_models(self, args: Namespace) -> List[dict]: continue languages = get_lang_tags(model_info) - if "en" not in languages: + if "en" not in languages and not import_all: logging.warning(f"Skip non-English model: {model_id}.") continue @@ -94,6 +103,12 @@ def list_models(self, args: Namespace) -> List[dict]: logging.info(f"Skip converted model: {model_id}.") continue + if model_info.downloads < 50 and not import_all: + logging.info( + f"Skip model {model_info.modelId}, downloads {model_info.downloads} < 50" + ) + continue + try: config = hf_hub_download(repo_id=model_id, filename="config.json") diff --git a/extensions/tokenizers/src/main/python/metadata.py b/extensions/tokenizers/src/main/python/metadata.py index b4b61f115d2..94982cdc30c 100644 --- a/extensions/tokenizers/src/main/python/metadata.py +++ b/extensions/tokenizers/src/main/python/metadata.py @@ -16,15 +16,15 @@ class HuggingfaceMetadata: - def __init__(self, model_info, application: str, translator: str, - sha1: str, file_size: int): + def __init__(self, model_info, application: str, sha1: str, file_size: int, + arguments: dict): self.model_info = model_info self.artifact_id = model_info.modelId self.model_name = model_info.modelId.split("/")[-1] self.application = application - self.translator = translator self.sha1 = sha1 self.file_size = file_size + self.arguments = arguments def save_metadata(self, metadata_file: str): properties = get_lang_tags(self.model_info) @@ -57,10 +57,7 @@ def save_metadata(self, metadata_file: str): "snapshot": False, "name": self.model_name, "properties": properties, - "arguments": { - "engine": "PyTorch", - "translatorFactory": self.translator - }, + "arguments": self.arguments, "options": { "mapLocation": True }, diff --git a/extensions/tokenizers/src/main/python/model_zoo_importer.py b/extensions/tokenizers/src/main/python/model_zoo_importer.py index 9ed32ec58ef..0ed67bd1018 100644 --- a/extensions/tokenizers/src/main/python/model_zoo_importer.py +++ b/extensions/tokenizers/src/main/python/model_zoo_importer.py @@ -49,9 +49,17 @@ def main(): model_info = model["model_info"] converter = SUPPORTED_TASK[task] - result, reason, size = converter.save_model(model_info, args, temp_dir) - if not result: - logging.error(f"{model_info.modelId}: {reason}") + try: + result, reason, size = converter.save_model( + model_info, args, temp_dir) + if not result: + logging.error(f"{model_info.modelId}: {reason}") + except Exception as e: + logging.warning(f"Failed to convert model: {model_info.modelId}.") + logging.warning(e, exc_info=True) + result = False + reason = "Failed to convert model" + size = -1 huggingface_models.update_progress(model_info, converter.application, result, reason, size, args.cpu_only) diff --git a/extensions/tokenizers/src/main/python/requirements.txt b/extensions/tokenizers/src/main/python/requirements.txt index bf197b644ea..05ce0bc4833 100644 --- a/extensions/tokenizers/src/main/python/requirements.txt +++ b/extensions/tokenizers/src/main/python/requirements.txt @@ -1,4 +1,4 @@ huggingface_hub transformers -torch==1.11.0 +torch protobuf==3.20.2 diff --git a/extensions/tokenizers/src/main/python/sentence_similarity_converter.py b/extensions/tokenizers/src/main/python/sentence_similarity_converter.py index c96975ab51b..7d9b02b59b8 100644 --- a/extensions/tokenizers/src/main/python/sentence_similarity_converter.py +++ b/extensions/tokenizers/src/main/python/sentence_similarity_converter.py @@ -13,6 +13,7 @@ import json import logging import os +import shutil import requests import torch @@ -57,25 +58,118 @@ def verify_jit_output(self, hf_pipeline, encoding, out): return True, None - def get_extra_arguments(self, hf_pipeline, model_id: str) -> dict: + def get_extra_arguments(self, hf_pipeline, model_id: str, + temp_dir: str) -> dict: args = {"padding": "true"} + for config_name in [ + 'sentence_bert_config.json', 'sentence_roberta_config.json', + 'sentence_distilbert_config.json', + 'sentence_camembert_config.json', + 'sentence_albert_config.json', + 'sentence_xlm-roberta_config.json', + 'sentence_xlnet_config.json' + ]: + try: + file = hf_hub_download(repo_id=model_id, filename=config_name) + with open(file) as f: + config = json.load(f) + if config.get("max_seq_length"): + args["maxLength"] = config.get("max_seq_length") + if config.get("do_lower_case"): + args["doLowerCase"] = config.get("do_lower_case") + + break + except requests.exceptions.HTTPError: + pass + + if not "maxLength" in args: + config = hf_pipeline.model.config + tokenizer = hf_pipeline.tokenizer + if hasattr(config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length"): + max_seq_length = min(config.max_position_embeddings, + tokenizer.model_max_length) + args["maxLength"] = str(max_seq_length) + + pooling_path = None + dense_path = None + layer_norm_path = None + normalize = False try: - file = hf_hub_download(repo_id=model_id, - filename="1_Pooling/config.json") - if os.path.exists(file): - with open(file, "r") as f: - pooling = json.load(f) - if pooling.get("pooling_mode_cls_token"): - args["pooling"] = "cls" - elif pooling.get("pooling_mode_max_tokens"): - args["pooling"] = "max" - elif pooling.get("pooling_mode_mean_sqrt_len_tokens"): - args["pooling"] = "mean_sqrt_len" - elif pooling.get("pooling_mode_weightedmean_tokens"): - args["pooling"] = "weightedmean" - elif pooling.get("pooling_mode_lasttoken"): - args["pooling"] = "lasttoken" + file = hf_hub_download(repo_id=model_id, filename="modules.json") + with open(file, "r") as f: + modules = json.load(f) + + for module in modules: + module_type = module.get("type") + if module_type == "sentence_transformers.models.Pooling": + pooling_path = module["path"] + elif module_type == "sentence_transformers.models.Dense": + dense_path = module["path"] + elif module_type == "sentence_transformers.models.LayerNorm": + layer_norm_path = module["path"] + elif module_type == "sentence_transformers.models.Normalize": + normalize = "true" + elif module_type != "sentence_transformers.models.Transformer": + logging.warning(f"Unexpected module: {module_type}.") except requests.exceptions.HTTPError: - logging.warning(f"{model_id}: 1_Pooling/config.json not found.") + logging.warning(f"{model_id}: modules.json not found.") + + if pooling_path: + try: + file = hf_hub_download(repo_id=model_id, + filename=f"{pooling_path}/config.json") + if os.path.exists(file): + with open(file, "r") as f: + pooling = json.load(f) + if pooling.get("pooling_mode_cls_token"): + args["pooling"] = "cls" + elif pooling.get("pooling_mode_max_tokens"): + args["pooling"] = "max" + elif pooling.get("pooling_mode_mean_sqrt_len_tokens"): + args["pooling"] = "mean_sqrt_len" + elif pooling.get("pooling_mode_weightedmean_tokens"): + args["pooling"] = "weightedmean" + elif pooling.get("pooling_mode_lasttoken"): + args["pooling"] = "lasttoken" + except requests.exceptions.HTTPError: + logging.warning( + f"{model_id}: {pooling_path}/config.json not found.") + + if dense_path: + try: + file = hf_hub_download(repo_id=model_id, + filename=f"{dense_path}/config.json") + with open(file, "r") as f: + dense = json.load(f) + activation = dense.get("activation_function") + if activation == "torch.nn.modules.activation.Tanh": + args["denseActivation"] = "Tanh" + elif activation != "torch.nn.modules.linear.Identity": + logging.warning( + f"Unexpected activation function: {activation}.") + self.save_module_weight(model_id, temp_dir, dense_path, + "linear") + args["dense"] = "linear.safetensors" + except requests.exceptions.HTTPError: + logging.debug(f"{model_id}: {dense_path} not found.") + + if layer_norm_path: + try: + self.save_module_weight(model_id, temp_dir, layer_norm_path, + "norm") + args["layerNorm"] = "norm.safetensors" + except requests.exceptions.HTTPError: + logging.warning(f"{model_id}: {layer_norm_path} not found.") + + if not normalize: + args["normalize"] = "false" return args + + @staticmethod + def save_module_weight(model_id: str, temp_dir: str, layer: str, + name: str): + file = hf_hub_download(repo_id=model_id, + filename=f"{layer}/model.safetensors") + shutil.copyfile(file, os.path.join(temp_dir, f"{name}.safetensors")) diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java new file mode 100644 index 00000000000..f3ee102e325 --- /dev/null +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java @@ -0,0 +1,204 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.tokenizers; + +import ai.djl.Model; +import ai.djl.ModelException; +import ai.djl.huggingface.translator.CrossEncoderTranslatorFactory; +import ai.djl.inference.Predictor; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.nn.Block; +import ai.djl.nn.LambdaBlock; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; +import ai.djl.util.JsonUtils; +import ai.djl.util.StringPair; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +public class CrossEncoderTranslatorTest { + + @Test + public void testCrossEncoderTranslator() + throws ModelException, IOException, TranslateException { + String text1 = "Sentence 1"; + String text2 = "Sentence 2"; + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray array = manager.create(new float[] {-0.7329f}); + return new NDList(array); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(StringPair.class, float[].class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + StringPair input = new StringPair(text1, text2); + float[] res = predictor.predict(input); + Assert.assertEquals(res[0], 0.32456556f, 0.0001); + } + + Criteria criteria2 = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria2.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input = new Input(); + input.add("key", text1); + input.add("value", text2); + Output res = predictor.predict(input); + float[] buf = (float[]) res.getData().getAsObject(); + Assert.assertEquals(buf[0], 0.32455865, 0.0001); + + Assert.assertThrows(TranslateException.class, () -> predictor.predict(new Input())); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.add("something", "false"); + predictor.predict(req); + }); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.addProperty("Content-Type", "application/json"); + req.add("Invalid json"); + predictor.predict(req); + }); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.addProperty("Content-Type", "application/json"); + req.add(JsonUtils.GSON.toJson(new StringPair(text1, null))); + predictor.predict(req); + }); + } + + try (Model model = Model.newInstance("test")) { + model.setBlock(block); + Map options = new HashMap<>(); + options.put("hasParameter", "false"); + model.load(modelDir, "test", options); + + CrossEncoderTranslatorFactory factory = new CrossEncoderTranslatorFactory(); + Map arguments = new HashMap<>(); + + Assert.assertThrows( + TranslateException.class, + () -> factory.newInstance(String.class, Integer.class, model, arguments)); + + arguments.put("tokenizer", "bert-base-cased"); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(String.class, Integer.class, model, arguments)); + } + } + + @Test + public void testCrossEncoderBatchTranslator() + throws ModelException, IOException, TranslateException { + StringPair pair1 = new StringPair("Sentence 1", "Sentence 2"); + StringPair pair2 = new StringPair("Sentence 3", "Sentence 4"); + + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray array = manager.create(new float[][] {{-0.7329f}, {-0.7329f}}); + return new NDList(array); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(StringPair[].class, float[][].class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + StringPair[] inputs = {pair1, pair2}; + float[][] res = predictor.predict(inputs); + Assert.assertEquals(res[1][0], 0.32455865, 0.0001); + } + + Criteria criteria2 = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria2.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input = new Input(); + input.add(JsonUtils.GSON.toJson(new StringPair[] {pair1, pair2})); + input.addProperty("Content-Type", "application/json"); + Output out = predictor.predict(input); + float[][] buf = (float[][]) out.getData().getAsObject(); + Assert.assertEquals(buf[0][0], 0.32455865, 0.0001); + } + } +} diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index dcbef24748d..a3c2e51cbe1 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -26,6 +26,7 @@ import java.nio.file.Paths; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -40,7 +41,17 @@ public void testTokenizer() throws IOException { "[CLS]", "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?", "[SEP]" }; - try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { + try (HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder() + .optTokenizerName("bert-base-cased") + .optTruncation(false) + .build()) { + Assert.assertEquals(tokenizer.getTruncation(), "DO_NOT_TRUNCATE"); + Assert.assertEquals(tokenizer.getPadding(), "DO_NOT_PAD"); + Assert.assertEquals(tokenizer.getMaxLength(), -1); + Assert.assertEquals(tokenizer.getStride(), 0); + Assert.assertEquals(tokenizer.getPadToMultipleOf(), 0); + List ret = tokenizer.tokenize(input); Assert.assertEquals(ret.toArray(Utils.EMPTY_ARRAY), expected); Encoding encoding = tokenizer.encode(input); @@ -115,6 +126,43 @@ public void testTokenizer() throws IOException { Assert.assertEquals(encodings.length, 2); Assert.assertEquals(encodings[0].getIds(), ids); } + + Assert.assertThrows( + () -> { + Path file = Paths.get("build/tokenizer/non-exists.json"); + HuggingFaceTokenizer.builder().optTokenizerPath(file).build(); + }); + } + + @Test + public void testDoLowerCase() throws IOException { + String input = "Hello, y'all! How are you 😁 ?"; + String[] inputs = {"Hello, y'all!", "How are you 😁 ?"}; + try (HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder() + .optTokenizerName("bert-base-cased") + .optAddSpecialTokens(false) + .optDoLowerCase(true) + .build()) { + Encoding encoding = tokenizer.encode(inputs); + String sentence = tokenizer.buildSentence(Arrays.asList(encoding.getTokens())); + Assert.assertEquals(sentence, "hello , y ' all ! how are you [UNK] ?"); + + encoding = tokenizer.encode(input); + Assert.assertEquals(encoding.getTokens().length, 11); + + encoding = tokenizer.encode(input, "How are you my friend"); + Assert.assertEquals(encoding.getTokens().length, 16); + + Encoding[] encodings = tokenizer.batchEncode(inputs); + Assert.assertEquals(encodings.length, 2); + + PairList batch = new PairList<>(2); + batch.add("Hello", "How are you"); + batch.add("Hi, you all", "I'm fine."); + encodings = tokenizer.batchEncode(batch); + Assert.assertEquals(encodings.length, 2); + } } @Test @@ -200,7 +248,10 @@ public void testMaxModelLengthTruncationAndAllPaddings() throws IOException { stringBuilder.append(repeat); } List inputs = Arrays.asList(stringBuilder.toString(), "This is a short sentence"); - try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { + Map options = new ConcurrentHashMap<>(); + options.put("tokenizer", "bert-base-cased"); + options.put("truncation", "false"); + try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) { int[] expectedNumberOfIdsNoTruncationNoPadding = new int[] {numRepeats * 2 + 2, 7}; Encoding[] encodings = tokenizer.batchEncode(inputs); for (int i = 0; i < encodings.length; ++i) { @@ -209,10 +260,7 @@ public void testMaxModelLengthTruncationAndAllPaddings() throws IOException { } } - Map options = new ConcurrentHashMap<>(); - options.put("tokenizer", "bert-base-cased"); - options.put("truncation", "true"); - try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) { + try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { int[] expectedSize = new int[] {512, 7}; Encoding[] encodings = tokenizer.batchEncode(inputs); for (int i = 0; i < encodings.length; ++i) { @@ -220,8 +268,11 @@ public void testMaxModelLengthTruncationAndAllPaddings() throws IOException { } } - options.put("padding", "true"); - try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(options).build()) { + try (HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder() + .optTokenizerName("bert-base-cased") + .optPadding(true) + .build()) { Encoding[] encodings = tokenizer.batchEncode(inputs); for (Encoding encoding : encodings) { Assert.assertEquals(encoding.getIds().length, 512); @@ -294,6 +345,7 @@ public void testTruncationStride() throws IOException { HuggingFaceTokenizer.builder() .optTokenizerName("bert-base-cased") .optAddSpecialTokens(false) + .optWithOverflowingTokens(true) .optTruncation(true) .optMaxLength(3) .optStride(1) @@ -316,13 +368,16 @@ public void testTruncationStride() throws IOException { HuggingFaceTokenizer.builder() .optTokenizerName("bert-base-cased") .optAddSpecialTokens(false) + .optWithOverflowingTokens(true) .optTruncation(true) .optMaxLength(8) .optStride(2) .build()) { String text = "Hello there my friend I am happy to see you"; String textPair = "How are you my friend"; - Encoding[] overflowing = tokenizer.encode(text, textPair).getOverflowing(); + Encoding encoding = tokenizer.encode(text, textPair); + Assert.assertTrue(encoding.exceedMaxLength()); + Encoding[] overflowing = encoding.getOverflowing(); int expectedNumberOfOverflowEncodings = 7; Assert.assertEquals(overflowing.length, expectedNumberOfOverflowEncodings); @@ -367,6 +422,7 @@ public void testTruncationAndPaddingForPairInputs() throws IOException { .optTokenizerName("bert-base-cased") .optTruncateSecondOnly() .optMaxLength(8) + .optDoLowerCase(Locale.ROOT.toLanguageTag()) .build()) { Encoding encoding = tokenizer.encode(text, textPair); Assert.assertEquals(encoding.getIds().length, 8); @@ -452,13 +508,13 @@ public void testBatchProcessing() throws IOException { Assert.assertEquals(outputs, outputsWithSpecialTokens); // encode with special tokens, decode with special tokens - encodings = tokenizer.batchEncode(inputs, true); + encodings = tokenizer.batchEncode(inputs, true, false); batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); outputs = tokenizer.batchDecode(batchIds, false); Assert.assertEquals(outputs, outputsWithSpecialTokens); // encode without special tokens, decode without special tokens - encodings = tokenizer.batchEncode(inputs, false); + encodings = tokenizer.batchEncode(inputs, false, false); batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); outputs = tokenizer.batchDecode(batchIds, true); Assert.assertEquals(outputs, outputsWithoutSpecialTokens); diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java index 91a96fd3ec8..b6ed53492a0 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java @@ -37,7 +37,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; public class TextEmbeddingTranslatorTest { @@ -157,6 +159,16 @@ public void testTextEmbeddingTranslator() float[] res = JsonUtils.GSON.fromJson(out.getAsString(0), float[].class); Assert.assertEquals(res.length, 384); Assertions.assertAlmostEquals(res[0], 0.05103); + + input = new Input(); + Map map = new HashMap<>(); + map.put("inputs", text); + input.add(JsonUtils.GSON.toJson(map)); + input.addProperty("Content-Type", "application/json"); + out = predictor.predict(input); + res = (float[]) out.getData().getAsObject(); + Assert.assertEquals(res.length, 384); + Assertions.assertAlmostEquals(res[0], 0.05103); } try (Model model = Model.newInstance("test")) { @@ -237,6 +249,81 @@ public void testTextEmbeddingBatchTranslator() float[][] res = (float[][]) out.getData().getAsObject(); Assert.assertEquals(res[0].length, 384); Assertions.assertAlmostEquals(res[0][0], 0.05103); + + input = new Input(); + Map map = new HashMap<>(); + map.put("inputs", text); + input.add(JsonUtils.GSON.toJson(map)); + input.addProperty("Content-Type", "application/json"); + out = predictor.predict(input); + res = (float[][]) out.getData().getAsObject(); + Assert.assertEquals(res[0].length, 384); + Assertions.assertAlmostEquals(res[0][0], 0.05103); + + Assert.assertThrows( + () -> { + Input empty = new Input(); + empty.add(JsonUtils.GSON.toJson(new HashMap<>())); + empty.addProperty("Content-Type", "application/json"); + predictor.predict(empty); + }); + + Assert.assertThrows( + () -> { + Input empty = new Input(); + empty.add("{ \"invalid json\""); + empty.addProperty("Content-Type", "application/json"); + predictor.predict(empty); + }); + } + } + + @Test + public void testTextEmbeddingTranslatorServingBatch() + throws ModelException, IOException, TranslateException { + String[] text = {"This is an example sentence", "This is the second sentence"}; + + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray arr = manager.ones(new Shape(4, 7, 384)); + arr.setName("last_hidden_state"); + return new NDList(arr); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-uncased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input1 = new Input(); + input1.add(JsonUtils.GSON.toJson(text)); + input1.addProperty("Content-Type", "application/json"); + + Input input2 = new Input(); + Map map = new HashMap<>(); + map.put("inputs", text); + input2.add(JsonUtils.GSON.toJson(map)); + input2.addProperty("Content-Type", "application/json"); + List batchInput = Arrays.asList(input1, input2); + + List batchOutput = predictor.batchPredict(batchInput); + Assert.assertEquals(batchOutput.size(), 2); + float[][] res = (float[][]) batchOutput.get(0).getData().getAsObject(); + Assert.assertEquals(res[0].length, 384); + Assertions.assertAlmostEquals(res[0][0], 0.05103); } } } diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/zoo/ModelZooTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/zoo/ModelZooTest.java index e585a219f17..5c3ed6c3ed0 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/zoo/ModelZooTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/zoo/ModelZooTest.java @@ -103,21 +103,21 @@ public void testFutureVersion() throws IOException { @Test public void testOffLine() throws IOException { System.setProperty("DJL_CACHE_DIR", "build/cache"); - System.setProperty("offline", "true"); + System.setProperty("ai.djl.offline", "true"); try { Utils.deleteQuietly(Paths.get("build/cache")); // static variables cannot not be initialized properly if directly use new HfModelZoo() ModelZoo.getModelZoo("ai.djl.huggingface.pytorch"); ModelZoo zoo = new HfModelZoo(); - Assert.assertTrue(zoo.getModelLoaders().size() > 0); + Assert.assertFalse(zoo.getModelLoaders().isEmpty()); Set engines = zoo.getSupportedEngines(); Assert.assertEquals(engines.size(), 1); Assert.assertEquals(engines.iterator().next(), "PyTorch"); } finally { System.clearProperty("DJL_CACHE_DIR"); - System.clearProperty("offline"); + System.clearProperty("ai.djl.offline"); } } } diff --git a/gradle.properties b/gradle.properties index 23a6019761a..998189900e6 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,40 +11,41 @@ systemProp.org.gradle.internal.http.connectionTimeout=60000 # FIXME: Workaround gradle publish issue: https://github.com/gradle/gradle/issues/11308 systemProp.org.gradle.internal.publish.checksums.insecure=true -djl_version=0.24.0 +djl_version=0.28.0 mxnet_version=1.9.1 -pytorch_version=2.0.1 +pytorch_version=2.1.2 tensorflow_version=2.10.1 tflite_version=2.6.2 -trt_version=8.4.1 -onnxruntime_version=1.15.1 +trt_version=9.2.0 +onnxruntime_version=1.17.1 paddlepaddle_version=2.3.2 sentencepiece_version=0.1.97 -tokenizers_version=0.13.3 +tokenizers_version=0.15.2 +llamacpp_version=b1696 fasttext_version=0.9.2 -xgboost_version=1.7.5 +xgboost_version=2.0.3 lightgbm_version=3.2.110 rapis_version=22.12.0 -commons_cli_version=1.5.0 -commons_compress_version=1.23.0 +commons_cli_version=1.6.0 +commons_compress_version=1.26.1 commons_csv_version=1.10.0 -commons_logging_version=1.2 +commons_logging_version=1.3.1 gson_version=2.10.1 -jna_version=5.13.0 +jna_version=5.14.0 slf4j_version=1.7.36 -log4j_slf4j_version=2.20.0 -awssdk_version=2.20.121 -hadoop_version=3.3.5 +log4j_slf4j_version=2.23.1 +awssdk_version=2.25.17 +hadoop_version=3.3.6 javacpp_version=1.5.9 javacv_version=1.5.9 ffmpeg_version=6.0-1.5.9 -protobuf_version=3.23.3 +protobuf_version=3.25.3 tablesaw_version=0.43.1 spark_version=3.3.2 openpnp_opencv_version=4.7.0-0 antlr_version=4.11.1 -testng_version=7.8.0 +testng_version=7.9.0 junit_version=4.13.2 -mockito_version=5.3.1 +mockito_version=5.11.0 diff --git a/index1.0.html b/index1.0.html index 1a7d9841065..98b4a3f2911 100644 --- a/index1.0.html +++ b/index1.0.html @@ -59,7 +59,7 @@
  • JavaDoc
  • Demos
  • Blogs
  • -
  • Tutorial
  • +
  • Tutorial
  • Examples
  • Slack @@ -73,7 +73,7 @@
  • JavaDoc
  • Demos
  • Blogs
  • -
  • Tutorial
  • +
  • Tutorial
  • Examples
  • Slack diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java index b5907925ee4..008d652dc82 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java @@ -31,6 +31,7 @@ import ai.djl.nn.LambdaBlock; import ai.djl.nn.SequentialBlock; import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; @@ -123,10 +124,8 @@ private TrainingConfig setupTrainingConfig() { } private ZooModel getModel() throws IOException, ModelException { - // SSD-pikachu model only available in MXNet - // TODO: Add PyTorch model to model zoo - TestUtils.requiresEngine("MXNet"); - + TestUtils.requiresEngine( + ModelZoo.getModelZoo("ai.djl.zoo").getSupportedEngines().toArray(String[]::new)); Criteria criteria = Criteria.builder() .optApplication(Application.CV.OBJECT_DETECTION) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java index 410b4009a6d..04779187267 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java @@ -22,6 +22,7 @@ import org.testng.annotations.Test; import java.util.stream.DoubleStream; +import java.util.stream.IntStream; public class NDArrayNumericOpTest { @@ -499,6 +500,42 @@ public void testAtan() { } } + @Test + public void testAtan2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + double[] x1 = {1.0, -1.0, -1.0, 0.0, 0.0, 0.0}; + NDArray array = manager.create(x1); + double[] y1 = {1.0, 0.0, -1.0, 1.0, -1.0, 0.0}; + NDArray other = manager.create(y1); + double[] output = + IntStream.range(0, x1.length) + .mapToDouble(i -> Math.atan2(x1[i], y1[i])) + .toArray(); + NDArray expected = manager.create(output); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test multi-dim + double[] x2 = {-1.0, -0.5, 0, 0.5, 1.0}; + array = manager.create(x2, new Shape(5, 1)); + double[] y2 = {-2.0, 3.0, 6.0, 0.0, -0.3}; + other = manager.create(y2, new Shape(5, 1)); + output = + IntStream.range(0, x2.length) + .mapToDouble(i -> Math.atan2(x2[i], y2[i])) + .toArray(); + expected = manager.create(output, new Shape(5, 1)); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test scalar + array = manager.create(0f); + other = manager.create(0f); + expected = manager.create(0f); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test zero-dim + array = manager.create(new Shape(1, 0)); + other = manager.create(new Shape(1, 0)); + Assert.assertEquals(array.atan2(other), array); + } + } + @Test public void testToDegrees() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java index 6788a405f22..66bb136ab37 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java @@ -875,6 +875,40 @@ public void testErfinv() { } } + @Test + public void testErf() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + // test 1-D + NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY}); + NDArray expected = manager.create(new float[] {0f, 0.5f, -1f}); + Assertions.assertAlmostEquals(NDArrays.erf(array), expected); + // test 3-D + array = + manager.create( + new float[] { + Float.NEGATIVE_INFINITY, + -0.8134f, + -0.4769f, + -0.2253f, + 0f, + 0.2253f, + 0.4769f, + 0.8134f, + Float.POSITIVE_INFINITY + }) + .reshape(3, 1, 3); + expected = manager.linspace(-1.0f, 1.0f, 9).reshape(3, 1, 3); + Assertions.assertAlmostEquals(array.erf(), expected); + // test scalar + array = manager.create(Float.POSITIVE_INFINITY); + expected = manager.create(1f); + Assertions.assertAlmostEquals(array.erf(), expected); + // test zero-dim + array = manager.create(new Shape(2, 0)); + Assertions.assertAlmostEquals(array.erf(), array); + } + } + @Test public void testInverse() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { @@ -1053,4 +1087,58 @@ public void testStft() { Assertions.assertAlmostEquals(result.real().flatten(), expected); } } + + @Test + public void testFft2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray array = + manager.create( + new float[][] { + {1f, 6.6f, 4.315f, 2.0f}, + {16.9f, 6.697f, 2.399f, 67.9f}, + {0f, 5f, 67.09f, 9.87f} + }); + NDArray result = array.fft2(new long[] {3, 4}, new long[] {0, 1}); + result = result.real().flatten(1, 2); // flatten complex numbers + NDArray expected = + manager.create( + new float[][] { + {189.771f, 0f, -55.904f, 61.473f, -6.363f, 0f, -55.904f, -61.473f}, + { + -74.013f, + -10.3369f, + 71.7653f, + -108.2964f, + -1.746f, + 93.1133f, + -25.8063f, + -33.0234f + }, + { + -74.013f, 10.3369f, -25.8063f, 33.0234f, -1.746f, -93.1133f, + 71.7653f, 108.2964f + } + }); + Assertions.assertAlmostEquals(result, expected); + } + } + + @Test + public void testIfft2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray array = + manager.create( + new float[][] { + {1f, 6.6f, 4.315f, 2.0f}, + {16.9f, 6.697f, 2.399f, 67.9f}, + {0f, 5f, 67.09f, 9.87f} + }); + long[] sizes = {3, 4}; + long[] axes = {0, 1}; + NDArray fft2 = array.fft2(sizes, axes); + NDArray actual = fft2.ifft2(sizes, axes).real(); + NDArray expected = array.toType(DataType.COMPLEX64, true).real(); + Assertions.assertAlmostEquals(expected, actual); + } + } } diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java index ca680129062..3ace9c2bdf5 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java @@ -27,6 +27,7 @@ import org.testng.annotations.Test; import java.io.IOException; +import java.nio.file.Files; import java.nio.file.Paths; public class ModelTest { @@ -37,7 +38,9 @@ public void testModelSaveAndLoad() throws IOException, MalformedModelException { block.add(Conv2d.builder().setKernelShape(new Shape(1, 1)).setFilters(10).build()); block.add(BatchNorm.builder().build()); try (Model saveModel = Model.newInstance("saveModel", TestUtils.getEngine()); - Model loadModel = Model.newInstance("loadModel", TestUtils.getEngine())) { + Model loadModel = Model.newInstance("loadModel", TestUtils.getEngine()); + Model loadStreamModel = + Model.newInstance("loadStreamModel", TestUtils.getEngine()); ) { block.initialize(saveModel.getNDManager(), DataType.FLOAT32, new Shape(1, 3, 32, 32)); ParameterList savedParameters = block.getParameters(); saveModel.setBlock(block); @@ -48,6 +51,13 @@ public void testModelSaveAndLoad() throws IOException, MalformedModelException { loadModel.load(Paths.get("build/tmp/test/models"), "saveAndLoad"); ParameterList loadedParameters = loadModel.getBlock().getParameters(); compareParameters(savedParameters, loadedParameters); + + loadStreamModel.setBlock(block); + loadStreamModel.load( + Files.newInputStream( + Paths.get("build/tmp/test/models/saveAndLoad-0000.params"))); + loadedParameters = loadStreamModel.getBlock().getParameters(); + compareParameters(savedParameters, loadedParameters); } } diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java new file mode 100644 index 00000000000..9aee2661411 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java @@ -0,0 +1,190 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.integration.tests.training.listener; + +import ai.djl.Model; +import ai.djl.basicdataset.cv.classification.Mnist; +import ai.djl.basicmodelzoo.basic.Mlp; +import ai.djl.integration.util.TestUtils; +import ai.djl.metric.Metrics; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Activation; +import ai.djl.training.DefaultTrainingConfig; +import ai.djl.training.EasyTrain; +import ai.djl.training.Trainer; +import ai.djl.training.TrainingResult; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.listener.EarlyStoppingListener; +import ai.djl.training.listener.TrainingListener; +import ai.djl.training.loss.Loss; +import ai.djl.training.optimizer.Optimizer; +import ai.djl.training.tracker.Tracker; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.time.Duration; + +public class EarlyStoppingListenerTest { + + private final Optimizer sgd = + Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build(); + + private NDManager manager; + private Mnist testMnistDataset; + private Mnist trainMnistDataset; + + @BeforeTest + public void setUp() throws IOException, TranslateException { + manager = NDManager.newBaseManager(TestUtils.getEngine()); + testMnistDataset = + Mnist.builder() + .optUsage(Dataset.Usage.TEST) + .optManager(manager) + .optLimit(8) + .setSampling(8, false) + .build(); + testMnistDataset.prepare(); + + trainMnistDataset = + Mnist.builder() + .optUsage(Dataset.Usage.TRAIN) + .optManager(manager) + .optLimit(16) + .setSampling(8, false) + .build(); + trainMnistDataset.prepare(); + } + + @AfterTest + public void closeResources() { + manager.close(); + } + + @Test + public void testEarlyStoppingStopsOnEpoch2() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder() + .optEpochPatience(1) + .optEarlyStopPctImprovement(99) + .optMaxDuration(Duration.ofMinutes(1)) + .optMinEpochs(1) + .build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 5 as we expect the early stopping to stop after the second epoch + EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertEquals( + e.getMessage(), "failed to achieve 99.0% improvement 1 times in a row"); + Assert.assertEquals(e.getStopEpoch(), 2); + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertEquals(trainingResult.getEpoch(), 2); + } + } + } + + @Test + public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder() + .optEpochPatience(1) + .optEarlyStopPctImprovement(50) + .optMaxMillis(60_000) + .optMinEpochs(3) + .build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 5 as we expect the early stopping to stop after the second epoch + EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertEquals( + e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row"); + Assert.assertEquals(e.getStopEpoch(), 3); + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertEquals(trainingResult.getEpoch(), 3); + } + } + } + + @Test + public void testEarlyStoppingStopsOnEpoch1AsMaxDurationIs1ms() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder().optMaxMillis(1).build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 10 as we expect the early stopping to stop after the second + // epoch + EasyTrain.fit(trainer, 10, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertTrue(e.getMessage().contains("ms elapsed >=")); + Assert.assertTrue(e.getStopEpoch() < 10); // Stop epoch is before 10 + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertTrue(trainingResult.getEpoch() < 10); // Stop epoch is before 10 + } + } + } +} diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java new file mode 100644 index 00000000000..88680e5fe89 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests using the listeners {@link ai.djl.training}. */ +package ai.djl.integration.tests.training.listener; diff --git a/jacoco/build.gradle b/jacoco/build.gradle index d9196393283..fa570c50a3b 100644 --- a/jacoco/build.gradle +++ b/jacoco/build.gradle @@ -10,6 +10,7 @@ repositories { dependencies { jacocoAggregation project(":api") jacocoAggregation project(":basicdataset") + jacocoAggregation project(":engines:llama") jacocoAggregation project(":engines:ml:xgboost") jacocoAggregation project(":engines:ml:lightgbm") jacocoAggregation project(":engines:mxnet:mxnet-engine") @@ -39,7 +40,9 @@ dependencies { jacocoAggregation project(":extensions:tokenizers") jacocoAggregation project(":extensions:tablesaw") jacocoAggregation project(":extensions:timeseries") - jacocoAggregation project(":extensions:spark") + if (JavaVersion.current() < JavaVersion.VERSION_19) { + jacocoAggregation project(":extensions:spark") + } jacocoAggregation project(":integration") jacocoAggregation project(":model-zoo") } diff --git a/jupyter/BERTQA.ipynb b/jupyter/BERTQA.ipynb deleted file mode 100644 index 4ec97cbd838..00000000000 --- a/jupyter/BERTQA.ipynb +++ /dev/null @@ -1,214 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# DJL BERT Inference Demo\n", - "\n", - "## Introduction\n", - "\n", - "In this tutorial, you walk through running inference using DJL on a [BERT](https://towardsdatascience.com/bert-explained-state-of-the-art-language-model-for-nlp-f8b21a9b6270) QA model trained with MXNet and PyTorch. \n", - "You can provide a question and a paragraph containing the answer to the model. The model is then able to find the best answer from the answer paragraph.\n", - "\n", - "Example:\n", - "```text\n", - "Q: When did BBC Japan start broadcasting?\n", - "```\n", - "\n", - "Answer paragraph:\n", - "```text\n", - "BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.\n", - "It ceased operations after its Japanese distributor folded.\n", - "```\n", - "And it picked the right answer:\n", - "```text\n", - "A: December 2004\n", - "```\n", - "\n", - "One of the most powerful features of DJL is that it's engine agnostic. Because of this, you can run different backend engines seamlessly. We showcase BERT QA first with an MXNet pre-trained model, then with a PyTorch model." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import java packages by running the following:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.engine.*;\n", - "import ai.djl.modality.nlp.qa.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.repository.zoo.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that all of the prerequisites are complete, start writing code to run inference with this example.\n", - "\n", - "\n", - "## Load the model and input\n", - "\n", - "**First, load the input**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var question = \"When did BBC Japan start broadcasting?\";\n", - "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", - " \"Which operated between December 2004 and April 2006.\\n\" +\n", - " \"It ceased operations after its Japanese distributor folded.\";\n", - "\n", - "QAInput input = new QAInput(question, resourceDocument);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then load the model and vocabulary. Create a variable `model` by using the `ModelZoo` as shown in the following code." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .optApplication(Application.NLP.QUESTION_ANSWER)\n", - " .setTypes(QAInput.class, String.class)\n", - " .optEngine(\"MXNet\") // For DJL to use MXNet engine\n", - " .optProgress(new ProgressBar()).build();\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Run inference\n", - "Once the model is loaded, you can call `Predictor` and run inference as follows" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();\n", - "String answer = predictor.predict(input);\n", - "answer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Running inference on DJL is that easy. Now, let's try the PyTorch engine by specifying PyTorch engine in Criteria.optEngine(\"PyTorch\"). Let's rerun the inference code." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var question = \"When did BBC Japan start broadcasting?\";\n", - "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", - " \"Which operated between December 2004 and April 2006.\\n\" +\n", - " \"It ceased operations after its Japanese distributor folded.\";\n", - "\n", - "QAInput input = new QAInput(question, resourceDocument);\n", - "\n", - "Criteria criteria = Criteria.builder()\n", - " .optApplication(Application.NLP.QUESTION_ANSWER)\n", - " .setTypes(QAInput.class, String.class)\n", - " .optFilter(\"modelType\", \"distilbert\")\n", - " .optEngine(\"PyTorch\") // Use PyTorch engine\n", - " .optProgress(new ProgressBar()).build();\n", - "ZooModel model = criteria.loadModel();\n", - "Predictor predictor = model.newPredictor();\n", - "String answer = predictor.predict(input);\n", - "answer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "Suprisingly, there are no differences between the PyTorch code snippet and MXNet code snippet. \n", - "This is power of DJL. We define a unified API where you can switch to different backend engines on the fly.\n", - "Next chapter: Inference with your own BERT: [MXNet](mxnet/load_your_own_mxnet_bert.ipynb) [PyTorch](pytorch/load_your_own_pytorch_bert.ipynb)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/Dockerfile b/jupyter/Dockerfile deleted file mode 100644 index 9c79ec3e54a..00000000000 --- a/jupyter/Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -FROM ubuntu:18.04 - -RUN apt-get update || true -RUN apt-get install -y openjdk-11-jdk-headless -RUN apt-get install -y python3-pip git -RUN pip3 install jupyter -RUN apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y locales \ - && sed -i -e 's/# en_US.UTF-8 UTF-8/en_US.UTF-8 UTF-8/' /etc/locale.gen \ - && dpkg-reconfigure --frontend=noninteractive locales \ - && update-locale LANG=en_US.UTF-8 -RUN apt-get install -y curl - -RUN git clone https://github.com/frankfliu/IJava.git -RUN cd IJava/ && ./gradlew installKernel && cd .. && rm -rf IJava/ -RUN rm -rf ~/.gradle - -WORKDIR /home/jupyter - -ENV LANG en_US.UTF-8 -ENV LC_ALL en_US.UTF-8 - -EXPOSE 8888 -ENTRYPOINT ["jupyter", "notebook", "--ip=0.0.0.0", "--no-browser", "--allow-root", "--NotebookApp.token=''", "--NotebookApp.password=''"] diff --git a/jupyter/README.md b/jupyter/README.md index 17b0a9c9405..1b9a2584238 100644 --- a/jupyter/README.md +++ b/jupyter/README.md @@ -1,83 +1,3 @@ # DJL - Jupyter notebooks -## Overview - -This folder contains tutorials that illustrate how to accomplish basic AI tasks with Deep Java Library (DJL). - -## [Beginner Tutorial](tutorial/README.md) - -## More Tutorial Notebooks - -- [Run object detection with model zoo](object_detection_with_model_zoo.ipynb) -- [Load pre-trained PyTorch model](load_pytorch_model.ipynb) -- [Load pre-trained Apache MXNet model](load_mxnet_model.ipynb) -- [Transfer learning example](transfer_learning_on_cifar10.ipynb) -- [Question answering example](BERTQA.ipynb) - -You can run our notebook online: [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/deepjavalibrary/djl/master?filepath=jupyter) - -## Setup - -### JDK 11 (not jre) - -JDK 11 (or above are required) to run the examples provided in this folder. - -to confirm the java path is configured properly: - -```bash -java --list-modules | grep "jdk.jshell" - -> jdk.jshell@12.0.1 -``` - -### Install jupyter notebook on python3 - -```bash -pip3 install jupyter -``` - -### Install IJava kernel for jupyter - -```bash -git clone https://github.com/frankfliu/IJava.git -cd IJava/ -./gradlew installKernel -``` - -## Start jupyter notebook - -```bash -jupyter notebook -``` - -## Docker setup - -You may want to use docker for simple installation or you are using Windows. - -### Run docker image - -```sh -cd jupyter -docker run -itd -p 127.0.0.1:8888:8888 -v $PWD:/home/jupyter deepjavalibrary/jupyter -``` - -You can open the `http://localhost:8888` to see the hosted instance on docker. - -### Build docker image by yourself - -You can read [Dockerfile](https://github.com/deepjavalibrary/djl/blob/master/jupyter/Dockerfile) for detail. To build docker image: - -```sh -cd jupyter -docker build -t deepjavalibrary/jupyter . -``` - -### Run docker compose - -```sh -cd jupyter -docker-compose build -docker-compose up -d -``` - -You can open the `http://localhost:8888` to see the hosted instance on docker compose. +The jupyter notebook documentation and examples have been moved to the [DJL Demos repo](http://docs.djl.ai/docs/demos/jupyter/index.html). \ No newline at end of file diff --git a/jupyter/docker-compose.yml b/jupyter/docker-compose.yml deleted file mode 100644 index e8e4d2f83b8..00000000000 --- a/jupyter/docker-compose.yml +++ /dev/null @@ -1,12 +0,0 @@ -version: "2.4" -services: - deepjavalibrary_container: - build: - context: . - dockerfile: Dockerfile - ports: - - 8888:8888 - volumes: - - ./:/home/jupyter - restart: always - diff --git a/jupyter/load_mxnet_model.ipynb b/jupyter/load_mxnet_model.ipynb deleted file mode 100644 index f90091d1ef4..00000000000 --- a/jupyter/load_mxnet_model.ipynb +++ /dev/null @@ -1,190 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load MXNet model\n", - "\n", - "In this tutorial, you learn how to load an existing MXNet model and use it to run a prediction task.\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:model-zoo:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.awt.image.*;\n", - "import java.nio.file.*;\n", - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Prepare your MXNet model\n", - "\n", - "This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files:\n", - "* Symbol file: {MODEL_NAME}-symbol.json - a json file that contains network information about the model\n", - "* Parameters file: {MODEL_NAME}-{EPOCH}.params - a binary file that stores the parameter weight and bias\n", - "* Synset file: synset.txt - an optional text file that stores classification classes labels\n", - "\n", - "This tutorial uses a pre-trained MXNet `resnet18_v1` model." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We use `DownloadUtils` for downloading files from internet." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json\", \"build/resnet/resnet18_v1-symbol.json\", new ProgressBar());\n", - "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz\", \"build/resnet/resnet18_v1-0000.params\", new ProgressBar());\n", - "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt\", \"build/resnet/synset.txt\", new ProgressBar());\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Load your model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path modelDir = Paths.get(\"build/resnet\");\n", - "Model model = Model.newInstance(\"resnet\");\n", - "model.load(modelDir, \"resnet18_v1\");" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Create a `Translator`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Pipeline pipeline = new Pipeline();\n", - "pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());\n", - "Translator translator = ImageClassificationTranslator.builder()\n", - " .setPipeline(pipeline)\n", - " .optSynsetArtifactName(\"synset.txt\")\n", - " .optApplySoftmax(true)\n", - " .build();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Load image for classification" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/kitten.jpg\");\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 5: Run inference" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor(translator);\n", - "Classifications classifications = predictor.predict(img);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now, you can load any MXNet symbolic model and run inference.\n", - "\n", - "You might also want to check out [load_pytorch_model.ipynb](https://github.com/deepjavalibrary/djl/blob/master/jupyter/load_pytorch_model.ipynb) which demonstrates loading a local model using the ModelZoo API." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/load_pytorch_model.ipynb b/jupyter/load_pytorch_model.ipynb deleted file mode 100644 index bf4e3db3e3f..00000000000 --- a/jupyter/load_pytorch_model.ipynb +++ /dev/null @@ -1,232 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "# Load PyTorch model\n", - "\n", - "In this tutorial, you learn how to load an existing PyTorch model and use it to run a prediction task.\n", - "\n", - "We will run the inference in DJL way with [example](https://pytorch.org/hub/pytorch_vision_resnet/) on the pytorch official website.\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.nio.file.*;\n", - "import java.awt.image.*;\n", - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Prepare your model\n", - "\n", - "This tutorial assumes that you have a TorchScript model.\n", - "DJL only supports the TorchScript format for loading models from PyTorch, so other models will need to be [converted](https://github.com/deepjavalibrary/djl/blob/master/docs/pytorch/how_to_convert_your_model_to_torchscript.md).\n", - "A TorchScript model includes the model structure and all of the parameters.\n", - "\n", - "We will be using a pre-trained `resnet18` model. First, use the `DownloadUtils` to download the model files and save them in the `build/pytorch_models` folder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz\", \"build/pytorch_models/resnet18/resnet18.pt\", new ProgressBar());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In order to do image classification, you will also need the synset.txt which stores the classification class labels. We will need the synset containing the Imagenet labels with which resnet18 was originally trained." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt\", \"build/pytorch_models/resnet18/synset.txt\", new ProgressBar());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Create a Translator\n", - "\n", - "We will create a transformation pipeline which maps the transforms shown in the [PyTorch example](https://pytorch.org/hub/pytorch_vision_resnet/).\n", - "```python\n", - "...\n", - "preprocess = transforms.Compose([\n", - " transforms.Resize(256),\n", - " transforms.CenterCrop(224),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", - "])\n", - "...\n", - "```\n", - "\n", - "Then, we will use this pipeline to create the [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Translator translator = ImageClassificationTranslator.builder()\n", - " .addTransform(new Resize(256))\n", - " .addTransform(new CenterCrop(224, 224))\n", - " .addTransform(new ToTensor())\n", - " .addTransform(new Normalize(\n", - " new float[] {0.485f, 0.456f, 0.406f},\n", - " new float[] {0.229f, 0.224f, 0.225f}))\n", - " .optApplySoftmax(true)\n", - " .build();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Load your model\n", - "\n", - "Next, we add some search criteria to find the resnet18 model and load it. In this case, we need to tell `Criteria` where to locate the model by calling `.optModelPath()` API." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelPath(Paths.get(\"build/pytorch_models/resnet18\"))\n", - " .optOption(\"mapLocation\", \"true\") // this model requires mapLocation for GPU\n", - " .optTranslator(translator)\n", - " .optProgress(new ProgressBar()).build();\n", - "\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Load image for classification\n", - "\n", - "We will use a sample dog image to run our prediction on." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg\");\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 5: Run inference\n", - "\n", - "Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();\n", - "Classifications classifications = predictor.predict(img);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now, you can load any TorchScript model and run inference using it.\n", - "\n", - "You might also want to check out [load_mxnet_model.ipynb](https://github.com/deepjavalibrary/djl/blob/master/jupyter/load_mxnet_model.ipynb) which demonstrates loading a local model directly instead of through the Model Zoo API." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/mxnet/load_your_own_mxnet_bert.ipynb b/jupyter/mxnet/load_your_own_mxnet_bert.ipynb deleted file mode 100644 index 9691a4d683a..00000000000 --- a/jupyter/mxnet/load_your_own_mxnet_bert.ipynb +++ /dev/null @@ -1,485 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load your own MXNet BERT model\n", - "\n", - "In the previous [example](../BERTQA.ipynb), you run BERT inference with the model from Model Zoo. You can also load the model on your own pre-trained BERT and use custom classes as the input and output.\n", - "\n", - "In general, the MXNet BERT model requires these three inputs:\n", - "\n", - "- word indices: The index of each word in a sentence\n", - "- word types: The type index of the word.\n", - "- valid length: The actual length of the question and resource document tokens\n", - "\n", - "We will dive deep into these details later." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are dependencies we will use." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import java packages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import java.util.stream.*;\n", - "\n", - "import ai.djl.*;\n", - "import ai.djl.util.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.modality.nlp.*;\n", - "import ai.djl.modality.nlp.qa.*;\n", - "import ai.djl.mxnet.zoo.nlp.qa.*;\n", - "import ai.djl.modality.nlp.bert.*;\n", - "\n", - "import com.google.gson.annotations.SerializedName;\n", - "import java.nio.charset.StandardCharsets;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Reuse the previous input**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var question = \"When did BBC Japan start broadcasting?\";\n", - "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", - " \"Which operated between December 2004 and April 2006.\\n\" +\n", - " \"It ceased operations after its Japanese distributor folded.\";\n", - "\n", - "QAInput input = new QAInput(question, resourceDocument);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dive deep into Translator\n", - "\n", - "Inference in deep learning is the process of predicting the output for a given input based on a pre-defined model.\n", - "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide\n", - "output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:\n", - "\n", - "![https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "The red block (\"Images\") in the workflow is the input that DJL expects from you. The green block (\"Images\n", - "bounding box\") is the output that you expect. Because DJL does not know which input to expect and which output format that you prefer, DJL provides the [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface so you can define your own\n", - "input and output.\n", - "\n", - "The `Translator` interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing\n", - "component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the\n", - "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n", - "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n", - "format." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Pre-processing\n", - "\n", - "Now, you need to convert the sentences into tokens. We provide a powerful tool [`BertTokenizer`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/bert/BertTokenizer.html) that you can use to convert questions and answers into tokens, and batchify your sequence together. Once you have properly formatted tokens, you can use [`Vocabulary`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/Vocabulary.html) to map your token to BERT index.\n", - "\n", - "The following code block demonstrates tokenizing the question and answer defined earlier into BERT-formatted tokens." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var tokenizer = new BertTokenizer();\n", - "List tokenQ = tokenizer.tokenize(question.toLowerCase());\n", - "List tokenA = tokenizer.tokenize(resourceDocument.toLowerCase());\n", - "\n", - "System.out.println(\"Question Token: \" + tokenQ);\n", - "System.out.println(\"Answer Token: \" + tokenA);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`BertTokenizer` can also help you batchify questions and resource documents together by calling `encode()`.\n", - "The output contains information that BERT ingests.\n", - "\n", - "- getTokens: It returns a list of strings, including the question, resource document and special word to let the model tell which part is the question and which part is the resource document. Because MXNet BERT was trained with a fixed sequence length, you see the `[PAD]` in the tokens as well.\n", - "- getTokenTypes: It returns a list of type indices of the word to indicate the location of the resource document. All Questions will be labelled with 0 and all resource documents will be labelled with 1.\n", - "\n", - " [Question tokens...DocResourceTokens...padding tokens] => [000000...11111....0000]\n", - " \n", - "\n", - "- getValidLength: It returns the actual length of the question and tokens, which are required by MXNet BERT.\n", - "- getAttentionMask: It returns the mask for the model to indicate which part should be paid attention to and which part is the padding. It is required by PyTorch BERT.\n", - "\n", - " [Question tokens...DocResourceTokens...padding tokens] => [111111...11111....0000]\n", - " \n", - "MXNet BERT was trained with fixed sequence length 384, so we need to pass that in when we encode the question and resource doc. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BertToken token = tokenizer.encode(question.toLowerCase(), resourceDocument.toLowerCase(), 384);\n", - "System.out.println(\"Encoded tokens: \" + token.getTokens());\n", - "System.out.println(\"Encoded token type: \" + token.getTokenTypes());\n", - "System.out.println(\"Valid length: \" + token.getValidLength());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Normally, words and sentences are represented as indices instead of tokens for training. \n", - "They typically work like a vector in a n-dimensional space. In this case, you need to map them into indices.\n", - "DJL provides `Vocabulary` to take care of you vocabulary mapping.\n", - "\n", - "Assume your vocab.json is of the following format\n", - "```\n", - "{'token_to_idx':{'\"slots\": 19832,...}, 'idx_to_token':[\"[UNK]\", \"[PAD]\", ...]}\n", - "```\n", - "We provide the `vocab.json` from our pre-trained BERT for demonstration." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/mxnet/bertqa/vocab.json\", \"build/mxnet/bertqa/vocab.json\", new ProgressBar());" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class VocabParser {\n", - " @SerializedName(\"idx_to_token\")\n", - " List idx2token;\n", - "\n", - " public static List parseToken(URL file) {\n", - " try (InputStream is = file.openStream();\n", - " Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {\n", - " return JsonUtils.GSON.fromJson(reader, VocabParser.class).idx2token;\n", - " } catch (IOException e) {\n", - " throw new IllegalArgumentException(\"Invalid url: \" + file, e);\n", - " }\n", - " }\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "URL url = Paths.get(\"build/mxnet/bertqa/vocab.json\").toUri().toURL();\n", - "var vocabulary = DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromCustomizedFile(url, VocabParser::parseToken)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can easily convert the token to the index using `vocabulary.getIndex(token)` and the other way around using `vocabulary.getToken(index)`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "long index = vocabulary.getIndex(\"car\");\n", - "String token = vocabulary.getToken(2482);\n", - "System.out.println(\"The index of the car is \" + index);\n", - "System.out.println(\"The token of the index 2482 is \" + token);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To properly convert them into `float[]` for `NDArray` creation, use the following helper function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "/**\n", - " * Convert a List of Number to float array.\n", - " *\n", - " * @param list the list to be converted\n", - " * @return float array\n", - " */\n", - "public static float[] toFloatArray(List list) {\n", - " float[] ret = new float[list.size()];\n", - " int idx = 0;\n", - " for (Number n : list) {\n", - " ret[idx++] = n.floatValue();\n", - " }\n", - " return ret;\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that you have everything you need, you can create an NDList and populate all of the inputs you formatted earlier. You're done with pre-processing! \n", - "\n", - "#### Construct `Translator`\n", - "\n", - "You need to do this processing within an implementation of the `Translator` interface. `Translator` is designed to do pre-processing and post-processing. You must define the input and output objects. It contains the following two override classes:\n", - "- `public NDList processInput(TranslatorContext ctx, I)`\n", - "- `public String processOutput(TranslatorContext ctx, O)`\n", - "\n", - "Every translator takes in input and returns output in the form of generic objects. In this case, the translator takes input in the form of `QAInput` (I) and returns output as a `String` (O). `QAInput` is just an object that holds questions and answer; We have prepared the Input class for you." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Armed with the needed knowledge, you can write an implementation of the `Translator` interface. `BertTranslator` uses the code snippets explained previously to implement the `processInput`method. For more information, see [`NDManager`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDManager.html).\n", - "\n", - "```\n", - "manager.create(Number[] data, Shape)\n", - "manager.create(Number[] data)\n", - "```\n", - "\n", - "The `Shape` for `data0` and `data1` is sequence_length. For `data2` the `Shape` is just 1." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "public class BertTranslator implements NoBatchifyTranslator {\n", - " private List tokens;\n", - " private Vocabulary vocabulary;\n", - " private BertTokenizer tokenizer;\n", - " \n", - " @Override\n", - " public void prepare(TranslatorContext ctx) throws IOException {\n", - " URL path = Paths.get(\"build/mxnet/bertqa/vocab.json\").toUri().toURL();\n", - " vocabulary =\n", - " DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromCustomizedFile(path, VocabParser::parseToken)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", - " tokenizer = new BertTokenizer();\n", - " }\n", - " \n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, QAInput input) {\n", - " BertToken token =\n", - " tokenizer.encode(\n", - " input.getQuestion().toLowerCase(),\n", - " input.getParagraph().toLowerCase(),\n", - " 384);\n", - " // get the encoded tokens that would be used in precessOutput\n", - " tokens = token.getTokens();\n", - " // map the tokens(String) to indices(long)\n", - " List indices =\n", - " token.getTokens().stream().map(vocabulary::getIndex).collect(Collectors.toList());\n", - " float[] indexesFloat = toFloatArray(indices);\n", - " float[] types = toFloatArray(token.getTokenTypes());\n", - " int validLength = token.getValidLength();\n", - "\n", - " NDManager manager = ctx.getNDManager();\n", - " NDArray data0 = manager.create(indexesFloat);\n", - " data0.setName(\"data0\");\n", - " NDArray data1 = manager.create(types);\n", - " data1.setName(\"data1\");\n", - " NDArray data2 = manager.create(new float[] {validLength});\n", - " data2.setName(\"data2\");\n", - " return new NDList(data0, data1, data2);\n", - " }\n", - "\n", - " @Override\n", - " public String processOutput(TranslatorContext ctx, NDList list) {\n", - " NDArray array = list.singletonOrThrow();\n", - " NDList output = array.split(2, 2);\n", - " // Get the formatted logits result\n", - " NDArray startLogits = output.get(0).reshape(new Shape(1, -1));\n", - " NDArray endLogits = output.get(1).reshape(new Shape(1, -1));\n", - " int startIdx = (int) startLogits.argMax(1).getLong();\n", - " int endIdx = (int) endLogits.argMax(1).getLong();\n", - " return tokens.subList(startIdx, endIdx + 1).toString();\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Congrats! You have created your first Translator! We have pre-filled the `processOutput()` function to process the `NDList` and return it in a desired format. `processInput()` and `processOutput()` offer the flexibility to get the predictions from the model in any format you desire. \n", - "\n", - "With the Translator implemented, you need to bring up the predictor that uses your `Translator` to start making predictions. You can find the usage for `Predictor` in the [Predictor Javadoc](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html). Create a translator and use the `question` and `resourceDocument` provided previously." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/mxnet/bertqa/0.0.1/static_bert_qa-symbol.json\", \"build/mxnet/bertqa/bertqa-symbol.json\", new ProgressBar());\n", - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/mxnet/bertqa/0.0.1/static_bert_qa-0002.params.gz\", \"build/mxnet/bertqa/bertqa-0000.params\", new ProgressBar());" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BertTranslator translator = new BertTranslator();\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(QAInput.class, String.class)\n", - " .optModelPath(Paths.get(\"build/mxnet/bertqa/\")) // Search for models in the build/mxnet/bert folder\n", - " .optTranslator(translator)\n", - " .optProgress(new ProgressBar()).build();\n", - "\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String predictResult = null;\n", - "QAInput input = new QAInput(question, resourceDocument);\n", - "\n", - "// Create a Predictor and use it to predict the output\n", - "try (Predictor predictor = model.newPredictor(translator)) {\n", - " predictResult = predictor.predict(input);\n", - "}\n", - "\n", - "System.out.println(question);\n", - "System.out.println(predictResult);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Based on the input, the following result will be shown:\n", - "```\n", - "[december, 2004]\n", - "```\n", - "That's it! \n", - "\n", - "You can try with more questions and answers. Here are the samples:\n", - "\n", - "**Answer Material**\n", - "\n", - "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.\n", - "\n", - "\n", - "**Question**\n", - "\n", - "Q: When were the Normans in Normandy?\n", - "A: 10th and 11th centuries\n", - "\n", - "Q: In what country is Normandy located?\n", - "A: france\n", - "\n", - "For the full source code,see the [DJL repo](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java) and translator implementation [MXNet](https://github.com/deepjavalibrary/djl/blob/master/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/qa/MxBertQATranslator.java) [PyTorch](https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/qa/PtBertQATranslator.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/object_detection_with_model_zoo.ipynb b/jupyter/object_detection_with_model_zoo.ipynb deleted file mode 100644 index 9435b9de7aa..00000000000 --- a/jupyter/object_detection_with_model_zoo.ipynb +++ /dev/null @@ -1,159 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Object detection with model zoo model\n", - "\n", - "In this tutorial, you learn how to use a built-in model zoo model (SSD) to achieve an [object detection](https://en.wikipedia.org/wiki/Object_detection) task.\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.mxnet.zoo.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.training.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Load image" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/dog_bike_car.jpg\");\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Load model zoo model\n", - "\n", - "In this example, you load a SSD (Single Shot MultiBox Detector) model from the MXNet model zoo.\n", - "For more information about model zoo, see the [Model Zoo Documentation](https://github.com/deepjavalibrary/djl/blob/master/docs/model-zoo.md) " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria = Criteria.builder()\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optArtifactId(\"ssd\")\n", - " .optProgress(new ProgressBar())\n", - " .build();\n", - "var model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Create Predictor and detect an object in the image" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var detections = model.newPredictor().predict(img);\n", - "\n", - "detections" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Check detected result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "img.drawBoundingBoxes(detections);\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Using the model zoo model provided, you can run inference with just the following lines of code:\n", - "\n", - "```\n", - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/dog_bike_car.jpg\");\n", - "var criteria = Criteria.builder()\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optArtifactId(\"ssd\")\n", - " .build();\n", - "var model = criteria.loadModel();\n", - "var detections = model.newPredictor().predict(img);\n", - "```\n", - "\n", - "You can find full SsdExample source code [here](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/object_detection.md).\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb b/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb deleted file mode 100644 index d068a97e78b..00000000000 --- a/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb +++ /dev/null @@ -1,224 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Classification on Iris dataset with sklearn and DJL\n", - "\n", - "In this notebook, you will try to use a pre-trained sklearn model to run on DJL for a general classification task. The model was trained with [Iris flower dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set).\n", - "\n", - "## Background \n", - "\n", - "### Iris Dataset\n", - "\n", - "The dataset contains a set of 150 records under five attributes - sepal length, sepal width, petal length, petal width and species.\n", - "\n", - "Iris setosa | Iris versicolor | Iris virginica\n", - ":-------------------------:|:-------------------------:|:-------------------------:\n", - "![](https://upload.wikimedia.org/wikipedia/commons/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg) | ![](https://upload.wikimedia.org/wikipedia/commons/4/41/Iris_versicolor_3.jpg) | ![](https://upload.wikimedia.org/wikipedia/commons/9/9f/Iris_virginica.jpg) \n", - "\n", - "The chart above shows three different kinds of the Iris flowers. \n", - "\n", - "We will use sepal length, sepal width, petal length, petal width as the feature and species as the label to train the model.\n", - "\n", - "### Sklearn Model\n", - "\n", - "You can find more information [here](http://onnx.ai/sklearn-onnx/). You can use the sklearn built-in iris dataset to load the data. Then we defined a [RandomForestClassifer](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) to train the model. After that, we convert the model to onnx format for DJL to run inference. The following code is a sample classification setup using sklearn:\n", - "\n", - "```python\n", - "# Train a model.\n", - "from sklearn.datasets import load_iris\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.ensemble import RandomForestClassifier\n", - "iris = load_iris()\n", - "X, y = iris.data, iris.target\n", - "X_train, X_test, y_train, y_test = train_test_split(X, y)\n", - "clr = RandomForestClassifier()\n", - "clr.fit(X_train, y_train)\n", - "```\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md).\n", - "\n", - "These are dependencies we will use. To enhance the NDArray operation capability, we are importing ONNX Runtime and PyTorch Engine at the same time. Please find more information [here](https://github.com/deepjavalibrary/djl/blob/master/docs/hybrid_engine.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.onnxruntime:onnxruntime-engine:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1 create a Translator\n", - "\n", - "Inference in machine learning is the process of predicting the output for a given input based on a pre-defined model.\n", - "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide\n", - "output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:\n", - "\n", - "![https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "The [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing\n", - "component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the\n", - "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n", - "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n", - "format.\n", - "\n", - "In our use case, we use a class namely `IrisFlower` as our input class type. We will use [`Classifications`](https://javadoc.io/doc/ai.djl/api/0.23.0/ai/djl/modality/Classifications.html) as our output class type." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "public static class IrisFlower {\n", - "\n", - " public float sepalLength;\n", - " public float sepalWidth;\n", - " public float petalLength;\n", - " public float petalWidth;\n", - "\n", - " public IrisFlower(float sepalLength, float sepalWidth, float petalLength, float petalWidth) {\n", - " this.sepalLength = sepalLength;\n", - " this.sepalWidth = sepalWidth;\n", - " this.petalLength = petalLength;\n", - " this.petalWidth = petalWidth;\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create a translator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "public static class MyTranslator implements NoBatchifyTranslator {\n", - "\n", - " private final List synset;\n", - "\n", - " public MyTranslator() {\n", - " // species name\n", - " synset = Arrays.asList(\"setosa\", \"versicolor\", \"virginica\");\n", - " }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, IrisFlower input) {\n", - " float[] data = {input.sepalLength, input.sepalWidth, input.petalLength, input.petalWidth};\n", - " NDArray array = ctx.getNDManager().create(data, new Shape(1, 4));\n", - " return new NDList(array);\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " float[] data = list.get(1).toFloatArray();\n", - " List probabilities = new ArrayList<>(data.length);\n", - " for (float f : data) {\n", - " probabilities.add((double) f);\n", - " }\n", - " return new Classifications(synset, probabilities);\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2 Prepare your model\n", - "\n", - "We will load a pretrained sklearn model into DJL. We defined a [`ModelZoo`](https://javadoc.io/doc/ai.djl/api/0.23.0/ai/djl/repository/zoo/ModelZoo.html) concept to allow user load model from varity of locations, such as remote URL, local files or DJL pretrained model zoo. We need to define [`Criteria`](https://javadoc.io/doc/ai.djl/api/0.23.0/ai/djl/repository/zoo/Criteria.html) class to help the modelzoo locate the model and attach translator. In this example, we download a compressed ONNX model from S3." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String modelUrl = \"https://mlrepo.djl.ai/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip\";\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(IrisFlower.class, Classifications.class)\n", - " .optModelUrls(modelUrl)\n", - " .optTranslator(new MyTranslator())\n", - " .optEngine(\"OnnxRuntime\") // use OnnxRuntime engine by default\n", - " .build();\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3 Run inference\n", - "\n", - "User will just need to create a `Predictor` from model to run the inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();\n", - "IrisFlower info = new IrisFlower(1.0f, 2.0f, 3.0f, 4.0f);\n", - "predictor.predict(info);" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb b/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb deleted file mode 100644 index 1249ee12e2f..00000000000 --- a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb +++ /dev/null @@ -1,369 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Face Mask Detection using PaddlePaddle\n", - "\n", - "In this tutorial, we will be using pretrained PaddlePaddle model from [PaddleHub](https://github.com/PaddlePaddle/PaddleHub/tree/release/v1.5/demo/mask_detection/cpp) to do mask detection on the sample image. To complete this procedure, there are two steps needs to be done:\n", - "\n", - "- Recognize face on the image (no matter wearing mask or not) using Face object detection model\n", - "- classify the face is wearing mask or not\n", - "\n", - "These two steps will involve two paddle models. We will implement the corresponding preprocess and postprocess logic to it.\n", - "\n", - "## Import dependencies and classes\n", - "\n", - "PaddlePaddle is one of the Deep Engines that requires DJL hybrid mode to run inference. Itself does not contains NDArray operations and needs a supplemental DL framework to help with that. So we import Pytorch DL engine as well in here to do the processing works." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// second engine to do preprocessing and postprocessing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.Shape;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "\n", - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Face Detection model\n", - "\n", - "Now we can start working on the first model. The model can do face detection and require some additional processing before we feed into it:\n", - "\n", - "- Resize: Shrink the image with a certain ratio to feed in\n", - "- Normalize the image with a scale\n", - "\n", - "Fortunatly, DJL offers a [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface that can help you with these processing. The rough Translator architecture looks like below:\n", - "\n", - "![](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "In the following sections, we will implement a `FaceTranslator` class to do the work.\n", - "\n", - "### Preprocessing\n", - "\n", - "In this stage, we will load an image and do some preprocessing work to it. Let's load the image first and take a look at it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String url = \"https://raw.githubusercontent.com/PaddlePaddle/PaddleHub/release/v1.5/demo/mask_detection/python/images/mask.jpg\";\n", - "Image img = ImageFactory.getInstance().fromUrl(url);\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then, let's try to apply some transformation to it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "NDList processImageInput(NDManager manager, Image input, float shrink) {\n", - " NDArray array = input.toNDArray(manager);\n", - " Shape shape = array.getShape();\n", - " array = NDImageUtils.resize(\n", - " array, (int) (shape.get(1) * shrink), (int) (shape.get(0) * shrink));\n", - " array = array.transpose(2, 0, 1).flip(0); // HWC -> CHW BGR -> RGB\n", - " NDArray mean = manager.create(new float[] {104f, 117f, 123f}, new Shape(3, 1, 1));\n", - " array = array.sub(mean).mul(0.007843f); // normalization\n", - " array = array.expandDims(0); // make batch dimension\n", - " return new NDList(array);\n", - "}\n", - "\n", - "processImageInput(NDManager.newBaseManager(), img, 0.5f);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see above, we convert the image to a NDArray with shape following (number_of_batches, channel (RGB), height, width). This is the required input for the model to run object detection.\n", - "\n", - "### Postprocessing\n", - "\n", - "For postprocessing, The output is in shape of (number_of_boxes, (class_id, probability, xmin, ymin, xmax, ymax)). We can store them into the prebuilt DJL [`DetectedObjects`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/cv/output/DetectedObjects.html) classes for further processing. Let's assume we have an inference output of ((1, 0.99, 0.2, 0.4, 0.5, 0.8)) and try to draw this box out." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "DetectedObjects processImageOutput(NDList list, List className, float threshold) {\n", - " NDArray result = list.singletonOrThrow();\n", - " float[] probabilities = result.get(\":,1\").toFloatArray();\n", - " List names = new ArrayList<>();\n", - " List prob = new ArrayList<>();\n", - " List boxes = new ArrayList<>();\n", - " for (int i = 0; i < probabilities.length; i++) {\n", - " if (probabilities[i] >= threshold) {\n", - " float[] array = result.get(i).toFloatArray();\n", - " names.add(className.get((int) array[0]));\n", - " prob.add((double) probabilities[i]);\n", - " boxes.add(\n", - " new Rectangle(\n", - " array[2], array[3], array[4] - array[2], array[5] - array[3]));\n", - " }\n", - " }\n", - " return new DetectedObjects(names, prob, boxes);\n", - "}\n", - "\n", - "NDArray tempOutput = NDManager.newBaseManager().create(new float[]{1f, 0.99f, 0.1f, 0.1f, 0.2f, 0.2f}, new Shape(1, 6));\n", - "DetectedObjects testBox = processImageOutput(new NDList(tempOutput), Arrays.asList(\"Not Face\", \"Face\"), 0.7f);\n", - "Image newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(testBox);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Create Translator and run inference\n", - "\n", - "After this step, you might understand how process and postprocess works in DJL. Now, let's do something real and put them together in a single piece:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class FaceTranslator implements NoBatchifyTranslator {\n", - "\n", - " private float shrink;\n", - " private float threshold;\n", - " private List className;\n", - "\n", - " FaceTranslator(float shrink, float threshold) {\n", - " this.shrink = shrink;\n", - " this.threshold = threshold;\n", - " className = Arrays.asList(\"Not Face\", \"Face\");\n", - " }\n", - "\n", - " @Override\n", - " public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {\n", - " return processImageOutput(list, className, threshold);\n", - " }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, Image input) {\n", - " return processImageInput(ctx.getNDManager(), input, shrink);\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To run inference with this model, we need to load the model from Paddle model zoo. To load a model in DJL, you need to specify a [`Criteria`](https://javadoc.io/doc/ai.djl/api/0.23.1/ai/djl/repository/zoo/Criteria.html). `Criteria` is used identify where to load the model and which `Translator` should apply to it. Then, all we need to do is to get a [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) from the model and use it to do inference:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optModelUrls(\"djl://ai.djl.paddlepaddle/face_detection/0.0.1/mask_detection\")\n", - " .optFilter(\"flavor\", \"server\")\n", - " .optTranslator(new FaceTranslator(0.5f, 0.7f))\n", - " .build();\n", - " \n", - "var model = criteria.loadModel();\n", - "var predictor = model.newPredictor();\n", - "\n", - "DetectedObjects inferenceResult = predictor.predict(img);\n", - "newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(inferenceResult);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see above, it brings you three faces detections.\n", - "\n", - "## Mask Classification model\n", - "\n", - "\n", - "So, once we have the image location ready, we can crop the image and feed it to the Mask Classification model for further processing.\n", - "\n", - "### Crop the image\n", - "\n", - "The output of the box location is a value from 0 - 1 that can be mapped to the actual box pixel location if we simply multiply by width/height. For better accuracy on the cropped image, we extend the detection box to square. Let's try to get a cropped image:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int[] extendSquare(\n", - " double xmin, double ymin, double width, double height, double percentage) {\n", - " double centerx = xmin + width / 2;\n", - " double centery = ymin + height / 2;\n", - " double maxDist = Math.max(width / 2, height / 2) * (1 + percentage);\n", - " return new int[] {\n", - " (int) (centerx - maxDist), (int) (centery - maxDist), (int) (2 * maxDist)\n", - " };\n", - "}\n", - "\n", - "Image getSubImage(Image img, BoundingBox box) {\n", - " Rectangle rect = box.getBounds();\n", - " int width = img.getWidth();\n", - " int height = img.getHeight();\n", - " int[] squareBox =\n", - " extendSquare(\n", - " rect.getX() * width,\n", - " rect.getY() * height,\n", - " rect.getWidth() * width,\n", - " rect.getHeight() * height,\n", - " 0.18);\n", - " return img.getSubImage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);\n", - "}\n", - "\n", - "List faces = inferenceResult.items();\n", - "getSubImage(img, faces.get(2).getBoundingBox()).getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prepare Translator and load the model\n", - "\n", - "For the face classification model, we can use DJL prebuilt [`ImageClassificationTranslator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/cv/translator/ImageClassificationTranslator.html) with a few transformation. This Translator brings a basic image translation process and can be extended with additional standard processing steps. So in our case, we don't have to create another `Translator` and just leverage on this prebuilt one." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(\"djl://ai.djl.paddlepaddle/mask_classification/0.0.1/mask_classification\")\n", - " .optFilter(\"flavor\", \"server\")\n", - " .optTranslator(\n", - " ImageClassificationTranslator.builder()\n", - " .addTransform(new Resize(128, 128))\n", - " .addTransform(new ToTensor()) // HWC -> CHW div(255)\n", - " .addTransform(\n", - " new Normalize(\n", - " new float[] {0.5f, 0.5f, 0.5f},\n", - " new float[] {1.0f, 1.0f, 1.0f}))\n", - " .addTransform(nd -> nd.flip(0)) // RGB -> GBR\n", - " .build())\n", - " .build();\n", - "\n", - "var classifyModel = criteria.loadModel();\n", - "var classifier = classifyModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Run inference\n", - "\n", - "So all we need to do is to apply the previous implemented functions and apply them all together. We firstly crop the image and then use it for inference. After these steps, we create a new DetectedObjects with new Classification classes:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "List names = new ArrayList<>();\n", - "List prob = new ArrayList<>();\n", - "List rect = new ArrayList<>();\n", - "for (DetectedObjects.DetectedObject face : faces) {\n", - " Image subImg = getSubImage(img, face.getBoundingBox());\n", - " Classifications classifications = classifier.predict(subImg);\n", - " names.add(classifications.best().getClassName());\n", - " prob.add(face.getProbability());\n", - " rect.add(face.getBoundingBox());\n", - "}\n", - "\n", - "newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", - "newImage.getWrappedImage();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb b/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb deleted file mode 100644 index 46c86461bdb..00000000000 --- a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb +++ /dev/null @@ -1,352 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# į”¨éŖ›æ§ŗ+ DJL å¯ĻäŊœäēē臉åŖįŊŠčž¨č­˜\n", - "åœ¨é€™å€‹æ•™å­¸ä¸­æˆ‘å€‘å°‡æœƒåą•į¤ē刊į”¨ PaddleHub 下čŧ‰é č¨“įˇ´åĨŊįš„ PaddlePaddle æ¨Ąåž‹ä¸Ļ針對į¯„äž‹į…§į‰‡åšäēē臉åŖįŊŠčž¨č­˜ã€‚這個į¯„äž‹į¸Ŋå…ąæœƒåˆ†æˆå…Šå€‹æ­Ĩ驟:\n", - "\n", - "- į”¨č‡‰éƒ¨æĒĸæ¸Ŧæ¨Ąåž‹č­˜åˆĨ圖į‰‡ä¸­įš„äēē臉(į„ĄčĢ–是åĻ有戴åŖįŊŠ) \n", - "- įĸēčĒåœ–į‰‡ä¸­įš„č‡‰æ˜¯åĻ有戴åŖįŊŠ\n", - "\n", - "這兊個æ­Ĩ銟會包åĢäŊŋį”¨å…Šå€‹ Paddle æ¨Ąåž‹īŧŒæˆ‘們會在æŽĨ下來įš„內厚äģ‹į´šå…Šå€‹æ¨Ąåž‹å°æ‡‰éœ€čĻåšįš„å‰åžŒč™•į†é‚čŧ¯\n", - "\n", - "## 導å…Ĩį›¸é—œį’°åĸƒäžčŗ´åŠå­éĄžåˆĨ\n", - "在這個䞋子中įš„å‰č™•į†éŖ›æ§ŗæˇąåēĻå­¸įŋ’åŧ•æ“Žéœ€čĻæ­é… DJL æˇˇåˆæ¨Ąåŧé€˛čĄŒæˇąåēĻå­¸įŋ’推į†īŧŒåŽŸå› æ˜¯åŧ•æ“ŽæœŦčēĢæ˛’æœ‰åŒ…åĢ NDArray 操äŊœīŧŒå› æ­¤éœ€čĻč—‰į”¨å…ļäģ–åŧ•æ“Žįš„ NDArray 操äŊœčƒŊ力䞆厌成。這邊我們導å…Ĩ PyTorch 䞆做協同įš„å‰č™•į†åˇĨäŊœ:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// second engine to do preprocessing and postprocessing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.Shape;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "\n", - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 臉部åĩæ¸Ŧæ¨Ąåž‹\n", - "įžåœ¨æˆ‘們可äģĨé–‹å§‹č™•į†įŦŦä¸€å€‹æ¨Ąåž‹īŧŒåœ¨å°‡åœ–į‰‡čŧ¸å…Ĩ臉部æĒĸæ¸Ŧæ¨Ąåž‹å‰æˆ‘å€‘åŋ…須先做一äē›é č™•į†:\n", - "â€ĸ\tčĒŋ整圖į‰‡å°ē寸: äģĨį‰šåŽšæ¯”äž‹į¸Žå°åœ–į‰‡\n", - "â€ĸ\tį”¨ä¸€å€‹æ•¸å€ŧ對į¸Žå°åžŒåœ–į‰‡æ­ŖčĻåŒ–\n", - "對開į™ŧ者來čĒĒåĨŊæļˆæ¯æ˜¯īŧŒDJL 提䞛äē† Translator äģ‹éĸ來åšĢ劊開į™ŧ做這æ¨Ŗįš„預處į†. 一個比čŧƒį˛—į•Ĩįš„ Translator æžļ構åĻ‚下:\n", - "\n", - "![](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "在æŽĨ下來įš„æŽĩčŊīŧŒæˆ‘們會刊į”¨ä¸€å€‹ FaceTranslator å­éĄžåˆĨå¯ĻäŊœäž†åŽŒæˆåˇĨäŊœ\n", - "### 預處į†\n", - "在這個階æŽĩæˆ‘å€‘æœƒčŽ€å–ä¸€åŧĩ圖į‰‡ä¸Ļ且對å…ļ做一äē›äē‹å…ˆįš„預處į†īŧŒčŽ“我們先į¤ēį¯„čŽ€å–ä¸€åŧĩ圖į‰‡:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String url = \"https://raw.githubusercontent.com/PaddlePaddle/PaddleHub/release/v1.5/demo/mask_detection/python/images/mask.jpg\";\n", - "Image img = ImageFactory.getInstance().fromUrl(url);\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "æŽĨ著īŧŒčŽ“我們čŠĻč‘—å°åœ–į‰‡åšä¸€äē›é č™•į†įš„čŊ‰æ›:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "NDList processImageInput(NDManager manager, Image input, float shrink) {\n", - " NDArray array = input.toNDArray(manager);\n", - " Shape shape = array.getShape();\n", - " array = NDImageUtils.resize(\n", - " array, (int) (shape.get(1) * shrink), (int) (shape.get(0) * shrink));\n", - " array = array.transpose(2, 0, 1).flip(0); // HWC -> CHW BGR -> RGB\n", - " NDArray mean = manager.create(new float[] {104f, 117f, 123f}, new Shape(3, 1, 1));\n", - " array = array.sub(mean).mul(0.007843f); // normalization\n", - " array = array.expandDims(0); // make batch dimension\n", - " return new NDList(array);\n", - "}\n", - "\n", - "processImageInput(NDManager.newBaseManager(), img, 0.5f);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "åĻ‚上čŋ°æ‰€čĻ‹īŧŒæˆ‘å€‘åˇ˛įļ“把圖į‰‡čŊ‰æˆåĻ‚下å°ē寸įš„ NDArray: (æŠĢ量, 通道(RGB), éĢ˜åēĻ, å¯ŦåēĻ). 這是į‰ŠäģļæĒĸæ¸Ŧæ¨Ąåž‹čŧ¸å…Ĩįš„æ ŧåŧ\n", - "### åžŒč™•į†\n", - "į•ļæˆ‘å€‘åšåžŒč™•į†æ™‚, æ¨Ąåž‹čŧ¸å‡ēįš„æ ŧåŧæ˜¯ (number_of_boxes, (class_id, probability, xmin, ymin, xmax, ymax)). 我們可äģĨ將å…ļ存å…Ĩ預先åģēįĢ‹åĨŊįš„ DJL å­éĄžåˆĨ DetectedObjects äģĨäžŋ做垌įēŒæ“äŊœ. æˆ‘å€‘å‡č¨­æœ‰ä¸€įĩ„推čĢ–垌įš„čŧ¸å‡ē是 ((1, 0.99, 0.2, 0.4, 0.5, 0.8)) ä¸Ļ且čŠĻč‘—æŠŠäēēåƒæĄ†éĄ¯į¤ē在圖į‰‡ä¸Š" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "DetectedObjects processImageOutput(NDList list, List className, float threshold) {\n", - " NDArray result = list.singletonOrThrow();\n", - " float[] probabilities = result.get(\":,1\").toFloatArray();\n", - " List names = new ArrayList<>();\n", - " List prob = new ArrayList<>();\n", - " List boxes = new ArrayList<>();\n", - " for (int i = 0; i < probabilities.length; i++) {\n", - " if (probabilities[i] >= threshold) {\n", - " float[] array = result.get(i).toFloatArray();\n", - " names.add(className.get((int) array[0]));\n", - " prob.add((double) probabilities[i]);\n", - " boxes.add(\n", - " new Rectangle(\n", - " array[2], array[3], array[4] - array[2], array[5] - array[3]));\n", - " }\n", - " }\n", - " return new DetectedObjects(names, prob, boxes);\n", - "}\n", - "\n", - "NDArray tempOutput = NDManager.newBaseManager().create(new float[]{1f, 0.99f, 0.1f, 0.1f, 0.2f, 0.2f}, new Shape(1, 6));\n", - "DetectedObjects testBox = processImageOutput(new NDList(tempOutput), Arrays.asList(\"Not Face\", \"Face\"), 0.7f);\n", - "Image newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(testBox);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### į”Ÿæˆä¸€å€‹įŋģč­¯å™¨ä¸ĻåŸˇčĄŒæŽ¨į†äģģ務\n", - "透過這個æ­Ĩ驟īŧŒäŊ æœƒį†č§Ŗ DJL 中įš„å‰åžŒč™•į†åĻ‚äŊ•é‹äŊœīŧŒįžåœ¨čŽ“我們把前數įš„嚞個æ­ĨéŠŸä¸˛åœ¨ä¸€čĩˇä¸Ļ對įœŸå¯Ļ圖į‰‡é€˛čĄŒæ“äŊœ:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class FaceTranslator implements NoBatchifyTranslator {\n", - "\n", - " private float shrink;\n", - " private float threshold;\n", - " private List className;\n", - "\n", - " FaceTranslator(float shrink, float threshold) {\n", - " this.shrink = shrink;\n", - " this.threshold = threshold;\n", - " className = Arrays.asList(\"Not Face\", \"Face\");\n", - " }\n", - "\n", - " @Override\n", - " public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {\n", - " return processImageOutput(list, className, threshold);\n", - " }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, Image input) {\n", - " return processImageInput(ctx.getNDManager(), input, shrink);\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "čĻåŸˇčĄŒé€™å€‹äēē臉æĒĸæ¸Ŧ推į†īŧŒæˆ‘們åŋ…須先垞 DJL įš„ Paddle Model Zoo čŽ€å–æ¨Ąåž‹īŧŒåœ¨čŽ€å–æ¨Ąåž‹äš‹å‰æˆ‘å€‘åŋ…須指厚åĨŊ `Crieteria` . `Crieteria` 是į”¨äž†įĸēčĒčĻåžžå“Ēé‚ŠčŽ€å–æ¨Ąåž‹č€ŒåžŒåŸˇčĄŒ `Translator` äž†é€˛čĄŒæ¨Ąåž‹å°Žå…Ĩ. æŽĨ著īŧŒæˆ‘們åĒčĻåˆŠį”¨ `Predictor` å°ąå¯äģĨé–‹å§‹é€˛čĄŒæŽ¨čĢ–" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optModelUrls(\"djl://ai.djl.paddlepaddle/face_detection/0.0.1/mask_detection\")\n", - " .optFilter(\"flavor\", \"server\")\n", - " .optTranslator(new FaceTranslator(0.5f, 0.7f))\n", - " .build();\n", - " \n", - "var model = criteria.loadModel();\n", - "var predictor = model.newPredictor();\n", - "\n", - "DetectedObjects inferenceResult = predictor.predict(img);\n", - "newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(inferenceResult);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "åĻ‚圖į‰‡æ‰€į¤ēīŧŒé€™å€‹æŽ¨čĢ–æœå‹™åˇ˛įļ“可äģĨæ­Ŗįĸēįš„čž¨č­˜å‡ē圖į‰‡ä¸­įš„三åŧĩäēē臉\n", - "## åŖįŊŠåˆ†éĄžæ¨Ąåž‹\n", - "一æ—Ļ有äē†åœ–į‰‡įš„åē§æ¨™īŧŒæˆ‘å€‘å°ąå¯äģĨ將圖į‰‡čŖå‰Ē到遊į•ļ大小ä¸Ļ且將å…ļå‚ŗįĩĻåŖįŊŠåˆ†éĄžæ¨Ąåž‹åšåžŒįēŒįš„推čĢ–\n", - "### 圖į‰‡čŖå‰Ē\n", - "åœ–ä¸­æ–šæĄ†äŊįŊŽįš„數å€ŧį¯„圍垞0到1, åĒčĻå°‡é€™å€‹æ•¸å€ŧ䚘上圖į‰‡įš„镡å¯Ŧæˆ‘å€‘å°ąå¯äģĨå°‡æ–šæĄ†å°æ‡‰åˆ°åœ–į‰‡ä¸­įš„æē–įĸēäŊįŊŽ. į‚ēäē†äŊŋčŖå‰Ē垌įš„圖į‰‡æœ‰æ›´åĨŊįš„į˛žįĸēåēĻīŧŒæˆ‘們將圖į‰‡čŖå‰Ē成斚åŊĸīŧŒčŽ“我們į¤ēį¯„一下:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int[] extendSquare(\n", - " double xmin, double ymin, double width, double height, double percentage) {\n", - " double centerx = xmin + width / 2;\n", - " double centery = ymin + height / 2;\n", - " double maxDist = Math.max(width / 2, height / 2) * (1 + percentage);\n", - " return new int[] {\n", - " (int) (centerx - maxDist), (int) (centery - maxDist), (int) (2 * maxDist)\n", - " };\n", - "}\n", - "\n", - "Image getSubImage(Image img, BoundingBox box) {\n", - " Rectangle rect = box.getBounds();\n", - " int width = img.getWidth();\n", - " int height = img.getHeight();\n", - " int[] squareBox =\n", - " extendSquare(\n", - " rect.getX() * width,\n", - " rect.getY() * height,\n", - " rect.getWidth() * width,\n", - " rect.getHeight() * height,\n", - " 0.18);\n", - " return img.getSubImage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);\n", - "}\n", - "\n", - "List faces = inferenceResult.items();\n", - "getSubImage(img, faces.get(2).getBoundingBox()).getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### äē‹å…ˆæē–å‚™ Translator ä¸ĻčŽ€å–æ¨Ąåž‹\n", - "在äŊŋį”¨č‡‰éƒ¨æĒĸæ¸Ŧæ¨Ąåž‹įš„時候īŧŒæˆ‘們可äģĨ刊į”¨ DJL 預先åģēåĨŊįš„ `ImageClassificationTranslator` ä¸Ļ且加上一äē›čŊ‰æ›ã€‚這個 Translator 提䞛äē†ä¸€äē›åŸēį¤Žįš„圖į‰‡įŋģč­¯č™•į†ä¸Ļ且同時包åĢ一äē›é€˛éšŽįš„標æē–化圖į‰‡č™•į†ã€‚äģĨ這個䞋子䞆čĒĒ, 我們不需čĻéĄå¤–åģēįĢ‹æ–°įš„ `Translator` 而äŊŋį”¨é å…ˆåģēįĢ‹įš„å°ąå¯äģĨ" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(\"djl://ai.djl.paddlepaddle/mask_classification/0.0.1/mask_classification\")\n", - " .optFilter(\"flavor\", \"server\")\n", - " .optTranslator(\n", - " ImageClassificationTranslator.builder()\n", - " .addTransform(new Resize(128, 128))\n", - " .addTransform(new ToTensor()) // HWC -> CHW div(255)\n", - " .addTransform(\n", - " new Normalize(\n", - " new float[] {0.5f, 0.5f, 0.5f},\n", - " new float[] {1.0f, 1.0f, 1.0f}))\n", - " .addTransform(nd -> nd.flip(0)) // RGB -> GBR\n", - " .build())\n", - " .build();\n", - "\n", - "var classifyModel = criteria.loadModel();\n", - "var classifier = classifyModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### åŸˇčĄŒæŽ¨čĢ–äģģ務\n", - "最垌īŧŒčĻåŽŒæˆä¸€å€‹åŖįŊŠč­˜åˆĨįš„äģģ務īŧŒæˆ‘們åĒ需čĻå°‡ä¸Ščŋ°įš„æ­Ĩ銟合在一čĩˇåŗ可。我們先將圖į‰‡åščŖå‰Ē垌ä¸Ļ對å…ļ做上čŋ°įš„推čĢ–操äŊœīŧŒįĩæŸäš‹åžŒå†į”Ÿæˆä¸€å€‹æ–°įš„åˆ†éĄžå­éĄžåˆĨ `DetectedObjects`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "List names = new ArrayList<>();\n", - "List prob = new ArrayList<>();\n", - "List rect = new ArrayList<>();\n", - "for (DetectedObjects.DetectedObject face : faces) {\n", - " Image subImg = getSubImage(img, face.getBoundingBox());\n", - " Classifications classifications = classifier.predict(subImg);\n", - " names.add(classifications.best().getClassName());\n", - " prob.add(face.getProbability());\n", - " rect.add(face.getBoundingBox());\n", - "}\n", - "\n", - "newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", - "newImage.getWrappedImage();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/paddlepaddle/paddle_ocr_java.ipynb b/jupyter/paddlepaddle/paddle_ocr_java.ipynb deleted file mode 100644 index da8527020ab..00000000000 --- a/jupyter/paddlepaddle/paddle_ocr_java.ipynb +++ /dev/null @@ -1,313 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PaddleOCR DJL example\n", - "\n", - "In this tutorial, we will be using pretrained PaddlePaddle model from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) to do Optical character recognition (OCR) from the given image. There are three models involved in this tutorial:\n", - "\n", - "- Word detection model: used to detect the word block from the image\n", - "- Word direction model: used to find if the text needs to rotate\n", - "- Word recognition model: Used to recognize test from the word block\n", - "\n", - "## Import dependencies and classes\n", - "\n", - "PaddlePaddle is one of the Deep Engines that requires DJL hybrid mode to run inference. Itself does not contains NDArray operations and needs a supplemental DL framework to help with that. So we import Pytorch DL engine as well in here to do the processing works." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// second engine to do preprocessing and postprocessing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.inference.Predictor;\n", - "import ai.djl.modality.Classifications;\n", - "import ai.djl.modality.cv.Image;\n", - "import ai.djl.modality.cv.ImageFactory;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.util.NDImageUtils;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.DataType;\n", - "import ai.djl.ndarray.types.Shape;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;\n", - "import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;\n", - "import ai.djl.paddlepaddle.zoo.cv.wordrecognition.PpWordRecognitionTranslator;\n", - "import ai.djl.translate.*;\n", - "import java.util.concurrent.ConcurrentHashMap;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## the Image\n", - "Firstly, let's take a look at our sample image, a flight ticket:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String url = \"https://resources.djl.ai/images/flight_ticket.jpg\";\n", - "Image img = ImageFactory.getInstance().fromUrl(url);\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Word detection model\n", - "\n", - "In our word detection model, we load the model exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-detection-model-to-inference-model). After that, we can spawn a DJL Predictor from it called detector." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria1 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip\")\n", - " .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap()))\n", - " .build();\n", - "var detectionModel = criteria1.loadModel();\n", - "var detector = detectionModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then, we can detect the word block from it. The original output from the model is a bitmap that marked all word regions. The `PpWordDetectionTranslator` convert the output bitmap into a rectangle bounded box for us to crop the image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var detectedObj = detector.predict(img);\n", - "Image newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(detectedObj);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see above, the word block are very narrow and does not include the whole body of all words. Let's try to extend it a bit for a better result. `extendRect` extend the box height and width to a certain scale. `getSubImage` will crop the image and extract the word block." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Image getSubImage(Image img, BoundingBox box) {\n", - " Rectangle rect = box.getBounds();\n", - " double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());\n", - " int width = img.getWidth();\n", - " int height = img.getHeight();\n", - " int[] recovered = {\n", - " (int) (extended[0] * width),\n", - " (int) (extended[1] * height),\n", - " (int) (extended[2] * width),\n", - " (int) (extended[3] * height)\n", - " };\n", - " return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);\n", - "}\n", - "\n", - "double[] extendRect(double xmin, double ymin, double width, double height) {\n", - " double centerx = xmin + width / 2;\n", - " double centery = ymin + height / 2;\n", - " if (width > height) {\n", - " width += height * 2.0;\n", - " height *= 3.0;\n", - " } else {\n", - " height += width * 2.0;\n", - " width *= 3.0;\n", - " }\n", - " double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;\n", - " double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;\n", - " double newWidth = newX + width > 1 ? 1 - newX : width;\n", - " double newHeight = newY + height > 1 ? 1 - newY : height;\n", - " return new double[] {newX, newY, newWidth, newHeight};\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's try to extract one block out:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "List boxes = detectedObj.items();\n", - "var sample = getSubImage(img, boxes.get(5).getBoundingBox());\n", - "sample.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Word Direction model\n", - "\n", - "This model is exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-angle-classification-model-to-inference-model) that can help to identify if the image is required to rotate. The following code will load this model and create a rotateClassifier." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria2 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/cls.zip\")\n", - " .optTranslator(new PpWordRotateTranslator())\n", - " .build();\n", - "var rotateModel = criteria2.loadModel();\n", - "var rotateClassifier = rotateModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Word Recgonition model\n", - "\n", - "The word recognition model is exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-recognition-model-to-inference-model) that can recognize the text on the image. Let's load this model as well.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria3 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, String.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/rec_crnn.zip\")\n", - " .optTranslator(new PpWordRecognitionTranslator())\n", - " .build();\n", - "var recognitionModel = criteria3.loadModel();\n", - "var recognizer = recognitionModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then we can try to play with these two models on the previous cropped image:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "System.out.println(rotateClassifier.predict(sample));\n", - "recognizer.predict(sample);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, let's run these models on the whole image and see the outcome. DJL offers a rich image toolkit that allows you to draw the text on image and display them." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Image rotateImg(Image image) {\n", - " try (NDManager manager = NDManager.newBaseManager()) {\n", - " NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);\n", - " return ImageFactory.getInstance().fromNDArray(rotated);\n", - " }\n", - "}\n", - "\n", - "List names = new ArrayList<>();\n", - "List prob = new ArrayList<>();\n", - "List rect = new ArrayList<>();\n", - "\n", - "for (int i = 0; i < boxes.size(); i++) {\n", - " Image subImg = getSubImage(img, boxes.get(i).getBoundingBox());\n", - " if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {\n", - " subImg = rotateImg(subImg);\n", - " }\n", - " Classifications.Classification result = rotateClassifier.predict(subImg).best();\n", - " if (\"Rotate\".equals(result.getClassName()) && result.getProbability() > 0.8) {\n", - " subImg = rotateImg(subImg);\n", - " }\n", - " String name = recognizer.predict(subImg);\n", - " names.add(name);\n", - " prob.add(-1.0);\n", - " rect.add(boxes.get(i).getBoundingBox());\n", - "}\n", - "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", - "newImage.getWrappedImage();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb b/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb deleted file mode 100644 index 2419baf89c7..00000000000 --- a/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb +++ /dev/null @@ -1,309 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PaddleOCR在DJL 上įš„å¯Ļįž\n", - "在這個教į¨‹čŖĄīŧŒæˆ‘å€‘æœƒåą•į¤ē刊į”¨ PaddleOCR 下čŧ‰é č¨“įˇ´åĨŊæ–‡å­—č™•į†æ¨Ąåž‹ä¸Ļ對指厚įš„į…§į‰‡é€˛čĄŒæ–‡å­¸æ–‡å­—æĒĸæ¸Ŧ (OCR)。這個教į¨‹į¸Ŋå…ąæœƒåˆ†æˆä¸‰å€‹éƒ¨åˆ†:\n", - "\n", - "- æ–‡å­—å€åĄŠæĒĸæ¸Ŧ: 垞圖į‰‡æĒĸæ¸Ŧå‡ēæ–‡å­—å€åĄŠ\n", - "- æ–‡å­—č§’åēĻæĒĸæ¸Ŧ: įĸēčĒæ–‡å­—是åĻ需čĻæ—‹čŊ‰\n", - "- æ–‡å­—č­˜åˆĨ: įĸēčĒå€åĄŠå…§įš„文字\n", - "\n", - "## 導å…Ĩį›¸é—œį’°åĸƒäžčŗ´åŠå­éĄžåˆĨ\n", - "在這個䞋子中įš„å‰č™•į†éŖ›æ§ŗæˇąåēĻå­¸įŋ’åŧ•æ“Žéœ€čĻæ­é…DJLæˇˇåˆæ¨Ąåŧé€˛čĄŒæˇąåēĻå­¸įŋ’推į†īŧŒåŽŸå› æ˜¯åŧ•æ“ŽæœŦčēĢæ˛’æœ‰åŒ…åĢND數įĩ„操äŊœīŧŒå› æ­¤éœ€čĻč—‰į”¨å…ļäģ–åŧ•æ“Žįš„數įĩ„操äŊœčƒŊ力䞆厌成。這邊我們導å…ĨPytorch䞆做協同įš„å‰č™•į†åˇĨäŊœ:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// second engine to do preprocessing and postprocessing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.inference.Predictor;\n", - "import ai.djl.modality.Classifications;\n", - "import ai.djl.modality.cv.Image;\n", - "import ai.djl.modality.cv.ImageFactory;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.util.NDImageUtils;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.DataType;\n", - "import ai.djl.ndarray.types.Shape;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;\n", - "import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;\n", - "import ai.djl.paddlepaddle.zoo.cv.wordrecognition.PpWordRecognitionTranslator;\n", - "import ai.djl.translate.*;\n", - "import java.util.concurrent.ConcurrentHashMap;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 圖į‰‡čŽ€å–\n", - "éĻ–å…ˆčŽ“æˆ‘å€‘čŧ‰å…Ĩ這æŦĄæ•™į¨‹æœƒį”¨åˆ°įš„抟įĨ¨į¯„䞋圖į‰‡:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String url = \"https://resources.djl.ai/images/flight_ticket.jpg\";\n", - "Image img = ImageFactory.getInstance().fromUrl(url);\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## æ–‡å­—å€åĄŠæĒĸæ¸Ŧ\n", - "我們éĻ–先垞 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-detection-model-to-inference-model) 開į™ŧåĨ—äģļä¸­čŽ€å–æ–‡å­—æĒĸæ¸Ŧįš„æ¨Ąåž‹īŧŒäš‹åžŒæˆ‘們可äģĨį”Ÿæˆä¸€å€‹DJL `Predictor` ä¸Ļ將å…ļå‘Ŋ名į‚ē `detector`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria1 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip\")\n", - " .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap()))\n", - " .build();\n", - "var detectionModel = criteria1.loadModel();\n", - "var detector = detectionModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "æŽĨč‘—æˆ‘å€‘æĒĸæ¸Ŧå‡ē圖į‰‡ä¸­įš„æ–‡å­—å€åĄŠīŧŒé€™å€‹æ¨Ąåž‹įš„原始čŧ¸å‡ē是åĢ有標č¨ģ所有文字區域įš„圖įŽ—æŗ•(Bitmap)īŧŒæˆ‘們可äģĨ刊į”¨`PpWordDetectionTranslator` å‡Ŋåŧå°‡åœ–įŽ—æŗ•įš„čŧ¸å‡ēčŊ‰æˆé•ˇæ–šåŊĸįš„æ–šæĄ†äž†čŖå‰Ē圖į‰‡" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var detectedObj = detector.predict(img);\n", - "Image newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(detectedObj);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "åĻ‚上所į¤ēīŧŒæ‰€æ¨™č¨ģįš„æ–‡å­—å€åĄŠéƒŊ非常įĒ„īŧŒä¸”æ˛’æœ‰åŒ…äŊæ‰€æœ‰åŽŒæ•´įš„æ–‡å­—å€åĄŠã€‚čŽ“æˆ‘å€‘å˜—čŠĻäŊŋį”¨`extendRect`å‡Ŋåŧäž†æ“´åą•æ–‡å­—æĄ†įš„镡å¯Ŧ到需čĻįš„大小, 再刊į”¨ `getSubImage` čŖå‰Ēä¸Ļæ“ˇå–å‡ēæ–‡å­å€åĄŠã€‚" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Image getSubImage(Image img, BoundingBox box) {\n", - " Rectangle rect = box.getBounds();\n", - " double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());\n", - " int width = img.getWidth();\n", - " int height = img.getHeight();\n", - " int[] recovered = {\n", - " (int) (extended[0] * width),\n", - " (int) (extended[1] * height),\n", - " (int) (extended[2] * width),\n", - " (int) (extended[3] * height)\n", - " };\n", - " return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);\n", - "}\n", - "\n", - "double[] extendRect(double xmin, double ymin, double width, double height) {\n", - " double centerx = xmin + width / 2;\n", - " double centery = ymin + height / 2;\n", - " if (width > height) {\n", - " width += height * 2.0;\n", - " height *= 3.0;\n", - " } else {\n", - " height += width * 2.0;\n", - " width *= 3.0;\n", - " }\n", - " double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;\n", - " double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;\n", - " double newWidth = newX + width > 1 ? 1 - newX : width;\n", - " double newHeight = newY + height > 1 ? 1 - newY : height;\n", - " return new double[] {newX, newY, newWidth, newHeight};\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "čŽ“æˆ‘å€‘čŧ¸å‡ēå…ļä¸­ä¸€å€‹æ–‡å­—å€åĄŠ" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "List boxes = detectedObj.items();\n", - "var sample = getSubImage(img, boxes.get(5).getBoundingBox());\n", - "sample.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## æ–‡å­—č§’åēĻæĒĸæ¸Ŧ\n", - "我們垞 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-angle-classification-model-to-inference-model) čŧ¸å‡ēé€™å€‹æ¨Ąåž‹ä¸ĻįĸēčĒåœ–į‰‡åŠæ–‡å­—是åĻ需čĻæ—‹čŊ‰ã€‚äģĨ下įš„äģŖįĸŧæœƒčŽ€å…Ĩé€™å€‹æ¨Ąåž‹ä¸Ļį”Ÿæˆa `rotateClassifier` å­éĄžåˆĨ" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria2 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/cls.zip\")\n", - " .optTranslator(new PpWordRotateTranslator())\n", - " .build();\n", - "var rotateModel = criteria2.loadModel();\n", - "var rotateClassifier = rotateModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## æ–‡å­—č­˜åˆĨ\n", - "\n", - "我們垞 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-recognition-model-to-inference-model) čŧ¸å‡ēé€™å€‹æ¨Ąåž‹ä¸Ļ識åˆĨ圖į‰‡ä¸­įš„文字, 我們一æ¨Ŗäģŋ造上čŋ°įš„æ­ĨéŠŸčŽ€å–é€™å€‹æ¨Ąåž‹\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria3 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, String.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/rec_crnn.zip\")\n", - " .optTranslator(new PpWordRecognitionTranslator())\n", - " .build();\n", - "var recognitionModel = criteria3.loadModel();\n", - "var recognizer = recognitionModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "æŽĨč‘—æˆ‘å€‘å¯äģĨčŠĻ著åĨ—į”¨é€™å…Šå€‹æ¨Ąåž‹åœ¨å…ˆå‰å‰ĒčŖåĨŊįš„æ–‡å­—å€åĄŠä¸Š" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "System.out.println(rotateClassifier.predict(sample));\n", - "recognizer.predict(sample);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "最垌我們把這äē›æ¨Ąåž‹ä¸˛é€Ŗ在一čĩˇä¸ĻåĨ—į”¨åœ¨æ•´åŧĩ圖į‰‡ä¸Šįœ‹įœ‹įĩæžœæœƒåĻ‚äŊ•ã€‚DJL提䞛äē†čąå¯Œįš„åŊąåƒåˇĨå…ˇåŒ…čŽ“äŊ å¯äģĨ垞圖į‰‡ä¸­æ“ˇå–å‡ē文字ä¸Ļ且厌įžŽå‘ˆįž" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Image rotateImg(Image image) {\n", - " try (NDManager manager = NDManager.newBaseManager()) {\n", - " NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);\n", - " return ImageFactory.getInstance().fromNDArray(rotated);\n", - " }\n", - "}\n", - "\n", - "List names = new ArrayList<>();\n", - "List prob = new ArrayList<>();\n", - "List rect = new ArrayList<>();\n", - "\n", - "for (int i = 0; i < boxes.size(); i++) {\n", - " Image subImg = getSubImage(img, boxes.get(i).getBoundingBox());\n", - " if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {\n", - " subImg = rotateImg(subImg);\n", - " }\n", - " Classifications.Classification result = rotateClassifier.predict(subImg).best();\n", - " if (\"Rotate\".equals(result.getClassName()) && result.getProbability() > 0.8) {\n", - " subImg = rotateImg(subImg);\n", - " }\n", - " String name = recognizer.predict(subImg);\n", - " names.add(name);\n", - " prob.add(-1.0);\n", - " rect.add(boxes.get(i).getBoundingBox());\n", - "}\n", - "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", - "newImage.getWrappedImage();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/pytorch/load_your_own_pytorch_bert.ipynb b/jupyter/pytorch/load_your_own_pytorch_bert.ipynb deleted file mode 100644 index 3c52ee599b0..00000000000 --- a/jupyter/pytorch/load_your_own_pytorch_bert.ipynb +++ /dev/null @@ -1,441 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load your own PyTorch BERT model\n", - "\n", - "In the previous [example](https://github.com/deepjavalibrary/djl/blob/master/jupyter/BERTQA.ipynb), you run BERT inference with the model from Model Zoo. You can also load the model on your own pre-trained BERT and use custom classes as the input and output.\n", - "\n", - "In general, the PyTorch BERT model from [HuggingFace](https://github.com/huggingface/transformers) requires these three inputs:\n", - "\n", - "- word indices: The index of each word in a sentence\n", - "- word types: The type index of the word.\n", - "- attention mask: The mask indicates to the model which tokens should be attended to, and which should not after batching sequence together.\n", - "\n", - "We will dive deep into these details later." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are dependencies we will use." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import java packages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import java.util.stream.*;\n", - "\n", - "import ai.djl.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.modality.nlp.*;\n", - "import ai.djl.modality.nlp.qa.*;\n", - "import ai.djl.modality.nlp.bert.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Reuse the previous input**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var question = \"When did BBC Japan start broadcasting?\";\n", - "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", - " \"Which operated between December 2004 and April 2006.\\n\" +\n", - " \"It ceased operations after its Japanese distributor folded.\";\n", - "\n", - "QAInput input = new QAInput(question, resourceDocument);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dive deep into Translator\n", - "\n", - "Inference in deep learning is the process of predicting the output for a given input based on a pre-defined model.\n", - "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide\n", - "output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:\n", - "\n", - "![https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "The red block (\"Images\") in the workflow is the input that DJL expects from you. The green block (\"Images\n", - "bounding box\") is the output that you expect. Because DJL does not know which input to expect and which output format that you prefer, DJL provides the [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface so you can define your own\n", - "input and output.\n", - "\n", - "The `Translator` interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing\n", - "component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the\n", - "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n", - "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n", - "format." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Pre-processing\n", - "\n", - "Now, you need to convert the sentences into tokens. We provide a powerful tool [`BertTokenizer`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/bert/BertTokenizer.html) that you can use to convert questions and answers into tokens, and batchify your sequence together. Once you have properly formatted tokens, you can use [`Vocabulary`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/Vocabulary.html) to map your token to BERT index.\n", - "\n", - "The following code block demonstrates tokenizing the question and answer defined earlier into BERT-formatted tokens." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var tokenizer = new BertTokenizer();\n", - "List tokenQ = tokenizer.tokenize(question.toLowerCase());\n", - "List tokenA = tokenizer.tokenize(resourceDocument.toLowerCase());\n", - "\n", - "System.out.println(\"Question Token: \" + tokenQ);\n", - "System.out.println(\"Answer Token: \" + tokenA);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`BertTokenizer` can also help you batchify questions and resource documents together by calling `encode()`.\n", - "The output contains information that BERT ingests.\n", - "\n", - "- getTokens: It returns a list of strings including the question, resource document and special word to let the model tell which part is the question and which part is the resource document. Because PyTorch BERT was trained with varioue sequence length, you don't pad the tokens.\n", - "- getTokenTypes: It returns a list of type indices of the word to indicate the location of the resource document. All Questions will be labelled with 0 and all resource documents will be labelled with 1.\n", - "\n", - " [Question tokens...DocResourceTokens...padding tokens] => [000000...11111....0000]\n", - " \n", - "\n", - "- getValidLength: It returns the actual length of the question and resource document tokens tokens, which are required by MXNet BERT.\n", - "- getAttentionMask: It returns the mask for the model to indicate which part should be paid attention to and which part is the padding. It is required by PyTorch BERT.\n", - "\n", - " [Question tokens...DocResourceTokens...padding tokens] => [111111...11111....0000]\n", - " \n", - "PyTorch BERT was trained with varioue sequence length, so we don't need to pad the tokens." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BertToken token = tokenizer.encode(question.toLowerCase(), resourceDocument.toLowerCase());\n", - "System.out.println(\"Encoded tokens: \" + token.getTokens());\n", - "System.out.println(\"Encoded token type: \" + token.getTokenTypes());\n", - "System.out.println(\"Valid length: \" + token.getValidLength());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Normally, words and sentences are represented as indices instead of tokens for training. \n", - "They typically work like a vector in a n-dimensional space. In this case, you need to map them into indices.\n", - "DJL provides `Vocabulary` to take care of you vocabulary mapping.\n", - "\n", - "The bert vocab from Huggingface is of the following format.\n", - "```\n", - "[PAD]\n", - "[unused0]\n", - "[unused1]\n", - "[unused2]\n", - "[unused3]\n", - "[unused4]\n", - "[unused5]\n", - "[unused6]\n", - "[unused7]\n", - "[unused8]\n", - "...\n", - "```\n", - "We provide the `bert-base-uncased-vocab.txt` from our pre-trained BERT for demonstration." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/bert-base-uncased-vocab.txt.gz\", \"build/pytorch/bertqa/vocab.txt\", new ProgressBar());" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var path = Paths.get(\"build/pytorch/bertqa/vocab.txt\");\n", - "var vocabulary = DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromTextFile(path)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can easily convert the token to the index using `vocabulary.getIndex(token)` and the other way around using `vocabulary.getToken(index)`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "long index = vocabulary.getIndex(\"car\");\n", - "String token = vocabulary.getToken(2482);\n", - "System.out.println(\"The index of the car is \" + index);\n", - "System.out.println(\"The token of the index 2482 is \" + token);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To properly convert them into `float[]` for `NDArray` creation, here is the helper function:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that you have everything you need, you can create an NDList and populate all of the inputs you formatted earlier. You're done with pre-processing! \n", - "\n", - "#### Construct `Translator`\n", - "\n", - "You need to do this processing within an implementation of the `Translator` interface. `Translator` is designed to do pre-processing and post-processing. You must define the input and output objects. It contains the following two override classes:\n", - "- `public NDList processInput(TranslatorContext ctx, I)`\n", - "- `public String processOutput(TranslatorContext ctx, O)`\n", - "\n", - "Every translator takes in input and returns output in the form of generic objects. In this case, the translator takes input in the form of `QAInput` (I) and returns output as a `String` (O). `QAInput` is just an object that holds questions and answer; We have prepared the Input class for you." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Armed with the needed knowledge, you can write an implementation of the `Translator` interface. `BertTranslator` uses the code snippets explained previously to implement the `processInput`method. For more information, see [`NDManager`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDManager.html).\n", - "\n", - "```\n", - "manager.create(Number[] data, Shape)\n", - "manager.create(Number[] data)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "public class BertTranslator implements Translator {\n", - " private List tokens;\n", - " private Vocabulary vocabulary;\n", - " private BertTokenizer tokenizer;\n", - " \n", - " @Override\n", - " public void prepare(TranslatorContext ctx) throws IOException {\n", - " Path path = Paths.get(\"build/pytorch/bertqa/vocab.txt\");\n", - " vocabulary = DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromTextFile(path)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", - " tokenizer = new BertTokenizer();\n", - " }\n", - " \n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, QAInput input) {\n", - " BertToken token =\n", - " tokenizer.encode(\n", - " input.getQuestion().toLowerCase(),\n", - " input.getParagraph().toLowerCase());\n", - " // get the encoded tokens that would be used in precessOutput\n", - " tokens = token.getTokens();\n", - " NDManager manager = ctx.getNDManager();\n", - " // map the tokens(String) to indices(long)\n", - " long[] indices = tokens.stream().mapToLong(vocabulary::getIndex).toArray();\n", - " long[] attentionMask = token.getAttentionMask().stream().mapToLong(i -> i).toArray();\n", - " long[] tokenType = token.getTokenTypes().stream().mapToLong(i -> i).toArray();\n", - " NDArray indicesArray = manager.create(indices);\n", - " NDArray attentionMaskArray =\n", - " manager.create(attentionMask);\n", - " NDArray tokenTypeArray = manager.create(tokenType);\n", - " // The order matters\n", - " return new NDList(indicesArray, attentionMaskArray, tokenTypeArray);\n", - " }\n", - " \n", - " @Override\n", - " public String processOutput(TranslatorContext ctx, NDList list) {\n", - " NDArray startLogits = list.get(0);\n", - " NDArray endLogits = list.get(1);\n", - " int startIdx = (int) startLogits.argMax().getLong();\n", - " int endIdx = (int) endLogits.argMax().getLong();\n", - " return tokens.subList(startIdx, endIdx + 1).toString();\n", - " }\n", - " \n", - " @Override\n", - " public Batchifier getBatchifier() {\n", - " return Batchifier.STACK;\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Congrats! You have created your first Translator! We have pre-filled the `processOutput()` function to process the `NDList` and return it in a desired format. `processInput()` and `processOutput()` offer the flexibility to get the predictions from the model in any format you desire. \n", - "\n", - "With the Translator implemented, you need to bring up the predictor that uses your `Translator` to start making predictions. You can find the usage for `Predictor` in the [Predictor Javadoc](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html). Create a translator and use the `question` and `resourceDocument` provided previously." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/trace_bertqa.pt.gz\", \"build/pytorch/bertqa/bertqa.pt\", new ProgressBar());" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BertTranslator translator = new BertTranslator();\n", - "\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(QAInput.class, String.class)\n", - " .optModelPath(Paths.get(\"build/pytorch/bertqa/\")) // search in local folder\n", - " .optTranslator(translator)\n", - " .optProgress(new ProgressBar()).build();\n", - "\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String predictResult = null;\n", - "QAInput input = new QAInput(question, resourceDocument);\n", - "\n", - "// Create a Predictor and use it to predict the output\n", - "try (Predictor predictor = model.newPredictor(translator)) {\n", - " predictResult = predictor.predict(input);\n", - "}\n", - "\n", - "System.out.println(question);\n", - "System.out.println(predictResult);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Based on the input, the following result will be shown:\n", - "```\n", - "[december, 2004]\n", - "```\n", - "That's it! \n", - "\n", - "You can try with more questions and answers. Here are the samples:\n", - "\n", - "**Answer Material**\n", - "\n", - "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.\n", - "\n", - "\n", - "**Question**\n", - "\n", - "Q: When were the Normans in Normandy?\n", - "A: 10th and 11th centuries\n", - "\n", - "Q: In what country is Normandy located?\n", - "A: france\n", - "\n", - "For the full source code, see the [DJL repo](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java) and translator implementation [MXNet](https://github.com/deepjavalibrary/djl/blob/master/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/qa/MxBertQATranslator.java) [PyTorch](https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/qa/PtBertQATranslator.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb b/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb deleted file mode 100644 index 2edbc6c195f..00000000000 --- a/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb +++ /dev/null @@ -1,473 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Rank Classification using BERT on Amazon Review dataset\n", - "\n", - "## Introduction\n", - "\n", - "In this tutorial, you learn how to train a rank classification model using [Transfer Learning](https://en.wikipedia.org/wiki/Transfer_learning). We will use a pretrained DistilBert model to train on the Amazon review dataset.\n", - "\n", - "## About the dataset and model\n", - "\n", - "[Amazon Customer Review dataset](https://s3.amazonaws.com/amazon-reviews-pds/readme.html) consists of all different valid reviews from amazon.com. We will use the \"Digital_software\" category that consists of 102k valid reviews. As for the pre-trained model, use the DistilBERT[[1]](https://arxiv.org/abs/1910.01108) model. It's a light-weight BERT model already trained on [Wikipedia text corpora](https://en.wikipedia.org/wiki/List_of_text_corpora), a much larger dataset consisting of over millions text. The DistilBERT served as a base layer and we will add some more classification layers to output as rankings (1 - 5).\n", - "\n", - "\n", - "
    Amazon Review example
    \n", - "\n", - "We will use review body as our data input and ranking as label.\n", - "\n", - "\n", - "## Pre-requisites\n", - "This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:\n", - "1. How to setup and run [Java Kernel in Jupyter Notebook](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)\n", - "2. Basic components of Deep Java Library, and how to [train your first model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb).\n", - "\n", - "\n", - "## Getting started\n", - "Load the Deep Java Libarary and its dependencies from Maven. In here, you can choose between MXNet or PyTorch. MXNet is enabled by default. You can uncomment PyTorch dependencies and comment MXNet ones to switch to PyTorch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:basicdataset:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "\n", - "// PyTorch\n", - "// %maven ai.djl.pytorch:pytorch-model-zoo:0.23.0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's import the necessary modules:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.basicdataset.tabular.*;\n", - "import ai.djl.basicdataset.tabular.utils.*;\n", - "import ai.djl.basicdataset.utils.*;\n", - "import ai.djl.engine.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.metric.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.nlp.*;\n", - "import ai.djl.modality.nlp.bert.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.nn.*;\n", - "import ai.djl.nn.core.*;\n", - "import ai.djl.nn.norm.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.training.*;\n", - "import ai.djl.training.dataset.*;\n", - "import ai.djl.training.evaluator.*;\n", - "import ai.djl.training.listener.*;\n", - "import ai.djl.training.loss.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.translate.*;\n", - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import org.apache.commons.csv.*;\n", - "\n", - "System.out.println(\"You are using: \" + Engine.getInstance().getEngineName() + \" Engine\");" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prepare Dataset\n", - "\n", - "First step is to prepare the dataset for training. Since the original data was in TSV format, we can use CSVDataset to be the dataset container. We will also need to specify how do we want to preprocess the raw data. For BERT model, the input data are required to be tokenized and mapped into indices based on the inputs. In DJL, we defined an interface called Fearurizer, it is designed to allow user customize operation on each selected row/column of a dataset. In our case, we would like to clean and tokenize our sentencies. So let's try to implement it to deal with customer review sentencies." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "final class BertFeaturizer implements Featurizer {\n", - "\n", - " private final BertFullTokenizer tokenizer;\n", - " private final int maxLength; // the cut-off length\n", - "\n", - " public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {\n", - " this.tokenizer = tokenizer;\n", - " this.maxLength = maxLength;\n", - " }\n", - "\n", - " /** {@inheritDoc} */\n", - " @Override\n", - " public void featurize(DynamicBuffer buf, String input) {\n", - " Vocabulary vocab = tokenizer.getVocabulary();\n", - " // convert sentence to tokens (toLowerCase for uncased model)\n", - " List tokens = tokenizer.tokenize(input.toLowerCase());\n", - " // trim the tokens to maxLength\n", - " tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;\n", - " // BERT embedding convention \"[CLS] Your Sentence [SEP]\"\n", - " buf.put(vocab.getIndex(\"[CLS]\"));\n", - " tokens.forEach(token -> buf.put(vocab.getIndex(token)));\n", - " buf.put(vocab.getIndex(\"[SEP]\"));\n", - " }\n", - "\n", - " /** {@inheritDoc} */\n", - " @Override\n", - " public int dataRequired() {\n", - " throw new IllegalStateException(\"BertFeaturizer only support featurize, not deFeaturize\");\n", - " }\n", - "\n", - " /** {@inheritDoc} */\n", - " @Override\n", - " public Object deFeaturize(float[] data) {\n", - " throw new IllegalStateException(\"BertFeaturizer only support featurize, not deFeaturize\");\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Once we got this part done, we can apply the `BertFeaturizer` into our Dataset. We take `review_body` column and apply the Featurizer. We also pick `star_rating` as our label set. Since we go for batch input, we need to tell the dataset to pad our data if it is less than the `maxLength` we defined. `PaddingStackBatchifier` will do the work for you." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {\n", - " String amazonReview =\n", - " \"https://mlrepo.djl.ai/dataset/nlp/ai/djl/basicdataset/amazon_reviews/1.0/amazon_reviews_us_Digital_Software_v1_00.tsv.gz\";\n", - " float paddingToken = tokenizer.getVocabulary().getIndex(\"[PAD]\");\n", - " return CsvDataset.builder()\n", - " .optCsvUrl(amazonReview) // load from Url\n", - " .setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format\n", - " .setSampling(batchSize, true) // make sample size and random access\n", - " .optLimit(limit)\n", - " .addFeature(\n", - " new Feature(\n", - " \"review_body\", new BertFeaturizer(tokenizer, maxLength)))\n", - " .addLabel(\n", - " new Feature(\n", - " \"star_rating\", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))\n", - " .optDataBatchifier(\n", - " PaddingStackBatchifier.builder()\n", - " .optIncludeValidLengths(false)\n", - " .addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))\n", - " .build()) // define how to pad dataset to a fix length\n", - " .build();\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Construct your model\n", - "\n", - "We will load our pretrained model and prepare the classification. First construct the `criteria` to specify where to load the embedding (DistiledBERT), then call `loadModel` to download that embedding with pre-trained weights. Since this model is built without classification layer, we need to add a classification layer to the end of the model and train it. After you are done modifying the block, set it back to model using `setBlock`.\n", - "\n", - "### Load the word embedding\n", - "\n", - "We will download our word embedding and load it to memory (this may take a while)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// MXNet base model\n", - "String modelUrls = \"https://resources.djl.ai/test-models/distilbert.zip\";\n", - "if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n", - " modelUrls = \"https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip\";\n", - "}\n", - "\n", - "Criteria criteria = Criteria.builder()\n", - " .optApplication(Application.NLP.WORD_EMBEDDING)\n", - " .setTypes(NDList.class, NDList.class)\n", - " .optModelUrls(modelUrls)\n", - " .optProgress(new ProgressBar())\n", - " .build();\n", - "ZooModel embedding = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Create classification layers\n", - "\n", - "Then let's build a simple MLP layer to classify the ranks. We set the output of last FullyConnected (Linear) layer to 5 to get the predictions for star 1 to 5. Then all we need to do is to load the block into the model. Before applying the classification layer, we also need to add text embedding to the front. In our case, we just create a Lambda function that do the followings:\n", - "\n", - "1. batch_data (batch size, token indices) -> batch_data + max_length (size of the token indices)\n", - "2. generate embedding" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor embedder = embedding.newPredictor();\n", - "Block classifier = new SequentialBlock()\n", - " // text embedding layer\n", - " .add(\n", - " ndList -> {\n", - " NDArray data = ndList.singletonOrThrow();\n", - " NDList inputs = new NDList();\n", - " long batchSize = data.getShape().get(0);\n", - " float maxLength = data.getShape().get(1);\n", - "\n", - " if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n", - " inputs.add(data.toType(DataType.INT64, false));\n", - " inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));\n", - " inputs.add(data.getManager().arange(maxLength)\n", - " .toType(DataType.INT64, false)\n", - " .broadcast(data.getShape()));\n", - " } else {\n", - " inputs.add(data);\n", - " inputs.add(data.getManager().full(new Shape(batchSize), maxLength));\n", - " }\n", - " // run embedding\n", - " try {\n", - " return embedder.predict(inputs);\n", - " } catch (TranslateException e) {\n", - " throw new IllegalArgumentException(\"embedding error\", e);\n", - " }\n", - " })\n", - " // classification layer\n", - " .add(Linear.builder().setUnits(768).build()) // pre classifier\n", - " .add(Activation::relu)\n", - " .add(Dropout.builder().optRate(0.2f).build())\n", - " .add(Linear.builder().setUnits(5).build()) // 5 star rating\n", - " .addSingleton(nd -> nd.get(\":,0\")); // Take [CLS] as the head\n", - "Model model = Model.newInstance(\"AmazonReviewRatingClassification\");\n", - "model.setBlock(classifier);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Start Training\n", - "\n", - "Finally, we can start building our training pipeline to train the model.\n", - "\n", - "### Creating Training and Testing dataset\n", - "\n", - "Firstly, we need to create a voabulary that is used to map token to index such as \"hello\" to 1121 (1121 is the index of \"hello\" in dictionary). Then we simply feed the vocabulary to the tokenizer that used to tokenize the sentence. Finally, we just need to split the dataset based on the ratio.\n", - "\n", - "Note: we set the cut-off length to 64 which means only the first 64 tokens from the review will be used. You can increase this value to achieve better accuracy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Prepare the vocabulary\n", - "DefaultVocabulary vocabulary = DefaultVocabulary.builder()\n", - " .addFromTextFile(embedding.getArtifact(\"vocab.txt\"))\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", - "// Prepare dataset\n", - "int maxTokenLength = 64; // cutoff tokens length\n", - "int batchSize = 8;\n", - "int limit = Integer.MAX_VALUE;\n", - "// int limit = 512; // uncomment for quick testing\n", - "\n", - "BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);\n", - "CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);\n", - "// split data with 7:3 train:valid ratio\n", - "RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);\n", - "RandomAccessDataset trainingSet = datasets[0];\n", - "RandomAccessDataset validationSet = datasets[1];" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setup Trainer and training config\n", - "\n", - "Then, we need to setup our trainer. We set up the accuracy and loss function. The model training logs will be saved to `build/modlel`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "SaveModelTrainingListener listener = new SaveModelTrainingListener(\"build/model\");\n", - " listener.setSaveModelCallback(\n", - " trainer -> {\n", - " TrainingResult result = trainer.getTrainingResult();\n", - " Model model = trainer.getModel();\n", - " // track for accuracy and loss\n", - " float accuracy = result.getValidateEvaluation(\"Accuracy\");\n", - " model.setProperty(\"Accuracy\", String.format(\"%.5f\", accuracy));\n", - " model.setProperty(\"Loss\", String.format(\"%.5f\", result.getValidateLoss()));\n", - " });\n", - "DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type\n", - " .addEvaluator(new Accuracy())\n", - " .optDevices(Engine.getInstance().getDevices(1)) // train using single GPU\n", - " .addTrainingListeners(TrainingListener.Defaults.logging(\"build/model\"))\n", - " .addTrainingListeners(listener);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Start training\n", - "\n", - "We will start our training process. Training on GPU will takes approximately 10 mins. For CPU, it will take more than 2 hours to finish." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int epoch = 2;\n", - "\n", - "Trainer trainer = model.newTrainer(config);\n", - "trainer.setMetrics(new Metrics());\n", - "Shape encoderInputShape = new Shape(batchSize, maxTokenLength);\n", - "// initialize trainer with proper input shape\n", - "trainer.initialize(encoderInputShape);\n", - "EasyTrain.fit(trainer, epoch, trainingSet, validationSet);\n", - "System.out.println(trainer.getTrainingResult());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Save the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model.save(Paths.get(\"build/model\"), \"amazon-review.param\");" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Verify the model\n", - "\n", - "We can create a predictor from the model to run inference on our customized dataset. Firstly, we can create a `Translator` for the model to do preprocessing and post processing. Similar to what we have done before, we need to tokenize the input sentence and get the output ranking." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MyTranslator implements Translator {\n", - "\n", - " private BertFullTokenizer tokenizer;\n", - " private Vocabulary vocab;\n", - " private List ranks;\n", - "\n", - " public MyTranslator(BertFullTokenizer tokenizer) {\n", - " this.tokenizer = tokenizer;\n", - " vocab = tokenizer.getVocabulary();\n", - " ranks = Arrays.asList(\"1\", \"2\", \"3\", \"4\", \"5\");\n", - " }\n", - "\n", - " @Override\n", - " public Batchifier getBatchifier() { return Batchifier.STACK; }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, String input) {\n", - " List tokens = tokenizer.tokenize(input);\n", - " float[] indices = new float[tokens.size() + 2];\n", - " indices[0] = vocab.getIndex(\"[CLS]\");\n", - " for (int i = 0; i < tokens.size(); i++) {\n", - " indices[i+1] = vocab.getIndex(tokens.get(i));\n", - " }\n", - " indices[indices.length - 1] = vocab.getIndex(\"[SEP]\");\n", - " return new NDList(ctx.getNDManager().create(indices));\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " return new Classifications(ranks, list.singletonOrThrow().softmax(0));\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, we can create a `Predictor` to run the inference. Let's try with a random customer review:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String review = \"It works great, but it takes too long to update itself and slows the system\";\n", - "Predictor predictor = model.newPredictor(new MyTranslator(tokenizer));\n", - "\n", - "predictor.predict(review)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/tensorflow/pneumonia_detection.ipynb b/jupyter/tensorflow/pneumonia_detection.ipynb deleted file mode 100644 index c790ad13f55..00000000000 --- a/jupyter/tensorflow/pneumonia_detection.ipynb +++ /dev/null @@ -1,243 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Detecting Pneumonia from X-ray images using Deep Java Library" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*Disclaimer: this blog post is intended for educational purposes only. The application was developed using experimental code. The result should not be used for any medical diagnoses of pneumonia. This content has not been reviewed or approved by any scientists or medical professionals.*\n", - "\n", - "## Introduction\n", - "In this example, we demonstrate how deep learning (DL) can be used to detect pneumonia from chest X-ray images. This work is inspired by the [Chest X-ray Images Challenge](https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia) on Kaggle and a related [paper](https://www.cell.com/cell/fulltext/S0092-8674\\(18\\)30154-5). In this notebook, we illustrates how artificial intelligence can assist clinical decision making with focus on enterprise deployment. This work leverages a model trained using Keras and TensorFlow with [this Kaggle kernel](https://www.kaggle.com/aakashnain/beating-everything-with-depthwise-convolution). In this blog post, we will focus on generating predictions with this model using [Deep Java Library](https://djl.ai/) (DJL), an open source library to build and deploy DL in Java." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [documentation](https://docs.djl.ai/jupyter/index.html).\n", - "\n", - "These are the dependencies we will use:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-api:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-engine:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%loadFromPOM\n", - "\n", - " com.google.protobuf\n", - " protobuf-java\n", - " 3.19.2\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import java packages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.util.*;\n", - "import java.net.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### set the model URL" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var modelUrl = \"https://resources.djl.ai/demo/pneumonia-detection-model/saved_model.zip\";" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dive deep into Translator\n", - "\n", - "To successfully run inference, we need to define some preprocessing and post processing logic to achieve the best \n", - "prediction result and understandable output." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MyTranslator implements Translator {\n", - "\n", - " private static final List CLASSES = Arrays.asList(\"Normal\", \"Pneumonia\");\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, Image input) {\n", - " NDManager manager = ctx.getNDManager();\n", - " NDArray array = input.toNDArray(manager, Image.Flag.COLOR);\n", - " array = NDImageUtils.resize(array, 224).div(255.0f);\n", - " return new NDList(array);\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " NDArray probabilities = list.singletonOrThrow();\n", - " return new Classifications(CLASSES, probabilities);\n", - " }\n", - "\n", - " @Override\n", - " public Batchifier getBatchifier() {\n", - " return Batchifier.STACK;\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see above, the translator resizes the image to 224x224 and normalizes the image by dividing by 255 before feeding it into the model. When doing inference, you need to follow the same pre-processing procedure as was used during training. In this case, we need to match the Keras training code. After running prediction, the model outputs probabilities of each class as an [NDArray](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDArray.html). We need to tell the predictor to translate it back to classes, namely “Normal” or \"Pneumonia\".\n", - "\n", - "Until this point, all preparation work is done, we can start working on the prediction logic." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Predict using DJL\n", - "\n", - "### Load the image\n", - "We are going to load an CT scanned image of an infected lung from internet " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var imagePath = \"https://resources.djl.ai/images/chest_xray.jpg\";\n", - "var image = ImageFactory.getInstance().fromUrl(imagePath);\n", - "image.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load your model\n", - "Next, we will download the model from `modelUrl`. This will download the model into the DJL cache location" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria =\n", - " Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(modelUrl)\n", - " .optTranslator(new MyTranslator())\n", - " .optProgress(new ProgressBar())\n", - " .build();\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Run inference\n", - "Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();\n", - "Classifications classifications = predictor.predict(image);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tensorflow/rank_classification_using_BERT_on_Amazon_Review.ipynb b/jupyter/tensorflow/rank_classification_using_BERT_on_Amazon_Review.ipynb deleted file mode 100644 index 1b4647919c1..00000000000 --- a/jupyter/tensorflow/rank_classification_using_BERT_on_Amazon_Review.ipynb +++ /dev/null @@ -1,267 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Rank Classification using BERT on Amazon Review\n", - "\n", - "## Introduction\n", - "\n", - "In this tutorial, you learn how to use a pre-trained Tensorflow model to classifiy a Amazon Review rank. The model was refined on Amazon Review dataset with a pretrained DistilBert model.\n", - "\n", - "### About the dataset and model\n", - "\n", - "[Amazon Customer Review dataset](https://s3.amazonaws.com/amazon-reviews-pds/readme.html) consists of all different valid reviews from amazon.com. We will use the \"Digital_software\" category that consists of 102k valid reviews. As for the pre-trained model, use the DistilBERT[[1]](https://arxiv.org/abs/1910.01108) model. It's a light-weight BERT model already trained on [Wikipedia text corpora](https://en.wikipedia.org/wiki/List_of_text_corpora), a much larger dataset consisting of over millions text. The DistilBERT served as a base layer and we will add some more classification layers to output as rankings (1 - 5).\n", - "\n", - "\n", - "
    Amazon Review example
    \n", - "\n", - "\n", - "## Pre-requisites\n", - "This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:\n", - "1. How to setup and run [Java Kernel in Jupyter Notebook](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)\n", - "2. Basic components of Deep Java Library, and how to [train your first model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb).\n", - "\n", - "\n", - "## Getting started\n", - "Load the Deep Java Libarary and its dependencies from Maven. In here, you can choose between MXNet or PyTorch. MXNet is enabled by default. You can uncomment PyTorch dependencies and comment MXNet ones to switch to PyTorch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-engine:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-api:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%loadFromPOM\n", - "\n", - " com.google.protobuf\n", - " protobuf-java\n", - " 3.19.2\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's import the necessary modules:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.engine.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.nlp.*;\n", - "import ai.djl.modality.nlp.bert.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.util.*;\n", - "\n", - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prepare your model files\n", - "\n", - "You can download pre-trained Tensorflow model from: https://resources.djl.ai/demo/tensorflow/amazon_review_rank_classification.zip." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String modelUrl = \"https://resources.djl.ai/demo/tensorflow/amazon_review_rank_classification.zip\";\n", - "DownloadUtils.download(modelUrl, \"build/amazon_review_rank_classification.zip\", new ProgressBar());\n", - "Path zipFile = Paths.get(\"build/amazon_review_rank_classification.zip\");\n", - "\n", - "Path modelDir = Paths.get(\"build/saved_model\");\n", - "if (Files.notExists(modelDir)) {\n", - " ZipUtils.unzip(Files.newInputStream(zipFile), modelDir); \n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create Translator\n", - "\n", - "Inference in deep learning is the process of predicting the output for a given input based on a pre-defined model.\n", - "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide output.\n", - "\n", - "The [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface is used to: Pre-processing and Post-processing. The pre-processing\n", - "component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the\n", - "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n", - "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n", - "format.\n", - "\n", - "### Pre-processing\n", - "\n", - "Now, you need to convert the sentences into tokens. We provide a powerful tool [`BertTokenizer`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/bert/BertTokenizer.html) that you can use to convert questions and answers into tokens, and batchify your sequence together. Once you have properly formatted tokens, you can use [`Vocabulary`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/Vocabulary.html) to map your token to BERT index.\n", - "\n", - "The following code block demonstrates tokenizing the question and answer defined earlier into BERT-formatted tokens.\n", - "\n", - "In the zip file, we also bundled the BERT `vocab.txt` file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Prepare the vocabulary\n", - "Path vocabFile = modelDir.resolve(\"vocab.txt\");\n", - "DefaultVocabulary vocabulary = DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromTextFile(vocabFile)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", - "BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);\n", - "int maxTokenLength = 64; // cutoff tokens length\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MyTranslator implements Translator {\n", - "\n", - " private BertFullTokenizer tokenizer;\n", - " private Vocabulary vocab;\n", - " private List ranks;\n", - " private int length;\n", - "\n", - " public MyTranslator(BertFullTokenizer tokenizer, int length) {\n", - " this.tokenizer = tokenizer;\n", - " this.length = length;\n", - " vocab = tokenizer.getVocabulary();\n", - " ranks = Arrays.asList(\"1\", \"2\", \"3\", \"4\", \"5\");\n", - " }\n", - "\n", - " @Override\n", - " public Batchifier getBatchifier() {\n", - " return Batchifier.STACK;\n", - " }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, String input) {\n", - " List tokens = tokenizer.tokenize(input);\n", - " long[] indices = new long[length];\n", - " long[] mask = new long[length];\n", - " long[] segmentIds = new long[length];\n", - " int size = Math.min(length, tokens.size());\n", - " for (int i = 0; i < size; i++) {\n", - " indices[i + 1] = vocab.getIndex(tokens.get(i));\n", - " }\n", - " Arrays.fill(mask, 0, size, 1);\n", - " NDManager m = ctx.getNDManager();\n", - " return new NDList(m.create(indices), m.create(mask), m.create(segmentIds));\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " return new Classifications(ranks, list.singletonOrThrow().softmax(0));\n", - " }\n", - "}\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load your model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "MyTranslator translator = new MyTranslator(tokenizer, maxTokenLength);\n", - "\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(String.class, Classifications.class)\n", - " .optModelPath(modelDir) // Load model form model directory\n", - " .optTranslator(translator) // use custom translaotr \n", - " .build();\n", - "\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Run inference\n", - "\n", - "Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String review = \"It works great, but it takes too long to update itself and slows the system\";\n", - "\n", - "Predictor predictor = model.newPredictor();\n", - "Classifications classifications = predictor.predict(review);\n", - "\n", - "classifications" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/tensorflow_lite/inference_with_tensorflow_lite.ipynb b/jupyter/tensorflow_lite/inference_with_tensorflow_lite.ipynb deleted file mode 100644 index 3fb55f9799a..00000000000 --- a/jupyter/tensorflow_lite/inference_with_tensorflow_lite.ipynb +++ /dev/null @@ -1,156 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Inference with Tensorflow Lite\n", - "\n", - "In this tutorial, you learn how to load an existing TensorFlow Lite model and use it to run a prediction task.\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:model-zoo:0.23.0\n", - "%maven ai.djl.tflite:tflite-engine:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// Use secondary engine to help pre-processing and post-processing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.awt.image.*;\n", - "import java.nio.file.*;\n", - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Load your Tensorflow Lite mode from DJL model zoo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optEngine(\"TFLite\")\n", - " .optFilter(\"dataset\", \"aiyDish\")\n", - " .build();\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Create a Predictor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Load image for classification" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/sachertorte.jpg\");\n", - "\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Run inference" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Classifications classifications = predictor.predict(img);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now, you can load Tensorflow Lite model and run inference.\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/test_notebook.sh b/jupyter/test_notebook.sh deleted file mode 100755 index a4cd2166e9e..00000000000 --- a/jupyter/test_notebook.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash - -# test_notebook.sh [filename] -# If no filename is passed, it runs all files in current directory and subdirectories - -set -e - -function run_test { - base=$(basename $1) - # Workaround on crashes - if [[ "$base" == transfer_learning_on_cifar10* || "$base" == rank_classification_using_BERT* ]]; then - jupyter nbconvert --to notebook --inplace $1 - else - jupyter nbconvert --to notebook --execute --ExecutePreprocessor.timeout=600 --inplace $1 - fi -} - -if [[ $# -eq 0 ]]; then - for f in {**,.}/*.ipynb - do - dir=$(dirname f) - run_test "$f" - done -else - run_test $1 -fi diff --git a/jupyter/transfer_learning_on_cifar10.ipynb b/jupyter/transfer_learning_on_cifar10.ipynb deleted file mode 100644 index 663a9eafc7f..00000000000 --- a/jupyter/transfer_learning_on_cifar10.ipynb +++ /dev/null @@ -1,285 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Transfer Learning on CIFAR-10 Dataset\n", - "\n", - "\n", - "## Introduction\n", - "\n", - "In this tutorial, you learn how to train an image classification model using [Transfer Learning](https://en.wikipedia.org/wiki/Transfer_learning). Transfer learning is a popular machine learning technique that uses a model trained on one problem and applies it to a second related problem. Compared to training from scratch or designing a model for your specific problem, transfer learning can leverage the features already learned on a similar problem and produce a more robust model in a much shorter time.\n", - "\n", - "Train your model with the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset which consists of 60,000 32x32 color images in 10 classes. As for the pre-trained model, use the ResNet50v1[1] model. It's a 50 layer deep model already trained on [ImageNet](http://www.image-net.org/), a much larger dataset consisting of over 1.2 million images in 1000 classes. Modify it to classify 10 classes from the CIFAR-10 dataset.\n", - "\n", - "![The CIFAR-10 Dataset](https://resources.djl.ai/images/cifar-10.png)\n", - "
    the CIFAR10 dataset
    \n", - "\n", - "\n", - "## Pre-requisites\n", - "This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:\n", - "1. How to setup and run [Java Kernel in Jupyter Notebook](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)\n", - "2. Basic components of Deep Java Library, and how to [train your first model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb).\n", - "\n", - "\n", - "## Getting started\n", - "Load the Deep Java Libarary and its dependencies from Maven:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:basicdataset:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's import the necessary modules." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.basicdataset.cv.classification.*;\n", - "import ai.djl.engine.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.nn.*;\n", - "import ai.djl.nn.core.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.training.*;\n", - "import ai.djl.training.dataset.*;\n", - "import ai.djl.training.initializer.*;\n", - "import ai.djl.training.listener.*;\n", - "import ai.djl.training.loss.*;\n", - "import ai.djl.training.evaluator.*;\n", - "import ai.djl.training.optimizer.*;\n", - "import ai.djl.training.tracker.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.translate.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import java.util.concurrent.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Construct your model\n", - "\n", - "Load the pre-trained ResNet50V1 model. You can find it in the [Model Zoo](https://github.com/deepjavalibrary/djl/blob/master/docs/model-zoo.md). First construct the `criteria` to specify which ResNet model to load, then call `loadModel` to get a ResNet50V1 model with pre-trained weights. Note this model was trained on ImageNet with 1000 classes; the last layer is a Linear layer with 1000 output channels. Because you are repurposing it on CIFAR10 with 10 classes, you need to remove the last layer and add a new Linear layer with 10 output channels. After you are done modifying the block, set it back to model using `setBlock`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// load model and change last layer\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optProgress(new ProgressBar())\n", - " .optArtifactId(\"resnet\")\n", - " .optFilter(\"layers\", \"50\")\n", - " .optFilter(\"flavor\", \"v1\").build();\n", - "Model model = criteria.loadModel();\n", - "SequentialBlock newBlock = new SequentialBlock();\n", - "SymbolBlock block = (SymbolBlock) model.getBlock();\n", - "block.removeLastBlock();\n", - "newBlock.add(block);\n", - "newBlock.add(Blocks.batchFlattenBlock());\n", - "newBlock.add(Linear.builder().setUnits(10).build());\n", - "model.setBlock(newBlock);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prepare Dataset\n", - "\n", - "After you have the model, the next step is to prepare the dataset for training. You can construct a CIFAR10 builder with your own specifications. You have the options to get the train or test dataset, specify desired batch size, specify whether to shuffle your data during training, and most importantly, specify the pre-process pipeline. \n", - "\n", - "A pipeline consists of a series of transformations to apply on the input data before feeding it to the model. \n", - "\n", - "For example, `ToTensor` can be used to transform colored image NDArrays with shape (32, 32, 3) and values from 0 to 256 to NDArrays with shape (3, 32, 32) and values from 0 to 1. This operation is transposing image data from channels last to channels first format, which is more suitable for GPU computation. \n", - "\n", - "The `Normalize` transformation can normalize input data according to their mean and standard deviation values. This will make different features have similar range and help our model perform better." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int batchSize = 32;\n", - "int limit = Integer.MAX_VALUE; // change this to a small value for a dry run\n", - "// int limit = 160; // limit 160 records in the dataset for a dry run\n", - "Pipeline pipeline = new Pipeline(\n", - " new ToTensor(),\n", - " new Normalize(new float[] {0.4914f, 0.4822f, 0.4465f}, new float[] {0.2023f, 0.1994f, 0.2010f}));\n", - "Cifar10 trainDataset = \n", - " Cifar10.builder()\n", - " .setSampling(batchSize, true)\n", - " .optUsage(Dataset.Usage.TRAIN)\n", - " .optLimit(limit)\n", - " .optPipeline(pipeline)\n", - " .build();\n", - "trainDataset.prepare(new ProgressBar());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set up training configuration\n", - "\n", - "You are leveraging a pre-trained model, so you can expect the model to converge quickly. You will only train only ten epochs. As the model converges, you need to reduce the learning rate to get better results. You can use a `Tracker` to reduce the learning rate by 0.1 after two, five, and eight epochs. \n", - "\n", - "Deep Java Library supports training on multiple GPUs. You can use `setDevices` and pass an array of devices you want the model to be trained on. For example, `new Device[]{Device.gpu(0), Device.gpu(1)}` for training on GPU0 and GPU1. You can also call `Engine.getInstancec().getDevices(4)` and pass the number of GPUs you want to train. It will start with GPU0, and use CPU if no GPU is available. To learn more about multi-GPU training, read our multi-GPU [documentation](https://github.com/deepjavalibrary/djl/tree/master/examples/docs).\n", - "\n", - "To complete the training configuration set up, use the `Adam` optimizer, `SoftmaxCrossEntropyLoss`, and `Accuracy` for classification problems." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())\n", - " //softmaxCrossEntropyLoss is a standard loss for classification problems\n", - " .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is\n", - " .optDevices(Engine.getInstance().getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging\n", - " .addTrainingListeners(TrainingListener.Defaults.logging());\n", - "\n", - "// Now that we have our training configuration, we should create a new trainer for our model\n", - "Trainer trainer = model.newTrainer(config);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train your model\n", - "Now you can start training. This procedure is similar to the one in [Train Your First Model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb). Training requires the following steps:\n", - "1. Initialize a new trainer using the training config you just set up\n", - "2. Initialize the weights in trainer\n", - "3. Using a `for` loop to iterate through the whole dataset 10 times (epochs), resetting the evaluators at the end of each epoch\n", - "4. During each epoch, using a `for` loop to iterate through the dataset in batches and train batch by batch while printing the training accuracy on the progress bar." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int epoch = 10;\n", - "Shape inputShape = new Shape(1, 3, 32, 32);\n", - "trainer.initialize(inputShape);" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for (int i = 0; i < epoch; ++i) {\n", - " int index = 0;\n", - " for (Batch batch : trainer.iterateDataset(trainDataset)) {\n", - " EasyTrain.trainBatch(trainer, batch);\n", - " trainer.step();\n", - " batch.close();\n", - " }\n", - "\n", - " // reset training and validation evaluators at end of epoch\n", - " trainer.notifyListeners(listener -> listener.onEpoch(trainer));\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Save your model\n", - "\n", - "Finally, you can save your model after training is done and use it for inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path modelDir = Paths.get(\"build/resnet\");\n", - "Files.createDirectories(modelDir);\n", - "\n", - "model.setProperty(\"Epoch\", String.valueOf(epoch));\n", - "model.save(modelDir, \"resnet\");" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## What's next\n", - "\n", - "1. Try inference using the model you just trained. You can find an airplane image in [test resources](https://github.com/deepjavalibrary/djl/blob/master/examples/src/test/resources/airplane1.png) and follow the inference tutorials in the [Jupyter module](https://github.com/deepjavalibrary/djl/tree/master/jupyter).\n", - "\n", - "2. Follow the complete example with multi-GPU support, a validation dataset, and the fit API in the [examples module](https://github.com/deepjavalibrary/djl/tree/master/examples/docs).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## References\n", - "[1] [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)\n", - "\n", - "[2] [Gluon CV model zoo](https://gluon-cv.mxnet.io/model_zoo/classification.html) for pre-trained ResNet50 models" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tutorial/01_create_your_first_network.ipynb b/jupyter/tutorial/01_create_your_first_network.ipynb deleted file mode 100644 index 293fde5fec4..00000000000 --- a/jupyter/tutorial/01_create_your_first_network.ipynb +++ /dev/null @@ -1,206 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Create your first deep learning neural network\n", - "\n", - "## Introduction\n", - "\n", - "This is the first part of our [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial) that will take you through creating, training, and running inference on a neural network. In this part, you will learn how to use the built-in `Block` to create your first neural network - a Multilayer Perceptron.\n", - "\n", - "## Step 1: Setup development environment\n", - "\n", - "### Installation\n", - "\n", - "This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the [Jupyter README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Add the snapshot repository to get the DJL snapshot artifacts\n", - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "// Add the maven dependencies\n", - "%maven ai.djl:api:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.nn.*;\n", - "import ai.djl.nn.core.*;\n", - "import ai.djl.training.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Neural Network\n", - "\n", - "A neural network is a black box function. Instead of coding this function yourself, you provide many sample input/output pairs for this function. Then, we try to train the network to learn how to best approximate the observed behavior of the function given only these input/output pairs. A better model with more data can more accurately approximate the function.\n", - "\n", - "## Application\n", - "\n", - "The first thing to figure out when trying to build a neural network, like building most functions, is what your function signature is. What are your input types and output types? Because most models use relatively consistent signatures, we refer to them as [Applications](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/Application.html). Within the Applications interface, you can find a list of some of the more common model applications used in deep learning.\n", - "\n", - "In this tutorial, we will focus on the image classification application. It is one of the most common first applications and has a significant history with deep learning. In image classification, the input is a single image and it is classified based on the main subject of the image into a number of different possible classes. The classes for the image depend on the specific data you are training with." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Application application = Application.CV.IMAGE_CLASSIFICATION;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset\n", - "\n", - "Once you have figured out what application you want to learn, next you need to collect the data you are training with and form it into a dataset. Often, trying to collect and clean up the data is the most troublesome task in the deep learning process. \n", - "\n", - "Using a dataset can either involve collecting custom data from various sources or using one of the many datasets freely available online. The custom data may better suit your use case, but a free dataset is often faster and easier to use. You can read our [dataset guide](http://docs.djl.ai/docs/dataset.html) to learn more about datasets.\n", - "\n", - "### MNIST\n", - "\n", - "The dataset we will be using is [MNIST](https://en.wikipedia.org/wiki/MNIST_database), a database of handwritten digits. Each image contains a black and white digit from 0 to 9 in a 28x28 image. It is commonly used when getting started with deep learning because it is small and fast to train.\n", - "\n", - "![Mnist Image](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png)\n", - "\n", - "Once you understand your dataset, you should create an implementation of the [Dataset class](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Dataset.html). In this case, we provide the MNIST dataset built-in to make it easy for you to use it.\n", - "\n", - "## Multilayer Perceptron\n", - "\n", - "Now that we have our dataset, we can choose a model to train with it. For this tutorial, we will build one of the simplest and oldest deep learning networks: a Multilayer Perceptron (MLP).\n", - "\n", - "The MLP is organized into layers. The first layer is the input layer which contains your input data and the last layer is the output layer which produces the final result of the network. Between them are layers referred to as hidden layers. Having more hidden layers and larger hidden layers allows the MLP to represent more complex functions.\n", - "\n", - "The example below contains an input of size 3, a single hidden layer of size 3, and an output of size 2. The number and sizes of the hidden layers are usually determined through experimentation. Between each pair of layers is a linear operation (sometimes called a FullyConnected operation because each number in the input is connected to each number in the output by a matrix multiplication). Not pictured, there is also a non-linear activation function after each linear operation. For more information, see the [Multilayer Perceptron chapter of the D2l DJL book](https://d2l.djl.ai/chapter_multilayer-perceptrons/index.html).\n", - "\n", - "![MLP Image](https://upload.wikimedia.org/wikipedia/commons/c/c2/MultiLayerNeuralNetworkBigger_english.png)\n", - "\n", - "\n", - "## Step 2: Determine your input and output size\n", - "\n", - "The MLP model uses a one dimensional vector as the input and the output. You should determine the appropriate size of this vector based on your input data and what you will use the output of the model for.\n", - "\n", - "Our input vector will have size `28x28` because the MNIST input images have a height and width of 28 and it takes only a single number to represent each pixel. For a color image, you would need to further multiply this by `3` for the RGB channels.\n", - "\n", - "Our output vector has size `10` because there are `10` possible classes (0 to 9) for each image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "long inputSize = 28*28;\n", - "long outputSize = 10;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Create a **SequentialBlock**\n", - "\n", - "### NDArray\n", - "\n", - "The core data type used for working with deep learning is the [NDArray](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDArray.html). An NDArray represents a multidimensional, fixed-size homogeneous array. It has very similar behavior to the Numpy python package with the addition of efficient computing. We also have a helper class, the [NDList](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDList.html) which is a list of NDArrays which can have different sizes and data types.\n", - "\n", - "### Block API\n", - "\n", - "In DJL, [Blocks](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/Block.html) serve a purpose similar to functions that convert an input `NDList` to an output `NDList`. They can represent single operations, parts of a neural network, and even the whole neural network. What makes blocks special is that they contain a number of parameters that are used in their function and are trained during deep learning. As these parameters are trained, the function represented by the blocks get more and more accurate.\n", - "\n", - "When building these block functions, the easiest way is to use composition. Similar to how functions are built by calling other functions, blocks can be built by combining other blocks. We refer to the containing block as the parent and the sub-blocks as the children.\n", - "\n", - "\n", - "We provide several helpers to make it easy to build common block composition structures. For the MLP we will use the [SequentialBlock](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/SequentialBlock.html), a container block whose children form a chain of blocks where each child block feeds its output to the next child block in a sequence.\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "SequentialBlock block = new SequentialBlock();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Add blocks to SequentialBlock\n", - "\n", - "An MLP is organized into several layers. Each layer is composed of a [Linear Block](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/core/Linear.html) and a non-linear activation function. If we just had two linear blocks in a row, it would be the same as a combined linear block ($f(x) = W_2(W_1x) = (W_2W_1)x = W_{combined}x$). An activation is used to intersperse between the linear blocks to allow them to represent non-linear functions. We will use the popular [ReLU](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/Activation.html#reluBlock()) as our activation function.\n", - "\n", - "The first layer and last layers have fixed sizes depending on your desired input and output size. However, you are free to choose the number and sizes of the middle layers in the network. We will create a smaller MLP with two middle layers that gradually decrease the size. Typically, you would experiment with different values to see what works the best on your data set." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "block.add(Blocks.batchFlattenBlock(inputSize));\n", - "block.add(Linear.builder().setUnits(128).build());\n", - "block.add(Activation::relu);\n", - "block.add(Linear.builder().setUnits(64).build());\n", - "block.add(Activation::relu);\n", - "block.add(Linear.builder().setUnits(outputSize).build());\n", - "\n", - "block" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now that you've successfully created your first neural network, you can use this network to train your model.\n", - "\n", - "Next chapter: [Train your first model](02_train_your_first_model.ipynb)\n", - "\n", - "You can find the complete source code for this tutorial in the [model zoo](https://github.com/deepjavalibrary/djl/blob/master/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tutorial/02_train_your_first_model.ipynb b/jupyter/tutorial/02_train_your_first_model.ipynb deleted file mode 100644 index 4905dadfbb5..00000000000 --- a/jupyter/tutorial/02_train_your_first_model.ipynb +++ /dev/null @@ -1,243 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train your first model\n", - "\n", - "This is the second of our [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial) that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to train an image classification model that can recognize handwritten digits.\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the [Jupyter README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Add the snapshot repository to get the DJL snapshot artifacts\n", - "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "// Add the maven dependencies\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:basicdataset:0.23.0\n", - "%maven ai.djl:model-zoo:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.nio.file.*;\n", - "\n", - "import ai.djl.*;\n", - "import ai.djl.basicdataset.cv.classification.Mnist;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.training.*;\n", - "import ai.djl.training.dataset.*;\n", - "import ai.djl.training.initializer.*;\n", - "import ai.djl.training.loss.*;\n", - "import ai.djl.training.listener.*;\n", - "import ai.djl.training.evaluator.*;\n", - "import ai.djl.training.optimizer.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.basicmodelzoo.cv.classification.*;\n", - "import ai.djl.basicmodelzoo.basic.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 1: Prepare MNIST dataset for training\n", - "\n", - "In order to train, you must create a [Dataset class](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Dataset.html) to contain your training data. A dataset is a collection of sample input/output pairs for the function represented by your neural network. Each single input/output is represented by a [Record](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Record.html). Each record could have multiple arrays of inputs or outputs such as an image question and answer dataset where the input is both an image and a question about the image while the output is the answer to the question.\n", - "\n", - "Because data learning is highly parallelizable, training is often done not with a single record at a time, but a [Batch](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Batch.html). This can lead to significant performance gains, especially when working with images\n", - "\n", - "## Sampler\n", - "\n", - "Then, we must decide the parameters for loading data from the dataset. The only parameter we need for MNIST is the choice of [Sampler](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Sampler.html). The sampler decides which and how many element from datasets are part of each batch when iterating through it. We will have it randomly shuffle the elements for the batch and use a batchSize of 32. The batchSize is usually the largest power of 2 that fits within memory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int batchSize = 32;\n", - "Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();\n", - "mnist.prepare(new ProgressBar());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 2: Create your Model\n", - "\n", - "Next we will build a model. A [Model](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/Model.html) contains a neural network [Block](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/Block.html) along with additional artifacts used for the training process. It possesses additional information about the inputs, outputs, shapes, and data types you will use. Generally, you will use the Model once you have fully completed your Block.\n", - "\n", - "In this part of the tutorial, we will use the built-in Multilayer Perceptron Block from the Model Zoo. To learn how to build it from scratch, see the previous tutorial: [Create Your First Network](01_create_your_first_network.ipynb).\n", - "\n", - "Because images in the MNIST dataset are 28x28 grayscale images, we will create an MLP block with 28 x 28 input. The output will be 10 because there are 10 possible classes (0 to 9) each image could be. For the hidden layers, we have chosen `new int[] {128, 64}` by experimenting with different values." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Model model = Model.newInstance(\"mlp\");\n", - "model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 3: Create a Trainer\n", - "\n", - "Now, you can create a [`Trainer`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/Trainer.html) to train your model. The trainer is the main class to orchestrate the training process. Usually, they will be opened using a try-with-resources and closed after training is over.\n", - "\n", - "The trainer takes an existing model and attempts to optimize the parameters inside the model's Block to best match the dataset. Most optimization is based upon [Stochastic Gradient Descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) (SGD).\n", - "\n", - "## Step 3.1: Setup your training configurations\n", - "\n", - "Before you create your trainer, we we will need a [training configuration](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/DefaultTrainingConfig.html) that describes how to train your model.\n", - "\n", - "The following are a few common items you may need to configure your training:\n", - "\n", - "* **REQUIRED** [`Loss`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/loss/Loss.html) function: A loss function is used to measure how well our model matches the dataset. Because the lower value of the function is better, it's called the \"loss\" function. The Loss is the only required argument to the model\n", - "* [`Evaluator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/evaluator/Evaluator.html) function: An evaluator function is also used to measure how well our model matches the dataset. Unlike the loss, they are only there for people to look at and are not used for optimizing the model. Since many losses are not as intuitive, adding other evaluators such as Accuracy can help to understand how your model is doing. If you know of any useful evaluators, we recommend adding them.\n", - "* [`Training Listeners`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/listener/TrainingListener.html): The training listener adds additional functionality to the training process through a listener interface. This can include showing training progress, stopping early if training becomes undefined, or recording performance metrics. We offer several easy sets of [default listeners](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/listener/TrainingListener.Defaults.html).\n", - "\n", - "You can also configure other options such as the Device, Initializer, and Optimizer. See [more details](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/TrainingConfig.html)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())\n", - " //softmaxCrossEntropyLoss is a standard loss for classification problems\n", - " .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is\n", - " .addTrainingListeners(TrainingListener.Defaults.logging());\n", - "\n", - "// Now that we have our training configuration, we should create a new trainer for our model\n", - "Trainer trainer = model.newTrainer(config);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 5: Initialize Training\n", - "\n", - "Before training your model, you have to initialize all of the parameters with starting values. You can use the trainer for this initialization by passing in the input shape.\n", - "\n", - "* The first axis of the input shape is the batch size. This won't impact the parameter initialization, so you can use 1 here.\n", - "* The second axis of the input shape of the MLP - the number of pixels in the input image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.initialize(new Shape(1, 28 * 28));" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 6: Train your model\n", - "\n", - "Now, we can train the model.\n", - "\n", - "When training, it is usually organized into epochs where each epoch trains the model on each item in the dataset once. It is slightly faster than training randomly.\n", - "\n", - "Then, we will use the EasyTrain to, as the name promises, make the training easy. If you want to see more details about how the training loop works, see [the EasyTrain class](https://github.com/deepjavalibrary/djl/blob/master/api/src/main/java/ai/djl/training/EasyTrain.java) or [read our Dive into Deep Learning book](https://d2l.djl.ai)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.\n", - "int epoch = 2;\n", - "\n", - "EasyTrain.fit(trainer, epoch, mnist, null);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 7: Save your model\n", - "\n", - "Once your model is trained, you should save it so that it can be reloaded later. You can also add metadata to it such as training accuracy, number of epochs trained, etc that can be used when loading the model or when examining it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path modelDir = Paths.get(\"build/mlp\");\n", - "Files.createDirectories(modelDir);\n", - "\n", - "model.setProperty(\"Epoch\", String.valueOf(epoch));\n", - "\n", - "model.save(modelDir, \"mlp\");\n", - "\n", - "model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Summary\n", - "\n", - "Now, you've successfully trained a model that can recognize handwritten digits. You'll learn how to apply this model in the next chapter: [Run image classification with your model](03_image_classification_with_your_model.ipynb).\n", - "\n", - "You can find the complete source code for this tutorial in the [examples project](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/TrainMnist.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tutorial/03_image_classification_with_your_model.ipynb b/jupyter/tutorial/03_image_classification_with_your_model.ipynb deleted file mode 100644 index f8d42d7972e..00000000000 --- a/jupyter/tutorial/03_image_classification_with_your_model.ipynb +++ /dev/null @@ -1,214 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Inference with your model\n", - "\n", - "This is the third and final tutorial of our [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial) that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to execute your image classification model for a production system.\n", - "\n", - "In the [previous tutorial](02_train_your_first_model.ipynb), you successfully trained your model. Now, we will learn how to implement a `Translator` to convert between POJO and `NDArray` as well as a `Predictor` to run inference.\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the [Jupyter README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Add the snapshot repository to get the DJL snapshot artifacts\n", - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "// Add the maven dependencies\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:model-zoo:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.awt.image.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import java.util.stream.*;\n", - "import ai.djl.*;\n", - "import ai.djl.basicmodelzoo.basic.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.NDImageUtils;\n", - "import ai.djl.translate.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Load your handwritten digit image\n", - "\n", - "We will start by loading the image that we want to run our model to classify." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/0.png\");\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Load your model\n", - "\n", - "Next, we need to load the model to run inference with. This model should have been saved to the `build/mlp` directory when running the [previous tutorial](02_train_your_first_model.ipynb)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path modelDir = Paths.get(\"build/mlp\");\n", - "Model model = Model.newInstance(\"mlp\");\n", - "model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));\n", - "model.load(modelDir);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In addition to loading a local model, you can also find pretrained models within our [model zoo](http://docs.djl.ai/docs/model-zoo.html). See more options in our [model loading documentation](http://docs.djl.ai/docs/load_model.html).\n", - "\n", - "## Step 3: Create a `Translator`\n", - "\n", - "The [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) is used to encapsulate the pre-processing and post-processing functionality of your application. The input to the processInput and processOutput should be single data items, not batches." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Translator translator = new Translator() {\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, Image input) {\n", - " // Convert Image to NDArray\n", - " NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);\n", - " return new NDList(NDImageUtils.toTensor(array));\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " // Create a Classifications with the output probabilities\n", - " NDArray probabilities = list.singletonOrThrow().softmax(0);\n", - " List classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());\n", - " return new Classifications(classNames, probabilities);\n", - " }\n", - " \n", - " @Override\n", - " public Batchifier getBatchifier() {\n", - " // The Batchifier describes how to combine a batch together\n", - " // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array\n", - " return Batchifier.STACK;\n", - " }\n", - "};" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Create Predictor\n", - "\n", - "Using the translator, we will create a new [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html). The predictor is the main class to orchestrate the inference process. During inference, a trained model is used to predict values, often for production use cases. The predictor is NOT thread-safe, so if you want to do prediction in parallel, you should call newPredictor multiple times to create a predictor object for each thread." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var predictor = model.newPredictor(translator);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 5: Run inference\n", - "\n", - "With our predictor, we can simply call the [predict](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html#predict(I)) method to run inference. For better performance, you can also call [batchPredict](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html#batchPredict(java.util.List)) with a list of input items. Afterwards, the same predictor should be used for further inference calls. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var classifications = predictor.predict(img);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now, you've successfully built a model, trained it, and run inference. Congratulations on finishing the [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial). After this, you should read our other [examples](https://github.com/deepjavalibrary/djl/tree/master/examples) and [jupyter notebooks](https://github.com/deepjavalibrary/djl/tree/master/jupyter) to learn more about DJL.\n", - "\n", - "You can find the complete source code for this tutorial in the [examples project](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tutorial/README.md b/jupyter/tutorial/README.md deleted file mode 100644 index 4c53b0f41e8..00000000000 --- a/jupyter/tutorial/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# DJL - Beginner Tutorial - -Our beginner tutorial takes you through creating your first network, training it, and using it in a real system. This is a good place to start if you are new to DJL or to deep learning. - -1. [Create your first neural network](01_create_your_first_network.ipynb) -2. [Train your first model](02_train_your_first_model.ipynb) -3. [Run image classification with your first model](03_image_classification_with_your_model.ipynb) diff --git a/model-zoo/README.md b/model-zoo/README.md index 11ae15c5505..8dbf702eae5 100644 --- a/model-zoo/README.md +++ b/model-zoo/README.md @@ -33,7 +33,7 @@ You can pull the model zoo from the central Maven repository by including the fo ai.djl model-zoo - 0.23.0 + 0.27.0 ``` @@ -61,7 +61,7 @@ The following is an example of the criteria to find a Resnet50-v1 model that has .optFilter("dataset", "cifar10") .build(); - ZooModel ssd = criteria.loadModel()); + ZooModel ssd = criteria.loadModel(); ``` If you already know which `ModelLoader` to use, you can simply do the following: diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java index 7cdfc040c12..543ab5f1f21 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java @@ -43,8 +43,8 @@ public String getGroupId() { public Set getSupportedEngines() { Set set = new HashSet<>(); set.add("MXNet"); + set.add("PyTorch"); // TODO Currently WIP in supporting these two engines in the basic model zoo - // set.add("PyTorch"); // set.add("TensorFlow"); return set; } diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java index 05f171c1ec2..2869e42f55d 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java @@ -56,6 +56,7 @@ public Mlp(int input, int output, int[] hidden) { * @param hidden the sizes of all of the hidden layers * @param activation the activation function to use */ + @SuppressWarnings("this-escape") public Mlp(int input, int output, int[] hidden, Function activation) { add(Blocks.batchFlattenBlock(input)); for (int hiddenSize : hidden) { diff --git a/settings.gradle b/settings.gradle index 75a1f854ef8..ff6967fc308 100644 --- a/settings.gradle +++ b/settings.gradle @@ -2,6 +2,7 @@ rootProject.name = 'djl' include ':api' include ':basicdataset' include ':djl-zero' +include ':engines:llama' include ':engines:ml:xgboost' include ':engines:ml:lightgbm' include ':engines:mxnet:jnarator' @@ -34,7 +35,9 @@ include ':extensions:sentencepiece' include ':extensions:tokenizers' include ':extensions:tablesaw' include ':extensions:timeseries' -include ':extensions:spark' +if (JavaVersion.current() < JavaVersion.VERSION_21) { + include ':extensions:spark' +} include ':integration' include ':model-zoo' include ':testing' diff --git a/testing/src/main/java/ai/djl/testing/TestRequirements.java b/testing/src/main/java/ai/djl/testing/TestRequirements.java index bf57d64bd7c..32f242589b9 100644 --- a/testing/src/main/java/ai/djl/testing/TestRequirements.java +++ b/testing/src/main/java/ai/djl/testing/TestRequirements.java @@ -13,6 +13,7 @@ package ai.djl.testing; import ai.djl.engine.Engine; +import ai.djl.util.Utils; import org.testng.SkipException; @@ -45,7 +46,7 @@ public static void weekly() { /** Requires a test not be run in offline mode. */ public static void notOffline() { - if (Boolean.getBoolean("offline")) { + if (Utils.isOfflineMode()) { throw new SkipException("This test can not run while offline"); } } diff --git a/tools/conf/checkstyle.xml b/tools/conf/checkstyle.xml index 4dd8c7bb0d5..b1473792b6c 100644 --- a/tools/conf/checkstyle.xml +++ b/tools/conf/checkstyle.xml @@ -156,7 +156,6 @@ value="Override, Test, Before, After, BeforeClass, AfterClass"/> --> - - - + + diff --git a/tools/gradle/check.gradle b/tools/gradle/check.gradle index 4bf78c6b336..c0f2e013bb9 100644 --- a/tools/gradle/check.gradle +++ b/tools/gradle/check.gradle @@ -35,7 +35,7 @@ tasks.withType(Pmd) { apply plugin: "checkstyle" checkstyle { - toolVersion = "8.26" + toolVersion = "10.14.2" ignoreFailures = false checkstyleTest.enabled = true configProperties = [ diff --git a/tools/gradle/publish.gradle b/tools/gradle/publish.gradle index 663f847e95b..0baa3d5a2c1 100644 --- a/tools/gradle/publish.gradle +++ b/tools/gradle/publish.gradle @@ -1,7 +1,8 @@ -configure([ +def projects = [ project(':api'), project(':basicdataset'), project(':djl-zero'), + project(':engines:llama'), project(':engines:ml:xgboost'), project(':engines:ml:lightgbm'), project(':engines:mxnet:mxnet-engine'), @@ -27,8 +28,13 @@ configure([ project(':extensions:tablesaw'), project(':extensions:timeseries'), project(':extensions:tokenizers'), - project(':extensions:spark'), - project(':model-zoo')]) { + project(':model-zoo') +] +if (JavaVersion.current() < JavaVersion.VERSION_21) { + projects.add(project(':extensions:spark')) +} + +configure(projects) { apply plugin: "maven-publish" apply plugin: "signing" diff --git a/tools/scripts/build_ft_deps.sh b/tools/scripts/build_ft_deps.sh deleted file mode 100755 index 4d3cb94a103..00000000000 --- a/tools/scripts/build_ft_deps.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env bash - -set -ex - -FT_VERSION=$1 -NVIDIA_TRITON_SERVER_VERSION=$2 -IS_LLAMA_BUILD=$3 - -apt-get update && apt-get install -y rapidjson-dev - -pushd /tmp - -git clone https://github.com/NVIDIA/FasterTransformer.git -b ${FT_VERSION} - -export FT_DIR=/tmp/FasterTransformer -mkdir -p /tmp/binaries - -# Build FasterTransformer Triton library -if [ "$IS_LLAMA_BUILD" = "false" ] ; then - git clone https://github.com/triton-inference-server/fastertransformer_backend.git -else - echo "cloning forked FT backend repo with llama support" - git clone https://github.com/rohithkrn/fastertransformer_backend.git -b llama_void_main -fi -mkdir -p fastertransformer_backend/build -cd fastertransformer_backend/build -cmake \ - -D CMAKE_EXPORT_COMPILE_COMMANDS=1 \ - -D CMAKE_BUILD_TYPE=Release \ - -D ENABLE_FP8=OFF \ - -D CMAKE_INSTALL_PREFIX=/opt/tritonserver \ - -D TRITON_COMMON_REPO_TAG="${NVIDIA_TRITON_SERVER_VERSION}" \ - -D TRITON_CORE_REPO_TAG="${NVIDIA_TRITON_SERVER_VERSION}" \ - -D TRITON_BACKEND_REPO_TAG="${NVIDIA_TRITON_SERVER_VERSION}" \ - .. -make -j$(nproc) install -cp /opt/tritonserver/backends/fastertransformer/*.so /tmp/binaries/ -cd ../../ - -# Build FasterTransformer TH Ops library -mkdir -p FasterTransformer/build -cd FasterTransformer/build -git submodule init && git submodule update -cmake -DCMAKE_BUILD_TYPE=Release -DSM=70,75,80,86 -DBUILD_PYT=ON -DBUILD_MULTI_GPU=ON .. -make -j$(nproc) -cp lib/libth_transformer.so /tmp/binaries/ -cd ../../ - -popd diff --git a/website/js/index.js b/website/js/index.js index 70b313acea3..605e1c04228 100644 --- a/website/js/index.js +++ b/website/js/index.js @@ -27,7 +27,7 @@ let app = new Vue({ }, { name: 'Tutorial', - url: 'https://docs.djl.ai/jupyter/tutorial/index.html' + url: 'https://docs.djl.ai/docs/demos/jupyter/tutorial/index.html' }, { name: 'Examples',

+ The Llama Engine module contains the Llama.cpp implementation of the DJL EngineProvider. + See here for more details. +