diff --git a/.gitignore b/.gitignore index 760b3ea..9aaab14 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ nitro/ # C extensions *.so - +*.wav # Distribution / packaging .Python build/ diff --git a/app.py b/app.py index fb3d078..73a2f54 100644 --- a/app.py +++ b/app.py @@ -1,50 +1,82 @@ -import queue +import time import threading +import numpy as np +import whisper import sounddevice as sd -from core.common.const import console +from queue import Queue +from rich.console import Console +console = Console() +model = whisper.load_model("base.en") -def listen_for_quit(stop_event): - console.print("[yellow]Start speaking! Press 'q' to quit.") - while True: - if console.read() == "q": - console.print("[yellow]Goodbye!") - stop_event.set() - break +def transcribe(audio_np: np.ndarray) -> str: + result = model.transcribe(audio_np, fp16=False) # Set fp16=True if using a GPU + text = result["text"].strip() + return text -def conversation(stop_event): - q = queue.Queue() +def record_audio(stop_event, data_queue): def callback(indata, frames, time, status): if status: console.print(status) - q.put(bytes(indata)) + # Put the audio bytes into the queue + data_queue.put(bytes(indata)) - with sd.RawInputStream(samplerate=16000, blocksize=8000, callback=callback): - console.print("[blue]Start speaking!") + # Start recording + with sd.RawInputStream( + samplerate=16000, dtype="int16", channels=1, callback=callback + ): while not stop_event.is_set(): - audio = q.get() - console.print(f"[green]You said: {audio}") + # Small sleep to prevent this loop from consuming too much CPU + time.sleep(0.1) if __name__ == "__main__": - console.print("[blue]Welcome to local taking llm!") + console.print( + "[blue]Press Enter to start speaking. Press Enter again to stop and transcribe." + ) + + try: + while True: + # Wait for the user to press Enter to start recording + input("[blue]Press Enter to start recording...") + console.print("[yellow]Recording... Press Enter to stop.") - user_name = None - while not user_name: - user_name = console.input("[blue]Your name: ") + data_queue = Queue() # type: ignore[var-annotated] + stop_event = threading.Event() - console.print(f"[cyan]Nice to see you, {user_name}!") + # Start recording in a background thread + recording_thread = threading.Thread( + target=record_audio, + args=( + stop_event, + data_queue, + ), + ) + recording_thread.start() + + # Wait for the user to press Enter to stop recording + input() # No need to print a message as the previous message indicates to press Enter to stop + stop_event.set() + recording_thread.join() - stop_event = threading.Event() - conversation_thread = threading.Thread(target=conversation, args=(stop_event,)) - quit_listener_thread = threading.Thread(target=listen_for_quit, args=(stop_event,)) + # Combine audio data from queue + audio_data = b"".join(list(data_queue.queue)) + audio_np = ( + np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 + ) - conversation_thread.start() - quit_listener_thread.start() + # Transcribe the recorded audio + if audio_np.size > 0: # Proceed if there's audio data + text = transcribe(audio_np) + console.print(f"[green]Transcription: {text}") + else: + console.print( + "[red]No audio recorded. Please ensure your microphone is working." + ) - conversation_thread.join() - quit_listener_thread.join() + except KeyboardInterrupt: + console.print("\n[red]Exiting...") - console.print("[blue]Session ended!") + console.print("[blue]Session ended.") diff --git a/pyproject.toml b/pyproject.toml index d70e69f..0267004 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,8 @@ torch = "^2.2.1" openai-whisper = {git = "https://github.com/openai/whisper.git"} sounddevice = "^0.4.6" suno-bark = {git = "https://github.com/suno-ai/bark.git"} +speechrecognition = "^3.10.1" +pyaudio = "^0.2.14" [tool.poetry.group.dev.dependencies] diff --git a/stt.py b/stt.py index 51c213b..bc41a3b 100644 --- a/stt.py +++ b/stt.py @@ -10,8 +10,9 @@ class SpeechToTextService: - def __init__(self): - self.model = whisper.load_model("base") + def __init__(self, device: str = "cpu"): + self.device = device + self.model = whisper.load_model("tiny.en") self.options = whisper.DecodingOptions() @staticmethod diff --git a/test.py b/test.py new file mode 100644 index 0000000..31c4fee --- /dev/null +++ b/test.py @@ -0,0 +1,173 @@ +#! python3.7 + +import argparse +import os +import numpy as np +import speech_recognition as sr +import whisper +import torch + +from datetime import datetime, timedelta +from queue import Queue +from time import sleep +from sys import platform + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + default="medium", + help="Model to use", + choices=["tiny", "base", "small", "medium", "large"], + ) + parser.add_argument( + "--non_english", action="store_true", help="Don't use the english model." + ) + parser.add_argument( + "--energy_threshold", + default=1000, + help="Energy level for mic to detect.", + type=int, + ) + parser.add_argument( + "--record_timeout", + default=2, + help="How real time the recording is in seconds.", + type=float, + ) + parser.add_argument( + "--phrase_timeout", + default=3, + help="How much empty space between recordings before we " + "consider it a new line in the transcription.", + type=float, + ) + if "linux" in platform: + parser.add_argument( + "--default_microphone", + default="pulse", + help="Default microphone name for SpeechRecognition. " + "Run this with 'list' to view available Microphones.", + type=str, + ) + args = parser.parse_args() + + # The last time a recording was retrieved from the queue. + phrase_time = None + # Thread safe Queue for passing data from the threaded recording callback. + data_queue = Queue() + # We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends. + recorder = sr.Recognizer() + recorder.energy_threshold = args.energy_threshold + # Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording. + recorder.dynamic_energy_threshold = False + + # Important for linux users. + # Prevents permanent application hang and crash by using the wrong Microphone + if "linux" in platform: + mic_name = args.default_microphone + if not mic_name or mic_name == "list": + print("Available microphone devices are: ") + for index, name in enumerate(sr.Microphone.list_microphone_names()): + print(f'Microphone with name "{name}" found') + return + else: + for index, name in enumerate(sr.Microphone.list_microphone_names()): + if mic_name in name: + source = sr.Microphone(sample_rate=16000, device_index=index) + break + else: + source = sr.Microphone(sample_rate=16000) + + # Load / Download model + model = args.model + if args.model != "large" and not args.non_english: + model = model + ".en" + audio_model = whisper.load_model(model) + + record_timeout = args.record_timeout + phrase_timeout = args.phrase_timeout + + transcription = [""] + + with source: + recorder.adjust_for_ambient_noise(source) + + def record_callback(_, audio: sr.AudioData) -> None: + """ + Threaded callback function to receive audio data when recordings finish. + audio: An AudioData containing the recorded bytes. + """ + # Grab the raw bytes and push it into the thread safe queue. + data = audio.get_raw_data() + data_queue.put(data) + + # Create a background thread that will pass us raw audio bytes. + # We could do this manually but SpeechRecognizer provides a nice helper. + recorder.listen_in_background( + source, record_callback, phrase_time_limit=record_timeout + ) + + # Cue the user that we're ready to go. + print("Model loaded.\n") + + while True: + try: + now = datetime.utcnow() + # Pull raw recorded audio from the queue. + if not data_queue.empty(): + phrase_complete = False + # If enough time has passed between recordings, consider the phrase complete. + # Clear the current working audio buffer to start over with the new data. + if phrase_time and now - phrase_time > timedelta( + seconds=phrase_timeout + ): + phrase_complete = True + # This is the last time we received new audio data from the queue. + phrase_time = now + + # Combine audio data from queue + audio_data = b"".join(data_queue.queue) + data_queue.queue.clear() + + # Convert in-ram buffer to something the model can use directly without needing a temp file. + # Convert data from 16 bit wide integers to floating point with a width of 32 bits. + # Clamp the audio stream frequency to a PCM wavelength compatible default of 32768hz max. + audio_np = ( + np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) + / 32768.0 + ) + + # Read the transcription. + result = audio_model.transcribe( + audio_np, fp16=torch.cuda.is_available() + ) + text = result["text"].strip() + + # If we detected a pause between recordings, add a new item to our transcription. + # Otherwise edit the existing one. + if phrase_complete: + transcription.append(text) + else: + transcription[-1] = text + + # Clear the console to reprint the updated transcription. + os.system("cls" if os.name == "nt" else "clear") + for line in transcription: + print(line) + # Flush stdout. + print("", end="", flush=True) + else: + # Infinite loops are bad for processors, must sleep. + sleep(0.25) + except KeyboardInterrupt: + break + + print("\n\nTranscription:") + for line in transcription: + print(line) + + +if __name__ == "__main__": + main() diff --git a/tts.py b/tts.py index fe86bce..ccab19d 100644 --- a/tts.py +++ b/tts.py @@ -1,16 +1,36 @@ +import torch import scipy +import warnings from transformers import AutoProcessor, BarkModel +warnings.filterwarnings( + "ignore", + message="torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.", +) + class TextToSpeechService: - def __init__(self): - self.processor = AutoProcessor.from_pretrained("suno/bark") - self.model = BarkModel.from_pretrained("suno/bark") + def __init__(self, device: str = "cpu"): + self.device = device + self.processor = AutoProcessor.from_pretrained("suno/bark-small") + self.model = BarkModel.from_pretrained("suno/bark-small") + self.model.to(self.device) + + def synthesize(self, text: str, voice_preset: str = "v2/en_speaker_9"): + inputs = self.processor(text, voice_preset=voice_preset, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + audio_array = self.model.generate(**inputs) - def synthesize(self, text: str, voice_preset: str = "v2/en_speaker_6"): - inputs = self.processor(text, voice_preset=voice_preset) - audio_array = self.model.generate(**inputs) audio_array = audio_array.cpu().numpy().squeeze() sample_rate = self.model.generation_config.sample_rate scipy.io.wavfile.write("bark_out.wav", rate=sample_rate, data=audio_array) + + +tts = TextToSpeechService() +tts.synthesize( + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. " + "Disabling parallelism to avoid deadlocks..." +)