Skip to content

Commit

Permalink
Merge pull request #40 from hoechenberger/rng
Browse files Browse the repository at this point in the history
ENH: Allow JSON serialization of RNG
  • Loading branch information
hoechenberger authored Dec 24, 2019
2 parents a0fb0e6 + 02a51e9 commit 9abd28b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
v2019.4
-------
* Allow JSON serialization of random number generator

v2019.3
-------
* Allow to pass a prior when instantiating `QuestPlusWeibull`
Expand Down
10 changes: 10 additions & 0 deletions questplus/qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,10 @@ def to_json(self) -> str:
self_copy.prior = self_copy.prior.to_dict()
self_copy.posterior = self_copy.posterior.to_dict()
self_copy.likelihoods = self_copy.likelihoods.to_dict()

if self_copy._rng is not None: # NumPy RandomState cannot be serialized.
self_copy._rng = self_copy._rng.get_state()

return json_tricks.dumps(self_copy, allow_nan=True)

@staticmethod
Expand All @@ -412,6 +416,12 @@ def from_json(data: str):
loaded.prior = xr.DataArray.from_dict(loaded.prior)
loaded.posterior = xr.DataArray.from_dict(loaded.posterior)
loaded.likelihoods = xr.DataArray.from_dict(loaded.likelihoods)

if loaded._rng is not None:
state = deepcopy(loaded._rng)
loaded._rng = np.random.RandomState()
loaded._rng.set_state(state)

return loaded

def __eq__(self, other):
Expand Down
31 changes: 31 additions & 0 deletions questplus/tests/test_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,36 @@ def test_json():
q_loaded.update(stim=q_loaded.next_stim, outcome=dict(response='Correct'))


def test_json_rng():
threshold = np.arange(-40, 0 + 1)
slope, guess, lapse = 3.5, 0.5, 0.02
contrasts = threshold.copy()

stim_domain = dict(intensity=contrasts)
param_domain = dict(threshold=threshold, slope=slope,
lower_asymptote=guess, lapse_rate=lapse)
outcome_domain = dict(response=['Correct', 'Incorrect'])
f = 'weibull'
scale = 'dB'
stim_selection_method = 'min_n_entropy'
param_estimation_method = 'mode'
random_seed = 5
stim_selection_options = dict(n=3, random_seed=random_seed)

q = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
outcome_domain=outcome_domain, func=f, stim_scale=scale,
stim_selection_method=stim_selection_method,
param_estimation_method=param_estimation_method,
stim_selection_options=stim_selection_options)

q2 = QuestPlus.from_json(q.to_json())

rand = q._rng.random_sample(10)
rand2 = q2._rng.random_sample(10)

assert np.allclose(rand, rand2)


def test_marginal_posterior():
contrasts = np.arange(-40, 0 + 1)
slope = np.arange(2, 5 + 1)
Expand Down Expand Up @@ -688,6 +718,7 @@ def test_weibull_prior():
test_weibull()
test_eq()
test_json()
test_json_rng()
test_marginal_posterior()
test_prior_for_unknown_parameter()
test_prior_for_parameter_subset()
Expand Down

0 comments on commit 9abd28b

Please sign in to comment.