Skip to content

Commit

Permalink
Refactor: multi-thread optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
ionic-bond committed Dec 21, 2023
1 parent d3bad77 commit ba34cdd
Show file tree
Hide file tree
Showing 12 changed files with 481 additions and 488 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,4 @@ dmypy.json
.pyre/

--Frag*
*.wav
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ By default, the URL can be of the form ```twitch.tv/forsen``` and yt-dlp is used
| `--model` | small | Select model size. See [here](https://github.com/openai/whisper#available-models-and-languages) for available models. |
| `--task` | translate | Whether to transcribe the audio (keep original language) or translate to english. |
| `--language` | auto | Language spoken in the stream. See [here](https://github.com/openai/whisper#available-models-and-languages) for available languages. |
| `--history_buffer_size` | 0 | Times of previous audio/text to use for conditioning the model. Set to 0 to just use audio from the last processing. Note that this can easily lead to repetition/loops if the chosen language/model settings do not produce good results to begin with. |
| `--beam_size` | 5 | Number of beams in beam search. Set to 0 to use greedy algorithm instead (faster but less accurate). |
| `--best_of` | 5 | Number of candidates when sampling with non-zero temperature. |
| `--direct_url` | | Set this flag to pass the URL directly to ffmpeg. Otherwise, yt-dlp is used to obtain the stream URL. |
Expand Down
86 changes: 86 additions & 0 deletions audio_getter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import queue
import signal
import subprocess
import sys
import threading

import ffmpeg
import numpy as np

from common import SAMPLE_RATE


def _transport(ytdlp_proc, ffmpeg_proc):
while (ytdlp_proc.poll() is None) and (ffmpeg_proc.poll() is None):
try:
chunk = ytdlp_proc.stdout.read(1024)
ffmpeg_proc.stdin.write(chunk)
except (BrokenPipeError, OSError):
pass
ytdlp_proc.kill()
ffmpeg_proc.kill()


def _open_stream(url: str, direct_url: bool, format: str, cookies: str):
if direct_url:
try:
process = (ffmpeg.input(
url, loglevel="panic").output("pipe:",
format="s16le",
acodec="pcm_s16le",
ac=1,
ar=SAMPLE_RATE).run_async(pipe_stdout=True))
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

return process, None

cmd = ['yt-dlp', url, '-f', format, '-o', '-', '-q']
if cookies:
cmd.extend(['--cookies', cookies])
ytdlp_process = subprocess.Popen(cmd, stdout=subprocess.PIPE)

try:
ffmpeg_process = (ffmpeg.input("pipe:", loglevel="panic").output("pipe:",
format="s16le",
acodec="pcm_s16le",
ac=1,
ar=SAMPLE_RATE).run_async(
pipe_stdin=True,
pipe_stdout=True))
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

thread = threading.Thread(target=_transport, args=(ytdlp_process, ffmpeg_process))
thread.start()
return ffmpeg_process, ytdlp_process


class StreamAudioGetter():

def __init__(self, url: str, direct_url: bool, format: str, cookies: str, frame_duration: float):
print("Opening stream {}".format(url))
self.ffmpeg_process, self.ytdlp_process = _open_stream(url, direct_url, format, cookies)
self.byte_size = round(frame_duration * SAMPLE_RATE * 2) # Factor 2 comes from reading the int16 stream as bytes
signal.signal(signal.SIGINT, self._exit_handler)

def _exit_handler(self, signum, frame):
self.ffmpeg_process.kill()
if self.ytdlp_process:
self.ytdlp_process.kill()
sys.exit(0)

def work(self, output_queue: queue.SimpleQueue[np.array]):
while self.ffmpeg_process.poll() is None:
in_bytes = self.ffmpeg_process.stdout.read(self.byte_size)
if not in_bytes:
break
if len(in_bytes) != self.byte_size:
continue
audio = np.frombuffer(in_bytes, np.int16).flatten().astype(np.float32) / 32768.0
output_queue.put(audio)

self.ffmpeg_process.kill()
if self.ytdlp_process:
self.ytdlp_process.kill()

103 changes: 103 additions & 0 deletions audio_slicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import queue
import torch
import warnings

import numpy as np

from common import TranslationTask, SAMPLE_RATE

warnings.filterwarnings("ignore")


def _init_jit_model(model_path: str, device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model


class VAD:

def __init__(self):
self.model = _init_jit_model("silero_vad.jit")

def is_speech(self, audio, threshold: float = 0.5, sampling_rate: int = 16000):
if not torch.is_tensor(audio):
try:
audio = torch.Tensor(audio)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
speech_prob = self.model(audio, sampling_rate).item()
return speech_prob >= threshold

def reset_states(self):
self.model.reset_states()


class AudioSlicer:

def __init__(self, frame_duration: float, continuous_no_speech_threshold: float, min_audio_length: float,
max_audio_length: float, prefix_retention_length: float, vad_threshold: float):
self.vad = VAD()
self.continuous_no_speech_threshold = round(continuous_no_speech_threshold / frame_duration)
self.min_audio_length = round(min_audio_length / frame_duration)
self.max_audio_length = round(max_audio_length / frame_duration)
self.prefix_retention_length = round(prefix_retention_length / frame_duration)
self.vad_threshold = vad_threshold
self.sampling_rate = SAMPLE_RATE
self.audio_buffer = []
self.prefix_audio_buffer = []
self.speech_count = 0
self.no_speech_count = 0
self.continuous_no_speech_count = 0
self.frame_duration = frame_duration
self.counter = 0
self.last_slice_second = 0.0

def put(self, audio):
self.counter += 1
if self.vad.is_speech(audio, self.vad_threshold, self.sampling_rate):
self.audio_buffer.append(audio)
self.speech_count += 1
self.continuous_no_speech_count = 0
else:
if self.speech_count == 0 and self.no_speech_count == 1:
self.slice()
self.audio_buffer.append(audio)
self.no_speech_count += 1
self.continuous_no_speech_count += 1
if self.speech_count and self.no_speech_count / 4 > self.speech_count:
self.slice()

def should_slice(self):
audio_len = len(self.audio_buffer)
if audio_len < self.min_audio_length:
return False
if audio_len > self.max_audio_length:
return True
if self.continuous_no_speech_count >= self.continuous_no_speech_threshold:
return True
return False

def slice(self):
concatenate_buffer = self.prefix_audio_buffer + self.audio_buffer
concatenate_audio = np.concatenate(concatenate_buffer)
self.audio_buffer = []
self.prefix_audio_buffer = concatenate_buffer[-self.prefix_retention_length:]
self.speech_count = 0
self.no_speech_count = 0
self.continuous_no_speech_count = 0
# self.vad.reset_states()
slice_second = self.counter * self.frame_duration
last_slice_second = self.last_slice_second
self.last_slice_second = slice_second
return concatenate_audio, (last_slice_second, slice_second)

def work(self, input_queue: queue.SimpleQueue[np.array], output_queue: queue.SimpleQueue[TranslationTask]):
while True:
audio = input_queue.get()
self.put(audio)
if self.should_slice():
sliced_audio, time_range = self.slice()
task = TranslationTask(sliced_audio, time_range)
output_queue.put(task)
72 changes: 72 additions & 0 deletions audio_transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import queue
from scipy.io.wavfile import write as write_audio

import numpy as np
from openai import OpenAI

import filters
from common import TranslationTask, SAMPLE_RATE

TEMP_AUDIO_FILE_NAME = 'temp.wav'


def _filter_text(text, whisper_filters):
filter_name_list = whisper_filters.split(',')
for filter_name in filter_name_list:
filter = getattr(filters, filter_name)
if not filter:
raise Exception('Unknown filter: %s' % filter_name)
text = filter(text)
return text


class OpenaiWhisper():

def __init__(self, model) -> None:
print("Loading whisper model: {}".format(model))
import whisper
self.model = whisper.load_model(model)

def transcribe(self, audio: np.array, **transcribe_options) -> str:
result = self.model.transcribe(audio,
without_timestamps=True,
**transcribe_options)
return result.get("text")

def work(self, input_queue: queue.SimpleQueue[TranslationTask], output_queue: queue.SimpleQueue[TranslationTask], whisper_filters, **transcribe_options):
while True:
task = input_queue.get()
task.transcribed_text = _filter_text(self.transcribe(task.audio, **transcribe_options), whisper_filters)
print(task.transcribed_text)
output_queue.put(task)


class FasterWhisper(OpenaiWhisper):

def __init__(self, model) -> None:
print("Loading faster-whisper model: {}".format(model))
from faster_whisper import WhisperModel
self.model = WhisperModel(model)

def transcribe(self, audio: np.array, **transcribe_options) -> str:
segments, info = self.model.transcribe(audio, **transcribe_options)
transcribed_text = ""
for segment in segments:
transcribed_text += segment.text
return transcribed_text


class RemoteOpenaiWhisper(OpenaiWhisper):
# https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=python

def __init__(self) -> None:
self.client = OpenAI()

def transcribe(self, audio: np.array, **transcribe_options) -> str:
with open(TEMP_AUDIO_FILE_NAME, 'wb') as audio_file:
write_audio(audio_file, SAMPLE_RATE, audio)
with open(TEMP_AUDIO_FILE_NAME, 'rb') as audio_file:
result = self.client.audio.transcriptions.create(model="whisper-1", file=audio_file, language=transcribe_options['language']).text
os.remove(TEMP_AUDIO_FILE_NAME)
return result
12 changes: 12 additions & 0 deletions common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np
from whisper.audio import SAMPLE_RATE


class TranslationTask:

def __init__(self, audio: np.array, time_range: tuple[float, float]):
self.audio = audio
self.transcribed_text = None
self.translated_text = None
self.time_range = time_range
self.start_time = None
Loading

0 comments on commit ba34cdd

Please sign in to comment.