Coverage for python/lsst/scarlet/lite/fft.py: 12%

160 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 08:24 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Fourier"] 

25 

26import logging 

27import operator 

28from typing import Callable, Sequence 

29 

30import numpy as np 

31from numpy.typing import DTypeLike 

32from scipy import fftpack 

33 

34logger = logging.getLogger("scarlet.lite.fft") 

35 

36 

37def centered(arr: np.ndarray, newshape: Sequence[int]) -> np.ndarray: 

38 """Return the central newshape portion of the array. 

39 

40 Parameters 

41 ---------- 

42 arr: 

43 The array to center. 

44 newshape: 

45 The new shape of the array. 

46 

47 Notes 

48 ----- 

49 If the array shape is odd and the target is even, 

50 the center of `arr` is shifted to the center-right 

51 pixel position. 

52 This is slightly different than the scipy implementation, 

53 which uses the center-left pixel for the array center. 

54 The reason for the difference is that we have 

55 adopted the convention of `np.fft.fftshift` in order 

56 to make sure that changing back and forth from 

57 fft standard order (0 frequency and position is 

58 in the bottom left) to 0 position in the center. 

59 """ 

60 _newshape = np.array(newshape) 

61 currshape = np.array(arr.shape) 

62 

63 if not np.all(_newshape <= currshape): 

64 msg = f"arr must be larger than newshape in both dimensions, received {arr.shape}, and {_newshape}" 

65 raise ValueError(msg) 

66 

67 startind = (currshape - _newshape + 1) // 2 

68 endind = startind + _newshape 

69 myslice = [slice(startind[k], endind[k]) for k in range(len(endind))] 

70 

71 return arr[tuple(myslice)] 

72 

73 

74def fast_zero_pad(arr: np.ndarray, pad_width: Sequence[Sequence[int]]) -> np.ndarray: 

75 """Fast version of numpy.pad when `mode="constant"` 

76 

77 Executing `numpy.pad` with zeros is ~1000 times slower 

78 because it doesn't make use of the `zeros` method for padding. 

79 

80 Parameters 

81 --------- 

82 arr: 

83 The array to pad 

84 pad_width: 

85 Number of values padded to the edges of each axis. 

86 See numpy.pad docs for more. 

87 

88 Returns 

89 ------- 

90 result: np.ndarray 

91 The array padded with `constant_values` 

92 """ 

93 newshape = tuple([a + ps[0] + ps[1] for a, ps in zip(arr.shape, pad_width)]) 

94 

95 result = np.zeros(newshape, dtype=arr.dtype) 

96 slices = tuple([slice(start, s - end) for s, (start, end) in zip(result.shape, pad_width)]) 

97 result[slices] = arr 

98 return result 

99 

100 

101def _pad( 

102 arr: np.ndarray, 

103 newshape: Sequence[int], 

104 axes: int | Sequence[int] | None = None, 

105 mode: str = "constant", 

106 constant_values: float = 0, 

107) -> np.ndarray: 

108 """Pad an array to fit into newshape 

109 

110 Pad `arr` with zeros to fit into newshape, 

111 which uses the `np.fft.fftshift` convention of moving 

112 the center pixel of `arr` (if `arr.shape` is odd) to 

113 the center-right pixel in an even shaped `newshape`. 

114 

115 Parameters 

116 ---------- 

117 arr: 

118 The arrray to pad. 

119 newshape: 

120 The new shape of the array. 

121 axes: 

122 The axes that are being reshaped. 

123 mode: 

124 The numpy mode used to pad the array. 

125 In other words, how to fill the new padded elements. 

126 See ``numpy.pad`` for details. 

127 constant_values: 

128 If `mode` == "constant" then this is the value to set all of 

129 the new padded elements to. 

130 """ 

131 _newshape = np.asarray(newshape) 

132 if axes is None: 

133 currshape = np.array(arr.shape) 

134 diff = _newshape - currshape 

135 startind = (diff + 1) // 2 

136 endind = diff - startind 

137 pad_width = list(zip(startind, endind)) 

138 else: 

