diff --git a/audio_getter.py b/audio_getter.py index a771964..ee75199 100644 --- a/audio_getter.py +++ b/audio_getter.py @@ -6,6 +6,7 @@ import ffmpeg import numpy as np +import sounddevice as sd from common import SAMPLE_RATE @@ -90,7 +91,6 @@ def work(self, output_queue: queue.SimpleQueue[np.array]): class DeviceAudioGetter(): def __init__(self, device_index: int, frame_duration: float) -> None: - import sounddevice as sd if device_index: sd.default.device[0] = device_index sd.default.dtype[0] = np.float32 diff --git a/audio_slicer.py b/audio_slicer.py index 5a24c5b..bd372ac 100644 --- a/audio_slicer.py +++ b/audio_slicer.py @@ -21,7 +21,7 @@ class VAD: def __init__(self): self.model = _init_jit_model("silero_vad.jit") - def is_speech(self, audio, threshold: float = 0.5, sampling_rate: int = 16000): + def is_speech(self, audio: np.array, threshold: float = 0.5, sampling_rate: int = 16000): if not torch.is_tensor(audio): try: audio = torch.Tensor(audio) @@ -55,7 +55,7 @@ def __init__(self, frame_duration: float, continuous_no_speech_threshold: float, self.counter = 0 self.last_slice_second = 0.0 - def put(self, audio): + def put(self, audio: np.array): self.counter += 1 if self.vad.is_speech(audio, self.vad_threshold, self.sampling_rate): self.audio_buffer.append(audio) diff --git a/audio_transcriber.py b/audio_transcriber.py index 13cceaa..3c31307 100644 --- a/audio_transcriber.py +++ b/audio_transcriber.py @@ -3,6 +3,8 @@ from scipy.io.wavfile import write as write_audio import numpy as np +import whisper +from faster_whisper import WhisperModel from openai import OpenAI import filters @@ -11,7 +13,7 @@ TEMP_AUDIO_FILE_NAME = 'temp.wav' -def _filter_text(text, whisper_filters): +def _filter_text(text: str, whisper_filters: str): filter_name_list = whisper_filters.split(',') for filter_name in filter_name_list: filter = getattr(filters, filter_name) @@ -23,9 +25,8 @@ def _filter_text(text, whisper_filters): class OpenaiWhisper(): - def __init__(self, model) -> None: + def __init__(self, model: str) -> None: print("Loading whisper model: {}".format(model)) - import whisper self.model = whisper.load_model(model) def transcribe(self, audio: np.array, **transcribe_options) -> str: @@ -48,9 +49,8 @@ def work(self, input_queue: queue.SimpleQueue[TranslationTask], class FasterWhisper(OpenaiWhisper): - def __init__(self, model) -> None: + def __init__(self, model: str) -> None: print("Loading faster-whisper model: {}".format(model)) - from faster_whisper import WhisperModel self.model = WhisperModel(model) def transcribe(self, audio: np.array, **transcribe_options) -> str: diff --git a/filters.py b/filters.py index 1ae4626..78ffceb 100644 --- a/filters.py +++ b/filters.py @@ -1,13 +1,13 @@ import re -def emoji_filter(text): +def emoji_filter(text: str): return re.sub( r'[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF]+', '', text) -def japanese_stream_filter(text): +def japanese_stream_filter(text: str): for filter_pattern in [ r'【.+】', r'ご視聴ありがとうございました', r'チャンネル登録をお願いいたします', r'ご視聴いただきありがとうございます', r'チャンネル登録してね', r'字幕視聴ありがとうございました', r'動画をご覧頂きましてありがとうございました', r'次の動画でお会いしましょう', r'最後までご視聴頂きありがとうございました', diff --git a/gpt_translator.py b/gpt_translator.py index 6f6bab7..acfa820 100644 --- a/gpt_translator.py +++ b/gpt_translator.py @@ -9,7 +9,7 @@ from common import TranslationTask -def _translate_by_gpt(client, translation_task, assistant_prompt, model, history_messages=[]): +def _translate_by_gpt(client: OpenAI, translation_task: TranslationTask, assistant_prompt: str, model: str, history_messages: list=[]): # https://platform.openai.com/docs/api-reference/chat/create?lang=python try: system_prompt = "You are a translation engine." @@ -33,14 +33,14 @@ def _translate_by_gpt(client, translation_task, assistant_prompt, model, history class ParallelTranslator(): - def __init__(self, prompt, model, timeout): + def __init__(self, prompt: str, model: str, timeout: int): self.prompt = prompt self.model = model self.timeout = timeout self.client = OpenAI() self.processing_queue = deque() - def trigger(self, translation_task): + def trigger(self, translation_task: TranslationTask): self.processing_queue.append(translation_task) translation_task.start_time = datetime.utcnow() thread = threading.Thread(target=_translate_by_gpt, @@ -73,7 +73,7 @@ def work(self, input_queue: queue.SimpleQueue[TranslationTask], class SerialTranslator(): - def __init__(self, prompt, model, timeout, history_size): + def __init__(self, prompt: str, model: str, timeout: int, history_size: int): self.prompt = prompt self.model = model self.timeout = timeout diff --git a/result_exporter.py b/result_exporter.py index 654566e..3f381e2 100644 --- a/result_exporter.py +++ b/result_exporter.py @@ -11,7 +11,7 @@ def _send_to_cqhttp(url: str, token: str, text: str): requests.post(url, headers=headers, data=data) -def _sec2str(second): +def _sec2str(second: float): dt = datetime.utcfromtimestamp(second) result = dt.strftime('%H:%M:%S') result += ',' + str(round(second * 10 % 10))