Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Cramer-rao uncertainties + covariance using autodiff to non-minuit fits by default #2269

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ htmlcov

# text editors
.vscode/

venv/
114 changes: 100 additions & 14 deletions src/pyhf/optimize/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from pyhf.optimize.common import shim
from pyhf.tensor.manager import get_backend


log = logging.getLogger(__name__)

__all__ = ("OptimizerMixin",)


class OptimizerMixin:
"""Mixin Class to build optimizers."""
Expand Down Expand Up @@ -50,6 +53,7 @@ def _internal_minimize(
do_grad=do_grad,
par_names=par_names,
)

result = self._minimize(
minimizer,
func,
Expand All @@ -67,7 +71,18 @@ def _internal_minimize(
raise exceptions.FailedMinimization(result)
return result

def _internal_postprocess(self, fitresult, stitch_pars, return_uncertainties=False):
def _internal_postprocess(
self,
fitresult,
stitch_pars,
using_minuit,
return_uncertainties=False,
uncertainties=None,
hess_inv=None,
calc_correlations=False,
fixed_vals=None,
init_pars=None,
):
"""
Post-process the fit result.

Expand All @@ -80,17 +95,29 @@ def _internal_postprocess(self, fitresult, stitch_pars, return_uncertainties=Fal
fitted_pars = stitch_pars(tensorlib.astensor(fitresult.x))

# check if uncertainties were provided (and stitch just in case)
uncertainties = getattr(fitresult, 'unc', None)
uncertainties = getattr(fitresult, 'unc', None) or uncertainties
if uncertainties is not None:
# extract number of fixed parameters
num_fixed_pars = len(fitted_pars) - len(fitresult.x)

# FIXME: Set uncertainties for fixed parameters to 0 manually
# https://github.com/scikit-hep/iminuit/issues/762
# https://github.com/scikit-hep/pyhf/issues/1918
# https://github.com/scikit-hep/cabinetry/pull/346
uncertainties = np.where(fitresult.minuit.fixed, 0.0, uncertainties)

# Set uncertainties for fixed parameters to 0 manually
if fixed_vals is not None: # check for fixed vals
if using_minuit:
# See related discussion here:
# https://github.com/scikit-hep/iminuit/issues/762
# https://github.com/scikit-hep/pyhf/issues/1918
# https://github.com/scikit-hep/cabinetry/pull/346
uncertainties = np.where(fitresult.minuit.fixed, 0.0, uncertainties)
else:
# Not using minuit, so don't have `fitresult.minuit.fixed` here: do it manually
fixed_bools = [False] * len(init_pars)
for index, _ in fixed_vals:
fixed_bools[index] = True
uncertainties = tensorlib.where(
tensorlib.astensor(fixed_bools, dtype="bool"),
tensorlib.astensor(0.0),
uncertainties,
)
# stitch in zero-uncertainty for fixed values
uncertainties = stitch_pars(
tensorlib.astensor(uncertainties),
Expand All @@ -99,24 +126,57 @@ def _internal_postprocess(self, fitresult, stitch_pars, return_uncertainties=Fal
if return_uncertainties:
fitted_pars = tensorlib.stack([fitted_pars, uncertainties], axis=1)

correlations = getattr(fitresult, 'corr', None)
if correlations is not None:
cov = getattr(fitresult, 'hess_inv', None)
if cov is None and hess_inv is not None:
cov = hess_inv

# we also need to edit the covariance matrix to zero-out uncertainties!
# NOTE: minuit already does this (https://github.com/scikit-hep/iminuit/issues/762#issuecomment-1207436406)
if fixed_vals is not None and not using_minuit:
fixed_bools = [False] * len(init_pars)
# Convert fixed_bools to a numpy array and reshape to make it a column vector
fixed_mask = tensorlib.reshape(
tensorlib.astensor(fixed_bools, dtype="bool"), (-1, 1)
)
# Create 2D masks for rows and columns
row_mask = fixed_mask
col_mask = tensorlib.transpose(fixed_mask)

# Use logical OR to combine the masks
final_mask = row_mask | col_mask

# Use np.where to set elements of the covariance matrix to 0 where the mask is True
cov = tensorlib.where(
final_mask, tensorlib.astensor(0.0), tensorlib.astensor(cov)
)

correlations_from_fit = getattr(fitresult, 'corr', None)
if correlations_from_fit is None and calc_correlations:
correlations_from_fit = cov / tensorlib.outer(uncertainties, uncertainties)
correlations_from_fit = tensorlib.where(
tensorlib.isfinite(correlations_from_fit),
correlations_from_fit,
tensorlib.astensor(0.0),
)

if correlations_from_fit is not None and not using_minuit:
_zeros = tensorlib.zeros(num_fixed_pars)
# possibly a more elegant way to do this
stitched_columns = [
stitch_pars(tensorlib.astensor(column), stitch_with=_zeros)
for column in zip(*correlations)
for column in zip(*correlations_from_fit)
]
stitched_rows = [
stitch_pars(tensorlib.astensor(row), stitch_with=_zeros)
for row in zip(*stitched_columns)
]
correlations = tensorlib.stack(stitched_rows, axis=1)
correlations_from_fit = tensorlib.stack(stitched_rows, axis=1)

fitresult.x = fitted_pars
fitresult.fun = tensorlib.astensor(fitresult.fun)
fitresult.unc = uncertainties
fitresult.corr = correlations
fitresult.hess_inv = cov
fitresult.corr = correlations_from_fit

return fitresult

Expand Down Expand Up @@ -164,6 +224,10 @@ def minimize(
- minimum (:obj:`float`): if ``return_fitted_val`` flagged, return minimized objective value
- result (:class:`scipy.optimize.OptimizeResult`): if ``return_result_obj`` flagged
"""
# literally just for the minimizer name to check if we're using minuit
# so we can check if valid for uncertainty calc later
using_minuit = hasattr(self, "name") and self.name == "minuit"

# Configure do_grad based on backend "automagically" if not set by user
tensorlib, _ = get_backend()
do_grad = tensorlib.default_do_grad if do_grad is None else do_grad
Expand Down Expand Up @@ -194,8 +258,30 @@ def minimize(
result = self._internal_minimize(
**minimizer_kwargs, options=kwargs, par_names=par_names
)

# compute uncertainties with automatic differentiation
if not using_minuit and tensorlib.name in ['tensorflow', 'jax', 'pytorch']:
# stitch in missing parameters (e.g. fixed parameters)
all_pars = stitch_pars(tensorlib.astensor(result.x))
hess_inv = tensorlib.fisher_cov(pdf, all_pars, data)
uncertainties = tensorlib.sqrt(tensorlib.diagonal(hess_inv))
calc_correlations = True
else:
hess_inv = None
uncertainties = None
calc_correlations = False

# uncerts are set to 0 in here for fixed pars
result = self._internal_postprocess(
result, stitch_pars, return_uncertainties=return_uncertainties
result,
stitch_pars,
using_minuit,
return_uncertainties=return_uncertainties,
uncertainties=uncertainties,
hess_inv=hess_inv,
calc_correlations=calc_correlations,
fixed_vals=fixed_vals,
init_pars=init_pars,
)

_returns = [result.x]
Expand Down
2 changes: 2 additions & 0 deletions src/pyhf/optimize/opt_minuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import scipy
import iminuit

__all__ = ("minuit_optimizer",)


class minuit_optimizer(OptimizerMixin):
"""
Expand Down
40 changes: 39 additions & 1 deletion src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

config.update('jax_enable_x64', True)

from jax import Array
from jax import Array, hessian
import jax.numpy as jnp
from jax.scipy.special import gammaln, xlogy
from jax.scipy import special
Expand Down Expand Up @@ -622,3 +622,41 @@ def transpose(self, tensor_in):
.. versionadded:: 0.7.0
"""
return tensor_in.transpose()

def fisher_cov(self, model, pars, data):
"""Calculates the inverse of the Fisher information matrix to estimate the covariance of the maximum likelihood estimate.
See the Cramér-Rao bound for more details on the derivation of this.

Args:
model (:obj:`pyhf.pdf.Model`): The statistical model adhering to the schema ``model.json``.
pars (:obj:`tensor`): The (mle) model parameters at which to evaluate the uncertainty.
data (:obj:`tensor`): The observed data.

Returns:
JAX ndarray: The covariance matrix of the maximum likelihood estimate.
"""
return jnp.linalg.inv(
-hessian(lambda pars, data: model.logpdf(pars, data)[0])(pars, data)
)

def diagonal(self, tensor_in):
"""
Return the diagonal elements of the tensor.

Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 0.0], [0.0, 1.0]])
>>> tensor
Array([[1., 0.],
[0., 1.]], dtype=float64)
>>> pyhf.tensorlib.diagonal(tensor)
Array([1., 1.], dtype=float64)

Args:
tensor_in (:obj:`tensor`): The input tensor object.

Returns:
JAX ndarray: The diagonal of the input tensor.
"""
return jnp.diag(tensor_in)
35 changes: 34 additions & 1 deletion src/pyhf/tensor/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable, Generic, Mapping, Sequence, TypeVar, Union
from typing import (
TYPE_CHECKING,
Callable,
Generic,
Mapping,
Sequence,
TypeVar,
Union,
Any,
)

import numpy as np

Expand Down Expand Up @@ -648,3 +657,27 @@ def transpose(self, tensor_in: Tensor[T]) -> ArrayLike:
.. versionadded:: 0.7.0
"""
return tensor_in.transpose()

def fisher_cov(self, model: Any, pars: Tensor[T], data: Tensor[T]) -> ArrayLike:
raise NotImplementedError

def diagonal(self, tensor_in: Tensor[T]) -> ArrayLike:
"""Return the diagonal elements of the tensor.

Example:
>>> import pyhf
>>> pyhf.set_backend("numpy")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 0.0], [0.0, 1.0]])
>>> tensor
array([[1., 0.],
[0., 1.]])
>>> pyhf.tensorlib.diagonal(tensor)
array([1., 1.])

