Skip to content

Commit

Permalink
moved merge_dict to param.py to avoid circular import
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Nov 28, 2023
1 parent 06795f3 commit 0dc986a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
13 changes: 12 additions & 1 deletion src/hssm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import bambi as bmb
import numpy as np
import pandas as pd
from deepcopy import deepcopy
from formulae import design_matrices

from .link import Link
from .prior import Prior, get_default_prior, get_hddm_default_prior
from .utils import merge_dicts

# PEP604 union operator "|" not supported by pylint
# Fall back to old syntax
Expand Down Expand Up @@ -648,3 +648,14 @@ def _make_default_prior(bounds: tuple[float, float]) -> bmb.Prior:
return bmb.Prior("TruncatedNormal", mu=lower, lower=lower, sigma=2.0)
else:
return bmb.Prior(name="Uniform", lower=lower, upper=upper)


def merge_dicts(dict1: dict, dict2: dict) -> dict:
"""Recursively merge two dictionaries."""
merged = deepcopy(dict1)
for key, value in dict2.items():
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = merge_dicts(merged[key], value)
else:
merged[key] = value
return merged
12 changes: 0 additions & 12 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

import logging
from copy import deepcopy
from typing import Any, Iterable, Literal, NewType

import bambi as bmb
Expand Down Expand Up @@ -54,17 +53,6 @@ def download_hf(path: str):
return hf_hub_download(repo_id=REPO_ID, filename=path)


def merge_dicts(dict1: dict, dict2: dict) -> dict:
"""Recursively merge two dictionaries."""
merged = deepcopy(dict1)
for key, value in dict2.items():
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = merge_dicts(merged[key], value)
else:
merged[key] = value
return merged


def make_alias_dict_from_parent(parent: Param) -> dict[str, str]:
"""Make aliases from the parent parameter.
Expand Down

0 comments on commit 0dc986a

Please sign in to comment.