Skip to content

Commit

Permalink
fixed tests + n_channels for wasserstein, spherical, and resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
domkirke committed Dec 18, 2023
1 parent cab9f84 commit 562ef5e
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 35 deletions.
15 changes: 9 additions & 6 deletions rave/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,15 @@ def __init__(
data_size: int,
ratios: int,
noise_bands: int,
n_channels: int = 1,
activation: Callable[[int], nn.Module] = lambda dim: nn.LeakyReLU(.2),
):
super().__init__()
net = []
self.n_channels = n_channels
channels = [in_size]
channels.extend((len(ratios) - 1) * [hidden_size])
channels.append(data_size * noise_bands)
channels.append(data_size * noise_bands * n_channels)

for i, r in enumerate(ratios):
net.append(
Expand All @@ -280,7 +282,7 @@ def __init__(
def forward(self, x):
amp = mod_sigmoid(self.net(x) - 5)
amp = amp.permute(0, 2, 1)
amp = amp.reshape(amp.shape[0], amp.shape[1], self.data_size, -1)
amp = amp.reshape(amp.shape[0], amp.shape[1], self.n_channels * self.data_size, -1)

ir = amp_to_impulse_response(amp, self.target_size)
noise = torch.rand_like(ir) * 2 - 1
Expand Down Expand Up @@ -683,7 +685,7 @@ def __init__(

if noise_module is not None:
self.waveform_module = waveform_module
self.noise_module = noise_module(out_channels)
self.noise_module = noise_module(out_channels, n_channels = n_channels)
else:
net.append(waveform_module)

Expand Down Expand Up @@ -749,9 +751,10 @@ def __init__(
self,
encoder_cls,
noise_augmentation: int = 0,
n_channels: int = 1
):
super().__init__()
self.encoder = encoder_cls()
self.encoder = encoder_cls(n_channels=n_channels)
self.register_buffer("warmed_up", torch.tensor(0))
self.noise_augmentation = noise_augmentation

Expand Down Expand Up @@ -829,9 +832,9 @@ def forward(self, x):

class SphericalEncoder(nn.Module):

def __init__(self, encoder_cls: Callable[[], nn.Module]) -> None:
def __init__(self, encoder_cls: Callable[[], nn.Module], n_channels: int = 1) -> None:
super().__init__()
self.encoder = encoder_cls()
self.encoder = encoder_cls(n_channels=n_channels)

def reparametrize(self, z):
norm_z = z / torch.norm(z, p=2, dim=1, keepdim=True)
Expand Down
13 changes: 7 additions & 6 deletions rave/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def rectified_2d_conv_block(

class EncodecConvNet(nn.Module):

def __init__(self, capacity: int) -> None:
def __init__(self, capacity: int, n_channels: int = 1) -> None:
super().__init__()
self.net = nn.Sequential(
rectified_2d_conv_block(capacity, (9, 3), in_size=2),
rectified_2d_conv_block(capacity, (9, 3), in_size=2*n_channels),
rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 1)),
rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 2)),
rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 4)),
Expand Down Expand Up @@ -139,10 +139,10 @@ def forward(self, x):
class MultiScaleSpectralDiscriminator(nn.Module):

def __init__(self, scales: Sequence[int],
convnet: Callable[[], nn.Module]) -> None:
convnet: Callable[[], nn.Module], n_channels: int = 1) -> None:
super().__init__()
self.specs = nn.ModuleList([spectrogram(n) for n in scales])
self.nets = nn.ModuleList([convnet() for _ in scales])
self.nets = nn.ModuleList([convnet(n_channels=n_channels) for _ in scales])

def forward(self, x):
features = []
Expand All @@ -156,10 +156,11 @@ def forward(self, x):
class MultiScaleSpectralDiscriminator1d(nn.Module):

def __init__(self, scales: Sequence[int],
convnet: Callable[[int], nn.Module]) -> None:
convnet: Callable[[int], nn.Module],
n_channels: int = 1) -> None:
super().__init__()
self.specs = nn.ModuleList([spectrogram(n) for n in scales])
self.nets = nn.ModuleList([convnet(n + 2) for n in scales])
self.nets = nn.ModuleList([convnet(n + 2, n_channels) for n in scales])

def forward(self, x):
features = []
Expand Down
12 changes: 8 additions & 4 deletions rave/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ def __init__(self, target_sr, model_sr):
self.ratio = ratio

def to_model_sampling_rate(self, x):
return self.downsample(x)
x_down = x.reshape(-1, 1, x.shape[-1])
x_down = self.downsample(x_down)
return x_down.reshape(x.shape[0], x.shape[1], -1)

