diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000..f82b92d --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,25 @@ +name: Run pytest + +on: [push] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v3 + with: + python-version: "3.11" + - name: Install dependencies + run: | + cd api + python -m pip install --upgrade pip + pip install pipenv + pipenv install --dev + - name: Test with pytest + run: | + cd api + pipenv run pytest diff --git a/api/Pipfile b/api/Pipfile index a1c6c74..aaae373 100644 --- a/api/Pipfile +++ b/api/Pipfile @@ -27,6 +27,7 @@ sqlalchemy = "*" mysql-connector-python = "*" [dev-packages] +pytest = "*" [requires] python_version = "3.11" diff --git a/api/Pipfile.lock b/api/Pipfile.lock index 930c2d4..3df309f 100644 --- a/api/Pipfile.lock +++ b/api/Pipfile.lock @@ -1034,5 +1034,39 @@ "version": "==1.9.2" } }, - "develop": {} + "develop": { + "iniconfig": { + "hashes": [ + "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", + "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374" + ], + "markers": "python_version >= '3.7'", + "version": "==2.0.0" + }, + "packaging": { + "hashes": [ + "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61", + "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f" + ], + "markers": "python_version >= '3.7'", + "version": "==23.1" + }, + "pluggy": { + "hashes": [ + "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12", + "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7" + ], + "markers": "python_version >= '3.8'", + "version": "==1.3.0" + }, + "pytest": { + "hashes": [ + "sha256:1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002", + "sha256:a766259cfab564a2ad52cb1aae1b881a75c3eb7e34ca3779697c23ed47c47069" + ], + "index": "pypi", + "markers": "python_version >= '3.7'", + "version": "==7.4.2" + } + } } diff --git a/api/src/stampy_chat/followups.py b/api/src/stampy_chat/followups.py index 96567d8..5afce48 100644 --- a/api/src/stampy_chat/followups.py +++ b/api/src/stampy_chat/followups.py @@ -19,25 +19,27 @@ class Followup: # https://nlp.stampy.ai/api/search?query=what%20is%20agi def search_authored(query: str): - multisearch_authored([query]) + return multisearch_authored([query]) -# search with multiple queries, combine results -def multisearch_authored(queries: List[str]): - - followups = {} - for query in queries: +def get_followups(query): + url = 'https://nlp.stampy.ai/api/search?query=' + quote(query) + response = requests.get(url).json() + return [Followup(entry['title'], entry['pageid'], entry['score']) for entry in response] - url = 'https://nlp.stampy.ai/api/search?query=' + quote(query) - response = requests.get(url).json() - for entry in response: - followups[entry['pageid']] = Followup(entry['title'], entry['pageid'], entry['score']) - followups = list(followups.values()) +# search with multiple queries, combine results +def multisearch_authored(queries: List[str]): + # sort the followups from lowest to highest score + followups = [entry for query in queries for entry in get_followups(query)] + followups = sorted(followups, key=lambda entry: entry.score) - followups.sort(key=lambda f: f.score, reverse=True) + # Remove any duplicates by making a map from the pageids. This should result in highest scored entry being used + followups = {entry.pageid: entry for entry in followups if entry.score > SIMILARITY_THRESHOLD} - followups = followups[:MAX_FOLLOWUPS] + # Get the first `MAX_FOLLOWUPS` + followups = sorted(followups.values(), reverse=True, key=lambda e: e.score) + followups = list(followups)[:MAX_FOLLOWUPS] if logger.is_debug(): logger.debug(" ------------------------------ suggested followups: -----------------------------") @@ -50,6 +52,4 @@ def multisearch_authored(queries: List[str]): logger.debug(followup.pageid) logger.debug('') - followups = [ f for f in followups if f.score > SIMILARITY_THRESHOLD ] - return followups diff --git a/api/src/stampy_chat/get_blocks.py b/api/src/stampy_chat/get_blocks.py index 0908132..91fe738 100644 --- a/api/src/stampy_chat/get_blocks.py +++ b/api/src/stampy_chat/get_blocks.py @@ -6,7 +6,8 @@ import regex as re import requests import time -from typing import List, Tuple +from itertools import groupby +from typing import Iterable, List, Tuple from stampy_chat.env import PINECONE_NAMESPACE, REMOTE_CHAT_INSTANCE, EMBEDDING_MODEL from stampy_chat import logging @@ -48,6 +49,51 @@ def get_embedding(text: str) -> np.ndarray: time.sleep(min(max_wait_time, 2 ** attempt)) +def parse_block(match) -> Block: + metadata = match['metadata'] + + date = metadata.get('date_published') or metadata.get('date') + + if isinstance(date, datetime.date): + date = date.isoformat() + elif isinstance(date, datetime.datetime): + date = date.date().isoformat() + elif isinstance(date, (int, float)): + date = datetime.datetime.fromtimestamp(date).isoformat() + + authors = metadata.get('authors') + if not authors and metadata.get('author'): + authors = [metadata.get('author')] + + return Block( + id = metadata.get('hash_id') or metadata.get('id'), + title = metadata['title'], + authors = authors, + date = date, + url = metadata['url'], + tags = metadata.get('tags'), + text = strip_block(metadata['text']) + ) + + +def join_blocks(blocks: Iterable[Block]) -> List[Block]: + # for all blocks that are "the same" (same title, author, date, url, tags), + # combine their text with "....." in between. Return them in order such + # that the combined block has the minimum index of the blocks combined. + + def to_tuple(block): + return (block.id, block.title or "", block.authors or [], block.date or "", block.url or "", block.tags or "") + + def merge_texts(blocks): + return "\n.....\n".join(sorted(block.text for block in blocks)) + + unified_blocks = [ + Block(*key, merge_texts(group)) + for key, group in groupby(blocks, key=to_tuple) + ] + return sorted(unified_blocks, key=to_tuple) + + # Get the k blocks most semantically similar to the query using Pinecone. def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: @@ -69,7 +115,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: } ) - return [Block(**block) for block in response.json()] + return [parse_block({'metadata': block}) for block in response.json()] # print time t = time.time() @@ -87,63 +133,12 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: include_metadata=True, vector=query_embedding ) - blocks = [] - for match in query_response['matches']: - metadata = match['metadata'] - - date = metadata.get('date_published') or metadata.get('date') - - if isinstance(date, datetime.date): - date = date.isoformat() - elif isinstance(date, datetime.datetime): - date = date.date().isoformat() - elif isinstance(date, float): - date = datetime.datetime.fromtimestamp(date).date().isoformat() - - authors = metadata.get('authors') - if not authors and metadata.get('author'): - authors = [metadata.get('author')] - - blocks.append(Block( - id = metadata.get('hash_id'), - title = metadata['title'], - authors = authors, - date = date, - url = metadata['url'], - tags = metadata.get('tags'), - text = strip_block(metadata['text']) - )) - + blocks = [parse_block(match) for match in query_response['matches']] t2 = time.time() logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s') - # for all blocks that are "the same" (same title, author, date, url, tags), - # combine their text with "....." in between. Return them in order such - # that the combined block has the minimum index of the blocks combined. - - key = lambda bi: (bi[0].id, bi[0].title or "", bi[0].authors or [], bi[0].date or "", bi[0].url or "", bi[0].tags or "") - - blocks_plus_old_index = [(block, i) for i, block in enumerate(blocks)] - blocks_plus_old_index.sort(key=key) - - unified_blocks: List[Tuple[Block, int]] = [] - - for key, group in itertools.groupby(blocks_plus_old_index, key=key): - group = list(group) - if not group: - continue - - # group = group[:3] # limit to a max of 3 blocks from any one source - - text = "\n.....\n".join([block[0].text for block in group]) - - min_index = min([block[1] for block in group]) - - unified_blocks.append((Block(*key, text), min_index)) - - unified_blocks.sort(key=lambda bi: bi[1]) - return [block for block, _ in unified_blocks] + return join_blocks(blocks) # we add the title and authors inside the contents of the block, so that diff --git a/api/src/stampy_chat/logging.py b/api/src/stampy_chat/logging.py index ab126a1..ec3ab5b 100644 --- a/api/src/stampy_chat/logging.py +++ b/api/src/stampy_chat/logging.py @@ -41,7 +41,7 @@ def __init__(self, *args, **kwargs): def is_debug(self): return self.isEnabledFor(DEBUG) - def interaction(self, session_id, query, response, history, prompt, blocks): + def interaction(self, session_id: str, query: str, response: str, history, prompt, blocks): prompt = [i for i in prompt if i.get('role') == 'system'] prompt = prompt[0].get('content') if prompt else None diff --git a/api/tests/stampy_chat/test_followups.py b/api/tests/stampy_chat/test_followups.py new file mode 100644 index 0000000..2f7c54b --- /dev/null +++ b/api/tests/stampy_chat/test_followups.py @@ -0,0 +1,67 @@ +import pytest +from unittest.mock import patch, Mock + +from stampy_chat.followups import Followup, search_authored, multisearch_authored + + +@pytest.mark.parametrize("query, expected_result", [ + ("what is agi", [Followup("agi title", "agi", 0.5)],), + ("what is ai", [Followup("ai title", "ai", 0.5)],)]) +def test_search_authored(query, expected_result): + response = Mock(json=lambda: [ + {'title': r.text, 'pageid': r.pageid, 'score': r.score} + for r in expected_result + ]) + + with patch('requests.get', return_value=response): + assert search_authored(query) == expected_result + + +@patch('stampy_chat.followups.logger') +def test_multisearch_authored(_logger): + results = [ + {'pageid': '1', 'title': f'result 1', 'score': 0.423}, + {'pageid': '2', 'title': f'result 2', 'score': 0.623}, + {'pageid': '3', 'title': f'this should be skipped', 'score': 0.323}, + {'pageid': '4', 'title': f'this should also be skipped', 'score': 0.1}, + {'pageid': '5', 'title': f'result 5', 'score': 0.543}, + ] + + response = Mock(json=lambda: results) + with patch('requests.get', return_value=response): + assert multisearch_authored(["what is this?", "how about this?"]) == [ + Followup('result 2', '2', 0.623), + Followup('result 5', '5', 0.543), + Followup('result 1', '1', 0.423), + ] + + +@patch('stampy_chat.followups.logger') +def test_multisearch_authored_duplicates(_logger): + results = { + 'query1': [ + {'pageid': '1', 'title': f'result 1', 'score': 0.423}, + {'pageid': '2', 'title': f'result 2', 'score': 0.623}, + {'pageid': '3', 'title': f'this should be skipped', 'score': 0.323}, + {'pageid': '4', 'title': f'this should also be skipped', 'score': 0.1}, + {'pageid': '5', 'title': f'result 5', 'score': 0.543}, + ], + 'query2': [ + {'pageid': '1', 'title': f'result 1', 'score': 0.723}, + {'pageid': '2', 'title': f'this should be skipped', 'score': 0.323}, + {'pageid': '5', 'title': f'this should also be skipped', 'score': 0.1}, + ], + 'query3': [ + {'pageid': '5', 'title': f'result 5', 'score': 0.511}, + ], + } + def getter(url): + query = url.split('query=')[-1] + return Mock(json=lambda: results[query]) + + with patch('requests.get', getter): + assert multisearch_authored(["query1", "query2", "query3"]) == [ + Followup('result 1', '1', 0.723), + Followup('result 2', '2', 0.623), + Followup('result 5', '5', 0.543), + ] diff --git a/api/tests/stampy_chat/test_get_blocks.py b/api/tests/stampy_chat/test_get_blocks.py new file mode 100644 index 0000000..842f60e --- /dev/null +++ b/api/tests/stampy_chat/test_get_blocks.py @@ -0,0 +1,159 @@ +import pytest +from datetime import datetime +from unittest.mock import patch, Mock, call +from stampy_chat.get_blocks import Block, get_top_k_blocks, parse_block, join_blocks + + +@pytest.mark.parametrize('match_override, block_override', ( + ({}, {}), + + # Check dates + ({'date_published': '2023-01-02T03:04:05'}, {'date': '2023-01-02T03:04:05'}), + ( + {'date_published': datetime.fromisoformat('2023-01-02T03:04:05')}, + {'date': '2023-01-02T03:04:05'} + ), + ( + {'date_published': datetime.fromisoformat('2023-01-02T03:04:05').date()}, + {'date': '2023-01-02'} + ), + ( + {'date_published': datetime.fromisoformat('2023-01-02T03:04:05').timestamp()}, + {'date': '2023-01-02T03:04:05'} + ), + ( + {'date_published': int(datetime.fromisoformat('2023-01-02T03:04:05').timestamp())}, + {'date': '2023-01-02T03:04:05'} + ), + + # Check authors + ({'author': 'mr blobby'}, {'authors': ['mr blobby']}), + ({'authors': ['mr blobby', 'John Snow']}, {'authors': ['mr blobby', 'John Snow']}), + ( + {'authors': ['mr blobby', 'John Snow'], 'author': 'your momma'}, + {'authors': ['mr blobby', 'John Snow']} + ), +)) +def test_parse_block(match_override, block_override): + match = dict({ + "hash_id": "1", + "title": "Block", + "text": "text", + "date_published": "2021-12-30", + "authors": [], + "url": "http://test.com", + "tags": "tag", + }, **match_override) + + expected_block_data = dict({ + "id": "1", + "title": "Block", + "text": "text", + "date": "2021-12-30", + "authors": [], + "url": "http://test.com", + "tags": "tag", + }, **block_override) + + assert parse_block({'metadata': match}) == Block(**expected_block_data) + + +@pytest.mark.parametrize("blocks, expected", [ + ([], []), + ( + [Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1')], + [Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1')] + ), + ( + [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'), + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text2') + ], + [Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1\n.....\ntext2')] + ), + ( + [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'), + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text2'), + Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2'), + Block('id3', 'title3', ['author3'], 'date3', 'url3', 'tags3', 'text3'), + ], + [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1\n.....\ntext2'), + Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2'), + Block('id3', 'title3', ['author3'], 'date3', 'url3', 'tags3', 'text3'), + ] + ), +]) +def test_join_blocks(blocks, expected): + assert list(join_blocks(blocks)) == expected + + +def test_join_blocks_different_blocks(): + blocks = [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'), + Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2') + ] + assert list(join_blocks(blocks)) == [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'), + Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2') + ] + + +def test_get_top_k_blocks_no_index(): + response = Mock() + response.json.return_value = [ + { + "hash_id": f"{i}", + "title": f"Block {i}", + "text": f"text {i}", + "date_published": f"2021-12-0{i}", + "authors": [], + "url": f"http://test.com/{i}", + "tags": f"tag{i}", + } for i in range(5) + ] + with patch('stampy_chat.get_blocks.requests.post', return_value=response): + assert get_top_k_blocks(None, "bla bla bla", 10) == [ + Block( + id=f"{i}", + title=f"Block {i}", + text=f"text {i}", + date=f"2021-12-0{i}", + authors=[], + url=f"http://test.com/{i}", + tags=f"tag{i}" + ) for i in range(5) + ] + + +@patch('stampy_chat.get_blocks.get_embedding') +def test_get_top_k_blocks(_mock_embedder): + index = Mock() + index.query.return_value = { + 'matches': [ + { + 'metadata': { + "hash_id": f"{i}", + "title": f"Block {i}", + "text": f"text {i}", + "date_published": f"2021-12-0{i}", + "authors": [], + "url": f"http://test.com/{i}", + "tags": f"tag{i}", + } + } for i in range(5) + ] + } + + assert get_top_k_blocks(index, "bla bla bla", 10) == [ + Block( + id=f"{i}", + title=f"Block {i}", + text=f"text {i}", + date=f"2021-12-0{i}", + authors=[], + url=f"http://test.com/{i}", + tags=f"tag{i}" + ) for i in range(5) + ] diff --git a/api/tests/stampy_chat/test_logging.py b/api/tests/stampy_chat/test_logging.py new file mode 100644 index 0000000..8f5286b --- /dev/null +++ b/api/tests/stampy_chat/test_logging.py @@ -0,0 +1,153 @@ +import pytest +from unittest.mock import patch, Mock, call +from stampy_chat.get_blocks import Block + +from stampy_chat.logging import * + + +def test_emit_ignore_internal(): + handler = DiscordHandler() + record = Mock(name="stampy_chat.bla") + record.name = 'stampy_chat.bla' + + with patch.object(handler, 'to_discord') as sender: + assert handler.emit(record) is None + assert sender.assert_not_called + + +@pytest.mark.parametrize('level, discord_level', ( + ('debug', 'warn'), + ('info', 'warn'), +)) +def test_emit_ignore_lower_levels(level, discord_level): + handler = DiscordHandler() + record = Mock(exc_text='bla', stack_info='', levelno=getLevelName(level)) + record.name = 'bla' + + with patch.object(handler, 'to_discord') as sender: + with patch('stampy_chat.logging.DISCORD_LOG_LEVEL', discord_level): + handler.emit(record) + assert sender.assert_not_called + + +@pytest.mark.parametrize('level, discord_level', ( + ('warn', 'warn'), + ('warn', 'info'), + ('debug', 'debug'), +)) +def test_emit_for_higher_levels(level, discord_level): + handler = DiscordHandler() + record = Mock(exc_text='bla', stack_info='', levelno=getLevelName(level)) + record.name = 'bla' + + with patch.object(handler, 'to_discord') as sender: + with patch('stampy_chat.logging.DISCORD_LOG_LEVEL', discord_level): + handler.emit(record) + assert sender.assert_called + + +def test_to_discord_no_url(): + handler = DiscordHandler() + + with patch('stampy_chat.logging.DISCORD_LOGGING_URL', None): + with patch('stampy_chat.logging.DiscordWebhook') as discord: + handler.to_discord("bla bla bla") + assert discord.assert_not_called + + +def test_to_discord(): + handler = DiscordHandler() + + with patch('stampy_chat.logging.DISCORD_LOGGING_URL', 'http://example.org'): + with patch('stampy_chat.logging.DiscordWebhook') as discord: + handler.to_discord("bla bla bla") + discord.assert_called_once_with(url='http://example.org', content='```\nbla bla bla\n```') + + +def test_to_discord_splits_large(): + handler = DiscordHandler() + + with patch('stampy_chat.logging.DISCORD_LOGGING_URL', 'http://example.org'): + with patch('stampy_chat.logging.MAX_MESSAGE_LEN', 30): + with patch('stampy_chat.logging.DiscordWebhook') as discord: + handler.to_discord(""" + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut + labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris + """) + assert discord.call_args_list == [ + call(url='http://example.org', content='```\n\n Lorem ipsum d\n```'), + call(url='http://example.org', content='```\nolor sit amet, consectetur adi\n```'), + call(url='http://example.org', content='```\npiscing elit, sed do eiusmod t\n```'), + call(url='http://example.org', content='```\nempor incididunt ut\n \n```'), + call(url='http://example.org', content='```\n labore et dolore magna a\n```'), + call(url='http://example.org', content='```\nliqua. Ut enim ad minim veniam\n```'), + call(url='http://example.org', content='```\n, quis nostrud exercitation ul\n```'), + call(url='http://example.org', content='```\nlamco laboris\n \n```'), + ] + + +def test_ChatLogger_is_debug(): + logger = ChatLogger('tester') + logger.setLevel(DEBUG) + assert logger.is_debug() + + +@pytest.mark.parametrize('level', (WARN, ERROR, INFO)) +def test_ChatLogger_is_debug_false(level): + logger = ChatLogger('tester') + logger.setLevel(level) + assert not logger.is_debug() + + +def test_ChatLogger_interaction(): + history = [ + {"role": "user", "content": "Die monster. You don’t belong in this world!"}, + {"role": "assistant", "content": "It was not by my hand[1] I am once again given flesh. I was called here by humans who wished to pay me tribute."}, + {"role": "user", "content": "Tribute!?! You steal men's souls and make them your slaves!"}, + {"role": "assistant", "content": "Perhaps the same could be said[321] of all religions..."}, + {"role": "user", "content": "Your words are as empty as your soul! Mankind ill needs a savior such as you!"}, + {"role": "assistant", "content": "What is a man? A[4234] miserable little pile of secrets. But enough talk... Have at you!"}, + ] + blocks = [ + Block( + id=str(i), + url=f"http://bla.bla/{i}", + tags=[], + title=f"Block{i}", + authors=[f"Author{i}"], + date=f"2021-01-0{i + 1}", + text=f"Block text {i}" + ) for i in range(5) + ] + response = "This is the response from the LLM to the user's query" + prompt = [ + {'content': "This is where the system prompt would go", 'role': 'system'}, + {'content': 'Q: Die monster. You don’t belong in this world!', 'role': 'user'}, + {'content': 'It was not by my hand[x] I am once again given flesh. I was called', 'role': 'assistant'}, + {'content': "Q: Tribute!?! You steal men's souls and make them your slaves!", 'role': 'user'}, + {'content': 'Perhaps the same could be said[x] of all religions...', 'role': 'assistant'}, + {'content': 'Q: Your words are as empty as your soul! Mankind ill needs a savior such as you!', 'role': 'user'}, + {'content': 'What is a man? A[x] miserable little pile of secrets', 'role': 'assistant'}, + { + 'content': ( + 'In your answer, please cite any claims you make back to each ' + 'source using the format: [a], [b], etc. If you use multiple ' + 'sources to make a claim cite all of them. For example: "AGI is ' + 'concerning [c, d, e]."\n' + '\n' + 'Q: to be or not to be?' + ), + 'role': 'user' + }, + ] + + logger = ChatLogger('tester') + with patch.object(logger, 'item_adder') as adder: + logger.interaction("session id", "what is this?", response, history, prompt, blocks) + interaction = adder.add.call_args_list[0][0][0] + assert interaction.session_id == 'session id' + assert interaction.interaction_no == 3 + assert interaction.query == 'what is this?' + assert interaction.response == "This is the response from the LLM to the user's query" + assert interaction.prompt == "This is where the system prompt would go" + assert interaction.chunks == ','.join(b.id for b in blocks)