From 47a40483399017654efc2be98d7f0abfaaee8ea7 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 14 Feb 2024 20:22:14 -0800 Subject: [PATCH 1/3] Timescale normalisation and other fixes --- docs/methods.md | 22 +- docs/usage.md | 5 - evaluation/evaluate_accuracy.py | 813 -------------------------------- requirements.txt | 1 + tests/TODO | 5 + tests/test_cli.py | 16 +- tests/test_evaluation.py | 2 +- tests/test_functions.py | 129 +++-- tests/test_inference.py | 131 ++--- tests/test_provenance.py | 18 +- tests/utility_functions.py | 33 ++ tsdate/__init__.py | 1 + tsdate/approx.py | 268 ++++++++++- tsdate/cli.py | 39 +- tsdate/core.py | 302 ++++++------ tsdate/evaluation.py | 438 ++++++++++++++++- tsdate/hypergeo.py | 133 ++++++ tsdate/mixture.py | 267 ----------- tsdate/normalisation.py | 415 ++++++++++++++++ tsdate/util.py | 341 ++++++++------ tsdate/variational.py | 410 ++++++++++------ 21 files changed, 2055 insertions(+), 1734 deletions(-) delete mode 100644 evaluation/evaluate_accuracy.py create mode 100644 tests/TODO delete mode 100644 tsdate/mixture.py create mode 100644 tsdate/normalisation.py diff --git a/docs/methods.md b/docs/methods.md index ee471c35..a1382485 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -29,11 +29,11 @@ each timepoint). Continuous-time approaches approximate the posterior by a continuous univariate distribution (e.g. a gamma distribution). -In tests, we find that the continuous-time `variational_gamma` approach is -the most accurate (but can suffer from {ref}`numerical instability`). -The discrete-time `inside_outside` approach is slightly less accurate, especially for older times, -but is more numerically robust, and the discrete-time `maximization` approach is -always stable but is the least accurate. +In tests, we find that the continuous-time `variational_gamma` approach is the +most accurate. The discrete-time `inside_outside` approach is slightly less +accurate, especially for older times, but is more numerically robust, and the +discrete-time `maximization` approach is always stable but is the least +accurate. Changing the method is very simple: @@ -43,13 +43,13 @@ import tskit import tsdate input_ts = tskit.load("data/basic_example.trees") -ts = tsdate.date(input_ts, method="variational_gamma", population_size=100, mutation_rate=1e-8) +ts = tsdate.date(input_ts, method="variational_gamma", mutation_rate=1e-8) ``` Alternatively each method can be called directly as a separate function: ```{code-cell} ipython3 -ts = tsdate.variational_gamma(input_ts, population_size=100, mutation_rate=1e-8) +ts = tsdate.variational_gamma(input_ts, mutation_rate=1e-8) ``` Currently the default is `inside_outside`, but this may change in future releases. @@ -127,13 +127,6 @@ local estimates to each gamma distribution are iteratively refined until they converge to a stable solution. This comes under a class of approaches sometimes known as "loopy belief propagation". -:::{todo} -Add details about [numerical instability](sec_usage_real_data_stability), -describing expected errors (e.g. about non-convergence of a hypergeometric series), -and detailing potential workarounds using the `max_shape` option to constrain the -gamma variance. -::: - :::{note} As a result of testing, the default priors used for this method are identical for all nodes (i.e. a "global" prior is used), based on a composite @@ -166,6 +159,5 @@ ts = tsdate.date( input_ts, method="variational_gamma", progress=True, - population_size=100, mutation_rate=1e-8) ``` diff --git a/docs/usage.md b/docs/usage.md index 99b6b517..548eb1bb 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -308,11 +308,6 @@ The {func}`tsdate.preprocess_ts()` function can help remove topology from these regions. See the documentation for that function for details on how to increase or decrease its stringency. -The [`variational_gamma`](sec_methods_continuous_time_vgamma) method is more prone to -instability, and switching to another method may help. Note, however, that this is usually -a sign that you should re-inspect the original tree sequence, which is likely to -have poorly inferred topologies. - (sec_usage_real_data_simplify)= ### Simplification and unary nodes diff --git a/evaluation/evaluate_accuracy.py b/evaluation/evaluate_accuracy.py deleted file mode 100644 index e9f58ae5..00000000 --- a/evaluation/evaluate_accuracy.py +++ /dev/null @@ -1,813 +0,0 @@ -import argparse -import sys - -sys.path.insert(1, "../tsdate") -import tsdate # NOQA: E402 -from tsdate.date import ( # NOQA: E402, F401 - SpansBySamples, - ConditionalCoalescentTimes, - fill_prior, - Likelihoods, - InOutAlgorithms, - NodeGridValues, - posterior_mean_var, - constrain_ages_topo, - LogLikelihoods, -) -import msprime # NOQA: E402 -import numpy as np # NOQA: E402 -import scipy # NOQA: E402 -import matplotlib.pyplot as plt # NOQA: E402 -from tqdm import tqdm # NOQA: E402 -import tsinfer # NOQA: E402 -from sklearn.metrics import mean_squared_log_error # NOQA: E402 - -from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes # NOQA: E402 - - -def get_prior_results(): - def evaluate_prior(ts, Ne, prior_distr, progress=False): - fixed_node_set = set(ts.samples()) - num_samples = len(fixed_node_set) - - span_data = SpansBySamples(ts, fixed_node_set, progress=progress) - base_priors = ConditionalCoalescentTimes(None, prior_distr) - base_priors.add(len(fixed_node_set), False) - mixture_prior = base_priors.get_mixture_prior_params(span_data) - confidence_intervals = np.zeros((ts.num_nodes - ts.num_samples, 4)) - - if prior_distr == "lognorm": - lognorm_func = scipy.stats.lognorm - for node in np.arange(num_samples, ts.num_nodes): - confidence_intervals[node - num_samples, 0] = np.sum( - span_data.get_weights(node)[num_samples].descendant_tips - * span_data.get_weights(node)[num_samples].weight - ) - confidence_intervals[node - num_samples, 1] = ( - 2 - * Ne - * lognorm_func.mean( - s=np.sqrt(mixture_prior[node, 1]), - scale=np.exp(mixture_prior[node, 0]), - ) - ) - confidence_intervals[node - num_samples, 2:4] = ( - 2 - * Ne - * lognorm_func.ppf( - [0.025, 0.975], - s=np.sqrt(mixture_prior[node, 1]), - scale=np.exp(mixture_prior[node, 0]), - ) - ) - elif prior_distr == "gamma": - gamma_func = scipy.stats.gamma - for node in np.arange(ts.num_samples, ts.num_nodes): - confidence_intervals[node - num_samples, 0] = np.sum( - span_data.get_weights(node)[ts.num_samples].descendant_tips - * span_data.get_weights(node)[ts.num_samples].weight - ) - confidence_intervals[node - num_samples, 1] = ( - 2 - * Ne - * gamma_func.mean( - mixture_prior[node, 0], scale=1 / mixture_prior[node, 1] - ) - ) - confidence_intervals[node - num_samples, 2:4] = ( - 2 - * Ne - * gamma_func.ppf( - [0.025, 0.975], - mixture_prior[node, 0], - scale=1 / mixture_prior[node, 1], - ) - ) - return confidence_intervals - - all_results = { - i: { - i: [] - for i in [ - "in_range", - "expectations", - "real_ages", - "ts_size", - "upper_bound", - "lower_bound", - "num_tips", - ] - } - for i in ["Lognormal_0", "Lognormal_1e-8", "Gamma_0", "Gamma_1e-8"] - } - - for prior, (prior_distr, rec_rate) in tqdm( - zip( - all_results.keys(), - [("lognorm", 0), ("lognorm", 1e-8), ("gamma", 0), ("gamma", 1e-8)], - ), - desc="Evaluating Priors", - total=4, - ): - for i in range(1, 11): - Ne = 10000 - ts = msprime.simulate( - sample_size=100, - length=5e5, - Ne=Ne, - mutation_rate=1e-8, - recombination_rate=rec_rate, - random_seed=i, - ) - - confidence_intervals = evaluate_prior(ts, Ne, prior_distr) - all_results[prior]["in_range"].append( - np.sum( - np.logical_and( - ts.tables.nodes.time[ts.num_samples :] - < confidence_intervals[:, 3], - ts.tables.nodes.time[ts.num_samples :] - > confidence_intervals[:, 2], - ) - ) - ) - all_results[prior]["lower_bound"].append(confidence_intervals[:, 2]) - all_results[prior]["upper_bound"].append(confidence_intervals[:, 3]) - all_results[prior]["expectations"].append(confidence_intervals[:, 1]) - all_results[prior]["num_tips"].append(confidence_intervals[:, 0]) - all_results[prior]["real_ages"].append( - ts.tables.nodes.time[ts.num_samples :] - ) - all_results[prior]["ts_size"].append(ts.num_nodes - ts.num_samples) - - return all_results - - -def make_prior_plot(all_results): - fig, ax = plt.subplots(2, 2, figsize=(16, 12), sharex=True, sharey=True) - axes = ax.ravel() - plt.xscale("log") - plt.yscale("log") - plt.xlim(1.9, 110) - plt.ylim(1e-3, 4e5) - - for index, ((name, result), mixtures) in enumerate( - zip(all_results.items(), [False, False, False, False]) - ): - num_tips_all = np.concatenate(result["num_tips"]).ravel() - num_tips_all_int = num_tips_all.astype(int) - only_mixtures = np.full(len(num_tips_all), True) - if mixtures: - only_mixtures = np.where((num_tips_all - num_tips_all_int) != 0)[0] - - upper_bound_all = np.concatenate(result["upper_bound"]).ravel()[only_mixtures] - lower_bound_all = np.concatenate(result["lower_bound"]).ravel()[only_mixtures] - expectations_all = np.concatenate(result["expectations"]).ravel()[only_mixtures] - - real_ages_all = np.concatenate(result["real_ages"]).ravel()[only_mixtures] - num_tips_all = num_tips_all[only_mixtures] - yerr = [expectations_all - lower_bound_all, upper_bound_all - expectations_all] - - axes[index].errorbar( - num_tips_all, - expectations_all, - ls="none", - yerr=yerr, - elinewidth=0.1, - alpha=0.2, - color="grey", - label="95% credible interval of the prior", - ) - - axes[index].scatter( - num_tips_all, real_ages_all, s=1, alpha=0.5, color="blue", label="True Time" - ) - axes[index].scatter( - num_tips_all, - expectations_all, - s=1, - color="red", - label="expected time", - alpha=0.5, - ) - coverage = np.sum( - np.logical_and( - real_ages_all < upper_bound_all, real_ages_all > lower_bound_all - ) - ) / len(expectations_all) - axes[index].text( - 0.35, - 0.25, - "Overall Coverage Probability:" + f"{coverage:.3f}", - size=10, - ha="center", - va="center", - transform=axes[index].transAxes, - ) - less5_tips = np.where(num_tips_all < 5)[0] - coverage = np.sum( - np.logical_and( - real_ages_all[less5_tips] < upper_bound_all[less5_tips], - (real_ages_all[less5_tips] > lower_bound_all[less5_tips]), - ) - / len(expectations_all[less5_tips]) - ) - axes[index].text( - 0.35, - 0.21, - "<10 Tips Coverage Probability:" + f"{coverage:.3f}", - size=10, - ha="center", - va="center", - transform=axes[index].transAxes, - ) - mrcas = np.where(num_tips_all == 100)[0] - coverage = np.sum( - np.logical_and( - real_ages_all[mrcas] < upper_bound_all[mrcas], - (real_ages_all[mrcas] > lower_bound_all[mrcas]), - ) - / len(expectations_all[mrcas]) - ) - axes[index].text( - 0.35, - 0.17, - "MRCA Coverage Probability:" + f"{coverage:.3f}", - size=10, - ha="center", - va="center", - transform=axes[index].transAxes, - ) - axes[index].set_title( - "Evaluating Conditional Coalescent Using " - + name.split("_")[0] - + " Prior: \n 10 Samples of n=1000, \ - length=500kb, mu=1e-8, p=" - + name.split("_")[1] - ) - axins = zoomed_inset_axes(axes[index], 2.7, loc=7) - axins.errorbar( - num_tips_all, - expectations_all, - ls="none", - yerr=yerr, - elinewidth=0.5, - alpha=0.1, - color="grey", - solid_capstyle="projecting", - capsize=5, - label="95% credible interval of the prior", - ) - axins.scatter( - num_tips_all, real_ages_all, s=2, color="blue", alpha=0.5, label="True Time" - ) - axins.scatter( - num_tips_all, - expectations_all, - s=2, - color="red", - label="expected time", - alpha=0.5, - ) - x1, x2, y1, y2 = 90, 105, 5e3, 3e5 - axins.set_xlim(x1, x2) - axins.set_ylim(y1, y2) - axins.set_xscale("log") - axins.set_yscale("log") - plt.yticks(visible=False) - plt.xticks(visible=False) - from mpl_toolkits.axes_grid1.inset_locator import mark_inset - - mark_inset(axes[index], axins, loc1=2, loc2=1, fc="none", ec="0.5") - lgnd = axes[3].legend(loc=4, prop={"size": 12}, bbox_to_anchor=(1, -0.3)) - lgnd.legendHandles[0]._sizes = [30] - lgnd.legendHandles[1]._sizes = [30] - lgnd.legendHandles[2]._linewidths = [2] - fig.text(0.5, 0.04, "Number of Tips", ha="center", size=15) - fig.text( - 0.04, - 0.5, - "Expectation of the Prior Distribution on Node Age", - va="center", - rotation="vertical", - size=15, - ) - - plt.savefig("evaluation/evaluating_conditional_coalescent_prior", dpi=300) - - -def evaluate_tsdate_accuracy( - parameter, - parameters_arr, - node_mut=False, - inferred=True, - prior_distr="lognorm", - progress=True, -): - Ne = 10000 - if node_mut and inferred: - raise ValueError("cannot evaluate node accuracy on inferred tree sequence") - mutation_rate = 1e-8 - recombination_rate = 1e-8 - all_results = { - i: {i: [] for i in ["io", "max", "true_times"]} - for i in list(map(str, parameters_arr)) - } - - random_seeds = range(1, 6) - - if inferred: - inferred_progress = "using tsinfer" - else: - inferred_progress = "true topology" - if node_mut: - node_mut_progress = "comparing true and estimated node times" - else: - node_mut_progress = "comparing true and estimated mutation times" - for _, param in tqdm( - enumerate(parameters_arr), - desc="Testing " - + parameter - + " " - + inferred_progress - + ". Evaluation by " - + node_mut_progress, - total=len(parameters_arr), - disable=not progress, - ): - for random_seed in random_seeds: - if parameter == "sample_size": - sample_size = param - else: - sample_size = 100 - ts = msprime.simulate( - sample_size=sample_size, - Ne=Ne, - length=1e6, - mutation_rate=mutation_rate, - recombination_rate=recombination_rate, - random_seed=random_seed, - ) - - if parameter == "length": - ts = msprime.simulate( - sample_size=sample_size, - Ne=Ne, - length=param, - mutation_rate=mutation_rate, - recombination_rate=recombination_rate, - random_seed=random_seed, - ) - if parameter == "mutation_rate": - mutated_ts = msprime.mutate(ts, rate=param, random_seed=random_seed) - else: - mutated_ts = msprime.mutate( - ts, rate=mutation_rate, random_seed=random_seed - ) - if inferred: - sample_data = tsinfer.formats.SampleData.from_tree_sequence( - mutated_ts, use_times=False - ) - target_ts = tsinfer.infer(sample_data).simplify() - else: - target_ts = mutated_ts - - if parameter == "mutation_rate": - io_dated = tsdate.date( - target_ts, - mutation_rate=param, - Ne=Ne, - progress=False, - method="inside_outside", - ) - max_dated = tsdate.date( - target_ts, - mutation_rate=param, - Ne=Ne, - progress=False, - method="maximization", - ) - elif parameter == "timepoints": - prior = tsdate.build_prior_grid( - target_ts, - timepoints=param, - approximate_prior=True, - prior_distribution=prior_distr, - progress=False, - ) - io_dated = tsdate.date( - target_ts, - mutation_rate=mutation_rate, - prior=prior, - Ne=Ne, - progress=False, - method="inside_outside", - ) - max_dated = tsdate.date( - target_ts, - mutation_rate=mutation_rate, - prior=prior, - Ne=Ne, - progress=False, - method="maximization", - ) - else: - io_dated = tsdate.date( - target_ts, - mutation_rate=mutation_rate, - Ne=Ne, - progress=False, - method="inside_outside", - ) - max_dated = tsdate.date( - target_ts, - mutation_rate=mutation_rate, - Ne=Ne, - progress=False, - method="maximization", - ) - if node_mut and not inferred: - all_results[str(param)]["true_times"].append( - mutated_ts.tables.nodes.time[ts.num_samples :] - ) - all_results[str(param)]["io"].append( - io_dated.tables.nodes.time[ts.num_samples :] - ) - all_results[str(param)]["max"].append( - max_dated.tables.nodes.time[ts.num_samples :] - ) - else: - all_results[str(param)]["true_times"].append( - mutated_ts.tables.nodes.time[mutated_ts.tables.mutations.node] - ) - all_results[str(param)]["io"].append( - io_dated.tables.nodes.time[io_dated.tables.mutations.node] - ) - all_results[str(param)]["max"].append( - max_dated.tables.nodes.time[max_dated.tables.mutations.node] - ) - - return all_results, prior_distr, inferred, node_mut - - -def plot_tsdate_accuracy( - all_results, parameter, parameter_arr, prior_distr, inferred, node_mut -): - f, axes = plt.subplots(3, 2, figsize=(16, 12), sharex=True, sharey=True) - axes[0, 0].set_xscale("log") - axes[0, 0].set_yscale("log") - axes[0, 0].set_xlim(2e-1, 2e5) - axes[0, 0].set_ylim(2e-1, 2e5) - - for index, param in enumerate(parameter_arr): - true_ages = np.concatenate(all_results[param]["true_times"]) - maximized = np.concatenate(all_results[param]["max"]) - inside_outside = np.concatenate(all_results[param]["io"]) - - axes[index, 0].scatter( - true_ages, inside_outside, alpha=0.2, s=10, label="Inside-Outside" - ) - axes[index, 1].scatter(true_ages, maximized, alpha=0.2, s=10, label="Maximized") - axes[index, 0].plot(plt.xlim(), plt.ylim(), ls="--", c=".3") - axes[index, 1].plot(plt.xlim(), plt.ylim(), ls="--", c=".3") - - axes[index, 0].text( - 0.05, - 0.9, - "RMSLE: " + f"{mean_squared_log_error(true_ages, inside_outside):.2f}", - transform=axes[index, 0].transAxes, - size=15, - ) - axes[index, 1].text( - 0.05, - 0.9, - "RMSLE: " + f"{mean_squared_log_error(true_ages, maximized):.2f}", - transform=axes[index, 1].transAxes, - size=15, - ) - axes[index, 0].text( - 0.05, - 0.8, - "Pearson's r: " - + f"{scipy.stats.pearsonr(true_ages, inside_outside)[0]:.2f}", - transform=axes[index, 0].transAxes, - size=15, - ) - axes[index, 1].text( - 0.05, - 0.8, - "Pearson's r: " + f"{scipy.stats.pearsonr(true_ages, maximized)[0]:.2f}", - transform=axes[index, 1].transAxes, - size=15, - ) - axes[index, 0].text( - 0.05, - 0.7, - "Spearman's Rho: " - + f"{scipy.stats.spearmanr(true_ages, inside_outside)[0]:.2f}", - transform=axes[index, 0].transAxes, - size=15, - ) - axes[index, 1].text( - 0.05, - 0.7, - "Spearman's Rho: " - + f"{scipy.stats.spearmanr(true_ages, maximized)[0]:.2f}", - transform=axes[index, 1].transAxes, - size=15, - ) - axes[index, 0].text( - 0.05, - 0.6, - "Bias:" + f"{np.mean(true_ages) - np.mean(inside_outside):.2f}", - transform=axes[index, 0].transAxes, - size=15, - ) - axes[index, 1].text( - 0.05, - 0.6, - "Bias:" + f"{np.mean(true_ages) - np.mean(maximized):.2f}", - transform=axes[index, 1].transAxes, - size=15, - ) - axes[index, 1].text( - 1.04, - 0.8, - parameter + ": " + str(param), - rotation=90, - color="Red", - transform=axes[index, 1].transAxes, - size=20, - ) - - axes[0, 0].set_title("Inside-Outside", size=20) - axes[0, 1].set_title("Maximization", size=20) - - f.text(0.5, 0.05, "True Time", ha="center", size=25) - f.text(0.04, 0.5, "Estimated Time", va="center", rotation="vertical", size=25) - - if inferred: - inferred = "Inferred" - else: - inferred = "True Topologies" - - if node_mut: - node_mut = "Nodes" - else: - node_mut = "Mutations" - - if parameter == "Mut Rate": - plt.suptitle( - "Evaluating " - + parameter - + ": " - + inferred - + " " - + node_mut - + " vs. True " - + node_mut - + ". \n Inside-Outside Algorithm and Maximization. \n" - + prior_distr - + " Prior, n=100, Length=1Mb, Rec Rate=1e-8", - y=0.99, - size=21, - ) - elif parameter == "Sample Size": - plt.suptitle( - "Evaluating " - + parameter - + ": " - + inferred - + " " - + node_mut - + " vs. True " - + node_mut - + ". \n Inside-Outside Algorithm and Maximization. \n" - + prior_distr - + " Prior, Length=1Mb, Mut Rate=1e-8, Rec Rate=1e-8", - y=0.99, - size=21, - ) - elif parameter == "Length": - plt.suptitle( - "Evaluating " - + parameter - + ": " - + inferred - + " " - + node_mut - + " vs. True " - + node_mut - + ". \n Inside-Outside Algorithm and Maximization. \n" - + prior_distr - + " Prior, n=100, Mut Rate=1e-8, Rec Rate=1e-8", - y=0.99, - size=21, - ) - elif parameter == "Timepoints": - plt.suptitle( - "Evaluating " - + parameter - + ": " - + inferred - + " " - + node_mut - + " vs. True " - + node_mut - + ". \n Inside-Outside Algorithm and Maximization. \n" - + prior_distr - + " Prior, n=100, length=1Mb, Mut Rate=1e-8, Rec Rate=1e-8", - y=0.99, - size=21, - ) - # plt.tight_layout() - plt.savefig( - "evaluation/" - + parameter - + "_" - + inferred - + "_" - + node_mut - + "_" - + prior_distr - + "_accuracy", - dpi=300, - bbox_inches="tight", - ) - - -def run_eval(args): - if args.prior: - all_results = get_prior_results() - make_prior_plot(all_results) - if args.sample_size: - samplesize_inf, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "sample_size", [50, 250, 500], inferred=True, progress=True - ) - plot_tsdate_accuracy( - samplesize_inf, - "Sample Size", - ["50", "250", "500"], - prior_distr, - inferred, - node_mut, - ) - samplesize_inf_node, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "sample_size", [50, 250, 500], inferred=False, node_mut=True, progress=True - ) - plot_tsdate_accuracy( - samplesize_inf_node, - "Sample Size", - ["50", "250", "500"], - prior_distr, - inferred, - node_mut, - ) - samplesize_node, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "sample_size", [50, 250, 500], inferred=False, progress=True - ) - plot_tsdate_accuracy( - samplesize_node, - "Sample Size", - ["50", "250", "500"], - prior_distr, - inferred, - node_mut, - ) - if args.mutation_rate: - mutrate_inf, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "mutation_rate", [1e-09, 1e-08, 1e-07], inferred=True, progress=True - ) - plot_tsdate_accuracy( - mutrate_inf, - "Mut Rate", - ["1e-09", "1e-08", "1e-07"], - prior_distr, - inferred, - node_mut, - ) - mutrate_inf_node, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "mutation_rate", - [1e-09, 1e-08, 1e-07], - inferred=False, - node_mut=True, - progress=True, - ) - plot_tsdate_accuracy( - mutrate_inf_node, - "Mut Rate", - ["1e-09", "1e-08", "1e-07"], - prior_distr, - inferred, - node_mut, - ) - mutrate_node, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "mutation_rate", [1e-09, 1e-08, 1e-07], inferred=False, progress=True - ) - plot_tsdate_accuracy( - mutrate_node, - "Mut Rate", - ["1e-09", "1e-08", "1e-07"], - prior_distr, - inferred, - node_mut, - ) - - if args.length: - length_inf, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "length", [5e4, 5e5, 5e6], inferred=True, progress=True - ) - plot_tsdate_accuracy( - length_inf, - "Length", - ["50000.0", "500000.0", "5000000.0"], - prior_distr, - inferred, - node_mut, - ) - length_inf_node, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "length", [5e4, 5e5, 5e6], inferred=False, node_mut=True, progress=True - ) - plot_tsdate_accuracy( - length_inf_node, - "Length", - ["50000.0", "500000.0", "5000000.0"], - prior_distr, - inferred, - node_mut, - ) - length_node, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "length", [5e4, 5e5, 5e6], inferred=False, progress=True - ) - plot_tsdate_accuracy( - length_node, - "Length", - ["50000.0", "500000.0", "5000000.0"], - prior_distr, - inferred, - node_mut, - ) - - if args.timepoints: - timepoints_inf, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "timepoints", [5, 10, 50], inferred=True, progress=True - ) - plot_tsdate_accuracy( - timepoints_inf, - "Timepoints", - ["5", "10", "50"], - prior_distr, - inferred, - node_mut, - ) - timepoints_inf_node, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "timepoints", [5, 10, 50], inferred=False, node_mut=True, progress=True - ) - plot_tsdate_accuracy( - timepoints_inf_node, - "Timepoints", - ["5", "10", "50"], - prior_distr, - inferred, - node_mut, - ) - timepoints_node, prior_distr, inferred, node_mut = evaluate_tsdate_accuracy( - "timepoints", [5, 10, 50], inferred=False, progress=True - ) - plot_tsdate_accuracy( - timepoints_node, - "Timepoints", - ["5", "10", "50"], - prior_distr, - inferred, - node_mut, - ) - - -def main(): - parser = argparse.ArgumentParser(description="Evaluate tsdate.") - parser.add_argument("--prior", action="store_true", help="Evaluate the prior") - parser.add_argument( - "--sample-size", - action="store_true", - help="Evaluate effect of variable sample size", - ) - parser.add_argument( - "--mutation-rate", - action="store_true", - help="Evaluate effect of variable mutation rate", - ) - parser.add_argument( - "--length", action="store_true", help="Evaluate effect of variable length" - ) - parser.add_argument( - "--timepoints", - action="store_true", - help="Evaluate effect of variable numbers of timepoints", - ) - args = parser.parse_args() - run_eval(args) - - -if __name__ == "__main__": - main() diff --git a/requirements.txt b/requirements.txt index 08e557c5..b8f421fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,5 +15,6 @@ pytest-cov mpmath numdifftools setuptools>=45 +matplotlib twine build diff --git a/tests/TODO b/tests/TODO new file mode 100644 index 00000000..a0824b4a --- /dev/null +++ b/tests/TODO @@ -0,0 +1,5 @@ +1. mutation posteriors +2. poisson changepoint algorithm +3. fixed changepoint algorithm +4. interval/edge intersection algorithm +5. metadata posteriors match posterior array diff --git a/tests/test_cli.py b/tests/test_cli.py index c7661adb..9067b9bf 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -248,7 +248,7 @@ def test_verbosity(self, tmp_path, caplog, flag, log_status): ) def test_no_progress(self, method, tmp_path, capfd): input_ts = msprime.simulate(4, random_seed=123) - params = f"-m 0.1 --method {method}" + params = f"-m 0.1 --method {method} --normalisation-intervals 0" self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}") (out, err) = capfd.readouterr() assert out == "" @@ -257,7 +257,7 @@ def test_no_progress(self, method, tmp_path, capfd): def test_progress(self, tmp_path, capfd): input_ts = msprime.simulate(4, random_seed=123) - params = "--method inside_outside --progress" + params = "--method inside_outside --progress --normalisation-intervals 0" self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}") (out, err) = capfd.readouterr() assert out == "" @@ -268,7 +268,6 @@ def test_progress(self, tmp_path, capfd): "Find Mixture Priors", "Inside", "Outside", - "Constrain Ages", ) for match in desc: assert match in err @@ -277,7 +276,8 @@ def test_progress(self, tmp_path, capfd): def test_iterative_progress(self, tmp_path, capfd): input_ts = msprime.simulate(4, random_seed=123) - params = "--method variational_gamma --mutation-rate 1e-8 --progress" + params = "--method variational_gamma --mutation-rate 1e-8 " + params += "--progress --normalisation-intervals 0" self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}") (out, err) = capfd.readouterr() assert out == "" @@ -315,7 +315,7 @@ def ts_equal(self, ts1, ts2, times_equal=False): col_t2 = getattr(t2.nodes, column_name) assert np.array_equal(col_t1, col_t2) for column_name in t1.mutations.column_names: - if column_name not in ["time"]: + if column_name not in ["time", "metadata", "metadata_offset"]: col_t1 = getattr(t1.mutations, column_name) col_t2 = getattr(t2.mutations, column_name) assert np.array_equal(col_t1, col_t2) @@ -337,8 +337,12 @@ def verify(self, tmp_path, input_ts, params=None): def compare_python_api(self, tmp_path, input_ts, params, Ne, mutation_rate, method): output_ts = self.run_tsdate_cli(tmp_path, input_ts, params) + popsize = None if method == "variational_gamma" else Ne dated_ts = tsdate.date( - input_ts, population_size=Ne, mutation_rate=mutation_rate, method=method + input_ts, + population_size=popsize, + mutation_rate=mutation_rate, + method=method, ) assert np.array_equal(dated_ts.nodes_time, output_ts.nodes_time) diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index ead30be1..10f7d9a0 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -70,7 +70,7 @@ def _clade_dict(tree): assert ts.sequence_length == other.sequence_length assert ts.num_samples == other.num_samples out = np.zeros((ts.num_nodes, other.num_nodes)) - for (interval, query_tree, target_tree) in ts.coiterate(other): + for interval, query_tree, target_tree in ts.coiterate(other): query = _clade_dict(query_tree) target = _clade_dict(target_tree) span = interval.right - interval.left diff --git a/tests/test_functions.py b/tests/test_functions.py index 9ab14f2c..5e784d96 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -36,10 +36,10 @@ import tsinfer import tskit import utility_functions +from utility_functions import constrain_ages_topo import tsdate from tsdate import base -from tsdate.core import constrain_ages_topo from tsdate.core import date from tsdate.core import DiscreteTimeMethod from tsdate.core import InOutAlgorithms @@ -53,9 +53,9 @@ from tsdate.prior import MixturePrior from tsdate.prior import PriorParams from tsdate.prior import SpansBySamples +from tsdate.util import constrain_ages from tsdate.util import nodes_time_unconstrained from tsdate.util import split_disjoint_nodes -from tsdate.util import split_root_nodes class TestBasicFunctions: @@ -787,23 +787,6 @@ def test_log_tri_functions(self): ) -class TestVariational: - # TODO - needs a few more tests in here - def test_variational_nosize(self): - ts = utility_functions.two_tree_mutation_ts() - with pytest.raises(ValueError, match="Must specify population size"): - tsdate.variational_gamma(ts, mutation_rate=1) - - def test_variational_toomanysizes(self): - ts = utility_functions.two_tree_mutation_ts() - Ne = 1 - priors = tsdate.build_prior_grid(ts, Ne, np.array([0, 1.2, 2])) - with pytest.raises(ValueError, match="Cannot specify"): - tsdate.variational_gamma( - ts, mutation_rate=1, population_size=Ne, priors=priors - ) - - class TestNodeGridValuesClass: def test_init(self): num_nodes = 5 @@ -1639,11 +1622,8 @@ def test_node_metadata_simulated_tree(self): for met in tskit.unpack_bytes(metadata, metadata_offset) if len(met.decode()) > 0 ] - assert np.allclose(unconstrained_mn, mn_post[larger_ts.num_samples :]) - assert np.all( - dated_ts.tables.nodes.time[larger_ts.num_samples :] - >= mn_post[larger_ts.num_samples :] - ) + assert np.allclose(unconstrained_mn, mn_post) + assert np.all(dated_ts.tables.nodes.time >= mn_post) class TestConstrainAgesTopo: @@ -1658,7 +1638,7 @@ def test_constrain_ages_topo(self): ts = utility_functions.two_tree_ts() post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0]) eps = 1e-6 - constrained_ages = constrain_ages_topo(ts, post_mn, eps) + constrained_ages = constrain_ages(ts, post_mn, eps, 0) assert np.array_equal( np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages ) @@ -1673,7 +1653,7 @@ def test_constrain_ages_topo_node_order_bug(self): ts = ts.subset([0, 1, 5, 3, 4, 2]) # alter the node order post_mn = np.array([3.0, 0.0, 0.0, 0.0, 0.0, 0.0]) eps = 1e-6 - constrained_ages = constrain_ages_topo(ts, post_mn, eps) + constrained_ages = constrain_ages(ts, post_mn, eps, 0) tables = ts.dump_tables() tables.nodes.time = constrained_ages tables.sort() @@ -1682,7 +1662,7 @@ def test_constrain_ages_topo_unary_nodes_unordered(self): ts = utility_functions.single_tree_ts_with_unary() post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 0.5, 5.0, 1.0]) eps = 1e-6 - constrained_ages = constrain_ages_topo(ts, post_mn, eps) + constrained_ages = constrain_ages(ts, post_mn, eps, 0) assert np.allclose( np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 2.000002, 5.0, 5.000001]), constrained_ages, @@ -1692,7 +1672,7 @@ def test_constrain_ages_topo_part_dangling(self): ts = utility_functions.two_tree_ts_n2_part_dangling() post_mn = np.array([1.0, 0.0, 0.0, 0.1, 0.05]) eps = 1e-6 - constrained_ages = constrain_ages_topo(ts, post_mn, eps) + constrained_ages = constrain_ages(ts, post_mn, eps, 0) assert np.allclose( np.array([1.0, 0.0, 0.0, 1.000001, 1.000002]), constrained_ages ) @@ -1701,18 +1681,73 @@ def test_constrain_ages_topo_sample_as_parent(self): ts = utility_functions.single_tree_ts_n3_sample_as_parent() post_mn = np.array([0.0, 0.0, 0.0, 3.0, 1.0]) eps = 1e-6 - constrained_ages = constrain_ages_topo(ts, post_mn, eps) + constrained_ages = constrain_ages(ts, post_mn, eps, 0) assert np.allclose(np.array([0.0, 0.0, 0.0, 3.0, 3.000001]), constrained_ages) def test_two_tree_ts_n3_non_contemporaneous(self): ts = utility_functions.two_tree_ts_n3_non_contemporaneous() post_mn = np.array([0.0, 0.0, 3.0, 4.0, 0.1, 4.1]) eps = 1e-6 - constrained_ages = constrain_ages_topo(ts, post_mn, eps) + constrained_ages = constrain_ages(ts, post_mn, eps, 0) assert np.allclose( np.array([0.0, 0.0, 3.0, 4.0, 4.000001, 4.1]), constrained_ages ) + def test_constrain_ages_backcompat(self): + """ + Test constrain ages (without least squares correction) against + the original implementation + """ + ts = msprime.sim_ancestry( + 10, + population_size=1e4, + recombination_rate=1e-8, + sequence_length=1e6, + random_seed=1, + ) + ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) + sample_data = tsinfer.SampleData.from_tree_sequence(ts) + inf_ts = tsinfer.infer(sample_data).simplify() + noise = np.random.uniform(0, 0.1, size=inf_ts.num_nodes) + nodes_time = inf_ts.nodes_time + noise + eps = 1e-6 + blen = nodes_time[inf_ts.edges_parent] - nodes_time[inf_ts.edges_child] + assert np.any(blen < 0) + constr_1 = constrain_ages_topo(inf_ts, nodes_time, eps) + constr_2 = constrain_ages(inf_ts, nodes_time, eps, 0) + assert np.allclose(constr_1, constr_2, rtol=eps) + + def test_constrain_ages_leastsquare(self): + """ + Test that constrain ages with least squares correction has positive + branch lengths + """ + ts = msprime.sim_ancestry( + 10, + population_size=1e4, + recombination_rate=1e-8, + sequence_length=1e6, + random_seed=1, + ) + ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) + sample_data = tsinfer.SampleData.from_tree_sequence(ts) + inf_ts = tsinfer.infer(sample_data).simplify() + noise = np.random.uniform(0, 0.5, size=inf_ts.num_nodes) + nodes_time = inf_ts.nodes_time + noise + eps = 1e-6 + blen = nodes_time[inf_ts.edges_parent] - nodes_time[inf_ts.edges_child] + assert np.any(blen < 0) # ensure negative branches + constr_1 = constrain_ages(inf_ts, nodes_time, eps, 0) # no least squares + blen_1 = constr_1[inf_ts.edges_parent] - constr_1[inf_ts.edges_child] + assert np.all(blen_1 > 0) + constr_2 = constrain_ages(inf_ts, nodes_time, eps, 100) + blen_2 = constr_2[inf_ts.edges_parent] - constr_2[inf_ts.edges_child] + assert np.all(blen_2 > 0) + # check that r2 is improved by the least squares step + r2_1 = np.corrcoef(constr_1, nodes_time).flatten()[1] ** 2 + r2_2 = np.corrcoef(constr_2, nodes_time).flatten()[1] ** 2 + assert r2_2 > r2_1 + class TestPreprocessTs(unittest.TestCase): """ @@ -2288,20 +2323,20 @@ def test_split_disjoint_nodes(self): assert split_ts.num_edges == inferred_ts.num_edges assert split_ts.num_nodes > inferred_ts.num_nodes - def test_split_root_nodes(self): - ts = msprime.sim_ancestry( - 10, - population_size=1e4, - recombination_rate=1e-8, - sequence_length=1e6, - random_seed=1, - ) - ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) - sample_data = tsinfer.SampleData.from_tree_sequence(ts) - inferred_ts = tsinfer.infer(sample_data).simplify() - split_ts = split_root_nodes(inferred_ts) - split_root_nodes(ts) - assert not self.childset_changes_with_root(inferred_ts) - assert self.childset_changes_with_root(split_ts) - assert split_ts.num_edges > inferred_ts.num_edges - assert split_ts.num_nodes > inferred_ts.num_nodes + # def test_split_root_nodes(self): + # ts = msprime.sim_ancestry( + # 10, + # population_size=1e4, + # recombination_rate=1e-8, + # sequence_length=1e6, + # random_seed=1, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) + # sample_data = tsinfer.SampleData.from_tree_sequence(ts) + # inferred_ts = tsinfer.infer(sample_data).simplify() + # split_ts = split_root_nodes(inferred_ts) + # split_root_nodes(ts) + # assert not self.childset_changes_with_root(inferred_ts) + # assert self.childset_changes_with_root(split_ts) + # assert split_ts.num_edges > inferred_ts.num_edges + # assert split_ts.num_nodes > inferred_ts.num_nodes diff --git a/tests/test_inference.py b/tests/test_inference.py index 04aa79a5..200a43c5 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -33,6 +33,8 @@ from tsdate.base import LIN from tsdate.base import LOG from tsdate.demography import PopulationSizeHistory +from tsdate.evaluation import remove_edges +from tsdate.evaluation import unsupported_edges class TestPrebuilt: @@ -50,11 +52,14 @@ def test_no_population_size(self): with pytest.raises(ValueError, match="Must specify population size"): tsdate.date(ts, mutation_rate=None) - @pytest.mark.parametrize("method", ["maximization", "variational_gamma"]) - def test_no_mutation(self, method): + def test_no_mutation(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="method requires mutation rate"): - tsdate.date(ts, method=method, population_size=1, mutation_rate=None) + tsdate.date( + ts, method="maximization", population_size=1, mutation_rate=None + ) + with pytest.raises(ValueError, match="method requires mutation rate"): + tsdate.date(ts, method="variational_gamma", mutation_rate=None) def test_not_needed_population_size(self): ts = utility_functions.two_tree_mutation_ts() @@ -144,21 +149,18 @@ def test_discretised_posteriors(self): assert np.isclose(np.sum(posteriors[node.id]), 1) def test_variational_posteriors(self): + """ + There are no time-gridded posteriors returned by variational gamma, + so output is None + """ ts = utility_functions.two_tree_mutation_ts() ts, posteriors = tsdate.date( ts, mutation_rate=1e-2, - population_size=1, method="variational_gamma", return_posteriors=True, ) - assert len(posteriors) == ts.num_nodes - ts.num_samples + 1 - assert len(posteriors["parameter"]) == 2 - for node in ts.nodes(): - if not node.is_sample(): - assert node.id in posteriors - assert len(posteriors[node.id]) == 2 - assert np.all(posteriors[node.id] > 0) + assert posteriors is None def test_marginal_likelihood(self): ts = utility_functions.two_tree_mutation_ts() @@ -330,11 +332,11 @@ def test_non_contemporaneous(self): with pytest.raises(ValueError, match="noncontemporaneous"): tsdate.date(ts, population_size=1, mutation_rate=2) - def test_no_mutation_times(self): + def test_mutation_times(self): ts = msprime.simulate(20, Ne=1, mutation_rate=1, random_seed=12) assert np.all(ts.tables.mutations.time > 0) dated = tsdate.date(ts, population_size=1, mutation_rate=1) - assert np.all(np.isnan(dated.tables.mutations.time)) + assert np.all(~np.isnan(dated.tables.mutations.time)) @pytest.mark.skip("YAN to fix") def test_truncated_ts(self): @@ -406,84 +408,25 @@ class TestVariational: Tests for tsdate with variational algorithm """ - def test_simple_sim_1_tree(self): - ts = msprime.simulate(8, mutation_rate=5, random_seed=2) - tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma") - - def test_simple_sim_multi_tree(self): - ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) - tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma") - - def test_invalid_priors(self): - ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) - priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") - grid = priors.make_parameter_grid(population_size=1) - grid.grid_data[:] = [1.0, 0.0] # noninformative prior - with pytest.raises(ValueError, match="Non-positive shape/rate"): - tsdate.date( - ts, - mutation_rate=5, - method="variational_gamma", - priors=grid, - ) - - def test_custom_priors(self): - ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) - priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") - grid = priors.make_parameter_grid(population_size=1) - grid.grid_data[:] += 1.0 - tsdate.date( - ts, - mutation_rate=5, - method="variational_gamma", - priors=grid, - ) - - def test_prior_mixture_dim(self): - ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) - priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") - grid = priors.make_parameter_grid(population_size=1) - tsdate.date( - ts, - mutation_rate=5, - method="variational_gamma", - priors=grid, - prior_mixture_dim=2, - ) - - def test_bad_arguments(self): - ts = utility_functions.two_tree_mutation_ts() - with pytest.raises(ValueError, match="Maximum number of EP iterations"): - tsdate.date( - ts, - mutation_rate=5, - population_size=1, - method="variational_gamma", - max_iterations=-1, - ) - with pytest.raises(ValueError, match="must be a nonnegative integer"): - tsdate.date( - ts, - mutation_rate=5, - population_size=1, - method="variational_gamma", - prior_mixture_dim=0.1, - ) - - def test_match_central_moments(self): - ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) - ts0 = tsdate.date( - ts, - mutation_rate=5, - population_size=1, - method="variational_gamma", - match_central_moments=False, - ) - ts1 = tsdate.date( - ts, - mutation_rate=5, - population_size=1, - method="variational_gamma", - match_central_moments=True, - ) - assert np.any(np.not_equal(ts0.nodes_time, ts1.nodes_time)) + ts = msprime.sim_ancestry( + samples=10, + recombination_rate=1e-8, + sequence_length=1e5, + population_size=1e4, + random_seed=2, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-8, + ) + + def test_binary(self): + tsdate.date(self.ts, mutation_rate=1e-8, method="variational_gamma") + + def test_polytomy(self): + pts = remove_edges(self.ts, unsupported_edges(self.ts)).simplify() + tsdate.date(pts, mutation_rate=1e-8, method="variational_gamma") + + def test_inferred(self): + its = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(self.ts)).simplify() + tsdate.date(its, mutation_rate=1e-8, method="variational_gamma") diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 45db3b94..e94a8782 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -121,9 +121,14 @@ def test_preprocess_interval_recorded(self): @pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys()) def test_named_methods(self, method): - ts = utility_functions.single_tree_ts_n2() - dated_ts = tsdate.date(ts, method=method, mutation_rate=0.1, population_size=10) - dated_ts2 = getattr(tsdate, method)(ts, mutation_rate=0.1, population_size=10) + ts = utility_functions.single_tree_ts_mutation_n3() + popsize = None if method == "variational_gamma" else 10 + dated_ts = tsdate.date( + ts, method=method, mutation_rate=0.1, population_size=popsize + ) + dated_ts2 = getattr(tsdate, method)( + ts, mutation_rate=0.1, population_size=popsize + ) rec = json.loads(dated_ts.provenance(-1).record) assert rec["parameters"]["command"] == method rec = json.loads(dated_ts2.provenance(-1).record) @@ -131,16 +136,17 @@ def test_named_methods(self, method): @pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys()) def test_identical_methods(self, method): - ts = utility_functions.single_tree_ts_n2() + ts = utility_functions.single_tree_ts_mutation_n3() + popsize = None if method == "variational_gamma" else 10 dated_ts = tsdate.date( ts, method=method, mutation_rate=0.1, - population_size=10, + population_size=popsize, record_provenance=False, ) dated_ts2 = getattr(tsdate, method)( - ts, mutation_rate=0.1, population_size=10, record_provenance=False + ts, mutation_rate=0.1, population_size=popsize, record_provenance=False ) assert dated_ts.num_provenances == ts.num_provenances assert dated_ts == dated_ts2 diff --git a/tests/utility_functions.py b/tests/utility_functions.py index 7df8c1c9..f605583e 100644 --- a/tests/utility_functions.py +++ b/tests/utility_functions.py @@ -23,11 +23,13 @@ A collection of utilities to edit and construct tree sequences for testing purposes """ import io +import itertools import msprime import numpy as np import tskit from scipy.special import comb +from tqdm import tqdm def add_grand_mrca(ts): @@ -1115,3 +1117,34 @@ def tau_var(i, n): def conditional_coalescent_variance(n_tips): """Variance calculation for prior, slow but clear version""" return np.array([tau_var(i, n_tips) for i in range(n_tips + 1)]) + + +def constrain_ages_topo(ts, node_times, epsilon, progress=False): + """ + If node_times violate the topology in ts, return increased node_times so that each + node is guaranteed to be older than any of its children. + + Used to check back-compatibility. + """ + edges_parent = ts.edges_parent + edges_child = ts.edges_child + + new_node_times = np.copy(node_times) + # Traverse through the ARG, ensuring children come before parents. + # This can be done by iterating over groups of edges with the same parent + new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1 + for edges_start, edges_end in tqdm( + zip( + itertools.chain([0], new_parent_edge_idx), + itertools.chain(new_parent_edge_idx, [len(edges_parent)]), + ), + desc="Constrain Ages", + total=len(new_parent_edge_idx) + 1, + disable=not progress, + ): + parent = edges_parent[edges_start] + child_ids = edges_child[edges_start:edges_end] # May contain dups + oldest_child_time = np.max(new_node_times[child_ids]) + if oldest_child_time >= new_node_times[parent]: + new_node_times[parent] = oldest_child_time + epsilon + return new_node_times diff --git a/tsdate/__init__.py b/tsdate/__init__.py index ba79af6f..220c1c27 100644 --- a/tsdate/__init__.py +++ b/tsdate/__init__.py @@ -24,6 +24,7 @@ from .core import inside_outside # NOQA: F401 from .core import maximization # NOQA: F401 from .core import variational_gamma # NOQA: F401 +from .normalisation import normalise_tree_sequence as normalise # NOQA: F401 from .prior import parameter_grid as build_parameter_grid # NOQA: F401 from .prior import prior_grid as build_prior_grid # NOQA: F401 from .provenance import __version__ # NOQA: F401 diff --git a/tsdate/approx.py b/tsdate/approx.py index f94df90c..cb832217 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -24,6 +24,8 @@ Tools for approximating combinations of Gamma variates with Gamma distributions """ from math import exp +from math import inf +from math import lgamma from math import log import numba @@ -35,8 +37,8 @@ # TODO: these are reasonable defaults but could # be set via a control dict -_KLMIN_MAXITER = 100 -_KLMIN_TOL = np.sqrt(np.finfo(np.float64).eps) +_KLMIN_MAXITT = 100 +_KLMIN_RELTOL = np.sqrt(np.finfo(np.float64).eps) # shorthand for numba readonly array types, [type][dimension][constness] @@ -101,8 +103,8 @@ def approximate_gamma_kl(x, logx): delta = np.inf # determine convergence when the change in alpha falls below # some small value (e.g. square root of machine precision) - while np.abs(delta) > alpha * _KLMIN_TOL: - if itt > _KLMIN_MAXITER: + while np.abs(delta) > np.abs(alpha) * _KLMIN_RELTOL: + if itt > _KLMIN_MAXITT: raise KLMinimizationFailed("Maximum iterations reached in KL minimization") delta = hypergeo._digamma(alpha) - np.log(alpha) + np.log(x) - logx delta /= hypergeo._trigamma(alpha) - 1 / alpha @@ -126,6 +128,44 @@ def approximate_gamma_mom(mean, variance): return shape - 1.0, rate +@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f)) +def approximate_gamma_iqr(q1, q2, x1, x2): + """Find gamma natural parameters that match empirical quantiles""" + if not (q2 > q1 and x2 > x1): + raise KLMinimizationFailed("Quantiles must be sorted") + # find starting value from asymptotic solutions + # if x2 / x1 < log(1 - q2) / log(1 - q1): + # y1 = hypergeo._erf_inv(2 * q1 - 1) * sqrt(2) + # y2 = hypergeo._erf_inv(2 * q2 - 1) * sqrt(2) + # alpha = (y1 * x2 - y2 * x1) ** 2 / (x1 - x2) ** 2 + # else: + alpha = log(q2 / q1) / log(x2 / x1) + # refine with newton iteration + delta = inf + itt = 0 + while abs(delta) > abs(alpha) * _KLMIN_RELTOL: + if itt > _KLMIN_MAXITT: + raise KLMinimizationFailed( + "Maximum iterations reached in quantile matching" + ) + y1 = hypergeo._gammainc_inv(alpha, q1) + y2 = hypergeo._gammainc_inv(alpha, q2) + obj = y2 / y1 - x2 / x1 + inv_1 = -exp(y1 + log(y1) * (1 - alpha) + lgamma(alpha)) + inv_2 = -exp(y2 + log(y2) * (1 - alpha) + lgamma(alpha)) + # print(itt, alpha, y1, y2) #DEBUG + gra_1 = hypergeo._gammainc_der(alpha, y1) * inv_1 + gra_2 = hypergeo._gammainc_der(alpha, y2) * inv_2 + gra = (gra_2 * y1 - gra_1 * y2) / y1**2 + delta = -obj / gra + alpha += delta + itt += 1 + if not alpha > 0: + raise KLMinimizationFailed("Negative shape parameter") + beta = hypergeo._gammainc_inv(alpha, q1) / x1 + return alpha - 1, beta + + @numba.njit(_unituple(_f, 2)(_f1r, _f1r)) def average_gammas(alpha, beta): """ @@ -226,7 +266,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): # mn_i = shape / rate # va_i = shape / rate**2 # ln_i = hypergeo._digamma(shape) - log(rate) -# return logl, mn_i, va_i, ln_i +# return logl, mn_i, ln_i, va_i # # a = y_ij + 1 # b = a_i + y_ij + 1 @@ -278,7 +318,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): mn_i = s / r va_i = s / r**2 ln_i = hypergeo._digamma(s) - log(r) - return logl, mn_i, va_i, ln_i + return logl, mn_i, ln_i, va_i hyperu = hypergeo._hyperu_laplace f0, d0 = hyperu(a + 0, b + 0, z) @@ -406,8 +446,9 @@ def _hyperu_valid_parameterization(t_j, a_i, b_i, y, mu): """Uses shape / rate parameterization""" a = y + 1 b = a_i + y + 1 - z = t_j * (mu + b_i) - if z <= 0.0: + if t_j < 0.0: + return False + if mu + b_i <= 0.0: return False if not (b > a > 0.0): return False @@ -579,3 +620,214 @@ def truncated_projection(bounds_i, pars_i, min_kl): proj_i = approximate_gamma_mom(t_i, va_t_i) return logconst, np.array(proj_i) + + +# --- mutation posteriors from node posteriors --- # + + +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) +def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): + r""" + Calculate gamma sufficient statistics for the PDF proportional to: + + ..math:: + + p(x) = \int_0^\infty \int_0^{t_i} Unif(x | t_i, t_j) + Ga(t_i | a_i, b_i) Ga(t_j | a_j b_j) Po(y | \mu_ij (t_i - t_j)) dt_j dt_i + + which models the time :math:`x` of a mutation uniformly distributed between + parent age :math:`t_i` and child age :math:`t_j`, on a branch with + :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. + + Returns E[x], E[\log x], V[x]. + """ + + f, t_i, _, _, t_j, _, _ = moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) + f_ii, _, _, _, _, _, _ = moments(a_i + 2, b_i, a_j, b_j, y_ij, mu_ij) + f_ij, _, _, _, _, _, _ = moments(a_i + 1, b_i, a_j + 1, b_j, y_ij, mu_ij) + f_jj, _, _, _, _, _, _ = moments(a_i, b_i, a_j + 2, b_j, y_ij, mu_ij) + mn_m = t_i / 2 + t_j / 2 + sq_m = 1 / 3 * (np.exp(f_ii - f) + np.exp(f_ij - f) + np.exp(f_jj - f)) + va_m = sq_m - mn_m**2 + ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf + + return mn_m, ln_m, va_m + + +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): + r""" + Calculate gamma sufficient statistics for the PDF proportional to: + + ..math:: + + p(x) = \int_{t_j}^\infty Unif(x | t_i, t_j) + Ga(t_i | a_i, b_i) Po(y | \mu_ij (t_i - t_j)) dt_i + + which models the time :math:`x` of a mutation uniformly distributed between + parent age :math:`t_i` and child age :math:`t_j`, on a branch with + :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. + + Returns E[x], E[\log x], V[x]. + """ + + _, mn_i, _, va_i = rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + mn_m = mn_i / 2 + t_j / 2 + sq_m = (va_i + mn_i**2 + mn_i * t_j + t_j**2) / 3 + va_m = sq_m - mn_m**2 + ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf + + return mn_m, ln_m, va_m + + +@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): + r""" + Calculate gamma sufficient statistics for the PDF proportional to: + + ..math:: + + p(x) = \int_0^{t_i} Unif(x | t_i, t_j) + Ga(t_j | a_j, b_j) Po(y | \mu_ij (t_i - t_j)) dt_j + + which models the time :math:`x` of a mutation uniformly distributed between + parent age :math:`t_i` and child age :math:`t_j`, on a branch with + :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. + + Returns E[x], E[\log x], V[x]. + """ + + _, mn_j, _, va_j = leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + mn_m = mn_j / 2 + t_i / 2 + sq_m = (va_j + mn_j**2 + mn_j * t_i + t_i**2) / 3 + va_m = sq_m - mn_m**2 + ln_m = log(mn_m) - va_m / 2 / mn_m**2 if mn_m > 0 else -np.inf + + return mn_m, ln_m, va_m + + +@numba.njit(_f1r(_f1r, _f1r, _f1r, _b)) +def mutation_gamma_projection(pars_i, pars_j, pars_ij, min_kl): + r""" + Match a gamma distribution via KL minimization to the potential function + + ..math:: + + p(x) = \int_0^\infty \int_0^{t_i} Unif(x | t_i, t_j) + Ga(t_i | a_i, b_i) Ga(t_j | a_j b_j) Po(y | \mu_ij (t_i - t_j)) dt_j dt_i + + which models the time :math:`x` of a mutation uniformly distributed between + parent age :math:`t_i` and child age :math:`t_j`, on a branch with + :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. + + TODO: params + + :return: gamma parameters for mutation age + """ + + # switch from natural to canonical parameterization + a_i, b_i = pars_i + a_j, b_j = pars_j + y_ij, mu_ij = pars_ij + a_i += 1 + a_j += 1 + + if not _hyp2f1_valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): + return np.full(2, np.nan) + + t_m, ln_t_m, va_t_m = mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij) + + valid = _valid_moments(t_m, ln_t_m, va_t_m) + if not valid: + return np.full(2, np.nan) + + if min_kl: + proj_m = approximate_gamma_kl(t_m, ln_t_m) + else: + proj_m = approximate_gamma_mom(t_m, va_t_m) + + return np.array(proj_m) + + +@numba.njit(_f1r(_f, _f1r, _f1r, _b)) +def mutation_leafward_projection(t_i, pars_j, pars_ij, min_kl): + r""" + Match a gamma distribution via KL minimization to the potential function + + ..math:: + + p(x) = \int_0^{t_i} Unif(x | t_i, t_j) + Ga(t_j | a_j, b_j) Po(y | \mu_ij (t_i - t_j)) dt_j + + which models the time :math:`x` of a mutation uniformly distributed between + parent age :math:`t_i` and child age :math:`t_j`, on a branch with + :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. + + TODO + + :return: gamma parameters for mutation age + """ + + # switch from natural to canonical parameterization + a_j, b_j = pars_j + y_ij, mu_ij = pars_ij + a_j += 1 + + # skip update, zeroing out message + if not _hyp1f1_valid_parameterization(t_i, a_j, b_j, y_ij, mu_ij): + return np.full(2, np.nan) + + t_m, ln_t_m, va_t_m = mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij) + + valid = _valid_moments(t_m, ln_t_m, va_t_m) + if not valid: + return np.full(2, np.nan) + + if min_kl: + proj_m = approximate_gamma_kl(t_m, ln_t_m) + else: + proj_m = approximate_gamma_mom(t_m, va_t_m) + + return np.array(proj_m) + + +@numba.njit(_f1r(_f, _f1r, _f1r, _b)) +def mutation_rootward_projection(t_j, pars_i, pars_ij, min_kl): + r""" + Match a gamma distribution via KL minimization to the potential function + + ..math:: + + p(x) = \int_{t_j}^{\infty} Unif(x | t_i, t_j) + Ga(t_i | a_i, b_i) Po(y | \mu_ij (t_i - t_j)) dt_i + + which models the time :math:`x` of a mutation uniformly distributed between + parent age :math:`t_i` and child age :math:`t_j`, on a branch with + :math:`y_{ij}` mutations and total mutation rate :math:`\mu_{ij}`. + + TODO + + :return: gamma parameters for mutation age + """ + + # switch from natural to canonical parameterization + a_i, b_i = pars_i + y_ij, mu_ij = pars_ij + a_i += 1 + + # skip update, zeroing out message + if not _hyperu_valid_parameterization(t_j, a_i, b_i, y_ij, mu_ij): + return np.full(2, np.nan) + + t_m, ln_t_m, va_t_m = mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij) + + valid = _valid_moments(t_m, ln_t_m, va_t_m) + if not valid: + return np.full(2, np.nan) + + if min_kl: + proj_m = approximate_gamma_kl(t_m, ln_t_m) + else: + proj_m = approximate_gamma_mom(t_m, va_t_m) + + return np.array(proj_m) diff --git a/tsdate/cli.py b/tsdate/cli.py index 6248f358..24bc4429 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -91,6 +91,7 @@ def tsdate_cli_parser(): parser.add_argument( "population_size", type=float, + nargs="?", help="Estimated effective (diploid) population size.", ) parser.add_argument( @@ -191,42 +192,23 @@ def tsdate_cli_parser(): default=1000, ) parser.add_argument( - "--match-central-moments", - action="store_true", + "--normalisation-intervals", + type=float, help=( - "Match mean and variance rather than gamma sufficient statistics in " - 'the "variational_gamma" algorithm. Faster with similar accuracy, ' - "but does not exactly minimize KL divergence in each EP update." + "The number of time intervals within which to estimate a time " + "scaling parameter. Default: 1000" ), + default=1000, ) parser.add_argument( "--max-iterations", type=int, help=( "The number of iterations used in the expectation propagation " - "algorithm. Default: 20" - ), - default=20, - ) - parser.add_argument( - "--em-iterations", - type=int, - help=( - "The number of expectation-maximization iterations used to optimize the " - "i.i.d. mixture prior at the end of each expectation propagation step. " - "Setting to zero disables optimization. Default: 10" + "algorithm. Default: 10" ), default=10, ) - parser.add_argument( - "--prior-mixture-dim", - type=int, - help=( - "The number of components in the i.i.d. mixture prior for node " - "ages. Default: 1" - ), - default=1, - ) parser.set_defaults(runner=run_date) parser = subparsers.add_parser( @@ -283,12 +265,11 @@ def run_date(args): progress=args.progress, max_iterations=args.max_iterations, max_shape=args.max_shape, - match_central_moments=args.match_central_moments, - em_iterations=args.em_iterations, - prior_mixture_dim=args.prior_mixture_dim, + normalisation_intervals=args.normalisation_intervals, ) else: params = dict( + population_size=args.population_size, recombination_rate=args.recombination_rate, method=args.method, eps=args.epsilon, @@ -300,7 +281,7 @@ def run_date(args): params["ignore_oldest_root"] = args.ignore_oldest # For backwards compat # TODO: remove and error out if ignore_oldest_root is set, # see https://github.com/tskit-dev/tsdate/issues/262 - dated_ts = tsdate.date(ts, args.mutation_rate, args.population_size, **params) + dated_ts = tsdate.date(ts, args.mutation_rate, **params) dated_ts.dump(args.output) diff --git a/tsdate/core.py b/tsdate/core.py index 880b10b2..a76c0e70 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -29,6 +29,7 @@ import logging import multiprocessing import operator +import time # DEBUG from collections import defaultdict from collections import namedtuple @@ -40,7 +41,6 @@ from . import base from . import demography -from . import mixture from . import prior from . import provenance from . import util @@ -101,9 +101,9 @@ def __init__( # The mut_ll contains unpacked (1D) lower triangular matrices. We need to # index this by row and by column index. self.row_indices = [] - for time in range(self.grid_size): + for t in range(self.grid_size): n = np.arange(self.grid_size) - self.row_indices.append((((n * (n + 1)) // 2) + time)[time:]) + self.row_indices.append((((n * (n + 1)) // 2) + t)[t:]) self.col_indices = [] running_sum = 0 # use this to find the index of the last element of # each column in order to appropriately sum the vv by columns. @@ -869,37 +869,17 @@ def outside_maximization(self, *, eps, progress=None): return self.lik.timepoints[np.array(maximized_node_times).astype("int")] -def constrain_ages_topo(ts, node_times, epsilon, progress=False): - # If node_times violate the topology in ts, return increased node_times so that each - # node is guaranteed to be older than any of its children. - edges_parent = ts.edges_parent - edges_child = ts.edges_child - - new_node_times = np.copy(node_times) - # Traverse through the ARG, ensuring children come before parents. - # This can be done by iterating over groups of edges with the same parent - new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1 - for edges_start, edges_end in tqdm( - zip( - itertools.chain([0], new_parent_edge_idx), - itertools.chain(new_parent_edge_idx, [len(edges_parent)]), - ), - desc="Constrain Ages", - total=len(new_parent_edge_idx) + 1, - disable=not progress, - ): - parent = edges_parent[edges_start] - child_ids = edges_child[edges_start:edges_end] # May contain dups - oldest_child_time = np.max(new_node_times[child_ids]) - if oldest_child_time >= new_node_times[parent]: - new_node_times[parent] = oldest_child_time + epsilon - return new_node_times - - # Classes for each method Results = namedtuple( "Results", - ["posterior_mean", "posterior_var", "posterior_obj", "mutation_likelihood"], + [ + "posterior_mean", + "posterior_var", + "posterior_obj", + "mutation_mean", + "mutation_var", + "mutation_likelihood", + ], ) @@ -929,6 +909,7 @@ def __init__( return_posteriors=None, return_likelihood=None, record_provenance=None, + constr_iterations=None, progress=None, ): # Set up all the generic params describe in the tsdate.date function, and define @@ -942,6 +923,7 @@ def __init__( self.time_units = "generations" if time_units is None else time_units if record_provenance is None: record_provenance = True + Ne = population_size # shorthand if isinstance(Ne, dict): Ne = demography.PopulationSizeHistory(**Ne) @@ -957,47 +939,104 @@ def __init__( population_size=Ne.as_dict() if hasattr(Ne, "as_dict") else Ne, ) - if priors is None: - if Ne is None: + if constr_iterations is None: + self.constr_iterations = 0 + else: + if not (isinstance(constr_iterations, int) and constr_iterations >= 0): raise ValueError( - "Must specify population size if priors are not already built using" - f"tsdate.build_{self.prior_grid_func_name}()" + "Number of constrained least squares iterations must be a " + "non-negative integer" ) - mk_prior = getattr(prior, self.prior_grid_func_name) - # Default to not creating approximate priors unless ts has - # greater than DEFAULT_APPROX_PRIOR_SIZE samples - approx = True if ts.num_samples > base.DEFAULT_APPROX_PRIOR_SIZE else False - self.priors = mk_prior(ts, Ne, approximate_priors=approx, progress=progress) - else: - logging.info("Using user-specified priors") + self.constr_iterations = constr_iterations + + if self.prior_grid_func_name is None: + if priors is not None: + raise ValueError(f"Priors are not used for method {self.name}") if Ne is not None: - raise ValueError( - "Cannot specify population size if specifying priors " - f"from tsdate.build_{self.prior_grid_func_name}()" + raise ValueError(f"Population size is not used for method {self.name}") + else: + if priors is None: + if Ne is None: + raise ValueError( + "Must specify population size if priors are not already " + f"built using tsdate.build_{self.prior_grid_func_name}()" + ) + mk_prior = getattr(prior, self.prior_grid_func_name) + # Default to not creating approximate priors unless ts has + # greater than DEFAULT_APPROX_PRIOR_SIZE samples + approx = ( + True if ts.num_samples > base.DEFAULT_APPROX_PRIOR_SIZE else False + ) + self.priors = mk_prior( + ts, Ne, approximate_priors=approx, progress=progress ) - self.priors = priors + else: + logging.info("Using user-specified priors") + if Ne is not None: + raise ValueError( + "Cannot specify population size if specifying priors " + f"from tsdate.build_{self.prior_grid_func_name}()" + ) + self.priors = priors + + # mutation to edge mapping + mutspan_timing = time.time() + self.edges_mutations, self.mutations_edge = util.mutation_span_array(ts) + mutspan_timing -= time.time() + logging.info(f"Extracted mutations in {abs(mutspan_timing)} seconds") def get_modified_ts(self, result, eps): # Return a new ts based on the existing one, but with the various # time-related information correctly set. tables = self.ts.dump_tables() nodes = tables.nodes + mutations = tables.mutations if self.provenance_params is not None: provenance.record_provenance(tables, self.name, **self.provenance_params) - nodes.time = constrain_ages_topo(self.ts, result.posterior_mean, eps, self.pbar) + # Constrain node ages for positive branch lengths + constr_timing = time.time() + nodes.time = util.constrain_ages( + self.ts, result.posterior_mean, eps, self.constr_iterations + ) + mutations.time = util.constrain_mutations( + self.ts, nodes.time, self.mutations_edge + ) tables.time_units = self.time_units - tables.mutations.time = np.full(self.ts.num_mutations, tskit.UNKNOWN_TIME) - # Add posterior mean and variance to node metadata - if result.posterior_obj is not None: + constr_timing -= time.time() + logging.info(f"Constrained node ages in {abs(constr_timing)} seconds") + # Add posterior mean and variance to node/mutation metadata + # TODO: retain original metadata? + meta_timing = time.time() + if result.posterior_var is not None: metadata_array = tskit.unpack_bytes(nodes.metadata, nodes.metadata_offset) - for u in result.posterior_obj.nonfixed_nodes: + for u, (mn, vr) in enumerate( + zip(result.posterior_mean, result.posterior_var) + ): metadata_array[u] = json.dumps( { - "mn": result.posterior_mean[u], - "vr": result.posterior_var[u], + "mn": mn, + "vr": vr, } ).encode() nodes.packset_metadata(metadata_array) + if result.mutation_var is not None: + metadata_array = tskit.unpack_bytes( + mutations.metadata, mutations.metadata_offset + ) + for u, (mn, vr) in enumerate( + zip(result.mutation_mean, result.mutation_var) + ): + metadata_array[u] = json.dumps( + { + "mn": mn, + "vr": vr, + } + ).encode() + mutations.packset_metadata(metadata_array) + meta_timing -= time.time() + logging.info( + f"Inserted node and mutation metadata in {abs(meta_timing)} seconds" + ) tables.sort() return tables.tree_sequence() @@ -1105,7 +1144,9 @@ def run( posterior_obj.to_probabilities() posterior_mean, posterior_var = self.mean_var(self.ts, posterior_obj) - return Results(posterior_mean, posterior_var, posterior_obj, marginal_likl) + return Results( + posterior_mean, posterior_var, posterior_obj, None, None, marginal_likl + ) class MaximizationMethod(DiscreteTimeMethod): @@ -1132,26 +1173,18 @@ def run( dynamic_prog = self.main_algorithm(probability_space, eps, num_threads) marginal_likl = dynamic_prog.inside_pass(cache_inside=cache_inside) posterior_mean = dynamic_prog.outside_maximization(eps=eps) - return Results(posterior_mean, None, None, marginal_likl) + return Results(posterior_mean, None, None, None, None, marginal_likl) class VariationalGammaMethod(EstimationMethod): - prior_grid_func_name = "parameter_grid" + prior_grid_func_name = None name = "variational_gamma" def __init__(self, ts, **kwargs): super().__init__(ts, **kwargs) - # convert priors to natural parameterization - for n in self.priors.nonfixed_nodes: - if not np.all(self.priors[n] > 0.0): - raise ValueError( - f"Non-positive shape/rate parameters for node {n}: " - f"{self.priors[n]}" - ) - self.priors[n][0] -= 1.0 @staticmethod - def mean_var(ts, posterior): + def mean_var(posteriors, constraints): """ Mean and variance of node age from variational posterior (e.g. gamma distributions). Fixed nodes will be given a mean of their exact time in @@ -1159,28 +1192,27 @@ def mean_var(ts, posterior): fixed_node_set). This is a static method for ease of testing. """ - assert posterior.grid_data.shape[1] == 2 - assert np.all(posterior.grid_data > 0) + mn_post = np.full( + posteriors.shape[0], np.nan + ) # Fill with NaNs so we detect when + va_post = np.full(posteriors.shape[0], np.nan) # there's been an error - mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when - va_post = np.full(ts.num_nodes, np.nan) # there's been an error + fixed = constraints[:, 0] == constraints[:, 1] + mn_post[fixed] = constraints[fixed, 0] + va_post[fixed] = 0 - is_fixed = np.ones(posterior.num_nodes, dtype=bool) - is_fixed[posterior.nonfixed_nodes] = False - mn_post[is_fixed] = ts.nodes_time[is_fixed] - va_post[is_fixed] = 0 + for i in np.flatnonzero(~fixed): + pars = posteriors[i] + mn_post[i] = (pars[0] + 1) / pars[1] + va_post[i] = (pars[0] + 1) / pars[1] ** 2 - for node in posterior.nonfixed_nodes: - pars = posterior[node] - mn_post[node] = pars[0] / pars[1] - va_post[node] = pars[0] / pars[1] ** 2 return mn_post, va_post def main_algorithm(self): # edge likelihoods # TODO: variable mutation rates across genome # TODO: truncate edge spans with accessiblity mask - likelihoods = util.mutation_span_array(self.ts) + likelihoods = self.edges_mutations.copy() likelihoods[:, 1] *= self.mutation_rate # lower and upper bounds on node ages @@ -1190,7 +1222,7 @@ def main_algorithm(self): constraints[sample_idx, :] = self.ts.nodes_time[sample_idx, np.newaxis] return variational.ExpectationPropagation( - self.ts, likelihoods, constraints, self.global_prior + self.ts, likelihoods, constraints, self.mutations_edge ) def run( @@ -1199,8 +1231,9 @@ def run( max_iterations, max_shape, match_central_moments, - prior_mixture_dim, - em_iterations, + normalisation_intervals, + match_segregating_sites, + regularise_roots, ): if self.provenance_params is not None: self.provenance_params.update( @@ -1208,45 +1241,39 @@ def run( ) if not max_iterations >= 1: raise ValueError("Maximum number of EP iterations must be greater than 0") - if not (isinstance(prior_mixture_dim, int) and prior_mixture_dim >= 0): - raise ValueError( - "Number of mixture components must be a nonnegative integer" - ) if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") - # initialize weights/shapes/rates for i.i.d mixture prior - # note that self.priors (node-specific priors) are not currently - # used except for initialization of the mixture - self.global_prior = mixture.initialize_mixture( - self.priors.grid_data, prior_mixture_dim - ) - self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors - # match sufficient statistics or match central moments min_kl = not match_central_moments dynamic_prog = self.main_algorithm() dynamic_prog.run( ep_maxitt=max_iterations, - em_maxitt=em_iterations, max_shape=max_shape, min_kl=min_kl, + norm_intervals=normalisation_intervals, + regularise=regularise_roots, + norm_segsites=match_segregating_sites, progress=self.pbar, ) - num_skipped = np.sum(np.isnan(dynamic_prog.log_partition)) - if num_skipped > 0: - logging.info( - f"Skipped {num_skipped} messages with invalid posterior updates." - ) - - posterior_obj = self.priors.clone_with_new_data( - grid_data=dynamic_prog.posterior[self.priors.nonfixed_nodes, :] + # TODO: use dynamic_prog.point_estimate + posterior_mean, posterior_vari = self.mean_var( + dynamic_prog.posterior, dynamic_prog.constraints ) - posterior_obj.grid_data[:, 0] += 1 # to shape/rate parameterization - posterior_mean, posterior_var = self.mean_var(self.ts, posterior_obj) + + # TODO: clean up + mutation_post = dynamic_prog.mutations_posterior + mutation_mean = np.full(mutation_post.shape[0], np.nan) + mutation_vari = np.full(mutation_post.shape[0], np.nan) + idx = mutation_post[:, 1] > 0 + mutation_mean[idx] = (mutation_post[idx, 0] + 1) / mutation_post[idx, 1] + mutation_vari[idx] = (mutation_post[idx, 0] + 1) / mutation_post[idx, 1] ** 2 + # TODO: return marginal likelihood - return Results(posterior_mean, posterior_var, posterior_obj, None) + return Results( + posterior_mean, posterior_vari, None, mutation_mean, mutation_vari, None + ) def maximization( @@ -1449,9 +1476,10 @@ def variational_gamma( eps=None, max_iterations=None, max_shape=None, - match_central_moments=None, - prior_mixture_dim=None, - em_iterations=None, + normalisation_intervals=None, + match_central_moments=None, # undocumented + match_segregating_sites=None, # undocumented + regularise_roots=None, # undocumented **kwargs, ): """ @@ -1463,41 +1491,23 @@ def variational_gamma( .. code-block:: python - new_ts = tsdate.variational_gamma( - ts, mutation_rate=1e-8, population_size=1e4, max_iterations=10) + new_ts = tsdate.variational_gamma(ts, mutation_rate=1e-8, max_iterations=10) - An i.i.d. gamma mixture is used as a prior for each node, that is - initialized from the conditional coalescent and updated via expectation - maximization in each iteration. If node-specific priors are supplied - (via a grid of shape/rate parameters) then these are used for - initialization. - - .. note:: - The prior parameters for each node-to-be-dated take the form of a - gamma-distributed prior on node age, parameterised by shape and rate. - If the ``priors`` parameter is used, it must specify an object constructed - using :func:`build_parameter_grid`. If not used, ``population_size`` must be - provided, which is used to create an iid prior derived from the conditional - coalescent prior (tilted according to population size), assuming the nodes - to be dated are all the non-sample nodes in the input tree sequence. + An piecewise-constant uniform distribution is used as a prior for each + node, that is updated via expectation maximization in each iteration. + Node-specific priors are not currently supported. :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. :param float eps: The minimum distance separating parent and child ages in the returned tree sequence. Default: None, treated as 1e-6 :param int max_iterations: The number of iterations used in the expectation - propagation algorithm. Default: None, treated as 20. + propagation algorithm. Default: None, treated as 10. :param float max_shape: The maximum value for the shape parameter in the variational posteriors. This is equivalent to the maximum precision (inverse variance) on a logarithmic scale. Default: None, treated as 1000. - :param bool match_central_moments: If `True`, each expectation propgation - update matches mean and variance rather than expected gamma sufficient - statistics. Faster with a similar accuracy, but does not exactly minimize - Kullback-Leibler divergence. Default: None, treated as False. - :param int prior_mixture_dim: The number of components in the i.i.d. mixture prior - for node ages. Default: None, treated as 1. - :param int em_iterations: The number of expectation maximization iterations used - to optimize the i.i.d. mixture prior. Setting to zero disables optimization. - Default: None, treated as 10. + :param float normalisation_intervals: For normalisation, the number of time + intervals within which to estimate a rescaling parameter. Default None, + treated as 1000. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, notably ``mutation_rate``, and ``population_size`` or ``priors``. Further arguments include ``time_units``, ``progress``, and @@ -1524,15 +1534,17 @@ def variational_gamma( if eps is None: eps = 1e-6 if max_iterations is None: - max_iterations = 20 + max_iterations = 10 if max_shape is None: max_shape = 1000 + if normalisation_intervals is None: + normalisation_intervals = 1000 if match_central_moments is None: - match_central_moments = False - if prior_mixture_dim is None: - prior_mixture_dim = 1 - if em_iterations is None: - em_iterations = 10 + match_central_moments = True + if match_segregating_sites is None: + match_segregating_sites = False + if regularise_roots is None: + regularise_roots = True dating_method = VariationalGammaMethod(tree_sequence, **kwargs) result = dating_method.run( @@ -1540,8 +1552,9 @@ def variational_gamma( max_iterations=max_iterations, max_shape=max_shape, match_central_moments=match_central_moments, - prior_mixture_dim=prior_mixture_dim, - em_iterations=em_iterations, + normalisation_intervals=normalisation_intervals, + match_segregating_sites=match_segregating_sites, + regularise_roots=regularise_roots, ) return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]}) @@ -1572,6 +1585,7 @@ def date( priors=None, method=None, *, + constr_iterations=None, return_posteriors=None, return_likelihood=None, progress=None, @@ -1640,6 +1654,9 @@ def date( :param bool return_posteriors: If ``True``, instead of returning just a dated tree sequence, return a tuple of ``(dated_ts, posteriors)``. Default: None, treated as False. + :param int constr_iterations: The maximum number of constrained least + squares iterations to use prior to forcing positive branch lengths. + Default: None, treated as 0. :param bool return_likelihood: If ``True``, return the log marginal likelihood from the inside algorithm in addition to the dated tree sequence. If ``return_posteriors`` is also ``True``, then the marginal likelihood @@ -1679,6 +1696,7 @@ def date( time_units=time_units, priors=priors, progress=progress, + constr_iterations=constr_iterations, return_posteriors=return_posteriors, return_likelihood=return_likelihood, record_provenance=record_provenance, diff --git a/tsdate/evaluation.py b/tsdate/evaluation.py index 896f1708..7ed44784 100644 --- a/tsdate/evaluation.py +++ b/tsdate/evaluation.py @@ -23,9 +23,13 @@ Tools for comparing node times between tree sequences with different node sets """ import copy +import json from collections import defaultdict +from itertools import groupby from itertools import product +from math import isqrt +import matplotlib.pyplot as plt import numpy as np import scipy.sparse import tskit @@ -202,7 +206,7 @@ def shared_node_spans(ts, other): for clade_map in (query, target): if clade_map.interval[1] == right: clade_diff = clade_map.next() - for (prev, curr) in clade_diff.values(): + for prev, curr in clade_diff.values(): if prev != nil: modified.add(prev) if curr != nil: @@ -252,3 +256,435 @@ def match_node_ages(ts, other): matched_time[matched_span == 0] = np.nan return matched_time, matched_span, best_match + + +# --- infrastructure for testing against polytomies --- # + + +def remove_edges(ts, edge_id_remove_list): + edges_to_remove_by_child = defaultdict(list) + edge_id_remove_list = set(edge_id_remove_list) + for m in ts.mutations(): + if m.edge in edge_id_remove_list: + # If we remove this edge, we will remove the associated mutation + # as the child node won't have ancestral material in this region. + # So we force the user to explicitly (re)move the mutations beforehand + raise ValueError("Cannot remove edges that have associated mutations") + for remove_edge in edge_id_remove_list: + e = ts.edge(remove_edge) + edges_to_remove_by_child[e.child].append(e) + + # sort left-to-right for each child + for k, v in edges_to_remove_by_child.items(): + edges_to_remove_by_child[k] = sorted(v, key=lambda e: e.left) + # check no overlaps + for e1, e2 in zip(edges_to_remove_by_child[k], edges_to_remove_by_child[k][1:]): + assert e1.right <= e2.left + + # Sanity check: this means the topmost node will deal with modified edges + # left at the end + assert ts.edge(-1).parent not in edges_to_remove_by_child + + new_edges = defaultdict(list) + tables = ts.dump_tables() + tables.edges.clear() + # Edges are sorted by parent time, youngest first, so we can iterate over + # nodes-as-parents visiting children before parents by using itertools.groupby + for parent_id, ts_edges in groupby(ts.edges(), lambda e: e.parent): + # Iterate through the ts edges *plus* the polytomy edges we created in + # previous steps. This allows us to re-edit polytomy edges when the + # edges_to_remove are stacked + edges = list(ts_edges) + if parent_id in new_edges: + edges += new_edges.pop(parent_id) + if parent_id in edges_to_remove_by_child: + for e in edges: + assert parent_id == e.parent + left = -1 + if e.id in edge_id_remove_list: + continue + # NB: we go left to right along the target edges, reducing edge + # e as required + for target_edge in edges_to_remove_by_child[parent_id]: + # As we go along the target_edges, gradually split e into + # chunks. If edge e is in the target_edge region, change + # the edge parent + assert target_edge.left > left + left = target_edge.left + if e.left >= target_edge.right: + # This target edge is entirely to the LHS of edge e, + # with no overlap + continue + elif e.right <= target_edge.left: + # This target edge is entirely to the RHS of edge e + # with no overlap. Since target edges are sorted by + # left coord, all other target edges are to RHS too, + # and we are finished dealing with edge e + tables.edges.append(e) + e = None + break + else: + # Edge e must overlap with current target edge somehow + if e.left < target_edge.left: + # Edge had region to LHS of target + # Add the left hand section (change the edge right coord) + tables.edges.add_row( + left=e.left, + right=target_edge.left, + parent=e.parent, + child=e.child, + ) + e = e.replace(left=target_edge.left) + if e.right > target_edge.right: + # Edge continues after RHS of target + assert e.left < target_edge.right + new_edges[target_edge.parent].append( + e.replace( + right=target_edge.right, parent=target_edge.parent + ) + ) + e = e.replace(left=target_edge.right) + else: + # No more of edge e to RHS + assert e.left < e.right + new_edges[target_edge.parent].append( + e.replace(parent=target_edge.parent) + ) + e = None + break + if e is not None: + # Need to add any remaining regions of edge back in + tables.edges.append(e) + else: + # NB: sanity check at top means that the oldest node will have no + # edges above, so the last iteration should hit this branch + for e in edges: + if e.id not in edge_id_remove_list: + tables.edges.append(e) + assert len(new_edges) == 0 + tables.sort() + return tables.tree_sequence() + + +def unsupported_edges(ts, per_interval=False): + """ + Return the internal edges that are unsupported by a mutation. + If ``per_interval`` is True, each interval needs to be supported, + otherwise, a mutation on an edge (even if there are multiple intervals + per edge) will result in all intervals on that edge being treated + as supported. + """ + edges_to_remove = np.ones(ts.num_edges, dtype="bool") + edges_to_remove[[m.edge for m in ts.mutations()]] = False + # We don't remove edges above samples + edges_to_remove[np.isin(ts.edges_child, ts.samples())] = False + + if per_interval: + return np.where(edges_to_remove)[0] + else: + keep = ~edges_to_remove + for p, c in zip(ts.edges_parent[keep], ts.edges_child[keep]): + edges_to_remove[ + np.logical_and(ts.edges_parent == p, ts.edges_child == c) + ] = False + return np.where(edges_to_remove)[0] + + +# --- first drafts of diagnostic plots --- # + + +def node_coverage(ts, inferred_ts, alpha): + assert np.all(np.logical_and(1 > alpha, alpha > 0)) + posteriors = np.zeros((inferred_ts.num_nodes, 2)) + for n in inferred_ts.nodes(): + mn = json.loads(n.metadata or '{"mn":0}')["mn"] + vr = json.loads(n.metadata or '{"vr":0}')["vr"] + posteriors[n.id] = [mn**2 / vr, mn / vr] if vr > 0 else np.nan + positions = {p: i for i, p in enumerate(ts.sites_position)} + true_child = np.full(ts.sites_position.size, tskit.NULL) + infr_child = np.full(ts.sites_position.size, tskit.NULL) + for s in ts.sites(): + if len(s.mutations) == 1: + sid = positions[s.position] + true_child[sid] = s.mutations[0].node + for s in inferred_ts.sites(): + if len(s.mutations) == 1: + sid = positions[s.position] + nid = s.mutations[0].node + if not np.isnan(posteriors[nid, 0]): + infr_child[s.id] = s.mutations[0].node + missing = np.logical_or(true_child == tskit.NULL, infr_child == tskit.NULL) + infr_child = infr_child[~missing] + true_child = true_child[~missing] + post = posteriors[infr_child] + upper = np.zeros((post.shape[0], alpha.size)) + lower = np.zeros((post.shape[0], alpha.size)) + for i in range(post.shape[0]): + shape, rate = post[i, 0], post[i, 1] + if shape <= 1: + upper[i] = scipy.stats.gamma.ppf(1 - alpha, shape, scale=1 / rate) + lower[i] = 0.0 + else: + upper[i] = scipy.stats.gamma.ppf(1 - alpha / 2, shape, scale=1 / rate) + lower[i] = scipy.stats.gamma.ppf(alpha / 2, shape, scale=1 / rate) + true = ts.nodes_time[true_child] + is_covered = np.logical_and( + true[:, np.newaxis] < upper, true[:, np.newaxis] > lower + ) + prop_covered = np.sum(is_covered, axis=0) / is_covered.shape[0] + # import matplotlib.pyplot as plt + # plt.axline((0,0), slope=1, linestyle="--", color="black") + # plt.xlim(0, 1) + # plt.ylim(0, 1) + # plt.xlabel("Expected coverage") + # plt.xlabel("Observed coverage") + # plt.scatter(1 - alpha, prop_covered, color="red") + # plt.savefig(plot) + # plt.clf() + # plt.clf() + # fig, axs = plt.subplots(1, figsize=(10,5)) + # cmap = plt.get_cmap("plasma") + # samp = np.random.randint(0, true.size, size=1000) + # rnk = scipy.stats.rankdata(true[samp]) + # for i in range(alpha.size): + # axs.vlines( + # x=rnk, + # ymin=np.log10(lower[samp, i]) - np.log10(true[samp]), + # ymax=np.log10(upper[samp, i]) - np.log10(true[samp]), + # color=cmap(i/(alpha.size-1)), + # linewidth=1, + # ) + # axs.axhline(y=0, linestyle="--", color="black") + # axs.set_xlabel("True age rank order") + # axs.set_ylabel("Interval - true age (log)") + # plt.savefig("bar.png") + # plt.clf() + return prop_covered + + +def mutation_coverage(ts, inferred_ts, alpha): + assert np.all(np.logical_and(1 > alpha, alpha > 0)) + # extract mutation posteriors from metadata + posteriors = np.zeros((inferred_ts.num_mutations, 2)) + for m in inferred_ts.mutations(): + mn = json.loads(m.metadata or '{"mn":0}')["mn"] + vr = json.loads(m.metadata or '{"vr":0}')["vr"] + posteriors[m.id] = [mn**2 / vr, mn / vr] if vr > 0 else np.nan + # find shared biallelic sites + positions = {p: i for i, p in enumerate(ts.sites_position)} + true_mut = np.full(ts.sites_position.size, tskit.NULL) + infr_mut = np.full(ts.sites_position.size, tskit.NULL) + for s in ts.sites(): + if len(s.mutations) == 1: + sid = positions[s.position] + true_mut[sid] = s.mutations[0].id + for s in inferred_ts.sites(): + if len(s.mutations) == 1: + mid = s.mutations[0].id + if not np.isnan(posteriors[mid, 0]): + sid = positions[s.position] + infr_mut[sid] = s.mutations[0].id + missing = np.logical_or(true_mut == tskit.NULL, infr_mut == tskit.NULL) + infr_mut = infr_mut[~missing] + true_mut = true_mut[~missing] + # calculate coverage + post = posteriors[infr_mut] + upper = np.zeros((post.shape[0], alpha.size)) + lower = np.zeros((post.shape[0], alpha.size)) + for i in range(post.shape[0]): + shape, rate = post[i, 0], post[i, 1] + if shape <= 1: + upper[i] = scipy.stats.gamma.ppf(1 - alpha, shape, scale=1 / rate) + lower[i] = 0.0 + else: + upper[i] = scipy.stats.gamma.ppf(1 - alpha / 2, shape, scale=1 / rate) + lower[i] = scipy.stats.gamma.ppf(alpha / 2, shape, scale=1 / rate) + true = ts.mutations_time[true_mut] + is_covered = np.logical_and( + true[:, np.newaxis] < upper, true[:, np.newaxis] > lower + ) + prop_covered = np.sum(is_covered, axis=0) / is_covered.shape[0] + # plt.clf() + # plt.axline((0,0), slope=1, linestyle="--", color="black") + # plt.xlim(0, 1) + # plt.ylim(0, 1) + # plt.xlabel("Expected coverage") + # plt.xlabel("Observed coverage") + # plt.scatter(1 - alpha, prop_covered, color="red") + # plt.savefig(plot) + # plt.clf() + # fig, axs = plt.subplots(1, figsize=(10,5)) + # cmap = plt.get_cmap("plasma") + # samp = np.random.randint(0, true.size, size=1000) + # rnk = scipy.stats.rankdata(true[samp]) + # for i in range(alpha.size): + # axs.vlines( + # x=rnk, + # ymin=np.log10(lower[samp, i]) - np.log10(true[samp]), + # ymax=np.log10(upper[samp, i]) - np.log10(true[samp]), + # color=cmap(i/(alpha.size-1)), + # linewidth=1, + # ) + # axs.axhline(y=0, linestyle="--", color="black") + # axs.set_xlabel("True age rank order") + # axs.set_ylabel("Interval - true age (log)") + # plt.savefig("foo.png") + # plt.clf() + return prop_covered + + +def mutations_time(ts, inferred_ts, min_freq=None, max_freq=None, plotpath=None): + """ + Return true and inferred mutation ages, optionally creating a scatterplot and + filtering by minimum or maximum frequency. + """ + # find shared biallelic sites + positions = {p: i for i, p in enumerate(ts.sites_position)} + true_mut = np.full(ts.sites_position.size, tskit.NULL) + infr_mut = np.full(ts.sites_position.size, tskit.NULL) + for s in ts.sites(): + if len(s.mutations) == 1: + if s.mutations[0].edge != tskit.NULL: + sid = positions[s.position] + true_mut[sid] = s.mutations[0].id + for s in inferred_ts.sites(): + if len(s.mutations) == 1: + if s.mutations[0].edge != tskit.NULL: + sid = positions[s.position] + infr_mut[sid] = s.mutations[0].id + missing = np.logical_or(true_mut == tskit.NULL, infr_mut == tskit.NULL) + infr_mut = infr_mut[~missing] + true_mut = true_mut[~missing] + mean = inferred_ts.mutations_time[infr_mut] + truth = ts.mutations_time[true_mut] + # filter by frequency + if min_freq is not None or max_freq is not None: + freq = np.zeros(inferred_ts.num_mutations) + for t in inferred_ts.trees(): + for m in t.mutations(): + freq[m.id] = t.num_samples(m.node) + if min_freq is None: + min_freq = np.min(freq) + if max_freq is None: + max_freq = np.max(freq) + freq = freq[infr_mut] + is_freq = np.logical_and(freq >= min_freq, freq <= max_freq) + mean = mean[is_freq] + truth = truth[is_freq] + # plot + if plotpath is not None: + rsq = np.corrcoef(np.log10(mean), np.log10(truth))[0, 1] ** 2 + bias = np.mean(np.log10(mean) - np.log10(truth)) + pt1 = (truth.mean(), truth.mean()) + pt2 = (truth.mean() + 1, truth.mean() + 1) + info = f"$r^2 = {rsq:0.3f}$\n$\\mathrm{{bias}} = {bias:0.3f}$" + plt.hexbin(truth, mean, xscale="log", yscale="log", mincnt=1) + plt.text(0.01, 0.99, info, ha="left", va="top", transform=plt.gca().transAxes) + plt.axline(pt1, pt2, linestyle="--", color="firebrick") + plt.xlabel("True mutation age") + plt.ylabel("Estimated mutation age") + plt.tight_layout() + plt.savefig(plotpath) + plt.clf() + return truth, mean + + +def afs_bias(ts, mutation_rate, plotpath=None, polarised=True): + """ + Calculate site and branch allele frequency spectra across windows, where + adjacent AFS bins are pooled. Optionally produce a scatterplot for each + pooled bin. Optionally truncate the AFS at a given `max_freq`. + """ + ts_trim = ts.trim() + site_afs = ts_trim.allele_frequency_spectrum( + mode="site", span_normalise=False, polarised=polarised + ) + branch_afs = mutation_rate * ts_trim.allele_frequency_spectrum( + mode="branch", span_normalise=False, polarised=polarised + ) + if plotpath is not None: + plt.scatter(np.arange(site_afs.size), site_afs, c="black", s=8) + plt.scatter(np.arange(site_afs.size), branch_afs, c="firebrick", s=8) + plt.xlabel("Mutation frequency") + plt.ylabel("# mutations") + plt.yscale("log") + plt.savefig(plotpath) + plt.clf() + return site_afs, branch_afs + + +def allele_frequency_spectra( + ts, + mutation_rate, + plotpath=None, + title=None, + max_freq=None, + num_bins=9, + num_windows=500, + polarised=True, + size_biased=False, +): + """ + Calculate site and branch allele frequency spectra across windows, where + adjacent AFS bins are pooled. Optionally produce a scatterplot for each + pooled bin. Optionally truncate the AFS at a given `max_freq`. + """ + + if max_freq is None: + max_freq = -1 + ts_trim = ts.trim() + windows = np.linspace(0, ts_trim.sequence_length, num_windows + 1) + if size_biased: + bin_sizes = np.tile(np.arange(ts.num_samples + 1), (num_windows, 1)) + else: + bin_sizes = np.ones((num_windows, ts.num_samples + 1)) + site_afs = bin_sizes * ts_trim.allele_frequency_spectrum( + mode="site", windows=windows, span_normalise=False, polarised=polarised + ) + site_afs = site_afs[:, 1:max_freq] + branch_afs = ( + mutation_rate + * bin_sizes + * ts_trim.allele_frequency_spectrum( + mode="branch", windows=windows, span_normalise=False, polarised=polarised + ) + ) + branch_afs = branch_afs[:, 1:max_freq] + dim = isqrt(num_bins) + num_bins = dim * dim + cumulative = np.arange(0, branch_afs.shape[1], dtype=np.float64) + cumulative /= cumulative[-1] + bins = np.linspace(0, 1, num_bins + 1) + bins = np.searchsorted(cumulative, bins, side="right") - 1 + if plotpath is not None: + fig, axs = plt.subplots(dim, dim, squeeze=0) + fudge = 90 / 100 + for i, j, ax in zip(bins[:-1], bins[1:], axs.reshape(-1)): + obs = site_afs[:, i:j].sum(axis=1) + exp = branch_afs[:, i:j].sum(axis=1) + ax.text( + 0.02, + 0.98, + f"{i+1}:{j+1}", + ha="left", + va="top", + transform=ax.transAxes, + size=8, + ) + ax.set_xticks(np.linspace(exp.min(), exp.max(), 3)) + ax.set_yticks(np.linspace(obs.min(), obs.max(), 3)) + ax.set_xlim(exp.min() * fudge, exp.max() / fudge) + ax.set_ylim(obs.min() * fudge, obs.max() / fudge) + ax.tick_params(labelsize=8) + ax.scatter(exp, obs, color="firebrick", s=4) + ax.axline( + (np.mean(obs), np.mean(obs)), slope=1, linestyle="--", color="black" + ) + fig.supylabel("Observed # sites in window") + fig.supxlabel("Expected # sites in window") + if title is not None: + fig.suptitle(title) + plt.tight_layout() + plt.savefig(plotpath) + plt.clf() + return site_afs, branch_afs diff --git a/tsdate/hypergeo.py b/tsdate/hypergeo.py index d75894dc..b64d9b69 100644 --- a/tsdate/hypergeo.py +++ b/tsdate/hypergeo.py @@ -24,7 +24,12 @@ Numerically stable implementations of the Gauss hypergeometric function with numba. """ import ctypes +from math import erf +from math import exp +from math import lgamma from math import log +from math import pi +from math import pow from math import sqrt import numba @@ -53,6 +58,20 @@ class Invalid2F1(Exception): _gammainc_functype = ctypes.CFUNCTYPE(_dbl, _dbl, _dbl) _gammainc_f8 = _gammainc_functype(_gammainc_addr) +# gammaincinv +_gammaincinv_addr = get_cython_function_address( + "scipy.special.cython_special", "gammaincinv" +) +_gammaincinv_functype = ctypes.CFUNCTYPE(_dbl, _dbl, _dbl) +_gammaincinv_f8 = _gammaincinv_functype(_gammaincinv_addr) + +# erfinv +_erfinv_addr = get_cython_function_address( + "scipy.special.cython_special", "__pyx_fuse_0erfinv" +) +_erfinv_functype = ctypes.CFUNCTYPE(_dbl, _dbl) +_erfinv_f8 = _erfinv_functype(_erfinv_addr) + @numba.cfunc("f8(f8)") def _gammaln(x): @@ -66,6 +85,18 @@ def _gammainc(a, x): return _gammainc_f8(a, x) +@numba.cfunc("f8(f8, f8)") +def _gammainc_inv(a, x): + """scipy.special.cython_special.gammaincinv""" + return _gammaincinv_f8(a, x) + + +@numba.cfunc("f8(f8)") +def _erf_inv(x): + """scipy.special.cython_special.erfinv""" + return _erfinv_f8(x) + + @numba.njit("f8(f8)") def _digamma(x): """ @@ -249,3 +280,105 @@ def _hyp2f1_laplace(a, b, c, x): ) return f - log(r) / 2 + s + + +@numba.njit("f8(f8, f8)") +def _gammainc_der(p, x): + """ + Derivative of lower incomplete gamma function with regards to `p`. + + Based on Shea B (1988) "Algorithm AS 239" Applied Statistics 37: 466-473 + + Adapted from https://people.math.sc.edu/Burkardt/cpp_src/asa239/asa239.cpp + """ + + elimit = -88.0 + oflo = 1.0e37 + plimit = 1.0e3 + tol = 1.0e-14 + xbig = 1.0e8 + + assert x >= 0.0 + assert p > 0.0 + + if x == 0: + value, grad = 0.0, 0.0 + return grad + + if x > xbig: + value, grad = 1.0, 0.0 + return grad + + if p > plimit: # gaussian approximation + pn1 = 3 * sqrt(p) * (pow(x / p, 1 / 3) + 1 / (9 * p) - 1) + grad = pn1 / (2 * p) - 1 / sqrt(p) * (pow(x / p, 1 / 3) + 1 / (3 * p)) + grad *= 1 / sqrt(2 * pi) * exp(-(pn1**2) / 2) + value = (1 + erf(pn1 / sqrt(2))) / 2 + return grad + + if x <= 1 or x < p: # series expansion + arg = p * log(x) - x - lgamma(p + 1) + value, grad = 1.0, 0.0 + c, dc = 1.0, 0.0 + a = p + while True: + a += 1.0 + dc = -c * x / a**2 + dc * x / a + c *= x / a + value += c + grad += dc + if c <= tol: + break + darg = exp(arg) * (log(x) - _digamma(p + 1)) + grad = grad * exp(arg) + value * darg + arg += log(value) + if arg >= elimit: + value = exp(arg) + else: + grad = 0.0 + value = 0.0 + return grad + else: # continued fraction + arg = p * log(x) - x - lgamma(p) + a = 1.0 - p + b = a + x + 1.0 + c = 0.0 + pn1, pn2, pn3, pn4 = 1.0, x, x + 1.0, x * b + dpn1, dpn2, dpn3, dpn4 = 0.0, 0.0, 0.0, -x + value, grad = pn3 / pn4, 0.0 + while True: + a += 1.0 + b += 2.0 + c += 1.0 + an = a * c + pn5 = b * pn3 - an * pn1 + pn6 = b * pn4 - an * pn2 + dpn5 = b * dpn3 - pn3 - an * dpn1 + c * pn1 + dpn6 = b * dpn4 - pn4 - an * dpn2 + c * pn2 + if pn6 != 0.0: + rn = pn5 / pn6 + grad = (dpn5 * pn6 - pn5 * dpn6) / pn6**2 + if abs(value - rn) <= min(tol, tol * rn): + break + value = rn + pn1, pn2, pn3, pn4 = pn3, pn4, pn5, pn6 + dpn1, dpn2, dpn3, dpn4 = dpn3, dpn4, dpn5, dpn6 + if oflo <= abs(pn5): + pn1 /= oflo + pn2 /= oflo + pn3 /= oflo + pn4 /= oflo + dpn1 /= oflo + dpn2 /= oflo + dpn3 /= oflo + dpn4 /= oflo + darg = exp(arg) * (log(x) - _digamma(p)) + grad = grad * exp(arg) + value * darg + arg += log(value) + if arg >= elimit: + grad *= -1.0 + value = 1.0 - exp(arg) + else: + grad = 0.0 + value = 1.0 + return grad diff --git a/tsdate/mixture.py b/tsdate/mixture.py deleted file mode 100644 index ddb5f4f9..00000000 --- a/tsdate/mixture.py +++ /dev/null @@ -1,267 +0,0 @@ -# MIT License -# -# Copyright (c) 2021-23 Tskit Developers -# Copyright (c) 2020-21 University of Oxford -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -""" -Mixture of gamma distributions that may be fit via EM to distribution-valued observations -""" -import numba -import numpy as np -from numba.types import Tuple as _tuple -from numba.types import UniTuple as _unituple - -from . import approx -from . import hypergeo - -# shorthand for numba readonly array types, [type][dimension][constness] -# type is one of "i" (int32), "f" (float64) -# dimension is a nonzero integer -# constness is one of "r" (read-only) or "w" (writable) -_f = numba.types.float64 -_i = numba.types.int32 -_b = numba.types.bool_ -_f1w = numba.types.Array(_f, 1, "C", readonly=False) -_f1r = numba.types.Array(_f, 1, "C", readonly=True) -_f2w = numba.types.Array(_f, 2, "C", readonly=False) -_f2r = numba.types.Array(_f, 2, "C", readonly=True) -_f3w = numba.types.Array(_f, 3, "C", readonly=False) -_i1r = numba.types.Array(_i, 1, "C", readonly=True) - - -@numba.njit(_unituple(_f1w, 4)(_f1r, _f1r, _f1r, _f, _f)) -def _conditional_posterior(prior_logweight, prior_alpha, prior_beta, alpha, beta): - r""" - Return expectations of node age :math:`t` from the mixture model, - - ..math:: - - Ga(t | a, b) \sum_j \pi_j w_j Ga(t | \alpha_j, \beta_j) - - where :math:`a` and :math:`b` are variational parameters, - and :math:`\pi_j, \alpha_j, \beta_j` are prior weights and - parameters for a gamma mixture; and :math:`w_j` are fixed, - observation-specific weights. We use natural parameterization, - so that the shape parameter is :math:`\alpha + 1`. - - TODO: - The normalizing constants of the prior are assumed to have already - been integrated into `prior_weight`. - - Returns the contribution from each component to the - posterior expectations of :math:`E[1]`, :math:`E[t]`, :math:`E[log t]`, - and :math:`E[t log t]`. - - Note that :math:`E[1]` is *unnormalized* and *log-transformed*. - """ - - dim = prior_logweight.size - E = np.full(dim, -np.inf) # E[1] (e.g. normalizing constant) - E_t = np.zeros(dim) # E[t] - E_logt = np.zeros(dim) # E[log(t)] - E_tlogt = np.zeros(dim) # E[t * log(t)] - C = (alpha + 1) * np.log(beta) - hypergeo._gammaln(alpha + 1) if beta > 0 else 0.0 - for i in range(dim): - post_alpha = prior_alpha[i] + alpha - post_beta = prior_beta[i] + beta - if (post_alpha <= -1) or (post_beta <= 0): # skip node if divergent - E[:] = -np.inf - break - E[i] = C + ( - +hypergeo._gammaln(post_alpha + 1) - - (post_alpha + 1) * np.log(post_beta) - + prior_logweight[i] - ) - assert np.isfinite(E[i]) - # TODO: option to use moments instead of sufficient statistics? - E_t[i] = (post_alpha + 1) / post_beta - E_logt[i] = hypergeo._digamma(post_alpha + 1) - np.log(post_beta) - E_tlogt[i] = E_t[i] * E_logt[i] + E_t[i] / (post_alpha + 1) - - return E, E_t, E_logt, E_tlogt - - -@numba.njit(_f(_f1w, _f1w, _f1w, _f1r, _f1r)) -def _em_update(prior_weight, prior_alpha, prior_beta, alpha, beta): - """ - Perform an expectation maximization step for parameters of mixture components, - given variational parameters `alpha`, `beta` for each node. - - The maximization step is performed using Ye & Chen (2017) "Closed form - estimators for the gamma distribution ..." - - ``prior_weight``, ``prior_alpha``, ``prior_beta`` are updated in place. - """ - assert alpha.size == beta.size - - dim = prior_weight.size - n = np.zeros(dim) - t = np.zeros(dim) - logt = np.zeros(dim) - tlogt = np.zeros(dim) - - # incorporate prior normalizing constants into weights - prior_logweight = np.log(prior_weight) - for k in range(dim): - prior_logweight[k] += (prior_alpha[k] + 1) * np.log( - prior_beta[k] - ) - hypergeo._gammaln(prior_alpha[k] + 1) - - # expectation step: - loglik = 0.0 - for a, b in zip(alpha, beta): - E, E_t, E_logt, E_tlogt = _conditional_posterior( - prior_logweight, prior_alpha, prior_beta, a, b - ) - - # skip if posterior is improper - if np.any(np.isinf(E)): - continue - - # convert evidence to posterior weights - norm_const = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) - weight = np.exp(E - norm_const) - - # weighted contributions to sufficient statistics - loglik += norm_const - n += weight - t += E_t * weight - logt += E_logt * weight - tlogt += E_tlogt * weight - - # maximization step: update parameters in place - prior_weight[:] = n / np.sum(n) - prior_beta[:] = n**2 / (n * tlogt - t * logt) - prior_alpha[:] = n * t / (n * tlogt - t * logt) - 1.0 - - return loglik - - -@numba.njit(_f1w(_f1r, _f1r, _f1r, _f1w, _f1w)) -def _gamma_projection(prior_weight, prior_alpha, prior_beta, alpha, beta): - """ - Given variational approximation to posterior: multiply by exact prior, - calculate sufficient statistics, and moment match to get new - approximate posterior. - - Updates ``alpha`` and ``beta`` in-place. - """ - assert alpha.size == beta.size - - dim = prior_weight.size - - # incorporate prior normalizing constants into weights - prior_logweight = np.log(prior_weight) - for k in range(dim): - prior_logweight[k] += (prior_alpha[k] + 1) * np.log( - prior_beta[k] - ) - hypergeo._gammaln(prior_alpha[k] + 1) - - log_const = np.full(alpha.size, -np.inf) - for i in range(alpha.size): - E, E_t, E_logt, E_tlogt = _conditional_posterior( - prior_logweight, prior_alpha, prior_beta, alpha[i], beta[i] - ) - - # skip if posterior is improper for all components - if np.any(np.isinf(E)): - continue - - norm = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) - weight = np.exp(E - norm) - t = np.sum(weight * E_t) - logt = np.sum(weight * E_logt) - # tlogt = np.sum(weight * E_tlogt) - log_const[i] = norm - alpha[i], beta[i] = approx.approximate_gamma_kl(t, logt) - # TODO: do something faster, like - # beta[i] = 1 / (tlogt - t * logt) - # alpha[i] = t * beta[i] - 1.0 - - return log_const - - -@numba.njit(_tuple((_f2w, _f2w, _f1w))(_f2r, _f2r, _i, _f, _b)) -def fit_gamma_mixture(mixture, observations, max_iterations, tolerance, verbose): - """ - Run EM until relative tolerance or maximum number of iterations is - reached. Then, perform expectation-propagation update and return new - variational parameters for the posterior approximation. - """ - - assert mixture.shape[1] == 3 - assert observations.shape[1] == 2 - - # mix_weight, mix_alpha, mix_beta = np.hsplit(mixture, 3) - # alpha, beta = np.hsplit(observations, 2) - mix_weight = mixture[:, 0].copy() - mix_alpha = mixture[:, 1].copy() - mix_beta = mixture[:, 2].copy() - alpha = observations[:, 0].copy() - beta = observations[:, 1].copy() - - last = np.inf - for itt in range(max_iterations): - loglik = _em_update(mix_weight, mix_alpha, mix_beta, alpha, beta) - loglik /= float(alpha.size) - update = np.abs(loglik - last) - last = loglik - if verbose: - print("EM iteration:", itt, "mean(loglik):", np.round(loglik, 5)) - print(" -> weights:", mix_weight) - print(" -> alpha:", mix_alpha) - print(" -> beta:", mix_beta) - if update < np.abs(loglik) * tolerance: - break - - # conditional posteriors for each observation - log_const = _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) - - new_mixture = np.zeros(mixture.shape) - new_observations = np.zeros(observations.shape) - new_observations[:, 0] = alpha - new_observations[:, 1] = beta - new_mixture[:, 0] = mix_weight - new_mixture[:, 1] = mix_alpha - new_mixture[:, 2] = mix_beta - - return new_mixture, new_observations, log_const - - -def initialize_mixture(parameters, num_components): - """ - Initialize clusters by dividing nodes into equal groups. - "parameters" are in natural parameterization (not shape/rate) - """ - global_prior = np.empty((num_components, 3)) - if num_components == 0: - return global_prior - num_nodes = parameters.shape[0] - age_classes = np.tile( - np.arange(num_components), - num_nodes // num_components + 1, - )[:num_nodes] - for k in range(num_components): - indices = np.equal(age_classes, k) - alpha, beta = approx.average_gammas( - parameters[indices, 0], parameters[indices, 1] - ) - global_prior[k] = [1.0 / num_components, alpha, beta] - return global_prior diff --git a/tsdate/normalisation.py b/tsdate/normalisation.py new file mode 100644 index 00000000..7f730224 --- /dev/null +++ b/tsdate/normalisation.py @@ -0,0 +1,415 @@ +# MIT License +# +# Copyright (c) 2020 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Utilities for rescaling time according to a mutational clock +""" +from math import inf +from math import log + +import numba +import numpy as np +import tskit +from numba.types import UniTuple as _unituple + +from .approx import _b +from .approx import _b1r +from .approx import _f +from .approx import _f1r +from .approx import _f1w +from .approx import _f2r +from .approx import _f2w +from .approx import _i +from .approx import _i1r +from .approx import _i1w +from .approx import approximate_gamma_iqr +from .hypergeo import _gammainc_inv as gammainc_inv +from .util import mutation_span_array + + +@numba.njit(_i1w(_f1r, _i)) +def _fixed_changepoints(counts, epochs): + """ + Find breakpoints such that `counts` is divided roughly equally across `epochs` + """ + assert epochs > 0 + Y = np.append(0.0, np.cumsum(counts)) + Z = Y / Y[-1] + z = np.linspace(0, 1, epochs + 1) + e = np.searchsorted(Z, z, "right") - 1 + if e[0] > 0: + e[0] = 0 + if e[-1] < counts.size: + e[-1] = counts.size + return e.astype(np.int32) + + +@numba.njit(_i1w(_f1r, _f1r, _f, _f, _f)) +def _poisson_changepoints(counts, offset, penalty, min_counts, min_offset): + """ + Given Poisson counts and offsets for a sequence of observations, find the set + of changepoints for the Poisson rate that maximizes the profile likelihood + under a linear penalty on complexity (e.g. penalty == 2 is AIC). + + See: "Optimal detection of changepoints with a linear computation cost" + (https://doi.org/10.1080/01621459.2012.737745) + """ + + assert counts.size == offset.size + assert min_counts >= 0 + assert min_offset >= 0 + assert penalty >= 0 + + N = np.append(0, np.cumsum(offset)) + Y = np.append(0, np.cumsum(counts)) + + def f(i, j): # loss + n = N[j] - N[i] + y = Y[j] - Y[i] + s = n < min_offset or y < min_counts + return inf if s else -2 * y * (log(y) - log(n) - 1) + + dim = counts.size + cost = np.empty(dim) + F = np.empty(dim + 1) + C = {0: np.empty(0, dtype=np.int64)} + + F[0] = -penalty + for j in np.arange(1, dim + 1): + argmin, minval = 0, np.inf + for i in C: # minimize + cost[i] = F[i] + f(i, j) + penalty + if cost[i] < minval: + minval = cost[i] + argmin = i + F[j] = minval + for i in set(C): # prune + if cost[i] > F[j] + penalty: + C.pop(i) + C[j] = np.append(C[argmin], argmin) + + breaks = np.append(C[dim], dim).astype(np.int32) + return breaks + + +@numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) +def mutational_timescale( + nodes_time, + likelihoods, + constraints, + edges_parent, + edges_child, + edges_weight, + max_intervals, +): + """ + Rescale node ages so that the instantaneous mutation rate is constant. + Edges with a negative duration are ignored when calculating the total + rate. Returns a rescaled point estimate and the posterior. + + :param np.ndarray nodes_time: point estimates for node ages + :param np.ndarray likelihoods: edges are rows; mutation + counts and mutational span are columns + :param np.ndarray constraints: lower and upper bounds on node age + :param np.ndarray edges_parent: node index for the parent of each edge + :param np.ndarray edges_child: node index for the child of each edge + :param np.ndarray edges_weight: a weight for each edge + :param int max_intervals: maximum number of intervals within which to + estimate the time scaling + """ + + assert edges_parent.size == edges_child.size == edges_weight.size + assert likelihoods.shape[0] == edges_parent.size and likelihoods.shape[1] == 2 + assert constraints.shape[0] == nodes_time.size and constraints.shape[1] == 2 + assert max_intervals > 0 + + nodes_fixed = constraints[:, 0] == constraints[:, 1] + assert np.all(nodes_time[nodes_fixed] == constraints[nodes_fixed, 0]) + + # index node by unique time breaks + nodes_order = np.argsort(nodes_time) + nodes_index = np.zeros(nodes_time.size, dtype=np.int32) + epoch_breaks = [0.0] + k = 0 + for i, j in zip(nodes_order[1:], nodes_order[:-1]): + if nodes_time[i] > nodes_time[j]: + epoch_breaks.append(nodes_time[i]) + k += 1 + nodes_index[i] = k + epoch_breaks = np.array(epoch_breaks) + epoch_length = np.diff(epoch_breaks) + num_epochs = epoch_length.size + + # instantaneous mutation rate per edge + edges_length = nodes_time[edges_parent] - nodes_time[edges_child] + edges_subset = edges_length > 0 + edges_counts = likelihoods.copy() + edges_counts[edges_subset, 0] /= edges_length[edges_subset] + + # pass over edges, measuring overlap with each time interval + epoch_counts = np.zeros((num_epochs, 2)) + for e in np.flatnonzero(edges_subset): + p, c = edges_parent[e], edges_child[e] + a, b = nodes_index[c], nodes_index[p] + if a < num_epochs: + epoch_counts[a] += edges_counts[e] * edges_weight[e] + if b < num_epochs: + epoch_counts[b] -= edges_counts[e] * edges_weight[e] + counts = np.cumsum(epoch_counts[:, 0]) + offset = np.cumsum(epoch_counts[:, 1]) + + # rescale time such that mutation density is constant between changepoints + # TODO: use poisson changepoints to further refine + changepoints = _fixed_changepoints(offset * epoch_length, max_intervals) + changepoints = np.union1d(changepoints, nodes_index[nodes_fixed]) + adjust = np.zeros(changepoints.size) + k = 0 + for i, j in zip(changepoints[:-1], changepoints[1:]): + assert j > i + # TODO: when changepoint intersects a fixed node? + n = np.sum(offset[i:j]) + y = np.sum(counts[i:j]) + z = np.sum(epoch_length[i:j]) + assert n > 0, "Zero edge span in interval" + adjust[k + 1] = z * y / n + k += 1 + adjust = np.cumsum(adjust) + origin = epoch_breaks[changepoints] + + return origin, adjust + + +@numba.njit(_f2w(_f2r, _f1r, _f1r, _f, _b)) +def piecewise_scale_posterior( + posteriors, + original_breaks, + rescaled_breaks, + quantile_width, + use_median, +): + """ + :param float quantile_width: width of interquantile range to use for estimating + rescaled shape parameter, e.g. 0.5 uses interquartile range + """ + + assert original_breaks.size == rescaled_breaks.size + assert 1 > quantile_width > 0 + + dim = posteriors.shape[0] + quant_lower = quantile_width / 2 + quant_upper = 1 - quantile_width / 2 + + # use posterior mean or median as a point estimate + freed = np.logical_and(posteriors[:, 0] > -1, posteriors[:, 1] > 0) + lower = np.zeros(dim) + upper = np.zeros(dim) + midpt = np.zeros(dim) + for i in np.flatnonzero(freed): + alpha, beta = posteriors[i] + lower[i] = gammainc_inv(alpha + 1, quant_lower) / beta + upper[i] = gammainc_inv(alpha + 1, quant_upper) / beta + midpt[i] = gammainc_inv(alpha + 1, 0.5) if use_median else (alpha + 1) + midpt[i] /= beta + + # rescale quantiles + assert np.all(np.diff(rescaled_breaks) > 0) + assert np.all(np.diff(original_breaks) > 0) + scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) + + def rescale(x): + i = np.searchsorted(original_breaks, x, "right") - 1 + assert i.min() >= 0 and i.max() < scalings.size # DEBUG + return rescaled_breaks[i] + scalings[i] * (x - original_breaks[i]) + + midpt = rescale(midpt) + lower = rescale(lower) + upper = rescale(upper) + + # reproject posteriors using inter-quantile range + # TODO: catch rare cases where lower/upper quantiles are nearly identical + new_posteriors = np.zeros(posteriors.shape) + for i in np.flatnonzero(freed): + alpha, beta = approximate_gamma_iqr( + quant_lower, quant_upper, lower[i], upper[i] + ) + beta = gammainc_inv(alpha + 1, 0.5) if use_median else (alpha + 1) + beta /= midpt[i] # choose rate so as to keep mean or median + new_posteriors[i] = alpha, beta + + return new_posteriors + + +@numba.njit(_f1w(_b1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r)) +def edge_sampling_weight( + is_leaf, + edges_parent, + edges_child, + edges_left, + edges_right, + insert_index, + remove_index, +): + """ + Calculate the probability that a randomly selected root-to-leaf path from a + random point on the sequence contains a given edge, for all edges. + + :param np.ndarray is_leaf: boolean array indicating whether a node is a leaf + """ + num_nodes = is_leaf.size + num_edges = edges_child.size + + insert_position = edges_left[insert_index] + remove_position = edges_right[remove_index] + sequence_length = remove_position[-1] + + nodes_parent = np.full(num_nodes, tskit.NULL) + nodes_edge = np.full(num_nodes, tskit.NULL) + nodes_leaves = np.zeros(num_nodes) + edges_leaves = np.zeros(num_edges) + + nodes_leaves[is_leaf] = 1.0 + total_leaves = 0.0 + position = 0.0 + a, b = 0, 0 + while position < sequence_length: + edges_out = [] + while b < num_edges and remove_position[b] == position: + edges_out.append(remove_index[b]) + b += 1 + + edges_in = [] + while a < num_edges and insert_position[a] == position: # edges in + edges_in.append(insert_index[a]) + a += 1 + + remainder = sequence_length - position + + for e in edges_out: + p, c = edges_parent[e], edges_child[e] + update = nodes_leaves[c] + while p != tskit.NULL: + u = nodes_edge[c] + edges_leaves[u] -= update * remainder + c, p = p, nodes_parent[p] + p, c = edges_parent[e], edges_child[e] + while p != tskit.NULL: + nodes_leaves[p] -= update + p = nodes_parent[p] + nodes_parent[c] = tskit.NULL + nodes_edge[c] = tskit.NULL + if is_leaf[c]: + total_leaves -= remainder + + for e in edges_in: + p, c = edges_parent[e], edges_child[e] + nodes_parent[c] = p + nodes_edge[c] = e + if is_leaf[c]: + total_leaves += remainder + update = nodes_leaves[c] + while p != tskit.NULL: + nodes_leaves[p] += update + p = nodes_parent[p] + p, c = edges_parent[e], edges_child[e] + while p != tskit.NULL: + u = nodes_edge[c] + edges_leaves[u] += update * remainder + c, p = p, nodes_parent[p] + + position = sequence_length + if b < num_edges: + position = min(position, remove_position[b]) + if a < num_edges: + position = min(position, insert_position[a]) + + edges_leaves /= total_leaves + return edges_leaves + + +def normalise_tree_sequence( + ts, mutation_rate, *, normalisation_intervals=1000, match_segregating_sites=False +): + """ + Adjust the time scaling of a tree sequence so that expected mutational area + matches the expected number of mutations on a path from leaf to root, where + the expectation is taken over all paths and bases in the sequence. + + :param tskit.TreeSequence ts: the tree sequence to normalise + :param float mutation_rate: the per-base mutation rate + :param int normalisation_intervals: the number of time intervals for which + to estimate a separate time rescaling parameter + :param bool match_segregating_sites: if True, match the total number of + mutations rather than the average number of differences from the ancestral + state + """ + if match_segregating_sites: + edge_weights = np.ones(ts.num_edges) + else: + has_parent = np.full(ts.num_nodes, False) + has_child = np.full(ts.num_nodes, False) + has_parent[ts.edges_child] = True + has_child[ts.edges_parent] = True + is_leaf = np.logical_and(~has_child, has_parent) + edge_weights = edge_sampling_weight( + is_leaf, + ts.edges_parent, + ts.edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ) + # estimate time rescaling parameter within intervals + samples = list(ts.samples()) + if not np.all(ts.nodes_time[samples] == 0.0): + raise ValueError("Normalisation not implemented for ancient samples") + constraints = np.zeros((ts.num_nodes, 2)) + constraints[:, 1] = np.inf + constraints[samples, :] = ts.nodes_time[samples, np.newaxis] + mutations_span, mutations_edge = mutation_span_array(ts) + mutations_span[:, 1] *= mutation_rate + original_breaks, rescaled_breaks = mutational_timescale( + ts.nodes_time, + mutations_span, + constraints, + ts.edges_parent, + ts.edges_child, + edge_weights, + normalisation_intervals, + ) + # rescale node time + assert np.all(np.diff(rescaled_breaks) > 0) + assert np.all(np.diff(original_breaks) > 0) + scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) + idx = np.searchsorted(original_breaks, ts.nodes_time, "right") - 1 + nodes_time = rescaled_breaks[idx] + scalings[idx] * ( + ts.nodes_time - original_breaks[idx] + ) + # calculate mutation time + mutations_parent = ts.edges_parent[mutations_edge] + mutations_child = ts.edges_child[mutations_edge] + mutations_time = (nodes_time[mutations_parent] + nodes_time[mutations_child]) / 2 + above_root = mutations_edge == tskit.NULL + mutations_time[above_root] = nodes_time[mutations_child[above_root]] + tables = ts.dump_tables() + tables.nodes.time = nodes_time + tables.mutations.time = mutations_time + return tables.tree_sequence() diff --git a/tsdate/util.py b/tsdate/util.py index 7f826013..7f8139a2 100644 --- a/tsdate/util.py +++ b/tsdate/util.py @@ -33,7 +33,10 @@ from . import provenance from .approx import _b1r +from .approx import _f from .approx import _f1r +from .approx import _f1w +from .approx import _i from .approx import _i1r from .approx import _i1w @@ -331,17 +334,19 @@ def add_sampledata_times(samples, sites_time): def mutation_span_array(tree_sequence): """Extract mutation counts and spans per edge into a two-column array""" mutation_spans = np.zeros((tree_sequence.num_edges, 2)) + mutation_edges = np.zeros(tree_sequence.num_mutations, dtype=np.int32) for mut in tree_sequence.mutations(): + mutation_edges[mut.id] = mut.edge if mut.edge != tskit.NULL: mutation_spans[mut.edge, 0] += 1 for edge in tree_sequence.edges(): mutation_spans[edge.id, 1] = edge.span - return mutation_spans + return mutation_spans, mutation_edges -@numba.njit(_unituple(_i1w, 4)(_i1r, _i1r, _f1r, _f1r, _i1r, _b1r)) +@numba.njit(_unituple(_i1w, 3)(_i1r, _i1r, _f1r, _f1r, _b1r)) def _split_disjoint_nodes( - edges_parent, edges_child, edges_left, edges_right, mutations_edge, nodes_exclude + edges_parent, edges_child, edges_left, edges_right, nodes_exclude ): """ Split disconnected regions of nodes into separate nodes. @@ -352,12 +357,12 @@ def _split_disjoint_nodes( assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size num_edges = edges_parent.size num_nodes = nodes_exclude.size - num_mutations = mutations_edge.size # For each edge, check whether parent/child is separated by a gap from the # previous edge involving either parent/child. Label disconnected segments # per node by integers starting at zero. edges_order = np.argsort(edges_left) + # TODO: is a sort really needed here? edges_segments = np.full((2, num_edges), -1, dtype=np.int32) nodes_segments = np.full(num_nodes, -1, dtype=np.int32) nodes_right = np.full(nodes_exclude.size, -np.inf, dtype=np.float64) @@ -392,13 +397,76 @@ def _split_disjoint_nodes( edges_segments[i, e] = n edges_parent, edges_child = edges_segments[0, ...], edges_segments[1, ...] - # Relabel node under each mutation - mutations_node = np.full(num_mutations, tskit.NULL, dtype=np.int32) - for i, e in enumerate(mutations_edge): - if e != tskit.NULL: - mutations_node[i] = edges_child[e] + return edges_parent, edges_child, nodes_order - return edges_parent, edges_child, mutations_node, nodes_order + +@numba.njit(_i1w(_i1r, _f1r, _i1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r)) +def _relabel_mutations_node( + mutations_node, + mutations_position, + nodes_order, + edges_parent, + edges_child, + edges_left, + edges_right, + insert_index, + remove_index, +): + """ + Traverse trees, maintaining a mapping between old and new node IDs in the + current tree. Update `mutations_node` to reflect new IDs. + """ + assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size + assert edges_parent.size == insert_index.size == remove_index.size + assert mutations_position.size == mutations_node.size + + num_nodes = nodes_order.size + num_edges = edges_parent.size + num_mutations = mutations_position.size + + insert_position = edges_left[insert_index] + remove_position = edges_right[remove_index] + sequence_length = remove_position[-1] + + output = np.full(num_mutations, tskit.NULL, dtype=np.int32) + nodes_map = np.full(num_nodes, tskit.NULL, dtype=np.int32) + a, b, m = 0, 0, 0 + left = 0.0 + while left < sequence_length: + while b < num_edges and remove_position[b] == left: # edges out + b += 1 + + while a < num_edges and insert_position[a] == left: # edges in + e = insert_index[a] + c, p = edges_child[e], edges_parent[e] + nodes_map[nodes_order[c]] = c + nodes_map[nodes_order[p]] = p + a += 1 + + right = sequence_length + if b < num_edges: + right = min(right, remove_position[b]) + if a < num_edges: + right = min(right, insert_position[a]) + left = right + + while m < num_mutations and mutations_position[m] < right: + assert nodes_map[mutations_node[m]] != tskit.NULL + output[m] = nodes_map[mutations_node[m]] + m += 1 + + return output + + +# def _naive_relabel_mutations_node(ts, nodes_order, mutations_node): +# num_nodes = nodes_order.size +# new_node_id = np.full(num_nodes, tskit.NULL) +# for t in ts.trees(): +# for n in t.nodes(): # mapping from original to new node ids +# new_node_id[nodes_order[n]] = n +# for m in t.mutations(): +# mutations_node[m.id] = new_node_id[mutations_node[m.id]] +# return mutations_node def split_disjoint_nodes(ts): @@ -412,23 +480,26 @@ def split_disjoint_nodes(ts): new nodes. """ - mutations_edge = np.full(ts.num_mutations, tskit.NULL, dtype=np.int32) - for m in ts.mutations(): - mutations_edge[m.id] = m.edge - node_is_sample = np.bitwise_and(ts.nodes_flags, tskit.NODE_IS_SAMPLE).astype(bool) - edges_parent, edges_child, mutations_node, nodes_order = _split_disjoint_nodes( + edges_parent, edges_child, nodes_order = _split_disjoint_nodes( ts.edges_parent, ts.edges_child, ts.edges_left, ts.edges_right, - mutations_edge, node_is_sample, ) - # TODO: correctly handle mutations above root (m.edge == tskit.NULL) - nonsegregating = np.flatnonzero(mutations_node == tskit.NULL) - mutations_node[nonsegregating] = ts.mutations_node[nonsegregating] + mutations_node = _relabel_mutations_node( + ts.mutations_node, + ts.sites_position[ts.mutations_site], + nodes_order, + edges_parent, + edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ) tables = ts.dump_tables() tables.nodes.set_columns( @@ -445,149 +516,125 @@ def split_disjoint_nodes(ts): tables.mutations.node = mutations_node tables.sort() + assert np.array_equal( + tables.nodes.time[tables.mutations.node], ts.nodes_time[ts.mutations_node] + ) + return tables.tree_sequence() -# TODO: numba.njit -def _split_root_nodes(ts): +@numba.njit(_f1w(_f1r, _b1r, _i1r, _i1r, _f, _i)) +def _constrain_ages( + nodes_time, nodes_fixed, edges_parent, edges_child, epsilon, max_iterations +): """ - Split roots whenever the set of children changes. Nodes will only be split - on the interior of the intervals where they are roots. + Approximate least squares solution to the positive branch length + constraint, using the method of alternating projections. Loosely based on + Dykstra's algorithm, see: - Returns new edges (parent, child, left, right) and the original ids for - each node. + Dykstra RL, "An algorithm for restricted least squares regression", JASA + 1983 """ + assert nodes_time.size == nodes_fixed.size + assert edges_parent.size == edges_child.size - num_nodes = ts.num_nodes - num_edges = ts.num_edges - - # Find locations where root node changes - roots_node = [] - roots_breaks = [] - last_root = None - for t in ts.trees(): - root = tskit.NULL if t.num_edges == 0 else t.root - if root != last_root: - roots_node.append(root) - roots_breaks.append(t.interval.left) - last_root = root - roots_breaks.append(ts.sequence_length) - roots_node = np.array(roots_node, dtype=np.int32) - roots_breaks = np.array(roots_breaks, dtype=np.float64) - - # Segment roots at edge additions/removals - add_breaks = {n: list() for n in roots_node if n != tskit.NULL} - for e in range(num_edges): - p = ts.edges_parent[e] - if p in add_breaks: - for x in (ts.edges_left[e], ts.edges_right[e]): - i = np.searchsorted(roots_breaks, x, side="right") - 1 - if x == ts.sequence_length: - continue - if ( - p == roots_node[i] and x > roots_breaks[i] - ): # store *internal* breaks for root segments - add_breaks[p].append(x) - - # Create a new node for each segment except the leftmost - add_nodes = {} - add_split = {} - nodes_order = [i for i in range(num_nodes)] - for p in add_breaks: - breaks = np.unique(np.asarray(add_breaks[p])) - if breaks.size > 0: - add_split[p] = breaks - add_nodes[p] = [p] # segment left of first break retains original node ID - for _ in range(breaks.size): - add_nodes[p].append(num_nodes) - nodes_order.append(p) - num_nodes += 1 - - # Split each edge along the union of parent/child segments - new_parent = list(ts.edges_parent) - new_child = list(ts.edges_child) - new_left = list(ts.edges_left) - new_right = list(ts.edges_right) - for e in range(num_edges): - p, c = ts.edges_parent[e], ts.edges_child[e] - - if not (p in add_nodes or c in add_nodes): # no breaks in parent/child - continue - - # find parent/child breaks on edge - left, right = ts.edges_left[e], ts.edges_right[e] - p_nodes = add_nodes.get(p, [p]) - c_nodes = add_nodes.get(c, [c]) - p_split = add_split.get(p, np.empty(0)) - c_split = add_split.get(c, np.empty(0)) - e_split = np.unique(np.append(p_split, c_split)) - e_split = e_split[np.logical_and(e_split > left, e_split < right)] - - e_split = np.append(e_split, right) - p_index = np.searchsorted(p_split, e_split, side="left") - c_index = np.searchsorted(c_split, e_split, side="left") - for x, i, j in zip(e_split, p_index, c_index): - new_p, new_c = p_nodes[i], c_nodes[j] - if ( - left == new_left[e] - ): # segment left of first break retains original edge ID - new_parent[e] = new_p - new_child[e] = new_c - new_left[e] = left - new_right[e] = x - else: - new_parent.append(new_p) - new_child.append(new_c) - new_left.append(left) - new_right.append(x) - left = x - assert left == right - - nodes_order = np.array(nodes_order, dtype=np.int32) - new_parent = np.array(new_parent, dtype=np.int32) - new_child = np.array(new_child, dtype=np.int32) - new_left = np.array(new_left, dtype=np.float64) - new_right = np.array(new_right, dtype=np.float64) + num_edges = edges_parent.size + nodes_time = nodes_time.copy() + edges_cavity = np.zeros((num_edges, 2)) + for _ in range(max_iterations): # method of alternating projections + if np.all(nodes_time[edges_parent] - nodes_time[edges_child] > 0): + return nodes_time + for e in range(num_edges): + p, c = edges_parent[e], edges_child[e] + nodes_time[c] -= edges_cavity[e, 0] + nodes_time[p] -= edges_cavity[e, 1] + adjustment = nodes_time[c] - nodes_time[p] # + epsilon + edges_cavity[e, :] = 0.0 + if adjustment > 0: + assert not nodes_fixed[p] # TODO: no reason not to support this + edges_cavity[e, 0] = 0 if nodes_fixed[c] else -adjustment / 2 + edges_cavity[e, 1] = adjustment if nodes_fixed[c] else adjustment / 2 + nodes_time[c] += edges_cavity[e, 0] + nodes_time[p] += edges_cavity[e, 1] + # print( + # "min length:", np.min(nodes_time[edges_parent] - nodes_time[edges_child]) + # ) + for e in range(num_edges): # force constraint + p, c = edges_parent[e], edges_child[e] + if nodes_time[c] >= nodes_time[p]: + nodes_time[p] = nodes_time[c] + epsilon - return new_parent, new_child, new_left, new_right, nodes_order + return nodes_time -def split_root_nodes(ts): +def constrain_ages(ts, nodes_time, epsilon=1e-6, max_iterations=0): """ - Split roots whenever the set of children changes. Nodes are only split in the - interior of intervals where they are roots. + Use a hybrid approach to adjust node times such that branch lengths are + positive. The first pass iteratively solves a constrained least squares + problem that seeks to find constrained ages as close as possible to + unconstrained ages. Progress is initially fast but typically becomes quite + slow, so after a fixed number of iterations the iterative algorithm + terminates and the constraint is forced. + + :param tskit.TreeSequence ts: The input tree sequence, with arbitrary node + times. + :param np.ndarray nodes_time: Unconstrained node ages to inject into the + tree sequence. + :param float epsilon: The minimum allowed branch length when forcing + positive branch lengths. + :param int max_iterations: The number of iterations of alternating + projections before forcing positive branch lengths. + + :return np.ndarray: Constrained node ages """ - edges_parent, edges_child, edges_left, edges_right, nodes_order = _split_root_nodes( - ts + assert nodes_time.size == ts.num_nodes + assert epsilon >= 0 + assert max_iterations >= 0 + + node_is_sample = np.bitwise_and(ts.nodes_flags, tskit.NODE_IS_SAMPLE).astype(bool) + constrained_nodes_time = _constrain_ages( + nodes_time, + node_is_sample, + ts.edges_parent, + ts.edges_child, + epsilon, + max_iterations, ) + modified = np.sum(~np.isclose(nodes_time, constrained_nodes_time)) + if modified: + logging.info(f"Modified ages of {modified} nodes to satisfy constraints") - # TODO: correctly handle mutations above root (m.edge == tskit.NULL) - mutations_node = ts.mutations_node.copy() - for m in ts.mutations(): - if m.edge != tskit.NULL: - mutations_node[m.id] = edges_child[m.edge] + return constrained_nodes_time - tables = ts.dump_tables() - tables.nodes.set_columns( - flags=tables.nodes.flags[nodes_order], - time=tables.nodes.time[nodes_order], - individual=tables.nodes.individual[nodes_order], - population=tables.nodes.population[nodes_order], - ) - # TODO: copy existing metadata for original nodes - # TODO: add new metadata indicating origin for split nodes - # TODO: add flag for split nodes - tables.edges.set_columns( - parent=edges_parent, - child=edges_child, - left=edges_left, - right=edges_right, - ) - tables.mutations.node = mutations_node - tables.sort() - tables.edges.squash() - tables.sort() +def constrain_mutations(ts, nodes_time, mutations_edge): + """ + If the mutation is above a root, its age set to the age of the root. If + the mutation is between two internal nodes, the edge midpoint is used. - return tables.tree_sequence() + :param tskit.TreeSequence ts: The input tree sequence, with arbitrary node + times. + :param np.ndarray nodes_time: Constrained node ages. + :param np.ndarray mutations_edge: The edge that each mutation falls on. + + :return np.ndarray: Constrained mutation ages + """ + + parent = ts.edges_parent[mutations_edge] + child = ts.edges_child[mutations_edge] + parent_time = nodes_time[parent] + child_time = nodes_time[child] + assert np.all(parent_time > child_time), "Negative branch lengths" + + mutations_time = (child_time + parent_time) / 2 + internal = mutations_edge != tskit.NULL + constrained_time = np.full(mutations_time.size, tskit.UNKNOWN_TIME) + constrained_time[internal] = mutations_time[internal] + constrained_time[~internal] = nodes_time[ts.mutations_node[~internal]] + + external = np.sum(~internal) + if external: + logging.info(f"Set ages of {external} nonsegregating mutations to root times.") + + return constrained_time diff --git a/tsdate/variational.py b/tsdate/variational.py index cabcce48..6fafabce 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -23,22 +23,31 @@ """ Expectation propagation implementation """ +import logging +import time + import numba import numpy as np +import tskit from numba.types import void as _void from tqdm.auto import tqdm from . import approx -from . import mixture from .approx import _b +from .approx import _b1r from .approx import _f from .approx import _f1r from .approx import _f1w from .approx import _f2r from .approx import _f2w +from .approx import _f3r from .approx import _f3w from .approx import _i from .approx import _i1r +from .hypergeo import _gammainc_inv as gammainc_inv +from .normalisation import edge_sampling_weight +from .normalisation import mutational_timescale +from .normalisation import piecewise_scale_posterior # columns for edge_factors @@ -144,7 +153,7 @@ def _check_valid_constraints(constraints, edges_parent, edges_child): ) @staticmethod - def _check_valid_inputs(ts, likelihoods, constraints, prior): + def _check_valid_inputs(ts, likelihoods, constraints, mutations_edge): if likelihoods.shape != (ts.num_edges, 2): raise ValueError("Edge likelihoods are the wrong shape") if constraints.shape != (ts.num_nodes, 2): @@ -153,6 +162,8 @@ def _check_valid_inputs(ts, likelihoods, constraints, prior): raise ValueError("Edge likelihoods contains negative values") if np.any(constraints < 0.0): raise ValueError("Node age constraints contain negative values") + if mutations_edge.size > 0 and mutations_edge.max() >= ts.num_edges: + raise ValueError("Mutation edge indices are out-of-bounds") ExpectationPropagation._check_valid_constraints( constraints, ts.edges_parent, ts.edges_child ) @@ -170,7 +181,20 @@ def _check_valid_state( posterior_check += node_factors[:, CONSTRNT] return np.allclose(posterior_check, posterior) - def __init__(self, ts, likelihoods, constraints, prior): + @staticmethod + @numba.njit(_f1w(_f2r, _f2r, _b)) + def _point_estimate(posteriors, constraints, median): + assert posteriors.shape == constraints.shape + fixed = constraints[:, 0] == constraints[:, 1] + point_estimate = np.zeros(posteriors.shape[0]) + for i in np.flatnonzero(~fixed): + alpha, beta = posteriors[i] + point_estimate[i] = gammainc_inv(alpha + 1, 0.5) if median else (alpha + 1) + point_estimate[i] /= beta + point_estimate[fixed] = constraints[fixed, 0] + return point_estimate + + def __init__(self, ts, likelihoods, constraints, mutations_edge): """ Initialize an expectation propagation algorithm for dating nodes in a tree sequence. @@ -189,51 +213,51 @@ def __init__(self, ts, likelihoods, constraints, prior): :param ~np.ndarray likelihoods: a `ts.num_edges`-by-two array containing mutation counts and mutational spans (e.g. edge span multiplied by mutation rate) per edge. - :param ~np.ndarray prior: a `K`-by-three array containing parameters of - an i.i.d. gamma mixture prior with `K` components, used for all - nonfixed nodes. Each row contains the weight, shape, and rate - for a mixture component. + :param ~np.ndarray mutations_edge: an array containing edge indices + (one per mutation) for which to compute posteriors. """ - self._check_valid_inputs(ts, likelihoods, constraints, prior) + # TODO: pass in edge table rather than tree sequence + # TODO: check valid mutations_edge + self._check_valid_inputs(ts, likelihoods, constraints, mutations_edge) # const self.parents = ts.edges_parent self.children = ts.edges_child self.likelihoods = likelihoods self.constraints = constraints + self.mutations_edge = mutations_edge # mutable - self.prior = prior.copy() self.node_factors = np.zeros((ts.num_nodes, 2, 2)) self.edge_factors = np.zeros((ts.num_edges, 2, 2)) - # self.node_lognorm = np.zeros(ts.num_nodes) - # self.edge_lognorm = np.zeros(ts.num_nodes) self.posterior = np.zeros((ts.num_nodes, 2)) self.log_partition = np.zeros(ts.num_edges) self.scale = np.ones(ts.num_nodes) - # get edge traversal order - node_is_fixed = constraints[:, LOWER] == constraints[:, UPPER] - child_is_contemporary = np.logical_and( - constraints[ts.edges_child, LOWER] == 0.0, - node_is_fixed[ts.edges_child], - ) + # terminal nodes + has_parent = np.full(ts.num_nodes, False) + has_child = np.full(ts.num_nodes, False) + has_parent[self.children] = True + has_child[self.parents] = True + self.roots = np.logical_and(has_child, ~has_parent) + self.leaves = np.logical_and(~has_child, has_parent) + if np.any(np.logical_and(~has_child, ~has_parent)): + raise ValueError("Tree sequence contains disconnected nodes") + + # edge traversal order edges = np.arange(ts.num_edges, dtype=np.int32) - contemp = edges[child_is_contemporary] - noncontemp = edges[~child_is_contemporary] - self.edge_order = np.concatenate( # rootward + leafward - (noncontemp[:-1], np.flip(noncontemp)) + self.edge_order = np.concatenate((edges[:-1], np.flip(edges))) + self.edge_weights = edge_sampling_weight( + self.leaves, + ts.edges_parent, + ts.edges_child, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, ) - # edges attached to contemporary nodes are visited once - for i in contemp: - p, c = ts.edges_parent[i], ts.edges_child[i] - assert np.all(constraints[c] == 0.0) - self.edge_factors[i, ROOTWARD] = self.likelihoods[i] - self.posterior[p] += self.likelihoods[i] - # self.node_lognorm[i] += ... # TODO - @staticmethod @numba.njit(_f(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f, _b)) def propagate_likelihood( @@ -253,8 +277,6 @@ def propagate_likelihood( """ Update approximating factors for Poisson mutation likelihoods on edges. - TODO: return max difference in natural parameters for stopping criterion - :param ndarray edges_parent: integer array of parent ids per edge :param ndarray edges_child: integer array of child ids per edge :param ndarray likelihoods: array of dimension `[num_edges, 2]` @@ -368,18 +390,16 @@ def posterior_damping(x): return np.nan @staticmethod - @numba.njit(_f(_f2w, _f2w, _f2w, _f3w, _f1w, _f, _i, _f)) + @numba.njit(_f(_b1r, _f2w, _f3w, _f1w, _f, _i, _f)) def propagate_prior( - prior, constraints, posterior, factors, scale, max_shape, em_maxitt, em_reltol + free, posterior, factors, scale, max_shape, em_maxitt, em_reltol ): """ - Update approximating factors for global prior at each node. + Update approximating factors for global prior. - :param ndarray constraints: rows are nodes, columns are upper and - lower bounds for node age. - :param ndarray prior: rows are mixture components, columns are - zeroth, first, and second natural parameters of gamma mixture - components. Updated in place. + :param ndarray free: boolean array indicating if prior should be + applied to node + :param ndarray penalty: initial value for regularisation penalty :param ndarray posterior: rows are nodes, columns are first and second natural parameters of gamma posteriors. Updated in place. @@ -389,47 +409,37 @@ def propagate_prior( scaling factor for the posteriors, updated in-place. :param float max_shape: the maximum allowed shape for node posteriors. :param int em_maxitt: the maximum number of EM iterations to use when - fitting the mixture model. + fitting the regularisation. :param int em_reltol: the termination criterion for relative change in log-likelihood. """ - assert prior.shape[1] == 3 - assert constraints.shape == posterior.shape - assert factors.shape == (constraints.shape[0], 2, 2) - assert scale.size == constraints.shape[0] + assert free.size == posterior.shape[0] + assert factors.shape == (free.size, 2, 2) + assert scale.size == free.size assert max_shape >= 1.0 - if prior.shape[0] == 0: - return 0.0 - def posterior_damping(x): return _rescale(x, max_shape) - lognorm = np.zeros(constraints.shape[0]) # TODO: move to member - - # fit a mixture-of-gamma model to cavity distributions for unconstrained nodes - free = np.logical_and( - constraints[:, LOWER] == 0.0, constraints[:, UPPER] == np.inf - ) + # fit an exponential to cavity distributions for unconstrained nodes cavity = posterior - factors[:, MIXPRIOR] * scale[:, np.newaxis] - prior[:], posterior[free], lognorm[free] = mixture.fit_gamma_mixture( - prior, cavity[free], em_maxitt, em_reltol, False - ) - - # reset nodes that were skipped (b/c of improper posteriors) - skipped = np.logical_and(free, ~np.isfinite(lognorm)) - posterior[skipped] = ( - cavity[skipped] + factors[skipped, MIXPRIOR] * scale[skipped, np.newaxis] - ) - - # the remaining nodes may be updated - updated = np.logical_and(free, np.isfinite(lognorm)) - factors[updated, MIXPRIOR] = (posterior[updated] - cavity[updated]) / scale[ - updated, np.newaxis + shape, rate = cavity[free, 0] + 1, cavity[free, 1] + penalty = 1 / np.mean(shape / rate) + itt, delta = 0, np.inf + while abs(delta) > abs(penalty) * em_reltol: + if itt > em_maxitt: + break + delta = 1 / np.mean(shape / (rate + penalty)) - penalty + penalty += delta + itt += 1 + assert penalty > 0 + + # update posteriors and rescale to keep shape bounded + posterior[free, 1] = cavity[free, 1] + penalty + factors[free, MIXPRIOR] = (posterior[free] - cavity[free]) / scale[ + free, np.newaxis ] - - # rescale posterior to keep shape bounded for i in np.flatnonzero(free): eta = posterior_damping(posterior[i]) posterior[i] *= eta @@ -438,62 +448,93 @@ def posterior_damping(x): return np.nan @staticmethod - @numba.njit(_f(_f2r, _f2w, _f3w, _f1w, _f, _f, _b)) - def propagate_constraints( - constraints, posterior, factors, scale, max_shape, min_step, min_kl + @numba.njit(_f2w(_i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b)) + def propagate_mutations( + mutations_edge, + edges_parent, + edges_child, + likelihoods, + constraints, + posterior, + factors, + scale, + min_kl, ): """ - Update approximating factors for node age constraints (indicator - functions) at each node. + Calculate posteriors for mutations. - :param ndarray constraints: rows are nodes, columns are lower and - upper bounds for age. - :param ndarray posterior: rows are nodes, columns are first and - second natural parameters of gamma posteriors. Updated in - place. - :param ndarray factors: rows are edges, columns are different - types of updates. Updated in place. + :param ndarray mutations_edge: integer array giving edge for each + mutation + :param ndarray edges_parent: integer array of parent ids per edge + :param ndarray edges_child: integer array of child ids per edge + :param ndarray likelihoods: array of dimension `[num_edges, 2]` + containing mutation count and mutational target size per edge. + :param ndarray constraints: array of dimension `[num_nodes, 2]` + containing lower and upper bounds for each node. + :param ndarray posterior: array of dimension `[num_nodes, 2]` + containing natural parameters for each node, updated in-place. + :param ndarray factors: array of dimension `[num_edges, 2, 2]` + containing parent and child factors (natural parameters) for each + edge, updated in-place. :param ndarray scale: array of dimension `[num_nodes]` containing a scaling factor for the posteriors, updated in-place. - :param float max_shape: the maximum allowed shape for node posteriors. - :param float min_step: the minimum allowed step size in (0, 1). :param bool min_kl: minimize KL divergence or match central moments. """ - assert constraints.shape == posterior.shape - assert factors.shape == (constraints.shape[0], 2, 2) - assert scale.size == constraints.shape[0] - assert max_shape >= 1.0 - assert 0.0 < min_step < 1.0 - - def cavity_damping(x, y): - return _damp(x, y, min_step) + # TODO: scale should be 1.0, can we delete + # TODO: we don't seem to need to damp? + # TODO: might as well copy format in other functions and have void return - def posterior_damping(x): - return _rescale(x, max_shape) - - lognorm = np.zeros(constraints.shape[0]) # TODO: move to member - - bounded = np.logical_or( - constraints[:, LOWER] > 0.0, constraints[:, UPPER] < np.inf - ) + assert constraints.shape == posterior.shape + assert edges_child.size == edges_parent.size + assert factors.shape == (edges_parent.size, 2, 2) + assert likelihoods.shape == (edges_parent.size, 2) - for i in np.flatnonzero(bounded): - if constraints[i, LOWER] == constraints[i, UPPER]: + mutations_posterior = np.zeros((mutations_edge.size, 2)) + fixed = constraints[:, LOWER] == constraints[:, UPPER] + for m, i in enumerate(mutations_edge): + if i == tskit.NULL: # skip mutations above root + mutations_posterior[m] = np.nan continue - message = factors[i, CONSTRNT] * scale[i] - delta = cavity_damping(posterior[i], message) - cavity = posterior[i] - delta * message - lognorm[i], posterior[i] = approx.truncated_projection( - constraints[i], cavity, min_kl - ) - factors[i, CONSTRNT] *= 1.0 - delta - factors[i, CONSTRNT] += (posterior[i] - cavity) / scale[i] - eta = posterior_damping(posterior[i]) - posterior[i] *= eta - scale[i] *= eta + p, c = edges_parent[i], edges_child[i] + if fixed[p] and fixed[c]: + child_age = constraints[c, 0] + parent_age = constraints[p, 0] + mean = 1 / 2 * (child_age + parent_age) + variance = 1 / 12 * (parent_age - child_age) ** 2 + mutations_posterior[m] = approx.approximate_gamma_mom(mean, variance) + elif fixed[p] and not fixed[c]: + child_message = factors[i, LEAFWARD] * scale[c] + child_delta = 1.0 # hopefully we don't need to damp + child_cavity = posterior[c] - child_delta * child_message + edge_likelihood = child_delta * likelihoods[i] + parent_age = constraints[p, LOWER] + mutations_posterior[m] = approx.mutation_leafward_projection( + parent_age, child_cavity, edge_likelihood, min_kl + ) + elif fixed[c] and not fixed[p]: + parent_message = factors[i, ROOTWARD] * scale[p] + parent_delta = 1.0 # hopefully we don't need to damp + parent_cavity = posterior[p] - parent_delta * parent_message + edge_likelihood = parent_delta * likelihoods[i] + child_age = constraints[c, LOWER] + mutations_posterior[m] = approx.mutation_rootward_projection( + child_age, parent_cavity, edge_likelihood, min_kl + ) + else: + parent_message = factors[i, ROOTWARD] * scale[p] + child_message = factors[i, LEAFWARD] * scale[c] + parent_delta = 1.0 # hopefully we don't need to damp + child_delta = 1.0 # hopefully we don't need to damp + delta = min(parent_delta, child_delta) + parent_cavity = posterior[p] - delta * parent_message + child_cavity = posterior[c] - delta * child_message + edge_likelihood = delta * likelihoods[i] + mutations_posterior[m] = approx.mutation_gamma_projection( + parent_cavity, child_cavity, edge_likelihood, min_kl + ) - return np.nan + return mutations_posterior @staticmethod @numba.njit(_void(_i1r, _i1r, _f3w, _f3w, _f1w)) @@ -507,26 +548,16 @@ def rescale_factors(edges_parent, edges_child, node_factors, edge_factors, scale scale[:] = 1.0 def iterate( - self, em_maxitt=10, em_reltol=1e-6, max_shape=1000, min_step=0.1, min_kl=True + self, + *, + max_shape=1000, + min_step=0.1, + em_maxitt=100, + em_reltol=1e-8, + min_kl=False, + regularise=True, + check_valid=False, ): - """ - Update approximating factors. - - Returns the approximate log marginal likelihood (TODO) - """ - - # mixture prior over unconstrained nodes - self.propagate_prior( - self.prior, - self.constraints, - self.posterior, - self.node_factors, - self.scale, - max_shape, - em_maxitt, - em_reltol, - ) - # rootward + leafward pass through edges self.propagate_likelihood( self.edge_order, @@ -543,16 +574,17 @@ def iterate( min_kl, ) - # upper and lower bounds on node age - # self.propagate_constraints( - # self.constraints, - # self.posterior, - # self.node_factors, - # self.scale, - # max_shape, - # min_step, - # min_kl, - # ) + # exponential regularization on roots + if regularise: + self.propagate_prior( + self.roots, + self.posterior, + self.node_factors, + self.scale, + max_shape, + em_maxitt, + em_reltol, + ) # absorb the scaling term into the factors self.rescale_factors( @@ -563,32 +595,104 @@ def iterate( self.scale, ) - # for debugging - # assert self._check_valid_state( - # self.parents, self.children, self.posterior, - # self.node_factors, self.edge_factors, - # ) + if check_valid: # for debugging + assert self._check_valid_state( + self.parents, + self.children, + self.posterior, + self.node_factors, + self.edge_factors, + ) return np.nan # TODO: placeholder for marginal likelihood + def normalise( + self, + *, + norm_intervals=1000, + norm_segsites=False, + use_median=False, + quantile_width=0.5, + ): + """Normalise posteriors so that empirical mutation rate is constant""" + edge_weights = ( + np.ones(self.edge_weights.size) if norm_segsites else self.edge_weights + ) + nodes_time = self._point_estimate(self.posterior, self.constraints, use_median) + original_breaks, rescaled_breaks = mutational_timescale( + nodes_time, + self.likelihoods, + self.constraints, + self.parents, + self.children, + edge_weights, + norm_intervals, + ) + self.posterior[:] = piecewise_scale_posterior( + self.posterior, + original_breaks, + rescaled_breaks, + quantile_width, + use_median, + ) + self.mutations_posterior[:] = piecewise_scale_posterior( + self.mutations_posterior, + original_breaks, + rescaled_breaks, + quantile_width, + use_median, + ) + def run( self, *, - ep_maxitt=20, - em_maxitt=10, + ep_maxitt=10, max_shape=1000, min_step=0.1, - min_kl=True, - progress=None + min_kl=False, + norm_intervals=1000, + norm_segsites=False, + regularise=True, + progress=None, ): - for itt in tqdm( + nodes_timing = time.time() + for _ in tqdm( np.arange(ep_maxitt), desc="Expectation Propagation", disable=not progress, ): self.iterate( - em_maxitt=em_maxitt if itt else 0, max_shape=max_shape, min_step=min_step, min_kl=min_kl, + regularise=regularise, ) + nodes_timing -= time.time() + skipped_nodes = np.sum(np.isnan(self.log_partition)) + if skipped_nodes: + logging.info(f"Skipped {skipped_nodes} nodes with invalid posteriors") + logging.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds") + + muts_timing = time.time() + self.mutations_posterior = self.propagate_mutations( + self.mutations_edge, + self.parents, + self.children, + self.likelihoods, + self.constraints, + self.posterior, + self.edge_factors, + self.scale, + min_kl, + ) + muts_timing -= time.time() + skipped_muts = np.sum(np.isnan(self.mutations_posterior[:, 0])) + if skipped_muts: + logging.info(f"Skipped {skipped_muts} mutations with invalid posteriors") + logging.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds") + + if norm_intervals > 0: + norm_timing = time.time() + self.normalise(norm_intervals=norm_intervals, norm_segsites=norm_segsites) + norm_timing -= time.time() + logging.info(f"Timescale normalised in {abs(norm_timing)} seconds") From a9e1f835b124f6de19971c6338ee07b8ae556208 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 10 May 2024 11:10:05 -0700 Subject: [PATCH 2/3] Rename normalisation --- tests/test_cli.py | 6 +- tsdate/__init__.py | 1 - tsdate/cli.py | 4 +- tsdate/core.py | 16 +-- tsdate/{normalisation.py => rescaling.py} | 141 +++++++++++----------- tsdate/variational.py | 32 ++--- 6 files changed, 101 insertions(+), 99 deletions(-) rename tsdate/{normalisation.py => rescaling.py} (78%) diff --git a/tests/test_cli.py b/tests/test_cli.py index 9067b9bf..53eb644e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -248,7 +248,7 @@ def test_verbosity(self, tmp_path, caplog, flag, log_status): ) def test_no_progress(self, method, tmp_path, capfd): input_ts = msprime.simulate(4, random_seed=123) - params = f"-m 0.1 --method {method} --normalisation-intervals 0" + params = f"-m 0.1 --method {method} --rescaling-intervals 0" self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}") (out, err) = capfd.readouterr() assert out == "" @@ -257,7 +257,7 @@ def test_no_progress(self, method, tmp_path, capfd): def test_progress(self, tmp_path, capfd): input_ts = msprime.simulate(4, random_seed=123) - params = "--method inside_outside --progress --normalisation-intervals 0" + params = "--method inside_outside --progress --rescaling-intervals 0" self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}") (out, err) = capfd.readouterr() assert out == "" @@ -277,7 +277,7 @@ def test_progress(self, tmp_path, capfd): def test_iterative_progress(self, tmp_path, capfd): input_ts = msprime.simulate(4, random_seed=123) params = "--method variational_gamma --mutation-rate 1e-8 " - params += "--progress --normalisation-intervals 0" + params += "--progress --rescaling-intervals 0" self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}") (out, err) = capfd.readouterr() assert out == "" diff --git a/tsdate/__init__.py b/tsdate/__init__.py index 220c1c27..ba79af6f 100644 --- a/tsdate/__init__.py +++ b/tsdate/__init__.py @@ -24,7 +24,6 @@ from .core import inside_outside # NOQA: F401 from .core import maximization # NOQA: F401 from .core import variational_gamma # NOQA: F401 -from .normalisation import normalise_tree_sequence as normalise # NOQA: F401 from .prior import parameter_grid as build_parameter_grid # NOQA: F401 from .prior import prior_grid as build_prior_grid # NOQA: F401 from .provenance import __version__ # NOQA: F401 diff --git a/tsdate/cli.py b/tsdate/cli.py index 24bc4429..b30aef00 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -192,7 +192,7 @@ def tsdate_cli_parser(): default=1000, ) parser.add_argument( - "--normalisation-intervals", + "--rescaling-intervals", type=float, help=( "The number of time intervals within which to estimate a time " @@ -265,7 +265,7 @@ def run_date(args): progress=args.progress, max_iterations=args.max_iterations, max_shape=args.max_shape, - normalisation_intervals=args.normalisation_intervals, + rescaling_intervals=args.rescaling_intervals, ) else: params = dict( diff --git a/tsdate/core.py b/tsdate/core.py index a76c0e70..573038a6 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1231,7 +1231,7 @@ def run( max_iterations, max_shape, match_central_moments, - normalisation_intervals, + rescaling_intervals, match_segregating_sites, regularise_roots, ): @@ -1251,9 +1251,9 @@ def run( ep_maxitt=max_iterations, max_shape=max_shape, min_kl=min_kl, - norm_intervals=normalisation_intervals, + rescale_intervals=rescaling_intervals, regularise=regularise_roots, - norm_segsites=match_segregating_sites, + rescale_segsites=match_segregating_sites, progress=self.pbar, ) @@ -1476,7 +1476,7 @@ def variational_gamma( eps=None, max_iterations=None, max_shape=None, - normalisation_intervals=None, + rescaling_intervals=None, match_central_moments=None, # undocumented match_segregating_sites=None, # undocumented regularise_roots=None, # undocumented @@ -1505,7 +1505,7 @@ def variational_gamma( :param float max_shape: The maximum value for the shape parameter in the variational posteriors. This is equivalent to the maximum precision (inverse variance) on a logarithmic scale. Default: None, treated as 1000. - :param float normalisation_intervals: For normalisation, the number of time + :param float rescaling_intervals: For time rescaling, the number of time intervals within which to estimate a rescaling parameter. Default None, treated as 1000. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper @@ -1537,8 +1537,8 @@ def variational_gamma( max_iterations = 10 if max_shape is None: max_shape = 1000 - if normalisation_intervals is None: - normalisation_intervals = 1000 + if rescaling_intervals is None: + rescaling_intervals = 1000 if match_central_moments is None: match_central_moments = True if match_segregating_sites is None: @@ -1552,7 +1552,7 @@ def variational_gamma( max_iterations=max_iterations, max_shape=max_shape, match_central_moments=match_central_moments, - normalisation_intervals=normalisation_intervals, + rescaling_intervals=rescaling_intervals, match_segregating_sites=match_segregating_sites, regularise_roots=regularise_roots, ) diff --git a/tsdate/normalisation.py b/tsdate/rescaling.py similarity index 78% rename from tsdate/normalisation.py rename to tsdate/rescaling.py index 7f730224..0d6dbade 100644 --- a/tsdate/normalisation.py +++ b/tsdate/rescaling.py @@ -42,7 +42,7 @@ from .approx import _i1w from .approx import approximate_gamma_iqr from .hypergeo import _gammainc_inv as gammainc_inv -from .util import mutation_span_array +from .util import mutation_span_array # NOQA: F401 @numba.njit(_i1w(_f1r, _i)) @@ -344,72 +344,73 @@ def edge_sampling_weight( return edges_leaves -def normalise_tree_sequence( - ts, mutation_rate, *, normalisation_intervals=1000, match_segregating_sites=False -): - """ - Adjust the time scaling of a tree sequence so that expected mutational area - matches the expected number of mutations on a path from leaf to root, where - the expectation is taken over all paths and bases in the sequence. - - :param tskit.TreeSequence ts: the tree sequence to normalise - :param float mutation_rate: the per-base mutation rate - :param int normalisation_intervals: the number of time intervals for which - to estimate a separate time rescaling parameter - :param bool match_segregating_sites: if True, match the total number of - mutations rather than the average number of differences from the ancestral - state - """ - if match_segregating_sites: - edge_weights = np.ones(ts.num_edges) - else: - has_parent = np.full(ts.num_nodes, False) - has_child = np.full(ts.num_nodes, False) - has_parent[ts.edges_child] = True - has_child[ts.edges_parent] = True - is_leaf = np.logical_and(~has_child, has_parent) - edge_weights = edge_sampling_weight( - is_leaf, - ts.edges_parent, - ts.edges_child, - ts.edges_left, - ts.edges_right, - ts.indexes_edge_insertion_order, - ts.indexes_edge_removal_order, - ) - # estimate time rescaling parameter within intervals - samples = list(ts.samples()) - if not np.all(ts.nodes_time[samples] == 0.0): - raise ValueError("Normalisation not implemented for ancient samples") - constraints = np.zeros((ts.num_nodes, 2)) - constraints[:, 1] = np.inf - constraints[samples, :] = ts.nodes_time[samples, np.newaxis] - mutations_span, mutations_edge = mutation_span_array(ts) - mutations_span[:, 1] *= mutation_rate - original_breaks, rescaled_breaks = mutational_timescale( - ts.nodes_time, - mutations_span, - constraints, - ts.edges_parent, - ts.edges_child, - edge_weights, - normalisation_intervals, - ) - # rescale node time - assert np.all(np.diff(rescaled_breaks) > 0) - assert np.all(np.diff(original_breaks) > 0) - scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) - idx = np.searchsorted(original_breaks, ts.nodes_time, "right") - 1 - nodes_time = rescaled_breaks[idx] + scalings[idx] * ( - ts.nodes_time - original_breaks[idx] - ) - # calculate mutation time - mutations_parent = ts.edges_parent[mutations_edge] - mutations_child = ts.edges_child[mutations_edge] - mutations_time = (nodes_time[mutations_parent] + nodes_time[mutations_child]) / 2 - above_root = mutations_edge == tskit.NULL - mutations_time[above_root] = nodes_time[mutations_child[above_root]] - tables = ts.dump_tables() - tables.nodes.time = nodes_time - tables.mutations.time = mutations_time - return tables.tree_sequence() +# TODO: standalone API for rescaling +# def rescale_tree_sequence( +# ts, mutation_rate, *, rescaling_intervals=1000, match_segregating_sites=False +# ): +# """ +# Adjust the time scaling of a tree sequence so that expected mutational area +# matches the expected number of mutations on a path from leaf to root, where +# the expectation is taken over all paths and bases in the sequence. +# +# :param tskit.TreeSequence ts: the tree sequence to rescale +# :param float mutation_rate: the per-base mutation rate +# :param int rescaling_intervals: the number of time intervals for which +# to estimate a separate time rescaling parameter +# :param bool match_segregating_sites: if True, match the total number of +# mutations rather than the average number of differences from the ancestral +# state +# """ +# if match_segregating_sites: +# edge_weights = np.ones(ts.num_edges) +# else: +# has_parent = np.full(ts.num_nodes, False) +# has_child = np.full(ts.num_nodes, False) +# has_parent[ts.edges_child] = True +# has_child[ts.edges_parent] = True +# is_leaf = np.logical_and(~has_child, has_parent) +# edge_weights = edge_sampling_weight( +# is_leaf, +# ts.edges_parent, +# ts.edges_child, +# ts.edges_left, +# ts.edges_right, +# ts.indexes_edge_insertion_order, +# ts.indexes_edge_removal_order, +# ) +# # estimate time rescaling parameter within intervals +# samples = list(ts.samples()) +# if not np.all(ts.nodes_time[samples] == 0.0): +# raise ValueError("Normalisation not implemented for ancient samples") +# constraints = np.zeros((ts.num_nodes, 2)) +# constraints[:, 1] = np.inf +# constraints[samples, :] = ts.nodes_time[samples, np.newaxis] +# mutations_span, mutations_edge = mutation_span_array(ts) +# mutations_span[:, 1] *= mutation_rate +# original_breaks, rescaled_breaks = mutational_timescale( +# ts.nodes_time, +# mutations_span, +# constraints, +# ts.edges_parent, +# ts.edges_child, +# edge_weights, +# rescaling_intervals, +# ) +# # rescale node time +# assert np.all(np.diff(rescaled_breaks) > 0) +# assert np.all(np.diff(original_breaks) > 0) +# scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0) +# idx = np.searchsorted(original_breaks, ts.nodes_time, "right") - 1 +# nodes_time = rescaled_breaks[idx] + scalings[idx] * ( +# ts.nodes_time - original_breaks[idx] +# ) +# # calculate mutation time +# mutations_parent = ts.edges_parent[mutations_edge] +# mutations_child = ts.edges_child[mutations_edge] +# mutations_time = (nodes_time[mutations_parent] + nodes_time[mutations_child]) / 2 +# above_root = mutations_edge == tskit.NULL +# mutations_time[above_root] = nodes_time[mutations_child[above_root]] +# tables = ts.dump_tables() +# tables.nodes.time = nodes_time +# tables.mutations.time = mutations_time +# return tables.tree_sequence() diff --git a/tsdate/variational.py b/tsdate/variational.py index 6fafabce..f1a6b197 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -45,9 +45,9 @@ from .approx import _i from .approx import _i1r from .hypergeo import _gammainc_inv as gammainc_inv -from .normalisation import edge_sampling_weight -from .normalisation import mutational_timescale -from .normalisation import piecewise_scale_posterior +from .rescaling import edge_sampling_weight +from .rescaling import mutational_timescale +from .rescaling import piecewise_scale_posterior # columns for edge_factors @@ -606,17 +606,17 @@ def iterate( return np.nan # TODO: placeholder for marginal likelihood - def normalise( + def rescale( self, *, - norm_intervals=1000, - norm_segsites=False, + rescale_intervals=1000, + rescale_segsites=False, use_median=False, quantile_width=0.5, ): """Normalise posteriors so that empirical mutation rate is constant""" edge_weights = ( - np.ones(self.edge_weights.size) if norm_segsites else self.edge_weights + np.ones(self.edge_weights.size) if rescale_segsites else self.edge_weights ) nodes_time = self._point_estimate(self.posterior, self.constraints, use_median) original_breaks, rescaled_breaks = mutational_timescale( @@ -626,7 +626,7 @@ def normalise( self.parents, self.children, edge_weights, - norm_intervals, + rescale_intervals, ) self.posterior[:] = piecewise_scale_posterior( self.posterior, @@ -650,8 +650,8 @@ def run( max_shape=1000, min_step=0.1, min_kl=False, - norm_intervals=1000, - norm_segsites=False, + rescale_intervals=1000, + rescale_segsites=False, regularise=True, progress=None, ): @@ -691,8 +691,10 @@ def run( logging.info(f"Skipped {skipped_muts} mutations with invalid posteriors") logging.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds") - if norm_intervals > 0: - norm_timing = time.time() - self.normalise(norm_intervals=norm_intervals, norm_segsites=norm_segsites) - norm_timing -= time.time() - logging.info(f"Timescale normalised in {abs(norm_timing)} seconds") + if rescale_intervals > 0: + rescale_timing = time.time() + self.rescale( + rescale_intervals=rescale_intervals, rescale_segsites=rescale_segsites + ) + rescale_timing -= time.time() + logging.info(f"Timescale rescaled in {abs(rescale_timing)} seconds") From b8a9c941549272619ee104269280d8033723d4e0 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 13 May 2024 10:37:57 -0700 Subject: [PATCH 3/3] Minor testing additions --- tests/test_functions.py | 18 ------------------ tests/test_inference.py | 34 +++++++++++++++++++++++----------- tsdate/core.py | 2 +- 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/tests/test_functions.py b/tests/test_functions.py index 5e784d96..35ec918e 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -2322,21 +2322,3 @@ def test_split_disjoint_nodes(self): assert not self.has_disjoint_nodes(split_ts) assert split_ts.num_edges == inferred_ts.num_edges assert split_ts.num_nodes > inferred_ts.num_nodes - - # def test_split_root_nodes(self): - # ts = msprime.sim_ancestry( - # 10, - # population_size=1e4, - # recombination_rate=1e-8, - # sequence_length=1e6, - # random_seed=1, - # ) - # ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) - # sample_data = tsinfer.SampleData.from_tree_sequence(ts) - # inferred_ts = tsinfer.infer(sample_data).simplify() - # split_ts = split_root_nodes(inferred_ts) - # split_root_nodes(ts) - # assert not self.childset_changes_with_root(inferred_ts) - # assert self.childset_changes_with_root(split_ts) - # assert split_ts.num_edges > inferred_ts.num_edges - # assert split_ts.num_nodes > inferred_ts.num_nodes diff --git a/tests/test_inference.py b/tests/test_inference.py index 200a43c5..6339aaef 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -408,17 +408,20 @@ class TestVariational: Tests for tsdate with variational algorithm """ - ts = msprime.sim_ancestry( - samples=10, - recombination_rate=1e-8, - sequence_length=1e5, - population_size=1e4, - random_seed=2, - ) - ts = msprime.sim_mutations( - ts, - rate=1e-8, - ) + @pytest.fixture(autouse=True) + def ts(self): + ts = msprime.sim_ancestry( + samples=10, + recombination_rate=1e-8, + sequence_length=1e5, + population_size=1e4, + random_seed=2, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-8, + ) + self.ts = ts def test_binary(self): tsdate.date(self.ts, mutation_rate=1e-8, method="variational_gamma") @@ -430,3 +433,12 @@ def test_polytomy(self): def test_inferred(self): its = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(self.ts)).simplify() tsdate.date(its, mutation_rate=1e-8, method="variational_gamma") + + def test_bad_arguments(self): + with pytest.raises(ValueError, match="Maximum number of EP iterations"): + tsdate.date( + self.ts, + mutation_rate=5, + method="variational_gamma", + max_iterations=-1, + ) diff --git a/tsdate/core.py b/tsdate/core.py index 573038a6..28050d31 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1239,7 +1239,7 @@ def run( self.provenance_params.update( {k: v for k, v in locals().items() if k != "self"} ) - if not max_iterations >= 1: + if not max_iterations > 0: raise ValueError("Maximum number of EP iterations must be greater than 0") if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate")