diff --git a/rascaline/src/calculators/bondatom_neighbor_list.rs b/rascaline/src/calculators/bondatom_neighbor_list.rs index f5c3c3da5..6ce88b9da 100644 --- a/rascaline/src/calculators/bondatom_neighbor_list.rs +++ b/rascaline/src/calculators/bondatom_neighbor_list.rs @@ -759,7 +759,7 @@ mod tests { use metatensor::Labels; use crate::calculators::bondatom_neighbor_list::FullBANeighborList; - use crate::systems::test_utils::{test_systems, test_system}; + use crate::systems::test_utils::{test_systems}; use crate::Calculator; use super::{BANeighborList, HalfBANeighborList}; diff --git a/rascaline/src/calculators/descriptors_by_systems.rs b/rascaline/src/calculators/descriptors_by_systems.rs index 34bf0aee4..d3eecdbb6 100644 --- a/rascaline/src/calculators/descriptors_by_systems.rs +++ b/rascaline/src/calculators/descriptors_by_systems.rs @@ -22,6 +22,19 @@ struct UnsafeArrayViewMut { data: *mut f64, } +impl UnsafeArrayViewMut{ + fn as_arrayview(&self) -> ndarray::ArrayView { + // SAFETY: we checked that the arrays do not overlap when creating + // `UnsafeArrayViewMut` in split_by_system + unsafe{ndarray::ArrayView::from_shape_ptr(self.shape.clone(), self.data)} + } + fn as_arrayview_mut(&mut self) -> ndarray::ArrayViewMut { + // SAFETY: we checked that the arrays do not overlap when creating + // `UnsafeArrayViewMut` in split_by_system + unsafe{ndarray::ArrayViewMut::from_shape_ptr(self.shape.clone(), self.data)} + } +} + // SAFETY: `UnsafeArrayViewMut` can be transferred from one thread to another unsafe impl Send for UnsafeArrayViewMut {} // SAFETY: `UnsafeArrayViewMut` is Sync since there is no interior mutability @@ -74,12 +87,7 @@ impl metatensor::Array for UnsafeArrayViewMut { /// Extract an array stored in the `TensorBlock` returned by `split_tensor_map_by_system` pub fn array_mut_for_system(array: metatensor::ArrayRefMut<'_>) -> ArrayViewMutD<'_, f64> { let array = array.to_any_mut().downcast_mut::().expect("invalid array type"); - - // SAFETY: we checked that the arrays do not overlap when creating - // `UnsafeArrayViewMut` in split_by_system - return unsafe { - ArrayViewMutD::from_shape_ptr(array.shape.clone(), array.data) - }; + array.as_arrayview_mut() } /// View inside a `TensorMap` corresponding to one system diff --git a/rascaline/src/calculators/mod.rs b/rascaline/src/calculators/mod.rs index a9b66771f..f8d61a812 100644 --- a/rascaline/src/calculators/mod.rs +++ b/rascaline/src/calculators/mod.rs @@ -96,6 +96,8 @@ pub(crate) use self::descriptors_by_systems::{array_mut_for_system, split_tensor pub mod soap; pub use self::soap::{SphericalExpansionByPair, SphericalExpansionParameters}; pub use self::soap::SphericalExpansion; +pub use self::soap::{SphericalExpansionForBondType, SphericalExpansionForBondsParameters}; +pub use self::soap::SphericalExpansionForBonds; pub use self::soap::{SoapPowerSpectrum, PowerSpectrumParameters}; pub use self::soap::{SoapRadialSpectrum, RadialSpectrumParameters}; diff --git a/rascaline/src/calculators/soap/mod.rs b/rascaline/src/calculators/soap/mod.rs index 8afa7d866..8ca13dcd8 100644 --- a/rascaline/src/calculators/soap/mod.rs +++ b/rascaline/src/calculators/soap/mod.rs @@ -12,9 +12,15 @@ pub use self::cutoff::RadialScaling; mod spherical_expansion_pair; pub use self::spherical_expansion_pair::{SphericalExpansionByPair, SphericalExpansionParameters}; +mod spherical_expansion_bondcentered_pair; +pub use self::spherical_expansion_bondcentered_pair::{SphericalExpansionForBondType, SphericalExpansionForBondsParameters}; + mod spherical_expansion; pub use self::spherical_expansion::SphericalExpansion; +mod spherical_expansion_bondcentered; +pub use self::spherical_expansion_bondcentered::SphericalExpansionForBonds; + mod power_spectrum; pub use self::power_spectrum::{SoapPowerSpectrum, PowerSpectrumParameters}; diff --git a/rascaline/src/calculators/soap/spherical_expansion_bondcentered.rs b/rascaline/src/calculators/soap/spherical_expansion_bondcentered.rs new file mode 100644 index 000000000..bd1fa77d9 --- /dev/null +++ b/rascaline/src/calculators/soap/spherical_expansion_bondcentered.rs @@ -0,0 +1,539 @@ +use std::collections::{BTreeMap}; +use std::collections::btree_map::Entry; + +use ndarray::s; +use rayon::prelude::*; + +use metatensor::{LabelsBuilder, Labels, LabelValue}; +use metatensor::TensorMap; + +use crate::{Error, System}; + +use crate::labels::{SamplesBuilder, SpeciesFilter, BondCenteredSamples}; +use crate::labels::{KeysBuilder, TwoCentersSingleNeighborsSpeciesKeys}; + +use super::super::CalculatorBase; + +use super::{SphericalExpansionForBondType, SphericalExpansionForBondsParameters}; +use super::spherical_expansion_pair::{GradientsOptions}; + +use super::super::{split_tensor_map_by_system, array_mut_for_system}; + + +/// The actual calculator used to compute SOAP spherical expansion coefficients +#[derive(Debug)] +pub struct SphericalExpansionForBonds { + /// Underlying calculator, computing spherical expansion on pair at the time + by_type: SphericalExpansionForBondType, + // /// Cache for (-1)^l values + //m_1_pow_l: Vec, +} + +impl SphericalExpansionForBonds { + /// Create a new `SphericalExpansion` calculator with the given parameters + pub fn new(parameters: SphericalExpansionForBondsParameters) -> Result { + // let m_1_pow_l = (0..=parameters.max_angular) + // .map(|l| f64::powi(-1.0, l as i32)) + // .collect::>(); + + return Ok(Self { + by_type: SphericalExpansionForBondType::new(parameters)?, + //m_1_pow_l, + }); + } +} + +impl CalculatorBase for SphericalExpansionForBonds { + fn name(&self) -> String { + "spherical expansion".into() + } + + fn cutoffs(&self) -> &[f64] { + &self.by_type.parameters.cutoffs + } + + fn parameters(&self) -> String { + serde_json::to_string(self.by_type.parameters()).expect("failed to serialize to JSON") + } + + fn keys(&self, systems: &mut [Box]) -> Result { + let builder = TwoCentersSingleNeighborsSpeciesKeys { + cutoffs: self.by_type.parameters.cutoffs, + self_contributions: true, + }; + let keys = builder.keys(systems)?; + + let mut builder = LabelsBuilder::new(vec!["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"]); + for &[species_center_1, species_center_2, species_neighbor] in keys.iter_fixed_size() { + for spherical_harmonics_l in 0..=self.by_type.parameters().max_angular { + builder.add(&[spherical_harmonics_l.into(), species_center_1, species_center_2, species_neighbor]); + } + } + + return Ok(builder.finish()); + } + + fn samples_names(&self) -> Vec<&str> { + BondCenteredSamples::samples_names() + } + + fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + assert_eq!(keys.names(), ["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"]); + + // only compute the samples once for each `species_center, species_neighbor`, + // and re-use the results across `spherical_harmonics_l`. + let mut samples_per_species = BTreeMap::new(); + for [_, species_center_1, species_center_2, species_neighbor] in keys.iter_fixed_size() { + if samples_per_species.contains_key(&(species_center_1, species_center_2, species_neighbor)) { + continue; + } + + let builder = BondCenteredSamples { + cutoffs: self.by_type.parameters().cutoffs, + species_center_1: SpeciesFilter::Single(species_center_1.i32()), + species_center_2: SpeciesFilter::Single(species_center_2.i32()), + species_neighbor: SpeciesFilter::Single(species_neighbor.i32()), + self_contributions: true, + }; + + samples_per_species.insert((species_center_1, species_center_2, species_neighbor), builder.samples(systems)?); + } + + let mut result = Vec::new(); + for [_, species_center_1, species_center_2, species_neighbor] in keys.iter_fixed_size() { + let samples = samples_per_species.get( + &(species_center_1, species_center_2, species_neighbor) + ).expect("missing samples"); + + result.push(samples.clone()); + } + + return Ok(result); + } + + fn supports_gradient(&self, parameter: &str) -> bool { + self.by_type.supports_gradient(parameter) + } + + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + assert_eq!(keys.names(), ["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"]); + assert_eq!(keys.count(), samples.len()); + + let mut gradient_samples = Vec::new(); + for ([_, species_center_1, species_center_2, species_neighbor], samples) in keys.iter_fixed_size().zip(samples) { + // TODO: we don't need to rebuild the gradient samples for different + // spherical_harmonics_l + let builder = BondCenteredSamples { + cutoffs: self.by_type.parameters().cutoffs, + species_center_1: SpeciesFilter::Single(species_center_1.i32()), + species_center_2: SpeciesFilter::Single(species_center_2.i32()), + species_neighbor: SpeciesFilter::Single(species_neighbor.i32()), + self_contributions: true, + }; + + gradient_samples.push(builder.gradients_for(systems, samples)?); + } + + return Ok(gradient_samples); + } + + fn components(&self, keys: &Labels) -> Vec> { + assert_eq!(keys.names(), ["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"]); + + // only compute the components once for each `spherical_harmonics_l`, + // and re-use the results across `species_center, species_neighbor`. + let mut component_by_l = BTreeMap::new(); + for [spherical_harmonics_l, _, _, _] in keys.iter_fixed_size() { + if component_by_l.contains_key(spherical_harmonics_l) { + continue; + } + + let mut component = LabelsBuilder::new(vec!["spherical_harmonics_m"]); + for m in -spherical_harmonics_l.i32()..=spherical_harmonics_l.i32() { + component.add(&[LabelValue::new(m)]); + } + + let components = vec![component.finish()]; + component_by_l.insert(*spherical_harmonics_l, components); + } + + let mut result = Vec::new(); + for [spherical_harmonics_l, _, _, _] in keys.iter_fixed_size() { + let components = component_by_l.get(spherical_harmonics_l).expect("missing samples"); + result.push(components.clone()); + } + return result; + } + + fn properties_names(&self) -> Vec<&str> { + vec!["n"] + } + + fn properties(&self, keys: &Labels) -> Vec { + let mut properties = LabelsBuilder::new(self.properties_names()); + for n in 0..self.by_type.parameters().max_radial { + properties.add(&[n]); + } + let properties = properties.finish(); + + return vec![properties; keys.count()]; + } + + #[time_graph::instrument(name = "SphericalExpansion::compute")] + fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + assert_eq!(descriptor.keys().names(), ["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"]); + if descriptor.blocks().len() == 0 { + return Ok(()); + } + + let max_angular = self.by_type.parameters().max_angular; + let l_slices: Vec<_> = (0..=max_angular).map(|l|{ + let lsize = l*l; + let msize = 2*l+1; + lsize..lsize+msize + }).collect(); + + let do_gradients = GradientsOptions { + positions: descriptor.block_by_id(0).gradient("positions").is_some(), + cell: descriptor.block_by_id(0).gradient("cell").is_some(), + }; + if do_gradients.positions { + assert!(self.by_type.supports_gradient("positions")); + } + if do_gradients.cell { + assert!(self.by_type.supports_gradient("cell")); + } + + let radial_selection = descriptor.blocks().iter().map(|b|{ + let prop = b.properties(); + assert_eq!(prop.names(), ["n"]); + prop.iter_fixed_size().map(|&[n]|n.i32()).collect::>() + }).collect::>(); + // first, create some partial-key -> block lookup tables to avoid linear searches within blocks later + + // {(s1,s2,s3) -> i_s3} + let mut s1s2s3_to_is3: BTreeMap<(i32,i32,i32),usize> = BTreeMap::new(); + // {(s1,s2) -> [i_s3->(s3,[l->i_block])]} + let mut s1s2_to_block_ids: BTreeMap<(i32,i32),Vec<(i32,Vec)>> = BTreeMap::new(); + + for (block_i, &[l, s1,s2,s3]) in descriptor.keys().iter_fixed_size().enumerate(){ + let s1=s1.i32(); + let s2=s2.i32(); + let s3=s3.i32(); + let l=l.usize(); + let s1s2_blocks = s1s2_to_block_ids.entry((s1,s2)) + .or_insert_with(Vec::new); + let l_blocks = match s1s2s3_to_is3.entry((s1,s2,s3)) { + Entry::Occupied(i_s3_e) => { + let (s3_b, l_blocks) = & mut s1s2_blocks[*i_s3_e.get()]; + debug_assert_eq!(s3_b,&s3); + l_blocks + }, + Entry::Vacant(i_s3_e) => { + let i_s3 = s1s2_blocks.len(); + i_s3_e.insert(i_s3); + s1s2_blocks.push((s3,vec![usize::MAX;max_angular+1])); + &mut s1s2_blocks[i_s3].1 + }, + }; + l_blocks[l] = block_i; + } + + #[cfg(debug_assertions)]{ + for block in descriptor.blocks() { + assert_eq!(block.samples().names(), ["structure", "first_center", "second_center", "bond_i"]); + } + } + let mut descriptors_by_system = split_tensor_map_by_system(descriptor, systems.len()); + + systems.par_iter_mut() + .zip_eq(&mut descriptors_by_system) + .try_for_each(|(system, descriptor)| + { + system.compute_triplet_neighbors(self.by_type.parameters().bond_cutoff(), self.by_type.parameters().third_cutoff())?; + let system = &**system; + let species = system.species()?; + let triplets = system.triplets()?; + + for ((s1,s2),s1s2_blocks) in s1s2_to_block_ids.iter() { + let (s3_list,per_s3_blocks): (Vec,Vec<&Vec<_>>) = s1s2_blocks.iter().map( + |(s3,blocks)|(*s3,blocks) + ).unzip(); + // half-assume that blocks that share s1,s2,s3 have the same sample list + #[cfg(debug_assertions)]{ + for (_s3,s3blocks) in s1s2_blocks.iter(){ + debug_assert!(s3blocks.len()>0); + let mut s3goodblocks = s3blocks.iter().filter(|b_i|(**b_i)!=usize::MAX); + let first_goodblock = s3goodblocks.next(); + debug_assert!(first_goodblock.is_some()); + + let samples_n = descriptor.block_by_id(*first_goodblock.unwrap()).samples().size(); + for lblock in s3goodblocks { + debug_assert_eq!(descriptor.block_by_id(*lblock).samples().size(), samples_n); + } + } + } + // {bond_i->(i_s3,sample_i)} + let mut s3_samples = vec![]; + let mut sample_lut: BTreeMap> = BTreeMap::new(); + + // also assume that the systems are in order in the samples + for (i_s3, s3blocks) in per_s3_blocks.into_iter().enumerate() { + let first_good_block = s3blocks.iter().filter(|b_i|**b_i!=usize::MAX).next().unwrap(); + let samples = descriptor.block_by_id(*first_good_block).samples(); + for (sample_i, &[_system_i,_atom_1,_atom_2,bond_i]) in samples.iter_fixed_size().enumerate(){ + match sample_lut.entry(bond_i.usize()) { + Entry::Vacant(e) => { + e.insert(vec![(i_s3,sample_i)]); + }, + Entry::Occupied(mut e) => { + e.get_mut().push((i_s3,sample_i)); + }, + } + } + s3_samples.push(samples); + } + for (triplet_i,inverted,contribution) in self.by_type.get_coefficients_for(system, *s1, *s2, &s3_list, do_gradients)? { + let triplet = &triplets[triplet_i]; + + #[cfg(debug_assertions)]{ + if triplet.bond.first == triplet.bond.second{ + unimplemented!("cannot deal with bonds formed of self-images quite yet.") + } + } + + let contribution = contribution.borrow(); + let these_samples = match sample_lut.get(&triplet.bond_i){ + None => {continue;}, + Some(a) => a, + }; + + for (i_s3,sample_i) in these_samples.iter(){ + if s3_list[*i_s3] != species[triplet.third] { + continue // this triplet does not contribute to this block + } + let sample = &s3_samples[*i_s3][*sample_i]; + let (atom_1,atom_2) = (sample[1],sample[2]); + if (!inverted) && triplet.bond.first != atom_1.usize(){ + continue; + } else if inverted && triplet.bond.first != atom_2.usize(){ + continue; + } + + let ret_blocks = &s1s2_blocks[*i_s3].1; + for (l,lslice) in l_slices.iter().enumerate() { + let block_i = ret_blocks[l]; + if block_i == usize::MAX { + continue; + } + let mut block = descriptor.block_mut_by_id(block_i); + let mut array = array_mut_for_system(block.values_mut()); + let mut value_slice = array.slice_mut(s![*sample_i,..,..]); + let input_slice = contribution.values.slice(s![lslice.clone(),..]); + for (n_i,n) in radial_selection[block_i].iter().enumerate() { + let mut value_slice = value_slice.slice_mut(s![..,n_i]); + value_slice += &input_slice.slice(s![..,*n]); + } + + } + } + } + } + Ok::<_, Error>(()) + })?; + + Ok(()) + } +} + + +#[cfg(test)] +mod tests { + use ndarray::ArrayD; + use metatensor::{Labels, TensorBlock, EmptyArray, LabelsBuilder, TensorMap}; + + use crate::systems::test_utils::{test_systems}; + use crate::{Calculator, CalculationOptions, LabelsSelection}; + use crate::calculators::CalculatorBase; + + use super::{SphericalExpansionForBonds, SphericalExpansionForBondsParameters}; + use super::super::{CutoffFunction, RadialScaling}; + use crate::calculators::radial_basis::RadialBasis; + + + fn parameters() -> SphericalExpansionForBondsParameters { + SphericalExpansionForBondsParameters { + cutoffs: [3.5,3.5], + max_radial: 6, + max_angular: 6, + atomic_gaussian_width: 0.3, + center_atoms_weight: 10.0, + radial_basis: RadialBasis::splined_gto(1e-8), + radial_scaling: RadialScaling::Willatt2018 { scale: 1.5, rate: 0.8, exponent: 2.0}, + cutoff_function: CutoffFunction::ShiftedCosine { width: 0.5 }, + } + } + + #[test] + fn values() { + let mut calculator = Calculator::from(Box::new(SphericalExpansionForBonds::new( + parameters() + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + let descriptor = calculator.compute(&mut systems, Default::default()).unwrap(); + + for l in 0..6 { + for species_center1 in [1, -42] { + for species_center2 in [1, -42] { + if species_center1==-42 && species_center2==-42 { + continue; + } + for species_neighbor in [1, -42] { + let block_i = descriptor.keys().position(&[ + l.into(), species_center1.into(), species_center2.into(), species_neighbor.into() + ]); + assert!(block_i.is_some()); + let block = &descriptor.block_by_id(block_i.unwrap()); + let array = block.values().to_array(); + assert_eq!(array.shape().len(), 3); + assert_eq!(array.shape()[1], 2 * l + 1); + } + } + } + } + + // exact values for spherical expansion are regression-tested in + // `rascaline/tests/spherical-expansion.rs` + } + + #[test] + fn compute_partial() { + let calculator = Calculator::from(Box::new(SphericalExpansionForBonds::new( + SphericalExpansionForBondsParameters { + max_angular: 2, + ..parameters() + } + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + + let properties = Labels::new(["n"], &[ + [0], + [3], + [2], + ]); + + let samples = Labels::new(["structure", "first_center", "second_center", "bond_i"], &[ + [0, 0, 2, 1], + [0, 0, 1, 0], + //[0, 1, 2, 2], // excluding this one + ]); + + let keys = Labels::new(["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"], &[ + [0, -42, 1, -42], + [2, -42, 1, -42], + [0, 1, -42, -42], + [0, 1, 1, -42], + [0, 6, 1, 1], // not part of the default keys + [1, -42, 1, -42], + [1, -42, 1, 1], + [1, 1, -42, 1], + [0, -42, 1, 1], + [1, 1, 1, -42], + [2, -42, 1, 1], + [0, 1, 1, 1], + [2, 1, -42, -42], + [1, 1, 1, 1], + [2, 1, 1, -42], + [0, 1, -42, 1], + [2, 1, -42, 1], + [2, 1, 1, 1], + [1, 1, -42, -42], + ]); + + crate::calculators::tests_utils::compute_partial( + calculator, &mut systems, &keys, &samples, &properties + ); + } + + #[test] + fn non_existing_samples() { + let mut calculator = Calculator::from(Box::new(SphericalExpansionForBonds::new( + parameters() + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + + // include the three atoms in all blocks, regardless of the + // species_center key. + let block = TensorBlock::new( + EmptyArray::new(vec![3, 1]), + &Labels::new(["structure", "first_center", "second_center", "bond_i"], &[ + [0, 0, 2, 1], + [0, 1, 2, 2], + [0, 0, 1, 0], + ]), + &[], + &Labels::single(), + ).unwrap(); + + let mut keys = LabelsBuilder::new(vec!["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"]); + let mut blocks = Vec::new(); + for l in 0..(parameters().max_angular + 1) as isize { + for species_center1 in [1, -42] { + for species_center2 in [1, -42] { + for species_neighbor in [1, -42] { + keys.add(&[l, species_center1, species_center2, species_neighbor]); + blocks.push(block.as_ref().try_clone().unwrap()); + } + } + } + } + let select_all_samples = TensorMap::new(keys.finish(), blocks).unwrap(); + + let options = CalculationOptions { + selected_samples: LabelsSelection::Predefined(&select_all_samples), + ..Default::default() + }; + let descriptor = calculator.compute(&mut systems, options).unwrap(); + + // get the block for oxygen + assert_eq!(descriptor.keys().names(), ["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"]); + assert_eq!(descriptor.keys()[0], [0, -42, 1, -42]); + + let block = descriptor.block_by_id(0); + let block = block.data(); + + // entries centered on H atoms should be zero + assert_eq!( + *block.samples, + Labels::new(["structure", "first_center", "second_center", "bond_i"], &[ + [0, 0, 2, 1], + [0, 1, 2, 2], // the sample that doesn't exist + [0, 0, 1, 0], + ]) + ); + let array = block.values.as_array(); + assert_eq!(array.index_axis(ndarray::Axis(0), 1), ArrayD::from_elem(vec![1, 6], 0.0)); + + // get the block for hydrogen + assert_eq!(descriptor.keys().names(), ["spherical_harmonics_l", "species_center_1", "species_center_2", "species_neighbor"]); + assert_eq!(descriptor.keys()[21], [0, 1, -42, 1]); + + let block = descriptor.block_by_id(21); + let block = block.data(); + + // entries centered on O atoms should be zero + assert_eq!( + *block.samples, + Labels::new(["structure", "first_center", "second_center", "bond_i"], &[ + [0, 0, 2, 1], + [0, 1, 2, 2], + [0, 0, 1, 0], + ]) + ); + let array = block.values.as_array(); + assert_eq!(array.index_axis(ndarray::Axis(0), 0), ArrayD::from_elem(vec![1, 6], 0.0)); + } +} diff --git a/rascaline/src/calculators/soap/spherical_expansion_bondcentered_pair.rs b/rascaline/src/calculators/soap/spherical_expansion_bondcentered_pair.rs new file mode 100644 index 000000000..d8b83d3fa --- /dev/null +++ b/rascaline/src/calculators/soap/spherical_expansion_bondcentered_pair.rs @@ -0,0 +1,733 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::collections::btree_map::Entry; +use std::cell::RefCell; + +use ndarray::s; +use thread_local::ThreadLocal; +use rayon::prelude::*; + +use metatensor::{Labels, LabelsBuilder, LabelValue, TensorMap}; + +use crate::{Error, System, Vector3D}; + +use crate::math::SphericalHarmonicsCache; + +use super::super::CalculatorBase; +use super::super::bondatom_neighbor_list::{FullBANeighborList,BANeighborList}; + +use super::{CutoffFunction, RadialScaling}; + +use crate::calculators::radial_basis::RadialBasis; +use super::SoapRadialIntegralCache; + +use super::radial_integral::SoapRadialIntegralParameters; + +use super::spherical_expansion_pair::{ + SphericalExpansionParameters, + SphericalExpansionByPair, + GradientsOptions, + PairContribution +}; + +/// Parameters for spherical expansion calculator. +/// +/// The spherical expansion is at the core of representations in the SOAP +/// (Smooth Overlap of Atomic Positions) family. See [this review +/// article](https://doi.org/10.1063/1.5090481) for more information on the SOAP +/// representation, and [this paper](https://doi.org/10.1063/5.0044689) for +/// information on how it is implemented in rascaline. +#[derive(Debug, Clone)] +#[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)] +pub struct SphericalExpansionForBondsParameters { + /// Spherical cutoffs to use for atomic environments + pub(super) cutoffs: [f64;2], // bond_, third_cutoff + //pub bond_cutoff: f64, + //pub third_cutoff: f64, + //pub compute_both_sides: bool, + /// Number of radial basis function to use in the expansion + pub max_radial: usize, + /// Number of spherical harmonics to use in the expansion + pub max_angular: usize, + /// Width of the atom-centered gaussian used to create the atomic density + pub atomic_gaussian_width: f64, + /// Weight of the central atom contribution to the + /// features. If `1` the center atom contribution is weighted the same + /// as any other contribution. If `0` the central atom does not + /// contribute to the features at all. + pub center_atoms_weight: f64, + /// Radial basis to use for the radial integral + pub radial_basis: RadialBasis, + /// Cutoff function used to smooth the behavior around the cutoff radius + pub cutoff_function: CutoffFunction, + /// radial scaling can be used to reduce the importance of neighbor atoms + /// further away from the center, usually improving the performance of the + /// model + #[serde(default)] + pub radial_scaling: RadialScaling, +} + +impl SphericalExpansionForBondsParameters { + /// Validate all the parameters + pub fn validate(&self) -> Result<(), Error> { + self.cutoff_function.validate()?; + self.radial_scaling.validate()?; + + // try constructing a radial integral + SoapRadialIntegralCache::new(self.radial_basis.clone(), SoapRadialIntegralParameters { + max_radial: self.max_radial, + max_angular: self.max_angular, + atomic_gaussian_width: self.atomic_gaussian_width, + cutoff: self.third_cutoff(), + })?; + + return Ok(()); + } + pub fn bond_cutoff(&self) -> f64 { + self.cutoffs[0] + } + pub fn third_cutoff(&self) -> f64 { + self.cutoffs[1] + } +} + +impl Into for SphericalExpansionForBondsParameters { + fn into(self) -> SphericalExpansionParameters{ + SphericalExpansionParameters{ + cutoff: self.third_cutoff(), + max_radial: self.max_radial, + max_angular: self.max_angular, + atomic_gaussian_width: self.atomic_gaussian_width, + center_atom_weight: self.center_atoms_weight, + radial_basis: self.radial_basis, + cutoff_function: self.cutoff_function, + radial_scaling: self.radial_scaling, + } + } +} + +/// The actual calculator used to compute spherical expansion pair-by-pair +pub struct SphericalExpansionForBondType { + pub(crate) parameters: SphericalExpansionForBondsParameters, + /// several functions require the SphericalExpansionForBonds to behave like a regular spherical expansion + /// let's store most of the data in an actual SphericalExpansionByPair object! + faker: SphericalExpansionByPair, + distance_calculator: BANeighborList, + +} + +impl std::fmt::Debug for SphericalExpansionForBondType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.parameters) + } +} + + +impl SphericalExpansionForBondType { + pub fn new(parameters: SphericalExpansionForBondsParameters) -> Result { + parameters.validate()?; + + let m_1_pow_l = (0..=parameters.max_angular) + .map(|l| f64::powi(-1.0, l as i32)) + .collect::>(); + + Ok(SphericalExpansionForBondType { + faker: SphericalExpansionByPair{ + parameters: parameters.clone().into(), + radial_integral: ThreadLocal::new(), + spherical_harmonics: ThreadLocal::new(), + m_1_pow_l, + }, + distance_calculator: BANeighborList::Full(FullBANeighborList { + cutoffs: parameters.cutoffs, + bond_contribution: true, + }), + parameters, + }) + } + + /// Access the spherical expansion parameters used by this calculator + pub fn parameters(&self) -> &SphericalExpansionForBondsParameters { + &self.parameters + } + + /// Compute the product of radial scaling & cutoff smoothing functions + fn scaling_functions(&self, r: f64) -> f64 { + self.faker.scaling_functions(r) + } + + /// Compute the gradient of the product of radial scaling & cutoff smoothing functions + fn scaling_functions_gradient(&self, r: f64) -> f64 { + self.faker.scaling_functions_gradient(r) + } + + /// Compute the self-contribution (contribution coming from an atom "seeing" + /// it's own density). This is equivalent to a normal pair contribution, + /// with a distance of 0. + /// + /// For now, the same density is used for all atoms, so this function can be + /// called only once and re-used for all atoms (see `do_self_contributions` + /// below). + /// + /// By symmetry, the self-contribution is only non-zero for `L=0`, and does + /// not contributes to the gradients. + pub(super) fn compute_coefficients(&self, contribution: &mut PairContribution, vector: Vector3D, is_self_contribution: bool, gradients: Option<(Vector3D,Vector3D,Vector3D)>){ + let mut radial_integral = self.faker.radial_integral.get_or(|| { + let radial_integral = SoapRadialIntegralCache::new( + self.parameters.radial_basis.clone(), + SoapRadialIntegralParameters { + max_radial: self.parameters.max_radial, + max_angular: self.parameters.max_angular, + atomic_gaussian_width: self.parameters.atomic_gaussian_width, + cutoff: self.parameters.third_cutoff(), + } + ).expect("invalid radial integral parameters"); + return RefCell::new(radial_integral); + }).borrow_mut(); + + let mut spherical_harmonics = self.faker.spherical_harmonics.get_or(|| { + RefCell::new(SphericalHarmonicsCache::new(self.parameters.max_angular)) + }).borrow_mut(); + + let distance = vector.norm(); + let direction = vector/distance; + // Compute the three factors that appear in the center contribution. + // Note that this is simply the pair contribution for the special + // case where the pair distance is zero. + radial_integral.compute(distance, gradients.is_some()); + spherical_harmonics.compute(direction, gradients.is_some()); + + let f_scaling = self.scaling_functions(distance); + let f_scaling = if is_self_contribution{ + f_scaling * self.parameters.center_atoms_weight + } else { + f_scaling + }; + + let (values, gradient_values_o) = (&mut contribution.values, contribution.gradients.as_mut()); + + debug_assert_eq!( + values.shape(), + [(self.parameters.max_angular+1)*(self.parameters.max_angular+1), self.parameters.max_radial] + ); + for l in 0..=self.parameters.max_angular { + let l_offset = l*l; + let msize = 2*l+1; + //values.slice_mut(s![l_offset..l_offset+msize, ..]) *= spherical_harmonics.values.slice(l); + for m in 0..msize { + let lm = l_offset+m; + for n in 0..self.parameters.max_radial { + values[[lm, n]] = spherical_harmonics.values[lm] + * radial_integral.values[[l,n]] + * f_scaling; + } + } + } + + if let Some((dvdx,dvdy,dvdz)) = gradients { + let gradient_values = gradient_values_o.unwrap(); + + let ilen = 1./distance; + let dlendv = vector*ilen; + let dlendx = dlendv*dvdx; + let dlendy = dlendv*dvdy; + let dlendz = dlendv*dvdz; + let ddirdx = dvdx*ilen - vector*dlendx*ilen*ilen; + let ddirdy = dvdy*ilen - vector*dlendy*ilen*ilen; + let ddirdz = dvdy*ilen - vector*dlendz*ilen*ilen; + + let single_grad = |l,n,m,dlenda,ddirda: Vector3D| { + f_scaling * ( + radial_integral.gradients[[l,n]]*dlenda*spherical_harmonics.values[[l as isize,m as isize]] + + radial_integral.values[[l,n]]*( + spherical_harmonics.gradients[0][[l as isize,m as isize]]*ddirda[0] + +spherical_harmonics.gradients[1][[l as isize,m as isize]]*ddirda[1] + +spherical_harmonics.gradients[2][[l as isize,m as isize]]*ddirda[2] + ) + // todo scaling_function_gradient + ) + }; + + for l in 0..=self.parameters.max_angular { + let l_offset = l*l; + let msize = 2*l+1; + for m in 0..(msize) { + let lm = l_offset+m; + for n in 0..self.parameters.max_radial { + gradient_values[[0,lm,n]] = single_grad(l,n,m,dlendx,ddirdx); + gradient_values[[1,lm,n]] = single_grad(l,n,m,dlendy,ddirdy); + gradient_values[[2,lm,n]] = single_grad(l,n,m,dlendz,ddirdz); + } + } + } + } + } + + /// a smart-ish way to obtain the coefficients of all bond expansions: + /// this function's API is designed to be resource-efficient for both SphericalExpansionForBondType and + /// SphericalExpansionForBonds, while being computationally efficient for the underlying BANeighborList calculator. + pub(super) fn get_coefficients_for<'a>( + &'a self, system: &'a dyn System, + s1: i32, s2: i32, s3_list: &'a Vec, + do_gradients: GradientsOptions, + ) -> Result>)> + 'a, Error> { + + let max_angular = self.parameters.max_angular; + let max_radial = self.parameters.max_radial; + let species = system.species().unwrap(); + + let _ = system.triplets()?; + for s3 in s3_list{ + let triplets = system.triplets_with_species(s1, s2, *s3); + if let Err(error) = triplets { + return Err(error); + } + } + + let pre_iter = s3_list.iter().flat_map(|s3|{ + let all_triplets = system.triplets().unwrap(); + let triplets = system.triplets_with_species(s1, s2, *s3).unwrap(); + triplets.iter().map(|triplet_i|{ + let triplet = &all_triplets[*triplet_i]; + #[cfg(debug_assertions)]{ + assert_eq!(species[triplet.bond.first],s1); + assert_eq!(species[triplet.bond.second],s2); + assert_eq!(species[triplet.third],*s3); + } + (*triplet_i,triplet) + }) + }).flat_map(|(triplet_i, triplet)| { + let invert: &[bool] = { + if s1==s2 {&[false,true]} + else if species[triplet.bond.first] == s1 {&[false]} + else {&[true]} + }; + invert.iter().map(move |invert|(triplet_i,triplet,*invert)) + }).collect::>(); + + let contribution = std::rc::Rc::new(RefCell::new( + PairContribution::new(max_radial, max_angular, do_gradients.either()) + )); + + let mut mtx_cache = BTreeMap::new(); + let mut dmtx_cache = BTreeMap::new(); + + return Ok(pre_iter.into_iter().map(move |(triplet_i,triplet,invert)| { + let vector = BANeighborList::compute_single_triplet(triplet, invert, false, &mut mtx_cache, &mut dmtx_cache); + self.compute_coefficients(&mut *contribution.borrow_mut(), vector.vect,triplet.is_self_contribution,None); + (triplet_i, invert, contribution.clone()) + })); + + } +} + + +impl CalculatorBase for SphericalExpansionForBondType { + fn name(&self) -> String { + "spherical expansion by pair".into() + } + + fn cutoffs(&self) -> &[f64] { + &self.parameters.cutoffs + } + + fn parameters(&self) -> String { + serde_json::to_string(&self.parameters).expect("failed to serialize to JSON") + } + + fn keys(&self, systems: &mut [Box]) -> Result { + // the species part of the keys is the same for all l + let species_keys = self.distance_calculator.keys(systems)?; + + let all_species_triplets = species_keys.iter().map(|p| (p[0], p[1], p[2])).collect::>(); + + let mut keys = LabelsBuilder::new(vec![ + "spherical_harmonics_l", + "species_bond_atom_1", + "species_bond_atom_2", + "species_third_atom", + ]); + + for (s1, s2, s3) in all_species_triplets { + for l in 0..=self.parameters.max_angular { + keys.add(&[l.into(), s1, s2, s3]); + } + } + + + return Ok(keys.finish()); + } + + fn samples_names(&self) -> Vec<&str> { + return vec!["structure", "triplet_i", "bond_i", "first_bond_atom", "second_bond_atom", "third_atom"]; + } + + fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + let newkey_values = keys.iter_fixed_size().map(|&[_l, s1,s2,s3]|{ [s1.i32(),s2.i32(),s3.i32()] }).collect::>(); + let mut newkeys_lut = vec![0_usize;newkey_values.len()]; + let new_unique_keys: Vec<[i32;3]> = BTreeSet::from_iter(newkey_values.iter()).into_iter().map(|t|t.clone()).collect(); // note: this step sorts the keys + for (key,f_index) in newkey_values.iter().zip(newkeys_lut.iter_mut()) { + *f_index = new_unique_keys.binary_search(key).expect("unreachable: new_unique_keys was constructed with all keys"); + } + + let newkeys = Labels::new( + ["species_first_bond_atom", "species_second_bond_atom", "species_third_atom"], + &new_unique_keys, + ); + let samples = self.distance_calculator.samples(&newkeys, systems)?; + let ret = Ok(newkeys_lut.into_iter().map(|i|samples[i].clone()).collect::>()); + ret + } + + fn supports_gradient(&self, parameter: &str) -> bool { + return false; + self.distance_calculator.supports_gradient(parameter) + } + + fn positions_gradient_samples(&self, _keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + self.distance_calculator.positions_gradient_samples(_keys, samples, systems) + } + + fn components(&self, keys: &Labels) -> Vec> { + assert_eq!(keys.names().len(), 4); + assert_eq!(keys.names()[0], "spherical_harmonics_l"); + + let mut result = Vec::new(); + // only compute the components once for each `spherical_harmonics_l`, + // and re-use the results across the other keys. + let mut cache: BTreeMap<_, Vec> = BTreeMap::new(); + for &[spherical_harmonics_l, _, _, _] in keys.iter_fixed_size() { + let components = match cache.entry(spherical_harmonics_l) { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + let mut component = LabelsBuilder::new(vec!["spherical_harmonics_m"]); + for m in -spherical_harmonics_l.i32()..=spherical_harmonics_l.i32() { + component.add(&[LabelValue::new(m)]); + } + + let components = vec![component.finish()]; + entry.insert(components).clone() + } + }; + + result.push(components); + } + + return result; + } + + fn properties_names(&self) -> Vec<&str> { + vec!["n"] + } + + fn properties(&self, keys: &Labels) -> Vec { + let mut properties = LabelsBuilder::new(self.properties_names()); + for n in 0..self.parameters.max_radial { + properties.add(&[n]); + } + + return vec![properties.finish(); keys.count()]; + } + + #[time_graph::instrument(name = "SphericalExpansionByBondType::compute")] + fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + assert_eq!(descriptor.keys().names(), ["spherical_harmonics_l", "species_bond_atom_1", "species_bond_atom_2", "species_third_atom"]); + + let max_angular = self.parameters.max_angular; + let l_slices: Vec<_> = (0..=max_angular).map(|l|{ + let lsize = l*l; + let msize = 2*l+1; + lsize..lsize+msize + }).collect(); + + let do_gradients = GradientsOptions { + positions: descriptor.block_by_id(0).gradient("positions").is_some(), + cell: descriptor.block_by_id(0).gradient("cell").is_some(), + }; + if do_gradients.positions { + assert!(self.distance_calculator.supports_gradient("positions")); + } + if do_gradients.cell { + assert!(self.distance_calculator.supports_gradient("cell")); + } + + // first, create some partial-key -> block lookup tables to avoid linear searches within blocks later + + // {(s1,s2,s3) -> i_s3} + let mut s1s2s3_to_is3: BTreeMap<(i32,i32,i32),usize> = BTreeMap::new(); + // {(s1,s2) -> [i_s3->(s3,[l->i_block])]} + let mut s1s2_to_block_ids: BTreeMap<(i32,i32),Vec<(i32,Vec)>> = BTreeMap::new(); + + for (block_i, &[l, s1,s2,s3]) in descriptor.keys().iter_fixed_size().enumerate(){ + let s1=s1.i32(); + let s2=s2.i32(); + let s3=s3.i32(); + let l=l.usize(); + let s1s2_blocks = s1s2_to_block_ids.entry((s1,s2)) + .or_insert_with(Vec::new); + let l_blocks = match s1s2s3_to_is3.entry((s1,s2,s3)) { + Entry::Occupied(i_s3_e) => { + let (s3_b,l_blocks) = &mut s1s2_blocks[*i_s3_e.get()]; + debug_assert_eq!(s3_b, &s3); + l_blocks + }, + Entry::Vacant(i_s3_e) => { + let i_s3 = s1s2_blocks.len(); + i_s3_e.insert(i_s3); + s1s2_blocks.push((s3,vec![usize::MAX;max_angular+1])); + &mut s1s2_blocks[i_s3].1 + }, + }; + l_blocks[l] = block_i; + } + + // then, for every of those partial-keys construct a similar lookup table that helps select + // the right blocks and samples for the given compound and species. + for ((s1,s2),s1s2_blocks) in s1s2_to_block_ids.iter() { + let s3_list: Vec = s1s2_blocks.iter().map(|t|t.0).collect(); + // half-assume that blocks that share s1,s2,s3 have the same sample list + #[cfg(debug_assertions)]{ + for s3blocks in s1s2_blocks.iter().map(|t|&t.1) { + debug_assert!(s3blocks.len()>0); + let len = descriptor.block_by_id(s3blocks[0]).samples().size(); + for lblock in s3blocks { + if lblock != &usize::MAX{ + debug_assert_eq!(descriptor.block_by_id(*lblock).samples().size(), len); + } + } + } + } + + // [system_i -> {(triplet_i,inverted)->(i_s3,sample_i)} ] + let mut sample_lut: Vec>> = vec![BTreeMap::new(); systems.len()]; + let mut s3_samples = vec![]; + + // also assume that the systems are in order in the samples + for (i_s3,s3blocks) in s1s2_blocks.iter().map(|t|&t.1).enumerate() { + let good_block_i = s3blocks.iter().filter(|b_i|**b_i!=usize::MAX).next().unwrap(); + + let samples = descriptor.block_by_id(*good_block_i).samples(); + #[cfg(debug_assertions)]{s3_samples.push(samples.clone());} + for (sample_i, &[system_i, triplet_i, _bond_i, atom_1,atom_2,_atom_3]) in samples.iter_fixed_size().enumerate(){ + let (system_i, triplet_i, atom_1, atom_2) = (system_i.usize(),triplet_i.usize(),atom_1.usize(),atom_2.usize()); + let triplet = systems[system_i].triplets()?[triplet_i]; + if atom_1 != atom_2 && atom_1 == triplet.bond.first{ + // simple case 1: we know the triplet is uninverted in the sample + sample_lut[system_i].entry((triplet_i, false)) + .or_insert_with(||vec![]).push((i_s3,sample_i)); + } else if atom_1 != atom_2 && atom_2 == triplet.bond.first { + // simple case 2: we know the triplet is inverted in the sample + sample_lut[system_i].entry((triplet_i, true)) + .or_insert_with(||vec![]).push((i_s3,sample_i)); + } else if atom_1 == atom_2 && atom_1 == triplet.bond.first { + // complex case: bond's atoms are images of each other: + // we probably already crashed because of duped samples. oh well. + unimplemented!("I'm surpised we haven't crashed earlier"); + } else {unreachable!();} + + } + } + for (system_i,system) in systems.iter_mut().enumerate() { + if system_i >= sample_lut.len() || sample_lut[system_i].len() == 0 { + continue // sometimes someone would specify extra samples which have no underlying data… welp. + } + let system = &mut **system; + system.compute_triplet_neighbors(self.parameters.bond_cutoff(), self.parameters.third_cutoff())?; + for (i_s3,sample_i,_triplet_i,contribution) in self.get_coefficients_for(system, *s1, *s2, &s3_list, do_gradients)? + .filter_map(|(triplet_i, inverted,contribution)| + sample_lut[system_i].get(&(triplet_i,inverted)) + .map(|lutvec|(lutvec,triplet_i,contribution)) + ).flat_map(|(lutvec,triplet_i,contribution)| + lutvec.into_iter().map(move |(i_s3,sample_i)|((i_s3,sample_i,triplet_i,contribution.clone()))) + ){ + + let ret_blocks = &s1s2_blocks[*i_s3].1; + let contribution = contribution.borrow(); + for (l,lslice) in l_slices.iter().enumerate() { + if ret_blocks[l] == usize::MAX{ + continue; // sometimes the key for that l,s1,s2,s3 combination was not provided + } + let mut block = descriptor.block_mut_by_id(ret_blocks[l]); + debug_assert_eq!(block.samples(), s3_samples[*i_s3]); + let n_subset = block.properties(); + let mut values = block.values_mut(); + for (i_n,&[n]) in n_subset.iter_fixed_size().enumerate() { + let mut value_slice = values.as_array_mut().slice_mut(s![*sample_i,..,i_n]); + value_slice.assign(&contribution.values.slice(s![lslice.clone(),n.usize()])); + } + } + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use metatensor::Labels; + use ndarray::{s, Axis}; + use approx::assert_ulps_eq; + + use crate::systems::test_utils::{test_system, test_systems}; + use crate::Calculator; + use crate::calculators::{CalculatorBase, SphericalExpansionForBonds}; + + use super::{SphericalExpansionForBondType, SphericalExpansionForBondsParameters}; + use super::super::{CutoffFunction, RadialScaling}; + use crate::calculators::radial_basis::RadialBasis; + + + fn parameters() -> SphericalExpansionForBondsParameters{ + SphericalExpansionForBondsParameters{ + cutoffs: [3.5,3.5], + max_radial: 6, + max_angular: 6, + atomic_gaussian_width: 0.3, + center_atoms_weight: 10.0, + radial_basis: RadialBasis::splined_gto(1e-8), + radial_scaling: RadialScaling::Willatt2018 { scale: 1.5, rate: 0.8, exponent: 2.0}, + cutoff_function: CutoffFunction::ShiftedCosine { width: 0.5 }, + } + } + + // #[test] + // fn finite_differences_positions() { + // let calculator = Calculator::from(Box::new(SphericalExpansionForBondType::new( + // parameters() + // ).unwrap()) as Box); + + // let system = test_system("water"); + // let options = crate::calculators::tests_utils::FinalDifferenceOptions { + // displacement: 1e-6, + // max_relative: 1e-5, + // epsilon: 1e-16, + // }; + // crate::calculators::tests_utils::finite_differences_positions(calculator, &system, options); + // } + + // #[test] + // fn finite_differences_cell() { + // let calculator = Calculator::from(Box::new(SphericalExpansionForBondType::new( + // parameters() + // ).unwrap()) as Box); + + // let system = test_system("water"); + // let options = crate::calculators::tests_utils::FinalDifferenceOptions { + // displacement: 1e-6, + // max_relative: 1e-5, + // epsilon: 1e-16, + // }; + // crate::calculators::tests_utils::finite_differences_cell(calculator, &system, options); + // } + + #[test] + fn compute_partial() { + let calculator = Calculator::from(Box::new(SphericalExpansionForBondType::new( + SphericalExpansionForBondsParameters { + max_angular: 2, + ..parameters() + } + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + + let properties = Labels::new(["n"], &[ + [0], + [3], + [2], + ]); + + let samples = Labels::new([ + "structure", + "triplet_i", + "bond_i", + "first_bond_atom", + "second_bond_atom", + "third_atom" + ], &[ + //[0, 1, 2], + //[0, 2, 1], + [0, 7, 2, 1, 2, 1], + [0, 7, 2, 2, 1, 1], + [0, 8, 2, 1, 2, 2], + [0, 8, 2, 2, 1, 2], + [0, 5, 1, 0, 2, 2], + + ]); + + let keys = Labels::new([ + "spherical_harmonics_l", + "species_bond_atom_1", + "species_bond_atom_2", + "species_third_atom", + ], &[ + [0, -42, 1, -42], + [0, -42, 1, 1], + [0, 1, -42, -42], + [0, 1, -42, 1], + [0, 1, 1, 1], + [0, 1, 1, -42], + [0, 6, 1, 1], // not part of the default keys + [1, -42, 1, -42], + [1, -42, 1, 1], + [1, 1, -42, -42], + [1, 1, -42, 1], + [1, 1, 1, 1], + [1, 1, 1, -42], + [2, -42, 1, -42], + [2, -42, 1, 1], + [2, 1, -42, -42], + [2, 1, -42, 1], + [2, 1, 1, 1], + [2, 1, 1, -42], + ]); + + crate::calculators::tests_utils::compute_partial( + calculator, &mut systems, &keys, &samples, &properties + ); + } + + #[test] + fn sums_to_spherical_expansion() { + let mut calculator_by_pair = Calculator::from(Box::new(SphericalExpansionForBondType::new( + parameters() + ).unwrap()) as Box); + let mut calculator = Calculator::from(Box::new(SphericalExpansionForBonds::new( + parameters() + ).unwrap()) as Box); + + let mut systems = test_systems(&["water", "methane"]); + + let expected = calculator.compute(&mut systems, Default::default()).unwrap(); + + + let by_pair = calculator_by_pair.compute(&mut systems, Default::default()).unwrap(); + + // check that keys are the same appart for the names + assert_eq!(expected.keys().count(), by_pair.keys().count());//, "wrong key count: {} vs {}", expected.keys().count(), by_pair.keys().count()); + assert_eq!( + expected.keys().iter().collect::>(), + by_pair.keys().iter().collect::>(), + ); + + for (_bl_i,(block, spx)) in by_pair.blocks().iter().zip(expected.blocks()).enumerate() { + let spx = spx.data(); + let spx_values = spx.values.as_array(); + + let block = block.data(); + let values = block.values.as_array(); + + for (&[spx_structure, spx_center1,spx_center2, spx_bond_i], expected) in spx.samples.iter_fixed_size().zip(spx_values.axis_iter(Axis(0))) { + let mut sum = ndarray::Array::zeros(expected.raw_dim()); + + for (sample_i, &[structure, _triplet_i, bond_i, center1, center2, _atom3]) in block.samples.iter_fixed_size().enumerate() { + if spx_structure == structure && spx_bond_i == bond_i && spx_center1 == center1 && spx_center2 == center2 { + sum += &values.slice(s![sample_i, .., ..]); + } + } + + assert_ulps_eq!(sum, expected); + } + } + } +} diff --git a/rascaline/src/calculators/soap/spherical_expansion_pair.rs b/rascaline/src/calculators/soap/spherical_expansion_pair.rs index 897ecd7cf..aadafe0ba 100644 --- a/rascaline/src/calculators/soap/spherical_expansion_pair.rs +++ b/rascaline/src/calculators/soap/spherical_expansion_pair.rs @@ -79,12 +79,12 @@ pub struct SphericalExpansionByPair { pub(crate) parameters: SphericalExpansionParameters, /// implementation + cached allocation to compute the radial integral for a /// single pair - radial_integral: ThreadLocal>, + pub(super) radial_integral: ThreadLocal>, /// implementation + cached allocation to compute the spherical harmonics /// for a single pair - spherical_harmonics: ThreadLocal>, + pub(super) spherical_harmonics: ThreadLocal>, /// Cache for (-1)^l values - m_1_pow_l: Vec, + pub(super) m_1_pow_l: Vec, } impl std::fmt::Debug for SphericalExpansionByPair { @@ -193,14 +193,14 @@ impl SphericalExpansionByPair { } /// Compute the product of radial scaling & cutoff smoothing functions - fn scaling_functions(&self, r: f64) -> f64 { + pub(super) fn scaling_functions(&self, r: f64) -> f64 { let cutoff = self.parameters.cutoff_function.compute(r, self.parameters.cutoff); let scaling = self.parameters.radial_scaling.compute(r); return cutoff * scaling; } /// Compute the gradient of the product of radial scaling & cutoff smoothing functions - fn scaling_functions_gradient(&self, r: f64) -> f64 { + pub(super) fn scaling_functions_gradient(&self, r: f64) -> f64 { let cutoff = self.parameters.cutoff_function.compute(r, self.parameters.cutoff); let cutoff_grad = self.parameters.cutoff_function.derivative(r, self.parameters.cutoff); diff --git a/rascaline/src/labels/keys.rs b/rascaline/src/labels/keys.rs index 49eff665f..cc6b5639d 100644 --- a/rascaline/src/labels/keys.rs +++ b/rascaline/src/labels/keys.rs @@ -95,6 +95,53 @@ impl KeysBuilder for CenterSingleNeighborsSpeciesKeys { } } +/// Compute a set of keys with three variables: the species of two central atoms within a given cutoff to each other, +/// and the species of a third, neighbor atom, within a cutoff of the first two. +pub struct TwoCentersSingleNeighborsSpeciesKeys { + /// Spherical cutoff to use when searching for neighbors around an atom + pub(crate) cutoffs: [f64;2], // bond_, third_cutoff + //pub bond_cutoff: f64, + //pub third_cutoff: f64, + /// Should we consider an atom to be it's own neighbor or not? + pub self_contributions: bool, +} + +impl TwoCentersSingleNeighborsSpeciesKeys{ + pub fn bond_cutoff(&self) -> f64 { + self.cutoffs[0] + } + pub fn third_cutoff(&self) -> f64 { + self.cutoffs[1] + } +} + + +impl KeysBuilder for TwoCentersSingleNeighborsSpeciesKeys { + fn keys(&self, systems: &mut [Box]) -> Result { + assert!(self.bond_cutoff() > 0.0 && self.bond_cutoff().is_finite() && self.third_cutoff() > 0.0 && self.third_cutoff().is_finite()); + + let mut all_species_triplets = BTreeSet::new(); + for system in systems { + system.compute_triplet_neighbors(self.bond_cutoff(), self.third_cutoff())?; + + let species = system.species()?; + for triplet in system.triplets()? { + if (!self.self_contributions) && triplet.is_self_contribution { + continue; + } + all_species_triplets.insert((species[triplet.bond.first], species[triplet.bond.second], species[triplet.third])); + all_species_triplets.insert((species[triplet.bond.second], species[triplet.bond.first], species[triplet.third])); + } + } + + let mut keys = LabelsBuilder::new(vec!["species_center_1", "species_center_2", "species_neighbor"]); + for (center1, center2, neighbor) in all_species_triplets { + keys.add(&[center1,center2, neighbor]); + } + + return Ok(keys.finish()); + } +} /// Compute a set of keys with three variables: the central atom species and two /// neighbor atom species. diff --git a/rascaline/src/labels/mod.rs b/rascaline/src/labels/mod.rs index 2adeef803..413df4640 100644 --- a/rascaline/src/labels/mod.rs +++ b/rascaline/src/labels/mod.rs @@ -1,11 +1,11 @@ mod samples; pub use self::samples::{SpeciesFilter, SamplesBuilder}; -pub use self::samples::AtomCenteredSamples; +pub use self::samples::{AtomCenteredSamples,BondCenteredSamples}; pub use self::samples::LongRangeSamplesPerAtom; mod keys; pub use self::keys::KeysBuilder; pub use self::keys::CenterSpeciesKeys; -pub use self::keys::{CenterSingleNeighborsSpeciesKeys, AllSpeciesPairsKeys}; +pub use self::keys::{CenterSingleNeighborsSpeciesKeys, TwoCentersSingleNeighborsSpeciesKeys, AllSpeciesPairsKeys}; pub use self::keys::{CenterTwoNeighborsSpeciesKeys}; diff --git a/rascaline/src/labels/samples/bond_centered.rs b/rascaline/src/labels/samples/bond_centered.rs new file mode 100644 index 000000000..41619be62 --- /dev/null +++ b/rascaline/src/labels/samples/bond_centered.rs @@ -0,0 +1,377 @@ +use std::collections::{BTreeSet,BTreeMap}; + +use metatensor::{Labels, LabelsBuilder}; + +use crate::{Error, System}; +use super::{SamplesBuilder, SpeciesFilter}; + + +/// `SampleBuilder` for bond-centered representations. This will create one +/// sample for each pair of atoms (within a spherical cutoff to each other), +/// optionally filtering on the bond's atom species. The samples names are +/// (structure", "first_center", "second_center", "bond_i"). +/// +/// Positions gradient samples include all atoms within a spherical cutoff to the bond center, +/// optionally filtering on the neighbor atom species. +pub struct BondCenteredSamples { + /// spherical cutoff radius used to construct the atom-centered environments + pub cutoffs: [f64;2], // bond_, third_cutoff + //pub bond_cutoff: f64, + //pub third_cutoff: f64, + /// Filter for the central atom species + pub species_center_1: SpeciesFilter, + pub species_center_2: SpeciesFilter, + /// Filter for the neighbor atom species + pub species_neighbor: SpeciesFilter, + /// Should the central atom be considered it's own neighbor? + pub self_contributions: bool, +} + +impl BondCenteredSamples{ + pub fn bond_cutoff(&self) -> f64 { + self.cutoffs[0] + } + pub fn third_cutoff(&self) -> f64 { + self.cutoffs[1] + } +} + +impl SamplesBuilder for BondCenteredSamples { + fn samples_names() -> Vec<&'static str> { + // bond_i is needed in case we have several bonds with the same atoms (periodic boundaries) + vec!["structure", "first_center", "second_center", "bond_i"] + } + + fn samples(&self, systems: &mut [Box]) -> Result { + assert!( + self.bond_cutoff() > 0.0 && self.bond_cutoff().is_finite() && self.third_cutoff() > 0.0 && self.third_cutoff().is_finite(), + "cutoffs must be positive for BondCenteredSamples" + ); + let mut builder = LabelsBuilder::new(Self::samples_names()); + for (system_i, system) in systems.iter_mut().enumerate() { + system.compute_triplet_neighbors(self.bond_cutoff(), self.third_cutoff())?; + let species = system.species()?; + let all_triplets = system.triplets()?; + + let mut center_cache: BTreeMap<(usize,usize,usize), BTreeSet> = BTreeMap::new(); + + match (&self.species_center_1, &self.species_center_2) { + (SpeciesFilter::Any, SpeciesFilter::Any) => { + for triplet in system.triplets()? { + if self.self_contributions || (!triplet.is_self_contribution){ + center_cache.entry((triplet.bond.first, triplet.bond.second, triplet.bond_i)) + .or_insert_with(BTreeSet::new) + .insert(species[triplet.third]); + } + } + } + (SpeciesFilter::AllOf(_),_)|(_,SpeciesFilter::AllOf(_)) => + panic!("Cannot use Species::AllOf on BondCenteredSamples.center_species"), + (SpeciesFilter::Single(s1), SpeciesFilter::Single(s2)) => { + let species_set = BTreeSet::from_iter(species.iter()); + for s3 in species_set { + for triplet_i in system.triplets_with_species(*s1, *s2, *s3)? { + let triplet = &all_triplets[*triplet_i]; + if !self.self_contributions && triplet.is_self_contribution { + continue; + } + // TODO full/half logic + center_cache.entry((triplet.bond.first, triplet.bond.second, triplet.bond_i)) + .or_insert_with(BTreeSet::new) + .insert(species[triplet.third]); + } + } + + }, + (selection_1, selection_2) => { + for (center_i, &species_center) in species.iter().enumerate() { + if !selection_1.matches(species_center) { + continue; + } + for (center_j, &species_center_2) in species.iter().enumerate() { + if !selection_2.matches(species_center_2) { + continue; + } + for triplet in system.triplets_containing(center_i, center_j)? { + if !self.self_contributions && triplet.is_self_contribution { + continue; + } + // TODO full/half logic + center_cache.entry((triplet.bond.first, triplet.bond.second, triplet.bond_i)) + .or_insert_with(BTreeSet::new) + .insert(species[triplet.third]); + } + } + } + } + } + match &self.species_neighbor { + SpeciesFilter::Any => { + for (center_1,center_2,bond_i) in center_cache.keys() { + builder.add(&[system_i,*center_1,*center_2,*bond_i]); + } + }, + SpeciesFilter::AllOf(requirements) => { + for ((center_1,center_2,bond_i), neigh_set) in center_cache.iter() { + if requirements.is_subset(neigh_set) { + builder.add(&[system_i,*center_1,*center_2,*bond_i]); + } + } + }, + SpeciesFilter::Single(requirement) => { + for ((center_1,center_2,bond_i), neigh_set) in center_cache.iter() { + if neigh_set.contains(requirement) { + builder.add(&[system_i,*center_1,*center_2,*bond_i]); + } + } + }, + SpeciesFilter::OneOf(requirements) => { + let requirements: BTreeSet = BTreeSet::from_iter(requirements.iter().map(|x|*x)); + for ((center_1,center_2,bond_i), neigh_set) in center_cache.iter() { + if neigh_set.intersection(&requirements).count()>0 { + builder.add(&[system_i,*center_1,*center_2,*bond_i]); + } + } + }, + } + } + + return Ok(builder.finish()); + } + + fn gradients_for(&self, systems: &mut [Box], samples: &Labels) -> Result { + assert!( + self.bond_cutoff() > 0.0 && self.bond_cutoff().is_finite() && self.third_cutoff() > 0.0 && self.third_cutoff().is_finite(), + "cutoffs must be positive for BondCenteredSamples" + ); + assert_eq!(samples.names(), ["structure", "first_center", "second_center", "bond_i"]); + let mut builder = LabelsBuilder::new(vec!["sample", "structure", "atom"]); + + // we could try to find a better way to estimate this, but in the worst + // case this would only over-allocate a bit + let average_neighbors_per_atom = 10; + builder.reserve(average_neighbors_per_atom * samples.count()); + + for (sample_i, [structure_i, center_1, center_2, bond_i]) in samples.iter_fixed_size().enumerate() { + let structure_i = structure_i.usize(); + let center_1 = center_1.usize(); + let center_2 = center_2.usize(); + let bond_i = bond_i.usize(); + + let system = &mut systems[structure_i]; + system.compute_triplet_neighbors(self.bond_cutoff(), self.third_cutoff())?; + let species = system.species()?; + + let mut grad_contributors = BTreeSet::new(); + grad_contributors.insert(center_1); + grad_contributors.insert(center_2); + + for triplet in system.triplets_containing(center_1, center_2)? { + if triplet.bond_i != bond_i { + continue; + } + match &self.species_neighbor{ + SpeciesFilter::Any | SpeciesFilter::AllOf(_) => { + // in both of those cases, the sample already has been validated, and all known neighbors contribute + grad_contributors.insert(triplet.third); + }, + neighbor_filter => { + if neighbor_filter.matches(species[triplet.third]) { + grad_contributors.insert(triplet.third); + } + }, + } + } + + for contrib in grad_contributors{ + builder.add(&[sample_i, structure_i, contrib]); + } + } + + return Ok(builder.finish()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::systems::test_utils::test_systems; + + #[test] + fn all_samples() { + let mut systems = test_systems(&["CH", "water"]); + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + species_center_1: SpeciesFilter::Any, + species_center_2: SpeciesFilter::Any, + species_neighbor: SpeciesFilter::Any, + self_contributions: true, + }; + + let samples = builder.samples(&mut systems).unwrap(); + assert_eq!(samples, Labels::new( + ["structure", "first_center", "second_center", "bond_i"], + &[[0, 0, 1, 0], [1, 0, 1, 0], [1, 0, 2, 1], [1, 1, 2, 2]], + )); + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "structure", "atom"], + &[ + // gradients of atoms in CH + [0, 0, 0], [0, 0, 1], + // gradients of atoms in water + [1, 1, 0], [1, 1, 1], [1, 1, 2], + [2, 1, 0], [2, 1, 1], [2, 1, 2], + [3, 1, 0], [3, 1, 1], [3, 1, 2], + ], + )); + } + + #[test] + fn filter_species_center() { + let mut systems = test_systems(&["CH", "water"]); + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + species_center_1: SpeciesFilter::Single(6), + species_center_2: SpeciesFilter::Single(1), + species_neighbor: SpeciesFilter::Any, + self_contributions: true, + }; + + let samples = builder.samples(&mut systems).unwrap(); + assert_eq!(samples, Labels::new( + ["structure", "first_center", "second_center", "bond_i"], + &[[0, 0, 1, 0]], + )); + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "structure", "atom"], + &[ + // gradients of atoms in CH + [0, 0, 0], [0, 0, 1], + ] + )); + + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + species_center_1: SpeciesFilter::Single(1), + species_center_2: SpeciesFilter::Single(1), + species_neighbor: SpeciesFilter::Any, + self_contributions: true, + }; + + let samples = builder.samples(&mut systems).unwrap(); + assert_eq!(samples, Labels::new( + ["structure", "first_center", "second_center", "bond_i"], + &[[1, 1, 2, 2]], + )); + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "structure", "atom"], + &[ + // gradients of atoms in H2O + [0, 1, 0], [0, 1, 1], [0, 1, 2], + ] + )); + } + + #[test] + fn filter_species_neighbor() { + let mut systems = test_systems(&["CH", "water"]); + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + species_center_1: SpeciesFilter::Any, + species_center_2: SpeciesFilter::Any, + species_neighbor: SpeciesFilter::Single(1), + self_contributions: true, + }; + + let samples = builder.samples(&mut systems).unwrap(); + assert_eq!(samples, Labels::new( + ["structure", "first_center", "second_center", "bond_i"], + &[[0, 0, 1, 0], [1, 0, 1, 0], [1, 0, 2, 1], [1, 1, 2, 2]], + )); + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "structure", "atom"], + &[ + // gradients of atoms in CH w.r.t H atom only + [0, 0, 0], [0, 0, 1], + // gradients of atoms in water w.r.t H atoms only + [1, 1, 0], [1, 1, 1], [1, 1, 2], + [2, 1, 0], [2, 1, 1], [2, 1, 2], + [3, 1, 1], [3, 1, 2], + ] + )); + + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + species_center_1: SpeciesFilter::Any, + species_center_2: SpeciesFilter::Any, + species_neighbor: SpeciesFilter::OneOf(vec![1, 6]), + self_contributions: true, + }; + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "structure", "atom"], + &[ + // gradients of atoms in CH w.r.t C and H atoms + [0, 0, 0], [0, 0, 1], + // gradients of atoms in water w.r.t H atoms only + [1, 1, 0], [1, 1, 1], [1, 1, 2], + [2, 1, 0], [2, 1, 1], [2, 1, 2], + [3, 1, 1], [3, 1, 2], + ] + )); + } + + #[test] + fn partial_gradients() { + let samples = Labels::new(["structure", "first_center", "second_center", "bond_i"], &[ + [1, 0, 1, 0], + [0, 0, 1, 0], + [1, 1, 2, 2], + ]); + + let mut systems = test_systems(&["CH", "water"]); + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + species_center_1: SpeciesFilter::Any, + species_center_2: SpeciesFilter::Any, + species_neighbor: SpeciesFilter::Single(-42), + self_contributions: true, + }; + + let gradients = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradients, Labels::new(["sample", "structure", "atom"], &[ + [0, 1, 0], [0, 1, 1], + [1, 0, 0], [1, 0, 1], + [2, 1, 0], [2, 1, 1], [2, 1, 2], + ])); + + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + species_center_1: SpeciesFilter::Any, + species_center_2: SpeciesFilter::Any, + species_neighbor: SpeciesFilter::Single(1), + self_contributions: true, + }; + let gradients = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradients, Labels::new( + ["sample", "structure", "atom"], + &[ + // gradients of first sample, O-H1 in water + [0, 1, 0], [0, 1, 1], [0, 1, 2], + // gradients of second sample, C-H in CH + [1, 0, 0], [1, 0, 1], + // gradients of third sample, H1-H2 in water + [2, 1, 1], [2, 1, 2], + ] + )); + } +} \ No newline at end of file diff --git a/rascaline/src/labels/samples/mod.rs b/rascaline/src/labels/samples/mod.rs index fe7ada49f..7faf0dd3b 100644 --- a/rascaline/src/labels/samples/mod.rs +++ b/rascaline/src/labels/samples/mod.rs @@ -54,6 +54,8 @@ pub trait SamplesBuilder { mod atom_centered; pub use self::atom_centered::AtomCenteredSamples; +mod bond_centered; +pub use self::bond_centered::BondCenteredSamples; mod long_range; pub use self::long_range::LongRangeSamplesPerAtom;