Skip to content

Commit

Permalink
Add dim name to families with categorical response levels
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Jul 4, 2024
1 parent d2904f1 commit 20acf42
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
GP_KERNELS,
)
from bambi.families.multivariate import MultivariateFamily, Multinomial, DirichletMultinomial
from bambi.families.univariate import Categorical
from bambi.families.univariate import Categorical, Cumulative, StoppingRatio
from bambi.priors import Prior


Expand Down Expand Up @@ -234,7 +234,17 @@ def build(self, pymc_backend, bmb_model):
# Auxiliary parameters and data
kwargs = {"observed": data, "dims": ("__obs__",)}

if isinstance(self.family, (MultivariateFamily, Categorical)):
if isinstance(
self.family,
(
MultivariateFamily,
Categorical,
Cumulative,
StoppingRatio,
Multinomial,
DirichletMultinomial,
),
):
response_term = bmb_model.response_component.term
response_name = response_term.alias or response_term.name
dim_name = response_name + "_dim"
Expand Down Expand Up @@ -386,10 +396,6 @@ def robustify_dims(self, pymc_backend, kwargs):
# In this case, we add extra dimensions to avoid having shape mismatch between the data
# and the shape implied by the `dims` we pass.

# Don't do it for the Multinomial families (it's an exception)
if isinstance(self.family, (Multinomial, DirichletMultinomial)):
return kwargs

if (
self.term.is_censored
or self.term.is_truncated
Expand Down

0 comments on commit 20acf42

Please sign in to comment.