Skip to content

Commit

Permalink
Merge pull request #92 from StampyAI/mysql-logging
Browse files Browse the repository at this point in the history
Log to mysql
  • Loading branch information
mruwnik authored Sep 29, 2023
2 parents 6da4851 + 5e91e67 commit aa07d73
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 83 deletions.
3 changes: 3 additions & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
LOG_LEVEL="INFO"
DISCORD_LOG_LEVEL="INFO"

OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
PINECONE_API_KEY="" # leave blank to use our online API instead
LOGGING_URL="" # leave blank if you're not testing logging specifically
Expand Down
9 changes: 5 additions & 4 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import json
import re

from stampy_chat.env import PINECONE_INDEX, FLASK_PORT, log
from stampy_chat import logging
from stampy_chat.env import PINECONE_INDEX, FLASK_PORT
from stampy_chat.get_blocks import get_top_k_blocks
from stampy_chat.chat import talk_to_robot, talk_to_robot_simple

Expand All @@ -29,7 +30,7 @@ def stream(src):
@cross_origin()
def semantic():
query = request.json['query']
k = request.json['k'] if 'k' in request.json else 20
k = request.json.get('k', 20)
return jsonify([dataclasses.asdict(block) for block in get_top_k_blocks(PINECONE_INDEX, query, k)])


Expand All @@ -45,7 +46,7 @@ def chat():
mode = request.json['mode']
history = request.json['history']

return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history, log = log)), mimetype='text/event-stream')
return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history)), mimetype='text/event-stream')


# ------------- simplified non-streaming chat for internal testing -------------
Expand All @@ -64,7 +65,7 @@ def chat_simplified(param=''):
def human(id):
import requests
r = requests.get(f"https://aisafety.info/questions/{id}")
log(f"clicked followup '{json.loads(r.text)['data']['title']}': https://stampy.ai/?state={id}")
logging.info(f"clicked followup '{json.loads(r.text)['data']['title']}': https://stampy.ai/?state={id}")

# run a regex to replace all relative links with absolute links. Just doing
# a regex for now since we really don't need to parse everything out then
Expand Down
93 changes: 47 additions & 46 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from typing import List, Dict, Callable
import openai
import re
from sqlalchemy.orm import PropComparator
import tiktoken
import time

from stampy_chat.followups import multisearch_authored
from stampy_chat.get_blocks import get_top_k_blocks, Block
from stampy_chat import logging


logger = logging.getLogger(__name__)


# OpenAI models
Expand All @@ -21,17 +26,12 @@

# NOTE: All this is approximate, there's bits I'm intentionally not counting. Leave a buffer beyond what you might expect.
NUM_TOKENS = 8191 if COMPLETIONS_MODEL == 'gpt-4' else 4095
TOKENS_BUFFER = 50 # the number of tokens to leave as a buffer when calculating remaining tokens
HISTORY_FRACTION = 0.25 # the (approximate) fraction of num_tokens to use for history text before truncating
CONTEXT_FRACTION = 0.5 # the (approximate) fraction of num_tokens to use for context text before truncating

ENCODER = tiktoken.get_encoding("cl100k_base")

DEBUG_PRINT = True

def set_debug_print(val: bool):
global DEBUG_PRINT
DEBUG_PRINT = val

# --------------------------------- prompt code --------------------------------


Expand All @@ -47,9 +47,9 @@ def cap(text: str, max_tokens: int) -> str:
else: return ENCODER.decode(encoded_text[:max_tokens]) + " ..."


Prompt = List[Dict[str, str]]


def construct_prompt(query: str, mode: str, history: List[Dict[str, str]], context: List[Block]) -> List[Dict[str, str]]:
def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block]) -> Prompt:

prompt = []

Expand Down Expand Up @@ -142,7 +142,27 @@ def construct_prompt(query: str, mode: str, history: List[Dict[str, str]], conte
import time
import json

def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print):

def check_openai_moderation(prompt: Prompt, query: str):
prompt_string = '\n\n'.join([message["content"] for message in prompt])
mod_res = openai.Moderation.create( input = [ query, prompt_string ])

if any(map(lambda x: x["flagged"], mod_res["results"])):
logger.moderation_issue(query, prompt_string, mod_res)

raise ValueError("This conversation was rejected by OpenAI's moderation filter. Sorry.")


