Skip to content

Commit

Permalink
Add RFFT2D and IRFFT2D transforms (#1662)
Browse files Browse the repository at this point in the history
  • Loading branch information
snehilchatterjee authored Oct 4, 2024
1 parent b8c1900 commit cb4f1f6
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform
from lightly.transforms.fast_siam_transform import FastSiamTransform
from lightly.transforms.gaussian_blur import GaussianBlur
from lightly.transforms.irfft2d_transform import IRFFT2DTransform
from lightly.transforms.jigsaw import Jigsaw
from lightly.transforms.mae_transform import MAETransform
from lightly.transforms.mmcr_transform import MMCRTransform
from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform
from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform
from lightly.transforms.pirl_transform import PIRLTransform
from lightly.transforms.rfft2d_transform import RFFT2DTransform
from lightly.transforms.rotation import (
RandomRotate,
RandomRotateDegrees,
Expand Down
37 changes: 37 additions & 0 deletions lightly/transforms/irfft2d_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Tuple

import torch
from torch import Tensor


class IRFFT2DTransform:
"""Inverse 2D Fast Fourier Transform (IRFFT2D) Transformation.
This transformation applies the inverse 2D Fast Fourier Transform (IRFFT2D)
to an image in the frequency domain.
Input:
- Tensor of shape (C, H, W), where C is the number of channels.
Output:
- Tensor of shape (C, H, W), where C is the number of channels.
"""

def __init__(self, shape: Tuple[int, int]):
"""
Args:
shape: The desired output shape (H, W) after applying the inverse FFT
"""
self.shape = shape

def __call__(self, freq_image: Tensor) -> Tensor:
"""Applies the inverse 2D Fast Fourier Transform (IRFFT2D) to the input tensor.
Args:
freq_image: A tensor in the frequency domain of shape (C, H, W).
Returns:
Tensor: Reconstructed image after applying IRFFT2D, of shape (C, H, W).
"""
reconstructed_image: Tensor = torch.fft.irfft2(freq_image, s=self.shape)
return reconstructed_image
31 changes: 31 additions & 0 deletions lightly/transforms/rfft2d_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Union

import torch
from torch import Tensor


class RFFT2DTransform:
"""2D Fast Fourier Transform (RFFT2D) Transformation.
This transformation applies the 2D Fast Fourier Transform (RFFT2D)
to an image, converting it from the spatial domain to the frequency domain.
Input:
- Tensor of shape (C, H, W), where C is the number of channels.
Output:
- Tensor of shape (C, H, W) in the frequency domain, where C is the number of channels.
"""

def __call__(self, image: Tensor) -> Tensor:
"""Applies the 2D Fast Fourier Transform (RFFT2D) to the input image.
Args:
image: Input image as a Tensor of shape (C, H, W).
Returns:
Tensor: The image in the frequency domain after applying RFFT2D, of shape (C, H, W).
"""

rfft_image: Tensor = torch.fft.rfft2(image)
return rfft_image
10 changes: 10 additions & 0 deletions tests/transforms/test_irfft2d_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

from lightly.transforms import IRFFT2DTransform


def test() -> None:
transform = IRFFT2DTransform((32, 32))
image = torch.rand(3, 32, 17)
output = transform(image)
assert output.shape == (3, 32, 32)
10 changes: 10 additions & 0 deletions tests/transforms/test_rfft2d_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

from lightly.transforms import RFFT2DTransform


def test() -> None:
transform = RFFT2DTransform()
image = torch.rand(3, 32, 32)
output = transform(image)
assert output.shape == (3, 32, 17)

0 comments on commit cb4f1f6

Please sign in to comment.