Skip to content

Commit

Permalink
fix example in README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 30, 2023
1 parent 8ed4d5e commit a7c6f6d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,20 @@ nll = NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit
# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10})
optimizer = JaxOptimizer.make(name="ScipyMinimize", settings={"method": "trust-constr"})

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.39171364], dtype=float64),
# 'sigma': Array([0.00867292], dtype=float64)}
# -> {'mu': Array([1.4], dtype=float64),
# 'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values/parameters
print(model.update(values=values).evaluate().expectation())
# -> Array([64.0038656], dtype=float64)
# -> Array([64.], dtype=float64)


# gradients - of "prefit" model:
Expand All @@ -95,7 +95,7 @@ print(fast_grad_nll_prefit({"sigma": jnp.array([0.2])}))
postfit_nll = NLL(model=model.update(values=values), observation=jnp.array([64.0]))
fast_grad_nll_postfit = eqx.filter_grad(eqx.filter_jit(postfit_nll))
print(fast_grad_nll_postfit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.49084036], dtype=float64)}
# -> {'sigma': Array([0.5030303], dtype=float64)}
```

## Contributing
Expand Down
12 changes: 10 additions & 2 deletions src/dilax/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from collections.abc import Hashable
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable, cast

import equinox as eqx
import jax
import jaxopt

from dilax.util import Sentinel, _NoValue


class JaxOptimizer(eqx.Module):
"""
Expand All @@ -32,8 +34,14 @@ def __init__(self, name: str, _settings: tuple[tuple[str, Hashable], ...]) -> No

@classmethod
def make(
cls: type[JaxOptimizer], name: str, settings: dict[str, Hashable]
cls: type[JaxOptimizer],
name: str,
settings: dict[str, Hashable] | Sentinel = _NoValue,
) -> JaxOptimizer:
if settings is _NoValue:
settings = {}
if TYPE_CHECKING:
settings = cast(dict[str, Hashable], settings)
return cls(name=name, _settings=tuple(settings.items()))

@property
Expand Down

0 comments on commit a7c6f6d

Please sign in to comment.