This repository has been archived by the owner on Sep 9, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 122
/
models.py
1230 lines (1076 loc) · 44.5 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
# Torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torch.nn import init
# utils
import math
import os
import datetime
import numpy as np
import joblib
from tqdm import tqdm
from utils import grouper, sliding_window, count_sliding_window, camel_to_snake
def get_model(name, **kwargs):
"""
Instantiate and obtain a model with adequate hyperparameters
Args:
name: string of the model name
kwargs: hyperparameters
Returns:
model: PyTorch network
optimizer: PyTorch optimizer
criterion: PyTorch loss Function
kwargs: hyperparameters with sane defaults
"""
device = kwargs.setdefault("device", torch.device("cpu"))
n_classes = kwargs["n_classes"]
n_bands = kwargs["n_bands"]
weights = torch.ones(n_classes)
weights[torch.LongTensor(kwargs["ignored_labels"])] = 0.0
weights = weights.to(device)
weights = kwargs.setdefault("weights", weights)
if name == "nn":
kwargs.setdefault("patch_size", 1)
center_pixel = True
model = Baseline(n_bands, n_classes, kwargs.setdefault("dropout", False))
lr = kwargs.setdefault("learning_rate", 0.0001)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
kwargs.setdefault("epoch", 100)
kwargs.setdefault("batch_size", 100)
elif name == "hamida":
patch_size = kwargs.setdefault("patch_size", 5)
center_pixel = True
model = HamidaEtAl(n_bands, n_classes, patch_size=patch_size)
lr = kwargs.setdefault("learning_rate", 0.01)
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005)
kwargs.setdefault("batch_size", 100)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
elif name == "lee":
kwargs.setdefault("epoch", 200)
patch_size = kwargs.setdefault("patch_size", 5)
center_pixel = False
model = LeeEtAl(n_bands, n_classes)
lr = kwargs.setdefault("learning_rate", 0.001)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
elif name == "chen":
patch_size = kwargs.setdefault("patch_size", 27)
center_pixel = True
model = ChenEtAl(n_bands, n_classes, patch_size=patch_size)
lr = kwargs.setdefault("learning_rate", 0.003)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
kwargs.setdefault("epoch", 400)
kwargs.setdefault("batch_size", 100)
elif name == "li":
patch_size = kwargs.setdefault("patch_size", 5)
center_pixel = True
model = LiEtAl(n_bands, n_classes, n_planes=16, patch_size=patch_size)
lr = kwargs.setdefault("learning_rate", 0.01)
optimizer = optim.SGD(
model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005
)
epoch = kwargs.setdefault("epoch", 200)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
# kwargs.setdefault('scheduler', optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epoch // 2, (5 * epoch) // 6], gamma=0.1))
elif name == "hu":
kwargs.setdefault("patch_size", 1)
center_pixel = True
model = HuEtAl(n_bands, n_classes)
# From what I infer from the paper (Eq.7 and Algorithm 1), it is standard SGD with lr = 0.01
lr = kwargs.setdefault("learning_rate", 0.01)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
kwargs.setdefault("epoch", 100)
kwargs.setdefault("batch_size", 100)
elif name == "he":
# We train our model by AdaGrad [18] algorithm, in which
# the base learning rate is 0.01. In addition, we set the batch
# as 40, weight decay as 0.01 for all the layers
# The input of our network is the HSI 3D patch in the size of 7×7×Band
kwargs.setdefault("patch_size", 7)
kwargs.setdefault("batch_size", 40)
lr = kwargs.setdefault("learning_rate", 0.01)
center_pixel = True
model = HeEtAl(n_bands, n_classes, patch_size=kwargs["patch_size"])
# For Adagrad, we need to load the model on GPU before creating the optimizer
model = model.to(device)
optimizer = optim.Adagrad(model.parameters(), lr=lr, weight_decay=0.01)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
elif name == "luo":
# All the experiments are settled by the learning rate of 0.1,
# the decay term of 0.09 and batch size of 100.
kwargs.setdefault("patch_size", 3)
kwargs.setdefault("batch_size", 100)
lr = kwargs.setdefault("learning_rate", 0.1)
center_pixel = True
model = LuoEtAl(n_bands, n_classes, patch_size=kwargs["patch_size"])
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.09)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
elif name == "sharma":
# We train our S-CNN from scratch using stochastic gradient descent with
# momentum set to 0.9, weight decay of 0.0005, and with a batch size
# of 60. We initialize an equal learning rate for all trainable layers
# to 0.05, which is manually decreased by a factor of 10 when the validation
# error stopped decreasing. Prior to the termination the learning rate was
# reduced two times at 15th and 25th epoch. [...]
# We trained the network for 30 epochs
kwargs.setdefault("batch_size", 60)
epoch = kwargs.setdefault("epoch", 30)
lr = kwargs.setdefault("lr", 0.05)
center_pixel = True
# We assume patch_size = 64
kwargs.setdefault("patch_size", 64)
model = SharmaEtAl(n_bands, n_classes, patch_size=kwargs["patch_size"])
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
kwargs.setdefault(
"scheduler",
optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[epoch // 2, (5 * epoch) // 6], gamma=0.1
),
)
elif name == "liu":
kwargs["supervision"] = "semi"
# "The learning rate is set to 0.001 empirically. The number of epochs is set to be 40."
kwargs.setdefault("epoch", 40)
lr = kwargs.setdefault("lr", 0.001)
center_pixel = True
patch_size = kwargs.setdefault("patch_size", 9)
model = LiuEtAl(n_bands, n_classes, patch_size)
optimizer = optim.SGD(model.parameters(), lr=lr)
# "The unsupervised cost is the squared error of the difference"
criterion = (
nn.CrossEntropyLoss(weight=kwargs["weights"]),
lambda rec, data: F.mse_loss(
rec, data[:, :, :, patch_size // 2, patch_size // 2].squeeze()
),
)
elif name == "boulch":
kwargs["supervision"] = "semi"
kwargs.setdefault("patch_size", 1)
kwargs.setdefault("epoch", 100)
lr = kwargs.setdefault("lr", 0.001)
center_pixel = True
model = BoulchEtAl(n_bands, n_classes)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = (
nn.CrossEntropyLoss(weight=kwargs["weights"]),
lambda rec, data: F.mse_loss(rec, data.squeeze()),
)
elif name == "mou":
kwargs.setdefault("patch_size", 1)
center_pixel = True
kwargs.setdefault("epoch", 100)
# "The RNN was trained with the Adadelta algorithm [...] We made use of a
# fairly high learning rate of 1.0 instead of the relatively low
# default of 0.002 to train the network"
lr = kwargs.setdefault("lr", 1.0)
model = MouEtAl(n_bands, n_classes)
# For Adadelta, we need to load the model on GPU before creating the optimizer
model = model.to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(weight=kwargs["weights"])
else:
raise KeyError("{} model is unknown.".format(name))
model = model.to(device)
epoch = kwargs.setdefault("epoch", 100)
kwargs.setdefault(
"scheduler",
optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=0.1, patience=epoch // 4, verbose=True
),
)
# kwargs.setdefault('scheduler', None)
kwargs.setdefault("batch_size", 100)
kwargs.setdefault("supervision", "full")
kwargs.setdefault("flip_augmentation", False)
kwargs.setdefault("radiation_augmentation", False)
kwargs.setdefault("mixture_augmentation", False)
kwargs["center_pixel"] = center_pixel
return model, optimizer, criterion, kwargs
class Baseline(nn.Module):
"""
Baseline network
"""
@staticmethod
def weight_init(m):
if isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight)
init.zeros_(m.bias)
def __init__(self, input_channels, n_classes, dropout=False):
super(Baseline, self).__init__()
self.use_dropout = dropout
if dropout:
self.dropout = nn.Dropout(p=0.5)
self.fc1 = nn.Linear(input_channels, 2048)
self.fc2 = nn.Linear(2048, 4096)
self.fc3 = nn.Linear(4096, 2048)
self.fc4 = nn.Linear(2048, n_classes)
self.apply(self.weight_init)
def forward(self, x):
x = F.relu(self.fc1(x))
if self.use_dropout:
x = self.dropout(x)
x = F.relu(self.fc2(x))
if self.use_dropout:
x = self.dropout(x)
x = F.relu(self.fc3(x))
if self.use_dropout:
x = self.dropout(x)
x = self.fc4(x)
return x
class HuEtAl(nn.Module):
"""
Deep Convolutional Neural Networks for Hyperspectral Image Classification
Wei Hu, Yangyu Huang, Li Wei, Fan Zhang and Hengchao Li
Journal of Sensors, Volume 2015 (2015)
https://www.hindawi.com/journals/js/2015/258619/
"""
@staticmethod
def weight_init(m):
# [All the trainable parameters in our CNN should be initialized to
# be a random value between −0.05 and 0.05.]
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
init.uniform_(m.weight, -0.05, 0.05)
init.zeros_(m.bias)
def _get_final_flattened_size(self):
with torch.no_grad():
x = torch.zeros(1, 1, self.input_channels)
x = self.pool(self.conv(x))
return x.numel()
def __init__(self, input_channels, n_classes, kernel_size=None, pool_size=None):
super(HuEtAl, self).__init__()
if kernel_size is None:
# [In our experiments, k1 is better to be [ceil](n1/9)]
kernel_size = math.ceil(input_channels / 9)
if pool_size is None:
# The authors recommand that k2's value is chosen so that the pooled features have 30~40 values
# ceil(kernel_size/5) gives the same values as in the paper so let's assume it's okay
pool_size = math.ceil(kernel_size / 5)
self.input_channels = input_channels
# [The first hidden convolution layer C1 filters the n1 x 1 input data with 20 kernels of size k1 x 1]
self.conv = nn.Conv1d(1, 20, kernel_size)
self.pool = nn.MaxPool1d(pool_size)
self.features_size = self._get_final_flattened_size()
# [n4 is set to be 100]
self.fc1 = nn.Linear(self.features_size, 100)
self.fc2 = nn.Linear(100, n_classes)
self.apply(self.weight_init)
def forward(self, x):
# [In our design architecture, we choose the hyperbolic tangent function tanh(u)]
x = x.squeeze(dim=-1).squeeze(dim=-1)
x = x.unsqueeze(1)
x = self.conv(x)
x = torch.tanh(self.pool(x))
x = x.view(-1, self.features_size)
x = torch.tanh(self.fc1(x))
x = self.fc2(x)
return x
class HamidaEtAl(nn.Module):
"""
3-D Deep Learning Approach for Remote Sensing Image Classification
Amina Ben Hamida, Alexandre Benoit, Patrick Lambert, Chokri Ben Amar
IEEE TGRS, 2018
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8344565
"""
@staticmethod
def weight_init(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
init.kaiming_normal_(m.weight)
init.zeros_(m.bias)
def __init__(self, input_channels, n_classes, patch_size=5, dilation=1):
super(HamidaEtAl, self).__init__()
# The first layer is a (3,3,3) kernel sized Conv characterized
# by a stride equal to 1 and number of neurons equal to 20
self.patch_size = patch_size
self.input_channels = input_channels
dilation = (dilation, 1, 1)
if patch_size == 3:
self.conv1 = nn.Conv3d(
1, 20, (3, 3, 3), stride=(1, 1, 1), dilation=dilation, padding=1
)
else:
self.conv1 = nn.Conv3d(
1, 20, (3, 3, 3), stride=(1, 1, 1), dilation=dilation, padding=0
)
# Next pooling is applied using a layer identical to the previous one
# with the difference of a 1D kernel size (1,1,3) and a larger stride
# equal to 2 in order to reduce the spectral dimension
self.pool1 = nn.Conv3d(
20, 20, (3, 1, 1), dilation=dilation, stride=(2, 1, 1), padding=(1, 0, 0)
)
# Then, a duplicate of the first and second layers is created with
# 35 hidden neurons per layer.
self.conv2 = nn.Conv3d(
20, 35, (3, 3, 3), dilation=dilation, stride=(1, 1, 1), padding=(1, 0, 0)
)
self.pool2 = nn.Conv3d(
35, 35, (3, 1, 1), dilation=dilation, stride=(2, 1, 1), padding=(1, 0, 0)
)
# Finally, the 1D spatial dimension is progressively reduced
# thanks to the use of two Conv layers, 35 neurons each,
# with respective kernel sizes of (1,1,3) and (1,1,2) and strides
# respectively equal to (1,1,1) and (1,1,2)
self.conv3 = nn.Conv3d(
35, 35, (3, 1, 1), dilation=dilation, stride=(1, 1, 1), padding=(1, 0, 0)
)
self.conv4 = nn.Conv3d(
35, 35, (2, 1, 1), dilation=dilation, stride=(2, 1, 1), padding=(1, 0, 0)
)
# self.dropout = nn.Dropout(p=0.5)
self.features_size = self._get_final_flattened_size()
# The architecture ends with a fully connected layer where the number
# of neurons is equal to the number of input classes.
self.fc = nn.Linear(self.features_size, n_classes)
self.apply(self.weight_init)
def _get_final_flattened_size(self):
with torch.no_grad():
x = torch.zeros(
(1, 1, self.input_channels, self.patch_size, self.patch_size)
)
x = self.pool1(self.conv1(x))
x = self.pool2(self.conv2(x))
x = self.conv3(x)
x = self.conv4(x)
_, t, c, w, h = x.size()
return t * c * w * h
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = x.view(-1, self.features_size)
# x = self.dropout(x)
x = self.fc(x)
return x
class LeeEtAl(nn.Module):
"""
CONTEXTUAL DEEP CNN BASED HYPERSPECTRAL CLASSIFICATION
Hyungtae Lee and Heesung Kwon
IGARSS 2016
"""
@staticmethod
def weight_init(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
init.kaiming_uniform_(m.weight)
init.zeros_(m.bias)
def __init__(self, in_channels, n_classes):
super(LeeEtAl, self).__init__()
# The first convolutional layer applied to the input hyperspectral
# image uses an inception module that locally convolves the input
# image with two convolutional filters with different sizes
# (1x1xB and 3x3xB where B is the number of spectral bands)
self.conv_3x3 = nn.Conv3d(
1, 128, (in_channels, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1)
)
self.conv_1x1 = nn.Conv3d(
1, 128, (in_channels, 1, 1), stride=(1, 1, 1), padding=0
)
# We use two modules from the residual learning approach
# Residual block 1
self.conv1 = nn.Conv2d(256, 128, (1, 1))
self.conv2 = nn.Conv2d(128, 128, (1, 1))
self.conv3 = nn.Conv2d(128, 128, (1, 1))
# Residual block 2
self.conv4 = nn.Conv2d(128, 128, (1, 1))
self.conv5 = nn.Conv2d(128, 128, (1, 1))
# The layer combination in the last three convolutional layers
# is the same as the fully connected layers of Alexnet
self.conv6 = nn.Conv2d(128, 128, (1, 1))
self.conv7 = nn.Conv2d(128, 128, (1, 1))
self.conv8 = nn.Conv2d(128, n_classes, (1, 1))
self.lrn1 = nn.LocalResponseNorm(256)
self.lrn2 = nn.LocalResponseNorm(128)
# The 7 th and 8 th convolutional layers have dropout in training
self.dropout = nn.Dropout(p=0.5)
self.apply(self.weight_init)
def forward(self, x):
# Inception module
x_3x3 = self.conv_3x3(x)
x_1x1 = self.conv_1x1(x)
x = torch.cat([x_3x3, x_1x1], dim=1)
# Remove the third dimension of the tensor
x = torch.squeeze(x)
# Local Response Normalization
x = F.relu(self.lrn1(x))
# First convolution
x = self.conv1(x)
# Local Response Normalization
x = F.relu(self.lrn2(x))
# First residual block
x_res = F.relu(self.conv2(x))
x_res = self.conv3(x_res)
x = F.relu(x + x_res)
# Second residual block
x_res = F.relu(self.conv4(x))
x_res = self.conv5(x_res)
x = F.relu(x + x_res)
x = F.relu(self.conv6(x))
x = self.dropout(x)
x = F.relu(self.conv7(x))
x = self.dropout(x)
x = self.conv8(x)
return x
class ChenEtAl(nn.Module):
"""
DEEP FEATURE EXTRACTION AND CLASSIFICATION OF HYPERSPECTRAL IMAGES BASED ON
CONVOLUTIONAL NEURAL NETWORKS
Yushi Chen, Hanlu Jiang, Chunyang Li, Xiuping Jia and Pedram Ghamisi
IEEE Transactions on Geoscience and Remote Sensing (TGRS), 2017
"""
@staticmethod
def weight_init(m):
# In the beginning, the weights are randomly initialized
# with standard deviation 0.001
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
init.normal_(m.weight, std=0.001)
init.zeros_(m.bias)
def __init__(self, input_channels, n_classes, patch_size=27, n_planes=32):
super(ChenEtAl, self).__init__()
self.input_channels = input_channels
self.n_planes = n_planes
self.patch_size = patch_size
self.conv1 = nn.Conv3d(1, n_planes, (32, 4, 4))
self.pool1 = nn.MaxPool3d((1, 2, 2))
self.conv2 = nn.Conv3d(n_planes, n_planes, (32, 4, 4))
self.pool2 = nn.MaxPool3d((1, 2, 2))
self.conv3 = nn.Conv3d(n_planes, n_planes, (32, 4, 4))
self.features_size = self._get_final_flattened_size()
self.fc = nn.Linear(self.features_size, n_classes)
self.dropout = nn.Dropout(p=0.5)
self.apply(self.weight_init)
def _get_final_flattened_size(self):
with torch.no_grad():
x = torch.zeros(
(1, 1, self.input_channels, self.patch_size, self.patch_size)
)
x = self.pool1(self.conv1(x))
x = self.pool2(self.conv2(x))
x = self.conv3(x)
_, t, c, w, h = x.size()
return t * c * w * h
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = self.dropout(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = self.dropout(x)
x = F.relu(self.conv3(x))
x = self.dropout(x)
x = x.view(-1, self.features_size)
x = self.fc(x)
return x
class LiEtAl(nn.Module):
"""
SPECTRAL–SPATIAL CLASSIFICATION OF HYPERSPECTRAL IMAGERY
WITH 3D CONVOLUTIONAL NEURAL NETWORK
Ying Li, Haokui Zhang and Qiang Shen
MDPI Remote Sensing, 2017
http://www.mdpi.com/2072-4292/9/1/67
"""
@staticmethod
def weight_init(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
init.xavier_uniform_(m.weight.data)
init.constant_(m.bias.data, 0)
def __init__(self, input_channels, n_classes, n_planes=2, patch_size=5):
super(LiEtAl, self).__init__()
self.input_channels = input_channels
self.n_planes = n_planes
self.patch_size = patch_size
# The proposed 3D-CNN model has two 3D convolution layers (C1 and C2)
# and a fully-connected layer (F1)
# we fix the spatial size of the 3D convolution kernels to 3 × 3
# while only slightly varying the spectral depth of the kernels
# for the Pavia University and Indian Pines scenes, those in C1 and C2
# were set to seven and three, respectively
self.conv1 = nn.Conv3d(1, n_planes, (7, 3, 3), padding=(1, 0, 0))
# the number of kernels in the second convolution layer is set to be
# twice as many as that in the first convolution layer
self.conv2 = nn.Conv3d(n_planes, 2 * n_planes, (3, 3, 3), padding=(1, 0, 0))
# self.dropout = nn.Dropout(p=0.5)
self.features_size = self._get_final_flattened_size()
self.fc = nn.Linear(self.features_size, n_classes)
self.apply(self.weight_init)
def _get_final_flattened_size(self):
with torch.no_grad():
x = torch.zeros(
(1, 1, self.input_channels, self.patch_size, self.patch_size)
)
x = self.conv1(x)
x = self.conv2(x)
_, t, c, w, h = x.size()
return t * c * w * h
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(-1, self.features_size)
# x = self.dropout(x)
x = self.fc(x)
return x
class HeEtAl(nn.Module):
"""
MULTI-SCALE 3D DEEP CONVOLUTIONAL NEURAL NETWORK FOR HYPERSPECTRAL
IMAGE CLASSIFICATION
Mingyi He, Bo Li, Huahui Chen
IEEE International Conference on Image Processing (ICIP) 2017
https://ieeexplore.ieee.org/document/8297014/
"""
@staticmethod
def weight_init(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
init.kaiming_uniform(m.weight)
init.zeros_(m.bias)
def __init__(self, input_channels, n_classes, patch_size=7):
super(HeEtAl, self).__init__()
self.input_channels = input_channels
self.patch_size = patch_size
self.conv1 = nn.Conv3d(1, 16, (11, 3, 3), stride=(3, 1, 1))
self.conv2_1 = nn.Conv3d(16, 16, (1, 1, 1), padding=(0, 0, 0))
self.conv2_2 = nn.Conv3d(16, 16, (3, 1, 1), padding=(1, 0, 0))
self.conv2_3 = nn.Conv3d(16, 16, (5, 1, 1), padding=(2, 0, 0))
self.conv2_4 = nn.Conv3d(16, 16, (11, 1, 1), padding=(5, 0, 0))
self.conv3_1 = nn.Conv3d(16, 16, (1, 1, 1), padding=(0, 0, 0))
self.conv3_2 = nn.Conv3d(16, 16, (3, 1, 1), padding=(1, 0, 0))
self.conv3_3 = nn.Conv3d(16, 16, (5, 1, 1), padding=(2, 0, 0))
self.conv3_4 = nn.Conv3d(16, 16, (11, 1, 1), padding=(5, 0, 0))
self.conv4 = nn.Conv3d(16, 16, (3, 2, 2))
self.pooling = nn.MaxPool2d((3, 2, 2), stride=(3, 2, 2))
# the ratio of dropout is 0.6 in our experiments
self.dropout = nn.Dropout(p=0.6)
self.features_size = self._get_final_flattened_size()
self.fc = nn.Linear(self.features_size, n_classes)
self.apply(self.weight_init)
def _get_final_flattened_size(self):
with torch.no_grad():
x = torch.zeros(
(1, 1, self.input_channels, self.patch_size, self.patch_size)
)
x = self.conv1(x)
x2_1 = self.conv2_1(x)
x2_2 = self.conv2_2(x)
x2_3 = self.conv2_3(x)
x2_4 = self.conv2_4(x)
x = x2_1 + x2_2 + x2_3 + x2_4
x3_1 = self.conv3_1(x)
x3_2 = self.conv3_2(x)
x3_3 = self.conv3_3(x)
x3_4 = self.conv3_4(x)
x = x3_1 + x3_2 + x3_3 + x3_4
x = self.conv4(x)
_, t, c, w, h = x.size()
return t * c * w * h
def forward(self, x):
x = F.relu(self.conv1(x))
x2_1 = self.conv2_1(x)
x2_2 = self.conv2_2(x)
x2_3 = self.conv2_3(x)
x2_4 = self.conv2_4(x)
x = x2_1 + x2_2 + x2_3 + x2_4
x = F.relu(x)
x3_1 = self.conv3_1(x)
x3_2 = self.conv3_2(x)
x3_3 = self.conv3_3(x)
x3_4 = self.conv3_4(x)
x = x3_1 + x3_2 + x3_3 + x3_4
x = F.relu(x)
x = F.relu(self.conv4(x))
x = x.view(-1, self.features_size)
x = self.dropout(x)
x = self.fc(x)
return x
class LuoEtAl(nn.Module):
"""
HSI-CNN: A Novel Convolution Neural Network for Hyperspectral Image
Yanan Luo, Jie Zou, Chengfei Yao, Tao Li, Gang Bai
International Conference on Pattern Recognition 2018
"""
@staticmethod
def weight_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
init.kaiming_uniform_(m.weight)
init.zeros_(m.bias)
def __init__(self, input_channels, n_classes, patch_size=3, n_planes=90):
super(LuoEtAl, self).__init__()
self.input_channels = input_channels
self.patch_size = patch_size
self.n_planes = n_planes
# the 8-neighbor pixels [...] are fed into the Conv1 convolved by n1 kernels
# and s1 stride. Conv1 results are feature vectors each with height of and
# the width is 1. After reshape layer, the feature vectors becomes an image-like
# 2-dimension data.
# Conv2 has 64 kernels size of 3x3, with stride s2.
# After that, the 64 results are drawn into a vector as the input of the fully
# connected layer FC1 which has n4 nodes.
# In the four datasets, the kernel height nk1 is 24 and stride s1, s2 is 9 and 1
self.conv1 = nn.Conv3d(1, 90, (24, 3, 3), padding=0, stride=(9, 1, 1))
self.conv2 = nn.Conv2d(1, 64, (3, 3), stride=(1, 1))
self.features_size = self._get_final_flattened_size()
self.fc1 = nn.Linear(self.features_size, 1024)
self.fc2 = nn.Linear(1024, n_classes)
self.apply(self.weight_init)
def _get_final_flattened_size(self):
with torch.no_grad():
x = torch.zeros(
(1, 1, self.input_channels, self.patch_size, self.patch_size)
)
x = self.conv1(x)
b = x.size(0)
x = x.view(b, 1, -1, self.n_planes)
x = self.conv2(x)
_, c, w, h = x.size()
return c * w * h
def forward(self, x):
x = F.relu(self.conv1(x))
b = x.size(0)
x = x.view(b, 1, -1, self.n_planes)
x = F.relu(self.conv2(x))
x = x.view(-1, self.features_size)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class SharmaEtAl(nn.Module):
"""
HYPERSPECTRAL CNN FOR IMAGE CLASSIFICATION & BAND SELECTION, WITH APPLICATION
TO FACE RECOGNITION
Vivek Sharma, Ali Diba, Tinne Tuytelaars, Luc Van Gool
Technical Report, KU Leuven/ETH Zürich
"""
@staticmethod
def weight_init(m):
if isinstance(m, (nn.Linear, nn.Conv3d)):
init.kaiming_normal_(m.weight)
init.zeros_(m.bias)
def __init__(self, input_channels, n_classes, patch_size=64):
super(SharmaEtAl, self).__init__()
self.input_channels = input_channels
self.patch_size = patch_size
# An input image of size 263x263 pixels is fed to conv1
# with 96 kernels of size 6x6x96 with a stride of 2 pixels
self.conv1 = nn.Conv3d(1, 96, (input_channels, 6, 6), stride=(1, 2, 2))
self.conv1_bn = nn.BatchNorm3d(96)
self.pool1 = nn.MaxPool3d((1, 2, 2))
# 256 kernels of size 3x3x256 with a stride of 2 pixels
self.conv2 = nn.Conv3d(1, 256, (96, 3, 3), stride=(1, 2, 2))
self.conv2_bn = nn.BatchNorm3d(256)
self.pool2 = nn.MaxPool3d((1, 2, 2))
# 512 kernels of size 3x3x512 with a stride of 1 pixel
self.conv3 = nn.Conv3d(1, 512, (256, 3, 3), stride=(1, 1, 1))
# Considering those large kernel values, I assume they actually merge the
# 3D tensors at each step
self.features_size = self._get_final_flattened_size()
# The fc1 has 1024 outputs, where dropout was applied after
# fc1 with a rate of 0.5
self.fc1 = nn.Linear(self.features_size, 1024)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(1024, n_classes)
self.apply(self.weight_init)
def _get_final_flattened_size(self):
with torch.no_grad():
x = torch.zeros(
(1, 1, self.input_channels, self.patch_size, self.patch_size)
)
x = F.relu(self.conv1_bn(self.conv1(x)))
x = self.pool1(x)
print(x.size())
b, t, c, w, h = x.size()
x = x.view(b, 1, t * c, w, h)
x = F.relu(self.conv2_bn(self.conv2(x)))
x = self.pool2(x)
print(x.size())
b, t, c, w, h = x.size()
x = x.view(b, 1, t * c, w, h)
x = F.relu(self.conv3(x))
print(x.size())
_, t, c, w, h = x.size()
return t * c * w * h
def forward(self, x):
x = F.relu(self.conv1_bn(self.conv1(x)))
x = self.pool1(x)
b, t, c, w, h = x.size()
x = x.view(b, 1, t * c, w, h)
x = F.relu(self.conv2_bn(self.conv2(x)))
x = self.pool2(x)
b, t, c, w, h = x.size()
x = x.view(b, 1, t * c, w, h)
x = F.relu(self.conv3(x))
x = x.view(-1, self.features_size)
x = self.fc1(x)
x = self.dropout(x)
x = self.fc2(x)
return x
class LiuEtAl(nn.Module):
"""
A semi-supervised convolutional neural network for hyperspectral image classification
Bing Liu, Xuchu Yu, Pengqiang Zhang, Xiong Tan, Anzhu Yu, Zhixiang Xue
Remote Sensing Letters, 2017
"""
@staticmethod
def weight_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal_(m.weight)
init.zeros_(m.bias)
def __init__(self, input_channels, n_classes, patch_size=9):
super(LiuEtAl, self).__init__()
self.input_channels = input_channels
self.patch_size = patch_size
self.aux_loss_weight = 1
# "W1 is a 3x3xB1 kernel [...] B1 is the number of the output bands for the convolutional
# "and pooling layer" -> actually 3x3 2D convolutions with B1 outputs
# "the value of B1 is set to be 80"
self.conv1 = nn.Conv2d(input_channels, 80, (3, 3))
self.pool1 = nn.MaxPool2d((2, 2))
self.conv1_bn = nn.BatchNorm2d(80)
self.features_sizes = self._get_sizes()
self.fc_enc = nn.Linear(self.features_sizes[2], n_classes)
# Decoder
self.fc1_dec = nn.Linear(self.features_sizes[2], self.features_sizes[2])
self.fc1_dec_bn = nn.BatchNorm1d(self.features_sizes[2])
self.fc2_dec = nn.Linear(self.features_sizes[2], self.features_sizes[1])
self.fc2_dec_bn = nn.BatchNorm1d(self.features_sizes[1])
self.fc3_dec = nn.Linear(self.features_sizes[1], self.features_sizes[0])
self.fc3_dec_bn = nn.BatchNorm1d(self.features_sizes[0])
self.fc4_dec = nn.Linear(self.features_sizes[0], input_channels)
self.apply(self.weight_init)
def _get_sizes(self):
x = torch.zeros((1, self.input_channels, self.patch_size, self.patch_size))
x = F.relu(self.conv1_bn(self.conv1(x)))
_, c, w, h = x.size()
size0 = c * w * h
x = self.pool1(x)
_, c, w, h = x.size()
size1 = c * w * h
x = self.conv1_bn(x)
_, c, w, h = x.size()
size2 = c * w * h
return size0, size1, size2
def forward(self, x):
x = x.squeeze()
x_conv1 = self.conv1_bn(self.conv1(x))
x = x_conv1
x_pool1 = self.pool1(x)
x = x_pool1
x_enc = F.relu(x).view(-1, self.features_sizes[2])
x = x_enc
x_classif = self.fc_enc(x)
# x = F.relu(self.fc1_dec_bn(self.fc1_dec(x) + x_enc))
x = F.relu(self.fc1_dec(x))
x = F.relu(
self.fc2_dec_bn(self.fc2_dec(x) + x_pool1.view(-1, self.features_sizes[1]))
)
x = F.relu(
self.fc3_dec_bn(self.fc3_dec(x) + x_conv1.view(-1, self.features_sizes[0]))
)
x = self.fc4_dec(x)
return x_classif, x
class BoulchEtAl(nn.Module):
"""
Autoencodeurs pour la visualisation d'images hyperspectrales
A.Boulch, N. Audebert, D. Dubucq
GRETSI 2017
"""
@staticmethod
def weight_init(m):
if isinstance(m, (nn.Linear, nn.Conv1d)):
init.kaiming_normal_(m.weight)
init.zeros_(m.bias)
def __init__(self, input_channels, n_classes, planes=16):
super(BoulchEtAl, self).__init__()
self.input_channels = input_channels
self.aux_loss_weight = 0.1
encoder_modules = []
n = input_channels
with torch.no_grad():
x = torch.zeros((10, 1, self.input_channels))
print(x.size())
while n > 1:
print("---------- {} ---------".format(n))
if n == input_channels:
p1, p2 = 1, 2 * planes
elif n == input_channels // 2:
p1, p2 = 2 * planes, planes
else:
p1, p2 = planes, planes
encoder_modules.append(nn.Conv1d(p1, p2, 3, padding=1))
x = encoder_modules[-1](x)
print(x.size())
encoder_modules.append(nn.MaxPool1d(2))
x = encoder_modules[-1](x)
print(x.size())
encoder_modules.append(nn.ReLU(inplace=True))
x = encoder_modules[-1](x)
print(x.size())
encoder_modules.append(nn.BatchNorm1d(p2))
x = encoder_modules[-1](x)
print(x.size())
n = n // 2
encoder_modules.append(nn.Conv1d(planes, 3, 3, padding=1))
encoder_modules.append(nn.Tanh())
self.encoder = nn.Sequential(*encoder_modules)
self.features_sizes = self._get_sizes()
self.classifier = nn.Linear(self.features_sizes, n_classes)
self.regressor = nn.Linear(self.features_sizes, input_channels)
self.apply(self.weight_init)
def _get_sizes(self):
with torch.no_grad():
x = torch.zeros((10, 1, self.input_channels))
x = self.encoder(x)
_, c, w = x.size()
return c * w
def forward(self, x):
x = x.unsqueeze(1)
x = self.encoder(x)
x = x.view(-1, self.features_sizes)
x_classif = self.classifier(x)
x = self.regressor(x)
return x_classif, x
class MouEtAl(nn.Module):
"""
Deep recurrent neural networks for hyperspectral image classification
Lichao Mou, Pedram Ghamisi, Xiao Xang Zhu
https://ieeexplore.ieee.org/document/7914752/
"""
@staticmethod
def weight_init(m):
# All weight matrices in our RNN and bias vectors are initialized with a uniform distribution, and the values of these weight matrices and bias vectors are initialized in the range [−0.1,0.1]
if isinstance(m, (nn.Linear, nn.GRU)):
init.uniform_(m.weight.data, -0.1, 0.1)
init.uniform_(m.bias.data, -0.1, 0.1)
def __init__(self, input_channels, n_classes):
# The proposed network model uses a single recurrent layer that adopts our modified GRUs of size 64 with sigmoid gate activation and PRetanh activation functions for hidden representations
super(MouEtAl, self).__init__()
self.input_channels = input_channels
self.gru = nn.GRU(1, 64, 1, bidirectional=False) # TODO: try to change this ?
self.gru_bn = nn.BatchNorm1d(64 * input_channels)
self.tanh = nn.Tanh()
self.fc = nn.Linear(64 * input_channels, n_classes)
def forward(self, x):
x = x.squeeze()
x = x.unsqueeze(0)
# x is in 1, N, C but we expect C, N, 1 for GRU layer
x = x.permute(2, 1, 0)
x = self.gru(x)[0]
# x is in C, N, 64, we permute back
x = x.permute(1, 2, 0).contiguous()
x = x.view(x.size(0), -1)
x = self.gru_bn(x)
x = self.tanh(x)
x = self.fc(x)
return x
def train(
net,
optimizer,