139 # only pad the axes that will be transformed 

140 pad_width = [(0, 0) for _ in arr.shape] 

141 if isinstance(axes, int): 

142 axes = [axes] 

143 for a, axis in enumerate(axes): 

144 diff = _newshape[a] - arr.shape[axis] 

145 startind = (diff + 1) // 2 

146 endind = diff - startind 

147 pad_width[axis] = (startind, endind) 

148 if mode == "constant" and constant_values == 0: 

149 result = fast_zero_pad(arr, pad_width) 

150 else: 

151 result = np.pad(arr, tuple(pad_width), mode=mode) # type: ignore 

152 return result 

153 

154 

155def get_fft_shape( 

156 im_or_shape1: np.ndarray | Sequence[int], 

157 im_or_shape2: np.ndarray | Sequence[int], 

158 padding: int = 3, 

159 axes: int | Sequence[int] | None = None, 

160 use_max: bool = False, 

161) -> tuple: 

162 """Return the fast fft shapes for each spatial axis 

163 

164 Calculate the fast fft shape for each dimension in 

165 axes. 

166 

167 Parameters 

168 ---------- 

169 im_or_shape1: 

170 The left image or shape of an image. 

171 im_or_shape2: 

172 The right image or shape of an image. 

173 padding: 

174 Any additional padding to add to the final shape. 

175 axes: 

176 The axes that are being transformed. 

177 use_max: 

178 Whether or not to use the maximum of the two shapes, 

179 or the sum of the two shapes. 

180 

181 Returns 

182 ------- 

183 shape: 

184 Tuple of the shape to use when the two images are transformed 

185 into k-space. 

186 """ 

187 if isinstance(im_or_shape1, np.ndarray): 

188 shape1 = np.asarray(im_or_shape1.shape) 

189 else: 

190 shape1 = np.asarray(im_or_shape1) 

191 if isinstance(im_or_shape2, np.ndarray): 

192 shape2 = np.asarray(im_or_shape2.shape) 

193 else: 

194 shape2 = np.asarray(im_or_shape2) 

195 # Make sure the shapes are the same size 

196 if len(shape1) != len(shape2): 

197 msg = ( 

198 "img1 and img2 must have the same number of dimensions, " 

199 f"but got {len(shape1)} and {len(shape2)}" 

200 ) 

201 raise ValueError(msg) 

202 # Set the combined shape based on the total dimensions 

203 if axes is None: 

204 if use_max: 

205 shape = np.max([shape1, shape2], axis=0) 

206 else: 

207 shape = shape1 + shape2 

208 else: 

209 if isinstance(axes, int): 

210 axes = [axes] 

211 shape = np.zeros(len(axes), dtype="int") 

212 for n, ax in enumerate(axes): 

213 shape[n] = shape1[ax] + shape2[ax] 

214 if use_max: 

215 shape[n] = np.max([shape1[ax], shape2[ax]]) 

216 

217 shape += padding 

218 # Use the next fastest shape in each dimension 

219 shape = [fftpack.next_fast_len(s) for s in shape] 

220 return tuple(shape) 

221 

222 

223class Fourier: 

224 """An array that stores its Fourier Transform 

225 

226 The `Fourier` class is used for images that will make 

227 use of their Fourier Transform multiple times. 

228 In order to prevent numerical artifacts the same image 

229 convolved with different images might require different 

230 padding, so the FFT for each different shape is stored 

231 in a dictionary. 

232 

233 Parameters 

234 ---------- 

235 image: np.ndarray 

236 The real space image. 

237 image_fft: dict[Sequence[int], np.ndarray] 

238 A dictionary of {shape: fft_value} for which each different 

239 shape has a precalculated FFT. 

240 """ 

241 

242 def __init__( 

243 self, 

244 image: np.ndarray, 

245 image_fft: dict[Sequence[Sequence[int]], np.ndarray] | None = None, 

246 ): 

247 if image_fft is None: 

248 self._fft: dict[Sequence[Sequence[int]], np.ndarray] = {} 

249 else: 

250 self._fft = image_fft 

