From a2cf1a714acdc2840b71fc21e1700030f6a4bf29 Mon Sep 17 00:00:00 2001 From: Seon-Wook Park Date: Sat, 7 Dec 2019 12:56:34 +0100 Subject: [PATCH] First code commit for open-sourcing Faze --- .gitignore | 2 + .gitmodules | 3 + LICENSE | 49 ++ README.md | 76 ++- preprocess | 1 + requirements.txt | 9 + src/.flake8 | 7 + src/1_train_dt_ed.py | 861 ++++++++++++++++++++++++++++ src/2_meta_learning.py | 733 +++++++++++++++++++++++ src/3_combine_maml_results.py | 127 ++++ src/checkpoints_manager.py | 73 +++ src/data.py | 172 ++++++ src/full_train_test_and_plot.bash | 113 ++++ src/gazecapture_split.json | 1 + src/losses/__init__.py | 18 + src/losses/all_frontals_equal.py | 38 ++ src/losses/batch_hard_triplet.py | 128 +++++ src/losses/embedding_consistency.py | 118 ++++ src/losses/gaze_angular.py | 35 ++ src/losses/gaze_mse.py | 22 + src/losses/reconstruction_l1.py | 21 + src/models/__init__.py | 12 + src/models/densenet.py | 207 +++++++ src/models/dt_ed.py | 386 +++++++++++++ 24 files changed, 3201 insertions(+), 11 deletions(-) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 LICENSE create mode 160000 preprocess create mode 100644 requirements.txt create mode 100644 src/.flake8 create mode 100644 src/1_train_dt_ed.py create mode 100644 src/2_meta_learning.py create mode 100644 src/3_combine_maml_results.py create mode 100644 src/checkpoints_manager.py create mode 100644 src/data.py create mode 100644 src/full_train_test_and_plot.bash create mode 100644 src/gazecapture_split.json create mode 100644 src/losses/__init__.py create mode 100644 src/losses/all_frontals_equal.py create mode 100644 src/losses/batch_hard_triplet.py create mode 100644 src/losses/embedding_consistency.py create mode 100644 src/losses/gaze_angular.py create mode 100644 src/losses/gaze_mse.py create mode 100644 src/losses/reconstruction_l1.py create mode 100644 src/models/__init__.py create mode 100644 src/models/densenet.py create mode 100644 src/models/dt_ed.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..789c90b --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +src/output* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..5dc36fb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "preprocess"] + path = preprocess + url = https://github.com/swook/faze_preprocess diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0c9a4bd --- /dev/null +++ b/LICENSE @@ -0,0 +1,49 @@ +Nvidia Source Code License (1-Way Commercial) + + +1. Definitions + +“Licensor” means any person or entity that distributes its Work. + +“Software” means the original work of authorship made available under this License. + +“Work” means the Software and any additions to or derivative works of the Software that are made available under this License. + +“Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates. + +The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. + + +2. License Grants + +2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. + +2.2 Patent Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free patent license to make, have made, use, sell, offer for sale, import, and otherwise transfer its Work, in whole or in part. The foregoing license applies only to the patent claims licensable by Licensor that would be infringed by Licensor’s Work (or portion thereof) individually and excluding any combinations with any other materials or technology. + + +3. Limitations + +3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. + +3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. + +3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or it’s affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only. + +3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately. + +3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. + +3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately. + + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. SOME STATES’ CONSUMER LAWS DO NOT ALLOW EXCLUSION OF AN IMPLIED WARRANTY, SO THIS DISCLAIMER MAY NOT APPLY TO YOU. + + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + diff --git a/README.md b/README.md index 50c30d8..beaacaa 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,72 @@ -# Faze: Few-Shot Adaptive Gaze Estimation +# FAZE: Few-Shot Adaptive Gaze Estimation -This repository will contain the code for training, evaluation, and live demonstration of our ICCV 2019 work, which was presented as an Oral presentation in Seoul, Korea. Faze is a framework for few-shot adaptation of gaze estimation networks, consisting of equivariance learning (via the **DT-ED** or Disentangling Transforming Encoder-Decoder architecture) and meta-learning with gaze embeddings as input. +This repository contains the code for training and evaluation of our ICCV 2019 work, which was presented as an Oral presentation. FAZE is a framework for few-shot adaptation of gaze estimation networks, consisting of equivariance learning (via the **DT-ED** or Disentangling Transforming Encoder-Decoder architecture) and meta-learning with gaze-direction embeddings as input. + +![The FAZE Framework](https://ait.ethz.ch/projects/2019/faze/banner.jpg) + + +## Links +* [NVIDIA Project Page](https://research.nvidia.com/publication/2019-10_Few-Shot-Adaptive-Gaze) +* [ETH Zurich Project Page](https://ait.ethz.ch/projects/2019/faze/) +* [arXiv Page](https://arxiv.org/abs/1905.01941) +* [CVF Open Access PDF](http://openaccess.thecvf.com/content_ICCV_2019/papers/Park_Few-Shot_Adaptive_Gaze_Estimation_ICCV_2019_paper.pdf) +* [ICCV 2019 Presentation](https://conftube.com/video/ByfFufRhuRc?tocitem=17) +* [Pre-processing Code GitHub Repository](https://github.com/swook/faze_preprocess) _(also included as a submodule in this repository)_ -![The Faze Framework](https://ait.ethz.ch/projects/2019/faze/banner.jpg) ## Setup -Further setup instructions will be made available soon. For now, please pre-process the *GazeCapture* and *MPIIGaze* datasets using the code-base at https://github.com/swook/faze_preprocess -## Additional Resources -* Project Page (ETH Zurich): https://ait.ethz.ch/projects/2019/faze/ -* Project Page (Nvidia): https://research.nvidia.com/publication/2019-10_Few-Shot-Adaptive-Gaze -* arXiv Page: https://arxiv.org/abs/1905.01941 -* CVF Open Access PDF: http://openaccess.thecvf.com/content_ICCV_2019/papers/Park_Few-Shot_Adaptive_Gaze_Estimation_ICCV_2019_paper.pdf -* Pre-processing Code: https://github.com/swook/faze_preprocess +### 1. Datasets + +Pre-process the *GazeCapture* and *MPIIGaze* datasets using the code-base at https://github.com/swook/faze_preprocess which is also available as a git submodule at the relative path, `preprocess/`. + +If you have already cloned this `few_shot_gaze` repository without pulling the submodules, please run: + + git submodule update --init --recursive + +After the dataset preprocessing procedures have been performed, we can move on to the next steps. + +### 2. Prerequisites + +This codebase should run on most standard Linux systems. We specifically used Ubuntu + +Please install the following prerequisites manually (as well as their dependencies), by following the instructions found below: +* PyTorch 1.3 - https://pytorch.org/get-started/locally/ +* NVIDIA Apex - https://github.com/NVIDIA/apex#quick-start + * *please note that only NVIDIA Volta and newer architectures can benefit from AMP training via NVIDIA Apex.* + +The remaining Python package dependencies can be installed by running: + + pip3 install --user --upgrade -r requirements.txt + +### 3. Pre-trained weights for the DT-ED architecture + +You can obtain a copy of the pre-trained weights for the Disentangling Transforming Encoder-Decoder from the following location. + + cd src/ + wget -N https://ait.ethz.ch/projects/2019/faze/downloads/outputs_of_full_train_test_and_plot.zip + unzip -o outputs_of_full_train_test_and_plot.zip + +### 4. Training, Meta-Learning, and Final Evaluation + +Run the all-in-one example bash script with: + + cd src/ + bash full_train_test_and_plot.bash + +The bash script should be self-explanatory and can be edited to replicate the final FAZE model evaluation procedure, given that hardware requirements are satisfied (8x GPUs, where each are Tesla V100 GPUs with 32GB of memory). + +The pre-trained DT-ED weights should be loaded automatically by the script `1_train_dt_ed.py`. Please note that this model can take a long time to train when training from scratch, so we recommend adjusting batch sizes and the using multiple GPUs (the code is multi-GPU-ready). + +The Meta-Learning step is also very time consuming, particularly because it must be run for every value of `k` or *number of calibration samples*. The code pertinent to this step is `2_meta_learning.py`, and its execution is recommended to be done in parallel as shown in `full_train_test_and_plot.bash`. + +### 5. Outputs + +When the full pipeline successfully runs, you will find some outputs in the path `src/outputs_of_full_train_test_and_plot`, in particular: +* **walks/**: mp4 videos of latent space walks in gaze direction and head orientation +* **Zg_OLR1e-03_IN5_ILR1e-05_Net64/**: outputs of the meta-learning step. +* **Zg_OLR1e-03_IN5_ILR1e-05_Net64 MAML MPIIGaze.pdf**: plotted results of the few-shot learning evaluations on MPIIGaze. +* **Zg_OLR1e-03_IN5_ILR1e-05_Net64 MAML GazeCapture (test).pdf**: plotted results of the few-shot learning evaluations on the GazeCapture test set. ## Bibtex Please cite our paper when referencing or using our code. @@ -25,6 +79,6 @@ Please cite our paper when referencing or using our code. location = {Seoul, Korea} } -## Acknowledgements +## Acknowledgements Seonwook Park carried out this work during his internship at Nvidia. This work was supported in part by the ERC Grant OPTINT (StG-2016-717054). diff --git a/preprocess b/preprocess new file mode 160000 index 0000000..5c33caa --- /dev/null +++ b/preprocess @@ -0,0 +1 @@ +Subproject commit 5c33caaa1bc271a8d6aad21837e334108f293683 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3e35445 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +apex +h5py +imageio +moviepy +numpy +opencv_python +torch +torchvision +tqdm diff --git a/src/.flake8 b/src/.flake8 new file mode 100644 index 0000000..95c0e0e --- /dev/null +++ b/src/.flake8 @@ -0,0 +1,7 @@ +[flake8] +doctests = True +enable-extensions = docstrings +ignore = E402, W503 +max-line-length = 100 +statistics = True +show-source = True diff --git a/src/1_train_dt_ed.py b/src/1_train_dt_ed.py new file mode 100644 index 0000000..5cf7525 --- /dev/null +++ b/src/1_train_dt_ed.py @@ -0,0 +1,861 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + + +import argparse +parser = argparse.ArgumentParser(description='Train DT-ED') + +# architecture specification +parser.add_argument('--densenet-growthrate', type=int, default=32, + help='growth rate of encoder/decoder base densenet archi. (default: 32)') +parser.add_argument('--z-dim-app', type=int, default=64, + help='size of 1D latent code for appearance (default: 64)') +parser.add_argument('--z-dim-gaze', type=int, default=2, + help='size of 2nd dim. of 3D latent code for each gaze direction (default: 2)') +parser.add_argument('--z-dim-head', type=int, default=16, + help='size of 2nd dim. of 3D latent code for each head rotation (default: 16)') +parser.add_argument('--decoder-input-c', type=int, default=32, + help='size of feature map stack as input to decoder (default: 32)') + +parser.add_argument('--normalize-3d-codes', action='store_true', + help='normalize rows of 3D latent codes') +parser.add_argument('--normalize-3d-codes-axis', default=1, type=int, choices=[1, 2, 3], + help='axis over which to normalize 3D latent codes') + +parser.add_argument('--triplet-loss-type', choices=['angular', 'euclidean'], + help='Apply triplet loss with selected distance metric') +parser.add_argument('--triplet-loss-margin', type=float, default=0.0, + help='Triplet loss margin') +parser.add_argument('--triplet-regularize-d-within', action='store_true', + help='Regularize triplet loss by mean within-person distance') + +parser.add_argument('--all-equal-embeddings', action='store_true', + help='Apply loss to make all frontalized embeddings similar') + +parser.add_argument('--embedding-consistency-loss-type', + choices=['angular', 'euclidean'], default=None, + help='Apply embedding_consistency loss with selected distance metric') +parser.add_argument('--embedding-consistency-loss-warmup-samples', + type=int, default=1000000, + help='Start from 0.0 and warm up embedding consistency loss until n samples') + +parser.add_argument('--backprop-gaze-to-encoder', action='store_true', + help='Add gaze loss term to single loss and backprop to entire network.') + +parser.add_argument('--coeff-l1-recon-loss', type=float, default=1.0, + help='Weight/coefficient for L1 reconstruction loss term') +parser.add_argument('--coeff-gaze-loss', type=float, default=0.1, + help='Weight/coefficient for gaze direction loss term') +parser.add_argument('--coeff-embedding_consistency-loss', type=float, default=2.0, + help='Weight/coefficient for embedding_consistency loss term') + +# training +parser.add_argument('--pick-exactly-per-person', type=int, default=None, + help='Pick exactly this many entries per person for training.') +parser.add_argument('--pick-at-least-per-person', type=int, default=400, + help='Only pick person for training if at least this many entries.') +parser.add_argument('--use-apex', action='store_true', + help='Use half-precision floating points via the apex library.') +parser.add_argument('--base-lr', type=float, default=0.00005, metavar='LR', + help='learning rate (to be multiplied with batch size) (default: 0.00005)') +parser.add_argument('--warmup-period-for-lr', type=int, default=1000000, metavar='LR', + help=('no. of data entries (not batches) to have processed ' + + 'when stopping gradual ramp up of LR (default: 1000000)')) +parser.add_argument('--batch-size', type=int, default=128, metavar='N', + help='training batch size (default: 128)') +parser.add_argument('--decay-interval', type=int, default=0, metavar='N', + help='iterations after which to decay the learning rate (default: 0)') +parser.add_argument('--decay', type=float, default=0.8, metavar='decay', + help='learning rate decay multiplier (default: 0.8)') +parser.add_argument('--num-training-epochs', type=float, default=20, metavar='N', + help='number of steps to train (default: 20)') +parser.add_argument('--l2-reg', type=float, default=1e-4, + help='l2 weights regularization coefficient (default: 1e-4)') +parser.add_argument('--print-freq-train', type=int, default=20, metavar='N', + help='print training statistics after every N iterations (default: 20)') +parser.add_argument('--print-freq-test', type=int, default=5000, metavar='N', + help='print test statistics after every N iterations (default: 5000)') + +# data +parser.add_argument('--mpiigaze-file', type=str, default='../preprocess/outputs/MPIIGaze.h5', + help='Path to MPIIGaze dataset in HDF format.') +parser.add_argument('--gazecapture-file', type=str, default='../preprocess/outputs/GazeCapture.h5', + help='Path to GazeCapture dataset in HDF format.') +parser.add_argument('--test-subsample', type=float, default=1.0, + help='proportion of test set to use (default: 1.0)') +parser.add_argument('--num-data-loaders', type=int, default=0, metavar='N', + help='number of data loading workers (default: 0)') + +# logging +parser.add_argument('--use-tensorboard', action='store_true', default=False, + help='create tensorboard logs (stored in the args.save_path directory)') +parser.add_argument('--save-path', type=str, default='.', + help='path to save network parameters (default: .)') +parser.add_argument('--show-warnings', action='store_true', default=False, + help='show default Python warnings') + +# image saving +parser.add_argument('--save-freq-images', type=int, default=1000, + help='save sample images after every N iterations (default: 1000)') +parser.add_argument('--save-image-samples', type=int, default=100, + help='Save image outputs for N samples per dataset (default: 100)') + +# evaluation / prediction of outputs +parser.add_argument('--skip-training', action='store_true', + help='skip training to go straight to prediction generation') +parser.add_argument('--eval-batch-size', type=int, default=512, metavar='N', + help='evaluation batch size (default: 512)') + +args = parser.parse_args() + +import h5py +import numpy as np +from collections import OrderedDict +import gc +import json +import time +import os + +import moviepy.editor as mpy + +import torch +import torch.optim as optim +import torch.nn as nn +from torch.utils.data import DataLoader, Subset + +import logging +logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) + +from data import HDFDataset + +# Set device +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# Ignore warnings +if not args.show_warnings: + import warnings + warnings.filterwarnings('ignore') + +############################# +# Sanity check some arguments + +if args.embedding_consistency_loss_type is not None: + assert args.triplet_loss_type is None +if args.triplet_loss_type is not None: + assert args.embedding_consistency_loss_type is None + +if (args.triplet_loss_type == 'angular' + or args.embedding_consistency_loss_type == 'angular'): + assert args.normalize_3d_codes is True +elif (args.triplet_loss_type == 'euclidean' + or args.embedding_consistency_loss_type == 'euclidean'): + assert args.normalize_3d_codes is False + + +def embedding_consistency_loss_weight_at_step(current_step): + final_value = args.coeff_embedding_consistency_loss + if args.embedding_consistency_loss_warmup_samples is None: + return final_value + warmup_steps = int(args.embedding_consistency_loss_warmup_samples / args.batch_size) + if current_step <= warmup_steps: + return (final_value / warmup_steps) * current_step + else: + return final_value + + +##################################################### +# Calculate how to handle learning rate at given step + +max_lr = args.base_lr * args.batch_size +ramp_up_until_step = int(args.warmup_period_for_lr / args.batch_size) +ramp_up_a = (max_lr - args.base_lr) / ramp_up_until_step +ramp_up_b = args.base_lr + + +def learning_rate_at_step(current_step): + if current_step <= ramp_up_until_step: + return ramp_up_a * current_step + ramp_up_b + elif args.decay_interval != 0: + return np.power(args.decay, int((current_step - ramp_up_until_step) + / args.decay_interval)) + else: + return max_lr + + +def update_learning_rate(current_step): + global optimizer + lr = learning_rate_at_step(current_step) + all_param_groups = optimizer.param_groups + for i, param_group in enumerate(all_param_groups): + if i == 0: # Don't do it for the gaze-related weights + param_group['lr'] = lr + + +################################################ +# Create network +from models import DTED +network = DTED( + growth_rate=args.densenet_growthrate, + z_dim_app=args.z_dim_app, + z_dim_gaze=args.z_dim_gaze, + z_dim_head=args.z_dim_head, + decoder_input_c=args.decoder_input_c, + normalize_3d_codes=args.normalize_3d_codes, + normalize_3d_codes_axis=args.normalize_3d_codes_axis, + use_triplet=args.triplet_loss_type is not None, + backprop_gaze_to_encoder=args.backprop_gaze_to_encoder, +) +logging.info(network) + +################################################ +# Transfer on the GPU before constructing and optimizer +if torch.cuda.device_count() > 1: + logging.info('Using %d GPUs!' % torch.cuda.device_count()) + network = nn.DataParallel(network) +network = network.to(device) + +################################################ +# Build optimizers +gaze_lr = 1.0 * args.base_lr +if args.backprop_gaze_to_encoder: + optimizer = optim.SGD( + [ + {'params': [p for n, p in network.named_parameters() if not n.startswith('gaze')]}, + { + 'params': [p for n, p in network.named_parameters() if n.startswith('gaze')], + 'lr': gaze_lr, + }, + ], + lr=args.base_lr, momentum=0.9, + nesterov=True, weight_decay=args.l2_reg) +else: + optimizer = optim.SGD( + [p for n, p in network.named_parameters() if not n.startswith('gaze')], + lr=args.base_lr, momentum=0.9, + nesterov=True, weight_decay=args.l2_reg, + ) + + # one additional optimizer for just gaze estimation head + gaze_optimizer = optim.SGD( + [p for n, p in network.named_parameters() if n.startswith('gaze')], + lr=gaze_lr, momentum=0.9, + nesterov=True, weight_decay=args.l2_reg, + ) + +# Wrap optimizer instances with AMP +if args.use_apex: + from apex import amp + optimizers = ([optimizer] + if args.backprop_gaze_to_encoder + else [optimizer, gaze_optimizer]) + network, optimizers = amp.initialize(network, optimizers, + opt_level='O1', num_losses=len(optimizers)) + if args.backprop_gaze_to_encoder: + optimizer = optimizers[0] + else: + optimizer, gaze_optimizer = optimizers + +logging.info('Initialized optimizer(s)') + +################################################ +# Define loss functions +from losses import (ReconstructionL1Loss, GazeAngularLoss, BatchHardTripletLoss, + AllFrontalsEqualLoss, EmbeddingConsistencyLoss) + +loss_functions = OrderedDict() +loss_functions['recon_l1'] = ReconstructionL1Loss(suffix='b') +loss_functions['gaze'] = GazeAngularLoss() + +if args.triplet_loss_type is not None: + loss_functions['triplet'] = BatchHardTripletLoss( + distance_type=args.triplet_loss_type, + margin=args.triplet_loss_margin, + ) + +if args.all_equal_embeddings: + loss_functions['all_equal'] = AllFrontalsEqualLoss() + +if args.embedding_consistency_loss_type is not None: + loss_functions['embedding_consistency'] = EmbeddingConsistencyLoss( + distance_type=args.embedding_consistency_loss_type, + ) + +################################################ +# Create the train and test datasets. +# We train on the GazeCapture training set +# and test on the val+test set, and the entire MPIIGaze. +all_data = OrderedDict() + + +def worker_init_fn(worker_id): + # Custom worker init to not repeat pairs + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +# Load GazeCapture prefixes with train/val/test split spec. +with open('./gazecapture_split.json', 'r') as f: + all_gc_prefixes = json.load(f) + +# Define single training dataset +train_tag = 'gc/train' +train_prefixes = all_gc_prefixes['train'] +train_dataset = HDFDataset(hdf_file_path=args.gazecapture_file, + prefixes=train_prefixes, + get_2nd_sample=True, + pick_exactly_per_person=args.pick_exactly_per_person, + pick_at_least_per_person=args.pick_at_least_per_person, + ) +train_dataloader = DataLoader(train_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_workers=args.num_data_loaders, + pin_memory=True, + # worker_init_fn=worker_init_fn, + ) +all_data[train_tag] = {'dataset': train_dataset, 'dataloader': train_dataloader} + +# Define multiple validation/test datasets +for tag, hdf_file, prefixes in [('gc/val', args.gazecapture_file, all_gc_prefixes['val']), + ('gc/test', args.gazecapture_file, all_gc_prefixes['test']), + ('mpi', args.mpiigaze_file, None), + ]: + # Define dataset structure based on selected prefixes + dataset = HDFDataset(hdf_file_path=hdf_file, + prefixes=prefixes, + get_2nd_sample=True) + subsample = args.test_subsample + if tag == 'gc/test': # reduce no. of test samples for this case + subsample /= 10.0 + if subsample < (1.0 - 1e-6): # subsample if requested + dataset = Subset(dataset, np.linspace( + start=0, stop=len(dataset), + num=int(subsample * len(dataset)), + endpoint=False, + dtype=np.uint32, + )) + all_data[tag] = { + 'dataset': dataset, + 'dataloader': DataLoader(dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=2, # args.num_data_loaders, + pin_memory=True, + worker_init_fn=worker_init_fn), + } + +# Print some stats. +logging.info('') +for tag, val in all_data.items(): + tag = '[%s]' % tag + dataset = val['dataset'] + original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset + num_original_entries = len(original_dataset) + num_people = len(original_dataset.prefixes) + logging.info('%10s full set size: %7d' % (tag, num_original_entries)) + logging.info('%10s current set size: %7d' % (tag, len(dataset))) + logging.info('%10s num people: %7d' % (tag, num_people)) + logging.info('%10s mean entries per person: %7d' % (tag, num_original_entries / num_people)) + logging.info('') + +logging.info('Prepared Datasets') + +###################################################### +# Utility methods for accessing datasets + + +def send_data_dict_to_gpu(data): + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.detach().to(device, non_blocking=True) + return data + + +###################################################### +# Pre-collect entries for which to generate images for + +for tag, data_dict in all_data.items(): + dataset = data_dict['dataset'] + indices = np.linspace(start=0, stop=len(dataset), endpoint=False, + num=args.save_image_samples, dtype=np.uint32) + retrieved_samples = [dataset[index] for index in indices] + stacked_samples = {} + for k in ['image_a', 'face_a', 'R_gaze_a', 'R_head_a']: + if k in retrieved_samples[0]: + stacked_samples[k] = torch.stack([s[k] for s in retrieved_samples]) + data_dict['to_visualize'] = stacked_samples + + # Have dataloader re-open HDF to avoid multi-processing related errors. + original_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset + original_dataset.close_hdf() + + +################################# +# Latent Space Walk Preparations + + +def R_x(theta): + sin_ = np.sin(theta) + cos_ = np.cos(theta) + return np.array([ + [1., 0., 0.], + [0., cos_, -sin_], + [0., sin_, cos_] + ]). astype(np.float32) + + +def R_y(phi): + sin_ = np.sin(phi) + cos_ = np.cos(phi) + return np.array([ + [cos_, 0., sin_], + [0., 1., 0.], + [-sin_, 0., cos_] + ]). astype(np.float32) + + +walking_spec = [] +for rotation_fn, short, min_d, max_d, num_d in [(R_x, 'x', 45, -45, 15), + (R_y, 'y', -45, 45, 15)]: + degrees = np.linspace(min_d, max_d, num_d, dtype=np.float32) + walking_spec.append({ + 'name': '%s_%d_%d' % (short, min_d, max_d), + 'matrices': [ + torch.from_numpy(np.repeat( + np.expand_dims(rotation_fn(np.radians(deg)), 0), + args.save_image_samples, + axis=0, + )) + for deg in degrees + ], + }) + +identity_rotation = torch.from_numpy(np.repeat( + np.expand_dims(np.eye(3, dtype=np.float32), 0), + args.save_image_samples, + axis=0, +)) + + +def recover_images(x): + # Every specified iterations save sample images + # Note: We're doing this separate to Tensorboard to control which input + # samples we visualize, and also because Tensorboard is an inefficient + # way to store such images. + x = x.cpu().numpy() + x = (x + 1.0) * (255.0 / 2.0) + x = np.clip(x, 0, 255) # Avoid artifacts due to slight under/overflow + x = np.transpose(x, [0, 2, 3, 1]) # CHW to HWC + x = x.astype(np.uint8) + x = x[:, :, :, ::-1] # RGB to BGR for OpenCV + return x + + +############################ +# Load weights if available + +from checkpoints_manager import CheckpointsManager +saver = CheckpointsManager(network, args.save_path) +initial_step = saver.load_last_checkpoint() + +###################### +# Training step update + + +class RunningStatistics(object): + def __init__(self): + self.losses = OrderedDict() + + def add(self, key, value): + if key not in self.losses: + self.losses[key] = [] + self.losses[key].append(value) + + def means(self): + return OrderedDict([ + (k, np.mean(v)) for k, v in self.losses.items() if len(v) > 0 + ]) + + def reset(self): + for key in self.losses.keys(): + self.losses[key] = [] + + +time_epoch_start = None +num_elapsed_epochs = 0 + + +def execute_training_step(current_step): + global train_data_iterator, time_epoch_start, num_elapsed_epochs + time_iteration_start = time.time() + + # Get data + try: + if time_epoch_start is None: + time_epoch_start = time.time() + time_batch_fetch_start = time.time() + input_dict = next(train_data_iterator) + except StopIteration: + # Epoch counter and timer + num_elapsed_epochs += 1 + time_epoch_end = time.time() + time_epoch_diff = time_epoch_end - time_epoch_start + if args.use_tensorboard: + tensorboard.add_scalar('timing/epoch', time_epoch_diff, num_elapsed_epochs) + + # Done with an epoch now...! + if num_elapsed_epochs % 5 == 0: + saver.save_checkpoint(current_step) + + np.random.seed() # Ensure randomness + + # Some cleanup + train_data_iterator = None + torch.cuda.empty_cache() + gc.collect() + + # Restart! + time_epoch_start = time.time() + global train_dataloader + train_data_iterator = iter(train_dataloader) + time_batch_fetch_start = time.time() + input_dict = next(train_data_iterator) + + # get the inputs + input_dict = send_data_dict_to_gpu(input_dict) + running_timings.add('batch_fetch', time.time() - time_batch_fetch_start) + + # zero the parameter gradient + network.train() + optimizer.zero_grad() + if not args.backprop_gaze_to_encoder: + gaze_optimizer.zero_grad() + + # forward + backward + optimize + time_forward_start = time.time() + output_dict, loss_dict = network(input_dict, loss_functions=loss_functions) + # torch.cuda.synchronize() + + # If doing multi-GPU training, just take an average + for key, value in loss_dict.items(): + if value.dim() > 0: + value = torch.mean(value) + loss_dict[key] = value + + # Construct main loss + loss_to_optimize = args.coeff_l1_recon_loss * loss_dict['recon_l1'] + if args.triplet_loss_type is not None: + triplet_losses = [] + triplet_losses = [ + loss_dict['triplet_gaze_' + args.triplet_loss_type], + loss_dict['triplet_head_' + args.triplet_loss_type], + ] + if args.triplet_regularize_d_within: + triplet_losses += [ + loss_dict['triplet_gaze_%s_d_within' % args.triplet_loss_type], + loss_dict['triplet_head_%s_d_within' % args.triplet_loss_type], + ] + loss_to_optimize += 1.0 * sum(triplet_losses) + + if args.embedding_consistency_loss_type is not None: + embedding_consistency_losses = [ + loss_dict['embedding_consistency_gaze_' + args.embedding_consistency_loss_type], + # loss_dict['embedding_consistency_head_' + args.embedding_consistency_loss_type], + ] + coeff_embedding_consistency_loss = embedding_consistency_loss_weight_at_step(current_step) + loss_to_optimize += coeff_embedding_consistency_loss * sum(embedding_consistency_losses) + + if args.all_equal_embeddings: + loss_to_optimize += sum([ + loss_dict['all_equal_gaze'], + loss_dict['all_equal_head'], + ]) + + if args.backprop_gaze_to_encoder: + loss_to_optimize += args.coeff_gaze_loss * loss_dict['gaze'] + + # Learning rate ramp-up until specified no. of samples passed, or decay + update_learning_rate(current_step) + + # Optimize main objective + if args.use_apex: + with amp.scale_loss(loss_to_optimize, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss_to_optimize.backward() + optimizer.step() + + # optimize small gaze part too, separately (if required) + if not args.backprop_gaze_to_encoder: + if args.use_apex: + with amp.scale_loss(loss_dict['gaze'], gaze_optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss_dict['gaze'].backward() + gaze_optimizer.step() + + # Register timing + time_backward_end = time.time() + running_timings.add('forward_and_backward', time_backward_end - time_forward_start) + + # Store values for logging later + for key, value in loss_dict.items(): + loss_dict[key] = value.detach().cpu() + for key, value in loss_dict.items(): + running_losses.add(key, value.numpy()) + + running_timings.add('iteration', time.time() - time_iteration_start) + +#################################### +# Test for particular validation set + + +def execute_test(tag, data_dict): + test_losses = RunningStatistics() + with torch.no_grad(): + for input_dict in data_dict['dataloader']: + network.eval() + input_dict = send_data_dict_to_gpu(input_dict) + output_dict, loss_dict = network(input_dict, loss_functions=loss_functions) + for key, value in loss_dict.items(): + test_losses.add(key, value.detach().cpu().numpy()) + test_loss_means = test_losses.means() + logging.info('Test Losses at [%7d] for %10s: %s' % + (current_step + 1, '[' + tag + ']', + ', '.join(['%s: %.6f' % v for v in test_loss_means.items()]))) + if args.use_tensorboard: + for k, v in test_loss_means.items(): + tensorboard.add_scalar('test/%s/%s' % (tag, k), v, current_step + 1) + + +############ +# Main loop + +num_training_steps = int(args.num_training_epochs * len(train_dataset) / args.batch_size) +if args.skip_training: + num_training_steps = 0 +else: + logging.info('Training') + last_training_step = num_training_steps - 1 + if args.use_tensorboard: + from tensorboardX import SummaryWriter + tensorboard = SummaryWriter(log_dir=args.save_path) + +train_data_iterator = iter(train_dataloader) +running_losses = RunningStatistics() +running_timings = RunningStatistics() +for current_step in range(initial_step, num_training_steps): + + ################ + # Training loop + execute_training_step(current_step) + + if current_step % args.print_freq_train == args.print_freq_train - 1: + conv1_wt_lr = optimizer.param_groups[0]['lr'] + running_loss_means = running_losses.means() + logging.info('Losses at [%7d]: %s' % + (current_step + 1, + ', '.join(['%s: %.5f' % v + for v in running_loss_means.items()]))) + if args.use_tensorboard: + tensorboard.add_scalar('train_lr', conv1_wt_lr, current_step + 1) + for k, v in running_loss_means.items(): + tensorboard.add_scalar('train/' + k, v, current_step + 1) + running_losses.reset() + + # Print some timing statistics + if current_step % 100 == 99: + if args.use_tensorboard: + for k, v in running_timings.means().items(): + tensorboard.add_scalar('timing/' + k, v, current_step + 1) + running_timings.reset() + + # print some memory statistics + if current_step % 5000 == 0: + for i in range(torch.cuda.device_count()): + bytes = (torch.cuda.memory_allocated(device=i) + + torch.cuda.memory_cached(device=i)) + logging.info('GPU %d: probably allocated approximately %.2f GB' % (i, bytes / 1e9)) + + ############### + # Testing loop: every specified iterations compute the test statistics + if (current_step % args.print_freq_test == (args.print_freq_test - 1) + or current_step == last_training_step): + network.eval() + optimizer.zero_grad() + if not args.backprop_gaze_to_encoder: + gaze_optimizer.zero_grad() + torch.cuda.empty_cache() + + for tag, data_dict in list(all_data.items())[1:]: + execute_test(tag, data_dict) + + # This might help with memory leaks + torch.cuda.empty_cache() + + ##################### + # Visualization loop + + # Latent space walks (only store latest results) + if (args.save_image_samples > 0 + and (current_step % args.save_freq_images + == (args.save_freq_images - 1) + or current_step == last_training_step)): + network.eval() + torch.cuda.empty_cache() + with torch.no_grad(): + for tag, data_dict in all_data.items(): + + def save_images(images, dname, stem): + dpath = '%s/walks/%s/%s' % (args.save_path, tag, dname) + if not os.path.isdir(dpath): + os.makedirs(dpath) + for i in range(args.save_image_samples): + # Write single image + frames = [images[j][i] for j in range(len(images))] + # Write video + frames = [f[:, :, ::-1] for f in frames] # BGR to RGB + frames += frames[1:-1][::-1] # continue in reverse + clip = mpy.ImageSequenceClip(frames, fps=15) + clip.write_videofile('%s/%04d_%s.mp4' % (dpath, i, stem), + audio=False, threads=8, + logger=None, verbose=False) + + for spec in walking_spec: # Gaze-direction-walk + output_images = [] + for rotation_mat in spec['matrices']: + adjusted_input = data_dict['to_visualize'].copy() + adjusted_input['R_gaze_b'] = rotation_mat + adjusted_input['R_head_b'] = identity_rotation + adjusted_input = dict([(k, v.to(device)) + for k, v in adjusted_input.items()]) + output_dict = network(adjusted_input) + output_images.append(recover_images(output_dict['image_b_hat'])) + save_images(output_images, 'gaze', spec['name']) + + for spec in walking_spec: # Head-pose-walk + output_images = [] + for rotation_mat in spec['matrices']: + adjusted_input = data_dict['to_visualize'].copy() + adjusted_input['R_gaze_b'] = identity_rotation + adjusted_input['R_head_b'] = rotation_mat + adjusted_input = dict([(k, v.to(device)) + for k, v in adjusted_input.items()]) + output_dict = network(adjusted_input) + output_images.append(recover_images(output_dict['image_b_hat'])) + save_images(output_images, 'head', spec['name']) + + torch.cuda.empty_cache() + +if not args.skip_training: + logging.info('Finished Training') + + # Save model parameters + saver.save_checkpoint(current_step) + + if args.use_tensorboard: + tensorboard.close() + del tensorboard + +# Clean up a bit +optimizer.zero_grad() +del (train_dataloader, train_dataset, all_data, + walking_spec, optimizer, identity_rotation) + +######################################### +# Generating predictions with final model +logging.info('Now generating predictions with final model...') +all_data = OrderedDict() +for tag, hdf_file, prefixes in [('gc/train', args.gazecapture_file, all_gc_prefixes['train']), + ('gc/val', args.gazecapture_file, all_gc_prefixes['val']), + ('gc/test', args.gazecapture_file, all_gc_prefixes['test']), + ('mpi', args.mpiigaze_file, None), + ]: + # Define dataset structure based on selected prefixes + dataset = HDFDataset(hdf_file_path=hdf_file, + prefixes=prefixes, + get_2nd_sample=False) + all_data[tag] = { + 'dataset': dataset, + 'dataloader': DataLoader(dataset, + batch_size=args.eval_batch_size, + shuffle=False, + num_workers=args.num_data_loaders, + pin_memory=True, + worker_init_fn=worker_init_fn), + } +logging.info('') +for tag, val in all_data.items(): + tag = '[%s]' % tag + dataset = val['dataset'] + num_entries = len(dataset) + num_people = len(dataset.prefixes) + logging.info('%10s set size: %7d' % (tag, num_entries)) + logging.info('%10s num people: %7d' % (tag, num_people)) + logging.info('%10s mean entries per person: %7d' % (tag, num_entries / num_people)) + logging.info('') + +# every specified iterations compute the test statistics: +for tag, data_dict in all_data.items(): + current_person_id = None + current_person_data = {} + ofpath = '%s/%s_predictions.h5' % (args.save_path, tag.replace('/', '_')) + ofdir = os.path.dirname(ofpath) + if not os.path.isdir(ofdir): + os.makedirs(ofdir) + h5f = h5py.File(ofpath, 'w') + + def store_person_predictions(): + global current_person_data + if len(current_person_data) > 0: + g = h5f.create_group(current_person_id) + for key, data in current_person_data.items(): + g.create_dataset(key, data=data, dtype=np.float32) + current_person_data = {} + with torch.no_grad(): + np.random.seed() + num_batches = int(np.ceil(len(data_dict['dataset']) / args.eval_batch_size)) + for i, input_dict in enumerate(data_dict['dataloader']): + # Get embeddings + network.eval() + output_dict = network(send_data_dict_to_gpu(input_dict)) + output_dict = dict([(k, v.cpu().numpy()) for k, v in output_dict.items()]) + + # Process output line by line + zipped_data = zip( + input_dict['key'], + input_dict['gaze_a'].cpu().numpy(), + input_dict['head_a'].cpu().numpy(), + output_dict['z_app'], + output_dict['z_gaze_enc'], + output_dict['z_head_enc'], + output_dict['gaze_a_hat'], + ) + for (person_id, gaze, head, z_app, z_gaze, z_head, gaze_hat) in zipped_data: + # Store predictions if moved on to next person + if person_id != current_person_id: + store_person_predictions() + current_person_id = person_id + # Now write it + to_write = { + 'gaze': gaze, + 'head': head, + 'z_app': z_app, + 'z_gaze': z_gaze, + 'z_head': z_head, + 'gaze_hat': gaze_hat, + } + for k, v in to_write.items(): + if k not in current_person_data: + current_person_data[k] = [] + current_person_data[k].append(v.astype(np.float32)) + logging.info('[%s] processed batch [%04d/%04d] with %d entries.' % + (tag, i + 1, num_batches, len(next(iter(input_dict.values()))))) + store_person_predictions() + logging.info('Completed processing %s' % tag) +logging.info('Done') diff --git a/src/2_meta_learning.py b/src/2_meta_learning.py new file mode 100644 index 0000000..8c17527 --- /dev/null +++ b/src/2_meta_learning.py @@ -0,0 +1,733 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + + +import argparse +import os +import pickle +import random +from collections import OrderedDict + +import h5py +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable as V +from tensorboardX import SummaryWriter +from tqdm import tqdm + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + +""" + Utility functions +""" + + +def angular_error(a, b): + """Calculate angular error (via cosine similarity).""" + a = pitchyaw_to_vector(a) if a.shape[1] == 2 else a + b = pitchyaw_to_vector(b) if b.shape[1] == 2 else b + + ab = np.sum(np.multiply(a, b), axis=1) + a_norm = np.linalg.norm(a, axis=1) + b_norm = np.linalg.norm(b, axis=1) + + # Avoid zero-values (to avoid NaNs) + a_norm = np.clip(a_norm, a_min=1e-6, a_max=None) + b_norm = np.clip(b_norm, a_min=1e-6, a_max=None) + + similarity = np.divide(ab, np.multiply(a_norm, b_norm)) + similarity = np.clip(similarity, a_min=-1.0 + 1e-6, a_max=1.0 - 1e-6) + + return np.degrees(np.arccos(similarity)) + + +def nn_angular_error(y, y_hat): + sim = F.cosine_similarity(y, y_hat, eps=1e-6) + sim = F.hardtanh(sim, -1.0 + 1e-6, 1.0 - 1e-6) + return torch.acos(sim) * (180 / np.pi) + + +def nn_mean_angular_loss(y, y_hat): + return torch.mean(nn_angular_error(y, y_hat)) + + +def nn_mean_asimilarity(y, y_hat): + return torch.mean(1.0 - F.cosine_similarity(y, y_hat, eps=1e-6)) + + +def pitchyaw_to_vector(pitchyaws): + n = pitchyaws.shape[0] + sin = np.sin(pitchyaws) + cos = np.cos(pitchyaws) + out = np.empty((n, 3)) + out[:, 0] = np.multiply(cos[:, 0], sin[:, 1]) + out[:, 1] = sin[:, 0] + out[:, 2] = np.multiply(cos[:, 0], cos[:, 1]) + return out + + +""" + Tasks class for grabbing training/testing samples +""" + + +class Tasks(object): + def __init__(self, hdf_path, x_keys=['z_gaze']): + + # Select tasks for which min. 1000 entries exist + self.data = h5py.File(hdf_path, 'r') + previous_len = len(self.data.keys()) + self.selected_tasks = [k for k in self.data.keys() + if self.data[k + '/gaze'].len() > 1000] + self.num_tasks = len(self.selected_tasks) + + # Now load in all data into memory for selected tasks + self.processed_data = [] + for task in self.selected_tasks: + num_entries = self.data[task + '/gaze'].len() + xs = np.concatenate([ + np.array(self.data[task + '/' + key]).reshape(num_entries, -1) + for key in x_keys + ], axis=1) + ys = pitchyaw_to_vector(np.array(self.data[task + '/gaze']).reshape(-1, 2)) + self.processed_data.append((xs, ys)) + print('Loaded %s (%d -> %d tasks)' % (os.path.basename(hdf_path), + previous_len, self.num_tasks)) + + # By default, we just sample disjoint sets from the entire given data + self.all_indices = [list(range(len(entries[0]))) + for entries in self.processed_data] + self.train_indices = self.all_indices + self.test_indices = self.all_indices + + def create_sample(self, task_index, indices): + """Create a sample of a task for meta-learning. + + This consists of a x, y pair. + """ + xs, ys = zip(*[(self.processed_data[task_index][0][i], + self.processed_data[task_index][1][i]) + for i in indices]) + xs, ys = np.array(xs), np.array(ys) + return (torch.Tensor(xs).to(device), + torch.Tensor(ys).to(device)) + + def sample(self, num_train=4, num_test=100): + """Yields training and testing samples.""" + picked_task = random.randint(0, self.num_tasks - 1) + return self.sample_for_task(picked_task, num_train=num_train, + num_test=num_test) + + def sample_for_task(self, task, num_train=4, num_test=100): + if self.train_indices[task] is self.test_indices[task]: + # This is for meta-training and meta-validation + indices = random.sample(self.all_indices[task], num_train + num_test) + train_indices = indices[:num_train] + test_indices = indices[-num_test:] + else: + # This is for meta-testing + train_indices = random.sample(self.train_indices[task], num_train) + test_indices = self.test_indices[task] + return (self.create_sample(task, train_indices), + self.create_sample(task, test_indices)) + + +class TestTasks(Tasks): + """Class for final testing (not testing within meta-learning.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.train_indices = [indices[:-500] for indices in self.all_indices] + self.test_indices = [indices[-500:] for indices in self.all_indices] + + +""" + Replacement classes for standard PyTorch Module and Linear. +""" + + +class ModifiableModule(nn.Module): + def params(self): + return [p for _, p in self.named_params()] + + def named_leaves(self): + return [] + + def named_submodules(self): + return [] + + def named_params(self): + subparams = [] + for name, mod in self.named_submodules(): + for subname, param in mod.named_params(): + subparams.append((name + '.' + subname, param)) + return self.named_leaves() + subparams + + def set_param(self, name, param, copy=False): + if '.' in name: + n = name.split('.') + module_name = n[0] + rest = '.'.join(n[1:]) + for name, mod in self.named_submodules(): + if module_name == name: + mod.set_param(rest, param, copy=copy) + break + else: + if copy is True: + setattr(self, name, V(param.data.clone(), requires_grad=True)) + else: + assert hasattr(self, name) + setattr(self, name, param) + + def copy(self, other, same_var=False): + for name, param in other.named_params(): + self.set_param(name, param, copy=not same_var) + + +class GradLinear(ModifiableModule): + def __init__(self, *args, **kwargs): + super().__init__() + ignore = nn.Linear(*args, **kwargs).to(device) + + nn.init.normal_(ignore.weight.data, mean=0.0, std=np.sqrt(1. / args[0])) + nn.init.constant_(ignore.bias.data, val=0) + + self.weights = V(ignore.weight.data, requires_grad=True).to(device) + self.bias = V(ignore.bias.data, requires_grad=True).to(device) + + def forward(self, x): + return F.linear(x, self.weights, self.bias).to(device) + + def named_leaves(self): + return [('weights', self.weights), ('bias', self.bias)] + + +""" + Meta-learnable fully-connected neural network model definition +""" + + +class GazeEstimationModel(ModifiableModule): + def __init__(self, activation_type='relu', layer_num_features=[48, 64, 3], make_alpha=False): + super().__init__() + self.activation_type = activation_type + + # Construct layers + self.layer_num_features = layer_num_features + self.layers = [] + for i, f_now in enumerate(self.layer_num_features[:-1]): + f_next = self.layer_num_features[i + 1] + layer = GradLinear(f_now, f_next) + self.layers.append(('layer%02d' % (i + 1), layer)) + + # For use with Meta-SGD + self.alphas = [] + if make_alpha: + for i, f_now in enumerate(self.layer_num_features[:-1]): + f_next = self.layer_num_features[i + 1] + alphas = GradLinear(f_now, f_next) + alphas.weights.data.uniform_(0.005, 0.1) + alphas.bias.data.uniform_(0.005, 0.1) + self.alphas.append(('alpha%02d' % (i + 1), alphas)) + + def clone(self, make_alpha=None): + if make_alpha is None: + make_alpha = (self.alphas is not None and len(self.alphas) > 0) + new_model = self.__class__(self.activation_type, self.layer_num_features, + make_alpha=make_alpha) + new_model.copy(self) + return new_model + + def state_dict(self): + output = {} + for key, layer in self.layers: + output[key + '.weights'] = layer.weights.data + output[key + '.bias'] = layer.bias.data + return output + + def load_state_dict(self, weights): + for key, tensor in weights.items(): + self.set_param(key, tensor, copy=True) + + def forward(self, x): + for name, layer in self.layers[:-1]: + x = layer(x) + if self.activation_type == 'relu': + x = F.relu_(x) + elif self.activation_type == 'leaky_relu': + x = F.leaky_relu_(x) + elif self.activation_type == 'elu': + x = F.elu_(x) + elif self.activation_type == 'selu': + x = F.selu_(x) + elif self.activation_type == 'tanh': + x = torch.tanh_(x) + elif self.activation_type == 'sigmoid': + x = torch.sigmoid_(x) + elif self.activation_type == 'none': + pass + else: + raise ValueError('Unknown activation function "%s"' % self.activation_type) + x = self.layers[-1][1](x) # No activation on output of last layer + x = F.normalize(x, dim=-1) # Normalize + return x + + def named_submodules(self): + return self.layers + self.alphas + + +class GazeEstimationModelPreExtended(ModifiableModule): + def __init__(self): + super().__init__() + + # Construct layers + self.layer00 = GradLinear(640, 118) # 64 + 2*3 + 16*3 + self.layer01 = GradLinear(118, 64) + self.layer02 = GradLinear(64, 3) + self.layers = [('layer00', self.layer00), + ('layer01', self.layer01), + ('layer02', self.layer02)] + + def clone(self, make_alpha=None): + new_model = self.__class__() + new_model.copy(self) + return new_model + + def forward(self, x): + x = self.layer00(x) + x = x[:, 64:70] # Extract at hardcoded z_gaze indices + x = F.selu_(x) + x = self.layer01(x) + x = F.selu_(x) + x = self.layer02(x) + x = F.normalize(x, dim=-1) # Normalize + return x + + def named_submodules(self): + return self.layers + + +""" + Meta-learning utility functions. +""" + + +def forward_and_backward(model, data, optim=None, create_graph=False, + train_data=None, loss_function=nn_mean_angular_loss): + model.train() + if optim is not None: + optim.zero_grad() + loss = forward(model, data, train_data=train_data, for_backward=True, + loss_function=loss_function) + loss.backward(create_graph=create_graph, retain_graph=(optim is None)) + if optim is not None: + # nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optim.step() + return loss.data.cpu().numpy() + + +def forward(model, data, return_predictions=False, train_data=None, + for_backward=False, loss_function=nn_mean_angular_loss): + model.train() + x, y = data + y_hat = model(V(x)) + loss = loss_function(y_hat, V(y)) + if return_predictions: + return y_hat.data.cpu().numpy() + elif for_backward: + return loss + else: + return loss.data.cpu().numpy() + + +""" + Inference through model (with/without gradient calculation) +""" + + +class MAML(object): + def __init__(self, model, k, output_dir='./outputs/', + train_tasks=None, valid_tasks=None, no_tensorboard=False): + self.model = model + self.meta_model = model.clone() + + self.train_tasks = train_tasks + self.valid_tasks = valid_tasks + self.k = k + + self.output_dir = None + self.tensorboard = None + if output_dir is not None: + self.output_dir = '%s/%s_%02d' % (output_dir, self.__class__.__name__, k) + if not os.path.isdir(self.output_dir): + os.makedirs(self.output_dir) + + if not no_tensorboard: + self.tensorboard = SummaryWriter(self.output_dir) + + @property + def model_parameters_path(self): + return '%s/meta_learned_parameters.pth.tar' % self.output_dir + + def save_model_parameters(self): + if self.output_dir is not None: + torch.save(self.model.state_dict(), self.model_parameters_path) + + def load_model_parameters(self): + if os.path.isfile(self.model_parameters_path): + weights = torch.load(self.model_parameters_path) + self.model.load_state_dict(weights) + print('> Loaded weights from %s' % self.model_parameters_path) + + def train(self, steps_outer, steps_inner=1, lr_inner=0.01, lr_outer=0.001, + disable_tqdm=False): + self.lr_inner = lr_inner + print('\nBeginning meta-learning for k = %d' % self.k) + print('> Please check tensorboard logs for progress.\n') + + # Outer loop optimizer + optimizer = torch.optim.Adam(self.model.params(), lr=lr_outer) + + # Model and optimizer for validation + valid_model = self.model.clone() + valid_optim = torch.optim.SGD(valid_model.params(), lr=self.lr_inner) + + for i in tqdm(range(steps_outer), disable=disable_tqdm): + for j in range(steps_inner): + # Make copy of main model + self.meta_model.copy(self.model, same_var=True) + + # Get a task + train_data, test_data = self.train_tasks.sample(num_train=self.k) + + # Run the rest of the inner loop + task_loss = self.inner_loop(train_data, self.lr_inner) + + # Calculate gradients on a held-out set + new_task_loss = forward_and_backward( + self.meta_model, test_data, train_data=train_data, + ) + + # Update the main model + optimizer.step() + optimizer.zero_grad() + + if (i + 1) % 100 == 0: + # Log to Tensorflow + if self.tensorboard is not None: + self.tensorboard.add_scalar('meta-train/train-loss', task_loss, i) + self.tensorboard.add_scalar('meta-train/valid-loss', new_task_loss, i) + + # Validation + losses = [] + for j in range(self.valid_tasks.num_tasks): + valid_model.copy(self.model) + train_data, test_data = self.valid_tasks.sample_for_task(j, num_train=self.k) + train_loss = forward_and_backward(valid_model, train_data, valid_optim) + valid_loss = forward(valid_model, test_data, train_data=train_data) + losses.append((train_loss, valid_loss)) + train_losses, valid_losses = zip(*losses) + if self.tensorboard is not None: + self.tensorboard.add_scalar('meta-valid/train-loss', np.mean(train_losses), i) + self.tensorboard.add_scalar('meta-valid/valid-loss', np.mean(valid_losses), i) + + # Save MAML initial parameters + self.save_model_parameters() + + def test(self, test_tasks_list, num_iterations=[1, 5, 10], num_repeats=20): + print('\nBeginning testing for meta-learned model with k = %d\n' % self.k) + model = self.model.clone() + + # IMPORTANT + # + # Sets consistent seed such that as long as --num-test-repeats is the + # same, experiment results from multiple invocations of this script can + # yield the same calibration samples. + random.seed(4089213955) + + for test_set_name, test_tasks in test_tasks_list.items(): + predictions = OrderedDict() + losses = OrderedDict([(n, []) for n in num_iterations]) + for i, task_name in enumerate(test_tasks.selected_tasks): + predictions[task_name] = [] + for t in range(num_repeats): + model.copy(self.model) + optim = torch.optim.SGD(model.params(), lr=self.lr_inner) + + train_data, test_data = test_tasks.sample_for_task(i, num_train=self.k) + if num_iterations[0] == 0: + train_loss = forward(model, train_data) + test_loss = forward(model, test_data, train_data=train_data) + losses[0].append((train_loss, test_loss)) + for j in range(np.amax(num_iterations)): + train_loss = forward_and_backward(model, train_data, optim) + if (j + 1) in num_iterations: + test_loss = forward(model, test_data, train_data=train_data) + losses[j + 1].append((train_loss, test_loss)) + + # Register ground truth and prediction + predictions[task_name].append({ + 'groundtruth': test_data[1].cpu().numpy(), + 'predictions': forward(model, test_data, + return_predictions=True, + train_data=train_data), + }) + predictions[task_name][-1]['errors'] = angular_error( + predictions[task_name][-1]['groundtruth'], + predictions[task_name][-1]['predictions'], + ) + + print('Done for k = %3d, %s/%s... train: %.3f, test: %.3f' % ( + self.k, test_set_name, task_name, + np.mean([both[0] for both in losses[num_iterations[-1]][-num_repeats:]]), + np.mean([both[1] for both in losses[num_iterations[-1]][-num_repeats:]]), + )) + + if self.output_dir is not None: + # Save predictions to file + pkl_path = '%s/predictions_%s.pkl' % (self.output_dir, test_set_name) + with open(pkl_path, 'wb') as f: + pickle.dump(predictions, f) + + # Finally, log values to tensorboard + if self.tensorboard is not None: + for n, v in losses.items(): + train_losses, test_losses = zip(*v) + stem = 'meta-test/%s/' % test_set_name + self.tensorboard.add_scalar(stem + 'train-loss', np.mean(train_losses), n) + self.tensorboard.add_scalar(stem + 'valid-loss', np.mean(test_losses), n) + + # Write loss values as plain text too + np.savetxt('%s/losses_%s_train.txt' % (self.output_dir, test_set_name), + [[n, np.mean(list(zip(*v))[0])] for n, v in losses.items()]) + np.savetxt('%s/losses_%s_valid.txt' % (self.output_dir, test_set_name), + [[n, np.mean(list(zip(*v))[1])] for n, v in losses.items()]) + + out_msg = '> Completed test on %s for k = %d' % (test_set_name, self.k) + final_n = sorted(num_iterations)[-1] + final_train_losses, final_test_losses = zip(*(losses[final_n])) + out_msg += ('\n at %d steps losses were... train: %.3f, test: %.3f +/- %.3f' % + (final_n, np.mean(final_train_losses), + np.mean(final_test_losses), + np.mean([ + np.std([ + data['errors'] for data in person_data + ], axis=0) + for person_data in predictions.values() + ]))) + print(out_msg) + + def inner_loop(self, train_data, lr_inner=0.01): + # Forward-pass and calculate gradients on meta model + loss = forward_and_backward(self.meta_model, train_data, + create_graph=True) + + # Apply gradients + for name, param in self.meta_model.named_params(): + self.meta_model.set_param(name, param - lr_inner * param.grad) + return loss + + +class FOMAML(MAML): + def inner_loop(self, train_data, lr_inner=0.01): + # Forward-pass and calculate gradients on meta model + loss = forward_and_backward(self.meta_model, train_data) + + # Apply gradients + for name, param in self.meta_model.named_params(): + grad = V(param.grad.detach().data) + self.meta_model.set_param(name, param - lr_inner * grad) + return loss + + +class MetaSGD(MAML): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = self.model.clone(make_alpha=True) + self.meta_model = self.model.clone() + + def inner_loop(self, train_data, lr_inner=0.01): + # Forward-pass and calculate gradients on meta model + loss = forward_and_backward(self.meta_model, train_data, + create_graph=True) + + # Apply gradients + named_params = dict(self.meta_model.named_params()) + for name, param in named_params.items(): + if name.startswith('layer'): + alpha = named_params['alpha' + str(name[5:])] + self.meta_model.set_param(name, param - lr_inner * alpha * param.grad) + return loss + + +class NONE(MAML): + def train(self, steps_outer, steps_inner=1, lr_inner=0.01, lr_outer=0.001, + disable_tqdm=False): + self.lr_inner = lr_inner + + # Save randomly initialized MLP parameters + self.save_model_parameters() + + +""" + Actual run script +""" + +if __name__ == '__main__': + + # Available meta-learning methods + meta_learner_classes = { + 'MAML': MAML, + 'FOMAML': FOMAML, + 'Meta-SGD': MetaSGD, + 'NONE': NONE, + } + + # Define and parse configuration for training and evaluations + parser = argparse.ArgumentParser(description='Meta-learn gaze estimator from RotAE embeddings.') + parser.add_argument('input_dir', type=str, + help='Input directory for experiment data') + parser.add_argument('--output-dir', type=str, default='./', + help='Output directory for tensorboard log relative to input dir') + parser.add_argument('--no-tensorboard', action='store_true', + help='Log training and validation progress to tensorboard.') + parser.add_argument('--disable-tqdm', action='store_true', + help='Disable progress bar from tqdm (in particular on NGC).') + + parser.add_argument('--maml-use-pretrained-mlp', action='store_true', + help='Even for MAML, use pre-trained MLP paramters.') + + # Gaze estimation neural network configuration + parser.add_argument('--select-z', type=str, default='z_gaze', + help='Embeddings/features to select for using as input to MAML ' + + '(default: z_gaze)') + parser.add_argument('--layer-num-features', type=str, default='64', + help='Network configuration, number of FC features delimited by \',\' ' + + '(default: 64)') + parser.add_argument('--activation', type=str, default='selu', + choices=['sigmoid', 'relu', 'leaky_relu', 'elu', 'selu', 'tanh', 'none'], + help='Neural network activation function.') + + # Parameters for meta-learning + parser.add_argument('--meta-learner', type=str, default='MAML', + choices=list(meta_learner_classes.keys()), + help='Meta-learning algorithm') + parser.add_argument('--steps-meta-training', type=int, default=100000, + help='Number of steps to meta-learn for (default: 100000)') + parser.add_argument('--tasks-per-meta-iteration', type=int, default=5, + help='Tasks to evaluate per meta-learning iteration (default: 5)') + parser.add_argument('--lr-inner', type=float, default=1e-5, + help='Learning rate for inner loop (for the task) (default: 1e-5)') + parser.add_argument('--lr-outer', type=float, default=1e-3, + help='Learning rate for outer loop (the meta learner) (default: 1e-3)') + + # Evaluation + parser.add_argument('--skip-training', action='store_true', + help='Skips meta-training') + parser.add_argument('k', type=int, + help='Number of calibration samples to use - k as in k-shot learning.') + parser.add_argument('--num-test-repeats', type=int, default=10, + help='Number of times to repeat drawing of k samples for testing ' + + '(default: 10)') + parser.add_argument('--steps-testing', type=int, default=1000, + help='Number of steps to meta-learn for (default: 1000)') + + args = parser.parse_args() + + # Define data sources (tasks) + x_keys = args.select_z.split(',') + meta_train_tasks = Tasks(args.input_dir + '/gc_train_predictions.h5', x_keys=x_keys) + meta_val_tasks = Tasks(args.input_dir + '/gc_val_predictions.h5', x_keys=x_keys) + meta_test_tasks = [ + ('gc', TestTasks(args.input_dir + '/gc_test_predictions.h5', x_keys=x_keys)), + ('mpi', TestTasks(args.input_dir + '/mpi_predictions.h5', x_keys=x_keys)), + ] + + # Construct output directory path string + output_dir = None + if args.output_dir is not None: + output_dir = (os.path.realpath(args.input_dir + '/' + args.output_dir) + if args.output_dir[0] != '/' else args.output_dir) + output_dir += '/' + output_dir += 'Zg' + output_dir += '_OLR%.0e' % args.lr_outer + output_dir += '_IN%d' % args.tasks_per_meta_iteration + output_dir += '_ILR%.0e' % args.lr_inner + # output_dir += '_OutN%e' % args.steps_meta_training + output_dir += '_Net%s' % args.layer_num_features.replace(',', '-') + + # Get an example entry to design gaze estimation model + sample_train, _ = meta_train_tasks.sample(num_train=1, num_test=0) + + # Training configuration + layer_num_features = [int(f) for f in args.layer_num_features.split(',')] + layer_num_features = [sample_train[0].shape[1]] + layer_num_features + [3] + if not args.select_z == 'before_z': + model = GazeEstimationModel(activation_type=args.activation, + layer_num_features=layer_num_features) + else: + assert args.maml_use_pretrained_mlp is True + assert args.layer_num_features == '64' + assert args.activation == 'selu' + model = GazeEstimationModelPreExtended() + meta_learner_class = meta_learner_classes[args.meta_learner] + meta_learner = meta_learner_class(model, args.k, output_dir, + meta_train_tasks, meta_val_tasks, + no_tensorboard=args.no_tensorboard) + + # If doing fine-tuning... try to load pre-trained MLP weights + if args.meta_learner == 'NONE' or args.maml_use_pretrained_mlp: + import glob + checkpoint_path = sorted( + glob.glob('%s/checkpoints/at_step_*.pth.tar' % args.input_dir) + )[-1] + weights = torch.load(checkpoint_path) + try: + state_dict = { + 'layer01.weights': weights['module.gaze1.weight'], + 'layer01.bias': weights['module.gaze1.bias'], + 'layer02.weights': weights['module.gaze2.weight'], + 'layer02.bias': weights['module.gaze2.bias'], + } + if args.select_z == 'before_z': + state_dict['layer00.weights'] = weights['module.fc_enc.weight'] + state_dict['layer00.bias'] = weights['module.fc_enc.bias'] + except: # noqa + state_dict = { + 'layer01.weights': weights['gaze1.weight'], + 'layer01.bias': weights['gaze1.bias'], + 'layer02.weights': weights['gaze2.weight'], + 'layer02.bias': weights['gaze2.bias'], + } + if args.select_z == 'before_z': + state_dict['layer00.weights'] = weights['fc_enc.weight'] + state_dict['layer00.bias'] = weights['fc_enc.bias'] + for key, values in state_dict.items(): + model.set_param(key, values, copy=True) + del state_dict + print('Loaded %s' % checkpoint_path) + + if not args.skip_training: + meta_learner.train( + steps_outer=args.steps_meta_training, + steps_inner=args.tasks_per_meta_iteration, + lr_inner=args.lr_inner, + lr_outer=args.lr_outer, + disable_tqdm=args.disable_tqdm, + ) + + # Perform test (which entails the repeated training of person-specific models + if args.skip_training: + meta_learner.load_model_parameters() + meta_learner.lr_inner = args.lr_inner + meta_learner.test( + test_tasks_list=OrderedDict(meta_test_tasks), + num_iterations=list(np.arange(start=0, stop=args.steps_testing + 1, step=20)), + num_repeats=args.num_test_repeats, + ) diff --git a/src/3_combine_maml_results.py b/src/3_combine_maml_results.py new file mode 100644 index 0000000..e165e46 --- /dev/null +++ b/src/3_combine_maml_results.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +import argparse +import os +import pickle +import re +from collections import OrderedDict + +import matplotlib.pyplot as plt +import numpy as np +from tensorboardX import SummaryWriter + +pickles_to_process = OrderedDict([ + ('GazeCapture (test)', 'predictions_gc.pkl'), + ('MPIIGaze', 'predictions_mpi.pkl'), +]) + + +def process_dir(input_dir, meta_learner_identifier): + """Process experiment directory.""" + selected_dirs = OrderedDict() + candidate_dirs = sorted([ + d for d in os.listdir(input_dir) if os.path.isdir(input_dir + '/' + d) + ]) + for exp_dir in candidate_dirs: + maml_dirs = sorted([ + p for p in os.listdir('%s/%s' % (input_dir, exp_dir)) + if re.match(r'^%s_\d{2,4}$' % meta_learner_identifier, p) + ], key=lambda x: int(x.split('_')[-1])) + if len(maml_dirs) > 0: + selected_dirs[exp_dir] = [ # get full paths + '%s/%s/%s' % (input_dir, exp_dir, p) + for p in maml_dirs + ] + + for exp_dir, maml_dirs in selected_dirs.items(): + for dataset, pkl_fname in pickles_to_process.items(): + data = get_all_data(maml_dirs, fname=pkl_fname) + + output_path = '%s/%s %s %s.pdf' % (input_dir, exp_dir, + meta_learner_identifier, dataset) + plot_mean_error_with_bars(dataset, data, output_path, + meta_learner_identifier) + + +def get_all_data(all_dirs, fname): + """Process individual outputs for different k.""" + all_data = OrderedDict() + for d in all_dirs: + k = int(d.split('_')[-1]) + ifpath = '%s/%s' % (d, fname) + if os.path.isfile(ifpath): + with open(ifpath, 'rb') as f: + all_data[k] = pickle.load(f) + else: + print('Skipping %s' % ifpath) + return all_data + + +def common_post(dataset, output_path): + plt.title(dataset) + plt.xlabel('k') + plt.ylabel('Mean Test Error') + + plt.grid() + plt.tight_layout() + + plt.savefig(output_path) + print('> Wrote to %s' % output_path) + + +def plot_mean_error_with_bars(dataset, data, output_path, + meta_learner_identifier): + """Plot standard deviation of mean errors over trials.""" + # Pick out errors from people into single list + errors = [ + ( + k, + np.concatenate([ + np.concatenate([ + trial_data['errors'].reshape(-1, 1) + for trial_data in person_data + ], axis=1) + for person_data in k_data.values() + ], axis=0), + ) + for k, k_data in data.items() + ] + + ks = [k for k, _ in errors] + ys = [np.mean(y.reshape(-1)) for _, y in errors] + es = [np.std(np.mean(y, axis=0)) for _, y in errors] + print('means: ', ys) + print('stddev: ', es) + + plt.clf() + plt.errorbar(ks, ys, yerr=es, fmt='.-', capsize=5) + common_post(dataset, output_path) + + # Write means to file + np.savetxt(output_path[:-3] + 'txt', np.vstack([ + np.array(ks).reshape(1, -1), + np.array(ys).reshape(1, -1), + np.array(es).reshape(1, -1), + ]), fmt='%f') + + # Write means to tensorboard + tensorboard = SummaryWriter(os.path.dirname(output_path)) + for k, e in zip(ks, ys): + tensorboard.add_scalar( + 'meta-test-final/%s/%s' % (meta_learner_identifier, dataset), e, k) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Merge MAML outputs on NGC') + parser.add_argument('input_dir', type=str, + help='Training output directory to source MAML predictions from.') + parser.add_argument('--meta-learner', type=str, choices=['MAML', 'NONE'], + default='MAML', help='Select meta learning output to use') + args = parser.parse_args() + process_dir(args.input_dir, args.meta_learner) diff --git a/src/checkpoints_manager.py b/src/checkpoints_manager.py new file mode 100644 index 0000000..9e02b51 --- /dev/null +++ b/src/checkpoints_manager.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + + +import logging +import os + +import torch + +ckpt_extension = '.pth.tar' +ckpt_fmtstring = 'at_step_%07d' + ckpt_extension + + +def step_number_from_fname(fpath): + fname = fpath.split('/')[-1] + stem = fname.split('.')[0] + return int(stem.split('_')[-1]) + + +class CheckpointsManager(object): + + def __init__(self, network, output_dir): + self.network = network + self.output_dir = os.path.realpath(output_dir + '/checkpoints') + + @property + def all_available_checkpoint_files(self): + if not os.path.isdir(self.output_dir): + return [] + fpaths = [ + (step_number_from_fname(p), self.output_dir + '/' + p) + for p in os.listdir(self.output_dir) + if os.path.isfile(self.output_dir + '/' + p) + and p.endswith(ckpt_extension) + ] + fpaths = sorted(fpaths) # sort by step number + return fpaths + + def load_last_checkpoint(self): + available_fpaths = self.all_available_checkpoint_files + if len(available_fpaths) > 0: + step_number, fpath = available_fpaths[-1] + logging.info('Found weights file: %s' % fpath) + loaded_step_number = self.load_checkpoint(step_number, fpath) + return loaded_step_number + return 0 + + def load_checkpoint(self, step_number, checkpoint_fpath): + assert os.path.isfile(checkpoint_fpath) + weights = torch.load(checkpoint_fpath) + + # If was stored using DataParallel but being read on 1 GPU + if torch.cuda.device_count() == 1: + if next(iter(weights.keys())).startswith('module.'): + weights = dict([(k[7:], v) for k, v in weights.items()]) + + self.network.load_state_dict(weights) + logging.info('Loaded known model weights at step %d' % step_number) + return step_number + + def save_checkpoint(self, step_number): + assert os.path.isdir(os.path.abspath(self.output_dir + '/../')) + fname = ckpt_fmtstring % step_number + if not os.path.isdir(self.output_dir): + os.makedirs(self.output_dir) + ofpath = '%s/%s' % (self.output_dir, fname) + torch.save(self.network.state_dict(), ofpath) + torch.cuda.empty_cache() diff --git a/src/data.py b/src/data.py new file mode 100644 index 0000000..5b3e000 --- /dev/null +++ b/src/data.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +import os +import torch +import numpy as np +from torch.utils.data import Dataset + +import cv2 as cv +import h5py + + +class HDFDataset(Dataset): + """Dataset from HDF5 archives formed of 'groups' of specific persons.""" + + def __init__(self, hdf_file_path, + prefixes=None, + get_2nd_sample=False, + pick_exactly_per_person=None, + pick_at_least_per_person=None): + assert os.path.isfile(hdf_file_path) + self.get_2nd_sample = get_2nd_sample + self.pick_exactly_per_person = pick_exactly_per_person + self.hdf_path = hdf_file_path + self.hdf = None # h5py.File(hdf_file, 'r') + + with h5py.File(self.hdf_path, 'r', libver='latest', swmr=True) as h5f: + hdf_keys = sorted(list(h5f.keys())) + self.prefixes = hdf_keys if prefixes is None else prefixes + if pick_exactly_per_person is not None: + assert pick_at_least_per_person is None + # Pick exactly x many entries from front of group + self.prefixes = [ + k for k in self.prefixes if k in h5f + and len(next(iter(h5f[k].values()))) >= pick_exactly_per_person + ] + self.index_to_query = sum([ + [(prefix, i) for i in range(pick_exactly_per_person)] + for prefix in self.prefixes + ], []) + elif pick_at_least_per_person is not None: + assert pick_exactly_per_person is None + # Pick people for which there exists at least x many entries + self.prefixes = [ + k for k in self.prefixes if k in h5f + and len(next(iter(h5f[k].values()))) >= pick_at_least_per_person + ] + self.index_to_query = sum([ + [(prefix, i) for i in range(len(next(iter(h5f[prefix].values()))))] + for prefix in self.prefixes + ], []) + else: + # Pick all entries of person + self.prefixes = [ # to address erroneous inputs + k for k in self.prefixes if k in h5f + and len(next(iter(h5f[k].values()))) > 0 + ] + self.index_to_query = sum([ + [(prefix, i) for i in range(len(next(iter(h5f[prefix].values()))))] + for prefix in self.prefixes + ], []) + + def __len__(self): + return len(self.index_to_query) + + def close_hdf(self): + if self.hdf is not None: + self.hdf.close() + self.hdf = None + + def preprocess_image(self, image): + ycrcb = cv.cvtColor(image, cv.COLOR_RGB2YCrCb) + ycrcb[:, :, 0] = cv.equalizeHist(ycrcb[:, :, 0]) + image = cv.cvtColor(ycrcb, cv.COLOR_YCrCb2RGB) + image = np.transpose(image, [2, 0, 1]) # Colour image + image = 2.0 * image / 255.0 - 1 + return image + + def preprocess_entry(self, entry): + for key, val in entry.items(): + if isinstance(val, np.ndarray): + entry[key] = torch.from_numpy(val.astype(np.float32)) + elif isinstance(val, int): + # NOTE: maybe ints should be signed and 32-bits sometimes + entry[key] = torch.tensor(val, dtype=torch.int16, requires_grad=False) + return entry + + def __getitem__(self, idx): + if self.hdf is None: # Need to lazy-open this to avoid read error + self.hdf = h5py.File(self.hdf_path, 'r', libver='latest', swmr=True) + + # Pick entry a and b from same person + key_a, idx_a = self.index_to_query[idx] + group_a = self.hdf[key_a] + group_b = group_a + all_indices = list(range(len(next(iter(group_a.values()))))) + all_indices_but_a = np.delete(all_indices, idx_a) + idx_b = np.random.choice(all_indices_but_a) + + def retrieve(group, index): + eyes = self.preprocess_image(group['pixels'][index, :]) + g = group['labels'][index, :2] + h = group['labels'][index, 2:4] + return eyes, g, h + + # Functions to calculate relative rotation matrices for gaze dir. and head pose + def R_x(theta): + sin_ = np.sin(theta) + cos_ = np.cos(theta) + return np.array([ + [1., 0., 0.], + [0., cos_, -sin_], + [0., sin_, cos_] + ]). astype(np.float32) + + def R_y(phi): + sin_ = np.sin(phi) + cos_ = np.cos(phi) + return np.array([ + [cos_, 0., sin_], + [0., 1., 0.], + [-sin_, 0., cos_] + ]). astype(np.float32) + + def vector_to_pitchyaw(vectors): + n = vectors.shape[0] + out = np.empty((n, 2)) + vectors = np.divide(vectors, np.linalg.norm(vectors, axis=1).reshape(n, 1)) + out[:, 0] = -np.arcsin(vectors[:, 1]) # theta + out[:, 1] = np.arctan2(vectors[:, 0], vectors[:, 2]) # phi + return out + + def pitchyaw_to_vector(pitchyaws): + n = pitchyaws.shape[0] + sin = np.sin(pitchyaws) + cos = np.cos(pitchyaws) + out = np.empty((n, 3)) + out[:, 0] = np.multiply(cos[:, 0], sin[:, 1]) + out[:, 1] = sin[:, 0] + out[:, 2] = np.multiply(cos[:, 0], cos[:, 1]) + return out + + def calculate_rotation_matrix(e): + return np.matmul(R_y(e[1]), R_x(e[0])) + + # Grab 1st (input) entry + eyes_a, g_a, h_a = retrieve(group_a, idx_a) + entry = { + 'key': key_a, + 'key_index': self.prefixes.index(key_a), + 'image_a': eyes_a, + 'gaze_a': g_a, + 'head_a': h_a, + 'R_gaze_a': calculate_rotation_matrix(g_a), + 'R_head_a': calculate_rotation_matrix(h_a), + } + + if self.get_2nd_sample: + # Grab 2nd entry from same person + eyes_b, g_b, h_b = retrieve(group_b, idx_b) + entry['image_b'] = eyes_b + entry['gaze_b'] = g_b + entry['head_b'] = h_b + entry['R_gaze_b'] = calculate_rotation_matrix(entry['gaze_b']) + entry['R_head_b'] = calculate_rotation_matrix(entry['head_b']) + + return self.preprocess_entry(entry) diff --git a/src/full_train_test_and_plot.bash b/src/full_train_test_and_plot.bash new file mode 100644 index 0000000..e26283b --- /dev/null +++ b/src/full_train_test_and_plot.bash @@ -0,0 +1,113 @@ +#!/bin/bash + +########################### +# Necessary Configurations + +# We skip DT-ED training by default, such that the pre-trained weights +# can be used as-is. Please make sure that you have followed the README.md +# instructions to acquire these weights. +# +# Note: This is different to the `--skip-training` argument to +# `1_train_dt_ed.py` in that it skips the script completely. +# +# Set to 0 to perform DT-ED training and inference for the HDF output. +SKIP_DTED_TRAINING=1 + +# NOTE: please make sure to update the two paths below as necessary. +MPIIGAZE_FILE="../preprocess/outputs/MPIIGaze.h5" +GAZECAPTURE_FILE="../preprocess/outputs/GazeCapture.h5" + +# This batch size should fit a 11GB single GPU +# The original training used 8x Tesla V100 GPUs. +BATCH_SIZE=64 + +# Set the experiment output directory. +# NOTE: make sure to change this if you do not intend to over-write +# previous outputs. +OUTPUT_DIR="outputs_of_full_train_test_and_plot" + + +if [[ $SKIP_DTED_TRAINING -eq 0 ]] +then + ############################ + # 1. Perform DT-ED training + # + # The original setup used here was with: + # > Batch size: 1536 + # > # of epochs: 20 + # > # of GPUs: 8 + # > GPU model: Tesla V100 (32GB) + # > Mixed precision training with apex -O1 + + python3 1_train_dt_ed.py \ + --mpiigaze-file ${MPIIGAZE_FILE} \ + --gazecapture-file ${GAZECAPTURE_FILE} \ + \ + --num-training-epochs 20 \ + --batch-size $BATCH_SIZE \ + --eval-batch-size 1024 \ + \ + --normalize-3d-codes \ + --embedding-consistency-loss-type angular \ + --backprop-gaze-to-encoder \ + \ + --num-data-loaders 16 \ + \ + --save-image-samples 20 \ + --use-tensorboard \ + --save-path ${OUTPUT_DIR} \ + + ##################################################################################### + # NOTE: when adding the lines below, make sure to use the backslash ( \ ) correctly, + # such that the full command is correctly constructed and registered. + + # Use (append to above) the line below if wanting to use pre-trained weights, and skip training + # DO NOT JUST UNCOMMENT IT, IT WILL HAVE NO EFFECT DUE TO BASH PARSING + # --skip-training \ + + # Use (append to above) the line below if wanting to use mixed-precision training (as done in the paper) + # DO NOT JUST UNCOMMENT IT, IT WILL HAVE NO EFFECT DUE TO BASH PARSING + # --use-apex \ +fi + +########################### +# 2. Perform Meta Learning +# +# This step processes 6 experiments at a time because a single experiment +# does not make use of the GPU capacity sufficiently well. +# +# Please note, that you need output HDF files from the previous step to +# proceed to the next step. These HDF files are provided to you by default +# in this specific example pipeline. + +ML_COMMON=" --disable-tqdm --output-dir ./" + +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 1 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 2 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 3 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 4 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 5 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 6 & +wait + +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 7 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 8 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 9 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 10 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 11 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 12 & +wait + +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 13 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 14 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 15 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 16 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 17 & +python3 2_meta_learning.py ${ML_COMMON} ${OUTPUT_DIR} 18 & +wait + + +#################################################################### +# 3. Collect all of the individual meta-learning experiment results + +python3 3_combine_maml_results.py ${OUTPUT_DIR} diff --git a/src/gazecapture_split.json b/src/gazecapture_split.json new file mode 100644 index 0000000..5b08efa --- /dev/null +++ b/src/gazecapture_split.json @@ -0,0 +1 @@ +{"test-tablet": ["00010", "00509", "00546", "00619", "00653", "00741", "00796", "00876", "00880", "00921", "00932", "01091", "01152", "01183", "01200", "01702", "01805", "01941", "02032", "02155", "02190", "02194", "02364", "02413", "02450", "02480", "02734", "02805", "03117"], "val-tablet": ["00135", "00960", "01373", "02869", "03039"], "train": ["00002", "00003", "00005", "00006", "00024", "00028", "00033", "00034", "00087", "00089", "00097", "00098", "00099", "00102", "00103", "00104", "00114", "00117", "00120", "00121", "00122", "00123", "00124", "00127", "00128", "00130", "00132", "00133", "00137", "00138", "00139", "00140", "00141", "00142", "00143", "00144", "00145", "00146", "00148", "00149", "00150", "00153", "00154", "00156", "00162", "00164", "00165", "00173", "00179", "00191", "00194", "00198", "00200", "00202", "00207", "00208", "00209", "00210", "00211", "00212", "00214", "00218", "00221", "00224", "00225", "00226", "00227", "00228", "00229", "00232", "00234", "00236", "00237", "00238", "00239", "00240", "00241", "00243", "00245", "00247", "00249", "00251", "00252", "00258", "00266", "00268", "00269", "00273", "00274", "00285", "00288", "00289", "00295", "00296", "00299", "00300", "00303", "00304", "00305", "00307", "00309", "00310", "00311", "00312", "00317", "00322", "00324", "00325", "00326", "00331", "00332", "00339", "00342", "00351", "00354", "00355", "00356", "00357", "00358", "00359", "00363", "00376", "00377", "00383", "00459", "00463", "00465", "00466", "00467", "00469", "00472", "00473", "00475", "00477", "00480", "00481", "00487", "00488", "00491", "00492", "00493", "00494", "00495", "00496", "00499", "00500", "00501", "00503", "00505", "00510", "00512", "00513", "00514", "00518", "00519", "00520", "00521", "00522", "00525", "00531", "00533", "00534", "00535", "00539", "00540", "00542", "00544", "00545", "00548", "00549", "00550", "00553", "00554", "00555", "00560", "00562", "00565", "00566", "00569", "00572", "00574", "00575", "00578", "00581", "00584", "00588", "00590", "00595", "00597", "00599", "00600", "00601", "00602", "00605", "00606", "00607", "00610", "00613", "00617", "00621", "00622", "00623", "00624", "00626", "00627", "00632", "00633", "00634", "00636", "00638", "00641", "00642", "00643", "00644", "00645", "00649", "00650", "00658", "00661", "00663", "00666", "00667", "00668", "00669", "00670", "00672", "00675", "00676", "00677", "00678", "00679", "00682", "00683", "00687", "00688", "00690", "00691", "00693", "00694", "00695", "00696", "00699", "00704", "00706", "00707", "00710", "00711", "00712", "00714", "00716", "00718", "00719", "00722", "00728", "00729", "00730", "00731", "00732", "00733", "00737", "00740", "00742", "00743", "00745", "00747", "00748", "00749", "00750", "00752", "00753", "00755", "00756", "00757", "00764", "00765", "00767", "00771", "00772", "00773", "00774", "00775", "00779", "00789", "00790", "00791", "00795", "00798", "00801", "00802", "00804", "00806", "00807", "00808", "00810", "00811", "00812", "00814", "00818", "00819", "00820", "00821", "00823", "00825", "00827", "00828", "00831", "00832", "00833", "00835", "00837", "00840", "00841", "00842", "00849", "00850", "00851", "00852", "00853", "00855", "00859", "00861", "00864", "00865", "00869", "00872", "00873", "00874", "00875", "00878", "00881", "00882", "00886", "00888", "00889", "00890", "00891", "00892", "00894", "00896", "00897", "00898", "00899", "00900", "00904", "00905", "00907", "00911", "00912", "00914", "00915", "00923", "00924", "00927", "00930", "00931", "00933", "00934", "00938", "00939", "00944", "00945", "00947", "00948", "00955", "00956", "00961", "00963", "00969", "00971", "00974", "00976", "00980", "00981", "00982", "00983", "00984", "00986", "00989", "00991", "00992", "00997", "00998", "00999", "01000", "01001", "01002", "01003", "01009", "01010", "01012", "01015", "01018", "01019", "01020", "01021", "01022", "01024", "01025", "01029", "01030", "01031", "01032", "01034", "01035", "01038", "01039", "01042", "01044", "01045", "01046", "01050", "01052", "01054", "01055", "01056", "01057", "01058", "01059", "01060", "01062", "01063", "01064", "01065", "01066", "01069", "01070", "01073", "01075", "01076", "01077", "01080", "01081", "01082", "01083", "01084", "01085", "01086", "01087", "01088", "01089", "01090", "01092", "01093", "01095", "01099", "01100", "01102", "01104", "01105", "01106", "01107", "01109", "01110", "01118", "01120", "01121", "01122", "01123", "01126", "01127", "01128", "01129", "01134", "01135", "01138", "01139", "01143", "01145", "01146", "01147", "01149", "01151", "01156", "01157", "01158", "01161", "01162", "01163", "01164", "01165", "01166", "01167", "01168", "01169", "01170", "01171", "01172", "01173", "01174", "01175", "01177", "01178", "01180", "01181", "01182", "01184", "01185", "01186", "01188", "01191", "01195", "01199", "01201", "01204", "01206", "01207", "01208", "01209", "01211", "01212", "01213", "01219", "01221", "01222", "01224", "01225", "01231", "01232", "01233", "01237", "01243", "01244", "01247", "01250", "01252", "01254", "01255", "01256", "01259", "01260", "01262", "01266", "01267", "01269", "01270", "01275", "01276", "01279", "01281", "01282", "01283", "01285", "01293", "01295", "01298", "01300", "01301", "01303", "01304", "01315", "01316", "01320", "01323", "01327", "01328", "01330", "01331", "01333", "01340", "01347", "01348", "01349", "01350", "01351", "01352", "01353", "01354", "01356", "01357", "01358", "01360", "01361", "01362", "01366", "01367", "01368", "01372", "01375", "01377", "01379", "01380", "01382", "01383", "01384", "01386", "01387", "01388", "01389", "01390", "01391", "01392", "01393", "01396", "01400", "01405", "01406", "01414", "01415", "01420", "01421", "01423", "01424", "01428", "01430", "01431", "01432", "01434", "01435", "01438", "01440", "01443", "01445", "01446", "01448", "01451", "01454", "01456", "01459", "01460", "01462", "01467", "01470", "01471", "01472", "01473", "01474", "01478", "01479", "01480", "01481", "01482", "01483", "01485", "01486", "01487", "01488", "01491", "01492", "01496", "01497", "01499", "01508", "01510", "01511", "01514", "01515", "01516", "01519", "01523", "01524", "01528", "01531", "01532", "01533", "01534", "01540", "01542", "01544", "01546", "01551", "01553", "01556", "01566", "01569", "01574", "01577", "01581", "01582", "01583", "01584", "01602", "01603", "01604", "01606", "01611", "01612", "01613", "01617", "01618", "01627", "01630", "01631", "01633", "01635", "01636", "01637", "01640", "01643", "01644", "01645", "01648", "01650", "01651", "01653", "01658", "01661", "01665", "01669", "01671", "01676", "01678", "01680", "01681", "01682", "01684", "01687", "01690", "01692", "01693", "01697", "01698", "01700", "01703", "01705", "01706", "01709", "01710", "01713", "01717", "01718", "01719", "01720", "01726", "01727", "01728", "01729", "01730", "01731", "01734", "01738", "01741", "01744", "01745", "01747", "01748", "01755", "01762", "01763", "01768", "01770", "01771", "01775", "01778", "01779", "01783", "01789", "01792", "01795", "01796", "01798", "01802", "01803", "01806", "01809", "01812", "01816", "01817", "01818", "01819", "01821", "01823", "01825", "01826", "01827", "01828", "01833", "01843", "01849", "01858", "01859", "01860", "01862", "01866", "01867", "01868", "01869", "01870", "01874", "01876", "01878", "01880", "01882", "01883", "01884", "01885", "01887", "01888", "01889", "01892", "01896", "01897", "01900", "01901", "01902", "01905", "01906", "01907", "01908", "01912", "01915", "01921", "01922", "01924", "01925", "01926", "01927", "01930", "01933", "01936", "01939", "01943", "01960", "01961", "01962", "01964", "01965", "01966", "01975", "01976", "01977", "01979", "01984", "01987", "01995", "02002", "02009", "02011", "02015", "02019", "02022", "02023", "02024", "02025", "02026", "02027", "02028", "02029", "02033", "02034", "02035", "02038", "02045", "02047", "02048", "02051", "02052", "02056", "02058", "02059", "02061", "02064", "02065", "02077", "02084", "02085", "02086", "02087", "02090", "02092", "02093", "02099", "02102", "02105", "02106", "02112", "02113", "02114", "02115", "02117", "02118", "02119", "02123", "02131", "02136", "02137", "02138", "02140", "02141", "02142", "02152", "02154", "02156", "02159", "02161", "02162", "02165", "02168", "02170", "02172", "02173", "02174", "02186", "02187", "02193", "02198", "02203", "02204", "02206", "02207", "02212", "02216", "02219", "02220", "02223", "02229", "02230", "02232", "02234", "02237", "02241", "02243", "02244", "02249", "02250", "02255", "02257", "02264", "02266", "02267", "02270", "02272", "02277", "02278", "02279", "02282", "02293", "02297", "02298", "02300", "02311", "02314", "02319", "02321", "02322", "02324", "02326", "02327", "02328", "02332", "02334", "02337", "02339", "02342", "02343", "02347", "02349", "02350", "02352", "02353", "02355", "02358", "02359", "02361", "02362", "02365", "02366", "02367", "02368", "02370", "02371", "02373", "02375", "02379", "02394", "02412", "02414", "02415", "02417", "02418", "02420", "02421", "02424", "02426", "02430", "02431", "02432", "02434", "02435", "02436", "02439", "02440", "02441", "02442", "02443", "02445", "02447", "02448", "02452", "02454", "02456", "02457", "02458", "02459", "02462", "02465", "02467", "02468", "02469", "02472", "02474", "02478", "02510", "02518", "02520", "02521", "02522", "02524", "02525", "02526", "02533", "02534", "02535", "02540", "02542", "02547", "02550", "02551", "02552", "02553", "02554", "02557", "02559", "02566", "02567", "02571", "02573", "02575", "02576", "02578", "02581", "02585", "02587", "02588", "02590", "02595", "02610", "02611", "02613", "02615", "02617", "02619", "02622", "02629", "02632", "02634", "02649", "02663", "02666", "02669", "02673", "02681", "02689", "02690", "02700", "02705", "02709", "02713", "02718", "02721", "02722", "02723", "02725", "02729", "02730", "02732", "02737", "02739", "02740", "02741", "02749", "02758", "02760", "02761", "02762", "02763", "02764", "02765", "02772", "02773", "02774", "02776", "02780", "02781", "02785", "02797", "02818", "02819", "02827", "02829", "02832", "02837", "02840", "02841", "02843", "02846", "02847", "02852", "02854", "02857", "02868", "02872", "02873", "02874", "02876", "02877", "02878", "02879", "02880", "02882", "02883", "02888", "02898", "02902", "02908", "02911", "02919", "02920", "02921", "02922", "02924", "02925", "02928", "02938", "02941", "02944", "02945", "02954", "02955", "02956", "02960", "02961", "02964", "02967", "02976", "02977", "02978", "02979", "02980", "02984", "02985", "02987", "02988", "02989", "02991", "02997", "02998", "03003", "03004", "03006", "03007", "03009", "03012", "03013", "03014", "03023", "03026", "03027", "03037", "03042", "03051", "03057", "03059", "03060", "03064", "03065", "03079", "03089", "03102", "03107", "03116", "03122", "03125", "03130", "03133", "03134", "03137", "03139", "03160", "03163", "03172", "03174", "03178", "03179", "03180", "03188", "03189", "03190", "03192", "03193", "03197", "03199", "03200", "03205", "03206", "03211", "03212", "03218", "03219", "03222", "03224", "03225", "03231", "03239", "03246", "03248", "03251", "03253", "03255", "03259", "03263", "03265", "03266", "03273", "03275", "03277", "03278", "03282", "03283", "03302", "03303", "03304", "03307", "03314", "03315", "03327", "03328", "03332", "03336", "03340", "03342", "03343", "03348", "03351", "03354", "03358", "03359", "03360", "03367", "03371", "03374", "03375", "03377", "03378", "03379", "03380", "03381", "03382", "03384", "03389", "03397", "03403", "03406", "03413", "03425", "03431", "03432", "03435", "03442", "03453", "03454", "03456", "03463", "03465", "03466", "03467", "03469", "03473", "03474", "03491", "03492", "03495", "03498", "03501", "03502"], "test": ["00010", "00110", "00126", "00178", "00190", "00192", "00220", "00222", "00233", "00319", "00330", "00343", "00382", "00460", "00509", "00511", "00546", "00563", "00580", "00585", "00611", "00616", "00619", "00646", "00653", "00654", "00680", "00686", "00700", "00721", "00741", "00777", "00796", "00868", "00876", "00880", "00921", "00932", "00935", "00949", "00953", "00965", "00968", "01036", "01041", "01051", "01091", "01148", "01152", "01155", "01183", "01200", "01273", "01278", "01286", "01326", "01329", "01370", "01376", "01425", "01457", "01477", "01506", "01517", "01525", "01575", "01625", "01672", "01674", "01689", "01702", "01782", "01794", "01805", "01813", "01830", "01855", "01863", "01877", "01893", "01941", "01959", "01978", "01983", "01985", "01997", "02006", "02020", "02032", "02043", "02078", "02091", "02109", "02155", "02190", "02194", "02197", "02213", "02239", "02240", "02269", "02275", "02281", "02292", "02301", "02348", "02364", "02413", "02419", "02450", "02455", "02461", "02480", "02536", "02601", "02734", "02755", "02756", "02805", "02833", "02851", "02885", "02899", "02942", "02966", "02986", "03011", "03024", "03043", "03117", "03126", "03140", "03177", "03183", "03185", "03202", "03216", "03223", "03247", "03270", "03324", "03326", "03344", "03352", "03361", "03366", "03404", "03412", "03451", "03523"], "test-phone": ["00110", "00126", "00178", "00190", "00192", "00220", "00222", "00233", "00319", "00330", "00343", "00382", "00460", "00511", "00563", "00580", "00585", "00611", "00616", "00646", "00654", "00680", "00686", "00700", "00721", "00777", "00868", "00935", "00949", "00953", "00965", "00968", "01036", "01041", "01051", "01148", "01155", "01273", "01278", "01286", "01326", "01329", "01370", "01376", "01425", "01457", "01477", "01506", "01517", "01525", "01575", "01625", "01672", "01674", "01689", "01782", "01794", "01813", "01830", "01855", "01863", "01877", "01893", "01959", "01978", "01983", "01985", "01997", "02006", "02020", "02043", "02078", "02091", "02109", "02197", "02213", "02239", "02240", "02269", "02275", "02281", "02292", "02301", "02348", "02419", "02455", "02461", "02536", "02601", "02755", "02756", "02833", "02851", "02885", "02899", "02942", "02966", "02986", "03011", "03024", "03043", "03126", "03140", "03177", "03183", "03185", "03202", "03216", "03223", "03247", "03270", "03324", "03326", "03344", "03352", "03361", "03366", "03404", "03412", "03451", "03523"], "val": ["00135", "00213", "00267", "00471", "00507", "00547", "00618", "00746", "00776", "00803", "00926", "00960", "01103", "01119", "01220", "01248", "01274", "01297", "01319", "01373", "01404", "01475", "01509", "01543", "01646", "01760", "01773", "01786", "01845", "01998", "02060", "02133", "02166", "02217", "02236", "02265", "02416", "02593", "02650", "02731", "02736", "02869", "03039", "03093", "03214", "03232", "03258", "03312", "03349", "03450"], "train-tablet": ["00028", "00117", "00124", "00133", "00138", "00198", "00207", "00208", "00210", "00229", "00241", "00251", "00252", "00258", "00295", "00317", "00322", "00325", "00326", "00358", "00383", "00463", "00500", "00521", "00549", "00578", "00595", "00597", "00613", "00666", "00696", "00718", "00740", "00748", "00756", "00757", "00779", "00791", "00801", "00804", "00806", "00808", "00825", "00827", "00828", "00842", "00850", "00853", "00861", "00890", "00891", "00894", "00927", "00930", "00934", "00976", "00980", "00981", "00986", "00998", "01001", "01022", "01025", "01029", "01030", "01038", "01039", "01052", "01060", "01066", "01085", "01088", "01089", "01090", "01099", "01104", "01109", "01122", "01126", "01128", "01134", "01151", "01158", "01173", "01206", "01221", "01224", "01225", "01232", "01233", "01243", "01266", "01267", "01269", "01282", "01283", "01295", "01350", "01352", "01366", "01367", "01383", "01384", "01390", "01392", "01432", "01443", "01474", "01508", "01524", "01544", "01556", "01582", "01618", "01627", "01661", "01717", "01726", "01727", "01809", "01819", "01843", "01859", "01862", "01876", "01896", "01901", "01905", "01939", "02002", "02027", "02033", "02048", "02087", "02117", "02118", "02119", "02165", "02174", "02193", "02198", "02204", "02223", "02243", "02267", "02272", "02298", "02342", "02349", "02353", "02368", "02370", "02414", "02417", "02436", "02456", "02522", "02526", "02533", "02542", "02551", "02613", "02622", "02700", "02739", "02761", "02840", "02857", "02878", "02879", "02883", "02902", "02961", "02967", "02976", "02984", "03007", "03027", "03059", "03060", "03212", "03224", "03239", "03266", "03277", "03380", "03389", "03413", "03442", "03474"], "val-phone": ["00213", "00267", "00471", "00507", "00547", "00618", "00746", "00776", "00803", "00926", "01103", "01119", "01220", "01248", "01274", "01297", "01319", "01404", "01475", "01509", "01543", "01646", "01760", "01773", "01786", "01845", "01998", "02060", "02133", "02166", "02217", "02236", "02265", "02416", "02593", "02650", "02731", "02736", "03093", "03214", "03232", "03258", "03312", "03349", "03450"], "train-phone": ["00002", "00003", "00005", "00006", "00024", "00033", "00034", "00087", "00089", "00097", "00098", "00099", "00102", "00103", "00104", "00114", "00120", "00121", "00122", "00123", "00127", "00128", "00130", "00132", "00137", "00139", "00140", "00141", "00142", "00143", "00144", "00145", "00146", "00148", "00149", "00150", "00153", "00154", "00156", "00162", "00164", "00165", "00173", "00179", "00191", "00194", "00200", "00202", "00209", "00211", "00212", "00214", "00218", "00221", "00224", "00225", "00226", "00227", "00228", "00232", "00234", "00236", "00237", "00238", "00239", "00240", "00243", "00245", "00247", "00249", "00266", "00268", "00269", "00273", "00274", "00285", "00288", "00289", "00296", "00299", "00300", "00303", "00304", "00305", "00307", "00309", "00310", "00311", "00312", "00324", "00331", "00332", "00339", "00342", "00351", "00354", "00355", "00356", "00357", "00359", "00363", "00376", "00377", "00459", "00465", "00466", "00467", "00469", "00472", "00473", "00475", "00477", "00480", "00481", "00487", "00488", "00491", "00492", "00493", "00494", "00495", "00496", "00499", "00501", "00503", "00505", "00510", "00512", "00513", "00514", "00518", "00519", "00520", "00522", "00525", "00531", "00533", "00534", "00535", "00539", "00540", "00542", "00544", "00545", "00548", "00550", "00553", "00554", "00555", "00560", "00562", "00565", "00566", "00569", "00572", "00574", "00575", "00581", "00584", "00588", "00590", "00599", "00600", "00601", "00602", "00605", "00606", "00607", "00610", "00617", "00621", "00622", "00623", "00624", "00626", "00627", "00632", "00633", "00634", "00636", "00638", "00641", "00642", "00643", "00644", "00645", "00649", "00650", "00658", "00661", "00663", "00667", "00668", "00669", "00670", "00672", "00675", "00676", "00677", "00678", "00679", "00682", "00683", "00687", "00688", "00690", "00691", "00693", "00694", "00695", "00699", "00704", "00706", "00707", "00710", "00711", "00712", "00714", "00716", "00719", "00722", "00728", "00729", "00730", "00731", "00732", "00733", "00737", "00742", "00743", "00745", "00747", "00749", "00750", "00752", "00753", "00755", "00764", "00765", "00767", "00771", "00772", "00773", "00774", "00775", "00789", "00790", "00795", "00798", "00802", "00807", "00810", "00811", "00812", "00814", "00818", "00819", "00820", "00821", "00823", "00831", "00832", "00833", "00835", "00837", "00840", "00841", "00849", "00851", "00852", "00855", "00859", "00864", "00865", "00869", "00872", "00873", "00874", "00875", "00878", "00881", "00882", "00886", "00888", "00889", "00892", "00896", "00897", "00898", "00899", "00900", "00904", "00905", "00907", "00911", "00912", "00914", "00915", "00923", "00924", "00931", "00933", "00938", "00939", "00944", "00945", "00947", "00948", "00955", "00956", "00961", "00963", "00969", "00971", "00974", "00982", "00983", "00984", "00989", "00991", "00992", "00997", "00999", "01000", "01002", "01003", "01009", "01010", "01012", "01015", "01018", "01019", "01020", "01021", "01024", "01031", "01032", "01034", "01035", "01042", "01044", "01045", "01046", "01050", "01054", "01055", "01056", "01057", "01058", "01059", "01062", "01063", "01064", "01065", "01069", "01070", "01073", "01075", "01076", "01077", "01080", "01081", "01082", "01083", "01084", "01086", "01087", "01092", "01093", "01095", "01100", "01102", "01105", "01106", "01107", "01110", "01118", "01120", "01121", "01123", "01127", "01129", "01135", "01138", "01139", "01143", "01145", "01146", "01147", "01149", "01156", "01157", "01161", "01162", "01163", "01164", "01165", "01166", "01167", "01168", "01169", "01170", "01171", "01172", "01174", "01175", "01177", "01178", "01180", "01181", "01182", "01184", "01185", "01186", "01188", "01191", "01195", "01199", "01201", "01204", "01207", "01208", "01209", "01211", "01212", "01213", "01219", "01222", "01231", "01237", "01244", "01247", "01250", "01252", "01254", "01255", "01256", "01259", "01260", "01262", "01270", "01275", "01276", "01279", "01281", "01285", "01293", "01298", "01300", "01301", "01303", "01304", "01315", "01316", "01320", "01323", "01327", "01328", "01330", "01331", "01333", "01340", "01347", "01348", "01349", "01351", "01353", "01354", "01356", "01357", "01358", "01360", "01361", "01362", "01368", "01372", "01375", "01377", "01379", "01380", "01382", "01386", "01387", "01388", "01389", "01391", "01393", "01396", "01400", "01405", "01406", "01414", "01415", "01420", "01421", "01423", "01424", "01428", "01430", "01431", "01434", "01435", "01438", "01440", "01445", "01446", "01448", "01451", "01454", "01456", "01459", "01460", "01462", "01467", "01470", "01471", "01472", "01473", "01478", "01479", "01480", "01481", "01482", "01483", "01485", "01486", "01487", "01488", "01491", "01492", "01496", "01497", "01499", "01510", "01511", "01514", "01515", "01516", "01519", "01523", "01528", "01531", "01532", "01533", "01534", "01540", "01542", "01546", "01551", "01553", "01566", "01569", "01574", "01577", "01581", "01583", "01584", "01602", "01603", "01604", "01606", "01611", "01612", "01613", "01617", "01630", "01631", "01633", "01635", "01636", "01637", "01640", "01643", "01644", "01645", "01648", "01650", "01651", "01653", "01658", "01665", "01669", "01671", "01676", "01678", "01680", "01681", "01682", "01684", "01687", "01690", "01692", "01693", "01697", "01698", "01700", "01703", "01705", "01706", "01709", "01710", "01713", "01718", "01719", "01720", "01728", "01729", "01730", "01731", "01734", "01738", "01741", "01744", "01745", "01747", "01748", "01755", "01762", "01763", "01768", "01770", "01771", "01775", "01778", "01779", "01783", "01789", "01792", "01795", "01796", "01798", "01802", "01803", "01806", "01812", "01816", "01817", "01818", "01821", "01823", "01825", "01826", "01827", "01828", "01833", "01849", "01858", "01860", "01866", "01867", "01868", "01869", "01870", "01874", "01878", "01880", "01882", "01883", "01884", "01885", "01887", "01888", "01889", "01892", "01897", "01900", "01902", "01906", "01907", "01908", "01912", "01915", "01921", "01922", "01924", "01925", "01926", "01927", "01930", "01933", "01936", "01943", "01960", "01961", "01962", "01964", "01965", "01966", "01975", "01976", "01977", "01979", "01984", "01987", "01995", "02009", "02011", "02015", "02019", "02022", "02023", "02024", "02025", "02026", "02028", "02029", "02034", "02035", "02038", "02045", "02047", "02051", "02052", "02056", "02058", "02059", "02061", "02064", "02065", "02077", "02084", "02085", "02086", "02090", "02092", "02093", "02099", "02102", "02105", "02106", "02112", "02113", "02114", "02115", "02123", "02131", "02136", "02137", "02138", "02140", "02141", "02142", "02152", "02154", "02156", "02159", "02161", "02162", "02168", "02170", "02172", "02173", "02186", "02187", "02203", "02206", "02207", "02212", "02216", "02219", "02220", "02229", "02230", "02232", "02234", "02237", "02241", "02244", "02249", "02250", "02255", "02257", "02264", "02266", "02270", "02277", "02278", "02279", "02282", "02293", "02297", "02300", "02311", "02314", "02319", "02321", "02322", "02324", "02326", "02327", "02328", "02332", "02334", "02337", "02339", "02343", "02347", "02350", "02352", "02355", "02358", "02359", "02361", "02362", "02365", "02366", "02367", "02371", "02373", "02375", "02379", "02394", "02412", "02415", "02418", "02420", "02421", "02424", "02426", "02430", "02431", "02432", "02434", "02435", "02439", "02440", "02441", "02442", "02443", "02445", "02447", "02448", "02452", "02454", "02457", "02458", "02459", "02462", "02465", "02467", "02468", "02469", "02472", "02474", "02478", "02510", "02518", "02520", "02521", "02524", "02525", "02534", "02535", "02540", "02547", "02550", "02552", "02553", "02554", "02557", "02559", "02566", "02567", "02571", "02573", "02575", "02576", "02578", "02581", "02585", "02587", "02588", "02590", "02595", "02610", "02611", "02615", "02617", "02619", "02629", "02632", "02634", "02649", "02663", "02666", "02669", "02673", "02681", "02689", "02690", "02705", "02709", "02713", "02718", "02721", "02722", "02723", "02725", "02729", "02730", "02732", "02737", "02740", "02741", "02749", "02758", "02760", "02762", "02763", "02764", "02765", "02772", "02773", "02774", "02776", "02780", "02781", "02785", "02797", "02818", "02819", "02827", "02829", "02832", "02837", "02841", "02843", "02846", "02847", "02852", "02854", "02868", "02872", "02873", "02874", "02876", "02877", "02880", "02882", "02888", "02898", "02908", "02911", "02919", "02920", "02921", "02922", "02924", "02925", "02928", "02938", "02941", "02944", "02945", "02954", "02955", "02956", "02960", "02964", "02977", "02978", "02979", "02980", "02985", "02987", "02988", "02989", "02991", "02997", "02998", "03003", "03004", "03006", "03009", "03012", "03013", "03014", "03023", "03026", "03037", "03042", "03051", "03057", "03064", "03065", "03079", "03089", "03102", "03107", "03116", "03122", "03125", "03130", "03133", "03134", "03137", "03139", "03160", "03163", "03172", "03174", "03178", "03179", "03180", "03188", "03189", "03190", "03192", "03193", "03197", "03199", "03200", "03205", "03206", "03211", "03218", "03219", "03222", "03225", "03231", "03246", "03248", "03251", "03253", "03255", "03259", "03263", "03265", "03273", "03275", "03278", "03282", "03283", "03302", "03303", "03304", "03307", "03314", "03315", "03327", "03328", "03332", "03336", "03340", "03342", "03343", "03348", "03351", "03354", "03358", "03359", "03360", "03367", "03371", "03374", "03375", "03377", "03378", "03379", "03381", "03382", "03384", "03397", "03403", "03406", "03425", "03431", "03432", "03435", "03453", "03454", "03456", "03463", "03465", "03466", "03467", "03469", "03473", "03491", "03492", "03495", "03498", "03501", "03502"]} \ No newline at end of file diff --git a/src/losses/__init__.py b/src/losses/__init__.py new file mode 100644 index 0000000..67d7368 --- /dev/null +++ b/src/losses/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +from .all_frontals_equal import AllFrontalsEqualLoss +from .batch_hard_triplet import BatchHardTripletLoss +from .gaze_angular import GazeAngularLoss +from .gaze_mse import GazeMSELoss +from .reconstruction_l1 import ReconstructionL1Loss +from .embedding_consistency import EmbeddingConsistencyLoss + +__all__ = ('AllFrontalsEqualLoss', 'BatchHardTripletLoss', + 'GazeAngularLoss', 'GazeMSELoss', + 'ReconstructionL1Loss', 'EmbeddingConsistencyLoss') diff --git a/src/losses/all_frontals_equal.py b/src/losses/all_frontals_equal.py new file mode 100644 index 0000000..4c5c3d2 --- /dev/null +++ b/src/losses/all_frontals_equal.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +from collections import OrderedDict +import numpy as np +import torch +import torch.nn.functional as F + + +def nn_batch_angular_distance(a, b): + sim = F.cosine_similarity(a, b, dim=-1, eps=1e-6) + sim = F.hardtanh(sim, 1e-6, 1.0 - 1e-6) + return torch.mean(torch.acos(sim) * (180 / np.pi), dim=1) + + +class AllFrontalsEqualLoss(object): + + def __call__(self, input_dict, output_dict): + # Perform for each gaze and head modes + loss_terms = OrderedDict() + for mode in ['gaze', 'head']: + # Calculate the mean 3D frontalized embedding + all_embeddings = torch.cat([ + output_dict['canon_z_' + mode + '_a'], + output_dict['canon_z_' + mode + '_b'], + ], dim=0) + mean_embedding = torch.mean(all_embeddings, dim=0) + + # Final calculate and reduce to single loss term + loss_terms[mode] = torch.std( + nn_batch_angular_distance(mean_embedding, all_embeddings) + ) + return loss_terms diff --git a/src/losses/batch_hard_triplet.py b/src/losses/batch_hard_triplet.py new file mode 100644 index 0000000..a0a220e --- /dev/null +++ b/src/losses/batch_hard_triplet.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +from collections import OrderedDict +import numpy as np +import torch +import torch.nn.functional as F + + +def nn_batch_angular_distance(a, b): + sim = F.cosine_similarity(a, b, dim=-1, eps=1e-6) + sim = F.hardtanh(sim, 1e-6, 1.0 - 1e-6) + return torch.mean(torch.acos(sim) * (180 / np.pi), dim=1) + + +def nn_batch_euclidean_distance(a, b): + return torch.mean(torch.norm(a - b, dim=-1, p='fro'), dim=-1) + + +class BatchHardTripletLoss(object): + + def __init__(self, distance_type, margin=0.0): + self.margin = margin + + # Select distance function + self.distance_type = distance_type + self.distance_fn = None + if distance_type == 'angular': + self.distance_fn = nn_batch_angular_distance + elif distance_type == 'euclidean': + self.distance_fn = nn_batch_euclidean_distance + else: + raise ValueError('Unknown triplet loss distance type: ' + distance_type) + + # Zero loss tensor for when no triplet can be found + self.zero_loss = torch.tensor(0, dtype=torch.float, requires_grad=False, + device="cuda:0" if torch.cuda.is_available() else "cpu") + + def construct_person_identicality(self, input_dict): + """ + Construct a binary matrix which describes whether for a specific matrix + element, the row-column position mean that the distance is measured + between (inter) or within (intra) people. + The value is 1 for the case where person identity is identical, + and 0 when row and column indices refer to different people. + """ + num_entries = 2 * len(input_dict['key_index']) + person_indices = input_dict['key_index'].repeat(2).view(-1, 1) + person_indices = person_indices.repeat(1, num_entries) + identicality = (person_indices == person_indices.t()).byte() + inv_identicality = 1 - identicality + return identicality.float(), inv_identicality.float() + + def calculate_pairwise_distances(self, output_dict, mode): + """ + For all given pairs in a batch, calculate distances with selected + function, pairwise. This means that for a given batch of size B, + there are 2*B entries. We will calculate (2B)**2 distances. + """ + num_entries = 2 * len(output_dict['canon_z_' + mode + '_a']) + all_embeddings = torch.cat([ + output_dict['canon_z_' + mode + '_a'], + output_dict['canon_z_' + mode + '_b'], + ], dim=0) + a = all_embeddings.view(num_entries, 1, -1, 3).repeat(1, num_entries, 1, 1) + a = a.view(num_entries * num_entries, -1, 3) + b = all_embeddings.repeat(num_entries, 1, 1) + return self.distance_fn(a, b).view(num_entries, num_entries) + + def select_hard_triplets(self, dist_grid, person_identicality, + inv_person_identicality, selected_row_indices): + """ + In this function, we select the largest inter-person distance for each + input entry, and the smallest non-zero intra-person distance for each + input entry. We only select entries which form valid triplets. + """ + dist_same = dist_grid * person_identicality + dist_same_max, _ = torch.max(dist_same, dim=1) + + dist_diff = dist_grid * inv_person_identicality + dist_diff[dist_diff < 1e-6] = 1e6 # set some large value to zero values + dist_diff_min, _ = torch.min(dist_diff, dim=1) + + if len(selected_row_indices) < len(dist_grid): + dist_same_max = torch.take(dist_same_max, selected_row_indices) + dist_diff_min = torch.take(dist_diff_min, selected_row_indices) + return dist_same_max, dist_diff_min + + def __call__(self, input_dict, output_dict): + # Calculate masks + person_identicality, inv_person_identicality = \ + self.construct_person_identicality(input_dict) + + # Select only those that have both same and diff entries + # Basically ensure that there are valid triplets + num_per_row_same = torch.sum(person_identicality.byte(), dim=-1) - 1 + num_per_row_diff = torch.sum(inv_person_identicality.byte(), dim=-1) + num_per_row_both = num_per_row_same * num_per_row_diff + selected_row_indices = torch.nonzero(num_per_row_both) + + # Perform for each gaze and head modes + loss_terms = OrderedDict() + for mode in ['gaze', 'head']: + # Calculate pairwise distances + pairwise_distances = self.calculate_pairwise_distances(output_dict, mode) + + # Reduce to hard samples + d_positive, d_negative = self.select_hard_triplets( + pairwise_distances, person_identicality, + inv_person_identicality, selected_row_indices, + ) + + # Final calculate and reduce to single loss term + stem = mode + '_' + self.distance_type + if len(d_positive) > 0: + loss_terms[stem] = torch.mean(F.softplus(d_positive - d_negative + self.margin)) + loss_terms[stem + '_d_within'] = torch.mean(d_positive) + loss_terms[stem + '_d_between'] = torch.mean(d_negative) + else: + loss_terms[stem] = self.zero_loss + loss_terms[stem + '_d_within'] = self.zero_loss + loss_terms[stem + '_d_between'] = self.zero_loss + return loss_terms diff --git a/src/losses/embedding_consistency.py b/src/losses/embedding_consistency.py new file mode 100644 index 0000000..430c2fc --- /dev/null +++ b/src/losses/embedding_consistency.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +from collections import OrderedDict +import numpy as np +import torch +import torch.nn.functional as F + + +def nn_batch_angular_distance(a, b): + # The inputs here are of shape: B x F x 3 + # where B: batch size + # F: no. of features + # we would like to compare each corresponding feature separately + assert a.dim() == b.dim() == 3 + assert a.shape[-1] == b.shape[-1] == 3 + sim = F.cosine_similarity(a, b, dim=-1, eps=1e-6) + # We now have distances with shape B x F + + # Ensure no NaNs occur due to the input to the arccos function + sim = F.hardtanh(sim, 1e-6, 1.0 - 1e-6) + + # Now, we want to convert the similarity measure to degrees and calculate + # a single scalar distance value per entry in the batch + batch_distance = torch.mean(torch.acos(sim) * (180 / np.pi), dim=1) + + # The output is of length B + assert batch_distance.dim() == 1 + return batch_distance + + +def nn_batch_euclidean_distance(a, b): + # The inputs here are of shape: B x F x 3 + # Let's compare each 3D unit vector feature separately + assert a.dim() == b.dim() == 3 + assert a.shape[-1] == b.shape[-1] == 3 + featurewise_dists = torch.norm(a - b, dim=-1, p='fro') + + # Calculate a single scalar distance value per entry in the batch + entrywise_dists = torch.mean(featurewise_dists, dim=-1) + return entrywise_dists + + +class EmbeddingConsistencyLoss(object): + + def __init__(self, distance_type): + # Select distance function + self.distance_type = distance_type + self.distance_fn = None + if distance_type == 'angular': + self.distance_fn = nn_batch_angular_distance + elif distance_type == 'euclidean': + self.distance_fn = nn_batch_euclidean_distance + else: + raise ValueError('Unknown triplet loss distance type: ' + distance_type) + + def construct_person_identicality(self, input_dict): + """ + Construct a binary matrix which describes whether for a specific matrix + element, the row-column position mean that the distance is measured + between (inter) or within (intra) people. + The value is 1 for the case where person identity is identical, + and 0 when row and column indices refer to different people. + """ + num_entries = 2 * len(input_dict['key_index']) + person_indices = input_dict['key_index'].repeat(2).view(-1, 1) + person_indices = person_indices.repeat(1, num_entries) + identicality = (person_indices == person_indices.t()).byte() + return identicality.float() + + def calculate_pairwise_distances(self, output_dict, mode): + """ + For all given pairs in a batch, calculate distances with selected + function, pairwise. This means that for a given batch of size B, + there are 2*B entries. We will calculate (2B)**2 distances. + """ + num_entries = 2 * len(output_dict['canon_z_' + mode + '_a']) + all_embeddings = torch.cat([ + output_dict['canon_z_' + mode + '_a'], + output_dict['canon_z_' + mode + '_b'], + ], dim=0) + a = all_embeddings.view(num_entries, 1, -1, 3).repeat(1, num_entries, 1, 1) + a = a.view(num_entries * num_entries, -1, 3) + b = all_embeddings.repeat(num_entries, 1, 1) + return self.distance_fn(a, b).view(num_entries, num_entries) + + def select_max_distances(self, dist_grid, person_identicality): + """ + In this function, we select the largest inter-person distance for each + input entry, and the smallest non-zero intra-person distance for each + input entry. We only select entries which form valid triplets. + """ + dist_same = dist_grid * person_identicality + dist_same_max, _ = torch.max(dist_same, dim=1) + return dist_same_max + + def __call__(self, input_dict, output_dict): + # Calculate masks + person_identicality = self.construct_person_identicality(input_dict) + + # Perform for each gaze and head modes + loss_terms = OrderedDict() + for mode in ['gaze', 'head']: + # Calculate pairwise distances + pairwise_distances = self.calculate_pairwise_distances(output_dict, mode) + + # Reduce to hard samples + d_positive = self.select_max_distances(pairwise_distances, person_identicality) + + # Final calculate and reduce to single loss term + stem = mode + '_' + self.distance_type + loss_terms[stem] = torch.mean(d_positive) + return loss_terms diff --git a/src/losses/gaze_angular.py b/src/losses/gaze_angular.py new file mode 100644 index 0000000..f8c3f79 --- /dev/null +++ b/src/losses/gaze_angular.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn.functional as F + + +def nn_angular_distance(a, b): + sim = F.cosine_similarity(a, b, eps=1e-6) + sim = F.hardtanh(sim, 1e-6, 1.0 - 1e-6) + return torch.acos(sim) * (180 / np.pi) + + +class GazeAngularLoss(object): + + def __init__(self, key_true='gaze_a', key_pred='gaze_a_hat'): + self.key_true = key_true + self.key_pred = key_pred + + def __call__(self, input_dict, output_dict): + def pitchyaw_to_vector(pitchyaws): + sin = torch.sin(pitchyaws) + cos = torch.cos(pitchyaws) + return torch.stack([cos[:, 0] * sin[:, 1], sin[:, 0], cos[:, 0] * cos[:, 1]], 1) + y = pitchyaw_to_vector(input_dict[self.key_true]).detach() + y_hat = output_dict[self.key_pred] + if y_hat.shape[1] == 2: + y_hat = pitchyaw_to_vector(y_hat) + return torch.mean(nn_angular_distance(y, y_hat)) diff --git a/src/losses/gaze_mse.py b/src/losses/gaze_mse.py new file mode 100644 index 0000000..6141620 --- /dev/null +++ b/src/losses/gaze_mse.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +import torch + + +class GazeMSELoss(object): + + def __call__(self, input_dict, output_dict): + def pitchyaw_to_vector(pitchyaws): + sin = torch.sin(pitchyaws) + cos = torch.cos(pitchyaws) + return torch.stack([cos[:, 0] * sin[:, 1], sin[:, 0], cos[:, 0] * cos[:, 1]], 1) + y = pitchyaw_to_vector(input_dict['gaze_a']).detach() + y_hat = output_dict['gaze_a_hat'] + assert y.shape[1] == y_hat.shape[1] == 3 + return torch.mean((y - y_hat) ** 2) diff --git a/src/losses/reconstruction_l1.py b/src/losses/reconstruction_l1.py new file mode 100644 index 0000000..fa2bf41 --- /dev/null +++ b/src/losses/reconstruction_l1.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +import torch.nn as nn + + +class ReconstructionL1Loss(object): + + def __init__(self, suffix='b'): + self.suffix = suffix + self.loss_fn = nn.L1Loss(reduction='mean') + + def __call__(self, input_dict, output_dict): + x = input_dict['image_' + self.suffix].detach() + x_hat = output_dict['image_' + self.suffix + '_hat'] + return self.loss_fn(x, x_hat) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..32b4083 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +from .densenet import DenseNet +from .dt_ed import DTED + +__all__ = ('DenseNet', 'DTED') diff --git a/src/models/densenet.py b/src/models/densenet.py new file mode 100644 index 0000000..1a542b3 --- /dev/null +++ b/src/models/densenet.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +import torch +import torch.nn as nn + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class DenseNet(nn.Module): + + p_dropout = 0.0 # DON'T use this + + num_blocks = 4 + num_layers_per_block = 4 + use_bottleneck = False # Enabling this usually makes training unstable + compression_factor = 1.0 # Makes training less stable if != 1.0 + + fc_feats = [2] + + def __init__(self, growth_rate=8, activation_fn=nn.ReLU, + normalization_fn=nn.BatchNorm2d): + super(DenseNet, self).__init__() + + # Initial down-sampling conv layers + self.initial = DenseNetInitialLayers(growth_rate=growth_rate, + activation_fn=activation_fn, + normalization_fn=normalization_fn) + c_now = self.initial.c_now + + assert (self.num_layers_per_block % 2) == 0 + for i in range(self.num_blocks): + i_ = i + 1 + # Define dense block + self.add_module('block%d' % i_, DenseNetBlock( + c_now, + num_layers=(int(self.num_layers_per_block / 2) + if self.use_bottleneck + else self.num_layers_per_block), + growth_rate=growth_rate, + p_dropout=self.p_dropout, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + use_bottleneck=self.use_bottleneck, + )) + c_now = list(self.children())[-1].c_now + + # Define transition block if not last layer + if i < (self.num_blocks - 1): + self.add_module('trans%d' % i_, DenseNetTransitionDown( + c_now, p_dropout=self.p_dropout, + compression_factor=self.compression_factor, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + )) + c_now = list(self.children())[-1].c_now + + # Final FC layers + self.fcs = [] + f_now = c_now + for f in self.fc_feats: + fc = nn.Linear(f_now, f).to(device) + fc.weight.data.normal_(0, 0.01) + fc.bias.data.fill_(0) + self.fcs.append(fc) + f_now = f + self.fcs = nn.ModuleList(self.fcs) + + def forward(self, x): + # Apply initial layers and dense blocks + for name, module in self.named_children(): + if name == 'fcs': + break + x = module(x) + + # Global average pooling + x = torch.mean(x, dim=2) # reduce h + x = torch.mean(x, dim=2) # reduce w + + # fc to gaze direction + for fc in self.fcs: + x = fc(x) + + return x + + +class DenseNetInitialLayers(nn.Module): + + def __init__(self, growth_rate=8, activation_fn=nn.ReLU, + normalization_fn=nn.BatchNorm2d): + super(DenseNetInitialLayers, self).__init__() + c_next = 2 * growth_rate + self.conv1 = nn.Conv2d(3, c_next, bias=False, + kernel_size=3, stride=2, padding=1) + nn.init.kaiming_normal_(self.conv1.weight.data) + + self.norm = normalization_fn(c_next, track_running_stats=False).to(device) + self.act = activation_fn(inplace=True) + + c_out = 4 * growth_rate + self.conv2 = nn.Conv2d(2 * growth_rate, c_out, bias=False, + kernel_size=3, stride=2, padding=1) + nn.init.kaiming_normal_(self.conv2.weight.data) + + self.c_now = c_out + self.c_list = [c_next, c_out] + + def forward(self, x): + x = self.conv1(x) + x = self.norm(x) + x = self.act(x) + prev_scale_x = x + x = self.conv2(x) + return x, prev_scale_x + + +class DenseNetBlock(nn.Module): + + def __init__(self, c_in, num_layers=4, growth_rate=8, p_dropout=0.1, + use_bottleneck=False, activation_fn=nn.ReLU, + normalization_fn=nn.BatchNorm2d, transposed=False): + super(DenseNetBlock, self).__init__() + self.use_bottleneck = use_bottleneck + c_now = c_in + for i in range(num_layers): + i_ = i + 1 + if use_bottleneck: + self.add_module('bneck%d' % i_, DenseNetCompositeLayer( + c_now, 4 * growth_rate, kernel_size=1, p_dropout=p_dropout, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + )) + self.add_module('compo%d' % i_, DenseNetCompositeLayer( + 4 * growth_rate if use_bottleneck else c_now, growth_rate, + kernel_size=3, p_dropout=p_dropout, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + transposed=transposed, + )) + c_now += list(self.children())[-1].c_now + self.c_now = c_now + + def forward(self, x): + x_before = x + for i, (name, module) in enumerate(self.named_children()): + if ((self.use_bottleneck and name.startswith('bneck')) + or name.startswith('compo')): + x_before = x + x = module(x) + if name.startswith('compo'): + x = torch.cat([x_before, x], dim=1) + return x + + +class DenseNetTransitionDown(nn.Module): + + def __init__(self, c_in, compression_factor=0.1, p_dropout=0.1, + activation_fn=nn.ReLU, normalization_fn=nn.BatchNorm2d): + super(DenseNetTransitionDown, self).__init__() + c_out = int(compression_factor * c_in) + self.composite = DenseNetCompositeLayer( + c_in, c_out, + kernel_size=1, p_dropout=p_dropout, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + ) + self.pool = nn.AvgPool2d(kernel_size=2, stride=2) + self.c_now = c_out + + def forward(self, x): + x = self.composite(x) + x = self.pool(x) + return x + + +class DenseNetCompositeLayer(nn.Module): + + def __init__(self, c_in, c_out, kernel_size=3, growth_rate=8, p_dropout=0.1, + activation_fn=nn.ReLU, normalization_fn=nn.BatchNorm2d, + transposed=False): + super(DenseNetCompositeLayer, self).__init__() + self.norm = normalization_fn(c_in, track_running_stats=False).to(device) + self.act = activation_fn(inplace=True) + if transposed: + assert kernel_size > 1 + self.conv = nn.ConvTranspose2d(c_in, c_out, kernel_size=kernel_size, + padding=1 if kernel_size > 1 else 0, + stride=1, bias=False).to(device) + else: + self.conv = nn.Conv2d(c_in, c_out, kernel_size=kernel_size, stride=1, + padding=1 if kernel_size > 1 else 0, bias=False).to(device) + nn.init.kaiming_normal_(self.conv.weight.data) + self.drop = nn.Dropout2d(p=p_dropout) if p_dropout > 1e-5 else None + self.c_now = c_out + + def forward(self, x): + x = self.norm(x) + x = self.act(x) + x = self.conv(x) + if self.drop is not None: + x = self.drop(x) + return x diff --git a/src/models/dt_ed.py b/src/models/dt_ed.py new file mode 100644 index 0000000..9ce90b0 --- /dev/null +++ b/src/models/dt_ed.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python3 + +# -------------------------------------------------------- +# Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License (1-Way Commercial) +# Code written by Seonwook Park, Shalini De Mello. +# -------------------------------------------------------- + +from collections import OrderedDict + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .densenet import ( + DenseNetInitialLayers, + DenseNetBlock, + DenseNetTransitionDown, +) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class DTED(nn.Module): + + def __init__(self, z_dim_app, z_dim_gaze, z_dim_head, + growth_rate=32, activation_fn=nn.LeakyReLU, + normalization_fn=nn.InstanceNorm2d, + decoder_input_c=16, + normalize_3d_codes=False, + normalize_3d_codes_axis=None, + use_triplet=False, + gaze_hidden_layer_neurons=64, + backprop_gaze_to_encoder=False, + ): + super(DTED, self).__init__() + + # Cache some specific configurations + self.normalize_3d_codes = normalize_3d_codes + self.normalize_3d_codes_axis = normalize_3d_codes_axis + self.use_triplet = use_triplet + self.gaze_hidden_layer_neurons = gaze_hidden_layer_neurons + self.backprop_gaze_to_encoder = backprop_gaze_to_encoder + if self.normalize_3d_codes: + assert self.normalize_3d_codes_axis is not None + + # Define feature map dimensions at bottleneck + bottleneck_shape = (2, 8) + self.bottleneck_shape = bottleneck_shape + + # The meaty parts + self.encoder = DenseNetEncoder( + num_blocks=4, + growth_rate=growth_rate, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + ) + c_now = list(self.children())[-1].c_now + self.decoder_input_c = decoder_input_c + enc_num_all = np.prod(bottleneck_shape) * self.decoder_input_c + self.decoder = DenseNetDecoder( + self.decoder_input_c, + num_blocks=4, + growth_rate=growth_rate, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + compression_factor=1.0, + ) + + # The latent code parts + self.z_dim_app = z_dim_app + self.z_dim_gaze = z_dim_gaze + self.z_dim_head = z_dim_head + z_num_all = 3 * (z_dim_gaze + z_dim_head) + z_dim_app + + self.fc_enc = self.linear(c_now, z_num_all) + self.fc_dec = self.linear(z_num_all, enc_num_all) + + self.build_gaze_layers(3 * z_dim_gaze) + + def build_gaze_layers(self, num_input_neurons, num_hidden_neurons=64): + self.gaze1 = self.linear(num_input_neurons, self.gaze_hidden_layer_neurons) + self.gaze2 = self.linear(self.gaze_hidden_layer_neurons, 3) + + def linear(self, f_in, f_out): + fc = nn.Linear(f_in, f_out) + nn.init.kaiming_normal(fc.weight.data) + nn.init.constant(fc.bias.data, val=0) + return fc + + def rotate_code(self, data, code, mode, fr=None, to=None): + """Must calculate transposed rotation matrices to be able to + post-multiply to 3D codes.""" + key_stem = 'R_' + mode + if fr is not None and to is not None: + rotate_mat = torch.matmul( + data[key_stem + '_' + fr], + torch.transpose(data[key_stem + '_' + to], 1, 2) + ) + elif to is not None: + rotate_mat = torch.transpose(data[key_stem + '_' + to], 1, 2) + elif fr is not None: + # transpose-of-inverse is itself + rotate_mat = data[key_stem + '_' + fr] + return torch.matmul(code, rotate_mat) + + def encode_to_z(self, data, suffix): + x = self.encoder(data['image_' + suffix]) + enc_output_shape = x.shape + x = x.mean(-1).mean(-1) # Global-Average Pooling + + # Create latent codes + z_all = self.fc_enc(x) + z_app = z_all[:, :self.z_dim_app] + z_all = z_all[:, self.z_dim_app:] + z_all = z_all.view(self.batch_size, -1, 3) + z_gaze_enc = z_all[:, :self.z_dim_gaze, :] + z_head_enc = z_all[:, self.z_dim_gaze:, :] + + z_gaze_enc = z_gaze_enc.view(self.batch_size, -1, 3) + z_head_enc = z_head_enc.view(self.batch_size, -1, 3) + return [z_app, z_gaze_enc, z_head_enc, x, enc_output_shape] + + def decode_to_image(self, codes): + z_all = torch.cat([code.view(self.batch_size, -1) for code in codes], dim=1) + x = self.fc_dec(z_all) + x = x.view(self.batch_size, self.decoder_input_c, *self.bottleneck_shape) + x = self.decoder(x) + return x + + def maybe_do_norm(self, code): + if self.normalize_3d_codes: + norm_axis = self.normalize_3d_codes_axis + assert code.dim() == 3 + assert code.shape[-1] == 3 + if norm_axis == 3: + b, f, _ = code.shape + code = code.view(b, -1) + normalized_code = F.normalize(code, dim=-1) + return normalized_code.view(b, f, -1) + else: + return F.normalize(code, dim=norm_axis) + return code + + def forward(self, data, loss_functions=None): + is_inference_time = ('image_b' not in data) + self.batch_size = data['image_a'].shape[0] + + # Encode input from a + (z_a_a, ze1_g_a, ze1_h_a, ze1_before_z_a, _) = self.encode_to_z(data, 'a') + if not is_inference_time: + z_a_b, ze1_g_b, ze1_h_b, _, _ = self.encode_to_z(data, 'b') + + # Make each row a unit vector through L2 normalization to constrain + # embeddings to the surface of a hypersphere + if self.normalize_3d_codes: + assert ze1_g_a.dim() == ze1_h_a.dim() == 3 + assert ze1_g_a.shape[-1] == ze1_h_a.shape[-1] == 3 + ze1_g_a = self.maybe_do_norm(ze1_g_a) + ze1_h_a = self.maybe_do_norm(ze1_h_a) + if not is_inference_time: + ze1_g_b = self.maybe_do_norm(ze1_g_b) + ze1_h_b = self.maybe_do_norm(ze1_h_b) + + # Gaze estimation output for image a + if self.backprop_gaze_to_encoder: + gaze_features = ze1_g_a.clone().view(self.batch_size, -1) + else: + # Detach input embeddings from graph! + gaze_features = ze1_g_a.detach().view(self.batch_size, -1) + gaze_a_hat = self.gaze2(F.relu_(self.gaze1(gaze_features))) + gaze_a_hat = F.normalize(gaze_a_hat, dim=-1) + + output_dict = { + 'gaze_a_hat': gaze_a_hat, + 'z_app': z_a_a, + 'z_gaze_enc': ze1_g_a, + 'z_head_enc': ze1_h_a, + 'canon_z_gaze_a': self.rotate_code(data, ze1_g_a, 'gaze', fr='a'), + 'canon_z_head_a': self.rotate_code(data, ze1_h_a, 'head', fr='a'), + } + if 'R_gaze_b' not in data: + return output_dict + + if not is_inference_time: + output_dict['canon_z_gaze_b'] = self.rotate_code(data, ze1_g_b, 'gaze', fr='b') + output_dict['canon_z_head_b'] = self.rotate_code(data, ze1_h_b, 'head', fr='b') + + # Rotate codes + zd1_g_b = self.rotate_code(data, ze1_g_a, 'gaze', fr='a', to='b') + zd1_h_b = self.rotate_code(data, ze1_h_a, 'head', fr='a', to='b') + output_dict['z_gaze_dec'] = zd1_g_b + output_dict['z_head_dec'] = zd1_h_b + + # Reconstruct + x_b_hat = self.decode_to_image([z_a_a, zd1_g_b, zd1_h_b]) + output_dict['image_b_hat'] = x_b_hat + + # If loss functions specified, apply them + if loss_functions is not None: + losses_dict = OrderedDict() + for key, func in loss_functions.items(): + losses = func(data, output_dict) # may be dict or single value + if isinstance(losses, dict): + for sub_key, loss in losses.items(): + losses_dict[key + '_' + sub_key] = loss + else: + losses_dict[key] = losses + return output_dict, losses_dict + + return output_dict + + +class DenseNetEncoder(nn.Module): + + def __init__(self, growth_rate=8, num_blocks=4, num_layers_per_block=4, + p_dropout=0.0, compression_factor=1.0, + activation_fn=nn.ReLU, normalization_fn=nn.BatchNorm2d): + super(DenseNetEncoder, self).__init__() + self.c_at_end_of_each_scale = [] + + # Initial down-sampling conv layers + self.initial = DenseNetInitialLayers(growth_rate=growth_rate, + activation_fn=activation_fn, + normalization_fn=normalization_fn) + c_now = list(self.children())[-1].c_now + self.c_at_end_of_each_scale += list(self.children())[-1].c_list + + assert (num_layers_per_block % 2) == 0 + for i in range(num_blocks): + i_ = i + 1 + # Define dense block + self.add_module('block%d' % i_, DenseNetBlock( + c_now, + num_layers=num_layers_per_block, + growth_rate=growth_rate, + p_dropout=p_dropout, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + )) + c_now = list(self.children())[-1].c_now + self.c_at_end_of_each_scale.append(c_now) + + # Define transition block if not last layer + if i < (num_blocks - 1): + self.add_module('trans%d' % i_, DenseNetTransitionDown( + c_now, p_dropout=p_dropout, + compression_factor=compression_factor, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + )) + c_now = list(self.children())[-1].c_now + self.c_now = c_now + + def forward(self, x): + # Apply initial layers and dense blocks + for name, module in self.named_children(): + if name == 'initial': + x, prev_scale_x = module(x) + else: + x = module(x) + return x + + +class DenseNetDecoder(nn.Module): + + def __init__(self, c_in, growth_rate=8, num_blocks=4, num_layers_per_block=4, + p_dropout=0.0, compression_factor=1.0, + activation_fn=nn.ReLU, normalization_fn=nn.BatchNorm2d, + use_skip_connections_from=None): + super(DenseNetDecoder, self).__init__() + + self.use_skip_connections = (use_skip_connections_from is not None) + if self.use_skip_connections: + c_to_concat = use_skip_connections_from.c_at_end_of_each_scale + c_to_concat = list(reversed(c_to_concat))[1:] + else: + c_to_concat = [0] * (num_blocks + 2) + + assert (num_layers_per_block % 2) == 0 + c_now = c_in + for i in range(num_blocks): + i_ = i + 1 + # Define dense block + self.add_module('block%d' % i_, DenseNetBlock( + c_now, + num_layers=num_layers_per_block, + growth_rate=growth_rate, + p_dropout=p_dropout, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + transposed=True, + )) + c_now = list(self.children())[-1].c_now + + # Define transition block if not last layer + if i < (num_blocks - 1): + self.add_module('trans%d' % i_, DenseNetTransitionUp( + c_now, p_dropout=p_dropout, + compression_factor=compression_factor, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + )) + c_now = list(self.children())[-1].c_now + c_now += c_to_concat[i] + + # Last up-sampling conv layers + self.last = DenseNetDecoderLastLayers(c_now, + growth_rate=growth_rate, + activation_fn=activation_fn, + normalization_fn=normalization_fn, + skip_connection_growth=c_to_concat[-1]) + self.c_now = 1 + + def forward(self, x): + # Apply initial layers and dense blocks + for name, module in self.named_children(): + x = module(x) + return x + + +class DenseNetDecoderLastLayers(nn.Module): + + def __init__(self, c_in, growth_rate=8, activation_fn=nn.ReLU, + normalization_fn=nn.BatchNorm2d, + skip_connection_growth=0): + super(DenseNetDecoderLastLayers, self).__init__() + # First deconv + self.conv1 = nn.ConvTranspose2d(c_in, 4 * growth_rate, bias=False, + kernel_size=3, stride=2, padding=1, + output_padding=1) + nn.init.kaiming_normal_(self.conv1.weight.data) + + # Second deconv + c_in = 4 * growth_rate + skip_connection_growth + self.norm2 = normalization_fn(c_in, track_running_stats=False).to(device) + self.act = activation_fn(inplace=True) + self.conv2 = nn.ConvTranspose2d(c_in, 2 * growth_rate, bias=False, + kernel_size=3, stride=2, padding=1, + output_padding=1) + nn.init.kaiming_normal_(self.conv2.weight.data) + + # Final conv + c_in = 2 * growth_rate + c_out = 3 + self.norm3 = normalization_fn(c_in, track_running_stats=False).to(device) + self.conv3 = nn.Conv2d(c_in, c_out, bias=False, + kernel_size=3, stride=1, padding=1) + nn.init.kaiming_normal_(self.conv3.weight.data) + self.c_now = c_out + + def forward(self, x): + x = self.conv1(x) + # + x = self.norm2(x) + x = self.act(x) + x = self.conv2(x) + # + x = self.norm3(x) + x = self.act(x) + x = self.conv3(x) + return x + + +class DenseNetTransitionUp(nn.Module): + + def __init__(self, c_in, compression_factor=0.1, p_dropout=0.1, + activation_fn=nn.ReLU, normalization_fn=nn.BatchNorm2d): + super(DenseNetTransitionUp, self).__init__() + c_out = int(compression_factor * c_in) + self.norm = normalization_fn(c_in, track_running_stats=False).to(device) + self.act = activation_fn(inplace=True) + self.conv = nn.ConvTranspose2d(c_in, c_out, kernel_size=3, + stride=2, padding=1, output_padding=1, + bias=False).to(device) + nn.init.kaiming_normal_(self.conv.weight.data) + self.drop = nn.Dropout2d(p=p_dropout) if p_dropout > 1e-5 else None + self.c_now = c_out + + def forward(self, x): + x = self.norm(x) + x = self.act(x) + x = self.conv(x) + return x