Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds JAX IDAKLU solver integration #481

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1e08d6e
working commit, adds [jax] optional dependency to pybamm, example ida…
BradyPlanden Aug 20, 2024
79234f9
adds experimental sub-directory, sync commit
BradyPlanden Aug 20, 2024
7d8aa9c
Merge branch 'refs/heads/develop' into jaxify-idaklu-implementation
BradyPlanden Aug 20, 2024
ab861a9
Merge branch 'refs/heads/develop' into jaxify-idaklu-implementation
BradyPlanden Aug 28, 2024
f2bf143
feat: adds Jax functionality via idaklu_jax solver, with example and …
BradyPlanden Aug 31, 2024
025c7d8
Merge branch 'refs/heads/develop' into jaxify-idaklu-implementation
BradyPlanden Sep 2, 2024
d3a18d5
Post merge fixes, update base_model solver property
BradyPlanden Sep 2, 2024
eb08c00
examples: update benchmarking script
BradyPlanden Sep 2, 2024
92265b6
Merge branch 'develop' into jaxify-idaklu-implementation
BradyPlanden Sep 6, 2024
f1d03a9
Merge branch 'refs/heads/develop' into jaxify-idaklu-implementation
BradyPlanden Sep 24, 2024
3ecd5dd
updts docstrings, jax arg to models, adds tests.
BradyPlanden Sep 24, 2024
9794910
add coverage, refactor BaseModel solver setter
BradyPlanden Sep 25, 2024
df1ba0c
adds coverage
BradyPlanden Sep 25, 2024
d770315
convert BaseModel.calculate_sensitivites to property
BradyPlanden Sep 25, 2024
daae4ba
add changelog entry
BradyPlanden Sep 25, 2024
424699b
feat: removes Jax arg from BaseModel, BaseProblem.model is copied ins…
BradyPlanden Sep 30, 2024
2695fd7
Merge branch 'develop' into jaxify-idaklu-implementation
BradyPlanden Oct 11, 2024
23b8e3d
fix: cost shape for CMAES, adds example comments for jax solver
BradyPlanden Oct 21, 2024
bba2e56
Merge branch 'develop' into jaxify-idaklu-implementation
BradyPlanden Oct 21, 2024
f7672ce
fix: update model reparameterisation as model is now copied within Ba…
BradyPlanden Oct 21, 2024
0549d14
Merge branch 'develop' into jaxify-idaklu-implementation
BradyPlanden Oct 23, 2024
44e3095
Merge branch 'develop' into jaxify-idaklu-implementation
BradyPlanden Oct 23, 2024
e1b99d7
Merge branch 'develop' into jaxify-idaklu-implementation
BradyPlanden Nov 2, 2024
8f11efe
refactor: Jax implementation, FittingProblem.evaluate, adds Fisher In…
BradyPlanden Nov 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- [#481](https://github.com/pybop-team/PyBOP/pull/481) - Adds experimental support for PyBaMM's jaxified IDAKLU solver. Includes Jax-specific cost functions `pybop.JaxSumSquareError` and `pybop.JaxLogNormalLikelihood`. Adds `Jax` optional dependency to PyBaMM dependency.
- [#452](https://github.com/pybop-team/PyBOP/issues/452) - Extends `cell_mass` and `approximate_capacity` for half-cell models.
- [#544](https://github.com/pybop-team/PyBOP/issues/544) - Allows iterative plotting using `StandardPlot`.
- [#541](https://github.com/pybop-team/PyBOP/pull/541) - Adds `ScaledLogLikelihood` and `BaseMetaLikelihood` classes.
Expand Down
92 changes: 46 additions & 46 deletions examples/notebooks/single_pulse_circuit_model.ipynb

Large diffs are not rendered by default.

66 changes: 66 additions & 0 deletions examples/scripts/jax-solver-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import pybamm

import pybop

# Parameter set and model definition
parameter_set = pybop.ParameterSet.pybamm("Chen2020")

# The IDAKLU, and it's jaxified version perform very well on the DFN with and without
# gradient calculations
solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solver)

# Fitting parameters
parameters = pybop.Parameters(
pybop.Parameter(
"Negative electrode active material volume fraction",
initial_value=0.55,
bounds=[0.5, 0.8],
),
pybop.Parameter(
"Positive electrode active material volume fraction",
initial_value=0.55,
bounds=[0.5, 0.8],
),
)

# Define test protocol and generate data
t_eval = np.linspace(0, 600, 600)
values = model.predict(
initial_state={"Initial open-circuit voltage [V]": 4.2}, t_eval=t_eval
)

# Form dataset
dataset = pybop.Dataset(
{
"Time [s]": values["Time [s]"].data,
"Current function [A]": values["Current [A]"].data,
"Voltage [V]": values["Voltage [V]"].data,
}
)

# Construct the Problem
problem = pybop.FittingProblem(model, parameters, dataset)

# By selecting a Jax based cost function, the IDAKLU solver will be
# jaxified (wrapped in a Jax compiled expression) and used for optimisation
cost = pybop.JaxLogNormalLikelihood(problem, sigma0=2e-3)

# Non-gradient optimiser, change to `pybop.AdamW` for gradient-based example
optim = pybop.IRPropMin(
cost,
max_unchanged_iterations=20,
max_iterations=100,
)

results = optim.run()

# Plot convergence
pybop.plot.convergence(optim)

# Plot parameter trace
pybop.plot.parameters(optim)

# Plot voronoi optimiser surface
pybop.plot.surface(optim)
106 changes: 106 additions & 0 deletions examples/scripts/jaxified-idaklu-benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import time

import numpy as np
import pybamm

import pybop

n = 1 # Number of solves
solvers = [
pybamm.CasadiSolver(mode="fast with events", atol=1e-6, rtol=1e-6),
pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6),
]

# Parameter set and model definition
parameter_set = pybop.ParameterSet.pybamm("Chen2020")
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solvers[0])

# Fitting parameters
parameters = pybop.Parameters(
pybop.Parameter(
"Negative electrode active material volume fraction", initial_value=0.55
),
pybop.Parameter(
"Positive electrode active material volume fraction", initial_value=0.55
),
)

# Define test protocol and generate data
t_eval = np.linspace(0, 100, 1000)
values = model.predict(
initial_state={"Initial open-circuit voltage [V]": 4.2}, t_eval=t_eval
)

# Form dataset
dataset = pybop.Dataset(
{
"Time [s]": values["Time [s]"].data,
"Current function [A]": values["Current [A]"].data,
"Voltage [V]": values["Voltage [V]"].data,
}
)


# Create inputs function for benchmarking
def inputs():
return {
"Negative electrode active material volume fraction": 0.55
+ np.random.normal(0, 0.01),
"Positive electrode active material volume fraction": 0.55
+ np.random.normal(0, 0.01),
}


# Iterate over the solvers and print benchmarks
for solver in solvers:
# Setup Fitting Problem
model.solver = solver
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.SumSquaredError(problem)

start_time = time.time()
for _i in range(n):
out = problem.model.simulate(inputs=inputs(), t_eval=t_eval)
print(f"({solver.name}) Time model.simulate: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
out = problem.model.simulateS1(inputs=inputs(), t_eval=t_eval)
print(f"({solver.name}) Time model.SimulateS1: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
out = problem.evaluate(inputs=inputs())
print(f"({solver.name}) Time problem.evaluate: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
out = problem.evaluateS1(inputs=inputs())
print(f"({solver.name}) Time Problem.EvaluateS1: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=False)
print(f"({solver.name}) Time PyBOP Cost w/o grad: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=True)
print(f"({solver.name}) Time PyBOP Cost w/grad: {time.time() - start_time:.4f}")

# Recreate for Jax IDAKLU solver
ida_solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=ida_solver)
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.JaxSumSquaredError(problem)

# Jaxified benchmarks
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=False)
print(f"Time Jax SumSquaredError w/o grad: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=True)
print(f"Time Jax SumSquaredError w/ grad: {time.time() - start_time:.4f}")
11 changes: 7 additions & 4 deletions examples/scripts/maximum_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pybamm

import pybop

Expand All @@ -7,10 +8,12 @@
parameter_set.update(
{
"Negative electrode active material volume fraction": 0.63,
"Positive electrode active material volume fraction": 0.51,
"Positive electrode active material volume fraction": 0.62,
}
)
model = pybop.lithium_ion.SPM(parameter_set=parameter_set)
options = {"max_num_steps": int(1e6), "max_error_test_failures": 60}
solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6, options=options)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solver)

# Fitting parameters
parameters = pybop.Parameters(
Expand Down Expand Up @@ -57,8 +60,8 @@ def noise(sigma):
signal = ["Voltage [V]", "Bulk open-circuit voltage [V]"]
# Generate problem, cost function, and optimisation class
problem = pybop.FittingProblem(model, parameters, dataset, signal=signal)
likelihood = pybop.GaussianLogLikelihood(problem, sigma0=sigma * 4)
optim = pybop.IRPropMin(
likelihood = pybop.JaxGaussianLogLikelihoodKnownSigma(problem, sigma0=sigma)
optim = pybop.XNES(
likelihood,
max_unchanged_iterations=20,
min_iterations=20,
Expand Down
5 changes: 5 additions & 0 deletions pybop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@
)
from .costs._weighted_cost import WeightedCost

#
# Experimental
#
from .experimental import BaseJaxCost, JaxSumSquaredError, JaxLogNormalLikelihood, JaxGaussianLogLikelihoodKnownSigma

#
# Optimiser classes
#
Expand Down
1 change: 1 addition & 0 deletions pybop/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .jax_costs import BaseJaxCost, JaxLogNormalLikelihood, JaxSumSquaredError, JaxGaussianLogLikelihoodKnownSigma
166 changes: 166 additions & 0 deletions pybop/experimental/jax_costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from typing import Union

import jax
import jax.numpy as jnp
import numpy as np
from pybamm import IDAKLUSolver

from pybop import BaseCost, BaseLikelihood, BaseProblem, Inputs


class BaseJaxCost(BaseCost):
"""
Jax-based Sum of Squared Error cost function.
"""

def __init__(self, problem: BaseProblem):
super().__init__(problem)
self.model = self.problem.model
self.n_data = self.problem.n_data
if isinstance(self.model.solver, IDAKLUSolver):
self.model.jaxify_solver(t_eval=self.problem.domain_data)

def __call__(
self,
inputs: Inputs,
calculate_grad: bool = False,
apply_transform: bool = False,
) -> Union[np.array, tuple[float, np.ndarray]]:
"""
Computes the cost function for the given predictions.

Parameters
----------
y : dict
The dictionary of predictions with keys designating the signals for fitting.
dy : np.ndarray, optional
The corresponding gradient with respect to the parameters for each signal.
calculate_grad : bool, optional
A bool condition designating whether to calculate the gradient.

Returns
-------
float
The Sum of Squared Error.
"""
inputs = self.parameters.verify(inputs)
if calculate_grad != self.model.calculate_sensitivities:
self._update_solver_sensitivities(calculate_grad)

if calculate_grad:
y, dy = jax.value_and_grad(self.evaluate)(inputs)
return y, np.asarray(
list(dy.values())
) # Convert grad to numpy for optimisers
else:
return np.asarray(self.evaluate(inputs))

def _update_solver_sensitivities(self, calculate_grad: bool) -> None:
"""
Updates the solver's sensitivity calculation based on the gradient requirement.

Args:
calculate_grad (bool): Whether gradient calculation is required.
"""

self.model.jaxify_solver(
t_eval=self.problem.domain_data, calculate_sensitivities=calculate_grad
)

@staticmethod
def check_sigma0(sigma0):
if not isinstance(sigma0, (int, float)) or sigma0 <= 0:
raise ValueError("sigma0 must be a positive number")
return float(sigma0)

def observed_fisher(self, inputs: Inputs):
"""
Compute the observed fisher information matrix (FIM)
for the given inputs. This is done with the gradient
as the Hessian is not available.
"""
_, grad = self.__call__(inputs, calculate_grad=True)
return jnp.square(grad) / self.n_data

Check warning on line 83 in pybop/experimental/jax_costs.py

View check run for this annotation

Codecov / codecov/patch

pybop/experimental/jax_costs.py#L82-L83

Added lines #L82 - L83 were not covered by tests


class JaxSumSquaredError(BaseJaxCost):
"""
Jax-based Sum of Squared Error cost function.
"""

def __init__(self, problem: BaseProblem):
super().__init__(problem)

def evaluate(self, inputs):
# Calculate residuals and error
y = self.problem.evaluate(inputs)
r = jnp.asarray([y[s] - self._target[s] for s in self.signal])
return jnp.sum(r**2)


class JaxLogNormalLikelihood(BaseJaxCost, BaseLikelihood):
"""
A Log-Normal Likelihood function. This function represents the
underlining observed data sampled from a Log-Normal distribution.

Parameters
-----------
problem: BaseProblem
The problem to fit of type `pybop.BaseProblem`
sigma0: float, optional
The variance in the measured data
"""

def __init__(self, problem: BaseProblem, sigma0=0.02):
super().__init__(problem)
self.sigma = self.check_sigma0(sigma0)
self.sigma2 = jnp.square(self.sigma)
self._offset = 0.5 * self.n_data * jnp.log(2 * jnp.pi)
self._target_as_array = jnp.asarray([self._target[s] for s in self.signal])
self._log_target_sum = jnp.sum(jnp.log(self._target_as_array))
self._precompute()

def _precompute(self):
self._constant_term = (
-self._offset - self.n_data * jnp.log(self.sigma) - self._log_target_sum
)

def evaluate(self, inputs):
"""
Evaluates the log-normal likelihood.
"""
y = self.problem.evaluate(inputs)
e = jnp.asarray([jnp.log(y[s]) - jnp.log(self._target[s]) for s in self.signal])
likelihood = self._constant_term - jnp.sum(jnp.square(e)) / (2 * self.sigma2)
return likelihood


class JaxGaussianLogLikelihoodKnownSigma(BaseJaxCost, BaseLikelihood):
"""
A Jax implementation of the Gaussian Likelihood function.
This function represents the underlining observed data sampled
from a Gaussian distribution with known noise, `sigma0`.

Parameters
-----------
problem: BaseProblem
The problem to fit of type `pybop.BaseProblem`
sigma0: float, optional
The variance in the measured data
"""

def __init__(self, problem: BaseProblem, sigma0=0.02):
super().__init__(problem)
self.sigma = self.check_sigma0(sigma0)
self.sigma2 = jnp.square(self.sigma)
self._offset = -0.5 * self.n_data * jnp.log(2 * jnp.pi * self.sigma2)
self._multip = -1 / (2.0 * self.sigma2)

Check warning on line 157 in pybop/experimental/jax_costs.py

View check run for this annotation

Codecov / codecov/patch

pybop/experimental/jax_costs.py#L153-L157

Added lines #L153 - L157 were not covered by tests

def evaluate(self, inputs):
"""
Evaluates the log-normal likelihood.
"""
y = self.problem.evaluate(inputs)
e = jnp.asarray([y[s] - self._target[s] for s in self.signal])
likelihood = jnp.sum(self._offset + self._multip * jnp.sum(jnp.square(e)))
return likelihood

Check warning on line 166 in pybop/experimental/jax_costs.py

View check run for this annotation

Codecov / codecov/patch

pybop/experimental/jax_costs.py#L163-L166

Added lines #L163 - L166 were not covered by tests
Loading
Loading