Skip to content

Commit

Permalink
Add tests checking all calculator support key selection
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Dec 1, 2022
1 parent 5a2c2d7 commit 4842eb9
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 138 deletions.
3 changes: 2 additions & 1 deletion rascaline/src/calculators/dummy_calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,10 @@ mod tests {

let samples = Labels::new(["structure", "center"], &[[0, 1]]);
let properties = Labels::new(["index_delta", "x_y_z"], &[[0, 1]]);
let keys = Labels::new(["species_center"], &[[0], [1], [6], [-42]]);

crate::calculators::tests_utils::compute_partial(
calculator, &mut systems, &samples, &properties
calculator, &mut systems, &keys, &samples, &properties
);
}
}
36 changes: 32 additions & 4 deletions rascaline/src/calculators/lode/spherical_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,12 @@ impl LodeSphericalExpansion {
0.into(),
species[center_i].into(),
species[center_i].into(),
]).expect("missing block");
]);

if block_i.is_none() {
continue;
}
let block_i = block_i.expect("we just checked");

let mut block = descriptor.block_mut_by_id(block_i);
let values = block.values_mut();
Expand Down Expand Up @@ -627,7 +632,14 @@ impl CalculatorBase for LodeSphericalExpansion {
spherical_harmonics_l.into(),
species[center_i].into(),
species_neighbor.into(),
]).expect("missing block");
]);

if block_i.is_none() {
continue;
}
let block_i = block_i.expect("we just checked");


let mut block = descriptor.block_mut_by_id(block_i);
let values = block.values_mut();
let mut array = array_mut_for_system(&mut values.data);
Expand Down Expand Up @@ -802,7 +814,7 @@ mod tests {
cutoff: 1.0,
k_cutoff: None,
max_radial: 4,
max_angular: 4,
max_angular: 2,
atomic_gaussian_width: 1.0,
center_atom_weight: 1.0,
radial_basis: RadialBasis::splined_gto(1e-8),
Expand All @@ -824,8 +836,24 @@ mod tests {
[0, 2],
]);

let keys = Labels::new(["spherical_harmonics_l", "species_center", "species_neighbor"], &[
[0, -42, -42],
[0, 6, 1], // not part of the default keys
[2, -42, -42],
[1, -42, -42],
[1, -42, 1],
[1, 1, -42],
[0, -42, 1],
[2, -42, 1],
[0, 1, 1],
[1, 1, 1],
[0, 1, -42],
[2, 1, -42],
[2, 1, 1],
]);

crate::calculators::tests_utils::compute_partial(
calculator, &mut [Box::new(system)], &samples, &properties
calculator, &mut [Box::new(system)], &keys, &samples, &properties
);
}

Expand Down
47 changes: 26 additions & 21 deletions rascaline/src/calculators/neighbor_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,10 @@ impl CalculatorBase for NeighborList {
}

fn components(&self, keys: &Labels) -> Vec<Vec<Arc<Labels>>> {
let mut component = LabelsBuilder::new(vec!["pair_direction"]);
component.add(&[0]);
component.add(&[1]);
component.add(&[2]);

return vec![vec![Arc::new(component.finish())]; keys.count()];
return vec![vec![Arc::new(Labels::new(
["pair_direction"],
&[[0], [1], [2]]
))]; keys.count()];
}

