-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 559532322 Change-Id: I04d210926dbd37a8b3118104433a8f8b416cfd9b
- Loading branch information
1 parent
a0816d9
commit 58b8c3f
Showing
4 changed files
with
266 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Copyright 2019 Google LLC | ||
# | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
import tensorflow as tf | ||
from tensorflow.keras import constraints | ||
from .quantizers import get_quantizer | ||
|
||
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer | ||
from .qlayers import get_auto_range_constraint_initializer | ||
|
||
|
||
# QKeras needs to support more layers for matrix multiplication and shift | ||
# operations such as in Tranformer. Such layers should be all placed here. | ||
|
||
|
||
class QScaleShift(tf.keras.layers.Layer, PrunableLayer): | ||
"""Quantized scale and shift layer. | ||
output = scale * x + bias where scale and bias are each of shape (1,). | ||
QScaleShift is similar to the special case in QDepthwiseConv2D | ||
where kernel_size=(1,1). However there are several differences: | ||
1) There is no concept of padding and striding in QScaleShift since | ||
it's not a conv layer; | ||
2) QDepthwiseConv2D expected min_ndim=4 for input shape; while QScaleShift | ||
input could be any shape; | ||
3) In QDepthwiseConv2D each output channel has its own weight value; | ||
while QScaleShift share the same weight across the entire input tensor. | ||
4) Since it's not a Conv operation, hardware implementation for | ||
QScaleShift and QDWConv2D is fundamentally different. Therefore it | ||
makes sense to separate them as two different types of layers. | ||
""" | ||
|
||
def __init__(self, | ||
weight_quantizer=None, | ||
bias_quantizer=None, | ||
use_bias=True, | ||
activation=None, | ||
weight_initializer="he_normal", | ||
weight_regularizer=None, | ||
bias_initializer="zeros", | ||
bias_regularizer=None, | ||
**kwargs): | ||
|
||
super().__init__() | ||
self.use_bias = use_bias | ||
self.weight_regularizer = weight_regularizer | ||
self.bias_regularizer = bias_regularizer | ||
|
||
self.weight_quantizer = weight_quantizer | ||
self.bias_quantizer = bias_quantizer | ||
|
||
self.weight_quantizer_internal = get_quantizer(self.weight_quantizer) | ||
self.bias_quantizer_internal = get_quantizer(self.bias_quantizer) | ||
|
||
_, self.weight_initializer = ( | ||
get_auto_range_constraint_initializer( | ||
self.weight_quantizer_internal, None, | ||
weight_initializer)) | ||
|
||
_, self.bias_initializer = ( | ||
get_auto_range_constraint_initializer( | ||
self.bias_quantizer_internal, None, bias_initializer)) | ||
|
||
# optimize parameter set to "auto" scaling mode if possible | ||
if hasattr(self.weight_quantizer_internal, "_set_trainable_parameter"): | ||
self.weight_quantizer_internal._set_trainable_parameter() | ||
if hasattr(self.bias_quantizer_internal, "_set_trainable_parameter"): | ||
self.bias_quantizer_internal._set_trainable_parameter() | ||
|
||
self.quantizers = [self.weight_quantizer_internal, | ||
self.bias_quantizer_internal] | ||
|
||
self.activation = get_quantizer(activation) | ||
|
||
super().__init__(**kwargs) | ||
|
||
def build(self, input_shape): | ||
self.weight = self.add_weight( | ||
name="weight", shape=(1, 1), dtype="float32", | ||
initializer=self.weight_initializer, | ||
regularizer=self.weight_regularizer, trainable=True) | ||
|
||
if self.use_bias: | ||
self.bias = self.add_weight( | ||
name="bias", shape=(1, 1), dtype="float32", | ||
initializer=self.bias_initializer, regularizer=self.bias_regularizer, | ||
trainable=True) | ||
else: | ||
self.bias = None | ||
self.built = True | ||
|
||
def call(self, inputs): | ||
|
||
quantized_weight = ( | ||
self.weight_quantizer_internal(self.weight) if | ||
self.weight_quantizer_internal is not None else self.weight) | ||
|
||
outputs = tf.math.multiply(inputs, quantized_weight) | ||
|
||
if self.use_bias: | ||
quantized_bias = ( | ||
self.bias_quantizer_internal(self.bias) if | ||
self.bias_quantizer_internal is not None else self.bias) | ||
|
||
outputs = quantized_bias + outputs | ||
|
||
return self.activation(outputs) if self.activation is not None else outputs | ||
|
||
def get_config(self): | ||
config = { | ||
"weight_quantizer": constraints.serialize( | ||
self.weight_quantizer_internal | ||
), | ||
"bias_quantizer": constraints.serialize( | ||
self.bias_quantizer_internal | ||
), | ||
"weight_initializer": constraints.serialize( | ||
self.weight_initializer), | ||
"bias_initializer": constraints.serialize( | ||
self.bias_initializer), | ||
"activation": constraints.serialize( | ||
self.activation), | ||
"use_bias": self.use_bias, | ||
"weight_regularizer": constraints.serialize( | ||
self.weight_regularizer), | ||
"bias_regularizer": constraints.serialize( | ||
self.bias_regularizer), | ||
} | ||
base_config = super().get_config() | ||
base_config.update(config) | ||
return base_config | ||
|
||
def get_quantization_config(self): | ||
return { | ||
"weight_quantizer": | ||
str(self.weight_quantizer_internal), | ||
"bias_quantizer": | ||
str(self.bias_quantizer_internal), | ||
"activation": | ||
str(self.activation) | ||
} | ||
|
||
def get_quantizers(self): | ||
return self.quantizers | ||
|
||
def get_prunable_weights(self): | ||
return [self.weight, self.bias] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright 2019 Google LLC | ||
# | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Test layers from qlayers.py.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import logging | ||
import os | ||
import tempfile | ||
|
||
import numpy as np | ||
from numpy.testing import assert_allclose | ||
from numpy.testing import assert_equal | ||
import pytest | ||
from tensorflow.keras import backend as K | ||
from tensorflow.keras.layers import Input | ||
from tensorflow.keras.models import Model | ||
|
||
from qkeras import QScaleShift | ||
from qkeras.utils import load_qmodel | ||
|
||
|
||
def create_qmac_model(layer_cls, | ||
kwargs=None, | ||
input_data=None, | ||
weight_data=None): | ||
"""Create a QMAC model for test purpose.""" | ||
layer = layer_cls(**kwargs) | ||
x = Input(shape=input_data.shape[1:], dtype=input_data.dtype) | ||
y = layer(x) | ||
layer.set_weights(weight_data) | ||
|
||
return Model(x, y) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'layer_kwargs, input_data, weight_data, bias_data, expected_output', | ||
[ | ||
( | ||
{ | ||
'weight_quantizer': 'quantized_bits(8,2,alpha=1.0)', | ||
'bias_quantizer': 'quantized_bits(8,2,alpha=1.0)', | ||
'activation': 'quantized_bits(8,4,alpha=1.0)' | ||
}, | ||
np.array([[1, 1], [2, 2]], dtype=K.floatx()), | ||
np.array([[1.0]]), | ||
np.array([[4.0]]), | ||
np.array([[5, 5], [6, 6]], dtype=K.floatx())), | ||
]) | ||
def test_qmac(layer_kwargs, input_data, weight_data, bias_data, | ||
expected_output): | ||
model = create_qmac_model( | ||
layer_cls=QScaleShift, | ||
kwargs=layer_kwargs, | ||
input_data=input_data, | ||
weight_data=[weight_data, bias_data]) | ||
|
||
actual_output = model.predict(input_data) | ||
assert_allclose(actual_output, expected_output, rtol=1e-4) | ||
|
||
# Test model loading and saving. | ||
fd, fname = tempfile.mkstemp('.h5') | ||
model.save(fname) | ||
|
||
# Load the model. | ||
loaded_model = load_qmodel(fname) | ||
|
||
# Clean the h5 file after loading the model | ||
os.close(fd) | ||
os.remove(fname) | ||
|
||
# Compare weights of original and loaded models. | ||
model_weights = model.weights | ||
loaded_model_weights = loaded_model.weights | ||
|
||
assert_equal(len(model_weights), len(loaded_model_weights)) | ||
for i, model_weight in enumerate(model_weights): | ||
assert_equal(model_weight.numpy(), loaded_model_weights[i].numpy()) | ||
|
||
# Compare if loaded models have the same prediction as original models. | ||
loaded_model_output = loaded_model.predict(input_data) | ||
assert_equal(actual_output, loaded_model_output) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__]) |