diff --git a/R/RcppExports.R b/R/RcppExports.R index de8d113..65fd7f3 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -5,6 +5,10 @@ run_mcmc_rcpp <- function(args) { .Call(`_moire_run_mcmc`, args) } +openmp_enabled <- function() { + .Call(`_moire_openmp_enabled`) +} + start_profiler <- function(str) { .Call(`_moire_start_profiler`, str) } diff --git a/R/mcmc.R b/R/mcmc.R index 20ed7ca..d7ef3d1 100644 --- a/R/mcmc.R +++ b/R/mcmc.R @@ -62,6 +62,9 @@ #' adaptation steps. Only used if `adapt_temp` is TRUE. #' @param max_initialization_tries Number of times to try to initialize the #' chain before giving up +#' @param max_runtime Maximum runtime in minutes. If the MCMC is running for +#' more than this amount of time, the function will stop and return the current +#' state of the MCMC. run_mcmc <- function(data, is_missing = FALSE, @@ -91,7 +94,8 @@ run_mcmc <- adapt_temp = TRUE, pre_adapt_steps = 25, temp_adapt_steps = 25, - max_initialization_tries = 10000) { + max_initialization_tries = 10000, + max_runtime = Inf) { start_time <- Sys.time() args <- as.list(environment()) mcmc_args <- as.list(environment()) diff --git a/R/utils.R b/R/utils.R index 0c9ea7f..9892223 100644 --- a/R/utils.R +++ b/R/utils.R @@ -16,6 +16,7 @@ #' @importFrom rlang .data load_long_form_data <- function(df, warn_uninformative = TRUE) { uninformative_loci <- df |> + dplyr::ungroup() |> dplyr::group_by(.data$locus) |> dplyr::summarise(total_alleles = length(unique(.data$allele))) |> dplyr::filter(.data$total_alleles == 1) |> diff --git a/man/run_mcmc.Rd b/man/run_mcmc.Rd index 11bec87..ca79b02 100644 --- a/man/run_mcmc.Rd +++ b/man/run_mcmc.Rd @@ -33,7 +33,8 @@ run_mcmc( adapt_temp = TRUE, pre_adapt_steps = 25, temp_adapt_steps = 25, - max_initialization_tries = 10000 + max_initialization_tries = 10000, + max_runtime = Inf ) } \arguments{ @@ -125,6 +126,10 @@ adaptation steps. Only used if \code{adapt_temp} is TRUE.} \item{max_initialization_tries}{Number of times to try to initialize the chain before giving up} + +\item{max_runtime}{Maximum runtime in minutes. If the MCMC is running for +more than this amount of time, the function will stop and return the current +state of the MCMC.} } \description{ Sample from the target distribution using MCMC diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 714c620..3e0aaa3 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -21,6 +21,16 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// openmp_enabled +SEXP openmp_enabled(); +RcppExport SEXP _moire_openmp_enabled() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(openmp_enabled()); + return rcpp_result_gen; +END_RCPP +} // start_profiler SEXP start_profiler(SEXP str); RcppExport SEXP _moire_start_profiler(SEXP strSEXP) { @@ -45,6 +55,7 @@ END_RCPP static const R_CallMethodDef CallEntries[] = { {"_moire_run_mcmc", (DL_FUNC) &_moire_run_mcmc, 1}, + {"_moire_openmp_enabled", (DL_FUNC) &_moire_openmp_enabled, 0}, {"_moire_start_profiler", (DL_FUNC) &_moire_start_profiler, 1}, {"_moire_stop_profiler", (DL_FUNC) &_moire_stop_profiler, 0}, {NULL, NULL, 0} diff --git a/src/chain.cpp b/src/chain.cpp index 58cde8a..c0342aa 100644 --- a/src/chain.cpp +++ b/src/chain.cpp @@ -1016,6 +1016,37 @@ float Chain::get_llik() { return llik; } float Chain::get_prior() { return prior; } float Chain::get_posterior() { return llik * temp + prior; } +float Chain::get_llik(int sample) { + int idx = sample * genotyping_data.num_loci; + + #ifdef HAS_EXECUTION + return std::reduce(std::execution::unseq, genotyping_llik_new.begin() + idx, + genotyping_llik_new.begin() + idx + genotyping_data.num_loci); + #else + float llik = 0.0f; + #pragma omp simd reduction(+:llik) + for (int i = 0; i < genotyping_data.num_loci; ++i) { + llik += genotyping_llik_new[idx + i]; + } + return llik; + #endif +} +float Chain::get_prior(int sample) { + float prior = 0.0f; + if (params.allow_relatedness) { + prior += relatedness_prior_new[sample]; + } + prior += coi_prior_new[sample]; + prior += eps_neg_prior_new[sample]; + prior += eps_pos_prior_new[sample]; + return prior; +} + +float Chain::get_posterior(int sample) { + float posterior = get_llik(sample) * temp + get_prior(sample); + return posterior; +} + void Chain::calculate_genotype_likelihood(int sample_idx, int locus_idx) { int idx = sample_idx * genotyping_data.num_loci + locus_idx; diff --git a/src/chain.h b/src/chain.h index 05d2430..093d0f6 100644 --- a/src/chain.h +++ b/src/chain.h @@ -154,6 +154,9 @@ class Chain float get_llik(); float get_prior(); float get_posterior(); + float get_llik(int sample); + float get_prior(int sample); + float get_posterior(int sample); void set_llik(float llik); void set_temp(float temp); diff --git a/src/main.cpp b/src/main.cpp index 2baa67c..1d01354 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -29,13 +29,20 @@ Rcpp::List run_mcmc(Rcpp::List args) params.adapt_temp ? "Yes" : "No"); } + enum events { + START_COMPUTATION + }; + + Timer timer; + MCMC mcmc(genotyping_data, params); MCMCProgressBar pb(params.burnin, params.samples, params.use_message); Progress p(params.burnin + params.samples, params.verbose, pb); pb.set_llik(mcmc.get_llik()); int step = 0; - while (step < params.burnin) + timer.record_event(events::START_COMPUTATION); + while (step < params.burnin && timer.time_since_event(events::START_COMPUTATION).count() < params.max_runtime) { Rcpp::checkUserInterrupt(); mcmc.burnin(step); @@ -54,7 +61,7 @@ Rcpp::List run_mcmc(Rcpp::List args) } step = 0; - while (step < params.samples) + while (step < params.samples && timer.time_since_event(events::START_COMPUTATION).count() < params.max_runtime) { Rcpp::checkUserInterrupt(); mcmc.sample(step); @@ -72,7 +79,10 @@ Rcpp::List run_mcmc(Rcpp::List args) } } + bool max_runtime_reached = timer.time_since_event(events::START_COMPUTATION).count() >= params.max_runtime && step < params.samples; + mcmc.finalize(); + float runtime = timer.time_since_event(events::START_COMPUTATION).count(); Rcpp::List acceptance_rates; Rcpp::List sampling_variances; @@ -123,6 +133,7 @@ Rcpp::List run_mcmc(Rcpp::List args) res.push_back(Rcpp::wrap(mcmc.prior_sample)); res.push_back(Rcpp::wrap(mcmc.posterior_burnin)); res.push_back(Rcpp::wrap(mcmc.posterior_sample)); + res.push_back(Rcpp::wrap(mcmc.data_llik_store)); res.push_back(Rcpp::wrap(mcmc.m_store)); res.push_back(Rcpp::wrap(mcmc.mean_coi_store)); res.push_back(Rcpp::wrap(mcmc.p_store)); @@ -137,6 +148,9 @@ Rcpp::List run_mcmc(Rcpp::List args) res.push_back(Rcpp::wrap(mcmc.temp_gradient)); res.push_back(Rcpp::wrap(acceptance_rates)); res.push_back(Rcpp::wrap(sampling_variances)); + res.push_back(Rcpp::wrap(max_runtime_reached)); + res.push_back(Rcpp::wrap(runtime)); + Rcpp::StringVector res_names; res_names.push_back("llik_burnin"); @@ -145,6 +159,7 @@ Rcpp::List run_mcmc(Rcpp::List args) res_names.push_back("prior_sample"); res_names.push_back("posterior_burnin"); res_names.push_back("posterior_sample"); + res_names.push_back("data_llik"); res_names.push_back("coi"); res_names.push_back("lam_coi"); res_names.push_back("allele_freqs"); @@ -159,6 +174,8 @@ Rcpp::List run_mcmc(Rcpp::List args) res_names.push_back("temp_gradient"); res_names.push_back("acceptance_rates"); res_names.push_back("sampling_variances"); + res_names.push_back("max_runtime_reached"); + res_names.push_back("total_runtime"); res.names() = res_names; return res; diff --git a/src/mcmc.cpp b/src/mcmc.cpp index 3c46f87..2555f93 100644 --- a/src/mcmc.cpp +++ b/src/mcmc.cpp @@ -31,6 +31,7 @@ MCMC::MCMC(GenotypingData genotyping_data, Parameters params) r_store.resize(genotyping_data.num_samples); eps_neg_store.resize(genotyping_data.num_samples); eps_pos_store.resize(genotyping_data.num_samples); + data_llik_store.resize(genotyping_data.num_samples); swap_acceptances.resize(params.pt_chains.size() - 1, 0); swap_barriers.resize(params.pt_chains.size() - 1, 0.0); swap_indices.resize(params.pt_chains.size(), 0); @@ -253,6 +254,7 @@ void MCMC::sample(int step) eps_neg_store[jj].push_back(chain.eps_neg[jj]); eps_pos_store[jj].push_back(chain.eps_pos[jj]); r_store[jj].push_back(chain.r[jj]); + data_llik_store[jj].push_back(chain.get_llik(jj)); if (params.record_latent_genotypes) { for (size_t kk = 0; kk < genotyping_data.num_loci; ++kk) @@ -282,3 +284,13 @@ void MCMC::finalize() float MCMC::get_llik() { return chains[swap_indices[0]].get_llik(); } float MCMC::get_prior() { return chains[swap_indices[0]].get_prior(); } float MCMC::get_posterior() { return chains[swap_indices[0]].get_posterior(); } + +// [[Rcpp::export]] +SEXP openmp_enabled() +{ + #ifdef _OPENMP + return Rcpp::wrap(true); + #else + return Rcpp::wrap(false); + #endif +} diff --git a/src/mcmc.h b/src/mcmc.h index 401f76c..21ce9fb 100644 --- a/src/mcmc.h +++ b/src/mcmc.h @@ -22,6 +22,7 @@ class MCMC std::vector> m_store{}; std::vector>> p_store{}; std::vector>>> latent_genotypes_store{}; + std::vector> data_llik_store{}; std::vector> eps_pos_store{}; std::vector> eps_neg_store{}; std::vector> r_store{}; diff --git a/src/mcmc_progress_bar.cpp b/src/mcmc_progress_bar.cpp index 9d606f4..6718a2e 100644 --- a/src/mcmc_progress_bar.cpp +++ b/src/mcmc_progress_bar.cpp @@ -25,7 +25,7 @@ void MCMCProgressBar::update(float progress) } else { - // stop and record time no more than every .1 seconds + // stop and record time no more than every 1 second if (clock_.time_since_event(events::UPDATE_CONSOLE).count() < 1000) { return; diff --git a/src/parameters.cpp b/src/parameters.cpp index 3642adb..3b032ed 100644 --- a/src/parameters.cpp +++ b/src/parameters.cpp @@ -21,6 +21,7 @@ Parameters::Parameters(const Rcpp::List &args) temp_adapt_steps = UtilFunctions::r_to_int(args["temp_adapt_steps"]); max_initialization_tries = UtilFunctions::r_to_int(args["max_initialization_tries"]); record_latent_genotypes = UtilFunctions::r_to_bool(args["record_latent_genotypes"]); + max_runtime = UtilFunctions::r_to_float(args["max_runtime"]); // Model max_coi = UtilFunctions::r_to_int(args["max_coi"]); diff --git a/src/parameters.h b/src/parameters.h index 9d0d90e..4853419 100644 --- a/src/parameters.h +++ b/src/parameters.h @@ -22,6 +22,7 @@ class Parameters int pre_adapt_steps; int temp_adapt_steps; int max_initialization_tries; + float max_runtime; bool record_latent_genotypes; diff --git a/src/prob_any_missing.cpp b/src/prob_any_missing.cpp index ca3090d..1bbbdee 100644 --- a/src/prob_any_missing.cpp +++ b/src/prob_any_missing.cpp @@ -69,6 +69,12 @@ std::vector probAnyMissingFunctor::vectorized( const std::size_t totalEvents = eventProbs.size(); std::vector probVec(maxNumEvents - minNumEvents + 1, 0.0); + + if (maxNumEvents < totalEvents) { + std::fill_n(probVec.begin(), maxNumEvents - minNumEvents + 1, 1.0); + return probVec; + } + std::fill_n(probVec.begin(), totalEvents - 1, 1.0); // Calculate via inclusion-exclusion principle