From a7c6f6d4d84c7241d4aab0e22797ab4e83032fb9 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 30 Sep 2023 13:38:52 +0200 Subject: [PATCH] fix example in README.md --- README.md | 12 ++++++------ src/dilax/optimizer.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 287fb27..7aa24a6 100644 --- a/README.md +++ b/README.md @@ -70,20 +70,20 @@ nll = NLL(model=model, observation=jnp.array([64.0])) # jit it! fast_nll = eqx.filter_jit(nll) -# setup fit +# setup fit: initial values of parameters and a suitable optimizer init_values = model.parameter_values -optimizer = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10}) +optimizer = JaxOptimizer.make(name="ScipyMinimize", settings={"method": "trust-constr"}) # fit values, state = optimizer.fit(fun=fast_nll, init_values=init_values) print(values) -# -> {'mu': Array([1.39171364], dtype=float64), -# 'sigma': Array([0.00867292], dtype=float64)} +# -> {'mu': Array([1.4], dtype=float64), +# 'sigma': Array([4.04723836e-14], dtype=float64)} # eval model with fitted values/parameters print(model.update(values=values).evaluate().expectation()) -# -> Array([64.0038656], dtype=float64) +# -> Array([64.], dtype=float64) # gradients - of "prefit" model: @@ -95,7 +95,7 @@ print(fast_grad_nll_prefit({"sigma": jnp.array([0.2])})) postfit_nll = NLL(model=model.update(values=values), observation=jnp.array([64.0])) fast_grad_nll_postfit = eqx.filter_grad(eqx.filter_jit(postfit_nll)) print(fast_grad_nll_postfit({"sigma": jnp.array([0.2])})) -# -> {'sigma': Array([0.49084036], dtype=float64)} +# -> {'sigma': Array([0.5030303], dtype=float64)} ``` ## Contributing diff --git a/src/dilax/optimizer.py b/src/dilax/optimizer.py index effe79b..b267022 100644 --- a/src/dilax/optimizer.py +++ b/src/dilax/optimizer.py @@ -1,12 +1,14 @@ from __future__ import annotations from collections.abc import Hashable -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, cast import equinox as eqx import jax import jaxopt +from dilax.util import Sentinel, _NoValue + class JaxOptimizer(eqx.Module): """ @@ -32,8 +34,14 @@ def __init__(self, name: str, _settings: tuple[tuple[str, Hashable], ...]) -> No @classmethod def make( - cls: type[JaxOptimizer], name: str, settings: dict[str, Hashable] + cls: type[JaxOptimizer], + name: str, + settings: dict[str, Hashable] | Sentinel = _NoValue, ) -> JaxOptimizer: + if settings is _NoValue: + settings = {} + if TYPE_CHECKING: + settings = cast(dict[str, Hashable], settings) return cls(name=name, _settings=tuple(settings.items())) @property