Skip to content

Commit

Permalink
added HDDMnnRL demo
Browse files Browse the repository at this point in the history
  • Loading branch information
Krishn_bera committed May 20, 2022
1 parent 74726ee commit 0c79cc3
Show file tree
Hide file tree
Showing 23 changed files with 90,731 additions and 2,310 deletions.
1,658 changes: 0 additions & 1,658 deletions hddm/examples/TEST_RLHDDM_NN.ipynb

This file was deleted.

652 changes: 0 additions & 652 deletions hddm/examples/Test_HDDMnn.ipynb

This file was deleted.

Binary file not shown.
487 changes: 487 additions & 0 deletions hddm/examples/demo_HDDMnnRL/demo_HDDMnnRL.ipynb

Large diffs are not rendered by default.

243 changes: 243 additions & 0 deletions hddm/examples/demo_HDDMnnRL/demo_HDDMnnRL.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
Tutorial for analyzing instrumental learning data with the HDDMnnRL module
==========================================================================

This is a tutorial for using the HDDMrl module to simultaneously
estimate reinforcement learning parameters and decision parameters
within a fully hierarchical bayesian estimation framework, including
steps for sampling, assessing convergence, model fit, parameter
recovery, and posterior predictive checks (model validation). The module
uses the reinforcement learning drift diffusion model (RLDDM), a
reinforcement learning model that replaces the standard “softmax” choice
function with a drift diffusion process. The softmax and drift diffusion
process is equivalent for capturing choice proportions, but the DDM also
takes RT distributions into account; options are provided to also only
fit RL parameters without RT. The RLDDM estimates trial-by-trial drift
rate as a scaled difference in expected rewards (expected reward for
upper bound alternative minus expected reward for lower bound
alternative). Expected rewards are updated with a delta learning rule
using either a single learning rate or with separate learning rates for
positive and negative prediction errors. The model also includes the
standard DDM-parameters. The RLDDM is described in detail in `Pedersen,
Frank & Biele
(2017). <http://ski.clps.brown.edu/papers/PedersenEtAl_RLDDM.pdf>`__
(Note this approach differs from Frank et al (2015) who used HDDM to fit
instrumental learning but did not allow for simultaneous estimation of
learning parameters).

.. code:: ipython3
import hddm
import pickle
import pandas as pd
Load the data
^^^^^^^^^^^^^

.. code:: ipython3
with open('./angle_d1_c3_s20_t500.pickle', 'rb') as handle:
datafile = pickle.load(handle)
# Here, datafile is saved as a list of datasets. We pick the first dataset.
dataset = datafile[0]
.. parsed-literal::
1
.. code:: ipython3
# Reformat the dataset as a dataframe
data = hddm.utils.get_dataset_as_dataframe_rlssm(dataset)
Initialize the HDDMnnRL model and sample
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code:: ipython3
# Specify number of samples and burnins
nsamples = 100
nburn = 50
.. code:: ipython3
m = hddm.HDDMnnRL(data, model='angle', rl_rule='RWupdate', non_centered=True, include=['z', 'theta', 'rl_alpha'], p_outlier = 0.0)
m.sample(nsamples, burn=nburn, dbname='traces.db', db='pickle')
.. parsed-literal::
Printing model specifications --
ssm: angle
rl rule: RWupdate
using non-centered dist.: False
Using default priors: Uninformative
Supplied model_config specifies params_std_upper for z as None.
Changed to 10
Supplied model_config specifies params_std_upper for rl_alpha as None.
Changed to 10
[-----------------101%-----------------] 101 of 100 complete in 256.1 sec
.. parsed-literal::
<pymc.MCMC.MCMC at 0x7f5b5c224f10>
Save the model
^^^^^^^^^^^^^^

.. code:: ipython3
# Save the model
m.save('rlssm_model')
.. code:: ipython3
# Load the model
# model = hddm.load('rlssm_model')
Check the posterior results
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code:: ipython3
m.plot_posteriors()
.. parsed-literal::
Plotting v
Plotting v_std
Plotting a
Plotting a_std
Plotting z
Plotting z_std
Plotting t
Plotting t_std
Plotting theta
Plotting theta_std
Plotting rl_alpha
Plotting rl_alpha_std
.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_1.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_2.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_3.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_4.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_5.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_6.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_7.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_8.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_9.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_10.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_11.png



.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_12_12.png


.. code:: ipython3
# Load the trace
with open('./traces.db', 'rb') as handle:
tracefile = pickle.load(handle)
.. code:: ipython3
# Re-format traces as a dataframe
traces = hddm.utils.get_traces_rlssm(tracefile)
.. code:: ipython3
model_ssm = 'angle'
model_rl = 'RWupdate'
config_ssm = hddm.model_config.model_config[model_ssm]
config_rl = hddm.model_config_rl.model_config_rl[model_rl]
.. code:: ipython3
hddm.plotting.plot_posterior_pairs_rlssm(tracefile, config_ssm['params'] + config_rl['params'])
Posterior Predictive Checks
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code:: ipython3
num_posterior_samples = 3
p_lower = {0: 0.15, 1:0.30, 2:0.45}
p_upper = {0: 0.85, 1:0.70, 2:0.55}
ppc_sdata = hddm.plotting.gen_ppc_rlssm(model_ssm, config_ssm, model_rl, config_rl, data, traces, num_posterior_samples, p_lower, p_upper, save_data=True, save_name='ppc_data')
.. parsed-literal::
100%|██████████| 3/3 [05:49<00:00, 116.55s/it]
.. parsed-literal::
ppc data saved at ./ppc_data.csv
.. code:: ipython3
# Load the saved ppc data
# ppc_sdata = pd.read_csv('./ppc_data.csv')
.. code:: ipython3
_ = hddm.plotting.plot_ppc_choice_rlssm(data, ppc_sdata, 40, 10)
.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_20_0.png


.. code:: ipython3
_ = hddm.plotting.plot_ppc_rt_rlssm(data, ppc_sdata, 40, 0.06)
.. image:: demo_HDDMnnRL_files/demo_HDDMnnRL_21_0.png


Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 0c79cc3

Please sign in to comment.