From 6386946c0126f6119b06dace4caee62e82f6fa04 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 10:36:28 +0000 Subject: [PATCH 01/50] Add basic gitignore --- .gitignore | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d7d195a --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.coverage +.hypothesis + +__pycache__ + +*.egg-info + +scrap \ No newline at end of file From 7133debdd77615d74c73e7756c611b0df30e7a11 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 10:36:49 +0000 Subject: [PATCH 02/50] Add API dependency --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 546f243..d6618a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ readme = "README.md" requires-python = ">=3.7" license = {text = "MIT License"} dependencies = [ + "census21api@git+https://github.com/MichaelaLawrenceONS/Census21_CACD_Wrapper@dev-0.0.1", + "dask[complete]", "numpy", "pandas", "private-pgm@git+https://github.com/ryan112358/private-pgm", @@ -30,7 +32,7 @@ test = [ "pytest-randomly", ] lint = [ - "black>=22.6.0,<23", + "black<24", "ruff>=0.1.1", ] dev = [ From c8aafe0b150145b273b802fd2be164019874bc56 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 10:58:29 +0000 Subject: [PATCH 03/50] Write instantiation test for MST --- tests/test_mst.py | 64 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/test_mst.py diff --git a/tests/test_mst.py b/tests/test_mst.py new file mode 100644 index 0000000..243002d --- /dev/null +++ b/tests/test_mst.py @@ -0,0 +1,64 @@ +"""Tests for the `centhesus.mst` module.""" + +from unittest import mock + +from census21api import CensusAPI +from census21api.constants import ( + AREA_TYPES_BY_POPULATION_TYPE, + DIMENSIONS_BY_POPULATION_TYPE, + POPULATION_TYPES, +) +from hypothesis import given +from hypothesis import strategies as st + +from centhesus import MST + + +@st.composite +def st_api_parameters(draw): + """Create a valid set of Census API parameters.""" + + population_type = draw(st.sampled_from(POPULATION_TYPES)) + area_type = draw( + st.sampled_from(AREA_TYPES_BY_POPULATION_TYPE[population_type]) + ) + dimensions = draw( + st.one_of( + ( + st.just(None), + st.sets( + st.sampled_from( + DIMENSIONS_BY_POPULATION_TYPE[population_type] + ), + min_size=1, + ), + ) + ) + ) + + return population_type, area_type, dimensions + + +@given(st_api_parameters()) +def test_init(params): + """Test instantiation of the MST class.""" + + population_type, area_type, dimensions = params + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + get_domain.return_value = "domain" + mst = MST(population_type, area_type, dimensions) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type == area_type + + if dimensions is None: + assert mst.dimensions == DIMENSIONS_BY_POPULATION_TYPE[population_type] + else: + assert mst.dimensions == dimensions + + assert isinstance(mst.api, CensusAPI) + assert mst.domain == "domain" + + get_domain.assert_called_once_with() From 37e4abf0c5a0be4a16dad2510b996152a4bd621d Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 12:27:20 +0000 Subject: [PATCH 04/50] Implement instantiation of MST class --- src/centhesus/__init__.py | 4 ++- src/centhesus/mst.py | 68 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 src/centhesus/mst.py diff --git a/src/centhesus/__init__.py b/src/centhesus/__init__.py index cb55d3c..ada8767 100644 --- a/src/centhesus/__init__.py +++ b/src/centhesus/__init__.py @@ -1,5 +1,7 @@ """Synthesising the 2021 England and Wales Census with public data.""" +from .mst import MST + __version__ = "0.0.1" -__all__ = ["__version__"] +__all__ = ["MST", "__version__"] diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py new file mode 100644 index 0000000..76d1502 --- /dev/null +++ b/src/centhesus/mst.py @@ -0,0 +1,68 @@ +"""Module for the Maximum Spanning Tree generator.""" + +from census21api import CensusAPI +from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE as DIMENSIONS + + +class MST: + """ + Data synthesiser based on the Maximum Spanning Tree (MST) method. + + This class uses the principles of the + [MST method](https://doi.org/10.29012/jpc.778) that won the 2018 + NIST Differential Privacy Synthetic Data Challenge. The original + method makes use of a formal privacy framework to protect the + confidentiality of the dataset being synthesised. In our case, we + use the publicly available tables to create our synthetic data. + These tables have undergone stringent statistical disclosure control + to make them safe to be in the public domain. + + As such, we adapt MST by removing the formal privacy mechanisms. We + do not add noise to the public tables, and we use Kruskal's + algorithm to find the true maximum spanning tree of the feature + graph. We still make use of the Private-PGM method to generate the + graphical model and subsequent synthetic data with a nominal amount + of noise (1e-10). + + The public tables are drawn from the ONS "Create a custom dataset" + API, which is accessed via the `census21api` package. See + `census21api.constants` for details of available population types, + area types, and dimensions. + + Parameters + ---------- + population_type : str + Population type to synthesise. Defaults to usual residents in + households (`"UR_HH"`). + area_type : str, optional + Area type to synthesise. If you wish to include an area type + column (like local authority) in the final dataset, include it + here. The lowest recommended level is MSOA because of issues + handling too-large marginal tables. + dimensions : list of str, optional + Dimensions to synthesise. All features (other than an area type) + you would like in the final dataset. If not specified, all + available dimensions will be included. + + Attributes + ---------- + api : census21api.CensusAPI + Client instance to connect to the 2021 Census API. + domain : mbi.Domain + Dictionary-like object defining the domain size of every column + in the synthetic dataset (area type and dimensions). + """ + + def __init__( + self, population_type="UR_HH", area_type=None, dimensions=None + ): + + self.population_type = population_type + self.area_type = area_type + self.dimensions = dimensions or DIMENSIONS[self.population_type] + + self.api = CensusAPI() + self.domain = self.get_domain() + + def get_domain(self): + """Retrieve domain metadata from the API.""" \ No newline at end of file From 325e1e33707c467635cfb44637dda16d7bdfc5a0 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 13:33:38 +0000 Subject: [PATCH 05/50] Separate instantiation tests --- tests/test_mst.py | 63 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/tests/test_mst.py b/tests/test_mst.py index 243002d..8196206 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -20,20 +20,13 @@ def st_api_parameters(draw): population_type = draw(st.sampled_from(POPULATION_TYPES)) area_type = draw( - st.sampled_from(AREA_TYPES_BY_POPULATION_TYPE[population_type]) + st.sampled_from(AREA_TYPES_BY_POPULATION_TYPE[population_type]), ) dimensions = draw( - st.one_of( - ( - st.just(None), - st.sets( - st.sampled_from( - DIMENSIONS_BY_POPULATION_TYPE[population_type] - ), - min_size=1, - ), - ) - ) + st.sets( + st.sampled_from(DIMENSIONS_BY_POPULATION_TYPE[population_type]), + min_size=1, + ).map(sorted), ) return population_type, area_type, dimensions @@ -52,11 +45,49 @@ def test_init(params): assert isinstance(mst, MST) assert mst.population_type == population_type assert mst.area_type == area_type + assert mst.dimensions == dimensions - if dimensions is None: - assert mst.dimensions == DIMENSIONS_BY_POPULATION_TYPE[population_type] - else: - assert mst.dimensions == dimensions + assert isinstance(mst.api, CensusAPI) + assert mst.domain == "domain" + + get_domain.assert_called_once_with() + + +@given(st_api_parameters()) +def test_init_none_area_type(params): + """Test instantiation of the MST class when area type is None.""" + + population_type, _, dimensions = params + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + get_domain.return_value = "domain" + mst = MST(population_type, None, dimensions) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type is None + assert mst.dimensions == dimensions + + assert isinstance(mst.api, CensusAPI) + assert mst.domain == "domain" + + get_domain.assert_called_once_with() + + +@given(st_api_parameters()) +def test_init_none_dimensions(params): + """Test instantiation of the MST class when dimensions is None.""" + + population_type, area_type, _ = params + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + get_domain.return_value = "domain" + mst = MST(population_type, area_type, None) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type == area_type + assert mst.dimensions == DIMENSIONS_BY_POPULATION_TYPE[population_type] assert isinstance(mst.api, CensusAPI) assert mst.domain == "domain" From 58f9717bb9fca4da5273f6e198ee64c960b66e32 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 13:42:04 +0000 Subject: [PATCH 06/50] Write base test for domain feature getter --- tests/test_mst.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_mst.py b/tests/test_mst.py index 8196206..2bae379 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -2,6 +2,7 @@ from unittest import mock +import pandas as pd from census21api import CensusAPI from census21api.constants import ( AREA_TYPES_BY_POPULATION_TYPE, @@ -32,6 +33,22 @@ def st_api_parameters(draw): return population_type, area_type, dimensions +@st.composite +def st_feature_metadata_parameters(draw): + """Create a parameter set and feature metadata for a test.""" + + population_type, area_type, dimensions = draw(st_api_parameters()) + + feature = draw(st.sampled_from(("area-types", "dimensions"))) + items = [area_type] if feature == "area-types" else dimensions + metadata = pd.DataFrame( + ((item, draw(st.integers())) for item in items), + columns=("id", "total_count"), + ) + + return population_type, area_type, dimensions, feature, metadata + + @given(st_api_parameters()) def test_init(params): """Test instantiation of the MST class.""" @@ -93,3 +110,27 @@ def test_init_none_dimensions(params): assert mst.domain == "domain" get_domain.assert_called_once_with() + + +@given(st_feature_metadata_parameters()) +def test_get_domain_of_feature(params): + """Test the domain of a feature can be retrieved correctly.""" + + population_type, area_type, dimensions, feature, metadata = params + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + get_domain.return_value = "domain" + mst = MST(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: + query.return_value = metadata + domain = mst._get_domain_of_feature(feature) + + assert isinstance(domain, dict) + + items = [area_type] if feature == "area-types" else dimensions + assert list(domain.keys()) == items + assert list(domain.values()) == metadata["total_count"].to_list() + + query.assert_called_once_with(population_type, feature, *items) + get_domain.assert_called_once_with() From 83086bf3420452a414fa9df465fa714f294de9dc Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 14:14:31 +0000 Subject: [PATCH 07/50] Factor test strategies out into their own module. --- tests/__init__.py | 1 + tests/strategies.py | 43 +++++++++++++++++++++++++++ tests/test_mst.py | 71 ++++++++++++++++++++++----------------------- 3 files changed, 78 insertions(+), 37 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/strategies.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..2af27a9 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Make the testing modules importable.""" diff --git a/tests/strategies.py b/tests/strategies.py new file mode 100644 index 0000000..57efca4 --- /dev/null +++ b/tests/strategies.py @@ -0,0 +1,43 @@ +"""Custom strategies for testing the package.""" + +import pandas as pd +from census21api.constants import ( + AREA_TYPES_BY_POPULATION_TYPE, + DIMENSIONS_BY_POPULATION_TYPE, + POPULATION_TYPES, +) +from hypothesis import strategies as st + + +@st.composite +def st_api_parameters(draw): + """Create a valid set of Census API parameters.""" + + population_type = draw(st.sampled_from(POPULATION_TYPES)) + area_type = draw( + st.sampled_from(AREA_TYPES_BY_POPULATION_TYPE[population_type]), + ) + dimensions = draw( + st.sets( + st.sampled_from(DIMENSIONS_BY_POPULATION_TYPE[population_type]), + min_size=1, + ).map(sorted), + ) + + return population_type, area_type, dimensions + + +@st.composite +def st_feature_metadata_parameters(draw): + """Create a parameter set and feature metadata for a test.""" + + population_type, area_type, dimensions = draw(st_api_parameters()) + + feature = draw(st.sampled_from(("area-types", "dimensions"))) + items = [area_type] if feature == "area-types" else dimensions + metadata = pd.DataFrame( + ((item, draw(st.integers())) for item in items), + columns=("id", "total_count"), + ) + + return population_type, area_type, dimensions, feature, metadata diff --git a/tests/test_mst.py b/tests/test_mst.py index 2bae379..f7762a9 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -2,51 +2,17 @@ from unittest import mock -import pandas as pd +import pytest from census21api import CensusAPI from census21api.constants import ( - AREA_TYPES_BY_POPULATION_TYPE, DIMENSIONS_BY_POPULATION_TYPE, - POPULATION_TYPES, ) from hypothesis import given from hypothesis import strategies as st from centhesus import MST - -@st.composite -def st_api_parameters(draw): - """Create a valid set of Census API parameters.""" - - population_type = draw(st.sampled_from(POPULATION_TYPES)) - area_type = draw( - st.sampled_from(AREA_TYPES_BY_POPULATION_TYPE[population_type]), - ) - dimensions = draw( - st.sets( - st.sampled_from(DIMENSIONS_BY_POPULATION_TYPE[population_type]), - min_size=1, - ).map(sorted), - ) - - return population_type, area_type, dimensions - - -@st.composite -def st_feature_metadata_parameters(draw): - """Create a parameter set and feature metadata for a test.""" - - population_type, area_type, dimensions = draw(st_api_parameters()) - - feature = draw(st.sampled_from(("area-types", "dimensions"))) - items = [area_type] if feature == "area-types" else dimensions - metadata = pd.DataFrame( - ((item, draw(st.integers())) for item in items), - columns=("id", "total_count"), - ) - - return population_type, area_type, dimensions, feature, metadata +from .strategies import st_api_parameters, st_feature_metadata_parameters @given(st_api_parameters()) @@ -119,7 +85,6 @@ def test_get_domain_of_feature(params): population_type, area_type, dimensions, feature, metadata = params with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - get_domain.return_value = "domain" mst = MST(population_type, area_type, dimensions) with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: @@ -134,3 +99,35 @@ def test_get_domain_of_feature(params): query.assert_called_once_with(population_type, feature, *items) get_domain.assert_called_once_with() + + +@given(st_feature_metadata_parameters()) +def test_get_domain_of_feature_none_area_type(params): + """Test the feature domain getter when area type is None.""" + + population_type, _, dimensions, _, metadata = params + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + mst = MST(population_type, None, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: + domain = mst._get_domain_of_feature("area-types") + + assert isinstance(domain, dict) + assert domain == {} + + get_domain.assert_called_once_with() + query.assert_not_called() + + +@given(st_api_parameters(), st.text()) +def test_get_domain_of_feature_raises_error(params, feature): + """Test the domain getter raises an error for invalid features.""" + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + mst = MST(*params) + + with pytest.raises(ValueError): + mst._get_domain_of_feature(feature) + + get_domain.assert_called_once_with() From e71b47ec29eb971b8c550d80ae3d7acd63c785f8 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 14:23:37 +0000 Subject: [PATCH 08/50] Factor out mocked MST creation --- tests/test_mst.py | 51 ++++++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/tests/test_mst.py b/tests/test_mst.py index f7762a9..2f21a81 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -15,15 +15,25 @@ from .strategies import st_api_parameters, st_feature_metadata_parameters +def mocked_mst(population_type, area_type, dimensions, return_value=None): + """Create an instance of MST with mocked `get_domain`.""" + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + get_domain.return_value = return_value + mst = MST(population_type, area_type, dimensions) + + get_domain.assert_called_once_with() + + return mst + + @given(st_api_parameters()) def test_init(params): """Test instantiation of the MST class.""" population_type, area_type, dimensions = params - with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - get_domain.return_value = "domain" - mst = MST(population_type, area_type, dimensions) + mst = mocked_mst(population_type, area_type, dimensions) assert isinstance(mst, MST) assert mst.population_type == population_type @@ -31,9 +41,7 @@ def test_init(params): assert mst.dimensions == dimensions assert isinstance(mst.api, CensusAPI) - assert mst.domain == "domain" - - get_domain.assert_called_once_with() + assert mst.domain is None @given(st_api_parameters()) @@ -42,9 +50,7 @@ def test_init_none_area_type(params): population_type, _, dimensions = params - with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - get_domain.return_value = "domain" - mst = MST(population_type, None, dimensions) + mst = mocked_mst(population_type, None, dimensions) assert isinstance(mst, MST) assert mst.population_type == population_type @@ -52,9 +58,7 @@ def test_init_none_area_type(params): assert mst.dimensions == dimensions assert isinstance(mst.api, CensusAPI) - assert mst.domain == "domain" - - get_domain.assert_called_once_with() + assert mst.domain is None @given(st_api_parameters()) @@ -63,9 +67,7 @@ def test_init_none_dimensions(params): population_type, area_type, _ = params - with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - get_domain.return_value = "domain" - mst = MST(population_type, area_type, None) + mst = mocked_mst(population_type, area_type, None) assert isinstance(mst, MST) assert mst.population_type == population_type @@ -73,9 +75,7 @@ def test_init_none_dimensions(params): assert mst.dimensions == DIMENSIONS_BY_POPULATION_TYPE[population_type] assert isinstance(mst.api, CensusAPI) - assert mst.domain == "domain" - - get_domain.assert_called_once_with() + assert mst.domain is None @given(st_feature_metadata_parameters()) @@ -84,8 +84,7 @@ def test_get_domain_of_feature(params): population_type, area_type, dimensions, feature, metadata = params - with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - mst = MST(population_type, area_type, dimensions) + mst = mocked_mst(population_type, area_type, dimensions) with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: query.return_value = metadata @@ -98,7 +97,6 @@ def test_get_domain_of_feature(params): assert list(domain.values()) == metadata["total_count"].to_list() query.assert_called_once_with(population_type, feature, *items) - get_domain.assert_called_once_with() @given(st_feature_metadata_parameters()) @@ -107,8 +105,7 @@ def test_get_domain_of_feature_none_area_type(params): population_type, _, dimensions, _, metadata = params - with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - mst = MST(population_type, None, dimensions) + mst = mocked_mst(population_type, None, dimensions) with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: domain = mst._get_domain_of_feature("area-types") @@ -116,7 +113,6 @@ def test_get_domain_of_feature_none_area_type(params): assert isinstance(domain, dict) assert domain == {} - get_domain.assert_called_once_with() query.assert_not_called() @@ -124,10 +120,7 @@ def test_get_domain_of_feature_none_area_type(params): def test_get_domain_of_feature_raises_error(params, feature): """Test the domain getter raises an error for invalid features.""" - with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - mst = MST(*params) + mst = mocked_mst(*params) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^Feature"): mst._get_domain_of_feature(feature) - - get_domain.assert_called_once_with() From bb4a2df2917cb20cb68ab385c33ea8cf5c32b0df Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 14:43:30 +0000 Subject: [PATCH 09/50] Add metadata id check to domain getter test --- tests/test_mst.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mst.py b/tests/test_mst.py index 2f21a81..289da0c 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -93,7 +93,7 @@ def test_get_domain_of_feature(params): assert isinstance(domain, dict) items = [area_type] if feature == "area-types" else dimensions - assert list(domain.keys()) == items + assert list(domain.keys()) == metadata["id"].to_list() == items assert list(domain.values()) == metadata["total_count"].to_list() query.assert_called_once_with(population_type, feature, *items) From 14387f0e5c459b7df8dbd8c646fac54feda3d6c6 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 14:50:52 +0000 Subject: [PATCH 10/50] Write CI workflow --- .github/workflows/tests.yml | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 .github/workflows/tests.yml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..d973808 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,41 @@ +name: CI + +on: + pull_request: + push: + branches: + - main + - "dev*" + +jobs: + test: + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + python-version: [3.8, 3.11] + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - name: Update pip and install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install ".[test]" pytest-sugar + - name: Run tests + run: | + python -m pytest --cov=centhesus --cov-fail-under=100 tests + - name: Install and run linters + if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.11 + run: | + python -m pip install ".[lint]" + python -m black --check . + python -m ruff . From 61cca254f5ac68941f993e88d29357e2d24814fa Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 14:54:58 +0000 Subject: [PATCH 11/50] Implement domain feature getter in MST --- src/centhesus/mst.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 76d1502..0506631 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -64,5 +64,27 @@ def __init__( self.api = CensusAPI() self.domain = self.get_domain() + def _get_domain_of_feature(self, feature): + """Retrieve the domain for items in a feature of the API.""" + + if feature == "area-types" and self.area_type is None: + return {} + elif feature == "area-types": + items = [self.area_type] + elif feature == "dimensions": + items = self.dimensions + else: + raise ValueError( + "Feature must be one of 'area-types' or 'dimensions', " + f"not '{feature}'" + ) + + metadata = self.api.query_feature( + self.population_type, feature, *items + ) + domain = dict(metadata[["id", "total_count"]].to_dict("split")["data"]) + + return domain + def get_domain(self): - """Retrieve domain metadata from the API.""" \ No newline at end of file + """Retrieve domain metadata from the API.""" From 7d3a6aeb77a91dd3ff1cfcb095b4ccb126d368d4 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 14:59:05 +0000 Subject: [PATCH 12/50] Update black and reformat code base --- src/centhesus/mst.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 0506631..965ef10 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -56,7 +56,6 @@ class MST: def __init__( self, population_type="UR_HH", area_type=None, dimensions=None ): - self.population_type = population_type self.area_type = area_type self.dimensions = dimensions or DIMENSIONS[self.population_type] From b840fd058a3463496f210fae6b9b8ce2a41b16b9 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 15:27:10 +0000 Subject: [PATCH 13/50] Write overall domain getter test --- tests/test_mst.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_mst.py b/tests/test_mst.py index 289da0c..3a6b90a 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -1,5 +1,6 @@ """Tests for the `centhesus.mst` module.""" +import string from unittest import mock import pytest @@ -9,6 +10,7 @@ ) from hypothesis import given from hypothesis import strategies as st +from mbi import Domain from centhesus import MST @@ -124,3 +126,34 @@ def test_get_domain_of_feature_raises_error(params, feature): with pytest.raises(ValueError, match="^Feature"): mst._get_domain_of_feature(feature) + + +@given( + st_api_parameters(), + st.dictionaries( + st.text(string.ascii_uppercase, min_size=1), st.integers() + ), + st.dictionaries( + st.text(string.ascii_lowercase, min_size=1), st.integers() + ), +) +def test_get_domain(params, area_type_domain, dimensions_domain): + """Test the domain getter can process metadata correctly.""" + + mst = mocked_mst(*params) + + with mock.patch("centhesus.mst.MST._get_domain_of_feature") as feature: + feature.side_effect = [area_type_domain, dimensions_domain] + domain = mst.get_domain() + + assert isinstance(domain, Domain) + assert domain.attrs == ( + *area_type_domain.keys(), + *dimensions_domain.keys(), + ) + + assert feature.call_count == 2 + assert [call.args for call in feature.call_args_list] == [ + ("area-types",), + ("dimensions",), + ] From 85f006d41d2993236c93d4cdf7d7da0722023653 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 15:27:31 +0000 Subject: [PATCH 14/50] Implement basic domain getter --- src/centhesus/mst.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 965ef10..85c3ca3 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -2,6 +2,7 @@ from census21api import CensusAPI from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE as DIMENSIONS +from mbi import Domain class MST: @@ -64,7 +65,25 @@ def __init__( self.domain = self.get_domain() def _get_domain_of_feature(self, feature): - """Retrieve the domain for items in a feature of the API.""" + """ + Retrieve the domain for items in a feature of the API. + + Parameters + ---------- + feature : {"area-types", "dimensions"} + Feature of the API from which to call. + + Raises + ------ + ValueError + If `feature` is invalid. + + Returns + ------- + domain : dict + Dictionary containing the domain metadata. Empty if + `feature` is `"area-types"` and `self.area_type` is `None`. + """ if feature == "area-types" and self.area_type is None: return {} @@ -86,4 +105,19 @@ def _get_domain_of_feature(self, feature): return domain def get_domain(self): - """Retrieve domain metadata from the API.""" + """ + Retrieve domain metadata from the API. + + Returns + ------- + domain : mbi.Domain + Dictionary-like object defining the domain size of every column + in the synthetic dataset (area type and dimensions). + """ + + area_type_domain = self._get_domain_of_feature("area-types") + dimension_domain = self._get_domain_of_feature("dimensions") + + domain = Domain.fromdict({**area_type_domain, **dimension_domain}) + + return domain From 315393ccf08063507ceaacf37a0c99a93ded3378 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 18:32:31 +0000 Subject: [PATCH 15/50] Write tests for marginal getter --- tests/strategies.py | 40 ++++++++++++++++++++++++++++++++++++++++ tests/test_mst.py | 28 +++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/strategies.py b/tests/strategies.py index 57efca4..6c190fc 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -6,7 +6,9 @@ DIMENSIONS_BY_POPULATION_TYPE, POPULATION_TYPES, ) +from hypothesis import assume from hypothesis import strategies as st +from hypothesis.extra.pandas import column, data_frames @st.composite @@ -41,3 +43,41 @@ def st_feature_metadata_parameters(draw): ) return population_type, area_type, dimensions, feature, metadata + + +@st.composite +def st_marginals(draw): + """Create a marginal table and its parameters for a test.""" + + population_type, area_type, dimensions = draw(st_api_parameters()) + + clique = draw( + st.sets( + st.sampled_from((area_type, *dimensions)), min_size=1, max_size=2 + ).map(tuple) + ) + + columns = [ + *( + column( + col, + elements=st.integers(min_value=-1, max_value=5), + unique=True, + ) + for col in clique + ), + column( + "count", + elements=st.integers(min_value=0, max_value=10), + ), + ] + + marginal = ( + draw(data_frames(columns)) + .sort_values(list(clique)) + .reset_index(drop=True) + ) + + assume(len(marginal)) + + return population_type, area_type, dimensions, clique, marginal diff --git a/tests/test_mst.py b/tests/test_mst.py index 3a6b90a..fcfa5a8 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -3,6 +3,8 @@ import string from unittest import mock +import numpy as np +import pandas as pd import pytest from census21api import CensusAPI from census21api.constants import ( @@ -14,7 +16,11 @@ from centhesus import MST -from .strategies import st_api_parameters, st_feature_metadata_parameters +from .strategies import ( + st_api_parameters, + st_feature_metadata_parameters, + st_marginals, +) def mocked_mst(population_type, area_type, dimensions, return_value=None): @@ -157,3 +163,23 @@ def test_get_domain(params, area_type_domain, dimensions_domain): ("area-types",), ("dimensions",), ] + + +@given(st_marginals(), st.booleans()) +def test_get_marginal(params, flatten): + """Test that a marginal table can be processed correctly.""" + + population_type, area_type, dimensions, clique, table = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_table") as query: + query.return_value = table + marginal = mst.get_marginal(clique, flatten) + + if flatten: + assert isinstance(marginal, np.ndarray) + assert (marginal == table["count"]).all() + else: + assert isinstance(marginal, pd.Series) + assert marginal.name == "count" + assert (marginal.reset_index() == table).all().all() From 81f54bca5a4b4c31184816427ba5e93f50c98f69 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 14 Nov 2023 18:32:55 +0000 Subject: [PATCH 16/50] Implement marginal getter --- src/centhesus/mst.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 85c3ca3..7cd6a34 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -121,3 +121,43 @@ def get_domain(self): domain = Domain.fromdict({**area_type_domain, **dimension_domain}) return domain + + def get_marginal(self, clique, flatten=True): + """ + Retrieve the marginal table for a clique from the API. + + This function also returns the metadata to "measure" the + marginal in the package that underpins the synthesis, `mbi`. + + Parameters + ---------- + clique : tuple of str + Tuple defining the columns of the clique to be measured. + Should be of the form `(col,)` or `(col1, col2)`. + flatten : bool + Whether the marginal should be flattened or not. Default is + `True` to work with `mbi`. Flattened marginals are NumPy + arrays rather than Pandas series. + + Returns + ------- + marginal : numpy.ndarray or pandas.Series + Marginal table. If `flatten` is True, this a flat array. + Otherwise, the indexed series is returned. + """ + + area_type = self.area_type or "nat" + dimensions = [ + col for col in clique if col != area_type + ] or self.dimensions[0:1] + + marginal = ( + self.api.query_table(self.population_type, area_type, dimensions) + .groupby(list(clique))["count"] + .sum() + ) + + if flatten is True: + marginal = marginal.to_numpy().flatten() + + return marginal From 81d8cabbb5accedccbba98320953f39387fbf9aa Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 11:11:04 +0000 Subject: [PATCH 17/50] Write tests for measurement and failed marginal --- tests/strategies.py | 29 +++++--------------------- tests/test_mst.py | 50 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 26 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index 6c190fc..cc015db 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -6,9 +6,7 @@ DIMENSIONS_BY_POPULATION_TYPE, POPULATION_TYPES, ) -from hypothesis import assume from hypothesis import strategies as st -from hypothesis.extra.pandas import column, data_frames @st.composite @@ -46,7 +44,7 @@ def st_feature_metadata_parameters(draw): @st.composite -def st_marginals(draw): +def st_single_marginals(draw): """Create a marginal table and its parameters for a test.""" population_type, area_type, dimensions = draw(st_api_parameters()) @@ -57,27 +55,10 @@ def st_marginals(draw): ).map(tuple) ) - columns = [ - *( - column( - col, - elements=st.integers(min_value=-1, max_value=5), - unique=True, - ) - for col in clique - ), - column( - "count", - elements=st.integers(min_value=0, max_value=10), - ), - ] - - marginal = ( - draw(data_frames(columns)) - .sort_values(list(clique)) - .reset_index(drop=True) + counts = draw(st.lists(st.integers(0), min_size=1, max_size=10)) + marginal = pd.DataFrame( + ((*([i] * len(clique)), count) for i, count in enumerate(counts)), + columns=(*clique, "count"), ) - assume(len(marginal)) - return population_type, area_type, dimensions, clique, marginal diff --git a/tests/test_mst.py b/tests/test_mst.py index fcfa5a8..045df5c 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -13,13 +13,14 @@ from hypothesis import given from hypothesis import strategies as st from mbi import Domain +from scipy import sparse from centhesus import MST from .strategies import ( st_api_parameters, st_feature_metadata_parameters, - st_marginals, + st_single_marginals, ) @@ -165,7 +166,7 @@ def test_get_domain(params, area_type_domain, dimensions_domain): ] -@given(st_marginals(), st.booleans()) +@given(st_single_marginals(), st.booleans()) def test_get_marginal(params, flatten): """Test that a marginal table can be processed correctly.""" @@ -183,3 +184,48 @@ def test_get_marginal(params, flatten): assert isinstance(marginal, pd.Series) assert marginal.name == "count" assert (marginal.reset_index() == table).all().all() + + query.assert_called_once() + + +@given(st_single_marginals(), st.booleans()) +def test_get_marginal_failed_call(params, flatten): + """Test that a failed call can be processed still.""" + + population_type, area_type, dimensions, clique, _ = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_table") as query: + query.return_value = None + marginal = mst.get_marginal(clique, flatten) + + assert marginal is None + + query.assert_called_once() + + +@given(st_single_marginals(), st.integers(1, 5)) +def test_measure(params, num_cliques): + """Test a set of cliques can be measured.""" + + population_type, area_type, dimensions, clique, table = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + get_marginal.return_value = table + measurements = mst.measure([clique] * num_cliques) + + assert isinstance(measurements, list) + assert len(measurements) == num_cliques + + for measurement in measurements: + assert isinstance(measurement, tuple) + assert len(measurement) == 4 + + ident, marg, noise, cliq = measurement + assert isinstance(ident, sparse._dia.dia_matrix) + assert ident.shape == (marg.size,) * 2 + assert ident.sum() == marg.size + assert marg.equals(table) + assert noise == 1e-12 + assert cliq == clique From 8eeb979c1c1bc51508ddc583c76fe118711a87c5 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 11:11:37 +0000 Subject: [PATCH 18/50] Write measure method --- src/centhesus/mst.py | 71 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 12 deletions(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 7cd6a34..2db9023 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -1,8 +1,10 @@ """Module for the Maximum Spanning Tree generator.""" +import dask from census21api import CensusAPI from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE as DIMENSIONS from mbi import Domain +from scipy import sparse class MST: @@ -141,23 +143,68 @@ def get_marginal(self, clique, flatten=True): Returns ------- - marginal : numpy.ndarray or pandas.Series - Marginal table. If `flatten` is True, this a flat array. + marginal : numpy.ndarray or pandas.Series or None + Marginal table if the API call succeeds and `None` if not. + On a success, if `flatten` is `True`, this a flat array. Otherwise, the indexed series is returned. """ area_type = self.area_type or "nat" - dimensions = [ - col for col in clique if col != area_type - ] or self.dimensions[0:1] - - marginal = ( - self.api.query_table(self.population_type, area_type, dimensions) - .groupby(list(clique))["count"] - .sum() + dimensions = [col for col in clique if col != area_type] + if not dimensions: + dimensions = self.dimensions[0:1] + + marginal = self.api.query_table( + self.population_type, area_type, dimensions ) - if flatten is True: - marginal = marginal.to_numpy().flatten() + if marginal is not None: + marginal = marginal.groupby(list(clique))["count"].sum() + if flatten is True: + marginal = marginal.to_numpy().flatten() return marginal + + def measure(self, cliques): + """ + Measure the marginals of a set of cliques. + + This function returns a list of "measurements" to be passed to + the `mbi` package. Each measurement consists of a sparse + identity matrix, the marginal table, a nominally small float + representing the "noise" added to the marginal, and the clique + associated with the marginal. + + Although we are not applying differential privacy to our tables, + `mbi` requires non-zero noise for each measurement to form the + graphical model. + + If a column pair has been blocked by the API, then their + marginal is `None` and we skip over them. + + Parameters + ---------- + cliques : iterable of tuple + The cliques to measure. These cliques should be of the form + `(col,)` or `(col1, col2)`. + + Returns + ------- + measurements : list of tuple + Measurement tuples for each clique. + """ + + tasks = [] + for clique in cliques: + marginal = dask.delayed(self.get_marginal)(clique) + tasks.append(marginal) + + marginals = dask.compute(*tasks) + + measurements = [ + (sparse.eye(marginal.size), marginal, 1e-12, clique) + for marginal, clique in zip(marginals, cliques) + if marginal is not None + ] + + return measurements From 8fc60fd9351b1d1e80ebec1f40c7bdbebe3f9b9b Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 12:10:38 +0000 Subject: [PATCH 19/50] Write tests for model fitter (rm measure deadline) --- tests/strategies.py | 15 ++++++++++++--- tests/test_mst.py | 28 ++++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index cc015db..f23ab36 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -1,5 +1,8 @@ """Custom strategies for testing the package.""" +import itertools + +import numpy as np import pandas as pd from census21api.constants import ( AREA_TYPES_BY_POPULATION_TYPE, @@ -55,10 +58,16 @@ def st_single_marginals(draw): ).map(tuple) ) - counts = draw(st.lists(st.integers(0), min_size=1, max_size=10)) + num_uniques = [draw(st.integers(2, 5)) for _ in clique] + num_rows = int(np.prod(num_uniques)) + counts = draw( + st.lists(st.integers(0, 100), min_size=num_rows, max_size=num_rows) + ) + marginal = pd.DataFrame( - ((*([i] * len(clique)), count) for i, count in enumerate(counts)), - columns=(*clique, "count"), + itertools.product(*(range(num_unique) for num_unique in num_uniques)), + columns=clique, ) + marginal["count"] = counts return population_type, area_type, dimensions, clique, marginal diff --git a/tests/test_mst.py b/tests/test_mst.py index 045df5c..fd07a3c 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -10,9 +10,9 @@ from census21api.constants import ( DIMENSIONS_BY_POPULATION_TYPE, ) -from hypothesis import given +from hypothesis import given, settings from hypothesis import strategies as st -from mbi import Domain +from mbi import Domain, GraphicalModel from scipy import sparse from centhesus import MST @@ -24,11 +24,11 @@ ) -def mocked_mst(population_type, area_type, dimensions, return_value=None): +def mocked_mst(population_type, area_type, dimensions, domain=None): """Create an instance of MST with mocked `get_domain`.""" with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - get_domain.return_value = return_value + get_domain.return_value = domain mst = MST(population_type, area_type, dimensions) get_domain.assert_called_once_with() @@ -204,6 +204,7 @@ def test_get_marginal_failed_call(params, flatten): query.assert_called_once() +@settings(deadline=None) @given(st_single_marginals(), st.integers(1, 5)) def test_measure(params, num_cliques): """Test a set of cliques can be measured.""" @@ -229,3 +230,22 @@ def test_measure(params, num_cliques): assert marg.equals(table) assert noise == 1e-12 assert cliq == clique + + +@given(st_single_marginals(), st.integers(1, 5)) +def test_fit_model(params, iters): + """Test that a model can be fitted to some measurements.""" + + population_type, area_type, dimensions, clique, table = params + domain = Domain.fromdict(table.drop("count", axis=1).nunique().to_dict()) + table = table["count"] + mst = mocked_mst(population_type, area_type, dimensions, domain=domain) + + measurements = [(sparse.eye(table.size), table, 1e-12, clique)] + model = mst.fit_model(measurements, iters) + + assert isinstance(model, GraphicalModel) + assert model.domain == mst.domain + assert model.cliques == [clique] + assert model.elimination_order == list(clique) + assert model.total == table.sum() or 1 From 1e95c247e3f8a678b45adf78988ad8e136643728 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 12:12:03 +0000 Subject: [PATCH 20/50] Implement model fitter --- src/centhesus/mst.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 2db9023..24ebcf5 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -3,7 +3,7 @@ import dask from census21api import CensusAPI from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE as DIMENSIONS -from mbi import Domain +from mbi import Domain, FactoredInference from scipy import sparse @@ -208,3 +208,26 @@ def measure(self, cliques): ] return measurements + + def fit_model(self, measurements, iters=5000): + """ + Fit a graphical model to some measurements. + + Parameters + ---------- + measurements : list of tuple + Measurement tuples associated with some cliques to fit. + iters : int + Number of iterations to use when fitting the model. Default + is 5000. + + Returns + ------- + model : mbi.GraphicalModel + Fitted graphical model. + """ + + engine = FactoredInference(self.domain, iters=iters) + model = engine.estimate(measurements) + + return model From c444138ae15d884a228d90413aa546b5246778dc Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 12:54:15 +0000 Subject: [PATCH 21/50] Write pair importance calculator test --- tests/strategies.py | 12 ++++++++++-- tests/test_mst.py | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index f23ab36..0aa56ab 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -47,14 +47,22 @@ def st_feature_metadata_parameters(draw): @st.composite -def st_single_marginals(draw): +def st_single_marginals(draw, kind=None): """Create a marginal table and its parameters for a test.""" population_type, area_type, dimensions = draw(st_api_parameters()) + min_size, max_size = 1, 2 + if kind == "single": + max_size = 1 + if kind == "pair": + min_size = 2 + clique = draw( st.sets( - st.sampled_from((area_type, *dimensions)), min_size=1, max_size=2 + st.sampled_from((area_type, *dimensions)), + min_size=min_size, + max_size=max_size, ).map(tuple) ) diff --git a/tests/test_mst.py b/tests/test_mst.py index fd07a3c..37920b7 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -249,3 +249,27 @@ def test_fit_model(params, iters): assert model.cliques == [clique] assert model.elimination_order == list(clique) assert model.total == table.sum() or 1 + + +@given(st_single_marginals(kind="pair")) +def test_calculate_importance_of_pair(params): + """Test the importance of a column pair can be calculated.""" + + population_type, area_type, dimensions, clique, table = params + table = table["count"] + mst = mocked_mst(population_type, area_type, dimensions) + + interim = mock.MagicMock() + interim.project.return_value.datavector.return_value = table.sample( + frac=1.0 + ).reset_index(drop=True) + with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + get_marginal.return_value = table + weight = mst._calculate_importance_of_pair(interim, clique) + + assert isinstance(weight, float) + assert weight <= 0 + + interim.project.assert_called_once_with(clique) + interim.project.return_value.datavector.assert_called_once_with() + get_marginal.assert_called_once_with(clique) From ce6e283f62dca55fb55d22151c9aa9192136a713 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 12:54:50 +0000 Subject: [PATCH 22/50] Implement calculator of pair importance --- src/centhesus/mst.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 24ebcf5..f753e5d 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -1,6 +1,7 @@ """Module for the Maximum Spanning Tree generator.""" import dask +import numpy as np from census21api import CensusAPI from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE as DIMENSIONS from mbi import Domain, FactoredInference @@ -231,3 +232,31 @@ def fit_model(self, measurements, iters=5000): model = engine.estimate(measurements) return model + + def _calculate_importance_of_pair(self, interim, pair): + """ + Determine the importance of a column pair with an interim model. + + Importance is defined as the negative L1 norm between the + observed marginal table for the column pair and that estimated + by our interim model. + + Parameters + ---------- + interim : mbi.GraphicalModel + Interim model based on one-way marginals only. + pair : tuple of str + Column pair to be assessed. + + Returns + ------- + weight : float + Importance of the pair given as the negative of the L1 norm + between the observed and estimated marginals for the pair. + """ + + estimate = interim.project(pair).datavector() + marginal = self.get_marginal(pair) + weight = -np.linalg.norm(marginal - estimate, 1) + + return weight From 49bc5e97e22720265a7740d69b0cc23ae1507da9 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 13:21:37 +0000 Subject: [PATCH 23/50] Write test for failed call in importance checker --- tests/test_mst.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_mst.py b/tests/test_mst.py index 37920b7..5641c54 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -273,3 +273,22 @@ def test_calculate_importance_of_pair(params): interim.project.assert_called_once_with(clique) interim.project.return_value.datavector.assert_called_once_with() get_marginal.assert_called_once_with(clique) + + +@given(st_single_marginals(kind="pair")) +def test_calculate_importance_of_pair_failed_call(params): + """Test that a failed call doesn't stop importance processing.""" + + population_type, area_type, dimensions, clique, _ = params + mst = mocked_mst(population_type, area_type, dimensions) + + interim = mock.MagicMock() + with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + get_marginal.return_value = None + weight = mst._calculate_importance_of_pair(interim, clique) + + assert weight is None + + interim.project.assert_not_called() + interim.project.return_value.datavector.assert_not_called() + get_marginal.assert_called_once_with(clique) From 220dee268eb67db8a4e25412ef9b81bda8b246cd Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 14:54:08 +0000 Subject: [PATCH 24/50] Write test for calculating importances Weird indexing problems with Dask... will need to improve the `measure` test(s)... --- tests/strategies.py | 23 +++++++++++++++++++++++ tests/test_mst.py | 27 +++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/tests/strategies.py b/tests/strategies.py index 0aa56ab..673c246 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -1,6 +1,7 @@ """Custom strategies for testing the package.""" import itertools +import math import numpy as np import pandas as pd @@ -10,6 +11,7 @@ POPULATION_TYPES, ) from hypothesis import strategies as st +from mbi import Domain @st.composite @@ -79,3 +81,24 @@ def st_single_marginals(draw, kind=None): marginal["count"] = counts return population_type, area_type, dimensions, clique, marginal + + +@st.composite +def st_importances(draw): + """Create a domain and set of importances for a test.""" + + population_type, area_type, dimensions = draw(st_api_parameters()) + num = len(dimensions) + 1 + + sizes = draw(st.lists(st.integers(2, 10), min_size=num, max_size=num)) + domain = Domain.fromdict(dict(zip((area_type, *dimensions), sizes))) + + importances = draw( + st.lists( + st.floats(max_value=0, allow_infinity=False, allow_nan=False), + min_size=math.comb(num, 2), + max_size=math.comb(num, 2), + ) + ) + + return population_type, area_type, dimensions, domain, importances diff --git a/tests/test_mst.py b/tests/test_mst.py index 5641c54..8de36d6 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -1,5 +1,6 @@ """Tests for the `centhesus.mst` module.""" +import itertools import string from unittest import mock @@ -20,6 +21,7 @@ from .strategies import ( st_api_parameters, st_feature_metadata_parameters, + st_importances, st_single_marginals, ) @@ -292,3 +294,28 @@ def test_calculate_importance_of_pair_failed_call(params): interim.project.assert_not_called() interim.project.return_value.datavector.assert_not_called() get_marginal.assert_called_once_with(clique) + + +@settings(deadline=None) +@given(st_importances()) +def test_calculate_importances(params): + """Test that a set of importances can be calculated.""" + + population_type, area_type, dimensions, domain, importances = params + mst = mocked_mst(population_type, area_type, dimensions, domain=domain) + + with mock.patch("centhesus.mst.MST._calculate_importance_of_pair") as calc: + calc.side_effect = importances + weights = mst._calculate_importances("interim") + + pairs = list(itertools.combinations(domain, 2)) + calc.call_count == len(pairs) + call_args = [call.args for call in calc.call_args_list] + assert set(call_args) == set(("interim", pair) for pair in pairs) + + assert isinstance(weights, dict) + assert set(weights.keys()) == set(pairs) + + pairs_execution_order = [pair for _, pair in call_args] + for pair, importance in zip(pairs_execution_order, importances): + assert weights[pair] == importance From 5e56172428d5b76aee7972744f5a9558bf0b57e8 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 14:55:06 +0000 Subject: [PATCH 25/50] Implement importance calculator; idx task execut'n --- src/centhesus/mst.py | 59 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index f753e5d..5faa0a4 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -1,5 +1,7 @@ """Module for the Maximum Spanning Tree generator.""" +import itertools + import dask import numpy as np from census21api import CensusAPI @@ -183,6 +185,8 @@ def measure(self, cliques): If a column pair has been blocked by the API, then their marginal is `None` and we skip over them. + We use `dask` to compute these marginals in parallel. + Parameters ---------- cliques : iterable of tuple @@ -197,14 +201,14 @@ def measure(self, cliques): tasks = [] for clique in cliques: - marginal = dask.delayed(self.get_marginal)(clique) - tasks.append(marginal) + get_marginal = dask.delayed(lambda x: (x, self.get_marginal(x))) + tasks.append(get_marginal(clique)) - marginals = dask.compute(*tasks) + indexed_marginals = dask.compute(*tasks) measurements = [ (sparse.eye(marginal.size), marginal, 1e-12, clique) - for marginal, clique in zip(marginals, cliques) + for clique, marginal in indexed_marginals if marginal is not None ] @@ -250,13 +254,54 @@ def _calculate_importance_of_pair(self, interim, pair): Returns ------- - weight : float + pair : tuple of str + Assessed column pair. + weight : float or None Importance of the pair given as the negative of the L1 norm between the observed and estimated marginals for the pair. + If the API call fails, this is `None`. """ - estimate = interim.project(pair).datavector() + weight = None marginal = self.get_marginal(pair) - weight = -np.linalg.norm(marginal - estimate, 1) + if marginal is not None: + estimate = interim.project(pair).datavector() + weight = -np.linalg.norm(marginal - estimate, 1) return weight + + def _calculate_importances(self, interim): + """ + Determine every column pair's importance given an interim model. + + We use `dask` to compute these importances in parallel. + + Parameters + ---------- + interim : mbi.GraphicalModel + Interim model based on one-way marginals only. + + Returns + ------- + weights : dict + Dictionary mapping column pairs to their weight. If a column + pair is blocked by the API, it is skipped. + """ + + pairs = list(itertools.combinations(self.domain.attrs, 2)) + tasks = [] + for pair in pairs: + calculate_importance = dask.delayed( + lambda x: (x, self._calculate_importance_of_pair(interim, x)) + ) + tasks.append(calculate_importance(pair)) + + indexed_importances = dask.compute(*tasks) + + weights = { + pair: importance + for pair, importance in indexed_importances + if importance is not None + } + + return weights From 28a65423591ad54ef2bf62860dbf54188f1cfb85 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 15:14:04 +0000 Subject: [PATCH 26/50] Write test finding MST --- tests/test_mst.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_mst.py b/tests/test_mst.py index 8de36d6..5d4a700 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -4,6 +4,7 @@ import string from unittest import mock +import networkx as nx import numpy as np import pandas as pd import pytest @@ -270,7 +271,7 @@ def test_calculate_importance_of_pair(params): weight = mst._calculate_importance_of_pair(interim, clique) assert isinstance(weight, float) - assert weight <= 0 + assert weight >= 0 interim.project.assert_called_once_with(clique) interim.project.return_value.datavector.assert_called_once_with() @@ -319,3 +320,19 @@ def test_calculate_importances(params): pairs_execution_order = [pair for _, pair in call_args] for pair, importance in zip(pairs_execution_order, importances): assert weights[pair] == importance + + +@given(st_importances()) +def test_find_maximum_spanning_tree(params): + """Test an MST can be found from a set of importances.""" + + *api_params, domain, importances = params + mst = mocked_mst(*api_params, domain=domain) + weights = dict(zip(itertools.combinations(domain, 2), importances)) + + tree = mst._find_maximum_spanning_tree(weights) + + assert isinstance(tree, nx.Graph) + assert set(tree.edges).issubset(weights.keys()) + for edge in tree.edges: + assert tree.edges[edge]["weight"] == -weights[edge] From 26aff3f8c648116525db45b2c11c5fd1489c9063 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 15:16:01 +0000 Subject: [PATCH 27/50] Implement method to find MST --- src/centhesus/mst.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 5faa0a4..6988eee 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -3,6 +3,7 @@ import itertools import dask +import networkx as nx import numpy as np from census21api import CensusAPI from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE as DIMENSIONS @@ -241,9 +242,9 @@ def _calculate_importance_of_pair(self, interim, pair): """ Determine the importance of a column pair with an interim model. - Importance is defined as the negative L1 norm between the - observed marginal table for the column pair and that estimated - by our interim model. + Importance is defined as the L1 norm between the observed + marginal table for the column pair and that estimated by our + interim model. Parameters ---------- @@ -257,16 +258,16 @@ def _calculate_importance_of_pair(self, interim, pair): pair : tuple of str Assessed column pair. weight : float or None - Importance of the pair given as the negative of the L1 norm - between the observed and estimated marginals for the pair. - If the API call fails, this is `None`. + Importance of the pair given as the L1 norm between the + observed and estimated marginals for the pair. If the API + call fails, this is `None`. """ weight = None marginal = self.get_marginal(pair) if marginal is not None: estimate = interim.project(pair).datavector() - weight = -np.linalg.norm(marginal - estimate, 1) + weight = np.linalg.norm(marginal - estimate, 1) return weight @@ -305,3 +306,30 @@ def _calculate_importances(self, interim): } return weights + + def _find_maximum_spanning_tree(self, weights): + """ + Find the maximum spanning tree given a set of edge importances. + + To find the tree, we use Kruskal's algorithm to find the minimum + spanning tree with negative weights. + + Parameters + ---------- + weights : dict + Dictionary mapping edges (column pairs) to their importance. + + Returns + ------- + tree : nx.Graph + Maximum spanning tree of all column pairs. + """ + + graph = nx.Graph() + graph.add_nodes_from(self.domain) + for edge, weight in weights.items(): + graph.add_edge(*edge, weight=-weight) + + tree = nx.minimum_spanning_tree(graph) + + return tree From 7e1948a7f41c2772a5e609ff79a7177de41f4ce2 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 15:49:47 +0000 Subject: [PATCH 28/50] Write tests for selection method --- tests/strategies.py | 32 +++++++++++++++++++++++++++++--- tests/test_mst.py | 24 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index 673c246..8eb4e2a 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -3,6 +3,7 @@ import itertools import math +import networkx as nx import numpy as np import pandas as pd from census21api.constants import ( @@ -84,15 +85,25 @@ def st_single_marginals(draw, kind=None): @st.composite -def st_importances(draw): - """Create a domain and set of importances for a test.""" +def st_domains(draw): + """Create a domain and its parameters for a test.""" population_type, area_type, dimensions = draw(st_api_parameters()) + num = len(dimensions) + 1 - sizes = draw(st.lists(st.integers(2, 10), min_size=num, max_size=num)) domain = Domain.fromdict(dict(zip((area_type, *dimensions), sizes))) + return population_type, area_type, dimensions, domain + + +@st.composite +def st_importances(draw): + """Create a domain and set of importances for a test.""" + + population_type, area_type, dimensions, domain = draw(st_domains()) + + num = len(domain) importances = draw( st.lists( st.floats(max_value=0, allow_infinity=False, allow_nan=False), @@ -102,3 +113,18 @@ def st_importances(draw): ) return population_type, area_type, dimensions, domain, importances + + +@st.composite +def st_subgraphs(draw): + """Create a subgraph and its parameters for a test.""" + + population_type, area_type, dimensions, domain = draw(st_domains()) + + edges = draw( + st.sets(st.sampled_from(list(itertools.combinations(domain, 2)))) + ) + graph = nx.Graph() + graph.add_edges_from(edges) + + return population_type, area_type, dimensions, domain, graph diff --git a/tests/test_mst.py b/tests/test_mst.py index 5d4a700..96369b4 100644 --- a/tests/test_mst.py +++ b/tests/test_mst.py @@ -24,6 +24,7 @@ st_feature_metadata_parameters, st_importances, st_single_marginals, + st_subgraphs, ) @@ -333,6 +334,29 @@ def test_find_maximum_spanning_tree(params): tree = mst._find_maximum_spanning_tree(weights) assert isinstance(tree, nx.Graph) + assert set(tree.nodes) == set(domain) assert set(tree.edges).issubset(weights.keys()) for edge in tree.edges: assert tree.edges[edge]["weight"] == -weights[edge] + +@given(st_subgraphs()) +def test_select(params): + """Test that a set of two-way cliques can be found correctly.""" + + *api_params, domain, tree = params + mst = mocked_mst(*api_params, domain=domain) + + with mock.patch("centhesus.mst.MST.fit_model") as fit, mock.patch("centhesus.mst.MST._calculate_importances") as calc, mock.patch("centhesus.mst.MST._find_maximum_spanning_tree") as find: + fit.return_value = "interim" + calc.return_value = "weights" + find.return_value = tree + cliques = mst.select("measurements") + + possible_edges = [set(pair) for pair in itertools.combinations(domain, 2)] + assert isinstance(cliques, list) + for clique in cliques: + assert set(clique) in possible_edges + + fit.assert_called_once_with("measurements", iters=1000) + calc.assert_called_once_with("interim") + find.assert_called_once_with("weights") From 3b6f469fdb7c7f3634ea8c7cbb654b025d87efd6 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 15:50:01 +0000 Subject: [PATCH 29/50] Implement selection method --- src/centhesus/mst.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 6988eee..06922c7 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -333,3 +333,40 @@ def _find_maximum_spanning_tree(self, weights): tree = nx.minimum_spanning_tree(graph) return tree + + def select(self, measurements): + """ + Select the most informative two-way cliques. + + To determine how informative a column pair is, we first create + an interim graphical model from all observed one-way marginals. + Then, each column pair's importance is defined as the L1 + difference between its observed two-way marginal and the + estimated marginal from the interim model. + + With all the importances calculated, we model the column pairs + as a weighted graph where columns are nodes and an edge + represents the importance of the column pair at its endpoints. + In this way, the smallest set of the most informative column + pairs is given as the maximum spanning tree of this graph. + + The selected two-way cliques are the edges of this tree. + + Parameters + ---------- + measurements : list of tuple + One-way marginal measurements with which to fit an interim + graphical model. + + Returns + ------- + cliques : list of tuple + Edges of the maximum spanning tree of our weighted graph. + """ + + interim = self.fit_model(measurements, iters=1000) + weights = self._calculate_importances(interim) + tree = self._find_maximum_spanning_tree(weights) + + return list(tree.edges) + \ No newline at end of file From 732dae0cdc7c9f4213d15b594a4aaad5eb9deaac Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 16:07:53 +0000 Subject: [PATCH 30/50] Format code base --- src/centhesus/mst.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 06922c7..312b697 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -369,4 +369,3 @@ def select(self, measurements): tree = self._find_maximum_spanning_tree(weights) return list(tree.edges) - \ No newline at end of file From 7f92d1c14bd3b59c1ab434490c5b8ab3943b7c91 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 16:08:07 +0000 Subject: [PATCH 31/50] Separate out tests --- tests/mst/__init__.py | 1 + tests/mst/test_fit_model.py | 27 +++ tests/mst/test_get_domain.py | 94 +++++++++ tests/mst/test_init.py | 64 +++++++ tests/mst/test_measure.py | 77 ++++++++ tests/mst/test_select.py | 126 ++++++++++++ tests/strategies.py | 17 +- tests/test_mst.py | 362 ----------------------------------- 8 files changed, 405 insertions(+), 363 deletions(-) create mode 100644 tests/mst/__init__.py create mode 100644 tests/mst/test_fit_model.py create mode 100644 tests/mst/test_get_domain.py create mode 100644 tests/mst/test_init.py create mode 100644 tests/mst/test_measure.py create mode 100644 tests/mst/test_select.py delete mode 100644 tests/test_mst.py diff --git a/tests/mst/__init__.py b/tests/mst/__init__.py new file mode 100644 index 0000000..ead63ef --- /dev/null +++ b/tests/mst/__init__.py @@ -0,0 +1 @@ +"""MST-level tests.""" \ No newline at end of file diff --git a/tests/mst/test_fit_model.py b/tests/mst/test_fit_model.py new file mode 100644 index 0000000..fab69e7 --- /dev/null +++ b/tests/mst/test_fit_model.py @@ -0,0 +1,27 @@ +"""Unit test(s) for the model fitting in `centhesus.MST`.""" + +from hypothesis import given +from hypothesis import strategies as st +from mbi import Domain, GraphicalModel +from scipy import sparse + +from ..strategies import mocked_mst, st_single_marginals + + +@given(st_single_marginals(), st.integers(1, 5)) +def test_fit_model(params, iters): + """Test that a model can be fitted to some measurements.""" + + population_type, area_type, dimensions, clique, table = params + domain = Domain.fromdict(table.drop("count", axis=1).nunique().to_dict()) + table = table["count"] + mst = mocked_mst(population_type, area_type, dimensions, domain=domain) + + measurements = [(sparse.eye(table.size), table, 1e-12, clique)] + model = mst.fit_model(measurements, iters) + + assert isinstance(model, GraphicalModel) + assert model.domain == mst.domain + assert model.cliques == [clique] + assert model.elimination_order == list(clique) + assert model.total == table.sum() or 1 diff --git a/tests/mst/test_get_domain.py b/tests/mst/test_get_domain.py new file mode 100644 index 0000000..ad35bf3 --- /dev/null +++ b/tests/mst/test_get_domain.py @@ -0,0 +1,94 @@ +"""Unit tests for getting domain in `centhesus.MST`.""" + +import string +from unittest import mock + +import pytest +from hypothesis import given +from hypothesis import strategies as st +from mbi import Domain + +from ..strategies import ( + mocked_mst, + st_api_parameters, + st_feature_metadata_parameters, +) + + +@given(st_feature_metadata_parameters()) +def test_get_domain_of_feature(params): + """Test the domain of a feature can be retrieved correctly.""" + + population_type, area_type, dimensions, feature, metadata = params + + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: + query.return_value = metadata + domain = mst._get_domain_of_feature(feature) + + assert isinstance(domain, dict) + + items = [area_type] if feature == "area-types" else dimensions + assert list(domain.keys()) == metadata["id"].to_list() == items + assert list(domain.values()) == metadata["total_count"].to_list() + + query.assert_called_once_with(population_type, feature, *items) + + +@given(st_feature_metadata_parameters()) +def test_get_domain_of_feature_none_area_type(params): + """Test the feature domain getter when area type is None.""" + + population_type, _, dimensions, _, metadata = params + + mst = mocked_mst(population_type, None, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: + domain = mst._get_domain_of_feature("area-types") + + assert isinstance(domain, dict) + assert domain == {} + + query.assert_not_called() + + +@given(st_api_parameters(), st.text()) +def test_get_domain_of_feature_raises_error(params, feature): + """Test the domain getter raises an error for invalid features.""" + + mst = mocked_mst(*params) + + with pytest.raises(ValueError, match="^Feature"): + mst._get_domain_of_feature(feature) + + +@given( + st_api_parameters(), + st.dictionaries( + st.text(string.ascii_uppercase, min_size=1), st.integers() + ), + st.dictionaries( + st.text(string.ascii_lowercase, min_size=1), st.integers() + ), +) +def test_get_domain(params, area_type_domain, dimensions_domain): + """Test the domain getter can process metadata correctly.""" + + mst = mocked_mst(*params) + + with mock.patch("centhesus.mst.MST._get_domain_of_feature") as feature: + feature.side_effect = [area_type_domain, dimensions_domain] + domain = mst.get_domain() + + assert isinstance(domain, Domain) + assert domain.attrs == ( + *area_type_domain.keys(), + *dimensions_domain.keys(), + ) + + assert feature.call_count == 2 + assert [call.args for call in feature.call_args_list] == [ + ("area-types",), + ("dimensions",), + ] diff --git a/tests/mst/test_init.py b/tests/mst/test_init.py new file mode 100644 index 0000000..e5f2ec2 --- /dev/null +++ b/tests/mst/test_init.py @@ -0,0 +1,64 @@ +"""Tests for the `centhesus.mst` module.""" + + +from census21api import CensusAPI +from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE +from hypothesis import given + +from centhesus import MST + +from ..strategies import ( + mocked_mst, + st_api_parameters, +) + + +@given(st_api_parameters()) +def test_init(params): + """Test instantiation of the MST class.""" + + population_type, area_type, dimensions = params + + mst = mocked_mst(population_type, area_type, dimensions) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type == area_type + assert mst.dimensions == dimensions + + assert isinstance(mst.api, CensusAPI) + assert mst.domain is None + + +@given(st_api_parameters()) +def test_init_none_area_type(params): + """Test instantiation of the MST class when area type is None.""" + + population_type, _, dimensions = params + + mst = mocked_mst(population_type, None, dimensions) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type is None + assert mst.dimensions == dimensions + + assert isinstance(mst.api, CensusAPI) + assert mst.domain is None + + +@given(st_api_parameters()) +def test_init_none_dimensions(params): + """Test instantiation of the MST class when dimensions is None.""" + + population_type, area_type, _ = params + + mst = mocked_mst(population_type, area_type, None) + + assert isinstance(mst, MST) + assert mst.population_type == population_type + assert mst.area_type == area_type + assert mst.dimensions == DIMENSIONS_BY_POPULATION_TYPE[population_type] + + assert isinstance(mst.api, CensusAPI) + assert mst.domain is None diff --git a/tests/mst/test_measure.py b/tests/mst/test_measure.py new file mode 100644 index 0000000..fb3fae9 --- /dev/null +++ b/tests/mst/test_measure.py @@ -0,0 +1,77 @@ +"""Unit tests for the measurement methods in `centhesus.MST`.""" + +from unittest import mock + +import numpy as np +import pandas as pd +from hypothesis import given, settings +from hypothesis import strategies as st +from scipy import sparse + +from ..strategies import mocked_mst, st_single_marginals + + +@given(st_single_marginals(), st.booleans()) +def test_get_marginal(params, flatten): + """Test that a marginal table can be processed correctly.""" + + population_type, area_type, dimensions, clique, table = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_table") as query: + query.return_value = table + marginal = mst.get_marginal(clique, flatten) + + if flatten: + assert isinstance(marginal, np.ndarray) + assert (marginal == table["count"]).all() + else: + assert isinstance(marginal, pd.Series) + assert marginal.name == "count" + assert (marginal.reset_index() == table).all().all() + + query.assert_called_once() + + +@given(st_single_marginals(), st.booleans()) +def test_get_marginal_failed_call(params, flatten): + """Test that a failed call can be processed still.""" + + population_type, area_type, dimensions, clique, _ = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.CensusAPI.query_table") as query: + query.return_value = None + marginal = mst.get_marginal(clique, flatten) + + assert marginal is None + + query.assert_called_once() + + +@settings(deadline=None) +@given(st_single_marginals(), st.integers(1, 5)) +def test_measure(params, num_cliques): + """Test a set of cliques can be measured.""" + + population_type, area_type, dimensions, clique, table = params + mst = mocked_mst(population_type, area_type, dimensions) + + with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + get_marginal.return_value = table + measurements = mst.measure([clique] * num_cliques) + + assert isinstance(measurements, list) + assert len(measurements) == num_cliques + + for measurement in measurements: + assert isinstance(measurement, tuple) + assert len(measurement) == 4 + + ident, marg, noise, cliq = measurement + assert isinstance(ident, sparse._dia.dia_matrix) + assert ident.shape == (marg.size,) * 2 + assert ident.sum() == marg.size + assert marg.equals(table) + assert noise == 1e-12 + assert cliq == clique diff --git a/tests/mst/test_select.py b/tests/mst/test_select.py new file mode 100644 index 0000000..c5ffbac --- /dev/null +++ b/tests/mst/test_select.py @@ -0,0 +1,126 @@ +"""Unit tests for the selection methods in `centhesus.MST`.""" + +import itertools +from unittest import mock + +import networkx as nx +from hypothesis import given, settings + +from ..strategies import ( + mocked_mst, + st_importances, + st_single_marginals, + st_subgraphs, +) + + +@given(st_single_marginals(kind="pair")) +def test_calculate_importance_of_pair(params): + """Test the importance of a column pair can be calculated.""" + + population_type, area_type, dimensions, clique, table = params + table = table["count"] + mst = mocked_mst(population_type, area_type, dimensions) + + interim = mock.MagicMock() + interim.project.return_value.datavector.return_value = table.sample( + frac=1.0 + ).reset_index(drop=True) + with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + get_marginal.return_value = table + weight = mst._calculate_importance_of_pair(interim, clique) + + assert isinstance(weight, float) + assert weight >= 0 + + interim.project.assert_called_once_with(clique) + interim.project.return_value.datavector.assert_called_once_with() + get_marginal.assert_called_once_with(clique) + + +@given(st_single_marginals(kind="pair")) +def test_calculate_importance_of_pair_failed_call(params): + """Test that a failed call doesn't stop importance processing.""" + + population_type, area_type, dimensions, clique, _ = params + mst = mocked_mst(population_type, area_type, dimensions) + + interim = mock.MagicMock() + with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + get_marginal.return_value = None + weight = mst._calculate_importance_of_pair(interim, clique) + + assert weight is None + + interim.project.assert_not_called() + interim.project.return_value.datavector.assert_not_called() + get_marginal.assert_called_once_with(clique) + + +@settings(deadline=None) +@given(st_importances()) +def test_calculate_importances(params): + """Test that a set of importances can be calculated.""" + + population_type, area_type, dimensions, domain, importances = params + mst = mocked_mst(population_type, area_type, dimensions, domain=domain) + + with mock.patch("centhesus.mst.MST._calculate_importance_of_pair") as calc: + calc.side_effect = importances + weights = mst._calculate_importances("interim") + + pairs = list(itertools.combinations(domain, 2)) + calc.call_count == len(pairs) + call_args = [call.args for call in calc.call_args_list] + assert set(call_args) == set(("interim", pair) for pair in pairs) + + assert isinstance(weights, dict) + assert set(weights.keys()) == set(pairs) + + pairs_execution_order = [pair for _, pair in call_args] + for pair, importance in zip(pairs_execution_order, importances): + assert weights[pair] == importance + + +@given(st_importances()) +def test_find_maximum_spanning_tree(params): + """Test an MST can be found from a set of importances.""" + + *api_params, domain, importances = params + mst = mocked_mst(*api_params, domain=domain) + weights = dict(zip(itertools.combinations(domain, 2), importances)) + + tree = mst._find_maximum_spanning_tree(weights) + + assert isinstance(tree, nx.Graph) + assert set(tree.nodes) == set(domain) + assert set(tree.edges).issubset(weights.keys()) + for edge in tree.edges: + assert tree.edges[edge]["weight"] == -weights[edge] + + +@given(st_subgraphs()) +def test_select(params): + """Test that a set of two-way cliques can be found correctly.""" + + *api_params, domain, tree = params + mst = mocked_mst(*api_params, domain=domain) + + with mock.patch("centhesus.mst.MST.fit_model") as fit, mock.patch( + "centhesus.mst.MST._calculate_importances" + ) as calc, mock.patch( + "centhesus.mst.MST._find_maximum_spanning_tree" + ) as find: + fit.return_value = "interim" + calc.return_value = "weights" + find.return_value = tree + cliques = mst.select("measurements") + + possible_edges = [set(pair) for pair in itertools.combinations(domain, 2)] + assert isinstance(cliques, list) + for clique in cliques: + assert set(clique) in possible_edges + + fit.assert_called_once_with("measurements", iters=1000) + calc.assert_called_once_with("interim") + find.assert_called_once_with("weights") diff --git a/tests/strategies.py b/tests/strategies.py index 8eb4e2a..17a54cc 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -2,6 +2,7 @@ import itertools import math +from unittest import mock import networkx as nx import numpy as np @@ -14,6 +15,20 @@ from hypothesis import strategies as st from mbi import Domain +from centhesus import MST + + +def mocked_mst(population_type, area_type, dimensions, domain=None): + """Create an instance of MST with mocked `get_domain`.""" + + with mock.patch("centhesus.mst.MST.get_domain") as get_domain: + get_domain.return_value = domain + mst = MST(population_type, area_type, dimensions) + + get_domain.assert_called_once_with() + + return mst + @st.composite def st_api_parameters(draw): @@ -89,7 +104,7 @@ def st_domains(draw): """Create a domain and its parameters for a test.""" population_type, area_type, dimensions = draw(st_api_parameters()) - + num = len(dimensions) + 1 sizes = draw(st.lists(st.integers(2, 10), min_size=num, max_size=num)) domain = Domain.fromdict(dict(zip((area_type, *dimensions), sizes))) diff --git a/tests/test_mst.py b/tests/test_mst.py deleted file mode 100644 index 96369b4..0000000 --- a/tests/test_mst.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Tests for the `centhesus.mst` module.""" - -import itertools -import string -from unittest import mock - -import networkx as nx -import numpy as np -import pandas as pd -import pytest -from census21api import CensusAPI -from census21api.constants import ( - DIMENSIONS_BY_POPULATION_TYPE, -) -from hypothesis import given, settings -from hypothesis import strategies as st -from mbi import Domain, GraphicalModel -from scipy import sparse - -from centhesus import MST - -from .strategies import ( - st_api_parameters, - st_feature_metadata_parameters, - st_importances, - st_single_marginals, - st_subgraphs, -) - - -def mocked_mst(population_type, area_type, dimensions, domain=None): - """Create an instance of MST with mocked `get_domain`.""" - - with mock.patch("centhesus.mst.MST.get_domain") as get_domain: - get_domain.return_value = domain - mst = MST(population_type, area_type, dimensions) - - get_domain.assert_called_once_with() - - return mst - - -@given(st_api_parameters()) -def test_init(params): - """Test instantiation of the MST class.""" - - population_type, area_type, dimensions = params - - mst = mocked_mst(population_type, area_type, dimensions) - - assert isinstance(mst, MST) - assert mst.population_type == population_type - assert mst.area_type == area_type - assert mst.dimensions == dimensions - - assert isinstance(mst.api, CensusAPI) - assert mst.domain is None - - -@given(st_api_parameters()) -def test_init_none_area_type(params): - """Test instantiation of the MST class when area type is None.""" - - population_type, _, dimensions = params - - mst = mocked_mst(population_type, None, dimensions) - - assert isinstance(mst, MST) - assert mst.population_type == population_type - assert mst.area_type is None - assert mst.dimensions == dimensions - - assert isinstance(mst.api, CensusAPI) - assert mst.domain is None - - -@given(st_api_parameters()) -def test_init_none_dimensions(params): - """Test instantiation of the MST class when dimensions is None.""" - - population_type, area_type, _ = params - - mst = mocked_mst(population_type, area_type, None) - - assert isinstance(mst, MST) - assert mst.population_type == population_type - assert mst.area_type == area_type - assert mst.dimensions == DIMENSIONS_BY_POPULATION_TYPE[population_type] - - assert isinstance(mst.api, CensusAPI) - assert mst.domain is None - - -@given(st_feature_metadata_parameters()) -def test_get_domain_of_feature(params): - """Test the domain of a feature can be retrieved correctly.""" - - population_type, area_type, dimensions, feature, metadata = params - - mst = mocked_mst(population_type, area_type, dimensions) - - with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: - query.return_value = metadata - domain = mst._get_domain_of_feature(feature) - - assert isinstance(domain, dict) - - items = [area_type] if feature == "area-types" else dimensions - assert list(domain.keys()) == metadata["id"].to_list() == items - assert list(domain.values()) == metadata["total_count"].to_list() - - query.assert_called_once_with(population_type, feature, *items) - - -@given(st_feature_metadata_parameters()) -def test_get_domain_of_feature_none_area_type(params): - """Test the feature domain getter when area type is None.""" - - population_type, _, dimensions, _, metadata = params - - mst = mocked_mst(population_type, None, dimensions) - - with mock.patch("centhesus.mst.CensusAPI.query_feature") as query: - domain = mst._get_domain_of_feature("area-types") - - assert isinstance(domain, dict) - assert domain == {} - - query.assert_not_called() - - -@given(st_api_parameters(), st.text()) -def test_get_domain_of_feature_raises_error(params, feature): - """Test the domain getter raises an error for invalid features.""" - - mst = mocked_mst(*params) - - with pytest.raises(ValueError, match="^Feature"): - mst._get_domain_of_feature(feature) - - -@given( - st_api_parameters(), - st.dictionaries( - st.text(string.ascii_uppercase, min_size=1), st.integers() - ), - st.dictionaries( - st.text(string.ascii_lowercase, min_size=1), st.integers() - ), -) -def test_get_domain(params, area_type_domain, dimensions_domain): - """Test the domain getter can process metadata correctly.""" - - mst = mocked_mst(*params) - - with mock.patch("centhesus.mst.MST._get_domain_of_feature") as feature: - feature.side_effect = [area_type_domain, dimensions_domain] - domain = mst.get_domain() - - assert isinstance(domain, Domain) - assert domain.attrs == ( - *area_type_domain.keys(), - *dimensions_domain.keys(), - ) - - assert feature.call_count == 2 - assert [call.args for call in feature.call_args_list] == [ - ("area-types",), - ("dimensions",), - ] - - -@given(st_single_marginals(), st.booleans()) -def test_get_marginal(params, flatten): - """Test that a marginal table can be processed correctly.""" - - population_type, area_type, dimensions, clique, table = params - mst = mocked_mst(population_type, area_type, dimensions) - - with mock.patch("centhesus.mst.CensusAPI.query_table") as query: - query.return_value = table - marginal = mst.get_marginal(clique, flatten) - - if flatten: - assert isinstance(marginal, np.ndarray) - assert (marginal == table["count"]).all() - else: - assert isinstance(marginal, pd.Series) - assert marginal.name == "count" - assert (marginal.reset_index() == table).all().all() - - query.assert_called_once() - - -@given(st_single_marginals(), st.booleans()) -def test_get_marginal_failed_call(params, flatten): - """Test that a failed call can be processed still.""" - - population_type, area_type, dimensions, clique, _ = params - mst = mocked_mst(population_type, area_type, dimensions) - - with mock.patch("centhesus.mst.CensusAPI.query_table") as query: - query.return_value = None - marginal = mst.get_marginal(clique, flatten) - - assert marginal is None - - query.assert_called_once() - - -@settings(deadline=None) -@given(st_single_marginals(), st.integers(1, 5)) -def test_measure(params, num_cliques): - """Test a set of cliques can be measured.""" - - population_type, area_type, dimensions, clique, table = params - mst = mocked_mst(population_type, area_type, dimensions) - - with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: - get_marginal.return_value = table - measurements = mst.measure([clique] * num_cliques) - - assert isinstance(measurements, list) - assert len(measurements) == num_cliques - - for measurement in measurements: - assert isinstance(measurement, tuple) - assert len(measurement) == 4 - - ident, marg, noise, cliq = measurement - assert isinstance(ident, sparse._dia.dia_matrix) - assert ident.shape == (marg.size,) * 2 - assert ident.sum() == marg.size - assert marg.equals(table) - assert noise == 1e-12 - assert cliq == clique - - -@given(st_single_marginals(), st.integers(1, 5)) -def test_fit_model(params, iters): - """Test that a model can be fitted to some measurements.""" - - population_type, area_type, dimensions, clique, table = params - domain = Domain.fromdict(table.drop("count", axis=1).nunique().to_dict()) - table = table["count"] - mst = mocked_mst(population_type, area_type, dimensions, domain=domain) - - measurements = [(sparse.eye(table.size), table, 1e-12, clique)] - model = mst.fit_model(measurements, iters) - - assert isinstance(model, GraphicalModel) - assert model.domain == mst.domain - assert model.cliques == [clique] - assert model.elimination_order == list(clique) - assert model.total == table.sum() or 1 - - -@given(st_single_marginals(kind="pair")) -def test_calculate_importance_of_pair(params): - """Test the importance of a column pair can be calculated.""" - - population_type, area_type, dimensions, clique, table = params - table = table["count"] - mst = mocked_mst(population_type, area_type, dimensions) - - interim = mock.MagicMock() - interim.project.return_value.datavector.return_value = table.sample( - frac=1.0 - ).reset_index(drop=True) - with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: - get_marginal.return_value = table - weight = mst._calculate_importance_of_pair(interim, clique) - - assert isinstance(weight, float) - assert weight >= 0 - - interim.project.assert_called_once_with(clique) - interim.project.return_value.datavector.assert_called_once_with() - get_marginal.assert_called_once_with(clique) - - -@given(st_single_marginals(kind="pair")) -def test_calculate_importance_of_pair_failed_call(params): - """Test that a failed call doesn't stop importance processing.""" - - population_type, area_type, dimensions, clique, _ = params - mst = mocked_mst(population_type, area_type, dimensions) - - interim = mock.MagicMock() - with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: - get_marginal.return_value = None - weight = mst._calculate_importance_of_pair(interim, clique) - - assert weight is None - - interim.project.assert_not_called() - interim.project.return_value.datavector.assert_not_called() - get_marginal.assert_called_once_with(clique) - - -@settings(deadline=None) -@given(st_importances()) -def test_calculate_importances(params): - """Test that a set of importances can be calculated.""" - - population_type, area_type, dimensions, domain, importances = params - mst = mocked_mst(population_type, area_type, dimensions, domain=domain) - - with mock.patch("centhesus.mst.MST._calculate_importance_of_pair") as calc: - calc.side_effect = importances - weights = mst._calculate_importances("interim") - - pairs = list(itertools.combinations(domain, 2)) - calc.call_count == len(pairs) - call_args = [call.args for call in calc.call_args_list] - assert set(call_args) == set(("interim", pair) for pair in pairs) - - assert isinstance(weights, dict) - assert set(weights.keys()) == set(pairs) - - pairs_execution_order = [pair for _, pair in call_args] - for pair, importance in zip(pairs_execution_order, importances): - assert weights[pair] == importance - - -@given(st_importances()) -def test_find_maximum_spanning_tree(params): - """Test an MST can be found from a set of importances.""" - - *api_params, domain, importances = params - mst = mocked_mst(*api_params, domain=domain) - weights = dict(zip(itertools.combinations(domain, 2), importances)) - - tree = mst._find_maximum_spanning_tree(weights) - - assert isinstance(tree, nx.Graph) - assert set(tree.nodes) == set(domain) - assert set(tree.edges).issubset(weights.keys()) - for edge in tree.edges: - assert tree.edges[edge]["weight"] == -weights[edge] - -@given(st_subgraphs()) -def test_select(params): - """Test that a set of two-way cliques can be found correctly.""" - - *api_params, domain, tree = params - mst = mocked_mst(*api_params, domain=domain) - - with mock.patch("centhesus.mst.MST.fit_model") as fit, mock.patch("centhesus.mst.MST._calculate_importances") as calc, mock.patch("centhesus.mst.MST._find_maximum_spanning_tree") as find: - fit.return_value = "interim" - calc.return_value = "weights" - find.return_value = tree - cliques = mst.select("measurements") - - possible_edges = [set(pair) for pair in itertools.combinations(domain, 2)] - assert isinstance(cliques, list) - for clique in cliques: - assert set(clique) in possible_edges - - fit.assert_called_once_with("measurements", iters=1000) - calc.assert_called_once_with("interim") - find.assert_called_once_with("weights") From 05fc4664ae3f8a3b20fb07dfc83db5b05604899b Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Wed, 15 Nov 2023 16:10:17 +0000 Subject: [PATCH 32/50] Format tests --- tests/mst/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mst/__init__.py b/tests/mst/__init__.py index ead63ef..ca47d11 100644 --- a/tests/mst/__init__.py +++ b/tests/mst/__init__.py @@ -1 +1 @@ -"""MST-level tests.""" \ No newline at end of file +"""MST-level tests.""" From 8694b4b8f59d19b716a7ed3338e89c12a4e739f5 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Thu, 16 Nov 2023 11:26:34 +0000 Subject: [PATCH 33/50] Write test for column synthesiser --- tests/mst/test_generate.py | 42 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/mst/test_generate.py diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py new file mode 100644 index 0000000..5856b5f --- /dev/null +++ b/tests/mst/test_generate.py @@ -0,0 +1,42 @@ +"""Unit tests for the generation methods of `centhesus.MST`.""" + +import dask +import dask.array as da +import numpy as np +from hypothesis import assume, given, settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +from centhesus import MST + + +@settings(deadline=None) +@given( + arrays( + float, + st.integers(2, 10), + elements=st.one_of((st.just(0), st.floats(1, 50))), + ), + st.integers(10, 100), +) +def test_synthesise_column(marginal, total): + """Test a column can be synthesised from a marginal.""" + + assume(marginal.sum()) + + prng = da.random.default_rng(0) + column = MST._synthesise_column(marginal, total, prng) + + assert isinstance(column, da.Array) + assert column.shape == (total,) + assert column.dtype == int + + uniques, counts = dask.compute(*da.unique(column, return_counts=True)) + if len(uniques) == marginal.size: + assert np.array_equal(uniques, np.arange(marginal.size)) + assert np.all(counts - marginal * total / marginal.sum() <= 1) + else: + assert set(uniques).issubset(range(marginal.size)) + assert np.all( + counts - marginal[uniques] * total / marginal[uniques].sum() <= 1 + ) From 85307beae888625309523d489e5f0271fd0f42b9 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Thu, 16 Nov 2023 11:39:35 +0000 Subject: [PATCH 34/50] Implement column synthesiser --- src/centhesus/mst.py | 60 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 312b697..f4b87b8 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -3,6 +3,7 @@ import itertools import dask +import dask.array as da import networkx as nx import numpy as np from census21api import CensusAPI @@ -369,3 +370,62 @@ def select(self, measurements): tree = self._find_maximum_spanning_tree(weights) return list(tree.edges) + + @staticmethod + def _synthesise_column(marginal, total, prng): + """ + Sample a column of given length based on a marginal. + + Columns are synthesised to match the distribution of the + marginal very closely. The process for synthesising the column + is as follows: + + 1. Normalise the marginal against the total, and then separate + its integer and fractional components. + 2. If there are insufficient integer counts, distribute the + additional elements among the integer counts randomly + using the fractional component as a weight. In this way, the + difference between the normalised marginal and the final + counts in the synthetic data will be at most one. + 3. Create an array by repeating the index of the marginal + according to the adjusted integer counts. + 4. Permute the array to give a synthetic column. + + Parameters + ---------- + marginal : np.ndarray + Marginal counts from which to synthesise the column. + total : int + Number of elements in the synthesised column. + prng : dask.array.random.Generator + Pseudo-random number generator. We use this to distribute + additional elements in the synthetic column, and to shuffle + its elements after creation. + + Returns + ------- + column : dask.array.Array + Synthetic column closely matching the distribution of the + marginal. + """ + + marginal = marginal.copy() + marginal *= total / marginal.sum() + fractions, integers = np.modf(marginal) + + integers = integers.astype(int) + extra = total - integers.sum() + if extra > 0: + idx = prng.choice( + marginal.size, extra, False, fractions / fractions.sum() + ).compute() + integers[idx] += 1 + + uniques = da.arange(marginal.size) + repeats = ( + da.repeat(uniques[i : i + 1], count) + for i, count in enumerate(integers) + ) + values = da.concatenate(repeats) + + return prng.permutation(values) From d7283222aa90335c6d71e978d8a9cd5103839ed8 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Thu, 16 Nov 2023 17:26:35 +0000 Subject: [PATCH 35/50] Write tests for dependent column synthesiser --- tests/mst/test_generate.py | 39 +++++++++++++++++++++++++++++++++++--- tests/strategies.py | 19 +++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py index 5856b5f..a811592 100644 --- a/tests/mst/test_generate.py +++ b/tests/mst/test_generate.py @@ -1,14 +1,20 @@ """Unit tests for the generation methods of `centhesus.MST`.""" +from unittest import mock + import dask import dask.array as da +import dask.dataframe as dd import numpy as np +import pandas as pd from hypothesis import assume, given, settings from hypothesis import strategies as st from hypothesis.extra.numpy import arrays from centhesus import MST +from ..strategies import st_group_marginals + @settings(deadline=None) @given( @@ -27,11 +33,13 @@ def test_synthesise_column(marginal, total): prng = da.random.default_rng(0) column = MST._synthesise_column(marginal, total, prng) - assert isinstance(column, da.Array) - assert column.shape == (total,) + assert isinstance(column, dd.Series) + assert dask.compute(*column.shape) == (total,) assert column.dtype == int - uniques, counts = dask.compute(*da.unique(column, return_counts=True)) + uniques, counts = dask.compute( + *da.unique(column.to_dask_array(lengths=True), return_counts=True) + ) if len(uniques) == marginal.size: assert np.array_equal(uniques, np.arange(marginal.size)) assert np.all(counts - marginal * total / marginal.sum() <= 1) @@ -40,3 +48,28 @@ def test_synthesise_column(marginal, total): assert np.all( counts - marginal[uniques] * total / marginal[uniques].sum() <= 1 ) + + +@given(st_group_marginals()) +def test_synthesise_group(params): + """Test that a dependent column can be synthesised in groups.""" + + group, marginal = params + + column, prng = "foo", da.random.default_rng(0) + with mock.patch("centhesus.mst.MST._synthesise_column") as synth: + synth.return_value.compute.return_value = marginal + synthetic = ( + group.copy() + .groupby("a") + .apply(MST._synthesise_column_in_group, column, [[]], prng) + ) + + assert isinstance(synthetic, pd.DataFrame) + assert synthetic.shape[0] == group.shape[0] + assert synthetic.columns.to_list() == [*group.columns.to_list(), column] + + assert np.array_equal(synthetic[column], marginal) + + synth.assert_called_once_with([], group.shape[0], prng, 1e6) + synth.return_value.compute.called_once_with() diff --git a/tests/strategies.py b/tests/strategies.py index 17a54cc..e3f5e9e 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -143,3 +143,22 @@ def st_subgraphs(draw): graph.add_edges_from(edges) return population_type, area_type, dimensions, domain, graph + + +@st.composite +def st_group_marginals(draw): + """Create a group and matching marginal for a test.""" + + num_rows_in_group = draw(st.integers(10, 50)) + group = pd.DataFrame({"a": [0] * num_rows_in_group}) + object.__setattr__(group, "name", 0) + + marginal = draw( + st.lists( + st.integers(0, 3), + min_size=num_rows_in_group, + max_size=num_rows_in_group, + ) + ) + + return group, marginal From 5e048f14b57a637dfef33358924ab888c9901427 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Thu, 16 Nov 2023 17:26:53 +0000 Subject: [PATCH 36/50] Implement column synthesisers. --- src/centhesus/mst.py | 61 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index f4b87b8..d814565 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -4,6 +4,7 @@ import dask import dask.array as da +import dask.dataframe as dd import networkx as nx import numpy as np from census21api import CensusAPI @@ -372,7 +373,7 @@ def select(self, measurements): return list(tree.edges) @staticmethod - def _synthesise_column(marginal, total, prng): + def _synthesise_column(marginal, total, prng, chunksize=1e6): """ Sample a column of given length based on a marginal. @@ -380,7 +381,7 @@ def _synthesise_column(marginal, total, prng): marginal very closely. The process for synthesising the column is as follows: - 1. Normalise the marginal against the total, and then separate + 1. Scale the marginal against the total, and then separate its integer and fractional components. 2. If there are insufficient integer counts, distribute the additional elements among the integer counts randomly @@ -401,15 +402,16 @@ def _synthesise_column(marginal, total, prng): Pseudo-random number generator. We use this to distribute additional elements in the synthetic column, and to shuffle its elements after creation. + chunksize : int or float + Target size of a chunk or partition in the column. Returns ------- - column : dask.array.Array + column : dask.dataframe.Series Synthetic column closely matching the distribution of the marginal. """ - marginal = marginal.copy() marginal *= total / marginal.sum() fractions, integers = np.modf(marginal) @@ -421,11 +423,52 @@ def _synthesise_column(marginal, total, prng): ).compute() integers[idx] += 1 - uniques = da.arange(marginal.size) + uniques = np.arange(integers.size) repeats = ( - da.repeat(uniques[i : i + 1], count) - for i, count in enumerate(integers) + unique * da.ones(shape=count, dtype=int) + for unique, count in zip(uniques, integers) ) - values = da.concatenate(repeats) - return prng.permutation(values) + values = da.concatenate(repeats).rechunk(chunksize) + column = dd.from_dask_array(prng.permutation(values)) + + return column + + @staticmethod + def _synthesise_column_in_group( + group, column, marginal, prng, chunksize=1e6 + ): + """ + Synthesise a column inside a group-by operation. + + This operation is used for synthesising columns that depend on + those that have already been synthesised. By performing this + synthesis in a group-by operation, we ensure a close matching to + the marginal distribution estimated by the graphical model given + what has already been synthesised. + + Parameters + ---------- + group : pandas.DataFrame + Group data frame on which to operate. + column : str + Name of column to be synthesised. + marginal : np.ndarray + Marginal estimated from the graphical model for the column + and all the columns it depends on. + prng : dask.array.random.Generator + Pseudo-random number generator. Used to synthesise the + column within this group. + + Returns + ------- + group : + Group with new synthetic column. + """ + + idx = group.name + group[column] = MST._synthesise_column( + marginal[idx], group.shape[0], prng, chunksize + ).compute() + + return group From 7813d538e0d07ef92d193671d8a317b3c4a573f8 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Thu, 16 Nov 2023 17:27:39 +0000 Subject: [PATCH 37/50] Fix group/marginal strategies --- tests/strategies.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index e3f5e9e..1a8c9a0 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -150,8 +150,7 @@ def st_group_marginals(draw): """Create a group and matching marginal for a test.""" num_rows_in_group = draw(st.integers(10, 50)) - group = pd.DataFrame({"a": [0] * num_rows_in_group}) - object.__setattr__(group, "name", 0) + group = pd.DataFrame({"a": [0] * num_rows_in_group}).groupby("a") marginal = draw( st.lists( From 3c432dbd739bef1b62d12b8363ada903a2fb74b6 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Thu, 16 Nov 2023 17:44:22 +0000 Subject: [PATCH 38/50] Fix dependent column test and strategy (for real) --- tests/mst/test_generate.py | 6 +++++- tests/strategies.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py index a811592..759bf8e 100644 --- a/tests/mst/test_generate.py +++ b/tests/mst/test_generate.py @@ -52,7 +52,11 @@ def test_synthesise_column(marginal, total): @given(st_group_marginals()) def test_synthesise_group(params): - """Test that a dependent column can be synthesised in groups.""" + """ + Test that a dependent column can be synthesised in groups. + + We only test the case where there is a single group currently. + """ group, marginal = params diff --git a/tests/strategies.py b/tests/strategies.py index 1a8c9a0..f59b170 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -150,7 +150,7 @@ def st_group_marginals(draw): """Create a group and matching marginal for a test.""" num_rows_in_group = draw(st.integers(10, 50)) - group = pd.DataFrame({"a": [0] * num_rows_in_group}).groupby("a") + group = pd.DataFrame({"a": [0] * num_rows_in_group}) marginal = draw( st.lists( From 051dc7ff401088d43890e5a09468e0ce3538860d Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Fri, 17 Nov 2023 11:16:45 +0000 Subject: [PATCH 39/50] Improve dependent column test --- tests/mst/test_generate.py | 39 ++++++++++++++++++++++---------------- tests/strategies.py | 13 ++++++++----- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py index 759bf8e..6567d87 100644 --- a/tests/mst/test_generate.py +++ b/tests/mst/test_generate.py @@ -13,7 +13,7 @@ from centhesus import MST -from ..strategies import st_group_marginals +from ..strategies import st_existing_new_columns @settings(deadline=None) @@ -50,30 +50,37 @@ def test_synthesise_column(marginal, total): ) -@given(st_group_marginals()) +@given(st_existing_new_columns()) def test_synthesise_group(params): - """ - Test that a dependent column can be synthesised in groups. + """Test that a dependent column can be synthesised in groups.""" - We only test the case where there is a single group currently. - """ - - group, marginal = params + existing, new = params + num_groups = existing["a"].nunique() column, prng = "foo", da.random.default_rng(0) + empty_marginal = [[]] * num_groups + with mock.patch("centhesus.mst.MST._synthesise_column") as synth: - synth.return_value.compute.return_value = marginal + synth.return_value.compute.return_value = new synthetic = ( - group.copy() + existing.copy() .groupby("a") - .apply(MST._synthesise_column_in_group, column, [[]], prng) + .apply( + MST._synthesise_column_in_group, column, empty_marginal, prng + ) ) assert isinstance(synthetic, pd.DataFrame) - assert synthetic.shape[0] == group.shape[0] - assert synthetic.columns.to_list() == [*group.columns.to_list(), column] + assert synthetic.shape[0] == existing.shape[0] + assert synthetic.columns.to_list() == [*existing.columns.to_list(), column] - assert np.array_equal(synthetic[column], marginal) + assert np.array_equal(synthetic[column], new * num_groups) - synth.assert_called_once_with([], group.shape[0], prng, 1e6) - synth.return_value.compute.called_once_with() + assert synth.call_count == num_groups + for i, call in enumerate(synth.call_args_list): + assert call.args == ([], (existing["a"] == i).sum(), prng, 1e6) + + assert synth.return_value.compute.call_count == num_groups + assert ( + synth.return_value.compute.call_args_list == [mock.call()] * num_groups + ) diff --git a/tests/strategies.py b/tests/strategies.py index f59b170..1cc3e6b 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -146,13 +146,16 @@ def st_subgraphs(draw): @st.composite -def st_group_marginals(draw): - """Create a group and matching marginal for a test.""" +def st_existing_new_columns(draw): + """Create an existing column and a new one for a test.""" + num_groups = draw(st.integers(1, 3)) num_rows_in_group = draw(st.integers(10, 50)) - group = pd.DataFrame({"a": [0] * num_rows_in_group}) + existing = pd.DataFrame( + {"a": [i for i in range(num_groups) for _ in range(num_rows_in_group)]} + ) - marginal = draw( + new = draw( st.lists( st.integers(0, 3), min_size=num_rows_in_group, @@ -160,4 +163,4 @@ def st_group_marginals(draw): ) ) - return group, marginal + return existing, new From e09338ba1d9928089388bd540d8c543e8556856d Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Thu, 23 Nov 2023 10:41:55 +0000 Subject: [PATCH 40/50] Write test for setting up generation --- tests/mst/test_generate.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py index 6567d87..b1e8d3c 100644 --- a/tests/mst/test_generate.py +++ b/tests/mst/test_generate.py @@ -16,6 +16,31 @@ from ..strategies import st_existing_new_columns +@given( + st.floats(1, 100), + st.lists(st.text(), min_size=1, max_size=10), + st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=5), + st.one_of((st.just(None), st.integers(1, 100))), + st.integers(0, 10), +) +def test_setup_generate(total, elimination_order, cliques_, nrows, seed): + """Test that generation can be set up correctly.""" + + model = mock.MagicMock() + model.total = total + model.elimination_order = elimination_order + model.cliques = cliques_ + + nrows, prng, cliques, column, order = MST._setup_generate(model, nrows, seed) + + assert isinstance(nrows, int) + assert nrows == total or int(model.total) + assert isinstance(prng, da.random.Generator) + assert cliques == [set(clique) for clique in cliques_] + assert column == elimination_order[-1] + assert order == elimination_order[-2::-1] + + @settings(deadline=None) @given( arrays( @@ -51,7 +76,7 @@ def test_synthesise_column(marginal, total): @given(st_existing_new_columns()) -def test_synthesise_group(params): +def test_synthesise_column_in_group(params): """Test that a dependent column can be synthesised in groups.""" existing, new = params @@ -61,7 +86,7 @@ def test_synthesise_group(params): empty_marginal = [[]] * num_groups with mock.patch("centhesus.mst.MST._synthesise_column") as synth: - synth.return_value.compute.return_value = new + synth.return_value = new synthetic = ( existing.copy() .groupby("a") @@ -80,7 +105,4 @@ def test_synthesise_group(params): for i, call in enumerate(synth.call_args_list): assert call.args == ([], (existing["a"] == i).sum(), prng, 1e6) - assert synth.return_value.compute.call_count == num_groups - assert ( - synth.return_value.compute.call_args_list == [mock.call()] * num_groups - ) + assert synth.call_count == num_groups From 5573af3a60662bd9dc85170bde790604b98130dc Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Thu, 23 Nov 2023 10:53:54 +0000 Subject: [PATCH 41/50] Dump commit. Moving offline. --- src/centhesus/mst.py | 183 ++++++++++++++++++++++++++++++++++---- tests/mst/test_measure.py | 2 +- 2 files changed, 167 insertions(+), 18 deletions(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index d814565..5e827e1 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -204,14 +204,14 @@ def measure(self, cliques): tasks = [] for clique in cliques: - get_marginal = dask.delayed(lambda x: (x, self.get_marginal(x))) - tasks.append(get_marginal(clique)) + marginal = dask.delayed(self.get_marginal)(clique) + tasks.append(marginal) - indexed_marginals = dask.compute(*tasks) + marginals = dask.compute(*tasks) measurements = [ - (sparse.eye(marginal.size), marginal, 1e-12, clique) - for clique, marginal in indexed_marginals + (sparse.eye(marginal.size), marginal, 1, clique) + for clique, marginal in zip(cliques, marginals) if marginal is not None ] @@ -294,16 +294,16 @@ def _calculate_importances(self, interim): pairs = list(itertools.combinations(self.domain.attrs, 2)) tasks = [] for pair in pairs: - calculate_importance = dask.delayed( - lambda x: (x, self._calculate_importance_of_pair(interim, x)) + importance = dask.delayed(self._calculate_importance_of_pair)( + interim, pair ) - tasks.append(calculate_importance(pair)) + tasks.append(importance) - indexed_importances = dask.compute(*tasks) + importances = dask.compute(*tasks) weights = { pair: importance - for pair, importance in indexed_importances + for pair, importance in zip(pairs, importances) if importance is not None } @@ -373,7 +373,41 @@ def select(self, measurements): return list(tree.edges) @staticmethod - def _synthesise_column(marginal, total, prng, chunksize=1e6): + def _setup_generate(model, nrows, seed): + """ + Set everything up for the generation of the synthetic data. + + Parameters + ---------- + model : mbi.GraphicalModel + Model from which the synthetic data will be drawn. + nrows : int or None + Number of rows in the synthetic data. Inferred from `model` + if `None`. + seed : int or None + Pseudo-random seed. If `None`, randomness not reproducible. + + Returns + ------- + nrows : int + Number of rows to generate. + prng : dask.array.random.Generator + Pseudo-random number generator. + cliques : list of set + Cliques identified by the graphical model. + order : list of str + Order in which to synthesise the columns. + """ + + nrows = int(model.total) if nrows is None else nrows + prng = da.random.default_rng(seed) + cliques = [set(clique) for clique in model.cliques] + column, *order = model.elimination_order[::-1] + + return nrows, prng, cliques, column, order + + @staticmethod + def _synthesise_column(marginal, nrows, prng, chunksize=1e6): """ Sample a column of given length based on a marginal. @@ -381,8 +415,8 @@ def _synthesise_column(marginal, total, prng, chunksize=1e6): marginal very closely. The process for synthesising the column is as follows: - 1. Scale the marginal against the total, and then separate - its integer and fractional components. + 1. Scale the marginal against the total count required, and then + separate its integer and fractional components. 2. If there are insufficient integer counts, distribute the additional elements among the integer counts randomly using the fractional component as a weight. In this way, the @@ -396,7 +430,7 @@ def _synthesise_column(marginal, total, prng, chunksize=1e6): ---------- marginal : np.ndarray Marginal counts from which to synthesise the column. - total : int + nrows : int Number of elements in the synthesised column. prng : dask.array.random.Generator Pseudo-random number generator. We use this to distribute @@ -412,11 +446,11 @@ def _synthesise_column(marginal, total, prng, chunksize=1e6): marginal. """ - marginal *= total / marginal.sum() + marginal *= nrows / marginal.sum() fractions, integers = np.modf(marginal) integers = integers.astype(int) - extra = total - integers.sum() + extra = nrows - integers.sum() if extra > 0: idx = prng.choice( marginal.size, extra, False, fractions / fractions.sum() @@ -469,6 +503,121 @@ def _synthesise_column_in_group( idx = group.name group[column] = MST._synthesise_column( marginal[idx], group.shape[0], prng, chunksize - ).compute() + ) return group + + def _synthesise_first_column(self, model, column, nrows, prng): + """ + Sample the first column from the model as a data frame. + + Parameters + ---------- + model : mbi.GraphicalModel + Model from which to synthesise the column. + column : str + Name of the column to synthesise. + nrows : int + Number of rows to generate. + prng : dask.array.random.Generator + Pseudo-random number generator. + + Returns + ------- + data : dask.dataframe.DataFrame + Data frame containing the first synthetic column. + """ + + marginal = model.project([column]).datavector(flatten=False) + data = self._synthesise_column(marginal, nrows, prng).to_frame( + name=column + ) + + return data + + @staticmethod + def _find_prerequisite_columns(column, cliques, used): + """ + Find the columns that inform the synthesis of a new column. + + Parameters + ---------- + column : str + Name of column to be synthesised. + cliques : list of set + Cliques identified by the graphical model. + used : set of str + Names of columns that have already been synthesised. + + Returns + ------- + prerequisites : tuple of str + All columns needed to synthesise the new column. + """ + + member_of_cliques = [clique for clique in cliques if column in clique] + prerequisites = used.intersection(set.union(*member_of_cliques)) + + return tuple(prerequisites) + + def generate(self, model, nrows=None, seed=None): + """ + Generate a synthetic dataset from the estimated model. + + Columns are synthesised in the order determined by the graphical + model. With each column after the first, we search for all the + columns on which it depends according to the model that have + been synthesised already. + + Parameters + ---------- + model : mbi.GraphicalModel + Model from which to draw synthetic data. This model should + be fit to all the marginal tables you care about. + nrows : int, optional + Number of rows in the synthetic dataset. If not specified, + the length of the dataset is inferred from the model. + seed : int, optional + Seed for pseudo-random number generation. If not specified, + the results will not be reproducible. + + Returns + ------- + data : dask.dataframe.DataFrame + Data frame containing the synthetic data. We use Dask to + allow for larger-than-memory datasets. As such, it is lazily + executed. + """ + + nrows, prng, cliques, column, order = self._setup_generate( + model, nrows, seed + ) + data = self._synthesise_first_column(model, column, nrows, prng) + used = {column} + + for column in order: + clique = self._find_prerequisite_columns(column, cliques, used) + used.add(column) + + marginal = model.project(clique + (column,)).datavector( + flatten=False + ) + + if len(clique) >= 1: + data = ( + data.groupby(list(clique)) + .apply( + self._synthesise_column_in_group, + column, + marginal, + prng, + meta={**data.dtypes, column: int}, + ) + .reset_index(drop=True) + ) + else: + data[column] = self._synthesise_column(marginal, nrows, prng) + + data = data.repartition("100MB") + + return data diff --git a/tests/mst/test_measure.py b/tests/mst/test_measure.py index fb3fae9..d8a1f79 100644 --- a/tests/mst/test_measure.py +++ b/tests/mst/test_measure.py @@ -73,5 +73,5 @@ def test_measure(params, num_cliques): assert ident.shape == (marg.size,) * 2 assert ident.sum() == marg.size assert marg.equals(table) - assert noise == 1e-12 + assert noise == 1 assert cliq == clique From 7ed16728d214f1fd7b4ef47c2c8fd5c05fa27b7e Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Fri, 24 Nov 2023 10:03:52 +0000 Subject: [PATCH 42/50] Write test for synthesising first column --- src/centhesus/mst.py | 5 +++-- tests/mst/test_generate.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 5e827e1..6aac88d 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -507,7 +507,8 @@ def _synthesise_column_in_group( return group - def _synthesise_first_column(self, model, column, nrows, prng): + @staticmethod + def _synthesise_first_column(model, column, nrows, prng): """ Sample the first column from the model as a data frame. @@ -529,7 +530,7 @@ def _synthesise_first_column(self, model, column, nrows, prng): """ marginal = model.project([column]).datavector(flatten=False) - data = self._synthesise_column(marginal, nrows, prng).to_frame( + data = MST._synthesise_column(marginal, nrows, prng).to_frame( name=column ) diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py index b1e8d3c..130a0c2 100644 --- a/tests/mst/test_generate.py +++ b/tests/mst/test_generate.py @@ -31,7 +31,9 @@ def test_setup_generate(total, elimination_order, cliques_, nrows, seed): model.elimination_order = elimination_order model.cliques = cliques_ - nrows, prng, cliques, column, order = MST._setup_generate(model, nrows, seed) + nrows, prng, cliques, column, order = MST._setup_generate( + model, nrows, seed + ) assert isinstance(nrows, int) assert nrows == total or int(model.total) @@ -106,3 +108,33 @@ def test_synthesise_column_in_group(params): assert call.args == ([], (existing["a"] == i).sum(), prng, 1e6) assert synth.call_count == num_groups + + +@settings(deadline=None) +@given( + arrays( + int, + st.integers(2, 10), + elements=st.integers(0, 50), + ), + st.text(min_size=1), + st.integers(10, 100), +) +def test_synthesise_first_column(values, column, nrows): + """Test that a single column frame can be created.""" + + prng = da.random.default_rng(0) + model = mock.MagicMock() + model.project.return_value.datavector.return_value = "marginal" + + with mock.patch("centhesus.mst.MST._synthesise_column") as synth: + synth.return_value = dd.from_array(values) + first = MST._synthesise_first_column(model, column, nrows, prng) + + assert isinstance(first, dd.DataFrame) + assert first.columns.to_list() == [column] + assert np.array_equal(first[column].compute(), values) + + model.project.assert_called_once_with([column]) + model.project.return_value.datavector.called_once_with(flatten=False) + synth.assert_called_once_with("marginal", nrows, prng) From 963e6406d63a62d919833441148626ef1a6afad6 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Fri, 24 Nov 2023 11:00:42 +0000 Subject: [PATCH 43/50] Write test for finding prerequisite columns --- tests/mst/test_generate.py | 20 +++++++++++++++++++- tests/strategies.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py index 130a0c2..59f4bc4 100644 --- a/tests/mst/test_generate.py +++ b/tests/mst/test_generate.py @@ -13,7 +13,7 @@ from centhesus import MST -from ..strategies import st_existing_new_columns +from ..strategies import st_existing_new_columns, st_prerequisite_columns @given( @@ -138,3 +138,21 @@ def test_synthesise_first_column(values, column, nrows): model.project.assert_called_once_with([column]) model.project.return_value.datavector.called_once_with(flatten=False) synth.assert_called_once_with("marginal", nrows, prng) + + +@given(st_prerequisite_columns()) +def test_find_prerequisite_columns(params): + """Test we can find all the columns on which another depends.""" + + column, cliques, used = params + + prerequisites = MST._find_prerequisite_columns(column, cliques, used) + + expected = set( + other + for clique in cliques + for other in clique + if column in clique and other != column and other in used + ) + assert isinstance(prerequisites, tuple) + assert set(prerequisites) == expected diff --git a/tests/strategies.py b/tests/strategies.py index 1cc3e6b..07f7a47 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -12,6 +12,7 @@ DIMENSIONS_BY_POPULATION_TYPE, POPULATION_TYPES, ) +from hypothesis import assume from hypothesis import strategies as st from mbi import Domain @@ -164,3 +165,36 @@ def st_existing_new_columns(draw): ) return existing, new + + +@st.composite +def st_prerequisite_columns(draw): + """Create a column, set of cliques and a used set for a test.""" + + columns = draw( + st.sets( + st.sampled_from(DIMENSIONS_BY_POPULATION_TYPE["UR_HH"]), min_size=2 + ).map(list) + ) + column = draw(st.sampled_from(columns)) + + combinations = [ + *itertools.combinations(columns, 2), + *itertools.combinations(columns, 3), + ] + + cliques = draw( + st.lists( + st.sampled_from(combinations).map(set), min_size=len(columns) - 1 + ) + ) + assume(any(column in clique for clique in cliques)) + + used = draw( + st.sets( + st.sampled_from([col for col in columns if col != column]), + min_size=1, + ) + ) + + return column, cliques, used From 1a2681ef5a1ffedeb7b0ce65ca78259e15c64305 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Mon, 27 Nov 2023 11:08:56 +0000 Subject: [PATCH 44/50] Test and implement generation method --- src/centhesus/mst.py | 79 ++++++++++++++++++++------------- tests/mst/test_generate.py | 91 +++++++++++++++++++++++++++++++++++--- 2 files changed, 132 insertions(+), 38 deletions(-) diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index 6aac88d..28fc146 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -469,11 +469,11 @@ def _synthesise_column(marginal, nrows, prng, chunksize=1e6): return column @staticmethod - def _synthesise_column_in_group( - group, column, marginal, prng, chunksize=1e6 + def _synthesise_column_in_group_by_partition( + partition, clique, column, marginal, prng, chunksize=1e6 ): """ - Synthesise a column inside a group-by operation. + Synthesise a column inside a groupby-apply over the partitions. This operation is used for synthesising columns that depend on those that have already been synthesised. By performing this @@ -481,31 +481,48 @@ def _synthesise_column_in_group( the marginal distribution estimated by the graphical model given what has already been synthesised. + Mapping the groupby-apply operation across the partitions allows + even very large datasets (10M+) to be created without + out-of-memory errors. + Parameters ---------- - group : pandas.DataFrame - Group data frame on which to operate. + partition : dask.dataframe.DataFrame + Partition of the data frame on which to operate. + clique : list of str + Prerequisite columns by which to group the operation. column : str - Name of column to be synthesised. + Name of the column to be synthesised. marginal : np.ndarray Marginal estimated from the graphical model for the column - and all the columns it depends on. + and its prerequisites. prng : dask.array.random.Generator Pseudo-random number generator. Used to synthesise the - column within this group. + column within groups in this partition. Returns ------- - group : - Group with new synthetic column. + partition : dask.dataframe.DataFrame + Partition with new synthetic column. """ - idx = group.name - group[column] = MST._synthesise_column( - marginal[idx], group.shape[0], prng, chunksize + def synthesise_in_group(group): + """Synthesise a column within a groupby-apply operation.""" + + idx = group.name + group[column] = MST._synthesise_column( + marginal[idx], group.shape[0], prng, chunksize + ) + + return group + + partition = ( + partition.groupby(list(clique)) + .apply(synthesise_in_group) + .reset_index(drop=True) ) - return group + return partition @staticmethod def _synthesise_first_column(model, column, nrows, prng): @@ -561,7 +578,8 @@ def _find_prerequisite_columns(column, cliques, used): return tuple(prerequisites) - def generate(self, model, nrows=None, seed=None): + @staticmethod + def generate(model, nrows=None, seed=None): """ Generate a synthetic dataset from the estimated model. @@ -590,34 +608,33 @@ def generate(self, model, nrows=None, seed=None): executed. """ - nrows, prng, cliques, column, order = self._setup_generate( + nrows, prng, cliques, column, order = MST._setup_generate( model, nrows, seed ) - data = self._synthesise_first_column(model, column, nrows, prng) + data = MST._synthesise_first_column(model, column, nrows, prng) used = {column} for column in order: - clique = self._find_prerequisite_columns(column, cliques, used) + prerequisites = MST._find_prerequisite_columns( + column, cliques, used + ) used.add(column) - marginal = model.project(clique + (column,)).datavector( + marginal = model.project(prerequisites + (column,)).datavector( flatten=False ) - if len(clique) >= 1: - data = ( - data.groupby(list(clique)) - .apply( - self._synthesise_column_in_group, - column, - marginal, - prng, - meta={**data.dtypes, column: int}, - ) - .reset_index(drop=True) + if len(prerequisites) >= 1: + data = data.map_partitions( + MST._synthesise_column_in_group_by_partition, + prerequisites, + column, + marginal, + prng, + meta={**data.dtypes, column: int}, ) else: - data[column] = self._synthesise_column(marginal, nrows, prng) + data[column] = MST._synthesise_column(marginal, nrows, prng) data = data.repartition("100MB") diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py index 59f4bc4..40ba1aa 100644 --- a/tests/mst/test_generate.py +++ b/tests/mst/test_generate.py @@ -78,7 +78,7 @@ def test_synthesise_column(marginal, total): @given(st_existing_new_columns()) -def test_synthesise_column_in_group(params): +def test_synthesise_column_in_group_by_partition(params): """Test that a dependent column can be synthesised in groups.""" existing, new = params @@ -89,12 +89,12 @@ def test_synthesise_column_in_group(params): with mock.patch("centhesus.mst.MST._synthesise_column") as synth: synth.return_value = new - synthetic = ( - existing.copy() - .groupby("a") - .apply( - MST._synthesise_column_in_group, column, empty_marginal, prng - ) + synthetic = MST._synthesise_column_in_group_by_partition( + existing.copy(), + ["a"], + column, + empty_marginal, + prng, ) assert isinstance(synthetic, pd.DataFrame) @@ -156,3 +156,80 @@ def test_find_prerequisite_columns(params): ) assert isinstance(prerequisites, tuple) assert set(prerequisites) == expected + + +@given( + st.integers(1, 100), + st.lists(st.text(), min_size=2, max_size=10, unique=True), +) +def test_generate(nrows, params): + """Test that generation can be executed correctly.""" + + column, *order = params + + prng = da.random.default_rng(0) + + data = mock.MagicMock() + data.dtypes = {"data": "dtypes"} + data.map_partitions.return_value = data + data.repartition.return_value = data + + marginal = mock.MagicMock() + + model = mock.MagicMock() + model.project.return_value.datavector.return_value = marginal + + with mock.patch("centhesus.mst.MST._setup_generate") as setup, mock.patch( + "centhesus.mst.MST._synthesise_first_column" + ) as first, mock.patch( + "centhesus.mst.MST._find_prerequisite_columns" + ) as find, mock.patch( + "centhesus.mst.MST._synthesise_column" + ) as synth: + setup.return_value = (nrows, prng, "cliques", column, order) + first.return_value = data + find.return_value = ("prerequisites",) + synth.return_value = "independent" + + synthetic = MST.generate(model, nrows) + + setup.assert_called_once_with(model, nrows, None) + first.assert_called_once_with(model, column, nrows, prng) + + used = {column} + num_subsequent_columns = len(order) + assert find.call_count == num_subsequent_columns + for call, col in zip(find.call_args_list, order): + assert tuple(call.args[:-1]) == (col, "cliques") + used.add(col) + + assert used == set((column, *order)) + + assert model.project.call_count == num_subsequent_columns + for call, col in zip(model.project.call_args_list, order): + assert call.args == (("prerequisites", col),) + + assert ( + model.project.return_value.datavector.call_count + == num_subsequent_columns + ) + for call in model.project.return_value.datavector.call_args_list: + assert call.args == () + assert call.kwargs == {"flatten": False} + + assert data.map_partitions.call_count == num_subsequent_columns + for call, col in zip(data.map_partitions.call_args_list, order): + assert call.args == ( + MST._synthesise_column_in_group_by_partition, + ("prerequisites",), + col, + marginal, + prng, + ) + assert call.kwargs == {"meta": {"data": "dtypes", col: int}} + + synth.assert_not_called() + + data.repartition.assert_called_once_with("100MB") + + assert synthetic is data From e537472c264b5736e2f4c175357223c74e0a384e Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Mon, 27 Nov 2023 12:20:59 +0000 Subject: [PATCH 45/50] Write test for independent columns in generate --- tests/mst/test_generate.py | 58 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/mst/test_generate.py b/tests/mst/test_generate.py index 40ba1aa..bebc564 100644 --- a/tests/mst/test_generate.py +++ b/tests/mst/test_generate.py @@ -233,3 +233,61 @@ def test_generate(nrows, params): data.repartition.assert_called_once_with("100MB") assert synthetic is data + + +@given( + st.integers(1, 100), + st.lists(st.text(), min_size=2, max_size=10, unique=True), +) +def test_generate_with_extra_independents(nrows, params): + """Test generation executes with multiple independent columns.""" + + column, *order = params + + prng = da.random.default_rng(0) + + data = mock.MagicMock() + data.repartition.return_value = data + + model = mock.MagicMock() + marginal = mock.MagicMock() + model.project.return_value.datavector.return_value = marginal + + with mock.patch("centhesus.mst.MST._setup_generate") as setup, mock.patch( + "centhesus.mst.MST._synthesise_first_column" + ) as first, mock.patch( + "centhesus.mst.MST._find_prerequisite_columns" + ) as find, mock.patch( + "centhesus.mst.MST._synthesise_column" + ) as synth: + setup.return_value = (nrows, prng, "cliques", column, order) + first.return_value = data + find.return_value = () + synth.return_value = "independent" + + synthetic = MST.generate(model, nrows) + + setup.assert_called_once_with(model, nrows, None) + first.assert_called_once_with(model, column, nrows, prng) + + num_subsequent_columns = len(order) + assert model.project.call_count == num_subsequent_columns + for call, col in zip(model.project.call_args_list, order): + assert call.args == ((col,),) + + assert ( + model.project.return_value.datavector.call_count + == num_subsequent_columns + ) + for call in model.project.return_value.datavector.call_args_list: + assert call.args == () + assert call.kwargs == {"flatten": False} + + assert synth.call_count == num_subsequent_columns + for call, col in zip(synth.call_args_list, order): + assert call.args == (marginal, nrows, prng) + assert hasattr(data, col) + + data.repartition.assert_called_once_with("100MB") + + assert synthetic is data From bf341dcbbc74b898bd3cbc3ab609e34f3987b2ce Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Mon, 27 Nov 2023 15:34:08 +0000 Subject: [PATCH 46/50] Skip patched delayed tests for 3.8 Python 3.8 and its dependencies seem to ignore or strip out the patched functions in two tests: `test_measure` and `test_calculate_importances`. For now, I'm skipping over them and hopefully the fact they pass for 3.11, and the forthcoming integration tests, will be sufficient. --- tests/mst/test_measure.py | 3 +++ tests/mst/test_select.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/tests/mst/test_measure.py b/tests/mst/test_measure.py index d8a1f79..41e2767 100644 --- a/tests/mst/test_measure.py +++ b/tests/mst/test_measure.py @@ -1,9 +1,11 @@ """Unit tests for the measurement methods in `centhesus.MST`.""" +import platform from unittest import mock import numpy as np import pandas as pd +import pytest from hypothesis import given, settings from hypothesis import strategies as st from scipy import sparse @@ -49,6 +51,7 @@ def test_get_marginal_failed_call(params, flatten): query.assert_called_once() +@pytest.mark.skipif(tuple(map(int, platform.python_version_tuple())) > (3, 8)) @settings(deadline=None) @given(st_single_marginals(), st.integers(1, 5)) def test_measure(params, num_cliques): diff --git a/tests/mst/test_select.py b/tests/mst/test_select.py index c5ffbac..2ca27cc 100644 --- a/tests/mst/test_select.py +++ b/tests/mst/test_select.py @@ -1,9 +1,11 @@ """Unit tests for the selection methods in `centhesus.MST`.""" import itertools +import platform from unittest import mock import networkx as nx +import pytest from hypothesis import given, settings from ..strategies import ( @@ -57,6 +59,7 @@ def test_calculate_importance_of_pair_failed_call(params): get_marginal.assert_called_once_with(clique) +@pytest.mark.skipif(tuple(map(int, platform.python_version_tuple())) > (3, 8)) @settings(deadline=None) @given(st_importances()) def test_calculate_importances(params): @@ -66,6 +69,7 @@ def test_calculate_importances(params): mst = mocked_mst(population_type, area_type, dimensions, domain=domain) with mock.patch("centhesus.mst.MST._calculate_importance_of_pair") as calc: + mst._calculate_importances = calc calc.side_effect = importances weights = mst._calculate_importances("interim") From fc341afa7e9a2f2162b2099f3699b229ed6fa970 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 28 Nov 2023 09:16:27 +0000 Subject: [PATCH 47/50] Fix skipif expressions to skip 3.8 and below --- tests/mst/test_measure.py | 5 ++++- tests/mst/test_select.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/mst/test_measure.py b/tests/mst/test_measure.py index 41e2767..4df0f3a 100644 --- a/tests/mst/test_measure.py +++ b/tests/mst/test_measure.py @@ -51,7 +51,10 @@ def test_get_marginal_failed_call(params, flatten): query.assert_called_once() -@pytest.mark.skipif(tuple(map(int, platform.python_version_tuple())) > (3, 8)) +@pytest.mark.skipif( + tuple(map(int, platform.python_version_tuple())) < (3, 9), + reason="Requires Python 3.9+", +) @settings(deadline=None) @given(st_single_marginals(), st.integers(1, 5)) def test_measure(params, num_cliques): diff --git a/tests/mst/test_select.py b/tests/mst/test_select.py index 2ca27cc..6b7fd96 100644 --- a/tests/mst/test_select.py +++ b/tests/mst/test_select.py @@ -59,7 +59,10 @@ def test_calculate_importance_of_pair_failed_call(params): get_marginal.assert_called_once_with(clique) -@pytest.mark.skipif(tuple(map(int, platform.python_version_tuple())) > (3, 8)) +@pytest.mark.skipif( + tuple(map(int, platform.python_version_tuple())) < (3, 9), + reason="Requires Python 3.9+", +) @settings(deadline=None) @given(st_importances()) def test_calculate_importances(params): @@ -69,7 +72,6 @@ def test_calculate_importances(params): mst = mocked_mst(population_type, area_type, dimensions, domain=domain) with mock.patch("centhesus.mst.MST._calculate_importance_of_pair") as calc: - mst._calculate_importances = calc calc.side_effect = importances weights = mst._calculate_importances("interim") From baf01f761d399f82f77e5a3740686401445d1cf2 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 28 Nov 2023 09:19:45 +0000 Subject: [PATCH 48/50] Separate CI workflow to check coverage on 3.11+ --- .github/workflows/tests.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d973808..b53462f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,7 +30,12 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install ".[test]" pytest-sugar - - name: Run tests + - name: Run tests (<3.9) + if: matrix.python-version < 3.9 + run: | + python -m pytest --cov=centhesus tests + - name: Run tests (>=3.11) + if: matrix.python-version >= 3.11 run: | python -m pytest --cov=centhesus --cov-fail-under=100 tests - name: Install and run linters From 034fdca3af4d3668abf68394e44088756cad8b48 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 28 Nov 2023 09:39:02 +0000 Subject: [PATCH 49/50] Separate python version components in CI workflow --- .github/workflows/tests.yml | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b53462f..67313c5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,30 +16,34 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: [3.8, 3.11] + py-major: [3] + py-minor: [8, 11] + env: + python-version: | + ${{ format('{0}.{1}', matrix.py-major, matrix.py-minor) }} steps: - name: Checkout repository uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python ${{ env.python-version }} uses: actions/setup-python@v4 with: - python-version: ${{ matrix.python-version }} + python-version: ${{ env.python-version }} cache: "pip" - name: Update pip and install dependencies run: | python -m pip install --upgrade pip python -m pip install ".[test]" pytest-sugar - name: Run tests (<3.9) - if: matrix.python-version < 3.9 + if: matrix.py-minor < 9 run: | python -m pytest --cov=centhesus tests - name: Run tests (>=3.11) - if: matrix.python-version >= 3.11 + if: matrix.py-minor >= 11 run: | python -m pytest --cov=centhesus --cov-fail-under=100 tests - name: Install and run linters - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.11 + if: matrix.os == 'ubuntu-latest' && matrix.py-minor == 11 run: | python -m pip install ".[lint]" python -m black --check . From 02496f606e0fc81ef16dfcbcf6a7a6e828799182 Mon Sep 17 00:00:00 2001 From: Henry Wilde Date: Tue, 28 Nov 2023 09:53:19 +0000 Subject: [PATCH 50/50] Force synchronous (serial) computation in tests --- tests/mst/test_measure.py | 5 ++++- tests/mst/test_select.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/mst/test_measure.py b/tests/mst/test_measure.py index 4df0f3a..ad8cfdd 100644 --- a/tests/mst/test_measure.py +++ b/tests/mst/test_measure.py @@ -3,6 +3,7 @@ import platform from unittest import mock +import dask import numpy as np import pandas as pd import pytest @@ -63,7 +64,9 @@ def test_measure(params, num_cliques): population_type, area_type, dimensions, clique, table = params mst = mocked_mst(population_type, area_type, dimensions) - with mock.patch("centhesus.mst.MST.get_marginal") as get_marginal: + with mock.patch( + "centhesus.mst.MST.get_marginal" + ) as get_marginal, dask.config.set(scheduler="synchronous"): get_marginal.return_value = table measurements = mst.measure([clique] * num_cliques) diff --git a/tests/mst/test_select.py b/tests/mst/test_select.py index 6b7fd96..004c8d3 100644 --- a/tests/mst/test_select.py +++ b/tests/mst/test_select.py @@ -4,6 +4,7 @@ import platform from unittest import mock +import dask import networkx as nx import pytest from hypothesis import given, settings @@ -71,7 +72,9 @@ def test_calculate_importances(params): population_type, area_type, dimensions, domain, importances = params mst = mocked_mst(population_type, area_type, dimensions, domain=domain) - with mock.patch("centhesus.mst.MST._calculate_importance_of_pair") as calc: + with mock.patch( + "centhesus.mst.MST._calculate_importance_of_pair" + ) as calc, dask.config.set(scheduler="synchronous"): calc.side_effect = importances weights = mst._calculate_importances("interim")