-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update agent * Update application * Address comment
- Loading branch information
Showing
22 changed files
with
492 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 16 additions & 8 deletions
24
example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,34 @@ | ||
import requests | ||
import os | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_place_weather(city: str) -> str: | ||
print(f"[Agent] Checking realtime weather info for {city}") | ||
logger.info(f"[Agent] Checking realtime weather info for {city}") | ||
|
||
"""Get city name and return city weather""" | ||
api_key = os.environ.get("weather_api_key") | ||
|
||
# 可以直接赋值给api_key,原始代码的config只有type类型。 | ||
base_url = "http://api.openweathermap.org/data/2.5/forecast?" | ||
complete_url = f"{base_url}q={city}&appid={api_key}&lang=zh_cn&units=metric" | ||
print(complete_url) | ||
response = requests.get(complete_url) | ||
logger.info(f"Requesting {complete_url}...") | ||
response = requests.get(complete_url, timeout=5) | ||
weather_data = response.json() | ||
|
||
if weather_data["cod"] != "200": | ||
print(f"获取天气信息失败,错误代码:{weather_data['cod']}") | ||
return None | ||
logger.error( | ||
f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}" | ||
) | ||
return f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}" | ||
|
||
element = weather_data["list"][0] | ||
|
||
return str( | ||
f"{city}的天气:\n 时间: {element['dt_txt']}\n 温度: {element['main']['temp']} °C\n 天气描述: {element['weather'][0]['description']}\n" | ||
) | ||
return f""" | ||
{city}的天气: | ||
时间: {element['dt_txt']} | ||
温度: {element['main']['temp']} °C | ||
天气描述: {element['weather'][0]['description']} | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from datetime import datetime | ||
from fastapi import APIRouter | ||
import logging | ||
|
||
from pydantic import BaseModel | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
demo_router = APIRouter() | ||
|
||
|
||
# Mock 数据 | ||
flights_data = [ | ||
{ | ||
"flight_number": "CA123", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "08:00", | ||
"arrival_time": "10:00", | ||
"price": 1200, | ||
}, | ||
{ | ||
"flight_number": "MU456", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "14:00", | ||
"arrival_time": "16:00", | ||
"price": 1300, | ||
}, | ||
{ | ||
"flight_number": "HU789", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "18:00", | ||
"arrival_time": "20:00", | ||
"price": 1100, | ||
}, | ||
{ | ||
"flight_number": "CA234", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "06:00", | ||
"arrival_time": "08:00", | ||
"price": 1250, | ||
}, | ||
{ | ||
"flight_number": "MU567", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "21:00", | ||
"arrival_time": "23:00", | ||
"price": 1350, | ||
}, | ||
] | ||
|
||
highspeed_trains_data = [ | ||
{ | ||
"train_number": "G1234", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "09:00", | ||
"arrival_time": "11:30", | ||
"price": 800, | ||
}, | ||
{ | ||
"train_number": "G5678", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "15:00", | ||
"arrival_time": "17:30", | ||
"price": 850, | ||
}, | ||
{ | ||
"train_number": "G9101", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "18:30", | ||
"arrival_time": "21:00", | ||
"price": 780, | ||
}, | ||
{ | ||
"train_number": "G1123", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "07:00", | ||
"arrival_time": "09:30", | ||
"price": 820, | ||
}, | ||
{ | ||
"train_number": "G4578", | ||
"from": "北京", | ||
"to": "上海", | ||
"departure_time": "22:00", | ||
"arrival_time": "00:30", | ||
"price": 870, | ||
}, | ||
] | ||
|
||
hotels_data = [ | ||
{ | ||
"hotel_name": "万豪酒店", | ||
"city": "上海", | ||
"price_per_night": 600, | ||
}, | ||
{ | ||
"hotel_name": "希尔顿酒店", | ||
"city": "上海", | ||
"price_per_night": 850, | ||
}, | ||
{ | ||
"hotel_name": "洲际酒店", | ||
"city": "上海", | ||
"price_per_night": 700, | ||
}, | ||
{ | ||
"hotel_name": "皇冠假日酒店", | ||
"city": "上海", | ||
"price_per_night": 750, | ||
}, | ||
{ | ||
"hotel_name": "如家酒店", | ||
"city": "上海", | ||
"price_per_night": 300, | ||
}, | ||
] | ||
|
||
|
||
@demo_router.get("/flights") | ||
async def get_flights(date: str, to_city: str, from_city: str): | ||
try: | ||
_ = datetime.strptime(date, "%Y-%m-%d") | ||
except Exception as _: | ||
return { | ||
"error": f"Invalid date format '{date}'. Please provide a date in YYYY-MM-DD format." | ||
} | ||
|
||
raw_fights = [ | ||
flight | ||
for flight in flights_data | ||
if flight["from"] == from_city and flight["to"] == to_city | ||
] | ||
|
||
for flight in raw_fights: | ||
flight["date"] = date | ||
|
||
return raw_fights | ||
|
||
|
||
@demo_router.get("/trains") | ||
async def get_trains(date: str, to_city: str, from_city: str): | ||
try: | ||
_ = datetime.strptime(date, "%Y-%m-%d") | ||
except Exception as _: | ||
return { | ||
"error": f"Invalid date format '{date}'. Please provide a date in YYYY-MM-DD format." | ||
} | ||
|
||
raw_trains = [ | ||
train | ||
for train in highspeed_trains_data | ||
if train["from"] == from_city and train["to"] == to_city | ||
] | ||
|
||
for train in raw_trains: | ||
train["date"] = date | ||
|
||
return raw_trains | ||
|
||
|
||
class HotelInput(BaseModel): | ||
checkin_date: str | ||
checkout_date: str | ||
city: str | ||
|
||
|
||
@demo_router.post("/hotels") | ||
async def get_hotels(input: HotelInput): | ||
try: | ||
_ = datetime.strptime(input.checkin_date, "%Y-%m-%d") | ||
_ = datetime.strptime(input.checkout_date, "%Y-%m-%d") | ||
except Exception as _: | ||
return { | ||
"error": f"Invalid date format '{input}'. Please provide a date in YYYY-MM-DD format." | ||
} | ||
|
||
hotels = [hotel for hotel in hotels_data if hotel["city"] == input.city] | ||
|
||
for hotel in hotels: | ||
hotel["checkin_date"] = input.checkin_date | ||
hotel["checkout_date"] = input.checkout_date | ||
|
||
return hotels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.