Args:
tensor_in (:obj:`tensor`): The input tensor object.

Returns:
:class:`numpy.ndarray`: The diagonal of the input tensor.
"""
return np.diag(tensor_in)
39 changes: 39 additions & 0 deletions src/pyhf/tensor/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.autograd
from torch.func import hessian
from torch.distributions.utils import broadcast_all
import logging
import math
Expand Down Expand Up @@ -626,3 +627,41 @@ def transpose(self, tensor_in):
.. versionadded:: 0.7.0
"""
return tensor_in.transpose(0, 1)

def fisher_cov(self, model, pars, data):
"""Calculates the inverse of the Fisher information matrix to estimate the covariance of the maximum likelihood estimate.
See the Cramér-Rao bound for more details on the derivation of this.

Args:
model (:obj:`pyhf.pdf.Model`): The statistical model adhering to the schema ``model.json``.
pars (:obj:`tensor`): The (mle) model parameters at which to evaluate the uncertainty.
data (:obj:`tensor`): The observed data.

Returns:
PyTorch FloatTensor: The covariance matrix of the maximum likelihood estimate.
"""
return torch.linalg.inv(
-hessian(lambda pars, data: model.logpdf(pars, data)[0])(pars, data)
)

def diagonal(self, tensor_in):
"""
Return the diagonal elements of the tensor.

Example:
>>> import pyhf
>>> pyhf.set_backend("pytorch")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 0.0], [0.0, 1.0]])
>>> tensor
tensor([[1., 0.],
[0., 1.]])
>>> pyhf.tensorlib.diagonal(tensor)
tensor([1., 1.])

Args:
tensor_in (:obj:`tensor`): The input tensor object.

Returns:
PyTorch FloatTensor: The diagonal of the input tensor.
"""
return torch.diagonal(tensor_in)
44 changes: 44 additions & 0 deletions src/pyhf/tensor/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,47 @@ def transpose(self, tensor_in):
.. versionadded:: 0.7.0
"""
return tf.transpose(tensor_in)

def fisher_cov(self, model, pars, data):
"""Calculates the inverse of the Fisher information matrix to estimate the covariance of the maximum likelihood estimate.
See the Cramér-Rao bound for more details on the derivation of this.

