Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize differenced process #446

Merged
merged 8 commits into from
Sep 18, 2024
117 changes: 115 additions & 2 deletions pyrenew/math.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
Helper functions for doing analytical
and/or numerical calculations about
a given renewal process.
and/or numerical calculations.
"""

from __future__ import annotations

import jax.numpy as jnp
from jax.lax import broadcast_shapes, scan
from jax.typing import ArrayLike

from pyrenew.distutil import validate_discrete_dist_vector
Expand Down Expand Up @@ -172,3 +172,116 @@ def get_asymptotic_growth_rate(
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
0
]


def integrate_discrete(
init_diff_vals: ArrayLike, highest_order_diff_vals: ArrayLike
) -> ArrayLike:
"""
Integrate (de-difference) the differenced process,
obtaining the process values :math:`X(t=0), X(t=1), ... X(t)`
from the :math:`n^{th}` differences and a set of
initial process / difference values
:math:`X(t=0), X^1(t=1), X^2(t=2), ... X^{(n-1)}(t=n-1)`,
where :math:`X^k(t)` is the value of the :math:`n^{th}`
difference at index :math:`t` of the process,
obtaining a sequence of length equal to the length of
the provided `highest_order_diff_vals` vector plus
the order of the process.

Parameters
----------
init_diff_vals : ArrayLike
Values of
:math:`X(t=0), X^1(t=1), X^2(t=2) ... X^{(n-1)}(t=n-1)`.

highest_order_diff_vals : ArrayLike
Array of differences at the highest order of
differencing, i.e. the order of the overall process,
starting with :math:`X^{n}(t=n)`

Returns
-------
ArrayLike
The integrated (de-differenced) sequence of values,
of length n_diffs + order, where n_diffs is the
number of highest_order_diff_vals and order is the
order of the process.
"""
inits_by_order = jnp.atleast_1d(init_diff_vals)
highest_diffs = jnp.atleast_1d(highest_order_diff_vals)
order = inits_by_order.shape[0]
n_diffs = highest_diffs.shape[0]

try:
batch_shape = broadcast_shapes(
highest_diffs.shape[1:], inits_by_order.shape[1:]
)
except Exception as e:
raise ValueError(
"Non-time dimensions "
"(i.e. dimensions after the first) "
"for highest_order_diff_vals and init_diff_vals "
"must be broadcastable together. "
"Got highest_order_diff_vals of shape "
f"{highest_diffs.shape} and "
"init_diff_vals of shape "
f"{inits_by_order.shape}"
) from e

highest_diffs = jnp.broadcast_to(highest_diffs, (n_diffs,) + batch_shape)
inits_by_order = jnp.broadcast_to(inits_by_order, (order,) + batch_shape)

highest_diffs = jnp.concatenate(
[jnp.zeros_like(inits_by_order), highest_diffs],
axis=0,
)

scan_arrays = (
jnp.arange(start=order - 1, stop=-1, step=-1),
jnp.flip(inits_by_order, axis=0),
)

integrated, _ = scan(
f=_integrate_one_step, init=highest_diffs, xs=scan_arrays
)

return integrated


def _integrate_one_step(
current_diffs: ArrayLike,
next_order_and_init: tuple[int, ArrayLike],
) -> tuple[ArrayLike, None]:
"""
Perform one step of integration
(de-differencing) for integrate_discrete().

Helper function passed to :func:`jax.lax.scan()`.

Parameters
----------
current_diffs: ArrayLike
Array of differences at the current
de-differencing order

next_order_and_init: tuple
Tuple containing with two entries.
First entry: the next order of de-differencing
(the current order - 1) as an integer.
Second entry: the initial value at
that the next order of de-differencing
as an ArrayLike of appropriate shape.

Returns
-------
tuple[ArrayLike, None]
A tuple whose first entry contains the
values at the next order of (de)-differencing
and whose second entry is None.
"""
next_order, next_init = next_order_and_init
next_diffs = jnp.cumsum(
current_diffs.at[next_order, ...].set(next_init), axis=0
)
return next_diffs, None
180 changes: 56 additions & 124 deletions pyrenew/process/differencedprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import jax.numpy as jnp
from jax.typing import ArrayLike
from numpyro.contrib.control_flow import scan

from pyrenew.math import integrate_discrete
from pyrenew.metaclass import RandomVariable


Expand All @@ -18,6 +18,15 @@
https://otexts.com/fpp3/stationarity.html
for a discussion of differencing in the
context of discrete timeseries data.

Notes
-----
The order of differencing is the discrete
analogue of the order of a derivative in single
variable calculus. A first difference (derivative)
represents a rate of change. A second difference
(derivative) represents the rate of change of that
rate of change, et cetera.
"""

def __init__(
Expand All @@ -33,7 +42,7 @@
----------
fundamental_process : RandomVariable
Stochastic process for the
differences. Should accept an
differences. Must accept an
`n` argument specifying the number
of samples to draw.
differencing_order : int
Expand All @@ -45,104 +54,49 @@
of change), 2 a process on the
2nd differences (rate of change of
the rate of change), et cetera.
**kwargs :
Additional keyword arguments passed to
the parent class constructor.

Returns
-------
None

