Skip to content

Commit

Permalink
Merge pull request #165 from usnistgov/develop
Browse files Browse the repository at this point in the history
Add MPtraj mode.
  • Loading branch information
knc6 authored Aug 29, 2024
2 parents 08a2b15 + ea84615 commit 3397aa6
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 51 deletions.
46 changes: 27 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
* [Funding support](#fund)

<a name="intro"></a>
# ALIGNN (Introduction)
# ALIGNN & ALIGNN-FF (Introduction)
The Atomistic Line Graph Neural Network (https://www.nature.com/articles/s41524-021-00650-1) introduces a new graph convolution layer that explicitly models both two and three body interactions in atomistic systems. This is achieved by composing two edge-gated graph convolution layers, the first applied to the atomistic line graph *L(g)* (representing triplet interactions) and the second applied to the atomistic bond graph *g* (representing pair interactions).

A unified force-field model, ALIGNN-FF (https://pubs.rsc.org/en/content/articlehtml/2023/dd/d2dd00096b ) was developed that can model both structurally and chemically diverse solids with any combination of 89 elements from the periodic table.


![ALIGNN layer schematic](https://github.com/usnistgov/alignn/blob/develop/alignn/tex/schematic_lg.jpg)

Expand Down Expand Up @@ -185,35 +187,41 @@ Atomisitic line graph neural network-based FF (ALIGNN-FF) can be used to model b
[ASE calculator](https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html) provides interface to various codes. An example for ALIGNN-FF is give below. Note that there are multiple pretrained ALIGNN-FF models available, here we use the deafult_path model. As more accurate models are developed, they will be made available as well:

```
from alignn.ff.ff import AlignnAtomwiseCalculator,default_path
model_path = default_path()
calc = AlignnAtomwiseCalculator(path=model_path)
from alignn.ff.ff import (
AlignnAtomwiseCalculator,
default_path,
mptraj_path,
wt01_path,
)
import matplotlib.pyplot as plt
from ase import Atom, Atoms
import time
from ase.build import bulk
import numpy as np
import matplotlib.pyplot as plt
from ase.build import make_supercell
%matplotlib inline
lattice_params = np.linspace(3.5, 3.8)
model_path = default_path()
calc = AlignnAtomwiseCalculator(path=model_path)
t1 = time.time()
# a = 5.43
lattice_params = np.linspace(5.2, 5.6)
fcc_energies = []
ready = True
for a in lattice_params:
atoms = Atoms([Atom('Cu', (0, 0, 0))],
cell=0.5 * a * np.array([[1.0, 1.0, 0.0],
[0.0, 1.0, 1.0],
[1.0, 0.0, 1.0]]),
pbc=True)
atoms = bulk("Si", "diamond", a=a)
atoms.set_tags(np.ones(len(atoms)))
atoms.calc = calc
e = atoms.get_potential_energy()
fcc_energies.append(e)
import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(lattice_params, fcc_energies)
plt.title('1x1x1')
plt.xlabel('Lattice constant ($\AA$)')
plt.ylabel('Total energy (eV)')
t2 = time.time()
print("Time", t2 - t1)
plt.plot(lattice_params, fcc_energies, "-o")
plt.title("Si")
plt.xlabel("Lattice constant ($\AA$)")
plt.ylabel("Total energy (eV)")
plt.show()
```

Expand Down
2 changes: 1 addition & 1 deletion alignn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version number."""

__version__ = "2024.5.27"
__version__ = "2024.8.27"
16 changes: 15 additions & 1 deletion alignn/ff/all_models_ff.json
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
{"v5.27.2024": "https://figshare.com/ndownloader/files/47286127", "alignnff_fmult": "https://figshare.com/ndownloader/files/41583585", "alignnff_wt10": "https://figshare.com/ndownloader/files/41583594", "alignnff_fd": "https://figshare.com/ndownloader/files/41583582", "alignnff_wt01": "https://figshare.com/ndownloader/files/41583588", "alignnff_wt1": "https://figshare.com/ndownloader/files/41583591", "fmult_mlearn_only": "https://figshare.com/ndownloader/files/41583597", "aff_Oct23": "https://figshare.com/ndownloader/files/42880573", "revised": "https://figshare.com/ndownloader/files/41583600", "scf_fd_top_10_en_42_fmax_600_wt01": "https://figshare.com/ndownloader/files/41967375", "scf_fd_top_10_en_42_fmax_600_wt10": "https://figshare.com/ndownloader/files/41967372"}
{
"v8.29.2024_dft_3d": "https://figshare.com/ndownloader/files/48889834",
"v8.29.2024_mpf": "https://figshare.com/ndownloader/files/48889837",
"v5.27.2024": "https://figshare.com/ndownloader/files/47286127",
"alignnff_fmult": "https://figshare.com/ndownloader/files/41583585",
"alignnff_wt10": "https://figshare.com/ndownloader/files/41583594",
"alignnff_fd": "https://figshare.com/ndownloader/files/41583582",
"alignnff_wt01": "https://figshare.com/ndownloader/files/41583588",
"alignnff_wt1": "https://figshare.com/ndownloader/files/41583591",
"fmult_mlearn_only": "https://figshare.com/ndownloader/files/41583597",
"aff_Oct23": "https://figshare.com/ndownloader/files/42880573",
"revised": "https://figshare.com/ndownloader/files/41583600",
"scf_fd_top_10_en_42_fmax_600_wt01": "https://figshare.com/ndownloader/files/41967375",
"scf_fd_top_10_en_42_fmax_600_wt10": "https://figshare.com/ndownloader/files/41967372"
}
19 changes: 16 additions & 3 deletions alignn/ff/ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def get_figshare_model_ff(

def default_path():
"""Get default model path."""
dpath = get_figshare_model_ff(model_name="v5.27.2024")
dpath = get_figshare_model_ff(model_name="v8.29.2024_dft_3d")
# dpath = get_figshare_model_ff(model_name="v5.27.2024")
# dpath = get_figshare_model_ff(model_name="alignnff_wt10")
# dpath = get_figshare_model_ff(model_name="alignnff_fmult")
# print("model_path", dpath)
Expand All @@ -141,12 +142,19 @@ def revised_path():


def alignnff_fmult():
"""Get defaukt model path."""
"""Get default model path."""
dpath = get_figshare_model_ff(model_name="alignnff_fmult")
print("model_path", dpath)
return dpath


def mptraj_path():
"""Get MPtraj model path."""
dpath = get_figshare_model_ff(model_name="v8.29.2024_mpf")
print("model_path", dpath)
return dpath


def mlearn_path():
"""Get model trained on mlearn path."""
dpath = get_figshare_model_ff(model_name="fmult_mlearn_only")
Expand Down Expand Up @@ -287,6 +295,7 @@ def calculate(self, atoms, properties=None, system_changes=None):
"""Calculate properties."""
j_atoms = ase_to_atoms(atoms)
num_atoms = j_atoms.num_atoms
# g, lg = Graph.atom_dgl_multigraph(
g, lg = Graph.atom_dgl_multigraph(
j_atoms,
neighbor_strategy=self.config["neighbor_strategy"],
Expand All @@ -295,7 +304,11 @@ def calculate(self, atoms, properties=None, system_changes=None):
atom_features=self.config["atom_features"],
use_canonize=self.config["use_canonize"],
)
result = self.net((g.to(self.device), lg.to(self.device)))
if self.config["model"]["alignn_layers"] > 0:
# g,lg = g
result = self.net((g.to(self.device), lg.to(self.device)))
else:
result = self.net((g.to(self.device)))
# print ('stress',result["stress"].detach().numpy())
if self.force_mult_natoms:
mult = num_atoms
Expand Down
32 changes: 25 additions & 7 deletions alignn/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ def prepare_line_graph_batch(
class TorchLMDBDataset(Dataset):
"""Dataset of crystal DGLGraphs using LMDB."""

def __init__(self, lmdb_path="", ids=[]):
def __init__(self, lmdb_path="", line_graph=True, ids=[]):
"""Intitialize with path and ids array."""
super(TorchLMDBDataset, self).__init__()
self.lmdb_path = lmdb_path
self.ids = ids
self.line_graph = line_graph
self.env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
with self.env.begin() as txn:
self.length = txn.stat()["entries"]
Expand All @@ -56,8 +57,12 @@ def __getitem__(self, idx):
"""Get sample."""
with self.env.begin() as txn:
serialized_data = txn.get(f"{idx}".encode())
graph, line_graph, label = pk.loads(serialized_data)
return graph, line_graph, label
if self.line_graph:
graph, line_graph, label = pk.loads(serialized_data)
return graph, line_graph, label
else:
graph, label = pk.loads(serialized_data)
return graph, label

def close(self):
"""Close connection."""
Expand All @@ -70,7 +75,9 @@ def __del__(self):
@staticmethod
def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]):
"""Dataloader helper to batch graphs cross `samples`."""
# print('samples',samples)
graphs, labels = map(list, zip(*samples))
# graphs, lgs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(labels)

Expand Down Expand Up @@ -113,6 +120,7 @@ def get_torch_dataset(
"""Get Torch Dataset with LMDB."""
vals = np.array([ii[target] for ii in dataset]) # df[target].values
print("data range", np.max(vals), np.min(vals))
print("line_graph", line_graph)
f = open(os.path.join(output_dir, tmp_name + "_data_range"), "w")
line = "Max=" + str(np.max(vals)) + "\n"
f.write(line)
Expand All @@ -123,15 +131,18 @@ def get_torch_dataset(
if os.path.exists(tmp_name) and read_existing:
for idx, (d) in tqdm(enumerate(dataset), total=len(dataset)):
ids.append(d[id_tag])
dat = TorchLMDBDataset(lmdb_path=tmp_name, ids=ids)
dat = TorchLMDBDataset(
lmdb_path=tmp_name, line_graph=line_graph, ids=ids
)
print("Reading dataset", tmp_name)
return dat
ids = []
env = lmdb.open(tmp_name, map_size=int(map_size))
with env.begin(write=True) as txn:
for idx, (d) in tqdm(enumerate(dataset), total=len(dataset)):
ids.append(d[id_tag])
g, lg = Graph.atom_dgl_multigraph(
# g, lg = Graph.atom_dgl_multigraph(
g = Graph.atom_dgl_multigraph(
Atoms.from_dict(d["atoms"]),
cutoff=float(cutoff),
max_neighbors=max_neighbors,
Expand All @@ -140,6 +151,8 @@ def get_torch_dataset(
use_canonize=use_canonize,
cutoff_extra=cutoff_extra,
)
if line_graph:
g, lg = g
label = torch.tensor(d[target]).type(torch.get_default_dtype())
# print('label',label,label.view(-1).long())
if classification:
Expand All @@ -165,11 +178,16 @@ def get_torch_dataset(
).type(torch.get_default_dtype())

# labels.append(label)
serialized_data = pk.dumps((g, lg, label))
if line_graph:
serialized_data = pk.dumps((g, lg, label))
else:
serialized_data = pk.dumps((g, label))
txn.put(f"{idx}".encode(), serialized_data)

env.close()
lmdb_dataset = TorchLMDBDataset(lmdb_path=tmp_name, ids=ids)
lmdb_dataset = TorchLMDBDataset(
lmdb_path=tmp_name, line_graph=line_graph, ids=ids
)
return lmdb_dataset


Expand Down
39 changes: 32 additions & 7 deletions alignn/models/alignn_atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class ALIGNNAtomWiseConfig(BaseSettings):
add_reverse_forces: bool = False # will make True as default soon
lg_on_fly: bool = False # will make True as default soon
batch_stress: bool = True
multiply_cutoff: bool = False
extra_features: int = 0
exponent: int = 3

class Config:
"""Configure model settings behavior."""
Expand Down Expand Up @@ -99,6 +101,14 @@ def cutoff_function_based_edges(r, inner_cutoff=4, exponent=3):
+ c2 * ratio ** (exponent + 1)
+ c3 * ratio ** (exponent + 2)
)
# r_cut = inner_cutoff
# r_on = inner_cutoff+1

# r_sq = r * r
# r_on_sq = r_on * r_on
# r_cut_sq = r_cut * r_cut
# envelope = (r_cut_sq - r_sq)
# ** 2 * (r_cut_sq + 2 * r_sq - 3 * r_on_sq)/ (r_cut_sq - r_on_sq) ** 3
return torch.where(r <= inner_cutoff, envelope, torch.zeros_like(r))


Expand Down Expand Up @@ -371,7 +381,6 @@ def forward(
features = g.ndata["extra_features"]
# print('features',features,features.shape)
features = self.extra_feature_embedding(features)

g = g.local_var()
result = {}

Expand All @@ -381,6 +390,9 @@ def forward(
r = g.edata["r"]
if self.config.calculate_gradient:
r.requires_grad_(True)
bondlength = torch.norm(r, dim=1)
# mask = bondlength >= self.config.inner_cutoff
# bondlength[mask]=float(1.1)
if self.config.lg_on_fly and len(self.alignn_layers) > 0:
# re-compute bond angle cosines here to ensure
# the three-body interactions are fully included
Expand All @@ -390,13 +402,26 @@ def forward(
z = self.angle_embedding(lg.edata.pop("h"))

# r = g.edata["r"].clone().detach().requires_grad_(True)
bondlength = torch.norm(r, dim=1)
if self.config.use_cutoff_function:
bondlength = cutoff_function_based_edges(
bondlength, inner_cutoff=self.config.inner_cutoff
)
y = self.edge_embedding(bondlength)

# bondlength = cutoff_function_based_edges(
if self.config.multiply_cutoff:
c_off = cutoff_function_based_edges(
bondlength,
inner_cutoff=self.config.inner_cutoff,
exponent=self.config.exponent,
).unsqueeze(dim=1)

y = self.edge_embedding(bondlength) * c_off
else:
bondlength = cutoff_function_based_edges(
bondlength,
inner_cutoff=self.config.inner_cutoff,
exponent=self.config.exponent,
)
y = self.edge_embedding(bondlength)
else:
y = self.edge_embedding(bondlength)
# y = self.edge_embedding(bondlength)
# ALIGNN updates: update node, edge, triplet features
for alignn_layer in self.alignn_layers:
x, y, z = alignn_layer(g, lg, x, y, z)
Expand Down
Loading

0 comments on commit 3397aa6

Please sign in to comment.