-
Notifications
You must be signed in to change notification settings - Fork 2
/
IrrelevantPowerRatio.py
34 lines (30 loc) · 1.45 KB
/
IrrelevantPowerRatio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
tr = torch
import torch.nn.functional as F
import numpy as np
import torch.fft
class IrrelevantPowerRatio(nn.Module):
# we reuse the code in Gideon2021 to get irrelevant power ratio
# Gideon, John, and Simon Stent. "The way to my heart is through contrastive learning: Remote photoplethysmography from unlabelled video." Proceedings of the IEEE/CVF international conference on computer vision. 2021.
def __init__(self, Fs, high_pass, low_pass):
super(IrrelevantPowerRatio, self).__init__()
self.Fs = Fs
self.high_pass = high_pass
self.low_pass = low_pass
def forward(self, preds):
# Get PSD
X_real = torch.view_as_real(torch.fft.rfft(preds, dim=-1, norm='forward'))
# Determine ratio of energy between relevant and non-relevant regions
Fn = self.Fs / 2
freqs = torch.linspace(0, Fn, X_real.shape[-2])
use_freqs = torch.logical_and(freqs >= self.high_pass / 60, freqs <= self.low_pass / 60)
zero_freqs = torch.logical_not(use_freqs)
use_energy = tr.sum(tr.linalg.norm(X_real[:,use_freqs], dim=-1), dim=-1)
zero_energy = tr.sum(tr.linalg.norm(X_real[:,zero_freqs], dim=-1), dim=-1)
denom = use_energy + zero_energy
energy_ratio = tr.ones_like(denom)
for ii in range(len(denom)):
if denom[ii] > 0:
energy_ratio[ii] = zero_energy[ii] / denom[ii]
return energy_ratio