Skip to content

Commit

Permalink
Write measure method
Browse files Browse the repository at this point in the history
  • Loading branch information
daffidwilde committed Nov 15, 2023
1 parent 81d8cab commit 8eeb979
Showing 1 changed file with 59 additions and 12 deletions.
71 changes: 59 additions & 12 deletions src/centhesus/mst.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Module for the Maximum Spanning Tree generator."""

import dask
from census21api import CensusAPI
from census21api.constants import DIMENSIONS_BY_POPULATION_TYPE as DIMENSIONS
from mbi import Domain
from scipy import sparse


class MST:
Expand Down Expand Up @@ -141,23 +143,68 @@ def get_marginal(self, clique, flatten=True):
Returns
-------
marginal : numpy.ndarray or pandas.Series
Marginal table. If `flatten` is True, this a flat array.
marginal : numpy.ndarray or pandas.Series or None
Marginal table if the API call succeeds and `None` if not.
On a success, if `flatten` is `True`, this a flat array.
Otherwise, the indexed series is returned.
"""

area_type = self.area_type or "nat"
dimensions = [
col for col in clique if col != area_type
] or self.dimensions[0:1]

marginal = (
self.api.query_table(self.population_type, area_type, dimensions)
.groupby(list(clique))["count"]
.sum()
dimensions = [col for col in clique if col != area_type]
if not dimensions:
dimensions = self.dimensions[0:1]

marginal = self.api.query_table(
self.population_type, area_type, dimensions
)

if flatten is True:
marginal = marginal.to_numpy().flatten()
if marginal is not None:
marginal = marginal.groupby(list(clique))["count"].sum()
if flatten is True:
marginal = marginal.to_numpy().flatten()

return marginal

def measure(self, cliques):
"""
Measure the marginals of a set of cliques.
This function returns a list of "measurements" to be passed to
the `mbi` package. Each measurement consists of a sparse
identity matrix, the marginal table, a nominally small float
representing the "noise" added to the marginal, and the clique
associated with the marginal.
Although we are not applying differential privacy to our tables,
`mbi` requires non-zero noise for each measurement to form the
graphical model.
If a column pair has been blocked by the API, then their
marginal is `None` and we skip over them.
Parameters
----------
cliques : iterable of tuple
The cliques to measure. These cliques should be of the form
`(col,)` or `(col1, col2)`.
Returns
-------
measurements : list of tuple
Measurement tuples for each clique.
"""

tasks = []
for clique in cliques:
marginal = dask.delayed(self.get_marginal)(clique)
tasks.append(marginal)

marginals = dask.compute(*tasks)

measurements = [
(sparse.eye(marginal.size), marginal, 1e-12, clique)
for marginal, clique in zip(marginals, cliques)
if marginal is not None
]

return measurements

0 comments on commit 8eeb979

Please sign in to comment.