Skip to content

Commit

Permalink
feat: record audio from console
Browse files Browse the repository at this point in the history
  • Loading branch information
vndee committed Mar 27, 2024
1 parent 82f7c30 commit f3f6258
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ nitro/

# C extensions
*.so

*.wav
# Distribution / packaging
.Python
build/
Expand Down
90 changes: 61 additions & 29 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,82 @@
import queue
import time
import threading
import numpy as np
import whisper
import sounddevice as sd
from core.common.const import console
from queue import Queue
from rich.console import Console

console = Console()
model = whisper.load_model("base.en")

def listen_for_quit(stop_event):
console.print("[yellow]Start speaking! Press 'q' to quit.")
while True:
if console.read() == "q":
console.print("[yellow]Goodbye!")
stop_event.set()
break

def transcribe(audio_np: np.ndarray) -> str:
result = model.transcribe(audio_np, fp16=False) # Set fp16=True if using a GPU
text = result["text"].strip()
return text

def conversation(stop_event):
q = queue.Queue()

def record_audio(stop_event, data_queue):
def callback(indata, frames, time, status):
if status:
console.print(status)
q.put(bytes(indata))
# Put the audio bytes into the queue
data_queue.put(bytes(indata))

with sd.RawInputStream(samplerate=16000, blocksize=8000, callback=callback):
console.print("[blue]Start speaking!")
# Start recording
with sd.RawInputStream(
samplerate=16000, dtype="int16", channels=1, callback=callback
):
while not stop_event.is_set():
audio = q.get()
console.print(f"[green]You said: {audio}")
# Small sleep to prevent this loop from consuming too much CPU
time.sleep(0.1)


if __name__ == "__main__":
console.print("[blue]Welcome to local taking llm!")
console.print(
"[blue]Press Enter to start speaking. Press Enter again to stop and transcribe."
)

try:
while True:
# Wait for the user to press Enter to start recording
input("[blue]Press Enter to start recording...")
console.print("[yellow]Recording... Press Enter to stop.")

user_name = None
while not user_name:
user_name = console.input("[blue]Your name: ")
data_queue = Queue() # type: ignore[var-annotated]
stop_event = threading.Event()

console.print(f"[cyan]Nice to see you, {user_name}!")
# Start recording in a background thread
recording_thread = threading.Thread(
target=record_audio,
args=(
stop_event,
data_queue,
),
)
recording_thread.start()

# Wait for the user to press Enter to stop recording
input() # No need to print a message as the previous message indicates to press Enter to stop
stop_event.set()
recording_thread.join()

stop_event = threading.Event()
conversation_thread = threading.Thread(target=conversation, args=(stop_event,))
quit_listener_thread = threading.Thread(target=listen_for_quit, args=(stop_event,))
# Combine audio data from queue
audio_data = b"".join(list(data_queue.queue))
audio_np = (
np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
)

conversation_thread.start()
quit_listener_thread.start()
# Transcribe the recorded audio
if audio_np.size > 0: # Proceed if there's audio data
text = transcribe(audio_np)
console.print(f"[green]Transcription: {text}")
else:
console.print(
"[red]No audio recorded. Please ensure your microphone is working."
)

conversation_thread.join()
quit_listener_thread.join()
except KeyboardInterrupt:
console.print("\n[red]Exiting...")

console.print("[blue]Session ended!")
console.print("[blue]Session ended.")
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ torch = "^2.2.1"
openai-whisper = {git = "https://github.com/openai/whisper.git"}
sounddevice = "^0.4.6"
suno-bark = {git = "https://github.com/suno-ai/bark.git"}
speechrecognition = "^3.10.1"
pyaudio = "^0.2.14"


[tool.poetry.group.dev.dependencies]
Expand Down
5 changes: 3 additions & 2 deletions stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@


class SpeechToTextService:
def __init__(self):
self.model = whisper.load_model("base")
def __init__(self, device: str = "cpu"):
self.device = device
self.model = whisper.load_model("tiny.en")
self.options = whisper.DecodingOptions()

@staticmethod
Expand Down
173 changes: 173 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#! python3.7

import argparse
import os
import numpy as np
import speech_recognition as sr
import whisper
import torch

from datetime import datetime, timedelta
from queue import Queue
from time import sleep
from sys import platform


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="medium",
help="Model to use",
choices=["tiny", "base", "small", "medium", "large"],
)
parser.add_argument(
"--non_english", action="store_true", help="Don't use the english model."
)
parser.add_argument(
"--energy_threshold",
default=1000,
help="Energy level for mic to detect.",
type=int,
)
parser.add_argument(
"--record_timeout",
default=2,
help="How real time the recording is in seconds.",
type=float,
)
parser.add_argument(
"--phrase_timeout",
default=3,
help="How much empty space between recordings before we "
"consider it a new line in the transcription.",
type=float,
)
if "linux" in platform:
parser.add_argument(
"--default_microphone",
default="pulse",
help="Default microphone name for SpeechRecognition. "
"Run this with 'list' to view available Microphones.",
type=str,
)
args = parser.parse_args()