def remaining_tokens(prompt: Prompt):
# Count number of tokens left for completion (-50 for a buffer)
used_tokens = sum([
len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"]))
for message in prompt
])
return NUM_TOKENS - used_tokens - TOKENS_BUFFER


def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, k: int = STANDARD_K):
try:
# 1. Find the most relevant blocks from the Alignment Research Dataset
yield {"state": "loading", "phase": "semantic"}
Expand All @@ -156,25 +176,10 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str,

# 3. Run both the standalone query and the full prompt through
# moderation to see if it will be accepted by OpenAI's api

prompt_string = '\n\n'.join([message["content"] for message in prompt])
mod_res = openai.Moderation.create( input = [ query, prompt_string ])

if any(map(lambda x: x["flagged"], mod_res["results"])):

# this is a biiig ask of a discord webhook - put most important
# info at start such that it's more likely to not be cut off
log('-' * 80)
log("MODERATION REJECTED")
log("MODERATION RESPONSE:\n\n" + json.dumps(mod_res["results"], indent=2))
log("REJECTED QUERY: " + query)
log("REJECTED PROMPT:\n\n" + prompt_string)
log('-' * 80)

raise ValueError("This conversation was rejected by OpenAI's moderation filter. Sorry.")
check_openai_moderation(prompt, query)

# 4. Count number of tokens left for completion (-50 for a buffer)
max_tokens_completion = NUM_TOKENS - sum([len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"])) for message in prompt]) - 50
max_tokens_completion = remaining_tokens(prompt)

# 5. Answer the user query
yield {"state": "loading", "phase": "llm"}
Expand All @@ -195,43 +200,39 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str,


t2 = time.time()
print(f'Time to get response: {t2-t1:.2f}s')

if DEBUG_PRINT:
print('\n' * 10)
print(" ------------------------------ prompt: -----------------------------")
logger.debug(f'Time to get response: {time.time() - t1:.2f}s')
if logger.is_debug():
logger.debug('\n' * 10)
logger.debug(" ------------------------------ prompt: -----------------------------")
for message in prompt:
print(f"----------- {message['role']}: ------------------")
print(message['content'])

print('\n' * 10)

print(' ------------------------------ response: -----------------------------')
print(response)
logger.debug("----------- %s: ------------------", message['role'])
logger.debug(message['content'])
logger.debug('\n' * 10)
logger.debug(' ------------------------------ response: -----------------------------')
logger.debug(response)

log(query)
log(response)
logger.interaction(query, response, history, prompt, top_k_blocks)

# yield done state, possibly with followup questions
fin_json = {'state': 'done'}
followups = multisearch_authored([query, response], DEBUG_PRINT)
followups = multisearch_authored([query, response])
for i, followup in enumerate(followups):
fin_json[f'followup_{i}'] = asdict(followup)
yield fin_json

except Exception as e:
print(e)
logger.error(e)
yield {'state': 'error', 'error': str(e)}

# convert talk_to_robot_internal from dict generator into json generator
def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print):
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k, log))
def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K):
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k))

# wayyy simplified api
def talk_to_robot_simple(index, query: str, log: Callable = print):
def talk_to_robot_simple(index, query: str):
res = {'response': ''}

for block in talk_to_robot_internal(index, query, "default", [], log = log):
for block in talk_to_robot_internal(index, query, "default", []):
if block['state'] == 'loading' and block['phase'] == 'semantic' and 'citations' in block:
citations = {}
for i, c in enumerate(block['citations']):
Expand Down
13 changes: 1 addition & 12 deletions api/src/stampy_chat/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import openai
import pinecone
from discord_webhook import DiscordWebhook

if os.path.exists('.env'):
from dotenv import load_dotenv
Expand All @@ -25,6 +24,7 @@
PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME", "alignment-search")
PINECONE_INDEX = None
PINECONE_NAMESPACE = os.environ.get("PINECONE_NAMESPACE", "alignment-search") # "normal" or "finetuned" for the new index, "alignment-search" for the old one

