FuseNet implementation in PyTorch
This is the PyTorch implementation for FuseNet, developed based on Pix2Pix code.
- Linux
- Python 3.7.0
- CPU or NVIDIA GPU + CUDA CuDNN
- Install PyTorch 0.4.1.post2 and dependencies from http://pytorch.org
- Clone this repo:
git clone https://github.com/MehmetAygun/fusenet-pytorch
cd fusenet-pytorch
pip install -r requirements.txt
- Download and untar the preprocessed sunrgbd dataset under
/datasets/sunrgbd
- Download the dataset and create the training set
cd datasets
sh download_nyuv2.sh
python create_training_set.py
- Download the
scannet_frames_25k
andscannet_frames_test
under/datasets/scannet/tasks/
- To view training errors and loss plots, set
--display_id 1
, runpython -m visdom.server
and click the URL http://localhost:8097 - Checkpoints are saved under
./checkpoints/sunrgbd/
python train.py --dataroot datasets/sunrgbd --dataset sunrgbd --name sunrgbd
python test.py --dataroot datasets/sunrgbd --dataset sunrgbd --name sunrgbd --epoch 400
python train.py --dataroot datasets/nyuv2 --dataset nyuv2 --name nyuv2
python test.py --dataroot datasets/nyuv2 --dataset nyuv2 --name nyuv2 --epoch 400
python train.py --dataroot datasets/scannet/tasks/scannet_frames_25k --dataset scannetv2 \
--name scannetv2
python test.py --dataroot datasets/scannet/tasks/scannet_frames_25k --dataset scannetv2 \
--name scannetv2 --epoch 260 --phase val
python test.py --dataroot datasets/scannet/tasks/scannet_frames_test --dataset scannetv2 \
--name scannetv2 --epoch 260 --phase test
- We use the training scheme defined in FuseNet
- Loss is weighted for SUNRGBD dataset
- Learning rate is set to 0.01 for NYUv2 dataset
- Results can be improved with a hyper-parameter search
- Results on the scannetv2-test (w/o class-weighted loss) can be found here
Dataset | FuseNet-SF5 (CAFFE) | FuseNet-SF5 | ||||
overall | mean | iou | overall | mean | iou | |
sunrgbd | 76.30 | 48.30 | 37.30 | 75.41 | 46.48 | 35.69 |
nyuv2 | 66.00 | 43.40 | 32.70 | 68.76 | 46.42 | 35.48 |
scannetv2-val | -- | -- | -- | 76.32 | 55.84 | 44.12 |
scannetv2-cls_weighted-val | -- | -- | -- | 76.26 | 55.74 | 44.40 |
@inproceedings{hazirbas16fusenet,
Title = {{FuseNet}: Incorporating Depth into Semantic Segmentation via Fusion-Based CNN Architecture},
Author = {Hazirbas, Caner and Ma, Lingni and Domokos, Csaba and Cremers, Daniel},
Booktitle = {Asian Conference on Computer Vision ({ACCV})},
Year = {2016},
Doi = {10.1007/978-3-319-54181-5_14},
Url = {https://github.com/tum-vision/fusenet}
}
Code is inspired by pytorch-CycleGAN-and-pix2pix.