Coverage for tests/test_component.py: 16%
258 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:25 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:25 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24from abc import ABC
25from copy import deepcopy
26from typing import Any, Callable
28import numpy as np
29from lsst.scarlet.lite import Box, Image, Parameter
30from lsst.scarlet.lite.component import (
31 Component,
32 CubeComponent,
33 FactorizedComponent,
34 default_adaprox_parameterization,
35 default_fista_parameterization,
36)
37from lsst.scarlet.lite.operators import Monotonicity
38from lsst.scarlet.lite.utils import integrated_circular_gaussian
39from numpy.testing import assert_almost_equal, assert_array_equal
40from utils import ScarletTestCase
43class DummyComponent(Component):
44 def resize(self) -> bool:
45 pass
47 def update(self, it: int, input_grad: np.ndarray):
48 pass
50 def get_model(self) -> Image:
51 pass
53 def parameterize(self, parameterization: Callable) -> None:
54 parameterization(self)
56 def to_data(self) -> DummyComponent:
57 pass
59 def __getitem__(self, indices: Any) -> DummyComponent:
60 pass
62 def __copy__(self) -> DummyComponent:
63 pass
65 def __deepcopy__(self, memo: dict[int, Any]) -> DummyComponent:
66 pass
69class _ComponentTestBase(ABC):
70 def test_slice(self):
71 component = self.component
72 component_sliced = component["g":"r"]
73 self.assertTupleEqual(component_sliced.bands, ("g", "r"))
74 np.testing.assert_array_equal(component_sliced.get_model(), component.get_model().data[0:2])
76 def test_reorder(self):
77 component = self.component
78 indices = ("i", "g", "r")
79 component_reordered = component["i", "g", "r"]
80 self.assertTupleEqual(component_reordered.bands, indices)
81 np.testing.assert_array_equal(
82 component_reordered.get_model(),
83 component.get_model().data[(2, 0, 1),],
84 )
86 component_reordered = component["igr"]
87 self.assertTupleEqual(component_reordered.bands, indices)
88 np.testing.assert_array_equal(
89 component_reordered.get_model(),
90 component.get_model().data[(2, 0, 1),],
91 )
93 def test_subset(self):
94 component = self.component
95 indices = ("r",)
96 component_subset = component["r"]
97 self.assertTupleEqual(component_subset.bands, indices)
98 np.testing.assert_array_equal(
99 component_subset.get_model(),
100 component.get_model().data[1:2,],
101 )
103 component = self.component.copy(deep=True)
104 component._bands = ("ab", "cd", "ef")
105 indices = "ab"
106 component_reordered = component["ab"]
107 self.assertTupleEqual(component_reordered.bands, (indices,))
108 np.testing.assert_array_equal(
109 component_reordered.get_model(),
110 component.get_model().data[0:1,],
111 )
113 def test_indexing_errors(self):
114 component = self.component
115 print("bands", component.bands)
116 with self.assertRaises(IndexError):
117 component["z"]
119 with self.assertRaises(IndexError):
120 component["r":"z"]
122 with self.assertRaises(IndexError):
123 component["z":"i"]
125 with self.assertRaises(IndexError):
126 component["g", "z", "i"]
128 with self.assertRaises(IndexError):
129 component[Box((0, 0), (10, 10))]
131 with self.assertRaises(IndexError):
132 component[:, 10:20, 10:20]
134 with self.assertRaises(IndexError):
135 component[1:]
137 with self.assertRaises(IndexError):
138 component[1]
140 with self.assertRaises(IndexError):
141 component[0, 1]
144class TestFactorizedComponent(_ComponentTestBase, ScarletTestCase):
145 def setUp(self) -> None:
146 spectrum = np.arange(3).astype(np.float32)
147 morph = np.arange(20).reshape(4, 5).astype(np.float32)
148 bands = ("g", "r", "i")
149 bbox = Box((4, 5), (22, 31))
150 self.model_box = Box((100, 100))
151 center = (24, 33)
153 self.component = FactorizedComponent(
154 bands,
155 spectrum,
156 morph,
157 bbox,
158 center,
159 )
161 self.bands = bands
162 self.spectrum = spectrum
163 self.morph = morph
164 self.full_shape = (3, 100, 100)
166 def test_constructor(self):
167 # Test with only required parameters
168 component = FactorizedComponent(
169 self.bands,
170 self.spectrum,
171 self.morph,
172 self.component.bbox,
173 )
175 self.assertIsInstance(component._spectrum, Parameter)
176 assert_array_equal(component.spectrum, self.spectrum)
177 self.assertIsInstance(component._morph, Parameter)
178 assert_array_equal(component.morph, self.morph)
179 self.assertBoxEqual(component.bbox, self.component.bbox)
180 self.assertIsNone(component.peak)
181 self.assertIsNone(component.bg_rms)
182 self.assertEqual(component.bg_thresh, 0.25)
183 self.assertEqual(component.floor, 1e-20)
184 self.assertTupleEqual(component.shape, (3, 4, 5))
186 # Test that parameters are passed through
187 center = self.component.peak
188 bg_rms = np.arange(5) / 10
189 bg_thresh = 0.9
190 floor = 1e-10
192 component = FactorizedComponent(
193 self.bands,
194 self.spectrum,
195 self.morph,
196 self.component.bbox,
197 center,
198 bg_rms,
199 bg_thresh,
200 floor,
201 )
203 self.assertTupleEqual(component.peak, center)
204 assert_array_equal(component.bg_rms, bg_rms) # type: ignore
205 self.assertEqual(component.bg_thresh, bg_thresh)
206 self.assertEqual(component.floor, floor)
207 self.assertEqual(component.get_model().dtype, np.float32)
209 def test_get_model(self):
210 component = self.component
211 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :])
213 # Insert component into a larger model
214 full_model = np.zeros(self.full_shape)
215 full_model[:, 22:26, 31:36] = self.spectrum[:, None, None] * self.morph[None, :, :]
217 test_model = Image(np.zeros(self.full_shape), bands=self.bands)
218 test_model += component.get_model()
220 assert_array_equal(test_model.data, full_model)
222 def test_gradients(self):
223 component = self.component
224 morph = self.morph
225 spectrum = self.spectrum
227 input_grad = np.array([morph, 2 * morph, 3 * morph])
228 true_spectrum_grad = np.array(
229 [
230 np.sum(morph**2),
231 np.sum(2 * morph**2),
232 np.sum(3 * morph**2),
233 ]
234 )
235 assert_almost_equal(component.grad_spectrum(input_grad, spectrum, morph), true_spectrum_grad)
237 true_morph_grad = np.sum(input_grad * spectrum[:, None, None], axis=0)
238 assert_almost_equal(component.grad_morph(input_grad, morph, spectrum), true_morph_grad)
240 def test_proximal_operators(self):
241 # Test spectrum positivity, morph threshold, and monotonicity
242 spectrum = np.array([-1, 2, 3], dtype=float)
243 morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, -1]], dtype=float)
244 bbox = Box((3, 3), (10, 10))
245 morph_bbox = Box((100, 100))
246 center = (11, 11)
247 monotonicity = Monotonicity((101, 101), fit_radius=0)
249 component = FactorizedComponent(
250 self.bands,
251 spectrum.copy(),
252 morph.copy(),
253 bbox,
254 center,
255 bg_rms=np.array([1, 1, 1]),
256 bg_thresh=0.5,
257 monotonicity=monotonicity,
258 )
260 proxed_spectrum = np.array([1e-20, 2, 3])
261 proxed_morph = np.array([[2.6666666666666667, 2, 1], [1, 5, 3], [0, 4, 0]])
262 proxed_morph = proxed_morph / 5
264 component.prox_spectrum(component.spectrum)
265 component.prox_morph(component.morph)
267 assert_array_equal(component.spectrum, proxed_spectrum)
268 assert_array_equal(component.morph, proxed_morph)
270 component = FactorizedComponent(
271 self.bands,
272 spectrum.copy(),
273 morph.copy(),
274 bbox,
275 None,
276 )
278 proxed_spectrum = np.array([1e-20, 2, 3])
279 proxed_morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, 0]])
280 proxed_morph = proxed_morph / 10
282 component.prox_spectrum(component.spectrum)
283 component.prox_morph(component.morph)
285 assert_array_equal(component.spectrum, proxed_spectrum)
286 assert_array_equal(component.morph, proxed_morph)
288 self.assertFalse(component.resize(morph_bbox))
290 def test_prox_morph_enforces_positivity_with_bg_thresh(self):
291 """Audit finding K-2: ``prox_morph`` previously enforced
292 positivity only on the ``else`` branch (no ``bg_thresh``).
293 With ``bg_thresh`` active, a negative morph pixel survived
294 when ``spectrum * morph >= bg_thresh`` in at least one band
295 — possible whenever the spectrum has a negative band, since
296 ``neg * neg`` is positive. Positivity must always be
297 enforced.
298 """
299 bands = ("g", "r", "i")
300 # Mixed-sign spectrum: the negative band makes the model
301 # positive at a negative-morph pixel for that band, so the
302 # threshold check no longer catches it.
303 spectrum = np.array([-1.0, 1.0, 1.0])
304 morph = np.full((3, 3), 0.5, dtype=float)
305 morph[0, 0] = -0.3
307 component = FactorizedComponent(
308 bands,
309 spectrum,
310 morph,
311 Box((3, 3), (0, 0)),
312 peak=None,
313 bg_rms=np.array([0.1, 0.1, 0.1]),
314 bg_thresh=0.5,
315 )
317 proxed = component.prox_morph(component.morph.copy())
318 # The negative pixel must be zeroed by the positivity guard,
319 # not slip through because spectrum*morph >= bg_thresh in
320 # the negative-spectrum band.
321 self.assertEqual(proxed[0, 0], 0)
323 def test_resize(self):
324 spectrum = np.array([1, 2, 3], dtype=float)
325 morph = np.zeros((10, 10), dtype=float)
326 morph[3:6, 5:8] = np.arange(9).reshape(3, 3)
327 bbox = Box((10, 10), (3, 5))
329 morph_bbox = Box((100, 100))
330 monotonicity = Monotonicity((101, 101), fit_radius=0)
332 component = FactorizedComponent(
333 self.bands,
334 spectrum.copy(),
335 morph.copy(),
336 bbox,
337 None,
338 bg_rms=np.array([1, 1, 1]),
339 bg_thresh=0.5,
340 monotonicity=monotonicity,
341 padding=1,
342 )
344 self.assertTupleEqual(component.morph.shape, (10, 10))
345 self.assertIsNone(component.component_center)
347 component.resize(morph_bbox)
348 self.assertTupleEqual(component.morph.shape, (5, 5))
349 self.assertTupleEqual(component.bbox.origin, (5, 9))
350 self.assertTupleEqual(component.bbox.shape, (5, 5))
351 self.assertIsNone(component.component_center)
353 def test_parameterization(self):
354 component = self.component
355 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :])
357 component.parameterize(default_fista_parameterization)
358 helpers = set(component._morph.helpers.keys())
359 self.assertSetEqual(helpers, {"z"})
360 component.parameterize(default_adaprox_parameterization)
361 helpers = set(component._morph.helpers.keys())
362 self.assertSetEqual(helpers, {"m", "v", "vhat"})
364 params = (tuple("grizy"), Box((5, 5)))
365 with self.assertRaises(NotImplementedError):
366 default_fista_parameterization(DummyComponent(*params))
368 with self.assertRaises(NotImplementedError):
369 default_adaprox_parameterization(DummyComponent(*params))
371 def test_shallow_copy(self):
372 component = self.component
373 component.monotonicity = Monotonicity((11, 11), fit_radius=0)
375 component_copy = component.copy()
377 self.assertIsNot(component, component_copy)
378 np.testing.assert_array_equal(component._spectrum.x, component_copy._spectrum.x)
379 np.testing.assert_array_equal(component._morph.x, component_copy._morph.x)
380 self.assertIs(component.bbox, component_copy.bbox)
381 self.assertIs(component.peak, component_copy.peak)
382 self.assertIs(component.bg_thresh, component_copy.bg_thresh)
383 self.assertIs(component.monotonicity, component_copy.monotonicity)
385 def test_deep_copy(self):
386 component = self.component
387 component.monotonicity = Monotonicity((11, 11), fit_radius=0)
388 component_deepcopy = component.copy(deep=True)
390 self.assertIsNot(component, component_deepcopy)
392 np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x)
393 component_deepcopy._spectrum.x += 1
394 with self.assertRaises(AssertionError):
395 np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x)
397 np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x)
398 component_deepcopy._morph.x += 1
399 with self.assertRaises(AssertionError):
400 np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x)
402 self.assertIsNot(component.bbox, component_deepcopy.bbox)
403 self.assertBoxEqual(component.bbox, component_deepcopy.bbox)
405 self.assertTupleEqual(component.peak, component_deepcopy.peak)
406 self.assertEqual(component.bg_thresh, component_deepcopy.bg_thresh)
407 self.assertIsNot(component.monotonicity, component_deepcopy.monotonicity)
410class TestCubeComponent(_ComponentTestBase, ScarletTestCase):
411 def setUp(self) -> None:
412 super().setUp()
413 self.bands = tuple("gri")
414 peak = (27, 32)
415 bbox = Box((15, 15), (20, 25))
416 morph = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
417 spectrum = np.arange(3, dtype=np.float32)
418 model = morph[None, :, :] * spectrum[:, None, None]
419 model_image = Image(model, yx0=bbox.origin, bands=self.bands)
420 self.component = CubeComponent(model=model_image, peak=peak)
422 def test_constructor(self):
423 component = self.component
424 self.assertIsInstance(component._model, Image)
425 np.testing.assert_array_equal(component._model.data, self.component._model.data)
426 self.assertTupleEqual(component.bands, self.bands)
427 self.assertBoxEqual(component.bbox, Box((15, 15), (20, 25)))
428 self.assertTupleEqual(component.peak, (27, 32))
430 def test_shallow_copy(self):
431 component = self.component
432 component_copy = component.copy()
434 self.assertIsNot(component_copy, component)
435 self.assertTupleEqual(component_copy.peak, component.peak)
436 self.assertImageEqual(component_copy._model, component._model)
438 def test_deep_copy(self):
439 component = self.component
440 component_copy = component.copy(deep=True)
442 self.assertIsNot(component, component_copy)
444 self.assertTupleEqual(component_copy.peak, component.peak)
445 self.assertImageEqual(component_copy._model, component._model)
446 with self.assertRaises(AssertionError):
447 component_copy._model._data -= 1
448 self.assertImageEqual(component_copy._model, component._model)
450 def test_deep_copy_preserves_memo(self):
451 """Audit finding K-6: ``__deepcopy__`` must call ``deepcopy``
452 with the memo dict on its sub-objects so shared references
453 in the input graph remain shared in the copy. Previously
454 ``self._model.copy()`` allocated a fresh image regardless of
455 whether ``self._model`` already appeared in the memo.
456 """
457 c1 = CubeComponent(model=self.component._model, peak=self.component.peak)
458 c2 = CubeComponent(model=self.component._model, peak=self.component.peak)
459 # Two CubeComponents sharing the same Image instance.
460 self.assertIs(c1._model, c2._model)
462 c1_copy, c2_copy = deepcopy([c1, c2])
464 # The model should be different in the deepcopy,
465 # since CubeComponent's deepcopy creates a new Image instance.
466 self.assertIsNot(c1_copy._model, c1._model)
468 # The shared Image must still be shared in the deepcopy,
469 # otherwise larger object graphs that rely on identity
470 # (e.g. observation sharing) silently fork.
471 self.assertIs(c1_copy._model, c2_copy._model)