Coverage for python/lsst/scarlet/lite/models/fit_psf.py: 28%

71 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 08:25 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["FittedPsfObservation", "FittedPsfBlend"] 

23 

24from typing import Callable, cast 

25 

26import numpy as np 

27 

28from ..bbox import Box 

29from ..blend import Blend 

30from ..fft import Fourier, centered 

31from ..fft import convolve as fft_convolve 

32from ..image import Image 

33from ..observation import Observation 

34from ..parameters import parameter 

35 

36 

37class FittedPsfObservation(Observation): 

38 """An observation that fits the PSF used to convolve the model. 

39 

40 Parameters 

41 ---------- 

42 images: 

43 (bands, y, x) array of observed images. 

44 variance: 

45 (bands, y, x) array of variance for each image pixel. 

46 weights: 

47 (bands, y, x) array of weights to use when calculate the 

48 likelihood of each pixel. 

49 psfs: 

50 (bands, y, x) array of the PSF image in each band. 

51 model_psf: 

52 (bands, y, x) array of the model PSF image in each band. 

53 If `model_psf` is `None` then convolution is performed, 

54 which should only be done when the observation is a 

55 PSF matched coadd, and the scarlet model has the same PSF. 

56 noise_rms: 

57 Per-band average noise RMS. If `noise_rms` is `None` then the mean 

58 of the sqrt of the variance is used. 

59 bbox: 

60 The bounding box containing the model. If `bbox` is `None` then 

61 a `Box` is created that is the shape of `images` with an origin 

62 at `(0, 0)`. 

63 bands: 

64 The bands covered by the observation. 

65 padding: 

66 Padding to use when performing an FFT convolution. 

67 convolution_mode: 

68 The method of convolution. This should be either "fft" or "real". 

69 shape: 

70 The `(height, width)` shape of the fitted PSF kernel. 

71 If `None` then ``(41, 41)`` is used. 

72 """ 

73 

74 def __init__( 

75 self, 

76 images: np.ndarray | Image, 

77 variance: np.ndarray | Image, 

78 weights: np.ndarray | Image, 

79 psfs: np.ndarray, 

80 model_psf: np.ndarray | None = None, 

81 noise_rms: np.ndarray | None = None, 

82 bbox: Box | None = None, 

83 bands: tuple | None = None, 

84 padding: int = 3, 

85 convolution_mode: str = "fft", 

86 shape: tuple[int, int] | None = None, 

87 ): 

88 super().__init__( 

89 images, 

90 variance, 

91 weights, 

92 psfs, 

93 model_psf, 

94 noise_rms, 

95 bbox, 

96 bands, 

97 padding, 

98 convolution_mode, 

99 ) 

100 

101 self.axes = (-2, -1) 

102 

103 if shape is None: 

104 shape = (41, 41) 

105 

106 # Make the DFT of the psf a fittable parameter 

107 self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).image) 

108 

109 def grad_fit_kernel(self, input_grad: np.ndarray, psf: np.ndarray, model: np.ndarray) -> np.ndarray: 

110 """Gradient of the loss wrt the PSF 

111 

112 This is just the cross correlation of the input gradient 

113 with the model. 

114 

115 Parameters 

116 ---------- 

117 input_grad: 

118 The gradient of the loss wrt the model 

119 psf: 

120 The PSF of the model. 

121 model: 

122 The deconvolved model. 

123 """ 

124 grad = cast( 

125 np.ndarray, 

126 fft_convolve( 

127 Fourier(model), 

128 Fourier(input_grad[:, ::-1, ::-1]), 

129 axes=(1, 2), 

130 return_fourier=False, 

131 ), 

132 ) 

133 

134 return centered(grad, psf.shape) 

135 

136 def prox_kernel(self, kernel: np.ndarray) -> np.ndarray: 

137 # No prox for now 

138 return kernel 

139 

140 @property 

141 def fitted_kernel(self) -> np.ndarray: 

142 return self._fitted_kernel.x 

143 

144 @property 

145 def cached_kernel(self): 

146 return self.fitted_kernel[:, ::-1, ::-1] 

147 

148 def convolve( 

149 self, 

150 image: Image, 

151 mode: str | None = None, 

152 grad: bool = False, 

153 cache: bool = False, 

154 ) -> Image: 

155 """Convolve the model into the observed seeing in each band. 

156 

157 Parameters 

158 ---------- 

159 image: 

160 The image to convolve 

161 mode: 

162 The convolution mode to use. 

163 This should be "real" or "fft" or `None`, 

164 where `None` will use the default `convolution_mode` 

165 specified during init. 

166 grad: 

167 Whether this is a backward gradient convolution 

168 (`grad==True`) or a pure convolution with the PSF. 

169 cache: 

170 See `Observation.convolve`. The fitted-PSF kernel is wrapped 

171 in a fresh `Fourier` on every call (since the kernel data 

172 changes during fitting), so this flag mainly serves to 

173 propagate caller intent through to `super().convolve` when 

174 delegating for non-FFT modes. 

175 """ 

