diff --git a/src/centhesus/mst.py b/src/centhesus/mst.py index f753e5d..5faa0a4 100644 --- a/src/centhesus/mst.py +++ b/src/centhesus/mst.py @@ -1,5 +1,7 @@ """Module for the Maximum Spanning Tree generator.""" +import itertools + import dask import numpy as np from census21api import CensusAPI @@ -183,6 +185,8 @@ def measure(self, cliques): If a column pair has been blocked by the API, then their marginal is `None` and we skip over them. + We use `dask` to compute these marginals in parallel. + Parameters ---------- cliques : iterable of tuple @@ -197,14 +201,14 @@ def measure(self, cliques): tasks = [] for clique in cliques: - marginal = dask.delayed(self.get_marginal)(clique) - tasks.append(marginal) + get_marginal = dask.delayed(lambda x: (x, self.get_marginal(x))) + tasks.append(get_marginal(clique)) - marginals = dask.compute(*tasks) + indexed_marginals = dask.compute(*tasks) measurements = [ (sparse.eye(marginal.size), marginal, 1e-12, clique) - for marginal, clique in zip(marginals, cliques) + for clique, marginal in indexed_marginals if marginal is not None ] @@ -250,13 +254,54 @@ def _calculate_importance_of_pair(self, interim, pair): Returns ------- - weight : float + pair : tuple of str + Assessed column pair. + weight : float or None Importance of the pair given as the negative of the L1 norm between the observed and estimated marginals for the pair. + If the API call fails, this is `None`. """ - estimate = interim.project(pair).datavector() + weight = None marginal = self.get_marginal(pair) - weight = -np.linalg.norm(marginal - estimate, 1) + if marginal is not None: + estimate = interim.project(pair).datavector() + weight = -np.linalg.norm(marginal - estimate, 1) return weight + + def _calculate_importances(self, interim): + """ + Determine every column pair's importance given an interim model. + + We use `dask` to compute these importances in parallel. + + Parameters + ---------- + interim : mbi.GraphicalModel + Interim model based on one-way marginals only. + + Returns + ------- + weights : dict + Dictionary mapping column pairs to their weight. If a column + pair is blocked by the API, it is skipped. + """ + + pairs = list(itertools.combinations(self.domain.attrs, 2)) + tasks = [] + for pair in pairs: + calculate_importance = dask.delayed( + lambda x: (x, self._calculate_importance_of_pair(interim, x)) + ) + tasks.append(calculate_importance(pair)) + + indexed_importances = dask.compute(*tasks) + + weights = { + pair: importance + for pair, importance in indexed_importances + if importance is not None + } + + return weights