From 592658feb3b53f31511a59729c35e11520dcae7e Mon Sep 17 00:00:00 2001 From: marshrossney <17361029+marshrossney@users.noreply.github.com> Date: Mon, 26 Apr 2021 16:31:14 +0100 Subject: [PATCH 1/4] layer-wise histograms of weights and field variables --- anvil/benchmark_config/free_scalar_sample.yml | 3 + anvil/models.py | 30 +++-- anvil/plot.py | 115 +++++------------- anvil/sample.py | 29 ++++- examples/runcards/report.md | 3 + 5 files changed, 87 insertions(+), 93 deletions(-) diff --git a/anvil/benchmark_config/free_scalar_sample.yml b/anvil/benchmark_config/free_scalar_sample.yml index 96cc6a2..884ea0a 100644 --- a/anvil/benchmark_config/free_scalar_sample.yml +++ b/anvil/benchmark_config/free_scalar_sample.yml @@ -16,6 +16,9 @@ template_text: | # Eigenvalues of kinetic operator {@plot_kinetic_eigenvalues@} {@table_kinetic_eigenvalues@} + # Layer-wise breakdown + {@plot_layerwise_histograms@} + {@plot_layerwise_weights@} actions_: - report(main=True) diff --git a/anvil/models.py b/anvil/models.py index 465e29c..c8b7a2f 100644 --- a/anvil/models.py +++ b/anvil/models.py @@ -4,14 +4,15 @@ Module containing reportengine actions which return callable objects that execute normalising flows constructed from multiple layers via function composition. """ +import torch from functools import partial +from reportengine import collect from anvil.core import Sequential - import anvil.layers as layers -def coupling_pair(coupling_layer, size_half, **layer_spec): +def coupling_block(coupling_layer, size_half, **layer_spec): """Helper function which returns a callable object that performs a coupling transformation on both even and odd lattice sites.""" coupling_transformation = partial(coupling_layer, size_half, **layer_spec) @@ -31,7 +32,7 @@ def real_nvp( """Action that returns a callable object that performs a sequence of `n_affine` affine coupling transformations on both partitions of the input vector.""" blocks = [ - coupling_pair( + coupling_block( layers.AffineLayer, size_half, hidden_shape=hidden_shape, @@ -53,7 +54,7 @@ def nice( """Action that returns a callable object that performs a sequence of `n_affine` affine coupling transformations on both partitions of the input vector.""" blocks = [ - coupling_pair( + coupling_block( layers.AdditiveLayer, size_half, hidden_shape=hidden_shape, @@ -74,10 +75,10 @@ def rational_quadratic_spline( activation="tanh", z2_equivar_spline=False, ): - """Action that returns a callable object that performs a pair of circular spline + """Action that returns a callable object that performs a block of circular spline transformations, one on each half of the input vector.""" blocks = [ - coupling_pair( + coupling_block( layers.RationalQuadraticSplineLayer, size_half, interval=interval, @@ -96,12 +97,25 @@ def rational_quadratic_spline( def spline_affine(real_nvp, rational_quadratic_spline): - return Sequential(rational_quadratic_spline, real_nvp) + return Sequential(*rational_quadratic_spline, *real_nvp) def affine_spline(real_nvp, rational_quadratic_spline): - return Sequential(real_nvp, rational_quadratic_spline) + return Sequential(*real_nvp, *rational_quadratic_spline) + +# TODO replace this +_loaded_model = collect("loaded_model", ("training_context",)) + + +def model_weights(_loaded_model): + model = _loaded_model[0] + for block in model: + params = { + key: tensor.flatten().numpy() for key, tensor in block.state_dict().items() + } + if len(params) > 1: # only want coupling layers + yield params MODEL_OPTIONS = { "nice": nice, diff --git a/anvil/plot.py b/anvil/plot.py index 9e9efc4..a74cabc 100644 --- a/anvil/plot.py +++ b/anvil/plot.py @@ -7,6 +7,7 @@ import torch import numpy as np import matplotlib.pyplot as plt +import matplotlib as mpl from matplotlib.ticker import MaxNLocator from reportengine.figure import figure, figuregen @@ -15,93 +16,35 @@ from anvil.observables import cosh_shift -def field_component(i, x_base, phi_model, base_neg, model_neg): - fig, ax = plt.subplots() - - ax.hist(x_base, bins=50, density=True, histtype="step", label="base") - ax.hist(phi_model, bins=50, density=True, histtype="step", label="model, full") - ax.hist( - base_neg, bins=50, density=True, histtype="step", label="model, $M_{base} < 0$" - ) - ax.hist( - model_neg, bins=50, density=True, histtype="step", label="model, $M_{mod} < 0$" - ) - ax.set_title(f"Coordinate {i}") - ax.legend() - fig.tight_layout() - return fig - - -def field_components(loaded_model, base_dist, lattice_size): - """Plot the distributions of base coordinates 'x' and output coordinates 'phi' and, - if known, plot the pdf of the target distribution.""" - sample_size = 10000 - - # Generate a large sample from the base distribution and pass it through the trained model - with torch.no_grad(): - x_base, _ = base_dist(sample_size) - sign = x_base.sum(dim=1).sign() - neg = (sign < 0).nonzero().squeeze() - phi_model, model_log_density = loaded_model(x_base, 0, neg) - - base_neg = phi_model[neg] - - sign = phi_model.sum(dim=1).sign() - neg = (sign < 0).nonzero().squeeze() - model_neg = phi_model[neg] - - # Convert to shape (n_coords, sample_size * lattice_size) - # NOTE: this is all pointless for the 1-component scalar - x_base = x_base.reshape(sample_size * lattice_size, -1).transpose(0, 1) - phi_model = phi_model.reshape(sample_size * lattice_size, -1).transpose(0, 1) - - base_neg = base_neg.reshape(1, -1) - model_neg = model_neg.reshape(1, -1) +# TODO: subplots for different neural networks +def plot_layer_weights(model_weights): + for weights in model_weights: + fig, ax = plt.subplots() + labels = list(weights.keys()) + data = weights.values() + ax.hist(data, bins=50, stacked=True, label=labels) + fig.legend() + yield fig - for i in range(x_base.shape[0]): - yield field_component(i, x_base[i], phi_model[i], base_neg[i], model_neg[i]) - -_plot_field_components = collect("field_components", ("training_context",)) - - -def example_configs(loaded_model, base_dist, training_geometry): - sample_size = 10 - - # Generate a large sample from the base distribution and pass it through the trained model - with torch.no_grad(): - x_base, _ = base_dist(sample_size) - sign = x_base.sum(dim=1).sign() - neg = (sign < 0).nonzero().squeeze() - phi_model, model_log_density = loaded_model(x_base, 0, neg) - - L = int(np.sqrt(phi_model.shape[1])) - - phi_true = np.zeros((4, L, L)) - phi_true[:, training_geometry.checkerboard] = phi_model[:4, : L ** 2 // 2] - phi_true[:, ~training_geometry.checkerboard] = phi_model[:4, L ** 2 // 2 :] - - fig, axes = plt.subplots(2, 2, sharex=True, sharey=True) - for i, ax in enumerate(axes.flatten()): - conf = ax.imshow(phi_true[i]) - fig.colorbar(conf, ax=ax) - - fig.suptitle("Example configurations") - - return fig - - -_plot_example_configs = collect("example_configs", ("training_context",)) +@figuregen +def plot_layerwise_weights(plot_layer_weights): + yield from plot_layer_weights -@figure -def plot_example_configs(_plot_example_configs): - return _plot_example_configs[0] +def plot_layer_histogram(layerwise_configs): + for v in layerwise_configs: + v = v.numpy() + v_pos = v[v.sum(axis=1) > 0].flatten() + v_neg = v[v.sum(axis=1) < 0].flatten() + fig, ax = plt.subplots() + ax.hist([v_pos, v_neg], bins=50, density=True, histtype="step") + yield fig @figuregen -def plot_field_components(_plot_field_components): - yield from _plot_field_components[0] +def plot_layerwise_histograms(plot_layer_histogram): + yield from plot_layer_histogram @figure @@ -172,9 +115,12 @@ def plot_two_point_correlator(two_point_correlator): """ corr = two_point_correlator.mean(axis=-1) - std = two_point_correlator.std(axis=-1) + error = two_point_correlator.std(axis=-1) + fractional_error = np.abs(error / corr) + L = corr.shape[0] - fractional_std = std / abs(corr) + corr = np.roll(corr, (-L // 2 - 1, -L // 2 - 1), (0, 1)) + fractional_error = np.roll(fractional_error, (-L // 2 - 1, -L // 2 - 1), (0, 1)) fig, (ax_mean, ax_std) = plt.subplots(1, 2, figsize=(13, 6), sharey=True) ax_std.set_title(r"$\sigma_G / G$") @@ -182,9 +128,10 @@ def plot_two_point_correlator(two_point_correlator): ax_mean.set_xlabel("$x$") ax_std.set_xlabel("$x$") ax_mean.set_ylabel("$t$") + norm = mpl.colors.LogNorm() - im1 = ax_mean.imshow(corr) - im2 = ax_std.imshow(fractional_std) + im1 = ax_mean.imshow(corr, norm=norm) + im2 = ax_std.imshow(fractional_error, norm=norm) ax_mean.yaxis.set_major_locator(MaxNLocator(integer=True)) ax_mean.xaxis.set_major_locator(MaxNLocator(integer=True)) diff --git a/anvil/sample.py b/anvil/sample.py index 466972b..3402877 100644 --- a/anvil/sample.py +++ b/anvil/sample.py @@ -120,7 +120,9 @@ def metropolis_hastings( history.append(0) tau = calc_tau_chain(history) - log.info(f"Integrated autocorrelation time from preliminary sampling phase: {tau:.2g}") + log.info( + f"Integrated autocorrelation time from preliminary sampling phase: {tau:.2g}" + ) sample_interval = ceil(2 * tau) # update sample interval log.info(f"Using sampling interval: {sample_interval}") @@ -180,3 +182,28 @@ def tau_chain(_metropolis_hastings): def acceptance(_metropolis_hastings): return _metropolis_hastings[0][2] + + +# TODO: figure out how to name each coupling block +@torch.no_grad() +def yield_configs_layerwise(loaded_model, base_dist, metropolis_hastings): + v, _ = base_dist(BATCH_SIZE) + yield v + + negative_mag = (v.sum(dim=1).sign() < 0).nonzero().squeeze() + + for block in loaded_model: + v, _ = block(v, 0, negative_mag) + # only want coupling layers + if len([tensor for tensor in block.state_dict().values()]) > 1: + yield v + + v = metropolis_hastings[0] + yield v + + +_layerwise_configs = collect("yield_configs_layerwise", ("training_context",)) + + +def layerwise_configs(_layerwise_configs): + return _layerwise_configs[0] diff --git a/examples/runcards/report.md b/examples/runcards/report.md index ab9f770..edab757 100644 --- a/examples/runcards/report.md +++ b/examples/runcards/report.md @@ -16,3 +16,6 @@ {@plot_magnetization_series@} {@plot_magnetization_autocorr@} {@plot_magnetization_integrated_autocorr@} +## Layer-wise breakdown +{@plot_layerwise_histograms@} +{@plot_layerwise_weights@} From 1548be08d3f25d2cf603b428458169207a539989 Mon Sep 17 00:00:00 2001 From: marshrossney <17361029+marshrossney@users.noreply.github.com> Date: Mon, 26 Apr 2021 16:43:39 +0100 Subject: [PATCH 2/4] basic reportengine loop over layers --- anvil/benchmark_config/free_scalar_sample.yml | 4 +- anvil/config.py | 19 ++++++++-- anvil/observables.py | 8 +++- anvil/plot.py | 37 ++++++++++++------- anvil/sample.py | 24 ++++++------ anvil/table.py | 7 +++- examples/runcards/layerwise.md | 30 +++++++++++++++ examples/runcards/layerwise.yml | 19 ++++++++++ examples/runcards/report.md | 4 +- examples/runcards/report.yml | 3 +- examples/runcards/sample.yml | 1 + 11 files changed, 116 insertions(+), 40 deletions(-) create mode 100644 examples/runcards/layerwise.md create mode 100644 examples/runcards/layerwise.yml diff --git a/anvil/benchmark_config/free_scalar_sample.yml b/anvil/benchmark_config/free_scalar_sample.yml index 884ea0a..870457a 100644 --- a/anvil/benchmark_config/free_scalar_sample.yml +++ b/anvil/benchmark_config/free_scalar_sample.yml @@ -1,5 +1,6 @@ training_output: /tmp/del_me_anvil_benchmark cp_id: -1 +layer_id: -1 sample_size: 100000 thermalization: 10 @@ -16,9 +17,6 @@ template_text: | # Eigenvalues of kinetic operator {@plot_kinetic_eigenvalues@} {@table_kinetic_eigenvalues@} - # Layer-wise breakdown - {@plot_layerwise_histograms@} - {@plot_layerwise_weights@} actions_: - report(main=True) diff --git a/anvil/config.py b/anvil/config.py index 92f9c31..7acf830 100644 --- a/anvil/config.py +++ b/anvil/config.py @@ -13,6 +13,7 @@ from anvil.checkpoint import TrainingOutput from anvil.models import MODEL_OPTIONS from anvil.distributions import BASE_OPTIONS, TARGET_OPTIONS +import anvil.sample as sample from random import randint from sys import maxsize @@ -121,6 +122,17 @@ def produce_checkpoint(self, cp_id=None, training_output=None): # get index from training_output class return training_output.checkpoints[training_output.cp_ids.index(cp_id)] + @element_of("layer_ids") + def parse_layer_id(self, layer_id: int = -1): + return layer_id + + @explicit_node + def produce_configs(self, layer_id): + if layer_id == -1: + return sample.configs_from_metropolis + else: + return sample.configs_from_model + def produce_training_context(self, training_output): """Given a training output produce the context of that training""" # NOTE: This seems a bit hacky, exposing the entire training configuration @@ -183,12 +195,11 @@ def parse_bootstrap_sample_size(self, n_boot: int): log.warning(f"Using user specified bootstrap sample size: {n_boot}") return n_boot - def produce_bootstrap_seed( - self, manual_bootstrap_seed: (int, type(None)) = None): + def produce_bootstrap_seed(self, manual_bootstrap_seed: (int, type(None)) = None): if manual_bootstrap_seed is None: return randint(0, maxsize) # numpy is actually this strict but let's keep it sensible. - if (manual_bootstrap_seed < 0) or (manual_bootstrap_seed > 2**32): + if (manual_bootstrap_seed < 0) or (manual_bootstrap_seed > 2 ** 32): raise ConfigError("Seed is outside of appropriate range: [0, 2 ** 32]") return manual_bootstrap_seed @@ -213,4 +224,4 @@ def produce_use_multiprocessing(self): """Don't use mp on MacOS""" if platform.system() == "Darwin": return False - return True \ No newline at end of file + return True diff --git a/anvil/observables.py b/anvil/observables.py index 237de2f..fa4165e 100644 --- a/anvil/observables.py +++ b/anvil/observables.py @@ -12,6 +12,7 @@ log = logging.getLogger(__name__) + def cosh_shift(x, xi, A, c): return A * np.cosh(-x / xi) + c @@ -42,8 +43,11 @@ def fit_zero_momentum_correlator(zero_momentum_correlator, training_geometry): def correlation_length_from_fit(fit_zero_momentum_correlator): - popt, pcov, _ = fit_zero_momentum_correlator - return popt[0], np.sqrt(pcov[0, 0]) + if fit_zero_momentum_correlator is not None: + popt, pcov, _ = fit_zero_momentum_correlator + return popt[0], np.sqrt(pcov[0, 0]) + else: + return None, None def autocorrelation(chain): diff --git a/anvil/plot.py b/anvil/plot.py index a74cabc..4544e2d 100644 --- a/anvil/plot.py +++ b/anvil/plot.py @@ -32,19 +32,28 @@ def plot_layerwise_weights(plot_layer_weights): yield from plot_layer_weights -def plot_layer_histogram(layerwise_configs): - for v in layerwise_configs: - v = v.numpy() - v_pos = v[v.sum(axis=1) > 0].flatten() - v_neg = v[v.sum(axis=1) < 0].flatten() - fig, ax = plt.subplots() - ax.hist([v_pos, v_neg], bins=50, density=True, histtype="step") - yield fig +@figure +def plot_layer_histogram(configs): + v = configs.numpy() + v_pos = v[v.sum(axis=1) > 0].flatten() + v_neg = v[v.sum(axis=1) < 0].flatten() + fig, ax = plt.subplots() + ax.hist([v_pos, v_neg], bins=50, density=True, histtype="step") + return fig -@figuregen -def plot_layerwise_histograms(plot_layer_histogram): - yield from plot_layer_histogram +@figure +def plot_correlation_length(table_correlation_length): + fig, ax = plt.subplots() + ax.errorbar( + x=table_correlation_length.index, + y=table_correlation_length.value, + yerr=table_correlation_length.error, + linestyle="", + marker="o", + ) + ax.set_xticklabels(table_correlation_length.index, rotation=45) + return fig @figure @@ -63,14 +72,14 @@ def plot_zero_momentum_correlator( if fit_zero_momentum_correlator is not None: popt, pcov, t0 = fit_zero_momentum_correlator - shift = popt[2] + xi, A, shift = popt t = np.linspace(t0, T - t0, 100) ax.plot( t, - cosh_shift(t - T // 2, *popt) - popt[2], + cosh_shift(t - T // 2, *popt) - shift, "r--", - label=r"fit $A \cosh(-(t - T/2) / \xi) + c$", + label=r"fit $A \cosh(-(t - T/2) / \xi) + c$" + "\n" + fr"$\xi = ${xi:.2f}", ) ax.errorbar( x=np.arange(T), diff --git a/anvil/sample.py b/anvil/sample.py index 3402877..0bba426 100644 --- a/anvil/sample.py +++ b/anvil/sample.py @@ -172,7 +172,7 @@ def metropolis_hastings( _metropolis_hastings = collect("metropolis_hastings", ("training_context",)) -def configs(_metropolis_hastings): +def configs_from_metropolis(_metropolis_hastings): return _metropolis_hastings[0][0] @@ -186,24 +186,26 @@ def acceptance(_metropolis_hastings): # TODO: figure out how to name each coupling block @torch.no_grad() -def yield_configs_layerwise(loaded_model, base_dist, metropolis_hastings): - v, _ = base_dist(BATCH_SIZE) - yield v +def yield_configs_layerwise(loaded_model, base_dist, sample_size, layer_id): + v, _ = base_dist(sample_size) + if layer_id == 0: + return v negative_mag = (v.sum(dim=1).sign() < 0).nonzero().squeeze() + i = 1 for block in loaded_model: v, _ = block(v, 0, negative_mag) # only want coupling layers if len([tensor for tensor in block.state_dict().values()]) > 1: - yield v - - v = metropolis_hastings[0] - yield v + if i == layer_id: + return v + else: + i += 1 -_layerwise_configs = collect("yield_configs_layerwise", ("training_context",)) +_configs_from_model = collect("yield_configs_layerwise", ("training_context",)) -def layerwise_configs(_layerwise_configs): - return _layerwise_configs[0] +def configs_from_model(_configs_from_model): + return _configs_from_model[0] diff --git a/anvil/table.py b/anvil/table.py index 5713114..2501495 100644 --- a/anvil/table.py +++ b/anvil/table.py @@ -42,6 +42,9 @@ def table_fit(fit_zero_momentum_correlator, training_geometry): index=["xi_fit", "m_fit"], ) return df + else: + # TODO should fail better than this + return pd.DataFrame([]) @table @@ -100,7 +103,7 @@ def table_correlation_length( df = pd.DataFrame( res, - columns=["Mean", "Standard deviation"], + columns=["value", "error"], index=[ "Estimate from fit", "Estimate using arcosh", @@ -108,7 +111,7 @@ def table_correlation_length( "Low momentum estimate", ], ) - df["No. correlation lengths"] = training_geometry.length / df["Mean"] + df["No. correlation lengths"] = training_geometry.length / df["value"] return df diff --git a/examples/runcards/layerwise.md b/examples/runcards/layerwise.md new file mode 100644 index 0000000..38cc9b8 --- /dev/null +++ b/examples/runcards/layerwise.md @@ -0,0 +1,30 @@ +Layer-wise breakdown +==================== + +Model weights +------------- +{@plot_layerwise_weights@} + +Field variables +--------------- +{@with layer_ids@} +{@plot_layer_histogram@} +{@endwith@} + +Correlation function +-------------------- +{@with layer_ids@} +{@plot_two_point_correlator@} +{@endwith@} + +Correlation function at zero momentum +------------------------------------- +{@with layer_ids@} +{@plot_zero_momentum_correlator@} +{@endwith@} + +Correlation length +------------------ +{@with layer_ids@} +{@plot_correlation_length@} +{@endwith@} diff --git a/examples/runcards/layerwise.yml b/examples/runcards/layerwise.yml new file mode 100644 index 0000000..530e77c --- /dev/null +++ b/examples/runcards/layerwise.yml @@ -0,0 +1,19 @@ +training_output: train +cp_id: -1 +layer_ids: [1, 2, 3, -1] + +sample_size: 10000 +thermalization: 1000 +sample_interval: 1 + +bootstrap_sample_size: 100 + +meta: + author: Author + title: Training Report + keywords: [example] + +template: layerwise.md + +actions_: + - report(main=True) diff --git a/examples/runcards/report.md b/examples/runcards/report.md index edab757..9f46acc 100644 --- a/examples/runcards/report.md +++ b/examples/runcards/report.md @@ -10,12 +10,10 @@ {@plot_zero_momentum_correlator@} ## Correlation length {@plot_effective_pole_mass@} +{@plot_correlation_length@} {@table_correlation_length@} ## Magnetisation {@table_magnetization@} {@plot_magnetization_series@} {@plot_magnetization_autocorr@} {@plot_magnetization_integrated_autocorr@} -## Layer-wise breakdown -{@plot_layerwise_histograms@} -{@plot_layerwise_weights@} diff --git a/examples/runcards/report.yml b/examples/runcards/report.yml index 25bcd47..342851c 100644 --- a/examples/runcards/report.yml +++ b/examples/runcards/report.yml @@ -1,11 +1,12 @@ training_output: train cp_id: -1 +layer_id: -1 sample_size: 10000 thermalization: 10 sample_interval: 1 -bootstrap_sample_size: 100 +bootstrap_sample_size: 1000 meta: author: Author diff --git a/examples/runcards/sample.yml b/examples/runcards/sample.yml index c0fee7a..79742b3 100644 --- a/examples/runcards/sample.yml +++ b/examples/runcards/sample.yml @@ -1,5 +1,6 @@ training_output: train cp_id: -1 +layer_id: -1 sample_size: 10000 thermalization: 1000 From f215adb44f8b5e8cf52d2258917740063581c2e6 Mon Sep 17 00:00:00 2001 From: marshrossney <17361029+marshrossney@users.noreply.github.com> Date: Mon, 26 Apr 2021 16:55:20 +0100 Subject: [PATCH 3/4] fix broken test due to class name change --- anvil/tests/test_distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/anvil/tests/test_distributions.py b/anvil/tests/test_distributions.py index 939e92a..b219bc9 100644 --- a/anvil/tests/test_distributions.py +++ b/anvil/tests/test_distributions.py @@ -5,7 +5,7 @@ from math import sqrt import numpy as np -from anvil.distributions import NormalDist +from anvil.distributions import Gaussian MEAN = 0 SIGMA = 1 @@ -20,7 +20,7 @@ def test_normal_distribution(): """ lattice_size = 5 - generator = NormalDist(lattice_size, sigma=SIGMA, mean=MEAN) + generator = Gaussian(lattice_size, sigma=SIGMA, mean=MEAN) sample_pt, _ = generator(N_SAMPLE) sample_np = sample_pt.detach().numpy() np.testing.assert_allclose( From 68a58d37a8216bc59429aed399b626a91d29737c Mon Sep 17 00:00:00 2001 From: marshrossney <17361029+marshrossney@users.noreply.github.com> Date: Mon, 26 Apr 2021 17:08:19 +0100 Subject: [PATCH 4/4] Revert "fix broken test due to class name change" This reverts commit f215adb44f8b5e8cf52d2258917740063581c2e6. Reverting update to fix broken test, since made in other PR --- anvil/tests/test_distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/anvil/tests/test_distributions.py b/anvil/tests/test_distributions.py index b219bc9..939e92a 100644 --- a/anvil/tests/test_distributions.py +++ b/anvil/tests/test_distributions.py @@ -5,7 +5,7 @@ from math import sqrt import numpy as np -from anvil.distributions import Gaussian +from anvil.distributions import NormalDist MEAN = 0 SIGMA = 1 @@ -20,7 +20,7 @@ def test_normal_distribution(): """ lattice_size = 5 - generator = Gaussian(lattice_size, sigma=SIGMA, mean=MEAN) + generator = NormalDist(lattice_size, sigma=SIGMA, mean=MEAN) sample_pt, _ = generator(N_SAMPLE) sample_np = sample_pt.detach().numpy() np.testing.assert_allclose(