From 1ea715db3ad05a082c786bf3b8e40453c83e45e9 Mon Sep 17 00:00:00 2001 From: mariecwhite Date: Thu, 5 Oct 2023 22:36:01 +0000 Subject: [PATCH 1/2] Generate TFLite files from Tensorflow models --- .../comparative_suite/tf/model_definitions.py | 5 +++ .../tf/scripts/generate_model_artifacts.py | 43 +++++++++++++++++++ .../openxla/benchmark/def_types.py | 3 ++ 3 files changed, 51 insertions(+) diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py index 0f7327a7..5c048aea 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py @@ -37,6 +37,7 @@ def_types.ModelArtifactType.STABLEHLO_MLIR, def_types.ModelArtifactType.XLA_HLO_DUMP, def_types.ModelArtifactType.TF_SAVEDMODEL_V2, + def_types.ModelArtifactType.TFLITE_FP32, ], ) T5_LARGE_FP32_TF_512XI32_BATCHES = utils.build_batch_models( @@ -69,6 +70,7 @@ def_types.ModelArtifactType.STABLEHLO_MLIR, def_types.ModelArtifactType.XLA_HLO_DUMP, def_types.ModelArtifactType.TF_SAVEDMODEL_V2, + def_types.ModelArtifactType.TFLITE_FP32, ], ) BERT_LARGE_FP32_TF_384XI32_BATCHES = utils.build_batch_models( @@ -100,6 +102,8 @@ def_types.ModelArtifactType.STABLEHLO_MLIR, def_types.ModelArtifactType.XLA_HLO_DUMP, def_types.ModelArtifactType.TF_SAVEDMODEL_V2, + def_types.ModelArtifactType.TFLITE_FP32, + def_types.ModelArtifactType.TFLITE_INT8, ], ) RESNET50_FP32_TF_224X224X3XF32_BATCHES = utils.build_batch_models( @@ -130,6 +134,7 @@ def_types.ModelArtifactType.STABLEHLO_MLIR, def_types.ModelArtifactType.XLA_HLO_DUMP, def_types.ModelArtifactType.TF_SAVEDMODEL_V2, + def_types.ModelArtifactType.TFLITE_FP32, ], ) diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/scripts/generate_model_artifacts.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/scripts/generate_model_artifacts.py index d17e1df0..69776f5a 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/scripts/generate_model_artifacts.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/scripts/generate_model_artifacts.py @@ -9,6 +9,7 @@ import pathlib import re import multiprocessing +import numpy as np import shutil import sys import tarfile @@ -61,6 +62,47 @@ def _generate_mlir(model_dir: pathlib.Path, saved_model_dir: pathlib.Path): write_bytecode(str(mlir_path), result) +def _generate_tflite(inputs: Tuple[Any, ...], model_dir: pathlib.Path, + saved_model_dir: pathlib.Path): + converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir)) + + # Generate fp32 model. + try: + tflite_model = converter.convert() + tflite_model_path = model_dir.joinpath("model_fp32.tflite") + with open(tflite_model_path, 'wb') as f: + f.write(tflite_model) + except Exception as e: + print(f"Failed to generate int8 TFLite model. Exception: {e}") + + # Generate int8 model. + try: + + def representative_examples(): + for _ in range(2): + random_inputs = [] + for input in inputs: + random_inputs.append( + np.random.uniform(low=input.dtype.min, + high=input.dtype.max, + size=input.shape).astype( + input.dtype.as_numpy_dtype)) + yield random_inputs + + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.TFLITE_BUILTINS_INT8 + ] + converter.representative_dataset = representative_examples + converter.inference_type = tf.int8 + tflite_model_int8 = converter.convert() + tflite_model_int8_path = model_dir.joinpath("model_int8.tflite") + with open(tflite_model_int8_path, 'wb') as f: + f.write(tflite_model_int8) + except Exception as e: + print(f"Failed to generate int8 TFLite model. Exception: {e}") + + def _generate_artifacts(model: def_types.Model, save_dir: pathlib.Path, auto_upload: bool): model_dir = save_dir.joinpath(model.name) @@ -87,6 +129,7 @@ def _generate_artifacts(model: def_types.Model, save_dir: pathlib.Path, saved_model_dir = _generate_saved_model(inputs, model_obj, model_dir) _generate_mlir(model_dir, saved_model_dir) + _generate_tflite(inputs, model_dir, saved_model_dir) with tarfile.open(model_dir.joinpath("tf-model.tgz"), "w:gz") as tar: tar.add(f"{saved_model_dir}/", arcname="") diff --git a/common_benchmark_suite/openxla/benchmark/def_types.py b/common_benchmark_suite/openxla/benchmark/def_types.py index 666d2564..25d45c2b 100644 --- a/common_benchmark_suite/openxla/benchmark/def_types.py +++ b/common_benchmark_suite/openxla/benchmark/def_types.py @@ -15,6 +15,7 @@ class ModelFrameworkType(Enum): """Type of framework a model is implemented in.""" TF_V1 = "tensorflow_v1" TF_V2 = "tensorflow_v2" + TFLITE = "tflite" PYTORCH = "pytorch" JAX = "jax" GGML = "ggml" @@ -42,6 +43,8 @@ class ModelArtifactType(Enum): """Type of derived model artifact.""" TF_SAVEDMODEL_V1 = "tf_savedmodel_v1" TF_SAVEDMODEL_V2 = "tf_savedmodel_v2" + TFLITE_FP32 = "tflite_fp32" + TFLITE_INT8 = "tflite_int8" XLA_HLO_DUMP = "xla_hlo_dump" STABLEHLO_MLIR = "stablehlo_mlir" LINALG_MLIR = "linalg_mlir" From 784755d888eb41e7d148cd9ce9b1bda2c6d6ebed Mon Sep 17 00:00:00 2001 From: mariecwhite Date: Sat, 7 Oct 2023 03:23:02 +0000 Subject: [PATCH 2/2] Benchmark TFLite --- .github/workflows/run_tflite_benchmark.yml | 123 +++++++++++++ .../comparative_suite/tf/model_definitions.py | 2 +- experimental/ggml/benchmark_ggml.sh | 4 +- experimental/iree/benchmark_iree.sh | 61 ++++++ .../iree/set_android_scaling_governor.sh | 51 ++++++ experimental/tflite/benchmark_lib.py | 151 +++++++++++++++ experimental/tflite/benchmark_tflite.sh | 96 ++++++++++ experimental/tflite/requirements.txt | 2 + experimental/tflite/run_benchmarks.py | 137 ++++++++++++++ experimental/tflite/run_benchmarks_android.py | 173 ++++++++++++++++++ .../tflite/set_android_scaling_governor.sh | 51 ++++++ experimental/tflite/setup_venv.sh | 32 ++++ 12 files changed, 880 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/run_tflite_benchmark.yml create mode 100755 experimental/iree/benchmark_iree.sh create mode 100755 experimental/iree/set_android_scaling_governor.sh create mode 100644 experimental/tflite/benchmark_lib.py create mode 100755 experimental/tflite/benchmark_tflite.sh create mode 100644 experimental/tflite/requirements.txt create mode 100755 experimental/tflite/run_benchmarks.py create mode 100755 experimental/tflite/run_benchmarks_android.py create mode 100755 experimental/tflite/set_android_scaling_governor.sh create mode 100644 experimental/tflite/setup_venv.sh diff --git a/.github/workflows/run_tflite_benchmark.yml b/.github/workflows/run_tflite_benchmark.yml new file mode 100644 index 00000000..46839d08 --- /dev/null +++ b/.github/workflows/run_tflite_benchmark.yml @@ -0,0 +1,123 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# TFLite Benchmarks Workflow. + +name: TFLite Benchmarks + +on: + workflow_dispatch: + pull_request: + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +env: + GCS_DIR: gs://openxla-github-actions-${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }}-artifacts/${{ github.run_id }}/${{ github.run_attempt }} + +jobs: + setup: + runs-on: ubuntu-22.04 + outputs: + runner-group: ${{ steps.configure.outputs.runner-group }} + benchmark-gcs-dir: ${{ steps.configure.outputs.benchmark-gcs-dir }} + steps: + - name: "Checking out PR repository" + uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - name: "Configuring CI options" + id: configure + env: + RUNNER_GROUP: ${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }} + run: | + # Just informative logging. There should only be two commits in the + # history here, but limiting the depth helps when copying from a local + # repo instead of using checkout, e.g. with + # https://github.com/nektos/act where there will be more. + git log --oneline --graph --max-count=3 + # Workflow jobs can't access `env` in `runs-on`, so we need to make + # `runner-group` a job output variable. + echo "runner-group=${RUNNER_GROUP}" > "${GITHUB_OUTPUT}" + + # For presubmit testing, the result artifacts are uploaded to the + # temporary workflow GCS dir. In postsubmit, the result artifacts are + # uploaded to the comparative benchmark GCS dir. + if [[ "${RUNNER_GROUP}" == "presubmit" ]]; then + BENCHMARK_GCS_DIR="${GCS_DIR}/comparative-benchmark-artifacts" + else + BENCHMARK_GCS_DIR="gs://comparative-benchmark-artifacts/$(date +'%Y-%m-%d').$(date +'%s')" + fi + echo "benchmark-gcs-dir=${BENCHMARK_GCS_DIR}" >> "${GITHUB_OUTPUT}" + + benchmark_on_c2-standard-16: + needs: [setup] + runs-on: + - self-hosted # must come first + - runner-group=${{ needs.setup.outputs.runner-group }} + - environment=prod + - machine-type=c2-standard-16 + env: + BENCHMARK_GCS_DIR: ${{ needs.setup.outputs.benchmark-gcs-dir }} + RESULTS_DIR: results-dir + TARGET_DEVICE: c2-standard-16 + TFLITE_TOOL_DIR: tool-dir + steps: + - name: "Checking out PR repository" + uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - name: "Setup" + id: setup + run: | + echo "results-gcs-dir=${BENCHMARK_GCS_DIR}/${TARGET_DEVICE}-results" >> "${GITHUB_OUTPUT}" + mkdir "${RESULTS_DIR}" + mkdir "${TFLITE_TOOL_DIR}" + - name: "Benchmarking TFLite" + env: + TFLITE_RESULTS_JSON: tflite.json + RESULTS_GCS_DIR: ${{ steps.setup.outputs.results-gcs-dir }} + run: | + RESULTS_PATH="${RESULTS_DIR}/${TFLITE_RESULTS_JSON}" + docker run --mount="type=bind,src="${PWD}",target=/work" --workdir="/work" \ + "gcr.io/iree-oss/openxla-benchmark/base@sha256:1bf3e319465ec8fb465baae3f6ba9a5b09cb84a5349a675c671a552fc77f2251" \ + ./experimental/tflite/benchmark_tflite.sh \ + "${TARGET_DEVICE}" \ + "${TFLITE_TOOL_DIR}" \ + "${RESULTS_PATH}" + gcloud storage cp "${RESULTS_PATH}" "${RESULTS_GCS_DIR}/" + + benchmark_on_pixel-6-pro: + needs: [setup] + runs-on: + - self-hosted # must come first + - runner-group=${{ needs.setup.outputs.runner-group }} + - environment=prod + - machine-type=pixel-6-pro + env: + BENCHMARK_GCS_DIR: ${{ needs.setup.outputs.benchmark-gcs-dir }} + RESULTS_DIR: results-dir + TARGET_DEVICE: pixel-6-pro + TFLITE_TOOL_DIR: tool-dir + steps: + - name: "Checking out PR repository" + uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - name: "Setup" + id: setup + run: | + echo "results-gcs-dir=${BENCHMARK_GCS_DIR}/${TARGET_DEVICE}-results" >> "${GITHUB_OUTPUT}" + mkdir "${RESULTS_DIR}" + mkdir "${TFLITE_TOOL_DIR}" + - name: "Benchmarking TFLite" + env: + TFLITE_RESULTS_JSON: tflite.json + RESULTS_GCS_DIR: ${{ steps.setup.outputs.results-gcs-dir }} + run: | + ./experimental/iree/benchmark_iree.sh + + #RESULTS_PATH="${RESULTS_DIR}/${TFLITE_RESULTS_JSON}" + #./experimental/tflite/benchmark_tflite.sh "${TARGET_DEVICE}" "${TFLITE_TOOL_DIR}" "${RESULTS_PATH}" + #gcloud storage cp "${RESULTS_PATH}" "${RESULTS_GCS_DIR}/" diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py index 5c048aea..79737c32 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py @@ -10,7 +10,7 @@ from openxla.benchmark import def_types from openxla.benchmark.comparative_suite import utils -PARENT_GCS_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/" +PARENT_GCS_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230829_1696537918/" ARTIFACTS_DIR_URL_TEMPLATE = string.Template(PARENT_GCS_DIR + "${name}") # T5-Large models. diff --git a/experimental/ggml/benchmark_ggml.sh b/experimental/ggml/benchmark_ggml.sh index de2778b9..5cdcf44d 100755 --- a/experimental/ggml/benchmark_ggml.sh +++ b/experimental/ggml/benchmark_ggml.sh @@ -88,8 +88,8 @@ if [[ "${TARGET_DEVICE_NAME}" =~ ^(pixel-4|pixel-6-pro|moto-edge-x30)$ ]]; then adb shell "su root sh /data/local/tmp/set_android_scaling_governor.sh performance" else BENCHMARK_SCRIPT="run_benchmarks.py" - # c2-standard-16 has 16 cores. - THREADS="1,8,16" + # c2-standard-16 has 8 cores. + THREADS="1,8" args+=( --threads "${THREADS}" diff --git a/experimental/iree/benchmark_iree.sh b/experimental/iree/benchmark_iree.sh new file mode 100755 index 00000000..a843f122 --- /dev/null +++ b/experimental/iree/benchmark_iree.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# This is a temporary hack to run IREE benchmarks on pixel-6-pro since +# it's currently not working in the IREE repo. + +ROOT_DIR=/tmp/iree-benchmarks +TD="$(cd $(dirname $0) && pwd)" + +rm -rf "${ROOT_DIR}" +mkdir "${ROOT_DIR}" +pushd "${ROOT_DIR}" + +# Download benchmark tool. +gsutil cp "gs://iree-github-actions-presubmit-artifacts/6464567954/1/benchmark-tools/android-armv8.2-a-benchmark-tools.tar" . +tar -xf "android-armv8.2-a-benchmark-tools.tar" +adb push "android-armv8.2-a-benchmark-tools-dir/build/tools/iree-benchmark-module" "/data/local/tmp" +adb shell "chmod +x /data/local/tmp/iree-benchmark-module" + +# Download vmfb's. + + +# Setup environment. +adb push "${TD}/set_android_scaling_governor.sh" "/data/local/tmp" +adb shell "chmod +x /data/local/tmp/set_android_scaling_governor.sh" +adb shell "su root sh /data/local/tmp/set_android_scaling_governor.sh performance" + +# Benchmark. +ITERATIONS=10 +gsutil cp "gs://iree-github-actions-presubmit-artifacts/6464567954/1/e2e-test-artifacts/iree_module_BertLarge_Fp32_Batch1_tflite___armv8.2-a-generic-linux_android29-llvm_cpu__experimental-flags_data-tiling_ukernel_/module.vmfb" "BertLarge_Batch1.vmfb" +adb push "BertLarge_Batch1.vmfb" "/data/local/tmp" +adb shell "taskset f0 /data/local/tmp/iree-benchmark-module --function=main --input=1x384xi32=0 --input=1x384xi32=0 --device_allocator=caching --task_topology_group_count=4 --device=local-task --module=/data/local/tmp/BertLarge_Batch1.vmfb --time_unit=ns --benchmark_format=json --benchmark_out_format=json --print_statistics=true --benchmark_repetitions=${ITERATIONS}" +adb shell "rm /data/local/tmp/BertLarge_Batch1.vmfb" +rm "BertLarge_Batch1.vmfb" + +gsutil cp "gs://iree-github-actions-presubmit-artifacts/6464567954/1/e2e-test-artifacts/iree_module_BertLarge_Fp32_Batch16_tflite___armv8.2-a-generic-linux_android29-llvm_cpu__experimental-flags_data-tiling_ukernel_/module.vmfb" "BertLarge_Batch16.vmfb" +adb push "BertLarge_Batch16.vmfb" "/data/local/tmp" +adb shell "taskset f0 /data/local/tmp/iree-benchmark-module --function=main --input=16x384xi32=0 --input=16x384xi32=0 --device_allocator=caching --task_topology_group_count=4 --device=local-task --module=/data/local/tmp/BertLarge_Batch16.vmfb --time_unit=ns --benchmark_format=json --benchmark_out_format=json --print_statistics=true --benchmark_repetitions=${ITERATIONS}" +adb shell "rm /data/local/tmp/BertLarge_Batch16.vmfb" +rm "BertLarge_Batch16.vmfb" + +gsutil cp "gs://iree-github-actions-presubmit-artifacts/6464567954/1/e2e-test-artifacts/iree_module_BertLarge_Fp32_Batch24_tflite___armv8.2-a-generic-linux_android29-llvm_cpu__experimental-flags_data-tiling_ukernel_/module.vmfb" "BertLarge_Batch24.vmfb" +adb push "BertLarge_Batch24.vmfb" "/data/local/tmp" +adb shell "taskset f0 /data/local/tmp/iree-benchmark-module --function=main --input=24x384xi32=0 --input=24x384xi32=0 --device_allocator=caching --task_topology_group_count=4 --device=local-task --module=/data/local/tmp/BertLarge_Batch24.vmfb --time_unit=ns --benchmark_format=json --benchmark_out_format=json --print_statistics=true --benchmark_repetitions=${ITERATIONS}" +adb shell "rm /data/local/tmp/BertLarge_Batch24.vmfb" +rm "BertLarge_Batch24.vmfb" + +gsutil cp "gs://iree-github-actions-presubmit-artifacts/6464567954/1/e2e-test-artifacts/iree_module_BertLarge_Fp32_Batch32_tflite___armv8.2-a-generic-linux_android29-llvm_cpu__experimental-flags_data-tiling_ukernel_/module.vmfb" "BertLarge_Batch32.vmfb" +adb push "BertLarge_Batch32.vmfb" "/data/local/tmp" +adb shell "taskset f0 /data/local/tmp/iree-benchmark-module --function=main --input=32x384xi32=0 --input=32x384xi32=0 --device_allocator=caching --task_topology_group_count=4 --device=local-task --module=/data/local/tmp/BertLarge_Batch32.vmfb --time_unit=ns --benchmark_format=json --benchmark_out_format=json --print_statistics=true --benchmark_repetitions=${ITERATIONS}" +adb shell "rm /data/local/tmp/BertLarge_Batch32.vmfb" +rm "BertLarge_Batch32.vmfb" + +adb shell "rm -rf /data/local/tmp/*" + +popd +rm -rf "${ROOT_DIR}" + + + + + diff --git a/experimental/iree/set_android_scaling_governor.sh b/experimental/iree/set_android_scaling_governor.sh new file mode 100755 index 00000000..9f51e273 --- /dev/null +++ b/experimental/iree/set_android_scaling_governor.sh @@ -0,0 +1,51 @@ +#!/bin/sh + +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Runs on an android device itself to set the frequency scaling governor for all +# CPUs (default performance). + +################################### WARNING #################################### +# This will overheat the phone if it's not on a cooling plate, resulting in # +# thermal throttling. To prevent anything catching on fire, the actual CPU # +# frequencies will be throttled to below the maximum, skewing your results. # +################################################################################ + +set -euo pipefail + +GOVERNOR="${1:-performance}" + +echo "CPU info (before changing governor):" +echo 'cpu\tgovernor\tcur\tmin\tmax' +echo "------------------------------------------------" +for i in `cat /sys/devices/system/cpu/present | tr '-' ' ' | xargs seq`; do \ + echo "cpu${i}" | paste \ + - \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/scaling_governor" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_cur_freq" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_min_freq" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_max_freq"; \ +done + +echo "Setting CPU frequency governor to ${GOVERNOR}" + +for i in `cat /sys/devices/system/cpu/present | tr '-' ' ' | xargs seq`; do \ + echo "${GOVERNOR}" > \ + "/sys/devices/system/cpu/cpu${i?}/cpufreq/scaling_governor"; \ +done + +echo "CPU info (after changing governor):" +echo 'cpu\tgovernor\tcur\tmin\tmax' +echo "------------------------------------------------" +for i in `cat /sys/devices/system/cpu/present | tr '-' ' ' | xargs seq`; do \ + echo "cpu${i}" | paste \ + - \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/scaling_governor" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_cur_freq" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_min_freq" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_max_freq"; \ +done diff --git a/experimental/tflite/benchmark_lib.py b/experimental/tflite/benchmark_lib.py new file mode 100644 index 00000000..48bd217f --- /dev/null +++ b/experimental/tflite/benchmark_lib.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import dataclasses +import json +import pathlib +import re +import subprocess +import sys +from typing import Sequence, List + +# Add common_benchmark_suite dir to the search path. +sys.path.insert( + 0, str(pathlib.Path(__file__).parents[2] / "common_benchmark_suite")) +from openxla.benchmark import def_types, devices +from openxla.benchmark.comparative_suite.tf import benchmark_definitions as tf_benchmark_definitions + +# Add common_benchmark_suite dir to the search path. +sys.path.insert( + 0, str(pathlib.Path(__file__).parents[2] / "comparative_benchmark")) +import utils + +ALL_DEVICE_NAMES = [device.name for device in devices.ALL_DEVICES] +TFLITE_FP32_FILENAME = "model_fp32.tflite" + +MIN_MAX_LATENCY_REGEXP = re.compile( + "INFO: count=\d+ first=\d+ curr=\d+ min=(.*) max=(.*) avg=(.*) std=(.*)") +AVG_LATENCY_REGEXP = re.compile( + "INFO: Inference timings in us: .* Inference \(avg\): (.*)") +PEAK_MEMORY_REGEXP = re.compile( + "INFO: Overall peak memory footprint \(MB\) via periodic monitoring: (.*)") + + +def download_artifacts(benchmarks: Sequence[def_types.BenchmarkCase], + root_dir: pathlib.Path, + verbose: bool = False): + """Download benchmark artifacts.""" + download_list = [] + for benchmark in benchmarks: + model = benchmark.model + if (model.artifacts_dir_url is None or + def_types.ModelArtifactType.TFLITE_FP32 + not in model.exported_model_types): + raise ValueError(f"XLA HLO dump isn't provided by '{model.name}'.") + model_url = model.artifacts_dir_url + "/" + TFLITE_FP32_FILENAME + model_path = root_dir / model.name / TFLITE_FP32_FILENAME + download_list.append((model_url, model_path)) + + utils.download_files(download_list, verbose=verbose) + + +def configure_parser(parser: argparse.ArgumentParser): + parser.add_argument("-o", + "--output", + type=pathlib.Path, + required=True, + help="JSON file path to merge the results.") + parser.add_argument("-name", + "--benchmark_name", + required=True, + help="The unique id that defines a benchmark.") + parser.add_argument("-device", + "--target_device", + dest="target_device_name", + type=str, + required=True, + choices=ALL_DEVICE_NAMES, + help="The target device to benchmark.") + parser.add_argument("--tflite-benchmark-tool", + "--tflite_benchmark_tool", + type=pathlib.Path, + required=True, + help="The path to the TFLite `benchmrk_model` tool.") + parser.add_argument("-t", + "--threads", + type=str, + default="1,4", + help="A comma-delimited list of threads.") + parser.add_argument("-iter", + "--iterations", + type=int, + default=10, + help="The number of iterations to benchmark.") + parser.add_argument("--root-dir", + "--root_dir", + type=pathlib.Path, + default=pathlib.Path("/tmp/openxla-benchmark/tflite"), + help="Root directory stores benchmark artifacts.") + parser.add_argument("--no-download", + "--no_download", + action="store_true", + help="Don't automatically download benchmark artifacts.") + parser.add_argument("--verbose", + action="store_true", + help="Show verbose messages.") + + +def benchmark(benchmark_command: List[str], benchmark_definition: dict, + iterations: int, verbose: bool) -> utils.BenchmarkResult: + if verbose: + print(f"Run command: {benchmark_command}") + + result = subprocess.run(benchmark_command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + result_text = result.stdout.decode("utf-8") + + if verbose: + print(result_text) + + min_latency_ms = None + max_latency_ms = None + mean_latency_ms = None + stddev_latency_ms = None + device_memory_peak_mb = None + + match = AVG_LATENCY_REGEXP.search(result_text) + if match: + mean_latency_ms = float(match.group(1)) * 1e-3 + + match = MIN_MAX_LATENCY_REGEXP.search(result_text) + if match: + min_latency_ms = float(match.group(1)) * 1e-3 + max_latency_ms = float(match.group(2)) * 1e-3 + stddev_latency_ms = float(match.group(4)) * 1e-3 + + match = PEAK_MEMORY_REGEXP.search(result_text) + if match: + device_memory_peak_mb = float(match.group(1)) + + metrics = { + "min_latency_ms": min_latency_ms, + "max_latency_ms": max_latency_ms, + "mean_latency_ms": mean_latency_ms, + "stddev_latency_ms": stddev_latency_ms, + "benchmark_iterations": iterations, + "device_memory_peak_mb": device_memory_peak_mb, + } + + return utils.BenchmarkResult( + definition=benchmark_definition, + metrics={ + "compiler_level": metrics, + }, + ) diff --git a/experimental/tflite/benchmark_tflite.sh b/experimental/tflite/benchmark_tflite.sh new file mode 100755 index 00000000..dd301bae --- /dev/null +++ b/experimental/tflite/benchmark_tflite.sh @@ -0,0 +1,96 @@ +#!/bin/bash +# +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# Environment variables: +# PYTHON: Python interpreter, default: /usr/bin/python3 +# OOBI_TARGET_DEVICE: target benchmark device, can also be specified the first +# argument. +# OOBI_TOOL_DIR: path to save benchmark tools. +# OOBI_OUTPUT: path to output benchmark results, can also be specified the +# third argument. +# +# Example usage: +# ./benchmark_tflite.sh + +set -xeuo pipefail + +VENV_DIR="${OOBI_VENV_DIR:-tflite-benchmarks.venv}" +PYTHON="${PYTHON:-"$(which python3)"}" +TARGET_DEVICE="${1:-${OOBI_TARGET_DEVICE}}" +TOOL_DIR="${2:-${OOBI_TOOL_DIR}}" +OUTPUT_PATH="${3:-${OOBI_OUTPUT}}" + +# Setup virtual environment. +TD="$(cd $(dirname $0) && pwd)" +VENV_DIR="${VENV_DIR}" PYTHON="${PYTHON}" source "${TD}/setup_venv.sh" + +# Initialize results json. +OUTPUT_PATH="$(realpath ${OUTPUT_PATH})" +"${TD}/../../comparative_benchmark/scripts/create_results_json.sh" "${OUTPUT_PATH}" + +declare -a args=( + --tflite_benchmark_tool "${TOOL_DIR}/tflite_benchmark_model" + --output "${OUTPUT_PATH}" + --target_device "${TARGET_DEVICE}" + --verbose +) + +# Download TFLite benchmark tool depending on target. +if [[ "${TARGET_DEVICE}" =~ ^(pixel-4|pixel-6-pro|moto-edge-x30)$ ]]; then + wget -O "${TOOL_DIR}/tflite_benchmark_model" https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_arm_benchmark_model_plus_flex + + # Setup mobile device for benchmarking. + adb push "${TD}/set_android_scaling_governor.sh" "/data/local/tmp" + adb shell "chmod +x /data/local/tmp/set_android_scaling_governor.sh" + adb shell "su root sh /data/local/tmp/set_android_scaling_governor.sh performance" + + BENCHMARK_SCRIPT="run_benchmarks_android.py" + + # Pixel 6 has a maximum of 8 cores. + THREADS="1,4" + TASKSETS="80,f0" + + args+=( + --threads "${THREADS}" + --tasksets "${TASKSETS}" + --iterations 10 + ) + + declare -a BENCHMARK_NAMES=( + #"models/RESNET50_FP32_TF_.+_BATCH(1|8)/.+" + "models/BERT_LARGE_FP32_TF_.+_BATCH(1|16|24|32)/.+" + #"models/T5_LARGE_FP32_TF_.+_BATCH(1|16)/.+" + #"models/EFFICIENTNETB7_FP32_TF_.+_BATCH1/.+" + ) +else + wget -O "${TOOL_DIR}/tflite_benchmark_model" https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_x86-64_benchmark_model_plus_flex + chmod +x "${TOOL_DIR}/tflite_benchmark_model" + + BENCHMARK_SCRIPT="run_benchmarks.py" + + # c2-standard-16 has 8 cores. + THREADS="1,8" + args+=( + --threads "${THREADS}" + --iterations 20 + ) + + declare -a BENCHMARK_NAMES=( + #"models/RESNET50_FP32_TF_.+_BATCH(1|64|128)/.+" + "models/BERT_LARGE_FP32_TF_.+_BATCH(1|16|24|32)/.+" + #"models/T5_LARGE_FP32_TF_.+_BATCH(1|16|24|32)/.+" + #"models/EFFICIENTNETB7_FP32_TF_.+_BATCH(1|64|128)/.+" + ) +fi + +for i in ${!BENCHMARK_NAMES[@]}; do + args+=( + --benchmark_name "${BENCHMARK_NAMES[$i]}" + ) + "${TD}/${BENCHMARK_SCRIPT}" "${args[@]}" +done diff --git a/experimental/tflite/requirements.txt b/experimental/tflite/requirements.txt new file mode 100644 index 00000000..945b4703 --- /dev/null +++ b/experimental/tflite/requirements.txt @@ -0,0 +1,2 @@ +numpy +requests diff --git a/experimental/tflite/run_benchmarks.py b/experimental/tflite/run_benchmarks.py new file mode 100755 index 00000000..9550caa3 --- /dev/null +++ b/experimental/tflite/run_benchmarks.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import dataclasses +import json +import pathlib +import re +import subprocess +import sys +from typing import Sequence + +import benchmark_lib + +# Add common_benchmark_suite dir to the search path. +sys.path.insert( + 0, str(pathlib.Path(__file__).parents[2] / "common_benchmark_suite")) +from openxla.benchmark import def_types, devices +from openxla.benchmark.comparative_suite.tf import benchmark_definitions as tf_benchmark_definitions + +# Add common_benchmark_suite dir to the search path. +sys.path.insert( + 0, str(pathlib.Path(__file__).parents[2] / "comparative_benchmark")) +import utils + +ALL_DEVICE_NAMES = [device.name for device in devices.ALL_DEVICES] +TFLITE_FP32_FILENAME = "model_fp32.tflite" + +LATENCY_REGEXP = re.compile( + "INFO: count=\d+ first=\d+ curr=\d+ min=(.*) max=(.*) avg=(.*) std=(.*)") +PEAK_MEMORY_REGEXP = re.compile( + "INFO: Overall peak memory footprint \(MB\) via periodic monitoring: (.*)") + + +def _run( + benchmark: def_types.BenchmarkCase, + target_device: def_types.DeviceSpec, + iterations: int, + num_threads: str, + tflite_benchmark_tool: pathlib.Path, + tflite_model_path: pathlib.Path, + verbose: bool, +) -> utils.BenchmarkResult: + model = benchmark.model + data_type = model.model_parameters["data_type"] + batch_size = model.model_parameters["batch_size"] + benchmark_definition = { + "benchmark_name": benchmark.name, + "framework": str(model.model_impl.framework_type), + "data_type": data_type, + "batch_size": batch_size, + "compiler": "TFLite", + "device": target_device.name, + "num_threads": num_threads, + "num_iterations": iterations, + "tags": model.model_impl.tags + model.tags, + } + cmd = [ + tflite_benchmark_tool, + f"--graph={tflite_model_path}", + f"--num_runs={iterations}", + f"--num_threads={num_threads}", + f"--report_peak_memory_footprint=true", + ] + + return benchmark_lib.benchmark(cmd, benchmark_definition, iterations, verbose) + + +def _parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run TFLite benchmarks.") + benchmark_lib.configure_parser(parser) + return parser.parse_args() + + +def main( + benchmark_name: str, + target_device_name: str, + output: pathlib.Path, + root_dir: pathlib.Path, + threads: str, + tflite_benchmark_tool: pathlib.Path, + iterations: int, + no_download: bool, + verbose: bool, +): + name_pattern = re.compile(f"^{benchmark_name}$") + all_benchmarks = tf_benchmark_definitions.ALL_BENCHMARKS + benchmarks = [ + benchmark for benchmark in all_benchmarks + if name_pattern.match(benchmark.name) + ] + + if not benchmarks: + all_benchmark_names = "\n".join( + benchmark.name for benchmark in all_benchmarks) + raise ValueError(f'No benchmark matches "{benchmark_name}".' + f' Available benchmarks:\n{all_benchmark_names}') + + try: + target_device = next(device for device in devices.ALL_DEVICES + if device.name == target_device_name) + except StopIteration: + raise ValueError(f'Target device "{target_device_name}" is not defined.' + f' Available device options:\n{ALL_DEVICE_NAMES}') + + if not no_download: + benchmark_lib.download_artifacts(benchmarks=benchmarks, + root_dir=root_dir, + verbose=verbose) + + threads = threads.split(",") + for benchmark in benchmarks: + tflite_model_path = root_dir / benchmark.model.name / TFLITE_FP32_FILENAME + if not tflite_model_path.exists(): + raise ValueError(f"TFLite model not found: '{tflite_model_path}'.") + + for num_threads in threads: + result = _run(benchmark=benchmark, + target_device=target_device, + iterations=iterations, + num_threads=num_threads, + tflite_benchmark_tool=tflite_benchmark_tool, + tflite_model_path=tflite_model_path, + verbose=verbose) + if verbose: + print(json.dumps(dataclasses.asdict(result), indent=2)) + + utils.append_benchmark_result(output, result) + + +if __name__ == "__main__": + main(**vars(_parse_arguments())) diff --git a/experimental/tflite/run_benchmarks_android.py b/experimental/tflite/run_benchmarks_android.py new file mode 100755 index 00000000..ea742d5b --- /dev/null +++ b/experimental/tflite/run_benchmarks_android.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import dataclasses +import json +import pathlib +import re +import subprocess +import sys +from typing import Sequence + +import benchmark_lib + +# Add common_benchmark_suite dir to the search path. +sys.path.insert( + 0, str(pathlib.Path(__file__).parents[2] / "common_benchmark_suite")) +from openxla.benchmark import def_types, devices +from openxla.benchmark.comparative_suite.tf import benchmark_definitions as tf_benchmark_definitions + +# Add common_benchmark_suite dir to the search path. +sys.path.insert( + 0, str(pathlib.Path(__file__).parents[2] / "comparative_benchmark")) +import utils + +ALL_DEVICE_NAMES = [device.name for device in devices.ALL_DEVICES] +TFLITE_FP32_FILENAME = "model_fp32.tflite" + +LATENCY_REGEXP = re.compile( + "INFO: count=\d+ first=\d+ curr=\d+ min=(.*) max=(.*) avg=(.*) std=(.*)") +PEAK_MEMORY_REGEXP = re.compile( + "INFO: Overall peak memory footprint \(MB\) via periodic monitoring: (.*)") + + +def _run( + benchmark: def_types.BenchmarkCase, + target_device: def_types.DeviceSpec, + iterations: int, + num_threads: str, + taskset: str, + tflite_benchmark_tool: pathlib.Path, + tflite_model_path: pathlib.Path, + verbose: bool, +) -> utils.BenchmarkResult: + model = benchmark.model + data_type = model.model_parameters["data_type"] + batch_size = model.model_parameters["batch_size"] + benchmark_definition = { + "benchmark_name": benchmark.name, + "framework": str(model.model_impl.framework_type), + "data_type": data_type, + "batch_size": batch_size, + "compiler": "TFLite", + "device": target_device.name, + "num_threads": num_threads, + "num_iterations": iterations, + "tags": model.model_impl.tags + model.tags, + } + cmd = [ + "adb", + "shell", + "taskset", + taskset, + tflite_benchmark_tool, + f"--graph={tflite_model_path}", + f"--num_runs={iterations}", + f"--num_threads={num_threads}", + f"--report_peak_memory_footprint=true", + ] + + return benchmark_lib.benchmark(cmd, benchmark_definition, iterations, verbose) + + +def _parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run TFLite benchmarks.") + parser.add_argument( + "--tasksets", + type=str, + default="f0", + help= + "A comma-separated list of tasksets to run under each thread configuration." + ) + benchmark_lib.configure_parser(parser) + return parser.parse_args() + + +def main( + benchmark_name: str, + target_device_name: str, + output: pathlib.Path, + root_dir: pathlib.Path, + threads: str, + tasksets: str, + tflite_benchmark_tool: pathlib.Path, + iterations: int, + no_download: bool, + verbose: bool, +): + name_pattern = re.compile(f"^{benchmark_name}$") + all_benchmarks = tf_benchmark_definitions.ALL_BENCHMARKS + benchmarks = [ + benchmark for benchmark in all_benchmarks + if name_pattern.match(benchmark.name) + ] + + if not benchmarks: + all_benchmark_names = "\n".join( + benchmark.name for benchmark in all_benchmarks) + raise ValueError(f'No benchmark matches "{benchmark_name}".' + f' Available benchmarks:\n{all_benchmark_names}') + + try: + target_device = next(device for device in devices.ALL_DEVICES + if device.name == target_device_name) + except StopIteration: + raise ValueError(f'Target device "{target_device_name}" is not defined.' + f' Available device options:\n{ALL_DEVICE_NAMES}') + + if not no_download: + benchmark_lib.download_artifacts(benchmarks=benchmarks, + root_dir=root_dir, + verbose=verbose) + + threads = threads.split(",") + tasksets = tasksets.split(",") + if len(threads) != len(tasksets): + raise ValueError( + "The number of tasksets specified must be equal to the number of threads." + ) + + # Push artifacts to the Android device. + subprocess.run(["adb", "push", tflite_benchmark_tool, "/data/local/tmp"]) + subprocess.run([ + "adb", "shell", "chmod", "+x", + f"/data/local/tmp/{tflite_benchmark_tool.name}" + ]) + + for benchmark in benchmarks: + tflite_model_path = root_dir / benchmark.model.name / TFLITE_FP32_FILENAME + if not tflite_model_path.exists(): + raise ValueError(f"TFLite model not found: '{tflite_model_path}'.") + subprocess.run(["adb", "push", tflite_model_path, "/data/local/tmp"]) + + for taskset, num_threads in zip(tasksets, threads): + result = _run(benchmark=benchmark, + target_device=target_device, + iterations=iterations, + num_threads=num_threads, + taskset=taskset, + tflite_benchmark_tool=pathlib.Path( + f"/data/local/tmp/{tflite_benchmark_tool.name}"), + tflite_model_path=pathlib.Path( + f"/data/local/tmp/{tflite_model_path.name}"), + verbose=verbose) + if verbose: + print(json.dumps(dataclasses.asdict(result), indent=2)) + + utils.append_benchmark_result(output, result) + + subprocess.run( + ["adb", "shell", "rm", f"/data/local/tmp/{tflite_model_path.name}"]) + + subprocess.run( + ["adb", "shell", "rm", f"/data/local/tmp/{tflite_benchmark_tool.name}"]) + + +if __name__ == "__main__": + main(**vars(_parse_arguments())) diff --git a/experimental/tflite/set_android_scaling_governor.sh b/experimental/tflite/set_android_scaling_governor.sh new file mode 100755 index 00000000..9f51e273 --- /dev/null +++ b/experimental/tflite/set_android_scaling_governor.sh @@ -0,0 +1,51 @@ +#!/bin/sh + +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Runs on an android device itself to set the frequency scaling governor for all +# CPUs (default performance). + +################################### WARNING #################################### +# This will overheat the phone if it's not on a cooling plate, resulting in # +# thermal throttling. To prevent anything catching on fire, the actual CPU # +# frequencies will be throttled to below the maximum, skewing your results. # +################################################################################ + +set -euo pipefail + +GOVERNOR="${1:-performance}" + +echo "CPU info (before changing governor):" +echo 'cpu\tgovernor\tcur\tmin\tmax' +echo "------------------------------------------------" +for i in `cat /sys/devices/system/cpu/present | tr '-' ' ' | xargs seq`; do \ + echo "cpu${i}" | paste \ + - \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/scaling_governor" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_cur_freq" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_min_freq" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_max_freq"; \ +done + +echo "Setting CPU frequency governor to ${GOVERNOR}" + +for i in `cat /sys/devices/system/cpu/present | tr '-' ' ' | xargs seq`; do \ + echo "${GOVERNOR}" > \ + "/sys/devices/system/cpu/cpu${i?}/cpufreq/scaling_governor"; \ +done + +echo "CPU info (after changing governor):" +echo 'cpu\tgovernor\tcur\tmin\tmax' +echo "------------------------------------------------" +for i in `cat /sys/devices/system/cpu/present | tr '-' ' ' | xargs seq`; do \ + echo "cpu${i}" | paste \ + - \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/scaling_governor" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_cur_freq" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_min_freq" \ + "/sys/devices/system/cpu/cpu${i}/cpufreq/cpuinfo_max_freq"; \ +done diff --git a/experimental/tflite/setup_venv.sh b/experimental/tflite/setup_venv.sh new file mode 100644 index 00000000..e7bd3960 --- /dev/null +++ b/experimental/tflite/setup_venv.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# Sets up a virtual environment suitable for running TFLite benchmarks. +# +# Environment variables: +# VENV_DIR=tflite-benchmarks.venv +# PYTHON=/usr/bin/python3.10 + +set -euo pipefail + +TD="$(cd $(dirname $0) && pwd)" +VENV_DIR="${VENV_DIR:-tflite-benchmarks.venv}" +PYTHON="${PYTHON:-"$(which python3)"}" + +echo "Setting up venv dir: ${VENV_DIR}" + +"${PYTHON}" -m venv "${VENV_DIR}" || echo "Could not create venv." +source "${VENV_DIR}/bin/activate" || echo "Could not activate venv" + +# Upgrade pip and install requirements. 'python' is used here in order to +# reference to the python executable from the venv. +python -m pip install --upgrade pip || echo "Could not upgrade pip" +python -m pip install --upgrade -r "${TD}/requirements.txt" + +echo "Activate venv with:" +echo " source ${VENV_DIR}/bin/activate"