diff --git a/enterprise/signals/gp_signals.py b/enterprise/signals/gp_signals.py index 443927d9..0caf179e 100644 --- a/enterprise/signals/gp_signals.py +++ b/enterprise/signals/gp_signals.py @@ -13,9 +13,9 @@ from sksparse.cholmod import cholesky from enterprise.signals import parameter, selections, signal_base, utils +from enterprise.signals.signal_base import KernelMatrix from enterprise.signals.parameter import function from enterprise.signals.selections import Selection -from enterprise.signals.utils import KernelMatrix # logging.basicConfig(format="%(levelname)s: %(name)s: %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) @@ -225,11 +225,11 @@ def get_timing_model_basis(use_svd=False, normed=True): return utils.unnormed_tm_basis() -def TimingModel(coefficients=False, name="linear_timing_model", use_svd=False, normed=True): +def TimingModel(coefficients=False, name="linear_timing_model", use_svd=False, normed=True, prior_variance=1e40): """Class factory for marginalized linear timing model signals.""" basis = get_timing_model_basis(use_svd, normed) - prior = utils.tm_prior() + prior = utils.tm_prior(variance=prior_variance) BaseClass = BasisGP(prior, basis, coefficients=coefficients, name=name) @@ -848,6 +848,7 @@ def MNMMNF(self, T): # we're ignoring logdet = True for two-dimensional cases, but OK def solve(self, right, left_array=None, logdet=False): + # compute generalized version of r+ N^-1 r if right.ndim == 1 and left_array is right: res = right @@ -856,11 +857,13 @@ def solve(self, right, left_array=None, logdet=False): MNr = self.MNr(res) ret = rNr - np.dot(MNr, self.cf(MNr)) return (ret, logdet_N + self.cf.logdet() + self.Mprior) if logdet else ret + # compute generalized version of T+ N^-1 r elif right.ndim == 1 and left_array is not None and left_array.ndim == 2: res, T = right, left_array TNr = self.Nmat.solve(res, left_array=T) return TNr - np.tensordot(self.MNMMNF(T), self.MNr(res), (0, 0)) + # compute generalized version of T+ N^-1 T elif right.ndim == 2 and left_array is right: T = right diff --git a/enterprise/signals/signal_base.py b/enterprise/signals/signal_base.py index e0fd3d32..35aaa37f 100644 --- a/enterprise/signals/signal_base.py +++ b/enterprise/signals/signal_base.py @@ -24,7 +24,6 @@ from enterprise.signals.parameter import Function # noqa: F401 from enterprise.signals.parameter import function # noqa: F401 from enterprise.signals.parameter import ConstantParameter -from enterprise.signals.utils import KernelMatrix from enterprise import __version__ from sys import version @@ -1212,6 +1211,89 @@ def solve(self, other, left_array=None, logdet=False): return (ret, self._get_logdet()) if logdet else ret +class KernelMatrix(np.ndarray): + def __new__(cls, init): + if isinstance(init, int): + ret = np.zeros(init, "d").view(cls) + else: + ret = init.view(cls) + + if ret.ndim == 2: + ret._cliques = -1 * np.ones(ret.shape[0]) + ret._clcount = 0 + + return ret + + # see PTA._setcliques + def _setcliques(self, idxs): + allidx = set(self._cliques[idxs]) + maxidx = max(allidx) + + if maxidx == -1: + self._cliques[idxs] = self._clcount + self._clcount = self._clcount + 1 + else: + self._cliques[idxs] = maxidx + if len(allidx) > 1: + self._cliques[np.in1d(self._cliques, allidx)] = maxidx + + def add(self, other, idx): + if other.ndim == 2 and self.ndim == 1: + self = KernelMatrix(np.diag(self)) + + if self.ndim == 1: + self[idx] += other + else: + if other.ndim == 1: + self[idx, idx] += other + else: + self._setcliques(idx) + idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx) + self[idx] += other + + return self + + def set(self, other, idx): + if other.ndim == 2 and self.ndim == 1: + self = KernelMatrix(np.diag(self)) + + if self.ndim == 1: + self[idx] = other + else: + if other.ndim == 1: + self[idx, idx] = other + else: + self._setcliques(idx) + idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx) + self[idx] = other + + return self + + def inv(self, logdet=False): + if self.ndim == 1: + inv = 1.0 / self + + if logdet: + return inv, np.sum(np.log(self)) + else: + return inv + else: + try: + cf = sl.cho_factor(self) + inv = sl.cho_solve(cf, np.identity(cf[0].shape[0])) + if logdet: + ld = 2.0 * np.sum(np.log(np.diag(cf[0]))) + except np.linalg.LinAlgError: + u, s, v = np.linalg.svd(self) + inv = np.dot(u / s, u.T) + if logdet: + ld = np.sum(np.log(s)) + if logdet: + return inv, ld + else: + return inv + + class ShermanMorrison(object): """Custom container class for Sherman-morrison array inversion.""" diff --git a/enterprise/signals/utils.py b/enterprise/signals/utils.py index df949af0..797b0871 100644 --- a/enterprise/signals/utils.py +++ b/enterprise/signals/utils.py @@ -18,6 +18,7 @@ import enterprise from enterprise import constants as const from enterprise import signals as sigs # noqa: F401 +from enterprise.signals.signal_base import ShermanMorrison from enterprise.signals.gp_bases import ( # noqa: F401 createfourierdesignmatrix_dm, createfourierdesignmatrix_env, @@ -31,6 +32,73 @@ logger = logging.getLogger(__name__) +def simulate(pta, params, sparse_cholesky=True): + """Simulate residuals for all pulsars in `pta` by sampling all white-noise + and GP objects for parameters `params`. Requires GPs to have `combine=False`, + and will run faster with GP ECORR. If `pta` includes a `TimingModel`, that + should be created with a small `prior_variance`. This function can be used + with `utils.set_residuals` to replace residuals in a `Pulsar` object. + Note that any PTA built from that `Pulsar` may nevertheless cache residuals + internally, so it is safer to rebuild the PTA with the modified `Pulsar`.""" + + delays, ndiags, fmats, phis = ( + pta.get_delay(params=params), + pta.get_ndiag(params=params), + pta.get_basis(params=params), + pta.get_phi(params=params), + ) + + gpresiduals = [] + if pta._commonsignals: + if sparse_cholesky: + cf = cholesky(sps.csc_matrix(phis)) + gp = np.zeros(phis.shape[0]) + gp[cf.P()] = np.dot(cf.L().toarray(), np.random.randn(phis.shape[0])) + else: + gp = np.dot(sl.cholesky(phis, lower=True), np.random.randn(phis.shape[0])) + + i = 0 + for fmat in fmats: + j = i + fmat.shape[1] + gpresiduals.append(np.dot(fmat, gp[i:j])) + i = j + + assert len(gp) == i + else: + for fmat, phi in zip(fmats, phis): + if phi is None: + gpresiduals.append(0) + elif phi.ndim == 1: + gpresiduals.append(np.dot(fmat, np.sqrt(phi) * np.random.randn(phi.shape[0]))) + else: + raise NotImplementedError + + whiteresiduals = [] + for delay, ndiag in zip(delays, ndiags): + if ndiag is None: + whiteresiduals.append(0) + elif isinstance(ndiag, ShermanMorrison): + # this code is very slow... + n = np.diag(ndiag._nvec) + for j, s in zip(ndiag._jvec, ndiag._slices): + n[s, s] += j + whiteresiduals.append(delay + np.dot(sl.cholesky(n, lower=True), np.random.randn(n.shape[0]))) + elif ndiag.ndim == 1: + whiteresiduals.append(delay + np.sqrt(ndiag) * np.random.randn(ndiag.shape[0])) + else: + raise NotImplementedError + + return [np.array(g + w) for g, w in zip(gpresiduals, whiteresiduals)] + + +def set_residuals(psr, y): + if isinstance(psr, list): + for p, r in zip(psr, y): + p._residuals[p._isort] = r + else: + psr._residuals[psr._isort] = y + + class ConditionalGP: def __init__(self, pta, phiinv_method="cliques"): """This class allows the computation of conditional means and @@ -208,89 +276,6 @@ def get_coefficients(pta, params, n=1, phiinv_method="cliques", variance=True, c return ret[0] if n == 1 else ret -class KernelMatrix(np.ndarray): - def __new__(cls, init): - if isinstance(init, int): - ret = np.zeros(init, "d").view(cls) - else: - ret = init.view(cls) - - if ret.ndim == 2: - ret._cliques = -1 * np.ones(ret.shape[0]) - ret._clcount = 0 - - return ret - - # see PTA._setcliques - def _setcliques(self, idxs): - allidx = set(self._cliques[idxs]) - maxidx = max(allidx) - - if maxidx == -1: - self._cliques[idxs] = self._clcount - self._clcount = self._clcount + 1 - else: - self._cliques[idxs] = maxidx - if len(allidx) > 1: - self._cliques[np.in1d(self._cliques, allidx)] = maxidx - - def add(self, other, idx): - if other.ndim == 2 and self.ndim == 1: - self = KernelMatrix(np.diag(self)) - - if self.ndim == 1: - self[idx] += other - else: - if other.ndim == 1: - self[idx, idx] += other - else: - self._setcliques(idx) - idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx) - self[idx] += other - - return self - - def set(self, other, idx): - if other.ndim == 2 and self.ndim == 1: - self = KernelMatrix(np.diag(self)) - - if self.ndim == 1: - self[idx] = other - else: - if other.ndim == 1: - self[idx, idx] = other - else: - self._setcliques(idx) - idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx) - self[idx] = other - - return self - - def inv(self, logdet=False): - if self.ndim == 1: - inv = 1.0 / self - - if logdet: - return inv, np.sum(np.log(self)) - else: - return inv - else: - try: - cf = sl.cho_factor(self) - inv = sl.cho_solve(cf, np.identity(cf[0].shape[0])) - if logdet: - ld = 2.0 * np.sum(np.log(np.diag(cf[0]))) - except np.linalg.LinAlgError: - u, s, v = np.linalg.svd(self) - inv = np.dot(u / s, u.T) - if logdet: - ld = np.sum(np.log(s)) - if logdet: - return inv, ld - else: - return inv - - def create_stabletimingdesignmatrix(designmat, fastDesign=True): """ Stabilize the timing-model design matrix. @@ -885,8 +870,8 @@ def svd_tm_basis(Mmat): @function -def tm_prior(weights): - return weights * 1e40 +def tm_prior(weights, variance=1e40): + return weights * variance # Physical ephemeris model utility functions diff --git a/tests/test_utils.py b/tests/test_utils.py index f8268cf3..e0aca890 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,8 +14,8 @@ import enterprise.constants as const from enterprise.pulsar import Pulsar +from enterprise.signals import utils, parameter, signal_base, white_signals, gp_signals from enterprise.signals import anis_coefficients as anis -from enterprise.signals import utils from tests.enterprise_test_data import datadir @@ -26,6 +26,7 @@ def setUpClass(cls): # initialize Pulsar class cls.psr = Pulsar(datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim") + cls.psr2 = Pulsar(datadir + "/J1909-3744_NANOGrav_9yv1.gls.par", datadir + "/J1909-3744_NANOGrav_9yv1.tim") cls.F, _ = utils.createfourierdesignmatrix_red(cls.psr.toas, nmodes=30) @@ -35,6 +36,28 @@ def setUpClass(cls): cls.Mm = utils.create_stabletimingdesignmatrix(cls.psr.Mmat) + def test_simulate(self): + ef = white_signals.MeasurementNoise() + + ec = gp_signals.EcorrBasisModel() + + pl = utils.powerlaw(log10_A=parameter.Uniform(-16, -13), gamma=parameter.Uniform(1, 7)) + orf = utils.hd_orf() + crn = gp_signals.FourierBasisCommonGP(pl, orf, components=20, name="GW") + + m = ef + ec + crn + + pta = signal_base.PTA([m(self.psr), m(self.psr2)]) + + ys = utils.simulate(pta, params=parameter.sample(pta.params)) + + msg = "Simulated residuals shape incorrect" + assert ys[0].shape == self.psr.residuals.shape, msg + assert ys[1].shape == self.psr2.residuals.shape, msg + + msg = "Simulated residuals shape not a number" + assert np.all(~np.isnan(np.concatenate(ys))), msg + def test_createstabletimingdesignmatrix(self): """Timing model design matrix shape."""