Skip to content

Commit

Permalink
'jax.tree_util' -> 'jtu'
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Dec 1, 2023
1 parent 7a0b815 commit 280645f
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions src/dilax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from dilax.parameter import Parameter
from dilax.util import Sentinel, _NoValue, deep_update
Expand All @@ -31,7 +32,7 @@ def add(self, process: str, expectation: jax.Array) -> Result:
return self

def expectation(self) -> jax.Array:
return cast(jax.Array, sum(jax.tree_util.tree_leaves(self.expectations)))
return cast(jax.Array, sum(jtu.tree_leaves(self.expectations)))


def _is_parameter(leaf: Any) -> bool:
Expand Down Expand Up @@ -120,7 +121,7 @@ def __init__(

@property
def parameter_values(self) -> dict:
return jax.tree_util.tree_map(
return jtu.tree_map(
lambda l: l.value, # noqa: E741
self.parameters,
is_leaf=_is_parameter,
Expand All @@ -135,7 +136,7 @@ def _constraint(param: Parameter) -> jax.Array:
return next(iter(param.constraints)).logpdf(param.value)
return jnp.array([0.0])

return jax.tree_util.tree_map(
return jtu.tree_map(
_constraint,
self.parameters,
is_leaf=_is_parameter,
Expand All @@ -160,9 +161,7 @@ def update(

# patch original parameters with new ones
_updates = deep_update(
jax.tree_util.tree_map(
lambda _: None, self.parameters, is_leaf=_is_parameter
),
jtu.tree_map(lambda _: None, self.parameters, is_leaf=_is_parameter),
values,
)

Expand All @@ -171,7 +170,7 @@ def _update_params(update: jax.Array | None, param: Parameter) -> Parameter:
return param
return param.update(value=update)

new_parameters = jax.tree_util.tree_map(
new_parameters = jtu.tree_map(
_update_params,
_updates,
self.parameters,
Expand All @@ -183,12 +182,17 @@ def _update_params(update: jax.Array | None, param: Parameter) -> Parameter:
)

def nll_boundary_penalty(self) -> jax.Array:
params = jax.tree_util.tree_leaves(self.parameters, is_leaf=_is_parameter)

return sum(
jax.tree_util.tree_map(
lambda p: p.boundary_penalty, params, is_leaf=_is_parameter
)
return cast(
jax.Array,
sum(
jtu.tree_leaves(
jtu.tree_map(
lambda p: p.boundary_penalty,
self.parameters,
is_leaf=_is_parameter,
)
)
),
)

@abc.abstractmethod
Expand Down

0 comments on commit 280645f

Please sign in to comment.