-
Notifications
You must be signed in to change notification settings - Fork 5
/
get_fft_feature.py
62 lines (53 loc) · 2.22 KB
/
get_fft_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
import torch
from torch.autograd import Variable
from PIL import Image
from torchvision import transforms
def RGB2gray(rgb):
if rgb.size(1) == 3:
r, g, b = rgb[:,0,:,:], rgb[:,1,:,:], rgb[:,2,:,:]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
elif rgb.size(1) == 1:
return rgb[:,0,:,:]
def shift(x):
out = torch.zeros(x.size())
H, W = x.size(-2), x.size(-1)
out[:,:int(H/2),:int(W/2)] = x[:,int(H/2):,int(W/2):]
out[:,:int(H/2),int(W/2):] = x[:,int(H/2):,:int(W/2)]
out[:,int(H/2):,:int(W/2)] = x[:,:int(H/2),int(W/2):]
out[:,int(H/2):,int(W/2):] = x[:,:int(H/2),:int(W/2)]
return out
def azimuthalAverage(image, center=None):
"""
Calculate the azimuthally averaged radial profile.
image - The 2D image
center - The [x,y] pixel coordinates used as the center. The default is
None, which then uses the center of the image (including
fracitonal pixels).
"""
# Calculate the indices from the image
H, W = image.shape[0], image.shape[1]
y, x = np.indices([H, W])
radius = np.sqrt((x - H/2)**2 + (y - W/2)**2)
radius = radius.astype(np.int).ravel()
nr = np.bincount(radius)
tbin = np.bincount(radius, image.ravel())
radial_prof = tbin / (nr + 1e-10)
return radial_prof[1:-2] # We ignore the last two extremely high frequency bands to avoid noise.
def get_fft_feature(x):
x_rgb = x.detach()
epsilon = 1e-8
x_gray = RGB2gray(x_rgb)
fft = torch.rfft(x_gray,2,onesided=False)
fft += epsilon
magnitude_spectrum = torch.log((torch.sqrt(fft[:,:,:,0]**2 + fft[:,:,:,1]**2 + 1e-10))+1e-10)
magnitude_spectrum = shift(magnitude_spectrum)
magnitude_spectrum = magnitude_spectrum.cpu().numpy()
out = []
for i in range(magnitude_spectrum.shape[0]):
out.append(torch.from_numpy(azimuthalAverage(magnitude_spectrum[i])).float().unsqueeze(0))
out = torch.cat(out, dim=0)
out = (out - torch.min(out, dim=1, keepdim=True)[0]) / (torch.max(out, dim=1, keepdim=True)[0] - torch.min(out, dim=1, keepdim=True)[0])
out = Variable(out, requires_grad=True).to(x.device)
return out