Skip to content

Commit

Permalink
Merge pull request #446 from lnccbrown/444-fix-initialization-with-nans
Browse files Browse the repository at this point in the history
fix key error bug when going through initialization, make safe priors…
  • Loading branch information
AlexanderFengler authored Jun 12, 2024
2 parents 3f7c518 + 353c446 commit 3d2f301
Show file tree
Hide file tree
Showing 7 changed files with 539 additions and 42 deletions.
4 changes: 2 additions & 2 deletions src/hssm/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import os
from typing import NamedTuple, Optional, Union
from typing import NamedTuple, Optional

import pandas as pd

Expand Down Expand Up @@ -34,7 +34,7 @@ class FileMetadata(NamedTuple):
}


def load_data(dataset: Optional[str] = None) -> Union[pd.DataFrame, str]:
def load_data(dataset: Optional[str] = None) -> pd.DataFrame | str:
"""
Load a dataset as a pandas DataFrame.
Expand Down
10 changes: 2 additions & 8 deletions src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -104,7 +103,6 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"bounds": {
Expand All @@ -123,7 +121,6 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -143,7 +140,6 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -155,7 +151,6 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"bounds": {
Expand All @@ -175,7 +170,6 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -195,7 +189,6 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"extra_fields": None,
Expand Down Expand Up @@ -328,6 +321,7 @@ class DefaultConfig(TypedDict):
},
}

# TODO: Initval settings could be specified directly in model config as well.
INITVAL_SETTINGS = {
# logit link function case
# should never use priors with bounds,
Expand All @@ -343,7 +337,7 @@ class DefaultConfig(TypedDict):
},
# identity link function case,
# need to take care of_log__ and _interval__ variables
"None": {
None: {
"t": 0.025,
"t_Intercept": 0.025,
"a": 1.5,
Expand Down
Loading

0 comments on commit 3d2f301

Please sign in to comment.