fn properties_names(&self) -> Vec<&str> {
Expand Down Expand Up @@ -439,7 +437,7 @@ impl FullNeighborList {
#[cfg(test)]
mod tests {
use approx::assert_relative_eq;
use equistore::{LabelValue, LabelsBuilder, Labels};
use equistore::Labels;

use crate::systems::test_utils::{test_systems, test_system};
use crate::Calculator;
Expand All @@ -464,7 +462,7 @@ mod tests {
));

// O-H block
let block = descriptor.blocks()[0].values();
let block = descriptor.block_by_id(0).values();
assert_eq!(*block.properties, Labels::new(["distance"], &[[0]]));

assert_eq!(block.components.len(), 1);
Expand All @@ -484,7 +482,7 @@ mod tests {
assert_relative_eq!(array, expected, max_relative=1e-6);

// H-H block
let block = descriptor.blocks()[1].values();
let block = descriptor.block_by_id(1).values();
assert_eq!(*block.samples, Labels::new(
["structure", "pair_id", "first_atom", "second_atom"],
// we have one H-H pair
Expand Down Expand Up @@ -515,7 +513,7 @@ mod tests {
));

// O-H block
let block = descriptor.blocks()[0].values();
let block = descriptor.block_by_id(0).values();
assert_eq!(*block.properties, Labels::new(["distance"], &[[0]]));

assert_eq!(block.components.len(), 1);
Expand All @@ -535,7 +533,7 @@ mod tests {
assert_relative_eq!(array, expected, max_relative=1e-6);

// H-O block
let block = descriptor.blocks()[1].values();
let block = descriptor.block_by_id(1).values();
assert_eq!(*block.properties, Labels::new(["distance"], &[[0]]));

assert_eq!(block.components.len(), 1);
Expand All @@ -555,7 +553,7 @@ mod tests {
assert_relative_eq!(array, expected, max_relative=1e-6);

// H-H block
let block = descriptor.blocks()[2].values();
let block = descriptor.block_by_id(2).values();
assert_eq!(*block.samples, Labels::new(
["structure", "pair_id", "first_atom", "second_atom"],
// we have one H-H pair, twice
Expand Down Expand Up @@ -601,18 +599,25 @@ mod tests {
cutoff: 1.0,
full_neighbor_list: false,
}) as Box<dyn CalculatorBase>);
let mut systems = test_systems(&["water"]);
let mut systems = test_systems(&["water", "methane"]);

let mut samples = LabelsBuilder::new(vec!["structure", "first_atom"]);
samples.add(&[LabelValue::new(0), LabelValue::new(1)]);
let samples = samples.finish();
let samples = Labels::new(
["structure", "first_atom"],
&[[0, 1]],
);

let mut properties = LabelsBuilder::new(vec!["distance"]);
properties.add(&[LabelValue::new(0)]);
let properties = properties.finish();
let properties = Labels::new(
["distance"],
&[[0]],
);

let keys = Labels::new(
["species_first_atom", "species_second_atom"],
&[[-42, 1], [1, -42], [1, 1], [6, 6]]
);

crate::calculators::tests_utils::compute_partial(
calculator, &mut systems, &samples, &properties
calculator, &mut systems, &keys, &samples, &properties
);

// full neighbor list
Expand All @@ -621,7 +626,7 @@ mod tests {
full_neighbor_list: true,
}) as Box<dyn CalculatorBase>);
crate::calculators::tests_utils::compute_partial(
calculator, &mut systems, &samples, &properties
calculator, &mut systems, &keys, &samples, &properties
);
}
}
27 changes: 18 additions & 9 deletions rascaline/src/calculators/soap/power_spectrum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,13 @@ impl SoapPowerSpectrum {
let block_id_1 = spherical_expansion.keys().position(&[
first_l, species_center, species_neighbor_1
]).expect("missing block in spherical expansion");
let spx_block_1 = &spherical_expansion.blocks()[block_id_1];
let spx_block_1 = &spherical_expansion.block_by_id(block_id_1);
let spx_samples_1 = &spx_block_1.values().samples;

let block_id_2 = spherical_expansion.keys().position(&[
first_l, species_center, species_neighbor_2
]).expect("missing block in spherical expansion");
let spx_block_2 = &spherical_expansion.blocks()[block_id_2];
let spx_block_2 = &spherical_expansion.block_by_id(block_id_2);
let spx_samples_2 = &spx_block_2.values().samples;

values_mapping.reserve(values.samples.count());
Expand Down Expand Up @@ -321,12 +321,12 @@ impl SoapPowerSpectrum {
let block_1 = spherical_expansion.keys().position(
&[l, species_center, species_neighbor_1]
).expect("missing first neighbor species block in spherical expansion");
let block_1 = &spherical_expansion.blocks()[block_1];
let block_1 = &spherical_expansion.block_by_id(block_1);

let block_2 = spherical_expansion.keys().position(
&[l, species_center, species_neighbor_2]
).expect("missing second neighbor species block in spherical expansion");
let block_2 = &spherical_expansion.blocks()[block_2];
let block_2 = &spherical_expansion.block_by_id(block_2);

let values_1 = block_1.values().data.as_array();
let values_2 = block_2.values().data.as_array();
Expand Down Expand Up @@ -518,6 +518,7 @@ impl CalculatorBase for SoapPowerSpectrum {
gradients: &gradients,
selected_samples: LabelsSelection::Predefined(&selected),
selected_properties: LabelsSelection::Predefined(&selected),
selected_keys: Some(selected.keys()),
..Default::default()
};

Expand Down Expand Up @@ -807,7 +808,7 @@ mod tests {
parameters()
).unwrap()) as Box<dyn CalculatorBase>);

let mut systems = test_systems(&["water", "methane"]);
let mut systems = test_systems(&["methane"]);

let properties = Labels::new(["l", "n1", "n2"], &[
[0, 0, 1],
Expand All @@ -819,14 +820,22 @@ mod tests {
]);

let samples = Labels::new(["structure", "center"], &[
[0, 1],
[0, 2],
[1, 0],
[1, 2],
[0, 1],
]);

let keys = Labels::new(["species_center", "species_neighbor_1", "species_neighbor_2"], &[
[1, 1, 1],
[6, 6, 6],
[1, 8, 6], // not part of the default keys
[1, 6, 6],
[1, 1, 6],
[6, 1, 1],
[6, 1, 6],
]);

crate::calculators::tests_utils::compute_partial(
calculator, &mut systems, &samples, &properties
calculator, &mut systems, &keys, &samples, &properties
);
}

Expand Down
16 changes: 14 additions & 2 deletions rascaline/src/calculators/soap/radial_spectrum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ impl CalculatorBase for SoapRadialSpectrum {
gradients: &gradients,
selected_samples: LabelsSelection::Predefined(&selected),
selected_properties: LabelsSelection::Predefined(&selected),
selected_keys: Some(selected.keys()),
..Default::default()
};

Expand Down Expand Up @@ -365,13 +366,24 @@ mod tests {
]);

let samples = Labels::new(["structure", "center"], &[
[1, 0],
[0, 1],
[0, 0],
[1, 0],
]);

let keys = Labels::new(["species_center", "species_neighbor"], &[
[1, 1],
[9, 1], // not part of the default keys
[-42, 1],
[1, -42],
[1, 6],
[-42, -42],
[6, 1],
[6, 6],
]);

crate::calculators::tests_utils::compute_partial(
calculator, &mut systems, &samples, &properties,
calculator, &mut systems, &keys, &samples, &properties,
);
}
}
31 changes: 24 additions & 7 deletions rascaline/src/calculators/soap/spherical_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ mod tests {
l.into(), species_center.into() , species_neighbor.into()
]);
assert!(block_i.is_some());
let block = &descriptor.blocks()[block_i.unwrap()];
let block = &descriptor.block_by_id(block_i.unwrap());
let array = block.values().data.as_array();
assert_eq!(array.shape().len(), 3);
assert_eq!(array.shape()[1], 2 * l + 1);
Expand Down Expand Up @@ -814,10 +814,13 @@ mod tests {
#[test]
fn compute_partial() {
let calculator = Calculator::from(Box::new(SphericalExpansion::new(
parameters()
SphericalExpansionParameters {
max_angular: 2,
..parameters()
}
).unwrap()) as Box<dyn CalculatorBase>);

let mut systems = test_systems(&["water", "methane"]);
let mut systems = test_systems(&["water"]);

let properties = Labels::new(["n"], &[
[0],
Expand All @@ -826,14 +829,28 @@ mod tests {
]);

let samples = Labels::new(["structure", "center"], &[
[0, 1],
[0, 2],
[1, 0],
[1, 2],
[0, 1],
]);

let keys = Labels::new(["spherical_harmonics_l", "species_center", "species_neighbor"], &[
[0, -42, -42],
[0, 6, 1], // not part of the default keys
[2, -42, -42],
[1, -42, -42],
[1, -42, 1],
[1, 1, -42],
[0, -42, 1],
[2, -42, 1],
[0, 1, 1],
[1, 1, 1],
[0, 1, -42],
[2, 1, -42],
[2, 1, 1],
]);

crate::calculators::tests_utils::compute_partial(
calculator, &mut systems, &samples, &properties
calculator, &mut systems, &keys, &samples, &properties
);
}

Expand Down
3 changes: 2 additions & 1 deletion rascaline/src/calculators/sorted_distances.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,12 @@ mod tests {

let mut systems = test_systems(&["water"]);

let keys = Labels::new(["species_center"], &[[1], [6], [8], [-42]]);
let samples = Labels::new(["structure", "center"], &[[0, 1]]);
let properties = Labels::new(["neighbor"], &[[2], [0]]);

crate::calculators::tests_utils::compute_partial(
calculator, &mut systems, &samples, &properties
calculator, &mut systems, &keys, &samples, &properties
);
}
}
Loading

0 comments on commit 4842eb9

Please sign in to comment.