Skip to content

Commit

Permalink
Merge branch 'no_jit_assert' into v0.0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewilyas authored Jan 24, 2022
2 parents 405c096 + d817cc2 commit 131d562
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 50 deletions.
2 changes: 2 additions & 0 deletions ffcv/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .replace_label import ReplaceLabel
from .normalize import NormalizeImage
from .translate import RandomTranslate
from .mixup import ImageMixup, LabelMixup, MixupToOneHot
from .module import ModuleWrapper

__all__ = ['ToTensor', 'ToDevice',
'ToTorchImage', 'NormalizeImage',
Expand Down
4 changes: 2 additions & 2 deletions ffcv/transforms/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import numpy as np
from typing import Callable, Optional, Tuple
from dataclasses import replace

from ffcv.pipeline.compiler import Compiler
from ..pipeline.allocation_query import AllocationQuery
Expand Down Expand Up @@ -48,5 +49,4 @@ def cutout_square(images, *_):
return cutout_square

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
assert previous_state.jit_mode
return previous_state, None
return replace(previous_state, jit_mode=True), None
6 changes: 3 additions & 3 deletions ffcv/transforms/flip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Random horizontal flip
"""
from numpy import dtype
from dataclasses import replace
from numpy.random import rand
from typing import Callable, Optional, Tuple
from ..pipeline.allocation_query import AllocationQuery
Expand Down Expand Up @@ -42,5 +42,5 @@ def flip(images, dst):
return flip

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
assert previous_state.jit_mode
return (previous_state, AllocationQuery(previous_state.shape, previous_state.dtype))
return (replace(previous_state, jit_mode=True),
AllocationQuery(previous_state.shape, previous_state.dtype))
5 changes: 1 addition & 4 deletions ffcv/transforms/mixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def mixer(images, dst, indices):
return mixer

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
# assert previous_state.jit_mode
# We do everything in place
return (previous_state, AllocationQuery(shape=previous_state.shape,
dtype=previous_state.dtype))

Expand Down Expand Up @@ -92,8 +90,6 @@ def mixer(labels, temp_array, indices):
return mixer

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
# assert previous_state.jit_mode
# We do everything in place
return (replace(previous_state, shape=(3,), dtype=np.float32),
AllocationQuery((3,), dtype=np.float32))

Expand All @@ -115,6 +111,7 @@ def one_hotter(mixedup_labels, dst):
return one_hotter

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
# Should already be converted to tensor
assert not previous_state.jit_mode
return (replace(previous_state, shape=(self.num_classes,)), \
AllocationQuery((self.num_classes,), dtype=previous_state.dtype, device=previous_state.device))
9 changes: 3 additions & 6 deletions ffcv/transforms/poisoning.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""
Poison images by adding a mask
"""
from collections.abc import Sequence
from typing import Tuple
from dataclasses import replace

import numpy as np
from numpy import dtype
from numpy.core.numeric import indices
from numpy.random import rand
from typing import Callable, Optional, Tuple
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
Expand Down Expand Up @@ -67,6 +64,6 @@ def poison(images, temp_array, indices):
return poison

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
assert previous_state.jit_mode
# We do everything in place
return (previous_state, AllocationQuery(shape=previous_state.shape, dtype=np.float32))
return (replace(previous_state, jit_mode=True), \
AllocationQuery(shape=previous_state.shape, dtype=np.dtype('float32')))
9 changes: 2 additions & 7 deletions ffcv/transforms/replace_label.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""
Replace label
"""
from collections.abc import Sequence
from typing import Tuple

import numpy as np
from numpy import dtype
from numpy.core.numeric import indices
from numpy.random import rand
from dataclasses import replace
from typing import Callable, Optional, Tuple
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
Expand Down Expand Up @@ -50,6 +47,4 @@ def replace_label(labels, temp_array, indices):
return replace_label

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
assert previous_state.jit_mode
# We do everything in place
return (previous_state, None)
return (replace(previous_state, jit_mode=True), None)
9 changes: 5 additions & 4 deletions ffcv/transforms/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Random translate
"""
import numpy as np
from numpy import dtype
from numpy.random import randint
from typing import Any, Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple
from dataclasses import replace
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
from ..pipeline.state import State
Expand Down Expand Up @@ -51,5 +51,6 @@ def translate(images, dst):

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
h, w, c = previous_state.shape
assert previous_state.jit_mode
return (previous_state, AllocationQuery((h + 2 * self.padding, w + 2 * self.padding, c), previous_state.dtype))
return (replace(previous_state, jit_mode=True), \
AllocationQuery((h + 2 * self.padding, w + 2 * self.padding, c), previous_state.dtype))

180 changes: 156 additions & 24 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
import os
import uuid
import numpy as np
import torch as ch
from torch.utils.data import Dataset
from torchvision import transforms as tvt
from assertpy import assert_that
from tempfile import NamedTemporaryFile
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image, make_grid
from torch.utils.data import Subset
from ffcv.fields.basics import IntDecoder
from ffcv.fields.rgb_image import SimpleRGBImageDecoder
from ffcv.transforms.cutout import Cutout

from ffcv.writer import DatasetWriter
from ffcv.fields import IntField, RGBImageField
from ffcv.loader import Loader
from ffcv.pipeline.compiler import Compiler
from ffcv.transforms import Squeeze, Cutout, ToTensor, Poison, RandomHorizontalFlip
from ffcv.transforms import *

def run_test(length, pipeline, compile):
SAVE_IMAGES = True
IMAGES_TMP_PATH = '/tmp/ffcv_augtest_output'
if SAVE_IMAGES:
os.makedirs(IMAGES_TMP_PATH, exist_ok=True)

UNAUGMENTED_PIPELINE=[
SimpleRGBImageDecoder(),
ToTensor(),
ToTorchImage()
]

def run_test(length, pipeline, compile=False):
my_dataset = Subset(CIFAR10(root='/tmp', train=True, download=True), range(length))

with NamedTemporaryFile() as handle:
Expand All @@ -28,52 +42,170 @@ def run_test(length, pipeline, compile):

writer.from_indexed_dataset(my_dataset, chunksize=10)

Compiler.set_enabled(True)
Compiler.set_enabled(compile)

loader = Loader(name, batch_size=7, num_workers=2, pipelines={
'image': pipeline,
'label': [IntDecoder(), ToTensor(), Squeeze()]
},
drop_last=False)

unaugmented_loader = Loader(name, batch_size=7, num_workers=2, pipelines={
'image': UNAUGMENTED_PIPELINE,
'label': [IntDecoder(), ToTensor(), Squeeze()]
}, drop_last=False)

tot_indices = 0
tot_images = 0
for images, label in loader:
tot_indices += label.shape[0]
for (images, labels), (original_images, original_labels) in zip(loader, unaugmented_loader):
print(images.shape, original_images.shape)
tot_indices += labels.shape[0]
tot_images += images.shape[0]

for label, original_label in zip(labels, original_labels):
assert_that(label).is_equal_to(original_label)

if SAVE_IMAGES:
save_image(make_grid(ch.concat([images, original_images])/255., images.shape[0]),
os.path.join(IMAGES_TMP_PATH, str(uuid.uuid4()) + '.jpeg')
)

assert_that(tot_indices).is_equal_to(len(my_dataset))
assert_that(tot_images).is_equal_to(len(my_dataset))

def test_cutout():
for comp in [True, False]:
run_test(100, [
SimpleRGBImageDecoder(),
Cutout(8),
ToTensor(),
ToTorchImage()
], comp)


def test_flip():
for comp in [True, False]:
run_test(100, [
SimpleRGBImageDecoder(),
RandomHorizontalFlip(1.0),
ToTensor(),
ToTorchImage()
], comp)


def test_module_wrapper():
for comp in [True, False]:
run_test(100, [
SimpleRGBImageDecoder(),
ToTensor(),
ToTorchImage(),
ModuleWrapper(tvt.Grayscale(3)),
], comp)


def test_mixup():
for comp in [True, False]:
run_test(100, [
SimpleRGBImageDecoder(),
ImageMixup(.5, False),
ToTensor(),
ToTorchImage()
], comp)


def test_poison():
mask = np.zeros((32, 32, 3))
# Red sqaure
mask[:5, :5, 0] = 1
alpha = np.ones((32, 32))

for comp in [True, False]:
run_test(100, [
SimpleRGBImageDecoder(),
Poison(mask, alpha, list(range(100))),
ToTensor(),
ToTorchImage()
], comp)


def test_random_resized_crop():
for comp in [True, False]:
run_test(100, [
SimpleRGBImageDecoder(),
RandomResizedCrop(scale=(0.08, 1.0),
ratio=(0.75, 4/3),
size=32),
ToTensor(),
ToTorchImage()
], comp)


def test_translate():
for comp in [True, False]:
run_test(100, [
SimpleRGBImageDecoder(),
RandomTranslate(padding=10),
ToTensor(),
ToTorchImage()
], comp)


## Torchvision Transforms
def test_torchvision_greyscale():
run_test(100, [
SimpleRGBImageDecoder(),
RandomHorizontalFlip(1.0),
ToTensor()
], True)
ToTensor(),
ToTorchImage(),
tvt.Grayscale(3),
])

