Skip to content

Commit

Permalink
Update agent module (#259)
Browse files Browse the repository at this point in the history
* Update agent

* Update application

* Address comment
  • Loading branch information
moria97 authored Oct 30, 2024
1 parent 7de2849 commit 5650acc
Show file tree
Hide file tree
Showing 22 changed files with 492 additions and 196 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
{
"intents": {
"retrieval": "关于一些通用的信息检索,比如搜索旅游攻略、搜索美食攻略、搜索注意事项等信息。",
"agent": "实时性的信息查询,比如查询航班信息、查询高铁信息、查询天气等时效性很强的信息。"
},
"system_prompt": "你是一个旅游小助手,可以帮助用户查询指定时间从A地区到B地区的机票信息,火车票信息以及天气信息等。请严格使用输入的工具,不要虚构任何细节。",
"function_tools": [
{
Expand All @@ -25,7 +21,7 @@
"api_tools": [
{
"name": "search_flight_ticket_api",
"url": "http://127.0.0.1:8070/demo/api/flights",
"url": "http://127.0.0.1:8001/demo/api/flights",
"headers": {
"Authorization": "Bearer YOUR_ACCESS_TOKEN"
},
Expand All @@ -42,14 +38,14 @@
},
"date": {
"type": "str",
"description": "出发时间,如'2024-03-29'"
"description": "出发时间,YYYY-MM-DD格式,如'2024-03-29'"
}
},
"required": ["from_city", "to_city", "date"]
},
{
"name": "search_train_ticket_api",
"url": "http://127.0.0.1:8070/demo/api/trains",
"url": "http://127.0.0.1:8001/demo/api/trains",
"headers": {
"Authorization": "Bearer YOUR_ACCESS_TOKEN"
},
Expand All @@ -66,14 +62,14 @@
},
"date": {
"type": "str",
"description": "出发时间,如'2024-03-29'"
"description": "出发时间,YYYY-MM-DD格式,如'2024-03-29'"
}
},
"required": ["from_city", "to_city", "date"]
},
{
"name": "search_hotels_api",
"url": "http://127.0.0.1:8070/demo/api/hotels",
"url": "http://127.0.0.1:8001/demo/api/hotels",
"headers": {
"Authorization": "Bearer YOUR_ACCESS_TOKEN"
},
Expand All @@ -85,12 +81,16 @@
"type": "str",
"description": "查询的城市,如'北京'、'上海'、'南京''"
},
"date": {
"checkin_date": {
"type": "str",
"description": "入住时间,YYYY-MM-DD格式,如'2024-03-29'"
},
"checkout_date": {
"type": "str",
"description": "出发时间,如'2024-03-29'"
"description": "离店时间,YYYY-MM-DD格式,如'2024-03-31'"
}
},
"required": ["city", "date"]
"required": ["city", "checkin_date", "checkout_date"]
}
]
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import requests
import os
import logging

logger = logging.getLogger(__name__)


def get_place_weather(city: str) -> str:
print(f"[Agent] Checking realtime weather info for {city}")
logger.info(f"[Agent] Checking realtime weather info for {city}")

"""Get city name and return city weather"""
api_key = os.environ.get("weather_api_key")

# 可以直接赋值给api_key,原始代码的config只有type类型。
base_url = "http://api.openweathermap.org/data/2.5/forecast?"
complete_url = f"{base_url}q={city}&appid={api_key}&lang=zh_cn&units=metric"
print(complete_url)
response = requests.get(complete_url)
logger.info(f"Requesting {complete_url}...")
response = requests.get(complete_url, timeout=5)
weather_data = response.json()

if weather_data["cod"] != "200":
print(f"获取天气信息失败,错误代码:{weather_data['cod']}")
return None
logger.error(
f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}"
)
return f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}"

element = weather_data["list"][0]

return str(
f"{city}的天气:\n 时间: {element['dt_txt']}\n 温度: {element['main']['temp']} °C\n 天气描述: {element['weather'][0]['description']}\n"
)
return f"""
{city}的天气:
时间: {element['dt_txt']}
温度: {element['main']['temp']} °C
天气描述: {element['weather'][0]['description']}
"""
192 changes: 192 additions & 0 deletions src/pai_rag/app/api/agent_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from datetime import datetime
from fastapi import APIRouter
import logging

from pydantic import BaseModel

logger = logging.getLogger(__name__)

demo_router = APIRouter()


# Mock 数据
flights_data = [
{
"flight_number": "CA123",
"from": "北京",
"to": "上海",
"departure_time": "08:00",
"arrival_time": "10:00",
"price": 1200,
},
{
"flight_number": "MU456",
"from": "北京",
"to": "上海",
"departure_time": "14:00",
"arrival_time": "16:00",
"price": 1300,
},
{
"flight_number": "HU789",
"from": "北京",
"to": "上海",
"departure_time": "18:00",
"arrival_time": "20:00",
"price": 1100,
},
{
"flight_number": "CA234",
"from": "北京",
"to": "上海",
"departure_time": "06:00",
"arrival_time": "08:00",
"price": 1250,
},
{
"flight_number": "MU567",
"from": "北京",
"to": "上海",
"departure_time": "21:00",
"arrival_time": "23:00",
"price": 1350,
},
]

