Skip to content

Commit

Permalink
check internal matrices are consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
jonpvandermause committed Oct 12, 2024
1 parent cb57c6b commit dd06a71
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/flare_pp/bffs/sparse_gp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ std::vector<Eigen::VectorXd>
SparseGP ::compute_cluster_uncertainties(const Structure &structure) {
// TODO: this only computes the energy-energy variance, and the Sigma matrix is not considered?

if (L_inv.rows() != Kuu.rows()) {
throw std::runtime_error(
"L_inv must be up to date to evaluate uncertainties. Please call update_matrices_QR and try again."
);
}

// Create cluster descriptors.
std::vector<ClusterDescriptor> cluster_descriptors;
for (int i = 0; i < structure.descriptors.size(); i++) {
Expand Down Expand Up @@ -632,6 +638,12 @@ void SparseGP ::update_matrices_QR() {

void SparseGP ::predict_mean(Structure &test_structure) {

if (L_inv.rows() != Kuu.rows()) {
throw std::runtime_error(
"L_inv must be up to date to make predictions. Please call update_matrices_QR and try again."
);
}

int n_atoms = test_structure.noa;
int n_out = 1 + 3 * n_atoms + 6;

Expand All @@ -650,6 +662,12 @@ void SparseGP ::predict_mean(Structure &test_structure) {

void SparseGP ::predict_SOR(Structure &test_structure) {

if (L_inv.rows() != Kuu.rows()) {
throw std::runtime_error(
"L_inv must be up to date to make predictions. Please call update_matrices_QR and try again."
);
}

int n_atoms = test_structure.noa;
int n_out = 1 + 3 * n_atoms + 6;

Expand All @@ -671,6 +689,12 @@ void SparseGP ::predict_SOR(Structure &test_structure) {

void SparseGP ::predict_DTC(Structure &test_structure) {

if (L_inv.rows() != Kuu.rows()) {
throw std::runtime_error(
"L_inv must be up to date to make predictions. Please call update_matrices_QR and try again."
);
}

int n_atoms = test_structure.noa;
int n_out = 1 + 3 * n_atoms + 6;

Expand Down Expand Up @@ -701,6 +725,12 @@ void SparseGP ::predict_DTC(Structure &test_structure) {
}

void SparseGP ::predict_local_uncertainties(Structure &test_structure) {
if (L_inv.rows() != Kuu.rows()) {
throw std::runtime_error(
"L_inv must be up to date to make predictions. Please call update_matrices_QR and try again."
);
}

int n_atoms = test_structure.noa;
int n_out = 1 + 3 * n_atoms + 6;

Expand Down

0 comments on commit dd06a71

Please sign in to comment.