Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ionic-bond committed Dec 27, 2023
1 parent 3fa3631 commit b567e08
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion audio_getter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ffmpeg
import numpy as np
import sounddevice as sd

from common import SAMPLE_RATE

Expand Down Expand Up @@ -90,7 +91,6 @@ def work(self, output_queue: queue.SimpleQueue[np.array]):
class DeviceAudioGetter():

def __init__(self, device_index: int, frame_duration: float) -> None:
import sounddevice as sd
if device_index:
sd.default.device[0] = device_index
sd.default.dtype[0] = np.float32
Expand Down
4 changes: 2 additions & 2 deletions audio_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class VAD:
def __init__(self):
self.model = _init_jit_model("silero_vad.jit")

def is_speech(self, audio, threshold: float = 0.5, sampling_rate: int = 16000):
def is_speech(self, audio: np.array, threshold: float = 0.5, sampling_rate: int = 16000):
if not torch.is_tensor(audio):
try:
audio = torch.Tensor(audio)
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, frame_duration: float, continuous_no_speech_threshold: float,
self.counter = 0
self.last_slice_second = 0.0

def put(self, audio):
def put(self, audio: np.array):
self.counter += 1
if self.vad.is_speech(audio, self.vad_threshold, self.sampling_rate):
self.audio_buffer.append(audio)
Expand Down
10 changes: 5 additions & 5 deletions audio_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from scipy.io.wavfile import write as write_audio

import numpy as np
import whisper
from faster_whisper import WhisperModel
from openai import OpenAI

import filters
Expand All @@ -11,7 +13,7 @@
TEMP_AUDIO_FILE_NAME = 'temp.wav'


def _filter_text(text, whisper_filters):
def _filter_text(text: str, whisper_filters: str):
filter_name_list = whisper_filters.split(',')
for filter_name in filter_name_list:
filter = getattr(filters, filter_name)
Expand All @@ -23,9 +25,8 @@ def _filter_text(text, whisper_filters):

class OpenaiWhisper():

def __init__(self, model) -> None:
def __init__(self, model: str) -> None:
print("Loading whisper model: {}".format(model))
import whisper
self.model = whisper.load_model(model)

def transcribe(self, audio: np.array, **transcribe_options) -> str:
Expand All @@ -48,9 +49,8 @@ def work(self, input_queue: queue.SimpleQueue[TranslationTask],

class FasterWhisper(OpenaiWhisper):

def __init__(self, model) -> None:
def __init__(self, model: str) -> None:
print("Loading faster-whisper model: {}".format(model))
from faster_whisper import WhisperModel
self.model = WhisperModel(model)

def transcribe(self, audio: np.array, **transcribe_options) -> str:
Expand Down
4 changes: 2 additions & 2 deletions filters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import re


def emoji_filter(text):
def emoji_filter(text: str):
return re.sub(
r'[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF]+',
'', text)


def japanese_stream_filter(text):
def japanese_stream_filter(text: str):
for filter_pattern in [
r'【.+】', r'ご視聴ありがとうございました', r'チャンネル登録をお願いいたします', r'ご視聴いただきありがとうございます', r'チャンネル登録してね',
r'字幕視聴ありがとうございました', r'動画をご覧頂きましてありがとうございました', r'次の動画でお会いしましょう', r'最後までご視聴頂きありがとうございました',
Expand Down
8 changes: 4 additions & 4 deletions gpt_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from common import TranslationTask


def _translate_by_gpt(client, translation_task, assistant_prompt, model, history_messages=[]):
def _translate_by_gpt(client: OpenAI, translation_task: TranslationTask, assistant_prompt: str, model: str, history_messages: list=[]):
# https://platform.openai.com/docs/api-reference/chat/create?lang=python
try:
system_prompt = "You are a translation engine."
Expand All @@ -33,14 +33,14 @@ def _translate_by_gpt(client, translation_task, assistant_prompt, model, history

class ParallelTranslator():

def __init__(self, prompt, model, timeout):
def __init__(self, prompt: str, model: str, timeout: int):
self.prompt = prompt
self.model = model
self.timeout = timeout
self.client = OpenAI()
self.processing_queue = deque()

def trigger(self, translation_task):
def trigger(self, translation_task: TranslationTask):
self.processing_queue.append(translation_task)
translation_task.start_time = datetime.utcnow()
thread = threading.Thread(target=_translate_by_gpt,
Expand Down Expand Up @@ -73,7 +73,7 @@ def work(self, input_queue: queue.SimpleQueue[TranslationTask],

class SerialTranslator():

def __init__(self, prompt, model, timeout, history_size):
def __init__(self, prompt: str, model: str, timeout: int, history_size: int):
self.prompt = prompt
self.model = model
self.timeout = timeout
Expand Down
2 changes: 1 addition & 1 deletion result_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _send_to_cqhttp(url: str, token: str, text: str):
requests.post(url, headers=headers, data=data)


def _sec2str(second):
def _sec2str(second: float):
dt = datetime.utcfromtimestamp(second)
result = dt.strftime('%H:%M:%S')
result += ',' + str(round(second * 10 % 10))
Expand Down

0 comments on commit b567e08

Please sign in to comment.