From b60fbc5cf1b90b6cd65bc5d2dd53f2b20530b3aa Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Fri, 13 Oct 2023 18:24:00 +0200 Subject: [PATCH] Allow to add data to a System --- rascaline-c-api/src/calculator.rs | 4 +- rascaline-c-api/src/system.rs | 6 +- rascaline/benches/lode-spherical-expansion.rs | 14 +---- rascaline/benches/soap-power-spectrum.rs | 15 +---- rascaline/benches/soap-spherical-expansion.rs | 14 +---- rascaline/examples/compute-soap.rs | 8 +-- rascaline/examples/profiling.rs | 7 +-- rascaline/src/calculator.rs | 6 +- .../src/calculators/atomic_composition.rs | 8 +-- rascaline/src/calculators/dummy_calculator.rs | 8 +-- .../calculators/lode/spherical_expansion.rs | 12 ++-- rascaline/src/calculators/mod.rs | 8 +-- rascaline/src/calculators/neighbor_list.rs | 20 +++---- .../src/calculators/soap/power_spectrum.rs | 8 +-- .../src/calculators/soap/radial_spectrum.rs | 8 +-- .../calculators/soap/spherical_expansion.rs | 20 +++---- .../soap/spherical_expansion_pair.rs | 10 ++-- rascaline/src/calculators/sorted_distances.rs | 8 +-- rascaline/src/calculators/tests_utils.rs | 30 +++++----- rascaline/src/labels/keys.rs | 10 ++-- rascaline/src/labels/samples/atom_centered.rs | 4 +- rascaline/src/labels/samples/long_range.rs | 4 +- rascaline/src/labels/samples/mod.rs | 4 +- rascaline/src/lib.rs | 2 +- rascaline/src/systems/chemfiles.rs | 33 +++++++++-- rascaline/src/systems/mod.rs | 57 ++++++++++++++++++- rascaline/src/systems/simple_system.rs | 8 +-- rascaline/src/systems/test_utils.rs | 4 +- rascaline/src/tutorials/moments/moments.rs | 8 +-- .../src/tutorials/moments/s1_scaffold.rs | 8 +-- .../src/tutorials/moments/s2_metadata.rs | 8 +-- .../src/tutorials/moments/s3_compute_1.rs | 8 +-- .../src/tutorials/moments/s3_compute_2.rs | 8 +-- .../src/tutorials/moments/s3_compute_3.rs | 8 +-- .../src/tutorials/moments/s3_compute_4.rs | 8 +-- .../src/tutorials/moments/s3_compute_5.rs | 8 +-- rascaline/tests/data/mod.rs | 4 +- rascaline/tests/lode-madelung.rs | 18 +++--- rascaline/tests/lode-vs-soap.rs | 2 +- 39 files changed, 237 insertions(+), 191 deletions(-) diff --git a/rascaline-c-api/src/calculator.rs b/rascaline-c-api/src/calculator.rs index 7bd3bfad4..4a836f462 100644 --- a/rascaline-c-api/src/calculator.rs +++ b/rascaline-c-api/src/calculator.rs @@ -402,7 +402,7 @@ pub unsafe extern fn rascal_calculator_compute( } check_pointers!(calculator, descriptor, systems); - // Create a Vec> from the passed systems + // Create a Vec from the passed systems let c_systems = if systems_count == 0 { &mut [] } else { @@ -411,7 +411,7 @@ pub unsafe extern fn rascal_calculator_compute( }; let mut systems = Vec::with_capacity(c_systems.len()); for system in c_systems { - systems.push(Box::new(system) as Box); + systems.push(System::new(system)); } let c_gradients = if options.gradients_count == 0 { diff --git a/rascaline-c-api/src/system.rs b/rascaline-c-api/src/system.rs index 61cbb8a35..c29c8bd1f 100644 --- a/rascaline-c-api/src/system.rs +++ b/rascaline-c-api/src/system.rs @@ -3,7 +3,7 @@ use std::ffi::CStr; use rascaline::types::{Vector3D, Matrix3}; use rascaline::systems::{SimpleSystem, Pair, UnitCell}; -use rascaline::{Error, System}; +use rascaline::{Error, SystemBase}; use crate::RASCAL_SYSTEM_ERROR; @@ -111,7 +111,7 @@ pub struct rascal_system_t { unsafe impl Send for rascal_system_t {} unsafe impl Sync for rascal_system_t {} -impl<'a> System for &'a mut rascal_system_t { +impl<'a> SystemBase for &'a mut rascal_system_t { fn size(&self) -> Result { let function = self.size.ok_or_else(|| Error::External { status: RASCAL_SYSTEM_ERROR, @@ -442,7 +442,7 @@ pub unsafe extern fn rascal_basic_systems_read( catch_unwind(move || { check_pointers!(path, systems, count); let path = CStr::from_ptr(path).to_str()?; - let simple_systems = rascaline::systems::read_from_file(path)?; + let simple_systems = rascaline::systems::read_simple_systems_from_file(path)?; let mut c_systems = Vec::with_capacity(simple_systems.len()); for system in simple_systems { diff --git a/rascaline/benches/lode-spherical-expansion.rs b/rascaline/benches/lode-spherical-expansion.rs index 60bcd2d39..eb338506a 100644 --- a/rascaline/benches/lode-spherical-expansion.rs +++ b/rascaline/benches/lode-spherical-expansion.rs @@ -1,24 +1,16 @@ #![allow(clippy::needless_return)] -use rascaline::{Calculator, System, CalculationOptions}; +use rascaline::{Calculator, CalculationOptions}; use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode}; use criterion::{criterion_group, criterion_main}; -fn load_systems(path: &str) -> Vec> { - let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path)) - .expect("failed to read file"); - - return systems.into_iter() - .map(|s| Box::new(s) as Box) - .collect() -} - fn run_spherical_expansion(mut group: BenchmarkGroup, path: &str, gradients: bool, test_mode: bool, ) { - let mut systems = load_systems(path); + let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path)) + .expect("failed to read file"); if test_mode { // Reduce the time/RAM required to test the benchmarks code. diff --git a/rascaline/benches/soap-power-spectrum.rs b/rascaline/benches/soap-power-spectrum.rs index ae0a34772..d3cf8a4c9 100644 --- a/rascaline/benches/soap-power-spectrum.rs +++ b/rascaline/benches/soap-power-spectrum.rs @@ -1,27 +1,18 @@ #![allow(clippy::needless_return)] -use rascaline::{Calculator, System, CalculationOptions}; +use rascaline::{Calculator, CalculationOptions}; use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode}; use criterion::{criterion_group, criterion_main}; - -fn load_systems(path: &str) -> Vec> { - let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path)) - .expect("failed to read file"); - - return systems.into_iter() - .map(|s| Box::new(s) as Box) - .collect() -} - fn run_soap_power_spectrum( mut group: BenchmarkGroup, path: &str, gradients: bool, test_mode: bool, ) { - let mut systems = load_systems(path); + let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path)) + .expect("failed to read file"); if test_mode { // Reduce the time/RAM required to test the benchmarks code. diff --git a/rascaline/benches/soap-spherical-expansion.rs b/rascaline/benches/soap-spherical-expansion.rs index a460ca622..3621fda2f 100644 --- a/rascaline/benches/soap-spherical-expansion.rs +++ b/rascaline/benches/soap-spherical-expansion.rs @@ -1,24 +1,16 @@ #![allow(clippy::needless_return)] -use rascaline::{Calculator, System, CalculationOptions}; +use rascaline::{Calculator, CalculationOptions}; use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode}; use criterion::{criterion_group, criterion_main}; -fn load_systems(path: &str) -> Vec> { - let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path)) - .expect("failed to read file"); - - return systems.into_iter() - .map(|s| Box::new(s) as Box) - .collect() -} - fn run_spherical_expansion(mut group: BenchmarkGroup, path: &str, gradients: bool, test_mode: bool, ) { - let mut systems = load_systems(path); + let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path)) + .expect("failed to read file"); if test_mode { // Reduce the time/RAM required to test the benchmarks code. diff --git a/rascaline/examples/compute-soap.rs b/rascaline/examples/compute-soap.rs index 42b9b1221..674e7a4de 100644 --- a/rascaline/examples/compute-soap.rs +++ b/rascaline/examples/compute-soap.rs @@ -1,14 +1,10 @@ use metatensor::Labels; -use rascaline::{Calculator, System, CalculationOptions}; +use rascaline::{Calculator, CalculationOptions}; fn main() -> Result<(), Box> { // load the systems from command line argument let path = std::env::args().nth(1).expect("expected a command line argument"); - let systems = rascaline::systems::read_from_file(path)?; - // transform systems into a vector of trait objects (`Vec>`) - let mut systems = systems.into_iter() - .map(|s| Box::new(s) as Box) - .collect::>(); + let mut systems = rascaline::systems::read_from_file(path)?; // pass hyper-parameters as JSON let parameters = r#"{ diff --git a/rascaline/examples/profiling.rs b/rascaline/examples/profiling.rs index c172e610c..ce7f6111b 100644 --- a/rascaline/examples/profiling.rs +++ b/rascaline/examples/profiling.rs @@ -1,5 +1,5 @@ use metatensor::{TensorMap, Labels}; -use rascaline::{Calculator, System, CalculationOptions}; +use rascaline::{Calculator, CalculationOptions}; fn main() -> Result<(), Box> { let path = std::env::args().nth(1).expect("expected a command line argument"); @@ -28,10 +28,7 @@ fn main() -> Result<(), Box> { /// Compute SOAP power spectrum, this is the same code as the 'compute-soap' /// example fn compute_soap(path: &str) -> Result> { - let systems = rascaline::systems::read_from_file(path)?; - let mut systems = systems.into_iter() - .map(|s| Box::new(s) as Box) - .collect::>(); + let mut systems = rascaline::systems::read_from_file(path)?; let parameters = r#"{ "cutoff": 5.0, diff --git a/rascaline/src/calculator.rs b/rascaline/src/calculator.rs index 71e1d8fd6..077a4b561 100644 --- a/rascaline/src/calculator.rs +++ b/rascaline/src/calculator.rs @@ -285,7 +285,7 @@ impl Calculator { } #[time_graph::instrument(name="Calculator::prepare")] - fn prepare(&mut self, systems: &mut [Box], options: CalculationOptions) -> Result { + fn prepare(&mut self, systems: &mut [System], options: CalculationOptions) -> Result { let default_keys = self.implementation.keys(systems)?; let keys = match options.selected_keys { @@ -503,14 +503,14 @@ impl Calculator { /// features. pub fn compute( &mut self, - systems: &mut [Box], + systems: &mut [System], options: CalculationOptions, ) -> Result { let mut native_systems; let systems = if options.use_native_system { native_systems = Vec::with_capacity(systems.len()); for system in systems { - native_systems.push(Box::new(SimpleSystem::try_from(&**system)?) as Box); + native_systems.push(System::new(SimpleSystem::try_from(&**system)?) as System); } &mut native_systems } else { diff --git a/rascaline/src/calculators/atomic_composition.rs b/rascaline/src/calculators/atomic_composition.rs index 41be866dd..530db5ba2 100644 --- a/rascaline/src/calculators/atomic_composition.rs +++ b/rascaline/src/calculators/atomic_composition.rs @@ -36,7 +36,7 @@ impl CalculatorBase for AtomicComposition { &[] } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { return CenterTypesKeys.keys(systems); } @@ -48,7 +48,7 @@ impl CalculatorBase for AtomicComposition { return vec!["system", "atom"]; } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type"]); let mut samples = Vec::new(); for [center_type_key] in keys.iter_fixed_size() { @@ -84,7 +84,7 @@ impl CalculatorBase for AtomicComposition { &self, keys: &Labels, _samples: &[Labels], - _systems: &mut [Box], + _systems: &mut [System], ) -> Result, Error> { // Positions/cell gradients of the composition are zero everywhere. // Therefore, we only return a vector of empty labels (one for each key). @@ -110,7 +110,7 @@ impl CalculatorBase for AtomicComposition { fn compute( &mut self, - systems: &mut [Box], + systems: &mut [System], descriptor: &mut TensorMap, ) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["center_type"]); diff --git a/rascaline/src/calculators/dummy_calculator.rs b/rascaline/src/calculators/dummy_calculator.rs index 89644e1f0..e9aa94738 100644 --- a/rascaline/src/calculators/dummy_calculator.rs +++ b/rascaline/src/calculators/dummy_calculator.rs @@ -43,7 +43,7 @@ impl CalculatorBase for DummyCalculator { std::slice::from_ref(&self.cutoff) } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { return CenterTypesKeys.keys(systems); } @@ -51,7 +51,7 @@ impl CalculatorBase for DummyCalculator { AtomCenteredSamples::sample_names() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type"]); let mut samples = Vec::new(); for [center_type] in keys.iter_fixed_size() { @@ -75,7 +75,7 @@ impl CalculatorBase for DummyCalculator { } } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { debug_assert_eq!(keys.count(), samples.len()); let mut gradient_samples = Vec::new(); for ([center_type], samples) in keys.iter_fixed_size().zip(samples) { @@ -110,7 +110,7 @@ impl CalculatorBase for DummyCalculator { } #[time_graph::instrument(name = "DummyCalculator::compute")] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { if self.name.contains("log-test-info:") { info!("{}", self.name); } else if self.name.contains("log-test-warn:") { diff --git a/rascaline/src/calculators/lode/spherical_expansion.rs b/rascaline/src/calculators/lode/spherical_expansion.rs index 6599458cd..417b18a13 100644 --- a/rascaline/src/calculators/lode/spherical_expansion.rs +++ b/rascaline/src/calculators/lode/spherical_expansion.rs @@ -400,7 +400,7 @@ impl LodeSphericalExpansion { /// By symmetry, this only affects the (l, m) = (0, 0) components of the /// projection coefficients and only the neighbor type blocks that agrees /// with the center atom. - fn do_center_contribution(&mut self, systems: &mut[Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn do_center_contribution(&mut self, systems: &mut[System], descriptor: &mut TensorMap) -> Result<(), Error> { let mut radial_integral = self.radial_integral.get_or(|| { let radial_integral = LodeRadialIntegralCache::new( self.parameters.radial_basis.clone(), @@ -470,7 +470,7 @@ impl CalculatorBase for LodeSphericalExpansion { &[] } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let builder = AllTypesPairsKeys {}; let keys = builder.keys(systems)?; @@ -488,7 +488,7 @@ impl CalculatorBase for LodeSphericalExpansion { LongRangeSamplesPerAtom::sample_names() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); // only compute the samples once for each `center_type, neighbor_type`, @@ -527,7 +527,7 @@ impl CalculatorBase for LodeSphericalExpansion { } } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); assert_eq!(keys.count(), samples.len()); @@ -588,7 +588,7 @@ impl CalculatorBase for LodeSphericalExpansion { } #[time_graph::instrument(name = "LodeSphericalExpansion::compute")] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); self.do_center_contribution(systems, descriptor)?; @@ -949,7 +949,7 @@ mod tests { ]); crate::calculators::tests_utils::compute_partial( - calculator, &mut [Box::new(system)], &keys, &samples, &properties + calculator, &mut [System::new(system)], &keys, &samples, &properties ); } diff --git a/rascaline/src/calculators/mod.rs b/rascaline/src/calculators/mod.rs index 5da33e0cf..851cb5deb 100644 --- a/rascaline/src/calculators/mod.rs +++ b/rascaline/src/calculators/mod.rs @@ -22,14 +22,14 @@ pub trait CalculatorBase: std::panic::RefUnwindSafe { fn cutoffs(&self) -> &[f64]; /// Get the set of keys for this calculator and the given systems - fn keys(&self, systems: &mut [Box]) -> Result; + fn keys(&self, systems: &mut [System]) -> Result; /// Get the names used for sample labels by this calculator fn sample_names(&self) -> Vec<&str>; /// Get the full list of samples this calculator would create for the given /// systems. This function should return one set of samples for each key. - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error>; + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error>; /// Can this calculator compute gradients with respect to the `parameter`? /// Right now, `parameter` can be either `"positions"`, `"strain"` or @@ -43,7 +43,7 @@ pub trait CalculatorBase: std::panic::RefUnwindSafe { /// /// If the gradients with respect to positions are not available, this /// function should return an error. - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error>; + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error>; /// Get the components this calculator computes for each key. fn components(&self, keys: &Labels) -> Vec>; @@ -66,7 +66,7 @@ pub trait CalculatorBase: std::panic::RefUnwindSafe { /// block if they are supported according to /// [`CalculatorBase::supports_gradient`], and the users requested them as /// part of the calculation options. - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error>; + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error>; } diff --git a/rascaline/src/calculators/neighbor_list.rs b/rascaline/src/calculators/neighbor_list.rs index 668d96424..5fc5dbadf 100644 --- a/rascaline/src/calculators/neighbor_list.rs +++ b/rascaline/src/calculators/neighbor_list.rs @@ -65,7 +65,7 @@ impl CalculatorBase for NeighborList { std::slice::from_ref(&self.cutoff) } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { assert!(self.cutoff > 0.0 && self.cutoff.is_finite()); if self.full_neighbor_list { @@ -85,7 +85,7 @@ impl CalculatorBase for NeighborList { return vec!["system", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"]; } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { assert!(self.cutoff > 0.0 && self.cutoff.is_finite()); if self.full_neighbor_list { @@ -109,7 +109,7 @@ impl CalculatorBase for NeighborList { } } - fn positions_gradient_samples(&self, _keys: &Labels, samples: &[Labels], _systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, _keys: &Labels, samples: &[Labels], _systems: &mut [System]) -> Result, Error> { let mut results = Vec::new(); for block_samples in samples { @@ -147,7 +147,7 @@ impl CalculatorBase for NeighborList { } #[time_graph::instrument(name = "NeighborList::compute")] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { if self.full_neighbor_list { FullNeighborList { cutoff: self.cutoff, @@ -171,7 +171,7 @@ struct HalfNeighborList { } impl HalfNeighborList { - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let mut all_types_pairs = BTreeSet::new(); for system in systems { system.compute_neighbors(self.cutoff)?; @@ -199,7 +199,7 @@ impl HalfNeighborList { return Ok(keys.finish()); } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { let mut results = Vec::new(); for [first_atom_type, second_atom_type] in keys.iter_fixed_size() { @@ -267,7 +267,7 @@ impl HalfNeighborList { return Ok(results); } - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { for (system_i, system) in systems.iter_mut().enumerate() { system.compute_neighbors(self.cutoff)?; let types = system.types()?; @@ -372,7 +372,7 @@ pub struct FullNeighborList { impl FullNeighborList { /// Get the list of keys for these systems (list of pair types present in the systems) - pub(crate) fn keys(&self, systems: &mut [Box]) -> Result { + pub(crate) fn keys(&self, systems: &mut [System]) -> Result { let mut all_types_pairs = BTreeSet::new(); for system in systems { system.compute_neighbors(self.cutoff)?; @@ -400,7 +400,7 @@ impl FullNeighborList { return Ok(keys.finish()); } - pub(crate) fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + pub(crate) fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { let mut results = Vec::new(); for &[first_atom_type, second_atom_type] in keys.iter_fixed_size() { @@ -492,7 +492,7 @@ impl FullNeighborList { } #[allow(clippy::too_many_lines)] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { for (system_i, system) in systems.iter_mut().enumerate() { system.compute_neighbors(self.cutoff)?; let types = system.types()?; diff --git a/rascaline/src/calculators/soap/power_spectrum.rs b/rascaline/src/calculators/soap/power_spectrum.rs index 33662f023..4043f7298 100644 --- a/rascaline/src/calculators/soap/power_spectrum.rs +++ b/rascaline/src/calculators/soap/power_spectrum.rs @@ -441,7 +441,7 @@ impl CalculatorBase for SoapPowerSpectrum { self.spherical_expansion.cutoffs() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let builder = CenterTwoNeighborsTypesKeys { cutoff: self.parameters.cutoff, self_pairs: true, @@ -454,7 +454,7 @@ impl CalculatorBase for SoapPowerSpectrum { AtomCenteredSamples::sample_names() } - fn samples(&self, keys: &metatensor::Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &metatensor::Labels, systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type", "neighbor_1_type", "neighbor_2_type"]); let mut result = Vec::new(); for [center_type, neighbor_1_type, neighbor_2_type] in keys.iter_fixed_size() { @@ -478,7 +478,7 @@ impl CalculatorBase for SoapPowerSpectrum { return Ok(result); } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type", "neighbor_1_type", "neighbor_2_type"]); assert_eq!(keys.count(), samples.len()); @@ -532,7 +532,7 @@ impl CalculatorBase for SoapPowerSpectrum { #[time_graph::instrument(name = "SoapPowerSpectrum::compute")] #[allow(clippy::too_many_lines)] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert!(descriptor.keys().count() > 0); let mut gradients = Vec::new(); diff --git a/rascaline/src/calculators/soap/radial_spectrum.rs b/rascaline/src/calculators/soap/radial_spectrum.rs index 0670cd2ae..3989c4ed5 100644 --- a/rascaline/src/calculators/soap/radial_spectrum.rs +++ b/rascaline/src/calculators/soap/radial_spectrum.rs @@ -130,7 +130,7 @@ impl CalculatorBase for SoapRadialSpectrum { self.spherical_expansion.cutoffs() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let builder = CenterSingleNeighborsTypesKeys { cutoff: self.parameters.cutoff, self_pairs: true, @@ -145,7 +145,7 @@ impl CalculatorBase for SoapRadialSpectrum { fn samples( &self, keys: &metatensor::Labels, - systems: &mut [Box], + systems: &mut [System], ) -> Result, Error> { assert_eq!(keys.names(), ["center_type", "neighbor_type"]); let mut result = Vec::new(); @@ -170,7 +170,7 @@ impl CalculatorBase for SoapRadialSpectrum { } } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type", "neighbor_type"]); assert_eq!(keys.count(), samples.len()); @@ -208,7 +208,7 @@ impl CalculatorBase for SoapRadialSpectrum { } #[time_graph::instrument(name = "SoapRadialSpectrum::compute")] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]); assert!(descriptor.keys().count() > 0); diff --git a/rascaline/src/calculators/soap/spherical_expansion.rs b/rascaline/src/calculators/soap/spherical_expansion.rs index b703f833f..b75b54e1d 100644 --- a/rascaline/src/calculators/soap/spherical_expansion.rs +++ b/rascaline/src/calculators/soap/spherical_expansion.rs @@ -44,7 +44,7 @@ impl SphericalExpansion { /// Accumulate the self contribution to the spherical expansion /// coefficients, i.e. the contribution arising from the density of the /// center atom around itself. - fn do_self_contributions(&mut self, systems: &[Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn do_self_contributions(&mut self, systems: &[System], descriptor: &mut TensorMap) -> Result<(), Error> { debug_assert_eq!(descriptor.keys().names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); let self_contribution = self.by_pair.self_contribution(); @@ -95,7 +95,7 @@ impl SphericalExpansion { #[allow(clippy::too_many_lines)] fn accumulate_all_pairs( &self, - system: &dyn System, + system: &System, do_gradients: GradientsOptions, requested_atoms: &BTreeSet, ) -> Result { @@ -344,7 +344,7 @@ impl SphericalExpansion { &self, key: &[LabelValue], block: &mut TensorBlockRefMut, - system: &dyn System, + system: &System, result: &PairAccumulationResult, ) -> Result<(), Error> { let types = system.types()?; @@ -407,7 +407,7 @@ impl SphericalExpansion { &self, key: &[LabelValue], block: &mut TensorBlockRefMut, - system: &dyn System, + system: &System, result: &PairAccumulationResult, ) -> Result<(), Error> { let positions_gradients = if let Some(ref data) = result.positions_gradient_by_pair { @@ -511,7 +511,7 @@ impl SphericalExpansion { key: &[LabelValue], parameter: &str, block: &mut TensorBlockRefMut, - system: &dyn System, + system: &System, result: &PairAccumulationResult, ) -> Result<(), Error> { let contributions = if parameter == "strain" { @@ -637,7 +637,7 @@ impl CalculatorBase for SphericalExpansion { self.by_pair.cutoffs() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let builder = CenterSingleNeighborsTypesKeys { cutoff: self.by_pair.parameters().cutoff, self_pairs: true, @@ -658,7 +658,7 @@ impl CalculatorBase for SphericalExpansion { AtomCenteredSamples::sample_names() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); // only compute the samples once for each `center_type, neighbor_type`, @@ -698,7 +698,7 @@ impl CalculatorBase for SphericalExpansion { } } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); assert_eq!(keys.count(), samples.len()); @@ -762,7 +762,7 @@ impl CalculatorBase for SphericalExpansion { } #[time_graph::instrument(name = "SphericalExpansion::compute")] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); assert!(descriptor.keys().count() > 0); @@ -778,7 +778,7 @@ impl CalculatorBase for SphericalExpansion { .zip_eq(&mut descriptors_by_system) .try_for_each(|(system, descriptor)| { system.compute_neighbors(self.by_pair.parameters().cutoff)?; - let system = &**system; + let system = &*system; // we will only run the calculation on pairs where one of the // atom is part of the requested samples diff --git a/rascaline/src/calculators/soap/spherical_expansion_pair.rs b/rascaline/src/calculators/soap/spherical_expansion_pair.rs index b4a78fab1..8bdc4b012 100644 --- a/rascaline/src/calculators/soap/spherical_expansion_pair.rs +++ b/rascaline/src/calculators/soap/spherical_expansion_pair.rs @@ -262,7 +262,7 @@ impl SphericalExpansionByPair { /// /// For the pair-by-pair spherical expansion, we use a special `pair_id` /// (-1) to store the data associated with self-pairs. - fn do_self_contributions(&self, systems: &[Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn do_self_contributions(&self, systems: &[System], descriptor: &mut TensorMap) -> Result<(), Error> { debug_assert_eq!(descriptor.keys().names(), ["o3_lambda", "o3_sigma", "first_atom_type", "second_atom_type"]); let self_contribution = self.self_contribution(); @@ -552,7 +552,7 @@ impl CalculatorBase for SphericalExpansionByPair { std::slice::from_ref(&self.parameters.cutoff) } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { // the atomic type part of the keys is the same for all l, and the same // as what a FullNeighborList with `self_pairs=True` produces. let full_neighbors_list_keys = FullNeighborList { @@ -580,7 +580,7 @@ impl CalculatorBase for SphericalExpansionByPair { return vec!["system", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"]; } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { // get all atomic types pairs in keys as a new set of Labels let mut types_keys = BTreeSet::new(); for &[_, _, first_type, second_type] in keys.iter_fixed_size() { @@ -642,7 +642,7 @@ impl CalculatorBase for SphericalExpansionByPair { } } - fn positions_gradient_samples(&self, _: &Labels, samples: &[Labels], _: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, _: &Labels, samples: &[Labels], _: &mut [System]) -> Result, Error> { let mut results = Vec::new(); for block_samples in samples { @@ -706,7 +706,7 @@ impl CalculatorBase for SphericalExpansionByPair { } #[time_graph::instrument(name = "SphericalExpansionByPair::compute")] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["o3_lambda", "o3_sigma", "first_atom_type", "second_atom_type"]); assert!(descriptor.keys().count() > 0); diff --git a/rascaline/src/calculators/sorted_distances.rs b/rascaline/src/calculators/sorted_distances.rs index b92163fbf..9f283339a 100644 --- a/rascaline/src/calculators/sorted_distances.rs +++ b/rascaline/src/calculators/sorted_distances.rs @@ -41,7 +41,7 @@ impl CalculatorBase for SortedDistances { std::slice::from_ref(&self.cutoff) } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { if self.separate_neighbor_types { let builder = CenterSingleNeighborsTypesKeys { cutoff: self.cutoff, @@ -57,7 +57,7 @@ impl CalculatorBase for SortedDistances { AtomCenteredSamples::sample_names() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { let mut samples = Vec::new(); if self.separate_neighbor_types { assert_eq!(keys.names(), ["center_type", "neighbor_type"]); @@ -92,7 +92,7 @@ impl CalculatorBase for SortedDistances { return false; } - fn positions_gradient_samples(&self, _: &Labels, _: &[Labels], _: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, _: &Labels, _: &[Labels], _: &mut [System]) -> Result, Error> { unimplemented!() } @@ -115,7 +115,7 @@ impl CalculatorBase for SortedDistances { } #[time_graph::instrument(name = "SortedDistances::compute")] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { if self.separate_neighbor_types { assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]); } else { diff --git a/rascaline/src/calculators/tests_utils.rs b/rascaline/src/calculators/tests_utils.rs index 5594acaba..466713eb7 100644 --- a/rascaline/src/calculators/tests_utils.rs +++ b/rascaline/src/calculators/tests_utils.rs @@ -5,7 +5,7 @@ use metatensor::{Labels, TensorMap, LabelsBuilder}; use crate::calculator::LabelsSelection; use crate::{CalculationOptions, Calculator, Matrix3}; -use crate::systems::{System, SimpleSystem, UnitCell}; +use crate::systems::{System, SystemBase, SimpleSystem, UnitCell}; /// Check that computing a partial subset of features/samples works as intended /// for the given `calculator` and `systems`. @@ -15,7 +15,7 @@ use crate::systems::{System, SimpleSystem, UnitCell}; /// gradients. pub fn compute_partial( mut calculator: Calculator, - systems: &mut [Box], + systems: &mut [System], keys: &Labels, samples: &Labels, properties: &Labels, @@ -51,7 +51,7 @@ pub fn compute_partial( fn check_compute_partial_keys( calculator: &mut Calculator, - systems: &mut [Box], + systems: &mut [System], full: &TensorMap, keys: &Labels, ) { @@ -84,7 +84,7 @@ fn check_compute_partial_keys( fn check_compute_partial_properties( calculator: &mut Calculator, - systems: &mut [Box], + systems: &mut [System], full: &TensorMap, properties: &Labels, ) { @@ -130,7 +130,7 @@ fn check_compute_partial_properties( fn check_compute_partial_samples( calculator: &mut Calculator, - systems: &mut [Box], + systems: &mut [System], full: &TensorMap, samples: &Labels, ) { @@ -183,7 +183,7 @@ fn check_compute_partial_samples( fn check_compute_partial_both( calculator: &mut Calculator, - systems: &mut [Box], + systems: &mut [System], full: &TensorMap, samples: &Labels, properties: &Labels, @@ -265,17 +265,17 @@ pub fn finite_differences_positions(mut calculator: Calculator, system: &SimpleS gradients: &["positions"], ..Default::default() }; - let reference = calculator.compute(&mut [Box::new(system.clone())], calculation_options).unwrap(); + let reference = calculator.compute(&mut [System::new(system.clone())], calculation_options).unwrap(); for neighbor_i in 0..system.size().unwrap() { for xyz in 0..3 { let mut system_pos = system.clone(); system_pos.positions_mut()[neighbor_i][xyz] += options.displacement / 2.0; - let updated_pos = calculator.compute(&mut [Box::new(system_pos)], Default::default()).unwrap(); + let updated_pos = calculator.compute(&mut [System::new(system_pos)], Default::default()).unwrap(); let mut system_neg = system.clone(); system_neg.positions_mut()[neighbor_i][xyz] -= options.displacement / 2.0; - let updated_neg = calculator.compute(&mut [Box::new(system_neg)], Default::default()).unwrap(); + let updated_neg = calculator.compute(&mut [System::new(system_neg)], Default::default()).unwrap(); assert_eq!(updated_pos.keys(), reference.keys()); assert_eq!(updated_neg.keys(), reference.keys()); @@ -327,7 +327,7 @@ pub fn finite_differences_cell(mut calculator: Calculator, system: &SimpleSystem gradients: &["cell"], ..Default::default() }; - let reference = calculator.compute(&mut [Box::new(system.clone())], calculation_options).unwrap(); + let reference = calculator.compute(&mut [System::new(system.clone())], calculation_options).unwrap(); let original_cell = system.cell().unwrap().matrix(); for abc in 0..3 { @@ -336,12 +336,12 @@ pub fn finite_differences_cell(mut calculator: Calculator, system: &SimpleSystem deformed_cell[abc][xyz] += options.displacement / 2.0; let mut system_pos = system.clone(); system_pos.set_cell(UnitCell::from(deformed_cell)); - let updated_pos = calculator.compute(&mut [Box::new(system_pos)], Default::default()).unwrap(); + let updated_pos = calculator.compute(&mut [System::new(system_pos)], Default::default()).unwrap(); deformed_cell[abc][xyz] -= options.displacement; let mut system_neg = system.clone(); system_neg.set_cell(UnitCell::from(deformed_cell)); - let updated_neg = calculator.compute(&mut [Box::new(system_neg)], Default::default()).unwrap(); + let updated_neg = calculator.compute(&mut [System::new(system_neg)], Default::default()).unwrap(); for (block_i, (_, block)) in reference.iter().enumerate() { let gradients = &block.gradient("cell").unwrap(); @@ -389,7 +389,7 @@ pub fn finite_differences_strain(mut calculator: Calculator, system: &SimpleSyst gradients: &["strain"], ..Default::default() }; - let reference = calculator.compute(&mut [Box::new(system.clone())], calculation_options).unwrap(); + let reference = calculator.compute(&mut [System::new(system.clone())], calculation_options).unwrap(); let original_cell = system.cell().unwrap().matrix(); for xyz_1 in 0..3 { @@ -401,7 +401,7 @@ pub fn finite_differences_strain(mut calculator: Calculator, system: &SimpleSyst for position in system_pos.positions_mut() { *position = *position * strain; } - let updated_pos = calculator.compute(&mut [Box::new(system_pos)], Default::default()).unwrap(); + let updated_pos = calculator.compute(&mut [System::new(system_pos)], Default::default()).unwrap(); strain[xyz_1][xyz_2] -= options.displacement; let mut system_neg = system.clone(); @@ -409,7 +409,7 @@ pub fn finite_differences_strain(mut calculator: Calculator, system: &SimpleSyst for position in system_neg.positions_mut() { *position = *position * strain; } - let updated_neg = calculator.compute(&mut [Box::new(system_neg)], Default::default()).unwrap(); + let updated_neg = calculator.compute(&mut [System::new(system_neg)], Default::default()).unwrap(); for (block_i, (_, block)) in reference.iter().enumerate() { let gradients = &block.gradient("strain").unwrap(); diff --git a/rascaline/src/labels/keys.rs b/rascaline/src/labels/keys.rs index aae620af5..48fbe393d 100644 --- a/rascaline/src/labels/keys.rs +++ b/rascaline/src/labels/keys.rs @@ -8,14 +8,14 @@ use crate::{System, Error}; /// Common interface to create a set of metatensor's `TensorMap` keys from systems pub trait KeysBuilder { /// Compute the keys corresponding to these systems - fn keys(&self, systems: &mut [Box]) -> Result; + fn keys(&self, systems: &mut [System]) -> Result; } /// Compute a set of keys with a single variable, the central atom type. pub struct CenterTypesKeys; impl KeysBuilder for CenterTypesKeys { - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let mut all_types = BTreeSet::new(); for system in systems { for &atomic_type in system.types()? { @@ -36,7 +36,7 @@ impl KeysBuilder for CenterTypesKeys { pub struct AllTypesPairsKeys {} impl KeysBuilder for AllTypesPairsKeys { - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let mut all_types_pairs = BTreeSet::new(); for system in systems { @@ -66,7 +66,7 @@ pub struct CenterSingleNeighborsTypesKeys { } impl KeysBuilder for CenterSingleNeighborsTypesKeys { - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { assert!(self.cutoff > 0.0 && self.cutoff.is_finite()); let mut all_types_pairs = BTreeSet::new(); @@ -108,7 +108,7 @@ pub struct CenterTwoNeighborsTypesKeys { } impl KeysBuilder for CenterTwoNeighborsTypesKeys { - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { assert!(self.cutoff > 0.0 && self.cutoff.is_finite()); let mut keys = BTreeSet::new(); diff --git a/rascaline/src/labels/samples/atom_centered.rs b/rascaline/src/labels/samples/atom_centered.rs index 96592c733..c7d520646 100644 --- a/rascaline/src/labels/samples/atom_centered.rs +++ b/rascaline/src/labels/samples/atom_centered.rs @@ -28,7 +28,7 @@ impl SamplesBuilder for AtomCenteredSamples { vec!["system", "atom"] } - fn samples(&self, systems: &mut [Box]) -> Result { + fn samples(&self, systems: &mut [System]) -> Result { assert!(self.cutoff > 0.0 && self.cutoff.is_finite(), "cutoff must be positive for AtomCenteredSamples"); let mut builder = LabelsBuilder::new(Self::sample_names()); for (system_i, system) in systems.iter_mut().enumerate() { @@ -100,7 +100,7 @@ impl SamplesBuilder for AtomCenteredSamples { return Ok(builder.finish()); } - fn gradients_for(&self, systems: &mut [Box], samples: &Labels) -> Result { + fn gradients_for(&self, systems: &mut [System], samples: &Labels) -> Result { assert!(self.cutoff > 0.0 && self.cutoff.is_finite(), "cutoff must be positive for AtomCenteredSamples"); assert_eq!(samples.names(), ["system", "atom"]); let mut builder = LabelsBuilder::new(vec!["sample", "system", "atom"]); diff --git a/rascaline/src/labels/samples/long_range.rs b/rascaline/src/labels/samples/long_range.rs index c4ac4c4a4..2edc59de8 100644 --- a/rascaline/src/labels/samples/long_range.rs +++ b/rascaline/src/labels/samples/long_range.rs @@ -22,7 +22,7 @@ impl SamplesBuilder for LongRangeSamplesPerAtom { vec!["system", "atom"] } - fn samples(&self, systems: &mut [Box]) -> Result { + fn samples(&self, systems: &mut [System]) -> Result { assert!(self.self_pairs, "self.self_pairs = false is not implemented"); let mut builder = LabelsBuilder::new(Self::sample_names()); @@ -52,7 +52,7 @@ impl SamplesBuilder for LongRangeSamplesPerAtom { return Ok(builder.finish()); } - fn gradients_for(&self, systems: &mut [Box], samples: &Labels) -> Result { + fn gradients_for(&self, systems: &mut [System], samples: &Labels) -> Result { assert_eq!(samples.names(), ["system", "atom"]); let mut builder = LabelsBuilder::new(vec!["sample", "system", "atom"]); diff --git a/rascaline/src/labels/samples/mod.rs b/rascaline/src/labels/samples/mod.rs index 96bd7c9d2..6fdc9207c 100644 --- a/rascaline/src/labels/samples/mod.rs +++ b/rascaline/src/labels/samples/mod.rs @@ -43,12 +43,12 @@ pub trait SamplesBuilder { /// Create `Labels` containing all the samples corresponding to the given /// list of systems. - fn samples(&self, systems: &mut [Box]) -> Result; + fn samples(&self, systems: &mut [System]) -> Result; /// Create a set of `Labels` containing the gradient samples corresponding /// to the given `samples` in the given `systems`; and only these. #[allow(unused_variables)] - fn gradients_for(&self, systems: &mut [Box], samples: &Labels) -> Result; + fn gradients_for(&self, systems: &mut [System], samples: &Labels) -> Result; } diff --git a/rascaline/src/lib.rs b/rascaline/src/lib.rs index 41acd16dc..d4e8318e6 100644 --- a/rascaline/src/lib.rs +++ b/rascaline/src/lib.rs @@ -24,7 +24,7 @@ mod errors; pub use self::errors::Error; pub mod systems; -pub use self::systems::{System, SimpleSystem}; +pub use self::systems::{System, SystemBase, SimpleSystem}; pub mod labels; diff --git a/rascaline/src/systems/chemfiles.rs b/rascaline/src/systems/chemfiles.rs index 35cb99b73..545f43a95 100644 --- a/rascaline/src/systems/chemfiles.rs +++ b/rascaline/src/systems/chemfiles.rs @@ -1,6 +1,6 @@ use std::path::Path; -use super::SimpleSystem; +use super::{SimpleSystem, System}; use crate::Error; #[cfg(feature = "chemfiles")] @@ -16,8 +16,7 @@ impl From for Error { /// This function can read all [formats supported by /// chemfiles](https://chemfiles.org/chemfiles/latest/formats.html). #[cfg(feature = "chemfiles")] -#[allow(clippy::needless_range_loop)] -pub fn read_from_file(path: impl AsRef) -> Result, Error> { +pub fn read_simple_systems_from_file(path: impl AsRef) -> Result, Error> { use std::collections::HashMap; use crate::Matrix3; use crate::systems::UnitCell; @@ -65,12 +64,38 @@ pub fn read_from_file(path: impl AsRef) -> Result, Error return Ok(systems); } +/// Read all structures in the file at the given `path` using +/// [chemfiles](https://chemfiles.org/), and convert them to `System`s. +/// +/// This function can read all [formats supported by +/// chemfiles](https://chemfiles.org/chemfiles/latest/formats.html). +#[cfg(feature = "chemfiles")] +pub fn read_from_file(path: impl AsRef) -> Result, Error> { + return Ok(read_simple_systems_from_file(path)? + .into_iter() + .map(System::new) + .collect()); +} + /// Read all structures in the file at the given `path` using /// [chemfiles](https://chemfiles.org/), and convert them to `SimpleSystem`s. /// /// This function can read all [formats supported by /// chemfiles](https://chemfiles.org/chemfiles/latest/formats.html). #[cfg(not(feature = "chemfiles"))] +pub fn read_simple_systems_from_file(_: impl AsRef) -> Result, Error> { + Err(Error::Chemfiles( + "read_simple_systems_from_file is only available with the chemfiles \ + feature enabled (RASCALINE_ENABLE_CHEMFILES=ON in CMake)".into() + )) +} + +/// Read all structures in the file at the given `path` using +/// [chemfiles](https://chemfiles.org/), and convert them to `System`s. +/// +/// This function can read all [formats supported by +/// chemfiles](https://chemfiles.org/chemfiles/latest/formats.html). +#[cfg(not(feature = "chemfiles"))] pub fn read_from_file(_: impl AsRef) -> Result, Error> { Err(Error::Chemfiles( "read_from_file is only available with the chemfiles feature enabled \ @@ -83,7 +108,7 @@ mod tests { use std::path::PathBuf; use approx::assert_relative_eq; - use crate::{System, Vector3D}; + use crate::Vector3D; use super::*; #[test] diff --git a/rascaline/src/systems/mod.rs b/rascaline/src/systems/mod.rs index 6e3252a11..4e74fd509 100644 --- a/rascaline/src/systems/mod.rs +++ b/rascaline/src/systems/mod.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use crate::{Error, Vector3D}; mod cell; @@ -11,6 +13,7 @@ pub use self::simple_system::SimpleSystem; mod chemfiles; pub use self::chemfiles::read_from_file; +pub use self::chemfiles::read_simple_systems_from_file; #[cfg(test)] pub(crate) mod test_utils; @@ -36,9 +39,9 @@ pub struct Pair { pub cell_shift_indices: [i32; 3], } -/// A `System` deals with the storage of atoms and related information, as well +/// A `SystemBase` deals with the storage of atoms and related information, as well /// as the computation of neighbor lists. -pub trait System: Send + Sync { +pub trait SystemBase: Send + Sync { /// Get the unit cell for this system fn cell(&self) -> Result; @@ -75,3 +78,53 @@ pub trait System: Send + Sync { /// `pairs_containing(j)`. fn pairs_containing(&self, atom: usize) -> Result<&[Pair], Error>; } + +/// A `System` deals with the storage of atoms and related information, as well +/// as the computation of neighbor lists. +/// +/// `System` also allows calculator to store arbitrary data inside the system +/// (to be used a a cache between different function calls in the +/// `CalculatorBase` trait). +pub struct System { + implementation: Box, + data: BTreeMap>, +} + +impl std::ops::Deref for System { + type Target = dyn SystemBase + 'static; + + fn deref(&self) -> &Self::Target { + &*self.implementation + } +} + +impl std::ops::DerefMut for System { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.implementation + } +} + +impl System { + /// TODO + pub fn new(system: impl SystemBase + 'static) -> System { + System { + implementation: Box::new(system), + data: BTreeMap::new() + } + } + + /// TODO + pub fn store_data(&mut self, name: String, data: impl std::any::Any + Send + Sync + 'static) { + self.data.insert(name, Box::new(data)); + } + + /// TODO + pub fn data(&self, name: &str) -> Option<&(dyn std::any::Any + Send + Sync)> { + self.data.get(name).map(|v| &**v) + } + + /// TODO + pub fn data_mut(&mut self, name: &str) -> Option<&mut (dyn std::any::Any + Send + Sync)> { + self.data.get_mut(name).map(|v| &mut **v) + } +} diff --git a/rascaline/src/systems/simple_system.rs b/rascaline/src/systems/simple_system.rs index 4af856ae2..c14042048 100644 --- a/rascaline/src/systems/simple_system.rs +++ b/rascaline/src/systems/simple_system.rs @@ -1,6 +1,6 @@ use crate::Error; -use super::{UnitCell, System, Vector3D, Pair}; +use super::{UnitCell, SystemBase, Vector3D, Pair}; use super::neighbors::NeighborsList; @@ -45,7 +45,7 @@ impl SimpleSystem { } } -impl System for SimpleSystem { +impl SystemBase for SimpleSystem { fn size(&self) -> Result { Ok(self.types.len()) } @@ -90,10 +90,10 @@ impl System for SimpleSystem { } } -impl std::convert::TryFrom<&dyn System> for SimpleSystem { +impl std::convert::TryFrom<&dyn SystemBase> for SimpleSystem { type Error = Error; - fn try_from(system: &dyn System) -> Result { + fn try_from(system: &dyn SystemBase) -> Result { let mut new = SimpleSystem::new(system.cell()?); for (&atomic_type, &position) in system.types()?.iter().zip(system.positions()?) { new.add_atom(atomic_type, position); diff --git a/rascaline/src/systems/test_utils.rs b/rascaline/src/systems/test_utils.rs index 39888a344..0e439d8ae 100644 --- a/rascaline/src/systems/test_utils.rs +++ b/rascaline/src/systems/test_utils.rs @@ -1,9 +1,9 @@ use crate::{System, Vector3D}; use super::{UnitCell, SimpleSystem}; -pub fn test_systems(names: &[&str]) -> Vec> { +pub fn test_systems(names: &[&str]) -> Vec { return names.iter() - .map(|&name| Box::new(test_system(name)) as Box) + .map(|&name| System::new(test_system(name))) .collect(); } diff --git a/rascaline/src/tutorials/moments/moments.rs b/rascaline/src/tutorials/moments/moments.rs index 0a3684446..bdee2f63f 100644 --- a/rascaline/src/tutorials/moments/moments.rs +++ b/rascaline/src/tutorials/moments/moments.rs @@ -25,7 +25,7 @@ impl CalculatorBase for GeometricMoments { std::slice::from_ref(&self.cutoff) } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let builder = CenterSingleNeighborsTypesKeys { cutoff: self.cutoff, self_pairs: false, @@ -37,7 +37,7 @@ impl CalculatorBase for GeometricMoments { AtomCenteredSamples::sample_names() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type", "neighbor_type"]); let mut samples = Vec::new(); @@ -62,7 +62,7 @@ impl CalculatorBase for GeometricMoments { } } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type", "neighbor_type"]); debug_assert_eq!(keys.count(), samples.len()); @@ -100,7 +100,7 @@ impl CalculatorBase for GeometricMoments { } // [compute] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]); assert!(descriptor.keys().count() > 0); diff --git a/rascaline/src/tutorials/moments/s1_scaffold.rs b/rascaline/src/tutorials/moments/s1_scaffold.rs index 0b12962b8..f60c150c4 100644 --- a/rascaline/src/tutorials/moments/s1_scaffold.rs +++ b/rascaline/src/tutorials/moments/s1_scaffold.rs @@ -29,7 +29,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { todo!() } @@ -37,7 +37,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { todo!() } @@ -45,7 +45,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { todo!() } @@ -61,7 +61,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { todo!() } } diff --git a/rascaline/src/tutorials/moments/s2_metadata.rs b/rascaline/src/tutorials/moments/s2_metadata.rs index 983f94093..51aaaf2ca 100644 --- a/rascaline/src/tutorials/moments/s2_metadata.rs +++ b/rascaline/src/tutorials/moments/s2_metadata.rs @@ -34,7 +34,7 @@ impl CalculatorBase for GeometricMoments { // [CalculatorBase::cutoffs] // [CalculatorBase::keys] - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { let builder = CenterSingleNeighborsTypesKeys { cutoff: self.cutoff, // self pairs would have a distance of 0 and would not contribute @@ -50,7 +50,7 @@ impl CalculatorBase for GeometricMoments { AtomCenteredSamples::sample_names() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type", "neighbor_type"]); let mut samples = Vec::new(); @@ -82,7 +82,7 @@ impl CalculatorBase for GeometricMoments { // [CalculatorBase::supports_gradient] // [CalculatorBase::positions_gradient_samples] - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { assert_eq!(keys.names(), ["center_type", "neighbor_type"]); debug_assert_eq!(keys.count(), samples.len()); @@ -127,7 +127,7 @@ impl CalculatorBase for GeometricMoments { } // [CalculatorBase::properties] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { todo!() } diff --git a/rascaline/src/tutorials/moments/s3_compute_1.rs b/rascaline/src/tutorials/moments/s3_compute_1.rs index 434d1f058..18542ea46 100644 --- a/rascaline/src/tutorials/moments/s3_compute_1.rs +++ b/rascaline/src/tutorials/moments/s3_compute_1.rs @@ -25,7 +25,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { todo!() } @@ -33,7 +33,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { todo!() } @@ -41,7 +41,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { todo!() } @@ -58,7 +58,7 @@ impl CalculatorBase for GeometricMoments { } // [compute] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]); // we'll add more code here diff --git a/rascaline/src/tutorials/moments/s3_compute_2.rs b/rascaline/src/tutorials/moments/s3_compute_2.rs index 0ef29fabd..5099447bc 100644 --- a/rascaline/src/tutorials/moments/s3_compute_2.rs +++ b/rascaline/src/tutorials/moments/s3_compute_2.rs @@ -25,7 +25,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { todo!() } @@ -33,7 +33,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { todo!() } @@ -41,7 +41,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { todo!() } @@ -58,7 +58,7 @@ impl CalculatorBase for GeometricMoments { } // [compute] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]); for (system_i, system) in systems.iter_mut().enumerate() { diff --git a/rascaline/src/tutorials/moments/s3_compute_3.rs b/rascaline/src/tutorials/moments/s3_compute_3.rs index 13994cfc9..69f917e9e 100644 --- a/rascaline/src/tutorials/moments/s3_compute_3.rs +++ b/rascaline/src/tutorials/moments/s3_compute_3.rs @@ -25,7 +25,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { todo!() } @@ -33,7 +33,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { todo!() } @@ -41,7 +41,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { todo!() } @@ -58,7 +58,7 @@ impl CalculatorBase for GeometricMoments { } // [compute] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]); for (system_i, system) in systems.iter_mut().enumerate() { diff --git a/rascaline/src/tutorials/moments/s3_compute_4.rs b/rascaline/src/tutorials/moments/s3_compute_4.rs index 3dd337c32..7c478e05d 100644 --- a/rascaline/src/tutorials/moments/s3_compute_4.rs +++ b/rascaline/src/tutorials/moments/s3_compute_4.rs @@ -32,7 +32,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { todo!() } @@ -40,7 +40,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { todo!() } @@ -48,7 +48,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { todo!() } @@ -65,7 +65,7 @@ impl CalculatorBase for GeometricMoments { } // [compute] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { // ... for (system_i, system) in systems.iter_mut().enumerate() { // ... diff --git a/rascaline/src/tutorials/moments/s3_compute_5.rs b/rascaline/src/tutorials/moments/s3_compute_5.rs index 5bb15fffc..5ac9e27b2 100644 --- a/rascaline/src/tutorials/moments/s3_compute_5.rs +++ b/rascaline/src/tutorials/moments/s3_compute_5.rs @@ -33,7 +33,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn keys(&self, systems: &mut [Box]) -> Result { + fn keys(&self, systems: &mut [System]) -> Result { todo!() } @@ -41,7 +41,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result, Error> { todo!() } @@ -49,7 +49,7 @@ impl CalculatorBase for GeometricMoments { todo!() } - fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box]) -> Result, Error> { + fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result, Error> { todo!() } @@ -66,7 +66,7 @@ impl CalculatorBase for GeometricMoments { } // [compute] - fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { + fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> { // ... // add these lines diff --git a/rascaline/tests/data/mod.rs b/rascaline/tests/data/mod.rs index fd6045c55..0b68f43ee 100644 --- a/rascaline/tests/data/mod.rs +++ b/rascaline/tests/data/mod.rs @@ -11,7 +11,7 @@ use rascaline::systems::UnitCell; type HyperParameters = String; -pub fn load_calculator_input(path: impl AsRef) -> (Vec>, HyperParameters) { +pub fn load_calculator_input(path: impl AsRef) -> (Vec, HyperParameters) { let json = std::fs::read_to_string(format!("tests/data/generated/{}", path.as_ref().display())) .expect("failed to read input file"); @@ -38,7 +38,7 @@ pub fn load_calculator_input(path: impl AsRef) -> (Vec>, H simple_system.add_atom(atomic_type, position); } - systems.push(Box::new(simple_system) as Box); + systems.push(System::new(simple_system)); } (systems, parameters) diff --git a/rascaline/tests/lode-madelung.rs b/rascaline/tests/lode-madelung.rs index 159407563..5a2ffd13d 100644 --- a/rascaline/tests/lode-madelung.rs +++ b/rascaline/tests/lode-madelung.rs @@ -13,7 +13,7 @@ use rascaline::systems::{System, SimpleSystem, UnitCell}; use rascaline::{Calculator, Matrix3, Vector3D, CalculationOptions}; struct CrystalParameters { - systems: Vec>, + systems: Vec, charges: Vec, madelung: f64, } @@ -22,42 +22,42 @@ struct CrystalParameters { /// Using a primitive unit cell, the distance between the /// closest Na-Cl pair is exactly 1. The cubic unit cell /// in these units would have a length of 2. -fn get_nacl() -> Vec> { +fn get_nacl() -> Vec { let cell = Matrix3::new([[0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]); let mut system = SimpleSystem::new(UnitCell::from(cell)); system.add_atom(11, Vector3D::new(0.0, 0.0, 0.0)); system.add_atom(17, Vector3D::new(1.0, 0.0, 0.0)); - vec![Box::new(system) as Box] + vec![System::new(system)] } /// CsCl structure /// This structure is simple since the primitive unit cell /// is just the usual cubic cell with side length set to one. -fn get_cscl() -> Vec> { +fn get_cscl() -> Vec { let mut system = SimpleSystem::new(UnitCell::cubic(1.0)); system.add_atom(17, Vector3D::new(0.0, 0.0, 0.0)); system.add_atom(55, Vector3D::new(0.5, 0.5, 0.5)); - vec![Box::new(system) as Box] + vec![System::new(system)] } /// ZnS (zincblende) structure /// As for NaCl, a primitive unit cell is used which makes /// the lattice parameter of the cubic cell equal to 2. /// In these units, the closest Zn-S distance is sqrt(3)/2. -fn get_zns() -> Vec> { +fn get_zns() -> Vec { let cell = Matrix3::new([[0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]); let mut system = SimpleSystem::new(UnitCell::from(cell)); system.add_atom(16, Vector3D::new(0.0, 0.0, 0.0)); system.add_atom(30, Vector3D::new(0.5, 0.5, 0.5)); - vec![Box::new(system) as Box] + vec![System::new(system)] } /// ZnS (O4) in wurtzite structure (triclinic cell) -fn get_znso4() -> Vec> { +fn get_znso4() -> Vec { let u = 3. / 8.; let c = f64::sqrt(1. / u); let cell = Matrix3::new([[0.5, -0.5 * f64::sqrt(3.0), 0.0], [0.5, 0.5 * f64::sqrt(3.0), 0.0], [0.0, 0.0, c]]); @@ -67,7 +67,7 @@ fn get_znso4() -> Vec> { system.add_atom(16, Vector3D::new(0.5, -0.5 / f64::sqrt(3.0), 0.5 * c)); system.add_atom(30, Vector3D::new(0.5, -0.5 / f64::sqrt(3.0), (0.5 + u) * c)); - vec![Box::new(system) as Box] + vec![System::new(system)] } /// Test the agreement with Madelung constant for a variety of diff --git a/rascaline/tests/lode-vs-soap.rs b/rascaline/tests/lode-vs-soap.rs index 357e47b69..ef797946c 100644 --- a/rascaline/tests/lode-vs-soap.rs +++ b/rascaline/tests/lode-vs-soap.rs @@ -13,7 +13,7 @@ fn lode_vs_soap() { system.add_atom(8, Vector3D::new(2.0, 2.2, 1.0)); system.add_atom(8, Vector3D::new(2.3, 2.0, 1.5)); - let mut systems = vec![Box::new(system) as Box]; + let mut systems = vec![System::new(system)]; // reduce max_radial/max_angular for debug builds to make this test faster let (max_radial, max_angular) = if cfg!(debug_assertions) {