Skip to content

Commit

Permalink
Get changes from mpcdf_raven development branch (#15)
Browse files Browse the repository at this point in the history
* first edits towards Kipf and Welling GCN.

* addition to ptable plot
now also plots number of compounds that LDAU correction has been applied

* add gpu test scripts

* work on gcn

* fix node update function
it is only a simple transformation of the nodes in context of the GCN

* dividing different modules into own files
the models module is now more structured. TODO tests

* fix circular imports in new modules

* changed utilities name to mlp

* add loading module
interfaces train/inference with different models

* updated configs

* run new grid points for egap model

* started rest ef grid and cleaned up crossval plot

* improve crossval boxplots

* new runs in job_scripts

* update of crossval scripts and mcd plotting

* added new script for ensemble predictions

* improved plots for ensemble_error

* add extra printing to classify.py

* improved error_analysis regression plots

* add printing of dataset STDEV for plotting scripts

* work on error calibration plots

* adjustments to crossval plotting

* change some visuals in plotting scripts

* small visual improvements to plots

* cleanup

* calculate pearson correlation for UQ

* updated requirements

* updated plots and added seaborn as requirement:
-larger default font and tick sizes
-added histogram of errors in error analysis
-added printing pearson correlation to mc_error

* job script: go back to standard pbj training

* small changes to axis labels and legend

* implement schnet model
as copy of mpeu without edge updates

* make methods in model modules protected
also includes making megnet model and gcn model independent of mpeu

* start work on test for GCN

* finish test for gcn node updates

* update data pulling for all aflow ef's

* fix dataframe append in plot crossval

* change logging of data conversion
and print validation errors in error analysis

* increase tick size in crossval plots

* cleaned up error_analysis
additional flag allows to choose single plots

* improvements on data conversion
preparation for running on big aflow dataset

* added batch size to training evaluater

* fix inference function
with using batch size from config and update evaluation job script

* updated most tests

* Deleted evaluation.py and inference_file function
added test for get_predictions

* udpate configs to include model name string

* added schnet to model loading
and added configs for schnet ef and egap models

* fixed naming of layers in schnet,
fixes compatibility with models trained in the old schnet branch

---------

Co-authored-by: dts <[email protected]>
  • Loading branch information
tisabe and dts authored Oct 10, 2023
1 parent 1eb6bec commit ab211b0
Show file tree
Hide file tree
Showing 53 changed files with 2,282 additions and 680 deletions.
7 changes: 7 additions & 0 deletions jax_GPU.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/mpcdf/soft/SLE_15/packages/x86_64/anaconda/3/2020.02/bin/python3.7

import jax

devices = jax.local_devices()

print(devices)
26 changes: 26 additions & 0 deletions job_scripts/jax_gpu_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash -l
# Standard output and error:
#SBATCH -o ./output_slurm/singlejob.%j.out
#SBATCH -e ./output_slurm/singlejob.%j.err
# Initial working directory:
#SBATCH -D ./
# Job name
#SBATCH -J egap_pbj
#
#SBATCH --nodes=1 # Request 1 or more full nodes
#SBATCH --constraint="gpu" # Request a GPU node
#SBATCH --gres=gpu:a100:4 # Use all a100 GPU on a node
#SBATCH --cpus-per-task=10
#SBATCH --ntasks-per-core=1
#SBATCH --mem=32000 # Request 32 GB of main memory per node in MB units.
#SBATCH --mail-type=none
#SBATCH [email protected]
#SBATCH --time=12:00:00

# load the environment with modules and python packages
cd ~/envs ; source ~/envs/activate_jax.sh
cd ~/jraph_MPEU

export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}

srun python jax_GPU.py >> text.out
12 changes: 6 additions & 6 deletions job_scripts_mpcdf/job_csv_to_ase.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -l
# Standard output and error:
#SBATCH -o ./output_slurm/datajob.out.%j
#SBATCH -e ./output_slurm/datajob.err.%j
#SBATCH -o ./output_slurm/datajob.%j.out
#SBATCH -e ./output_slurm/datajob.%j.err
# Initial working directory:
#SBATCH -D ./
# Job name
Expand All @@ -10,10 +10,10 @@
#SBATCH --nodes=1 # Request 1 or more full nodes
#SBATCH --cpus-per-task=10
#SBATCH --ntasks-per-core=1
#SBATCH --mem=32000 # Request 32 GB of main memory per node in MB units.
#SBATCH --mem=16000 # Request main memory per node in MB units.
#SBATCH --mail-type=none
#SBATCH [email protected]
#SBATCH --time=1:00:00 # 1 hour
#SBATCH --time=24:00:00 # time limit in hours

