Coverage for python/lsst/scarlet/lite/fft.py: 12%
160 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:25 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:25 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["Fourier"]
26import logging
27import operator
28from typing import Callable, Sequence
30import numpy as np
31from numpy.typing import DTypeLike
32from scipy import fftpack
34logger = logging.getLogger("scarlet.lite.fft")
37def centered(arr: np.ndarray, newshape: Sequence[int]) -> np.ndarray:
38 """Return the central newshape portion of the array.
40 Parameters
41 ----------
42 arr:
43 The array to center.
44 newshape:
45 The new shape of the array.
47 Notes
48 -----
49 If the array shape is odd and the target is even,
50 the center of `arr` is shifted to the center-right
51 pixel position.
52 This is slightly different than the scipy implementation,
53 which uses the center-left pixel for the array center.
54 The reason for the difference is that we have
55 adopted the convention of `np.fft.fftshift` in order
56 to make sure that changing back and forth from
57 fft standard order (0 frequency and position is
58 in the bottom left) to 0 position in the center.
59 """
60 _newshape = np.array(newshape)
61 currshape = np.array(arr.shape)
63 if not np.all(_newshape <= currshape):
64 msg = f"arr must be larger than newshape in both dimensions, received {arr.shape}, and {_newshape}"
65 raise ValueError(msg)
67 startind = (currshape - _newshape + 1) // 2
68 endind = startind + _newshape
69 myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
71 return arr[tuple(myslice)]
74def fast_zero_pad(arr: np.ndarray, pad_width: Sequence[Sequence[int]]) -> np.ndarray:
75 """Fast version of numpy.pad when `mode="constant"`
77 Executing `numpy.pad` with zeros is ~1000 times slower
78 because it doesn't make use of the `zeros` method for padding.
80 Parameters
81 ---------
82 arr:
83 The array to pad
84 pad_width:
85 Number of values padded to the edges of each axis.
86 See numpy.pad docs for more.
88 Returns
89 -------
90 result: np.ndarray
91 The array padded with `constant_values`
92 """
93 newshape = tuple([a + ps[0] + ps[1] for a, ps in zip(arr.shape, pad_width)])
95 result = np.zeros(newshape, dtype=arr.dtype)
96 slices = tuple([slice(start, s - end) for s, (start, end) in zip(result.shape, pad_width)])
97 result[slices] = arr
98 return result
101def _pad(
102 arr: np.ndarray,
103 newshape: Sequence[int],
104 axes: int | Sequence[int] | None = None,
105 mode: str = "constant",
106 constant_values: float = 0,
107) -> np.ndarray:
108 """Pad an array to fit into newshape
110 Pad `arr` with zeros to fit into newshape,
111 which uses the `np.fft.fftshift` convention of moving
112 the center pixel of `arr` (if `arr.shape` is odd) to
113 the center-right pixel in an even shaped `newshape`.
115 Parameters
116 ----------
117 arr:
118 The arrray to pad.
119 newshape:
120 The new shape of the array.
121 axes:
122 The axes that are being reshaped.
123 mode:
124 The numpy mode used to pad the array.
125 In other words, how to fill the new padded elements.
126 See ``numpy.pad`` for details.
127 constant_values:
128 If `mode` == "constant" then this is the value to set all of
129 the new padded elements to.
130 """
131 _newshape = np.asarray(newshape)
132 if axes is None:
133 currshape = np.array(arr.shape)
134 diff = _newshape - currshape
135 startind = (diff + 1) // 2
136 endind = diff - startind
137 pad_width = list(zip(startind, endind))
138 else:
139 # only pad the axes that will be transformed
140 pad_width = [(0, 0) for _ in arr.shape]
141 if isinstance(axes, int):
142 axes = [axes]
143 for a, axis in enumerate(axes):
144 diff = _newshape[a] - arr.shape[axis]
145 startind = (diff + 1) // 2
146 endind = diff - startind
147 pad_width[axis] = (startind, endind)
148 if mode == "constant" and constant_values == 0:
149 result = fast_zero_pad(arr, pad_width)
150 else:
151 result = np.pad(arr, tuple(pad_width), mode=mode) # type: ignore
152 return result
155def get_fft_shape(
156 im_or_shape1: np.ndarray | Sequence[int],
157 im_or_shape2: np.ndarray | Sequence[int],
158 padding: int = 3,
159 axes: int | Sequence[int] | None = None,
160 use_max: bool = False,
161) -> tuple:
162 """Return the fast fft shapes for each spatial axis
164 Calculate the fast fft shape for each dimension in
165 axes.
167 Parameters
168 ----------
169 im_or_shape1:
170 The left image or shape of an image.
171 im_or_shape2:
172 The right image or shape of an image.
173 padding:
174 Any additional padding to add to the final shape.
175 axes:
176 The axes that are being transformed.
177 use_max:
178 Whether or not to use the maximum of the two shapes,
179 or the sum of the two shapes.
181 Returns
182 -------
183 shape:
184 Tuple of the shape to use when the two images are transformed
185 into k-space.
186 """
187 if isinstance(im_or_shape1, np.ndarray):
188 shape1 = np.asarray(im_or_shape1.shape)
189 else:
190 shape1 = np.asarray(im_or_shape1)
191 if isinstance(im_or_shape2, np.ndarray):
192 shape2 = np.asarray(im_or_shape2.shape)
193 else:
194 shape2 = np.asarray(im_or_shape2)
195 # Make sure the shapes are the same size
196 if len(shape1) != len(shape2):
197 msg = (
198 "img1 and img2 must have the same number of dimensions, "
199 f"but got {len(shape1)} and {len(shape2)}"
200 )
201 raise ValueError(msg)
202 # Set the combined shape based on the total dimensions
203 if axes is None:
204 if use_max:
205 shape = np.max([shape1, shape2], axis=0)
206 else:
207 shape = shape1 + shape2
208 else:
209 if isinstance(axes, int):
210 axes = [axes]
211 shape = np.zeros(len(axes), dtype="int")
212 for n, ax in enumerate(axes):
213 shape[n] = shape1[ax] + shape2[ax]
214 if use_max:
215 shape[n] = np.max([shape1[ax], shape2[ax]])
217 shape += padding
218 # Use the next fastest shape in each dimension
219 shape = [fftpack.next_fast_len(s) for s in shape]
220 return tuple(shape)
223class Fourier:
224 """An array that stores its Fourier Transform
226 The `Fourier` class is used for images that will make
227 use of their Fourier Transform multiple times.
228 In order to prevent numerical artifacts the same image
229 convolved with different images might require different
230 padding, so the FFT for each different shape is stored
231 in a dictionary.
233 Parameters
234 ----------
235 image: np.ndarray
236 The real space image.
237 image_fft: dict[Sequence[int], np.ndarray]
238 A dictionary of {shape: fft_value} for which each different
239 shape has a precalculated FFT.
240 """
242 def __init__(
243 self,
244 image: np.ndarray,
245 image_fft: dict[Sequence[Sequence[int]], np.ndarray] | None = None,
246 ):
247 if image_fft is None:
248 self._fft: dict[Sequence[Sequence[int]], np.ndarray] = {}
249 else:
250 self._fft = image_fft
251 self._image = image
253 @staticmethod
254 def from_fft(
255 image_fft: np.ndarray,
256 fft_shape: Sequence[int],
257 image_shape: Sequence[int],
258 axes: int | Sequence[int] | None = None,
259 dtype: DTypeLike = float,
260 ) -> Fourier:
261 """Generate a new Fourier object from an FFT dictionary
263 If the fft of an image has been generated but not its
264 real space image (for example when creating a convolution kernel),
265 this method can be called to create a new `Fourier` instance
266 from the k-space representation.
268 Parameters
269 ----------
270 image_fft:
271 The FFT of the image.
272 fft_shape:
273 "Fast" shape of the image used to generate the FFT.
274 This will be different than `image_fft.shape` if
275 any of the dimensions are odd, since `np.fft.rfft`
276 requires an even number of dimensions (for symmetry),
277 so this tells `np.fft.irfft` how to go from
278 complex k-space to real space.
279 image_shape:
280 The shape of the image *before padding*.
281 This will regenerate the image with the extra
282 padding stripped.
283 axes:
284 The dimension(s) of the array that will be transformed.
286 Returns
287 -------
288 result:
289 A `Fourier` object generated from the FFT.
290 """
291 if axes is None:
292 axes = range(len(image_shape))
293 if isinstance(axes, int):
294 axes = [axes]
295 all_axes = range(len(image_shape))
296 image = np.fft.irfftn(image_fft, fft_shape, axes=axes).astype(dtype)
297 # Shift the center of the image from the bottom left to the center
298 image = np.fft.fftshift(image, axes=axes)
299 # Trim the image to remove the padding added
300 # to reduce fft artifacts
301 image = centered(image, image_shape)
302 key = (tuple(fft_shape), tuple(axes), tuple(all_axes))
304 return Fourier(image, {key: image_fft})
306 @property
307 def image(self) -> np.ndarray:
308 """The real space image"""
309 return self._image
311 @property
312 def shape(self) -> tuple[int, ...]:
313 """The shape of the real space image"""
314 return self._image.shape
316 def fft(
317 self,
318 fft_shape: Sequence[int],
319 axes: int | Sequence[int],
320 cache: bool = True,
321 ) -> np.ndarray:
322 """The FFT of an image for a given `fft_shape` along desired `axes`
324 Parameters
325 ----------
326 fft_shape:
327 "Fast" shape of the image used to generate the FFT.
328 This will be different than `image_fft.shape` if
329 any of the dimensions are odd, since `np.fft.rfft`
330 requires an even number of dimensions (for symmetry),
331 so this tells `np.fft.irfft` how to go from
332 complex k-space to real space.
333 axes:
334 The dimension(s) of the array that will be transformed.
335 cache:
336 Whether to store the computed FFT in ``self._fft`` for reuse.
337 An already-cached entry for the same key is always returned
338 regardless of this flag; ``cache=False`` only suppresses
339 *adding* new entries, which prevents unbounded growth when a
340 long-lived `Fourier` (e.g. a kernel stored on `Observation`)
341 is convolved against many different shapes.
342 """
343 if isinstance(axes, int):
344 axes = (axes,)
345 all_axes = range(len(self.image.shape))
346 fft_key = (tuple(fft_shape), tuple(axes), tuple(all_axes))
348 if fft_key in self._fft:
349 return self._fft[fft_key]
351 if len(fft_shape) != len(axes):
352 msg = f"fft_shape self.axes must have the same number of dimensions, got {fft_shape}, {axes}"
353 raise ValueError(msg)
354 image = _pad(self.image, fft_shape, axes)
355 result = np.fft.rfftn(np.fft.ifftshift(image, axes), axes=axes)
356 if cache:
357 self._fft[fft_key] = result
358 return result
360 def __len__(self) -> int:
361 """Length of the image"""
362 return len(self.image)
364 def __getitem__(self, index: int | Sequence[int] | slice) -> Fourier:
365 # Make the index a tuple
366 if isinstance(index, int):
367 index = tuple([index])
369 # Axes that are removed from the shape of the new object
370 if isinstance(index, slice):
371 removed = np.array([])
372 else:
373 removed = np.array([n for n, idx in enumerate(index) if idx is not None])
375 # Create views into the fft transformed values, appropriately adjusting
376 # the shapes for the new axes
378 fft_kernels = {
379 (
380 tuple([s for idx, s in enumerate(key[0]) if key[0][idx] not in removed]),
381 tuple([s for idx, s in enumerate(key[1]) if key[1][idx] not in removed]),
382 tuple([s for idx, s in enumerate(key[2]) if key[2][idx] not in removed]),
383 ): kernel[index]
384 for key, kernel in self._fft.items()
385 }
386 # mpypy doesn't recognize that tuple[int, ...]
387 # is a valid Sequence[int] for some reason
388 return Fourier(self.image[index], fft_kernels) # type: ignore
391def _kspace_operation(
392 image1: Fourier,
393 image2: Fourier,
394 padding: int,
395 op: Callable,
396 shape: Sequence[int],
397 axes: int | Sequence[int],
398 cache: bool = True,
399) -> Fourier:
400 """Combine two images in k-space using a given `operator`
402 Parameters
403 ----------
404 image1:
405 The LHS of the equation.
406 image2:
407 The RHS of the equation.
408 padding:
409 The amount of padding to add before transforming into k-space.
410 op:
411 The operator used to combine the two images.
412 This is either ``operator.mul`` for a convolution
413 or ``operator.truediv`` for deconvolution.
414 shape:
415 The shape of the output image.
416 axes:
417 The dimension(s) of the array that will be transformed.
418 cache:
419 Whether the per-shape FFT of each input should be stored on the
420 respective `Fourier` instance. Forwarded to `Fourier.fft`.
421 """
422 if len(image1.shape) != len(image2.shape):
423 msg = (
424 "Both images must have the same number of axes, "
425 f"got {len(image1.shape)} and {len(image2.shape)}"
426 )
427 raise ValueError(msg)
429 fft_shape = get_fft_shape(image1.image, image2.image, padding, axes)
430 if (
431 op == operator.truediv
432 or op == operator.floordiv
433 or op == operator.itruediv
434 or op == operator.ifloordiv
435 ):
436 # prevent divide by zero
437 lhs = image1.fft(fft_shape, axes, cache=cache)
438 rhs = image2.fft(fft_shape, axes, cache=cache)
440 # Broadcast, if necessary
441 if rhs.shape[0] == 1 and lhs.shape[0] != rhs.shape[0]:
442 rhs = np.tile(rhs, (lhs.shape[0],) + (1,) * len(rhs.shape[1:]))
443 if lhs.shape[0] == 1 and lhs.shape[0] != rhs.shape[0]:
444 lhs = np.tile(lhs, (rhs.shape[0],) + (1,) * len(lhs.shape[1:]))
445 # only select non-zero elements for the denominator
446 cuts = rhs != 0
447 transformed_fft = np.zeros(lhs.shape, dtype=lhs.dtype)
448 transformed_fft[cuts] = op(lhs[cuts], rhs[cuts])
449 else:
450 transformed_fft = op(
451 image1.fft(fft_shape, axes, cache=cache),
452 image2.fft(fft_shape, axes, cache=cache),
453 )
454 return Fourier.from_fft(transformed_fft, fft_shape, shape, axes, image1.image.dtype)
457def match_kernel(
458 kernel1: np.ndarray | Fourier,
459 kernel2: np.ndarray | Fourier,
460 padding: int = 3,
461 axes: int | Sequence[int] = (-2, -1),
462 return_fourier: bool = True,
463 normalize: None = None,
464 cache: bool = True,
465) -> Fourier | np.ndarray:
466 """Calculate the difference kernel to match kernel1 to kernel2
468 Parameters
469 ----------
470 kernel1:
471 The first kernel, either as array or as `Fourier` object
472 kernel2:
473 The second kernel, either as array or as `Fourier` object
474 padding:
475 Additional padding to use when generating the FFT
476 to supress artifacts.
477 axes:
478 Axes that contain the spatial information for the kernels.
479 return_fourier:
480 Whether to return `Fourier` or array
481 normalize:
482 Deprecated and unused. Will be removed after v31.0.
483 cache:
484 Whether the per-shape FFTs of the inputs should be stored on the
485 respective `Fourier` instances. Forwarded to `Fourier.fft`.
487 Returns
488 -------
489 result:
490 The difference kernel to go from `kernel1` to `kernel2`.
491 """
492 if normalize is not None:
493 logger.warning(
494 "normalize is deprecated and will be removed after v31.0. "
495 "It has never been used by match_kernel."
496 )
497 if not isinstance(kernel1, Fourier):
498 kernel1 = Fourier(kernel1)
499 if not isinstance(kernel2, Fourier):
500 kernel2 = Fourier(kernel2)
502 if kernel1.shape[0] < kernel2.shape[0]:
503 shape = kernel2.shape
504 else:
505 shape = kernel1.shape
507 diff = _kspace_operation(kernel1, kernel2, padding, operator.truediv, shape, axes=axes, cache=cache)
508 if return_fourier:
509 return diff
510 else:
511 return np.real(diff.image)
514def convolve(
515 image: np.ndarray | Fourier,
516 kernel: np.ndarray | Fourier,
517 padding: int = 3,
518 axes: int | Sequence[int] = (-2, -1),
519 return_fourier: bool = True,
520 normalize: None = None,
521 cache: bool = True,
522) -> np.ndarray | Fourier:
523 """Convolve image with a kernel
525 Parameters
526 ----------
527 image:
528 Image either as array or as `Fourier` object
529 kernel:
530 Convolution kernel either as array or as `Fourier` object
531 padding:
532 Additional padding to use when generating the FFT
533 to suppress artifacts.
534 axes:
535 Axes that contain the spatial information for the PSFs.
536 return_fourier:
537 Whether to return `Fourier` or array
538 normalize:
539 Deprecated and unused. Will be removed after v31.0.
540 cache:
541 Whether the per-shape FFTs of `image` and `kernel` should be
542 stored on the respective `Fourier` instances. Forwarded to
543 `Fourier.fft`. Long-lived `kernel` objects (e.g. those held by
544 `Observation`) accumulate one cache entry per distinct image
545 shape; pass ``cache=False`` when convolving with shapes that
546 won't recur, to avoid unbounded growth.
548 Returns
549 -------
550 result:
551 The convolution of the image with the kernel.
552 """
553 if normalize is not None:
554 logger.warning(
555 "normalize is deprecated and will be removed after v31.0. " "It has never been used by convolve."
556 )
557 if not isinstance(image, Fourier):
558 image = Fourier(image)
559 if not isinstance(kernel, Fourier):
560 kernel = Fourier(kernel)
562 convolved = _kspace_operation(image, kernel, padding, operator.mul, image.shape, axes=axes, cache=cache)
563 if return_fourier:
564 return convolved
565 else:
566 return np.real(convolved.image)