Skip to content

Commit

Permalink
Change the jax random number generator key for 32 bit systems
Browse files Browse the repository at this point in the history
Resolves #832. In theory we can do `2**31` and not `2**31-1` if we would like, the latter seems to run without issue as well, I just kept with the `-1` to be consistent with the current implementation
  • Loading branch information
tjburch authored Aug 21, 2024
1 parent d574614 commit 79176a4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _run_mcmc(
random_seed = random_seed[0]
np.random.seed(random_seed)

jax_seed = jax.random.PRNGKey(np.random.randint(2**32 - 1))
jax_seed = jax.random.PRNGKey(np.random.randint(2**31 - 1))

bx_model = bx.Model.from_pymc(self.model)
bx_sampler = operator.attrgetter(sampler_backend)(
Expand Down

0 comments on commit 79176a4

Please sign in to comment.