# Only init pinecone if we have an env value for it.
if PINECONE_API_KEY:
pinecone.init(
Expand All @@ -34,17 +34,6 @@

PINECONE_INDEX = pinecone.Index(index_name=PINECONE_INDEX_NAME)

# log something only if the logging url is set
def log(*args, end="\n"):
message = " ".join([str(arg) for arg in args]) + end
# print(message)
if DISCORD_LOGGING_URL is not None and DISCORD_LOGGING_URL != "":
while len(message) > 2000 - 8:
m_section, message = message[:2000 - 8], message[2000 - 8:]
m_section = "```\n" + m_section + "\n```"
DiscordWebhook(url=DISCORD_LOGGING_URL, content=m_section).execute()
DiscordWebhook(url=DISCORD_LOGGING_URL, content="```\n" + message + "\n```").execute()

### MySQL ###
user = os.environ.get("CHAT_DB_USER", "user")
password = os.environ.get("CHAT_DB_PASSWORD", "we all live in a yellow submarine")
Expand Down
28 changes: 14 additions & 14 deletions api/src/stampy_chat/followups.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import requests
from dataclasses import dataclass
from typing import List
from urllib.parse import quote
import requests
from stampy_chat import logging

logger = logging.getLogger(__name__)

SIMILARITY_THRESHOLD = 0.4 # bit of a shot in the dark - play with this later
MAX_FOLLOWUPS = 3
Expand All @@ -15,11 +18,11 @@ class Followup:
# do a search like this:
# https://nlp.stampy.ai/api/search?query=what%20is%20agi

def search_authored(query: str, DEBUG_PRINT: bool = False):
multisearch_authored([query], DEBUG_PRINT)
def search_authored(query: str):
multisearch_authored([query])

# search with multiple queries, combine results
def multisearch_authored(queries: List[str], DEBUG_PRINT: bool = False):
def multisearch_authored(queries: List[str]):

followups = {}

Expand All @@ -36,20 +39,17 @@ def multisearch_authored(queries: List[str], DEBUG_PRINT: bool = False):

followups = followups[:MAX_FOLLOWUPS]

if DEBUG_PRINT:
print(" ------------------------------ suggested followups: -----------------------------")
if logger.is_debug():
logger.debug(" ------------------------------ suggested followups: -----------------------------")
for followup in followups:
if followup.score > SIMILARITY_THRESHOLD:
print(f'{followup.score:.2f} - suggested to user')
logger.debug(f'{followup.score:.2f} - suggested to user')
else:
print(f'{followup.score:.2f} - not suggested')

print(followup.text)
print(followup.pageid)
print()
logger.debug(f'{followup.score:.2f} - not suggested')
logger.debug(followup.text)
logger.debug(followup.pageid)
logger.debug('')

followups = [ f for f in followups if f.score > SIMILARITY_THRESHOLD ]

return followups


18 changes: 11 additions & 7 deletions api/src/stampy_chat/get_blocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import List, Tuple
import dataclasses
import datetime
import itertools
Expand All @@ -7,7 +6,12 @@
import regex as re
import requests
import time
from typing import List, Tuple
from stampy_chat.env import PINECONE_NAMESPACE
from stampy_chat import logging


logger = logging.getLogger(__name__)

# ---------------------------------- constants ---------------------------------

Expand Down Expand Up @@ -60,7 +64,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:

if index is None:

print('Pinecone index not found, performing semantic search on chat.stampy.ai endpoint.')
logger.info('Pinecone index not found, performing semantic search on chat.stampy.ai endpoint.')
response = requests.post(
"https://chat.stampy.ai:8443/semantic",
json = {
Expand All @@ -78,7 +82,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:
query_embedding = get_embedding(user_query)

t1 = time.time()
print(f'Time to get embedding: {t1-t:.2f}s')
logger.debug(f'Time to get embedding: {t1-t:.2f}s')

query_response = index.query(
namespace=PINECONE_NAMESPACE,
Expand Down Expand Up @@ -115,7 +119,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:

t2 = time.time()

print(f'Time to get top-k blocks: {t2-t1:.2f}s')
logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s')

# for all blocks that are "the same" (same title, author, date, url, tags),
# combine their text with "....." in between. Return them in order such
Expand All @@ -130,7 +134,8 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:

for key, group in itertools.groupby(blocks_plus_old_index, key=key):
group = list(group)
if len(group) == 0: continue
if not group:
continue

# group = group[:3] # limit to a max of 3 blocks from any one source

Expand All @@ -150,6 +155,5 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:
def strip_block(text: str) -> str:
r = re.match(r"^\"(.*)\"\s*-\s*Title:.*$", text, re.DOTALL)
if not r:
print("Warning: couldn't strip block")
print(text)
logger.warning("couldn't strip block:\n%s", text)
return r.group(1) if r else text
Loading

0 comments on commit aa07d73

Please sign in to comment.