Skip to content

Commit

Permalink
Compatibility with older pytorch versions
Browse files Browse the repository at this point in the history
  • Loading branch information
jbojar committed Aug 11, 2021
1 parent 2bfba87 commit 0002278
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions torch_dct/_dct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
import numpy as np
import torch
import torch.nn as nn
import torch.fft

try:
# PyTorch 1.7.0 and newer versions
import torch.fft

def dct1_rfft_impl(x):
return torch.view_as_real(torch.fft.rfft(x, dim=1))

def dct_fft_impl(v):
return torch.view_as_real(torch.fft.fft(v, dim=1))

def idct_irfft_impl(V):
return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
except ImportError:
# PyTorch 1.6.0 and older versions
def dct1_rfft_impl(x):
return torch.rfft(x, 1)

def dct_fft_impl(v):
return torch.rfft(v, 1, onesided=False)

def idct_irfft_impl(V):
return torch.irfft(V, 1, onesided=False)



def dct1(x):
Expand All @@ -13,8 +36,9 @@ def dct1(x):
"""
x_shape = x.shape
x = x.view(-1, x_shape[-1])
x = torch.cat([x, x.flip([1])[:, 1:-1]], dim=1)

return torch.view_as_real(torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), dim=1))[:, :, 0].view(*x_shape)
return dct1_rfft_impl(x)[:, :, 0].view(*x_shape)


def idct1(X):
Expand Down Expand Up @@ -47,7 +71,7 @@ def dct(x, norm=None):

v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
Vc = dct_fft_impl(v)

k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
Expand Down Expand Up @@ -99,7 +123,7 @@ def idct(X, norm=None):

V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)

v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
v = idct_irfft_impl(V)
x = v.new_zeros(v.shape)
x[:, ::2] += v[:, :N - (N // 2)]
x[:, 1::2] += v.flip([1])[:, :N // 2]
Expand Down

0 comments on commit 0002278

Please sign in to comment.