You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am getting the mentioned error in this part of the code:
if epoch >= TMA_epoch: # start TMA training
loss_s2s = 0
for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
loss_s2s /= texts.size(0)
d_loss = self._dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
Getting into details, it's the error in the forward method, in WavLMLoss class:
def forward(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True)
y_rec_embeddings = y_rec_embeddings.hidden_states
floss = 0
for er, eg in zip(wav_embeddings, y_rec_embeddings):
floss += torch.mean(torch.abs(er - eg))
return floss.mean()
self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True) is giving me the exact same error and I don't know why.
What is your dependencies versions for this project?
I am getting the mentioned error in this part of the code:
if epoch >= TMA_epoch: # start TMA training
loss_s2s = 0
for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
loss_s2s /= texts.size(0)
The text was updated successfully, but these errors were encountered: