Skip to content

Commit

Permalink
Generate TFLite files from Tensorflow models
Browse files Browse the repository at this point in the history
  • Loading branch information
mariecwhite committed Oct 5, 2023
1 parent f2b8505 commit 1ea715d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
def_types.ModelArtifactType.TF_SAVEDMODEL_V2,
def_types.ModelArtifactType.TFLITE_FP32,
],
)
T5_LARGE_FP32_TF_512XI32_BATCHES = utils.build_batch_models(
Expand Down Expand Up @@ -69,6 +70,7 @@
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
def_types.ModelArtifactType.TF_SAVEDMODEL_V2,
def_types.ModelArtifactType.TFLITE_FP32,
],
)
BERT_LARGE_FP32_TF_384XI32_BATCHES = utils.build_batch_models(
Expand Down Expand Up @@ -100,6 +102,8 @@
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
def_types.ModelArtifactType.TF_SAVEDMODEL_V2,
def_types.ModelArtifactType.TFLITE_FP32,
def_types.ModelArtifactType.TFLITE_INT8,
],
)
RESNET50_FP32_TF_224X224X3XF32_BATCHES = utils.build_batch_models(
Expand Down Expand Up @@ -130,6 +134,7 @@
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
def_types.ModelArtifactType.TF_SAVEDMODEL_V2,
def_types.ModelArtifactType.TFLITE_FP32,
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pathlib
import re
import multiprocessing
import numpy as np
import shutil
import sys
import tarfile
Expand Down Expand Up @@ -61,6 +62,47 @@ def _generate_mlir(model_dir: pathlib.Path, saved_model_dir: pathlib.Path):
write_bytecode(str(mlir_path), result)


def _generate_tflite(inputs: Tuple[Any, ...], model_dir: pathlib.Path,
saved_model_dir: pathlib.Path):
converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))

# Generate fp32 model.
try:
tflite_model = converter.convert()
tflite_model_path = model_dir.joinpath("model_fp32.tflite")
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)
except Exception as e:
print(f"Failed to generate int8 TFLite model. Exception: {e}")

# Generate int8 model.
try:

def representative_examples():
for _ in range(2):
random_inputs = []
for input in inputs:
random_inputs.append(
np.random.uniform(low=input.dtype.min,
high=input.dtype.max,
size=input.shape).astype(
input.dtype.as_numpy_dtype))
yield random_inputs

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.representative_dataset = representative_examples
converter.inference_type = tf.int8
tflite_model_int8 = converter.convert()
tflite_model_int8_path = model_dir.joinpath("model_int8.tflite")
with open(tflite_model_int8_path, 'wb') as f:
f.write(tflite_model_int8)
except Exception as e:
print(f"Failed to generate int8 TFLite model. Exception: {e}")


def _generate_artifacts(model: def_types.Model, save_dir: pathlib.Path,
auto_upload: bool):
model_dir = save_dir.joinpath(model.name)
Expand All @@ -87,6 +129,7 @@ def _generate_artifacts(model: def_types.Model, save_dir: pathlib.Path,

saved_model_dir = _generate_saved_model(inputs, model_obj, model_dir)
_generate_mlir(model_dir, saved_model_dir)
_generate_tflite(inputs, model_dir, saved_model_dir)

with tarfile.open(model_dir.joinpath("tf-model.tgz"), "w:gz") as tar:
tar.add(f"{saved_model_dir}/", arcname="")
Expand Down
3 changes: 3 additions & 0 deletions common_benchmark_suite/openxla/benchmark/def_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ModelFrameworkType(Enum):
"""Type of framework a model is implemented in."""
TF_V1 = "tensorflow_v1"
TF_V2 = "tensorflow_v2"
TFLITE = "tflite"
PYTORCH = "pytorch"
JAX = "jax"
GGML = "ggml"
Expand Down Expand Up @@ -42,6 +43,8 @@ class ModelArtifactType(Enum):
"""Type of derived model artifact."""
TF_SAVEDMODEL_V1 = "tf_savedmodel_v1"
TF_SAVEDMODEL_V2 = "tf_savedmodel_v2"
TFLITE_FP32 = "tflite_fp32"
TFLITE_INT8 = "tflite_int8"
XLA_HLO_DUMP = "xla_hlo_dump"
STABLEHLO_MLIR = "stablehlo_mlir"
LINALG_MLIR = "linalg_mlir"
Expand Down

0 comments on commit 1ea715d

Please sign in to comment.