Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ][OV] Sequential models support for Engine by TensorCollectorAdapter and by base TensorStatistic #18

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def run_benchmark(path_to_model: str, shape: Optional[List[int]] = None, verbose
# >> output_names = [output.name for output in sess.get_outputs()]
# >> for data_item in val_loader:
# >> sess.run(output_names, input_feed=transform_fn(data_item))

input_name = model.graph.input[0].name


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{
"compression": {
"algorithms": [
{
"name": "DefaultQuantization",
"params": {
"preset": "performance",
"stat_subset_size": 3
}
}
],
"dump_intermediate_model": true
},
"engine": {
"datasets": [
{
"metrics": [
{
"type": "wer"
}
],
"name": "LibriSpeech_test_clean_wav",
"data_source": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav",

"annotation_conversion": {
"converter": "librispeech",
"data_dir": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav"
},
"preprocessing": [
{
"int16mode": true,
"type": "audio_normalization"
},
{
"duration": "512 samples",
"overlap": "192 samples",
"type": "clip_audio"
},
{
"base": 512,
"type": "hanning_window"
},
{
"fftbase": 512,
"magnitude_squared": true,
"skip_channels": true,
"type": "audio_spectrogram"
},
{
"base": 257,
"filterbank_channel_count": 40,
"lower_frequency_limit": 20,
"sample_rate": 16000,
"type": "audio_triangle_filtering",
"upper_frequency_limit": 4000
},
{
"filterbank_channel_count": 40,
"numceps": 26,
"type": "audio_dct"
},
{
"context": 9,
"numceps": 26,
"type": "clip_cepstrum"
},
{
"step": 16,
"type": "pack_cepstrum"
}
],
"reader": "wav_reader"
}
],
"launchers": [
{
"adapter": {
"beam_size": 32,
"lm_alpha": 0.75,
"lm_beta": 1.05,
"lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary",
"lm_oov_score": -1000,
"lm_vocabulary_length": 4463723,
"lm_vocabulary_offset": 941235601,
"logarithmic_prob": false,
"probability_out": "logits",
"type": "ctc_beam_search_decoder_with_lm"
},
"framework": "dlsdk",
"inputs": [
{
"layout": "NHWC",
"name": "input_node",
"type": "INPUT"
},
{
"name": "previous_state_c",
"type": "LSTM_INPUT",
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.2"
},
{
"name": "previous_state_h",
"type": "LSTM_INPUT",
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.1"
}
]
},
{
"adapter": {
"beam_size": 32,
"lm_alpha": 0.75,
"lm_beta": 1.05,
"lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary",
"lm_oov_score": -1000,
"lm_vocabulary_length": 4463723,
"lm_vocabulary_offset": 941235601,
"logarithmic_prob": false,
"probability_out": "logits",
"type": "ctc_beam_search_decoder_with_lm"
},
"framework": "openvino",
"inputs": [
{
"layout": "NHWC",
"name": "input_node",
"type": "INPUT"
},
{
"name": "previous_state_c",
"type": "LSTM_INPUT",
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd:0"
},
{
"name": "previous_state_h",
"type": "LSTM_INPUT",
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1:0"
}
]
}
]
},
"model": {
"model": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.xml",
"model_name": "mozilla-deepspeech-0.6.1",
"weights": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.bin"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import subprocess
from typing import Dict

import numpy as np
import openvino.runtime as ov
from openvino.tools.accuracy_checker.evaluators.quantization_model_evaluator import create_model_evaluator
from openvino.tools.pot.configs.config import Config

import nncf

model_name = "mozilla-deepspeech-0.6.1"
cache_dir = os.path.dirname(__file__)
dataset_config = os.path.join(cache_dir, "accuracy_checker.json")

command = f"omz_downloader --name {model_name} --cache_dir {cache_dir}"
cmd_output = subprocess.call(command, shell=True) # nosec

model_dir = os.path.join(cache_dir, model_name)
if not os.path.exists(model_dir):
command = f"omz_converter --name {model_name} -o {os.path.join(cache_dir, model_name)}"
cmd_output = subprocess.call(command, shell=True) # nosec

xml_path = os.path.join(model_dir, f"public/{model_name}/FP16/{model_name}.xml")
ov_model = ov.Core().read_model(xml_path)

config = Config.read_config(dataset_config)
config.configure_params()
accuracy_checker_config = config.engine

model_evaluator = create_model_evaluator(accuracy_checker_config)
model_evaluator.load_network([{"model": ov_model}])
model_evaluator.select_dataset("")


def sequence_transform_fn(data_item):
"""
Quantization transform function. Extracts and preprocesses sequential inputs data from dataloader
for quantization, returns iterable on preprocessed elements of feeded data item.

:param data_item: Data item produced by DataLoader during iteration
:return: Iterable object on preprocessed elements of feeded data item.
"""
_, batch_annotation, batch_input, _ = data_item
filled_inputs, _, _ = model_evaluator._get_batch_input(batch_input, batch_annotation)
for filled_input in filled_inputs:
input_data = {}
for name, value in filled_input.items():
input_data[model_evaluator.launcher.input_to_tensor_name[name]] = value
yield input_data


def create_model_input_fn(model_inputs, model_outputs):
"""
Combines preprocessed model inputs from `get_tokens_from_sequence_fn` and model outputs
from previous iteration. None is feeded as model outputs on first iteration.

:param model_inputs: Preprocessed model input from `get_token_from_sequence_fn`.
:param model_outputs: Outuputs of target model from previous iteration. None on first iteration.
:return: Dict of acutual model inputs combined from preprocessed model input from `get_token_from_sequence_fn`
and previous model outputs for sequential models.
"""
state_inputs = model_evaluator.launcher._fill_lstm_inputs(model_outputs)
model_inputs.update(state_inputs)
return model_inputs


dataset = nncf.RecurentDataset(model_evaluator.dataset, sequence_transform_fn, create_model_input_fn)

# Check for user
output = None
data_item = next(dataset.get_inference_data())
sequence = sequence_transform_fn(data_item)
for sequence_item in sequence:
input = create_model_input_fn(sequence_item, output)
output = ov_model(input)

quantized_model = nncf.quantize(ov_model, dataset, subset_size=3)
1 change: 1 addition & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from nncf.common.logging.logger import set_log_level
from nncf.config import NNCFConfig
from nncf.data import Dataset
from nncf.data import RecurentDataset
from nncf.parameters import DropType
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
Expand Down
43 changes: 41 additions & 2 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from itertools import islice
from typing import Any, Dict, TypeVar

Expand All @@ -21,6 +22,7 @@
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.data.dataset import Dataset
from nncf.data.dataset import RecurentDataset

TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")
Expand All @@ -31,10 +33,13 @@ class StatisticsAggregator(ABC):
Base class for statistics collection.
"""

STACK_AXIS = 0

def __init__(self, dataset: Dataset):
self.dataset = dataset
self.stat_subset_size = None
self.statistic_points = StatisticPointsContainer()
self._is_sequential = isinstance(dataset, RecurentDataset)

def collect_statistics(self, model: TModel) -> None:
"""
Expand All @@ -46,19 +51,44 @@ def collect_statistics(self, model: TModel) -> None:
model_transformer = ModelTransformerFactory.create(model)

merged_statistics = self._get_merged_statistic_points(self.statistic_points, model)
if self._is_sequential:
merged_statistics = self._adapt_collectors(merged_statistics, self.STACK_AXIS)

transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
model_with_outputs = model_transformer.transform(transformation_layout)
engine = EngineFactory.create(model_with_outputs)
infer_fn = self._infer_sequential if self._is_sequential else self._infer

for input_data in tqdm(
islice(self.dataset.get_inference_data(), self.stat_subset_size),
total=self.stat_subset_size,
desc="Statistics collection",
):
outputs = engine.infer(input_data)
processed_outputs = self._process_outputs(outputs)
processed_outputs = infer_fn(engine, input_data)
self._register_statistics(processed_outputs, merged_statistics)

def _infer(self, engine, input_data):
outputs = engine.infer(input_data)
return self._process_outputs(outputs)

def _infer_sequential(self, engine, sequence):
model_output = None
model_outputs = defaultdict(list)
for token in sequence.get_tokens_iter():
filled_inputs = sequence.fill_inputs(token, model_output)
model_output = engine.infer(filled_inputs)
processed_output = self._process_outputs(model_output)
for output_name, output_value in processed_output.items():
model_outputs[output_name].append(output_value)

# Stack model outputs and return them
stacked_outputs = {}
tensor_processor = self._get_tensor_processor()
for output_name, output_values in model_outputs.items():
stacked_outputs[output_name] = tensor_processor.stack(output_values, axis=self.STACK_AXIS)

return stacked_outputs

def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
"""
Register statistic points for statistics collection and recalculates the maximum number samples
Expand Down Expand Up @@ -115,6 +145,10 @@ def _get_merged_statistic_points(
:return: Merged statistic points container bounded with given statistic point container.
"""

@staticmethod
def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int):
return statistic_points

@staticmethod
@abstractmethod
def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
Expand All @@ -124,3 +158,8 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
:param outputs: raw model outputs
:return: processed model outputs in Dict[str, NNCFTensor] format
"""

@staticmethod
@abstractmethod
def _get_tensor_processor():
pass
2 changes: 2 additions & 0 deletions nncf/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@
# limitations under the License.

from nncf.data.dataset import Dataset
from nncf.data.dataset import RecurentDataset
from nncf.data.dataset import Sequence
29 changes: 28 additions & 1 deletion nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Generic, Iterable, List, Optional, TypeVar
from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar

from nncf.common.utils.api_marker import api

Expand Down Expand Up @@ -115,3 +115,30 @@ def _get_iterator_for_iter(
if idx == indices[pos]:
pos = pos + 1
yield transform_func(data_item)


@api(canonical_alias="nncf.RecurentDataset")
class RecurentDataset(Dataset):
def __init__(self, data_source: Iterable, get_token_from_sequence_fn, fill_sequential_inputs_fn):
def transform_fn_wrapper(data_item):
return Sequence(data_item, get_token_from_sequence_fn, fill_sequential_inputs_fn)

super().__init__(data_source, transform_fn_wrapper)


class Sequence:
def __init__(
self,
raw_sequence,
get_tokens_from_sequence_func: Callable[[DataItem], ModelInput],
fill_sequential_inputs_fn: Callable[[DataItem], ModelInput],
):
self._raw_sequence = raw_sequence
self._get_tokens_from_sequence_func = get_tokens_from_sequence_func
self._fill_sequential_inputs_fn = fill_sequential_inputs_fn

def get_tokens_iter(self):
return self._get_tokens_from_sequence_func(self._raw_sequence)

def fill_inputs(self, token, model_outputs):
return self._fill_sequential_inputs_fn(token, model_outputs)
Loading