Coverage for python / lsst / scarlet / lite / models / free_form.py: 27%
112 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 00:46 -0700
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 00:46 -0700
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["FactorizedFreeFormComponent"]
25from copy import deepcopy
26from typing import TYPE_CHECKING, Any, Callable, cast
28import numpy as np
30from ..bbox import Box
31from ..component import Component, FactorizedComponent
32from ..detect import footprints_to_image
33from ..detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore
34from ..image import Image
35from ..parameters import Parameter, parameter
37if TYPE_CHECKING:
38 from ..io.component import ScarletComponentBaseData
41class FactorizedFreeFormComponent(FactorizedComponent):
42 """Implements a free-form component
44 With no constraints this component is typically either a garbage collector,
45 or part of a set of components to deconvolve an image by separating out
46 the different spectral components.
48 See `FactorizedComponent` for a list of parameters not shown here.
50 Parameters
51 ----------
52 peaks: `list` of `tuple`
53 A set of ``(cy, cx)`` peaks for detected sources.
54 If peak is not ``None`` then only pixels in the same "footprint"
55 as one of the peaks are included in the morphology.
56 If `peaks` is ``None`` then there is no constraint applied.
57 min_area: float
58 The minimum area for a peak.
59 If `min_area` is not `None` then all regions of the morphology
60 with fewer than `min_area` connected pixels are removed.
61 """
63 def __init__(
64 self,
65 bands: tuple,
66 spectrum: np.ndarray | Parameter,
67 morph: np.ndarray | Parameter,
68 model_bbox: Box,
69 bg_thresh: float | None = None,
70 bg_rms: np.ndarray | None = None,
71 floor: float = 1e-20,
72 peaks: list[tuple[int, int]] | None = None,
73 min_area: float = 0,
74 ):
75 super().__init__(
76 bands=bands,
77 spectrum=spectrum,
78 morph=morph,
79 bbox=model_bbox,
80 peak=None,
81 bg_rms=bg_rms,
82 bg_thresh=bg_thresh,
83 floor=floor,
84 )
86 self.peaks = peaks
87 self.min_area = min_area
89 def prox_spectrum(self, spectrum: np.ndarray) -> np.ndarray:
90 """Apply a prox-like update to the spectrum
92 This differs from `FactorizedComponent` because an
93 `SedComponent` has the spectrum normalized to unity.
94 """
95 # prevent divergent spectrum
96 spectrum[spectrum < self.floor] = self.floor
97 # Normalize the spectrum
98 spectrum = spectrum / np.sum(spectrum)
99 return spectrum
101 def prox_morph(self, morph: np.ndarray) -> np.ndarray:
102 """Apply a prox-like update to the morphology
104 This is the main difference between an `SedComponent` and a
105 `FactorizedComponent`, since this component has fewer constraints.
106 """
107 from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore
109 if self.bg_thresh is not None and isinstance(self.bg_rms, np.ndarray):
110 bg_thresh = self.bg_rms * self.bg_thresh
111 # Enforce background thresholding
112 model = self.spectrum[:, None, None] * morph[None, :, :]
113 morph[np.all(model < bg_thresh[:, None, None], axis=0)] = 0
114 else:
115 # enforce positivity
116 morph[morph < 0] = 0
118 if self.peaks is not None:
119 footprint = get_connected_multipeak(morph, self.peaks, 0)
120 morph = morph * footprint
122 if self.min_area > 0:
123 footprints = get_footprints(morph, 4.0, self.min_area, 0, 0, False)
124 bbox = self.bbox.copy()
125 bbox.origin = (0, 0)
126 footprint_image = footprints_to_image(footprints, bbox)
127 morph = morph * (footprint_image > 0).data
129 if np.all(morph == 0):
130 morph[0, 0] = self.floor
132 return morph
134 def resize(self, model_box: Box) -> bool:
135 return False
137 def __str__(self):
138 return (
139 f"FactorizedFreeFormComponent(\n bands={self.bands}\n "
140 f"spectrum={self.spectrum})\n center={self.peak}\n "
141 f"morph_shape={self.morph.shape}"
142 )
144 def __repr__(self):
145 return self.__str__()
148class FreeFormComponent(Component):
149 """Implements a component with no spectral or monotonicty constraints
151 This is a FreeFormComponent that is not factorized into a
152 spectrum and morphology with no monotonicity constraint.
154 Parameters
155 ----------
156 bands:
157 The bands covered by the component.
158 model:
159 The 3D (bands, y, x) model of the component.
160 model_bbox:
161 The bounding box of the model.
162 bg_thresh:
163 The background threshold, in units of `bg_rms`, below which
164 pixels are set to zero. If `None` then only positivity is enforced.
165 bg_rms:
166 The background RMS in each band.
167 floor:
168 The minimum value to use for the model when it is otherwise empty.
169 peaks:
170 The `(y, x)` peaks of the component, used to keep only the pixels
171 connected to a peak. If `None` then no peak connectivity is enforced.
172 min_area:
173 The minimum area (in pixels) of a connected footprint to keep.
174 """
176 def __init__(
177 self,
178 bands: tuple,
179 model: np.ndarray | Parameter,
180 model_bbox: Box,
181 bg_thresh: float | None = None,
182 bg_rms: np.ndarray | None = None,
183 floor: float = 1e-20,
184 peaks: list[tuple[int, int]] | None = None,
185 min_area: float = 0,
186 ):
187 super().__init__(bands=bands, bbox=model_bbox)
188 self._model = parameter(model)
189 self.bg_rms = bg_rms
190 self.bg_thresh = bg_thresh
191 self.floor = floor
192 self.peaks = peaks
193 self.min_area = min_area
195 @property
196 def model(self) -> np.ndarray:
197 return self._model.x
199 def get_model(self) -> Image:
200 return Image(self.model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin))
202 @property
203 def shape(self) -> tuple:
204 return self.model.shape
206 def grad_model(self, input_grad: np.ndarray, model: np.ndarray) -> np.ndarray:
207 return input_grad
209 def prox_model(self, model: np.ndarray) -> np.ndarray:
210 if self.bg_thresh is not None and isinstance(self.bg_rms, np.ndarray):
211 bg_thresh = self.bg_rms * self.bg_thresh
212 # Enforce background thresholding
213 model[model < bg_thresh[:, None, None]] = 0
214 else:
215 # enforce positivity
216 model[model < 0] = 0
218 if self.peaks is not None:
219 # Remove pixels not connected to one of the peaks
220 model2d = np.sum(model, axis=0)
221 footprint = get_connected_multipeak(model2d, self.peaks, 0)
222 model = model * footprint[None, :, :]
224 if self.min_area > 0:
225 # Remove regions with fewer than min_area connected pixels
226 model2d = np.sum(model, axis=0)
227 footprints = get_footprints(model2d, 4.0, self.min_area, 0, 0, False)
228 bbox = self.bbox.copy()
229 bbox.origin = (0, 0)
230 footprint_image = footprints_to_image(footprints, bbox)
231 model = model * (footprint_image > 0).data[None, :, :]
233 if np.all(model == 0):
234 # If the model is all zeros, set a single pixel to the floor
235 model[0, 0] = self.floor
237 return model
239 def resize(self, model_box: Box) -> bool:
240 return False
242 def update(self, it: int, grad_log_likelihood: np.ndarray):
243 self._model.update(it, grad_log_likelihood)
245 def parameterize(self, parameterization: Callable) -> None:
246 """Convert the component parameter arrays into Parameter instances
248 Parameters
249 ----------
250 parameterization: Callable
251 A function to use to convert parameters of a given type into
252 a `Parameter` in place. It should take a single argument that
253 is the `Component` or `Source` that is to be parameterized.
254 """
255 # Update the spectrum and morph in place
256 parameterization(self)
257 # update the parameters
258 self._model.grad = self.grad_model
259 self._model.prox = self.prox_model
261 def __str__(self):
262 result = f"FreeFormComponent<bands={self.bands}, shape={self.shape}>"
263 return result
265 def __repr__(self):
266 return self.__str__()
268 def to_data(self) -> ScarletComponentBaseData:
269 raise NotImplementedError("Serialization not implemented for FreeFormComponent")
271 def __getitem__(self, indices: Any) -> FreeFormComponent:
272 """Get a sub-component corresponding to the given indices.
274 Parameters
275 ----------
276 indices: Any
277 The indices to use to slice the component model.
279 Returns
280 -------
281 component: FreeFormComponent
282 A new component that is a sub-component of this one.
284 Raises
285 ------
286 IndexError :
287 If the index includes a ``Box`` or spatial indices.
288 """
289 if indices in self.bands:
290 bands = (indices,)
291 else:
292 bands = tuple(indices)
294 return FreeFormComponent(
295 bands=bands,
296 model=self.model[indices],
297 model_bbox=self.bbox,
298 bg_thresh=self.bg_thresh,
299 bg_rms=self.bg_rms,
300 floor=self.floor,
301 peaks=self.peaks,
302 min_area=self.min_area,
303 )
305 def __deepcopy__(self, memo: dict[int, Any]) -> FreeFormComponent:
306 """Create a deep copy of this component.
308 Parameters
309 ----------
310 memo: dict[int, Any]
311 A dictionary to keep track of already copied objects.
313 Returns
314 -------
315 component : FreeFormComponent
316 A new component that is a deep copy of this one.
317 """
318 if id(self) in memo:
319 return memo[id(self)]
321 component = FreeFormComponent.__new__(FreeFormComponent)
322 memo[id(self)] = component
324 component.__init__( # type: ignore[misc]
325 bands=deepcopy(self.bands),
326 model=deepcopy(self.model),
327 model_bbox=deepcopy(self.bbox),
328 bg_thresh=self.bg_thresh,
329 bg_rms=deepcopy(self.bg_rms),
330 floor=self.floor,
331 peaks=deepcopy(self.peaks),
332 min_area=self.min_area,
333 )
334 return component
336 def __copy__(self) -> FreeFormComponent:
337 """Create a copy of this component.
339 Returns
340 -------
341 component : FreeFormComponent
342 A new component that is a copy of this one.
343 """
344 return FreeFormComponent(
345 bands=self.bands,
346 model=self.model,
347 model_bbox=self.bbox,
348 bg_thresh=self.bg_thresh,
349 bg_rms=self.bg_rms,
350 floor=self.floor,
351 peaks=self.peaks,
352 min_area=self.min_area,
353 )