Skip to content

Commit

Permalink
fix docs and some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
liam-o-marsh committed Oct 24, 2023
1 parent fd77379 commit 30d6f47
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 104 deletions.
56 changes: 27 additions & 29 deletions rascaline/src/calculators/bondatom_neighbor_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,38 +162,40 @@ fn rotate_vector_to_z_derivatives(vec: Vector3D) -> (Matrix3,Matrix3,Matrix3) {



/// TODO fix docstring
/// This calculator computes the neighbor list for a given spherical cutoff, and
/// returns the list of distance vectors between all pairs of atoms strictly
/// inside the cutoff.
/// Manages a list of 'neighbors', where one neighbor is the center of a pair of atoms
/// (first and second atom), and the other neighbor is a simple atom (third atom).
/// Both the length of the bond and the distance between neighbors are subjected to a spherical cutoff.
///
/// Unlike the corresponding pre_calculator, this calculator focuses on storing
/// the canonical-orientation vector between bond and atom, rather than the bond vector and 'third vector'.
///
/// Users can request either a "full" neighbor list (including an entry for both
/// `i - j` pairs and `j - i` pairs) or save memory/computational by only
/// working with "half" neighbor list (only including one entry for each `i/j`
/// pair)
/// `i - j` bonds and `j - i` bonds) or save memory/computational by only
/// working with "half" neighbor list (only including one entry for each `i - j`
/// bond)
/// if memory is saved, the order of i and j is that the atom with
/// the smallest Z (or species ID in general) comes first.
///
/// Self pairs (pairs between an atom and periodic copy itself) can appear when
/// the cutoff is larger than the cell under periodic boundary conditions. Self
/// pairs with a distance of 0 are not included in this calculator, even though
/// they are required when computing SOAP.
/// The two first atoms must not be the same atom, but the third atom may be one of them,
/// if the `bond_conbtribution` option is active
/// (When periodic boundaries arise, atom which must not be the same may be images of each other.)
///
/// This sample produces a single property (`"distance"`) with three components
/// (`"vector_direction"`) containing the x, y, and z component of the vector from
/// the first atom in the pair to the second. In addition to the atom indexes,
/// the samples also contain a pair index, to be able to distinguish between
/// multiple pairs between the same atom (if the cutoff is larger than the
/// cell).
/// the center of the triplet's 'bond' to the triplet's 'third atom', in the bond's canonical orientation.
///
/// In addition to the atom indexes, the samples also contain a pair and triplet index,
/// to be able to distinguish between multiple triplets involving the same atoms
/// (which can occur in periodic boundary conditions when the cutoffs are larger than the unit cell).
#[derive(Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
pub struct BANeighborList {
/// the pre-calculator responsible for making a raw enumeration of the system's bond-atom triplets
pub raw_triplets: BATripletNeighborList,
/// Should individual atoms be considered their own neighbor? Setting this
/// to `true` will add "self pairs", i.e. pairs between an atom and itself,
/// with the distance 0. The `pair_i` of such pairs is set to -1.
/// Should we include triplets where the third atom is one of the bond's atoms?
pub bond_contribution: bool,
/// Should we compute a full neighbor list (each pair appears twice, once as
/// `i-j` and once as `j-i`), or a half neighbor list (each pair only
/// Should we compute a full neighbor list (each triplet appears twice, once as
/// `i-j +k` and once as `j-i +k`), or a half neighbor list (each triplet only
/// appears once, (such that `species_i <= species_j`))
pub use_half_enumeration: bool,
}
Expand Down Expand Up @@ -224,11 +226,7 @@ impl BANeighborList {
self.raw_triplets.third_cutoff()
}

/// validate that the cutoffs make sense
fn validate_cutoffs(&self) {
self.raw_triplets.validate_cutoffs()
}

/// a "flatter" initialisation method than the structure-based one
pub fn from_params(cutoffs: [f64;2], use_half_enumeration: bool, bond_contribution: bool) -> Self {
Self{
raw_triplets: BATripletNeighborList {
Expand All @@ -239,7 +237,8 @@ impl BANeighborList {
}
}


/// the core of the calculation being done here:
/// computing the canonical-orientation vector and distance of a given bond-atom triplet.
pub(super) fn compute_single_triplet(
triplet: &BATripletInfo,
invert: bool,
Expand Down Expand Up @@ -334,6 +333,8 @@ impl BANeighborList {
return Ok(res);
}

/// get the canonical-orientation vector and distance of a triplet
/// and store it in a TensorBlock
fn compute_single_triplet_inplace(
triplet: &BATripletInfo,
out_block: &mut TensorBlockRefMut,
Expand Down Expand Up @@ -400,8 +401,6 @@ impl CalculatorBase for BANeighborList {
}

fn keys(&self, systems: &mut [System]) -> Result<Labels, Error> {
self.validate_cutoffs();

let mut all_species_triplets = BTreeSet::new();
for system in systems {
self.raw_triplets.ensure_computed_for_system(system)?;
Expand Down Expand Up @@ -437,7 +436,6 @@ impl CalculatorBase for BANeighborList {
}

fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result<Vec<Labels>, Error> {
self.validate_cutoffs();
let mut results = Vec::new();

for &[species_first, species_second, species_third] in keys.iter_fixed_size() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ use super::spherical_expansion_pair::{
PairContribution
};

/// Parameters for spherical expansion calculator.
/// Parameters for spherical expansion calculator for bond-centered neighbor densities.
///
/// The spherical expansion is at the core of representations in the SOAP
/// (The spherical expansion is at the core of representations in the SOAP
/// (Smooth Overlap of Atomic Positions) family. See [this review
/// article](https://doi.org/10.1063/1.5090481) for more information on the SOAP
/// representation, and [this paper](https://doi.org/10.1063/5.0044689) for
/// information on how it is implemented in rascaline.
/// information on how it is implemented in rascaline.)
///
/// This calculator is only needed to characterize local environments that are centered
/// on a pair of atoms rather than a single one.
#[derive(Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
pub struct SphericalExpansionForBondsParameters {
/// Spherical cutoffs to use for atomic environments
pub(super) cutoffs: [f64;2], // bond_, third_cutoff
//pub bond_cutoff: f64,
//pub third_cutoff: f64,
//pub compute_both_sides: bool,
pub(super) cutoffs: [f64;2],
/// Number of radial basis function to use in the expansion
pub max_radial: usize,
/// Number of spherical harmonics to use in the expansion
Expand Down Expand Up @@ -277,26 +277,8 @@ impl SphericalExpansionForBondType {
let max_radial = self.parameters.max_radial;
let species = system.species().unwrap();

//let _ = system.triplets()?;
// for s3 in s3_list{
// let triplets = system.triplets_with_species(s1, s2, *s3)?;
// if let Err(error) = triplets {
// return Err(error);
// }
// }

let pre_iter = s3_list.iter().flat_map(|s3|{
// let all_triplets = system.triplets().unwrap();
// let triplets = system.triplets_with_species(s1, s2, *s3).unwrap();
// triplets.iter().map(|triplet_i|{
// let triplet = &all_triplets[*triplet_i];
// #[cfg(debug_assertions)]{
// assert_eq!(species[triplet.bond.first],s1);
// assert_eq!(species[triplet.bond.second],s2);
// assert_eq!(species[triplet.third],*s3);
// }
// (*triplet_i,triplet)
// })
self.distance_calculator.raw_triplets.get_per_system_per_species(system,s1,s2,*s3,true).unwrap().into_iter()
}).flat_map(|triplet| {
let invert: &'static [bool] = {
Expand Down
5 changes: 1 addition & 4 deletions rascaline/src/labels/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ impl KeysBuilder for CenterSingleNeighborsSpeciesKeys {
/// and the species of a third, neighbor atom, within a cutoff of the first two.
pub struct TwoCentersSingleNeighborsSpeciesKeys<'a> {
/// Spherical cutoff to use when searching for neighbors around an atom
pub(crate) cutoffs: [f64;2], // bond_, third_cutoff
//pub bond_cutoff: f64,
//pub third_cutoff: f64,
pub(crate) cutoffs: [f64;2],
/// Should we consider an atom to be it's own neighbor or not?
pub self_contributions: bool,
pub raw_triplets: &'a BATripletNeighborList,
Expand All @@ -123,7 +121,6 @@ impl<'a> KeysBuilder for TwoCentersSingleNeighborsSpeciesKeys<'a> {

let mut all_species_triplets = BTreeSet::new();
for system in systems {
//system.compute_triplet_neighbors(self.bond_cutoff(), self.third_cutoff())?;
self.raw_triplets.ensure_computed_for_system(system)?;

let species = system.species()?;
Expand Down
4 changes: 1 addition & 3 deletions rascaline/src/labels/samples/bond_centered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ use super::super::super::pre_calculators::BATripletNeighborList;
/// optionally filtering on the neighbor atom species.
pub struct BondCenteredSamples<'a> {
/// spherical cutoff radius used to construct the atom-centered environments
pub cutoffs: [f64;2], // bond_, third_cutoff
//pub bond_cutoff: f64,
//pub third_cutoff: f64,
pub cutoffs: [f64;2],
/// Filter for the central atom species
pub species_center_1: SpeciesFilter,
pub species_center_2: SpeciesFilter,
Expand Down
62 changes: 19 additions & 43 deletions rascaline/src/pre_calculators/bond_atom_triplet_neighbors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,17 @@ pub struct BATripletInfo{


/// Manages a list of 'neighbors', where one neighbor is the center of a pair of atoms
/// (first and second atom), and the other neighbot is a simple atom (third atom).
/// (first and second atom), and the other neighbor is a simple atom (third atom).
/// Both the length of the bond and the distance between neighbors are subjected to a spherical cutoff.
/// This pre-calculator can compute and cache this list within a given system
/// (with two distance vectors per entry: one within the bond and one between neighbors).
/// Then, it can re-enumerate those neighbors, either for a full system, or with restrictions on the atoms or their species.
///
/// Users can request either a "full" neighbor list (including an entry for both
/// `i - j` bonds and `j - i` bonds) or save memory/computational by only
/// working with "half" neighbor list (only including one entry for each `i - j`
/// bond)
/// This saves memory/computational power by only working with "half" neighbor list
/// This is done by only including one entry for each `i - j` bond, not both `i - j` and `j - i`.
/// The order of i and j is that the atom with the smallest Z (or species ID in general) comes first.
///
/// The two first atoms may not be the same atom, but the third atom may be one of them.
/// The two first atoms must not be the same atom, but the third atom may be one of them.
/// (When periodic boundaries arise, the two first atoms may be images of each other.)
#[derive(Debug,Clone)]
#[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
Expand All @@ -72,6 +71,7 @@ pub struct BATripletNeighborList {
pub cutoffs: [f64;2], // bond_, third_cutoff
}

/// the internal function doing the triplet computing itself
fn list_raw_triplets(system: &mut dyn SystemBase, bond_cutoff: f64, third_cutoff: f64) -> Result<Vec<BATripletInfo>,Error> {
system.compute_neighbors(bond_cutoff)?;
let bonds = system.pairs()?.to_owned();
Expand Down Expand Up @@ -206,53 +206,18 @@ impl BATripletNeighborList {
}

/// validate that the cutoffs make sense
pub(crate) fn validate_cutoffs(&self) { // TODO: un-pub
pub fn validate_cutoffs(&self) {
let (bond_cutoff, third_cutoff) = (self.bond_cutoff(), self.third_cutoff());
assert!(bond_cutoff > 0.0 && bond_cutoff.is_finite());
assert!(third_cutoff >= bond_cutoff && third_cutoff.is_finite());
}

/// internal function that deletages computing the triplets, but deals with storing them for a given system.
fn do_compute_for_system(&self, system: &mut System) -> Result<(), Error> {
// let triplets_raw = TripletNeighborsList::for_system(&**system, self.bond_cutoff(), self.third_cutoff())?;
// let triplets = triplets_raw.triplets();
let triplets = list_raw_triplets(&mut **system, self.cutoffs[0], self.cutoffs[1])?;

//let species = system.species()?;

// let triplets = triplets.into_iter().map(|tr|{
// if species[tr.atom_i] <= species[tr.atom_j] {
// tr
// } else {
// BATripletInfo{
// atom_i: tr.atom_j,
// atom_j: tr.atom_i,
// atom_k: tr.atom_k,
// bond_i: tr.bond_i,
// triplet_i: tr.triplet_i,
// is_self_contrib: tr.is_self_contrib,
// bond_vector: tr.bond_vector.map(|v|-v),
// third_vector: tr.third_vector,
// }
// // BATriplet{
// // bond_i: tr.bond_i,
// // bond: Pair {
// // first: tr.bond.second,
// // second: tr.bond.first,
// // distance: tr.bond.distance,
// // vector: -tr.bond.vector,
// // cell_shift_indices: {let c = &tr.bond.cell_shift_indices; [-c[0],-c[1],-c[2]]},
// // },
// // third: tr.third, third_vector: tr.third_vector,
// // is_self_contribution: tr.is_self_contribution,
// // distance: tr.distance,
// // cell_shift_indices: {
// // let (c1,c2) = (&tr.cell_shift_indices,&tr.bond.cell_shift_indices);
// // [c1[0]-c2[0],c1[1]-c2[1],c1[2]-c2[2]]
// // },
// // }
// }
// }).collect::<Vec<_>>();

let components = [Labels::new(
["vector_pair_component"],
&[[0x00_i32],[0x01],[0x02], [0x10],[0x11],[0x12]],
Expand Down Expand Up @@ -332,7 +297,10 @@ impl BATripletNeighborList {
Ok(())
}

/// check that the precalculator has computed its values for a given system,
/// and if not, compute them.
pub fn ensure_computed_for_system(&self, system: &mut System) -> Result<(),Error> {
self.validate_cutoffs();
'cached_path: {
let cutoffs2: &[f64;2] = match system.data(Self::CACHE_NAME_ATTR.into()) {
Some(cutoff) => cutoff.downcast_ref()
Expand All @@ -349,6 +317,8 @@ impl BATripletNeighborList {
return self.do_compute_for_system(system);
}

/// for a given system, get a copy of all the bond-atom triplets.
/// optionally include the vectors tied to these triplets
pub fn get_for_system(&self, system: &System, with_vectors: bool) -> Result<Vec<BATripletInfo>, Error>{
let block: &TensorBlock = system.data(&Self::CACHE_NAME1)
.ok_or_else(||Error::Internal("triplets not yet computed".into()))?
Expand All @@ -375,6 +345,9 @@ impl BATripletNeighborList {
Ok(res)
}

/// for a given system, get a copy of the bond-atom triplets of given set of atomic species.
/// optionally include the vectors tied to these triplets
/// note: inverting s1 and s2 does not change the result, and the returned triplets may have these species swapped
pub fn get_per_system_per_species(&self, system: &System, s1:i32,s2:i32,s3:i32, with_vectors: bool) -> Result<Vec<BATripletInfo>, Error>{
let block: &TensorBlock = system.data(&Self::CACHE_NAME1)
.ok_or_else(||Error::Internal("triplets not yet computed".into()))?
Expand Down Expand Up @@ -414,6 +387,9 @@ impl BATripletNeighborList {
Ok(res)
}

/// for a given system, get a copy of the bond-atom triplets of given set of atomic species.
/// optionally include the vectors tied to these triplets
/// note: the triplets may be for (c2,c1) rather than (c1,c2)
pub fn get_per_system_per_center(&self, system: &System, c1:usize,c2:usize, with_vectors: bool) -> Result<Vec<BATripletInfo>, Error>{
{
let sz = system.size()?;
Expand Down

0 comments on commit 30d6f47

Please sign in to comment.