# load the environment with modules and python packages
cd ~/envs ; source ~/envs/activate_jax.sh
Expand All @@ -22,5 +22,5 @@ cd ~/jraph_MPEU
export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}

srun python scripts/data/aflow_to_graphs.py \
-f aflow/egaps_eform_all.csv -o aflow/graphs_all_24knn.db -cutoff_type knearest \
-cutoff 24.0
--file_in=aflow/eform_all.csv --file_out=aflow/eform_all_graphs_1.db --cutoff_type=knearest \
--cutoff=12.0
29 changes: 29 additions & 0 deletions job_scripts_mpcdf/job_evaluate_batch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash -l
# specify the indexes (max. 30000) of the job array elements (max. 300 - the default job submit limit per user)
#SBATCH --array=183,17,85,236,231,53,458,470,140,409
# Standard output and error:
#SBATCH -o ./output_slurm/eval_%A_%a.out
#SBATCH -e ./output_slurm/eval_%A_%a.err
# Initial working directory:
#SBATCH -D ./
# Job name
#SBATCH -J eval_batch
#
#SBATCH --nodes=1 # Request 1 or more full nodes
#SBATCH --constraint="gpu" # Request a GPU node
#SBATCH --gres=gpu:a100:1 # Use one a100 GPU
#SBATCH --cpus-per-task=10
#SBATCH --ntasks-per-core=1
#SBATCH --mem=32000 # Request 32 GB of main memory per node in MB units.
#SBATCH --mail-type=none
#SBATCH [email protected]
#SBATCH --time=12:00:00 # 12h should be enough for any configuration

# load the environment with modules and python packages
cd ~/envs ; source ~/envs/activate_jax.sh
cd ~/jraph_MPEU

export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}

srun python scripts/plotting/error_analysis.py \
--file=./results/aflow/ef_rand_search/id${SLURM_ARRAY_TASK_ID} \
27 changes: 27 additions & 0 deletions job_scripts_mpcdf/job_evaluate_single.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash -l
# Standard output and error:
#SBATCH -o ./output_slurm/evaljob.%j.out
#SBATCH -e ./output_slurm/evaljob.%j.err
# Initial working directory:
#SBATCH -D ./
# Job name
#SBATCH -J eval
#
#SBATCH --nodes=1 # Request 1 or more full nodes
#SBATCH --constraint="gpu" # Request a GPU node
#SBATCH --gres=gpu:a100:1 # Use one a100 GPU
#SBATCH --cpus-per-task=10
#SBATCH --ntasks-per-core=1
#SBATCH --mem=32000 # Request 32 GB of main memory per node in MB units.
#SBATCH --mail-type=none
#SBATCH [email protected]
#SBATCH --time=12:00:00

# load the environment with modules and python packages
cd ~/envs ; source ~/envs/activate_jax.sh
cd ~/jraph_MPEU

export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}

srun python scripts/plotting/error_analysis.py \
--file=results/aflow/ef_full_data/ --label=ef --plot=nothing
4 changes: 2 additions & 2 deletions job_scripts_mpcdf/job_rand_search.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash -l
# specify the indexes (max. 30000) of the job array elements (max. 300 - the default job submit limit per user)
#SBATCH --array=101-200%50
#SBATCH --array=12
# Standard output and error:
#SBATCH -o ./output_slurm/job_%A_%a.out
#SBATCH -e ./output_slurm/job_%A_%a.err
Expand Down Expand Up @@ -29,4 +29,4 @@ srun python scripts/crossval/crossval_mc.py \
--workdir=./results/aflow/egap_rand_search/id${SLURM_ARRAY_TASK_ID} \
--config=jraph_MPEU_configs/aflow_rand_search_egap.py \
--index=${SLURM_ARRAY_TASK_ID} \
--split_file=./results/aflow/egap_rand_search/splits_ins.json
# --split_file=./results/aflow/ef_rand_search/splits_ins.json
6 changes: 3 additions & 3 deletions job_scripts_mpcdf/job_train_single.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Initial working directory:
#SBATCH -D ./
# Job name
#SBATCH -J egap_pbj
#SBATCH -J ef_schnet
#
#SBATCH --nodes=1 # Request 1 or more full nodes
#SBATCH --constraint="gpu" # Request a GPU node
Expand All @@ -24,5 +24,5 @@ cd ~/jraph_MPEU
export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}

