Skip to content

Commit

Permalink
type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
blaisewf committed Dec 3, 2024
1 parent 2a763e1 commit 5df0934
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 119 deletions.
32 changes: 16 additions & 16 deletions rvc/lib/algorithm/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class MultiHeadAttention(torch.nn.Module):

def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
channels: int,
out_channels: int,
n_heads: int,
p_dropout: float = 0.0,
window_size: int = None,
heads_share: bool = True,
block_length: int = None,
proximal_bias: bool = False,
proximal_init: bool = False,
):
super().__init__()
assert (
Expand Down Expand Up @@ -201,13 +201,13 @@ class FFN(torch.nn.Module):

def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
in_channels = int,
out_channels = int,
filter_channels = int,
kernel_size = int,
p_dropout: float = 0.0,
activation: str = None,
causal: bool = False,
):
super().__init__()
self.padding_fn = self._causal_padding if causal else self._same_padding
Expand Down
12 changes: 5 additions & 7 deletions rvc/lib/algorithm/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
Defaults to False.
"""

def __init__(self, version, use_spectral_norm=False):
def __init__(self, version: str, use_spectral_norm: bool = False):
super(MultiPeriodDiscriminator, self).__init__()
periods = (
[2, 3, 5, 7, 11, 17] if version == "v1" else [2, 3, 5, 7, 11, 17, 23, 37]
Expand Down Expand Up @@ -59,7 +59,7 @@ class DiscriminatorS(torch.nn.Module):
convolutional layers that are applied to the input signal.
"""

def __init__(self, use_spectral_norm=False):
def __init__(self, use_spectral_norm: bool =False):
super(DiscriminatorS, self).__init__()
norm_f = spectral_norm if use_spectral_norm else weight_norm
self.convs = torch.nn.ModuleList(
Expand Down Expand Up @@ -103,14 +103,12 @@ class DiscriminatorP(torch.nn.Module):
Args:
period (int): Period of the discriminator.
kernel_size (int): Kernel size of the convolutional layers.
Defaults to 5.
kernel_size (int): Kernel size of the convolutional layers. Defaults to 5.
stride (int): Stride of the convolutional layers. Defaults to 3.
use_spectral_norm (bool): Whether to use spectral normalization.
Defaults to False.
use_spectral_norm (bool): Whether to use spectral normalization. Defaults to False.
"""

def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
def __init__(self, period: int, kernel_size: int = 5, stride: int = 3, use_spectral_norm: bool = False):
super(DiscriminatorP, self).__init__()
self.period = period
norm_f = spectral_norm if use_spectral_norm else weight_norm
Expand Down
46 changes: 23 additions & 23 deletions rvc/lib/algorithm/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ class Encoder(torch.nn.Module):

def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=10,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int = 1,
p_dropout: float = 0.0,
window_size: int = 10,
):
super().__init__()
self.hidden_channels = hidden_channels
Expand Down Expand Up @@ -101,15 +101,15 @@ class TextEncoder(torch.nn.Module):

def __init__(
self,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
embedding_dim,
f0=True,
out_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: float,
embedding_dim: int,
f0: bool = True,
):
super(TextEncoder, self).__init__()
self.out_channels = out_channels
Expand Down Expand Up @@ -166,13 +166,13 @@ class PosteriorEncoder(torch.nn.Module):

def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
gin_channels: int = 0,
):
super(PosteriorEncoder, self).__init__()
self.in_channels = in_channels
Expand Down
15 changes: 7 additions & 8 deletions rvc/lib/algorithm/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ class Generator(torch.nn.Module):

def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
initial_channel: int,
resblock_kernel_sizes: list,
resblock_dilation_sizes: list,
upsample_rates: list,
upsample_initial_channel: int,
upsample_kernel_sizes: list,
gin_channels: int = 0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
Expand Down
10 changes: 5 additions & 5 deletions rvc/lib/algorithm/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ class WaveNet(torch.nn.Module):

def __init__(
self,
hidden_channels,
kernel_size,
hidden_channels: int,
kernel_size: int,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
n_layers: int,
gin_channels: int = 0,
p_dropout: int = 0,
):
super().__init__()
assert kernel_size % 2 == 1, "Kernel size must be odd for proper padding."
Expand Down
2 changes: 1 addition & 1 deletion rvc/lib/algorithm/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class LayerNorm(torch.nn.Module):
eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-5.
"""

def __init__(self, channels, eps=1e-5):
def __init__(self, channels: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.gamma = torch.nn.Parameter(torch.ones(channels))
Expand Down
35 changes: 17 additions & 18 deletions rvc/lib/algorithm/nsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class SourceModuleHnNSF(torch.nn.Module):

def __init__(
self,
sample_rate,
harmonic_num=0,
sine_amp=0.1,
add_noise_std=0.003,
voiced_threshod=0,
is_half=True,
sample_rate: int,
harmonic_num: int = 0,
sine_amp: float = 0.1,
add_noise_std: float = 0.003,
voiced_threshod: float = 0,
is_half: bool = True,
):
super(SourceModuleHnNSF, self).__init__()

Expand Down Expand Up @@ -69,16 +69,15 @@ class GeneratorNSF(torch.nn.Module):

def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels,
sr,
is_half=False,
initial_channel : int,
resblock_kernel_sizes: list,
resblock_dilation_sizes: list,
upsample_rates: list,
upsample_initial_channel: int,
upsample_kernel_sizes: list,
gin_channels: int,
sr: int,
is_half: bool = False,
):
super(GeneratorNSF, self).__init__()

Expand Down Expand Up @@ -157,13 +156,13 @@ def forward(self, x, f0, g: Optional[torch.Tensor] = None):
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
x = ups(x)
x += noise_convs(har_source)
x += noise_convs(har_source)

xs = sum(
self.resblocks[j](x)
for j in range(i * self.num_kernels, (i + 1) * self.num_kernels)
)
x = xs / self.num_kernels
x = xs / self.num_kernels

x = torch.nn.functional.leaky_relu(x)
x = torch.tanh(self.conv_post(x))
Expand Down
38 changes: 19 additions & 19 deletions rvc/lib/algorithm/residuals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple
import torch
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
Expand Down Expand Up @@ -27,7 +27,7 @@ def apply_mask(tensor, mask):


class ResBlockBase(torch.nn.Module):
def __init__(self, channels, kernel_size, dilations):
def __init__(self, channels: int, kernel_size: int, dilations: Tuple[int]):
super(ResBlockBase, self).__init__()
self.convs1 = torch.nn.ModuleList(
[create_conv1d_layer(channels, kernel_size, d) for d in dilations]
Expand Down Expand Up @@ -55,7 +55,7 @@ def remove_weight_norm(self):


class ResBlock(ResBlockBase):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int] = (1, 3, 5)):
super(ResBlock, self).__init__(channels, kernel_size, dilation)


