Skip to content

Commit

Permalink
No negative volumes in rare cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed Jun 28, 2023
1 parent 2f43aa8 commit 0b02c41
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Most recent change on the bottom.
- Work with `wandb>=0.13.8`
- Better error for standard deviation with too few data
- `load_model_state` GPU -> CPU
- No negative volumes in rare cases

### Removed
- [Breaking] `fixed_fields` machinery (`npz_fixed_field_keys` is still supported, but through a more straightforward implementation)
Expand Down
11 changes: 6 additions & 5 deletions nequip/nn/_grad_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,13 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# ^ can only scale by cell volume if we have one...:
# Rescale stress tensor
# See https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/atomistic/output_modules.py#L180
# See also https://en.wikipedia.org/wiki/Triple_product
# See also https://gitlab.com/ase/ase/-/blob/master/ase/cell.py,
# which uses np.abs(np.linalg.det(cell))
# First dim is batch, second is vec, third is xyz
volume = torch.einsum(
"zi,zi->z",
cell[:, 0, :],
torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
).unsqueeze(-1)
# Note the .abs(), since volume should always be positive
# det is equal to a dot (b cross c)
volume = torch.linalg.det(cell).abs().unsqueeze(-1)
stress = virial / volume.view(num_batch, 1, 1)
data[AtomicDataDict.CELL_KEY] = orig_cell
else:
Expand Down

0 comments on commit 0b02c41

Please sign in to comment.