From ad6dd2327a36a2eaa1e6123b3da80ff71ff4cb15 Mon Sep 17 00:00:00 2001 From: Florian Pfaff Date: Sun, 24 Nov 2024 12:41:24 +0100 Subject: [PATCH] Added plotting for hypercylindrical distributions --- .../abstract_hypercylindrical_distribution.py | 100 ++++++++++++++++++ ..._abstract_hypercylindrical_distribution.py | 14 +++ 2 files changed, 114 insertions(+) diff --git a/pyrecest/distributions/cart_prod/abstract_hypercylindrical_distribution.py b/pyrecest/distributions/cart_prod/abstract_hypercylindrical_distribution.py index 18957fd7..7720e776 100644 --- a/pyrecest/distributions/cart_prod/abstract_hypercylindrical_distribution.py +++ b/pyrecest/distributions/cart_prod/abstract_hypercylindrical_distribution.py @@ -2,26 +2,33 @@ from math import pi from typing import Union +import matplotlib.pyplot as plt import pyrecest.backend import scipy.integrate import scipy.optimize +from matplotlib import cm # pylint: disable=redefined-builtin,no-name-in-module,no-member # pylint: disable=no-name-in-module,no-member from pyrecest.backend import ( allclose, any, + arange, array, column_stack, concatenate, + cos, empty, full, int32, int64, isnan, + linspace, + meshgrid, mod, ndim, ones, + sin, sqrt, tile, vstack, @@ -324,3 +331,96 @@ def mode_numerical(self, starting_point=None): @property def input_dim(self): return self.dim + + def plot(self, *args, **kwargs): + if self.bound_dim != 1: + raise NotImplementedError("Plotting is only supported for bound_dim == 1.") + + lin_size = 3 + if self.lin_dim == 1: + # Creates a three-dimensional surface plot + step = 2 * pi / 100 + # sigma is the standard deviation of the linear variable + sigma = sqrt(self.linear_covariance())[0, 0] + m = self.mode() + # Create grid over periodic variable (0 to 2*pi) + x_vals = arange(0, 2 * pi + step, step) + # Create grid over linear variable (mean +/- lin_size * sigma) + theta_vals = linspace( + m[self.bound_dim] - lin_size * sigma, + m[self.bound_dim] + lin_size * sigma, + 100, + ) + x, theta = meshgrid(x_vals, theta_vals) + + # Evaluate pdf at the grid points + points = vstack((x.ravel(), theta.ravel())).T + f = self.pdf(points) + f = f.reshape(x.shape) + + # Now plot + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + ax.plot_surface(x, theta, f, *args, **kwargs) + ax.set_xlim([0, 2 * pi]) + ax.set_xlabel("Periodic Variable (Radians)") + ax.set_ylabel("Linear Variable") + ax.set_zlabel("PDF Value") + plt.show() + elif self.lin_dim == 2: + raise NotImplementedError("Plotting not supported for lin_dim == 2.") + else: + raise NotImplementedError("Plotting not supported for this lin_dim.") + + # pylint: disable=too-many-locals + def plot_cylinder(self, limits_linear=None): + assert ( + self.bound_dim == 1 and self.lin_dim == 1 + ), "plot_cylinder is only implemented for bound_dim == 1 and lin_dim == 1." + + if limits_linear is None: + scale_lin = 3 + m = self.linear_mean() + P = self.linear_covariance() + if not isnan(m).any() and not isnan(P).any(): + limits_linear = [ + m[0] - scale_lin * sqrt(P[0, 0]), + m[0] + scale_lin * sqrt(P[0, 0]), + ] + else: + # Sample to find a suitable range + s = self.sample(100) + # s is an array of shape (dim, N) + limits_linear = [ + s[self.bound_dim :, :].min(), # noqa: E203 + s[self.bound_dim :, :].max(), # noqa: E203 + ] + + phi = linspace(0.0, 2 * pi, 100) + lin = linspace(limits_linear[0], limits_linear[1], 100) + Phi, L = meshgrid(phi, lin) + points = vstack([Phi.ravel(), L.ravel()]).T + C = self.pdf(points) + C = C.reshape(Phi.shape) + + X = cos(Phi) + Y = sin(Phi) + Z = L + + # Now plot using surf + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + # Normalize C to 0-1 + norm = plt.Normalize(C.min(), C.max()) + # Map normalized values to colors + colors = cm.viridis(norm(C)) + + surf = ax.plot_surface( + X, Y, Z, facecolors=colors, linewidth=0, antialiased=False, shade=False + ) + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Linear Variable") + plt.show() + + return surf diff --git a/pyrecest/tests/distributions/test_abstract_hypercylindrical_distribution.py b/pyrecest/tests/distributions/test_abstract_hypercylindrical_distribution.py index 557e5d16..1b114c11 100644 --- a/pyrecest/tests/distributions/test_abstract_hypercylindrical_distribution.py +++ b/pyrecest/tests/distributions/test_abstract_hypercylindrical_distribution.py @@ -1,6 +1,7 @@ import unittest from math import pi +import matplotlib import numpy.testing as npt # pylint: disable=no-name-in-module,no-member @@ -58,6 +59,19 @@ def test_condition_on_periodic(self): atol=1e-10, ) + @unittest.skipIf( + pyrecest.backend.__name__ in ("pyrecest.pytorch", "pyrecest.jax"), + reason="Not supported on this backend", + ) + def test_plot(self): + matplotlib.use("Agg") + matplotlib.pyplot.close("all") + hwn = PartiallyWrappedNormalDistribution( + array([1.0, 2.0]), array([[2.0, 0.3], [0.3, 1.0]]), 1 + ) + hwn.plot() + hwn.plot_cylinder() + @unittest.skipIf( pyrecest.backend.__name__ in ("pyrecest.pytorch", "pyrecest.jax"), reason="Not supported on this backend",