diff --git a/src/hssm/param.py b/src/hssm/param.py index 606d7342..9a4b95d7 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -6,11 +6,11 @@ import bambi as bmb import numpy as np import pandas as pd +from deepcopy import deepcopy from formulae import design_matrices from .link import Link from .prior import Prior, get_default_prior, get_hddm_default_prior -from .utils import merge_dicts # PEP604 union operator "|" not supported by pylint # Fall back to old syntax @@ -648,3 +648,14 @@ def _make_default_prior(bounds: tuple[float, float]) -> bmb.Prior: return bmb.Prior("TruncatedNormal", mu=lower, lower=lower, sigma=2.0) else: return bmb.Prior(name="Uniform", lower=lower, upper=upper) + + +def merge_dicts(dict1: dict, dict2: dict) -> dict: + """Recursively merge two dictionaries.""" + merged = deepcopy(dict1) + for key, value in dict2.items(): + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): + merged[key] = merge_dicts(merged[key], value) + else: + merged[key] = value + return merged diff --git a/src/hssm/utils.py b/src/hssm/utils.py index 2adcc1e6..5534aa28 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -10,7 +10,6 @@ """ import logging -from copy import deepcopy from typing import Any, Iterable, Literal, NewType import bambi as bmb @@ -54,17 +53,6 @@ def download_hf(path: str): return hf_hub_download(repo_id=REPO_ID, filename=path) -def merge_dicts(dict1: dict, dict2: dict) -> dict: - """Recursively merge two dictionaries.""" - merged = deepcopy(dict1) - for key, value in dict2.items(): - if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): - merged[key] = merge_dicts(merged[key], value) - else: - merged[key] = value - return merged - - def make_alias_dict_from_parent(parent: Param) -> dict[str, str]: """Make aliases from the parent parameter.