Skip to content

Commit

Permalink
Implement marginal getter
Browse files Browse the repository at this point in the history
  • Loading branch information
daffidwilde committed Nov 14, 2023
1 parent 315393c commit 81f54bc
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/centhesus/mst.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,43 @@ def get_domain(self):
domain = Domain.fromdict({**area_type_domain, **dimension_domain})

return domain

def get_marginal(self, clique, flatten=True):
"""
Retrieve the marginal table for a clique from the API.
This function also returns the metadata to "measure" the
marginal in the package that underpins the synthesis, `mbi`.
Parameters
----------
clique : tuple of str
Tuple defining the columns of the clique to be measured.
Should be of the form `(col,)` or `(col1, col2)`.
flatten : bool
Whether the marginal should be flattened or not. Default is
`True` to work with `mbi`. Flattened marginals are NumPy
arrays rather than Pandas series.
Returns
-------
marginal : numpy.ndarray or pandas.Series
Marginal table. If `flatten` is True, this a flat array.
Otherwise, the indexed series is returned.
"""

area_type = self.area_type or "nat"
dimensions = [
col for col in clique if col != area_type
] or self.dimensions[0:1]

marginal = (
self.api.query_table(self.population_type, area_type, dimensions)
.groupby(list(clique))["count"]
.sum()
)

if flatten is True:
marginal = marginal.to_numpy().flatten()

return marginal

0 comments on commit 81f54bc

Please sign in to comment.