diff --git a/applications/FLASK/MPNN/MPN.py b/applications/FLASK/MPNN/MPN.py new file mode 100644 index 00000000000..a71e39a1aff --- /dev/null +++ b/applications/FLASK/MPNN/MPN.py @@ -0,0 +1,155 @@ +import lbann +from lbann.modules import Module, ChannelwiseFullyConnectedModule + + +class MPNEncoder(Module): + """ """ + + global_count = 0 + + def __init__( + self, + atom_fdim, + bond_fdim, + hidden_size, + activation_func, + max_atoms, + bias=False, + depth=3, + name=None, + ): + MPNEncoder.global_count += 1 + # For debugging + self.name = name if name else "MPNEncoder_{}".format(MPNEncoder.global_count) + + self.atom_fdim = atom_fdim + self.bond_fdim = bond_fdim + self.max_atoms = max_atoms + self.hidden_size = hidden_size + self.bias = bias + self.depth = depth + self.activation_func = activation_func + + # Channelwise fully connected layer: (*, *, bond_fdim) -> (*, *, hidden_size) + self.W_i = ChannelwiseFullyConnectedModule( + self.hidden_size, + bias=self.bias, + activation=self.activation_func, + name=self.name + "W_i", + ) + + # Channelwise fully connected layer (*, *, hidden_size) -> (*, *, hidden_size)) + self.W_h = ChannelwiseFullyConnectedModule( + self.hidden_size, + bias=self.bias, + activation=self.activation_func, + name=self.name + "W_h", + ) + # Channelwise fully connected layer (*, *, atom_fdim + hidden_size) -> (*, *, hidden_size)) + self.W_o = ChannelwiseFullyConnectedModule( + self.hidden_size, + bias=True, + activation=self.activation_func, + name=self.name + "W_o", + ) + + def message( + self, + bond_features, + bond2atom_mapping, + atom2bond_sources_mapping, + atom2bond_target_mapping, + bond2revbond_mapping, + ): + """ """ + messages = self.W_i(bond_features) + for depth in range(self.depth - 1): + nei_message = lbann.Gather(messages, atom2bond_target_mapping, axis=0) + + a_message = lbann.Scatter( + nei_message, + atom2bond_sources_mapping, + dims=[self.max_atoms, self.hidden_size], + axis=0, + ) + + bond_message = lbann.Gather( + a_message, + bond2atom_mapping, + axis=0, + name=self.name + f"_bond_messages_{depth}", + ) + rev_message = lbann.Gather( + messages, + bond2revbond_mapping, + axis=0, + name=self.name + f"_rev_bond_messages_{depth}", + ) + + messages = lbann.Subtract(bond_message, rev_message) + messages = self.W_h(messages) + + return messages + + def aggregate(self, atom_messages, bond_messages, bond2atom_mapping): + """ """ + a_messages = lbann.Scatter( + bond_messages, + bond2atom_mapping, + axis=0, + dims=[self.max_atoms, self.hidden_size], + ) + + atoms_hidden = lbann.Concatenation( + [atom_messages, a_messages], axis=1, name=self.name + "atom_hidden_concat" + ) + return self.W_o(atoms_hidden) + + def readout(self, atom_encoded_features, graph_mask, num_atoms): + """ """ + mol_encoding = lbann.Scatter( + atom_encoded_features, + graph_mask, + name=self.name + "graph_scatter", + axis=0, + dims=[1, self.hidden_size], + ) + num_atoms = lbann.Reshape(num_atoms, dims=[1, 1]) + + mol_encoding = lbann.Divide( + mol_encoding, + lbann.Tessellate( + num_atoms, + dims=[1, self.hidden_size], + name=self.name + "expand_num_nodes", + ), + name=self.name + "_reduce", + ) + return mol_encoding + + def forward( + self, + atom_input_features, + bond_input_features, + atom2bond_sources_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2revbond_mapping, + graph_mask, + num_atoms, + ): + """ """ + bond_messages = self.message( + bond_input_features, + bond2atom_mapping, + atom2bond_sources_mapping, + atom2bond_target_mapping, + bond2revbond_mapping, + ) + + atom_encoded_features = self.aggregate( + atom_input_features, bond_messages, bond2atom_mapping + ) + + readout = self.readout(atom_encoded_features, graph_mask, num_atoms) + return readout diff --git a/applications/FLASK/MPNN/PrepareDataset.py b/applications/FLASK/MPNN/PrepareDataset.py new file mode 100644 index 00000000000..538af6cff87 --- /dev/null +++ b/applications/FLASK/MPNN/PrepareDataset.py @@ -0,0 +1,79 @@ +from config import DATASET_CONFIG +from tqdm import tqdm +import numpy as np +from chemprop.args import TrainArgs +from chemprop.features import reset_featurization_parameters +from chemprop.data import MoleculeDataLoader, utils +import os.path as osp +import pickle + + +def retrieve_dual_mapping(atom2bond, ascope): + atom_bond_mapping = [] + for a_start, a_size in enumerate: + _a2b = atom2bond.narrow(0, a_start, a_size) + for row, possible_bonds in enumerate(_a2b): + for bond in possible_bonds: + ind = bond.item() - 1 # Shift by 1 to account for null nodes + if ind >= 0: + atom_bond_mapping.append([row, ind]) + return np.array(atom_bond_mapping) + + +def PrepareDataset(save_file_name, target_file): + data_file = osp.join(DATASET_CONFIG["DATA_DIR"], target_file) + + arguments = [ + "--data_path", + data_file, + "--dataset_type", + "regression", + "--save_dir", + "./data/10k_dft_density", + ] + args = TrainArgs().parse_args(arguments) + reset_featurization_parameters() + data = utils.get_data(data_file, args=args) + # Need to use the data loader as the featurization happens in the dataloader + # Only use 1 mol as in LBANN we do not do coalesced batching (yet) + dataloader = MoleculeDataLoader(data, batch_size=1) + lbann_data = [] + for mol in tqdm(dataloader): + mol_data = {} + + mol_data["target"] = mol.targets()[0][0] + mol_data["num_atoms"] = mol.number_of_atoms[0][0] + # Multiply by 2 for directional bonds + mol_data["num_bonds"] = mol.number_of_bonds[0][0] * 2 + + mol_batch = mol.batch_graph()[0] + f_atoms, f_bonds, a2b, b2a, b2revb, ascope, bscope = mol_batch.get_components( + False + ) + + # shift by 1 as we don't use null nodes as in the ChemProp implementation + mol_data["atom_features"] = f_atoms[1:].numpy() + mol_data["bond_features"] = f_bonds[1:].numpy() + dual_graph_mapping = retrieve_dual_mapping(a2b, ascope) + + mol_data['dual_graph_atom2bond_source'] = dual_graph_mapping[:, 0] + mol_data['dual_graph_atom2bond_target'] = dual_graph_mapping[:, 1] + + # subtract 1 to shift the indices + mol_data['bond_graph_source'] = b2a[1:].numpy() - 1 + mol_data['bond_graph_target'] = b2revb[1:].numpy() - 1 + + lbann_data.append(mol_data) + + save_file = osp.join(DATASET_CONFIG["DATA_DIR"], save_file_name) + with open(save_file, 'wb') as f: + pickle.dump(lbann_data, f) + + +def main(): + PrepareDataset("10k_density_lbann.bin", "10k_dft_density_data.csv") + PrepareDataset("10k_hof_lbann.bin", "10k_dft_hof_data.csv") + + +if __name__ == "__main__": + main() diff --git a/applications/FLASK/MPNN/README.md b/applications/FLASK/MPNN/README.md new file mode 100644 index 00000000000..0804ef4bddd --- /dev/null +++ b/applications/FLASK/MPNN/README.md @@ -0,0 +1,38 @@ +# ChemProp on LBANN + +## Prepere Dataset (optional) + +If not on lbann system or required to regenerate the data file so it is ingestible on LBANN. + +### Requirements + +``` +chemprop +numpy +torch +``` + +### Generate Data + +The data generator is set to read from and write data to the `DATA_DIR` directory in `config.py`. Update that line to read and store +from a custom directory. + + +Generate the data by calling: + + +`python PrepareDataset.py +` + +## Run the Trainer + +### Hyperparameters + +The hyperparameters for the model and training algorihms can be set in `config.py`. + + +### Run the trainer + + +### Results + diff --git a/applications/FLASK/MPNN/config.py b/applications/FLASK/MPNN/config.py new file mode 100644 index 00000000000..e6b5cd18fef --- /dev/null +++ b/applications/FLASK/MPNN/config.py @@ -0,0 +1,21 @@ +# Dataset feature defeaults +# In general, don't change these unless using cusom data - S.Z. + +DATASET_CONFIG: dict = { + "MAX_ATOMS": 100, # The number of maximum atoms in CSD dataset + "MAX_BONDS": 224, # The number of maximum bonds in CSD dataset + "ATOM_FEATURES": 133, + "BOND_FEATURES": 147, + "DATA_DIR": "/p/vast1/lbann/datasets/FLASK/CSD10K", + "TARGET_FILE": "10k_dft_density_data.csv" # Change to 10k_dft_hof_data.csv for heat of formation +} + +# Hyperamaters used to set up trainer and MPN +# These can be changed freely +HYPERPARAMETERS_CONFIG: dict = { + "HIDDEN_SIZE": 300, + "LR": 0.001, + "BATCH_SIZE": 64, + "EPOCH": 100, + "MPN_DEPTH": 3 +} diff --git a/applications/FLASK/MPNN/dataset.py b/applications/FLASK/MPNN/dataset.py new file mode 100644 index 00000000000..fd55e050195 --- /dev/null +++ b/applications/FLASK/MPNN/dataset.py @@ -0,0 +1,116 @@ +import pickle +import numpy as np + + +MAX_ATOMS = 100 # The number of maximum atoms in CSD dataset +MAX_BONDS = 224 # The number of maximum bonds in CSD dataset +ATOM_FEATURES = 133 +BOND_FEATURES = 147 + +SAMPLE_SIZE = ( + (MAX_ATOMS * ATOM_FEATURES) + + (MAX_BONDS * BOND_FEATURES) + + 4 * MAX_BONDS + + MAX_ATOMS + + 2 +) + +DATA_DIR = "/p/vast1/lbann/datasets/FLASK/CSD10K/" + +with open(DATA_DIR + "10k_density_lbann.bin", "rb") as f: + data = pickle.load(f) + +train_index = np.load(DATA_DIR + "train_sample_indices.npy") +valid_index = np.load(DATA_DIR + "valid_sample_indices.npy") +test_index = np.load(DATA_DIR + "test_sample_indices.npy") + + +def padded_index_array(size, special_ignore_index=-1): + padded_indices = np.zeros(size, dtype=np.float32) + special_ignore_index + return padded_indices + + +def pad_data_sample(data): + """ + Args: + data(dict): Dictionary of data samples with fields 'num_atoms', 'num_bonds', + 'dual_graph_atom2bond_source', 'dual_graph_atom2bond_target', + 'bond_graph_source', 'bond_grap_target', and 'target' + + Returns: + (np.array) + """ + num_atoms = data["num_atoms"] + num_bonds = data["num_bonds"] + f_atoms = np.zeros((MAX_ATOMS, ATOM_FEATURES), dtype=np.float32) + f_atoms[:num_atoms, :] = data["atom_features"] + + f_bonds = np.zeros((MAX_BONDS, BOND_FEATURES), dtype=np.float32) + + f_bonds[:num_bonds, :] = data["bond_features"] + + atom2bond_source = padded_index_array(MAX_BONDS) + atom2bond_source[:num_bonds] = data["dual_graph_atom2bond_source"] + + atom2bond_target = padded_index_array(MAX_BONDS) + atom2bond_target[:num_bonds] = data["dual_graph_atom2bond_target"] + + bond2atom_source = padded_index_array(MAX_BONDS) + bond2atom_source[:num_bonds] = data["bond_graph_source"] + bond2bond_target = padded_index_array(MAX_BONDS) + bond2bond_target[:num_bonds] = data["bond_graph_target"] + + atom_mask = padded_index_array(MAX_ATOMS) + atom_mask[:num_atoms] = np.zeros(num_atoms) + + num_atoms = np.array([num_atoms]).astype(np.float32) + target = (np.array([data["target"]]).astype(np.float32) + 67.14776709141553) / ( + 108.13423283538837 + ) + + _data_array = [ + f_atoms.flatten(), + f_bonds.flatten(), + atom2bond_source.flatten(), + atom2bond_target.flatten(), + bond2atom_source.flatten(), + bond2bond_target.flatten(), + atom_mask.flatten(), + num_atoms.flatten(), + target.flatten(), + ] + + flattened_data_array = np.concatenate(_data_array, axis=None) + return flattened_data_array + + +def train_sample(index): + return pad_data_sample(data[train_index[index]]) + + +def validation_sample(index): + return pad_data_sample(data[valid_index[index]]) + + +def test_sample(index): + return pad_data_sample(data[test_index[index]]) + + +def train_num_samples(): + return 8164 + + +def validation_num_samples(): + return 1020 + + +def test_num_samples(): + return 1022 + + +def sample_dims(): + return (SAMPLE_SIZE,) + + +if __name__ == "__main__": + print(train_sample(2).shape, sample_dims()) diff --git a/applications/FLASK/MPNN/model.py b/applications/FLASK/MPNN/model.py new file mode 100644 index 00000000000..5931497e37f --- /dev/null +++ b/applications/FLASK/MPNN/model.py @@ -0,0 +1,176 @@ +import lbann +from config import DATASET_CONFIG, HYPERPARAMETERS_CONFIG +from MPN import MPNEncoder +import os.path as osp + + +def graph_splitter(_input): + """ + Args: + _input: (lbann.InputLayer) The padded, flattened graph data + return: + (lbann.Layer, lbann.Layer, lbann.Layer, lbann.Layer, lbann.Layer, + lbann.Layer, lbann.Layer, lbann.Layer, lbann.Layer) + + A 9-tuple with the input features, bond features, source atom to bond + graph mapping, target atom to bond graph mapping, bong to atom mapping, + bond to bond mapping, graph mask, number of atoms in the molecule, and + the target + """ + split_indices = [0] + + max_atoms = DATASET_CONFIG["MAX_ATOMS"] + max_bonds = DATASET_CONFIG["MAX_BONDS"] + atom_features = DATASET_CONFIG["ATOM_FEATURES"] + bond_features = DATASET_CONFIG["BOND_FEATURES"] + + indices_length = max_bonds + + f_atom_size = max_atoms * atom_features + split_indices.append(f_atom_size) + + f_bond_size = max_bonds * bond_features + split_indices.append(f_bond_size) + + split_indices.append(max_bonds) + split_indices.append(max_bonds) + split_indices.append(max_bonds) + split_indices.append(max_bonds) + + split_indices.append(max_atoms) + split_indices.append(1) + split_indices.append(1) + + for i in range(1, len(split_indices)): + split_indices[i] = split_indices[i] + split_indices[i - 1] + + graph_input = lbann.Slice(_input, axis=0, slice_points=split_indices) + f_atoms = lbann.Reshape( + lbann.Identity(graph_input), dims=[max_atoms, atom_features] + ) + f_bonds = lbann.Reshape( + lbann.Identity(graph_input), dims=[max_bonds, bond_features] + ) + atom2bond_source_mapping = lbann.Reshape( + lbann.Identity(graph_input), dims=[indices_length] + ) + atom2bond_target_mapping = lbann.Reshape( + lbann.Identity(graph_input), dims=[indices_length] + ) + bond2atom_mapping = lbann.Reshape( + lbann.Identity(graph_input), dims=[indices_length] + ) + bond2bond_mapping = lbann.Reshape( + lbann.Identity(graph_input), dims=[indices_length] + ) + graph_mask = lbann.Reshape(lbann.Identity(graph_input), dims=[max_atoms]) + num_atoms = lbann.Reshape(lbann.Identity(graph_input), dims=[1]) + target = lbann.Reshape(lbann.Identity(graph_input), dims=[1], name="TARGET") + + return ( + f_atoms, + f_bonds, + atom2bond_source_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2bond_mapping, + graph_mask, + num_atoms, + target, + ) + + +def make_model(): + """ + Returns: + (lbann.Model) LBANN model for a regression target on the CSD10K dataset + """ + _input = lbann.Input(data_field="samples") + + ( + f_atoms, + f_bonds, + atom2bond_source_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2bond_mapping, + graph_mask, + num_atoms, + target, + ) = graph_splitter(_input) + + encoder = MPNEncoder( + atom_fdim=DATASET_CONFIG["ATOM_FEATURES"], + bond_fdim=DATASET_CONFIG["BOND_FEATURES"], + max_atoms=DATASET_CONFIG["MAX_ATOMS"], + hidden_size=HYPERPARAMETERS_CONFIG["HIDDEN_SIZE"], + activation_func=lbann.Relu, + ) + + encoded_vec = encoder( + f_atoms, + f_bonds, + atom2bond_source_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2bond_mapping, + graph_mask, + num_atoms, + ) + + # Readout layers + x = lbann.FullyConnected( + encoded_vec, + num_neurons=HYPERPARAMETERS_CONFIG["HIDDEN_SIZE"], + name="READOUT_Linear_1", + ) + x = lbann.Relu(x, name="READOUT_Activation_1") + + x = lbann.FullyConnected(x, num_neurons=1, name="PREDICTION") + + loss = lbann.MeanSquaredError(x, target) + + layers = lbann.traverse_layer_graph(_input) + + # Callbacks + training_output = lbann.CallbackPrint(interval=1, print_global_stat_only=False) + gpu_usage = lbann.CallbackGPUMemoryUsage() + timer = lbann.CallbackTimer() + predictions = lbann.CallbackDumpOutputs(layers="PREDICTION", execution_modes="test") + + targets = lbann.CallbackDumpOutputs(layers="TARGET", execution_modes="test") + step_learning_rate = lbann.CallbackStepLearningRate(step=10, amt=0.9) + callbacks = [ + training_output, + gpu_usage, + timer, + predictions, + targets, + step_learning_rate, + ] + model = lbann.Model( + HYPERPARAMETERS_CONFIG["EPOCH"], + layers=layers, + objective_function=loss, + callbacks=callbacks, + ) + return model + + +def make_data_reader(classname="dataset", sample="sample", num_samples="num_samples"): + data_dir = osp.dirname(osp.realpath(__file__)) + reader = lbann.reader_pb2.DataReader() + + for role in ["train", "validation", "test"]: + _reader = reader.reader.add() + _reader.name = "python" + _reader.role = role + _reader.shuffle = True + _reader.fraction_of_data_to_use = 1.0 + _reader.python.module = classname + _reader.python.module_dir = data_dir + _reader.python.sample_function = f"{role}_{sample}" + _reader.python.num_samples_function = f"{role}_{num_samples}" + _reader.python.sample_dims_function = "sample_dims" + + return reader diff --git a/applications/FLASK/MPNN/train.py b/applications/FLASK/MPNN/train.py new file mode 100644 index 00000000000..21b5ab7b708 --- /dev/null +++ b/applications/FLASK/MPNN/train.py @@ -0,0 +1,25 @@ +import lbann +import lbann.contrib.launcher +import lbann.contrib.args +from config import HYPERPARAMETERS_CONFIG +from model import make_model, make_data_reader +import argparse + + +desc = " Training a MPNN Model using LBANN" +parser = argparse.ArgumentParser(description=desc) +lbann.contrib.args.add_scheduler_arguments(parser, 'ChemProp') +lbann.contrib.args.add_optimizer_arguments(parser) + +args = parser.parse_args() +kwargs = lbann.contrib.args.get_scheduler_kwargs(args) +job_name = args.job_name + +model = make_model() +data_reader = make_data_reader() +optimizer = lbann.Adam(learn_rate=HYPERPARAMETERS_CONFIG["LR"], beta1=0.9, beta2=0.99, eps=1e-8, adamw_weight_decay=0) +trainer = lbann.Trainer(mini_batch_size=HYPERPARAMETERS_CONFIG["BATCH_SIZE"]) + +lbann.contrib.launcher.run( + trainer, model, data_reader, optimizer, job_name=job_name, **kwargs +)