-
Notifications
You must be signed in to change notification settings - Fork 0
/
JSCE.py
31 lines (28 loc) · 939 Bytes
/
JSCE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import math
import torch.nn as nn
from fl_module import FL1
from af_module import AF
class JSCE(nn.Module):
def __init__(self):
super(JSCE,self).__init__()
self.fl1 = FL1(5,1,3,1,0)
self.af1 = AF(3,1)
self.fl2 = FL1(3,3,16,2,1)
self.af2 = AF(16,1)
self.fl3 = FL1(3,16,32,2,1)
self.af3 = AF(32,1)
self.fl4 = FL1(3,32,64,2,1)
def forward(self,x,SNR):
x = self.fl1.forward(x)
x = self.af1.forward(x,SNR)
x = self.fl2.forward(x)
x = self.af2.forward(x,SNR)
x = self.fl3.forward(x)
x = self.af3.forward(x,SNR)
x = self.fl4.forward(x)
#加入能量Pnorm模块
temp = torch.sum(torch.sum(torch.sum(x*x,dim=1),dim=1),dim=1)
self.norm = torch.sqrt(temp.reshape(x.shape[0],1,1,1))
x = x * (1 / self.norm) * math.sqrt(64*4*4)
return x