Notes
-----
The order of differencing is the discrete
analogue of the order of a derivative in single
variable calculus. A first difference (derivative)
represents a rate of change. A second difference
(derivative) represents the rate of change of that
rate of change, et cetera.
"""
self.assert_valid_differencing_order(differencing_order)
self.differencing_order = differencing_order
self.fundamental_process = fundamental_process
self.differencing_order = differencing_order
super().__init__(**kwargs)

def integrate(
self, init_diff_vals: ArrayLike, highest_order_diff_vals: ArrayLike
):
@staticmethod
def assert_valid_differencing_order(differencing_order: any):
"""
Integrate (de-difference) the differenced process,
obtaining the process values :math:`X(t=0), X(t=1), ... X(t)`
from the :math:`n^{th}` differences and a set of
initial process / difference values
:math:`X(t=0), X^1(t=1), X^2(t=2), ... X^{(n-1)}(t=n-1)`,
where :math:`X^k(t)` is the value of the :math:`n^{th}`
difference at index :math:`t` of the process,
obtaining a sequence of length equal to the length of
the provided `highest_order_diff_vals` vector plus
the order of the process.

Parameters
----------
init_diff_vals : ArrayLike
Values of
:math:`X(t=0), X^1(t=1), X^2(t=2) ... X^{(n-1)}(t=n-1)`.

highest_order_diff_vals : ArrayLike
Array of differences at the highest order of
differencing, i.e. the order of the overall process,
starting with :math:`X^{n}(t=n)`

To be valid, a differencing order must
be an integer and must be strictly positive.
This function raises a value error if its
argument is not a valid differencing order.
Parameter
---------
differcing_order : any
Potential differencing order to validate.
Returns
-------
The integrated (de-differenced) sequence of values,
of length n_diffs + order, where n_diffs is the
number of highest_order_diff_vals and order is the
order of the process.
None or raises a ValueError
"""
init_arr = jnp.atleast_1d(init_diff_vals)
diff_arr = jnp.atleast_1d(highest_order_diff_vals)
if not init_arr.ndim == 1:
raise ValueError(
"init_diff_vals must be 1-dimensional "
"array or a scalar. "
f"Got {init_diff_vals}"
)
if not diff_arr.ndim == 1:
if not isinstance(differencing_order, int):
raise ValueError(
"highest_order_diff_vals must be a "
"1-dimensional array or a scalar "
f"Got {highest_order_diff_vals}"
"differencing_order must be an integer. "
f"got type {type(differencing_order)} "
f"and value {differencing_order}"
)
n_inits = init_arr.size
if not n_inits == self.differencing_order:
if not differencing_order >= 1:
raise ValueError(
"Must have exactly as many "
"initial difference values as "
"the differencing order, given "
"in the sequence :math:`X(t=0), X^1(t=1),` "
"et cetera. "
f"Got {n_inits} values "
"for a process of order "
f"{self.differencing_order}"
"differencing_order must be an integer "
"greater than or equal to 1. Got "
f"{differencing_order}"
)

def _integrate_one_step(diffs, scanned):
# numpydoc ignore=GL08
order, init = scanned
new_diffs = jnp.cumsum(diffs.at[order].set(init))
return (new_diffs, None)

integrated, _ = scan(
_integrate_one_step,
init=jnp.pad(diff_arr, (self.differencing_order, 0)),
xs=(
jnp.flip(jnp.arange(self.differencing_order)),
jnp.flip(init_arr),
),
)

return integrated
def validate(self):
"""
Empty validation method.
"""
pass

Check warning on line 99 in pyrenew/process/differencedprocess.py

View check run for this annotation

Codecov / codecov/patch

pyrenew/process/differencedprocess.py#L99

Added line #L99 was not covered by tests

def sample(
self,
Expand All @@ -161,22 +115,22 @@
initial values for the :math:`0^{th}` through
:math:`(n-1)^{st}` differences, passed as the
``init_diff_vals`` argument to
:meth:`DifferencedProcess.integrate()`
:func:`integrate_discrete()`

n : int
Number of values to sample. Will sample
``n - self.differencing_order`` values from
:code:`n - differencing_order` values from
:meth:`self.fundamental_process` to ensure
that the de-differenced output is of length
``n``.
:code`n`.

*args :
Additional positional arguments passed to
:meth:`self.fundamental_process.sample()`

fundamental_process_init_vals : ArrayLike
Initial values for the fundamental process.
Passed as the ``init_vals`` keyword argument
Passed as the :arg:`init_vals` keyword argument
to :meth:`self.fundamental_process.sample()`.

**kwargs : dict, optional
Expand All @@ -193,7 +147,22 @@
if n < 1:
raise ValueError("n must be positive. " f"Got {n}")

init_vals = jnp.atleast_1d(init_vals)
n_inits = init_vals.shape[0]

if not n_inits == self.differencing_order:
raise ValueError(
"Must have exactly as many "
"initial difference values as "
"the differencing order, given "
"in the sequence :math:`X(t=0), X^1(t=1),` "
"et cetera. "
f"Got {n_inits} values "
"for a process of order "
f"{self.differencing_order}."
)
n_diffs = n - self.differencing_order

if n_diffs > 0:
diff_samp = self.fundamental_process.sample(
*args,
Expand All @@ -204,42 +173,5 @@
diffs = diff_samp
else:
diffs = jnp.array([])
integrated_ts = self.integrate(init_vals, diffs)[:n]
integrated_ts = integrate_discrete(init_vals, diffs)[:n]
return integrated_ts

@staticmethod
def validate():
"""
Validates input parameters, implementation pending.
"""
return None

@staticmethod
def assert_valid_differencing_order(differencing_order: any):
"""
To be valid, a differencing order must
be an integer and must be strictly positive.
This function raises a value error if its
argument is not a valid differencing order.

Parameter
---------
differcing_order : any
Potential differencing order to validate.

Returns
-------
None or raises a ValueError
"""
if not isinstance(differencing_order, int):
raise ValueError(
"differencing_order must be an integer. "
f"got type {type(differencing_order)} "
f"and value {differencing_order}"
)
if not differencing_order >= 1:
raise ValueError(
"differencing_order must be an integer "
"greater than or equal to 1. Got "
f"{differencing_order}"
)
Loading