highspeed_trains_data = [
{
"train_number": "G1234",
"from": "北京",
"to": "上海",
"departure_time": "09:00",
"arrival_time": "11:30",
"price": 800,
},
{
"train_number": "G5678",
"from": "北京",
"to": "上海",
"departure_time": "15:00",
"arrival_time": "17:30",
"price": 850,
},
{
"train_number": "G9101",
"from": "北京",
"to": "上海",
"departure_time": "18:30",
"arrival_time": "21:00",
"price": 780,
},
{
"train_number": "G1123",
"from": "北京",
"to": "上海",
"departure_time": "07:00",
"arrival_time": "09:30",
"price": 820,
},
{
"train_number": "G4578",
"from": "北京",
"to": "上海",
"departure_time": "22:00",
"arrival_time": "00:30",
"price": 870,
},
]

hotels_data = [
{
"hotel_name": "万豪酒店",
"city": "上海",
"price_per_night": 600,
},
{
"hotel_name": "希尔顿酒店",
"city": "上海",
"price_per_night": 850,
},
{
"hotel_name": "洲际酒店",
"city": "上海",
"price_per_night": 700,
},
{
"hotel_name": "皇冠假日酒店",
"city": "上海",
"price_per_night": 750,
},
{
"hotel_name": "如家酒店",
"city": "上海",
"price_per_night": 300,
},
]


@demo_router.get("/flights")
async def get_flights(date: str, to_city: str, from_city: str):
try:
_ = datetime.strptime(date, "%Y-%m-%d")
except Exception as _:
return {
"error": f"Invalid date format '{date}'. Please provide a date in YYYY-MM-DD format."
}

raw_fights = [
flight
for flight in flights_data
if flight["from"] == from_city and flight["to"] == to_city
]

for flight in raw_fights:
flight["date"] = date

return raw_fights


@demo_router.get("/trains")
async def get_trains(date: str, to_city: str, from_city: str):
try:
_ = datetime.strptime(date, "%Y-%m-%d")
except Exception as _:
return {
"error": f"Invalid date format '{date}'. Please provide a date in YYYY-MM-DD format."
}

raw_trains = [
train
for train in highspeed_trains_data
if train["from"] == from_city and train["to"] == to_city
]

for train in raw_trains:
train["date"] = date

return raw_trains


class HotelInput(BaseModel):
checkin_date: str
checkout_date: str
city: str


@demo_router.post("/hotels")
async def get_hotels(input: HotelInput):
try:
_ = datetime.strptime(input.checkin_date, "%Y-%m-%d")
_ = datetime.strptime(input.checkout_date, "%Y-%m-%d")
except Exception as _:
return {
"error": f"Invalid date format '{input}'. Please provide a date in YYYY-MM-DD format."
}

hotels = [hotel for hotel in hotels_data if hotel["city"] == input.city]

for hotel in hotels:
hotel["checkin_date"] = input.checkin_date
hotel["checkout_date"] = input.checkout_date

return hotels
9 changes: 8 additions & 1 deletion src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,14 @@ async def aquery_retrieval(query: RetrievalQuery):

@router.post("/query/agent")
async def aquery_agent(query: RagQuery):
return await rag_service.aquery_agent(query)
response = await rag_service.aquery_agent(query)
if not query.stream:
return response
else:
return StreamingResponse(
response,
media_type="text/event-stream",
)


@router.post("/config/agent")
Expand Down
8 changes: 4 additions & 4 deletions src/pai_rag/app/api/service.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from fastapi import APIRouter, FastAPI
from fastapi import FastAPI
from pai_rag.core.rag_config_manager import RagConfigManager
from pai_rag.core.rag_service import rag_service
from pai_rag.app.api import query
from pai_rag.app.api import agent_demo
from pai_rag.app.api.middleware import init_middleware
from pai_rag.app.api.error_handler import config_app_errors


def init_router(app: FastAPI):
api_router = APIRouter()
api_router.include_router(query.router, tags=["RagQuery"])
app.include_router(api_router, prefix="/service")
app.include_router(query.router, prefix="/service", tags=["RAG"])
app.include_router(agent_demo.demo_router, tags=["AgentDemo"], prefix="/demo/api")


def configure_app(app: FastAPI, rag_configuration: RagConfigManager):
Expand Down
17 changes: 12 additions & 5 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def query(
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
if not chunk.startswith("data:"):
continue
chunk_response = dotdict(json.loads(chunk[5:]))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
Expand All @@ -249,7 +251,9 @@ def query_search(
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
if not chunk.startswith("data:"):
continue
chunk_response = dotdict(json.loads(chunk[5:]))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(text, chunk_response, stream=stream)
Expand All @@ -275,7 +279,9 @@ def query_data_analysis(
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
if not chunk.startswith("data:"):
continue
chunk_response = dotdict(json.loads(chunk[5:]))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(text, chunk_response, stream=stream)
Expand Down Expand Up @@ -308,7 +314,9 @@ def query_llm(
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
if not chunk.startswith("data:"):
continue
chunk_response = dotdict(json.loads(chunk[5:]))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
Expand Down Expand Up @@ -448,7 +456,6 @@ def get_config(self):
r = requests.get(self.config_url, timeout=DEFAULT_CLIENT_TIME_OUT)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)

config = RagConfig.model_validate_json(json_data=r.text)
return config

Expand Down
Loading

0 comments on commit 5650acc

Please sign in to comment.