Skip to content

Commit

Permalink
feat: add function docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
vndee committed Mar 29, 2024
1 parent 1af6236 commit a3ae582
Showing 1 changed file with 66 additions and 33 deletions.
99 changes: 66 additions & 33 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
console = Console()
stt = whisper.load_model("base.en")
tts = TextToSpeechService()

template = """
You are a helpful assistant.You are polite and respectful. Your response should less than 20 words
The conversation transcript is as follows
You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less
than 20 words.
The conversation transcript is as follows:
{history}
And here is the user follow-up: {input}
Your answer:
And here is the user's follow-up: {input}
Your response:
"""
PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
chain = ConversationChain(
Expand All @@ -30,73 +35,102 @@
)


def record_audio(stop_event, data_queue):
"""
Captures audio data from the user's microphone and adds it to a queue for further processing.
Args:
stop_event (threading.Event): An event that, when set, signals the function to stop recording.
data_queue (queue.Queue): A queue to which the recorded audio data will be added.
Returns:
None
"""
def callback(indata, frames, time, status):
if status:
console.print(status)
data_queue.put(bytes(indata))

with sd.RawInputStream(
samplerate=16000, dtype="int16", channels=1, callback=callback
):
while not stop_event.is_set():
time.sleep(0.1)


def transcribe(audio_np: np.ndarray) -> str:
"""
Transcribes the given audio data using the Whisper speech recognition model.
Args:
audio_np (numpy.ndarray): The audio data to be transcribed.
Returns:
str: The transcribed text.
"""
result = stt.transcribe(audio_np, fp16=False) # Set fp16=True if using a GPU
text = result["text"].strip()
return text


def get_llm_response(text: str) -> str:
"""
Generates a response to the given text using the Llama-2 language model.
Args:
text (str): The input text to be processed.
Returns:
str: The generated response.
"""
response = chain.predict(input=text)
# trim the ai prefix if it exists
if response.startswith("Assistant:"):
response = response[len("Assistant:") :].strip()

return response


def record_audio(stop_event, data_queue):
def callback(indata, frames, time, status):
if status:
console.print(status)
# Put the audio bytes into the queue
data_queue.put(bytes(indata))
def play_audio(sample_rate, audio_array):
"""
Plays the given audio data using the sounddevice library.
# Start recording
with sd.RawInputStream(
samplerate=16000, dtype="int16", channels=1, callback=callback
):
while not stop_event.is_set():
# Small sleep to prevent this loop from consuming too much CPU
time.sleep(0.1)
Args:
sample_rate (int): The sample rate of the audio data.
audio_array (numpy.ndarray): The audio data to be played.
Returns:
None
"""
sd.play(audio_array, sample_rate)
sd.wait()


if __name__ == "__main__":
console.print("[cyan]Assistant started! Press Ctrl+C to exit.")

try:
while True:
# Wait for the user to press Enter to start recording
console.input(
"Press Enter to start recording, then press Enter again to stop."
)

data_queue = Queue() # type: ignore[var-annotated]
stop_event = threading.Event()

# Start recording in a background thread
recording_thread = threading.Thread(
target=record_audio,
args=(
stop_event,
data_queue,
),
args=(stop_event, data_queue),
)
recording_thread.start()

# Wait for the user to press Enter to stop recording
input() # No need to print a message as the previous message indicates to press Enter to stop
input()
stop_event.set()
recording_thread.join()

# Combine audio data from queue
audio_data = b"".join(list(data_queue.queue))
audio_np = (
np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
)

# Transcribe the recorded audio
if audio_np.size > 0: # Proceed if there's audio data
if audio_np.size > 0:
with console.status("Transcribing...", spinner="earth"):
text = transcribe(audio_np)
console.print(f"[yellow]You: {text}")
Expand All @@ -106,8 +140,7 @@ def callback(indata, frames, time, status):
sample_rate, audio_array = tts.long_form_synthesize(response)

console.print(f"[cyan]Assistant: {response}")
sd.play(audio_array, sample_rate)
sd.wait()
play_audio(sample_rate, audio_array)
else:
console.print(
"[red]No audio recorded. Please ensure your microphone is working."
Expand Down

0 comments on commit a3ae582

Please sign in to comment.