Coverage for python/lsst/scarlet/lite/models/fit_psf.py: 28%
71 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:24 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:24 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["FittedPsfObservation", "FittedPsfBlend"]
24from typing import Callable, cast
26import numpy as np
28from ..bbox import Box
29from ..blend import Blend
30from ..fft import Fourier, centered
31from ..fft import convolve as fft_convolve
32from ..image import Image
33from ..observation import Observation
34from ..parameters import parameter
37class FittedPsfObservation(Observation):
38 """An observation that fits the PSF used to convolve the model.
40 Parameters
41 ----------
42 images:
43 (bands, y, x) array of observed images.
44 variance:
45 (bands, y, x) array of variance for each image pixel.
46 weights:
47 (bands, y, x) array of weights to use when calculate the
48 likelihood of each pixel.
49 psfs:
50 (bands, y, x) array of the PSF image in each band.
51 model_psf:
52 (bands, y, x) array of the model PSF image in each band.
53 If `model_psf` is `None` then convolution is performed,
54 which should only be done when the observation is a
55 PSF matched coadd, and the scarlet model has the same PSF.
56 noise_rms:
57 Per-band average noise RMS. If `noise_rms` is `None` then the mean
58 of the sqrt of the variance is used.
59 bbox:
60 The bounding box containing the model. If `bbox` is `None` then
61 a `Box` is created that is the shape of `images` with an origin
62 at `(0, 0)`.
63 bands:
64 The bands covered by the observation.
65 padding:
66 Padding to use when performing an FFT convolution.
67 convolution_mode:
68 The method of convolution. This should be either "fft" or "real".
69 shape:
70 The `(height, width)` shape of the fitted PSF kernel.
71 If `None` then ``(41, 41)`` is used.
72 """
74 def __init__(
75 self,
76 images: np.ndarray | Image,
77 variance: np.ndarray | Image,
78 weights: np.ndarray | Image,
79 psfs: np.ndarray,
80 model_psf: np.ndarray | None = None,
81 noise_rms: np.ndarray | None = None,
82 bbox: Box | None = None,
83 bands: tuple | None = None,
84 padding: int = 3,
85 convolution_mode: str = "fft",
86 shape: tuple[int, int] | None = None,
87 ):
88 super().__init__(
89 images,
90 variance,
91 weights,
92 psfs,
93 model_psf,
94 noise_rms,
95 bbox,
96 bands,
97 padding,
98 convolution_mode,
99 )
101 self.axes = (-2, -1)
103 if shape is None:
104 shape = (41, 41)
106 # Make the DFT of the psf a fittable parameter
107 self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).image)
109 def grad_fit_kernel(self, input_grad: np.ndarray, psf: np.ndarray, model: np.ndarray) -> np.ndarray:
110 """Gradient of the loss wrt the PSF
112 This is just the cross correlation of the input gradient
113 with the model.
115 Parameters
116 ----------
117 input_grad:
118 The gradient of the loss wrt the model
119 psf:
120 The PSF of the model.
121 model:
122 The deconvolved model.
123 """
124 grad = cast(
125 np.ndarray,
126 fft_convolve(
127 Fourier(model),
128 Fourier(input_grad[:, ::-1, ::-1]),
129 axes=(1, 2),
130 return_fourier=False,
131 ),
132 )
134 return centered(grad, psf.shape)
136 def prox_kernel(self, kernel: np.ndarray) -> np.ndarray:
137 # No prox for now
138 return kernel
140 @property
141 def fitted_kernel(self) -> np.ndarray:
142 return self._fitted_kernel.x
144 @property
145 def cached_kernel(self):
146 return self.fitted_kernel[:, ::-1, ::-1]
148 def convolve(
149 self,
150 image: Image,
151 mode: str | None = None,
152 grad: bool = False,
153 cache: bool = False,
154 ) -> Image:
155 """Convolve the model into the observed seeing in each band.
157 Parameters
158 ----------
159 image:
160 The image to convolve
161 mode:
162 The convolution mode to use.
163 This should be "real" or "fft" or `None`,
164 where `None` will use the default `convolution_mode`
165 specified during init.
166 grad:
167 Whether this is a backward gradient convolution
168 (`grad==True`) or a pure convolution with the PSF.
169 cache:
170 See `Observation.convolve`. The fitted-PSF kernel is wrapped
171 in a fresh `Fourier` on every call (since the kernel data
172 changes during fitting), so this flag mainly serves to
173 propagate caller intent through to `super().convolve` when
174 delegating for non-FFT modes.
175 """
176 if grad:
177 kernel = self.cached_kernel
178 else:
179 kernel = self.fitted_kernel
181 if mode != "fft" and mode is not None:
182 return super().convolve(image, mode, grad, cache=cache)
184 result = fft_convolve(
185 Fourier(image.data),
186 Fourier(kernel),
187 axes=(1, 2),
188 return_fourier=False,
189 cache=cache,
190 )
191 return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0)
193 def update(self, it: int, input_grad: np.ndarray, model: np.ndarray):
194 """Update the PSF given the gradient of the loss
196 Parameters
197 ----------
198 it: int
199 The current iteration
200 input_grad: np.ndarray
201 The gradient of the loss wrt the model
202 model: np.ndarray
203 The deconvolved model.
204 """
205 self._fitted_kernel.update(it, input_grad, model)
207 def parameterize(self, parameterization: Callable) -> None:
208 """Convert the component parameter arrays into Parameter instances
210 Parameters
211 ----------
212 parameterization: Callable
213 A function to use to convert parameters of a given type into
214 a `Parameter` in place. It should take a single argument that
215 is the `Component` or `Source` that is to be parameterized.
216 """
217 # Update the fitted kernel in place
218 parameterization(self)
219 # update the parameters
220 self._fitted_kernel.grad = self.grad_fit_kernel
221 self._fitted_kernel.prox = self.prox_kernel
224class FittedPsfBlend(Blend):
225 """A blend that attempts to fit the PSF along with the source models."""
227 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
228 """Gradient of the likelihood wrt the unconvolved model"""
229 model = self.get_model(convolve=True)
230 # Update the loss
231 self.loss.append(self.observation.log_likelihood(model))
232 # Calculate the gradient wrt the model d(logL)/d(model)
233 residual = self.observation.weights * (model - self.observation.images)
235 return residual, model.data
237 def fit(
238 self,
239 max_iter: int,
240 e_rel: float = 1e-4,
241 min_iter: int = 1,
242 resize: int = 10,
243 ) -> tuple[int, float]:
244 """Fit all of the parameters
246 Parameters
247 ----------
248 max_iter: int
249 The maximum number of iterations
250 e_rel: float
251 The relative error to use for determining convergence.
252 min_iter: int
253 The minimum number of iterations.
254 resize: int
255 Number of iterations before attempting to resize the
256 resizable components. If `resize` is `None` then
257 no resizing is ever attempted.
258 """
259 it = self.it
260 while it < max_iter:
261 # Calculate the gradient wrt the on-convolved model
262 grad_log_likelihood, model = self._grad_log_likelihood()
263 _grad_log_likelihood = self.observation.convolve(grad_log_likelihood, grad=True, cache=True)
264 # Check if resizing needs to be performed in this iteration
265 if resize is not None and self.it > 0 and self.it % resize == 0:
266 do_resize = True
267 else:
268 do_resize = False
269 # Update each component given the current gradient
270 for component in self.components:
271 overlap = component.bbox & self.bbox
272 component.update(it, _grad_log_likelihood[overlap].data)
273 # Check to see if any components need to be resized
274 if do_resize:
275 component.resize(self.bbox)
277 # Update the PSF
278 cast(FittedPsfObservation, self.observation).update(
279 self.it,
280 grad_log_likelihood.data,
281 model,
282 )
283 # Stopping criteria
284 it += 1
285 if it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
286 break
287 self.it = it
288 return it, self.loss[-1]
290 def parameterize(self, parameterization: Callable):
291 """Convert the component parameter arrays into Parameter instances
293 Parameters
294 ----------
295 parameterization:
296 A function to use to convert parameters of a given type into
297 a `Parameter` in place. It should take a single argument that
298 is the `Component` or `Source` that is to be parameterized.
299 """
300 for source in self.sources:
301 source.parameterize(parameterization)
302 cast(FittedPsfObservation, self.observation).parameterize(parameterization)