forked from fortypercnt/stream-translator
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d3bad77
commit ba34cdd
Showing
12 changed files
with
481 additions
and
488 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,3 +129,4 @@ dmypy.json | |
.pyre/ | ||
|
||
--Frag* | ||
*.wav |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import queue | ||
import signal | ||
import subprocess | ||
import sys | ||
import threading | ||
|
||
import ffmpeg | ||
import numpy as np | ||
|
||
from common import SAMPLE_RATE | ||
|
||
|
||
def _transport(ytdlp_proc, ffmpeg_proc): | ||
while (ytdlp_proc.poll() is None) and (ffmpeg_proc.poll() is None): | ||
try: | ||
chunk = ytdlp_proc.stdout.read(1024) | ||
ffmpeg_proc.stdin.write(chunk) | ||
except (BrokenPipeError, OSError): | ||
pass | ||
ytdlp_proc.kill() | ||
ffmpeg_proc.kill() | ||
|
||
|
||
def _open_stream(url: str, direct_url: bool, format: str, cookies: str): | ||
if direct_url: | ||
try: | ||
process = (ffmpeg.input( | ||
url, loglevel="panic").output("pipe:", | ||
format="s16le", | ||
acodec="pcm_s16le", | ||
ac=1, | ||
ar=SAMPLE_RATE).run_async(pipe_stdout=True)) | ||
except ffmpeg.Error as e: | ||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | ||
|
||
return process, None | ||
|
||
cmd = ['yt-dlp', url, '-f', format, '-o', '-', '-q'] | ||
if cookies: | ||
cmd.extend(['--cookies', cookies]) | ||
ytdlp_process = subprocess.Popen(cmd, stdout=subprocess.PIPE) | ||
|
||
try: | ||
ffmpeg_process = (ffmpeg.input("pipe:", loglevel="panic").output("pipe:", | ||
format="s16le", | ||
acodec="pcm_s16le", | ||
ac=1, | ||
ar=SAMPLE_RATE).run_async( | ||
pipe_stdin=True, | ||
pipe_stdout=True)) | ||
except ffmpeg.Error as e: | ||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | ||
|
||
thread = threading.Thread(target=_transport, args=(ytdlp_process, ffmpeg_process)) | ||
thread.start() | ||
return ffmpeg_process, ytdlp_process | ||
|
||
|
||
class StreamAudioGetter(): | ||
|
||
def __init__(self, url: str, direct_url: bool, format: str, cookies: str, frame_duration: float): | ||
print("Opening stream {}".format(url)) | ||
self.ffmpeg_process, self.ytdlp_process = _open_stream(url, direct_url, format, cookies) | ||
self.byte_size = round(frame_duration * SAMPLE_RATE * 2) # Factor 2 comes from reading the int16 stream as bytes | ||
signal.signal(signal.SIGINT, self._exit_handler) | ||
|
||
def _exit_handler(self, signum, frame): | ||
self.ffmpeg_process.kill() | ||
if self.ytdlp_process: | ||
self.ytdlp_process.kill() | ||
sys.exit(0) | ||
|
||
def work(self, output_queue: queue.SimpleQueue[np.array]): | ||
while self.ffmpeg_process.poll() is None: | ||
in_bytes = self.ffmpeg_process.stdout.read(self.byte_size) | ||
if not in_bytes: | ||
break | ||
if len(in_bytes) != self.byte_size: | ||
continue | ||
audio = np.frombuffer(in_bytes, np.int16).flatten().astype(np.float32) / 32768.0 | ||
output_queue.put(audio) | ||
|
||
self.ffmpeg_process.kill() | ||
if self.ytdlp_process: | ||
self.ytdlp_process.kill() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import queue | ||
import torch | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
from common import TranslationTask, SAMPLE_RATE | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
|
||
def _init_jit_model(model_path: str, device=torch.device('cpu')): | ||
torch.set_grad_enabled(False) | ||
model = torch.jit.load(model_path, map_location=device) | ||
model.eval() | ||
return model | ||
|
||
|
||
class VAD: | ||
|
||
def __init__(self): | ||
self.model = _init_jit_model("silero_vad.jit") | ||
|
||
def is_speech(self, audio, threshold: float = 0.5, sampling_rate: int = 16000): | ||
if not torch.is_tensor(audio): | ||
try: | ||
audio = torch.Tensor(audio) | ||
except: | ||
raise TypeError("Audio cannot be casted to tensor. Cast it manually") | ||
speech_prob = self.model(audio, sampling_rate).item() | ||
return speech_prob >= threshold | ||
|
||
def reset_states(self): | ||
self.model.reset_states() | ||
|
||
|
||
class AudioSlicer: | ||
|
||
def __init__(self, frame_duration: float, continuous_no_speech_threshold: float, min_audio_length: float, | ||
max_audio_length: float, prefix_retention_length: float, vad_threshold: float): | ||
self.vad = VAD() | ||
self.continuous_no_speech_threshold = round(continuous_no_speech_threshold / frame_duration) | ||
self.min_audio_length = round(min_audio_length / frame_duration) | ||
self.max_audio_length = round(max_audio_length / frame_duration) | ||
self.prefix_retention_length = round(prefix_retention_length / frame_duration) | ||
self.vad_threshold = vad_threshold | ||
self.sampling_rate = SAMPLE_RATE | ||
self.audio_buffer = [] | ||
self.prefix_audio_buffer = [] | ||
self.speech_count = 0 | ||
self.no_speech_count = 0 | ||
self.continuous_no_speech_count = 0 | ||
self.frame_duration = frame_duration | ||
self.counter = 0 | ||
self.last_slice_second = 0.0 | ||
|
||
def put(self, audio): | ||
self.counter += 1 | ||
if self.vad.is_speech(audio, self.vad_threshold, self.sampling_rate): | ||
self.audio_buffer.append(audio) | ||
self.speech_count += 1 | ||
self.continuous_no_speech_count = 0 | ||
else: | ||
if self.speech_count == 0 and self.no_speech_count == 1: | ||
self.slice() | ||
self.audio_buffer.append(audio) | ||
self.no_speech_count += 1 | ||
self.continuous_no_speech_count += 1 | ||
if self.speech_count and self.no_speech_count / 4 > self.speech_count: | ||
self.slice() | ||
|
||
def should_slice(self): | ||
audio_len = len(self.audio_buffer) | ||
if audio_len < self.min_audio_length: | ||
return False | ||
if audio_len > self.max_audio_length: | ||
return True | ||
if self.continuous_no_speech_count >= self.continuous_no_speech_threshold: | ||
return True | ||
return False | ||
|
||
def slice(self): | ||
concatenate_buffer = self.prefix_audio_buffer + self.audio_buffer | ||
concatenate_audio = np.concatenate(concatenate_buffer) | ||
self.audio_buffer = [] | ||
self.prefix_audio_buffer = concatenate_buffer[-self.prefix_retention_length:] | ||
self.speech_count = 0 | ||
self.no_speech_count = 0 | ||
self.continuous_no_speech_count = 0 | ||
# self.vad.reset_states() | ||
slice_second = self.counter * self.frame_duration | ||
last_slice_second = self.last_slice_second | ||
self.last_slice_second = slice_second | ||
return concatenate_audio, (last_slice_second, slice_second) | ||
|
||
def work(self, input_queue: queue.SimpleQueue[np.array], output_queue: queue.SimpleQueue[TranslationTask]): | ||
while True: | ||
audio = input_queue.get() | ||
self.put(audio) | ||
if self.should_slice(): | ||
sliced_audio, time_range = self.slice() | ||
task = TranslationTask(sliced_audio, time_range) | ||
output_queue.put(task) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
import queue | ||
from scipy.io.wavfile import write as write_audio | ||
|
||
import numpy as np | ||
from openai import OpenAI | ||
|
||
import filters | ||
from common import TranslationTask, SAMPLE_RATE | ||
|
||
TEMP_AUDIO_FILE_NAME = 'temp.wav' | ||
|
||
|
||
def _filter_text(text, whisper_filters): | ||
filter_name_list = whisper_filters.split(',') | ||
for filter_name in filter_name_list: | ||
filter = getattr(filters, filter_name) | ||
if not filter: | ||
raise Exception('Unknown filter: %s' % filter_name) | ||
text = filter(text) | ||
return text | ||
|
||
|
||
class OpenaiWhisper(): | ||
|
||
def __init__(self, model) -> None: | ||
print("Loading whisper model: {}".format(model)) | ||
import whisper | ||
self.model = whisper.load_model(model) | ||
|
||
def transcribe(self, audio: np.array, **transcribe_options) -> str: | ||
result = self.model.transcribe(audio, | ||
without_timestamps=True, | ||
**transcribe_options) | ||
return result.get("text") | ||
|
||
def work(self, input_queue: queue.SimpleQueue[TranslationTask], output_queue: queue.SimpleQueue[TranslationTask], whisper_filters, **transcribe_options): | ||
while True: | ||
task = input_queue.get() | ||
task.transcribed_text = _filter_text(self.transcribe(task.audio, **transcribe_options), whisper_filters) | ||
print(task.transcribed_text) | ||
output_queue.put(task) | ||
|
||
|
||
class FasterWhisper(OpenaiWhisper): | ||
|
||
def __init__(self, model) -> None: | ||
print("Loading faster-whisper model: {}".format(model)) | ||
from faster_whisper import WhisperModel | ||
self.model = WhisperModel(model) | ||
|
||
def transcribe(self, audio: np.array, **transcribe_options) -> str: | ||
segments, info = self.model.transcribe(audio, **transcribe_options) | ||
transcribed_text = "" | ||
for segment in segments: | ||
transcribed_text += segment.text | ||
return transcribed_text | ||
|
||
|
||
class RemoteOpenaiWhisper(OpenaiWhisper): | ||
# https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=python | ||
|
||
def __init__(self) -> None: | ||
self.client = OpenAI() | ||
|
||
def transcribe(self, audio: np.array, **transcribe_options) -> str: | ||
with open(TEMP_AUDIO_FILE_NAME, 'wb') as audio_file: | ||
write_audio(audio_file, SAMPLE_RATE, audio) | ||
with open(TEMP_AUDIO_FILE_NAME, 'rb') as audio_file: | ||
result = self.client.audio.transcriptions.create(model="whisper-1", file=audio_file, language=transcribe_options['language']).text | ||
os.remove(TEMP_AUDIO_FILE_NAME) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import numpy as np | ||
from whisper.audio import SAMPLE_RATE | ||
|
||
|
||
class TranslationTask: | ||
|
||
def __init__(self, audio: np.array, time_range: tuple[float, float]): | ||
self.audio = audio | ||
self.transcribed_text = None | ||
self.translated_text = None | ||
self.time_range = time_range | ||
self.start_time = None |
Oops, something went wrong.