-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RFFT2D and IRFFT2D transforms (#1662)
- Loading branch information
1 parent
b8c1900
commit cb4f1f6
Showing
5 changed files
with
90 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
class IRFFT2DTransform: | ||
"""Inverse 2D Fast Fourier Transform (IRFFT2D) Transformation. | ||
This transformation applies the inverse 2D Fast Fourier Transform (IRFFT2D) | ||
to an image in the frequency domain. | ||
Input: | ||
- Tensor of shape (C, H, W), where C is the number of channels. | ||
Output: | ||
- Tensor of shape (C, H, W), where C is the number of channels. | ||
""" | ||
|
||
def __init__(self, shape: Tuple[int, int]): | ||
""" | ||
Args: | ||
shape: The desired output shape (H, W) after applying the inverse FFT | ||
""" | ||
self.shape = shape | ||
|
||
def __call__(self, freq_image: Tensor) -> Tensor: | ||
"""Applies the inverse 2D Fast Fourier Transform (IRFFT2D) to the input tensor. | ||
Args: | ||
freq_image: A tensor in the frequency domain of shape (C, H, W). | ||
Returns: | ||
Tensor: Reconstructed image after applying IRFFT2D, of shape (C, H, W). | ||
""" | ||
reconstructed_image: Tensor = torch.fft.irfft2(freq_image, s=self.shape) | ||
return reconstructed_image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
class RFFT2DTransform: | ||
"""2D Fast Fourier Transform (RFFT2D) Transformation. | ||
This transformation applies the 2D Fast Fourier Transform (RFFT2D) | ||
to an image, converting it from the spatial domain to the frequency domain. | ||
Input: | ||
- Tensor of shape (C, H, W), where C is the number of channels. | ||
Output: | ||
- Tensor of shape (C, H, W) in the frequency domain, where C is the number of channels. | ||
""" | ||
|
||
def __call__(self, image: Tensor) -> Tensor: | ||
"""Applies the 2D Fast Fourier Transform (RFFT2D) to the input image. | ||
Args: | ||
image: Input image as a Tensor of shape (C, H, W). | ||
Returns: | ||
Tensor: The image in the frequency domain after applying RFFT2D, of shape (C, H, W). | ||
""" | ||
|
||
rfft_image: Tensor = torch.fft.rfft2(image) | ||
return rfft_image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
|
||
from lightly.transforms import IRFFT2DTransform | ||
|
||
|
||
def test() -> None: | ||
transform = IRFFT2DTransform((32, 32)) | ||
image = torch.rand(3, 32, 17) | ||
output = transform(image) | ||
assert output.shape == (3, 32, 32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
|
||
from lightly.transforms import RFFT2DTransform | ||
|
||
|
||
def test() -> None: | ||
transform = RFFT2DTransform() | ||
image = torch.rand(3, 32, 32) | ||
output = transform(image) | ||
assert output.shape == (3, 32, 17) |