diff --git a/api/src/stampy_chat/get_blocks.py b/api/src/stampy_chat/get_blocks.py index 0908132..91fe738 100644 --- a/api/src/stampy_chat/get_blocks.py +++ b/api/src/stampy_chat/get_blocks.py @@ -6,7 +6,8 @@ import regex as re import requests import time -from typing import List, Tuple +from itertools import groupby +from typing import Iterable, List, Tuple from stampy_chat.env import PINECONE_NAMESPACE, REMOTE_CHAT_INSTANCE, EMBEDDING_MODEL from stampy_chat import logging @@ -48,6 +49,51 @@ def get_embedding(text: str) -> np.ndarray: time.sleep(min(max_wait_time, 2 ** attempt)) +def parse_block(match) -> Block: + metadata = match['metadata'] + + date = metadata.get('date_published') or metadata.get('date') + + if isinstance(date, datetime.date): + date = date.isoformat() + elif isinstance(date, datetime.datetime): + date = date.date().isoformat() + elif isinstance(date, (int, float)): + date = datetime.datetime.fromtimestamp(date).isoformat() + + authors = metadata.get('authors') + if not authors and metadata.get('author'): + authors = [metadata.get('author')] + + return Block( + id = metadata.get('hash_id') or metadata.get('id'), + title = metadata['title'], + authors = authors, + date = date, + url = metadata['url'], + tags = metadata.get('tags'), + text = strip_block(metadata['text']) + ) + + +def join_blocks(blocks: Iterable[Block]) -> List[Block]: + # for all blocks that are "the same" (same title, author, date, url, tags), + # combine their text with "....." in between. Return them in order such + # that the combined block has the minimum index of the blocks combined. + + def to_tuple(block): + return (block.id, block.title or "", block.authors or [], block.date or "", block.url or "", block.tags or "") + + def merge_texts(blocks): + return "\n.....\n".join(sorted(block.text for block in blocks)) + + unified_blocks = [ + Block(*key, merge_texts(group)) + for key, group in groupby(blocks, key=to_tuple) + ] + return sorted(unified_blocks, key=to_tuple) + + # Get the k blocks most semantically similar to the query using Pinecone. def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: @@ -69,7 +115,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: } ) - return [Block(**block) for block in response.json()] + return [parse_block({'metadata': block}) for block in response.json()] # print time t = time.time() @@ -87,63 +133,12 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: include_metadata=True, vector=query_embedding ) - blocks = [] - for match in query_response['matches']: - metadata = match['metadata'] - - date = metadata.get('date_published') or metadata.get('date') - - if isinstance(date, datetime.date): - date = date.isoformat() - elif isinstance(date, datetime.datetime): - date = date.date().isoformat() - elif isinstance(date, float): - date = datetime.datetime.fromtimestamp(date).date().isoformat() - - authors = metadata.get('authors') - if not authors and metadata.get('author'): - authors = [metadata.get('author')] - - blocks.append(Block( - id = metadata.get('hash_id'), - title = metadata['title'], - authors = authors, - date = date, - url = metadata['url'], - tags = metadata.get('tags'), - text = strip_block(metadata['text']) - )) - + blocks = [parse_block(match) for match in query_response['matches']] t2 = time.time() logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s') - # for all blocks that are "the same" (same title, author, date, url, tags), - # combine their text with "....." in between. Return them in order such - # that the combined block has the minimum index of the blocks combined. - - key = lambda bi: (bi[0].id, bi[0].title or "", bi[0].authors or [], bi[0].date or "", bi[0].url or "", bi[0].tags or "") - - blocks_plus_old_index = [(block, i) for i, block in enumerate(blocks)] - blocks_plus_old_index.sort(key=key) - - unified_blocks: List[Tuple[Block, int]] = [] - - for key, group in itertools.groupby(blocks_plus_old_index, key=key): - group = list(group) - if not group: - continue - - # group = group[:3] # limit to a max of 3 blocks from any one source - - text = "\n.....\n".join([block[0].text for block in group]) - - min_index = min([block[1] for block in group]) - - unified_blocks.append((Block(*key, text), min_index)) - - unified_blocks.sort(key=lambda bi: bi[1]) - return [block for block, _ in unified_blocks] + return join_blocks(blocks) # we add the title and authors inside the contents of the block, so that diff --git a/api/tests/stampy_chat/test_get_blocks.py b/api/tests/stampy_chat/test_get_blocks.py new file mode 100644 index 0000000..842f60e --- /dev/null +++ b/api/tests/stampy_chat/test_get_blocks.py @@ -0,0 +1,159 @@ +import pytest +from datetime import datetime +from unittest.mock import patch, Mock, call +from stampy_chat.get_blocks import Block, get_top_k_blocks, parse_block, join_blocks + + +@pytest.mark.parametrize('match_override, block_override', ( + ({}, {}), + + # Check dates + ({'date_published': '2023-01-02T03:04:05'}, {'date': '2023-01-02T03:04:05'}), + ( + {'date_published': datetime.fromisoformat('2023-01-02T03:04:05')}, + {'date': '2023-01-02T03:04:05'} + ), + ( + {'date_published': datetime.fromisoformat('2023-01-02T03:04:05').date()}, + {'date': '2023-01-02'} + ), + ( + {'date_published': datetime.fromisoformat('2023-01-02T03:04:05').timestamp()}, + {'date': '2023-01-02T03:04:05'} + ), + ( + {'date_published': int(datetime.fromisoformat('2023-01-02T03:04:05').timestamp())}, + {'date': '2023-01-02T03:04:05'} + ), + + # Check authors + ({'author': 'mr blobby'}, {'authors': ['mr blobby']}), + ({'authors': ['mr blobby', 'John Snow']}, {'authors': ['mr blobby', 'John Snow']}), + ( + {'authors': ['mr blobby', 'John Snow'], 'author': 'your momma'}, + {'authors': ['mr blobby', 'John Snow']} + ), +)) +def test_parse_block(match_override, block_override): + match = dict({ + "hash_id": "1", + "title": "Block", + "text": "text", + "date_published": "2021-12-30", + "authors": [], + "url": "http://test.com", + "tags": "tag", + }, **match_override) + + expected_block_data = dict({ + "id": "1", + "title": "Block", + "text": "text", + "date": "2021-12-30", + "authors": [], + "url": "http://test.com", + "tags": "tag", + }, **block_override) + + assert parse_block({'metadata': match}) == Block(**expected_block_data) + + +@pytest.mark.parametrize("blocks, expected", [ + ([], []), + ( + [Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1')], + [Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1')] + ), + ( + [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'), + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text2') + ], + [Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1\n.....\ntext2')] + ), + ( + [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'), + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text2'), + Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2'), + Block('id3', 'title3', ['author3'], 'date3', 'url3', 'tags3', 'text3'), + ], + [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1\n.....\ntext2'), + Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2'), + Block('id3', 'title3', ['author3'], 'date3', 'url3', 'tags3', 'text3'), + ] + ), +]) +def test_join_blocks(blocks, expected): + assert list(join_blocks(blocks)) == expected + + +def test_join_blocks_different_blocks(): + blocks = [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'), + Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2') + ] + assert list(join_blocks(blocks)) == [ + Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'), + Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2') + ] + + +def test_get_top_k_blocks_no_index(): + response = Mock() + response.json.return_value = [ + { + "hash_id": f"{i}", + "title": f"Block {i}", + "text": f"text {i}", + "date_published": f"2021-12-0{i}", + "authors": [], + "url": f"http://test.com/{i}", + "tags": f"tag{i}", + } for i in range(5) + ] + with patch('stampy_chat.get_blocks.requests.post', return_value=response): + assert get_top_k_blocks(None, "bla bla bla", 10) == [ + Block( + id=f"{i}", + title=f"Block {i}", + text=f"text {i}", + date=f"2021-12-0{i}", + authors=[], + url=f"http://test.com/{i}", + tags=f"tag{i}" + ) for i in range(5) + ] + + +@patch('stampy_chat.get_blocks.get_embedding') +def test_get_top_k_blocks(_mock_embedder): + index = Mock() + index.query.return_value = { + 'matches': [ + { + 'metadata': { + "hash_id": f"{i}", + "title": f"Block {i}", + "text": f"text {i}", + "date_published": f"2021-12-0{i}", + "authors": [], + "url": f"http://test.com/{i}", + "tags": f"tag{i}", + } + } for i in range(5) + ] + } + + assert get_top_k_blocks(index, "bla bla bla", 10) == [ + Block( + id=f"{i}", + title=f"Block {i}", + text=f"text {i}", + date=f"2021-12-0{i}", + authors=[], + url=f"http://test.com/{i}", + tags=f"tag{i}" + ) for i in range(5) + ]