251 self._image = image 

252 

253 @staticmethod 

254 def from_fft( 

255 image_fft: np.ndarray, 

256 fft_shape: Sequence[int], 

257 image_shape: Sequence[int], 

258 axes: int | Sequence[int] | None = None, 

259 dtype: DTypeLike = float, 

260 ) -> Fourier: 

261 """Generate a new Fourier object from an FFT dictionary 

262 

263 If the fft of an image has been generated but not its 

264 real space image (for example when creating a convolution kernel), 

265 this method can be called to create a new `Fourier` instance 

266 from the k-space representation. 

267 

268 Parameters 

269 ---------- 

270 image_fft: 

271 The FFT of the image. 

272 fft_shape: 

273 "Fast" shape of the image used to generate the FFT. 

274 This will be different than `image_fft.shape` if 

275 any of the dimensions are odd, since `np.fft.rfft` 

276 requires an even number of dimensions (for symmetry), 

277 so this tells `np.fft.irfft` how to go from 

278 complex k-space to real space. 

279 image_shape: 

280 The shape of the image *before padding*. 

281 This will regenerate the image with the extra 

282 padding stripped. 

283 axes: 

284 The dimension(s) of the array that will be transformed. 

285 

286 Returns 

287 ------- 

288 result: 

289 A `Fourier` object generated from the FFT. 

290 """ 

291 if axes is None: 

292 axes = range(len(image_shape)) 

293 if isinstance(axes, int): 

294 axes = [axes] 

295 all_axes = range(len(image_shape)) 

296 image = np.fft.irfftn(image_fft, fft_shape, axes=axes).astype(dtype) 

297 # Shift the center of the image from the bottom left to the center 

298 image = np.fft.fftshift(image, axes=axes) 

299 # Trim the image to remove the padding added 

300 # to reduce fft artifacts 

301 image = centered(image, image_shape) 

302 key = (tuple(fft_shape), tuple(axes), tuple(all_axes)) 

303 

304 return Fourier(image, {key: image_fft}) 

305 

306 @property 

307 def image(self) -> np.ndarray: 

308 """The real space image""" 

309 return self._image 

310 

311 @property 

312 def shape(self) -> tuple[int, ...]: 

313 """The shape of the real space image""" 

314 return self._image.shape 

315 

316 def fft( 

317 self, 

318 fft_shape: Sequence[int], 

319 axes: int | Sequence[int], 

320 cache: bool = True, 

321 ) -> np.ndarray: 

322 """The FFT of an image for a given `fft_shape` along desired `axes` 

323 

324 Parameters 

325 ---------- 

326 fft_shape: 

327 "Fast" shape of the image used to generate the FFT. 

328 This will be different than `image_fft.shape` if 

329 any of the dimensions are odd, since `np.fft.rfft` 

330 requires an even number of dimensions (for symmetry), 

331 so this tells `np.fft.irfft` how to go from 

332 complex k-space to real space. 

333 axes: 

334 The dimension(s) of the array that will be transformed. 

335 cache: 

336 Whether to store the computed FFT in ``self._fft`` for reuse. 

337 An already-cached entry for the same key is always returned 

338 regardless of this flag; ``cache=False`` only suppresses 

339 *adding* new entries, which prevents unbounded growth when a 

340 long-lived `Fourier` (e.g. a kernel stored on `Observation`) 

341 is convolved against many different shapes. 

342 """ 

343 if isinstance(axes, int): 

344 axes = (axes,) 

345 all_axes = range(len(self.image.shape)) 

346 fft_key = (tuple(fft_shape), tuple(axes), tuple(all_axes)) 

347 

348 if fft_key in self._fft: 

349 return self._fft[fft_key] 

350 

351 if len(fft_shape) != len(axes): 

352 msg = f"fft_shape self.axes must have the same number of dimensions, got {fft_shape}, {axes}" 

353 raise ValueError(msg) 

354 image = _pad(self.image, fft_shape, axes) 

355 result = np.fft.rfftn(np.fft.ifftshift(image, axes), axes=axes) 