def from_model_sampling_rate(self, x):
x = self.upsample(x) # B x 2 x T
x = x.permute(0, 2, 1).reshape(x.shape[0], -1).unsqueeze(1)
return x
x_up = x.reshape(-1, 1, x.shape[-1])
x_up = self.upsample(x_up) # B x 2 x T
x_up = x_up.permute(0, 2, 1).reshape(x_up.shape[0], -1).unsqueeze(1)
x_up = x_up.reshape(x.shape[0], x.shape[1], -1)
return x_up
9 changes: 4 additions & 5 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,13 @@ def __init__(self,
self.decoder = pretrained.decoder
x_len = 2**14
x = torch.zeros(1, self.n_channels, x_len)
if self.resampler is not None:
x = self.resampler.to_model_sampling_rate(x)
z = self.encode(x)
ratio_encode = x_len // z.shape[-1]

# configure encoder
if pretrained.input_mode != "pqmf":
if (pretrained.input_mode == "pqmf") or (pretrained.output_mode == "pqmf"):
# scripting fails if cached conv is not initialized
self.pqmf(x)
self.pqmf(torch.zeros(1, 1, x_len))

encode_shape = (pretrained.n_channels, 2**14)

Expand Down Expand Up @@ -258,6 +256,8 @@ def encode(self, x):
if self.spectrogram is not None:
x = self.spectrogram(x)[..., :-1]
x = torch.log1p(x).reshape(batch_size + (-1, x.shape[-1]))
else:
raise RuntimeError()
z = self.encoder(x)
z = self.post_process_latent(z)
return z
Expand Down Expand Up @@ -298,7 +298,6 @@ def decode(self, z, from_forward: bool = False):
y = torch.cat(y.chunk(self.target_channels, 0), 1)
elif self.target_channels < self.n_channels:
y = y[:, :self.target_channels]
# return y[..., ]
return y

def forward(self, x):
Expand Down
29 changes: 15 additions & 14 deletions tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@
["v2.gin", "adain.gin"],
["v2.gin", "wasserstein.gin"],
["v2.gin", "spherical.gin"],
# ["v2.gin", "hybrid.gin"], NOT READY YET
["v2.gin", "hybrid.gin"],
["v2_small.gin", "adain.gin"],
["v2_small.gin", "wasserstein.gin"],
["v2_small.gin", "spherical.gin"],
["v2_small.gin", "hybrid.gin"],
["discrete.gin"],
["discrete.gin", "snake.gin"],
["discrete.gin", "snake.gin", "adain.gin"],
["discrete.gin", "snake.gin", "descript_discriminator.gin"],
["discrete.gin", "spectral_discriminator.gin"],
["discrete.gin", "noise.gin"],
["discrete.gin", "hybrid.gin"],
["v3.gin"],
["v3.gin", "hybrid.gin"]
]

configs += [c + ["causal.gin"] for c in configs]
Expand All @@ -44,24 +50,19 @@
("stereo" if e[2] else "mono"), configs),
)
def test_config(config, sr, stereo):
if any(map(lambda x: "adain" in x, config)) and stereo:
pytest.skip()

gin.clear_config()
gin.parse_config_files_and_bindings(config, [
f"SAMPLING_RATE={sr}",
"CAPACITY=2",
])

model = rave.RAVE()
n_channels = 2 if stereo else 1
model = rave.RAVE(n_channels=n_channels)

if stereo:
for m in model.modules():
if isinstance(m, rave.blocks.AdaptiveInstanceNormalization):
pytest.skip()

x = torch.randn(1, 1, 2**15)
z = model.encode(x)
x = torch.randn(1, n_channels, 2**15)
z, _ = model.encode(x, return_mb=True)
z, _ = model.encoder.reparametrize(z)[:2]
y = model.decode(z)
score = model.discriminator(y)

Expand All @@ -79,7 +80,7 @@ def test_config(config, sr, stereo):
raise ValueError(f"Encoder type {type(model.encoder)} "
"not supported for export.")

x = torch.zeros(1, 1, 2**14)
x = torch.zeros(1, n_channels, 2**14)

model(x)

Expand All @@ -89,12 +90,12 @@ def test_config(config, sr, stereo):

scripted_rave = script_class(
pretrained=model,
stereo=stereo,
channels=n_channels,
)

scripted_rave_resampled = script_class(
pretrained=model,
stereo=stereo,
channels=n_channels,
target_sr=44100,
)

Expand Down

0 comments on commit 562ef5e

Please sign in to comment.