# The last time a recording was retrieved from the queue.
phrase_time = None
# Thread safe Queue for passing data from the threaded recording callback.
data_queue = Queue()
# We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends.
recorder = sr.Recognizer()
recorder.energy_threshold = args.energy_threshold
# Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording.
recorder.dynamic_energy_threshold = False

# Important for linux users.
# Prevents permanent application hang and crash by using the wrong Microphone
if "linux" in platform:
mic_name = args.default_microphone
if not mic_name or mic_name == "list":
print("Available microphone devices are: ")
for index, name in enumerate(sr.Microphone.list_microphone_names()):
print(f'Microphone with name "{name}" found')
return
else:
for index, name in enumerate(sr.Microphone.list_microphone_names()):
if mic_name in name:
source = sr.Microphone(sample_rate=16000, device_index=index)
break
else:
source = sr.Microphone(sample_rate=16000)

# Load / Download model
model = args.model
if args.model != "large" and not args.non_english:
model = model + ".en"
audio_model = whisper.load_model(model)

record_timeout = args.record_timeout
phrase_timeout = args.phrase_timeout

transcription = [""]

with source:
recorder.adjust_for_ambient_noise(source)

def record_callback(_, audio: sr.AudioData) -> None:
"""
Threaded callback function to receive audio data when recordings finish.
audio: An AudioData containing the recorded bytes.
"""
# Grab the raw bytes and push it into the thread safe queue.
data = audio.get_raw_data()
data_queue.put(data)

# Create a background thread that will pass us raw audio bytes.
# We could do this manually but SpeechRecognizer provides a nice helper.
recorder.listen_in_background(
source, record_callback, phrase_time_limit=record_timeout
)

# Cue the user that we're ready to go.
print("Model loaded.\n")

while True:
try:
now = datetime.utcnow()
# Pull raw recorded audio from the queue.
if not data_queue.empty():
phrase_complete = False
# If enough time has passed between recordings, consider the phrase complete.
# Clear the current working audio buffer to start over with the new data.
if phrase_time and now - phrase_time > timedelta(
seconds=phrase_timeout
):
phrase_complete = True
# This is the last time we received new audio data from the queue.
phrase_time = now

# Combine audio data from queue
audio_data = b"".join(data_queue.queue)
data_queue.queue.clear()

# Convert in-ram buffer to something the model can use directly without needing a temp file.
# Convert data from 16 bit wide integers to floating point with a width of 32 bits.
# Clamp the audio stream frequency to a PCM wavelength compatible default of 32768hz max.
audio_np = (
np.frombuffer(audio_data, dtype=np.int16).astype(np.float32)
/ 32768.0
)

# Read the transcription.
result = audio_model.transcribe(
audio_np, fp16=torch.cuda.is_available()
)
text = result["text"].strip()

# If we detected a pause between recordings, add a new item to our transcription.
# Otherwise edit the existing one.
if phrase_complete:
transcription.append(text)
else:
transcription[-1] = text

# Clear the console to reprint the updated transcription.
os.system("cls" if os.name == "nt" else "clear")
for line in transcription:
print(line)
# Flush stdout.
print("", end="", flush=True)
else:
# Infinite loops are bad for processors, must sleep.
sleep(0.25)
except KeyboardInterrupt:
break

print("\n\nTranscription:")
for line in transcription:
print(line)


if __name__ == "__main__":
main()
32 changes: 26 additions & 6 deletions tts.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
import torch
import scipy
import warnings
from transformers import AutoProcessor, BarkModel

warnings.filterwarnings(
"ignore",
message="torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.",
)


class TextToSpeechService:
def __init__(self):
self.processor = AutoProcessor.from_pretrained("suno/bark")
self.model = BarkModel.from_pretrained("suno/bark")
def __init__(self, device: str = "cpu"):
self.device = device
self.processor = AutoProcessor.from_pretrained("suno/bark-small")
self.model = BarkModel.from_pretrained("suno/bark-small")
self.model.to(self.device)

def synthesize(self, text: str, voice_preset: str = "v2/en_speaker_9"):
inputs = self.processor(text, voice_preset=voice_preset, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}

with torch.no_grad():
audio_array = self.model.generate(**inputs)

def synthesize(self, text: str, voice_preset: str = "v2/en_speaker_6"):
inputs = self.processor(text, voice_preset=voice_preset)
audio_array = self.model.generate(**inputs)
audio_array = audio_array.cpu().numpy().squeeze()

sample_rate = self.model.generation_config.sample_rate
scipy.io.wavfile.write("bark_out.wav", rate=sample_rate, data=audio_array)


tts = TextToSpeechService()
tts.synthesize(
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. "
"Disabling parallelism to avoid deadlocks..."
)

0 comments on commit f3f6258

Please sign in to comment.