356 if cache: 

357 self._fft[fft_key] = result 

358 return result 

359 

360 def __len__(self) -> int: 

361 """Length of the image""" 

362 return len(self.image) 

363 

364 def __getitem__(self, index: int | Sequence[int] | slice) -> Fourier: 

365 # Make the index a tuple 

366 if isinstance(index, int): 

367 index = tuple([index]) 

368 

369 # Axes that are removed from the shape of the new object 

370 if isinstance(index, slice): 

371 removed = np.array([]) 

372 else: 

373 removed = np.array([n for n, idx in enumerate(index) if idx is not None]) 

374 

375 # Create views into the fft transformed values, appropriately adjusting 

376 # the shapes for the new axes 

377 

378 fft_kernels = { 

379 ( 

380 tuple([s for idx, s in enumerate(key[0]) if key[0][idx] not in removed]), 

381 tuple([s for idx, s in enumerate(key[1]) if key[1][idx] not in removed]), 

382 tuple([s for idx, s in enumerate(key[2]) if key[2][idx] not in removed]), 

383 ): kernel[index] 

384 for key, kernel in self._fft.items() 

385 } 

386 # mpypy doesn't recognize that tuple[int, ...] 

387 # is a valid Sequence[int] for some reason 

388 return Fourier(self.image[index], fft_kernels) # type: ignore 

389 

390 

391def _kspace_operation( 

392 image1: Fourier, 

393 image2: Fourier, 

394 padding: int, 

395 op: Callable, 

396 shape: Sequence[int], 

397 axes: int | Sequence[int], 

398 cache: bool = True, 

399) -> Fourier: 

400 """Combine two images in k-space using a given `operator` 

401 

402 Parameters 

403 ---------- 

404 image1: 

405 The LHS of the equation. 

406 image2: 

407 The RHS of the equation. 

408 padding: 

409 The amount of padding to add before transforming into k-space. 

410 op: 

411 The operator used to combine the two images. 

412 This is either ``operator.mul`` for a convolution 

413 or ``operator.truediv`` for deconvolution. 

414 shape: 

415 The shape of the output image. 

416 axes: 

417 The dimension(s) of the array that will be transformed. 

418 cache: 

419 Whether the per-shape FFT of each input should be stored on the 

420 respective `Fourier` instance. Forwarded to `Fourier.fft`. 

421 """ 

422 if len(image1.shape) != len(image2.shape): 

423 msg = ( 

424 "Both images must have the same number of axes, " 

425 f"got {len(image1.shape)} and {len(image2.shape)}" 

426 ) 

427 raise ValueError(msg) 

428 

429 fft_shape = get_fft_shape(image1.image, image2.image, padding, axes) 

430 if ( 

431 op == operator.truediv 

432 or op == operator.floordiv 

433 or op == operator.itruediv 

434 or op == operator.ifloordiv 

435 ): 

436 # prevent divide by zero 

437 lhs = image1.fft(fft_shape, axes, cache=cache) 

438 rhs = image2.fft(fft_shape, axes, cache=cache) 

439 

440 # Broadcast, if necessary 

441 if rhs.shape[0] == 1 and lhs.shape[0] != rhs.shape[0]: 

442 rhs = np.tile(rhs, (lhs.shape[0],) + (1,) * len(rhs.shape[1:])) 

443 if lhs.shape[0] == 1 and lhs.shape[0] != rhs.shape[0]: 

444 lhs = np.tile(lhs, (rhs.shape[0],) + (1,) * len(lhs.shape[1:])) 

445 # only select non-zero elements for the denominator 

446 cuts = rhs != 0 

447 transformed_fft = np.zeros(lhs.shape, dtype=lhs.dtype) 

448 transformed_fft[cuts] = op(lhs[cuts], rhs[cuts]) 

449 else: 

450 transformed_fft = op( 

451 image1.fft(fft_shape, axes, cache=cache), 

452 image2.fft(fft_shape, axes, cache=cache), 

453 ) 

454 return Fourier.from_fft(transformed_fft, fft_shape, shape, axes, image1.image.dtype) 

