diff --git a/api/.env.example b/api/.env.example index e22e57b..0b00b8f 100644 --- a/api/.env.example +++ b/api/.env.example @@ -1,3 +1,6 @@ +LOG_LEVEL="INFO" +DISCORD_LOG_LEVEL="INFO" + OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" PINECONE_API_KEY="" # leave blank to use our online API instead LOGGING_URL="" # leave blank if you're not testing logging specifically diff --git a/api/main.py b/api/main.py index a937f39..f18ef68 100644 --- a/api/main.py +++ b/api/main.py @@ -5,7 +5,8 @@ import json import re -from stampy_chat.env import PINECONE_INDEX, FLASK_PORT, log +from stampy_chat import logging +from stampy_chat.env import PINECONE_INDEX, FLASK_PORT from stampy_chat.get_blocks import get_top_k_blocks from stampy_chat.chat import talk_to_robot, talk_to_robot_simple @@ -29,7 +30,7 @@ def stream(src): @cross_origin() def semantic(): query = request.json['query'] - k = request.json['k'] if 'k' in request.json else 20 + k = request.json.get('k', 20) return jsonify([dataclasses.asdict(block) for block in get_top_k_blocks(PINECONE_INDEX, query, k)]) @@ -45,7 +46,7 @@ def chat(): mode = request.json['mode'] history = request.json['history'] - return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history, log = log)), mimetype='text/event-stream') + return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history)), mimetype='text/event-stream') # ------------- simplified non-streaming chat for internal testing ------------- @@ -64,7 +65,7 @@ def chat_simplified(param=''): def human(id): import requests r = requests.get(f"https://aisafety.info/questions/{id}") - log(f"clicked followup '{json.loads(r.text)['data']['title']}': https://stampy.ai/?state={id}") + logging.info(f"clicked followup '{json.loads(r.text)['data']['title']}': https://stampy.ai/?state={id}") # run a regex to replace all relative links with absolute links. Just doing # a regex for now since we really don't need to parse everything out then diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index 024c0f3..a2c7cf6 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -3,11 +3,16 @@ from typing import List, Dict, Callable import openai import re +from sqlalchemy.orm import PropComparator import tiktoken import time from stampy_chat.followups import multisearch_authored from stampy_chat.get_blocks import get_top_k_blocks, Block +from stampy_chat import logging + + +logger = logging.getLogger(__name__) # OpenAI models @@ -21,17 +26,12 @@ # NOTE: All this is approximate, there's bits I'm intentionally not counting. Leave a buffer beyond what you might expect. NUM_TOKENS = 8191 if COMPLETIONS_MODEL == 'gpt-4' else 4095 +TOKENS_BUFFER = 50 # the number of tokens to leave as a buffer when calculating remaining tokens HISTORY_FRACTION = 0.25 # the (approximate) fraction of num_tokens to use for history text before truncating CONTEXT_FRACTION = 0.5 # the (approximate) fraction of num_tokens to use for context text before truncating ENCODER = tiktoken.get_encoding("cl100k_base") -DEBUG_PRINT = True - -def set_debug_print(val: bool): - global DEBUG_PRINT - DEBUG_PRINT = val - # --------------------------------- prompt code -------------------------------- @@ -47,9 +47,9 @@ def cap(text: str, max_tokens: int) -> str: else: return ENCODER.decode(encoded_text[:max_tokens]) + " ..." +Prompt = List[Dict[str, str]] - -def construct_prompt(query: str, mode: str, history: List[Dict[str, str]], context: List[Block]) -> List[Dict[str, str]]: +def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block]) -> Prompt: prompt = [] @@ -142,7 +142,27 @@ def construct_prompt(query: str, mode: str, history: List[Dict[str, str]], conte import time import json -def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print): + +def check_openai_moderation(prompt: Prompt, query: str): + prompt_string = '\n\n'.join([message["content"] for message in prompt]) + mod_res = openai.Moderation.create( input = [ query, prompt_string ]) + + if any(map(lambda x: x["flagged"], mod_res["results"])): + logger.moderation_issue(query, prompt_string, mod_res) + + raise ValueError("This conversation was rejected by OpenAI's moderation filter. Sorry.") + + +def remaining_tokens(prompt: Prompt): + # Count number of tokens left for completion (-50 for a buffer) + used_tokens = sum([ + len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"])) + for message in prompt + ]) + return NUM_TOKENS - used_tokens - TOKENS_BUFFER + + +def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, k: int = STANDARD_K): try: # 1. Find the most relevant blocks from the Alignment Research Dataset yield {"state": "loading", "phase": "semantic"} @@ -156,25 +176,10 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, # 3. Run both the standalone query and the full prompt through # moderation to see if it will be accepted by OpenAI's api - - prompt_string = '\n\n'.join([message["content"] for message in prompt]) - mod_res = openai.Moderation.create( input = [ query, prompt_string ]) - - if any(map(lambda x: x["flagged"], mod_res["results"])): - - # this is a biiig ask of a discord webhook - put most important - # info at start such that it's more likely to not be cut off - log('-' * 80) - log("MODERATION REJECTED") - log("MODERATION RESPONSE:\n\n" + json.dumps(mod_res["results"], indent=2)) - log("REJECTED QUERY: " + query) - log("REJECTED PROMPT:\n\n" + prompt_string) - log('-' * 80) - - raise ValueError("This conversation was rejected by OpenAI's moderation filter. Sorry.") + check_openai_moderation(prompt, query) # 4. Count number of tokens left for completion (-50 for a buffer) - max_tokens_completion = NUM_TOKENS - sum([len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"])) for message in prompt]) - 50 + max_tokens_completion = remaining_tokens(prompt) # 5. Answer the user query yield {"state": "loading", "phase": "llm"} @@ -195,43 +200,39 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, t2 = time.time() - print(f'Time to get response: {t2-t1:.2f}s') - - if DEBUG_PRINT: - print('\n' * 10) - print(" ------------------------------ prompt: -----------------------------") + logger.debug(f'Time to get response: {time.time() - t1:.2f}s') + if logger.is_debug(): + logger.debug('\n' * 10) + logger.debug(" ------------------------------ prompt: -----------------------------") for message in prompt: - print(f"----------- {message['role']}: ------------------") - print(message['content']) - - print('\n' * 10) - - print(' ------------------------------ response: -----------------------------') - print(response) + logger.debug("----------- %s: ------------------", message['role']) + logger.debug(message['content']) + logger.debug('\n' * 10) + logger.debug(' ------------------------------ response: -----------------------------') + logger.debug(response) - log(query) - log(response) + logger.interaction(query, response, history, prompt, top_k_blocks) # yield done state, possibly with followup questions fin_json = {'state': 'done'} - followups = multisearch_authored([query, response], DEBUG_PRINT) + followups = multisearch_authored([query, response]) for i, followup in enumerate(followups): fin_json[f'followup_{i}'] = asdict(followup) yield fin_json except Exception as e: - print(e) + logger.error(e) yield {'state': 'error', 'error': str(e)} # convert talk_to_robot_internal from dict generator into json generator -def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print): - yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k, log)) +def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K): + yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k)) # wayyy simplified api -def talk_to_robot_simple(index, query: str, log: Callable = print): +def talk_to_robot_simple(index, query: str): res = {'response': ''} - for block in talk_to_robot_internal(index, query, "default", [], log = log): + for block in talk_to_robot_internal(index, query, "default", []): if block['state'] == 'loading' and block['phase'] == 'semantic' and 'citations' in block: citations = {} for i, c in enumerate(block['citations']): diff --git a/api/src/stampy_chat/env.py b/api/src/stampy_chat/env.py index 211e4a0..4bc2952 100644 --- a/api/src/stampy_chat/env.py +++ b/api/src/stampy_chat/env.py @@ -1,7 +1,6 @@ import os import openai import pinecone -from discord_webhook import DiscordWebhook if os.path.exists('.env'): from dotenv import load_dotenv @@ -25,6 +24,7 @@ PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME", "alignment-search") PINECONE_INDEX = None PINECONE_NAMESPACE = os.environ.get("PINECONE_NAMESPACE", "alignment-search") # "normal" or "finetuned" for the new index, "alignment-search" for the old one + # Only init pinecone if we have an env value for it. if PINECONE_API_KEY: pinecone.init( @@ -34,17 +34,6 @@ PINECONE_INDEX = pinecone.Index(index_name=PINECONE_INDEX_NAME) -# log something only if the logging url is set -def log(*args, end="\n"): - message = " ".join([str(arg) for arg in args]) + end - # print(message) - if DISCORD_LOGGING_URL is not None and DISCORD_LOGGING_URL != "": - while len(message) > 2000 - 8: - m_section, message = message[:2000 - 8], message[2000 - 8:] - m_section = "```\n" + m_section + "\n```" - DiscordWebhook(url=DISCORD_LOGGING_URL, content=m_section).execute() - DiscordWebhook(url=DISCORD_LOGGING_URL, content="```\n" + message + "\n```").execute() - ### MySQL ### user = os.environ.get("CHAT_DB_USER", "user") password = os.environ.get("CHAT_DB_PASSWORD", "we all live in a yellow submarine") diff --git a/api/src/stampy_chat/followups.py b/api/src/stampy_chat/followups.py index 7553e5e..96567d8 100644 --- a/api/src/stampy_chat/followups.py +++ b/api/src/stampy_chat/followups.py @@ -1,7 +1,10 @@ +import requests from dataclasses import dataclass from typing import List from urllib.parse import quote -import requests +from stampy_chat import logging + +logger = logging.getLogger(__name__) SIMILARITY_THRESHOLD = 0.4 # bit of a shot in the dark - play with this later MAX_FOLLOWUPS = 3 @@ -15,11 +18,11 @@ class Followup: # do a search like this: # https://nlp.stampy.ai/api/search?query=what%20is%20agi -def search_authored(query: str, DEBUG_PRINT: bool = False): - multisearch_authored([query], DEBUG_PRINT) +def search_authored(query: str): + multisearch_authored([query]) # search with multiple queries, combine results -def multisearch_authored(queries: List[str], DEBUG_PRINT: bool = False): +def multisearch_authored(queries: List[str]): followups = {} @@ -36,20 +39,17 @@ def multisearch_authored(queries: List[str], DEBUG_PRINT: bool = False): followups = followups[:MAX_FOLLOWUPS] - if DEBUG_PRINT: - print(" ------------------------------ suggested followups: -----------------------------") + if logger.is_debug(): + logger.debug(" ------------------------------ suggested followups: -----------------------------") for followup in followups: if followup.score > SIMILARITY_THRESHOLD: - print(f'{followup.score:.2f} - suggested to user') + logger.debug(f'{followup.score:.2f} - suggested to user') else: - print(f'{followup.score:.2f} - not suggested') - - print(followup.text) - print(followup.pageid) - print() + logger.debug(f'{followup.score:.2f} - not suggested') + logger.debug(followup.text) + logger.debug(followup.pageid) + logger.debug('') followups = [ f for f in followups if f.score > SIMILARITY_THRESHOLD ] return followups - - diff --git a/api/src/stampy_chat/get_blocks.py b/api/src/stampy_chat/get_blocks.py index b75fcca..0119ccb 100644 --- a/api/src/stampy_chat/get_blocks.py +++ b/api/src/stampy_chat/get_blocks.py @@ -1,4 +1,3 @@ -from typing import List, Tuple import dataclasses import datetime import itertools @@ -7,7 +6,12 @@ import regex as re import requests import time +from typing import List, Tuple from stampy_chat.env import PINECONE_NAMESPACE +from stampy_chat import logging + + +logger = logging.getLogger(__name__) # ---------------------------------- constants --------------------------------- @@ -60,7 +64,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: if index is None: - print('Pinecone index not found, performing semantic search on chat.stampy.ai endpoint.') + logger.info('Pinecone index not found, performing semantic search on chat.stampy.ai endpoint.') response = requests.post( "https://chat.stampy.ai:8443/semantic", json = { @@ -78,7 +82,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: query_embedding = get_embedding(user_query) t1 = time.time() - print(f'Time to get embedding: {t1-t:.2f}s') + logger.debug(f'Time to get embedding: {t1-t:.2f}s') query_response = index.query( namespace=PINECONE_NAMESPACE, @@ -115,7 +119,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: t2 = time.time() - print(f'Time to get top-k blocks: {t2-t1:.2f}s') + logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s') # for all blocks that are "the same" (same title, author, date, url, tags), # combine their text with "....." in between. Return them in order such @@ -130,7 +134,8 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: for key, group in itertools.groupby(blocks_plus_old_index, key=key): group = list(group) - if len(group) == 0: continue + if not group: + continue # group = group[:3] # limit to a max of 3 blocks from any one source @@ -150,6 +155,5 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: def strip_block(text: str) -> str: r = re.match(r"^\"(.*)\"\s*-\s*Title:.*$", text, re.DOTALL) if not r: - print("Warning: couldn't strip block") - print(text) + logger.warning("couldn't strip block:\n%s", text) return r.group(1) if r else text diff --git a/api/src/stampy_chat/logging.py b/api/src/stampy_chat/logging.py new file mode 100644 index 0000000..67254af --- /dev/null +++ b/api/src/stampy_chat/logging.py @@ -0,0 +1,75 @@ +import json +from logging import * +from discord_webhook import DiscordWebhook +from stampy_chat.env import LOG_LEVEL, DISCORD_LOG_LEVEL, DISCORD_LOGGING_URL +from stampy_chat.db.session import ItemAdder +from stampy_chat.db.models import Interaction + + +class DiscordHandler(StreamHandler): + def emit(self, record): + # Ignore messages that come from non chat modules + if record.name.startswith('stampy_chat'): + return + + # Ignore messages that have lower levels + if record.levelno < getLevelName(DISCORD_LOG_LEVEL): + return + + self.to_discord(self.format(record)) + + def to_discord(self, message): + if not DISCORD_LOGGING_URL: + return + + while len(message) > 2000 - 8: + m_section, message = message[:2000 - 8], message[2000 - 8:] + m_section = "```\n" + m_section + "\n```" + DiscordWebhook(url=DISCORD_LOGGING_URL, content=m_section).execute() + DiscordWebhook(url=DISCORD_LOGGING_URL, content="```\n" + message + "\n```").execute() + + +class ChatLogger(Logger): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.addHandler(DiscordHandler()) + self.item_adder = ItemAdder() + + def is_debug(self): + return self.isEnabledFor(DEBUG) + + def interaction(self, query, response, history, prompt, blocks): + prompt = [i for i in prompt if i.get('role') == 'system'] + prompt = prompt[0].get('content') if prompt else None + + self.item_adder.add( + Interaction( + # session_id=session_id, + interaction_no=len([i for i in history if i.get('role') == 'user']), + query=query, + prompt=prompt, + response=response, + chunks=",".join(b.id for b in blocks), + ) + ) + self.info('query: %s', query) + self.info('response: %s', response) + + def moderation_issue(self, query, prompt_string, mod_res): + # this is a biiig ask of a discord webhook - put most important + # info at start such that it's more likely to not be cut off + messages = [ + '-' * 80, + "MODERATION REJECTED", + "MODERATION RESPONSE:\n\n" + json.dumps(mod_res["results"], indent=2), + "REJECTED QUERY: " + query, + "REJECTED PROMPT:\n\n " + prompt_string, + '-' * 80, + ] + for message in messages: + self.warn(message) + + +setLoggerClass(ChatLogger) +basicConfig(level=getLevelName(LOG_LEVEL))