From 1559a97173900c8c503e462bd0da4d3827bff32d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Mon, 16 Dec 2024 05:41:47 -0600 Subject: [PATCH] Handle multivariate responses with HSGP (#856) * Make HSGP terms aware of multivariate responses * Make sure two dimensional outputs have two dims * Remove redundant classes from checks * Remove prints and add comments * Remove commented code --- bambi/backend/model_components.py | 6 +++--- bambi/backend/terms.py | 33 +++++++++++++++++++------------ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/bambi/backend/model_components.py b/bambi/backend/model_components.py index 7ba3abee..996bd03c 100644 --- a/bambi/backend/model_components.py +++ b/bambi/backend/model_components.py @@ -53,7 +53,7 @@ def build(self, pymc_backend, bmb_model): self.build_intercept(bmb_model) self.build_offsets() self.build_common_terms(pymc_backend, bmb_model) - self.build_hsgp_terms(pymc_backend) + self.build_hsgp_terms(bmb_model, pymc_backend) self.build_group_specific_terms(pymc_backend, bmb_model) def build_intercept(self, bmb_model): @@ -109,7 +109,7 @@ def build_common_terms(self, pymc_backend, bmb_model): # Add term to linear predictor self.output += pt.dot(data, coefs) - def build_hsgp_terms(self, pymc_backend): + def build_hsgp_terms(self, bmb_model, pymc_backend): """Add HSGP (Hilbert-Space Gaussian Process approximation) terms to the PyMC model. The linear predictor 'X @ b + Z @ u' can be augmented with non-parametric HSGP terms @@ -120,7 +120,7 @@ def build_hsgp_terms(self, pymc_backend): for name, values in hsgp_term.coords.items(): if name not in pymc_backend.model.coords: pymc_backend.model.add_coords({name: values}) - self.output += hsgp_term.build() + self.output += hsgp_term.build(bmb_model) def build_group_specific_terms(self, pymc_backend, bmb_model): """Add group-specific (random or varying) terms to the PyMC model diff --git a/bambi/backend/terms.py b/bambi/backend/terms.py index 8da04ed1..e1265fee 100644 --- a/bambi/backend/terms.py +++ b/bambi/backend/terms.py @@ -12,7 +12,7 @@ make_weighted_distribution, GP_KERNELS, ) -from bambi.families.multivariate import MultivariateFamily, Multinomial, DirichletMultinomial +from bambi.families.multivariate import MultivariateFamily from bambi.families.univariate import Categorical, Cumulative, StoppingRatio from bambi.priors import Prior @@ -234,22 +234,16 @@ def build(self, pymc_backend, bmb_model): # Auxiliary parameters and data kwargs = {"observed": data, "dims": ("__obs__",)} - if isinstance( - self.family, - ( - MultivariateFamily, - Categorical, - Cumulative, - StoppingRatio, - Multinomial, - DirichletMultinomial, - ), - ): + if isinstance(self.family, (MultivariateFamily, Categorical, Cumulative, StoppingRatio)): response_term = bmb_model.response_component.term response_name = response_term.alias or response_term.name dim_name = response_name + "_dim" pymc_backend.model.add_coords({dim_name: response_term.levels}) dims = ("__obs__", dim_name) + + # For multivariate families, the outcome variable has two dimensions too. + if isinstance(self.family, MultivariateFamily): + kwargs["dims"] = dims else: dims = ("__obs__",) @@ -447,7 +441,7 @@ def __init__(self, term): if self.term.by_levels is not None: self.coords[f"{self.term.alias}_by"] = self.coords.pop(f"{self.term.name}_by") - def build(self): + def build(self, spec): # Get the name of the term label = self.name @@ -507,6 +501,19 @@ def build(self): phi = phi.eval() # Build weights coefficient + # Handle the case where the outcome is multivariate + if isinstance(spec.family, (MultivariateFamily, Categorical)): + # Append the dims of the response variables to the coefficient and contribution dims + # In general: + # coeff_dims: ('weights_dim', ) -> ('weights_dim', f'{response}_dim') + # contribution_dims: ('__obs__', ) -> ('__obs__', f'{response}_dim') + response_dims = tuple(spec.response_component.term.coords) + coeff_dims = coeff_dims + response_dims + contribution_dims = contribution_dims + response_dims + + # Append a dimension to sqrt_psd: ('weights_dim', ) -> ('weights_dim', 1) + sqrt_psd = sqrt_psd[:, np.newaxis] + if self.term.centered: coeffs = pm.Normal(f"{label}_weights", sigma=sqrt_psd, dims=coeff_dims) else: