From dd06a716a27382ac7f5679424989c591cb55add6 Mon Sep 17 00:00:00 2001 From: Jonathan Vandermause Date: Fri, 11 Oct 2024 22:32:52 -0400 Subject: [PATCH] check internal matrices are consistent --- src/flare_pp/bffs/sparse_gp.cpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/flare_pp/bffs/sparse_gp.cpp b/src/flare_pp/bffs/sparse_gp.cpp index 76692e94..698ee6f0 100644 --- a/src/flare_pp/bffs/sparse_gp.cpp +++ b/src/flare_pp/bffs/sparse_gp.cpp @@ -94,6 +94,12 @@ std::vector SparseGP ::compute_cluster_uncertainties(const Structure &structure) { // TODO: this only computes the energy-energy variance, and the Sigma matrix is not considered? + if (L_inv.rows() != Kuu.rows()) { + throw std::runtime_error( + "L_inv must be up to date to evaluate uncertainties. Please call update_matrices_QR and try again." + ); + } + // Create cluster descriptors. std::vector cluster_descriptors; for (int i = 0; i < structure.descriptors.size(); i++) { @@ -632,6 +638,12 @@ void SparseGP ::update_matrices_QR() { void SparseGP ::predict_mean(Structure &test_structure) { + if (L_inv.rows() != Kuu.rows()) { + throw std::runtime_error( + "L_inv must be up to date to make predictions. Please call update_matrices_QR and try again." + ); + } + int n_atoms = test_structure.noa; int n_out = 1 + 3 * n_atoms + 6; @@ -650,6 +662,12 @@ void SparseGP ::predict_mean(Structure &test_structure) { void SparseGP ::predict_SOR(Structure &test_structure) { + if (L_inv.rows() != Kuu.rows()) { + throw std::runtime_error( + "L_inv must be up to date to make predictions. Please call update_matrices_QR and try again." + ); + } + int n_atoms = test_structure.noa; int n_out = 1 + 3 * n_atoms + 6; @@ -671,6 +689,12 @@ void SparseGP ::predict_SOR(Structure &test_structure) { void SparseGP ::predict_DTC(Structure &test_structure) { + if (L_inv.rows() != Kuu.rows()) { + throw std::runtime_error( + "L_inv must be up to date to make predictions. Please call update_matrices_QR and try again." + ); + } + int n_atoms = test_structure.noa; int n_out = 1 + 3 * n_atoms + 6; @@ -701,6 +725,12 @@ void SparseGP ::predict_DTC(Structure &test_structure) { } void SparseGP ::predict_local_uncertainties(Structure &test_structure) { + if (L_inv.rows() != Kuu.rows()) { + throw std::runtime_error( + "L_inv must be up to date to make predictions. Please call update_matrices_QR and try again." + ); + } + int n_atoms = test_structure.noa; int n_out = 1 + 3 * n_atoms + 6;