Skip to content

Commit

Permalink
lint, optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
Luisotee committed Nov 27, 2024
1 parent 41fc958 commit ad2f818
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 174 deletions.
292 changes: 120 additions & 172 deletions apps/ai_api/eda_ai_api/api/routes/supervisor.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import os
import tempfile
from typing import Any, Dict, List, Optional

from fastapi import APIRouter, File, UploadFile, Form
from typing import Optional
from fastapi import APIRouter, File, Form, UploadFile
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
from onboarding.crew import OnboardingCrew
from opportunity_finder.crew import OpportunityFinderCrew
from proposal_writer.crew import ProposalWriterCrew

from eda_ai_api.models.supervisor import SupervisorRequest, SupervisorResponse
from eda_ai_api.models.supervisor import SupervisorResponse
from eda_ai_api.utils.audio_converter import convert_ogg
from eda_ai_api.utils.transcriber import transcribe_audio

import tempfile

router = APIRouter()

ALLOWED_FORMATS = {
"audio/mpeg": "mp3",
"audio/mp4": "mp4",
"audio/mpga": "mpga",
"audio/wav": "wav",
"audio/webm": "webm",
"audio/ogg": "ogg",
}

# Setup LLM and prompt
llm = ChatGroq(
model_name="llama3-groq-70b-8192-tool-use-preview",
Expand All @@ -36,7 +44,6 @@
Return only one word (discovery/proposal/onboarding/heartbeat):"""


TOPIC_EXTRACTOR_TEMPLATE = """
Extract up to 5 most relevant topics for grant opportunity research from the user message.
Return only a comma-separated list of topics (maximum 5), no other text.
Expand All @@ -54,7 +61,7 @@
Output:"""

# Create prompt templates
# Create prompt templates and chains
router_prompt = PromptTemplate(input_variables=["message"], template=ROUTER_TEMPLATE)
topic_prompt = PromptTemplate(
input_variables=["message"], template=TOPIC_EXTRACTOR_TEMPLATE
Expand All @@ -63,189 +70,130 @@
input_variables=["message"], template=PROPOSAL_EXTRACTOR_TEMPLATE
)

# Create LLM chains
router_chain = LLMChain(llm=llm, prompt=router_prompt)
topic_chain = LLMChain(llm=llm, prompt=topic_prompt)
proposal_chain = LLMChain(llm=llm, prompt=proposal_prompt)


def detect_content_type(file: UploadFile) -> Optional[str]:
"""Helper to detect content type from file"""
if hasattr(file, "content_type") and file.content_type:
return file.content_type

if hasattr(file, "mime_type") and file.mime_type:
return file.mime_type

ext = os.path.splitext(file.filename)[1].lower()
return {
".mp3": "audio/mpeg",
".mp4": "audio/mp4",
".mpeg": "audio/mpeg",
".mpga": "audio/mpga",
".m4a": "audio/mp4",
".wav": "audio/wav",
".webm": "audio/webm",
".ogg": "audio/ogg",
}.get(ext)


async def process_audio(audio: UploadFile) -> str:
"""Process audio file and return transcription"""
content_type = detect_content_type(audio)
content = await audio.read()
audio_path = ""

try:
if not content_type:
content_type = "audio/mpeg"

if content_type == "audio/ogg":
audio_path = convert_ogg(content, output_format="mp3")
else:
with tempfile.NamedTemporaryFile(
suffix=f".{ALLOWED_FORMATS.get(content_type, 'mp3')}", delete=False
) as temp_file:
temp_file.write(content)
audio_path = temp_file.name

return transcribe_audio(audio_path)
finally:
if os.path.exists(audio_path):
os.unlink(audio_path)


def extract_topics(message: str) -> List[str]:
"""Extract topics from message"""
topics_raw = topic_chain.run(message=message)
topics = [t.strip() for t in topics_raw.split(",") if t.strip()][:5]
return topics if topics else ["AI", "Technology"]


def extract_proposal_details(message: str) -> tuple[str, str]:
"""Extract project and grant details"""
extracted = proposal_chain.run(message=message).split("|")
community_project = extracted[0].strip() if len(extracted) > 0 else "unknown"
grant_call = extracted[1].strip() if len(extracted) > 1 else "unknown"
return community_project, grant_call


def process_decision(decision: str, message: str) -> Dict[str, Any]:
"""Process routing decision and return result"""
print("\n==================================================")
print(f" DECISION: {decision}")
print("==================================================\n")

if decision == "discovery":
topics = extract_topics(message)
print("\n==================================================")
print(f" EXTRACTED TOPICS: {topics}")
print("==================================================\n")
return (
OpportunityFinderCrew().crew().kickoff(inputs={"topics": ", ".join(topics)})
)
elif decision == "proposal":
community_project, grant_call = extract_proposal_details(message)
print(f" PROJECT NAME: {community_project}")
print(f" GRANT PROGRAM: {grant_call}")
return (
ProposalWriterCrew(
community_project=community_project, grant_call=grant_call
)
.crew()
.kickoff()
)
elif decision == "heartbeat":
return {"is_alive": True}
elif decision == "onboarding":
return OnboardingCrew().crew().kickoff()
else:
return {"error": f"Unknown decision type: {decision}"}


@router.post("/supervisor", response_model=SupervisorResponse)
async def supervisor_route(
message: Optional[str] = Form(None), audio: Optional[UploadFile] = File(None)
) -> SupervisorResponse:
ALLOWED_FORMATS = {
"audio/mpeg": "mp3",
"audio/mp4": "mp4",
"audio/mpeg": "mpeg",
"audio/mpga": "mpga",
"audio/mp4": "m4a",
"audio/wav": "wav",
"audio/webm": "webm",
"audio/ogg": "ogg",
}

def detect_content_type(file: UploadFile) -> Optional[str]:
"""Helper to detect content type from file"""
if hasattr(file, "content_type") and file.content_type:
return file.content_type

if hasattr(file, "mime_type") and file.mime_type:
return file.mime_type

ext = os.path.splitext(file.filename)[1].lower()
return {
".mp3": "audio/mpeg",
".mp4": "audio/mp4",
".mpeg": "audio/mpeg",
".mpga": "audio/mpga",
".m4a": "audio/mp4",
".wav": "audio/wav",
".webm": "audio/webm",
".ogg": "audio/ogg",
}.get(ext)

"""Main route handler for supervisor API"""
try:
if audio:
content_type = detect_content_type(audio)
content = await audio.read()

try:
audio_path = ""
# Default to mp3 if content type detection failed
if not content_type:
content_type = "audio/mpeg"

if content_type == "audio/ogg":
audio_path = convert_ogg(content, output_format="mp3")
else:
with tempfile.NamedTemporaryFile(
suffix=f".{ALLOWED_FORMATS.get(content_type, 'mp3')}",
delete=False,
) as temp_file:
temp_file.write(content)
audio_path = temp_file.name

transcription = transcribe_audio(audio_path)
print("\n==================================================")
print(f" TRANSCRIPTION: {transcription}")
print("==================================================\n")

if os.path.exists(audio_path):
os.unlink(audio_path)

# Process transcription through the router chain
decision = router_chain.run(message=transcription).strip().lower()

# Continue with existing decision handling logic...
if decision == "discovery":
topics_raw = topic_chain.run(message=transcription)
topics = [t.strip() for t in topics_raw.split(",") if t.strip()][:5]
if not topics:
topics = ["AI", "Technology"]
result = (
OpportunityFinderCrew()
.crew()
.kickoff(inputs={"topics": ", ".join(topics)})
)

elif decision == "proposal":
extracted = proposal_chain.run(message=transcription).split("|")
community_project = (
extracted[0].strip() if len(extracted) > 0 else "unknown"
)
grant_call = (
extracted[1].strip() if len(extracted) > 1 else "unknown"
)
result = (
ProposalWriterCrew(
community_project=community_project, grant_call=grant_call
)
.crew()
.kickoff()
)

elif decision == "heartbeat":
result = {"is_alive": True}

elif decision == "onboarding":
result = OnboardingCrew().crew().kickoff()

else:
result = {"error": f"Unknown decision type: {decision}"}

return SupervisorResponse(result=str(result))

except Exception as e:
if os.path.exists(audio_path):
os.unlink(audio_path)
return SupervisorResponse(result=f"Error processing audio: {str(e)}")

transcription = await process_audio(audio)
print("==================================================\n")
print(f" TRANSCRIPTION: {transcription}")
print("==================================================\n")
decision = router_chain.run(message=transcription).strip().lower()
result = process_decision(decision, transcription)
elif message:
# Existing message handling logic
decision = router_chain.run(message=message).strip().lower()

# Print input message and decision for debugging
print("\n==================================================")
print("==================================================\n")
print(f" INPUT MESSAGE: {message}")
print("==================================================")
print(f" DECISION: {decision}")
print("==================================================\n")

# Handle different decision paths
if decision == "discovery":
# Extract topics using LLM (limited to 5 in prompt)
topics_raw = topic_chain.run(message=message)
topics = [t.strip() for t in topics_raw.split(",") if t.strip()][
:5
] # Safety check
if not topics:
topics = ["AI", "Technology"] # Fallback topics

print("==================================================")
print(f" EXTRACTED TOPICS: {topics}")
print("==================================================\n")

result = (
OpportunityFinderCrew()
.crew()
.kickoff(inputs={"topics": ", ".join(topics)})
)

elif decision == "proposal":
# Extract project and grant details using LLM
extracted = proposal_chain.run(message=message).split("|")
community_project = (
extracted[0].strip() if len(extracted) > 0 else "unknown"
)
grant_call = extracted[1].strip() if len(extracted) > 1 else "unknown"

print("==================================================")
print(f" PROJECT NAME: {community_project}")
print(f" GRANT PROGRAM: {grant_call}")
print("==================================================\n")

result = (
ProposalWriterCrew(
community_project=community_project, grant_call=grant_call
)
.crew()
.kickoff()
)

elif decision == "heartbeat":
result = {"is_alive": True}

elif decision == "onboarding":
# Generate guide using OnboardingCrew
result = OnboardingCrew().crew().kickoff()

else:
result = {"error": f"Unknown decision type: {decision}"}

decision = router_chain.run(message=message).strip().lower()
result = process_decision(decision, message)
else:
return SupervisorResponse(
result="Error: Neither message nor audio provided"
)

return SupervisorResponse(result=str(result))

except Exception as e:
return SupervisorResponse(result=f"Error processing request: {str(e)}")
1 change: 1 addition & 0 deletions apps/ai_api/eda_ai_api/models/supervisor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional

from fastapi import UploadFile
from pydantic import BaseModel

Expand Down
3 changes: 2 additions & 1 deletion apps/ai_api/eda_ai_api/utils/audio_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from pathlib import Path
import tempfile
from pathlib import Path
from typing import Literal

import ffmpeg

AudioFormat = Literal["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
Expand Down
1 change: 0 additions & 1 deletion apps/ai_api/eda_ai_api/utils/transcriber.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from groq import Groq
from loguru import logger

Expand Down

0 comments on commit ad2f818

Please sign in to comment.