Expand Down Expand Up @@ -95,13 +95,13 @@ class ResidualCouplingBlock(torch.nn.Module):

def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0,
channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
n_flows: int = 4,
gin_channels: int = 0,
):
super(ResidualCouplingBlock, self).__init__()
self.channels = channels
Expand Down Expand Up @@ -176,14 +176,14 @@ class ResidualCouplingLayer(torch.nn.Module):

def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False,
channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
p_dropout: float = 0,
gin_channels: int = 0,
mean_only: bool = False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
Expand All @@ -210,7 +210,7 @@ def __init__(
self.post.weight.data.zero_()
self.post.bias.data.zero_()

def forward(self, x, x_mask, g=None, reverse=False):
def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False):
"""Forward pass.
Args:
Expand Down
41 changes: 19 additions & 22 deletions rvc/lib/algorithm/synthesizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,25 @@ class Synthesizer(torch.nn.Module):

def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
sr,
use_f0,
text_enc_hidden_dim=768,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: float,
resblock_kernel_sizes: list,
resblock_dilation_sizes: list,
upsample_rates: list,
upsample_initial_channel: int,
upsample_kernel_sizes: list,
spk_embed_dim: int,
gin_channels: int,
sr: int,
use_f0: bool,
text_enc_hidden_dim: int = 768,
**kwargs,
):
super().__init__()
Expand All @@ -79,7 +78,6 @@ def __init__(
if use_f0:
self.dec = GeneratorNSF(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
Expand All @@ -92,7 +90,6 @@ def __init__(
else:
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
Expand Down

0 comments on commit 5df0934

Please sign in to comment.