Skip to content

Commit

Permalink
fix: make some small corrections to the python example_1 and test out…
Browse files Browse the repository at this point in the history
… the built-in autocorrelation observer
  • Loading branch information
denehoffman committed Dec 20, 2024
1 parent 40457b8 commit 5e31018
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
57 changes: 42 additions & 15 deletions python_examples/example_1/example_1_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@
from corner import corner


class AutoCorrCallback(ld.MCMCObserver):
def __init__(self, nll: ld.NLL, ncheck=20, dact=0.01, nact=50, discard=0.5):
# This custom observer differs from the one provided by `laddu`. Rather than tracing the
# walker positions and calculating the IAT, this first projects the current walker positions
# onto the two constituent waves and uses those to calculate a different IAT. This converges
# better because there is an implicit symmetry in the problem (only the absolute phase between
# the waves matters, not the sign of that phase) so walkers can bounce between two equivalent
# positions in the fit space which are very separate in the parameter space. It also
# demonstrates how to write and use a custom observer.
class CustomAutocorrelationObserver(ld.MCMCObserver):
def __init__(self, nll: ld.NLL, ncheck=20, dact=0.05, nact=20, discard=0.5):
self.nll = nll
self.ncheck = ncheck
self.dact = dact
Expand Down Expand Up @@ -39,7 +46,7 @@ def callback(self, step: int, ensemble: ld.Ensemble) -> tuple[ld.Ensemble, bool]
self.s0s.append(s0s)
self.d2s.append(d2s)
if step % self.ncheck == 0:
logger.info('Checking Autocorrelation')
logger.info('Checking Autocorrelation (custom)')
logger.info(
f'Chain dimensions: {ensemble.dimension[0]} walkers, {ensemble.dimension[1]} steps, {ensemble.dimension[2]} parameters'
)
Expand All @@ -64,6 +71,7 @@ def callback(self, step: int, ensemble: ld.Ensemble) -> tuple[ld.Ensemble, bool]
logger.info(
f'Δτ/τ = {abs(self.latest_tau - tau) / tau} (converges if < {self.dact})'
)
logger.info('End of custom Autocorrelation check')
converged = (tau * self.nact < step) and (
abs(self.latest_tau - tau) / tau < self.dact
)
Expand Down Expand Up @@ -160,7 +168,13 @@ def main():
bin_out = pickle.load(bin_out_file)
ensemble = bin_out['ensemble']
tau = bin_out['tau']
flat_chain = ensemble.get_flat_chain(burn=int(tau * 3))
taus = bin_out['taus']
(_, n_steps, _) = ensemble.dimension
n_steps_burned = n_steps - int(tau * 10) # 210
requested_steps = 100
excess_steps = n_steps_burned - requested_steps # 110
thin = 1 if excess_steps < 0 else n_steps_burned // requested_steps
flat_chain = ensemble.get_flat_chain(burn=int(tau * 3), thin=thin)
tot = []
s0s = []
d2s = []
Expand All @@ -171,29 +185,42 @@ def main():
else:
p0 = best.x + np.random.normal(0, scale=0.01, size=(100, len(best.x)))
nll_clone = nll
acc = AutoCorrCallback(nll_clone)
ensemble = nll.mcmc(p0, 30000, observers=[acc])
tau = acc.latest_tau
bin_out = {'ensemble': ensemble, 'tau': tau}
tot = np.array(acc.tot).reshape(-1)
s0s = np.array(acc.s0s).reshape(-1)
d2s = np.array(acc.d2s).reshape(-1)
flat_chain = ensemble.get_flat_chain(burn=int(tau * 3))
caco = CustomAutocorrelationObserver(nll_clone)
aco = ld.AutocorrelationObserver(n_check=10, terminate=False, verbose=True)
ensemble = nll.mcmc(p0, 30000, observers=[caco, aco])
tau = caco.latest_tau
taus = aco.taus
bin_out = {'ensemble': ensemble, 'tau': tau, 'taus': taus}
tot = np.array(caco.tot).reshape(-1)
s0s = np.array(caco.s0s).reshape(-1)
d2s = np.array(caco.d2s).reshape(-1)
(_, n_steps, _) = ensemble.dimension
n_steps_burned = n_steps - int(tau * 10) # 210
requested_steps = 100
excess_steps = n_steps_burned - requested_steps # 110
thin = 1 if excess_steps < 0 else n_steps_burned // requested_steps
flat_chain = ensemble.get_flat_chain(burn=int(tau * 3), thin=thin)
with open(f'bin_{ibin}_mcmc.pkl', 'wb') as bin_out_file:
pickle.dump(bin_out, bin_out_file)

chain = ensemble.get_chain(burn=int(tau * 3)).transpose(1, 0, 2)
chain = ensemble.get_chain(burn=int(tau * 10), thin=thin).transpose(1, 0, 2)
_, axes = plt.subplots(3, figsize=(10, 7), sharex=True)
labels = ['$S_0^+$ real', '$D_2^+$ real', '$D_2^+$ imag']
for i in range(3):
ax = axes[i]
ax.plot(chain[:, :, i], 'k', alpha=0.3)
ax.set_xlim(0, len(chain))
ax.plot(np.arange(len(chain)) + int(tau * 10), chain[:, :, i], 'k', alpha=0.3)
ax.set_xlim(int(tau * 10), len(chain) + int(tau * 10))
ax.set_ylabel(labels[i])
ax.yaxis.set_label_coords(-0.1, 0.5)
axes[-1].set_xlabel('step number')
plt.savefig(f'mcmc_plots/trace_{ibin}.svg')
plt.close()
plt.plot(np.arange(len(taus)) * 10, taus)
plt.xlabel('Step')
plt.ylabel(r'Mean $\tau$')
plt.tight_layout()
plt.savefig(f'mcmc_plots/iat_{ibin}.svg')
plt.close()
corner(
flat_chain,
truths=best.x,
Expand Down
7 changes: 3 additions & 4 deletions src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2476,11 +2476,10 @@ pub(crate) mod laddu {
if let Ok(Some(observer_arg)) = kwargs.get_item("observers") {
if let Ok(observer_list) = observer_arg.downcast::<PyList>() {
for item in observer_list.iter() {
if let Ok(observer) = item.extract::<PyMCMCObserver>() {
// TODO: fix this
observers.push(Arc::new(RwLock::new(observer)));
} else if let Ok(observer) = item.downcast::<AutocorrelationObserver>() {
if let Ok(observer) = item.downcast::<AutocorrelationObserver>() {
observers.push(observer.borrow().0.clone());
} else if let Ok(observer) = item.extract::<PyMCMCObserver>() {
observers.push(Arc::new(RwLock::new(observer)));
}
}
} else if let Ok(single_observer) =
Expand Down

0 comments on commit 5e31018

Please sign in to comment.