Args:
model (:obj:`pyhf.pdf.Model`): The statistical model adhering to the schema ``model.json``.
pars (:obj:`tensor`): The (mle) model parameters at which to evaluate the uncertainty.
data (:obj:`tensor`): The observed data.

Returns:
TensorFlow Tensor: The covariance matrix of the maximum likelihood estimate.
"""
with tf.GradientTape() as t2:
t2.watch(pars)
with tf.GradientTape() as t1:
t1.watch(pars)
lhood = model.logpdf(pars, data)[0]
g = t1.gradient(lhood, pars)
hess = t2.jacobian(g, pars)
return tf.linalg.inv(-hess)

def diagonal(self, tensor_in):
"""Return the diagonal elements of the tensor.

Example:
>>> import pyhf
>>> pyhf.set_backend("tensorflow")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 0.0], [0.0, 1.0]])
>>> tensor
<tf.Tensor: shape=(2, 2), dtype=float64, numpy=
array([[1., 0.],
[0., 1.]])>
>>> pyhf.tensorlib.diagonal(tensor)
<tf.Tensor: shape=(2), dtype=float64, numpy=array([1., 1.])

Args:
tensor_in (:obj:`tensor`): The input tensor object.

Returns:
TensorFlow Tensor: The diagonal elements of the input tensor.

"""
return tf.linalg.diag_part(tensor_in)
Loading
Loading