diff --git a/rascaline/src/calculators/dummy_calculator.rs b/rascaline/src/calculators/dummy_calculator.rs index 889df7aa1..e92aef831 100644 --- a/rascaline/src/calculators/dummy_calculator.rs +++ b/rascaline/src/calculators/dummy_calculator.rs @@ -245,9 +245,10 @@ mod tests { let samples = Labels::new(["structure", "center"], &[[0, 1]]); let properties = Labels::new(["index_delta", "x_y_z"], &[[0, 1]]); + let keys = Labels::new(["species_center"], &[[0], [1], [6], [-42]]); crate::calculators::tests_utils::compute_partial( - calculator, &mut systems, &samples, &properties + calculator, &mut systems, &keys, &samples, &properties ); } } diff --git a/rascaline/src/calculators/lode/spherical_expansion.rs b/rascaline/src/calculators/lode/spherical_expansion.rs index 6ee9c4613..77eb3ccb6 100644 --- a/rascaline/src/calculators/lode/spherical_expansion.rs +++ b/rascaline/src/calculators/lode/spherical_expansion.rs @@ -388,7 +388,12 @@ impl LodeSphericalExpansion { 0.into(), species[center_i].into(), species[center_i].into(), - ]).expect("missing block"); + ]); + + if block_i.is_none() { + continue; + } + let block_i = block_i.expect("we just checked"); let mut block = descriptor.block_mut_by_id(block_i); let values = block.values_mut(); @@ -627,7 +632,14 @@ impl CalculatorBase for LodeSphericalExpansion { spherical_harmonics_l.into(), species[center_i].into(), species_neighbor.into(), - ]).expect("missing block"); + ]); + + if block_i.is_none() { + continue; + } + let block_i = block_i.expect("we just checked"); + + let mut block = descriptor.block_mut_by_id(block_i); let values = block.values_mut(); let mut array = array_mut_for_system(&mut values.data); @@ -802,7 +814,7 @@ mod tests { cutoff: 1.0, k_cutoff: None, max_radial: 4, - max_angular: 4, + max_angular: 2, atomic_gaussian_width: 1.0, center_atom_weight: 1.0, radial_basis: RadialBasis::splined_gto(1e-8), @@ -824,8 +836,24 @@ mod tests { [0, 2], ]); + let keys = Labels::new(["spherical_harmonics_l", "species_center", "species_neighbor"], &[ + [0, -42, -42], + [0, 6, 1], // not part of the default keys + [2, -42, -42], + [1, -42, -42], + [1, -42, 1], + [1, 1, -42], + [0, -42, 1], + [2, -42, 1], + [0, 1, 1], + [1, 1, 1], + [0, 1, -42], + [2, 1, -42], + [2, 1, 1], + ]); + crate::calculators::tests_utils::compute_partial( - calculator, &mut [Box::new(system)], &samples, &properties + calculator, &mut [Box::new(system)], &keys, &samples, &properties ); } diff --git a/rascaline/src/calculators/neighbor_list.rs b/rascaline/src/calculators/neighbor_list.rs index 61da7ac5d..327e8edcc 100644 --- a/rascaline/src/calculators/neighbor_list.rs +++ b/rascaline/src/calculators/neighbor_list.rs @@ -107,12 +107,10 @@ impl CalculatorBase for NeighborList { } fn components(&self, keys: &Labels) -> Vec>> { - let mut component = LabelsBuilder::new(vec!["pair_direction"]); - component.add(&[0]); - component.add(&[1]); - component.add(&[2]); - - return vec![vec![Arc::new(component.finish())]; keys.count()]; + return vec![vec![Arc::new(Labels::new( + ["pair_direction"], + &[[0], [1], [2]] + ))]; keys.count()]; } fn properties_names(&self) -> Vec<&str> { @@ -439,7 +437,7 @@ impl FullNeighborList { #[cfg(test)] mod tests { use approx::assert_relative_eq; - use equistore::{LabelValue, LabelsBuilder, Labels}; + use equistore::Labels; use crate::systems::test_utils::{test_systems, test_system}; use crate::Calculator; @@ -464,7 +462,7 @@ mod tests { )); // O-H block - let block = descriptor.blocks()[0].values(); + let block = descriptor.block_by_id(0).values(); assert_eq!(*block.properties, Labels::new(["distance"], &[[0]])); assert_eq!(block.components.len(), 1); @@ -484,7 +482,7 @@ mod tests { assert_relative_eq!(array, expected, max_relative=1e-6); // H-H block - let block = descriptor.blocks()[1].values(); + let block = descriptor.block_by_id(1).values(); assert_eq!(*block.samples, Labels::new( ["structure", "pair_id", "first_atom", "second_atom"], // we have one H-H pair @@ -515,7 +513,7 @@ mod tests { )); // O-H block - let block = descriptor.blocks()[0].values(); + let block = descriptor.block_by_id(0).values(); assert_eq!(*block.properties, Labels::new(["distance"], &[[0]])); assert_eq!(block.components.len(), 1); @@ -535,7 +533,7 @@ mod tests { assert_relative_eq!(array, expected, max_relative=1e-6); // H-O block - let block = descriptor.blocks()[1].values(); + let block = descriptor.block_by_id(1).values(); assert_eq!(*block.properties, Labels::new(["distance"], &[[0]])); assert_eq!(block.components.len(), 1); @@ -555,7 +553,7 @@ mod tests { assert_relative_eq!(array, expected, max_relative=1e-6); // H-H block - let block = descriptor.blocks()[2].values(); + let block = descriptor.block_by_id(2).values(); assert_eq!(*block.samples, Labels::new( ["structure", "pair_id", "first_atom", "second_atom"], // we have one H-H pair, twice @@ -601,18 +599,25 @@ mod tests { cutoff: 1.0, full_neighbor_list: false, }) as Box); - let mut systems = test_systems(&["water"]); + let mut systems = test_systems(&["water", "methane"]); - let mut samples = LabelsBuilder::new(vec!["structure", "first_atom"]); - samples.add(&[LabelValue::new(0), LabelValue::new(1)]); - let samples = samples.finish(); + let samples = Labels::new( + ["structure", "first_atom"], + &[[0, 1]], + ); - let mut properties = LabelsBuilder::new(vec!["distance"]); - properties.add(&[LabelValue::new(0)]); - let properties = properties.finish(); + let properties = Labels::new( + ["distance"], + &[[0]], + ); + + let keys = Labels::new( + ["species_first_atom", "species_second_atom"], + &[[-42, 1], [1, -42], [1, 1], [6, 6]] + ); crate::calculators::tests_utils::compute_partial( - calculator, &mut systems, &samples, &properties + calculator, &mut systems, &keys, &samples, &properties ); // full neighbor list @@ -621,7 +626,7 @@ mod tests { full_neighbor_list: true, }) as Box); crate::calculators::tests_utils::compute_partial( - calculator, &mut systems, &samples, &properties + calculator, &mut systems, &keys, &samples, &properties ); } } diff --git a/rascaline/src/calculators/soap/power_spectrum.rs b/rascaline/src/calculators/soap/power_spectrum.rs index 7da32b166..734a56597 100644 --- a/rascaline/src/calculators/soap/power_spectrum.rs +++ b/rascaline/src/calculators/soap/power_spectrum.rs @@ -265,13 +265,13 @@ impl SoapPowerSpectrum { let block_id_1 = spherical_expansion.keys().position(&[ first_l, species_center, species_neighbor_1 ]).expect("missing block in spherical expansion"); - let spx_block_1 = &spherical_expansion.blocks()[block_id_1]; + let spx_block_1 = &spherical_expansion.block_by_id(block_id_1); let spx_samples_1 = &spx_block_1.values().samples; let block_id_2 = spherical_expansion.keys().position(&[ first_l, species_center, species_neighbor_2 ]).expect("missing block in spherical expansion"); - let spx_block_2 = &spherical_expansion.blocks()[block_id_2]; + let spx_block_2 = &spherical_expansion.block_by_id(block_id_2); let spx_samples_2 = &spx_block_2.values().samples; values_mapping.reserve(values.samples.count()); @@ -321,12 +321,12 @@ impl SoapPowerSpectrum { let block_1 = spherical_expansion.keys().position( &[l, species_center, species_neighbor_1] ).expect("missing first neighbor species block in spherical expansion"); - let block_1 = &spherical_expansion.blocks()[block_1]; + let block_1 = &spherical_expansion.block_by_id(block_1); let block_2 = spherical_expansion.keys().position( &[l, species_center, species_neighbor_2] ).expect("missing second neighbor species block in spherical expansion"); - let block_2 = &spherical_expansion.blocks()[block_2]; + let block_2 = &spherical_expansion.block_by_id(block_2); let values_1 = block_1.values().data.as_array(); let values_2 = block_2.values().data.as_array(); @@ -518,6 +518,7 @@ impl CalculatorBase for SoapPowerSpectrum { gradients: &gradients, selected_samples: LabelsSelection::Predefined(&selected), selected_properties: LabelsSelection::Predefined(&selected), + selected_keys: Some(selected.keys()), ..Default::default() }; @@ -807,7 +808,7 @@ mod tests { parameters() ).unwrap()) as Box); - let mut systems = test_systems(&["water", "methane"]); + let mut systems = test_systems(&["methane"]); let properties = Labels::new(["l", "n1", "n2"], &[ [0, 0, 1], @@ -819,14 +820,22 @@ mod tests { ]); let samples = Labels::new(["structure", "center"], &[ - [0, 1], [0, 2], - [1, 0], - [1, 2], + [0, 1], + ]); + + let keys = Labels::new(["species_center", "species_neighbor_1", "species_neighbor_2"], &[ + [1, 1, 1], + [6, 6, 6], + [1, 8, 6], // not part of the default keys + [1, 6, 6], + [1, 1, 6], + [6, 1, 1], + [6, 1, 6], ]); crate::calculators::tests_utils::compute_partial( - calculator, &mut systems, &samples, &properties + calculator, &mut systems, &keys, &samples, &properties ); } diff --git a/rascaline/src/calculators/soap/radial_spectrum.rs b/rascaline/src/calculators/soap/radial_spectrum.rs index dda69bedf..4cf1324e1 100644 --- a/rascaline/src/calculators/soap/radial_spectrum.rs +++ b/rascaline/src/calculators/soap/radial_spectrum.rs @@ -223,6 +223,7 @@ impl CalculatorBase for SoapRadialSpectrum { gradients: &gradients, selected_samples: LabelsSelection::Predefined(&selected), selected_properties: LabelsSelection::Predefined(&selected), + selected_keys: Some(selected.keys()), ..Default::default() }; @@ -365,13 +366,24 @@ mod tests { ]); let samples = Labels::new(["structure", "center"], &[ + [1, 0], [0, 1], [0, 0], - [1, 0], + ]); + + let keys = Labels::new(["species_center", "species_neighbor"], &[ + [1, 1], + [9, 1], // not part of the default keys + [-42, 1], + [1, -42], + [1, 6], + [-42, -42], + [6, 1], + [6, 6], ]); crate::calculators::tests_utils::compute_partial( - calculator, &mut systems, &samples, &properties, + calculator, &mut systems, &keys, &samples, &properties, ); } } diff --git a/rascaline/src/calculators/soap/spherical_expansion.rs b/rascaline/src/calculators/soap/spherical_expansion.rs index aa1248e90..2909d6002 100644 --- a/rascaline/src/calculators/soap/spherical_expansion.rs +++ b/rascaline/src/calculators/soap/spherical_expansion.rs @@ -769,7 +769,7 @@ mod tests { l.into(), species_center.into() , species_neighbor.into() ]); assert!(block_i.is_some()); - let block = &descriptor.blocks()[block_i.unwrap()]; + let block = &descriptor.block_by_id(block_i.unwrap()); let array = block.values().data.as_array(); assert_eq!(array.shape().len(), 3); assert_eq!(array.shape()[1], 2 * l + 1); @@ -814,10 +814,13 @@ mod tests { #[test] fn compute_partial() { let calculator = Calculator::from(Box::new(SphericalExpansion::new( - parameters() + SphericalExpansionParameters { + max_angular: 2, + ..parameters() + } ).unwrap()) as Box); - let mut systems = test_systems(&["water", "methane"]); + let mut systems = test_systems(&["water"]); let properties = Labels::new(["n"], &[ [0], @@ -826,14 +829,28 @@ mod tests { ]); let samples = Labels::new(["structure", "center"], &[ - [0, 1], [0, 2], - [1, 0], - [1, 2], + [0, 1], + ]); + + let keys = Labels::new(["spherical_harmonics_l", "species_center", "species_neighbor"], &[ + [0, -42, -42], + [0, 6, 1], // not part of the default keys + [2, -42, -42], + [1, -42, -42], + [1, -42, 1], + [1, 1, -42], + [0, -42, 1], + [2, -42, 1], + [0, 1, 1], + [1, 1, 1], + [0, 1, -42], + [2, 1, -42], + [2, 1, 1], ]); crate::calculators::tests_utils::compute_partial( - calculator, &mut systems, &samples, &properties + calculator, &mut systems, &keys, &samples, &properties ); } diff --git a/rascaline/src/calculators/sorted_distances.rs b/rascaline/src/calculators/sorted_distances.rs index 2f68b33f5..3227d3c75 100644 --- a/rascaline/src/calculators/sorted_distances.rs +++ b/rascaline/src/calculators/sorted_distances.rs @@ -225,11 +225,12 @@ mod tests { let mut systems = test_systems(&["water"]); + let keys = Labels::new(["species_center"], &[[1], [6], [8], [-42]]); let samples = Labels::new(["structure", "center"], &[[0, 1]]); let properties = Labels::new(["neighbor"], &[[2], [0]]); crate::calculators::tests_utils::compute_partial( - calculator, &mut systems, &samples, &properties + calculator, &mut systems, &keys, &samples, &properties ); } } diff --git a/rascaline/src/calculators/tests_utils.rs b/rascaline/src/calculators/tests_utils.rs index 6bcb3bebc..b31460f73 100644 --- a/rascaline/src/calculators/tests_utils.rs +++ b/rascaline/src/calculators/tests_utils.rs @@ -1,7 +1,7 @@ use ndarray::Axis; use approx::{assert_relative_eq, assert_ulps_eq}; -use equistore::{Labels, TensorMap}; +use equistore::{Labels, TensorMap, LabelsBuilder}; use crate::calculator::LabelsSelection; use crate::{CalculationOptions, Calculator}; @@ -16,16 +16,61 @@ use crate::systems::{System, SimpleSystem, UnitCell}; pub fn compute_partial( mut calculator: Calculator, systems: &mut [Box], + keys: &Labels, samples: &Labels, properties: &Labels, ) { let full = calculator.compute(systems, Default::default()).unwrap(); + dbg!(full.keys()); + assert!(full.keys().count() < keys.count(), "selected keys should be a superset of the keys"); + check_compute_partial_keys(&mut calculator, &mut *systems, &full, keys); + + assert!(keys.count() > 3, "selected keys should have more than 3 keys"); + let mut subset_keys = LabelsBuilder::new(keys.names()); + for key in keys.iter().take(3) { + subset_keys.add(key); + } + check_compute_partial_keys(&mut calculator, &mut *systems, &full, &subset_keys.finish()); + check_compute_partial_properties(&mut calculator, &mut *systems, &full, properties); check_compute_partial_samples(&mut calculator, &mut *systems, &full, samples); check_compute_partial_both(&mut calculator, &mut *systems, &full, samples, properties); } +fn check_compute_partial_keys( + calculator: &mut Calculator, + systems: &mut [Box], + full: &TensorMap, + keys: &Labels, +) { + // select keys manually + let options = CalculationOptions { + selected_keys: Some(keys), + ..Default::default() + }; + let partial = calculator.compute(systems, options).unwrap(); + + assert_eq!(partial.keys(), keys); + for key in keys { + let mut selected_key = LabelsBuilder::new(keys.names()); + selected_key.add(key); + let selected_key = selected_key.finish(); + + let partial = partial.block(&selected_key).expect("missing block in partial"); + let full = full.block(&selected_key); + if let Ok(full) = full { + assert_eq!(full.values().samples, partial.values().samples); + assert_eq!(full.values().components, partial.values().components); + assert_eq!(full.values().properties, partial.values().properties); + + let full_values = full.values().data.as_array(); + let partial_values = partial.values().data.as_array(); + assert_ulps_eq!(full_values, partial_values); + } + } +} + fn check_compute_partial_properties( calculator: &mut Calculator, systems: &mut [Box], @@ -227,8 +272,8 @@ pub fn finite_differences_positions(mut calculator: Calculator, system: &SimpleS for (block_i, (_, block)) in reference.iter().enumerate() { let gradients = &block.gradient("positions").unwrap(); - let block_pos = &updated_pos.blocks()[block_i]; - let block_neg = &updated_neg.blocks()[block_i]; + let block_pos = &updated_pos.block_by_id(block_i); + let block_neg = &updated_neg.block_by_id(block_i); for (gradient_i, [sample_i, _, atom]) in gradients.samples.iter_fixed_size().enumerate() { if atom.usize() != atom_i { @@ -298,8 +343,8 @@ pub fn finite_differences_cell(mut calculator: Calculator, system: &SimpleSystem for (block_i, (_, block)) in reference.iter().enumerate() { let gradients = &block.gradient("cell").unwrap(); - let block_pos = &updated_pos.blocks()[block_i]; - let block_neg = &updated_neg.blocks()[block_i]; + let block_pos = &updated_pos.block_by_id(block_i); + let block_neg = &updated_neg.block_by_id(block_i); for (gradient_i, [sample_i]) in gradients.samples.iter_fixed_size().enumerate() { let sample_i = sample_i.usize(); diff --git a/rascaline/src/tutorials/moments/moments.rs b/rascaline/src/tutorials/moments/moments.rs index bcf1a8029..1cf50a50b 100644 --- a/rascaline/src/tutorials/moments/moments.rs +++ b/rascaline/src/tutorials/moments/moments.rs @@ -110,21 +110,26 @@ impl CalculatorBase for GeometricMoments { for pair in system.pairs()? { let first_block_id = descriptor.keys().position(&[ species[pair.first].into(), species[pair.second].into(), - ]).expect("missing block for the first atom"); - let first_block = &descriptor.blocks()[first_block_id]; + ]); + + let first_sample_position = if let Some(block_id) = first_block_id { + descriptor.block_by_id(block_id).values().samples.position(&[ + system_i.into(), pair.first.into() + ]) + } else { + None + }; let second_block_id = descriptor.keys().position(&[ species[pair.second].into(), species[pair.first].into(), - ]).expect("missing block for the second atom"); - let second_block = &descriptor.blocks()[second_block_id]; - - - let first_sample_position = first_block.values().samples.position(&[ - system_i.into(), pair.first.into() - ]); - let second_sample_position = second_block.values().samples.position(&[ - system_i.into(), pair.second.into() ]); + let second_sample_position = if let Some(block_id) = second_block_id { + descriptor.block_by_id(block_id).values().samples.position(&[ + system_i.into(), pair.second.into() + ]) + } else { + None + }; if first_sample_position.is_none() && second_sample_position.is_none() { continue; @@ -134,7 +139,8 @@ impl CalculatorBase for GeometricMoments { let n_neighbors_second = system.pairs_containing(pair.second)?.len() as f64; if let Some(sample_i) = first_sample_position { - let mut block = descriptor.block_mut_by_id(first_block_id); + let block_id = first_block_id.expect("we have a sample in this block"); + let mut block = descriptor.block_mut_by_id(block_id); let values = block.values_mut(); let array = values.data.as_array_mut(); @@ -145,7 +151,8 @@ impl CalculatorBase for GeometricMoments { } if let Some(sample_i) = second_sample_position { - let mut block = descriptor.block_mut_by_id(second_block_id); + let block_id = second_block_id.expect("we have a sample in this block"); + let mut block = descriptor.block_mut_by_id(block_id); let values = block.values_mut(); let array = values.data.as_array_mut(); @@ -166,7 +173,9 @@ impl CalculatorBase for GeometricMoments { } if let Some(sample_position) = first_sample_position { - let mut block = descriptor.block_mut_by_id(first_block_id); + let block_id = first_block_id.expect("we have a sample in this block"); + let mut block = descriptor.block_mut_by_id(block_id); + let gradient = block.gradient_mut("positions").expect("missing gradient storage"); let array = gradient.data.as_array_mut(); @@ -195,7 +204,9 @@ impl CalculatorBase for GeometricMoments { } if let Some(sample_position) = second_sample_position { - let mut block = descriptor.block_mut_by_id(second_block_id); + let block_id = second_block_id.expect("we have a sample in this block"); + let mut block = descriptor.block_mut_by_id(block_id); + let gradient = block.gradient_mut("positions").expect("missing gradient storage"); let array = gradient.data.as_array_mut(); @@ -258,27 +269,20 @@ mod tests { let descriptor = calculator.compute(&mut systems, Default::default()).unwrap(); // check the results - assert_eq!(descriptor.keys().names(), &["species_center", "species_neighbor"]); - assert_eq!(descriptor.keys().iter().collect::>(), [ - &[-42, 1], - &[1, -42], - &[1, 1], - &[1, 6], - &[6, 1] - ]); - - let mut expected_properties = LabelsBuilder::new(vec!["k"]); - expected_properties.add(&[0]); - let expected_properties = Arc::new(expected_properties.finish()); + assert_eq!(*descriptor.keys(), Labels::new( + ["species_center", "species_neighbor"], + &[[-42, 1], [1, -42], [1, 1], [1, 6], [6, 1]] + )); + + let expected_properties = Arc::new(Labels::new(["k"], &[[0]])); /**********************************************************************/ // O center, H neighbor let block = &descriptor.block_by_id(0); - let samples = &block.values().samples; - assert_eq!(samples.names(), ["structure", "center"]); - assert_eq!(samples.iter().collect::>(), [ - &[0, 0], - ]); + assert_eq!(*block.values().samples, Labels::new( + ["structure", "center"], + &[[0, 0]] + )); assert_eq!(block.values().properties, expected_properties); @@ -287,11 +291,10 @@ mod tests { /**********************************************************************/ // H center, O neighbor let block = &descriptor.block_by_id(1); - let samples = &block.values().samples; - assert_eq!(samples.names(), ["structure", "center"]); - assert_eq!(samples.iter().collect::>(), [ - &[0, 1], &[0, 2], - ]); + assert_eq!(*block.values().samples, Labels::new( + ["structure", "center"], + &[[0, 1], [0, 2]] + )); assert_eq!(block.values().properties, expected_properties); @@ -300,11 +303,10 @@ mod tests { /**********************************************************************/ // H center, H neighbor let block = &descriptor.block_by_id(2); - let samples = &block.values().samples; - assert_eq!(samples.names(), ["structure", "center"]); - assert_eq!(samples.iter().collect::>(), [ - &[0, 1], &[0, 2], - ]); + assert_eq!(*block.values().samples, Labels::new( + ["structure", "center"], + &[[0, 1], [0, 2]] + )); assert_eq!(block.values().properties, expected_properties); @@ -313,11 +315,10 @@ mod tests { /**********************************************************************/ // H center, C neighbor let block = &descriptor.block_by_id(3); - let samples = &block.values().samples; - assert_eq!(samples.names(), ["structure", "center"]); - assert_eq!(samples.iter().collect::>(), [ - &[1, 1], - ]); + assert_eq!(*block.values().samples, Labels::new( + ["structure", "center"], + &[[1, 1]] + )); assert_eq!(block.values().properties, expected_properties); @@ -326,11 +327,10 @@ mod tests { /**********************************************************************/ // C center, H neighbor let block = &descriptor.block_by_id(4); - let samples = &block.values().samples; - assert_eq!(samples.names(), ["structure", "center"]); - assert_eq!(samples.iter().collect::>(), [ - &[1, 0], - ]); + assert_eq!(*block.values().samples, Labels::new( + ["structure", "center"], + &[[1, 0]] + )); assert_eq!(block.values().properties, expected_properties); @@ -356,27 +356,25 @@ mod more_tests { let mut systems = test_systems(&["water", "methane"]); // build a list of samples to compute - let mut samples = LabelsBuilder::new(vec![ - "structure", "center" - ]); - samples.add(&[0, 1]); - samples.add(&[0, 2]); - samples.add(&[1, 0]); - samples.add(&[1, 2]); - let samples = samples.finish(); + let samples = Labels::new( + ["structure", "center"], + &[[0, 1], [0, 2], [1, 0], [1, 2]] + ); // create some properties. There is no need to order them in the same way // as the default calculator - let mut properties = LabelsBuilder::new(vec!["k"]); - properties.add(&[2]); - properties.add(&[1]); - properties.add(&[5]); - let properties = properties.finish(); + let properties = Labels::new(["k"], &[[2], [1], [5]]); + + // Some keys (more than the calculator would produce by default) + let keys = Labels::new( + ["species_center", "species_neighbor"], + &[[-42, 1], [1, 8], [1, -42], [8, 8], [1, 1], [1, 6], [6, 1]] + ); - // this function will check that selecting samples/properties or both will + // this function will check that selecting keys/samples/properties will // not change the result of the calculation crate::calculators::tests_utils::compute_partial( - calculator, &mut systems, &samples, &properties + calculator, &mut systems, &keys, &samples, &properties ); } // [partial-test] diff --git a/rascaline/src/tutorials/moments/s3_compute_3.rs b/rascaline/src/tutorials/moments/s3_compute_3.rs index c092a26b9..1040703c6 100644 --- a/rascaline/src/tutorials/moments/s3_compute_3.rs +++ b/rascaline/src/tutorials/moments/s3_compute_3.rs @@ -69,24 +69,32 @@ impl CalculatorBase for GeometricMoments { // get the block where the first atom is the center let first_block_id = descriptor.keys().position(&[ species[pair.first].into(), species[pair.second].into(), - ]).expect("missing block for the first atom"); - let first_block = &descriptor.blocks()[first_block_id]; + ]); + + // get the sample corresponding to the first atom as a center + // + // This will be `None` if the block or samples are not present + // in the descriptor, i.e. if the user did not request them. + let first_sample_position = if let Some(block_id) = first_block_id { + descriptor.block_by_id(block_id).values().samples.position(&[ + system_i.into(), pair.first.into() + ]) + } else { + None + }; // get the id of the block where the second atom is the center let second_block_id = descriptor.keys().position(&[ species[pair.second].into(), species[pair.first].into(), - ]).expect("missing block for the second atom"); - let second_block = &descriptor.blocks()[second_block_id]; - - // get the positions of the samples in their respective blocks. - // These variables will be `None` if the samples are not present - // in the blocks, i.e. if the user did not request them. - let first_sample_position = first_block.values().samples.position(&[ - system_i.into(), pair.first.into() - ]); - let second_sample_position = second_block.values().samples.position(&[ - system_i.into(), pair.second.into() ]); + // get the sample corresponding to the first atom as a center + let second_sample_position = if let Some(block_id) = second_block_id { + descriptor.block_by_id(block_id).values().samples.position(&[ + system_i.into(), pair.second.into() + ]) + } else { + None + }; // skip calculation if neither of the samples is present if first_sample_position.is_none() && second_sample_position.is_none() { diff --git a/rascaline/src/tutorials/moments/s3_compute_4.rs b/rascaline/src/tutorials/moments/s3_compute_4.rs index 3fc1d5678..8c4ebdc72 100644 --- a/rascaline/src/tutorials/moments/s3_compute_4.rs +++ b/rascaline/src/tutorials/moments/s3_compute_4.rs @@ -11,8 +11,8 @@ use crate::calculators::CalculatorBase; // these are here just to make the code below compile const first_sample_position: Option = None; const second_sample_position: Option = None; -const first_block_id: usize = 0; -const second_block_id: usize = 0; +const first_block_id: Option = None; +const second_block_id: Option = None; #[derive(Clone, Debug)] #[derive(serde::Serialize, serde::Deserialize)] @@ -74,7 +74,8 @@ impl CalculatorBase for GeometricMoments { let n_neighbors_second = system.pairs_containing(pair.second)?.len() as f64; if let Some(sample_i) = first_sample_position { - let mut block = descriptor.block_mut_by_id(first_block_id); + let block_id = first_block_id.expect("we have a sample in this block"); + let mut block = descriptor.block_mut_by_id(block_id); let values = block.values_mut(); let array = values.data.as_array_mut(); @@ -85,7 +86,8 @@ impl CalculatorBase for GeometricMoments { } if let Some(sample_i) = second_sample_position { - let mut block = descriptor.block_mut_by_id(second_block_id); + let block_id = second_block_id.expect("we have a sample in this block"); + let mut block = descriptor.block_mut_by_id(block_id); let values = block.values_mut(); let array = values.data.as_array_mut(); diff --git a/rascaline/src/tutorials/moments/s3_compute_5.rs b/rascaline/src/tutorials/moments/s3_compute_5.rs index 448ab92c7..8ac17787f 100644 --- a/rascaline/src/tutorials/moments/s3_compute_5.rs +++ b/rascaline/src/tutorials/moments/s3_compute_5.rs @@ -10,8 +10,8 @@ use crate::calculators::CalculatorBase; // these are here just to make the code below compile const first_sample_position: Option = None; const second_sample_position: Option = None; -const first_block_id: usize = 0; -const second_block_id: usize = 0; +const first_block_id: Option = None; +const second_block_id: Option = None; const n_neighbors_first: f64 = 0.0; const n_neighbors_second: f64 = 0.0; @@ -86,7 +86,9 @@ impl CalculatorBase for GeometricMoments { } if let Some(sample_position) = first_sample_position { - let mut block = descriptor.block_mut_by_id(first_block_id); + let block_id = first_block_id.expect("we have a sample in this block"); + let mut block = descriptor.block_mut_by_id(block_id); + let gradient = block.gradient_mut("positions").expect("missing gradient storage"); let array = gradient.data.as_array_mut(); @@ -118,7 +120,9 @@ impl CalculatorBase for GeometricMoments { } if let Some(sample_position) = second_sample_position { - let mut block = descriptor.block_mut_by_id(second_block_id); + let block_id = second_block_id.expect("we have a sample in this block"); + let mut block = descriptor.block_mut_by_id(block_id); + let gradient = block.gradient_mut("positions").expect("missing gradient storage"); let array = gradient.data.as_array_mut();