Coverage for python/lsst/scarlet/lite/observation.py: 17%

163 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 08:24 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Observation", "convolve"] 

25 

26from copy import deepcopy 

27from typing import Any, cast 

28 

29import numpy as np 

30import numpy.typing as npt 

31 

32from .bbox import Box 

33from .fft import Fourier, _pad, centered 

34from .fft import convolve as fft_convolve 

35from .fft import match_kernel 

36from .image import Image 

37 

38 

39def get_filter_coords(filter_values: np.ndarray, center: tuple[int, int] | None = None) -> np.ndarray: 

40 """Create filter coordinate grid needed for the apply filter function 

41 

42 Parameters 

43 ---------- 

44 filter_values: 

45 The 2D array of the filter to apply. 

46 center: 

47 The center (y,x) of the filter. If `center` is `None` then 

48 `filter_values` must have an odd number of rows and columns 

49 and the center will be set to the center of `filter_values`. 

50 

51 Returns 

52 ------- 

53 coords: 

54 The coordinates of the pixels in `filter_values`, 

55 where the coordinates of the `center` pixel are `(0,0)`. 

56 """ 

57 if filter_values.ndim != 2: 

58 raise ValueError("`filter_values` must be 2D") 

59 if center is None: 

60 if filter_values.shape[0] % 2 == 0 or filter_values.shape[1] % 2 == 0: 

61 msg = """Ambiguous center of the `filter_values` array, 

62 you must use a `filter_values` array 

63 with an odd number of rows and columns or 

64 calculate `coords` on your own.""" 

65 raise ValueError(msg) 

