Skip to content

Commit

Permalink
Added plotting for hypercylindrical distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianPfaff committed Nov 24, 2024
1 parent 8ee5175 commit ad6dd23
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,33 @@
from math import pi
from typing import Union

import matplotlib.pyplot as plt
import pyrecest.backend
import scipy.integrate
import scipy.optimize
from matplotlib import cm

# pylint: disable=redefined-builtin,no-name-in-module,no-member
# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import (
allclose,
any,
arange,
array,
column_stack,
concatenate,
cos,
empty,
full,
int32,
int64,
isnan,
linspace,
meshgrid,
mod,
ndim,
ones,
sin,
sqrt,
tile,
vstack,
Expand Down Expand Up @@ -324,3 +331,96 @@ def mode_numerical(self, starting_point=None):
@property
def input_dim(self):
return self.dim

def plot(self, *args, **kwargs):
if self.bound_dim != 1:
raise NotImplementedError("Plotting is only supported for bound_dim == 1.")

lin_size = 3
if self.lin_dim == 1:
# Creates a three-dimensional surface plot
step = 2 * pi / 100
# sigma is the standard deviation of the linear variable
sigma = sqrt(self.linear_covariance())[0, 0]
m = self.mode()
# Create grid over periodic variable (0 to 2*pi)
x_vals = arange(0, 2 * pi + step, step)
# Create grid over linear variable (mean +/- lin_size * sigma)
theta_vals = linspace(
m[self.bound_dim] - lin_size * sigma,
m[self.bound_dim] + lin_size * sigma,
100,
)
x, theta = meshgrid(x_vals, theta_vals)

# Evaluate pdf at the grid points
points = vstack((x.ravel(), theta.ravel())).T
f = self.pdf(points)
f = f.reshape(x.shape)

# Now plot
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.plot_surface(x, theta, f, *args, **kwargs)
ax.set_xlim([0, 2 * pi])
ax.set_xlabel("Periodic Variable (Radians)")
ax.set_ylabel("Linear Variable")
ax.set_zlabel("PDF Value")
plt.show()
elif self.lin_dim == 2:
raise NotImplementedError("Plotting not supported for lin_dim == 2.")
else:
raise NotImplementedError("Plotting not supported for this lin_dim.")

# pylint: disable=too-many-locals
def plot_cylinder(self, limits_linear=None):
assert (
self.bound_dim == 1 and self.lin_dim == 1
), "plot_cylinder is only implemented for bound_dim == 1 and lin_dim == 1."

if limits_linear is None:
scale_lin = 3
m = self.linear_mean()
P = self.linear_covariance()
if not isnan(m).any() and not isnan(P).any():
limits_linear = [
m[0] - scale_lin * sqrt(P[0, 0]),
m[0] + scale_lin * sqrt(P[0, 0]),
]
else:
# Sample to find a suitable range
s = self.sample(100)
# s is an array of shape (dim, N)
limits_linear = [
s[self.bound_dim :, :].min(), # noqa: E203
s[self.bound_dim :, :].max(), # noqa: E203
]

phi = linspace(0.0, 2 * pi, 100)
lin = linspace(limits_linear[0], limits_linear[1], 100)
Phi, L = meshgrid(phi, lin)
points = vstack([Phi.ravel(), L.ravel()]).T
C = self.pdf(points)
C = C.reshape(Phi.shape)

X = cos(Phi)
Y = sin(Phi)
Z = L

# Now plot using surf
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
# Normalize C to 0-1
norm = plt.Normalize(C.min(), C.max())
# Map normalized values to colors
colors = cm.viridis(norm(C))

surf = ax.plot_surface(
X, Y, Z, facecolors=colors, linewidth=0, antialiased=False, shade=False
)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Linear Variable")
plt.show()

return surf
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from math import pi

import matplotlib
import numpy.testing as npt

# pylint: disable=no-name-in-module,no-member
Expand Down Expand Up @@ -58,6 +59,19 @@ def test_condition_on_periodic(self):
atol=1e-10,
)

@unittest.skipIf(
pyrecest.backend.__name__ in ("pyrecest.pytorch", "pyrecest.jax"),
reason="Not supported on this backend",
)
def test_plot(self):
matplotlib.use("Agg")
matplotlib.pyplot.close("all")
hwn = PartiallyWrappedNormalDistribution(
array([1.0, 2.0]), array([[2.0, 0.3], [0.3, 1.0]]), 1
)
hwn.plot()
hwn.plot_cylinder()

@unittest.skipIf(
pyrecest.backend.__name__ in ("pyrecest.pytorch", "pyrecest.jax"),
reason="Not supported on this backend",
Expand Down

0 comments on commit ad6dd23

Please sign in to comment.