Skip to content

StefanIsSmart/egnn

 
 

Repository files navigation

E(n) Equivariant Graph Neural Networks

Official implementation (Pytorch 1.7.1) of:

E(n) Equivariant Graph Neural Networks
Victor Garcia Satorras, Emiel Hogeboom, Max Welling
https://arxiv.org/abs/2102.09844

Abstract: This paper introduces a new model to learn graph neural networks equivariant to rotations, translations, reflections and permutations called E(n)-Equivariant Graph Neural Networks (EGNNs). In contrast with existing methods, our work does not require computationally expensive higher-order representations in intermediate layers while it still achieves competitive or better performance. In addition, whereas existing methods are limited to equivariance on 3 dimensional spaces, our model is easily scaled to higher-dimensional spaces. We demonstrate the effectiveness of our method on dynamical systems modelling, representation learning in graph autoencoders and predicting molecular properties.

Example code

For a simple example of a EGNN implementation click here. Or copy the file models/egnn_clean/egnn_clean.py into your working directory and run:

import egnn_clean as eg
import torch

# Dummy parameters
batch_size = 8
n_nodes = 4
n_feat = 1
x_dim = 3

# Dummy variables h, x and fully connected edges
h = torch.ones(batch_size * n_nodes, n_feat)
x = torch.ones(batch_size * n_nodes, x_dim)
edges, edge_attr = eg.get_edges_batch(n_nodes, batch_size)

# Initialize EGNN
egnn = eg.EGNN(in_node_nf=n_feat, hidden_nf=32, out_node_nf=1, in_edge_nf=1)

# Run EGNN
h, x = egnn(h, x, edges, edge_attr)

If you are using the EGNN in a new application we recommend checking the EGNN attributes description that contains some upgrades not included in the paper.

N-body system experiment

Create N-body dataset

cd n_body_system/dataset
python -u generate_dataset.py --num-train 10000 --seed 43 --sufix small

Run experiments

EGNN model

python -u main_nbody.py --exp_name exp_1_egnn_vel --model egnn_vel --max_training_samples 3000 --lr 5e-4

GNN model

python -u main_nbody.py --exp_name exp_1_gnn --model gnn --max_training_samples 3000 --lr 1e-3

Radial Field

python -u main_nbody.py --exp_name exp_1_gnn --model rf_vel --n_layers 4 --max_training_samples 3000 --lr 2e-4 

Tensor Field Networks

python -u main_nbody.py --exp_name exp_1_tfn --model tfn --max_training_samples 3000 --lr 1e-3 --degree 2 --nf 32

SE3 Transformer

python -u main_nbody.py --exp_name exp_1_se3 --model se3_transformer --max_training_samples 3000 --div 1 --degree 3 --nf 64 --lr 5e-3

N-body system sweep experiment

For the experiment where we sweep over different amounts of training samples you should create a larger training dataset

cd n_body_system/dataset
python -u generate_dataset.py  --num-train 50000 --sample-freq 500 

Then you can train on in this new partition by adding --dataset nbody to the above training commands. You can choose the number of training samples by modifying the argument --max_training_samples <number of training samples>

E.g. for the EGNN for 10.000 samples

python -u main_nbody.py --exp_name exp_debug --model egnn_vel --max_training_samples 10000 --lr 5e-4 --dataset nbody

Graph Autoencoder experiment

GNN Erdos & Renyi

python -u main_ae.py --exp_name exp1_gnn_erdosrenyi --model ae --dataset erdosrenyinodes_0.25_none --K 8 --emb_nf 8 --noise_dim 0

GNN Community

python -u main_ae.py --exp_name exp1_gnn_community --model ae --dataset community_ours --K 8 --emb_nf 8 --noise_dim 0

Noise-GNN Erdos&Renyi

python -u main_ae.py --exp_name exp1_gnn_noise_erdosrenyi --model ae --dataset erdosrenyinodes_0.25_none --K 8 --emb_nf 8 --noise_dim 1

Noise GNN Community

python -u main_ae.py --exp_name exp1_noise_gnn_community --model ae --dataset community_ours --K 8 --emb_nf 8 --noise_dim 1

Radial Field Erdos&Renyi

python -u main_ae.py --exp_name exp1_rf_erdosrenyi --model ae_rf --dataset erdosrenyinodes_0.25_none --K 8 --emb_nf 8

Radial Field Community

python -u main_ae.py --exp_name exp1_rf_community --model ae_rf --dataset community_ours --K 8 --emb_nf 8

EGNN Erdos&Renyi

python -u main_ae.py --exp_name exp1_egnn_erdosrenyi --model ae_egnn --dataset erdosrenyinodes_0.25_none --K 8 --emb_nf 8

EGNN Community

python -u main_ae.py --exp_name exp1_egnn_community --model ae_egnn --dataset community_ours --K 8 --emb_nf 8

The following overfit eperiments are for (p=0.2). p can be modified by replacing the 0.2 value from the dataset name (e.g. erdosrenyinodes_0.2_overfit) to other values.

GNN Erdos&Renyi overfit

python -u main_ae.py --model ae --dataset erdosrenyinodes_0.2_overfit --epochs 10001 --test_interval 200 --K 16 --emb_nf 16 2>&1 | tee outputs_ae/$EXP.log &

Noise-GNN Erdos&Renyi overfit

python -u main_ae.py --model ae --dataset erdosrenyinodes_0.2_overfit --epochs 10001 --test_interval 200 --noise_dim 1 --K 16 --emb_nf 16 2>&1 | tee outputs_ae/$EXP.log &

EGNN Erdos&Renyi overfit

python -u main_ae.py --model ae_egnn --dataset erdosrenyinodes_0.2_overfit --epochs 10001 --test_interval 200 --K 16 --emb_nf 16 2>&1 | tee outputs_ae/$EXP.log &

QM9 experiment

properties --> [alpha | gap | homo | lumo | mu | Cv | G | H | r2 | U | U0 | zpve]
learning rate --> 1e-3 for [gap, homo lumo], 5r-4 for the rest

python -u main_qm9.py --num_workers 2 --lr 5e-4 --property alpha --exp_name exp_1_alpha
python -u main_qm9.py --num_workers 2 --lr 1e-3 --property gap --exp_name exp_1_gap
python -u main_qm9.py --num_workers 2 --lr 1e-3 --property homo --exp_name exp_1_homo
python -u main_qm9.py --num_workers 2 --lr 1e-3 --property lumo --exp_name exp_1_lumo
python -u main_qm9.py --num_workers 2 --lr 5e-4 --property mu --exp_name exp_1_mu
python -u main_qm9.py --num_workers 2 --lr 5e-4 --property Cv --exp_name exp_1_Cv
python -u main_qm9.py --num_workers 2 --lr 5e-4 --property G --exp_name exp_1_G
python -u main_qm9.py --num_workers 2 --lr 5e-4 --property H --exp_name exp_1_H
python -u main_qm9.py --num_workers 2 --lr 5e-4 --property r2 --exp_name exp_1_r2
python -u main_qm9.py --num_workers 2 --lr 5e-4 --property U --exp_name exp_1_U
python -u main_qm9.py --num_workers 2 --lr 5e-4 --property U0 --exp_name exp_1_U0
python -u main_qm9.py --num_workers 2 --lr 5e-4 --property zpve --exp_name exp_1_zpve

Acknowledgements

The Robert Bosch GmbH is acknowledged for financial support.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.9%
  • Shell 0.1%