176 if grad: 

177 kernel = self.cached_kernel 

178 else: 

179 kernel = self.fitted_kernel 

180 

181 if mode != "fft" and mode is not None: 

182 return super().convolve(image, mode, grad, cache=cache) 

183 

184 result = fft_convolve( 

185 Fourier(image.data), 

186 Fourier(kernel), 

187 axes=(1, 2), 

188 return_fourier=False, 

189 cache=cache, 

190 ) 

191 return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0) 

192 

193 def update(self, it: int, input_grad: np.ndarray, model: np.ndarray): 

194 """Update the PSF given the gradient of the loss 

195 

196 Parameters 

197 ---------- 

198 it: int 

199 The current iteration 

200 input_grad: np.ndarray 

201 The gradient of the loss wrt the model 

202 model: np.ndarray 

203 The deconvolved model. 

204 """ 

205 self._fitted_kernel.update(it, input_grad, model) 

206 

207 def parameterize(self, parameterization: Callable) -> None: 

208 """Convert the component parameter arrays into Parameter instances 

209 

210 Parameters 

211 ---------- 

212 parameterization: Callable 

213 A function to use to convert parameters of a given type into 

214 a `Parameter` in place. It should take a single argument that 

215 is the `Component` or `Source` that is to be parameterized. 

216 """ 

217 # Update the fitted kernel in place 

218 parameterization(self) 

219 # update the parameters 

220 self._fitted_kernel.grad = self.grad_fit_kernel 

221 self._fitted_kernel.prox = self.prox_kernel 

222 

223 

224class FittedPsfBlend(Blend): 

225 """A blend that attempts to fit the PSF along with the source models.""" 

226 

227 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]: 

228 """Gradient of the likelihood wrt the unconvolved model""" 

229 model = self.get_model(convolve=True) 

230 # Update the loss 

231 self.loss.append(self.observation.log_likelihood(model)) 

232 # Calculate the gradient wrt the model d(logL)/d(model) 

233 residual = self.observation.weights * (model - self.observation.images) 

234 

235 return residual, model.data 

236 

237 def fit( 

238 self, 

239 max_iter: int, 

240 e_rel: float = 1e-4, 

241 min_iter: int = 1, 

242 resize: int = 10, 

243 ) -> tuple[int, float]: 

244 """Fit all of the parameters 

245 

246 Parameters 

247 ---------- 

248 max_iter: int 

249 The maximum number of iterations 

250 e_rel: float 

251 The relative error to use for determining convergence. 

252 min_iter: int 

253 The minimum number of iterations. 

254 resize: int 

255 Number of iterations before attempting to resize the 

256 resizable components. If `resize` is `None` then 

257 no resizing is ever attempted. 

258 """ 

259 it = self.it 

260 while it < max_iter: 

261 # Calculate the gradient wrt the on-convolved model 

262 grad_log_likelihood, model = self._grad_log_likelihood() 

263 _grad_log_likelihood = self.observation.convolve(grad_log_likelihood, grad=True, cache=True) 

264 # Check if resizing needs to be performed in this iteration 

265 if resize is not None and self.it > 0 and self.it % resize == 0: 

266 do_resize = True 

267 else: 

268 do_resize = False 

269 # Update each component given the current gradient 

270 for component in self.components: 

271 overlap = component.bbox & self.bbox 

272 component.update(it, _grad_log_likelihood[overlap].data) 

273 # Check to see if any components need to be resized 

274 if do_resize: 

275 component.resize(self.bbox) 

276 

277 # Update the PSF 

278 cast(FittedPsfObservation, self.observation).update( 

279 self.it, 

280 grad_log_likelihood.data, 

281 model, 

282 ) 

283 # Stopping criteria 

284 it += 1 

285 if it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]): 

286 break 

287 self.it = it 

288 return it, self.loss[-1] 

289 

290 def parameterize(self, parameterization: Callable): 

291 """Convert the component parameter arrays into Parameter instances 

292 

293 Parameters 

294 ---------- 

295 parameterization: 

296 A function to use to convert parameters of a given type into 

297 a `Parameter` in place. It should take a single argument that 

298 is the `Component` or `Source` that is to be parameterized. 

299 """ 

300 for source in self.sources: 

301 source.parameterize(parameterization) 

302 cast(FittedPsfObservation, self.observation).parameterize(parameterization)