Skip to content

Commit

Permalink
Use idiomatic torch tools to handle dtype and device
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 7, 2024
1 parent 900153b commit b4c166b
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 68 deletions.
19 changes: 6 additions & 13 deletions tests/test_spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch

import metatensor.torch
from metatensor.torch import Labels, TensorBlock, TensorMap
import ase.io

from torch_spex.spherical_expansions import VectorExpansion, SphericalExpansion
Expand Down Expand Up @@ -36,8 +35,7 @@ def test_vector_expansion_coeffs(self):
# we need to sort both computed and reference pair expansion coeffs,
# because ase.neighborlist can get different neighborlist order for some reasons
tm_ref = metatensor.torch.sort(tm_ref)
vector_expansion = VectorExpansion(self.hypers, self.all_species,
device=self.device, dtype=self.dtype)
vector_expansion = VectorExpansion(self.hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
tm = metatensor.torch.sort(vector_expansion.forward(**self.batch))
# Default types are float32 so we cannot get higher accuracy than 1e-7.
Expand All @@ -54,8 +52,7 @@ def test_vector_expansion_coeffs(self):
def test_spherical_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-ethanol1_0-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(self.hypers,
self.all_species, device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(self.hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
tm = spherical_expansion_calculator.forward(**self.batch)
# Default types are float32 so we cannot get higher accuracy than 1e-7.
Expand All @@ -75,8 +72,7 @@ def test_spherical_expansion_coeffs_alchemical(self):
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-ethanol1_0-alchemical-seed0-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
torch.manual_seed(0)
spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species,
device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species).to(self.device, self.dtype)
# Because setting seed seems not be enough to get the same initial combination matrix
# as in the reference values, we set the combination matrix manually
with torch.no_grad():
Expand Down Expand Up @@ -117,17 +113,15 @@ def test_vector_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/vector_expansion_coeffs-artificial-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
tm_ref = metatensor.torch.sort(tm_ref)
vector_expansion = VectorExpansion(self.hypers, self.all_species,
device=self.device, dtype=self.dtype)
vector_expansion = VectorExpansion(self.hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
tm = metatensor.torch.sort(vector_expansion.forward(**self.batch))
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)

def test_spherical_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-artificial-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(self.hypers,
self.all_species, device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(self.hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
tm = spherical_expansion_calculator.forward(**self.batch)
# The absolute accuracy is a bit smaller than in the ethanol case
Expand All @@ -139,8 +133,7 @@ def test_spherical_expansion_coeffs_artificial(self):
hypers = json.load(f)
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-artificial-alchemical-seed0-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species,
device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
torch.tensor(
Expand Down
4 changes: 1 addition & 3 deletions torch_spex/le.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def Jn_zeros(n, nt):
return zeros_j


def get_le_spliner(E_max, r_cut, normalize, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
def get_le_spliner(E_max, r_cut, normalize):

l_big = 50
n_big = 50
Expand Down Expand Up @@ -106,7 +106,5 @@ def laplacian_eigenstate_basis_derivative(index, r):
np.sum(n_max_l),
r_cut,
requested_accuracy=1e-6,
device=device,
dtype=dtype
)

4 changes: 1 addition & 3 deletions torch_spex/physical_LE/physical_LE.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def dc(n, x):
return -np.pi*(n+0.5)*np.sin(np.pi*(n+0.5)*x/10.0)/10.0


def get_physical_le_spliner(E_max, r_cut, normalize, device, dtype):
def get_physical_le_spliner(E_max, r_cut, normalize):

l_max = 50
n_max = 50
Expand Down Expand Up @@ -115,8 +115,6 @@ def function_for_splining_index_derivative(index, r):
np.sum(n_max_l),
a,
requested_accuracy=1e-6,
dtype=dtype,
device=device
)
print("Number of spline points:", len(spliner.spline_positions))

Expand Down
13 changes: 5 additions & 8 deletions torch_spex/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

class RadialBasis(torch.nn.Module):

def __init__(self, hypers, all_species,
device:Optional[torch.device] = None,
dtype:Optional[torch.dtype] = None) -> None:
def __init__(self, hypers, all_species) -> None:
super().__init__()

# Only for the physical basis, but initialized for all branches
Expand All @@ -28,10 +26,9 @@ def __init__(self, hypers, all_species,
self.is_physical = False

if hypers["type"] == "le":
self.n_max_l, self.spliner = get_le_spliner(hypers["E_max"],
hypers["r_cut"], hypers["normalize"], device=device, dtype=dtype)
self.n_max_l, self.spliner = get_le_spliner(hypers["E_max"], hypers["r_cut"], hypers["normalize"])
elif hypers["type"] == "physical":
self.n_max_l, self.spliner = get_physical_le_spliner(hypers["E_max"], hypers["r_cut"], hypers["normalize"], device=device, dtype=dtype)
self.n_max_l, self.spliner = get_physical_le_spliner(hypers["E_max"], hypers["r_cut"], hypers["normalize"])
self.is_physical = True
elif hypers["type"] == "custom":
# The custom keyword here allows the user to set splines from outside.
Expand All @@ -49,8 +46,8 @@ def __init__(self, hypers, all_species,
self.is_alchemical = True
self.n_pseudo_species = hypers["alchemical"]
self.combination_matrix = normalize("embedding",
torch.nn.Linear(len(all_species), self.n_pseudo_species, bias=False,
device=device, dtype=dtype))
torch.nn.Linear(len(all_species), self.n_pseudo_species, bias=False)
)
self.species_neighbor_labels = Labels(
names = ["species_neighbor"],
values = torch.tensor(self.all_species, dtype=torch.int).unsqueeze(1)
Expand Down
30 changes: 12 additions & 18 deletions torch_spex/spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class SphericalExpansion(torch.nn.Module):
>>> dataset = InMemoryDataset([h2o], transformers)
>>> loader = DataLoader(dataset, batch_size=1, collate_fn=collate_nl)
>>> batch = next(iter(loader))
>>> spherical_expansion = SphericalExpansion(hypers, [1, 8], device="cpu").to(torch.float64) #why?BUG
>>> spherical_expansion = SphericalExpansion(hypers, [1, 8])
>>> expansion = spherical_expansion.forward(**batch)
>>> print(expansion.keys)
Labels(
Expand All @@ -77,9 +77,7 @@ class SphericalExpansion(torch.nn.Module):
"""

def __init__(self, hypers: Dict, all_species: List[int],
device: Optional[torch.device] = None,
dtype: Optional[torch.device] = None) -> None:
def __init__(self, hypers: Dict, all_species: List[int]) -> None:
super().__init__()

self.hypers = hypers
Expand All @@ -92,8 +90,7 @@ def __init__(self, hypers: Dict, all_species: List[int],
self.normalization_factor = 1.0 # dummy for torchscript
self.normalization_factor_0 = 1.0 # dummy for torchscript
self.all_species = all_species
self.vector_expansion_calculator = VectorExpansion(hypers, self.all_species,
device=device, dtype=dtype)
self.vector_expansion_calculator = VectorExpansion(hypers, self.all_species)

if "alchemical" in self.hypers:
self.is_alchemical = True
Expand Down Expand Up @@ -165,7 +162,7 @@ def forward(self,
dtype = expanded_vectors_l.dtype,
device = expanded_vectors_l.device
)
densities_l.index_add_(dim=0, index=density_indices.to(expanded_vectors_l.device), source=expanded_vectors_l)
densities_l.index_add_(dim=0, index=density_indices, source=expanded_vectors_l)
densities_l = densities_l.reshape((n_centers, 2*l+1, -1))
densities.append(densities_l)
unique_species = -torch.arange(self.n_pseudo_species, dtype=torch.int64, device=density_indices.device)
Expand All @@ -181,7 +178,7 @@ def forward(self,
dtype = expanded_vectors_l.dtype,
device = expanded_vectors_l.device
)
densities_l.index_add_(dim=0, index=density_indices.to(expanded_vectors_l.device), source=expanded_vectors_l)
densities_l.index_add_(dim=0, index=density_indices, source=expanded_vectors_l)
densities_l = densities_l.reshape((n_centers, n_species, 2*l+1, -1)).swapaxes(1, 2).reshape((n_centers, 2*l+1, -1)) # need to swap n, a indices which are in the wrong order
densities.append(densities_l)
unique_species = torch.tensor(self.all_species, dtype=torch.int, device=species.device)
Expand Down Expand Up @@ -264,9 +261,7 @@ class VectorExpansion(torch.nn.Module):
"""

def __init__(self, hypers: Dict, all_species,
device: Optional[torch.device] = None,
dtype: Optional[torch.device] = None) -> None:
def __init__(self, hypers: Dict, all_species) -> None:
super().__init__()

self.hypers = hypers
Expand All @@ -282,8 +277,7 @@ def __init__(self, hypers: Dict, all_species,
else:
self.n_pseudo_species = 0 # dummy for torchscript
self.is_alchemical = False
self.radial_basis_calculator = RadialBasis(hypers_radial_basis, all_species,
device=device, dtype=dtype)
self.radial_basis_calculator = RadialBasis(hypers_radial_basis, all_species)
self.l_max = self.radial_basis_calculator.l_max
self.spherical_harmonics_calculator = sphericart.torch.SphericalHarmonics(self.l_max, normalized=True)
self.spherical_harmonics_split_list = [(2*l+1) for l in range(self.l_max+1)]
Expand Down Expand Up @@ -369,17 +363,17 @@ def forward(self,
samples = cartesian_vectors.samples,
components = [Labels(
names = ("m",),
values = torch.arange(start=-l, end=l+1, dtype=torch.int32).reshape(2*l+1, 1)
values = torch.arange(start=-l, end=l+1, dtype=torch.int32, device=vector_expansion_l.device).reshape(2*l+1, 1)
)],
properties = properties
properties = properties.to(vector_expansion_l.device)
)
)

l_max = len(vector_expansion_blocks) - 1
vector_expansion_tmap = TensorMap(
keys = Labels(
names = ("l",),
values = torch.arange(start=0, end=l_max+1, dtype=torch.int32).reshape(l_max+1, 1),
values = torch.arange(start=0, end=l_max+1, dtype=torch.int32, device=vector_expansion_blocks[0].values.device).reshape(l_max+1, 1),
),
blocks = vector_expansion_blocks
)
Expand Down Expand Up @@ -423,10 +417,10 @@ def get_cartesian_vectors(positions, cells, species, cell_shifts, centers, pairs
components = [
Labels(
names = ["cartesian_dimension"],
values = torch.tensor([-1, 0, 1], dtype=torch.int32).reshape((-1, 1))
values = torch.tensor([-1, 0, 1], dtype=torch.int32, device=direction_vectors.device).reshape((-1, 1))
)
],
properties = Labels.single()
properties = Labels.single().to(direction_vectors.device)
)

return block
31 changes: 9 additions & 22 deletions torch_spex/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ def generate_splines(
radial_basis_derivatives,
max_index,
cutoff_radius,
requested_accuracy=1e-8,
device: Optional[torch.device]=None,
dtype: Optional[torch.dtype]=None
requested_accuracy=1e-8
):
"""Spline generator for tabulated radial integrals.
Expand Down Expand Up @@ -55,8 +53,6 @@ def derivative_evaluator_2D(positions):
value_evaluator_2D,
derivative_evaluator_2D,
requested_accuracy,
device=device,
dtype=dtype
)
return dynamic_spliner

Expand All @@ -69,23 +65,18 @@ def __init__(
stop,
values_fn,
derivatives_fn,
requested_accuracy,
device: Optional[torch.device]=None,
dtype: Optional[torch.dtype]=None,
requested_accuracy
) -> None:
super().__init__()

self.start = start
self.stop = stop
self.values_fn = values_fn
self.derivatives_fn = derivatives_fn
self.requested_accuracy = requested_accuracy

# initialize spline with 11 points
positions = torch.linspace(start, stop, 11)
self.spline_positions = positions
self.spline_values = values_fn(positions)
self.spline_derivatives = derivatives_fn(positions)
self.register_buffer("spline_positions", positions)
self.register_buffer("spline_values", values_fn(positions))
self.register_buffer("spline_derivatives", derivatives_fn(positions))

self.number_of_custom_dimensions = len(self.spline_values.shape) - 1

Expand All @@ -104,20 +95,20 @@ def __init__(
)

estimated_values = self.compute(intermediate_positions)
new_values = self.values_fn(intermediate_positions)
new_values = values_fn(intermediate_positions)

mean_absolute_error = torch.mean(torch.abs(estimated_values - new_values))
mean_relative_error = torch.mean(
torch.abs((estimated_values - new_values) / new_values)
)

if (
mean_absolute_error < self.requested_accuracy
or mean_relative_error < self.requested_accuracy
mean_absolute_error < requested_accuracy
or mean_relative_error < requested_accuracy
):
break

new_derivatives = self.derivatives_fn(intermediate_positions)
new_derivatives = derivatives_fn(intermediate_positions)

concatenated_positions = torch.cat(
[self.spline_positions, intermediate_positions], dim=0
Expand All @@ -135,10 +126,6 @@ def __init__(
self.spline_values = concatenated_values[sort_indices]
self.spline_derivatives = concatenated_derivatives[sort_indices]

self.spline_positions = self.spline_positions.to(device=device, dtype=dtype)
self.spline_values = self.spline_values.to(device=device, dtype=dtype)
self.spline_derivatives = self.spline_derivatives.to(device=device, dtype=dtype)

def compute(self, positions):
x = positions
delta_x = self.spline_positions[1] - self.spline_positions[0]
Expand Down
2 changes: 1 addition & 1 deletion torch_spex/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def structure_to_torch(structure : AtomicStructure,
:returns:
Tuple of posititions, species, cell and periodic boundary conditions
"""
if dtype is None: dtype = torch.get_default_dtype()
if isinstance(structure, ase.Atoms):
# dtype is automatically referred from the type in the structure object if None
positions = torch.tensor(structure.positions, device=device, dtype=dtype)
species = torch.tensor(structure.numbers, device=device)
cell = torch.tensor(structure.cell.array, device=device, dtype=dtype)
Expand Down

0 comments on commit b4c166b

Please sign in to comment.