Go to src/models/mlp.py
.
MeanMLP
and default_HPs
is what you need.
conda create -n mlp_nn python=3.9
conda activate mlp_nn
conda install pytorch torchvision torchaudio pytorch-cuda=11.3 -c pytorch -c nvidia
pip install -r requirements.txt
DATASETS=('fbirn' 'bsnip' 'cobre' 'abide_869' 'oasis' 'adni' 'hcp' 'ukb' 'ukb_age_bins' 'fbirn_roi' 'abide_roi' 'hcp_roi_752')
MODELS=('mlp' 'lstm' 'pe_transformer' 'milc' 'dice' 'bnt' 'fbnetgen' 'brainnetcnn' 'lr')
for dataset in "${DATASETS[@]}"; do
for model in "${MODELS[@]}"; do
PYTHONPATH=. python scripts/run_experiments.py mode=exp dataset=$dataset model=$model prefix=general ++model.default_HP=True
done;
done
DATASETS=('hcp' 'hcp_roi_752' 'hcp_schaefer' 'hcp_non_mni_2' 'hcp_mni_3' 'ukb')
MODELS=('mlp' 'lstm' 'mean_lstm' 'pe_transformer' 'mean_pe_transformer')
for model in "${MODELS[@]}"; do
PYTHONPATH=. python scripts/run_experiments.py mode=exp dataset='hcp_time' model=$model prefix=additional ++model.default_HP=True
for dataset in "${DATASETS[@]}"; do
PYTHONPATH=. python scripts/run_experiments.py mode=exp dataset=$dataset model=$model prefix=additional ++model.default_HP=True
PYTHONPATH=. python scripts/run_experiments.py mode=exp dataset=$dataset model=$model prefix=additional ++model.default_HP=True permute=Multiple
done;
done
Plotting scripts can be found at scripts/plot_figures.ipynb
.
Data loading scripts rely on fetching the results from WandB. If you set WandB offline mode while running the experiments, you'll need to load the csv files from the experiment folders in assets/logs
.
-
mode
:tune
- tune mode: run multiple experiments with different hyperparamsexp
- experiment mode: run experiments with the best hyperparams found in thetune
mode, or with default hyperparamsdefault_HPs
is set toTrue
-
model
: model for the experiment. Models' config files can be found atsrc/conf/model
, and their sourse code is located atsrc/models
-
mlp
- our hero, TS model -
lstm
- classic LSTM model for classification, TS model (not used in the paper) -
mean_lstm
-lstm
with LSTM output embeddings averaging, TS model -
pe_transformer
- BERT-inspired model, uses transformer endocder, TS model (not used in the paper) -
mean_pe_transformer
-pe_transformer
with encoder output averaging, TS model -
dice
- TS model, https://www.sciencedirect.com/science/article/pii/S1053811922008588?via%3Dihub -
milc
- TS model, https://arxiv.org/abs/2007.16041 -
bnt
- FNC model, https://arxiv.org/abs/2210.06681 -
fbnetgen
- TS+FNC model, https://arxiv.org/abs/2205.12465 -
brainnetcnn
- FNC model, https://www.sciencedirect.com/science/article/pii/S1053811916305237 -
lr
- Logistic Regression, FNC model
-
-
dataset
: dataset for the experiments. Datasets' config files can be found atsrc/conf/dataset
, and their loading scripts are located atsrc/datasets
.-
fbirn
- ICA FBIRN dataset -
cobre
- ICA COBRE dataset -
bsnip
- ICA BSNIP dataset -
abide
- ICA ABIDE dataset (not used in the paper) -
abide_869
- ICA ABIDE extended dataset -
oasis
- ICA OASIS dataset -
adni
- ICA ADNI dataset -
hcp
- ICA HCP dataset -
ukb
- ICA UKB dataset withsex
labels -
ukb_age_bins
- ICA UKB dataset withsex X age bins
labels -
fbirn_roi
- Schaefer 200 ROIs FBIRN dataset -
abide_roi
- Schaefer 200 ROIs ABIDE dataset -
hcp_roi_752
- Schaefer 200 ROIs HCP dataset -
hcp_non_mni_2
- Deskian/Killiany ROIs HCP dataset in ORIG space -
hcp_mni_3
- Deskian/Killiany ROIs HCP dataset in MNI space -
hcp_schaefer
- Noisy Schaefer 200 ROIs HCP dataset -
hcp_time
- ICA HCP dataset with normal/inversed time direcion
-
prefix
: custom prefix for the project- default prefix is UTC time
- appears in the name of logs directory and the name of WandB project
exp
mode runs with custom prefix will use HPs fromtune
mode runs with the same prefix- unless model.default_HP is set to
True
- unless model.default_HP is set to
permute
: whether TS models should be trained on time-reshuffled data- set to
permute=Multiple
to permute
- set to
wandb_silent
: whether wandb logger should run silently (default:True
)wandb_offline
: whether wandb logger should only log results locally (default:False
)