Coverage for python / lsst / scarlet / lite / blend.py: 20%
171 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 07:47 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 07:47 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["Blend"]
26from abc import ABC, abstractmethod
27from copy import deepcopy
28from typing import TYPE_CHECKING, Any, Callable, Self, Sequence, cast
30import numpy as np
32from .bbox import Box
33from .component import Component, FactorizedComponent
34from .image import Image
35from .observation import Observation
36from .source import Source, SourceBase
38if TYPE_CHECKING:
39 from .io import ScarletBlendData, ScarletSourceBaseData
42class BlendBase(ABC):
43 """A base class for blends that can be extended to add additional
44 functionality.
46 This class holds all of the sources and observation that are to be fit,
47 as well as performing fitting and joint initialization of the
48 spectral components (when applicable).
50 Parameters
51 ----------
52 sources:
53 The sources to fit.
54 observation:
55 The observation that contains the images,
56 PSF, etc. that are being fit.
57 metadata:
58 Additional metadata to store with the blend.
59 """
61 sources: Sequence[SourceBase]
62 observation: Observation
63 metadata: dict | None
65 @property
66 def shape(self) -> tuple[int, int, int]:
67 """Shape of the model for the entire `Blend`."""
68 return self.observation.shape
70 @property
71 def bbox(self) -> Box:
72 """The bounding box of the entire blend."""
73 return self.observation.bbox
75 @property
76 def components(self) -> list[Component]:
77 """The list of all components in the blend.
79 Since the list of sources might change,
80 this is always built on the fly.
81 """
82 return [c for src in self.sources for c in src.components]
84 @abstractmethod
85 def __getitem__(self, indices: Any) -> Self:
86 """Get a sub-blend corresponding to the given indices.
88 Parameters
89 ----------
90 indices :
91 The indices to use to slice the blend.
93 Returns
94 -------
95 sub_blend :
96 A new `BlendBase` instance containing only data from the
97 specified bands in the specified order.
99 Raises
100 ------
101 IndexError :
102 If the indices contain bands not included in the original
103 blend or any spatial indices are given.
104 """
106 @abstractmethod
107 def __copy__(self) -> Self:
108 """Create a copy of this blend.
110 Returns
111 -------
112 blend : BlendBase
113 A new blend that is a copy of this one.
114 """
116 @abstractmethod
117 def __deepcopy__(self, memo: dict[int, Any]) -> Self:
118 """Create a deep copy of this blend.
120 Parameters
121 ----------
122 memo : dict[int, Any]
123 A memoization dictionary used by `copy.deepcopy`.
125 Returns
126 -------
127 blend : BlendBase
128 A new blend that is a deep copy of this one.
129 """
131 def copy(self, deep: bool = False) -> Self:
132 """Create a copy of this blend.
134 Parameters
135 ----------
136 deep :
137 If `True`, a deep copy is made. If `False`, a shallow copy is made.
138 Default is `False`.
140 Returns
141 -------
142 blend : Self
143 A new blend that is a copy of this one.
144 """
145 if deep:
146 return self.__deepcopy__({})
147 else:
148 return self.__copy__()
150 @abstractmethod
151 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
152 """Generate a model of the entire blend.
154 Parameters
155 ----------
156 convolve:
157 Whether to convolve the model with the observed PSF in each band.
158 use_flux:
159 Whether to use the re-distributed flux associated with the sources
160 instead of the component models.
162 Returns
163 -------
164 model:
165 The model created by combining all of the source models.
166 """
168 @abstractmethod
169 def to_data(self) -> ScarletBlendData:
170 """Convert the blend into a serializable dictionary format.
172 Returns
173 -------
174 data:
175 A dictionary containing all of the information needed to
176 reconstruct the blend.
177 """
180class Blend(BlendBase):
181 """A single blend.
183 This class holds all of the sources and observation that are to be fit,
184 as well as performing fitting and joint initialization of the
185 spectral components (when applicable).
187 Parameters
188 ----------
189 sources:
190 The sources to fit.
191 observation:
192 The observation that contains the images,
193 PSF, etc. that are being fit.
194 metadata:
195 Additional metadata to store with the blend.
196 """
198 sources: list[Source]
200 def __init__(self, sources: Sequence[Source], observation: Observation, metadata: dict | None = None):
201 self.sources = list(sources)
202 self.observation = observation
203 if metadata is not None and len(metadata) == 0:
204 metadata = None
205 self.metadata = metadata
207 # Initialize the iteration count and loss function
208 self.it = 0
209 self.loss: list[float] = []
211 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
212 """Generate a model of the entire blend.
214 Parameters
215 ----------
216 convolve:
217 Whether to convolve the model with the observed PSF in each band.
218 use_flux:
219 Whether to use the re-distributed flux associated with the sources
220 instead of the component models.
222 Returns
223 -------
224 model:
225 The model created by combining all of the source models.
226 """
227 model = Image(
228 np.zeros(self.shape, dtype=self.observation.images.dtype),
229 bands=self.observation.bands,
230 yx0=cast(tuple[int, int], self.observation.bbox.origin[-2:]),
231 )
233 if use_flux:
234 for src in self.sources:
235 if src.flux_weighted_image is None:
236 raise ValueError(
237 "Some sources do not have 'flux' attribute set. Run measure.conserve_flux"
238 )
239 src.flux_weighted_image.insert_into(model)
240 else:
241 for component in self.components:
242 component.get_model().insert_into(model)
243 if convolve:
244 return self.observation.convolve(model, cache=True)
245 return model
247 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
248 """Gradient of the likelihood wrt the unconvolved model
250 Returns
251 -------
252 result:
253 The gradient of the likelihood wrt the model
254 model_data:
255 The convol model data used to calculate the gradient.
256 This can be useful for debugging but is not used in
257 production.
258 """
259 model = self.get_model(convolve=True)
260 # Update the loss
261 self.loss.append(self.observation.log_likelihood(model))
262 # Calculate the gradient wrt the model d(logL)/d(model)
263 result = self.observation.weights * (model - self.observation.images)
264 result = self.observation.convolve(result, grad=True, cache=True)
265 return result, model.data
267 @property
268 def log_likelihood(self) -> float:
269 """The current log-likelihood
271 This is calculated on the fly to ensure that it is always up to date
272 with the current model parameters.
273 """
274 return self.observation.log_likelihood(self.get_model(convolve=True))
276 def fit_spectra(self, clip: bool = False) -> Blend:
277 """Fit all of the spectra given their current morphologies with a
278 linear least squares algorithm.
280 Parameters
281 ----------
282 clip:
283 Whether or not to clip components that were not
284 assigned any flux during the fit.
286 Returns
287 -------
288 blend:
289 The blend with updated components is returned.
290 """
291 from .initialization import multifit_spectra
293 morphs = []
294 spectra = []
295 factorized_indices = []
296 model = Image.from_box(
297 self.observation.bbox,
298 bands=self.observation.bands,
299 dtype=self.observation.dtype,
300 )
301 components = self.components
302 for idx, component in enumerate(components):
303 if hasattr(component, "morph") and hasattr(component, "spectrum"):
304 component = cast(FactorizedComponent, component)
305 morphs.append(component.morph)
306 spectra.append(component.spectrum)
307 factorized_indices.append(idx)
308 else:
309 model.insert(component.get_model())
310 model = self.observation.convolve(model, mode="real", cache=True)
312 boxes = [c.bbox for c in components]
313 fit_spectra = multifit_spectra(
314 self.observation,
315 [Image(morph, yx0=cast(tuple[int, int], bbox.origin)) for morph, bbox in zip(morphs, boxes)],
316 model,
317 )
318 for idx in range(len(morphs)):
319 component = cast(FactorizedComponent, components[factorized_indices[idx]])
320 component.spectrum[:] = fit_spectra[idx]
321 component.spectrum[component.spectrum < 0] = 0
323 # Run the proxes for all of the components to make sure that the
324 # spectra are consistent with the constraints.
325 # In practice this usually means making sure that they are
326 # non-negative.
327 for src in self.sources:
328 for component in src.components:
329 if (
330 hasattr(component, "spectrum")
331 and hasattr(component, "prox_spectrum")
332 and component.prox_spectrum is not None # type: ignore
333 ):
334 component.prox_spectrum(component.spectrum) # type: ignore
336 if clip:
337 # Remove components with no positive flux
338 for src in self.sources:
339 _components = []
340 for component in src.components:
341 component_model = component.get_model()
342 component_model.data[component_model.data < 0] = 0
343 if np.sum(component_model.data) > 0:
344 _components.append(component)
345 src.components = _components
347 return self
349 def fit(
350 self,
351 max_iter: int,
352 e_rel: float = 1e-4,
353 min_iter: int = 15,
354 resize: int = 10,
355 ) -> tuple[int, float]:
356 """Fit all of the parameters
358 Parameters
359 ----------
360 max_iter:
361 The maximum number of iterations
362 e_rel:
363 The relative error to use for determining convergence.
364 min_iter:
365 The minimum number of iterations.
366 resize:
367 Number of iterations before attempting to resize the
368 resizable components. If `resize` is `None` then
369 no resizing is ever attempted.
371 Returns
372 -------
373 it:
374 Number of iterations.
375 loss:
376 Loss for the last solution
377 """
378 while self.it < max_iter:
379 # Calculate the gradient wrt the on-convolved model
380 grad_log_likelihood = self._grad_log_likelihood()
381 if resize is not None and self.it > 0 and self.it % resize == 0:
382 do_resize = True
383 else:
384 do_resize = False
385 # Update each component given the current gradient
386 for component in self.components:
387 overlap = component.bbox & self.bbox
388 component.update(self.it, grad_log_likelihood[0][overlap].data)
389 # Check to see if any components need to be resized
390 if do_resize:
391 component.resize(self.bbox)
392 # Stopping criteria
393 self.it += 1
394 if self.it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
395 break
396 return self.it, self.loss[-1]
398 def parameterize(self, parameterization: Callable):
399 """Convert the component parameter arrays into Parameter instances
401 Parameters
402 ----------
403 parameterization:
404 A function to use to convert parameters of a given type into
405 a `Parameter` in place. It should take a single argument that
406 is the `Component` or `Source` that is to be parameterized.
407 """
408 for source in self.sources:
409 source.parameterize(parameterization)
411 def conserve_flux(self, mask_footprint: bool = True, weight_image: Image | None = None) -> None:
412 """Use the source models as templates to re-distribute flux
413 from the data
415 The source models are used as approximations to the data,
416 which redistribute the flux in the data according to the
417 ratio of the models for each source.
418 There is no return value for this function,
419 instead it adds (or modifies) a ``flux_weighted_image``
420 attribute to each the sources with the flux attributed to
421 that source.
423 Parameters
424 ----------
425 blend:
426 The blend that is being fit
427 mask_footprint:
428 Whether or not to apply a mask for pixels with zero weight.
429 weight_image:
430 The weight image to use for the redistribution.
431 If `None` then the observation image is used.
432 """
433 observation = self.observation
434 py = observation.psfs.shape[-2] // 2
435 px = observation.psfs.shape[-1] // 2
437 images = observation.images.copy()
438 if mask_footprint:
439 images.data[observation.weights.data == 0] = 0
441 if weight_image is None:
442 weight_image = self.get_model()
443 # Always convolve in real space to avoid FFT artifacts
444 weight_image = observation.convolve(weight_image, mode="real", cache=True)
446 # Due to ringing in the PSF, the convolved model can have
447 # negative values. We take the absolute value to avoid
448 # negative fluxes in the flux weighted images.
449 weight_image.data[:] = np.abs(weight_image.data)
451 for src in self.sources:
452 if src.is_null:
453 src.flux_weighted_image = Image.from_box(Box((0, 0)), bands=observation.bands) # type: ignore
454 continue
455 src_model = src.get_model()
457 # Grow the model to include the wings of the PSF
458 src_box = src.bbox.grow((py, px))
459 overlap = observation.bbox & src_box
460 src_model = src_model.project(bbox=overlap)
461 src_model = observation.convolve(src_model, mode="real")
462 src_model.data[:] = np.abs(src_model.data)
463 numerator = src_model.data
464 denominator = weight_image[overlap].data
465 cuts = denominator != 0
466 ratio = np.zeros(numerator.shape, dtype=numerator.dtype)
467 ratio[cuts] = numerator[cuts] / denominator[cuts]
468 # sometimes numerical errors can cause a hot pixel to have a
469 # slightly higher ratio than 1
470 ratio[ratio > 1] = 1
471 src.flux_weighted_image = src_model.copy_with(data=ratio) * images[overlap]
473 def to_data(self) -> ScarletBlendData:
474 """Convert the Blend into a persistable data object
476 Parameters
477 ----------
478 blend :
479 The blend that is being persisted.
481 Returns
482 -------
483 blend_data :
484 The data model for a single blend.
485 """
486 from .io import ScarletBlendData
488 sources: dict[Any, ScarletSourceBaseData] = {}
489 for sidx, source in enumerate(self.sources):
490 metadata = source.metadata or {}
491 if "id" in metadata:
492 sources[metadata["id"]] = source.to_data()
493 else:
494 sources[sidx] = source.to_data()
496 blend_data = ScarletBlendData(
497 origin=self.bbox.origin, # type: ignore
498 shape=self.bbox.shape, # type: ignore
499 sources=sources,
500 metadata=self.metadata,
501 )
503 return blend_data
505 def __getitem__(self, indices: Any) -> Blend:
506 """Get a sub-blend corresponding to the given indices.
508 Parameters
509 ----------
510 indices :
511 The indices to use to slice the blend.
513 Returns
514 -------
515 blend :
516 A new `Blend` instance containing only data from the
517 specified bands in the specified order.
519 Raises
520 ------
521 IndexError :
522 If the indices contain bands not included in the original
523 blend or a bounding box is given.
524 """
525 return Blend(
526 sources=[src[indices] for src in self.sources],
527 observation=self.observation[indices],
528 metadata=self.metadata,
529 )
531 def __copy__(self) -> Blend:
532 """Create a copy of this blend.
534 Returns
535 -------
536 blend : Blend
537 A new blend that is a copy of this one.
538 """
539 return Blend(sources=self.sources, observation=self.observation, metadata=self.metadata)
541 def __deepcopy__(self, memo: dict[int, Any]) -> Blend:
542 """Create a deep copy of this blend.
544 Parameters
545 ----------
546 memo : dict[int, Any]
547 A memoization dictionary used by `copy.deepcopy`.
549 Returns
550 -------
551 blend : Blend
552 A new blend that is a deep copy of this one.
553 """
554 # Check if already copied
555 if id(self) in memo:
556 return memo[id(self)]
558 # Create placeholder and add to memo FIRST
559 blend = Blend.__new__(Blend)
560 memo[id(self)] = blend
562 # Now safely initialize the placeholder with deepcopied arguments
563 blend.__init__( # type: ignore[misc]
564 sources=[deepcopy(src, memo) for src in self.sources],
565 observation=deepcopy(self.observation, memo),
566 metadata=deepcopy(self.metadata, memo),
567 )
569 return blend