66 center = tuple([filter_values.shape[0] // 2, filter_values.shape[1] // 2]) # type: ignore 

67 _x = np.arange(filter_values.shape[1]) 

68 _y = np.arange(filter_values.shape[0]) 

69 x, y = np.meshgrid(_x, _y) 

70 x -= center[1] 

71 y -= center[0] 

72 coords = np.dstack([y, x]) 

73 return coords 

74 

75 

76def get_filter_bounds(coords: np.ndarray) -> tuple[int, int, int, int]: 

77 """Get the slices in x and y to apply a filter 

78 

79 Parameters 

80 ---------- 

81 coords: 

82 The coordinates of the filter, 

83 defined by `get_filter_coords`. 

84 

85 Returns 

86 ------- 

87 y_start, y_end, x_start, x_end: 

88 The start and end of each slice that is passed to `apply_filter`. 

89 """ 

90 z = np.zeros((len(coords),), dtype=int) 

91 # Set the y slices 

92 y_start = np.max([z, coords[:, 0]], axis=0) 

93 y_end = -np.min([z, coords[:, 0]], axis=0) 

94 # Set the x slices 

95 x_start = np.max([z, coords[:, 1]], axis=0) 

96 x_end = -np.min([z, coords[:, 1]], axis=0) 

97 return y_start, y_end, x_start, x_end 

98 

99 

100def convolve(image: np.ndarray, psf: np.ndarray, bounds: tuple[int, int, int, int]): 

101 """Convolve an image with a PSF in real space 

102 

103 Parameters 

104 ---------- 

105 image: 

106 The multi-band image to convolve. 

107 psf: 

108 The psf to convolve the image with. 

109 bounds: 

110 The filter bounds required by the ``apply_filter`` C++ method, 

111 usually obtained by calling `get_filter_bounds`. 

112 """ 

113 from lsst.scarlet.lite.operators_pybind11 import apply_filter # type: ignore 

114 

115 result = np.empty(image.shape, dtype=image.dtype) 

116 for band in range(len(image)): 

117 img = image[band] 

118 

119 apply_filter( 

120 img, 

121 psf[band].reshape(-1), 

122 bounds[0], 

123 bounds[1], 

124 bounds[2], 

125 bounds[3], 

126 result[band], 

127 ) 

128 return result 

129 

130 

131def _set_image_like(images: np.ndarray | Image, bands: tuple | None = None, bbox: Box | None = None) -> Image: 

132 """Ensure that an image-like array is cast appropriately as an image 

133 

134 Parameters 

135 ---------- 

136 images: 

137 The multiband image-like array to cast as an Image. 

138 If it already has `bands` and `bbox` properties then it is returned 

139 with no modifications. 

140 bands: 

141 The bands for the multiband-image. 

142 If `images` is a numpy array, this parameter is mandatory. 

143 If `images` is an `Image` and `bands` is not `None`, 

144 then `bands` is ignored. 

145 bbox: 

146 Bounding box containing the image. 

147 If `images` is a numpy array, this parameter is mandatory. 

148 If `images` is an `Image` and `bbox` is not `None`, 

149 then `bbox` is ignored. 

150 

151 Returns 

152 ------- 

153 images: Image 

154 The input images converted into an image. 

155 """ 

156 if isinstance(images, Image): 

157 # This is already an image 

158 if bbox is not None and images.bbox != bbox: 

159 raise ValueError(f"Bounding boxes {images.bbox} and {bbox} do not agree") 

160 return images 

161 

162 if bbox is None: 

163 bbox = Box(images.shape[-2:]) 

164 return Image(images, bands=bands, yx0=cast(tuple[int, int], bbox.origin)) 

165 

166 

167class Observation: 

168 """A single observation 

169 

170 This class contains all of the observed images and derived 

171 properties, like PSFs, variance map, and weight maps, 

172 required for most optimizers. 

173 This includes methods to match a scarlet model PSF to the oberved PSF 

174 in each band. 

175 

176 Notes 

177 ----- 

178 This is effectively a combination of the `Observation` and 

179 `Renderer` class from scarlet main, greatly simplified due 

180 to the assumptions that the observations are all resampled 

181 onto the same pixel grid and that the `images` contain all 

182 of the information for all of the model bands. 

183 

184 Parameters 

185 ---------- 

186 images: 

187 (bands, y, x) array of observed images. 

188 variance: 

189 (bands, y, x) array of variance for each image pixel. 

190 weights: 

191 (bands, y, x) array of weights to use when calculate the 

192 likelihood of each pixel. 

193 psfs: 

194 (bands, y, x) array of the PSF image in each band. 

195 model_psf: 

196 (bands, y, x) array of the model PSF image in each band. 

197 If `model_psf` is `None` then convolution is performed, 

198 which should only be done when the observation is a 

199 PSF matched coadd, and the scarlet model has the same PSF. 

200 noise_rms: 

201 Per-band average noise RMS. If `noise_rms` is `None` then the mean 

202 of the sqrt of the variance is used. 

203 bbox: 

204 The bounding box containing the model. If `bbox` is `None` then 

205 a `Box` is created that is the shape of `images` with an origin 

206 at `(0, 0)`. 

207 padding: 

208 Padding to use when performing an FFT convolution. 

209 convolution_mode: 

210 The method of convolution. This should be either "fft" or "real". 

211 """ 

212 

213 def __init__( 

214 self, 

215 images: np.ndarray | Image, 

216 variance: np.ndarray | Image, 

217 weights: np.ndarray | Image, 

218 psfs: np.ndarray, 

219 model_psf: np.ndarray | None = None, 

220 noise_rms: np.ndarray | None = None, 

221 bbox: Box | None = None, 

222 bands: tuple | None = None, 

223 padding: int = 3, 

224 convolution_mode: str = "fft", 

225 ): 

226 # Convert the images to a multi-band `Image` and use the resulting 

227 # bbox and bands. 

228 images = _set_image_like(images, bands, bbox) 

229 bands = images.bands 

230 bbox = images.bbox 

231 self.images = images 

232 self.variance = _set_image_like(variance, bands, bbox) 

233 self.weights = _set_image_like(weights, bands, bbox) 

234 # make sure that the images and psfs have the same dtype 

235 if psfs.dtype != images.dtype: 

236 psfs = psfs.astype(images.dtype) 

237 self.psfs = psfs 

238 

239 if convolution_mode not in [ 

240 "fft", 

241 "real", 

242 ]: 

243 raise ValueError("convolution_mode must be either 'fft' or 'real'") 

244 self.mode = convolution_mode 

245 if noise_rms is None: 

246 noise_rms = np.array([np.mean(np.sqrt(v[np.isfinite(v)])) for v in self.variance.data]) 

247 self.noise_rms = noise_rms 

248 

249 # Create a difference kernel to convolve the model to the PSF 

250 # in each band 

251 self.model_psf = model_psf 

252 self.padding = padding 

253 if model_psf is not None: 

254 if model_psf.dtype != images.dtype: 

255 self.model_psf = model_psf.astype(images.dtype) 

256 self.diff_kernel: Fourier | None = cast(Fourier, match_kernel(psfs, model_psf, padding=padding)) 

257 # The gradient of a convolution is another convolution, 

258 # but with the flipped and transposed kernel. 

259 diff_img = self.diff_kernel.image 

260 self.grad_kernel: Fourier | None = Fourier(diff_img[:, ::-1, ::-1]) 

261 else: 

262 self.diff_kernel = None 

263 self.grad_kernel = None 

264 

265 self._convolution_bounds: tuple[int, int, int, int] | None = None 

266 

267 @property 

268 def bands(self) -> tuple: 

269 """The bands in the observations.""" 

270 return self.images.bands 

271 

272 @property 

273 def bbox(self) -> Box: 

274 """The bounding box for the full observation.""" 

275 return self.images.bbox 

276 

277 def convolve( 

278 self, 

279 image: Image, 

280 mode: str | None = None, 

281 grad: bool = False, 

282 cache: bool = False, 

283 ) -> Image: 

284 """Convolve the model into the observed seeing in each band. 

285 

286 Parameters 

287 ---------- 

288 image: 

289 The 3D image to convolve. 

290 mode: 

291 The convolution mode to use. 

292 This should be "real" or "fft" or `None`, 

293 where `None` will use the default `convolution_mode` 

294 specified during init. 

295 grad: 

296 Whether this is a backward gradient convolution 

297 (`grad==True`) or a pure convolution with the PSF. 

298 cache: 

299 Whether to cache the FFT of the kernel at this image's shape. 

300 Defaults to ``False`` because most call sites convolve 

301 many different shapes (per-source / per-component) and would 

302 grow `diff_kernel._fft` unboundedly. Pass ``cache=True`` 

303 for repeated full-blend convolutions (e.g. inside the fit 

304 loop), where the same shape recurs every iteration. 

305 Ignored for ``mode == "real"``. 

306 

307 Returns 

308 ------- 

309 result: 

310 The convolved image. 

311 """ 

312 if grad: 

313 kernel = self.grad_kernel 

314 else: 

315 kernel = self.diff_kernel 

316 

317 if kernel is None: 

318 return image 

319 

320 if mode is None: 

321 mode = self.mode 

322 if mode == "fft": 

323 result = fft_convolve( 

324 Fourier(image.data), 

325 kernel, 

326 axes=(1, 2), 

327 return_fourier=False, 

328 cache=cache, 

329 ) 

330 elif mode == "real": 

331 dy = image.shape[1] - kernel.image.shape[1] 

332 dx = image.shape[2] - kernel.image.shape[2] 

333 if dy < 0 or dx < 0: 

334 # The image needs to be padded because it is smaller than 

335 # the psf kernel 

336 _image = image.data 

337 newshape = list(_image.shape) 

338 if dy < 0: 

339 newshape[1] += kernel.image.shape[1] - image.shape[1] 

340 if dx < 0: 

341 newshape[2] += kernel.image.shape[2] - image.shape[2] 

342 _image = _pad(_image, newshape) 

343 result = convolve(_image, kernel.image, self.convolution_bounds) 

344 result = centered(result, image.data.shape) # type: ignore 

345 else: 

346 result = convolve(image.data, kernel.image, self.convolution_bounds) 

347 else: 

348 raise ValueError(f"mode must be either 'fft' or 'real', got {mode}") 

349 return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0) 

350 

351 def log_likelihood(self, model: Image) -> float: 

352 """Calculate the log likelihood of the given model 

353 

354 Parameters 

355 ---------- 

356 model: 

357 Model to compare with the observed images. 

358 

359 Returns 

360 ------- 

361 result: 

362 The log-likelihood of the given model. 

363 """ 

364 result = 0.5 * -np.sum((self.weights * (self.images - model) ** 2).data) 

365 return result 

366 

367 def __getitem__(self, indices: Any) -> Observation: 

368 """Get a view for the subset of an image 

369 

370 Parameters 

371 ---------- 

372 indices: 

373 The indices to select a subsection of the image. 

374 

375 Returns 

376 ------- 

377 result: 

378 The resulting image obtained by selecting subsets of the iamge 

379 based on the `indices`. 

380 """ 

381 new_image = self.images[indices] 

382 new_variance = self.variance[indices] 

383 new_weights = self.weights[indices] 

384 

385 # If the indices is a single band, make sure to keep the band axis 

386 if new_image.ndim == 2: 

387 if indices in self.bands: 

388 new_bands = (indices,) 

389 else: 

390 # The indices contain spatial and band indices 

391 new_bands = (indices[0],) 

392 new_image = Image( 

393 new_image.data[None, :, :], 

394 yx0=new_image.yx0, 

395 bands=new_bands, 

396 ) 

397 new_variance = Image( 

398 new_variance.data[None, :, :], 

399 yx0=new_variance.yx0, 

400 bands=new_bands, 

401 ) 

402 new_weights = Image( 

403 new_weights.data[None, :, :], 

404 yx0=new_weights.yx0, 

405 bands=new_bands, 

406 ) 

407 

408 # Extract the appropriate bands from the PSF 

409 bands = self.images.bands 

410 new_bands = new_image.bands 

411 if bands != new_bands: 

412 band_indices = self.images.spectral_indices(new_bands) 

413 psfs = self.psfs[band_indices,] 

414 noise_rms = self.noise_rms[band_indices,] 

415 else: 

416 psfs = self.psfs 

417 noise_rms = self.noise_rms 

418 

419 return Observation( 

420 images=new_image, 

421 variance=new_variance, 

422 weights=new_weights, 

423 psfs=psfs, 

424 model_psf=self.model_psf, 

425 noise_rms=noise_rms, 

426 bbox=new_image.bbox, 

427 bands=new_bands, 

428 padding=self.padding, 

429 convolution_mode=self.mode, 

430 ) 

431 

432 def __copy__(self) -> Observation: 

433 """Create a copy of the observation 

434 

435 Returns 

436 ------- 

437 result: 

438 The copy of the observation. 

439 """ 

440 return Observation( 

441 images=self.images, 

442 variance=self.variance, 

443 weights=self.weights, 

444 psfs=self.psfs, 

445 model_psf=self.model_psf, 

446 noise_rms=self.noise_rms, 

447 bands=self.bands, 

448 padding=self.padding, 

449 convolution_mode=self.mode, 

450 ) 

451 

452 def __deepcopy__(self, memo: dict[int, Any]) -> Observation: 

453 """Create a deep copy of the observation 

454 

455 Parameters 

456 ---------- 

457 memo: dict[int, Any] 

458 The memoization dictionary used by `copy.deepcopy`. 

459 

460 Returns 

461 ------- 

462 result: 

463 The deep copy of the observation. 

464 """ 

465 # Check if already copied 

466 if id(self) in memo: 

467 return memo[id(self)] 

468 

469 # Create placeholder and add to memo FIRST 

470 result = Observation.__new__(Observation) 

471 memo[id(self)] = result 

472 

473 # Now safely initialize the placeholder with deepcopied arguments 

474 result.__init__( # type: ignore[misc] 

475 images=deepcopy(self.images, memo), 

476 variance=deepcopy(self.variance, memo), 

477 weights=deepcopy(self.weights, memo), 

478 psfs=deepcopy(self.psfs, memo), 

479 model_psf=deepcopy(self.model_psf, memo), 

480 noise_rms=deepcopy(self.noise_rms, memo), 

481 bands=deepcopy(self.bands, memo), 

482 padding=deepcopy(self.padding, memo), 

483 convolution_mode=self.mode, 

484 ) 

485 

486 return result 

487 

488 def copy(self, deep: bool = False) -> Observation: 

489 """Create a copy of the observation 

490 

491 Parameters 

492 ---------- 

493 deep: 

494 Whether to perform a deep copy or not. 

495 

496 Returns 

497 ------- 

498 result: 

499 The copy of the observation. 

500 """ 

501 if deep: 

502 return self.__deepcopy__({}) 

503 return self.__copy__() 

504 

505 @property 

506 def shape(self) -> tuple[int, int, int]: 

507 """The shape of the images, variance, etc.""" 

508 return cast(tuple[int, int, int], self.images.shape) 

509 

510 @property 

511 def n_bands(self) -> int: 

512 """The number of bands in the observation""" 

513 return self.images.shape[0] 

514 

515 @property 

516 def dtype(self) -> npt.DTypeLike: 

517 """The dtype of the observation is the dtype of the images""" 

518 return self.images.dtype 

519 

520 @property 

521 def convolution_bounds(self) -> tuple[int, int, int, int]: 

522 """Build the slices needed for convolution in real space""" 

523 if self._convolution_bounds is None: 

524 coords = get_filter_coords(cast(Fourier, self.diff_kernel).image[0]) 

525 self._convolution_bounds = get_filter_bounds(coords.reshape(-1, 2)) 

526 return self._convolution_bounds 

527 

528 @staticmethod 

529 def empty( 

530 bands: tuple[Any], psfs: np.ndarray, model_psf: np.ndarray, bbox: Box, dtype: npt.DTypeLike 

531 ) -> Observation: 

532 dummy_image = np.zeros((len(bands),) + bbox.shape, dtype=dtype) 

533 

534 return Observation( 

535 images=dummy_image, 

536 variance=dummy_image, 

537 weights=dummy_image, 

538 psfs=psfs, 

539 model_psf=model_psf, 

540 noise_rms=np.zeros((len(bands),), dtype=dtype), 

541 bbox=bbox, 

542 bands=bands, 

543 convolution_mode="real", 

544 )