srun python scripts/main.py \
--workdir=./results/aflow/egap_pbj \
--config=jraph_MPEU_configs/aflow_egap_pbj.py
--workdir=./results/aflow/ef_schnet_small_new \
--config=jraph_MPEU_configs/aflow_ef_schnet.py
58 changes: 8 additions & 50 deletions jraph_MPEU/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from jraph_MPEU.utils import (
get_valid_mask, load_config, normalize_targets_dict, scale_targets
)
from jraph_MPEU.models import load_model
from jraph_MPEU.models.loading import load_model


def get_predictions(dataset, net, params, hk_state, label_type, mc_dropout=False):
def get_predictions(dataset, net, params, hk_state, label_type,
mc_dropout=False, batch_size=32):
"""Get predictions for a single dataset split.
Args:
Expand Down Expand Up @@ -52,7 +53,7 @@ def predict_batch(graphs, rng, hk_state):
preds = []
for _ in range(n_samples):
reader = DataReader(
data=dataset, batch_size=32, repeat=False)
data=dataset, batch_size=batch_size, repeat=False)
preds_sample = np.array([])
for graph in reader:
key, subkey = jax.random.split(key)
Expand All @@ -76,51 +77,6 @@ def predict_batch(graphs, rng, hk_state):
return preds


def load_inference_file(workdir, redo=False):
"""Return the inferences of the model and data defined in workdir.
This function finds inferences that have already been saved in the working
directory. If a file with inferences has been found, they are loaded and
returned in a dictionary with splits as keys.
If there is no file with inferences in workdir or 'redo' is true, the model
is loaded and inferences are calculated.
"""
config = load_config(workdir)
inference_dict = {}
path = workdir + '/inferences.pkl'
if not os.path.exists(path) or redo:
# compute the inferences
logging.info('Loading model.')
net, params, hk_state = load_model(workdir, is_training=False)
logging.info('Loading datasets.')
dataset, mean, std = load_data(workdir)
splits = dataset.keys()
print(splits)

for split in splits:
data_list = dataset[split]
logging.info(f'Predicting {split} data.')
preds = get_predictions(
data_list, net, params, hk_state, config.label_type)
targets = [graph.globals[0] for graph in data_list]
# scale the predictions and targets using the std
preds = np.array(preds)*float(std) + mean
targets = np.array(targets)*float(std) + mean

inference_dict[split] = {}
inference_dict[split]['preds'] = preds
inference_dict[split]['targets'] = targets

with open(path, 'wb') as inference_file:
pickle.dump(inference_dict, inference_file)
else:
# load inferences from dict
logging.info('Loading existing inference.')
with open(path, 'rb') as inference_file:
inference_dict = pickle.load(inference_file)
return inference_dict


def get_results_df(workdir, limit=None, mc_dropout=False):
"""Return a pandas dataframe with predictions and their database entries.
Expand Down Expand Up @@ -162,7 +118,8 @@ def get_results_df(workdir, limit=None, mc_dropout=False):
row_dict['numbers'] = row.numbers # get atomic numbers, when loading
# the csv from file, this has to be converted from string to list
row_dict['formula'] = row.formula
inference_df = inference_df.append(row_dict, ignore_index=True)
inference_df = pandas.concat(
[inference_df, pandas.DataFrame([row_dict])], ignore_index=True)
# Normalize graphs and targets
# Convert the atomic numbers in nodes to classes and set number of classes.
num_path = os.path.join(workdir, 'atomic_num_list.json')
Expand All @@ -181,7 +138,8 @@ def get_results_df(workdir, limit=None, mc_dropout=False):

logging.info('Predicting on dataset.')
preds = get_predictions(
graphs, net, params, hk_state, config.label_type, mc_dropout)
graphs, net, params, hk_state, config.label_type, mc_dropout,
config.batch_size)
if config.label_type == 'scalar':
# scale the predictions using the std and mean
logging.debug(f'using {pooling} pooling function')
Expand Down
4 changes: 4 additions & 0 deletions jraph_MPEU/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,12 @@ def get_graph_cutoff(atoms: Atoms, cutoff):
for i in range(len(atoms)):
nodes.append(atom_numbers[i])

# Loop over the atoms in the unit cell.
for i in range(len(atoms)):
# Get the neighbourhoods of atom i
neighbor_indices, offset = neighborhood.get_neighbors(i)
# Loop over the neighbours of atom i. Offset helps us calculate the
# distance to atoms in neighbouring unit cells.
for j, offs in zip(neighbor_indices, offset):
i_pos = atom_positions[i]
j_pos = atom_positions[j] + np.dot(offs, unitcell)
Expand Down
Empty file added jraph_MPEU/models/__init__.py
Empty file.
Loading

0 comments on commit ab211b0

Please sign in to comment.