Skip to content

Commit

Permalink
Implement importance calculator; idx task execut'n
Browse files Browse the repository at this point in the history
  • Loading branch information
daffidwilde committed Nov 15, 2023
1 parent 220dee2 commit 5e56172
Showing 1 changed file with 52 additions and 7 deletions.
59 changes: 52 additions & 7 deletions src/centhesus/mst.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Module for the Maximum Spanning Tree generator."""

import itertools

import dask
import numpy as np
from census21api import CensusAPI
Expand Down Expand Up @@ -183,6 +185,8 @@ def measure(self, cliques):
If a column pair has been blocked by the API, then their
marginal is `None` and we skip over them.
We use `dask` to compute these marginals in parallel.
Parameters
----------
cliques : iterable of tuple
Expand All @@ -197,14 +201,14 @@ def measure(self, cliques):

tasks = []
for clique in cliques:
marginal = dask.delayed(self.get_marginal)(clique)
tasks.append(marginal)
get_marginal = dask.delayed(lambda x: (x, self.get_marginal(x)))
tasks.append(get_marginal(clique))

marginals = dask.compute(*tasks)
indexed_marginals = dask.compute(*tasks)

measurements = [
(sparse.eye(marginal.size), marginal, 1e-12, clique)
for marginal, clique in zip(marginals, cliques)
for clique, marginal in indexed_marginals
if marginal is not None
]

Expand Down Expand Up @@ -250,13 +254,54 @@ def _calculate_importance_of_pair(self, interim, pair):
Returns
-------
weight : float
pair : tuple of str
Assessed column pair.
weight : float or None
Importance of the pair given as the negative of the L1 norm
between the observed and estimated marginals for the pair.
If the API call fails, this is `None`.
"""

estimate = interim.project(pair).datavector()
weight = None
marginal = self.get_marginal(pair)
weight = -np.linalg.norm(marginal - estimate, 1)
if marginal is not None:
estimate = interim.project(pair).datavector()
weight = -np.linalg.norm(marginal - estimate, 1)

return weight

def _calculate_importances(self, interim):
"""
Determine every column pair's importance given an interim model.
We use `dask` to compute these importances in parallel.
Parameters
----------
interim : mbi.GraphicalModel
Interim model based on one-way marginals only.
Returns
-------
weights : dict
Dictionary mapping column pairs to their weight. If a column
pair is blocked by the API, it is skipped.
"""

pairs = list(itertools.combinations(self.domain.attrs, 2))
tasks = []
for pair in pairs:
calculate_importance = dask.delayed(
lambda x: (x, self._calculate_importance_of_pair(interim, x))
)
tasks.append(calculate_importance(pair))

indexed_importances = dask.compute(*tasks)

weights = {
pair: importance
for pair, importance in indexed_importances
if importance is not None
}

return weights

0 comments on commit 5e56172

Please sign in to comment.