diff --git a/api/Pipfile b/api/Pipfile index 41db44b..f5cec86 100644 --- a/api/Pipfile +++ b/api/Pipfile @@ -15,6 +15,7 @@ werkzeug = "==2.0.3" flask-cors = "*" # ---- ---- openai = "==0.27.2" +litellm = "==0.1.619" numpy = "==1.24.2" tenacity = "==8.2.2" tiktoken = "*" diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index 024c0f3..99f3363 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -2,6 +2,8 @@ from dataclasses import asdict from typing import List, Dict, Callable import openai +import litellm +import uuid import re import tiktoken import time @@ -26,6 +28,9 @@ ENCODER = tiktoken.get_encoding("cl100k_base") +# initialize a budget manager to control costs for gpt-4/other llms +budget_manager = litellm.BudgetManager(project_name="stampy_chat") + DEBUG_PRINT = True def set_debug_print(val: bool): @@ -142,7 +147,7 @@ def construct_prompt(query: str, mode: str, history: List[Dict[str, str]], conte import time import json -def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print): +def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print, session_id: str = ""): try: # 1. Find the most relevant blocks from the Alignment Research Dataset yield {"state": "loading", "phase": "semantic"} @@ -181,7 +186,11 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, t1 = time.time() response = '' - for chunk in openai.ChatCompletion.create( + # check if budget exceeded for session + if budget_manager.get_current_cost(user=session_id) <= budget_manager.get_total_budget(session_id): + raise Exception(f"Exceeded the maximum budget for this session") + + for chunk in litellm.completion( model=COMPLETIONS_MODEL, messages=prompt, max_tokens=max_tokens_completion, @@ -189,6 +198,7 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, temperature=0, # may or may not be a good idea ): res = chunk["choices"][0]["delta"] + budget_manager.update_cost(completion_obj=response, user=session_id) if res is not None and res.get("content") is not None: response += res["content"] yield {"state": "streaming", "content": res["content"]} @@ -225,13 +235,16 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, # convert talk_to_robot_internal from dict generator into json generator def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print): - yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k, log)) + session_id = str(uuid.uuid4()) + budget_manager.create_budget(total_budget=10, user=session_id) # init $10 budget + yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k, log, session_id=session_id)) # wayyy simplified api def talk_to_robot_simple(index, query: str, log: Callable = print): res = {'response': ''} - - for block in talk_to_robot_internal(index, query, "default", [], log = log): + session_id = str(uuid.uuid4()) + budget_manager.create_budget(total_budget=10, user=session_id) # init $10 budget + for block in talk_to_robot_internal(index, query, "default", [], log = log, session_id=session_id): if block['state'] == 'loading' and block['phase'] == 'semantic' and 'citations' in block: citations = {} for i, c in enumerate(block['citations']):