Skip to content

Commit

Permalink
update versions, run through tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler committed Dec 28, 2024
1 parent 9be5bb8 commit 9a24cb2
Show file tree
Hide file tree
Showing 12 changed files with 8,563 additions and 7,154 deletions.
1,022 changes: 511 additions & 511 deletions docs/tutorials/blackbox_contribution_onnx_example.ipynb

Large diffs are not rendered by default.

123 changes: 51 additions & 72 deletions docs/tutorials/compile_logp.ipynb

Large diffs are not rendered by default.

129 changes: 32 additions & 97 deletions docs/tutorials/initial_values.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,16 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setting PyTensor floatX type to float32.\n",
"Setting \"jax_enable_x64\" to False. If this is not intended, please set `jax` to False.\n"
]
}
],
"outputs": [],
"source": [
"import warnings\n",
"\n",
"import numpy as np\n",
"\n",
"import hssm\n",
"\n",
"hssm.set_floatX(\"float32\")\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
Expand Down Expand Up @@ -69,7 +59,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -85,14 +75,14 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'z_interval__': array(0., dtype=float32), 'a_interval__': array(-1.1175871e-07, dtype=float32), 't_log__': array(0.6931472, dtype=float32), 'v': array(0., dtype=float32)}\n"
"Model initialized successfully.\n"
]
}
],
Expand All @@ -101,6 +91,7 @@
" data=cav_data,\n",
" model=\"ddm\",\n",
" loglik_kind=\"approx_differentiable\",\n",
" model_config={\"backend\": \"pytensor\"},\n",
")"
]
},
Expand All @@ -113,19 +104,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'z': array(0.5, dtype=float32),\n",
" 'a': array(1.5, dtype=float32),\n",
" 't': array(0.025, dtype=float32),\n",
" 'v': array(0., dtype=float32)}"
"{'a': array(1.5), 'z': array(0.5), 't': array(0.025), 'v': array(0.)}"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -149,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -161,94 +149,31 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.\n"
"Using default initvals. \n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b83aab5a5b545c69100111c82f52123",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a73c530664cf480aa8fdd770c87611f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5bfd44ee15814452bc334013fb08ace3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30db31b9dc60490390fd583306a5247d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"There were 40 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"The number of samples is too small to check convergence reliably.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CLEANING RESULTS\n",
"MAIN CLEANUP LOOP\n",
"RUNNING COMPONENT <bambi.backend.model_components.DistributionalComponent object at 0x2a253b7d0>\n",
"PERFORMING PREDICTION\n"
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [a, z, t, v]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4ca23b047ac7493989f9d17dc7260684",
"model_id": "e3e5c918329f4fcdbdb91096727351e3",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -281,10 +206,20 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 100 tune and 100 draw iterations (400 + 400 draws total) took 34 seconds.\n",
"The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n",
"The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n",
"100%|██████████| 400/400 [00:09<00:00, 43.43it/s]\n"
]
}
],
"source": [
"idata = model.sample(draws=10, tune=1, initvals=my_initvals)"
"idata = model.sample(draws=100, tune=100, sampler=\"mcmc\") # initvals=my_initvals)"
]
},
{
Expand Down
402 changes: 203 additions & 199 deletions docs/tutorials/jax_callable_contribution_onnx_example.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 9a24cb2

Please sign in to comment.