Skip to content

Commit

Permalink
Fixes #221 : workaround that disable SPD attention in latest version …
Browse files Browse the repository at this point in the history
…of openai-whisper (20240930) which prevents from accessing attention weights
  • Loading branch information
Jeronymous committed Nov 4, 2024
1 parent ee35e7c commit e495276
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = "Jérôme Louradour"
__credits__ = ["Jérôme Louradour"]
__license__ = "GPLv3"
__version__ = "1.15.5"
__version__ = "1.15.6"

# Set some environment variables
import os
Expand Down Expand Up @@ -899,8 +899,9 @@ def hook_output_logits(layer, ins, outs):
if compute_word_confidence or no_speech_threshold is not None:
all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits))

with disable_sdpa():
transcription = model.transcribe(audio, **whisper_options)
with torch.no_grad():
with disable_sdpa():
transcription = model.transcribe(audio, **whisper_options)

finally:

Expand Down Expand Up @@ -1062,8 +1063,9 @@ def hook_output_logits(layer, ins, outs):

try:
model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?""
with disable_sdpa():
transcription = model.transcribe(audio, **whisper_options)
with torch.no_grad():
with disable_sdpa():
transcription = model.transcribe(audio, **whisper_options)
finally:
for hook in all_hooks:
hook.remove()
Expand Down Expand Up @@ -1238,8 +1240,9 @@ def hook(layer, ins, outs, index=j):
i_start = len(sot_sequence)

with torch.no_grad():
logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0))
logprobs = F.log_softmax(logprobs, dim=-1)
with disable_sdpa():
logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0))
logprobs = F.log_softmax(logprobs, dim=-1)

end_token = tokenizer.timestamp_begin + round(min(N_FRAMES * HOP_LENGTH, end_sample - start_sample) // AUDIO_SAMPLES_PER_TOKEN)
tokens = tokens[i_start:] + [end_token]
Expand Down

0 comments on commit e495276

Please sign in to comment.