Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce utils.simulate() #348

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions enterprise/signals/gp_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from sksparse.cholmod import cholesky

from enterprise.signals import parameter, selections, signal_base, utils
from enterprise.signals.signal_base import KernelMatrix
from enterprise.signals.parameter import function
from enterprise.signals.selections import Selection
from enterprise.signals.utils import KernelMatrix

# logging.basicConfig(format="%(levelname)s: %(name)s: %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -225,11 +225,11 @@ def get_timing_model_basis(use_svd=False, normed=True):
return utils.unnormed_tm_basis()


def TimingModel(coefficients=False, name="linear_timing_model", use_svd=False, normed=True):
def TimingModel(coefficients=False, name="linear_timing_model", use_svd=False, normed=True, prior_variance=1e40):
"""Class factory for marginalized linear timing model signals."""

basis = get_timing_model_basis(use_svd, normed)
prior = utils.tm_prior()
prior = utils.tm_prior(variance=prior_variance)

BaseClass = BasisGP(prior, basis, coefficients=coefficients, name=name)

Expand Down Expand Up @@ -848,6 +848,7 @@ def MNMMNF(self, T):

# we're ignoring logdet = True for two-dimensional cases, but OK
def solve(self, right, left_array=None, logdet=False):
# compute generalized version of r+ N^-1 r
if right.ndim == 1 and left_array is right:
res = right

Expand All @@ -856,11 +857,13 @@ def solve(self, right, left_array=None, logdet=False):
MNr = self.MNr(res)
ret = rNr - np.dot(MNr, self.cf(MNr))
return (ret, logdet_N + self.cf.logdet() + self.Mprior) if logdet else ret
# compute generalized version of T+ N^-1 r
elif right.ndim == 1 and left_array is not None and left_array.ndim == 2:
res, T = right, left_array

TNr = self.Nmat.solve(res, left_array=T)
return TNr - np.tensordot(self.MNMMNF(T), self.MNr(res), (0, 0))
# compute generalized version of T+ N^-1 T
elif right.ndim == 2 and left_array is right:
T = right

Expand Down
84 changes: 83 additions & 1 deletion enterprise/signals/signal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from enterprise.signals.parameter import Function # noqa: F401
from enterprise.signals.parameter import function # noqa: F401
from enterprise.signals.parameter import ConstantParameter
from enterprise.signals.utils import KernelMatrix

from enterprise import __version__
from sys import version
Expand Down Expand Up @@ -1212,6 +1211,89 @@ def solve(self, other, left_array=None, logdet=False):
return (ret, self._get_logdet()) if logdet else ret


class KernelMatrix(np.ndarray):
def __new__(cls, init):
if isinstance(init, int):
ret = np.zeros(init, "d").view(cls)
else:
ret = init.view(cls)

if ret.ndim == 2:
ret._cliques = -1 * np.ones(ret.shape[0])
ret._clcount = 0

return ret

# see PTA._setcliques
def _setcliques(self, idxs):
allidx = set(self._cliques[idxs])
maxidx = max(allidx)

if maxidx == -1:
self._cliques[idxs] = self._clcount
self._clcount = self._clcount + 1
else:
self._cliques[idxs] = maxidx
if len(allidx) > 1:
self._cliques[np.in1d(self._cliques, allidx)] = maxidx

def add(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] += other
else:
if other.ndim == 1:
self[idx, idx] += other
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] += other

return self

def set(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] = other
else:
if other.ndim == 1:
self[idx, idx] = other
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] = other

return self

def inv(self, logdet=False):
if self.ndim == 1:
inv = 1.0 / self

if logdet:
return inv, np.sum(np.log(self))
else:
return inv
else:
try:
cf = sl.cho_factor(self)
inv = sl.cho_solve(cf, np.identity(cf[0].shape[0]))
if logdet:
ld = 2.0 * np.sum(np.log(np.diag(cf[0])))
except np.linalg.LinAlgError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a # pragma: no cover here?

u, s, v = np.linalg.svd(self)
inv = np.dot(u / s, u.T)
if logdet:
ld = np.sum(np.log(s))
if logdet:
return inv, ld
else:
return inv


class ShermanMorrison(object):
"""Custom container class for Sherman-morrison array inversion."""

