Skip to content

Commit

Permalink
Remove private Keras imports.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565742783
Change-Id: I970b9ee988f5f520d6d12b42fcfa8249f7fadc2e
  • Loading branch information
fchollet authored and copybara-github committed Sep 15, 2023
1 parent 9237aab commit 446d03b
Showing 1 changed file with 65 additions and 13 deletions.
78 changes: 65 additions & 13 deletions qkeras/qconvolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import print_function
import warnings

from keras.utils import conv_utils
import tensorflow as tf
from tensorflow.keras import constraints
from tensorflow.keras import initializers
Expand All @@ -43,6 +42,59 @@
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer


def deconv_output_length(
input_length,
filter_size,
padding,
output_padding=None,
stride=0,
dilation=1,
):
"""Determines output length of a transposed convolution given input length.
Args:
input_length: Integer.
filter_size: Integer.
padding: one of `"same"`, `"valid"`, `"full"`.
output_padding: Integer, amount of padding along the output dimension.
Can be set to `None` in which case the output length is inferred.
stride: Integer.
dilation: Integer.
Returns:
The output length (integer).
"""
assert padding in {"same", "valid", "full"}
if input_length is None:
return None

# Get the dilated kernel size
filter_size = filter_size + (filter_size - 1) * (dilation - 1)
pad = 0
length = 0

# Infer length if output padding is None, else compute the exact length
if output_padding is None:
if padding == "valid":
length = input_length * stride + max(filter_size - stride, 0)
elif padding == "full":
length = input_length * stride - (stride + filter_size - 2)
elif padding == "same":
length = input_length * stride
else:
if padding == "same":
pad = filter_size // 2
elif padding == "valid":
pad = 0
elif padding == "full":
pad = filter_size - 1

length = (
(input_length - 1) * stride + filter_size - 2 * pad + output_padding
)
return length


class QConv1D(Conv1D, PrunableLayer):
"""1D convolution layer (e.g. spatial convolution over images)."""

Expand Down Expand Up @@ -448,18 +500,18 @@ def call(self, inputs):
out_pad_h, out_pad_w = self.output_padding

# Infer the dynamic output shape:
out_height = conv_utils.deconv_output_length(height,
kernel_h,
padding=self.padding,
output_padding=out_pad_h,
stride=stride_h,
dilation=self.dilation_rate[0])
out_width = conv_utils.deconv_output_length(width,
kernel_w,
padding=self.padding,
output_padding=out_pad_w,
stride=stride_w,
dilation=self.dilation_rate[1])
out_height = deconv_output_length(height,
kernel_h,
padding=self.padding,
output_padding=out_pad_h,
stride=stride_h,
dilation=self.dilation_rate[0])
out_width = deconv_output_length(width,
kernel_w,
padding=self.padding,
output_padding=out_pad_w,
stride=stride_w,
dilation=self.dilation_rate[1])
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
else:
Expand Down

0 comments on commit 446d03b

Please sign in to comment.