Skip to content

Commit

Permalink
NF: Add JSON dump and load, and support "equals" operator
Browse files Browse the repository at this point in the history
  • Loading branch information
hoechenberger committed May 23, 2019
1 parent 14527e6 commit 4ee190b
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 0 deletions.
64 changes: 64 additions & 0 deletions questplus/qp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Sequence
import xarray as xr
import numpy as np
import json_tricks
from copy import deepcopy

from questplus import psychometric_function
Expand Down Expand Up @@ -294,6 +295,69 @@ def param_estimate(self) -> dict:

return param_estimates

def to_json(self) -> str:
self_copy = deepcopy(self)
self_copy.prior = self_copy.prior.to_dict()
self_copy.posterior = self_copy.posterior.to_dict()
self_copy.likelihoods = self_copy.likelihoods.to_dict()
return json_tricks.dumps(self_copy)

@staticmethod
def from_json(data: str):
loaded = json_tricks.loads(data)
loaded.prior = xr.DataArray.from_dict(loaded.prior)
loaded.posterior = xr.DataArray.from_dict(loaded.posterior)
loaded.likelihoods = xr.DataArray.from_dict(loaded.likelihoods)
return loaded

def __eq__(self, other):
if not self.likelihoods.equals(other.likelihoods):
return False

if not self.prior.equals(other.prior):
return False

if not self.posterior.equals(other.posterior):
return False

for param_name in self.param_domain.keys():
if not np.array_equal(self.param_domain[param_name],
other.param_domain[param_name]):
return False

for stim_property in self.stim_domain.keys():
if not np.array_equal(self.stim_domain[stim_property],
other.stim_domain[stim_property]):
return False

for outcome_name in self.outcome_domain.keys():
if not np.array_equal(self.outcome_domain[outcome_name],
other.outcome_domain[outcome_name]):
return False

if self.stim_selection != other.stim_selection:
return False

if self.stim_selection_options != other.stim_selection_options:
return False

if self.stim_scale != other.stim_scale:
return False

if self.stim_history != other.stim_history:
return False

if self.resp_history != other.resp_history:
return False

if self.param_estimation_method != other.param_estimation_method:
return False

if self.func != other.func:
return False

return True


class QuestPlusWeibull(QuestPlus):
def __init__(self, *,
Expand Down
68 changes: 68 additions & 0 deletions questplus/tests/test_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,78 @@ def test_weibull():
expected_mode_threshold)


def test_eq():
threshold = np.arange(-40, 0 + 1)
slope, guess, lapse = 3.5, 0.5, 0.02
contrasts = threshold.copy()

stim_domain = dict(intensity=contrasts)
param_domain = dict(threshold=threshold, slope=slope,
lower_asymptote=guess, lapse_rate=lapse)
outcome_domain = dict(response=['Correct', 'Incorrect'])

f = 'weibull'
scale = 'dB'
stim_selection_method = 'min_entropy'
param_estimation_method = 'mode'

q1 = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
outcome_domain=outcome_domain, func=f, stim_scale=scale,
stim_selection_method=stim_selection_method,
param_estimation_method=param_estimation_method)

q2 = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
outcome_domain=outcome_domain, func=f, stim_scale=scale,
stim_selection_method=stim_selection_method,
param_estimation_method=param_estimation_method)

# Add some random responses.
q1.update(stim=q1.next_stim, outcome=dict(response='Correct'))
q1.update(stim=q1.next_stim, outcome=dict(response='Incorrect'))
q2.update(stim=q2.next_stim, outcome=dict(response='Correct'))
q2.update(stim=q2.next_stim, outcome=dict(response='Incorrect'))

assert q1 == q2


def test_json():
threshold = np.arange(-40, 0 + 1)
slope, guess, lapse = 3.5, 0.5, 0.02
contrasts = threshold.copy()

stim_domain = dict(intensity=contrasts)
param_domain = dict(threshold=threshold, slope=slope,
lower_asymptote=guess, lapse_rate=lapse)
outcome_domain = dict(response=['Correct', 'Incorrect'])

f = 'weibull'
scale = 'dB'
stim_selection_method = 'min_entropy'
param_estimation_method = 'mode'

q = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
outcome_domain=outcome_domain, func=f, stim_scale=scale,
stim_selection_method=stim_selection_method,
param_estimation_method=param_estimation_method)

# Add some random responses.
q.update(stim=q.next_stim, outcome=dict(response='Correct'))
q.update(stim=q.next_stim, outcome=dict(response='Incorrect'))

q_dumped = q.to_json()
q_loaded = QuestPlus.from_json(q_dumped)

assert q_loaded == q

q_loaded.update(stim=q_loaded.next_stim, outcome=dict(response='Correct'))


if __name__ == '__main__':
test_threshold()
test_threshold_slope()
test_threshold_slope_lapse()
test_mean_sd_lapse()
test_spatial_contrast_sensitivity()
test_weibull()
test_eq()
test_json()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ install_requires =
numpy
scipy
xarray
json_tricks

[bdist_wheel]
universal = 1
Expand Down

0 comments on commit 4ee190b

Please sign in to comment.