Skip to content

Commit

Permalink
Allow to add data to a System
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Nov 1, 2024
1 parent 5326b6e commit b60fbc5
Show file tree
Hide file tree
Showing 39 changed files with 237 additions and 191 deletions.
4 changes: 2 additions & 2 deletions rascaline-c-api/src/calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ pub unsafe extern fn rascal_calculator_compute(
}
check_pointers!(calculator, descriptor, systems);

// Create a Vec<Box<dyn System>> from the passed systems
// Create a Vec<System> from the passed systems
let c_systems = if systems_count == 0 {
&mut []
} else {
Expand All @@ -411,7 +411,7 @@ pub unsafe extern fn rascal_calculator_compute(
};
let mut systems = Vec::with_capacity(c_systems.len());
for system in c_systems {
systems.push(Box::new(system) as Box<dyn System>);
systems.push(System::new(system));
}

let c_gradients = if options.gradients_count == 0 {
Expand Down
6 changes: 3 additions & 3 deletions rascaline-c-api/src/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ffi::CStr;

use rascaline::types::{Vector3D, Matrix3};
use rascaline::systems::{SimpleSystem, Pair, UnitCell};
use rascaline::{Error, System};
use rascaline::{Error, SystemBase};

use crate::RASCAL_SYSTEM_ERROR;

Expand Down Expand Up @@ -111,7 +111,7 @@ pub struct rascal_system_t {
unsafe impl Send for rascal_system_t {}
unsafe impl Sync for rascal_system_t {}

impl<'a> System for &'a mut rascal_system_t {
impl<'a> SystemBase for &'a mut rascal_system_t {
fn size(&self) -> Result<usize, Error> {
let function = self.size.ok_or_else(|| Error::External {
status: RASCAL_SYSTEM_ERROR,
Expand Down Expand Up @@ -442,7 +442,7 @@ pub unsafe extern fn rascal_basic_systems_read(
catch_unwind(move || {
check_pointers!(path, systems, count);
let path = CStr::from_ptr(path).to_str()?;
let simple_systems = rascaline::systems::read_from_file(path)?;
let simple_systems = rascaline::systems::read_simple_systems_from_file(path)?;

let mut c_systems = Vec::with_capacity(simple_systems.len());
for system in simple_systems {
Expand Down
14 changes: 3 additions & 11 deletions rascaline/benches/lode-spherical-expansion.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
#![allow(clippy::needless_return)]
use rascaline::{Calculator, System, CalculationOptions};
use rascaline::{Calculator, CalculationOptions};

use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode};
use criterion::{criterion_group, criterion_main};

fn load_systems(path: &str) -> Vec<Box<dyn System>> {
let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
.expect("failed to read file");

return systems.into_iter()
.map(|s| Box::new(s) as Box<dyn System>)
.collect()
}

fn run_spherical_expansion(mut group: BenchmarkGroup<WallTime>,
path: &str,
gradients: bool,
test_mode: bool,
) {
let mut systems = load_systems(path);
let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
.expect("failed to read file");

if test_mode {
// Reduce the time/RAM required to test the benchmarks code.
Expand Down
15 changes: 3 additions & 12 deletions rascaline/benches/soap-power-spectrum.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
#![allow(clippy::needless_return)]

use rascaline::{Calculator, System, CalculationOptions};
use rascaline::{Calculator, CalculationOptions};

use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode};
use criterion::{criterion_group, criterion_main};


fn load_systems(path: &str) -> Vec<Box<dyn System>> {
let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
.expect("failed to read file");

return systems.into_iter()
.map(|s| Box::new(s) as Box<dyn System>)
.collect()
}

fn run_soap_power_spectrum(
mut group: BenchmarkGroup<WallTime>,
path: &str,
gradients: bool,
test_mode: bool,
) {
let mut systems = load_systems(path);
let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
.expect("failed to read file");

if test_mode {
// Reduce the time/RAM required to test the benchmarks code.
Expand Down
14 changes: 3 additions & 11 deletions rascaline/benches/soap-spherical-expansion.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
#![allow(clippy::needless_return)]
use rascaline::{Calculator, System, CalculationOptions};
use rascaline::{Calculator, CalculationOptions};

use criterion::{BenchmarkGroup, Criterion, measurement::WallTime, SamplingMode};
use criterion::{criterion_group, criterion_main};

fn load_systems(path: &str) -> Vec<Box<dyn System>> {
let systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
.expect("failed to read file");

return systems.into_iter()
.map(|s| Box::new(s) as Box<dyn System>)
.collect()
}

fn run_spherical_expansion(mut group: BenchmarkGroup<WallTime>,
path: &str,
gradients: bool,
test_mode: bool,
) {
let mut systems = load_systems(path);
let mut systems = rascaline::systems::read_from_file(format!("benches/data/{}", path))
.expect("failed to read file");

if test_mode {
// Reduce the time/RAM required to test the benchmarks code.
Expand Down
8 changes: 2 additions & 6 deletions rascaline/examples/compute-soap.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use metatensor::Labels;
use rascaline::{Calculator, System, CalculationOptions};
use rascaline::{Calculator, CalculationOptions};

fn main() -> Result<(), Box<dyn std::error::Error>> {
// load the systems from command line argument
let path = std::env::args().nth(1).expect("expected a command line argument");
let systems = rascaline::systems::read_from_file(path)?;
// transform systems into a vector of trait objects (`Vec<Box<dyn System>>`)
let mut systems = systems.into_iter()
.map(|s| Box::new(s) as Box<dyn System>)
.collect::<Vec<_>>();
let mut systems = rascaline::systems::read_from_file(path)?;

// pass hyper-parameters as JSON
let parameters = r#"{
Expand Down
7 changes: 2 additions & 5 deletions rascaline/examples/profiling.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use metatensor::{TensorMap, Labels};
use rascaline::{Calculator, System, CalculationOptions};
use rascaline::{Calculator, CalculationOptions};

fn main() -> Result<(), Box<dyn std::error::Error>> {
let path = std::env::args().nth(1).expect("expected a command line argument");
Expand Down Expand Up @@ -28,10 +28,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
/// Compute SOAP power spectrum, this is the same code as the 'compute-soap'
/// example
fn compute_soap(path: &str) -> Result<TensorMap, Box<dyn std::error::Error>> {
let systems = rascaline::systems::read_from_file(path)?;
let mut systems = systems.into_iter()
.map(|s| Box::new(s) as Box<dyn System>)
.collect::<Vec<_>>();
let mut systems = rascaline::systems::read_from_file(path)?;

let parameters = r#"{
"cutoff": 5.0,
Expand Down
6 changes: 3 additions & 3 deletions rascaline/src/calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ impl Calculator {
}

#[time_graph::instrument(name="Calculator::prepare")]
fn prepare(&mut self, systems: &mut [Box<dyn System>], options: CalculationOptions) -> Result<TensorMap, Error> {
fn prepare(&mut self, systems: &mut [System], options: CalculationOptions) -> Result<TensorMap, Error> {
let default_keys = self.implementation.keys(systems)?;

let keys = match options.selected_keys {
Expand Down Expand Up @@ -503,14 +503,14 @@ impl Calculator {
/// features.
pub fn compute(
&mut self,
systems: &mut [Box<dyn System>],
systems: &mut [System],
options: CalculationOptions,
) -> Result<TensorMap, Error> {
let mut native_systems;
let systems = if options.use_native_system {
native_systems = Vec::with_capacity(systems.len());
for system in systems {
native_systems.push(Box::new(SimpleSystem::try_from(&**system)?) as Box<dyn System>);
native_systems.push(System::new(SimpleSystem::try_from(&**system)?) as System);
}
&mut native_systems
} else {
Expand Down
8 changes: 4 additions & 4 deletions rascaline/src/calculators/atomic_composition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl CalculatorBase for AtomicComposition {
&[]
}

fn keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error> {
fn keys(&self, systems: &mut [System]) -> Result<Labels, Error> {
return CenterTypesKeys.keys(systems);
}

Expand All @@ -48,7 +48,7 @@ impl CalculatorBase for AtomicComposition {
return vec!["system", "atom"];
}

fn samples(&self, keys: &Labels, systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result<Vec<Labels>, Error> {
assert_eq!(keys.names(), ["center_type"]);
let mut samples = Vec::new();
for [center_type_key] in keys.iter_fixed_size() {
Expand Down Expand Up @@ -84,7 +84,7 @@ impl CalculatorBase for AtomicComposition {
&self,
keys: &Labels,
_samples: &[Labels],
_systems: &mut [Box<dyn System>],
_systems: &mut [System],
) -> Result<Vec<Labels>, Error> {
// Positions/cell gradients of the composition are zero everywhere.
// Therefore, we only return a vector of empty labels (one for each key).
Expand All @@ -110,7 +110,7 @@ impl CalculatorBase for AtomicComposition {

fn compute(
&mut self,
systems: &mut [Box<dyn System>],
systems: &mut [System],
descriptor: &mut TensorMap,
) -> Result<(), Error> {
assert_eq!(descriptor.keys().names(), ["center_type"]);
Expand Down
8 changes: 4 additions & 4 deletions rascaline/src/calculators/dummy_calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ impl CalculatorBase for DummyCalculator {
std::slice::from_ref(&self.cutoff)
}

fn keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error> {
fn keys(&self, systems: &mut [System]) -> Result<Labels, Error> {
return CenterTypesKeys.keys(systems);
}

fn sample_names(&self) -> Vec<&str> {
AtomCenteredSamples::sample_names()
}

fn samples(&self, keys: &Labels, systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result<Vec<Labels>, Error> {
assert_eq!(keys.names(), ["center_type"]);
let mut samples = Vec::new();
for [center_type] in keys.iter_fixed_size() {
Expand All @@ -75,7 +75,7 @@ impl CalculatorBase for DummyCalculator {
}
}

fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result<Vec<Labels>, Error> {
debug_assert_eq!(keys.count(), samples.len());
let mut gradient_samples = Vec::new();
for ([center_type], samples) in keys.iter_fixed_size().zip(samples) {
Expand Down Expand Up @@ -110,7 +110,7 @@ impl CalculatorBase for DummyCalculator {
}

#[time_graph::instrument(name = "DummyCalculator::compute")]
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> {
if self.name.contains("log-test-info:") {
info!("{}", self.name);
} else if self.name.contains("log-test-warn:") {
Expand Down
12 changes: 6 additions & 6 deletions rascaline/src/calculators/lode/spherical_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ impl LodeSphericalExpansion {
/// By symmetry, this only affects the (l, m) = (0, 0) components of the
/// projection coefficients and only the neighbor type blocks that agrees
/// with the center atom.
fn do_center_contribution(&mut self, systems: &mut[Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
fn do_center_contribution(&mut self, systems: &mut[System], descriptor: &mut TensorMap) -> Result<(), Error> {
let mut radial_integral = self.radial_integral.get_or(|| {
let radial_integral = LodeRadialIntegralCache::new(
self.parameters.radial_basis.clone(),
Expand Down Expand Up @@ -470,7 +470,7 @@ impl CalculatorBase for LodeSphericalExpansion {
&[]
}

fn keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error> {
fn keys(&self, systems: &mut [System]) -> Result<Labels, Error> {
let builder = AllTypesPairsKeys {};
let keys = builder.keys(systems)?;

Expand All @@ -488,7 +488,7 @@ impl CalculatorBase for LodeSphericalExpansion {
LongRangeSamplesPerAtom::sample_names()
}

fn samples(&self, keys: &Labels, systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result<Vec<Labels>, Error> {
assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]);

// only compute the samples once for each `center_type, neighbor_type`,
Expand Down Expand Up @@ -527,7 +527,7 @@ impl CalculatorBase for LodeSphericalExpansion {
}
}

fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result<Vec<Labels>, Error> {
assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]);
assert_eq!(keys.count(), samples.len());

Expand Down Expand Up @@ -588,7 +588,7 @@ impl CalculatorBase for LodeSphericalExpansion {
}

#[time_graph::instrument(name = "LodeSphericalExpansion::compute")]
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error> {
assert_eq!(descriptor.keys().names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]);

self.do_center_contribution(systems, descriptor)?;
Expand Down Expand Up @@ -949,7 +949,7 @@ mod tests {
]);

crate::calculators::tests_utils::compute_partial(
calculator, &mut [Box::new(system)], &keys, &samples, &properties
calculator, &mut [System::new(system)], &keys, &samples, &properties
);
}

Expand Down
8 changes: 4 additions & 4 deletions rascaline/src/calculators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ pub trait CalculatorBase: std::panic::RefUnwindSafe {
fn cutoffs(&self) -> &[f64];

/// Get the set of keys for this calculator and the given systems
fn keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error>;
fn keys(&self, systems: &mut [System]) -> Result<Labels, Error>;

/// Get the names used for sample labels by this calculator
fn sample_names(&self) -> Vec<&str>;

/// Get the full list of samples this calculator would create for the given
/// systems. This function should return one set of samples for each key.
fn samples(&self, keys: &Labels, systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error>;
fn samples(&self, keys: &Labels, systems: &mut [System]) -> Result<Vec<Labels>, Error>;

/// Can this calculator compute gradients with respect to the `parameter`?
/// Right now, `parameter` can be either `"positions"`, `"strain"` or
Expand All @@ -43,7 +43,7 @@ pub trait CalculatorBase: std::panic::RefUnwindSafe {
///
/// If the gradients with respect to positions are not available, this
/// function should return an error.
fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error>;
fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [System]) -> Result<Vec<Labels>, Error>;

/// Get the components this calculator computes for each key.
fn components(&self, keys: &Labels) -> Vec<Vec<Labels>>;
Expand All @@ -66,7 +66,7 @@ pub trait CalculatorBase: std::panic::RefUnwindSafe {
/// block if they are supported according to
/// [`CalculatorBase::supports_gradient`], and the users requested them as
/// part of the calculation options.
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error>;
fn compute(&mut self, systems: &mut [System], descriptor: &mut TensorMap) -> Result<(), Error>;
}


Expand Down
Loading

0 comments on commit b60fbc5

Please sign in to comment.