Skip to content

Commit

Permalink
Update neuralnetwork class
Browse files Browse the repository at this point in the history
- move calculation of dxdr from calculate_loss into normalize function
- modified the calculation of dxdr in calculate_loss and evaluate functions
  • Loading branch information
VScoldness committed Jan 6, 2024
1 parent 22d182b commit 534edc5
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions pyxtal_ff/models/neuralnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,7 @@ def evaluate(self, data, figname):
_Energy += _e
if self.force_coefficient:
dedx = torch.autograd.grad(_e, _x)[0]
_dxdr = value['dxdr'][element]

tmp = np.zeros([len(value['x'][element]), n_atoms, value['x'][element].shape[1], 3])
for _m in range(n_atoms):
rows = np.where(value['seq'][element][:,1]==_m)[0]
tmp[value['seq'][element][rows, 0], _m, :, :] += _dxdr[rows, :, :]
_Force -= torch.einsum("ik, ijkl->jl", dedx, torch.from_numpy(tmp))
_Force -= torch.einsum("ik, ijkl->jl", dedx, value['dxdr'][element])

if self.stress_coefficient and (data2['group'] in self.stress_group):
if self.force_coefficient is None:
Expand Down Expand Up @@ -495,11 +489,7 @@ def calculate_loss(self, models, batch):

if self.force_coefficient:
dedx[element] = torch.autograd.grad(_energy, _x, create_graph=True)[0]
tmp = np.zeros([len(x[element]), n_atoms, x[element].shape[1], 3])
for _m in range(n_atoms):
rows = np.where(seq[element][:,1]==_m)[0]
tmp[seq[element][rows, 0], _m, :, :] += dxdr[element][rows, :, :]
_force -= torch.einsum("ik, ijkl->jl", dedx[element], torch.from_numpy(tmp))
_force -= torch.einsum("ik, ijkl->jl", dedx[element], dxdr[element])

if self.stress_coefficient and (group in self.stress_group):
if self.force_coefficient is None:
Expand Down Expand Up @@ -1019,6 +1009,20 @@ def normalize(self, data, drange, unit, norm=[0., 1.]):
#d['seq'][element][:, 0] -= torch.min(d['seq'][element][:, 0]) #adjust the initial position
d['seq'][element][:, 0] -= np.min(d['seq'][element][:, 0]) #adjust the initial position

x, dxdr, seq = d['x'], d['dxdr'], d['seq']
n_atoms = sum(len(value) for value in x.values())
new_dxdr = {}
for element in x.keys():
if x[element].nelement() > 0:
tmp = np.zeros([len(x[element]), n_atoms, x[element].shape[1], 3])
for _m in range(n_atoms):
rows = np.where(seq[element][:,1]==_m)[0]
tmp[seq[element][rows, 0], _m, :, :] += dxdr[element][rows, :, :]
new_dxdr[element] = torch.from_numpy(tmp)
else:
new_dxdr[element] = torch.empty((30,3), dtype=torch.float)
d['dxdr'] = new_dxdr

db2[str(i)] = d

db1.close()
Expand Down

0 comments on commit 534edc5

Please sign in to comment.