From f3bd6388fe8fba9e2a24f7ab1fa5f85459122757 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Mon, 8 Jan 2024 17:22:27 +0100 Subject: [PATCH] Update all calculators to remove neighbors list workaround --- rascaline/src/calculators/neighbor_list.rs | 104 ++++++++++++++---- .../calculators/soap/spherical_expansion.rs | 8 +- .../soap/spherical_expansion_pair.rs | 10 +- 3 files changed, 88 insertions(+), 34 deletions(-) diff --git a/rascaline/src/calculators/neighbor_list.rs b/rascaline/src/calculators/neighbor_list.rs index 0834713c0..8537a679d 100644 --- a/rascaline/src/calculators/neighbor_list.rs +++ b/rascaline/src/calculators/neighbor_list.rs @@ -39,7 +39,7 @@ pub struct NeighborList { pub full_neighbor_list: bool, /// Should individual atoms be considered their own neighbor? Setting this /// to `true` will add "self pairs", i.e. pairs between an atom and itself, - /// with the distance 0. The `pair_id` of such pairs is set to -1. + /// with the distance 0. pub self_pairs: bool, } @@ -423,7 +423,8 @@ impl FullNeighborList { let cell_c = pair.cell_shift_indices[2]; if species_first == species_second { - // same species for both atoms in the pair + // same species for both atoms in the pair, add the pair + // twice in both directions. if species[pair.first] == species_first.i32() && species[pair.second] == species_second.i32() { builder.add(&[ LabelValue::from(system_i), @@ -434,18 +435,14 @@ impl FullNeighborList { LabelValue::from(cell_c), ]); - if pair.first != pair.second { - // if the pair is between two different atoms, - // also add the reversed (second -> first) pair. - builder.add(&[ - LabelValue::from(system_i), - LabelValue::from(pair.second), - LabelValue::from(pair.first), - LabelValue::from(-cell_a), - LabelValue::from(-cell_b), - LabelValue::from(-cell_c), - ]); - } + builder.add(&[ + LabelValue::from(system_i), + LabelValue::from(pair.second), + LabelValue::from(pair.first), + LabelValue::from(-cell_a), + LabelValue::from(-cell_b), + LabelValue::from(-cell_c), + ]); } } else { // different species, find the right order for the pair @@ -501,6 +498,11 @@ impl FullNeighborList { let species = system.species()?; for pair in system.pairs()? { + if pair.first == pair.second { + // self pairs should not be part of the neighbors list + assert_ne!(pair.cell_shift_indices, [0, 0, 0]); + } + let first_block_i = descriptor.keys().position(&[ species[pair.first].into(), species[pair.second].into() ]); @@ -565,11 +567,6 @@ impl FullNeighborList { } } - if pair.first == pair.second { - // do not duplicate self pairs - continue; - } - // then the pair second -> first if let Some(second_block_i) = second_block_i { let mut block = descriptor.block_mut_by_id(second_block_i); @@ -764,6 +761,75 @@ mod tests { assert_relative_eq!(array, expected, max_relative=1e-6); } + #[test] + fn periodic_neighbor_list() { + let mut calculator = Calculator::from(Box::new(NeighborList{ + cutoff: 12.0, + full_neighbor_list: false, + self_pairs: false, + }) as Box); + + let mut systems = test_systems(&["CH"]); + + let descriptor = calculator.compute(&mut systems, Default::default()).unwrap(); + assert_eq!(*descriptor.keys(), Labels::new( + ["species_first_atom", "species_second_atom"], + &[[1, 1], [1, 6], [6, 6]] + )); + + // H-H block + let block = descriptor.block_by_id(0); + assert_eq!(block.samples(), Labels::new( + ["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], + // the pairs only differ in cell shifts + &[[0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 1, 0], [0, 1, 1, 1, 0, 0]] + )); + + let array = block.values().to_array(); + let expected = &ndarray::arr3(&[ + [[0.0], [0.0], [10.0]], + [[0.0], [10.0], [0.0]], + [[10.0], [0.0], [0.0]], + ]).into_dyn(); + assert_relative_eq!(array, expected, max_relative=1e-6); + + // now a full NL + let mut calculator = Calculator::from(Box::new(NeighborList{ + cutoff: 12.0, + full_neighbor_list: true, + self_pairs: false, + }) as Box); + + let descriptor = calculator.compute(&mut systems, Default::default()).unwrap(); + assert_eq!(*descriptor.keys(), Labels::new( + ["species_first_atom", "species_second_atom"], + &[[1, 1], [1, 6], [6, 1], [6, 6]] + )); + + // H-H block + let block = descriptor.block_by_id(0); + assert_eq!(block.samples(), Labels::new( + ["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], + // twice as many pairs + &[ + [0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, -1], + [0, 1, 1, 0, 1, 0], [0, 1, 1, 0, -1, 0], + [0, 1, 1, 1, 0, 0], [0, 1, 1, -1, 0, 0], + ] + )); + + let array = block.values().to_array(); + let expected = &ndarray::arr3(&[ + [[0.0], [0.0], [10.0]], + [[0.0], [0.0], [-10.0]], + [[0.0], [10.0], [0.0]], + [[0.0], [-10.0], [0.0]], + [[10.0], [0.0], [0.0]], + [[-10.0], [0.0], [0.0]], + ]).into_dyn(); + assert_relative_eq!(array, expected, max_relative=1e-6); + } + #[test] fn finite_differences_positions() { // half neighbor list diff --git a/rascaline/src/calculators/soap/spherical_expansion.rs b/rascaline/src/calculators/soap/spherical_expansion.rs index 63a8ddf19..aac10d3ca 100644 --- a/rascaline/src/calculators/soap/spherical_expansion.rs +++ b/rascaline/src/calculators/soap/spherical_expansion.rs @@ -241,12 +241,6 @@ impl SphericalExpansion { } } - if pair.first == pair.second { - // do not compute for the reversed pair if the pair is - // between an atom and its image - continue; - } - if let Some(mapped_center) = result.centers_mapping[pair.second] { // add the pair contribution to the atomic environnement // corresponding to the **second** atom in the pair @@ -778,7 +772,7 @@ mod tests { fn parameters() -> SphericalExpansionParameters { SphericalExpansionParameters { - cutoff: 3.5, + cutoff: 7.8, max_radial: 6, max_angular: 6, atomic_gaussian_width: 0.3, diff --git a/rascaline/src/calculators/soap/spherical_expansion_pair.rs b/rascaline/src/calculators/soap/spherical_expansion_pair.rs index 737b68797..898a31523 100644 --- a/rascaline/src/calculators/soap/spherical_expansion_pair.rs +++ b/rascaline/src/calculators/soap/spherical_expansion_pair.rs @@ -755,13 +755,7 @@ impl CalculatorBase for SphericalExpansionByPair { } } - // also check for the block with a reversed pair, except if - // we are handling a pair between an atom and it's own - // periodic image - if pair.first == pair.second { - continue; - } - + // also check for the block with a reversed pair contribution.inverse_pair(&self.m_1_pow_l); for spherical_harmonics_l in 0..=self.parameters.max_angular { @@ -817,7 +811,7 @@ mod tests { fn parameters() -> SphericalExpansionParameters { SphericalExpansionParameters { - cutoff: 3.5, + cutoff: 7.3, max_radial: 6, max_angular: 6, atomic_gaussian_width: 0.3,