Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dl/ov/tiny gpt2 example callbacks #20

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def run_benchmark(path_to_model: str, shape: Optional[List[int]] = None, verbose
# >> output_names = [output.name for output in sess.get_outputs()]
# >> for data_item in val_loader:
# >> sess.run(output_names, input_feed=transform_fn(data_item))

input_name = model.graph.input[0].name


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{
"compression": {
"algorithms": [
{
"name": "DefaultQuantization",
"params": {
"preset": "performance",
"stat_subset_size": 3
}
}
],
"dump_intermediate_model": true
},
"engine": {
"datasets": [
{
"metrics": [
{
"type": "wer"
}
],
"name": "LibriSpeech_test_clean_wav",
"data_source": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav",

"annotation_conversion": {
"converter": "librispeech",
"data_dir": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav"
},
"preprocessing": [
{
"int16mode": true,
"type": "audio_normalization"
},
{
"duration": "512 samples",
"overlap": "192 samples",
"type": "clip_audio"
},
{
"base": 512,
"type": "hanning_window"
},
{
"fftbase": 512,
"magnitude_squared": true,
"skip_channels": true,
"type": "audio_spectrogram"
},
{
"base": 257,
"filterbank_channel_count": 40,
"lower_frequency_limit": 20,
"sample_rate": 16000,
"type": "audio_triangle_filtering",
"upper_frequency_limit": 4000
},
{
"filterbank_channel_count": 40,
"numceps": 26,
"type": "audio_dct"
},
{
"context": 9,
"numceps": 26,
"type": "clip_cepstrum"
},
{
"step": 16,
"type": "pack_cepstrum"
}
],
"reader": "wav_reader"
}
],
"launchers": [
{
"adapter": {
"beam_size": 32,
"lm_alpha": 0.75,
"lm_beta": 1.05,
"lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary",
"lm_oov_score": -1000,
"lm_vocabulary_length": 4463723,
"lm_vocabulary_offset": 941235601,
"logarithmic_prob": false,
"probability_out": "logits",
"type": "ctc_beam_search_decoder_with_lm"
},
"framework": "dlsdk",
"inputs": [
{
"layout": "NHWC",
"name": "input_node",
"type": "INPUT"
},
{
"name": "previous_state_c",
"type": "LSTM_INPUT",
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.2"
},
{
"name": "previous_state_h",
"type": "LSTM_INPUT",
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.1"
}
]
},
{
"adapter": {
"beam_size": 32,
"lm_alpha": 0.75,
"lm_beta": 1.05,
"lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary",
"lm_oov_score": -1000,
"lm_vocabulary_length": 4463723,
"lm_vocabulary_offset": 941235601,
"logarithmic_prob": false,
"probability_out": "logits",
"type": "ctc_beam_search_decoder_with_lm"
},
"framework": "openvino",
"inputs": [
{
"layout": "NHWC",
"name": "input_node",
"type": "INPUT"
},
{
"name": "previous_state_c",
"type": "LSTM_INPUT",
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd:0"
},
{
"name": "previous_state_h",
"type": "LSTM_INPUT",
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1:0"
}
]
}
]
},
"model": {
"model": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.xml",
"model_name": "mozilla-deepspeech-0.6.1",
"weights": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.bin"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import subprocess

import openvino.runtime as ov
from openvino.tools.accuracy_checker.evaluators.quantization_model_evaluator import create_model_evaluator
from openvino.tools.pot.configs.config import Config

import nncf

model_name = "mozilla-deepspeech-0.6.1"
cache_dir = os.path.dirname(__file__)
dataset_config = os.path.join(cache_dir, "accuracy_checker.json")

command = f"omz_downloader --name {model_name} --cache_dir {cache_dir}"
cmd_output = subprocess.call(command, shell=True) # nosec

model_dir = os.path.join(cache_dir, model_name)
if not os.path.exists(model_dir):
command = f"omz_converter --name {model_name} -o {os.path.join(cache_dir, model_name)}"
cmd_output = subprocess.call(command, shell=True) # nosec

xml_path = os.path.join(model_dir, f"public/{model_name}/FP16/{model_name}.xml")
ov_model = ov.Core().read_model(xml_path)

config = Config.read_config(dataset_config)
config.configure_params()
accuracy_checker_config = config.engine

model_evaluator = create_model_evaluator(accuracy_checker_config)
model_evaluator.load_network([{"model": ov_model}])
model_evaluator.select_dataset("")


def sequence_transform_fn(data_item):
"""
Quantization transform function. Extracts and preprocesses sequential inputs data from dataloader
for quantization, returns iterable on preprocessed elements of feeded data item.

:param data_item: Data item produced by DataLoader during iteration
:return: Iterable object on preprocessed elements of feeded data item.
"""
return data_item


def get_custom_forward(model, callback):
def custom_forward(data_item):
def iter_through_sequence():
_, batch_annotation, batch_input, _ = data_item
filled_inputs, _, _ = model_evaluator._get_batch_input(batch_input, batch_annotation)
for filled_input in filled_inputs:
input_data = {}
for name, value in filled_input.items():
input_data[model_evaluator.launcher.input_to_tensor_name[name]] = value
yield input_data

model_outputs = None
for model_inputs in iter_through_sequence():
state_inputs = model_evaluator.launcher._fill_lstm_inputs(model_outputs)
model_inputs.update(state_inputs)
model_outputs = model(model_inputs)
callback(model_outputs)

return custom_forward


dataset = nncf.CustomInferenceDataset(model_evaluator.dataset, sequence_transform_fn, get_custom_forward)


quantized_model = nncf.quantize(ov_model, dataset, subset_size=3)
Empty file.
64 changes: 64 additions & 0 deletions examples/post_training_quantization/openvino/tiny_gpt2/main.py
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexKoff88 - main file

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from optimum.intel.openvino import OVModelForCausalLM
from transformers import AutoTokenizer

import nncf

GENERATION_LENGTH = 20


def fix_ov_model_names_duplicates(ov_model):
names = set()
for op in ov_model.get_ops():
friendly_name = op.get_friendly_name()
while True:
if friendly_name not in names:
break
friendly_name += "_"
names.add(friendly_name)
op.set_friendly_name(friendly_name)


model_id = "hf-internal-testing/tiny-random-gpt2"
# model_id = "hf-internal-testing/tiny-random-GPTNeoModel"
# model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt")

model_with_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True)


def set_ov_model_in_hf_model(hf_model, ov_model):
hf_model.model = ov_model
hf_model.request = ov_model.create_infer_request()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that ov_model has a type of ov::Model. If so, .create_infer_request() works only for CompiledModel

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right



def get_custom_forward(ov_model, callback_fn):
hf_model = model_with_pkv
set_ov_model_in_hf_model(hf_model, ov_model)

def _callback_fn(info):
outputs = {k: v for k, v in zip(info["infer_request"].model_outputs, info["infer_request"].outputs)}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the InferRequest object have .model_outputs property?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callback_fn(outputs)

hf_model.request.set_callback(_callback_fn, {"infer_request": hf_model.request})

def custom_forward(dataitem):
hf_model.generate(**dataitem, min_length=GENERATION_LENGTH, max_length=GENERATION_LENGTH, num_beams=1)

return custom_forward


def transform_fn(data_item):
return data_item


dataset = nncf.CustomInferenceDataset([tokens] * 10, transform_fn, get_custom_forward)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should make get_custom_forward a part of Dataset API. I propose:

  • rename it to get_forward_fn(model: ov.Model, output_processing_callback: Callable) -> Callable
  • make it an optional argument of nncf.quantize() API

Copy link

@alexsu52 alexsu52 May 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I absolutely agree that it should not be part of Dataset API.

Comments from my side:
I have some concerns about get_forward_fn:

  1. output_processing_callback is not needed for Torch and Keras TF models. It can confuse developer because they will call output_processing_callback that does not do it anything.
  2. signature of output_processing_callback is not clear for different frameworks.

Proposal:

  1. Introduce get_forward_fn(model: ov.Model) -> Callable Torch and Keras TF and get_forward_fn(model: ov.Model, statistic_aggregator: StatisticsAggregator) -> Callable for OpenVINO, ONNX and TF.

Pros:

  • It is addressed to 1 via explicit introduction different signatures for different frameworks because different frameworks collect statistics with using different approaches.

  • It is addressed to 2 because methods of a class can be easily documented + sugar from IDE. It also can provide several interfaces to register model output. statistic_collector.register_model_output(name, tensor) ``statistic_collector.register_model_outputs(outputs: Dict[str, tensor])
    Request changes:

def get_custom_forward(ov_model, statistic_aggregator):
    hf_model = model_with_pkv
    set_ov_model_in_hf_model(hf_model, ov_model)

    def _callback_fn(info):
        outputs = {k.key.get_any_name(): v.value for k, v in zip(info["infer_request"].model_outputs, info["infer_request"].outputs)}
        statistic_aggregator.register_model_outputs(outputs)
  1. Introduce a different classes to join framework model and custom forward function for each framework. For examplenncf.OVModelWithCustomForward(model: ov.Model, get_forward_fn: Callable) for OV

Pros:

  • nncf.quantize and nncf.quantize_with_accuracy_control w/o extending signature
  • The class explicitly specified signature of get_forward_fn for framework model.
  • Easy reuse in other algorithms
ov_model_with_custom_forward = nncf.OVModelWithCustomForward(model_with_pkv.model, get_forward_fn)
quantized_model_with_custom_forward = nncf.quantize(ov_model_with_custom_forward, dataset, subset_size=3)
  1. IMHO: rename get_forward_fn -> make_forward_fn

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To tell the truth, I am still skeptical about the whole approach of collecting recurrent states and how this is applicable to other models. Now, I am looking at the Whisper notebook and I would not use this API since it requires much more effort and code rewriting to use the proposed API.



# Fix ov model duplicated names:
fix_ov_model_names_duplicates(model_with_pkv.model)
quantized_model = quantized_model = nncf.quantize(model_with_pkv.model, dataset, subset_size=3)

model_with_pkv.model = quantized_model
model_with_pkv.request = None
1 change: 1 addition & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from nncf.common.logging.logger import disable_logging
from nncf.common.logging.logger import set_log_level
from nncf.config import NNCFConfig
from nncf.data import CustomInferenceDataset
from nncf.data import Dataset
from nncf.parameters import DropType
from nncf.parameters import ModelType
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create(model: TModel) -> NNCFGraph:
from nncf.onnx.graph.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.OPENVINO:
if model_backend in [BackendType.OPENVINO]:
from nncf.openvino.graph.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
Expand Down
Loading