455 

456 

457def match_kernel( 

458 kernel1: np.ndarray | Fourier, 

459 kernel2: np.ndarray | Fourier, 

460 padding: int = 3, 

461 axes: int | Sequence[int] = (-2, -1), 

462 return_fourier: bool = True, 

463 normalize: None = None, 

464 cache: bool = True, 

465) -> Fourier | np.ndarray: 

466 """Calculate the difference kernel to match kernel1 to kernel2 

467 

468 Parameters 

469 ---------- 

470 kernel1: 

471 The first kernel, either as array or as `Fourier` object 

472 kernel2: 

473 The second kernel, either as array or as `Fourier` object 

474 padding: 

475 Additional padding to use when generating the FFT 

476 to supress artifacts. 

477 axes: 

478 Axes that contain the spatial information for the kernels. 

479 return_fourier: 

480 Whether to return `Fourier` or array 

481 normalize: 

482 Deprecated and unused. Will be removed after v31.0. 

483 cache: 

484 Whether the per-shape FFTs of the inputs should be stored on the 

485 respective `Fourier` instances. Forwarded to `Fourier.fft`. 

486 

487 Returns 

488 ------- 

489 result: 

490 The difference kernel to go from `kernel1` to `kernel2`. 

491 """ 

492 if normalize is not None: 

493 logger.warning( 

494 "normalize is deprecated and will be removed after v31.0. " 

495 "It has never been used by match_kernel." 

496 ) 

497 if not isinstance(kernel1, Fourier): 

498 kernel1 = Fourier(kernel1) 

499 if not isinstance(kernel2, Fourier): 

500 kernel2 = Fourier(kernel2) 

501 

502 if kernel1.shape[0] < kernel2.shape[0]: 

503 shape = kernel2.shape 

504 else: 

505 shape = kernel1.shape 

506 

507 diff = _kspace_operation(kernel1, kernel2, padding, operator.truediv, shape, axes=axes, cache=cache) 

508 if return_fourier: 

509 return diff 

510 else: 

511 return np.real(diff.image) 

512 

513 

514def convolve( 

515 image: np.ndarray | Fourier, 

516 kernel: np.ndarray | Fourier, 

517 padding: int = 3, 

518 axes: int | Sequence[int] = (-2, -1), 

519 return_fourier: bool = True, 

520 normalize: None = None, 

521 cache: bool = True, 

522) -> np.ndarray | Fourier: 

523 """Convolve image with a kernel 

524 

525 Parameters 

526 ---------- 

527 image: 

528 Image either as array or as `Fourier` object 

529 kernel: 

530 Convolution kernel either as array or as `Fourier` object 

531 padding: 

532 Additional padding to use when generating the FFT 

533 to suppress artifacts. 

534 axes: 

535 Axes that contain the spatial information for the PSFs. 

536 return_fourier: 

537 Whether to return `Fourier` or array 

538 normalize: 

539 Deprecated and unused. Will be removed after v31.0. 

540 cache: 

541 Whether the per-shape FFTs of `image` and `kernel` should be 

542 stored on the respective `Fourier` instances. Forwarded to 

543 `Fourier.fft`. Long-lived `kernel` objects (e.g. those held by 

544 `Observation`) accumulate one cache entry per distinct image 

545 shape; pass ``cache=False`` when convolving with shapes that 

546 won't recur, to avoid unbounded growth. 

547 

548 Returns 

549 ------- 

550 result: 

551 The convolution of the image with the kernel. 

552 """ 

553 if normalize is not None: 

554 logger.warning( 

555 "normalize is deprecated and will be removed after v31.0. " "It has never been used by convolve." 

556 ) 

557 if not isinstance(image, Fourier): 

558 image = Fourier(image) 

559 if not isinstance(kernel, Fourier): 

560 kernel = Fourier(kernel) 

561 

562 convolved = _kspace_operation(image, kernel, padding, operator.mul, image.shape, axes=axes, cache=cache) 

563 if return_fourier: 

564 return convolved 

565 else: 

566 return np.real(convolved.image)