Skip to content

Commit

Permalink
replace unbiased with correction (NVIDIA#10555)
Browse files Browse the repository at this point in the history
* replace unbiased with correction

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <[email protected]>

---------

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Signed-off-by: nithinraok <[email protected]>
Co-authored-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: nithinraok <[email protected]>
  • Loading branch information
nithinraok and nithinraok authored Sep 20, 2024
1 parent d2d2aa0 commit 44d2ae7
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions nemo/collections/asr/parts/submodules/tdnn_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StatsPoolLayer(nn.Module):
pool_mode: Type of pool mode. Supported modes are 'xvector' (mean and standard deviation) and 'tap' (time
average pooling, i.e., mean)
eps: Epsilon, minimum value before taking the square root, when using 'xvector' mode.
biased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default
unbiased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default
for torch.Tensor.std() is True.
Returns:
Expand All @@ -42,15 +42,15 @@ class StatsPoolLayer(nn.Module):
ValueError if an unsupported pooling mode is specified.
"""

def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, biased: bool = True):
def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, unbiased: bool = True):
super().__init__()
supported_modes = {"xvector", "tap"}
if pool_mode not in supported_modes:
raise ValueError(f"Pool mode must be one of {supported_modes}; got '{pool_mode}'")
self.pool_mode = pool_mode
self.feat_in = feat_in
self.eps = eps
self.biased = biased
self.unbiased = unbiased
if self.pool_mode == 'xvector':
# Mean + std
self.feat_in *= 2
Expand All @@ -59,7 +59,8 @@ def forward(self, encoder_output, length=None):
if length is None:
mean = encoder_output.mean(dim=-1) # Time Axis
if self.pool_mode == 'xvector':
std = encoder_output.std(dim=-1)
correction = 1 if self.unbiased else 0
std = encoder_output.std(dim=-1, correction=correction).clamp(min=self.eps)
pooled = torch.cat([mean, std], dim=-1)
else:
pooled = mean
Expand All @@ -71,12 +72,13 @@ def forward(self, encoder_output, length=None):
# Re-scale to get padded means
means = means * (encoder_output.shape[-1] / length).unsqueeze(-1)
if self.pool_mode == "xvector":
correction = 1 if self.unbiased else 0
stds = (
encoder_output.sub(means.unsqueeze(-1))
.masked_fill(mask, 0.0)
.pow(2.0)
.sum(-1) # [B, D, T] -> [B, D]
.div(length.view(-1, 1).sub(1 if self.biased else 0))
.div(length.view(-1, 1).sub(correction))
.clamp(min=self.eps)
.sqrt()
)
Expand Down Expand Up @@ -104,7 +106,7 @@ def make_seq_mask_like(

def lens_to_mask(lens: List[int], max_len: int, device: str = None):
"""
outputs masking labels for list of lengths of audio features, with max length of any
outputs masking labels for list of lengths of audio features, with max length of any
mask as max_len
input:
lens: list of lens
Expand All @@ -124,8 +126,8 @@ def get_statistics_with_mask(x: torch.Tensor, m: torch.Tensor, dim: int = 2, eps
"""
compute mean and standard deviation of input(x) provided with its masking labels (m)
input:
x: feature input
m: averaged mask labels
x: feature input
m: averaged mask labels
output:
mean: mean of input features
std: stadard deviation of input features
Expand All @@ -146,7 +148,7 @@ class TDNNModule(nn.Module):
stride: stride for conv layer
padding: padding for conv layer (default None: chooses padding value such that input and output feature shape matches)
output:
tdnn layer output
tdnn layer output
"""

def __init__(
Expand Down Expand Up @@ -183,7 +185,7 @@ class MaskedSEModule(nn.Module):
"""
Squeeze and Excite module implementation with conv1d layers
input:
inp_filters: input filter channel size
inp_filters: input filter channel size
se_filters: intermediate squeeze and excite channel output and input size
out_filters: output filter channel size
kernel_size: kernel_size for both conv1d layers
Expand All @@ -196,10 +198,20 @@ class MaskedSEModule(nn.Module):
def __init__(self, inp_filters: int, se_filters: int, out_filters: int, kernel_size: int = 1, dilation: int = 1):
super().__init__()
self.se_layer = nn.Sequential(
nn.Conv1d(inp_filters, se_filters, kernel_size=kernel_size, dilation=dilation,),
nn.Conv1d(
inp_filters,
se_filters,
kernel_size=kernel_size,
dilation=dilation,
),
nn.ReLU(),
nn.BatchNorm1d(se_filters),
nn.Conv1d(se_filters, out_filters, kernel_size=kernel_size, dilation=dilation,),
nn.Conv1d(
se_filters,
out_filters,
kernel_size=kernel_size,
dilation=dilation,
),
nn.Sigmoid(),
)

Expand All @@ -220,7 +232,7 @@ class TDNNSEModule(nn.Module):
Modified building SE_TDNN group module block from ECAPA implementation for faster training and inference
Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
inputs:
inp_filters: input filter channel size
inp_filters: input filter channel size
out_filters: output filter channel size
group_scale: scale value to group wider conv channels (deafult:8)
se_channels: squeeze and excite output channel size (deafult: 1024/8= 128)
Expand Down Expand Up @@ -276,7 +288,7 @@ class AttentivePoolLayer(nn.Module):
inp_filters: input feature channel length from encoder
attention_channels: intermediate attention channel size
kernel_size: kernel_size for TDNN and attention conv1d layers (default: 1)
dilation: dilation size for TDNN and attention conv1d layers (default: 1)
dilation: dilation size for TDNN and attention conv1d layers (default: 1)
"""

def __init__(
Expand All @@ -295,7 +307,10 @@ def __init__(
TDNNModule(inp_filters * 3, attention_channels, kernel_size=kernel_size, dilation=dilation),
nn.Tanh(),
nn.Conv1d(
in_channels=attention_channels, out_channels=inp_filters, kernel_size=kernel_size, dilation=dilation,
in_channels=attention_channels,
out_channels=inp_filters,
kernel_size=kernel_size,
dilation=dilation,
),
)
self.eps = eps
Expand Down

0 comments on commit 44d2ae7

Please sign in to comment.