diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..fa3e3ffcf --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# LaTex +main.pdf +supp.pdf +**/*.aux +**/*.log +**/*.synctex.gz +**/*.aux +**/*.bbl +**/*.blg +**/*.brf +**/*.sublime-project +**/*.sublime-workspace +**/*.fdb_latexmk +**/*.fls +**/*.toc + +tools/debug.sh + +# MacOS stuff +.DS_Store +**/.DS_Store + +**/__pycache__ +**/*.pyc +**/.settings +.project +.pydevproject + +# external/* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/INSTALL.md b/INSTALL.md index c132d760d..bd34e63c0 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -16,6 +16,7 @@ - psutil: `pip install psutil` - OpenCV: `pip install opencv-python` - torchvision: `pip install torchvision` or `conda install torchvision -c pytorch` +- librosa: `pip install librosa` (if using Audiovisual SlowFast Networks) - tensorboard: `pip install tensorboard` - moviepy: (optional, for visualizing video on tensorboard) `conda install -c conda-forge moviepy` or `pip install moviepy` - [Detectron2](https://github.com/facebookresearch/detectron2): diff --git a/configs/Kinetics/AVSLOWFAST_4x16_R50.yaml b/configs/Kinetics/AVSLOWFAST_4x16_R50.yaml new file mode 100644 index 000000000..bc4d1b1f4 --- /dev/null +++ b/configs/Kinetics/AVSLOWFAST_4x16_R50.yaml @@ -0,0 +1,108 @@ +TRAIN: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 64 + EVAL_PERIOD: 10 + CHECKPOINT_PERIOD: 1 + AUTO_RESUME: True + # CHECKPOINT_FILE_PATH: ../../data/output/checkpoints/avslowfast.pth + # CHECKPOINT_TYPE: pytorch # caffe2 or pytorch +DATA: + USE_BGR_ORDER: False # False + NUM_FRAMES: 32 + SAMPLING_RATE: 2 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 256 + INPUT_CHANNEL_NUM: [3, 3, 1] + USE_AUDIO: True + GET_MISALIGNED_AUDIO: True + AUDIO_SAMPLE_RATE: 16000 + AUDIO_WIN_SZ: 32 + AUDIO_STEP_SZ: 16 + AUDIO_FRAME_NUM: 128 + AUDIO_MEL_NUM: 80 + AUDIO_MISALIGNED_GAP: 32 # half second + LOGMEL_MEAN: -7.03 # -7.03, -24.227 + LOGMEL_STD: 4.66 # 4.66, 1.0 + EASY_NEG_RATIO: 0.75 + MIX_NEG_EPOCH: 96 +SLOWFAST: + ALPHA: 8 + BETA_INV: 8 + FUSION_CONV_CHANNEL_RATIO: 2 + FUSION_KERNEL_SZ: 5 + AU_ALPHA: 32 + AU_BETA_INV: 2 + AU_FUSION_CONV_CHANNEL_MODE: ByDim # ByDim, ByRatio + AU_FUSION_CONV_CHANNEL_RATIO: 0.25 + AU_FUSION_CONV_CHANNEL_DIM: 64 + AU_FUSION_KERNEL_SZ: 5 + AU_FUSION_CONV_NUM: 2 + AU_REDUCE_TF_DIM: True + FS_FUSION: [False, False, True, True] + AFS_FUSION: [False, False, True, True] + AVS_FLAG: [False, False, True, True, True] + AVS_PROJ_DIM: 64 + AVS_VAR_THRESH: 0.01 + AVS_DUPLICATE_THRESH: 0.99999 + DROPPATHWAY_RATE: 0.8 # 0.8 +RESNET: + ZERO_INIT_FINAL_BN: True + WIDTH_PER_GROUP: 64 + NUM_GROUPS: 1 + DEPTH: 50 + TRANS_FUNC: bottleneck_transform + AUDIO_TRANS_FUNC: tf_bottleneck_transform_v1 + AUDIO_TRANS_NUM: 2 + STRIDE_1X1: False + # 18: [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + # 34: [[3, 3, 3], [4, 4, 4], [6, 6, 6], [3, 3, 3]] + # 50: [[3, 3, 3], [4, 4, 4], [6, 6, 6], [3, 3, 3]] + # 101: [[3, 3, 3], [4, 4, 4], [23, 23, 23], [3, 3, 3]] + # 152: [[3, 3, 3], [8, 8, 8], [36, 36, 36], [3, 3, 3]] + NUM_BLOCK_TEMP_KERNEL: [[3, 3, 3], [4, 4, 4], [6, 6, 6], [3, 3, 3]] + SPATIAL_DILATIONS: [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]] +NONLOCAL: + LOCATION: [[[], [], []], [[], [], []], [[], [], []], [[], [], []]] + GROUP: [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]] + POOL: [ + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + ] + INSTANTIATION: dot_product +BN: + USE_PRECISE_STATS: True + NUM_BATCHES_PRECISE: 200 + MOMENTUM: 0.1 + WEIGHT_DECAY: 0.0 +SOLVER: + BASE_LR: 0.1 # 0.1 + LR_POLICY: cosine + MAX_EPOCH: 196 + MOMENTUM: 0.9 + WEIGHT_DECAY: 1e-4 + WARMUP_EPOCHS: 34.0 # 34.0 + WARMUP_START_LR: 0.01 # 0.01 + OPTIMIZING_METHOD: sgd +MODEL: + NUM_CLASSES: 400 + MODEL_NAME: AVSlowFast + ARCH: avslowfast + LOSS_FUNC: cross_entropy + DROPOUT_RATE: 0.5 +TEST: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 64 + # CHECKPOINT_FILE_PATH: ../../data/output/checkpoints/avslowfast.pth + # CHECKPOINT_TYPE: pytorch # caffe2 or pytorch +DATA_LOADER: + NUM_WORKERS: 8 # 8 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 1 +RNG_SEED: 0 +OUTPUT_DIR: ./output/AVSlowFast-R50-4x16 diff --git a/configs/Kinetics/AVSLOWFAST_8x8_R50.yaml b/configs/Kinetics/AVSLOWFAST_8x8_R50.yaml new file mode 100644 index 000000000..bfa2be51c --- /dev/null +++ b/configs/Kinetics/AVSLOWFAST_8x8_R50.yaml @@ -0,0 +1,108 @@ +TRAIN: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 32 + EVAL_PERIOD: 10 + CHECKPOINT_PERIOD: 1 + AUTO_RESUME: True + # CHECKPOINT_FILE_PATH: ../../data/output/checkpoints/avslowfast.pth + # CHECKPOINT_TYPE: pytorch # caffe2 or pytorch +DATA: + USE_BGR_ORDER: False # False + NUM_FRAMES: 32 + SAMPLING_RATE: 2 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 256 + INPUT_CHANNEL_NUM: [3, 3, 1] + USE_AUDIO: True + GET_MISALIGNED_AUDIO: True + AUDIO_SAMPLE_RATE: 16000 + AUDIO_WIN_SZ: 32 + AUDIO_STEP_SZ: 16 + AUDIO_FRAME_NUM: 128 + AUDIO_MEL_NUM: 80 + AUDIO_MISALIGNED_GAP: 32 # half second + LOGMEL_MEAN: -7.03 # -7.03, -24.227 + LOGMEL_STD: 4.66 # 4.66, 1.0 + EASY_NEG_RATIO: 0.75 + MIX_NEG_EPOCH: 96 +SLOWFAST: + ALPHA: 4 + BETA_INV: 8 + FUSION_CONV_CHANNEL_RATIO: 2 + FUSION_KERNEL_SZ: 7 + AU_ALPHA: 16 + AU_BETA_INV: 2 + AU_FUSION_CONV_CHANNEL_MODE: ByDim # ByDim, ByRatio + AU_FUSION_CONV_CHANNEL_RATIO: 0.25 + AU_FUSION_CONV_CHANNEL_DIM: 64 + AU_FUSION_KERNEL_SZ: 5 + AU_FUSION_CONV_NUM: 2 + AU_REDUCE_TF_DIM: True + FS_FUSION: [False, False, True, True] + AFS_FUSION: [False, False, True, True] + AVS_FLAG: [False, False, True, True, True] + AVS_PROJ_DIM: 64 + AVS_VAR_THRESH: 0.01 + AVS_DUPLICATE_THRESH: 0.99999 + DROPPATHWAY_RATE: 0.8 # 0.8 +RESNET: + ZERO_INIT_FINAL_BN: True + WIDTH_PER_GROUP: 64 + NUM_GROUPS: 1 + DEPTH: 50 + TRANS_FUNC: bottleneck_transform + AUDIO_TRANS_FUNC: tf_bottleneck_transform_v1 + AUDIO_TRANS_NUM: 2 + STRIDE_1X1: False + # 18: [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + # 34: [[3, 3, 3], [4, 4, 4], [6, 6, 6], [3, 3, 3]] + # 50: [[3, 3, 3], [4, 4, 4], [6, 6, 6], [3, 3, 3]] + # 101: [[3, 3, 3], [4, 4, 4], [23, 23, 23], [3, 3, 3]] + # 152: [[3, 3, 3], [8, 8, 8], [36, 36, 36], [3, 3, 3]] + NUM_BLOCK_TEMP_KERNEL: [[3, 3, 3], [4, 4, 4], [6, 6, 6], [3, 3, 3]] + SPATIAL_DILATIONS: [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]] +NONLOCAL: + LOCATION: [[[], [], []], [[], [], []], [[], [], []], [[], [], []]] + GROUP: [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]] + POOL: [ + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + ] + INSTANTIATION: dot_product +BN: + USE_PRECISE_STATS: True + NUM_BATCHES_PRECISE: 400 + MOMENTUM: 0.1 + WEIGHT_DECAY: 0.0 +SOLVER: + BASE_LR: 0.1 # 0.1 + LR_POLICY: cosine + MAX_EPOCH: 196 + MOMENTUM: 0.9 + WEIGHT_DECAY: 1e-4 + WARMUP_EPOCHS: 34.0 # 34.0 + WARMUP_START_LR: 0.01 # 0.01 + OPTIMIZING_METHOD: sgd +MODEL: + NUM_CLASSES: 400 + MODEL_NAME: AVSlowFast + ARCH: avslowfast + LOSS_FUNC: cross_entropy + DROPOUT_RATE: 0.5 +TEST: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 32 + # CHECKPOINT_FILE_PATH: ../../data/output/checkpoints/avslowfast.pth + # CHECKPOINT_TYPE: pytorch # caffe2 or pytorch +DATA_LOADER: + NUM_WORKERS: 8 # 8 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 1 +RNG_SEED: 0 +OUTPUT_DIR: ./output/AVSlowFast-R50-8x8 diff --git a/projects/avslowfast/README.md b/projects/avslowfast/README.md new file mode 100644 index 000000000..61cec4c8b --- /dev/null +++ b/projects/avslowfast/README.md @@ -0,0 +1,37 @@ +# Getting Started with PyAVSlowFast + +This section supplements the original doc in PySlowFast (attached below) and provide instructions on how to start training AVSlowFast model with this codebase. + +First, a note that `DATA.PATH_TO_DATA_DIR` points to the directory where annotation csv files reside and `DATA.PATH_PREFIX` to the root of the data directory. + +Then, issue the following training command +``` +python tools/run_net.py \ + --cfg configs/Kinetics/AVSLOWFAST_4x16_R50.yaml \ + DATA.PATH_TO_DATA_DIR path_to_your_annotation \ + DATA.PATH_PREFIX path_to_your_dataset_root \ + NUM_GPUS 8 \ + DATA_LOADER.NUM_WORKERS 8 \ + TRAIN.BATCH_SIZE 64 \ +``` + +For testing, run the following +``` +python tools/run_net.py \ + --cfg configs/Kinetics/AVSLOWFAST_4x16_R50.yaml \ + DATA.PATH_TO_DATA_DIR path_to_your_annotation \ + DATA.PATH_PREFIX path_to_your_dataset_root \ + TEST.BATCH_SIZE 32 \ + TEST.CHECKPOINT_FILE_PATH path_to_your_checkpoint \ + TRAIN.ENABLE False \ +``` + +## Citing AVSlowFast +Please cite AVSlowFast if you use it in your research, you can use the following BibTeX entry. +```BibTeX +@article{xiao-avslowfast2020, + author = {Xiao, Fanyi and Lee, Yong Jae and Grauman, Kristen and Malik, Jitendra and Feichtenhofer, Christoph}, + title = {{Audiovisual SlowFast Networks for Video Recognition}}, + journal = {arXiv preprint arXiv:2001.08740}, + Year = {2020}} +``` diff --git a/setup.py b/setup.py index 9194bc968..2e62a2e9d 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ "pandas", "torchvision>=0.4.2", "sklearn", - "tensorboard", ], extras_require={"tensorboard_video_visualization": ["moviepy"]}, packages=find_packages(exclude=("configs", "tests")), diff --git a/slowfast/config/defaults.py b/slowfast/config/defaults.py index 586620b88..056c29f7d 100644 --- a/slowfast/config/defaults.py +++ b/slowfast/config/defaults.py @@ -108,6 +108,12 @@ # Transformation function. _C.RESNET.TRANS_FUNC = "bottleneck_transform" +# Transformation for audio pathway. +_C.RESNET.AUDIO_TRANS_FUNC = "tf_bottleneck_transform" + +# Number of ResStage that applies audio-specific transformation. +_C.RESNET.AUDIO_TRANS_NUM = 2 + # Number of groups. 1 for ResNet, and larger than 1 for ResNeXt). _C.RESNET.NUM_GROUPS = 1 @@ -185,7 +191,7 @@ _C.MODEL.SINGLE_PATHWAY_ARCH = ["c2d", "i3d", "slow"] # Model architectures that has multiple pathways. -_C.MODEL.MULTI_PATHWAY_ARCH = ["slowfast"] +_C.MODEL.MULTI_PATHWAY_ARCH = ["slowfast", "avslowfast"] # Dropout rate before final projection in the backbone. _C.MODEL.DROPOUT_RATE = 0.5 @@ -217,6 +223,37 @@ # pathway. _C.SLOWFAST.FUSION_KERNEL_SZ = 5 +# Audio pathway channel ratio +_C.SLOWFAST.AU_BETA_INV = 2 + +# Frame rate ratio between audio and slow pathways +_C.SLOWFAST.AU_ALPHA = 32 + +_C.SLOWFAST.AU_FUSION_CONV_CHANNEL_RATIO = 0.125 + +_C.SLOWFAST.AU_FUSION_CONV_CHANNEL_DIM = 64 + +_C.SLOWFAST.AU_FUSION_CONV_CHANNEL_MODE = 'ByRatio' # ByDim, ByRatio + +_C.SLOWFAST.AU_FUSION_KERNEL_SZ = 5 + +_C.SLOWFAST.AU_FUSION_CONV_NUM = 2 + +_C.SLOWFAST.AU_REDUCE_TF_DIM = True + +_C.SLOWFAST.FS_FUSION = [True, True, True, True] + +_C.SLOWFAST.AFS_FUSION = [True, True, True, True] + +_C.SLOWFAST.AVS_FLAG = [False, False, False, False, False] + +_C.SLOWFAST.AVS_PROJ_DIM = 64 + +_C.SLOWFAST.AVS_VAR_THRESH = 0.01 + +_C.SLOWFAST.AVS_DUPLICATE_THRESH = 0.99 + +_C.SLOWFAST.DROPPATHWAY_RATE = 0.8 # ----------------------------------------------------------------------------- # Data options @@ -243,13 +280,19 @@ # The mean value of the video raw pixels across the R G B channels. _C.DATA.MEAN = [0.45, 0.45, 0.45] -# List of input frame channel dimensions. +# List of input frame channel dimensions. _C.DATA.INPUT_CHANNEL_NUM = [3, 3] # The std value of the video raw pixels across the R G B channels. _C.DATA.STD = [0.225, 0.225, 0.225] +# Mean of logmel spectrogram +_C.DATA.LOGMEL_MEAN = 0.0 + +# Std of logmel spectrogram +_C.DATA.LOGMEL_STD = 1.0 + # The spatial augmentation jitter scales for training. _C.DATA.TRAIN_JITTER_SCALES = [256, 320] @@ -259,6 +302,29 @@ # The spatial crop size for testing. _C.DATA.TEST_CROP_SIZE = 256 +# Decode audio +_C.DATA.USE_AUDIO = False + +_C.DATA.GET_MISALIGNED_AUDIO = False + +_C.DATA.AUDIO_SAMPLE_RATE = 16000 + +_C.DATA.AUDIO_WIN_SZ = 32 + +_C.DATA.AUDIO_STEP_SZ = 16 + +_C.DATA.AUDIO_FRAME_NUM = 128 + +_C.DATA.AUDIO_MEL_NUM = 40 + +_C.DATA.AUDIO_MISALIGNED_GAP = 32 + +_C.DATA.EASY_NEG_RATIO = 0.75 + +_C.DATA.MIX_NEG_EPOCH = 96 + +_C.DATA.USE_BGR_ORDER = False + # Input videos may has different fps, convert it to the target video fps before # frame sampling. _C.DATA.TARGET_FPS = 30 @@ -530,6 +596,7 @@ _C.MULTIGRID.DEFAULT_T = 0 _C.MULTIGRID.DEFAULT_S = 0 + # ----------------------------------------------------------------------------- # Tensorboard Visualization Options # ----------------------------------------------------------------------------- @@ -694,7 +761,6 @@ "bend/bow (at the waist)", ] - # Add custom config with default values. custom_config.add_custom_config(_C) diff --git a/slowfast/datasets/decoder.py b/slowfast/datasets/decoder.py index c7073e8a4..2f951dd35 100644 --- a/slowfast/datasets/decoder.py +++ b/slowfast/datasets/decoder.py @@ -6,6 +6,7 @@ import random import torch import torchvision.io as io +import librosa def temporal_sampling(frames, start_idx, end_idx, num_samples): @@ -196,8 +197,39 @@ def torchvision_decode( return v_frames, video_meta["video_fps"], decode_all_video +def gen_logmel(y, orig_sr, sr, win_sz, step_sz, n_mels): + """ + Generate log-mel-spectrogram features from audio waveform + + Args: + y (ndarray): audio waveform input. + orig_sr (int): original sampling rate of audio inputs. + sr (int): targeted sampling rate. + win_sz (int): window step size in ms. + step_sz (int): step size in ms. + n_mels (int): number of frequency bins. + Returns: + logS (ndarray): log-mel-spectrogram computed from the input waveform. + """ + n_fft = int(float(sr) / 1000 * win_sz) + hop_length = int(float(sr) / 1000 * step_sz) + win_length = int(float(sr) / 1000 * win_sz) + eps = 1e-8 + y = y.reshape(-1) + y = np.asfortranarray(y) + y_resample = librosa.resample(y, orig_sr, sr, res_type='polyphase') + T = len(y_resample) / sr + S = librosa.feature.melspectrogram(y=y_resample, sr=sr, n_fft=n_fft, + win_length=win_length, hop_length=hop_length, + n_mels=n_mels, htk=True, center=False) + logS = np.log(S+eps) + return logS + + def pyav_decode( - container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30 + container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30, + decode_audio=False, extract_logmel=True, decode_all_audio=False, + au_sr=16000, au_win_sz=32, au_step_sz=16, au_n_mels=40, ): """ Convert the video from its original fps to the target_fps. If the video @@ -230,7 +262,7 @@ def pyav_decode( fps = float(container.streams.video[0].average_rate) frames_length = container.streams.video[0].frames duration = container.streams.video[0].duration - + if duration is None: # If failed to fetch the decoding information, decode the entire video. decode_all_video = True @@ -248,7 +280,9 @@ def pyav_decode( video_start_pts = int(start_idx * timebase) video_end_pts = int(end_idx * timebase) - frames = None + frames, audio_frames, au_raw_sr = None, None, None + meta = {} + # If video stream was found, fetch video frames from the video. if container.streams.video: video_frames, max_pts = pyav_decode_stream( @@ -258,11 +292,87 @@ def pyav_decode( container.streams.video[0], {"video": 0}, ) - container.close() - frames = [frame.to_rgb().to_ndarray() for frame in video_frames] frames = torch.as_tensor(np.stack(frames)) - return frames, fps, decode_all_video + + meta.update({ + 'video_start': video_start_pts / duration, + 'video_end': video_end_pts / duration, + }) + + # If audio stream was found, extract audio waveform from the video. + if decode_audio and container.streams.audio: + au_raw_sr = container.streams.audio[0].codec_context.sample_rate + audio_duration = container.streams.audio[0].duration + # audio_frames_length = container.streams.audio[0].frames + # audio_timebase = audio_duration / audio_frames_length + if decode_all_video or decode_all_audio: + audio_start_pts = 0 + audio_end_pts = math.inf + else: + audio_start_pts = int(start_idx / frames_length * audio_duration) + audio_end_pts = int(end_idx / frames_length * audio_duration) + audio_frames, audio_max_pts = pyav_decode_stream( + container, + audio_start_pts, + audio_end_pts, + container.streams.audio[0], + {"audio": 0}, + ) + + audio_frames = [frame.to_ndarray() for frame in audio_frames] + if len({x.shape[1] for x in audio_frames}) == 1: + # This is a bit faster then the alternative + audio_frames = np.concatenate([x[None] for x in audio_frames], axis=0) + audio_frames = np.mean(audio_frames, axis=1) + audio_frames = audio_frames.reshape(-1) + else: + audio_frames = [np.mean(x, axis=0) for x in audio_frames] + audio_frames = np.concatenate(audio_frames, axis=0) + meta.update({ + 'audio_start': audio_start_pts / audio_duration, + 'audio_end': audio_end_pts / audio_duration, + }) + + # Extract log-mel-spectrogram features. + if extract_logmel: + audio_frames = gen_logmel(audio_frames, au_raw_sr, au_sr, + au_win_sz, au_step_sz, au_n_mels) + audio_frames = audio_frames.transpose(1, 0) # [F,T]->[T,F] + audio_frames = torch.as_tensor(audio_frames) + + meta.update({ + 'decode_all_video': decode_all_video, + 'decode_all_audio': decode_all_audio, + }) + + container.close() + + return frames, fps, audio_frames, au_raw_sr, meta + + +def sample_misaligned_start(start_idx, gap, frames): + """ + Decide the starting point of a misaligned (i.e., negative) audio sample, + which can be used for audiovisual synchronization training for self and + semi-supervised training. + + Args: + start_idx (float): starting point of the positive sample. + gap (int): the minimal gap to maintain between positive and negative samples. + frames (tensor): decoded log-mel-spectrogram features. + Returns: + misaligned_start (float): starting point of the misaligned sample. + """ + total_frames = frames.shape[0] + pre_sample_region = (0, max(start_idx - gap, 0)) + post_sample_region = (min(start_idx + gap, total_frames), total_frames) + pre_size = pre_sample_region[1] - pre_sample_region[0] + post_size = post_sample_region[1] - post_sample_region[0] + misaligned_start = random.random() * (pre_size + post_size) + if misaligned_start > pre_size: + misaligned_start = misaligned_start - pre_size + post_sample_region[0] + return misaligned_start def decode( @@ -275,6 +385,16 @@ def decode( target_fps=30, backend="pyav", max_spatial_scale=0, + # audio-related + decode_audio=False, + get_misaligned_audio=False, + extract_logmel=False, + au_sr=16000, + au_win_sz=32, + au_step_sz=16, + num_audio_frames=128, + au_n_mels=40, + au_misaligned_gap=32, ): """ Decode the video and perform temporal sampling. @@ -303,16 +423,26 @@ def decode( """ # Currently support two decoders: 1) PyAV, and 2) TorchVision. assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx) + if decode_audio: assert backend == "pyav", 'Use PyAV for audio decoding' + frames, audio_frames, misaligned_audio_frames = None, None, None try: if backend == "pyav": - frames, fps, decode_all_video = pyav_decode( + frames, fps, audio_frames, au_raw_sr, meta = pyav_decode( container, sampling_rate, num_frames, clip_idx, num_clips, target_fps, + decode_audio=decode_audio, + extract_logmel=extract_logmel, + decode_all_audio=get_misaligned_audio, + au_sr=au_sr, + au_win_sz=au_win_sz, + au_step_sz=au_step_sz, + au_n_mels=au_n_mels, ) + decode_all_video = meta['decode_all_video'] elif backend == "torchvision": frames, fps, decode_all_video = torchvision_decode( container, @@ -331,11 +461,11 @@ def decode( ) except Exception as e: print("Failed to decode by {} with exception: {}".format(backend, e)) - return None + return frames, audio_frames, misaligned_audio_frames # Return None if the frames was not decoded successfully. if frames is None or frames.size(0) == 0: - return None + return frames, audio_frames, misaligned_audio_frames start_idx, end_idx = get_start_end_idx( frames.shape[0], @@ -343,6 +473,48 @@ def decode( clip_idx if decode_all_video else 0, num_clips if decode_all_video else 1, ) + if decode_audio and audio_frames is not None: + if get_misaligned_audio: + video_start = meta['video_start'] + video_end = meta['video_end'] + video_duration = video_end - video_start + audio_start_idx = (video_start + start_idx / frames.shape[0] * \ + video_duration) * audio_frames.shape[0] + audio_end_idx = (video_start + end_idx / frames.shape[0] * \ + video_duration) * audio_frames.shape[0] + else: + audio_start_idx = start_idx / frames.shape[0] * audio_frames.shape[0] + audio_end_idx = end_idx / frames.shape[0] * audio_frames.shape[0] + # audio_end_idx = audio_start_idx + num_audio_frames - 1 + # Perform temporal sampling from the decoded video. frames = temporal_sampling(frames, start_idx, end_idx, num_frames) - return frames + + # Perform temporal sampling from the decoded audio. + if decode_audio and audio_frames is not None: + if get_misaligned_audio: + audio_frame_len = audio_end_idx - audio_start_idx + misaligned_audio_start_idx = sample_misaligned_start( + audio_start_idx, + au_misaligned_gap, + audio_frames, + ) + misaligned_audio_end_idx = misaligned_audio_start_idx + audio_frame_len + misaligned_audio_frames = temporal_sampling( + audio_frames, + misaligned_audio_start_idx, + misaligned_audio_end_idx, + num_audio_frames + ) + misaligned_audio_frames = misaligned_audio_frames.reshape( + 1, + 1, + misaligned_audio_frames.size(0), + misaligned_audio_frames.size(1) + ) + audio_frames = temporal_sampling(audio_frames, audio_start_idx, + audio_end_idx, num_audio_frames) + audio_frames = audio_frames.reshape(1, 1, \ + audio_frames.size(0), audio_frames.size(1)) + + return frames, audio_frames, misaligned_audio_frames diff --git a/slowfast/datasets/kinetics.py b/slowfast/datasets/kinetics.py index 100417f7a..62ff772d4 100644 --- a/slowfast/datasets/kinetics.py +++ b/slowfast/datasets/kinetics.py @@ -72,6 +72,7 @@ def __init__(self, cfg, mode, num_retries=10): logger.info("Constructing Kinetics {}...".format(mode)) self._construct_loader() + def _construct_loader(self): """ @@ -204,7 +205,7 @@ def __getitem__(self, index): continue # Decode video. Meta info is used to perform selective decoding. - frames = decoder.decode( + frames, audio_frames, misaligned_audio_frames = decoder.decode( video_container, sampling_rate, self.cfg.DATA.NUM_FRAMES, @@ -214,6 +215,16 @@ def __getitem__(self, index): target_fps=self.cfg.DATA.TARGET_FPS, backend=self.cfg.DATA.DECODING_BACKEND, max_spatial_scale=max_scale, + # audio-related configs + decode_audio=self.cfg.DATA.USE_AUDIO, + get_misaligned_audio=self.cfg.DATA.GET_MISALIGNED_AUDIO, + extract_logmel=self.cfg.DATA.USE_AUDIO, + au_sr=self.cfg.DATA.AUDIO_SAMPLE_RATE, + au_win_sz=self.cfg.DATA.AUDIO_WIN_SZ, + au_step_sz=self.cfg.DATA.AUDIO_STEP_SZ, + num_audio_frames=self.cfg.DATA.AUDIO_FRAME_NUM, + au_n_mels=self.cfg.DATA.AUDIO_MEL_NUM, + au_misaligned_gap=self.cfg.DATA.AUDIO_MISALIGNED_GAP, ) # If decoding failed (wrong format, video is too short, and etc), @@ -221,6 +232,12 @@ def __getitem__(self, index): if frames is None: index = random.randint(0, len(self._path_to_videos) - 1) continue + + # If audio sampling is turned on but no audio is available, + # we discard this sample and continue. + if self.cfg.DATA.USE_AUDIO and audio_frames is None: + index = random.randint(0, len(self._path_to_videos) - 1) + continue # Perform color normalization. frames = utils.tensor_normalize( @@ -238,9 +255,30 @@ def __getitem__(self, index): random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP, inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE, ) + + # The default order is RGB, this is to convert it + # to BGR if needed. + if self.cfg.DATA.USE_BGR_ORDER: + frames = frames[[2, 1, 0], ...] + + # Optionally normalize audio inputs (log-mel-spectrogram) + if self.cfg.DATA.USE_AUDIO: + audio_frames = utils.tensor_normalize( + audio_frames, + self.cfg.DATA.LOGMEL_MEAN, + self.cfg.DATA.LOGMEL_STD + ) + if self.cfg.DATA.GET_MISALIGNED_AUDIO: + misaligned_audio_frames = utils.tensor_normalize( + misaligned_audio_frames, + self.cfg.DATA.LOGMEL_MEAN, + self.cfg.DATA.LOGMEL_STD + ) + audio_frames = torch.cat([audio_frames, \ + misaligned_audio_frames], dim=0) label = self._labels[index] - frames = utils.pack_pathway_output(self.cfg, frames) + frames = utils.pack_pathway_output(self.cfg, frames, audio_frames) return frames, label, index, {} else: raise RuntimeError( diff --git a/slowfast/datasets/loader.py b/slowfast/datasets/loader.py index 2a0ab045a..5e25753f5 100644 --- a/slowfast/datasets/loader.py +++ b/slowfast/datasets/loader.py @@ -52,6 +52,33 @@ def detection_collate(batch): return inputs, labels, video_idx, collated_extra_data +def shuffle_misaligned_audio(epoch, inputs, cfg): + """ + Shuffle the misaligned (negative) input audio clips, + such that creating positive/negative pairs that are + from different videos. + + Args: + epoch (int): the current epoch number. + inputs (list of tensors): inputs to model, + inputs[2] corresponds to audio inputs. + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + """ + + if len(inputs) > 2 and cfg.DATA.GET_MISALIGNED_AUDIO: + N = inputs[2].size(0) + # We only leave "hard negatives" after + # cfg.DATA.MIX_NEG_EPOCH epochs + SN = max(int(cfg.DATA.EASY_NEG_RATIO * N), 1) if \ + epoch >= cfg.DATA.MIX_NEG_EPOCH else N + with torch.no_grad(): + idx = torch.arange(N) + idx[:SN] = torch.arange(1, SN+1) % SN + inputs[2][:, 1, ...] = inputs[2][idx, 1, ...] + return inputs + + def construct_loader(cfg, split, is_precise_bn=False): """ Constructs the data loader for the given dataset. diff --git a/slowfast/datasets/utils.py b/slowfast/datasets/utils.py index e0d4f599a..06af58fe0 100644 --- a/slowfast/datasets/utils.py +++ b/slowfast/datasets/utils.py @@ -69,16 +69,22 @@ def get_sequence(center_idx, half_len, sample_rate, num_frames): return seq -def pack_pathway_output(cfg, frames): +def pack_pathway_output(cfg, frames, audio_frames=None): """ Prepare output as a list of tensors. Each tensor corresponding to a unique pathway. Args: frames (tensor): frames of images sampled from the video. The dimension is `channel` x `num frames` x `height` x `width`. + audio_frames (tensor): audio inputs in log-mel-spectrogram + of shape C x 1 x T x F. Where C is 2 if misaligned audio + samples are extracted (1st and 2nd channels correspond to + pos and neg audio samples) otherwise C is 1. T corresponds + to cfg.DATA.AUDIO_FRAME_NUM and F is cfg.DATA.AUDIO_MEL_NUM. Returns: frame_list (list): list of tensors with the dimension of - `channel` x `num frames` x `height` x `width`. + `channel` x `num frames` x `height` x `width`. audio_frames + is untouched. """ if cfg.DATA.REVERSE_INPUT_CHANNEL: frames = frames[[2, 1, 0], :, :, :] @@ -94,7 +100,10 @@ def pack_pathway_output(cfg, frames): 0, frames.shape[1] - 1, frames.shape[1] // cfg.SLOWFAST.ALPHA ).long(), ) - frame_list = [slow_pathway, fast_pathway] + if cfg.MODEL.ARCH == "slowfast": + frame_list = [slow_pathway, fast_pathway] + elif cfg.MODEL.ARCH == "avslowfast": + frame_list = [slow_pathway, fast_pathway, audio_frames] else: raise NotImplementedError( "Model arch {} is not in {}".format( diff --git a/slowfast/models/head_helper.py b/slowfast/models/head_helper.py index 731ebc5ac..1881d6afe 100644 --- a/slowfast/models/head_helper.py +++ b/slowfast/models/head_helper.py @@ -203,6 +203,15 @@ def forward(self, inputs): for pathway in range(self.num_pathways): m = getattr(self, "pathway{}_avgpool".format(pathway)) pool_out.append(m(inputs[pathway])) + # check if audio pathway is compatible with visual ones + if len(pool_out) > 2: + a_H, a_W = pool_out[2].size(-2), pool_out[2].size(-1) + v_H, v_W = pool_out[0].size(-2), pool_out[0].size(-1) + if a_H != v_H or a_W != v_W: + assert v_H % a_H == 0 and v_W % a_W == 0, \ + 'Visual pool output should be divisible by audio pool output size' + a_N, a_C, a_T, _, _ = pool_out[2].shape + pool_out[2] = pool_out[2].expand([a_N, a_C, a_T, v_H, v_W]) x = torch.cat(pool_out, 1) # (N, C, T, H, W) -> (N, T, H, W, C). x = x.permute((0, 2, 3, 4, 1)) diff --git a/slowfast/models/resnet_helper.py b/slowfast/models/resnet_helper.py index 67c96ef9d..23d918ae6 100644 --- a/slowfast/models/resnet_helper.py +++ b/slowfast/models/resnet_helper.py @@ -4,6 +4,7 @@ """Video models.""" import torch.nn as nn +from torch import cat from slowfast.models.nonlocal_helper import Nonlocal @@ -15,6 +16,13 @@ def get_trans_func(name): trans_funcs = { "bottleneck_transform": BottleneckTransform, "basic_transform": BasicTransform, + # the following two are transform that decouples + # time and frequency in log-mel-spectrogram as described + # in AVSlowFast paper. Specifically, tf_bottleneck_transform_v1 + # is used in the paper, but tf_bottleneck_transform_v2 is + # more memory efficient. + "tf_bottleneck_transform_v2": TimeFreqBottleneckTransform_v2, + "tf_bottleneck_transform_v1": TimeFreqBottleneckTransform_v1, } assert ( name in trans_funcs.keys() @@ -107,6 +115,288 @@ def forward(self, x): return x +class TimeFreqBottleneckTransform_v1(nn.Module): + """ + The transformation function that decouples time + and frequency axis in log-mel-spectrogram inputs, + as described in the AVSlowFast paper. + 1x1x3, 1x3x1, 1x1x1 + """ + + def __init__( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups, + stride_1x1=False, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + dilation=1, + norm_module=nn.BatchNorm3d, + ): + """ + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temp_kernel_size (int): the temporal kernel sizes of the middle + convolution in the bottleneck. + stride (int): the stride of the bottleneck. + dim_inner (int): the inner dimension of the block. + num_groups (int): number of groups for the convolution. num_groups=1 + is for standard ResNet like networks, and num_groups>1 is for + ResNeXt like networks. + stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise + apply stride to the 3x3 conv. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + dilation (int): size of dilation. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(TimeFreqBottleneckTransform_v1, self).__init__() + self.temp_kernel_size = temp_kernel_size + self._inplace_relu = inplace_relu + self._eps = eps + self._bn_mmt = bn_mmt + self._stride_1x1 = stride_1x1 + self._construct( + dim_in, + dim_out, + stride, + dim_inner, + num_groups, + dilation, + norm_module, + ) + + + def _construct(self, dim_in, dim_out, stride, dim_inner, num_groups, + dilation, norm_module): + (str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride) + # 1x3x1, BN, ReLU. + self.t = nn.Conv3d( + dim_in, + dim_inner, + [1, 3, 1], + stride=[1, str3x3, str3x3], + padding=[0, 1, 0], + groups=num_groups, + bias=False, + ) + self.t_bn = norm_module( + dim_inner, eps=self._eps, momentum=self._bn_mmt + ) + self.t_relu = nn.ReLU(inplace=self._inplace_relu) + + # 1x1x3, BN, ReLU. + self.f = nn.Conv3d( + dim_in, + dim_inner, + [1, 1, 3], + stride=[1, str3x3, str3x3], + padding=[0, 0, 1], + groups=num_groups, + bias=False, + ) + self.f_bn = norm_module( + dim_inner, eps=self._eps, momentum=self._bn_mmt + ) + self.f_relu = nn.ReLU(inplace=self._inplace_relu) + + # 1x1x1, BN. + self.out = nn.Conv3d( + dim_inner*2, + dim_out, + kernel_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + bias=False, + ) + self.out_bn = norm_module( + dim_out, eps=self._eps, momentum=self._bn_mmt + ) + self.out_bn.transform_final_bn = True + + + def forward(self, x): + # Explicitly forward every layer. + # Branch2a_t. + x_t = self.t(x) + x_t = self.t_bn(x_t) + x_t = self.t_relu(x_t) + + # Branch2a_f. + x_f = self.f(x) + x_f = self.f_bn(x_f) + x_f = self.f_relu(x_f) + + # Merge 2a_t and 2a_f. + x = cat([x_t, x_f], 1) + + # Branch2b + x = self.out(x) + x = self.out_bn(x) + return x + + +class TimeFreqBottleneckTransform_v2(nn.Module): + """ + A more memory efficient version of the transformation + function that decouples time and frequency axis in + log-mel-spectrogram inputs, as described in the + AVSlowFast paper. + Tx1x1, 1x1x3, 1x3x1, 1x1x1 + """ + + def __init__( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups, + stride_1x1=False, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + dilation=1, + norm_module=nn.BatchNorm3d, + ): + """ + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temp_kernel_size (int): the temporal kernel sizes of the middle + convolution in the bottleneck. + stride (int): the stride of the bottleneck. + dim_inner (int): the inner dimension of the block. + num_groups (int): number of groups for the convolution. num_groups=1 + is for standard ResNet like networks, and num_groups>1 is for + ResNeXt like networks. + stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise + apply stride to the 3x3 conv. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + dilation (int): size of dilation. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(TimeFreqBottleneckTransform_v2, self).__init__() + self.temp_kernel_size = temp_kernel_size + self._inplace_relu = inplace_relu + self._eps = eps + self._bn_mmt = bn_mmt + self._stride_1x1 = stride_1x1 + self._construct( + dim_in, + dim_out, + stride, + dim_inner, + num_groups, + dilation, + norm_module, + ) + + def _construct(self, dim_in, dim_out, stride, dim_inner, num_groups, + dilation, norm_module): + (str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride) + + # Tx1x1, BN, ReLU. + self.a = nn.Conv3d( + dim_in, + dim_inner, + kernel_size=[self.temp_kernel_size, 1, 1], + stride=[1, str1x1, str1x1], + padding=[int(self.temp_kernel_size // 2), 0, 0], + bias=False, + ) + self.a_bn = norm_module( + dim_inner, eps=self._eps, momentum=self._bn_mmt + ) + self.a_relu = nn.ReLU(inplace=self._inplace_relu) + + # 1x3x1, BN, ReLU. + self.b_t = nn.Conv3d( + dim_inner, + dim_inner, + [1, 3, 1], + stride=[1, str3x3, str3x3], + padding=[0, 1, 0], + groups=num_groups, + bias=False, + ) + self.b_t_bn = norm_module( + dim_inner, eps=self._eps, momentum=self._bn_mmt + ) + self.b_t_relu = nn.ReLU(inplace=self._inplace_relu) + + # 1x1x3, BN, ReLU. + self.b_f = nn.Conv3d( + dim_inner, + dim_inner, + [1, 1, 3], + stride=[1, str3x3, str3x3], + padding=[0, 0, 1], + groups=num_groups, + bias=False, + ) + self.b_f_bn = norm_module( + dim_inner, eps=self._eps, momentum=self._bn_mmt + ) + self.b_f_relu = nn.ReLU(inplace=self._inplace_relu) + + # 1x1x1, BN. + self.c = nn.Conv3d( + dim_inner, + dim_out, + kernel_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + bias=False, + ) + self.c_bn = norm_module( + dim_out, eps=self._eps, momentum=self._bn_mmt + ) + self.c_bn.transform_final_bn = True + + + def forward(self, x): + # Explicitly forward every layer. + # Branch2a. + x = self.a(x) + x = self.a_bn(x) + x = self.a_relu(x) + + # Branch2b_t. + x_t = self.b_t(x) + x_t = self.b_t_bn(x_t) + x_t = self.b_t_relu(x_t) + + # Branch2b_f. + x_f = self.b_f(x) + x_f = self.b_f_bn(x_f) + x_f = self.b_f_relu(x_f) + + # Merge 2b_t and 2b_f. + x = x_t + x_f + + # Branch2c + x = self.c(x) + x = self.c_bn(x) + return x + + class BottleneckTransform(nn.Module): """ Bottleneck transformation: Tx1x1, 1x3x3, 1x1x1, where T is the size of @@ -496,10 +786,13 @@ def _construct( dilation, norm_module, ): + if not isinstance(trans_func_name, list): + trans_func_name = [trans_func_name] * self.num_pathways + for pathway in range(self.num_pathways): for i in range(self.num_blocks[pathway]): # Retrieve the transformation function. - trans_func = get_trans_func(trans_func_name) + trans_func = get_trans_func(trans_func_name[pathway]) # Construct the block. res_block = ResBlock( dim_in[pathway] if i == 0 else dim_out[pathway], diff --git a/slowfast/models/stem_helper.py b/slowfast/models/stem_helper.py index 481977b15..b576d85dd 100644 --- a/slowfast/models/stem_helper.py +++ b/slowfast/models/stem_helper.py @@ -23,6 +23,7 @@ def __init__( eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d, + stride_pool=[True, True, True], ): """ The `__init__` method of any subclass should also contain these @@ -72,11 +73,15 @@ def __init__( self.eps = eps self.bn_mmt = bn_mmt # Construct the stem layer. - self._construct_stem(dim_in, dim_out, norm_module) + self._construct_stem(dim_in, dim_out, norm_module, stride_pool) - def _construct_stem(self, dim_in, dim_out, norm_module): + def _construct_stem(self, dim_in, dim_out, norm_module, stride_pool): for pathway in range(len(dim_in)): - stem = ResNetBasicStem( + if pathway == 2: + stem_func = AudioTFBasicStem + else: + stem_func = ResNetBasicStem + stem = stem_func( dim_in[pathway], dim_out[pathway], self.kernel[pathway], @@ -86,6 +91,7 @@ def _construct_stem(self, dim_in, dim_out, norm_module): self.eps, self.bn_mmt, norm_module, + stride_pool=stride_pool[pathway], ) self.add_module("pathway{}_stem".format(pathway), stem) @@ -99,6 +105,99 @@ def forward(self, x): return x +class AudioTFBasicStem(nn.Module): + """ + Audio time-frequency stem module. + Performs separate time and frequency Convolution, + BN, and Relu following by a spatiotemporal pooling. + """ + + def __init__( + self, + dim_in, + dim_out, + kernel, + stride, + padding, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + norm_module=nn.BatchNorm3d, + stride_pool=True, + ): + """ + The `__init__` method of any subclass should also contain these arguments. + + Args: + dim_in (int): the channel dimension of the input. Normally 1 is used + for audio log-mel-spectrogram input. + dim_out (int): the output dimension of the convolution in the stem + layer. + kernel (list): the kernel size of the convolution in the stem layer. + temporal kernel size, height kernel size, width kernel size in + order. + stride (list): the stride size of the convolution in the stem layer. + temporal kernel stride, height kernel size, width kernel size in + order. + padding (int): the padding size of the convolution in the stem + layer, temporal padding size, height padding size, width + padding size in order. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + + super(AudioTFBasicStem, self).__init__() + self.kernel = kernel + self.stride = stride + self.padding = padding + self.inplace_relu = inplace_relu + self.eps = eps + self.bn_mmt = bn_mmt + self.stride_pool = stride_pool + + # Construct the stem layer. + self._construct_stem(dim_in, dim_out, norm_module) + + def _construct_stem(self, dim_in, dim_out, norm_module): + self.conv_t = nn.Conv3d( + dim_in, + dim_out, + self.kernel[0], + stride=self.stride[0], + padding=self.padding[0], + bias=False, + ) + self.conv_f = nn.Conv3d( + dim_in, + dim_out, + self.kernel[1], + stride=self.stride[1], + padding=self.padding[1], + bias=False, + ) + self.bn = norm_module(dim_out, eps=self.eps, momentum=self.bn_mmt) + self.relu = nn.ReLU(self.inplace_relu) + if self.stride_pool: + self.pool_layer = nn.MaxPool3d( + kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] + ) + + def forward(self, x): + x_t = self.conv_t(x) + x_f = self.conv_f(x) + x = x_t + x_f + x = self.bn(x) + x = self.relu(x) + if self.stride_pool: + x = self.pool_layer(x) + return x + + class ResNetBasicStem(nn.Module): """ ResNe(X)t 3D stem module. @@ -117,6 +216,7 @@ def __init__( eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d, + stride_pool=True, ): """ The `__init__` method of any subclass should also contain these arguments. @@ -150,6 +250,7 @@ def __init__( self.inplace_relu = inplace_relu self.eps = eps self.bn_mmt = bn_mmt + self.stride_pool = stride_pool # Construct the stem layer. self._construct_stem(dim_in, dim_out, norm_module) @@ -166,13 +267,15 @@ def _construct_stem(self, dim_in, dim_out, norm_module): num_features=dim_out, eps=self.eps, momentum=self.bn_mmt ) self.relu = nn.ReLU(self.inplace_relu) - self.pool_layer = nn.MaxPool3d( - kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] - ) + if self.stride_pool: + self.pool_layer = nn.MaxPool3d( + kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] + ) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) - x = self.pool_layer(x) + if self.stride_pool: + x = self.pool_layer(x) return x diff --git a/slowfast/models/video_model_builder.py b/slowfast/models/video_model_builder.py index e54aea8a0..4701ffcd7 100644 --- a/slowfast/models/video_model_builder.py +++ b/slowfast/models/video_model_builder.py @@ -5,9 +5,12 @@ import torch import torch.nn as nn +import torch.nn.functional as F +import random import slowfast.utils.weight_init_helper as init_helper from slowfast.models.batchnorm_helper import get_norm +from slowfast.utils import misc from . import head_helper, resnet_helper, stem_helper from .build import MODEL_REGISTRY @@ -59,6 +62,13 @@ [[3], [3]], # res4 temporal kernel for slow and fast pathway. [[3], [3]], # res5 temporal kernel for slow and fast pathway. ], + "avslowfast": [ + [[1], [5], [1]], # conv1 temp kernel for slow, fast and audio pathway. + [[1], [3], [1]], # res2 temp kernel for slow, fast and audio pathway. + [[1], [3], [1]], # res3 temp kernel for slow, fast and audio pathway. + [[3], [3], [1]], # res4 temp kernel for slow, fast and audio pathway. + [[3], [3], [1]], # res5 temp kernel for slow, fast and audio pathway. + ], } _POOL1 = { @@ -68,9 +78,970 @@ "i3d_nopool": [[1, 1, 1]], "slow": [[1, 1, 1]], "slowfast": [[1, 1, 1], [1, 1, 1]], + "avslowfast": [[1, 1, 1], [1, 1, 1], [1, 1, 1]], } +class AVS(nn.Module): + """ + Compute Audio-Visual synchronization loss. + """ + + def __init__(self, ref_dim, query_dim, proj_dim, num_gpus, num_shards): + """ + Args: + ref_dim (int): the channel dimension of the reference data point + (usually a visual input). + query_dim (int): the channel dimension of the query data point + (usually an audio input). + proj_dim (int): the channel dimension of the projected codes. + num_gpus (int): number of gpus used. + num_shards (int): number of shards used. + """ + + super(AVS, self).__init__() + + # initialize fc projection layers + self.proj_dim = proj_dim + self.ref_fc = nn.Linear(ref_dim, proj_dim, bias=True) + self.query_fc = nn.Linear(query_dim, proj_dim, bias=True) + self.num_gpus = num_gpus + self.num_shards = num_shards + + + def contrastive_loss(self, ref, pos, neg, audio_mask, margin): + """ + Implement the contrastive loss used in https://arxiv.org/abs/1807.00230 + """ + N = torch.sum(audio_mask) + + pos_dist = ref - pos + neg_dist = ref - neg + pos_dist = pos_dist[audio_mask] + neg_dist = neg_dist[audio_mask] + + pos_loss = torch.norm(pos_dist)**2 + neg_dist = torch.norm(neg_dist, dim=1) + neg_loss = torch.sum(torch.clamp(margin - neg_dist, min=0)**2) + loss = (pos_loss + neg_loss) / (2*N + 1e-8) + return loss + + + def forward(self, ref, pos, neg, audio_mask, norm='L2', margin=0.99): + # reduce T, H, W dims + ref = torch.mean(ref, (2, 3, 4)) + pos = torch.mean(pos, (2, 3, 4)) + neg = torch.mean(neg, (2, 3, 4)) + + # projection + ref = self.ref_fc(ref) + pos = self.query_fc(pos) + neg = self.query_fc(neg) + + # normalize + if norm == 'L2': + ref = torch.nn.functional.normalize(ref, p=2, dim=1) + pos = torch.nn.functional.normalize(pos, p=2, dim=1) + neg = torch.nn.functional.normalize(neg, p=2, dim=1) + # scale data so that ||x-y||^2 fall in [0, 1] + ref = ref * 0.5 + pos = pos * 0.5 + neg = neg * 0.5 + elif norm == 'Tanh': + scale = 1.0 / self.proj_dim + ref = torch.nn.functional.tanh(ref) * scale + pos = torch.nn.functional.tanh(pos) * scale + neg = torch.nn.functional.tanh(neg) * scale + + # contrstive loss + loss = self.contrastive_loss(ref, pos, neg, audio_mask, margin) + + # scale the loss with nGPUs and shards + # loss = loss / float(self.num_gpus * self.num_shards) + loss = loss / float(self.num_shards) + + return loss + + +class FuseAV(nn.Module): + """ + Fuses information from audio to visual pathways. + """ + + def __init__( + self, + # slow pathway + dim_in_s, + # fast pathway + dim_in_f, + fusion_conv_channel_ratio_f, + fusion_kernel_f, + alpha_f, + # audio pathway + dim_in_a, + fusion_conv_channel_mode_a, + fusion_conv_channel_dim_a, + fusion_conv_channel_ratio_a, + fusion_kernel_a, + alpha_a, + conv_num_a, + # fusion connections + use_fs_fusion, + use_afs_fusion, + # AVS + use_avs, + avs_proj_dim, + # general params + num_gpus=1, + num_shards=1, + eps=1e-5, + bn_mmt=0.1, + inplace_relu=True, + ): + """ + Perform A2TS fusion described in AVSlowFast paper. + + Args: + dim_in_s (int): channel dimension of the slow pathway. + dim_in_f (int): channel dimension of the fast pathway. + fusion_conv_channel_ratio_f (int): channel ratio for the convolution + used to fuse from Fast pathway to Slow pathway. + fusion_kernel_f (int): kernel size of the convolution used to fuse + from Fast pathway to Slow pathway. + alpha_f (int): the frame rate ratio between the Fast and Slow pathway. + dim_in_a (int): channel dimension of audio inputs. + fusion_conv_channel_mode_a (str): 'ByDim' or 'ByRatio'. Decide how to + compute intermediate feature dimension for Audiovisual fusion. + fusion_conv_channel_dim_a (int): used when 'fusion_conv_channel_mode_a' + == 'ByDim', decide intermediate feature dimension for Audiovisual fusion. + fusion_conv_channel_ratio_a (float): used when 'fusion_conv_channel_mode_a' + == 'ByRatio', decide intermediate feature dimension for Audiovisual fusion. + fusion_kernel_a (int): kernel size of the convolution used to fuse + from Audio pathway to SlowFast pathways. + alpha_a (int): the frame rate ratio between the Audio and Slow pathway. + conv_num_a (int): number of convs applied on audio, before fusing into + SlowFast pathways. + use_fs_fusion (bool): whether use Fast->Slow fusion. + use_afs_fusion (bool): whether use Audio->SlowFast fusion. + use_avs (bool): whether compute audiovisual synchronization loss. + avs_proj_dim (int): channel dimension of the projection codes for audiovisual + synchronization loss. + num_gpus (int): number of gpus used. + num_shards (int): number of shards used. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + """ + super(FuseAV, self).__init__() + self.conv_num_a = conv_num_a + self.use_fs_fusion = use_fs_fusion + self.use_afs_fusion = use_afs_fusion + + # perform F->S fusion + if use_fs_fusion: + self.conv_f2s = nn.Conv3d( + dim_in_f, + dim_in_f * fusion_conv_channel_ratio_f, + kernel_size=[fusion_kernel_f, 1, 1], + stride=[alpha_f, 1, 1], + padding=[fusion_kernel_f // 2, 0, 0], + bias=False, + ) + self.bn_f2s = nn.BatchNorm3d( + dim_in_f * fusion_conv_channel_ratio_f, eps=eps, momentum=bn_mmt + ) + self.relu_f2s = nn.ReLU(inplace_relu) + + # perform A->FS fusion + if fusion_conv_channel_mode_a == 'ByDim': + afs_fusion_interm_dim = int(fusion_conv_channel_dim_a) + elif fusion_conv_channel_mode_a == 'ByRatio': + afs_fusion_interm_dim = int(dim_in_a * fusion_conv_channel_ratio_a) + else: + raise RuntimeError + if use_afs_fusion: + cur_dim_in = dim_in_a + for idx in range(conv_num_a): + if idx == conv_num_a - 1: + cur_stride = alpha_a + cur_dim_out = int(dim_in_f * fusion_conv_channel_ratio_f \ + + dim_in_s) + else: + cur_stride = 1 + cur_dim_out = afs_fusion_interm_dim + conv_a2fs = nn.Conv3d( + cur_dim_in, + cur_dim_out, + kernel_size=[1, fusion_kernel_a, 1], + stride=[1, cur_stride, 1], + padding=[0, fusion_kernel_a // 2, 0], + bias=False, + ) + bn_a2fs = nn.BatchNorm3d( + cur_dim_out, eps=eps, momentum=bn_mmt + ) + relu_a2fs = nn.ReLU(inplace_relu) + self.add_module('conv_a2fs_%d' % idx, conv_a2fs) + self.add_module('bn_a2fs_%d' % idx, bn_a2fs) + self.add_module('relu_a2fs_%d' % idx, relu_a2fs) + cur_dim_in = cur_dim_out + + dim_in_a = int(dim_in_f * fusion_conv_channel_ratio_f + dim_in_s) + + # optionally compute audiovisual synchronization loss + if use_avs: + self.avs = AVS( + dim_in_f * fusion_conv_channel_ratio_f + dim_in_s, + dim_in_a, + avs_proj_dim, + num_gpus, + num_shards, + ) + + + def forward(self, x, get_misaligned_audio=False, mode='AFS'): + """ + Forward function for audiovisual fusion, note that it currently only + supports A->FS fusion mode (which is the default used in AVSlowFast paper) + Args: + x (list): contains slow, fast and audio features + get_misaligned_audio (bool): whether misaligned audio is carried in x + mode (str): + AFS -- fuse audio, fast and slow + AS -- fuse audio and slow + FS -- fuse fast and slow + NONE -- do not fuse at all + """ + x_s = x[0] + x_f = x[1] + x_a = x[2] + fuse = x_s + cache = {} + + if mode != 'NONE': + fs_proc, afs_proc = None, None + + # F->S + if self.use_fs_fusion: + fs_proc = self.conv_f2s(x_f) + fs_proc = self.bn_f2s(fs_proc) + fs_proc = self.relu_f2s(fs_proc) + fuse = torch.cat([fuse, fs_proc], 1) + cache['fs'] = fuse + + # A->FS + if self.use_afs_fusion: + # [N C 1 T F] -> [N C 1 T 1] + afs_proc = torch.mean(x_a, dim=-1, keepdim=True) + for idx in range(self.conv_num_a): + conv = getattr(self, 'conv_a2fs_%d' % idx) + bn = getattr(self, 'bn_a2fs_%d' % idx) + relu = getattr(self, 'relu_a2fs_%d' % idx) + afs_proc = conv(afs_proc) + afs_proc = bn(afs_proc) + afs_proc = relu(afs_proc) + if get_misaligned_audio: + afs_proc_pos, afs_proc_neg = torch.chunk(afs_proc, 2, dim=0) + cache['a_pos'] = afs_proc_pos + cache['a_neg'] = afs_proc_neg + else: + afs_proc_pos = afs_proc + # [N C 1 T 1] -> [N C T 1 1] + afs_proc_pos = afs_proc_pos.permute(0, 1, 3, 2, 4) + if 'A' in mode: + fuse = afs_proc_pos + fuse + else: + fuse = afs_proc_pos * 0.0 + fuse + return [fuse, x_f, x_a], cache + + +class FuseFastToSlow(nn.Module): + """ + Fuses the information from the Fast pathway to the Slow pathway. Given the + tensors from Slow pathway and Fast pathway, fuse information from Fast to + Slow, then return the fused tensors from Slow and Fast pathway in order. + """ + + def __init__( + self, + dim_in, + fusion_conv_channel_ratio, + fusion_kernel, + alpha, + eps=1e-5, + bn_mmt=0.1, + inplace_relu=True, + ): + """ + Args: + dim_in (int): the channel dimension of the input. + fusion_conv_channel_ratio (int): channel ratio for the convolution + used to fuse from Fast pathway to Slow pathway. + fusion_kernel (int): kernel size of the convolution used to fuse + from Fast pathway to Slow pathway. + alpha (int): the frame rate ratio between the Fast and Slow pathway. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + """ + super(FuseFastToSlow, self).__init__() + self.conv_f2s = nn.Conv3d( + dim_in, + dim_in * fusion_conv_channel_ratio, + kernel_size=[fusion_kernel, 1, 1], + stride=[alpha, 1, 1], + padding=[fusion_kernel // 2, 0, 0], + bias=False, + ) + self.bn = nn.BatchNorm3d( + dim_in * fusion_conv_channel_ratio, eps=eps, momentum=bn_mmt + ) + self.relu = nn.ReLU(inplace_relu) + + + def forward(self, x): + x_s = x[0] + x_f = x[1] + fuse = self.conv_f2s(x_f) + fuse = self.bn(fuse) + fuse = self.relu(fuse) + x_s_fuse = torch.cat([x_s, fuse], 1) + return [x_s_fuse, x_f] + + +@MODEL_REGISTRY.register() +class AVSlowFast(nn.Module): + """ + Model builder for AVSlowFast network. + Fanyi Xiao, Yong Jae Lee, Kristen Grauman, Jitendra Malik, Christoph Feichtenhofer. + "Audiovisual Slowfast Networks for Video Recognition." + https://arxiv.org/abs/2001.08740 + """ + + def __init__(self, cfg): + """ + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(AVSlowFast, self).__init__() + self.norm_module = get_norm(cfg) + self.num_pathways = 3 + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + + def _construct_network(self, cfg): + """ + Builds an AVSlowFast model. The first pathway is the Slow pathway and the + second pathway is the Fast pathway, and the third one is the Audio + pathway. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + self.DROPPATHWAY_RATE = cfg.SLOWFAST.DROPPATHWAY_RATE + self.FS_FUSION = cfg.SLOWFAST.FS_FUSION + self.AFS_FUSION = cfg.SLOWFAST.AFS_FUSION + self.GET_MISALIGNED_AUDIO = cfg.DATA.GET_MISALIGNED_AUDIO + self.AVS_FLAG = cfg.SLOWFAST.AVS_FLAG + self.AVS_PROJ_DIM = cfg.SLOWFAST.AVS_PROJ_DIM + self.AVS_VAR_THRESH = cfg.SLOWFAST.AVS_VAR_THRESH + self.AVS_DUPLICATE_THRESH = cfg.SLOWFAST.AVS_DUPLICATE_THRESH + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + tf_trans_func = [cfg.RESNET.TRANS_FUNC] * 2 + \ + [cfg.RESNET.AUDIO_TRANS_FUNC] + trans_func = [tf_trans_func] * cfg.RESNET.AUDIO_TRANS_NUM + \ + [cfg.RESNET.TRANS_FUNC] * (4 - cfg.RESNET.AUDIO_TRANS_NUM) + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + out_dim_ratio = ( + cfg.SLOWFAST.BETA_INV // cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO + ) + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + if cfg.SLOWFAST.AU_REDUCE_TF_DIM: + tf_stride = 2 + else: + tf_stride = 1 + tf_dim_reduction = 1 + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[ + width_per_group, + width_per_group // cfg.SLOWFAST.BETA_INV, + width_per_group // cfg.SLOWFAST.AU_BETA_INV + ], + kernel=[ + temp_kernel[0][0] + [7, 7], + temp_kernel[0][1] + [7, 7], + [temp_kernel[0][2] + [9, 1], temp_kernel[0][2] + [1, 9]], + ], + stride=[[1, 2, 2], [1, 2, 2], [[1, 1, 1], [1, 1, 1]]], + padding=[ + [temp_kernel[0][0][0] // 2, 3, 3], + [temp_kernel[0][1][0] // 2, 3, 3], + [[temp_kernel[0][2][0] // 2, 4, 0], [temp_kernel[0][2][0] // 2, 0, 4]], + ], + stride_pool=[True, True, False], + ) + + if self.FS_FUSION[0] or self.AFS_FUSION[0]: + self.s1_fuse = FuseAV( + # Slow + width_per_group, + # Fast + width_per_group // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + # Audio + width_per_group // cfg.SLOWFAST.AU_BETA_INV, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_MODE, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_DIM, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.AU_FUSION_KERNEL_SZ, + cfg.SLOWFAST.AU_ALPHA // tf_dim_reduction, + cfg.SLOWFAST.AU_FUSION_CONV_NUM, + # Fusion connections + self.FS_FUSION[0], + self.AFS_FUSION[0], + # AVS + self.AVS_FLAG[0], + self.AVS_PROJ_DIM, + # nGPUs and shards + num_gpus=cfg.NUM_GPUS, + num_shards=cfg.NUM_SHARDS, + ) + + slow_dim = width_per_group + \ + (width_per_group // out_dim_ratio if self.FS_FUSION[0] else 0) + self.s2 = resnet_helper.ResStage( + dim_in=[ + slow_dim, + width_per_group // cfg.SLOWFAST.BETA_INV, + width_per_group // cfg.SLOWFAST.AU_BETA_INV, + ], + dim_out=[ + width_per_group * 4, + width_per_group * 4 // cfg.SLOWFAST.BETA_INV, + width_per_group * 4 // cfg.SLOWFAST.AU_BETA_INV, + ], + dim_inner=[ + dim_inner, + dim_inner // cfg.SLOWFAST.BETA_INV, + dim_inner // cfg.SLOWFAST.AU_BETA_INV + ], + temp_kernel_sizes=temp_kernel[1], + stride=[1] * 3, + num_blocks=[d2] * 3, + num_groups=[num_groups] * 3, + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=trans_func[0], + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + if self.FS_FUSION[1] or self.AFS_FUSION[1]: + self.s2_fuse = FuseAV( + # Slow + width_per_group * 4, + # Fast + width_per_group * 4 // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + # Audio + width_per_group * 4 // cfg.SLOWFAST.AU_BETA_INV, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_MODE, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_DIM, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.AU_FUSION_KERNEL_SZ, + cfg.SLOWFAST.AU_ALPHA // tf_dim_reduction, + cfg.SLOWFAST.AU_FUSION_CONV_NUM, + # Fusion connections + self.FS_FUSION[1], + self.AFS_FUSION[1], + # AVS + self.AVS_FLAG[1], + self.AVS_PROJ_DIM, + # nGPUs and shards + num_gpus=cfg.NUM_GPUS, + num_shards=cfg.NUM_SHARDS, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + slow_dim = width_per_group * 4 + \ + (width_per_group * 4 // out_dim_ratio if self.FS_FUSION[1] else 0) + self.s3 = resnet_helper.ResStage( + dim_in=[ + slow_dim, + width_per_group * 4 // cfg.SLOWFAST.BETA_INV, + width_per_group * 4 // cfg.SLOWFAST.AU_BETA_INV, + ], + dim_out=[ + width_per_group * 8, + width_per_group * 8 // cfg.SLOWFAST.BETA_INV, + width_per_group * 8 // cfg.SLOWFAST.AU_BETA_INV, + ], + dim_inner=[ + dim_inner * 2, + dim_inner * 2 // cfg.SLOWFAST.BETA_INV, + dim_inner * 2 // cfg.SLOWFAST.AU_BETA_INV + ], + temp_kernel_sizes=temp_kernel[2], + stride=[2, 2, tf_stride], + num_blocks=[d3] * 3, + num_groups=[num_groups] * 3, + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=trans_func[1], + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + tf_dim_reduction *= tf_stride + + if self.FS_FUSION[2] or self.AFS_FUSION[2]: + self.s3_fuse = FuseAV( + # Slow + width_per_group * 8, + # Fast + width_per_group * 8 // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + # Audio + width_per_group * 8 // cfg.SLOWFAST.AU_BETA_INV, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_MODE, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_DIM, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.AU_FUSION_KERNEL_SZ, + cfg.SLOWFAST.AU_ALPHA // tf_dim_reduction, + cfg.SLOWFAST.AU_FUSION_CONV_NUM, + # Fusion connections + self.FS_FUSION[2], + self.AFS_FUSION[2], + # AVS + self.AVS_FLAG[2], + self.AVS_PROJ_DIM, + # nGPUs and shards + num_gpus=cfg.NUM_GPUS, + num_shards=cfg.NUM_SHARDS, + ) + + slow_dim = width_per_group * 8 + \ + (width_per_group * 8 // out_dim_ratio if self.FS_FUSION[2] else 0) + self.s4 = resnet_helper.ResStage( + dim_in=[ + slow_dim, + width_per_group * 8 // cfg.SLOWFAST.BETA_INV, + width_per_group * 8 // cfg.SLOWFAST.AU_BETA_INV, + ], + dim_out=[ + width_per_group * 16, + width_per_group * 16 // cfg.SLOWFAST.BETA_INV, + width_per_group * 16 // cfg.SLOWFAST.AU_BETA_INV, + ], + dim_inner=[ + dim_inner * 4, + dim_inner * 4 // cfg.SLOWFAST.BETA_INV, + dim_inner * 4 // cfg.SLOWFAST.AU_BETA_INV + ], + temp_kernel_sizes=temp_kernel[3], + stride=[2, 2, tf_stride], + num_blocks=[d4] * 3, + num_groups=[num_groups] * 3, + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=trans_func[2], + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + tf_dim_reduction *= tf_stride + + if self.FS_FUSION[3] or self.AFS_FUSION[3]: + self.s4_fuse = FuseAV( + # Slow + width_per_group * 16, + # Fast + width_per_group * 16 // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + # Audio + width_per_group * 16 // cfg.SLOWFAST.AU_BETA_INV, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_MODE, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_DIM, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.AU_FUSION_KERNEL_SZ, + cfg.SLOWFAST.AU_ALPHA // tf_dim_reduction, + cfg.SLOWFAST.AU_FUSION_CONV_NUM, + # Fusion connections + self.FS_FUSION[3], + self.AFS_FUSION[3], + # AVS + self.AVS_FLAG[3], + self.AVS_PROJ_DIM, + # nGPUs and shards + num_gpus=cfg.NUM_GPUS, + num_shards=cfg.NUM_SHARDS, + ) + + slow_dim = width_per_group * 16 + \ + (width_per_group * 16 // out_dim_ratio if self.FS_FUSION[3] else 0) + self.s5 = resnet_helper.ResStage( + dim_in=[ + slow_dim, + width_per_group * 16 // cfg.SLOWFAST.BETA_INV, + width_per_group * 16 // cfg.SLOWFAST.AU_BETA_INV, + ], + dim_out=[ + width_per_group * 32, + width_per_group * 32 // cfg.SLOWFAST.BETA_INV, + width_per_group * 32 // cfg.SLOWFAST.AU_BETA_INV, + ], + dim_inner=[ + dim_inner * 8, + dim_inner * 8 // cfg.SLOWFAST.BETA_INV, + dim_inner * 8 // cfg.SLOWFAST.AU_BETA_INV, + ], + temp_kernel_sizes=temp_kernel[4], + stride=[2, 2, tf_stride], + num_blocks=[d5] * 3, + num_groups=[num_groups] * 3, + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + nonlocal_group=cfg.NONLOCAL.GROUP[3], + nonlocal_pool=cfg.NONLOCAL.POOL[3], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=trans_func[3], + dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + norm_module=self.norm_module, + ) + tf_dim_reduction *= tf_stride + + # setup AVS for pool5 output + if self.AVS_FLAG[4]: + # this FuseAV object is used for compute AVS loss only + self.s5_fuse = FuseAV( + # Slow + width_per_group * 32, + # Fast + width_per_group * 32 // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + # Audio + width_per_group * 32 // cfg.SLOWFAST.AU_BETA_INV, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_MODE, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_DIM, + cfg.SLOWFAST.AU_FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.AU_FUSION_KERNEL_SZ, + cfg.SLOWFAST.AU_ALPHA // tf_dim_reduction, + cfg.SLOWFAST.AU_FUSION_CONV_NUM, + # Fusion connections + True, + True, + # AVS + self.AVS_FLAG[4], + self.AVS_PROJ_DIM, + # nGPUs and shards + num_gpus=cfg.NUM_GPUS, + num_shards=cfg.NUM_SHARDS, + ) + + self.head = head_helper.ResNetBasicHead( + dim_in=[ + width_per_group * 32, + width_per_group * 32 // cfg.SLOWFAST.BETA_INV, + width_per_group * 32 // cfg.SLOWFAST.AU_BETA_INV, + ], + num_classes=cfg.MODEL.NUM_CLASSES, + pool_size=[ + [ + cfg.DATA.NUM_FRAMES + // cfg.SLOWFAST.ALPHA + // pool_size[0][0], + cfg.DATA.CROP_SIZE // 32 // pool_size[0][1], + cfg.DATA.CROP_SIZE // 32 // pool_size[0][2], + ], + [ + cfg.DATA.NUM_FRAMES // pool_size[1][0], + cfg.DATA.CROP_SIZE // 32 // pool_size[1][1], + cfg.DATA.CROP_SIZE // 32 // pool_size[1][2], + ], + [ + 1, + cfg.DATA.AUDIO_FRAME_NUM // tf_dim_reduction, + cfg.DATA.AUDIO_MEL_NUM // tf_dim_reduction, + ], + ], + dropout_rate=cfg.MODEL.DROPOUT_RATE, + ) + + + def freeze_bn(self, freeze_bn_affine): + """ + Freeze the BN parameters + """ + print("Freezing Mean/Var of BatchNorm.") + if freeze_bn_affine: + print("Freezing Weight/Bias of BatchNorm.") + for name, m in self.named_modules(): + if isinstance(m, nn.BatchNorm1d) or \ + isinstance(m, nn.BatchNorm2d) or \ + isinstance(m, nn.BatchNorm3d): + # if 'pathway2' in name or 'a2fs' in name: + # continue + m.eval() + if freeze_bn_affine: + m.weight.requires_grad = False + m.bias.requires_grad = False + + + def gen_fusion_avs_pattern(self): + """ + This function generates a fusion pattern and a avs loss compute pattern. + Specifically, fusion pattern is determined by both pre-defined fusion + connections between Slow/Fast/Audio, and the flag of whether to drop the + audio pathway, which is generated on the fly. + For AVS pattern, it is determined by fusion pattern. For example, if we + decided to have AFS fusion pattern like [False, False, True, True], + which means to have fusion between audio and visual after res3 and res4, + and let's say our AFS_FUSION is [False, False, False, True], then we will + not compute AVS loss anywhere. This is because since we have fused audio + into visual at res3, any visual features after this has already "seen" + audio features and the problem of telling whether audio and visual is in-sync + will be trivial. + """ + is_drop = self.training and random.random() < self.DROPPATHWAY_RATE + fs_fusion = self.FS_FUSION + afs_fusion = self.AFS_FUSION + runtime_afs_fusion = [] + fusion_pattern, avs_pattern = [], [] + + for idx in range(4): + # If a junction has both audiovisual fusion and slowfast fusion, + # we call it 'AFS'. If it only has slowfast fusion, we call it 'FS'. + # If it only has audio and slow fusion, we call it 'AS' + cur_fs_fusion = fs_fusion[idx] + cur_afs_fusion = afs_fusion[idx] and not is_drop + if cur_fs_fusion and cur_afs_fusion: + fusion_pattern.append('AFS') + elif cur_fs_fusion and not cur_afs_fusion: + fusion_pattern.append('FS') + elif not cur_fs_fusion and cur_afs_fusion: + fusion_pattern.append('AS') + else: + fusion_pattern.append('NONE') + runtime_afs_fusion.append(cur_afs_fusion) + + # compute the earliest audiovisual fusion, so that we don't do AVS + # for any stage later than that + earliest_afs = 4 + for idx in range(3, -1, -1): + if runtime_afs_fusion[idx]: + earliest_afs = idx + + for idx in range(5): + if idx <= earliest_afs and self.AVS_FLAG[idx]: + avs_pattern.append(True) + else: + avs_pattern.append(False) + + return fusion_pattern, avs_pattern + + + def move_C_to_N(self, x): + """ + Assume x is with shape [N C T H W], this function merges C into N which + results in shape [N*C 1 T H W] + """ + N, C, T, H, W = x[2].size() + x[2] = x[2].reshape(N*C, 1, T, H, W) + return x + + + def filter_duplicates(self, x): + """ + Compute a valid mask for near-duplicates and near-zero audios + """ + mask = None + if self.GET_MISALIGNED_AUDIO: + with torch.no_grad(): + audio = x[2] + N, C, T, H, W = audio.size() + audio = audio.reshape(N//2, C*2, -1) + # remove pairs that are near-zero + audio_std = torch.std(audio, dim=2) ** 2 + mask = audio_std > self.AVS_VAR_THRESH + mask = mask[:, 0] * mask[:, 1] + # remove pairs that are near-duplicate + audio = F.normalize(audio, dim=2) + similarity = audio[:, 0, :] * audio[:, 1, :] + similarity = torch.sum(similarity, dim=1) + similarity = similarity < self.AVS_DUPLICATE_THRESH + # integrate var and dup mask + mask = mask * similarity + # mask = mask.float() + return mask + + + def get_pos_audio(self, x): + """ + Slice the data and only take the first half + along the first dim for positive data + """ + x[2], _ = torch.chunk(x[2], 2, dim=0) + return x + + + def avs_forward(self, features, audio_mask): + """ + Forward for AVS loss + """ + loss_list = {} + avs_pattern = features['avs_pattern'] + for idx in range(5): + if self.AVS_FLAG[idx]: + a_pos = features['s{}_a_pos'.format(idx + 1)] + a_neg = features['s{}_a_neg'.format(idx + 1)] + fs = features['s{}_fs'.format(idx + 1)] + fuse = getattr(self, 's{}_fuse'.format(idx + 1)) + avs = getattr(fuse, 'avs') + loss = avs(fs, a_pos, a_neg, audio_mask) + if not avs_pattern[idx]: + loss = loss * 0.0 + loss_list['s{}_avs'.format(idx + 1)] = loss + return loss_list + + + def forward(self, x): + # generate fusion pattern + fusion_pattern, avs_pattern = self.gen_fusion_avs_pattern() + + # tackle misaligned logmel + if self.GET_MISALIGNED_AUDIO: + x = self.move_C_to_N(x) + + # generate mask for audio + audio_mask = self.filter_duplicates(x) + + # initialize feature list + features = {'avs_pattern': avs_pattern} + + # execute forward + x = self.s1(x) + if self.FS_FUSION[0] or self.AFS_FUSION[0]: + x, interm_feat = self.s1_fuse( + x, + get_misaligned_audio=self.GET_MISALIGNED_AUDIO, + mode=fusion_pattern[0], + ) + features = misc.update_dict_with_prefix( + features, + interm_feat, + prefix='s1_' + ) + x = self.s2(x) + if self.FS_FUSION[1] or self.AFS_FUSION[1]: + x, interm_feat = self.s2_fuse( + x, + get_misaligned_audio=self.GET_MISALIGNED_AUDIO, + mode=fusion_pattern[1], + ) + features = misc.update_dict_with_prefix( + features, + interm_feat, + prefix='s2_' + ) + for pathway in range(self.num_pathways): + pool = getattr(self, "pathway{}_pool".format(pathway)) + x[pathway] = pool(x[pathway]) + x = self.s3(x) + if self.FS_FUSION[2] or self.AFS_FUSION[2]: + x, interm_feat = self.s3_fuse( + x, + get_misaligned_audio=self.GET_MISALIGNED_AUDIO, + mode=fusion_pattern[2], + ) + features = misc.update_dict_with_prefix( + features, + interm_feat, + prefix='s3_' + ) + x = self.s4(x) + if self.FS_FUSION[3] or self.AFS_FUSION[3]: + x, interm_feat = self.s4_fuse( + x, + get_misaligned_audio=self.GET_MISALIGNED_AUDIO, + mode=fusion_pattern[3], + ) + features = misc.update_dict_with_prefix( + features, + interm_feat, + prefix='s4_' + ) + x = self.s5(x) + if self.AVS_FLAG[4]: + _, interm_feat = self.s5_fuse( + x, + get_misaligned_audio=self.GET_MISALIGNED_AUDIO, + mode='FS', + ) + features = misc.update_dict_with_prefix( + features, + interm_feat, + prefix='s5_' + ) + + # drop the negative samples in audio + if self.GET_MISALIGNED_AUDIO: + x = self.get_pos_audio(x) + + x = self.head(x) + + if self.training and self.GET_MISALIGNED_AUDIO: + # compute loss if in training + loss = self.avs_forward(features, audio_mask) + return x, loss + else: + return x + + class FuseFastToSlow(nn.Module): """ Fuses the information from the Fast pathway to the Slow pathway. Given the diff --git a/slowfast/utils/meters.py b/slowfast/utils/meters.py index f5a7d1c03..c7b78f3bd 100644 --- a/slowfast/utils/meters.py +++ b/slowfast/utils/meters.py @@ -9,7 +9,6 @@ from collections import defaultdict, deque import torch from fvcore.common.timer import Timer -from sklearn.metrics import average_precision_score import slowfast.datasets.ava_helper as ava_helper import slowfast.utils.logging as logging @@ -22,6 +21,8 @@ read_labelmap, ) +from sklearn.metrics import average_precision_score + logger = logging.get_logger(__name__) diff --git a/slowfast/utils/misc.py b/slowfast/utils/misc.py index e31c57426..a7221de36 100644 --- a/slowfast/utils/misc.py +++ b/slowfast/utils/misc.py @@ -95,7 +95,16 @@ def _get_model_analysis_input(cfg, use_train_input): cfg.DATA.TEST_CROP_SIZE, cfg.DATA.TEST_CROP_SIZE, ) - model_inputs = pack_pathway_output(cfg, input_tensors) + input_audio = None + if cfg.DATA.USE_AUDIO: + chn = 2 if cfg.DATA.GET_MISALIGNED_AUDIO else 1 + input_audio = torch.rand( + chn, + 1, + cfg.DATA.AUDIO_FRAME_NUM, + cfg.DATA.AUDIO_MEL_NUM, + ) + model_inputs = pack_pathway_output(cfg, input_tensors, input_audio) for i in range(len(model_inputs)): model_inputs[i] = model_inputs[i].unsqueeze(0) if cfg.NUM_GPUS: @@ -257,6 +266,20 @@ def aggregate_sub_bn_stats(module): return count +def update_dict_with_prefix(dict_dst, dict_src, prefix=''): + """ + Update a dictionary with the contents of another dictionary, with its keys + augmented with a prefix + Args: + dict_dst: destination dictionary + dict_src: source dictionary + prefix: the prefix to be inserted + """ + for k, v in dict_src.items(): + dict_dst[prefix + k] = v + return dict_dst + + def launch_job(cfg, init_method, func, daemon=False): """ Run 'func' on one or more GPUs, specified in cfg diff --git a/tools/run_net.py b/tools/run_net.py index 20d275466..b703eafeb 100644 --- a/tools/run_net.py +++ b/tools/run_net.py @@ -2,6 +2,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Wrapper to train and test a video classification model.""" +import torch + from slowfast.utils.misc import launch_job from slowfast.utils.parser import load_config, parse_args @@ -35,5 +37,6 @@ def main(): demo(cfg) -if __name__ == "__main__": +if __name__ == "__main__": + # torch.multiprocessing.set_start_method("forkserver") main() diff --git a/tools/train_net.py b/tools/train_net.py index 4224acb85..b8a095c66 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -2,6 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Train a video classification model.""" + import numpy as np import pprint import torch @@ -23,9 +24,7 @@ logger = logging.get_logger(__name__) -def train_epoch( - train_loader, model, optimizer, train_meter, cur_epoch, cfg, writer=None -): +def train_epoch(train_loader, model, optimizer, train_meter, cur_epoch, cfg): """ Perform the video training for one epoch. Args: @@ -37,8 +36,6 @@ def train_epoch( cur_epoch (int): current epoch of training. cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py - writer (TensorboardWriter, optional): TensorboardWriter object - to writer Tensorboard log. """ # Enable train mode. model.train() @@ -60,10 +57,16 @@ def train_epoch( val[i] = val[i].cuda(non_blocking=True) else: meta[key] = val.cuda(non_blocking=True) + + # Optionally shuffle misaligned audio data + inputs = loader.shuffle_misaligned_audio(cur_epoch, inputs, cfg) # Update the learning rate. lr = optim.get_epoch_lr(cur_epoch + float(cur_iter) / data_size, cfg) optim.set_lr(optimizer, lr) + + # Auxilliary losses. + aux_loss = {} if cfg.DETECTION.ENABLE: # Compute the predictions. @@ -72,11 +75,21 @@ def train_epoch( else: # Perform the forward pass. preds = model(inputs) + + # Fetch actual preds. + if type(preds) == type(()): + preds, avs_loss = preds + aux_loss.update(avs_loss) + # Explicitly declare reduction to mean. loss_fun = losses.get_loss_func(cfg.MODEL.LOSS_FUNC)(reduction="mean") # Compute the loss. loss = loss_fun(preds, labels) + + # Accumulate auxilliary losses. + if len(aux_loss) > 0: + loss = loss + sum(aux_loss.values()) # check Nan Loss. misc.check_nan_losses(loss) @@ -95,13 +108,6 @@ def train_epoch( train_meter.iter_toc() # Update and log stats. train_meter.update_stats(None, None, None, loss, lr) - # write to tensorboard format if available. - if writer is not None: - writer.add_scalars( - {"Train/loss": loss, "Train/lr": lr}, - global_step=data_size * cur_epoch + cur_iter, - ) - else: top1_err, top5_err = None, None if cfg.DATA.MULTI_LABEL: @@ -141,17 +147,6 @@ def train_epoch( cfg.NUM_GPUS, 1 ), # If running on CPU (cfg.NUM_GPUS == 1), use 1 to represent 1 CPU. ) - # write to tensorboard format if available. - if writer is not None: - writer.add_scalars( - { - "Train/loss": loss, - "Train/lr": lr, - "Train/Top1_err": top1_err, - "Train/Top5_err": top5_err, - }, - global_step=data_size * cur_epoch + cur_iter, - ) train_meter.log_iter_stats(cur_epoch, cur_iter) train_meter.iter_tic() @@ -162,7 +157,7 @@ def train_epoch( @torch.no_grad() -def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer=None): +def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg): """ Evaluate the model on the val set. Args: @@ -172,8 +167,6 @@ def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer=None): cur_epoch (int): number of the current epoch of training. cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py - writer (TensorboardWriter, optional): TensorboardWriter object - to writer Tensorboard log. """ # Evaluation mode enabled. The running stats would not be updated. @@ -222,6 +215,7 @@ def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer=None): if cfg.DATA.MULTI_LABEL: if cfg.NUM_GPUS > 1: preds, labels = du.all_gather([preds, labels]) + val_meter.update_predictions(preds, labels) else: # Compute the errors. num_topks_correct = metrics.topks_correct(preds, labels, (1, 5)) @@ -246,20 +240,13 @@ def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer=None): cfg.NUM_GPUS, 1 ), # If running on CPU (cfg.NUM_GPUS == 1), use 1 to represent 1 CPU. ) - # write to tensorboard format if available. - if writer is not None: - writer.add_scalars( - {"Val/Top1_err": top1_err, "Val/Top5_err": top5_err}, - global_step=len(val_loader) * cur_epoch + cur_iter, - ) - - val_meter.update_predictions(preds, labels) val_meter.log_iter_stats(cur_epoch, cur_iter) val_meter.iter_tic() # Log epoch stats. val_meter.log_epoch_stats(cur_epoch) + # write to tensorboard format if available. if writer is not None: if cfg.DETECTION.ENABLE: @@ -404,14 +391,6 @@ def train(cfg): train_meter = TrainMeter(len(train_loader), cfg) val_meter = ValMeter(len(val_loader), cfg) - # set up writer for logging to Tensorboard format. - if cfg.TENSORBOARD.ENABLE and du.is_master_proc( - cfg.NUM_GPUS * cfg.NUM_SHARDS - ): - writer = tb.TensorboardWriter(cfg) - else: - writer = None - # Perform the training loop. logger.info("Start epoch: {}".format(start_epoch + 1)) @@ -443,9 +422,7 @@ def train(cfg): # Shuffle the dataset. loader.shuffle_dataset(train_loader, cur_epoch) # Train for one epoch. - train_epoch( - train_loader, model, optimizer, train_meter, cur_epoch, cfg, writer - ) + train_epoch(train_loader, model, optimizer, train_meter, cur_epoch, cfg) # Compute precise BN stats. if cfg.BN.USE_PRECISE_STATS and len(get_bn_modules(model)) > 0: @@ -466,7 +443,4 @@ def train(cfg): if misc.is_eval_epoch( cfg, cur_epoch, None if multigrid is None else multigrid.schedule ): - eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer) - - if writer is not None: - writer.close() + eval_epoch(val_loader, model, val_meter, cur_epoch, cfg)