diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 2c14dee..9a14a44 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -3,7 +3,7 @@ __author__ = "Jérôme Louradour" __credits__ = ["Jérôme Louradour"] __license__ = "GPLv3" -__version__ = "1.15.1" +__version__ = "1.15.2" # Set some environment variables import os @@ -277,9 +277,9 @@ def transcribe_timestamped( compression_ratio_threshold=compression_ratio_threshold, ) - if vad: + if vad is not None: audio = get_audio_tensor(audio) - audio, vad_segments, convert_timestamps = remove_non_speech(audio, method=vad, sample_rate=SAMPLE_RATE, plot=plot_word_alignment) + audio, vad_segments, convert_timestamps = remove_non_speech(audio, method=vad, sample_rate=SAMPLE_RATE, plot=plot_word_alignment, avoid_empty_speech=True) else: vad_segments = None @@ -1856,8 +1856,8 @@ def check_vad_method(method, with_version=False): """ if method in [True, "True", "true"]: return check_vad_method("silero") # default method - elif method in [False, "False", "false"]: - return False + elif method in [None, False, "False", "false", "None", "none"]: + return None elif not isinstance(method, str) and hasattr(method, '__iter__'): # list of explicit timestamps checked_pairs = [] @@ -2063,6 +2063,7 @@ def remove_non_speech(audio, dilatation=0.5, sample_rate=SAMPLE_RATE, method="silero", + avoid_empty_speech=False, plot=False, ): """ @@ -2083,6 +2084,8 @@ def remove_non_speech(audio, how much (in sec) to enlarge each speech segment detected by the VAD method: str method to use to remove non-speech segments + avoid_empty_speech: bool + if True, avoid returning an empty speech segment (re) plot: bool or str if True, plot the result. If a string, save the plot to the given file @@ -2100,7 +2103,10 @@ def remove_non_speech(audio, segments = [(seg["start"], seg["end"]) for seg in segments] if len(segments) == 0: - segments = [(0, audio.shape[-1])] + if avoid_empty_speech: + segments = [(0, audio.shape[-1])] + else: + return torch.Tensor([]), [], lambda t, t2 = None: do_convert_timestamps(segments, t, t2) audio_speech = torch.cat([audio[..., s:e] for s,e in segments], dim=-1) @@ -2121,7 +2127,7 @@ def remove_non_speech(audio, if not use_sample: segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments] - return audio_speech, segments, lambda t, t2 = None: do_convert_timestamps(segments, t, t2) + return audio_speech, segments, lambda t, t2 = None: t if t2 is None else [t, t2] def do_convert_timestamps(segments, t, t2 = None): """