diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e11d09a..868a6a5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,19 +41,20 @@ jobs: if: | matrix.python-version == 3.11 run: | - python -m pip install ".[dev]" + python -m pip install -e ".[dev]" python -m ruff check . python -m ruff format --check . - # - name: Generate Report - # run: | - # coverage run -m pytest - # coverage xml - - # - name: Upload Coverage to Codecov - # uses: codecov/codecov-action@v3 - # with: - # file: ./coverage.xml - # flags: unittests - # verbose: true - # token: ${{secrets.CODECOV_TOKEN}} + - name: Generate Report + run: | + coverage run -m pytest + coverage xml + coverage report + + - name: Upload Coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + flags: unittests + verbose: true + token: ${{secrets.CODECOV_TOKEN}} diff --git a/src/parliai_public/__init__.py b/src/parliai_public/__init__.py index 0e1ac55..46f6c29 100644 --- a/src/parliai_public/__init__.py +++ b/src/parliai_public/__init__.py @@ -1,7 +1,13 @@ """Using LLMs to capture coverage of organisations, people or themes in UK political debate.""" +from . import dates +from .readers import Debates, WrittenAnswers + __version__ = "0.0.1" __all__ = [ "__version__", + "Debates", + "WrittenAnswers", + "dates", ] diff --git a/src/parliai_public/_config/base.toml b/src/parliai_public/_config/base.toml index 24b784f..5d16ea9 100644 --- a/src/parliai_public/_config/base.toml +++ b/src/parliai_public/_config/base.toml @@ -1,4 +1,5 @@ urls = [] keywords = ["Office for National Statistics", "ONS"] - +prompt = "" +llm_name = "" outdir = "" diff --git a/src/parliai_public/readers/base.py b/src/parliai_public/readers/base.py index ea17da8..b65d4ed 100644 --- a/src/parliai_public/readers/base.py +++ b/src/parliai_public/readers/base.py @@ -46,7 +46,7 @@ class BaseReader(metaclass=abc.ABCMeta): Key terms to filter content on. By default, we look for any mention of `Office for National Statistics` or `ONS`. dates : list[dt.date], optional - List of dates from which to pull entries. The `parliai.dates` + List of dates from which to pull entries. The `parliai_public.dates` module may be of help. If not specified, only yesterday is used. outdir : str, default="out" Location of a directory in which to write outputs. @@ -342,6 +342,8 @@ def _read_contents(self, soup: BeautifulSoup) -> dict: def instantiate_llm(self) -> None: """Instantiate LLM object per user specification.""" + # Temporary override to default to Gemma (known/tested LLM) + self.llm_name = "gemma" self.llm = ChatOllama(model=self.llm_name, temperature=0) return None diff --git a/src/parliai_public/readers/theyworkforyou.py b/src/parliai_public/readers/theyworkforyou.py index db4255d..53c90ca 100644 --- a/src/parliai_public/readers/theyworkforyou.py +++ b/src/parliai_public/readers/theyworkforyou.py @@ -30,7 +30,7 @@ class Debates(BaseReader): Key terms to filter content on. By default, we look for any mention of `Office for National Statistics` or `ONS`. dates : list[dt.date], optional - List of dates from which to pull entries. The `parliai.dates` + List of dates from which to pull entries. The `parliai_public.dates` module may be of help. If not specified, only yesterday is used. outdir : str, default="out" Location of a directory in which to write outputs. @@ -417,7 +417,7 @@ class WrittenAnswers(Debates): Key terms to filter content on. By default, we look for any mention of `Office for National Statistics` or `ONS`. dates : list[dt.date], optional - List of dates from which to pull entries. The `parliai.dates` + List of dates from which to pull entries. The `parliai_public.dates` module may be of help. If not specified, only yesterday is used. outdir : str, default="out" Location of a directory in which to write outputs. diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..a43136c --- /dev/null +++ b/tests/common.py @@ -0,0 +1,162 @@ +""" +Common strategies and utilities used across multiple test modules. + +Any real-world details or samples used as constants were correct when +taken on 2024-03-06. +""" + +import datetime as dt +import string + +from dateutil import relativedelta as rd +from hypothesis import strategies as st +from langchain_community.chat_models import ChatOllama + +from parliai_public.readers.base import BaseReader + + +class ToyReader(BaseReader): + """A toy class to allow testing our abstract base class.""" + + def retrieve_latest_entries(self): + """Allow testing with toy method.""" + + @staticmethod + def _read_metadata(url, soup): + """Allow testing with toy static method.""" + + @staticmethod + def _read_contents(soup): + """Allow testing with toy static method.""" + + def render(self, response, page): + """Allow testing with toy method.""" + + def _summary_template(self): + """Allow testing with toy method.""" + + +def where_what(reader): + """Get the right location and class for testing a reader.""" + + what = reader + if reader is ToyReader: + what = BaseReader + + where = ".".join((what.__module__, what.__name__)) + + return where, what + + +def default_llm() -> ChatOllama: + """Instantiate default LLM object for use in testing.""" + + llm = ChatOllama( + model="gemma", + temperature=0, + # max_output_tokens=2048, + ) + + return llm + + +MPS_SAMPLE = [ + ( + "Bob Seely", + "Conservative, Isle of Wight", + "https://www.theyworkforyou.com/mp/25645/bob_seely/isle_of_wight", + ), + ( + "Mark Logan", + "Conservative, Bolton North East", + "https://www.theyworkforyou.com/mp/25886/mark_logan/bolton_north_east", + ), + ( + "Nigel Huddleston", + "Conservative, Mid Worcestershire", + "https://www.theyworkforyou.com/mp/25381/nigel_huddleston/mid_worcestershire", + ), + ( + "Heather Wheeler", + "Conservative, South Derbyshire", + "https://www.theyworkforyou.com/mp/24769/heather_wheeler/south_derbyshire", + ), + ( + "Ian Paisley Jnr", + "DUP, North Antrim", + "https://www.theyworkforyou.com/mp/13852/ian_paisley_jnr/north_antrim", + ), + ( + "Matthew Offord", + "Conservative, Hendon", + "https://www.theyworkforyou.com/mp/24955/matthew_offord/hendon", + ), + ( + "John Howell", + "Conservative, Henley", + "https://www.theyworkforyou.com/mp/14131/john_howell/henley", + ), + ( + "Robert Goodwill", + "Conservative, Scarborough and Whitby", + "https://www.theyworkforyou.com/mp/11804/robert_goodwill/scarborough_and_whitby", + ), + ( + "Naseem Shah", + "Labour, Bradford West", + "https://www.theyworkforyou.com/mp/25385/naseem_shah/bradford_west", + ), + ( + "Amy Callaghan", + "Scottish National Party, East Dunbartonshire", + "https://www.theyworkforyou.com/mp/25863/amy_callaghan/east_dunbartonshire", + ), +] + +GOV_DEPARTMENTS = [ + "Attorney General's Office", + "Cabinet Office", + "Department for Business and Trade", + "Department for Culture, Media and Sport", + "Department for Education", + "Department for Energy Security and Net Zero", + "Department for Environment, Food and Rural Affairs", + "Department for Levelling Up, Housing and Communities", + "Department for Science, Innovation and Technology", + "Department for Transport", + "Department for Work and Pensions", + "Department of Health and Social Care", + "Export Credits Guarantee Department", + "Foreign, Commonwealth and Development Office", + "HM Treasury", + "Home Office", + "Ministry of Defence", + "Ministry of Justice", + "Northern Ireland Office", + "Office of the Advocate General for Scotland", + "Office of the Leader of the House of Commons", + "Office of the Leader of the House of Lords", + "Office of the Secretary of State for Scotland", + "Office of the Secretary of State for Wales", +] + +SEARCH_TERMS = ( + "ONS", + "Office for National Statistics", + "National Statistician", +) + +TODAY = dt.date.today() +ST_DATES = st.dates(TODAY - rd.relativedelta(years=4), TODAY) + +ST_FREE_TEXT = st.text( + string.ascii_letters + string.digits + ".:;!?-", min_size=1 +) + +MODEL_NAMES = ["llama3", "mistral", "openhermes"] + +GEMMA_PREAMBLES = [ + "Sure! Here is the text you are looking for: \nMy right honourable friend...", + "Sure - here is the quote: My right honourable friend...", + "Sure!The following contains references to your search terms:My right honourable friend...", +] diff --git a/tests/readers/__init__.py b/tests/readers/__init__.py new file mode 100644 index 0000000..d373dfe --- /dev/null +++ b/tests/readers/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the reader classes.""" diff --git a/tests/readers/base/__init__.py b/tests/readers/base/__init__.py new file mode 100644 index 0000000..99eeea0 --- /dev/null +++ b/tests/readers/base/__init__.py @@ -0,0 +1 @@ +"""Tests for the BaseReader class.""" diff --git a/tests/readers/base/strategies.py b/tests/readers/base/strategies.py new file mode 100644 index 0000000..c76440a --- /dev/null +++ b/tests/readers/base/strategies.py @@ -0,0 +1,37 @@ +"""Composite strategies for testing the base reader.""" + +from hypothesis import strategies as st +from langchain.docstore.document import Document + +from ...common import SEARCH_TERMS, ST_FREE_TEXT + + +@st.composite +def st_terms_and_texts(draw, terms=SEARCH_TERMS): + """Create a possibly term-ridden string.""" + + term = draw(st.lists(st.sampled_from(terms), max_size=1)) + string = draw(ST_FREE_TEXT) + add_in = draw(st.booleans()) + + text = " ".join((string, *term)) if add_in else string + + return term, text + + +@st.composite +def st_chunks_contains_responses(draw): + """Create a set of chunks, booleans, and responses for a test.""" + + chunks = draw( + st.lists( + ST_FREE_TEXT.map(lambda x: Document(page_content=x)), + min_size=1, + max_size=5, + ) + ) + + contains = [True, *(draw(st.booleans()) for _ in chunks[1:])] + responses = [draw(ST_FREE_TEXT) for con in contains if con is True] + + return chunks, contains, responses diff --git a/tests/readers/base/test_examples.py b/tests/readers/base/test_examples.py new file mode 100644 index 0000000..2553027 --- /dev/null +++ b/tests/readers/base/test_examples.py @@ -0,0 +1,31 @@ +"""Example tests for the base reader class.""" + +import requests +from bs4 import BeautifulSoup + +from ...common import ToyReader + + +def test_does_not_match_for_extra_abbreviations(): + """Ensure the string checker does not flag ONS+ abbreviations.""" + + reader = ToyReader(urls=[], terms=["ONS"]) + strings = ( + "The ONSR is the Only National Sandwich Ranking.", + "I AM UNLUCKY! SOME MIGHT SAY I AM DONSY!", + ) + + for string in strings: + assert not reader.check_contains_terms(string) + + +def test_81_add_ons_not_matched(): + """Ensure the example from #81 does not match.""" + + reader = ToyReader([], terms=["ONS"]) + url = "https://theyworkforyou.com/wrans/?id=2024-04-12.21381.h" + + response = requests.get(url) + soup = BeautifulSoup(response.content, "html.parser") + + assert not reader.check_contains_terms(soup.get_text()) diff --git a/tests/readers/base/test_reader.py b/tests/readers/base/test_reader.py new file mode 100644 index 0000000..f3b305d --- /dev/null +++ b/tests/readers/base/test_reader.py @@ -0,0 +1,411 @@ +"""Unit tests for the `base` module.""" + +import datetime as dt +import json +import os +import pathlib +import re +import shutil +import string +import warnings +from unittest import mock + +import pytest +from bs4 import BeautifulSoup +from hypothesis import example, given, provisional, settings +from hypothesis import strategies as st +from langchain.docstore.document import Document + +from ...common import ( + GEMMA_PREAMBLES, + MODEL_NAMES, + ST_DATES, + ST_FREE_TEXT, + ToyReader, + default_llm, +) +from .strategies import st_chunks_contains_responses, st_terms_and_texts + +settings.register_profile("ci", deadline=None) +settings.load_profile("ci") + + +@given(ST_FREE_TEXT, st.dictionaries(ST_FREE_TEXT, ST_FREE_TEXT)) +def test_load_config_from_path(path, config): + """Test a dictionary can be "loaded" from a given path.""" + + with mock.patch("parliai_public.readers.base.toml.load") as load: + load.return_value = config + loaded = ToyReader._load_config(path) + + assert isinstance(loaded, dict) + assert loaded == config + + load.assert_called_once_with(path) + + +def test_load_config_default(): + """Test that the default config file can be loaded correctly.""" + + # TODO: keywords hardcoded not ideal for flexible keywords + expected = { + "urls": [], + "keywords": ["Office for National Statistics", "ONS"], + "prompt": "", + "outdir": "", + "llm_name": "", + } + config = ToyReader._load_config() + + assert config == expected + + +@given(st_terms_and_texts()) +@example((["ONS"], "Have you heard of the ONS?")) +@example((["ONS"], "ONS numbers are reliable.")) +@example((["ONS"], "Mentions of other departments are like onions.")) +def test_check_contains_terms(term_text): + """Check the term checker works as it should.""" + + term, text = term_text + + reader = ToyReader(urls=[], terms=term) + contains = reader.check_contains_terms(text) + + if term: + assert (term[0] in text) is contains + else: + assert contains is False + + +@given( + st.lists(ST_DATES, min_size=1, max_size=14), + st.sampled_from(("gemma", "chat-bison")), +) +def test_make_outdir(date_list, llm_name): + """Check the output directory builder works as it should.""" + + where = os.path.join("~", ".parliai_public", "test") + tmpdir = pathlib.Path(where).expanduser() + tmpdir.mkdir(parents=True, exist_ok=True) + + reader = ToyReader( + urls=[], dates=date_list, outdir=tmpdir, llm_name=llm_name + ) + + with mock.patch( + "parliai_public.readers.base.BaseReader._tag_outdir" + ) as determiner: + determiner.side_effect = lambda x: x + reader.make_outdir() + + outdir, *others = list(tmpdir.glob("**/*")) + start, end, *llm_parts = outdir.name.split(".") + + assert others == [] + assert dt.datetime.strptime(start, "%Y-%m-%d").date() == min(date_list) + assert dt.datetime.strptime(end, "%Y-%m-%d").date() == max(date_list) + assert ".".join(llm_parts) == llm_name + + shutil.rmtree(tmpdir) + + +@given( + st.booleans(), + ( + st.lists(st.booleans()) + .map(lambda bools: [*sorted(bools, reverse=True), False]) + .map(lambda bools: bools[: bools.index(False) + 1]) + ), +) +def test_tag_outdir(exist, exists): + """Check the out directory tagger works.""" + + reader = ToyReader(urls=[]) + + with mock.patch( + "parliai_public.readers.base.os.path.exists" + ) as exists_checker: + exists_checker.side_effect = [exist, *exists] + outdir = reader._tag_outdir("out") + + out, *tags = outdir.split(".") + assert out == "out" + + if exist: + tag = tags[0] + assert tags == [tag] + assert tag == str(exists.index(False) + 1) + else: + assert not tags + + +@given(provisional.urls(), ST_FREE_TEXT, st.booleans()) +def test_get_with_check(url, content, contains): + """Test the soup getter method.""" + + reader = ToyReader(urls=[]) + + page = mock.MagicMock() + page.content = content + with ( + mock.patch("parliai_public.readers.base.requests.get") as get, + mock.patch( + "parliai_public.readers.base.BaseReader.check_contains_terms" + ) as check, + ): + get.return_value = page + check.return_value = contains + soup = reader.get(url) + + if contains is True: + assert isinstance(soup, BeautifulSoup) + assert soup.get_text() == content + else: + assert soup is None + + get.assert_called_once_with(url) + check.assert_called_once_with(content) + + +@given(provisional.urls(), ST_FREE_TEXT) +def test_get_without_check(url, content): + """Test the soup getter method when ignoring the checker.""" + + reader = ToyReader(urls=[]) + + page = mock.MagicMock() + page.content = content + with ( + mock.patch("parliai_public.readers.base.requests.get") as get, + mock.patch( + "parliai_public.readers.base.BaseReader.check_contains_terms" + ) as check, + ): + get.return_value = page + soup = reader.get(url, check=False) + + assert isinstance(soup, BeautifulSoup) + assert soup.get_text() == content + + get.assert_called_once_with(url) + check.assert_not_called() + + +@given(provisional.urls(), st.sampled_from((None, "soup"))) +def test_read(url, soup): + """ + Test the logic of the generic read method. + + Since the method relies on two abstract methods, we mock them here, + and just test that the correct order of events passes. + """ + + reader = ToyReader(urls=[]) + + with ( + mock.patch("parliai_public.readers.base.BaseReader.get") as get, + mock.patch(__name__ + ".ToyReader._read_metadata") as read_metadata, + mock.patch(__name__ + ".ToyReader._read_contents") as read_contents, + ): + get.return_value = soup + read_metadata.return_value = {"metadata": "foo"} + read_contents.return_value = {"contents": "bar"} + page = reader.read(url) + + if soup is None: + assert page is None + read_metadata.assert_not_called() + read_contents.assert_not_called() + else: + assert page == {"metadata": "foo", "contents": "bar"} + read_metadata.assert_called_once_with(url, soup) + read_contents.assert_called_once_with(soup) + + +@given( + st.sampled_from((None, "cat", "dog", "fish", "bird")), + ST_DATES.map(dt.date.isoformat), + st.text(string.ascii_lowercase), +) +def test_save(cat, date, code): + """ + Test the method for saving dictionaries to JSON. + + We cannot use `given` and the `tmp_path` fixture at once, so we + create a test directory and delete it with each run. It's not + perfect, but it works. If this test fails, manually delete the + `~/.parliai_public/test` directory to ensure there are no strange side + effects. + """ + + where = os.path.join( + "~", ".parliai_public", "test", ".".join(map(str, (cat, date, code))) + ) + tmpdir = pathlib.Path(where).expanduser() + tmpdir.mkdir(parents=True, exist_ok=True) + + idx = ".".join((date, code, "h")) + content = {"cat": cat, "idx": idx, "date": date} + + reader = ToyReader(urls=[], outdir=tmpdir) + reader.save(content) + + items = list(tmpdir.glob("**/*")) + + if cat is None: + assert len(items) == 2 + assert items == [tmpdir / "data", tmpdir / "data" / f"{idx}.json"] + else: + assert len(items) == 3 + assert items == [ + tmpdir / "data", + tmpdir / "data" / cat, + tmpdir / "data" / cat / f"{idx}.json", + ] + + with open(items[-1], "r") as f: + assert content == json.load(f) + + shutil.rmtree(tmpdir) + + +@given(st_chunks_contains_responses()) +def test_analyse(params): + """Test the logic of the analyse method.""" + + chunks, contains, responses = params + reader = ToyReader(urls=[], llm=default_llm) + + with ( + mock.patch( + "parliai_public.readers.base.BaseReader._split_text_into_chunks" + ) as splitter, + mock.patch( + "parliai_public.readers.base.BaseReader.check_contains_terms" + ) as checker, + mock.patch( + "parliai_public.readers.base.BaseReader._analyse_chunk" + ) as analyser, + ): + splitter.return_value = chunks + checker.side_effect = contains + analyser.side_effect = responses + response = reader.analyse({"text": "foo"}) + + assert isinstance(response, dict) and "response" in response + assert response["response"].split("\n\n") == responses + + splitter.assert_called_once_with("foo") + + assert checker.call_count == len(chunks) + assert [call.args for call in checker.call_args_list] == [ + (chunk.page_content,) for chunk in chunks + ] + + assert analyser.call_count == sum(contains) + filtered = filter(lambda x: x[1], zip(chunks, contains)) + for (chunk, contain), call in zip(filtered, analyser.call_args_list): + assert contain + assert call.args == (chunk,) + + +@given( + st.text( + ["\n", " ", *string.ascii_letters], + min_size=100, + max_size=500, + ), + st.sampled_from((100, 250, 500)), + st.sampled_from((0, 5, 10)), +) +def test_split_text_into_chunks(text, size, overlap): + """ + Test the text splitter method. + + Currently, we do not do any rigorous testing since we are using the + wrong splitter. For details, see: + + https://github.com/datasciencecampus/parli-ai/issues/72 + """ + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + chunks = ToyReader._split_text_into_chunks( + text, sep="\n", size=size, overlap=overlap + ) + + assert isinstance(chunks, list) + assert all(isinstance(chunk, Document) for chunk in chunks) + + +@given( + ST_FREE_TEXT.map(lambda x: Document(page_content=x)), + st.text(" ", max_size=5), +) +@pytest.mark.skip("""Temporarily skipping. Test hangs. Actual LLM call.""") +def test_analyse_chunk(chunk, pad): + """ + Test the chunk analyser. + + This function actually interacts with the LLM, so we mock that part + and check the logic holds up. + """ + + reader = ToyReader(urls=[], prompt="{text}", llm=default_llm()) + + with mock.patch("parliai_public.readers.base.ChatOllama") as chat: + chat.return_value.invoke.return_value.content = f"response{pad}" + response = reader._analyse_chunk(chunk) + + assert response == "response" + + chat.assert_called_once() + chat.return_value.invoke.assert_called_once_with(chunk.page_content) + + +@given( + st.lists(ST_DATES, min_size=1, unique=True), + st.lists(provisional.urls(), min_size=1, max_size=5, unique=True), +) +def test_make_header(dates, urls): + """Test the summary header looks right.""" + + reader = ToyReader(dates=dates, urls=urls) + + header = reader.make_header() + + publication, period, _, _, source, _, *links = header.split("\n") + + assert re.search(r"\w+, \d{1,2} \w+ \d{4}$", publication) is not None + assert period.startswith("Period covered: ") + + _, period = period.split(": ") + date_regex = r"\w+, \d{1,2} \w+ \d{4}" + if len(dates) == 1: + assert re.match(rf"{date_regex}$", period) is not None + else: + assert re.match(rf"{date_regex} to {date_regex}$", period) is not None + + assert str(reader._source) in source + + for url, link in zip(urls, links): + assert link.startswith("- ") + assert url in link + + +@given(st.one_of(st.sampled_from(MODEL_NAMES))) +def test_instantiate_llm(llm_name): + """Test that all model requests other than gemma revert to default.""" + + reader = ToyReader(urls=[], llm_name=llm_name) + _ = reader.instantiate_llm() + assert reader.llm == default_llm() + + +@given(st.sampled_from(GEMMA_PREAMBLES)) +def test_clean_response(response): + """Test 'Sure...: ' gemma preamble is consistently removed.""" + + reader = ToyReader(urls=[]) + assert reader.clean_response(response) == "My right honourable friend..." diff --git a/tests/readers/test_creation.py b/tests/readers/test_creation.py new file mode 100644 index 0000000..db67f77 --- /dev/null +++ b/tests/readers/test_creation.py @@ -0,0 +1,266 @@ +"""Unit tests for instantiation methods of our readers.""" + +import datetime as dt +from unittest import mock + +import pytest +from hypothesis import HealthCheck, given, provisional, settings +from hypothesis import strategies as st + +from parliai_public.readers import Debates, WrittenAnswers + +from ..common import ST_DATES, ST_FREE_TEXT, TODAY, ToyReader, where_what + +ST_OPTIONAL_STRINGS = st.one_of((st.just(None), ST_FREE_TEXT)) +YESTERDAY = TODAY - dt.timedelta(days=1) + + +@settings(suppress_health_check=(HealthCheck.too_slow,)) +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + st.lists(provisional.urls(), max_size=5), + st.one_of((st.just(None), st.lists(ST_FREE_TEXT, max_size=5))), + st.one_of(st.just(None), st.lists(ST_DATES, min_size=1, max_size=5)), + ST_FREE_TEXT, + ST_OPTIONAL_STRINGS, +) +def test_init(reader_class, urls, terms, dates, outdir, prompt): + """Test instantiation occurs correctly.""" + + where, what = where_what(reader_class) + if reader_class is WrittenAnswers: + urls = reader_class._supported_urls + + config = { + "prompt": "", + "llm_name": "gemma", + } + with mock.patch(f"{where}._load_config") as load: + load.return_value = config + reader = reader_class(urls, terms, dates, outdir, prompt) + + default_terms = ["Office for National Statistics", "ONS"] + assert isinstance(reader, what) + assert reader.urls == urls + assert reader.terms == default_terms if not terms else terms + assert reader.dates == [YESTERDAY] if dates is None else dates + assert reader.outdir == outdir + assert reader.prompt == ("" if prompt is None else prompt) + assert reader.llm_name == "gemma" + + load.assert_called_once_with() + + +@pytest.mark.skip("Skipping - requires diagnostics re keywords") +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + ST_OPTIONAL_STRINGS, + st.lists(provisional.urls(), max_size=5), + st.lists(ST_FREE_TEXT, max_size=5), + ST_FREE_TEXT, +) +def test_from_toml_no_dates(reader_class, path, urls, terms, text): + """ + Test that an instance can be made from a configuration file. + + In this test, we do not configure any of the date parameters, so + every reader instance should have the same `dates` attribute: + yesterday. + """ + + where, what = where_what(reader_class) + if reader_class is WrittenAnswers: + urls = reader_class._supported_urls + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": urls, + "terms": terms, + "outdir": text, + "prompt": text, + "llm_name": "gemma", + } + reader = reader_class.from_toml(path) + + assert isinstance(reader, what) + assert reader.urls == urls + assert reader.terms == terms + assert reader.dates == [YESTERDAY] + assert reader.outdir == text + assert reader.prompt == text + assert reader.llm_name == "gemma" + + assert loader.return_value["dates"] is None + assert loader.call_count == 2 + assert loader.call_args_list == [mock.call(path), mock.call()] + + lister.assert_not_called() + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + ST_DATES.map(dt.date.isoformat), +) +def test_from_toml_with_start(reader_class, start): + """ + Check the config constructor works with a start date. + + The actual date list construction is mocked here. + """ + where, what = where_what(reader_class) + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "start": start, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "start" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(start, None, None, "%Y-%m-%d") + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + ST_DATES.map(dt.date.isoformat), +) +def test_from_toml_with_end(reader_class, end): + """Check the config constructor works with an end date.""" + + where, what = where_what(reader_class) + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "end": end, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "end" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(None, end, None, "%Y-%m-%d") + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + st.tuples(ST_DATES, ST_DATES).map( + lambda dates: sorted(map(dt.date.isoformat, dates)) + ), +) +def test_from_toml_with_endpoints(reader_class, endpoints): + """Check the config constructor works with two endpoints.""" + + where, what = where_what(reader_class) + start, end = endpoints + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "start": start, + "end": end, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "start" not in loader.return_value + assert "end" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(start, end, None, "%Y-%m-%d") + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + st.integers(1, 14), +) +def test_from_toml_with_window(reader_class, window): + """Check the config constructor works with a window.""" + + where, what = where_what(reader_class) + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "window": window, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "end" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(None, None, window, "%Y-%m-%d") + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + ST_DATES.map(dt.date.isoformat), + st.integers(1, 14), +) +def test_from_toml_with_end_and_window(reader_class, end, window): + """Check the config constructor works with an end and a window.""" + + where, what = where_what(reader_class) + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "end": end, + "window": window, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "end" not in loader.return_value + assert "window" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(None, end, window, "%Y-%m-%d") diff --git a/tests/readers/theyworkforyou/__init__.py b/tests/readers/theyworkforyou/__init__.py new file mode 100644 index 0000000..7b94998 --- /dev/null +++ b/tests/readers/theyworkforyou/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the debates reader.""" diff --git a/tests/readers/theyworkforyou/strategies.py b/tests/readers/theyworkforyou/strategies.py new file mode 100644 index 0000000..eb7b2d5 --- /dev/null +++ b/tests/readers/theyworkforyou/strategies.py @@ -0,0 +1,242 @@ +"""Test strategies for the `Debates` class.""" + +import re +import string + +from bs4 import BeautifulSoup, NavigableString, Tag +from hypothesis import provisional +from hypothesis import strategies as st + +from ...common import GOV_DEPARTMENTS, MPS_SAMPLE, ST_DATES, ST_FREE_TEXT + + +@st.composite +def st_title_blocks(draw, date=None): + """Create text for a title block in a parliamentary entry.""" + + if date is None: + date = draw(ST_DATES) + + title = draw(ST_FREE_TEXT) + extra = draw(ST_FREE_TEXT) + + block = ": ".join((title, date.strftime("%d %b %Y"), extra)) + + return block + + +@st.composite +def st_indices(draw, date=None): + """Create an index for a parliamentary entry.""" + + if date is None: + date = draw(ST_DATES) + + prefix = draw(st.text(alphabet="abc", max_size=1)) + body = draw(st.integers(0, 10).map(str)) + + idx = ".".join((date.strftime("%Y-%m-%d"), prefix, body, "h")) + + return idx + + +@st.composite +def st_metadatas(draw): + """Create a metadata block for our parliamentary summary tests.""" + + date = draw(ST_DATES) + block = draw(st_title_blocks(date)) + idx = draw(st_indices(date)) + + cat = draw(st.sampled_from(("lords", "debates", "whall"))) + url = "/".join((draw(provisional.urls()), cat, f"?id={idx}")) + + return block, date, idx, cat, url + + +@st.composite +def st_lead_metadatas(draw): + """Create a lead block for a written answer test.""" + + date = draw(ST_DATES) + recipient = draw(st.sampled_from(GOV_DEPARTMENTS)) + + lead = ( + f"{recipient} written question " + f"- answered on {date.strftime('%d %B %Y')}" + ) + + return lead, recipient, date + + +@st.composite +def st_speeches(draw): + """Create a speech and its details for a parliamentary test.""" + + speaker, position, url = draw(st.sampled_from(MPS_SAMPLE)) + speech = draw(ST_FREE_TEXT) + + return speech, speaker, position, url + + +@st.composite +def st_daily_boards(draw): + """Create some HTML soup to simulate a daily board.""" + + date = draw(st.dates()).strftime("%Y-%m-%d") + url = f"https://theyworkforyou.com/debates/?d={date}" + + st_href = st.text( + string.digits + string.ascii_letters, min_size=1, max_size=5 + ).map(lambda x: f"/debates/{x}.h") + + hrefs = draw(st.lists(st_href, min_size=1, max_size=10)) + tags = [ + f'' for href in hrefs + ] + soup = BeautifulSoup("\n".join(tags), "html.parser") + + return url, hrefs, soup + + +def extract_href(url): + """Extract just the hyperlink reference from a URL.""" + + match = re.search(r"(?<=.com)\/\w+\/\d+(?=\/)", url) + + if match is None: + return url + + return match.group() + + +def format_speech_block(name, pos, href, text): + """Get a speech block into HTML format.""" + + html = '
{text}