Skip to content

Commit

Permalink
tests for get_blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Sep 29, 2023
1 parent 4e0956c commit d71e696
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 55 deletions.
105 changes: 50 additions & 55 deletions api/src/stampy_chat/get_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import regex as re
import requests
import time
from typing import List, Tuple
from itertools import groupby
from typing import Iterable, List, Tuple
from stampy_chat.env import PINECONE_NAMESPACE, REMOTE_CHAT_INSTANCE, EMBEDDING_MODEL
from stampy_chat import logging

Expand Down Expand Up @@ -48,6 +49,51 @@ def get_embedding(text: str) -> np.ndarray:
time.sleep(min(max_wait_time, 2 ** attempt))


def parse_block(match) -> Block:
metadata = match['metadata']

date = metadata.get('date_published') or metadata.get('date')

if isinstance(date, datetime.date):
date = date.isoformat()
elif isinstance(date, datetime.datetime):
date = date.date().isoformat()
elif isinstance(date, (int, float)):
date = datetime.datetime.fromtimestamp(date).isoformat()

authors = metadata.get('authors')
if not authors and metadata.get('author'):
authors = [metadata.get('author')]

return Block(
id = metadata.get('hash_id') or metadata.get('id'),
title = metadata['title'],
authors = authors,
date = date,
url = metadata['url'],
tags = metadata.get('tags'),
text = strip_block(metadata['text'])
)


def join_blocks(blocks: Iterable[Block]) -> List[Block]:
# for all blocks that are "the same" (same title, author, date, url, tags),
# combine their text with "....." in between. Return them in order such
# that the combined block has the minimum index of the blocks combined.

def to_tuple(block):
return (block.id, block.title or "", block.authors or [], block.date or "", block.url or "", block.tags or "")

def merge_texts(blocks):
return "\n.....\n".join(sorted(block.text for block in blocks))

unified_blocks = [
Block(*key, merge_texts(group))
for key, group in groupby(blocks, key=to_tuple)
]
return sorted(unified_blocks, key=to_tuple)


# Get the k blocks most semantically similar to the query using Pinecone.
def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:

Expand All @@ -69,7 +115,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:
}
)

return [Block(**block) for block in response.json()]
return [parse_block({'metadata': block}) for block in response.json()]

# print time
t = time.time()
Expand All @@ -87,63 +133,12 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:
include_metadata=True,
vector=query_embedding
)
blocks = []
for match in query_response['matches']:
metadata = match['metadata']

date = metadata.get('date_published') or metadata.get('date')

if isinstance(date, datetime.date):
date = date.isoformat()
elif isinstance(date, datetime.datetime):
date = date.date().isoformat()
elif isinstance(date, float):
date = datetime.datetime.fromtimestamp(date).date().isoformat()

authors = metadata.get('authors')
if not authors and metadata.get('author'):
authors = [metadata.get('author')]

blocks.append(Block(
id = metadata.get('hash_id'),
title = metadata['title'],
authors = authors,
date = date,
url = metadata['url'],
tags = metadata.get('tags'),
text = strip_block(metadata['text'])
))

blocks = [parse_block(match) for match in query_response['matches']]
t2 = time.time()

logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s')

# for all blocks that are "the same" (same title, author, date, url, tags),
# combine their text with "....." in between. Return them in order such
# that the combined block has the minimum index of the blocks combined.

key = lambda bi: (bi[0].id, bi[0].title or "", bi[0].authors or [], bi[0].date or "", bi[0].url or "", bi[0].tags or "")

blocks_plus_old_index = [(block, i) for i, block in enumerate(blocks)]
blocks_plus_old_index.sort(key=key)

unified_blocks: List[Tuple[Block, int]] = []

for key, group in itertools.groupby(blocks_plus_old_index, key=key):
group = list(group)
if not group:
continue

# group = group[:3] # limit to a max of 3 blocks from any one source

text = "\n.....\n".join([block[0].text for block in group])

min_index = min([block[1] for block in group])

unified_blocks.append((Block(*key, text), min_index))

unified_blocks.sort(key=lambda bi: bi[1])
return [block for block, _ in unified_blocks]
return join_blocks(blocks)


