Skip to content

Commit

Permalink
Write test for calculating importances
Browse files Browse the repository at this point in the history
Weird indexing problems with Dask... will need to improve the `measure`
test(s)...
  • Loading branch information
daffidwilde committed Nov 15, 2023
1 parent 49bc5e9 commit 220dee2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Custom strategies for testing the package."""

import itertools
import math

import numpy as np
import pandas as pd
Expand All @@ -10,6 +11,7 @@
POPULATION_TYPES,
)
from hypothesis import strategies as st
from mbi import Domain


@st.composite
Expand Down Expand Up @@ -79,3 +81,24 @@ def st_single_marginals(draw, kind=None):
marginal["count"] = counts

return population_type, area_type, dimensions, clique, marginal


@st.composite
def st_importances(draw):
"""Create a domain and set of importances for a test."""

population_type, area_type, dimensions = draw(st_api_parameters())
num = len(dimensions) + 1

sizes = draw(st.lists(st.integers(2, 10), min_size=num, max_size=num))
domain = Domain.fromdict(dict(zip((area_type, *dimensions), sizes)))

importances = draw(
st.lists(
st.floats(max_value=0, allow_infinity=False, allow_nan=False),
min_size=math.comb(num, 2),
max_size=math.comb(num, 2),
)
)

return population_type, area_type, dimensions, domain, importances
27 changes: 27 additions & 0 deletions tests/test_mst.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the `centhesus.mst` module."""

import itertools
import string
from unittest import mock

Expand All @@ -20,6 +21,7 @@
from .strategies import (
st_api_parameters,
st_feature_metadata_parameters,
st_importances,
st_single_marginals,
)

Expand Down Expand Up @@ -292,3 +294,28 @@ def test_calculate_importance_of_pair_failed_call(params):
interim.project.assert_not_called()
interim.project.return_value.datavector.assert_not_called()
get_marginal.assert_called_once_with(clique)


@settings(deadline=None)
@given(st_importances())
def test_calculate_importances(params):
"""Test that a set of importances can be calculated."""

population_type, area_type, dimensions, domain, importances = params
mst = mocked_mst(population_type, area_type, dimensions, domain=domain)

with mock.patch("centhesus.mst.MST._calculate_importance_of_pair") as calc:
calc.side_effect = importances
weights = mst._calculate_importances("interim")

pairs = list(itertools.combinations(domain, 2))
calc.call_count == len(pairs)
call_args = [call.args for call in calc.call_args_list]
assert set(call_args) == set(("interim", pair) for pair in pairs)

assert isinstance(weights, dict)
assert set(weights.keys()) == set(pairs)

pairs_execution_order = [pair for _, pair in call_args]
for pair, importance in zip(pairs_execution_order, importances):
assert weights[pair] == importance

0 comments on commit 220dee2

Please sign in to comment.