Expand Down
155 changes: 70 additions & 85 deletions enterprise/signals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import enterprise
from enterprise import constants as const
from enterprise import signals as sigs # noqa: F401
from enterprise.signals.signal_base import ShermanMorrison
from enterprise.signals.gp_bases import ( # noqa: F401
createfourierdesignmatrix_dm,
createfourierdesignmatrix_env,
Expand All @@ -31,6 +32,73 @@
logger = logging.getLogger(__name__)


def simulate(pta, params, sparse_cholesky=True):
"""Simulate residuals for all pulsars in `pta` by sampling all white-noise
and GP objects for parameters `params`. Requires GPs to have `combine=False`,
and will run faster with GP ECORR. If `pta` includes a `TimingModel`, that
should be created with a small `prior_variance`. This function can be used
with `utils.set_residuals` to replace residuals in a `Pulsar` object.
Note that any PTA built from that `Pulsar` may nevertheless cache residuals
internally, so it is safer to rebuild the PTA with the modified `Pulsar`."""

delays, ndiags, fmats, phis = (
pta.get_delay(params=params),
pta.get_ndiag(params=params),
pta.get_basis(params=params),
pta.get_phi(params=params),
)

gpresiduals = []
if pta._commonsignals:
if sparse_cholesky:
cf = cholesky(sps.csc_matrix(phis))
gp = np.zeros(phis.shape[0])
gp[cf.P()] = np.dot(cf.L().toarray(), np.random.randn(phis.shape[0]))
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a #pragma: no cover

gp = np.dot(sl.cholesky(phis, lower=True), np.random.randn(phis.shape[0]))

i = 0
for fmat in fmats:
j = i + fmat.shape[1]
gpresiduals.append(np.dot(fmat, gp[i:j]))
i = j

assert len(gp) == i
else:
for fmat, phi in zip(fmats, phis):
if phi is None:
gpresiduals.append(0)
elif phi.ndim == 1:
gpresiduals.append(np.dot(fmat, np.sqrt(phi) * np.random.randn(phi.shape[0])))
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a # pragma: no cover here

raise NotImplementedError

whiteresiduals = []
for delay, ndiag in zip(delays, ndiags):
if ndiag is None:
whiteresiduals.append(0)
elif isinstance(ndiag, ShermanMorrison):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a good idea to check for an instance of ShermanMorrison. When 'fastshermanmorrison' is used, this will be a different type. Instead, perhaps duck typing can be used, like:

if all(hasattr(ndiag, attr) for attr in ['_nvec', '_jvec', '_slices']):

# this code is very slow...
n = np.diag(ndiag._nvec)
for j, s in zip(ndiag._jvec, ndiag._slices):
n[s, s] += j
whiteresiduals.append(delay + np.dot(sl.cholesky(n, lower=True), np.random.randn(n.shape[0])))
elif ndiag.ndim == 1:
whiteresiduals.append(delay + np.sqrt(ndiag) * np.random.randn(ndiag.shape[0]))
else:
raise NotImplementedError

return [np.array(g + w) for g, w in zip(gpresiduals, whiteresiduals)]


def set_residuals(psr, y):
if isinstance(psr, list):
for p, r in zip(psr, y):
p._residuals[p._isort] = r
else:
psr._residuals[psr._isort] = y


class ConditionalGP:
def __init__(self, pta, phiinv_method="cliques"):
"""This class allows the computation of conditional means and
Expand Down Expand Up @@ -208,89 +276,6 @@ def get_coefficients(pta, params, n=1, phiinv_method="cliques", variance=True, c
return ret[0] if n == 1 else ret


class KernelMatrix(np.ndarray):
def __new__(cls, init):
if isinstance(init, int):
ret = np.zeros(init, "d").view(cls)
else:
ret = init.view(cls)

if ret.ndim == 2:
ret._cliques = -1 * np.ones(ret.shape[0])
ret._clcount = 0

return ret

# see PTA._setcliques
def _setcliques(self, idxs):
allidx = set(self._cliques[idxs])
maxidx = max(allidx)

if maxidx == -1:
self._cliques[idxs] = self._clcount
self._clcount = self._clcount + 1
else:
self._cliques[idxs] = maxidx
if len(allidx) > 1:
self._cliques[np.in1d(self._cliques, allidx)] = maxidx

def add(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] += other
else:
if other.ndim == 1:
self[idx, idx] += other
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] += other

return self

def set(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] = other
else:
if other.ndim == 1:
self[idx, idx] = other
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] = other

return self

def inv(self, logdet=False):
if self.ndim == 1:
inv = 1.0 / self

if logdet:
return inv, np.sum(np.log(self))
else:
return inv
else:
try:
cf = sl.cho_factor(self)
inv = sl.cho_solve(cf, np.identity(cf[0].shape[0]))
if logdet:
ld = 2.0 * np.sum(np.log(np.diag(cf[0])))
except np.linalg.LinAlgError:
u, s, v = np.linalg.svd(self)
inv = np.dot(u / s, u.T)
if logdet:
ld = np.sum(np.log(s))
if logdet:
return inv, ld
else:
return inv


def create_stabletimingdesignmatrix(designmat, fastDesign=True):
"""
Stabilize the timing-model design matrix.
Expand Down Expand Up @@ -885,8 +870,8 @@ def svd_tm_basis(Mmat):


@function
def tm_prior(weights):
return weights * 1e40
def tm_prior(weights, variance=1e40):
return weights * variance


# Physical ephemeris model utility functions
Expand Down
25 changes: 24 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import enterprise.constants as const
from enterprise.pulsar import Pulsar
from enterprise.signals import utils, parameter, signal_base, white_signals, gp_signals
from enterprise.signals import anis_coefficients as anis
from enterprise.signals import utils
from tests.enterprise_test_data import datadir


Expand All @@ -26,6 +26,7 @@ def setUpClass(cls):

# initialize Pulsar class
cls.psr = Pulsar(datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim")
cls.psr2 = Pulsar(datadir + "/J1909-3744_NANOGrav_9yv1.gls.par", datadir + "/J1909-3744_NANOGrav_9yv1.tim")

cls.F, _ = utils.createfourierdesignmatrix_red(cls.psr.toas, nmodes=30)

Expand All @@ -35,6 +36,28 @@ def setUpClass(cls):

cls.Mm = utils.create_stabletimingdesignmatrix(cls.psr.Mmat)

def test_simulate(self):
ef = white_signals.MeasurementNoise()

ec = gp_signals.EcorrBasisModel()

pl = utils.powerlaw(log10_A=parameter.Uniform(-16, -13), gamma=parameter.Uniform(1, 7))
orf = utils.hd_orf()
crn = gp_signals.FourierBasisCommonGP(pl, orf, components=20, name="GW")

m = ef + ec + crn

pta = signal_base.PTA([m(self.psr), m(self.psr2)])

ys = utils.simulate(pta, params=parameter.sample(pta.params))

msg = "Simulated residuals shape incorrect"
assert ys[0].shape == self.psr.residuals.shape, msg
assert ys[1].shape == self.psr2.residuals.shape, msg

msg = "Simulated residuals shape not a number"
assert np.all(~np.isnan(np.concatenate(ys))), msg

def test_createstabletimingdesignmatrix(self):
"""Timing model design matrix shape."""

Expand Down