From 6342c72b1070879b400db590418d32840843173c Mon Sep 17 00:00:00 2001 From: Fengler Date: Tue, 7 May 2024 21:39:28 +0200 Subject: [PATCH] make mypy happy --- src/hssm/hssm.py | 15 ++++++++------- tests/slow/test_mcmc.py | 7 ++++++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 2cd81fdc..2f2b96a3 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -500,13 +500,14 @@ def sample( # The parent was previously not part of deterministics --> compute it via # posterior_predictive (works because it acts as the 'mu' parameter # in the GLM as far as bambi is concerned) - if self._parent not in self._inference_obj.posterior.data_vars.keys(): - self.sample_posterior_predictive(self._inference_obj, kind="mean") - # rename 'rt,response_mean' to 'v' so in the traces everything - # looks the way it should - self._inference_obj.rename_vars( - {"rt,response_mean": self._parent}, inplace=True - ) + if self._inference_obj is not None: + if self._parent not in self._inference_obj.posterior.data_vars.keys(): + self.sample_posterior_predictive(self._inference_obj, kind="mean") + # rename 'rt,response_mean' to 'v' so in the traces everything + # looks the way it should + self._inference_obj.rename_vars( + {"rt,response_mean": self._parent}, inplace=True + ) return self.traces def sample_posterior_predictive( diff --git a/tests/slow/test_mcmc.py b/tests/slow/test_mcmc.py index 1d99b291..96f86a0d 100644 --- a/tests/slow/test_mcmc.py +++ b/tests/slow/test_mcmc.py @@ -224,7 +224,12 @@ def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expec # Only runs once if loglik_kind == "analytical" and sampler is None: - assert model._get_deterministic_var_names(model.traces) == ["~a", "~v"] + assert len(model._get_deterministic_var_names(model.traces)) == len( + ["~a", "~v"] + ) + assert set(model._get_deterministic_var_names(model.traces)) == set( + ["~a", "~v"] + ) # test summary: summary = model.summary() assert summary.shape[0] == 8