From 0b02c41cbd30ef9a2f58d95cc3dd41a8beb0ff5d Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 28 Jun 2023 09:43:26 -0600 Subject: [PATCH] No negative volumes in rare cases --- CHANGELOG.md | 1 + nequip/nn/_grad_output.py | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93b5cb55..24812ad5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ Most recent change on the bottom. - Work with `wandb>=0.13.8` - Better error for standard deviation with too few data - `load_model_state` GPU -> CPU +- No negative volumes in rare cases ### Removed - [Breaking] `fixed_fields` machinery (`npz_fixed_field_keys` is still supported, but through a more straightforward implementation) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index c03ec350..ee0ce6f9 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -330,12 +330,13 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # ^ can only scale by cell volume if we have one...: # Rescale stress tensor # See https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/atomistic/output_modules.py#L180 + # See also https://en.wikipedia.org/wiki/Triple_product + # See also https://gitlab.com/ase/ase/-/blob/master/ase/cell.py, + # which uses np.abs(np.linalg.det(cell)) # First dim is batch, second is vec, third is xyz - volume = torch.einsum( - "zi,zi->z", - cell[:, 0, :], - torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), - ).unsqueeze(-1) + # Note the .abs(), since volume should always be positive + # det is equal to a dot (b cross c) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) stress = virial / volume.view(num_batch, 1, 1) data[AtomicDataDict.CELL_KEY] = orig_cell else: