forked from zhaodongsun/contrast-phys
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PhysNetModel.py
113 lines (98 loc) · 4.48 KB
/
PhysNetModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------------------------------------------------------------------------------------------------------------
# PhysNet model
#
# the output is an ST-rPPG block rather than a rPPG signal.
# -------------------------------------------------------------------------------------------------------------------
class PhysNet(nn.Module):
def __init__(self, S=2, in_ch=3):
super().__init__()
self.S = S # S is the spatial dimension of ST-rPPG block
self.start = nn.Sequential(
nn.Conv3d(in_channels=in_ch, out_channels=32, kernel_size=(1, 5, 5), stride=1, padding=(0, 2, 2)),
nn.BatchNorm3d(32),
nn.ELU()
)
# 1x
self.loop1 = nn.Sequential(
nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0),
nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU()
)
# encoder
self.encoder1 = nn.Sequential(
nn.AvgPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
)
self.encoder2 = nn.Sequential(
nn.AvgPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU()
)
#
self.loop4 = nn.Sequential(
nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU()
)
# decoder to reach back initial temporal length
self.decoder1 = nn.Sequential(
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)),
nn.BatchNorm3d(64),
nn.ELU(),
)
self.decoder2 = nn.Sequential(
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)),
nn.BatchNorm3d(64),
nn.ELU()
)
self.end = nn.Sequential(
nn.AdaptiveAvgPool3d((None, S, S)),
nn.Conv3d(in_channels=64, out_channels=1, kernel_size=(1, 1, 1), stride=1, padding=(0, 0, 0))
)
def forward(self, x):
means = torch.mean(x, dim=(2, 3, 4), keepdim=True)
stds = torch.std(x, dim=(2, 3, 4), keepdim=True)
x = (x - means) / stds # (B, C, T, 128, 128)
parity = []
x = self.start(x) # (B, C, T, 128, 128)
x = self.loop1(x) # (B, 64, T, 64, 64)
parity.append(x.size(2) % 2)
x = self.encoder1(x) # (B, 64, T/2, 32, 32)
parity.append(x.size(2) % 2)
x = self.encoder2(x) # (B, 64, T/4, 16, 16)
x = self.loop4(x) # (B, 64, T/4, 8, 8)
x = F.interpolate(x, scale_factor=(2, 1, 1)) # (B, 64, T/2, 8, 8)
x = self.decoder1(x) # (B, 64, T/2, 8, 8)
x = F.pad(x, (0,0,0,0,0,parity[-1]), mode='replicate')
x = F.interpolate(x, scale_factor=(2, 1, 1)) # (B, 64, T, 8, 8)
x = self.decoder2(x) # (B, 64, T, 8, 8)
x = F.pad(x, (0,0,0,0,0,parity[-2]), mode='replicate')
x = self.end(x) # (B, 1, T, S, S), ST-rPPG block
x_list = []
for a in range(self.S):
for b in range(self.S):
x_list.append(x[:,:,:,a,b]) # (B, 1, T)
x = sum(x_list)/(self.S*self.S) # (B, 1, T)
X = torch.cat(x_list+[x], 1) # (B, M, T), flatten all spatial signals to the second dimension
return X