Skip to content

Commit

Permalink
Fix fractional centers in CircularConvolve (#471)
Browse files Browse the repository at this point in the history
* Add failing test

* Fix phase ramp

* Minor comment improvement

* Improve RNG key use

---------

Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
Michael-T-McCann and bwohlberg authored Nov 15, 2023
1 parent 4efcba4 commit 61f1c93
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
11 changes: 9 additions & 2 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def __init__(
output_dtype = snp.dtype(input_dtype) # cannot infer from h_dft because it is complex
else:
fft_shape = input_shape[-self.ndims :]
pad = ()
fft_axes = list(range(h.ndim - self.ndims, h.ndim))
self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes)
output_dtype = result_type(h.dtype, input_dtype)
Expand All @@ -140,7 +139,15 @@ def __init__(
offset = -snp.array(self.h_center)
shifts: Tuple[np.ndarray, ...] = np.ix_(
*tuple(
np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s)) # type: ignore
np.select(
# see doi:10.1109/78.700979 and doi:10.1109/LSP.2012.2191280
[np.arange(s) < s / 2, np.arange(s) == s / 2, np.arange(s) > s / 2],
[
np.exp(-1j * k * 2 * np.pi * np.arange(s) / s),
np.cos(k * np.pi),
np.exp(1j * k * 2 * np.pi * (s - np.arange(s)) / s),
], # type: ignore
)
for k, s in zip(offset, input_shape[-self.ndims :])
)
)
Expand Down
17 changes: 15 additions & 2 deletions scico/test/linop/test_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def setup_method(self, method):
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS)
def test_eval(self, axes_shape_spec, input_dtype, jit):

x_shape, ndims, h_shape = axes_shape_spec

h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)
Expand Down Expand Up @@ -62,7 +61,6 @@ def test_eval(self, axes_shape_spec, input_dtype, jit):
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS)
def test_adjoint(self, axes_shape_spec, input_dtype, jit):

x_shape, ndims, h_shape = axes_shape_spec

h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)
Expand Down Expand Up @@ -160,6 +158,21 @@ def test_center(self, center):
shift = -center[0]
np.testing.assert_allclose(A @ x, snp.roll(B @ x, shift), atol=1e-5)

def test_fractional_center(self):
"""A fractional center should keep outputs real."""
x, key = uniform(minval=-1, maxval=1, shape=(4, 5), key=self.key)
h, _ = uniform(minval=-1, maxval=1, shape=(2, 2), key=key)
A = CircularConvolve(h=h, input_shape=x.shape, h_center=[0.1, 2.7])

# taken from CircularConvolve._eval
x_dft = snp.fft.fftn(x, axes=A.x_fft_axes)
hx = snp.fft.ifftn(
A.h_dft * x_dft,
axes=A.ifft_axes,
)

np.testing.assert_allclose(hx, snp.real(hx))

@pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS)
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("jit_old_op", [True, False])
Expand Down

0 comments on commit 61f1c93

Please sign in to comment.