forked from polimi-ispl/GAN-image-detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gan_vs_real_detector.py
134 lines (100 loc) · 4.65 KB
/
gan_vs_real_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from collections import OrderedDict
import numpy as np
import torch
torch.manual_seed(21)
import random
random.seed(21)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import albumentations as A
import albumentations.pytorch as Ap
from utils import architectures
from utils.python_patch_extractor.PatchExtractor import PatchExtractor
from PIL import Image
import argparse
class Detector:
def __init__(self):
# You need to download the weight.zip file from here https://www.dropbox.com/s/g1z2u8wl6srjh6v/weigths.zip?dl=0
# and uncompress it into the main folder.
self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']
# GPU configuration if available
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
self.nets = []
for i, l in enumerate('ABCDE'):
# Instantiate and load network
network_class = getattr(architectures, 'EfficientNetB4')
net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
print(f'Loading model {l}...')
state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
if 'net' not in state_tmp.keys():
state = OrderedDict({'net': OrderedDict()})
[state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
else:
state = state_tmp
incomp_keys = net.load_state_dict(state['net'], strict=True)
print(incomp_keys)
print('Model loaded!\n')
self.nets += [net]
net_normalizer = net.get_normalizer() # pick normalizer from last network
transform = [
A.Normalize(mean=net_normalizer.mean, std=net_normalizer.std),
Ap.transforms.ToTensorV2()
]
self.trans = A.Compose(transform)
self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
def synth_real_detector(self, img_path: str, n_patch: int = 200):
# Load image:
img = np.asarray(Image.open(img_path))
# Opt-out if image is non conforming
if img.shape == ():
print('{} None dimension'.format(img_path))
return None
if img.shape[0] < 128 or img.shape[1] < 128:
print('Too small image')
return None
if img.ndim != 3:
print('RGB images only')
return None
if img.shape[2] > 3:
print('Omitting alpha channel')
img = img[:, :, :3]
print('Computing scores...')
img_net_scores = []
for net_idx, net in enumerate(self.nets):
if net_idx == 0:
# only for detector A, extract N = 200 random patches per image
patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]
else:
# for detectors B, C, D, E, extract patches aligned with the 8 x 8 pixel grid:
# we want more or less 200 patches per img
stride_0 = ((((img.shape[0] - 128) // 20) + 7) // 8) * 8
stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
patches = pe.extract(img)
patch_list = list(patches.reshape((patches.shape[0] * patches.shape[1], 128, 128, 3)))
# Normalization
transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
# Compute scores
transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
with torch.no_grad():
patch_scores = net(transf_patch_tensor).cpu().numpy()
patch_predictions = np.argmax(patch_scores, axis=1)
maj_voting = np.any(patch_predictions).astype(int)
scores_maj_voting = patch_scores[:, maj_voting]
img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))
# final score is the average among the 5 scores returned by the detectors
img_score = np.mean(img_net_scores)
return img_score
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--img_path', help='Path to the test image', required=True)
args = parser.parse_args()
img_path = args.img_path
detector = Detector()
score = detector.synth_real_detector(img_path)
print('Image Score: {}'.format(score))
return 0
if __name__ == '__main__':
main()