# we add the title and authors inside the contents of the block, so that
Expand Down
159 changes: 159 additions & 0 deletions api/tests/stampy_chat/test_get_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import pytest
from datetime import datetime
from unittest.mock import patch, Mock, call
from stampy_chat.get_blocks import Block, get_top_k_blocks, parse_block, join_blocks


@pytest.mark.parametrize('match_override, block_override', (
({}, {}),
# Check dates
({'date_published': '2023-01-02T03:04:05'}, {'date': '2023-01-02T03:04:05'}),
(
{'date_published': datetime.fromisoformat('2023-01-02T03:04:05')},
{'date': '2023-01-02T03:04:05'}
),
(
{'date_published': datetime.fromisoformat('2023-01-02T03:04:05').date()},
{'date': '2023-01-02'}
),
(
{'date_published': datetime.fromisoformat('2023-01-02T03:04:05').timestamp()},
{'date': '2023-01-02T03:04:05'}
),
(
{'date_published': int(datetime.fromisoformat('2023-01-02T03:04:05').timestamp())},
{'date': '2023-01-02T03:04:05'}
),
# Check authors
({'author': 'mr blobby'}, {'authors': ['mr blobby']}),
({'authors': ['mr blobby', 'John Snow']}, {'authors': ['mr blobby', 'John Snow']}),
(
{'authors': ['mr blobby', 'John Snow'], 'author': 'your momma'},
{'authors': ['mr blobby', 'John Snow']}
),
))
def test_parse_block(match_override, block_override):
match = dict({
"hash_id": "1",
"title": "Block",
"text": "text",
"date_published": "2021-12-30",
"authors": [],
"url": "http://test.com",
"tags": "tag",
}, **match_override)

expected_block_data = dict({
"id": "1",
"title": "Block",
"text": "text",
"date": "2021-12-30",
"authors": [],
"url": "http://test.com",
"tags": "tag",
}, **block_override)

assert parse_block({'metadata': match}) == Block(**expected_block_data)


@pytest.mark.parametrize("blocks, expected", [
([], []),
(
[Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1')],
[Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1')]
),
(
[
Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'),
Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text2')
],
[Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1\n.....\ntext2')]
),
(
[
Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'),
Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text2'),
Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2'),
Block('id3', 'title3', ['author3'], 'date3', 'url3', 'tags3', 'text3'),
],
[
Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1\n.....\ntext2'),
Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2'),
Block('id3', 'title3', ['author3'], 'date3', 'url3', 'tags3', 'text3'),
]
),
])
def test_join_blocks(blocks, expected):
assert list(join_blocks(blocks)) == expected


def test_join_blocks_different_blocks():
blocks = [
Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'),
Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2')
]
assert list(join_blocks(blocks)) == [
Block('id1', 'title1', ['author1'], 'date1', 'url1', 'tags1', 'text1'),
Block('id2', 'title2', ['author2'], 'date2', 'url2', 'tags2', 'text2')
]


def test_get_top_k_blocks_no_index():
response = Mock()
response.json.return_value = [
{
"hash_id": f"{i}",
"title": f"Block {i}",
"text": f"text {i}",
"date_published": f"2021-12-0{i}",
"authors": [],
"url": f"http://test.com/{i}",
"tags": f"tag{i}",
} for i in range(5)
]
with patch('stampy_chat.get_blocks.requests.post', return_value=response):
assert get_top_k_blocks(None, "bla bla bla", 10) == [
Block(
id=f"{i}",
title=f"Block {i}",
text=f"text {i}",
date=f"2021-12-0{i}",
authors=[],
url=f"http://test.com/{i}",
tags=f"tag{i}"
) for i in range(5)
]


@patch('stampy_chat.get_blocks.get_embedding')
def test_get_top_k_blocks(_mock_embedder):
index = Mock()
index.query.return_value = {
'matches': [
{
'metadata': {
"hash_id": f"{i}",
"title": f"Block {i}",
"text": f"text {i}",
"date_published": f"2021-12-0{i}",
"authors": [],
"url": f"http://test.com/{i}",
"tags": f"tag{i}",
}
} for i in range(5)
]
}

assert get_top_k_blocks(index, "bla bla bla", 10) == [
Block(
id=f"{i}",
title=f"Block {i}",
text=f"text {i}",
date=f"2021-12-0{i}",
authors=[],
url=f"http://test.com/{i}",
tags=f"tag{i}"
) for i in range(5)
]

0 comments on commit d71e696

Please sign in to comment.