def test_cutout():
def test_torchvision_centercrop_pad():
run_test(100, [
SimpleRGBImageDecoder(),
Cutout(8),
ToTensor()
], True)
ToTensor(),
ToTorchImage(),
tvt.CenterCrop(10),
tvt.Pad(11)
])

def test_torchvision_random_affine():
run_test(100, [
SimpleRGBImageDecoder(),
Cutout(8),
ToTensor()
], False)
ToTensor(),
ToTorchImage(),
tvt.RandomAffine(25),
])

def test_torchvision_random_crop():
run_test(100, [
SimpleRGBImageDecoder(),
ToTensor(),
ToTorchImage(),
tvt.Pad(10),
tvt.RandomCrop(size=32),
])

def test_poison():
mask = np.zeros((32, 32, 3))
# Red sqaure
mask[:5, :5, 0] = 1
alpha = np.ones((32, 32))
def test_torchvision_color_jitter():
run_test(100, [
SimpleRGBImageDecoder(),
Poison(mask, alpha, [0, 1, 2]),
ToTensor()
], False)
ToTensor(),
ToTorchImage(),
tvt.ColorJitter(.5, .5, .5, .5),
])


if __name__ == '__main__':
# test_cutout()
test_flip()
# test_module_wrapper()
# test_mixup()
# test_poison()
# test_random_resized_crop()
# test_translate()

## Torchvision Transforms
# test_torchvision_greyscale()
# test_torchvision_centercrop_pad()
# test_torchvision_random_affine()
# test_torchvision_random_crop()
# test_torchvision_color_jitter()

0 comments on commit 131d562

Please sign in to comment.