From ba34cdd3ad77869d36bd57be5688b64a24ddd35f Mon Sep 17 00:00:00 2001 From: ionic-bond Date: Thu, 21 Dec 2023 19:45:33 +0800 Subject: [PATCH] Refactor: multi-thread optimization --- .gitignore | 1 + README.md | 1 - audio_getter.py | 86 ++++++++++ audio_slicer.py | 103 ++++++++++++ audio_transcriber.py | 72 +++++++++ common.py | 12 ++ gpt_translator.py | 110 +++++++++++++ openai_api.py | 131 --------------- requirements.txt | 7 +- result_exporter.py | 41 +++++ translator.py | 376 ++++++------------------------------------- vad.py | 29 ---- 12 files changed, 481 insertions(+), 488 deletions(-) create mode 100644 audio_getter.py create mode 100644 audio_slicer.py create mode 100644 audio_transcriber.py create mode 100644 common.py create mode 100644 gpt_translator.py delete mode 100644 openai_api.py create mode 100644 result_exporter.py delete mode 100644 vad.py diff --git a/.gitignore b/.gitignore index b11a713..36ef0c2 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,4 @@ dmypy.json .pyre/ --Frag* +*.wav diff --git a/README.md b/README.md index 7e58bfa..d566c3e 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,6 @@ By default, the URL can be of the form ```twitch.tv/forsen``` and yt-dlp is used | `--model` | small | Select model size. See [here](https://github.com/openai/whisper#available-models-and-languages) for available models. | | `--task` | translate | Whether to transcribe the audio (keep original language) or translate to english. | | `--language` | auto | Language spoken in the stream. See [here](https://github.com/openai/whisper#available-models-and-languages) for available languages. | -| `--history_buffer_size` | 0 | Times of previous audio/text to use for conditioning the model. Set to 0 to just use audio from the last processing. Note that this can easily lead to repetition/loops if the chosen language/model settings do not produce good results to begin with. | | `--beam_size` | 5 | Number of beams in beam search. Set to 0 to use greedy algorithm instead (faster but less accurate). | | `--best_of` | 5 | Number of candidates when sampling with non-zero temperature. | | `--direct_url` | | Set this flag to pass the URL directly to ffmpeg. Otherwise, yt-dlp is used to obtain the stream URL. | diff --git a/audio_getter.py b/audio_getter.py new file mode 100644 index 0000000..b5c1275 --- /dev/null +++ b/audio_getter.py @@ -0,0 +1,86 @@ +import queue +import signal +import subprocess +import sys +import threading + +import ffmpeg +import numpy as np + +from common import SAMPLE_RATE + + +def _transport(ytdlp_proc, ffmpeg_proc): + while (ytdlp_proc.poll() is None) and (ffmpeg_proc.poll() is None): + try: + chunk = ytdlp_proc.stdout.read(1024) + ffmpeg_proc.stdin.write(chunk) + except (BrokenPipeError, OSError): + pass + ytdlp_proc.kill() + ffmpeg_proc.kill() + + +def _open_stream(url: str, direct_url: bool, format: str, cookies: str): + if direct_url: + try: + process = (ffmpeg.input( + url, loglevel="panic").output("pipe:", + format="s16le", + acodec="pcm_s16le", + ac=1, + ar=SAMPLE_RATE).run_async(pipe_stdout=True)) + except ffmpeg.Error as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return process, None + + cmd = ['yt-dlp', url, '-f', format, '-o', '-', '-q'] + if cookies: + cmd.extend(['--cookies', cookies]) + ytdlp_process = subprocess.Popen(cmd, stdout=subprocess.PIPE) + + try: + ffmpeg_process = (ffmpeg.input("pipe:", loglevel="panic").output("pipe:", + format="s16le", + acodec="pcm_s16le", + ac=1, + ar=SAMPLE_RATE).run_async( + pipe_stdin=True, + pipe_stdout=True)) + except ffmpeg.Error as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + thread = threading.Thread(target=_transport, args=(ytdlp_process, ffmpeg_process)) + thread.start() + return ffmpeg_process, ytdlp_process + + +class StreamAudioGetter(): + + def __init__(self, url: str, direct_url: bool, format: str, cookies: str, frame_duration: float): + print("Opening stream {}".format(url)) + self.ffmpeg_process, self.ytdlp_process = _open_stream(url, direct_url, format, cookies) + self.byte_size = round(frame_duration * SAMPLE_RATE * 2) # Factor 2 comes from reading the int16 stream as bytes + signal.signal(signal.SIGINT, self._exit_handler) + + def _exit_handler(self, signum, frame): + self.ffmpeg_process.kill() + if self.ytdlp_process: + self.ytdlp_process.kill() + sys.exit(0) + + def work(self, output_queue: queue.SimpleQueue[np.array]): + while self.ffmpeg_process.poll() is None: + in_bytes = self.ffmpeg_process.stdout.read(self.byte_size) + if not in_bytes: + break + if len(in_bytes) != self.byte_size: + continue + audio = np.frombuffer(in_bytes, np.int16).flatten().astype(np.float32) / 32768.0 + output_queue.put(audio) + + self.ffmpeg_process.kill() + if self.ytdlp_process: + self.ytdlp_process.kill() + diff --git a/audio_slicer.py b/audio_slicer.py new file mode 100644 index 0000000..ef426af --- /dev/null +++ b/audio_slicer.py @@ -0,0 +1,103 @@ +import queue +import torch +import warnings + +import numpy as np + +from common import TranslationTask, SAMPLE_RATE + +warnings.filterwarnings("ignore") + + +def _init_jit_model(model_path: str, device=torch.device('cpu')): + torch.set_grad_enabled(False) + model = torch.jit.load(model_path, map_location=device) + model.eval() + return model + + +class VAD: + + def __init__(self): + self.model = _init_jit_model("silero_vad.jit") + + def is_speech(self, audio, threshold: float = 0.5, sampling_rate: int = 16000): + if not torch.is_tensor(audio): + try: + audio = torch.Tensor(audio) + except: + raise TypeError("Audio cannot be casted to tensor. Cast it manually") + speech_prob = self.model(audio, sampling_rate).item() + return speech_prob >= threshold + + def reset_states(self): + self.model.reset_states() + + +class AudioSlicer: + + def __init__(self, frame_duration: float, continuous_no_speech_threshold: float, min_audio_length: float, + max_audio_length: float, prefix_retention_length: float, vad_threshold: float): + self.vad = VAD() + self.continuous_no_speech_threshold = round(continuous_no_speech_threshold / frame_duration) + self.min_audio_length = round(min_audio_length / frame_duration) + self.max_audio_length = round(max_audio_length / frame_duration) + self.prefix_retention_length = round(prefix_retention_length / frame_duration) + self.vad_threshold = vad_threshold + self.sampling_rate = SAMPLE_RATE + self.audio_buffer = [] + self.prefix_audio_buffer = [] + self.speech_count = 0 + self.no_speech_count = 0 + self.continuous_no_speech_count = 0 + self.frame_duration = frame_duration + self.counter = 0 + self.last_slice_second = 0.0 + + def put(self, audio): + self.counter += 1 + if self.vad.is_speech(audio, self.vad_threshold, self.sampling_rate): + self.audio_buffer.append(audio) + self.speech_count += 1 + self.continuous_no_speech_count = 0 + else: + if self.speech_count == 0 and self.no_speech_count == 1: + self.slice() + self.audio_buffer.append(audio) + self.no_speech_count += 1 + self.continuous_no_speech_count += 1 + if self.speech_count and self.no_speech_count / 4 > self.speech_count: + self.slice() + + def should_slice(self): + audio_len = len(self.audio_buffer) + if audio_len < self.min_audio_length: + return False + if audio_len > self.max_audio_length: + return True + if self.continuous_no_speech_count >= self.continuous_no_speech_threshold: + return True + return False + + def slice(self): + concatenate_buffer = self.prefix_audio_buffer + self.audio_buffer + concatenate_audio = np.concatenate(concatenate_buffer) + self.audio_buffer = [] + self.prefix_audio_buffer = concatenate_buffer[-self.prefix_retention_length:] + self.speech_count = 0 + self.no_speech_count = 0 + self.continuous_no_speech_count = 0 + # self.vad.reset_states() + slice_second = self.counter * self.frame_duration + last_slice_second = self.last_slice_second + self.last_slice_second = slice_second + return concatenate_audio, (last_slice_second, slice_second) + + def work(self, input_queue: queue.SimpleQueue[np.array], output_queue: queue.SimpleQueue[TranslationTask]): + while True: + audio = input_queue.get() + self.put(audio) + if self.should_slice(): + sliced_audio, time_range = self.slice() + task = TranslationTask(sliced_audio, time_range) + output_queue.put(task) diff --git a/audio_transcriber.py b/audio_transcriber.py new file mode 100644 index 0000000..81e9e66 --- /dev/null +++ b/audio_transcriber.py @@ -0,0 +1,72 @@ +import os +import queue +from scipy.io.wavfile import write as write_audio + +import numpy as np +from openai import OpenAI + +import filters +from common import TranslationTask, SAMPLE_RATE + +TEMP_AUDIO_FILE_NAME = 'temp.wav' + + +def _filter_text(text, whisper_filters): + filter_name_list = whisper_filters.split(',') + for filter_name in filter_name_list: + filter = getattr(filters, filter_name) + if not filter: + raise Exception('Unknown filter: %s' % filter_name) + text = filter(text) + return text + + +class OpenaiWhisper(): + + def __init__(self, model) -> None: + print("Loading whisper model: {}".format(model)) + import whisper + self.model = whisper.load_model(model) + + def transcribe(self, audio: np.array, **transcribe_options) -> str: + result = self.model.transcribe(audio, + without_timestamps=True, + **transcribe_options) + return result.get("text") + + def work(self, input_queue: queue.SimpleQueue[TranslationTask], output_queue: queue.SimpleQueue[TranslationTask], whisper_filters, **transcribe_options): + while True: + task = input_queue.get() + task.transcribed_text = _filter_text(self.transcribe(task.audio, **transcribe_options), whisper_filters) + print(task.transcribed_text) + output_queue.put(task) + + +class FasterWhisper(OpenaiWhisper): + + def __init__(self, model) -> None: + print("Loading faster-whisper model: {}".format(model)) + from faster_whisper import WhisperModel + self.model = WhisperModel(model) + + def transcribe(self, audio: np.array, **transcribe_options) -> str: + segments, info = self.model.transcribe(audio, **transcribe_options) + transcribed_text = "" + for segment in segments: + transcribed_text += segment.text + return transcribed_text + + +class RemoteOpenaiWhisper(OpenaiWhisper): + # https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=python + + def __init__(self) -> None: + self.client = OpenAI() + + def transcribe(self, audio: np.array, **transcribe_options) -> str: + with open(TEMP_AUDIO_FILE_NAME, 'wb') as audio_file: + write_audio(audio_file, SAMPLE_RATE, audio) + with open(TEMP_AUDIO_FILE_NAME, 'rb') as audio_file: + result = self.client.audio.transcriptions.create(model="whisper-1", file=audio_file, language=transcribe_options['language']).text + os.remove(TEMP_AUDIO_FILE_NAME) + return result \ No newline at end of file diff --git a/common.py b/common.py new file mode 100644 index 0000000..9b71bd0 --- /dev/null +++ b/common.py @@ -0,0 +1,12 @@ +import numpy as np +from whisper.audio import SAMPLE_RATE + + +class TranslationTask: + + def __init__(self, audio: np.array, time_range: tuple[float, float]): + self.audio = audio + self.transcribed_text = None + self.translated_text = None + self.time_range = time_range + self.start_time = None \ No newline at end of file diff --git a/gpt_translator.py b/gpt_translator.py new file mode 100644 index 0000000..3355b90 --- /dev/null +++ b/gpt_translator.py @@ -0,0 +1,110 @@ +import queue +import threading +import time +from collections import deque +from datetime import datetime, timedelta + +from openai import OpenAI + +from common import TranslationTask + + +def _translate_by_gpt(client, + translation_task, + assistant_prompt, + model, + history_messages=[]): + # https://platform.openai.com/docs/api-reference/chat/create?lang=python + system_prompt = "You are a translation engine." + messages = [{"role": "system", "content": system_prompt}] + messages.extend(history_messages) + messages.append({"role": "user", "content": assistant_prompt}) + messages.append({"role": "user", "content": translation_task.transcribed_text}) + completion = client.chat.completions.create( + model=model, + temperature=0, + max_tokens=1000, + top_p=1, + frequency_penalty=1, + presence_penalty=1, + messages=messages, + ) + translation_task.translated_text = completion.choices[0].message.content + + +class ParallelTranslator(): + + def __init__(self, prompt, model, timeout): + self.prompt = prompt + self.model = model + self.timeout = timeout + self.client = OpenAI() + self.processing_queue = deque() + + def trigger(self, translation_task): + self.processing_queue.append(translation_task) + translation_task.start_time = datetime.utcnow() + thread = threading.Thread(target=_translate_by_gpt, + args=(self.client, translation_task, self.prompt, self.model)) + thread.daemon = True + thread.start() + + def get_results(self): + results = [] + while self.processing_queue and (self.processing_queue[0].translated_text or + datetime.utcnow() - self.processing_queue[0].start_time + > timedelta(seconds=self.timeout)): + task = self.processing_queue.popleft() + results.append(task) + if not task.translated_text: + print("Translation timeout or failed: {}".format(task.transcribed_text)) + return results + + def work(self, input_queue: queue.SimpleQueue[TranslationTask], output_queue: queue.SimpleQueue[TranslationTask]): + while True: + if not input_queue.empty(): + task = input_queue.get() + self.trigger(task) + finished_tasks = self.get_results() + for task in finished_tasks: + output_queue.put(task) + time.sleep(0.1) + + +class SerialTranslator(): + + def __init__(self, prompt, model, timeout, history_size): + self.prompt = prompt + self.model = model + self.timeout = timeout + self.history_size = history_size + self.client = OpenAI() + self.history_messages = [] + + def work(self, input_queue: queue.SimpleQueue[TranslationTask], output_queue: queue.SimpleQueue[TranslationTask]): + current_task = None + while True: + if current_task: + if current_task.translated_text or datetime.utcnow( + ) - current_task.start_time > timedelta(seconds=self.timeout): + if current_task.translated_text: + # self.history_messages.append({"role": "user", "content": current_task.transcribed_text}) + self.history_messages.append({ + "role": "assistant", + "content": current_task.translated_text + }) + while (len(self.history_messages) > self.history_size): + self.history_messages.pop(0) + else: + print("Translation timeout or failed: {}".format(current_task.transcribed_text)) + output_queue.append(current_task) + current_task = None + + if current_task is None and not input_queue.empty(): + current_task = input_queue.get() + current_task.start_time = datetime.utcnow() + thread = threading.Thread(target=_translate_by_gpt, + args=(self.client, current_task, self.prompt, self.model, self.history_messages)) + thread.daemon = True + thread.start() + time.sleep(0.1) diff --git a/openai_api.py b/openai_api.py deleted file mode 100644 index 33720c5..0000000 --- a/openai_api.py +++ /dev/null @@ -1,131 +0,0 @@ -import threading -import time -from collections import deque -from datetime import datetime, timedelta - -import openai - - -def translate_by_gpt(translation_task, - openai_api_key, - assistant_prompt, - model, - history_messages=[]): - # https://platform.openai.com/docs/api-reference/chat/create?lang=python - openai.api_key = openai_api_key - system_prompt = "You are a translation engine." - messages = [{"role": "system", "content": system_prompt}] - messages.extend(history_messages) - messages.append({"role": "user", "content": assistant_prompt}) - messages.append({"role": "user", "content": translation_task.input_text}) - completion = openai.ChatCompletion.create( - model=model, - temperature=0, - max_tokens=1000, - top_p=1, - frequency_penalty=1, - presence_penalty=1, - messages=messages, - ) - translation_task.output_text = completion.choices[0].message['content'] - - -class TranslationTask: - - def __init__(self, text, time_range=(0.0, 0.0)): - self.input_text = text - self.output_text = None - self.time_range = time_range - self.start_time = datetime.utcnow() - - -class ParallelTranslator(): - - def __init__(self, openai_api_key, prompt, model, timeout): - self.openai_api_key = openai_api_key - self.prompt = prompt - self.model = model - self.timeout = timeout - self.processing_queue = deque() - - def put(self, translation_task): - self.processing_queue.append(translation_task) - thread = threading.Thread(target=translate_by_gpt, - args=(translation_task, self.openai_api_key, self.prompt, - self.model)) - thread.start() - - def get_results(self): - results = [] - while len( - self.processing_queue) and (self.processing_queue[0].output_text or - datetime.utcnow() - self.processing_queue[0].start_time - > timedelta(seconds=self.timeout)): - task = self.processing_queue.popleft() - results.append(task) - if not task.output_text: - print("Translation timeout or failed: {}".format(task.input_text)) - return results - - -class SerialTranslator(): - - def __init__(self, openai_api_key, prompt, model, timeout, history_size): - self.openai_api_key = openai_api_key - self.prompt = prompt - self.model = model - self.timeout = timeout - self.history_size = history_size - self.history_messages = [] - self.input_queue = deque() - self.output_queue = deque() - - self.running = True - self.loop_thread = threading.Thread(target=self._run_loop) - self.loop_thread.start() - - def __del__(self): - self.running = False - self.loop_thread.join() - - def _run_loop(self): - current_task = None - while (self.running): - if current_task: - if current_task.output_text or datetime.utcnow( - ) - current_task.start_time > timedelta(seconds=self.timeout): - if current_task.output_text: - # self.history_messages.append({"role": "user", "content": current_task.input_text}) - self.history_messages.append({ - "role": "assistant", - "content": current_task.output_text - }) - while (len(self.history_messages) > self.history_size): - self.history_messages.pop(0) - self.output_queue.append(current_task) - current_task = None - if current_task is None and len(self.input_queue): - current_task = self.input_queue.popleft() - current_task.start_time = datetime.utcnow() - thread = threading.Thread(target=translate_by_gpt, - args=(current_task, self.openai_api_key, self.prompt, - self.model, self.history_messages)) - thread.start() - time.sleep(0.1) - - def put(self, translation_task): - self.input_queue.append(translation_task) - - def get_results(self): - results = [] - while len(self.output_queue): - task = self.output_queue.popleft() - results.append(task) - if not task.output_text: - print("Translation timeout or failed: {}".format(task.input_text)) - return results - - -def whisper_transcribe(audio_file, openai_api_key): - openai.api_key = openai_api_key - return openai.Audio.transcribe("whisper-1", audio_file).get('text', '') diff --git a/requirements.txt b/requirements.txt index 6ec3623..5f7a8dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,9 @@ numpy tqdm more-itertools ---extra-index-url https://download.pytorch.org/whl/cu113 torch ffmpeg-python==0.2.0 -openai-whisper -faster-whisper -openai==0.28 +openai-whisper==20231117 +faster-whisper==0.10.0 +openai==1.6.0 yt-dlp diff --git a/result_exporter.py b/result_exporter.py new file mode 100644 index 0000000..03bdedb --- /dev/null +++ b/result_exporter.py @@ -0,0 +1,41 @@ +import queue +import requests +from datetime import datetime + +from common import TranslationTask + + +def _send_to_cqhttp(url: str, token: str, text: str): + headers = {'Authorization': 'Bearer {}'.format(token)} if token else None + data = {'message': text} + requests.post(url, headers=headers, data=data) + + +def _sec2str(second): + dt = datetime.utcfromtimestamp(second) + return dt.strftime('%H:%M:%S') + + +class ResultExporter(): + + def __init__(self, output_timestamps: bool, cqhttp_url: str, cqhttp_token: str) -> None: + self.output_timestamps = output_timestamps + self.cqhttp_url = cqhttp_url + self.cqhttp_token = cqhttp_token + + def work(self, input_queue: queue.SimpleQueue[TranslationTask]): + while True: + task = input_queue.get() + timestamp_text = '{}-{}'.format(_sec2str(task.time_range[0]), _sec2str( + task.time_range[1])) + text_to_send = task.transcribed_text + if self.output_timestamps: + text_to_send = timestamp_text + '\n' + text_to_send + if task.translated_text: + text_to_print = task.translated_text + if self.output_timestamps: + text_to_print = timestamp_text + ' ' + text_to_print + print('\033[1m{}\033[0m'.format(text_to_print)) + text_to_send += '\n{}'.format(task.translated_text) + if self.cqhttp_url: + _send_to_cqhttp(self.cqhttp_url, self.cqhttp_token, text_to_send) diff --git a/translator.py b/translator.py index 3cfd41c..20a743f 100644 --- a/translator.py +++ b/translator.py @@ -1,334 +1,71 @@ import argparse import os -import requests -import signal +import queue import sys -import subprocess -import tempfile import threading -from datetime import datetime -from scipy.io.wavfile import write as write_audio +import time -import ffmpeg -import numpy as np -from whisper.audio import SAMPLE_RATE +from audio_getter import StreamAudioGetter +from audio_slicer import AudioSlicer +from audio_transcriber import OpenaiWhisper, FasterWhisper, RemoteOpenaiWhisper +from gpt_translator import ParallelTranslator, SerialTranslator +from result_exporter import ResultExporter -import filters -from openai_api import TranslationTask, ParallelTranslator, SerialTranslator, whisper_transcribe -from vad import VAD - -class RingBuffer: - - def __init__(self, size): - self.size = size - self.data = [] - self.full = False - self.cur = 0 - - def append(self, x): - if self.size <= 0: - return - if self.full: - self.data[self.cur] = x - self.cur = (self.cur + 1) % self.size - else: - self.data.append(x) - if len(self.data) == self.size: - self.full = True - - def get_all(self): - """ Get all elements in chronological order from oldest to newest. """ - all_data = [] - for i in range(len(self.data)): - idx = (i + self.cur) % self.size - all_data.append(self.data[idx]) - return all_data - - def has_repetition(self): - prev = None - for elem in self.data: - if elem == prev: - return True - prev = elem - return False - - def clear(self): - self.data = [] - self.full = False - self.cur = 0 - - -def open_stream(stream, direct_url, format, cookies): - if direct_url: - try: - process = (ffmpeg.input( - stream, loglevel="panic").output("pipe:", - format="s16le", - acodec="pcm_s16le", - ac=1, - ar=SAMPLE_RATE).run_async(pipe_stdout=True)) - except ffmpeg.Error as e: - raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e - - return process, None - - def writer(ytdlp_proc, ffmpeg_proc): - while (ytdlp_proc.poll() is None) and (ffmpeg_proc.poll() is None): - try: - chunk = ytdlp_proc.stdout.read(1024) - ffmpeg_proc.stdin.write(chunk) - except (BrokenPipeError, OSError): - pass - ytdlp_proc.kill() - ffmpeg_proc.kill() - - cmd = ['yt-dlp', stream, '-f', format, '-o', '-', '-q'] - if cookies: - cmd.extend(['--cookies', cookies]) - ytdlp_process = subprocess.Popen(cmd, stdout=subprocess.PIPE) - - try: - ffmpeg_process = (ffmpeg.input("pipe:", loglevel="panic").output("pipe:", - format="s16le", - acodec="pcm_s16le", - ac=1, - ar=SAMPLE_RATE).run_async( - pipe_stdin=True, - pipe_stdout=True)) - except ffmpeg.Error as e: - raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e - - thread = threading.Thread(target=writer, args=(ytdlp_process, ffmpeg_process)) +def _start_daemon_thread(func, *args, **kwargs): + thread = threading.Thread(target=func, args=args, kwargs=kwargs) + thread.daemon = True thread.start() - return ffmpeg_process, ytdlp_process - - -def send_to_cqhttp(url, token, text): - headers = {'Authorization': 'Bearer {}'.format(token)} if token else None - data = {'message': text} - requests.post(url, headers=headers, data=data) - - -def filter_text(text, whisper_filters): - filter_name_list = whisper_filters.split(',') - for filter_name in filter_name_list: - filter = getattr(filters, filter_name) - if not filter: - raise Exception('Unknown filter: %s' % filter_name) - text = filter(text) - return text - - -class StreamSlicer: - - def __init__(self, frame_duration, continuous_no_speech_threshold, min_audio_length, - max_audio_length, prefix_retention_length, vad_threshold, sampling_rate): - self.vad = VAD() - self.continuous_no_speech_threshold = round(continuous_no_speech_threshold / frame_duration) - self.min_audio_length = round(min_audio_length / frame_duration) - self.max_audio_length = round(max_audio_length / frame_duration) - self.prefix_retention_length = round(prefix_retention_length / frame_duration) - self.vad_threshold = vad_threshold - self.sampling_rate = sampling_rate - self.audio_buffer = [] - self.prefix_audio_buffer = [] - self.speech_count = 0 - self.no_speech_count = 0 - self.continuous_no_speech_count = 0 - self.frame_duration = frame_duration - self.counter = 0 - self.last_slice_second = 0.0 - - def put(self, audio): - self.counter += 1 - if self.vad.is_speech(audio, self.vad_threshold, self.sampling_rate): - self.audio_buffer.append(audio) - self.speech_count += 1 - self.continuous_no_speech_count = 0 - else: - if self.speech_count == 0 and self.no_speech_count == 1: - self.slice() - self.audio_buffer.append(audio) - self.no_speech_count += 1 - self.continuous_no_speech_count += 1 - if self.speech_count and self.no_speech_count / 4 > self.speech_count: - self.slice() - - def should_slice(self): - audio_len = len(self.audio_buffer) - if audio_len < self.min_audio_length: - return False - if audio_len > self.max_audio_length: - return True - if self.continuous_no_speech_count >= self.continuous_no_speech_threshold: - return True - return False - - def slice(self): - concatenate_buffer = self.prefix_audio_buffer + self.audio_buffer - concatenate_audio = np.concatenate(concatenate_buffer) - self.audio_buffer = [] - self.prefix_audio_buffer = concatenate_buffer[-self.prefix_retention_length:] - self.speech_count = 0 - self.no_speech_count = 0 - self.continuous_no_speech_count = 0 - # self.vad.reset_states() - slice_second = self.counter * self.frame_duration - last_slice_second = self.last_slice_second - self.last_slice_second = slice_second - return concatenate_audio, (last_slice_second, slice_second) - - -def sec2str(second): - dt = datetime.utcfromtimestamp(second) - return dt.strftime('%H:%M:%S') def main(url, format, direct_url, cookies, frame_duration, continuous_no_speech_threshold, min_audio_length, max_audio_length, prefix_retention_length, vad_threshold, model, - language, use_faster_whisper, use_whisper_api, whisper_filters, output_timestamps, - history_buffer_size, gpt_translation_prompt, gpt_translation_history_size, openai_api_key, - gpt_model, gpt_translation_timeout, cqhttp_url, cqhttp_token, **decode_options): - - n_bytes = round(frame_duration * SAMPLE_RATE * - 2) # Factor 2 comes from reading the int16 stream as bytes - history_audio_buffer = RingBuffer(history_buffer_size + 1) - history_text_buffer = RingBuffer(history_buffer_size) - stream_slicer = StreamSlicer(frame_duration=frame_duration, - continuous_no_speech_threshold=continuous_no_speech_threshold, - min_audio_length=min_audio_length, - max_audio_length=max_audio_length, - prefix_retention_length=prefix_retention_length, - vad_threshold=vad_threshold, - sampling_rate=SAMPLE_RATE) - - if use_faster_whisper: - print("Loading faster whisper model: {}".format(model)) - from faster_whisper import WhisperModel - model = WhisperModel(model) - elif not use_whisper_api: - print("Loading whisper model: {}".format(model)) - import whisper - model = whisper.load_model(model) - - translator = None - if gpt_translation_prompt and openai_api_key: + use_faster_whisper, use_whisper_api, whisper_filters, output_timestamps, + gpt_translation_prompt, gpt_translation_history_size, openai_api_key, + gpt_model, gpt_translation_timeout, cqhttp_url, cqhttp_token, **transcribe_options): + + if openai_api_key: + os.environ['OPENAI_API_KEY'] = openai_api_key + + # Reverse order initialization + result_exporter = ResultExporter(output_timestamps, cqhttp_url, cqhttp_token) + gpt_translator = None + if gpt_translation_prompt: if gpt_translation_history_size == 0: - translator = ParallelTranslator(openai_api_key=openai_api_key, - prompt=gpt_translation_prompt, - model=gpt_model, - timeout=gpt_translation_timeout) + gpt_translator = ParallelTranslator(prompt=gpt_translation_prompt, + model=gpt_model, + timeout=gpt_translation_timeout) else: - translator = SerialTranslator(openai_api_key=openai_api_key, - prompt=gpt_translation_prompt, - model=gpt_model, - timeout=gpt_translation_timeout, - history_size=gpt_translation_history_size) - - print("Opening stream...") - ffmpeg_process, ytdlp_process = open_stream(url, direct_url, format, cookies) - - def handler(signum, frame): - ffmpeg_process.kill() - if ytdlp_process: - ytdlp_process.kill() - sys.exit(0) - - signal.signal(signal.SIGINT, handler) - - while ffmpeg_process.poll() is None: - # Read audio from ffmpeg stream - in_bytes = ffmpeg_process.stdout.read(n_bytes) - if not in_bytes: - break - - audio = np.frombuffer(in_bytes, np.int16).flatten().astype(np.float32) / 32768.0 - stream_slicer.put(audio) - - if stream_slicer.should_slice(): - # Decode the audio - sliced_audio, time_range = stream_slicer.slice() - history_audio_buffer.append(sliced_audio) - clear_buffers = False - if use_faster_whisper: - segments, info = model.transcribe(sliced_audio, language=language, **decode_options) - decoded_text = "" - previous_segment = "" - for segment in segments: - if segment.text != previous_segment: - decoded_text += segment.text - previous_segment = segment.text - - new_prefix = decoded_text - - elif use_whisper_api: - with tempfile.NamedTemporaryFile(mode='wb+', suffix='.wav') as audio_file: - write_audio(audio_file, SAMPLE_RATE, sliced_audio) - decoded_text = whisper_transcribe(audio_file, openai_api_key) - new_prefix = decoded_text - - else: - result = model.transcribe(np.concatenate(history_audio_buffer.get_all()), - prefix="".join(history_text_buffer.get_all()), - language=language, - without_timestamps=True, - **decode_options) - - decoded_text = result.get("text") - new_prefix = "" - for segment in result["segments"]: - if segment["temperature"] < 0.5 and segment["no_speech_prob"] < 0.6: - new_prefix += segment["text"] - else: - # Clear history if the translation is unreliable, otherwise prompting on this leads to - # repetition and getting stuck. - clear_buffers = True - - history_text_buffer.append(new_prefix) - - if clear_buffers or history_text_buffer.has_repetition(): - history_audio_buffer.clear() - history_text_buffer.clear() - - decoded_text = filter_text(decoded_text, whisper_filters) - if decoded_text.strip(): - timestamp_text = '{}-{} '.format(sec2str(time_range[0]), sec2str( - time_range[1])) if output_timestamps else '' - print('{}{}'.format(timestamp_text, decoded_text)) - if translator: - translation_task = TranslationTask(decoded_text, time_range) - translator.put(translation_task) - elif cqhttp_url: - send_to_cqhttp(cqhttp_url, cqhttp_token, decoded_text) - else: - print('skip...') - - if translator: - for task in translator.get_results(): - if cqhttp_url: - timestamp_text = '{}-{}\n'.format(sec2str( - task.time_range[0]), sec2str( - task.time_range[1])) if output_timestamps else '' - if task.output_text: - send_to_cqhttp( - cqhttp_url, cqhttp_token, - '{}{}\n{}'.format(timestamp_text, task.input_text, task.output_text)) - else: - send_to_cqhttp(cqhttp_url, cqhttp_token, - '{}{}'.format(timestamp_text, task.input_text)) - if task.output_text: - timestamp_text = '{}-{} '.format(sec2str( - task.time_range[0]), sec2str( - task.time_range[1])) if output_timestamps else '' - print('\033[1m{}{}\033[0m'.format(timestamp_text, task.output_text)) - + gpt_translator = SerialTranslator(prompt=gpt_translation_prompt, + model=gpt_model, + timeout=gpt_translation_timeout, + history_size=gpt_translation_history_size) + if use_faster_whisper: + audio_transcriber = FasterWhisper(model) + elif use_whisper_api: + audio_transcriber = RemoteOpenaiWhisper() + else: + audio_transcriber = OpenaiWhisper(model) + audio_slicer = AudioSlicer(frame_duration, continuous_no_speech_threshold, min_audio_length, max_audio_length, prefix_retention_length, vad_threshold) + audio_getter = StreamAudioGetter(url, direct_url, format, cookies, frame_duration) + + getter_to_slicer_queue = queue.SimpleQueue() + slicer_to_transcriber_queue = queue.SimpleQueue() + transcriber_to_translator_queue = queue.SimpleQueue() + translator_to_exporter_queue = queue.SimpleQueue() if gpt_translator else transcriber_to_translator_queue + + _start_daemon_thread(result_exporter.work, translator_to_exporter_queue) + if gpt_translator: + _start_daemon_thread(gpt_translator.work, transcriber_to_translator_queue, translator_to_exporter_queue) + _start_daemon_thread(audio_transcriber.work, slicer_to_transcriber_queue, transcriber_to_translator_queue, whisper_filters, **transcribe_options) + _start_daemon_thread(audio_slicer.work, getter_to_slicer_queue, slicer_to_transcriber_queue) + audio_getter.work(output_queue=getter_to_slicer_queue) + + # Wait for others process finish. + while (not getter_to_slicer_queue.empty() or not slicer_to_transcriber_queue.empty() or not transcriber_to_translator_queue.empty() or not translator_to_exporter_queue.empty()): + time.sleep(5) print("Stream ended") - ffmpeg_process.kill() - if ytdlp_process: - ytdlp_process.kill() - def cli(): parser = argparse.ArgumentParser(description="Parameters for translator.py") @@ -398,13 +135,6 @@ def cli(): help='Language spoken in the stream. ' 'Default option is to auto detect the spoken language. ' 'See https://github.com/openai/whisper for available languages.') - parser.add_argument('--history_buffer_size', - type=int, - default=0, - help='Times of previous audio/text to use for conditioning the model. ' - 'Set to 0 to just use audio from the last processing. ' - 'Note that this can easily lead to repetition/loops if the chosen ' - 'language/model settings do not produce good results to begin with.') parser.add_argument('--beam_size', type=int, default=5, @@ -489,7 +219,7 @@ def cli(): if (args['use_whisper_api'] or args['gpt_translation_prompt']) and not args['openai_api_key']: print("Please fill in the OpenAI API key when enabling GPT translation or Whisper API") sys.exit(0) - + if args['language'] == 'auto': args['language'] = None diff --git a/vad.py b/vad.py deleted file mode 100644 index f425131..0000000 --- a/vad.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -import warnings - -warnings.filterwarnings("ignore") - - -class VAD: - - def __init__(self): - self.model = init_jit_model("silero_vad.jit") - - def is_speech(self, audio, threshold: float = 0.5, sampling_rate: int = 16000): - if not torch.is_tensor(audio): - try: - audio = torch.Tensor(audio) - except: - raise TypeError("Audio cannot be casted to tensor. Cast it manually") - speech_prob = self.model(audio, sampling_rate).item() - return speech_prob >= threshold - - def reset_states(self): - self.model.reset_states() - - -def init_jit_model(model_path: str, device=torch.device('cpu')): - torch.set_grad_enabled(False) - model = torch.jit.load(model_path, map_location=device) - model.eval() - return model