diff --git a/.github/workflows/run_fast_tests.yml b/.github/workflows/run_fast_tests.yml index 431a3167..e9f7370d 100644 --- a/.github/workflows/run_fast_tests.yml +++ b/.github/workflows/run_fast_tests.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: true matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - name: Checkout repository diff --git a/.github/workflows/run_slow_tests.yml b/.github/workflows/run_slow_tests.yml index 64542474..72e4c961 100644 --- a/.github/workflows/run_slow_tests.yml +++ b/.github/workflows/run_slow_tests.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: true matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - name: Checkout repository diff --git a/docs/tutorials/pymc.ipynb b/docs/tutorials/pymc.ipynb index b55cf395..ceb7d6ac 100644 --- a/docs/tutorials/pymc.ipynb +++ b/docs/tutorials/pymc.ipynb @@ -659,7 +659,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index f9a635f7..a238ce85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,16 +15,17 @@ repository = "https://github.com/lnccbrown/HSSM" keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"] [tool.poetry.dependencies] -python = ">=3.10,<3.12" +python = ">=3.10,<=3.12" pymc = ">=5.16.2,<5.17.0" arviz = "^0.19.0" onnx = "^1.16.0" -ssm-simulators = "^0.7.2" +ssm-simulators = "^0.7.5" huggingface-hub = "^0.24.6" bambi = ">=0.14.0,<0.15.0" numpyro = "^0.15.2" hddm-wfpt = "^0.1.4" seaborn = "^0.13.2" +tqdm= "^4.66.0" jax = { version = "^0.4.25", extras = ["cuda12"], optional = true } numpy = ">=1.26.4,<2.0.0" diff --git a/src/hssm/config.py b/src/hssm/config.py index 43167ec6..ad5ac3b2 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -22,6 +22,7 @@ class Config: model_name: SupportedModels | str loglik_kind: LoglikKind response: list[str] | None = None + choices: list[int] | None = None list_params: list[str] | None = None description: str | None = None loglik: LogLik | None = None @@ -63,6 +64,7 @@ def from_defaults( model_name, loglik_kind=kind, response=default_config["response"], + choices=default_config["choices"], list_params=default_config["list_params"], description=default_config["description"], **loglik_config, @@ -90,6 +92,7 @@ def from_defaults( model_name, loglik_kind=loglik_kind, response=default_config["response"], + choices=default_config["choices"], list_params=default_config["list_params"], description=default_config["description"], **loglik_config, @@ -98,6 +101,7 @@ def from_defaults( model_name, loglik_kind=loglik_kind, response=default_config["response"], + choices=default_config["choices"], list_params=default_config["list_params"], description=default_config["description"], ) @@ -117,18 +121,33 @@ def update_loglik(self, loglik: Any | None) -> None: self.loglik = loglik + def update_choices(self, choices: list[int] | None) -> None: + """Update the choices from user input. + + Parameters + ---------- + choices : list[int] + A list of choices. + """ + if choices is None: + return + + self.choices = choices + def update_config(self, user_config: ModelConfig) -> None: """Update the object from a ModelConfig object. Parameters ---------- - loglik : optional - A user-defined log-likelihood function. + user_config: ModelConfig + User specified ModelConfig used update self. """ if user_config.response is not None: self.response = user_config.response if user_config.list_params is not None: self.list_params = user_config.list_params + if user_config.choices is not None: + self.choices = user_config.choices if ( self.loglik_kind == "approx_differentiable" @@ -146,6 +165,8 @@ def validate(self) -> None: raise ValueError("Please provide `response` via `model_config`.") if self.list_params is None: raise ValueError("Please provide `list_params` via `model_config`.") + if self.choices is None: + raise ValueError("Please provide `choices` via `model_config`.") if self.loglik is None: raise ValueError("Please provide a log-likelihood function via `loglik`.") if self.loglik_kind == "approx_differentiable" and self.backend is None: @@ -170,6 +191,7 @@ class ModelConfig: response: list[str] | None = None list_params: list[str] | None = None + choices: list[int] | None = None default_priors: dict[str, ParamSpec] = field(default_factory=dict) bounds: dict[str, tuple[float, float]] = field(default_factory=dict) backend: Literal["jax", "pytensor"] | None = None diff --git a/src/hssm/defaults.py b/src/hssm/defaults.py index 5e6d090d..b46f7782 100644 --- a/src/hssm/defaults.py +++ b/src/hssm/defaults.py @@ -14,8 +14,14 @@ ddm_params, ddm_sdv_bounds, ddm_sdv_params, + lba2_bounds, + lba2_params, + lba3_bounds, + lba3_params, logp_ddm, logp_ddm_sdv, + logp_lba2, + logp_lba3, ) from .likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox, logp_full_ddm from .param import ParamSpec, _make_default_prior @@ -32,6 +38,8 @@ "weibull", "race_no_bias_angle_4", "ddm_seq2_no_bias", + "lba3", + "lba2", ] LoglikKind = Literal["analytical", "approx_differentiable", "blackbox"] @@ -72,6 +80,7 @@ class DefaultConfig(TypedDict): response: list[str] list_params: list[str] + choices: list[int] description: Optional[str] likelihoods: LoglikConfigs @@ -82,6 +91,7 @@ class DefaultConfig(TypedDict): "ddm": { "response": ["rt", "response"], "list_params": ddm_params, + "choices": [-1, 1], "description": "The Drift Diffusion Model (DDM)", "likelihoods": { "analytical": { @@ -130,6 +140,7 @@ class DefaultConfig(TypedDict): "ddm_sdv": { "response": ["rt", "response"], "list_params": ddm_sdv_params, + "choices": [-1, 1], "description": "The Drift Diffusion Model (DDM) with standard deviation for v", "likelihoods": { "analytical": { @@ -179,6 +190,7 @@ class DefaultConfig(TypedDict): "full_ddm": { "response": ["rt", "response"], "list_params": ["v", "a", "z", "t", "sz", "sv", "st"], + "choices": [-1, 1], "description": "The full Drift Diffusion Model (DDM)", "likelihoods": { "blackbox": { @@ -195,9 +207,40 @@ class DefaultConfig(TypedDict): } }, }, + "lba2": { + "response": ["rt", "response"], + "list_params": lba2_params, + "choices": [0, 1], + "description": "Linear Ballistic Accumulator 2 Choices (LBA2)", + "likelihoods": { + "analytical": { + "loglik": logp_lba2, + "backend": None, + "default_priors": {}, + "bounds": lba2_bounds, + "extra_fields": None, + } + }, + }, + "lba3": { + "response": ["rt", "response"], + "list_params": lba3_params, + "choices": [0, 1, 2], + "description": "Linear Ballistic Accumulator 3 Choices (LBA3)", + "likelihoods": { + "analytical": { + "loglik": logp_lba3, + "backend": None, + "default_priors": {}, + "bounds": lba3_bounds, + "extra_fields": None, + } + }, + }, "angle": { "response": ["rt", "response"], "list_params": ["v", "a", "z", "t", "theta"], + "choices": [-1, 1], "description": None, "likelihoods": { "approx_differentiable": { @@ -218,6 +261,7 @@ class DefaultConfig(TypedDict): "levy": { "response": ["rt", "response"], "list_params": ["v", "a", "z", "alpha", "t"], + "choices": [-1, 1], "description": None, "likelihoods": { "approx_differentiable": { @@ -238,6 +282,7 @@ class DefaultConfig(TypedDict): "ornstein": { "response": ["rt", "response"], "list_params": ["v", "a", "z", "g", "t"], + "choices": [-1, 1], "description": None, "likelihoods": { "approx_differentiable": { @@ -258,6 +303,7 @@ class DefaultConfig(TypedDict): "weibull": { "response": ["rt", "response"], "list_params": ["v", "a", "z", "t", "alpha", "beta"], + "choices": [-1, 1], "description": None, "likelihoods": { "approx_differentiable": { @@ -279,6 +325,7 @@ class DefaultConfig(TypedDict): "race_no_bias_angle_4": { "response": ["rt", "response"], "list_params": ["v0", "v1", "v2", "v3", "a", "z", "t", "theta"], + "choices": [0, 1, 2, 3], "description": None, "likelihoods": { "approx_differentiable": { @@ -302,6 +349,7 @@ class DefaultConfig(TypedDict): "ddm_seq2_no_bias": { "response": ["rt", "response"], "list_params": ["vh", "vl1", "vl2", "a", "t"], + "choices": [0, 1, 2, 3], "description": None, "likelihoods": { "approx_differentiable": { diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 93ba9276..36ff6d2e 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -7,6 +7,7 @@ """ import logging +import typing from copy import deepcopy from inspect import isclass from os import PathLike @@ -45,6 +46,7 @@ _make_default_prior, ) from hssm.utils import ( + _compute_log_likelihood, _get_alias_dict, _print_prior, _process_param_in_kwargs, @@ -236,7 +238,7 @@ def __init__( self, data: pd.DataFrame, model: SupportedModels | str = "ddm", - choices: int | list[int] = 2, + choices: list[int] | None = None, include: list[dict | Param] | None = None, model_config: ModelConfig | dict | None = None, loglik: ( @@ -282,11 +284,46 @@ def __init__( self.model_config = Config.from_defaults(model, loglik_kind) # Update defaults with user-provided config, if any if model_config is not None: + if isinstance(model_config, dict): + if "choices" not in model_config: + if choices is not None: + model_config["choices"] = choices + else: + if choices is not None: + _logger.info( + "choices list provided in both model_config and " + "as an argument directly." + " Using the one provided in model_config. \n" + "We recommend providing choices in model_config." + ) + elif isinstance(model_config, ModelConfig): + if model_config.choices is None: + if choices is not None: + model_config.choices = choices + else: + if choices is not None: + _logger.info( + "choices list provided in both model_config and " + "as an argument directly." + " Using the one provided in model_config. \n" + "We recommend providing choices in model_config." + ) + self.model_config.update_config( model_config if isinstance(model_config, ModelConfig) else ModelConfig(**model_config) # also serves as dict validation ) + else: + if model not in typing.get_args(SupportedModels): + if choices is not None: + self.model_config.update_choices(choices) + else: + if choices is not None: + _logger.info( + "Model string is in SupportedModels." + " Ignoring choices arguments." + ) # Update loglik with user-provided value self.model_config.update_loglik(loglik) @@ -296,26 +333,17 @@ def __init__( # Set up shortcuts so old code will work self.response = self.model_config.response self.list_params = self.model_config.list_params + self.choices = self.model_config.choices self.model_name = self.model_config.model_name self.loglik = self.model_config.loglik self.loglik_kind = self.model_config.loglik_kind self.extra_fields = self.model_config.extra_fields - if isinstance(choices, int): - if choices == 2: - self.n_choices = 2 - self.choices = [-1, 1] - elif choices > 2: - self.n_choices = choices - self.choices = list(range(choices)) - else: - raise ValueError("choices must be greater than 1.") - elif isinstance(choices, list): - self.n_choices = len(choices) - self.choices = choices - else: - raise ValueError("choices must be an integer or a list of integers.") + assert ( + self.choices is not None + ), "choices must be provides either in model_config or as an argument." + self.n_choices = len(self.choices) self._pre_check_data_sanity() # Go-NoGo @@ -405,7 +433,7 @@ def __init__( self.formula, data=self.data, family=self.family, - priors=self.priors, + priors=self.priors, # center_predictors=False extra_namespace=self.additional_namespace, **other_kwargs, ) @@ -425,6 +453,11 @@ def __init__( vector_only=True, ) + # Make sure we reset rvs_to_initial_values --> Only None's + # Otherwise PyMC barks at us when asking to compute likelihoods + self.pymc_model.rvs_to_initial_values.update( + {key_: None for key_ in self.pymc_model.rvs_to_initial_values.keys()} + ) _logger.info("Model initialized successfully.") def find_MAP(self, **kwargs): @@ -594,7 +627,7 @@ def sample( "pymc" if sampler == "mcmc" else sampler.split("_")[1] ) - # Don't compute likelihood directly through pymc sampler + # Define whether likelihood should be computed compute_likelihood = True if "idata_kwargs" in kwargs: if "log_likelihood" in kwargs["idata_kwargs"]: @@ -613,22 +646,12 @@ def sample( **kwargs, ) + # Separate out log likelihood computation if compute_likelihood: - with self.pymc_model: - pm.compute_log_likelihood(self._inference_obj) + self.log_likelihood(self._inference_obj, inplace=True) # Subset data vars in posterior - if self._inference_obj is not None: - vars_to_keep = set( - [var.name for var in getattr(self, "pymc_model").free_RVs] - ).intersection(set(list(self._inference_obj["posterior"].data_vars.keys()))) - - setattr( - self._inference_obj, - "posterior", - self._inference_obj["posterior"][list(vars_to_keep)], - ) - + self._clean_posterior_group(idata=self._inference_obj) return self.traces def vi( @@ -684,18 +707,7 @@ def vi( self._inference_obj_vi = self._vi_approx.sample(draws) # Post-processing - if self._inference_obj_vi is not None: - vars_to_keep = set( - [var.name for var in self.pymc_model.free_RVs] - ).intersection( - set(list(self._inference_obj_vi["posterior"].data_vars.keys())) - ) - - setattr( - self._inference_obj_vi, - "posterior", - self._inference_obj_vi["posterior"][list(vars_to_keep)], - ) + self._clean_posterior_group(idata=self._inference_obj_vi) # Return the InferenceData object if return_idata is True if return_idata: @@ -703,6 +715,142 @@ def vi( # Otherwise return the appromation object directly return self.vi_approx + def _clean_posterior_group(self, idata: az.InferenceData | None = None): + """Clean up the posterior group of the InferenceData object. + + Parameters + ---------- + idata : az.InferenceData + The InferenceData object to clean up. If None, the last InferenceData object + will be used. + """ + # # Logic behind which variables to keep: + # # We essentially want to get rid of + # # all the trial-wise variables. + + # # We drop all distributional components, IF they are deterministics + # # (in which case they will be trial wise systematically) + # # and we keep distributional components, IF they are + # # basic random-variabels (in which case they should never + # # appear trial-wise). + if idata is None: + raise ValueError( + "The InferenceData object is None. Cannot clean up the posterior group." + ) + elif not hasattr(idata, "posterior"): + raise ValueError( + "The InferenceData object does not have a posterior group. " + + "Cannot clean up the posterior group." + ) + + vars_to_keep = set(idata["posterior"].data_vars.keys()).difference( + set( + key_ + for key_ in self.model.distributional_components.keys() + if key_ in [var_.name for var_ in self.pymc_model.deterministics] + ) + ) + vars_to_keep_clean = [var_ for var_ in vars_to_keep if "_mean" not in var_] + + setattr( + idata, + "posterior", + idata["posterior"][vars_to_keep_clean], + ) + + def log_likelihood( + self, + idata: az.InferenceData | None = None, + data: pd.DataFrame | None = None, + inplace: bool = True, + keep_likelihood_params: bool = False, + ) -> az.InferenceData | None: + """Compute the log likelihood of the model. + + Parameters + ---------- + idata : optional + The `InferenceData` object returned by `HSSM.sample()`. If not provided, + data : optional + A pandas DataFrame with values for the predictors that are used to obtain + out-of-sample predictions. If omitted, the original dataset is used. + inplace : optional + If `True` will modify idata in-place and append a `log_likelihood` group to + `idata`. Otherwise, it will return a copy of idata with the predictions + added, by default True. + keep_likelihood_params : optional + If `True`, the trial wise likelihood parameters that are computed + on route to getting the log likelihood are kept in the `idata` object. + Defaults to False. See also the method `add_likelihood_parameters_to_idata`. + + Returns + ------- + az.InferenceData | None + InferenceData or None + """ + if self._inference_obj is None and idata is None: + raise ValueError( + "Neither has the model been sampled yet nor" + + " an idata object has been provided." + ) + + if idata is None: + if self._inference_obj is None: + raise ValueError( + "The model has not been sampled yet. " + + "Please provide an idata object." + ) + else: + idata = self._inference_obj + + # Actual likelihood computation + idata = _compute_log_likelihood(self.model, idata, data, inplace) + + # clean up posterior: + if not keep_likelihood_params: + self._clean_posterior_group(idata=idata) + + if inplace: + return None + else: + return idata + + def add_likelihood_parameters_to_idata( + self, + idata: az.InferenceData | None = None, + inplace: bool = False, + ) -> az.InferenceData | None: + """Add likelihood parameters to the InferenceData object. + + Parameters + ---------- + idata : az.InferenceData + The InferenceData object returned by HSSM.sample(). + inplace : bool + If True, the likelihood parameters are added to idata in-place. Otherwise, + a copy of idata with the likelihood parameters added is returned. + Defaults to False. + + Returns + ------- + az.InferenceData | None + InferenceData or None + """ + if idata is None: + if self._inference_obj is None: + raise ValueError("No idata provided and model not yet sampled!") + else: + idata = self.model._compute_likelihood_params( # pylint: disable=protected-access + deepcopy(self._inference_obj) + if not inplace + else self._inference_obj + ) + else: + idata = self.model._compute_likelihood_params( # pylint: disable=protected-access + deepcopy(idata) if not inplace else idata + ) + return idata + def sample_posterior_predictive( self, idata: az.InferenceData | None = None, @@ -2118,5 +2266,4 @@ def _set_missing_data_and_deadline( "`missing_data` and `deadline` are both set to True, but you have no " + "missing data and/or no rts exceeding the deadline." ) - return network diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index 445ad80c..d626d26f 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -374,7 +374,7 @@ def logp_ddm_sdv( DDM: Type[pm.Distribution] = make_distribution( "ddm", logp_ddm, - list_params=["v", "a", "z", "t"], + list_params=ddm_params, bounds=ddm_bounds, ) @@ -384,3 +384,155 @@ def logp_ddm_sdv( list_params=ddm_sdv_params, bounds=ddm_sdv_bounds, ) + +# LBA + + +def _pt_normpdf(t): + return (1 / pt.sqrt(2 * pt.pi)) * pt.exp(-(t**2) / 2) + + +def _pt_normcdf(t): + return (1 / 2) * (1 + pt.erf(t / pt.sqrt(2))) + + +def _pt_tpdf(t, A, b, v, s): + g = (b - A - t * v) / (t * s) + h = (b - t * v) / (t * s) + f = ( + -v * _pt_normcdf(g) + + s * _pt_normpdf(g) + + v * _pt_normcdf(h) + - s * _pt_normpdf(h) + ) / A + return f + + +def _pt_tcdf(t, A, b, v, s): + e1 = ((b - A - t * v) / A) * _pt_normcdf((b - A - t * v) / (t * s)) + e2 = ((b - t * v) / A) * _pt_normcdf((b - t * v) / (t * s)) + e3 = ((t * s) / A) * _pt_normpdf((b - A - t * v) / (t * s)) + e4 = ((t * s) / A) * _pt_normpdf((b - t * v) / (t * s)) + F = 1 + e1 - e2 + e3 - e4 + return F + + +def _pt_lba3_ll(t, ch, A, b, v0, v1, v2): + s = 0.1 + __min = pt.exp(LOGP_LB) + __max = pt.exp(-LOGP_LB) + k = len([0, 1, 2]) + like = pt.zeros((*t.shape, k)) + running_idx = pt.arange(t.shape[0]) + + like_1 = ( + _pt_tpdf(t, A, b, v0, s) + * (1 - _pt_tcdf(t, A, b, v1, s)) + * (1 - _pt_tcdf(t, A, b, v2, s)) + ) + like_2 = ( + (1 - _pt_tcdf(t, A, b, v0, s)) + * _pt_tpdf(t, A, b, v1, s) + * (1 - _pt_tcdf(t, A, b, v2, s)) + ) + like_3 = ( + (1 - _pt_tcdf(t, A, b, v0, s)) + * (1 - _pt_tcdf(t, A, b, v1, s)) + * _pt_tpdf(t, A, b, v2, s) + ) + + like = pt.stack([like_1, like_2, like_3], axis=-1) + + # One should RETURN this because otherwise it will be pruned from graph + # like_printed = pytensor.printing.Print('like')(like) + + prob_neg = _pt_normcdf(-v0 / s) * _pt_normcdf(-v1 / s) * _pt_normcdf(-v2 / s) + return pt.log(pt.clip(like[running_idx, ch] / (1 - prob_neg), __min, __max)) + + +def _pt_lba2_ll(t, ch, A, b, v0, v1): + s = 0.1 + __min = pt.exp(LOGP_LB) + __max = pt.exp(-LOGP_LB) + k = len([0, 1]) + like = pt.zeros((*t.shape, k)) + running_idx = pt.arange(t.shape[0]) + + like_1 = _pt_tpdf(t, A, b, v0, s) * (1 - _pt_tcdf(t, A, b, v1, s)) + like_2 = (1 - _pt_tcdf(t, A, b, v0, s)) * _pt_tpdf(t, A, b, v1, s) + + like = pt.stack([like_1, like_2], axis=-1) + + # One should RETURN this because otherwise it will be pruned from graph + # like_printed = pytensor.printing.Print('like')(like) + + prob_neg = _pt_normcdf(-v0 / s) * _pt_normcdf(-v1 / s) + return pt.log(pt.clip(like[running_idx, ch] / (1 - prob_neg), __min, __max)) + + +def logp_lba2( + data: np.ndarray, + A: float, + b: float, + v0: float, + v1: float, +) -> np.ndarray: + """Compute the log-likelihood of the LBA model with 2 drift rates.""" + data = pt.reshape(data, (-1, 2)).astype(pytensor.config.floatX) + rt = pt.abs(data[:, 0]) + response = data[:, 1] + response_int = pt.cast(response, "int32") + logp = _pt_lba2_ll(rt, response_int, A, b, v0, v1).squeeze() + checked_logp = check_parameters(logp, b > A, msg="b > A") + return checked_logp + + +def logp_lba3( + data: np.ndarray, + A: float, + b: float, + v0: float, + v1: float, + v2: float, +) -> np.ndarray: + """Compute the log-likelihood of the LBA model with 3 drift rates.""" + data = pt.reshape(data, (-1, 2)).astype(pytensor.config.floatX) + rt = pt.abs(data[:, 0]) + response = data[:, 1] + response_int = pt.cast(response, "int32") + logp = _pt_lba3_ll(rt, response_int, A, b, v0, v1, v2).squeeze() + checked_logp = check_parameters(logp, b > A, msg="b > A") + return checked_logp + + +lba2_params = ["A", "b", "v0", "v1"] +lba3_params = ["A", "b", "v0", "v1", "v2"] + +lba2_bounds = { + "A": (0.0, inf), + "b": (0.2, inf), + "v0": (0.0, inf), + "v1": (0.0, inf), +} + +lba3_bounds = { + "A": (0.0, inf), + "b": (0.2, inf), + "v0": (0.0, inf), + "v1": (0.0, inf), + "v2": (0.0, inf), +} + +LBA2: Type[pm.Distribution] = make_distribution( + "lba2", + logp_lba2, + list_params=lba2_params, + bounds=lba2_bounds, +) + +LBA3: Type[pm.Distribution] = make_distribution( + "lba3", + logp_lba3, + list_params=lba3_params, + bounds=lba3_bounds, +) diff --git a/src/hssm/utils.py b/src/hssm/utils.py index cb2a154f..79adf6cb 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -9,17 +9,24 @@ _parse_bambi(). """ +import itertools import logging +from copy import deepcopy from typing import Any, Literal, cast +import arviz as az import bambi as bmb import jax import numpy as np import pandas as pd +import pymc as pm import pytensor +import pytensor.tensor as pt import xarray as xr from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm +from bambi.utils import get_aliased_name, response_evaluate_new_data from huggingface_hub import hf_hub_download +from tqdm import tqdm from .param import Param @@ -146,6 +153,202 @@ def _get_alias_dict( return alias_dict +def _compute_log_likelihood( + model: bmb.Model, + idata: az.InferenceData, + data: pd.DataFrame | None, + inplace: bool = True, +) -> az.InferenceData | None: + """Compute the model's log-likelihood. + + Parameters + ---------- + idata : InferenceData + The `InferenceData` instance returned by `.fit()`. + data : pandas.DataFrame or None + An optional data frame with values for the predictors and the response on which + the model's log-likelihood function is evaluated. + If omitted, the original dataset is used. + inplace : bool + If True` it will modify `idata` in-place. Otherwise, it will return a copy of + `idata` with the `log_likelihood` group added. + + Returns + ------- + InferenceData or None + """ + # These are not formal parameters because it does not make sense to... + # 1. compute the log-likelihood omitting + # the group-specific components of the model. + # 2. compute the log-likelihood on unseen groups. + include_group_specific = True + sample_new_groups = False + + # Get the aliased response name + response_aliased_name = get_aliased_name(model.response_component.term) + + if not inplace: + idata = deepcopy(idata) + + # # Populate the posterior in the InferenceData object + # with the likelihood parameters + idata = model._compute_likelihood_params( # pylint: disable=protected-access + idata, data, include_group_specific, sample_new_groups + ) + + required_kwargs = {"model": model, "posterior": idata["posterior"], "data": data} + log_likelihood_out = log_likelihood(model.family, **required_kwargs).to_dataset( + name=response_aliased_name + ) + + # Drop the existing log_likelihood group if it exists + if "log_likelihood" in idata: + _logger.info("Replacing existing log_likelihood group in idata.") + del idata["log_likelihood"] + + # Assign the log-likelihood group to the InferenceData object + idata.add_groups({"log_likelihood": log_likelihood_out}) + setattr( + idata, + "log_likelihood", + idata["log_likelihood"].assign_attrs( + modeling_interface="bambi", modeling_interface_version=bmb.__version__ + ), + ) + return idata + + +def log_likelihood( + family: bmb.Family, + model: bmb.Model, + posterior: xr.DataArray, + data: pd.DataFrame | None = None, + **kwargs, +) -> xr.DataArray: + """Evaluate the model log-likelihood. + + This is a variation on the `bambi.utils.log_likelihood` function that + loops over the chains and draws to evaluate the log-likelihood for each + instead of attempting to batch the computation as is done in the orignal. + + Parameters + ---------- + model : bambi.Model + The model + posterior : xr.Dataset + The xarray dataset that contains the draws for + all the parameters in the posterior. + It must contain the parameters that are needed + in the distribution of the response, or + the parameters that allow to derive them. + kwargs : + Parameters that are used to get draws but do + not appear in the posterior object or + other configuration parameters. + For instance, the 'n' in binomial models and multinomial models. + + Returns + ------- + xr.DataArray + A data array with the value of the log-likelihood + for each chain, draw, and value of the response variable. + """ + # Child classes pass "y_values" through the "y" kwarg + y_values = kwargs.pop("y", None) + + # Get the values of the outcome variable + if y_values is None: # when it's not handled by the specific family + if data is None: + y_values = np.squeeze(model.response_component.term.data) + else: + y_values = response_evaluate_new_data(model, data) + + response_dist = get_response_dist(model.family) + response_term = model.response_component.term + kwargs, coords = family._make_dist_kwargs_and_coords(model, posterior, **kwargs) + + # If it's multivariate, it's going to have a fourth coord, + # but we actually don't need it. We just need "chain", "draw", "__obs__" + coords = dict(list(coords.items())[:3]) + + n_chains = len(coords["chain"]) + n_draws = len(coords["draw"]) + output_array = np.zeros((n_chains, n_draws, len(y_values))) + kwargs_prep = {key_: val[0][0] for key_, val in kwargs.items()} + shape_dict = {key_: val.shape for key_, val in kwargs_prep.items()} + pt_dict = { + key_: (pt.vector(key_, shape=((1,) if val[0] == 1 else (None,)))) + for key_, val in shape_dict.items() + } + + # Compile likelihood function + if not response_term.is_constrained: + rv_logp = pm.logp(response_dist.dist(**pt_dict), y_values) + logp_compiled = pm.compile_pymc( + [val for key_, val in pt_dict.items()], + rv_logp, + allow_input_downcast=True, + ) + else: + # Bounds are scalars, we can safely pick them from the first row + lower, upper = response_term.data[0, 1:] + lower = lower if lower != -np.inf else None + upper = upper if upper != np.inf else None + + # Finally evaluate logp + rv_logp = pm.logp( + pm.Truncated.dist( + response_dist.dist(**kwargs_prep), lower=lower, upper=upper + ), + y_values, + ) + logp_compiled = pm.compile_pymc( + [val for key_, val in pt_dict.items()], rv_logp, allow_input_downcast=True + ) + + # Loop through chain and draws + for ids in tqdm( + list(itertools.product(coords["chain"].values, coords["draw"].values)) + ): + kwargs_tmp = { + key_: ( + val[ids[0], ids[1], ...] + if (val.shape[0] == n_chains and val.shape[1] == n_draws) + else val[0, 0, ...] + ) + for key_, val in kwargs.items() + } + + output_array[ids[0], ids[1], :] = logp_compiled(**kwargs_tmp) + + # output_array + return xr.DataArray(output_array, coords=coords) + + +def get_response_dist(family: bmb.Family) -> pm.Distribution: + """Get the PyMC distribution for the response. + + Parameters + ---------- + family : bambi.Family + The family for which the response distribution is wanted + + Returns + ------- + pm.Distribution + The response distribution + """ + mapping = {"Cumulative": pm.Categorical, "StoppingRatio": pm.Categorical} + + if family.likelihood.dist: + dist = family.likelihood.dist + elif family.likelihood.name in mapping: + dist = mapping[family.likelihood.name] + else: + dist = getattr(pm, family.likelihood.name) + return dist + + def set_floatX(dtype: Literal["float32", "float64"], update_jax: bool = True): """Set float types for pytensor and Jax. diff --git a/tests/slow/test_mcmc.py b/tests/slow/test_mcmc.py index c92e5843..0dd9c4f7 100644 --- a/tests/slow/test_mcmc.py +++ b/tests/slow/test_mcmc.py @@ -7,6 +7,8 @@ import hssm import numpy as np import pymc as pm +from copy import deepcopy +import xarray as xr from hssm.utils import _rearrange_data @@ -114,16 +116,52 @@ def sample(model, sampler, step): def run_sample(model, sampler, step, expected): - if expected == True: + """Run the sample function and check if the expected error is raised.""" + if expected is True: sample(model, sampler, step) assert isinstance(model.traces, az.InferenceData) + + # make sure log_likelihood computations check out + traces_copy = deepcopy(model.traces) + del traces_copy["log_likelihood"] + + # recomputing log-likelihood yields same results? + model.log_likelihood(traces_copy, inplace=True) + assert isinstance(traces_copy, az.InferenceData) + assert "log_likelihood" in traces_copy.groups() + for group_ in traces_copy.groups(): + xr.testing.assert_equal(traces_copy[group_], model.traces[group_]) + else: with pytest.raises(expected): sample(model, sampler, step) +# Basic tests for LBA likelihood +def test_lba_sampling(): + """Test if sampling works for available lba models.""" + lba2_data_out = hssm.simulate_data( + model="lba2", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0), size=500 + ) + + lba3_data_out = hssm.simulate_data( + model="lba3", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0), size=500 + ) + + lba2_model = hssm.HSSM(model="lba2", data=lba2_data_out) + + lba3_model = hssm.HSSM(model="lba3", data=lba3_data_out) + + traces_2 = lba2_model.sample(sampler="nuts_numpyro", draws=100, tune=100, chains=1) + traces_3 = lba3_model.sample(sampler="nuts_numpyro", draws=100, tune=100, chains=1) + + assert isinstance(traces_2, az.InferenceData) + assert isinstance(traces_3, az.InferenceData) + + @pytest.mark.parametrize(parameter_names, parameter_grid) def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected): + """Test simple models.""" print("PYMC VERSION: ") print(pm.__version__) print("TEST INPUTS WERE: ") @@ -133,6 +171,9 @@ def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected): model = hssm.HSSM( data_ddm, loglik_kind=loglik_kind, model_config={"backend": backend} ) + assert np.all( + [val_ is None for key_, val_ in model.pymc_model.rvs_to_initial_values.items()] + ) run_sample(model, sampler, step, expected) # Only runs once @@ -154,6 +195,7 @@ def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected): @pytest.mark.parametrize(parameter_names, parameter_grid) def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected): + """Test regression models.""" print("PYMC VERSION: ") print(pm.__version__) print("TEST INPUTS WERE: ") @@ -174,6 +216,9 @@ def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected) model_config={"backend": backend}, v=param_reg, ) + assert np.all( + [val_ is None for key_, val_ in model.pymc_model.rvs_to_initial_values.items()] + ) run_sample(model, sampler, step, expected) # Only runs once @@ -190,6 +235,7 @@ def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected) @pytest.mark.parametrize(parameter_names, parameter_grid) def test_reg_models_v_a(data_ddm_reg_va, loglik_kind, backend, sampler, step, expected): + """Test regression models with multiple parameters (v, a).""" print("PYMC VERSION: ") print(pm.__version__) print("TEST INPUTS WERE: ") @@ -224,6 +270,9 @@ def test_reg_models_v_a(data_ddm_reg_va, loglik_kind, backend, sampler, step, ex v=param_reg_v, a=param_reg_a, ) + assert np.all( + [val_ is None for key_, val_ in model.pymc_model.rvs_to_initial_values.items()] + ) print(model.params["a"]) run_sample(model, sampler, step, expected) @@ -269,6 +318,7 @@ def test_reg_models_v_a(data_ddm_reg_va, loglik_kind, backend, sampler, step, ex def test_simple_models_missing_data( data_ddm_missing, loglik_kind, backend, sampler, step, expected, cpn ): + """Test simple model with missing data (deadline e.g.)""" print("PYMC VERSION: ") print(pm.__version__) print("TEST INPUTS WERE: ") @@ -282,6 +332,9 @@ def test_simple_models_missing_data( missing_data=True, loglik_missing_data=cpn, ) + assert np.all( + [val_ is None for key_, val_ in model.pymc_model.rvs_to_initial_values.items()] + ) run_sample(model, sampler, step, expected) @@ -289,6 +342,7 @@ def test_simple_models_missing_data( def test_reg_models_missing_data( data_ddm_reg_missing, loglik_kind, backend, sampler, step, expected, cpn ): + """Test regression model with missing data (deadline e.g.)""" print("PYMC VERSION: ") print(pm.__version__) print("TEST INPUTS WERE: ") @@ -311,6 +365,9 @@ def test_reg_models_missing_data( missing_data=True, loglik_missing_data=cpn, ) + assert np.all( + [val_ is None for key_, val_ in model.pymc_model.rvs_to_initial_values.items()] + ) run_sample(model, sampler, step, expected) @@ -318,6 +375,7 @@ def test_reg_models_missing_data( def test_simple_models_deadline( data_ddm_deadline, loglik_kind, backend, sampler, step, expected, opn ): + """Test simple model with deadline.""" print("PYMC VERSION: ") print(pm.__version__) print("TEST INPUTS WERE: ") @@ -330,6 +388,9 @@ def test_simple_models_deadline( deadline=True, loglik_missing_data=opn, ) + assert np.all( + [val_ is None for key_, val_ in model.pymc_model.rvs_to_initial_values.items()] + ) run_sample(model, sampler, step, expected) @@ -337,6 +398,7 @@ def test_simple_models_deadline( def test_reg_models_deadline( data_ddm_reg_deadline, loglik_kind, backend, sampler, step, expected, opn ): + """Test regression model with deadline.""" print("PYMC VERSION: ") print(pm.__version__) print("TEST INPUTS WERE: ") @@ -359,4 +421,7 @@ def test_reg_models_deadline( deadline=True, loglik_missing_data=opn, ) + assert np.all( + [val_ is None for key_, val_ in model.pymc_model.rvs_to_initial_values.items()] + ) run_sample(model, sampler, step, expected) diff --git a/tests/test_data_sanity.py b/tests/test_data_sanity.py index b8376195..b7476b75 100644 --- a/tests/test_data_sanity.py +++ b/tests/test_data_sanity.py @@ -55,7 +55,7 @@ def test_data_sanity_check(data_ddm, cpn, caplog): with pytest.raises( ValueError, - match=r"Invalid responses found in your dataset: \[0\]", + match=r"Invalid responses found in your dataset: \[0, 2\]", ): data_ddm_miscoded = data_ddm.copy() data_ddm_miscoded["response"] = np.random.choice([0, 1, 2], data_ddm.shape[0]) @@ -64,20 +64,17 @@ def test_data_sanity_check(data_ddm, cpn, caplog): # Case 6: raise warning if there are missing responses in data data_ddm_miscoded = data_ddm.copy() - data_ddm_miscoded["response"] = np.random.choice([1, 2], data_ddm.shape[0]) + data_ddm_miscoded["response"] = np.random.choice([1], data_ddm.shape[0]) - hssm.HSSM(data=data_ddm_miscoded, model="ddm", choices=[1, 2, 3]) + hssm.HSSM(data=data_ddm_miscoded, model="ddm") print("THE CAPLOG RECORDS") print([record.msg for record in caplog.records]) - assert ( - "You set choices to be [1, 2, 3], but [3] are missing from your dataset." - in [ - record.msg % record.args if record.args else record.msg - for record in caplog.records - ] - ) + assert "You set choices to be [-1, 1], but [-1] are missing from your dataset." in [ + record.msg % record.args if record.args else record.msg + for record in caplog.records + ] # Case 7: if deadline or missing_data is True, data should contain missing values with pytest.raises( diff --git a/tests/test_hssm.py b/tests/test_hssm.py index c7890b69..2d930a41 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -5,6 +5,7 @@ import hssm from hssm import HSSM from hssm.likelihoods import DDM, logp_ddm +from copy import deepcopy hssm.set_floatX("float32", update_jax=True) @@ -114,6 +115,7 @@ def test_custom_model(data_ddm): model="custom", model_config={ "list_params": ["v", "a", "z", "t"], + "choices": [-1, 1], "bounds": { "v": (-3.0, 3.0), "a": (0.3, 2.5), @@ -279,6 +281,40 @@ def test_resampling(data_ddm): assert sample_1 is not sample_2 +def test_add_likelihood_parameters_to_data(data_ddm): + """Test if the likelihood parameters are added to the InferenceData object.""" + model = HSSM(data=data_ddm) + sample_1 = model.sample(draws=10, chains=1, tune=10) + sample_1_copy = deepcopy(sample_1) + model.add_likelihood_parameters_to_idata(inplace=True) + + # Get distributional components (make sure to take the right aliases) + distributional_component_names = [ + key_ if key_ not in model._aliases else model._aliases[key_] + for key_ in model.model.distributional_components.keys() + ] + + # Check that after computing the likelihood parameters + # all respective parameters appear in the InferenceData object + assert np.all( + [ + component_ in model.traces.posterior.data_vars + for component_ in distributional_component_names + ] + ) + + # Check that before computing the likelihood parameters + # at least one parameter is missing (in the simplest case + # this is the {parent}_mean parameter if nothing received a regression) + + assert not np.all( + [ + component_ in sample_1_copy.posterior.data_vars + for component_ in distributional_component_names + ] + ) + + # Setting any parameter to a fixed value should work: def test_model_creation_constant_parameter(data_ddm): for param_name in ["v", "a", "z", "t"]: diff --git a/tests/test_likelihoods_lba.py b/tests/test_likelihoods_lba.py new file mode 100644 index 00000000..1508b095 --- /dev/null +++ b/tests/test_likelihoods_lba.py @@ -0,0 +1,143 @@ +"""Unit testing for LBA likelihood functions.""" + +from pathlib import Path +from itertools import product + +import numpy as np +import pandas as pd +import pymc as pm +import pytensor +import pytensor.tensor as pt +import pytest +import arviz as az +from pytensor.compile.nanguardmode import NanGuardMode + +import hssm + +# pylint: disable=C0413 +from hssm.likelihoods.analytical import logp_lba2, logp_lba3 +from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox +from hssm.distribution_utils import make_likelihood_callable + +hssm.set_floatX("float32") + +CLOSE_TOLERANCE = 1e-4 + + +def test_lba2_basic(): + size = 1000 + + lba_data_out = hssm.simulate_data( + model="lba2", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0), size=size + ) + + # Test if vectorization ok across parameters + out_A_vec = logp_lba2( + lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0 + ).eval() + out_base = logp_lba2(lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0).eval() + assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE) + + out_b_vec = logp_lba2( + lba_data_out.values, + A=np.array([0.2] * size), + b=np.array([0.5] * size), + v0=1.0, + v1=1.0, + ).eval() + assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE) + + out_v_vec = logp_lba2( + lba_data_out.values, + A=np.array([0.2] * size), + b=np.array([0.5] * size), + v0=np.array([1.0] * size), + v1=np.array([1.0] * size), + ).eval() + assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE) + + # Test A > b leads to error + with pytest.raises(pm.logprob.utils.ParameterValueError): + logp_lba2( + lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0 + ).eval() + + with pytest.raises(pm.logprob.utils.ParameterValueError): + logp_lba2(lba_data_out.values, A=0.6, b=0.5, v0=1.0, v1=1.0).eval() + + with pytest.raises(pm.logprob.utils.ParameterValueError): + logp_lba2( + lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0 + ).eval() + + with pytest.raises(pm.logprob.utils.ParameterValueError): + logp_lba2( + lba_data_out.values, + A=np.array([0.6] * 1000), + b=np.array([0.5] * 1000), + v0=1.0, + v1=1.0, + ).eval() + + +def test_lba3_basic(): + size = 1000 + + lba_data_out = hssm.simulate_data( + model="lba3", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0), size=size + ) + + # Test if vectorization ok across parameters + out_A_vec = logp_lba3( + lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0, v2=1.0 + ).eval() + + out_base = logp_lba3( + lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0 + ).eval() + + assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE) + + out_b_vec = logp_lba3( + lba_data_out.values, + A=np.array([0.2] * size), + b=np.array([0.5] * size), + v0=1.0, + v1=1.0, + v2=1.0, + ).eval() + assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE) + + out_v_vec = logp_lba3( + lba_data_out.values, + A=np.array([0.2] * size), + b=np.array([0.5] * size), + v0=np.array([1.0] * size), + v1=np.array([1.0] * size), + v2=np.array([1.0] * size), + ).eval() + assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE) + + # Test A > b leads to error + with pytest.raises(pm.logprob.utils.ParameterValueError): + logp_lba3( + lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0, v2=1.0 + ).eval() + + with pytest.raises(pm.logprob.utils.ParameterValueError): + logp_lba3(lba_data_out.values, b=0.5, A=0.6, v0=1.0, v1=1.0, v2=1.0).eval() + + with pytest.raises(pm.logprob.utils.ParameterValueError): + logp_lba3( + lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0, v2=1.0 + ).eval() + + with pytest.raises(pm.logprob.utils.ParameterValueError): + logp_lba3( + lba_data_out.values, + A=np.array([0.6] * 1000), + b=np.array([0.5] * 1000), + v0=1.0, + v1=1.0, + v2=1.0, + ).eval()