Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-1806: Implementation of Doubly Truncated Power Law and Lower Truncated Power Law #1807

Merged
merged 33 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
febd260
implementation of DoublyTruncatedPowerLaw
Qazalbash May 30, 2024
dbb0d62
implementation of LowerTruncatedPowerLaw
Qazalbash May 30, 2024
7985364
chore: mathematical description in docstrings
Qazalbash Jun 23, 2024
2fae5e9
chore: mathematical details of `LowerTruncatedPowerLaw`
Qazalbash Jun 23, 2024
63da106
chore: Fix bug in DoublyTruncatedPowerLaw cdf and icdf calculation
Qazalbash Jun 24, 2024
10d8006
chore: Refactor mean and variance calculation by using kth-moment in …
Qazalbash Jun 24, 2024
d7d6c7e
chore: Refactor mean and variance calculation in LowerTruncatedPowerLaw
Qazalbash Jun 24, 2024
36cce79
chore: masking in icdf of LowerTruncatedPowerLaw
Qazalbash Jun 24, 2024
d8eb1e5
chore: entropy of LowerTruncatedPowerLaw
Qazalbash Jun 24, 2024
1192ca7
chore: `lax.sqaure` replaced with `jnp.sqaure`
Qazalbash Jul 6, 2024
0e4d35f
chore: moments and entropy were extra and removed
Qazalbash Jul 12, 2024
6fb9536
chore: unit tests
Qazalbash Jul 13, 2024
6743dde
fix: nan gradients fixed, values still diverging
Qazalbash Jul 13, 2024
a88c331
Updated UpperTruncatedPowerLaw with adequate derivations, including f…
InfinityMod Aug 15, 2024
b2db5f8
Changed constrains of alpha of LowerTruncatedPowerLaw to the smaller …
InfinityMod Aug 15, 2024
12d78f3
chore: code and docstring formated
Qazalbash Aug 17, 2024
2e8e33e
chore: equation refactor and simplified
Qazalbash Aug 18, 2024
ce3c53f
chore: equation refactor and simplified
Qazalbash Aug 18, 2024
aba947b
chore: use numpy arrays and numpy constants
Qazalbash Sep 5, 2024
96d5f6c
chore: high precision computation enable for powerlaws
Qazalbash Sep 5, 2024
bee7535
Merge branch 'master' into powerlaw-dist
Qazalbash Sep 5, 2024
cc3f81e
chore: `__name__` attribute calls removed
Qazalbash Sep 5, 2024
ef2a1f3
Merge branch 'master' into powerlaw-dist
Qazalbash Sep 5, 2024
3df15f2
chore: powerlaws shifted with truncated distributions
Qazalbash Sep 5, 2024
465253c
chore: spelling mistakes fixed with code spell checker pre-commit hook
Qazalbash Sep 5, 2024
0589bea
fix typo: perforance->perforamce->performance
Qazalbash Sep 10, 2024
21b50b5
Merge branch 'master' into powerlaw-dist
Qazalbash Sep 10, 2024
f1da2d5
chore: explicit enabling/disabling of 64bit floating point numbers
Qazalbash Sep 11, 2024
42ed59d
chore: disable everytime and enable x64 for power laws
Qazalbash Sep 11, 2024
2671edb
chore: disable x64 for every test
Qazalbash Sep 11, 2024
6d126ae
chore: linked explanation in comments for disabling x64 for future re…
Qazalbash Sep 11, 2024
34d9f82
chore: high precision test handeled efficiently for DoublyTruncatedPo…
Qazalbash Sep 17, 2024
41f6aa3
chore: high precision exception handled in test_log_prob_gradient
Qazalbash Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ repos:
- id: check-yaml
- id: check-added-large-files
exclude: notebooks/*

- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
args:
[--ignore-words-list, "Teh,aas", --check-filenames, --skip, "*.ipynb"]
75 changes: 39 additions & 36 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ Design Choices:
Future Work:

- Right now the jax, jaxlib, and numpyro versions are manually specified, so they have to be updated every NumPyro release. There are two ways forward for this:
1. If there is a CI/CD in place to build and push images to a repository like Dockerhub, then the jax, jaxlib, and numpyro versions can be passed in as environment variables (for example, if something like [Drone CI](http://plugins.drone.io/drone-plugins/drone-docker/) is used). If implemented this way, the jax/jaxlib/numpyro versions will be ephemereal (not stored in source code).
1. If there is a CI/CD in place to build and push images to a repository like Dockerhub, then the jax, jaxlib, and numpyro versions can be passed in as environment variables (for example, if something like [Drone CI](http://plugins.drone.io/drone-plugins/drone-docker/) is used). If implemented this way, the jax/jaxlib/numpyro versions will be ephemeral (not stored in source code).
2. Alternative, one can create a Python script that will modify the Dockerfiles upon release accordingly (using a hook of some sort).
16 changes: 16 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,14 @@ VonMises
Truncated Distributions
-----------------------

DoublyTruncatedPowerLaw
^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.truncated.DoublyTruncatedPowerLaw
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

LeftTruncatedDistribution
^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.truncated.LeftTruncatedDistribution
Expand All @@ -670,6 +678,14 @@ LeftTruncatedDistribution
:show-inheritance:
:member-order: bysource

LowerTruncatedPowerLaw
^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.truncated.LowerTruncatedPowerLaw
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

RightTruncatedDistribution
^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.truncated.RightTruncatedDistribution
Expand Down
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

All models have discrete latent variables. Under the hood, we enumerate over
(marginalize out) those discrete latent sites in inference. Those models have different
complexity so they are great refererences for those who are new to Pyro/NumPyro
complexity so they are great references for those who are new to Pyro/NumPyro
enumeration mechanism. We recommend readers compare the implementations with the
corresponding plate diagrams in [1] to see how concise a Pyro/NumPyro program is.

Expand Down
2 changes: 1 addition & 1 deletion examples/ar2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def run_inference(model, args, rng_key, y):


def main(args):
# generate artifical dataset
# generate artificial dataset
num_data = args.num_data
rng_key = jax.random.PRNGKey(0)
t = jnp.arange(0, num_data)
Expand Down
2 changes: 1 addition & 1 deletion examples/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def predict(model, args, samples, rng_key, y, n_seasons):


def main(args):
# generate artifical dataset
# generate artificial dataset
rng_key, _ = random.split(random.PRNGKey(0))
T = args.T
t = jnp.linspace(0, T + args.future, (T + args.future) * N_POINTS_PER_UNIT)
Expand Down
2 changes: 1 addition & 1 deletion examples/mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
dimensions of the age, space and time variables. This allows us to efficiently broadcast arrays
in the likelihood.

As written above, the model includes a lot of centred random effects. The NUTS alogrithm benefits
As written above, the model includes a lot of centred random effects. The NUTS algorithm benefits
from a non-centred reparamatrisation to overcome difficult posterior geometries [2]. Rather than
manually writing out the non-centred parametrisation, we make use of the NumPyro's automatic
reparametrisation in :class:`~numpyro.infer.reparam.LocScaleReparam`.
Expand Down
26 changes: 15 additions & 11 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@
from numpyro.distributions.mixtures import Mixture, MixtureGeneral, MixtureSameFamily
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.truncated import (
DoublyTruncatedPowerLaw,
LeftTruncatedDistribution,
LowerTruncatedPowerLaw,
RightTruncatedDistribution,
TruncatedCauchy,
TruncatedDistribution,
Expand All @@ -122,6 +124,7 @@
"Binomial",
"BinomialLogits",
"BinomialProbs",
"CAR",
"Categorical",
"CategoricalLogits",
"CategoricalProbs",
Expand All @@ -132,9 +135,10 @@
"DirichletMultinomial",
"DiscreteUniform",
"Distribution",
"DoublyTruncatedPowerLaw",
"EulerMaruyama",
"Exponential",
"ExpandedDistribution",
"Exponential",
"FoldedDistribution",
"Gamma",
"GammaPoisson",
Expand All @@ -152,29 +156,29 @@
"Independent",
"InverseGamma",
"Kumaraswamy",
"LKJ",
"LKJCholesky",
"Laplace",
"LeftTruncatedDistribution",
"LKJ",
"LKJCholesky",
"Logistic",
"LogNormal",
"LogUniform",
"MatrixNormal",
"LowerTruncatedPowerLaw",
"LowRankMultivariateNormal",
"MaskedDistribution",
"MatrixNormal",
"Mixture",
"MixtureSameFamily",
"MixtureGeneral",
"MixtureSameFamily",
"Multinomial",
"MultinomialLogits",
"MultinomialProbs",
"MultivariateNormal",
"CAR",
"MultivariateStudentT",
"LowRankMultivariateNormal",
"Normal",
"NegativeBinomialProbs",
"NegativeBinomialLogits",
"NegativeBinomial2",
"NegativeBinomialLogits",
"NegativeBinomialProbs",
"Normal",
"OrderedLogistic",
"Pareto",
"Poisson",
Expand All @@ -199,7 +203,7 @@
"Wishart",
"WishartCholesky",
"ZeroInflatedDistribution",
"ZeroInflatedPoisson",
"ZeroInflatedNegativeBinomial2",
"ZeroInflatedPoisson",
"ZeroSumNormal",
]
12 changes: 6 additions & 6 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ class LKJCholesky(Distribution):
r"""
LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is
controlled by ``concentration`` parameter :math:`\eta` to make the probability of the
correlation matrix :math:`M` generated from a Cholesky factor propotional to
correlation matrix :math:`M` generated from a Cholesky factor proportional to
:math:`\det(M)^{\eta - 1}`. Because of that, when ``concentration == 1``, we have a
uniform distribution over Cholesky factors of correlation matrices.

Expand Down Expand Up @@ -1048,7 +1048,7 @@ def __init__(

# We construct base distributions to generate samples for each method.
# The purpose of this base distribution is to generate a distribution for
# correlation matrices which is propotional to `det(M)^{\eta - 1}`.
# correlation matrices which is proportional to `det(M)^{\eta - 1}`.
# (note that this is not a unique way to define base distribution)
# Both of the following methods have marginal distribution of each off-diagonal
# element of sampled correlation matrices is Beta(eta + (D-2) / 2, eta + (D-2) / 2)
Expand Down Expand Up @@ -1150,12 +1150,12 @@ def log_prob(self, value):
# Generally, for a D dimensional matrix, we have:
# Jacobian = L22^(D-2) * L33^(D-3) * ... * Ldd^0
#
# From [1], we know that probability of a correlation matrix is propotional to
# From [1], we know that probability of a correlation matrix is proportional to
# determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
# On the other hand, Jabobian of the transformation from Cholesky factor to
# correlation matrix is:
# prod(L_ii ^ (D - i))
# So the probability of a Cholesky factor is propotional to
# So the probability of a Cholesky factor is proportional to
# prod(L_ii ^ (2 * concentration - 2 + D - i)) =: prod(L_ii ^ order_i)
# with order_i = 2 * concentration - 2 + D - i,
# i = 2..D (we omit the element i = 1 because L_11 = 1)
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def entropy(self):
def _batch_solve_triangular(A, B):
"""
Extende solve_triangular for the case that B.ndim > A.ndim.
This is achived by first flattening the leading B.ndim - A.ndim dimensions of B and then
This is achieved by first flattening the leading B.ndim - A.ndim dimensions of B and then
moving the first dimension to the end.


Expand Down Expand Up @@ -1720,7 +1720,7 @@ def log_prob(self, value):
D_rsqrt[..., None, :] * D_rsqrt[..., None]
)

# TODO: look into sparse eignvalue methods
# TODO: look into sparse eigenvalue methods
if isinstance(adj_matrix_scaled, np.ndarray):
lam = np.linalg.eigvalsh(adj_matrix_scaled)
else:
Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,10 @@ def infer_shapes(cls, *args, **kwargs):

def cdf(self, value):
"""
The cummulative distribution function of this distribution.
The cumulative distribution function of this distribution.

:param value: samples from this distribution.
:return: output of the cummulative distribution function evaluated at `value`.
:return: output of the cumulative distribution function evaluated at `value`.
"""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/gof.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def unif01_goodness_of_fit(samples, *, plot=False):

def exp_goodness_of_fit(samples, plot=False):
"""
Transform exponentially distribued samples to Uniform(0,1) distribution and
Transform exponentially distributed samples to Uniform(0,1) distribution and
assess goodness of fit via binned Pearson's chi^2 test.

:param numpy.ndarray samples: A vector of real-valued samples from a
Expand Down Expand Up @@ -353,7 +353,7 @@ def _chi2sf(x, s):
F(x; s) = \frac{ \gamma( x/2, s/2 ) }{ \Gamma(s/2) },

with :math:`\gamma` is the incomplete gamma function defined above.
Therefore, the survival probability is givne by:
Therefore, the survival probability is given by:

.. math::
1 - \frac{ \gamma( x/2, s/2 ) }{ \Gamma(s/2) }.
Expand Down
6 changes: 3 additions & 3 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _matrix_forward_shape(shape, offset=0):
N = shape[-1]
D = round((0.25 + 2 * N) ** 0.5 - 0.5)
if D * (D + 1) // 2 != N:
raise ValueError("Input is not a flattend lower-diagonal number")
raise ValueError("Input is not a flattened lower-diagonal number")
D = D - offset
return shape[:-1] + (D, D)

Expand Down Expand Up @@ -447,7 +447,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):

class CorrCholeskyTransform(ParameterFreeTransform):
r"""
Transforms a uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
Transforms a unconstrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
triangular matrix with positive diagonals and unit Euclidean norm for each row.
The transform is processed as follows:
Expand Down Expand Up @@ -655,7 +655,7 @@ def __eq__(self, other):

class L1BallTransform(ParameterFreeTransform):
r"""
Transforms a uncontrained real vector :math:`x` into the unit L1 ball.
Transforms a unconstrained real vector :math:`x` into the unit L1 ball.
"""

domain = constraints.real_vector
Expand Down
Loading
Loading