From 53c459190aa0110e3d1e4baadabbd82e16dab7ed Mon Sep 17 00:00:00 2001 From: Liam Marsh Date: Tue, 19 Nov 2024 17:12:39 +0100 Subject: [PATCH] Add bond-atom spherical expansion, feature-gated. [restructured contribution, commit 3] --- python/rascaline/rascaline/__init__.py | 2 + python/rascaline/rascaline/calculators.py | 51 ++ rascaline/src/calculator.rs | 2 + .../calculators/bondatom/bond_atom_math.rs | 715 +++++++++++++++++ rascaline/src/calculators/bondatom/mod.rs | 39 + .../spherical_expansion_bondcentered.rs | 717 ++++++++++++++++++ rascaline/src/calculators/mod.rs | 3 + rascaline/src/labels/keys.rs | 48 +- rascaline/src/labels/mod.rs | 4 +- rascaline/src/labels/samples/bond_centered.rs | 413 ++++++++++ rascaline/src/labels/samples/mod.rs | 2 + 11 files changed, 1993 insertions(+), 3 deletions(-) create mode 100644 rascaline/src/calculators/bondatom/bond_atom_math.rs create mode 100644 rascaline/src/calculators/bondatom/mod.rs create mode 100644 rascaline/src/calculators/bondatom/spherical_expansion_bondcentered.rs create mode 100644 rascaline/src/labels/samples/bond_centered.rs diff --git a/python/rascaline/rascaline/__init__.py b/python/rascaline/rascaline/__init__.py index c11deb7a8..615a0bbda 100644 --- a/python/rascaline/rascaline/__init__.py +++ b/python/rascaline/rascaline/__init__.py @@ -12,6 +12,7 @@ SortedDistances, SphericalExpansion, SphericalExpansionByPair, + SphericalExpansionForBonds, ) from .log import set_logging_callback # noqa from .profiling import Profiler # noqa @@ -29,4 +30,5 @@ "SortedDistances", "SphericalExpansion", "SphericalExpansionByPair", + "SphericalExpansionForBonds", ] diff --git a/python/rascaline/rascaline/calculators.py b/python/rascaline/rascaline/calculators.py index 6ecd94959..5a7b8a490 100644 --- a/python/rascaline/rascaline/calculators.py +++ b/python/rascaline/rascaline/calculators.py @@ -201,6 +201,57 @@ def __init__( super().__init__("spherical_expansion_by_pair", json.dumps(parameters)) +class SphericalExpansionForBonds(CalculatorBase): + """A SOAP-like spherical expansion coefficients for bond-centered environments + In other words, the spherical expansion of the neighbor density function centered + on the center of a bond, + 'after' rotating the system so that the bond is aligned with the z axis. + + This is not rotationally invariant, and as such you should use some + not-implemented-here matheatical trick + similar to what SOAP (the :py:class:`SoapPowerSpectrum` class) uses. + + Most hyperparameters are identical to that of the regulat spherical expansion: + :ref:`documentation `. + + the few changes to this are: + + - "cutoff" renamed to "third_cutoff" + - "bond_cutoff" which expresses how the pairs of atoms used for the 'bonds' are + chosen. + - "center_atomS_weight" (caps only used for emphasis): the weight multiplier + for the coefficients of the self interactions + (where the neighboring atom is one of the pair's atoms). + """ + + def __init__( + self, + bond_cutoff, + third_cutoff, + max_radial, + max_angular, + atomic_gaussian_width, + center_atoms_weight, + radial_basis, + cutoff_function, + radial_scaling=None, + ): + parameters = { + "cutoffs": [bond_cutoff, third_cutoff], + "max_radial": max_radial, + "max_angular": max_angular, + "atomic_gaussian_width": atomic_gaussian_width, + "center_atoms_weight": center_atoms_weight, + "radial_basis": radial_basis, + "cutoff_function": cutoff_function, + } + + if radial_scaling is not None: + parameters["radial_scaling"] = radial_scaling + + super().__init__("spherical_expansion_for_bonds", json.dumps(parameters)) + + class SoapRadialSpectrum(CalculatorBase): """Radial spectrum of Smooth Overlap of Atomic Positions (SOAP). diff --git a/rascaline/src/calculator.rs b/rascaline/src/calculator.rs index 077a4b561..cf81298a1 100644 --- a/rascaline/src/calculator.rs +++ b/rascaline/src/calculator.rs @@ -549,6 +549,7 @@ use crate::calculators::{SphericalExpansionByPair, SphericalExpansionParameters} use crate::calculators::SphericalExpansion; use crate::calculators::{SoapPowerSpectrum, PowerSpectrumParameters}; use crate::calculators::{SoapRadialSpectrum, RadialSpectrumParameters}; +use crate::calculators::{SphericalExpansionForBonds, SphericalExpansionForBondsParameters}; use crate::calculators::{LodeSphericalExpansion, LodeSphericalExpansionParameters}; type CalculatorCreator = fn(&str) -> Result, Error>; @@ -581,6 +582,7 @@ static REGISTERED_CALCULATORS: Lazy> = add_calculator!(map, "spherical_expansion", SphericalExpansion, SphericalExpansionParameters); add_calculator!(map, "soap_radial_spectrum", SoapRadialSpectrum, RadialSpectrumParameters); add_calculator!(map, "soap_power_spectrum", SoapPowerSpectrum, PowerSpectrumParameters); + add_calculator!(map, "spherical_expansion_for_bonds", SphericalExpansionForBonds, SphericalExpansionForBondsParameters); add_calculator!(map, "lode_spherical_expansion", LodeSphericalExpansion, LodeSphericalExpansionParameters); return map; diff --git a/rascaline/src/calculators/bondatom/bond_atom_math.rs b/rascaline/src/calculators/bondatom/bond_atom_math.rs new file mode 100644 index 000000000..6358b769c --- /dev/null +++ b/rascaline/src/calculators/bondatom/bond_atom_math.rs @@ -0,0 +1,715 @@ +use std::collections::BTreeMap; +use std::collections::btree_map::Entry; +use std::cell::RefCell; +use thread_local::ThreadLocal; +use log::warn; + +use metatensor::TensorBlockRefMut; + +use crate::Error; +use crate::types::{Vector3D,Matrix3}; +use crate::systems::BATripletInfo; + +use crate::calculators::soap::{CutoffFunction, RadialScaling}; +use crate::calculators::radial_basis::RadialBasis; +use crate::calculators::soap::{ + SoapRadialIntegralCache, + SoapRadialIntegralParameters, +}; +use crate::math::SphericalHarmonicsCache; + + +/// for a given vector (`vec`), compute a rotation matrix (`M`) so that `M×vec` +/// is expressed as `(0,0,+z)` +/// currently, this matrix corresponds to a rotatoin expressed as `-z;+y;+z` in euler angles, +/// or as `(x,y,0),theta` in axis-angle representation. +fn rotate_vector_to_z(vec: Vector3D) -> Matrix3 { + // re-orientation is done through a rotation matrix, computed through the axis-angle and quaternion representations of the rotation + // axis/angle representation of the rotation: axis is norm(-y,x,0), angle is arctan2( sqrt(x**2+y**2), z) + // meaning sin(angle) = sqrt((x**2+y**2) /r2); cos(angle) = z/sqrt(r2) + + let (xylen,len) = { + let xyl = vec[0]*vec[0] + vec[1]*vec[1]; + (xyl.sqrt(), (xyl+vec[2]*vec[2]).sqrt()) + }; + + if xylen.abs()<1E-7 { + if vec[2] < 0. { + return Matrix3::new([[-1.,0.,0.], [0.,1.,0.], [0.,0.,-1.]]) + } + else { + return Matrix3::new([[1.,0.,0.], [0.,1.,0.], [0.,0.,1.]]) + } + } + + let c = vec[2]/len; + let s = xylen/len; + let t = 1. - c; + + let x2 = -vec[1]/xylen; + let y2 = vec[0]/xylen; + + let tx = t*x2; + let sx = s*x2; + let sy = s*y2; + + return Matrix3::new([ + [tx*x2 +c, tx*y2, -sy], + [tx*y2, t*y2*y2 + c, sx], + [sy, -sx, c], + ]); +} + + +/// returns the derivatives of the reoriention matrix with the three components of the vector to reorient +fn rotate_vector_to_z_derivatives(vec: Vector3D) -> (Matrix3,Matrix3,Matrix3) { + + let (xylen,len) = { + let xyl = vec[0]*vec[0] + vec[1]*vec[1]; + (xyl.sqrt(), (xyl+vec[2]*vec[2]).sqrt()) + }; + + if xylen.abs()<1E-7 { + let co = 1./len; + if vec[2] < 0. { + warn!("trying to get the derivative of a rotation near a breaking point: expect pure jank"); + return ( + //Matrix3::new([[-1.,0.,0.], [0.,1.,0.], [0.,0.,-1.]]) <- the value to derive off of: a +y rotation + Matrix3::new([[0.,0.,-co], [0.,0.,0.], [co,0.,0.]]), // +x change -> +y rotation + Matrix3::new([[0.,0.,0.], [0.,0.,-co], [0.,-co,0.]]), // +y change -> -x rotation + Matrix3::new([[0.,0.,0.], [0.,0.,0.], [0.,0.,0.]]), // +z change -> nuthin + ) + } + else { + return ( + //Matrix3::new([[1.,0.,0.], [0.,1.,0.], [0.,0.,1.]]) <- the value to derive off of + Matrix3::new([[0.,0.,-co], [0.,0.,0.], [co,0.,0.]]), // +x change -> -y rotation + Matrix3::new([[0.,0.,0.], [0.,0.,-co], [0.,co,0.]]), // +y change -> +x rotation + Matrix3::new([[0.,0.,0.], [0.,0.,0.], [0.,0.,0.]]), // +z change -> nuthin + ) + } + } + + let inv_len = 1./len; + let inv_len2 = inv_len*inv_len; + let inv_len3 = inv_len2*inv_len; + let inv_xy = 1./xylen; + let inv_xy2 = inv_xy*inv_xy; + let inv_xy3 = inv_xy2*inv_xy; + + let c = vec[2]/len; // needed + let dcdz = 1./len - vec[2]*vec[2]*inv_len3; + let dcdx = -vec[2]*vec[0]*inv_len3; + let dcdy = -vec[2]*vec[1]*inv_len3; + let s = xylen/len; + let dsdx = vec[0]*inv_len*(inv_xy - xylen*inv_len2); + let dsdy = vec[1]*inv_len*(inv_xy - xylen*inv_len2); + let dsdz = -xylen*vec[2]*inv_len3; + + let t = 1. - c; + + let x2 = -vec[1]*inv_xy; + let dx2dx = vec[1]*vec[0]*inv_xy3; + let dx2dy = inv_xy * (-1. + vec[1]*vec[1]*inv_xy2); + + let y2 = vec[0]/xylen; + let dy2dy = -vec[1]*vec[0]*inv_xy3; + let dy2dx = inv_xy * (1. - vec[0]*vec[0]*inv_xy2); + + let tx = t*x2; + let dtxdx = -dcdx*x2 + t*dx2dx; + let dtxdy = -dcdy*x2 + t*dx2dy; + let dtxdz = -dcdz*x2; + + //let sx = s*x2; // needed + let dsxdx = dsdx*x2 + s*dx2dx; + let dsxdy = dsdy*x2 + s*dx2dy; + let dsxdz = dsdz*x2; + + //let sy = s*y2; //needed + let dsydx = dsdx*y2 + s*dy2dx; + let dsydy = dsdy*y2 + s*dy2dy; + let dsydz = dsdz*y2; + + //let t1 = tx*x2 +c; // needed + let dt1dx = dcdx + dtxdx*x2 + tx*dx2dx; + let dt1dy = dcdy + dtxdy*x2 + tx*dx2dy; + let dt1dz = dcdz + dtxdz*x2; + + //let t2 = tx*y2; // needed + let dt2dx = dtxdx*y2 + tx*dy2dx; + let dt2dy = dtxdy*y2 + tx*dy2dy; + let dt2dz = dtxdz*y2; + + //let t3 = t*y2*y2 +c; // needed + let dt3dx = -dcdx*y2*y2 + 2.*t*y2*dy2dx +dcdx; + let dt3dy = -dcdy*y2*y2 + 2.*t*y2*dy2dy +dcdy; + let dt3dz = -dcdz*y2*y2 +dcdz; + + return ( + // Matrix3::new([ + // [tx*x2 +c, tx*y2, -sy], + // [tx*y2, t*y2*y2 + c, sx], + // [sy, -sx, c], + // ]), + Matrix3::new([ + [dt1dx, dt2dx, -dsydx], + [dt2dx, dt3dx, dsxdx], + [dsydx, -dsxdx, dcdx], + ]), + Matrix3::new([ + [dt1dy, dt2dy, -dsydy], + [dt2dy, dt3dy, dsxdy], + [dsydy, -dsxdy, dcdy], + ]), + Matrix3::new([ + [dt1dz, dt2dz, -dsydz], + [dt2dz, dt3dz, dsxdz], + [dsydz, -dsxdz, dcdz], + ]), + ); +} + +/// result structure for canonical_vector_for_single_triplet +#[derive(Default,Debug)] +pub(crate) struct VectorResult{ + /// the canonical vector itelf + pub vect: Vector3D, + /// gradients of the canonical vector, as an array of three matrices + /// matrix is [quantity_component,gradient_component] + /// each matrix corresponds to a different atom to gradiate upon + pub grads: [Option<(usize,Matrix3)>;3], +} + +/// From a list of bond/atom triplets, compute the 'canonical third vector'. +/// Each triplet is composed of two 'neighbors': one is the center of a pair of atoms +/// (first and second atom), and the other is a simple atom (third atom). +/// The third vector of such a triplet is the vector from the center of the atom pair and to the third atom. +/// this third vector becomes canonical when the frame of reference is rotated to express +/// the triplet's bond vector (vector from the first and to the second atom) as (0,0,+z). +/// +/// Users can request either a "full" neighbor list (including an entry for both +/// `i-j +k` triplets and `j-i +k` triplets) or save memory/computational by only +/// working with "half" neighbor list (only including one entry for each `i-j +k` +/// bond) +/// When using a half neighbor list, i and j are ordered so the atom with the smallest species comes first. +/// +/// The two first atoms must not be the same atom, but the third atom may be one of them, +/// if the `bond_conbtribution` option is active +/// (When periodic boundaries arise, atom which must not be the same may be images of each other.) +/// +/// This sample produces a single property (`"distance"`) with three components +/// (`"vector_direction"`) containing the x, y, and z component of the vector from +/// the center of the triplet's 'bond' to the triplet's 'third atom', in the bond's canonical orientation. +/// +/// In addition to the atom indexes, the samples also contain a pair and triplet index, +/// to be able to distinguish between multiple triplets involving the same atoms +/// (which can occur in periodic boundary conditions when the cutoffs are larger than the unit cell). +pub(crate) fn canonical_vector_for_single_triplet( + triplet: &BATripletInfo, + invert: bool, + compute_grad: bool, + mtx_cache: &mut BTreeMap<(usize,usize,[i32;3],bool),Matrix3>, + dmtx_cache: &mut BTreeMap<(usize,usize,[i32;3],bool),(Matrix3,Matrix3,Matrix3)>, +) -> Result { + + let bond_vector = triplet.bond_vector; + let third_vector = triplet.third_vector; + let (atom_i,atom_j,bond_vector) = if invert { + (triplet.atom_j, triplet.atom_i, -bond_vector) + } else { + (triplet.atom_i, triplet.atom_j, bond_vector) + }; + + let mut res = VectorResult::default(); + + if triplet.is_self_contrib { + let vec_len = third_vector.norm(); + let vec_len = if third_vector * bond_vector > 0. { + // third atom on second atom + vec_len + } else { + // third atom on first atom + -vec_len + }; + res.vect[2] = vec_len; + + if compute_grad { + let inv_len = 1./vec_len; + + res.grads[0] = Some((atom_i,Matrix3::new([ + [ -0.25* inv_len * third_vector[0], 0., 0.], + [ 0., -0.25* inv_len * third_vector[0], 0.], + [ 0., 0., -0.25* inv_len * third_vector[0]], + ]))); + res.grads[1] = Some((atom_j,Matrix3::new([ + [ 0.25* inv_len * third_vector[0], 0., 0.], + [ 0., 0.25* inv_len * third_vector[0], 0.], + [ 0., 0., 0.25* inv_len * third_vector[0]], + ]))); + + } + } else { + + let tf_mtx = match mtx_cache.entry((triplet.atom_i,triplet.atom_j,triplet.bond_cell_shift,invert)) { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + entry.insert(rotate_vector_to_z(bond_vector)).clone() + }, + }; + res.vect = tf_mtx * third_vector; + + if compute_grad { + + // for a transformed vector v from an untransformed vector u, + // dv = TF*du + dTF*u + // also: the indexing of the gradient array is: i_gradsample, derivation_component, value_component, i_property + + let du_term = -0.5* tf_mtx; + let (tf_mtx_dx, tf_mtx_dy, tf_mtx_dz) = match dmtx_cache.entry((triplet.atom_i,triplet.atom_j,triplet.bond_cell_shift,invert)) { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + entry.insert(rotate_vector_to_z_derivatives(bond_vector)).clone() + }, + }; + + let dmat_term_dx = tf_mtx_dx * third_vector; + let dmat_term_dy = tf_mtx_dy * third_vector; + let dmat_term_dz = tf_mtx_dz * third_vector; + + res.grads[0] = Some((atom_i,Matrix3::new([ + [-dmat_term_dx[0] + du_term[0][0], -dmat_term_dy[0] + du_term[0][1], -dmat_term_dz[0] + du_term[0][2]], + [-dmat_term_dx[1] + du_term[1][0], -dmat_term_dy[1] + du_term[1][1], -dmat_term_dz[1] + du_term[1][2]], + [-dmat_term_dx[2] + du_term[2][0], -dmat_term_dy[2] + du_term[2][1], -dmat_term_dz[2] + du_term[2][2]], + ]))); + res.grads[1] = Some((atom_j,Matrix3::new([ + [dmat_term_dx[0] + du_term[0][0], dmat_term_dy[0] + du_term[0][1], dmat_term_dz[0] + du_term[0][2]], + [dmat_term_dx[1] + du_term[1][0], dmat_term_dy[1] + du_term[1][1], dmat_term_dz[1] + du_term[1][2]], + [dmat_term_dx[2] + du_term[2][0], dmat_term_dy[2] + du_term[2][1], dmat_term_dz[2] + du_term[2][2]], + ]))); + res.grads[2] = Some((triplet.atom_k,tf_mtx)); + } + } + return Ok(res); +} + +/// get the result of canonical_vector_for_single_triplet +/// and store it in a TensorBlock +pub(crate) fn canonical_vector_for_single_triplet_inplace( + triplet: &BATripletInfo, + out_block: &mut TensorBlockRefMut, + sample_i: usize, + system_i: usize, + invert: bool, + mtx_cache: &mut BTreeMap<(usize,usize,[i32;3],bool),Matrix3>, + dmtx_cache: &mut BTreeMap<(usize,usize,[i32;3],bool),(Matrix3,Matrix3,Matrix3)>, +) -> Result<(),Error> { + let compute_grad = out_block.gradient_mut("positions").is_some(); + let block_data = out_block.data_mut(); + let array = block_data.values.to_array_mut(); + + let res = canonical_vector_for_single_triplet( + triplet, + invert, + compute_grad, + mtx_cache, + dmtx_cache + )?; + + array[[sample_i, 0, 0]] = res.vect[0]; + array[[sample_i, 1, 0]] = res.vect[1]; + array[[sample_i, 2, 0]] = res.vect[2]; + + if let Some(mut gradient) = out_block.gradient_mut("positions") { + let gradient = gradient.data_mut(); + let array = gradient.values.to_array_mut(); + + for grad in res.grads { + if let Some((atom_i, grad_mtx)) = grad { + let grad_sample_i = gradient.samples.position(&[ + sample_i.into(), system_i.into(), atom_i.into() + ]).expect("missing gradient sample"); + + array[[grad_sample_i, 0, 0, 0]] = grad_mtx[0][0]; + array[[grad_sample_i, 1, 0, 0]] = grad_mtx[0][1]; + array[[grad_sample_i, 2, 0, 0]] = grad_mtx[0][2]; + array[[grad_sample_i, 0, 1, 0]] = grad_mtx[1][0]; + array[[grad_sample_i, 1, 1, 0]] = grad_mtx[1][1]; + array[[grad_sample_i, 2, 1, 0]] = grad_mtx[1][2]; + array[[grad_sample_i, 0, 2, 0]] = grad_mtx[2][0]; + array[[grad_sample_i, 1, 2, 0]] = grad_mtx[2][1]; + array[[grad_sample_i, 2, 2, 0]] = grad_mtx[2][2]; + } + } + } + Ok(()) +} + + +/// Contribution of a single triplet to the spherical expansion +pub(super) struct ExpansionContribution { + /// Values of the contribution. The shape is (lm, n), where the lm index + /// runs over both l and m + pub values: ndarray::Array2, + /// Gradients of the contribution w.r.t. the canonical vector of the triplet. + /// The shape is (x/y/z, lm, n). + pub gradients: Option>, +} + +impl ExpansionContribution { + pub fn new(max_radial: usize, max_angular: usize, do_gradients: bool) -> Self { + let lm_shape = (max_angular + 1) * (max_angular + 1); + Self { + values: ndarray::Array2::from_elem((lm_shape, max_radial), 0.0), + gradients: if do_gradients { + Some(ndarray::Array3::from_elem((3, lm_shape, max_radial), 0.0)) + } else { + None + } + } + } +} + +#[derive(Debug, Clone)] +#[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)] +pub(super) struct RawSphericalExpansionParameters { + /// the cutoff beyond which we neglect neighbors (often called third cutoff) + pub cutoff: f64, + /// Number of radial basis function to use in the expansion + pub max_radial: usize, + /// Number of spherical harmonics to use in the expansion + pub max_angular: usize, + /// Width of the atom-centered gaussian used to create the atomic density + pub atomic_gaussian_width: f64, + + /// Radial basis to use for the radial integral + pub radial_basis: RadialBasis, + /// Cutoff function used to smooth the behavior around the cutoff radius + pub cutoff_function: CutoffFunction, + /// radial scaling can be used to reduce the importance of neighbor atoms + /// further away from the center, usually improving the performance of the + /// model + pub radial_scaling: RadialScaling, +} + +pub(super) struct RawSphericalExpansion{ + parameters: RawSphericalExpansionParameters, + /// implementation + cached allocation to compute the radial integral for a + /// single pair + radial_integral: ThreadLocal>, + /// implementation + cached allocation to compute the spherical harmonics + /// for a single pair + spherical_harmonics: ThreadLocal>, +} + +impl RawSphericalExpansion { + pub(super) fn new(parameters: RawSphericalExpansionParameters) -> Self { + Self{ + parameters, + radial_integral: ThreadLocal::new(), + spherical_harmonics: ThreadLocal::new(), + } + } + + pub(super) fn parameters(&self) -> &RawSphericalExpansionParameters { + &self.parameters + } + + pub(super) fn make_contribution_buffer(&self, do_gradients: bool) -> ExpansionContribution { + ExpansionContribution::new( + self.parameters.max_radial, + self.parameters.max_angular, + do_gradients, + ) + } + + /// Compute the product of radial scaling & cutoff smoothing functions + pub(crate) fn scaling_functions(&self, r: f64) -> f64 { + let cutoff = self.parameters.cutoff_function.compute(r, self.parameters.cutoff); + let scaling = self.parameters.radial_scaling.compute(r); + return cutoff * scaling; + } + + /// Compute the gradient of the product of radial scaling & cutoff smoothing functions + pub(crate) fn scaling_functions_gradient(&self, r: f64) -> f64 { + let cutoff = self.parameters.cutoff_function.compute(r, self.parameters.cutoff); + let cutoff_grad = self.parameters.cutoff_function.derivative(r, self.parameters.cutoff); + + let scaling = self.parameters.radial_scaling.compute(r); + let scaling_grad = self.parameters.radial_scaling.derivative(r); + + return cutoff_grad * scaling + cutoff * scaling_grad; + } + + + /// compute the spherical expansion coefficients associated with + /// a single center->neighbor vector (`vector`), + /// and store it in a ExpansionContribution buffer (`contribution`). + /// `gradient_orientation` serves two purposes: + /// it tells whether or not the gradient should be computed, + /// and it deals with the case where the vector (and the spherical expansion) take place + /// in a rotated/scaled/sheared frame of reference: its three vectors contain + /// the changes of the vector (in the rotated frame of reference) when adding the + /// +x, +y, and +z vectors from the 'real' frame of reference. + /// `extra_scaling` simply applies a scaling factor to all coefficients + pub(super) fn compute_coefficients(&self, contribution: &mut ExpansionContribution, vector: Vector3D, extra_scaling: f64, gradient_orientation: Option<(Vector3D,Vector3D,Vector3D)>){ + let mut radial_integral = self.radial_integral.get_or(|| { + let radial_integral = SoapRadialIntegralCache::new( + self.parameters.radial_basis.clone(), + SoapRadialIntegralParameters { + max_radial: self.parameters.max_radial, + max_angular: self.parameters.max_angular, + atomic_gaussian_width: self.parameters.atomic_gaussian_width, + cutoff: self.parameters.cutoff, + } + ).expect("invalid radial integral parameters"); + return RefCell::new(radial_integral); + }).borrow_mut(); + + let mut spherical_harmonics = self.spherical_harmonics.get_or(|| { + RefCell::new(SphericalHarmonicsCache::new(self.parameters.max_angular)) + }).borrow_mut(); + + let distance = vector.norm(); + let direction = vector/distance; + // Compute the three factors that appear in the center contribution. + // Note that this is simply the pair contribution for the special + // case where the pair distance is zero. + radial_integral.compute(distance, gradient_orientation.is_some()); + spherical_harmonics.compute(direction, gradient_orientation.is_some()); + + let f_scaling = self.scaling_functions(distance) * extra_scaling; + + let (values, gradient_values_o) = (&mut contribution.values, contribution.gradients.as_mut()); + + debug_assert_eq!( + values.shape(), + [(self.parameters.max_angular+1)*(self.parameters.max_angular+1), self.parameters.max_radial] + ); + for l in 0..=self.parameters.max_angular { + let l_offset = l*l; + let msize = 2*l+1; + //values.slice_mut(s![l_offset..l_offset+msize, ..]) *= spherical_harmonics.values.slice(l); + for m in 0..msize { + let lm = l_offset+m; + for n in 0..self.parameters.max_radial { + values[[lm, n]] = spherical_harmonics.values[lm] + * radial_integral.values[[l,n]] + * f_scaling; + } + } + } + + if let Some((dvdx,dvdy,dvdz)) = gradient_orientation { + unimplemented!("ööps, gradient not ready yet"); + let gradient_values = gradient_values_o.unwrap(); + + let ilen = 1./distance; + let dlendv = vector*ilen; + let dlendx = dlendv*dvdx; + let dlendy = dlendv*dvdy; + let dlendz = dlendv*dvdz; + let ddirdx = dvdx*ilen - vector*dlendx*ilen*ilen; + let ddirdy = dvdy*ilen - vector*dlendy*ilen*ilen; + let ddirdz = dvdy*ilen - vector*dlendz*ilen*ilen; + + let single_grad = |l,n,m,dlenda,ddirda: Vector3D| { + f_scaling * ( + radial_integral.gradients[[l,n]]*dlenda*spherical_harmonics.values[[l as isize,m as isize]] + + radial_integral.values[[l,n]]*( + spherical_harmonics.gradients[0][[l as isize,m as isize]]*ddirda[0] + +spherical_harmonics.gradients[1][[l as isize,m as isize]]*ddirda[1] + +spherical_harmonics.gradients[2][[l as isize,m as isize]]*ddirda[2] + ) + // TODO scaling_function_gradient + ) + }; + + for l in 0..=self.parameters.max_angular { + let l_offset = l*l; + let msize = 2*l+1; + for m in 0..(msize) { + let lm = l_offset+m; + for n in 0..self.parameters.max_radial { + gradient_values[[0,lm,n]] = single_grad(l,n,m,dlendx,ddirdx); + gradient_values[[1,lm,n]] = single_grad(l,n,m,dlendy,ddirdy); + gradient_values[[2,lm,n]] = single_grad(l,n,m,dlendz,ddirdz); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + use super::Vector3D; + + use approx::assert_relative_eq; + + use crate::systems::BATripletNeighborList; + use crate::systems::test_utils::test_systems; + use super::{RawSphericalExpansion,RawSphericalExpansionParameters}; + use super::canonical_vector_for_single_triplet; + //use super::super::CalculatorBase; + + #[test] + fn half_neighbor_list() { + let pre_calculator = BATripletNeighborList{ + cutoffs: [2.0,2.0], + }; + + let mut systems = test_systems(&["water"]); + let expected = &[[ + [0.0, 0.0, -0.478948537162397], // SC 0 1 0 + [0.0, 0.0, 0.478948537162397], // SC 0 1 1 + [0.0, 0.9289563, -0.7126298], // 0 1 2 + [0.0, 0.0, -0.478948537162397], // SC 1 0 1 + [0.0, 0.0, 0.478948537162397], // SC 1 0 0 + [0.0, -0.9289563, -0.7126298], // 1 0 2 + [0.0, 0.0, -0.75545], // SC 1 2 1 + [0.0, 0.0, 0.75545], // SC 1 2 2 + [0.0, 0.58895, 0.0], // 1 2 0 + ]]; + for (system,expected) in systems.iter_mut().zip(expected) { + let mut mtx_cache = BTreeMap::new(); + let mut dmtx_cache = BTreeMap::new(); + pre_calculator.ensure_computed_for_system(system).unwrap(); + let triplets = pre_calculator.get_for_system(system).unwrap() + .into_iter().filter(|v| v.bond_cell_shift == [0,0,0] && v.third_cell_shift == [0,0,0]); + for (expected,triplet) in expected.iter().zip(triplets) { + let res = canonical_vector_for_single_triplet(&triplet, false, false, &mut mtx_cache, &mut dmtx_cache).unwrap(); + assert_relative_eq!(res.vect, Vector3D::new(expected[0],expected[1],expected[2]), max_relative=1e-6); + } + } + } + + #[test] + fn full_neighbor_list() { + let pre_calculator = BATripletNeighborList{ + cutoffs: [2.0,2.0], + }; + + let mut systems = test_systems(&["water"]); + let expected = &[[ + [0.0, 0.0, 0.478948537162397], // SC 0 1 0 + [0.0, 0.0, -0.478948537162397], // SC 0 1 1 + [0.0, -0.9289563, 0.7126298], // 0 1 2 + [0.0, 0.0, 0.478948537162397], // SC 1 0 1 + [0.0, 0.0, -0.478948537162397], // SC 1 0 0 + [0.0, 0.9289563, 0.7126298], // 1 0 2 + [0.0, 0.0, 0.75545], // SC 1 2 1 + [0.0, 0.0, -0.75545], // SC 1 2 2 + [0.0, -0.58895, 0.0], // 1 2 0 + ]]; + for (system,expected) in systems.iter_mut().zip(expected) { + let mut mtx_cache = BTreeMap::new(); + let mut dmtx_cache = BTreeMap::new(); + pre_calculator.ensure_computed_for_system(system).unwrap(); + let triplets = pre_calculator.get_for_system(system).unwrap() + .into_iter().filter(|v| v.bond_cell_shift == [0,0,0] && v.third_cell_shift == [0,0,0]); + for (expected,triplet) in expected.iter().zip(triplets) { + let res = canonical_vector_for_single_triplet(&triplet, true, false, &mut mtx_cache, &mut dmtx_cache).unwrap(); + assert_relative_eq!(res.vect, Vector3D::new(expected[0],expected[1],expected[2]), max_relative=1e-6); + } + } + } + + // note: the following test does pass, but gradients are disabled because we discovered that + // the values of this calculator ARE NOT CONTINUOUS around the values of bond_vector == (0,0,-z) + // //// + // #[test] + // fn finite_differences_positions() { + // // half neighbor list + // let calculator = Calculator::from(Box::new(BANeighborList::Half(HalfBANeighborList{ + // cutoffs: [2.0,3.0], + // bond_contribution: false, + // })) as Box); + + // let system = test_system("water"); + // let options = crate::calculators::tests_utils::FinalDifferenceOptions { + // displacement: 1e-6, + // max_relative: 1e-9, + // epsilon: 1e-16, + // }; + // crate::calculators::tests_utils::finite_differences_positions(calculator, &system, options); + + // // full neighbor list + // let calculator = Calculator::from(Box::new(BANeighborList::Full(FullBANeighborList{ + // cutoffs: [2.0,3.0], + // bond_contribution: false, + // })) as Box); + // crate::calculators::tests_utils::finite_differences_positions(calculator, &system, options); + // } + + use super::{RadialBasis,RadialScaling,CutoffFunction}; + + #[test] + fn spherical_expansion() { + + let expected = [ + [[0.16902879658926248, 0.028869505770363096, -0.012939303519269344], + [0.0, 0.0, -0.0], + [0.26212372007374773, 0.04923860892292029, -0.02052369607421798], + [0.0, 0.0, -0.0], + [0.0, 0.0, -0.0], + [0.0, 0.0, -0.0], + [0.2734914300150501, 0.05977378771423668, -0.022198889165475678], + [0.0, 0.0, -0.0], + [0.0, 0.0, -0.0]], + [[0.16902879658926248, 0.028869505770363096, -0.012939303519269344], + [0.0, 0.0, -0.0], + [0.0, 0.0, -0.0], + [0.26212372007374773, 0.04923860892292029, -0.02052369607421798], + [0.0, 0.0, -0.0], + [0.0, 0.0, -0.0], + [-0.13674571500752503, -0.029886893857118332, 0.011099444582737835], + [0.0, 0.0, -0.0], + [0.23685052611036728, 0.051765618640947135, -0.019224801953097073]], + [[0.055690489760816295, 0.06300370381466462, -0.002920081734154629], + [0.0, 0.0, -0.0], + [0.06460583443346805, 0.07393339000508466, -0.003326135960304247], + [0.06460583443346805, 0.07393339000508466, -0.003326135960304247], + [0.0, 0.0, -0.0], + [0.0, 0.0, -0.0], + [0.02647774981023968, 0.030983159900568692, -0.0013111835538098728], + [0.09172161588286466, 0.10732881425363133, -0.004542073066494842], + [0.04586080794143233, 0.053664407126815666, -0.002271036533247421]], + [[0.11139019922564429, 0.05256952008314618, -0.00976648821140304], + [-0.09132548515017183, -0.044385580585388364, 0.008057500982983558], + [0.15220914191695306, 0.07397596764231394, -0.013429168304972592], + [0.015220914191695308, 0.007397596764231396, -0.00134291683049726], + [-0.014919461593876988, -0.007651634517671902, 0.001329600535188854], + [-0.1491946159387699, -0.07651634517671901, 0.013296005351888539], + [0.11700350769036943, 0.06000672829237635, -0.010427180998805892], + [0.024865769323128322, 0.012752724196119839, -0.0022160008919814237], + [-0.04351509631547453, -0.022317267343209705, 0.0038780015609674885]], + ].map(|s|ndarray::arr2(&s)); + + let vectors = [ + (0., 0., 1.), + (1., 0., 0.), + (1., 0., 1.), + (0.1,-0.6,1.), + ].into_iter() + .map(|(x,y,z)|Vector3D::new(x,y,z)) + .collect::>(); + + let expander = RawSphericalExpansion::new(RawSphericalExpansionParameters{ + cutoff: 3.5, + max_radial: 3, + max_angular: 2, + atomic_gaussian_width: 0.3, + radial_basis: RadialBasis::splined_gto(1e-8), + radial_scaling: RadialScaling::Willatt2018 { scale: 1.5, rate: 0.8, exponent: 2.0}, + cutoff_function: CutoffFunction::ShiftedCosine { width: 0.5 }, + }); + let mut contrib = expander.make_contribution_buffer(false); + + for (vector,expected) in vectors.into_iter().zip(expected) { + expander.compute_coefficients(&mut contrib, vector, 1., None); + assert_relative_eq!(contrib.values, expected, max_relative=1E-6); + } + } +} diff --git a/rascaline/src/calculators/bondatom/mod.rs b/rascaline/src/calculators/bondatom/mod.rs new file mode 100644 index 000000000..af39c0de5 --- /dev/null +++ b/rascaline/src/calculators/bondatom/mod.rs @@ -0,0 +1,39 @@ +pub mod spherical_expansion_bondcentered; + +mod bond_atom_math; +pub(crate) use bond_atom_math::canonical_vector_for_single_triplet; +use bond_atom_math::{RawSphericalExpansion,RawSphericalExpansionParameters,ExpansionContribution}; + +//pub use bondatom_neighbor_list::BANeighborList; +pub use spherical_expansion_bondcentered::{ + SphericalExpansionForBonds, + SphericalExpansionForBondsParameters, +}; + + + +const FEATURE_GATE: &'static str = "RASCALINE_EXPERIMENTAL_BOND_ATOM_SPX"; +fn get_feature_gate() -> bool { + use std::env; + if let Ok(var) = env::var(FEATURE_GATE) { + if var.len() == 0 { + false + } else { + let var = var.to_lowercase(); + !(&var=="0" || var == "false" || var == "no" || var == "off") + } + } else { + false + } +} +fn assert_feature_gate() { + if !get_feature_gate() { + if !get_feature_gate() { + unimplemented!("Bond-Atom spherical expansion requires UNSTABLE feature gate: {}", FEATURE_GATE); + } + } +} +fn set_feature_gate() { + use std::env; + env::set_var(FEATURE_GATE, "true"); +} \ No newline at end of file diff --git a/rascaline/src/calculators/bondatom/spherical_expansion_bondcentered.rs b/rascaline/src/calculators/bondatom/spherical_expansion_bondcentered.rs new file mode 100644 index 000000000..d45b0f3fb --- /dev/null +++ b/rascaline/src/calculators/bondatom/spherical_expansion_bondcentered.rs @@ -0,0 +1,717 @@ +use std::collections::BTreeMap; +use std::collections::btree_map::Entry; +use std::cell::RefCell; + +use ndarray::s; +use rayon::prelude::*; + +use metatensor::{LabelsBuilder, Labels, LabelValue}; +use metatensor::TensorMap; + +use crate::{Error, System}; + +use crate::labels::{SamplesBuilder, AtomicTypeFilter, BondCenteredSamples}; +use crate::labels::{KeysBuilder, TwoCentersSingleNeighborsTypesKeys}; + +use crate::calculators::{CalculatorBase,GradientsOptions}; +use crate::calculators::{split_tensor_map_by_system, array_mut_for_system}; +use crate::calculators::soap::{CutoffFunction, RadialScaling}; +use crate::calculators::radial_basis::RadialBasis; +use crate::calculators::soap::{ + SoapRadialIntegralParameters, + SoapRadialIntegralCache, +}; + +use crate::systems::BATripletNeighborList; +use super::{canonical_vector_for_single_triplet,ExpansionContribution,RawSphericalExpansion,RawSphericalExpansionParameters}; + +use super::assert_feature_gate; + +/// Parameters for spherical expansion calculator for bond-centered neighbor densities. +/// +/// (The spherical expansion is at the core of representations in the SOAP +/// (Smooth Overlap of Atomic Positions) family. See [this review +/// article](https://doi.org/10.1063/1.5090481) for more information on the SOAP +/// representation, and [this paper](https://doi.org/10.1063/5.0044689) for +/// information on how it is implemented in rascaline.) +/// +/// This calculator is only needed to characterize local environments that are centered +/// on a pair of atoms rather than a single one. +#[derive(Debug, Clone)] +#[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)] +pub struct SphericalExpansionForBondsParameters { + /// Spherical cutoffs to use for atomic environments + pub(super) cutoffs: [f64;2], + /// Number of radial basis function to use in the expansion + pub max_radial: usize, + /// Number of spherical harmonics to use in the expansion + pub max_angular: usize, + /// Width of the atom-centered gaussian used to create the atomic density + pub atomic_gaussian_width: f64, + /// Weight of the central atom contribution to the + /// features. If `1` the center atom contribution is weighted the same + /// as any other contribution. If `0` the central atom does not + /// contribute to the features at all. + pub center_atoms_weight: f64, + /// Radial basis to use for the radial integral + pub radial_basis: RadialBasis, + /// Cutoff function used to smooth the behavior around the cutoff radius + pub cutoff_function: CutoffFunction, + /// radial scaling can be used to reduce the importance of neighbor atoms + /// further away from the center, usually improving the performance of the + /// model + #[serde(default)] + pub radial_scaling: RadialScaling, +} + +impl SphericalExpansionForBondsParameters { + /// Validate all the parameters + pub fn validate(&self) -> Result<(), Error> { + assert_feature_gate(); + self.cutoff_function.validate()?; + self.radial_scaling.validate()?; + + // try constructing a radial integral + SoapRadialIntegralCache::new(self.radial_basis.clone(), SoapRadialIntegralParameters { + max_radial: self.max_radial, + max_angular: self.max_angular, + atomic_gaussian_width: self.atomic_gaussian_width, + cutoff: self.third_cutoff(), + })?; + + return Ok(()); + } + pub fn bond_cutoff(&self) -> f64 { + self.cutoffs[0] + } + pub fn third_cutoff(&self) -> f64 { + self.cutoffs[1] + } + + fn decompose(self) -> (RawSphericalExpansionParameters,f64,f64){ + let (bond_cutoff,center_atoms_weight) = (self.bond_cutoff(),self.center_atoms_weight); + ( + RawSphericalExpansionParameters{ + cutoff: self.third_cutoff(), + max_radial: self.max_radial, + max_angular: self.max_angular, + atomic_gaussian_width: self.atomic_gaussian_width, + radial_basis: self.radial_basis, + cutoff_function: self.cutoff_function, + radial_scaling: self.radial_scaling, + }, + bond_cutoff, + center_atoms_weight, + )} + fn recompose(expansion_params: RawSphericalExpansionParameters, bond_cutoff: f64, center_atoms_weight: f64) -> Self { + Self{ + cutoffs: [bond_cutoff, expansion_params.cutoff], + max_radial: expansion_params.max_radial, + max_angular: expansion_params.max_angular, + atomic_gaussian_width: expansion_params.atomic_gaussian_width, + radial_basis: expansion_params.radial_basis, + cutoff_function: expansion_params.cutoff_function, + radial_scaling: expansion_params.radial_scaling, + center_atoms_weight, + } + } + +} + + +/// The actual calculator used to compute SOAP-like spherical expansion coefficients for bond-centered environments +/// In other words, the spherical expansion of the neighbor density function centered on the center of a bond, +/// 'after' rotating the system so that the bond is aligned with the z axis. +/// +/// This radial+angular decomposition yields coefficients with labels `n` (radial), and `l` and `m` (angular) +/// as a Calculator, it yields tonsorblocks of with individual values of `l` +/// and individual atomic types for center_1, center_2, and neighbor. +/// Each block has components for each possible value of `m`, and properties for each value of `n`. +/// a given sample corresponds to a single center bond (a pair of center atoms) within a given structure. +pub struct SphericalExpansionForBonds { + /// The object in charge of computing the vectors and distances + /// between the bond and the lone atom of the BA triplet (after rotating the system to put the bond in it canonical orientation) + distance_calculator: BATripletNeighborList, + /// actual spherical expansion object + raw_expansion: RawSphericalExpansion, + /// a weight multiplier for expansion coefficients from self-contributions + center_atoms_weight: f64, +} + +impl SphericalExpansionForBonds { + /// Create a new `SphericalExpansion` calculator with the given parameters + pub fn new(parameters: SphericalExpansionForBondsParameters) -> Result { + parameters.validate()?; + let cutoffs = parameters.cutoffs.clone(); + let (exp_params, _bond_cut, center_weight) = parameters.decompose(); + + return Ok(Self { + center_atoms_weight: center_weight, + raw_expansion: RawSphericalExpansion::new(exp_params), + distance_calculator: BATripletNeighborList{ + cutoffs, + }, + }); + } + + + /// a smart-ish way to obtain the coefficients of all bond expansions: + /// this function's API is designed to be resource-efficient for both SphericalExpansionForBondType and + /// SphericalExpansionForBonds, while being computationally efficient for the underlying BANeighborList calculator. + pub(super) fn get_coefficients_for<'a>( + &'a self, system: &'a System, + s1: i32, s2: i32, s3_list: &'a Vec, + do_gradients: GradientsOptions, + ) -> Result>)> + 'a, Error> { + + let types = system.types().unwrap(); + + + let pre_iter = s3_list.iter().flat_map(|s3|{ + self.distance_calculator.get_per_system_per_type_enumerated(system,s1,s2,*s3).unwrap().into_iter() + }).flat_map(|(triplet_i,triplet)| { + let invert: &'static [bool] = { + if s1==s2 {&[false,true]} + else if types[triplet.atom_i] == s1 {&[false]} + else {&[true]} + }; + invert.iter().map(move |invert|(triplet_i,triplet,*invert)) + }).collect::>(); + + let contribution = std::rc::Rc::new(RefCell::new( + self.raw_expansion.make_contribution_buffer(do_gradients.any()) + )); + + let mut mtx_cache = BTreeMap::new(); + let mut dmtx_cache = BTreeMap::new(); + + return Ok(pre_iter.into_iter().map(move |(triplet_i,triplet,invert)| { + let mut vector = canonical_vector_for_single_triplet(&triplet, invert, false, &mut mtx_cache, &mut dmtx_cache).unwrap(); + let weight = if triplet.is_self_contrib {self.center_atoms_weight} else {1.0}; + self.raw_expansion.compute_coefficients(&mut *contribution.borrow_mut(), vector.vect,weight,None); + (triplet_i, invert, contribution.clone()) + })); + + } + +} + +impl CalculatorBase for SphericalExpansionForBonds { + fn name(&self) -> String { + "spherical expansion".into() + } + + fn cutoffs(&self) -> &[f64] { + &self.distance_calculator.cutoffs + } + + fn parameters(&self) -> String { + let params = SphericalExpansionForBondsParameters::recompose( + (*self.raw_expansion.parameters()).clone(), + self.distance_calculator.bond_cutoff(), + self.center_atoms_weight, + ); + serde_json::to_string(¶ms).expect("failed to serialize to JSON") + } + + fn keys(&self, systems: &mut [System]) -> Result { + let builder = TwoCentersSingleNeighborsTypesKeys { + cutoffs: self.distance_calculator.cutoffs, + self_contributions: true, + raw_triplets: &self.distance_calculator, + }; + let keys = builder.keys(systems)?; + + let mut builder = LabelsBuilder::new(vec!["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"]); + for &[center_1_type, center_2_type, neighbor_type] in keys.iter_fixed_size() { + for o3_lambda in 0..=self.raw_expansion.parameters().max_angular { + builder.add(&[o3_lambda.into(), center_1_type, center_2_type, neighbor_type]); + } + } + + return Ok(builder.finish()); + } + + fn sample_names(&self) -> Vec<&str> { + BondCenteredSamples::sample_names() + } + + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { + assert_eq!(keys.names(), ["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"]); + + // only compute the samples once for each `atom_type, neighbor_type`, + // and re-use the results across `o3_lambda`. + let mut samples_per_type = BTreeMap::new(); + for [_, center_1_type, center_2_type, neighbor_type] in keys.iter_fixed_size() { + if samples_per_type.contains_key(&(center_1_type, center_2_type, neighbor_type)) { + continue; + } + + let builder = BondCenteredSamples { + cutoffs: self.distance_calculator.cutoffs, + center_1_type: AtomicTypeFilter::Single(center_1_type.i32()), + center_2_type: AtomicTypeFilter::Single(center_2_type.i32()), + neighbor_type: AtomicTypeFilter::Single(neighbor_type.i32()), + self_contributions: true, + raw_triplets: &self.distance_calculator, + }; + + samples_per_type.insert((center_1_type, center_2_type, neighbor_type), builder.samples(systems)?); + } + + let mut result = Vec::new(); + for [_, center_1_type, center_2_type, neighbor_type] in keys.iter_fixed_size() { + let samples = samples_per_type.get( + &(center_1_type, center_2_type, neighbor_type) + ).expect("missing samples"); + + result.push(samples.clone()); + } + + return Ok(result); + } + + fn supports_gradient(&self, parameter: &str) -> bool { + false // for now, discontinuities are a pain + } + + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { + assert_eq!(keys.names(), ["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"]); + assert_eq!(keys.count(), samples.len()); + + let mut gradient_samples = Vec::new(); + for ([_, center_1_type, center_2_type, neighbor_type], samples) in keys.iter_fixed_size().zip(samples) { + // TODO: we don't need to rebuild the gradient samples for different + // o3_lambda + let builder = BondCenteredSamples { + cutoffs: self.distance_calculator.cutoffs, + center_1_type: AtomicTypeFilter::Single(center_1_type.i32()), + center_2_type: AtomicTypeFilter::Single(center_2_type.i32()), + neighbor_type: AtomicTypeFilter::Single(neighbor_type.i32()), + self_contributions: true, + raw_triplets: &self.distance_calculator, + }; + + gradient_samples.push(builder.gradients_for(systems, samples)?); + } + + return Ok(gradient_samples); + } + + fn components(&self, keys: &Labels) -> Vec> { + assert_eq!(keys.names(), ["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"]); + + // only compute the components once for each `o3_lambda`, + // and re-use the results across `atom_type, neighbor_type`. + let mut component_by_l = BTreeMap::new(); + for [o3_lambda, _, _, _] in keys.iter_fixed_size() { + if component_by_l.contains_key(o3_lambda) { + continue; + } + + let mut component = LabelsBuilder::new(vec!["spherical_harmonics_m"]); + for m in -o3_lambda.i32()..=o3_lambda.i32() { + component.add(&[LabelValue::new(m)]); + } + + let components = vec![component.finish()]; + component_by_l.insert(*o3_lambda, components); + } + + let mut result = Vec::new(); + for [o3_lambda, _, _, _] in keys.iter_fixed_size() { + let components = component_by_l.get(o3_lambda).expect("missing samples"); + result.push(components.clone()); + } + return result; + } + + fn property_names(&self) -> Vec<&str> { + vec!["n"] + } + + fn properties(&self, keys: &Labels) -> Vec { + let mut properties = LabelsBuilder::new(self.property_names()); + for n in 0..self.raw_expansion.parameters().max_radial { + properties.add(&[n]); + } + let properties = properties.finish(); + + return vec![properties; keys.count()]; + } + + #[time_graph::instrument(name = "SphericalExpansion::compute")] + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { + assert_feature_gate(); + assert_eq!(descriptor.keys().names(), ["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"]); + if descriptor.blocks().len() == 0 { + return Ok(()); + } + + let max_angular = self.raw_expansion.parameters().max_angular; + let l_slices: Vec<_> = (0..=max_angular).map(|l|{ + let lsize = l*l; + let msize = 2*l+1; + lsize..lsize+msize + }).collect(); + + let do_gradients = GradientsOptions { + positions: descriptor.block_by_id(0).gradient("positions").is_some(), + cell: descriptor.block_by_id(0).gradient("cell").is_some(), + strain: descriptor.block_by_id(0).gradient("strain").is_some(), + }; + if do_gradients.positions { + assert!(self.supports_gradient("positions")); + } + if do_gradients.cell { + assert!(self.supports_gradient("cell")); + } + + let radial_selection = descriptor.blocks().iter().map(|b|{ + let prop = b.properties(); + assert_eq!(prop.names(), ["n"]); + prop.iter_fixed_size().map(|&[n]|n.i32()).collect::>() + }).collect::>(); + // first, create some partial-key -> block lookup tables to avoid linear searches within blocks later + + // {(s1,s2,s3) -> i_s3} + let mut s1s2s3_to_is3: BTreeMap<(i32,i32,i32),usize> = BTreeMap::new(); + // {(s1,s2) -> [i_s3->(s3,[l->i_block])]} + let mut s1s2_to_block_ids: BTreeMap<(i32,i32),Vec<(i32,Vec)>> = BTreeMap::new(); + + for (block_i, &[l, s1,s2,s3]) in descriptor.keys().iter_fixed_size().enumerate(){ + let s1=s1.i32(); + let s2=s2.i32(); + let s3=s3.i32(); + let l=l.usize(); + let s1s2_blocks = s1s2_to_block_ids.entry((s1,s2)) + .or_insert_with(Vec::new); + let l_blocks = match s1s2s3_to_is3.entry((s1,s2,s3)) { + Entry::Occupied(i_s3_e) => { + let (s3_b, l_blocks) = & mut s1s2_blocks[*i_s3_e.get()]; + debug_assert_eq!(s3_b,&s3); + l_blocks + }, + Entry::Vacant(i_s3_e) => { + let i_s3 = s1s2_blocks.len(); + i_s3_e.insert(i_s3); + s1s2_blocks.push((s3,vec![usize::MAX;max_angular+1])); + &mut s1s2_blocks[i_s3].1 + }, + }; + l_blocks[l] = block_i; + } + + #[cfg(debug_assertions)]{ + for block in descriptor.blocks() { + assert_eq!(block.samples().names(), ["system", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"]); + } + } + let mut descriptors_by_system = split_tensor_map_by_system(descriptor, systems.len()); + + systems.par_iter_mut() + .zip_eq(&mut descriptors_by_system) + .try_for_each(|(system, descriptor)| + { + //system.compute_triplet_neighbors(self.parameters.bond_cutoff(), self.parameters.third_cutoff())?; + self.distance_calculator.ensure_computed_for_system(system)?; + let triplets = self.distance_calculator.get_for_system(system)?; + let types = system.types()?; + + for ((s1,s2),s1s2_blocks) in s1s2_to_block_ids.iter() { + let (s3_list,per_s3_blocks): (Vec,Vec<&Vec<_>>) = s1s2_blocks.iter().map( + |(s3,blocks)|(*s3,blocks) + ).unzip(); + // half-assume that blocks that share s1,s2,s3 have the same sample list + #[cfg(debug_assertions)]{ + for (_s3,s3blocks) in s1s2_blocks.iter(){ + debug_assert!(s3blocks.len()>0); + let mut s3goodblocks = s3blocks.iter().filter(|b_i|(**b_i)!=usize::MAX); + let first_goodblock = s3goodblocks.next(); + debug_assert!(first_goodblock.is_some()); + + let samples_n = descriptor.block_by_id(*first_goodblock.unwrap()).samples().size(); + for lblock in s3goodblocks { + debug_assert_eq!(descriptor.block_by_id(*lblock).samples().size(), samples_n); + } + } + } + // {bond_i->(i_s3,sample_i)} + let mut s3_samples = vec![]; + let mut sample_lut: BTreeMap<(usize,usize,[i32;3]),Vec<(usize,usize)>> = BTreeMap::new(); + + // also assume that the systems are in order in the samples + for (i_s3, s3blocks) in per_s3_blocks.into_iter().enumerate() { + let first_good_block = s3blocks.iter().filter(|b_i|**b_i!=usize::MAX).next().unwrap(); + let samples = descriptor.block_by_id(*first_good_block).samples(); + for (sample_i, &[_system_i,atom_i,atom_j,cell_shift_a, cell_shift_b, cell_shift_c]) in samples.iter_fixed_size().enumerate(){ + match sample_lut.entry( + (atom_i.usize(),atom_j.usize(),[cell_shift_a.i32(),cell_shift_b.i32(),cell_shift_c.i32()]) + ) { + Entry::Vacant(e) => { + e.insert(vec![(i_s3,sample_i)]); + }, + Entry::Occupied(mut e) => { + e.get_mut().push((i_s3,sample_i)); + }, + } + } + s3_samples.push(samples); + } + for (triplet_i,inverted,contribution) in self.get_coefficients_for(system, *s1, *s2, &s3_list, do_gradients)? { + let triplet = &triplets[triplet_i]; + + let contribution = contribution.borrow(); + let these_samples = match sample_lut.get( + &(triplet.atom_i,triplet.atom_j,triplet.bond_cell_shift) + ){ + None => {continue;}, + Some(a) => a, + }; + + for (i_s3,sample_i) in these_samples.iter(){ + if s3_list[*i_s3] != types[triplet.atom_k] { + continue // this triplet does not contribute to this block + } + let sample = &s3_samples[*i_s3][*sample_i]; + let (atom_i,atom_j, ce_sh) = (sample[1].usize(),sample[2].usize(),[sample[3].i32(),sample[4].i32(),sample[5].i32()]); + if (!inverted) && ( + triplet.atom_i != atom_i || triplet.atom_j != atom_j + || triplet.bond_cell_shift != ce_sh + ){ + continue; + } else if inverted && ( + triplet.atom_i != atom_j || triplet.atom_j != atom_i + || triplet.bond_cell_shift != ce_sh.map(|x|-x) + ){ + continue; + } + + let ret_blocks = &s1s2_blocks[*i_s3].1; + for (l,lslice) in l_slices.iter().enumerate() { + let block_i = ret_blocks[l]; + if block_i == usize::MAX { + continue; + } + let mut block = descriptor.block_mut_by_id(block_i); + let mut array = array_mut_for_system(block.values_mut()); + let mut value_slice = array.slice_mut(s![*sample_i,..,..]); + let input_slice = contribution.values.slice(s![lslice.clone(),..]); + for (n_i,n) in radial_selection[block_i].iter().enumerate() { + let mut value_slice = value_slice.slice_mut(s![..,n_i]); + value_slice += &input_slice.slice(s![..,*n]); + } + + } + } + } + } + Ok::<_, Error>(()) + })?; + + Ok(()) + } +} + + +#[cfg(test)] +mod tests { + use ndarray::ArrayD; + use metatensor::{Labels, TensorBlock, EmptyArray, LabelsBuilder, TensorMap}; + + use crate::calculators::bondatom::set_feature_gate; + use crate::systems::test_utils::test_systems; + use crate::{Calculator, CalculationOptions, LabelsSelection}; + use crate::calculators::CalculatorBase; + + use super::{SphericalExpansionForBonds, SphericalExpansionForBondsParameters}; + use crate::calculators::soap::{CutoffFunction, RadialScaling}; + use crate::calculators::radial_basis::RadialBasis; + + + fn parameters() -> SphericalExpansionForBondsParameters { + set_feature_gate(); + SphericalExpansionForBondsParameters { + cutoffs: [3.5,3.5], + max_radial: 6, + max_angular: 6, + atomic_gaussian_width: 0.3, + center_atoms_weight: 10.0, + radial_basis: RadialBasis::splined_gto(1e-8), + radial_scaling: RadialScaling::Willatt2018 { scale: 1.5, rate: 0.8, exponent: 2.0}, + cutoff_function: CutoffFunction::ShiftedCosine { width: 0.5 }, + } + } + + #[test] + fn values() { + let mut calculator = Calculator::from(Box::new(SphericalExpansionForBonds::new( + parameters() + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + let descriptor = calculator.compute(&mut systems, Default::default()).unwrap(); + + for l in 0..6 { + for center_1_type in [1, -42] { + for center_2_type in [1, -42] { + if center_1_type==-42 && center_2_type==-42 { + continue; + } + for neighbor_type in [1, -42] { + let block_i = descriptor.keys().position(&[ + l.into(), center_1_type.into(), center_2_type.into(), neighbor_type.into() + ]); + assert!(block_i.is_some()); + let block = &descriptor.block_by_id(block_i.unwrap()); + let array = block.values().to_array(); + assert_eq!(array.shape().len(), 3); + assert_eq!(array.shape()[1], 2 * l + 1); + } + } + } + } + + // exact values for spherical expansion are regression-tested in + // `rascaline/tests/spherical-expansion.rs` + } + + #[test] + fn compute_partial() { + let calculator = Calculator::from(Box::new(SphericalExpansionForBonds::new( + SphericalExpansionForBondsParameters { + max_angular: 2, + ..parameters() + } + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + + let properties = Labels::new(["n"], &[ + [0], + [3], + [2], + ]); + + let samples = Labels::new(["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], &[ + [0, 0, 2, 0,0,0], + [0, 0, 1, 0,0,0], + //[0, 1, 2, 0,0,0], // excluding this one + ]); + + let keys = Labels::new(["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"], &[ + // every key that will be generated (in scrambled order) plus one + [0, -42, 1, -42], + [0, -42, -42, -42], + [2, -42, 1, -42], + [0, 1, -42, -42], + [0, 1, 1, -42], + [0, 6, 1, 1], // not part of the default keys + [1, -42, 1, -42], + [1, -42, 1, 1], + [2, -42, -42, -42], + [1, 1, -42, 1], + [0, -42, 1, 1], + [1, 1, 1, -42], + [2, -42, 1, 1], + [0, 1, 1, 1], + [2, 1, -42, -42], + [1, 1, 1, 1], + [0, -42, -42, 1], + [1, -42, -42, 1], + [2, -42, -42, 1], + [2, 1, 1, -42], + [0, 1, -42, 1], + [2, 1, -42, 1], + [2, 1, 1, 1], + [1, -42, -42, -42], + [1, 1, -42, -42], + ]); + + crate::calculators::tests_utils::compute_partial( + calculator, &mut systems, &keys, &samples, &properties + ); + } + + #[test] + fn non_existing_samples() { + let mut calculator = Calculator::from(Box::new(SphericalExpansionForBonds::new( + parameters() + ).unwrap()) as Box); + + let angular_stride = parameters().max_angular +1; + let mut systems = test_systems(&["water"]); + + // include the three atoms in all blocks, regardless of the + // atom_type key. + let block = TensorBlock::new( + EmptyArray::new(vec![3, 1]), + &Labels::new(["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], &[ + [0, 0, 2, 0,0,0], + [0, 1, 2, 0,0,0], + [0, 0, 1, 0,0,0], + ]), + &[], + &Labels::single(), + ).unwrap(); + + let mut keys = LabelsBuilder::new(vec!["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"]); + let mut blocks = Vec::new(); + for l in 0..(parameters().max_angular + 1) as isize { + for center_1_type in [1, -42] { + for center_2_type in [1, -42] { + for neighbor_type in [1, -42] { + keys.add(&[l, center_1_type, center_2_type, neighbor_type]); + blocks.push(block.as_ref().try_clone().unwrap()); + } + } + } + } + let select_all_samples = TensorMap::new(keys.finish(), blocks).unwrap(); + + let options = CalculationOptions { + selected_samples: LabelsSelection::Predefined(&select_all_samples), + ..Default::default() + }; + let descriptor = calculator.compute(&mut systems, options).unwrap(); + + // get the block for oxygen + // println!("{:?}", descriptor.keys()); + assert_eq!(descriptor.keys().names(), ["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"]); + assert_eq!(descriptor.keys()[2*angular_stride], [0, -42, 1, -42]); // start with [n, -42, -42, -42], then [n, -42, -42, 1] + + let block = descriptor.block_by_id(2*angular_stride); + let block = block.data(); + + // entries centered on H atoms should be zero + assert_eq!( + *block.samples, + Labels::new(["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], &[ + [0, 0, 2, 0,0,0], + [0, 1, 2, 0,0,0], // the sample that doesn't exist + [0, 0, 1, 0,0,0], + ]) + ); + let array = block.values.as_array(); + assert_eq!(array.index_axis(ndarray::Axis(0), 1), ArrayD::from_elem(vec![1, 6], 0.0)); + + // get the block for hydrogen + assert_eq!(descriptor.keys().names(), ["o3_lambda", "center_1_type", "center_2_type", "neighbor_type"]); + assert_eq!(descriptor.keys()[5*angular_stride], [0, 1, -42, 1]); + + let block = descriptor.block_by_id(5*angular_stride); + let block = block.data(); + + // entries centered on O atoms should be zero + assert_eq!( + *block.samples, + Labels::new(["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], &[ + [0, 0, 2, 0,0,0], + [0, 1, 2, 0,0,0], + [0, 0, 1, 0,0,0], + ]) + ); + let array = block.values.as_array(); + assert_eq!(array.index_axis(ndarray::Axis(0), 0), ArrayD::from_elem(vec![1, 6], 0.0)); + } +} diff --git a/rascaline/src/calculators/mod.rs b/rascaline/src/calculators/mod.rs index 7eaa43d35..d13414842 100644 --- a/rascaline/src/calculators/mod.rs +++ b/rascaline/src/calculators/mod.rs @@ -114,3 +114,6 @@ pub use self::soap::{SoapRadialSpectrum, RadialSpectrumParameters}; pub mod lode; pub use self::lode::{LodeSphericalExpansion, LodeSphericalExpansionParameters}; + +mod bondatom; +pub use self::bondatom::{SphericalExpansionForBondsParameters, SphericalExpansionForBonds}; diff --git a/rascaline/src/labels/keys.rs b/rascaline/src/labels/keys.rs index 48fbe393d..e4bc4810c 100644 --- a/rascaline/src/labels/keys.rs +++ b/rascaline/src/labels/keys.rs @@ -3,7 +3,7 @@ use std::collections::BTreeSet; use metatensor::{Labels, LabelsBuilder}; use crate::{System, Error}; - +use crate::systems::BATripletNeighborList; /// Common interface to create a set of metatensor's `TensorMap` keys from systems pub trait KeysBuilder { @@ -95,6 +95,52 @@ impl KeysBuilder for CenterSingleNeighborsTypesKeys { } } +/// Compute a set of keys with three variables: the types of two central atoms within a given cutoff to each other, +/// and the type of a third, neighbor atom, within a cutoff of the first two. +pub struct TwoCentersSingleNeighborsTypesKeys<'a> { + /// Spherical cutoff to use when searching for neighbors around an atom + pub(crate) cutoffs: [f64;2], + /// Should we consider an atom to be it's own neighbor or not? + pub self_contributions: bool, + pub raw_triplets: &'a BATripletNeighborList, +} + +impl<'a> TwoCentersSingleNeighborsTypesKeys<'a>{ + pub fn bond_cutoff(&self) -> f64 { + self.cutoffs[0] + } + pub fn third_cutoff(&self) -> f64 { + self.cutoffs[1] + } +} + + +impl<'a> KeysBuilder for TwoCentersSingleNeighborsTypesKeys<'a> { + fn keys(&self, systems: &mut [System]) -> Result { + assert!(self.bond_cutoff() > 0.0 && self.bond_cutoff().is_finite() && self.third_cutoff() > 0.0 && self.third_cutoff().is_finite()); + + let mut all_types_triplets = BTreeSet::new(); + for system in systems { + self.raw_triplets.ensure_computed_for_system(system)?; + + let types = system.types()?; + for triplet in self.raw_triplets.get_for_system(system)? { + if (!self.self_contributions) && triplet.is_self_contrib { + continue; + } + all_types_triplets.insert((types[triplet.atom_i], types[triplet.atom_j], types[triplet.atom_k])); + all_types_triplets.insert((types[triplet.atom_j], types[triplet.atom_i], types[triplet.atom_k])); + } + } + + let mut keys = LabelsBuilder::new(vec!["center_1_type", "center_2_type", "neighbor_type"]); + for (center1, center2, neighbor) in all_types_triplets { + keys.add(&[center1,center2, neighbor]); + } + + return Ok(keys.finish()); + } +} /// Compute a set of keys with three variables: the central atom type and two /// neighbor atom types. diff --git a/rascaline/src/labels/mod.rs b/rascaline/src/labels/mod.rs index 7c4e6e1a9..025b1f865 100644 --- a/rascaline/src/labels/mod.rs +++ b/rascaline/src/labels/mod.rs @@ -1,11 +1,11 @@ mod samples; pub use self::samples::{AtomicTypeFilter, SamplesBuilder}; -pub use self::samples::AtomCenteredSamples; +pub use self::samples::{AtomCenteredSamples,BondCenteredSamples}; pub use self::samples::LongRangeSamplesPerAtom; mod keys; pub use self::keys::KeysBuilder; pub use self::keys::CenterTypesKeys; -pub use self::keys::{CenterSingleNeighborsTypesKeys, AllTypesPairsKeys}; +pub use self::keys::{CenterSingleNeighborsTypesKeys, TwoCentersSingleNeighborsTypesKeys, AllTypesPairsKeys}; pub use self::keys::CenterTwoNeighborsTypesKeys; diff --git a/rascaline/src/labels/samples/bond_centered.rs b/rascaline/src/labels/samples/bond_centered.rs new file mode 100644 index 000000000..fe7703c4c --- /dev/null +++ b/rascaline/src/labels/samples/bond_centered.rs @@ -0,0 +1,413 @@ +use std::collections::{BTreeSet,BTreeMap}; + +use metatensor::{Labels, LabelsBuilder}; + +use crate::{Error, System}; +use super::{SamplesBuilder, AtomicTypeFilter}; +use crate::systems::BATripletNeighborList; + + +/// `SampleBuilder` for bond-centered representations. This will create one +/// sample for each pair of atoms (within a spherical cutoff to each other), +/// optionally filtering on the bond's atom types. The samples names are +/// (structure", "first_center", "second_center", "bond_i"). +/// (with type(first_center)<=type(second_center)) +/// +/// Positions gradient samples include all atoms within a spherical cutoff to the bond center, +/// optionally filtering on the neighbor atom types. +pub struct BondCenteredSamples<'a> { + /// spherical cutoff radius used to construct the atom-centered environments + pub cutoffs: [f64;2], + /// Filter for the central atom types + pub center_1_type: AtomicTypeFilter, + pub center_2_type: AtomicTypeFilter, + /// Filter for the neighbor atom types + pub neighbor_type: AtomicTypeFilter, + /// Should the central atom be considered it's own neighbor? + pub self_contributions: bool, + pub raw_triplets: &'a BATripletNeighborList, +} + +impl<'a> BondCenteredSamples<'a> { + pub fn bond_cutoff(&self) -> f64 { + self.cutoffs[0] + } + pub fn third_cutoff(&self) -> f64 { + self.cutoffs[1] + } +} + +impl<'a> SamplesBuilder for BondCenteredSamples<'a> { + fn sample_names() -> Vec<&'static str> { + // bond_i is needed in case we have several bonds with the same atoms (periodic boundaries) + vec!["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"] + } + + fn samples(&self, systems: &mut [System]) -> Result { + assert!( + self.bond_cutoff() > 0.0 && self.bond_cutoff().is_finite() && self.third_cutoff() > 0.0 && self.third_cutoff().is_finite(), + "cutoffs must be positive for BondCenteredSamples" + ); + let mut builder = LabelsBuilder::new(Self::sample_names()); + for (system_i, system) in systems.iter_mut().enumerate() { + self.raw_triplets.ensure_computed_for_system(system)?; + let types = system.types()?; + + let mut center_cache: BTreeMap<(usize,usize,[i32;3]), BTreeSet> = BTreeMap::new(); + + match (&self.center_1_type, &self.center_2_type) { + (AtomicTypeFilter::Any, AtomicTypeFilter::Any) => { + for triplet in self.raw_triplets.get_for_system(system)? { + if self.self_contributions || (!triplet.is_self_contrib){ + center_cache.entry((triplet.atom_i, triplet.atom_j, triplet.bond_cell_shift)) + .or_insert_with(BTreeSet::new) + .insert(types[triplet.atom_k]); + } + } + } + (AtomicTypeFilter::AllOf(_),_)|(_,AtomicTypeFilter::AllOf(_)) => + panic!("Cannot use AtomicTypeFilter::AllOf on BondCenteredSamples.center_types"), + (AtomicTypeFilter::Single(s1), AtomicTypeFilter::Single(s2)) => { + let types_set = BTreeSet::from_iter(types.iter()); + for s3 in types_set { + for triplet in self.raw_triplets.get_per_system_per_type(system, *s1, *s2, *s3)? { + if !self.self_contributions && triplet.is_self_contrib { + continue; + } + center_cache.entry((triplet.atom_i, triplet.atom_j, triplet.bond_cell_shift)) + .or_insert_with(BTreeSet::new) + .insert(types[triplet.atom_k]); + } + } + + }, + (selection_1, selection_2) => { + for (center_i, ¢er_1_type) in types.iter().enumerate() { + if !selection_1.matches(center_1_type) { + continue; + } + for (center_j, ¢er_2_type) in types.iter().enumerate() { + if !selection_2.matches(center_2_type) { + continue; + } + for triplet in self.raw_triplets.get_per_system_per_center(system, center_i, center_j)? { + if !self.self_contributions && triplet.is_self_contrib { + continue; + } + center_cache.entry((triplet.atom_i, triplet.atom_j, triplet.bond_cell_shift)) + .or_insert_with(BTreeSet::new) + .insert(types[triplet.atom_k]); + } + } + } + } + } + match &self.neighbor_type { + AtomicTypeFilter::Any => { + for (center_1,center_2,cell_shft) in center_cache.keys() { + builder.add(&[system_i as i32,*center_1 as i32,*center_2 as i32, cell_shft[0],cell_shft[1],cell_shft[2]]); + } + }, + AtomicTypeFilter::AllOf(requirements) => { + for ((center_1,center_2,cell_shft), neigh_set) in center_cache.iter() { + if requirements.is_subset(neigh_set) { + builder.add(&[system_i as i32,*center_1 as i32,*center_2 as i32, cell_shft[0],cell_shft[1],cell_shft[2]]); + } + } + }, + AtomicTypeFilter::Single(requirement) => { + for ((center_1,center_2,cell_shft), neigh_set) in center_cache.iter() { + if neigh_set.contains(requirement) { + builder.add(&[system_i as i32,*center_1 as i32,*center_2 as i32, cell_shft[0],cell_shft[1],cell_shft[2]]); + } + } + }, + AtomicTypeFilter::OneOf(requirements) => { + let requirements: BTreeSet = BTreeSet::from_iter(requirements.iter().map(|x|*x)); + for ((center_1,center_2,cell_shft), neigh_set) in center_cache.iter() { + if neigh_set.intersection(&requirements).count()>0 { + builder.add(&[system_i as i32,*center_1 as i32,*center_2 as i32, cell_shft[0],cell_shft[1],cell_shft[2]]); + } + } + }, + } + } + + return Ok(builder.finish()); + } + + fn gradients_for(&self, systems: &mut [System], samples: &Labels) -> Result { + assert!( + self.bond_cutoff() > 0.0 && self.bond_cutoff().is_finite() && self.third_cutoff() > 0.0 && self.third_cutoff().is_finite(), + "cutoffs must be positive for BondCenteredSamples" + ); + assert_eq!(samples.names(), ["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"]); + let mut builder = LabelsBuilder::new(vec!["sample", "system", "atom"]); + + // we could try to find a better way to estimate this, but in the worst + // case this would only over-allocate a bit + let average_neighbors_per_atom = 10; + builder.reserve(average_neighbors_per_atom * samples.count()); + + for (sample_i, [structure_i, center_1, center_2, clsh_a,clsh_b,clsh_c]) in samples.iter_fixed_size().enumerate() { + let structure_i = structure_i.usize(); + let center_1 = center_1.usize(); + let center_2 = center_2.usize(); + let cell_shift = [clsh_a.i32(),clsh_b.i32(),clsh_c.i32()]; + + let system = &mut systems[structure_i]; + self.raw_triplets.ensure_computed_for_system(system)?; + let types = system.types()?; + + let mut grad_contributors = BTreeSet::new(); + grad_contributors.insert(center_1); + grad_contributors.insert(center_2); + + for triplet in self.raw_triplets.get_per_system_per_center(system, center_1, center_2)? { + if triplet.bond_cell_shift != cell_shift { + continue; + } + match &self.neighbor_type{ + AtomicTypeFilter::Any | AtomicTypeFilter::AllOf(_) => { + // in both of those cases, the sample already has been validated, and all known neighbors contribute + grad_contributors.insert(triplet.atom_k); + }, + neighbor_filter => { + if neighbor_filter.matches(types[triplet.atom_k]) { + grad_contributors.insert(triplet.atom_k); + } + }, + } + } + + for contrib in grad_contributors{ + builder.add(&[sample_i, structure_i, contrib]); + } + } + + return Ok(builder.finish()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::systems::test_utils::test_systems; + + #[test] + fn all_samples() { + let mut systems = test_systems(&["CH", "water"]); + let raw = BATripletNeighborList { + cutoffs: [2.0,2.0], + }; + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + center_1_type: AtomicTypeFilter::Any, + center_2_type: AtomicTypeFilter::Any, + neighbor_type: AtomicTypeFilter::Any, + self_contributions: true, + raw_triplets: &raw, + }; + + let samples = builder.samples(&mut systems).unwrap(); + assert_eq!(samples, Labels::new( + ["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], + &[ + // CH + [0, 1, 0, 0,0,0], + // water, single cell + [1, 0, 1, 0,0,0], [1, 0, 2, 0,0,0], [1, 1, 2, 0,0,0], + //water, H-H bond through cell bounds + [1, 1, 2, 0,1,0] + ], + )); + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "system", "atom"], + &[ + // gradients of atoms in CH + [0, 0, 0], [0, 0, 1], + // gradients of atoms in water + [1, 1, 0], [1, 1, 1], [1, 1, 2], + [2, 1, 0], [2, 1, 1], [2, 1, 2], + [3, 1, 0], [3, 1, 1], [3, 1, 2], + [4, 1, 0], [4, 1, 1], [4, 1, 2], + ], + )); + } + + #[test] + fn filter_types_center() { + let mut systems = test_systems(&["CH", "water"]); + let raw = BATripletNeighborList { + cutoffs: [2.0,2.0], + }; + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + center_1_type: AtomicTypeFilter::Single(6), + center_2_type: AtomicTypeFilter::Single(1), + neighbor_type: AtomicTypeFilter::Any, + self_contributions: true, + raw_triplets: &raw, + }; + + let samples = builder.samples(&mut systems).unwrap(); + assert_eq!(samples, Labels::new( + ["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], + &[[0, 1, 0, 0,0,0]], + )); + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "system", "atom"], + &[ + // gradients of atoms in CH + [0, 0, 0], [0, 0, 1], + ] + )); + + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + center_1_type: AtomicTypeFilter::Single(1), + center_2_type: AtomicTypeFilter::Single(1), + neighbor_type: AtomicTypeFilter::Any, + self_contributions: true, + raw_triplets: &raw, + }; + + let samples = builder.samples(&mut systems).unwrap(); + assert_eq!(samples, Labels::new( + ["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], + &[[1, 1, 2, 0,0,0], [1, 1, 2, 0,1,0]], + )); + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "system", "atom"], + &[ + // gradients of atoms in H2O + [0, 1, 0], [0, 1, 1], [0, 1, 2], + [1, 1, 0], [1, 1, 1], [1, 1, 2], + ] + )); + } + + #[test] + fn filter_neighbor_type() { + let mut systems = test_systems(&["CH", "water"]); + let raw = BATripletNeighborList { + cutoffs: [2.0,2.0], + }; + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + center_1_type: AtomicTypeFilter::Any, + center_2_type: AtomicTypeFilter::Any, + neighbor_type: AtomicTypeFilter::Single(1), + self_contributions: true, + raw_triplets: &raw, + }; + + let samples = builder.samples(&mut systems).unwrap(); + assert_eq!(samples, Labels::new( + ["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], + &[ + //CH + [0, 1, 0, 0,0,0], + //water, in-cell + [1, 0, 1, 0,0,0], [1, 0, 2, 0,0,0], [1, 1, 2, 0,0,0], + // water, H-H through cell boundary + [1, 1, 2, 0,1,0] + ], + )); + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "system", "atom"], + &[ + // gradients of atoms in CH w.r.t H atom only + [0, 0, 0], [0, 0, 1], + // gradients of atoms in water w.r.t H atoms only + [1, 1, 0], [1, 1, 1], [1, 1, 2], + [2, 1, 0], [2, 1, 1], [2, 1, 2], + [3, 1, 1], [3, 1, 2], + [4, 1, 1], [4, 1, 2], + ] + )); + + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + center_1_type: AtomicTypeFilter::Any, + center_2_type: AtomicTypeFilter::Any, + neighbor_type: AtomicTypeFilter::OneOf(vec![1, 6]), + self_contributions: true, + raw_triplets: &raw, + }; + + let gradient_samples = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradient_samples, Labels::new( + ["sample", "system", "atom"], + &[ + // gradients of atoms in CH w.r.t C and H atoms + [0, 0, 0], [0, 0, 1], + // gradients of atoms in water w.r.t H atoms only + [1, 1, 0], [1, 1, 1], [1, 1, 2], + [2, 1, 0], [2, 1, 1], [2, 1, 2], + [3, 1, 1], [3, 1, 2], + [4, 1, 1], [4, 1, 2], + ] + )); + } + + #[test] + fn partial_gradients() { + let samples = Labels::new(["system", "first_atom", "second_atom", "cell_shift_a","cell_shift_b","cell_shift_c"], &[ + [1, 1, 0, 0,0,0], + [0, 0, 1, 0,0,0], + [1, 1, 2, 0,0,0], + ]); + + let mut systems = test_systems(&["CH", "water"]); + + let raw = BATripletNeighborList { + cutoffs: [2.0,2.0], + }; + + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + center_1_type: AtomicTypeFilter::Any, + center_2_type: AtomicTypeFilter::Any, + neighbor_type: AtomicTypeFilter::Single(-42), + self_contributions: true, + raw_triplets: &raw, + }; + + let gradients = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradients, Labels::new(["sample", "system", "atom"], &[ + [0, 1, 0], [0, 1, 1], + [1, 0, 0], [1, 0, 1], + [2, 1, 0], [2, 1, 1], [2, 1, 2], + ])); + + let builder = BondCenteredSamples { + cutoffs: [2.0,2.0], + center_1_type: AtomicTypeFilter::Any, + center_2_type: AtomicTypeFilter::Any, + neighbor_type: AtomicTypeFilter::Single(1), + self_contributions: true, + raw_triplets: &raw, + }; + let gradients = builder.gradients_for(&mut systems, &samples).unwrap(); + assert_eq!(gradients, Labels::new( + ["sample", "system", "atom"], + &[ + // gradients of first sample, O-H1 in water + [0, 1, 0], [0, 1, 1], [0, 1, 2], + // gradients of second sample, C-H in CH + [1, 0, 0], [1, 0, 1], + // gradients of third sample, H1-H2 in water + [2, 1, 1], [2, 1, 2], + ] + )); + } +} diff --git a/rascaline/src/labels/samples/mod.rs b/rascaline/src/labels/samples/mod.rs index 6fdc9207c..968fb991e 100644 --- a/rascaline/src/labels/samples/mod.rs +++ b/rascaline/src/labels/samples/mod.rs @@ -54,6 +54,8 @@ pub trait SamplesBuilder { mod atom_centered; pub use self::atom_centered::AtomCenteredSamples; +mod bond_centered; +pub use self::bond_centered::BondCenteredSamples; mod long_range; pub use self::long_range::LongRangeSamplesPerAtom;