diff --git a/pyannote/audio/utils/receptive_field.py b/pyannote/audio/utils/receptive_field.py index 324c60f16..0e484e4ad 100644 --- a/pyannote/audio/utils/receptive_field.py +++ b/pyannote/audio/utils/receptive_field.py @@ -20,8 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - -import math from typing import List @@ -52,9 +50,7 @@ def conv1d_num_frames( ------ https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d """ - return math.floor( - 1 + (num_samples + 2 * padding - dilation * (kernel_size - 1) - 1) / stride - ) + return 1 + (num_samples + 2 * padding - dilation * (kernel_size - 1) - 1) // stride def multi_conv_num_frames( @@ -105,6 +101,7 @@ def multi_conv_receptive_field_size( dilation: List[int] = None, ) -> int: receptive_field_size = num_frames + for k, s, d in reversed(list(zip(kernel_size, stride, dilation))): receptive_field_size = conv1d_receptive_field_size( num_frames=receptive_field_size, @@ -140,7 +137,7 @@ def conv1d_receptive_field_center( """ effective_kernel_size = 1 + (kernel_size - 1) * dilation - return frame * stride + (effective_kernel_size - 1) / 2 - padding + return frame * stride + (effective_kernel_size - 1) // 2 - padding def multi_conv_receptive_field_center(