Coverage for tests/test_initialization.py: 13%
176 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:24 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:24 +0000
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23from unittest.mock import patch
25import numpy as np
26from deprecated.sphinx import deprecated
27from lsst.scarlet.lite import Box, Image, Observation
28from lsst.scarlet.lite.initialization import (
29 FactorizedInitialization,
30 FactorizedWaveletInitialization,
31 init_monotonic_morph,
32 multifit_spectra,
33 trim_morphology,
34)
35from lsst.scarlet.lite.operators import Monotonicity, prox_monotonic_mask
36from lsst.scarlet.lite.utils import integrated_circular_gaussian
37from numpy.testing import assert_almost_equal, assert_array_equal
38from scipy.signal import convolve as scipy_convolve
39from utils import ObservationData, ScarletTestCase
42class TestInitialization(ScarletTestCase):
43 def setUp(self) -> None:
44 yx0 = (1000, 2000)
45 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz")
46 filename = os.path.abspath(filename)
47 data = np.load(filename)
48 model_psf = integrated_circular_gaussian(sigma=0.8)
49 self.detect = np.sum(data["images"], axis=0)
50 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T + np.array(yx0)
51 bands = data["filters"]
52 self.observation = Observation(
53 Image(data["images"], bands=bands, yx0=yx0),
54 Image(data["variance"], bands=bands, yx0=yx0),
55 Image(1 / data["variance"], bands=bands, yx0=yx0),
56 data["psfs"],
57 model_psf[None],
58 bands=bands,
59 )
61 def test_trim_morphology(self):
62 # Default parameters: returns a tight bbox around the non-zero
63 # support of the input.
64 morph = np.zeros((50, 50)).astype(np.float32)
65 morph[10:15, 12:27] = 1
66 trimmed, trimmed_box = trim_morphology(morph)
67 assert_array_equal(trimmed, morph)
68 self.assertTupleEqual(trimmed_box.origin, (10, 12))
69 self.assertTupleEqual(trimmed_box.shape, (5, 15))
70 self.assertEqual(trimmed.dtype, np.float32)
72 # With a threshold: pixels at or below the threshold are zeroed,
73 # and the bbox is the tight box around what remains. The input
74 # array must not be mutated (audit finding I-8).
75 morph = np.full((50, 50), 0.1).astype(np.float32)
76 morph[10:15, 12:27] = 1
77 original = morph.copy()
78 truth = np.zeros(morph.shape)
79 truth[10:15, 12:27] = 1
80 trimmed, trimmed_box = trim_morphology(morph, 0.5)
81 assert_array_equal(trimmed, truth)
82 assert_array_equal(morph, original)
83 self.assertTupleEqual(trimmed_box.origin, (10, 12))
84 self.assertTupleEqual(trimmed_box.shape, (5, 15))
85 self.assertEqual(trimmed.dtype, np.float32)
87 def test_init_monotonic_mask(self):
88 full_box = self.observation.bbox
89 center = self.centers[0]
90 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1])
92 # Default parameters
93 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box)
94 self.assertBoxEqual(bbox, Box((38, 29), (1014, 2000)))
95 _, masked_morph, _ = prox_monotonic_mask(self.detect.copy(), local_center, max_iter=0)
96 assert_array_equal(morph, masked_morph / np.max(masked_morph))
97 self.assertEqual(morph.dtype, np.float32)
99 # Non-zero threshold AND non-zero padding. This combination
100 # exercises the path-1 trim_morphology call AND the post-trim
101 # padding step; if those ever double up again (audit I-2), the
102 # bbox below would grow to (34, 28) at origin (1017, 2000)
103 # rather than (30, 25) at (1019, 2001).
104 bbox, morph = init_monotonic_morph(
105 self.detect.copy(),
106 center,
107 full_box,
108 2, # padding
109 False, # normalize
110 None, # monotonicity
111 0.2, # threshold
112 )
113 self.assertBoxEqual(bbox, Box((30, 25), (1019, 2001)))
114 # Remove pixels below the threshold
115 truth = masked_morph.copy()
116 truth[truth < 0.2] = 0
117 assert_array_equal(morph, truth)
118 self.assertEqual(morph.dtype, np.float32)
120 # Test an empty morphology
121 bbox, morph = init_monotonic_morph(np.zeros(self.detect.shape), center, full_box)
122 self.assertBoxEqual(bbox, Box((0, 0)))
123 self.assertIsNone(morph)
125 def test_init_monotonic_weighted(self):
126 full_box = self.observation.bbox
127 center = self.centers[0]
128 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1])
129 monotonicity = Monotonicity((101, 101))
131 # Default parameters
132 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box, monotonicity=monotonicity)
133 truth = monotonicity(self.detect.copy(), local_center)
134 truth[truth < 0] = 0
135 truth = truth / np.max(truth)
136 self.assertBoxEqual(bbox, Box((58, 48), origin=(1000, 2000)))
137 assert_array_equal(morph, truth)
138 self.assertEqual(morph.dtype, np.float32)
140 # Non-zero threshold AND non-zero padding. trim_morphology is
141 # always called on path 2; pairing that with padding > 0 makes
142 # the regression for audit I-2 (double padding) observable
143 # rather than masked by clipping or padding=0.
144 bbox, morph = init_monotonic_morph(
145 self.detect.copy(),
146 center,
147 full_box,
148 2, # padding
149 False, # normalize
150 monotonicity, # monotonicity
151 0.2, # threshold
152 )
153 truth = monotonicity(self.detect.copy(), local_center)
154 truth[truth < 0.2] = 0
155 self.assertBoxEqual(bbox, Box((49, 47), origin=(1008, 2001)))
156 assert_array_equal(morph, truth)
157 self.assertEqual(morph.dtype, np.float32)
159 # Test zero morphology
160 zeros = np.zeros(self.detect.shape)
161 bbox, morph = init_monotonic_morph(zeros, center, full_box, monotonicity=monotonicity)
162 self.assertBoxEqual(bbox, Box((0, 0), (1000, 2000)))
163 self.assertIsNone(morph)
165 def test_multifit_spectra(self):
166 bands = ("g", "r", "i")
167 variance = np.ones((3, 35, 35), dtype=np.float32)
168 weights = 1 / variance
169 psfs = np.array([integrated_circular_gaussian(sigma=sigma) for sigma in [1.05, 0.9, 1.2]])
170 psfs = psfs.astype(np.float32)
171 model_psf = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
173 # The spectrum of each source
174 spectra = np.array(
175 [
176 [31, 10, 0],
177 [0, 5, 20],
178 [15, 8, 3],
179 [20, 3, 4],
180 [0, 30, 60],
181 ],
182 dtype=np.float32,
183 )
185 # Use a point source for all of the sources
186 morphs = [
187 integrated_circular_gaussian(sigma=sigma).astype(np.float32)
188 for sigma in [0.8, 3.1, 1.1, 2.1, 1.5]
189 ]
190 # Make the second component a disk component
191 morphs[1] = scipy_convolve(morphs[1], model_psf, mode="same")
193 # Give the first two components the same center, and unique centers
194 # for the remaining sources
195 centers = [
196 (10, 12),
197 (10, 12),
198 (20, 23),
199 (20, 10),
200 (25, 20),
201 ]
203 # Create the Observation
204 test_data = ObservationData(bands, psfs, spectra, morphs, centers, model_psf, dtype=np.float32)
205 observation = Observation(
206 test_data.convolved,
207 variance,
208 weights,
209 psfs,
210 model_psf[None],
211 bands=bands,
212 )
214 fit_spectra = multifit_spectra(observation, test_data.morphs)
215 self.assertEqual(fit_spectra.dtype, spectra.dtype)
216 assert_almost_equal(fit_spectra, spectra, decimal=5)
218 def test_psf_component_at_boundary(self):
219 """``get_psf_component`` must extract the surviving region of
220 the model PSF when the source center is close enough to the
221 observation boundary that the PSF box is clipped.
223 Audit finding I-3: the original code created the Image with
224 ``yx0=bbox.origin`` (the *intersection's* origin) instead of
225 the original PSF origin, so ``[bbox]`` returned the top-left
226 of the PSF rather than the portion of the PSF that survived
227 the clip. The PSF was therefore spatially misaligned with the
228 actual source center.
229 """
230 init = FactorizedInitialization(self.observation, self.centers)
231 model_psf = self.observation.model_psf[0]
233 # Case 1: positive psf_bbox.origin. Center at the top-left
234 # corner of the observation (origin (1000, 2000)): the 15x15
235 # model PSF (py=px=7) extends 7 rows above and 3 columns to
236 # the left of the observation bbox, so 7 rows and 3 columns
237 # are clipped. psf_bbox.origin = (993, 1997) — both positive.
238 center = (1000, 2004)
239 component = init.get_psf_component(center)
240 self.assertBoxEqual(component.bbox, Box((8, 12), origin=(1000, 2000)))
241 # The surviving region is psf[7:15, 3:15] — the bottom-right
242 # of the PSF, not the top-left.
243 assert_array_equal(component.morph, model_psf[7:15, 3:15])
245 # Case 2: negative psf_bbox.origin. Build a synthetic
246 # observation at origin (0, 0) and place a center near the
247 # corner so psf_bbox.origin = (-5, -2) — both negative. The
248 # negative-origin path must still produce the correct
249 # surviving region. ``Box.slices`` rejects negative origins,
250 # so this also guards against future refactors that would
251 # call ``.slices`` on ``psf_bbox`` directly.
252 bands = ("r",)
253 shape = (30, 30)
254 images = np.ones((1,) + shape, dtype=np.float32)
255 variance = np.ones((1,) + shape, dtype=np.float32)
256 psfs = np.array([integrated_circular_gaussian(sigma=1.0)], dtype=np.float32)
257 small_model_psf = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
258 small_obs = Observation(
259 Image(images, bands=bands, yx0=(0, 0)),
260 Image(variance, bands=bands, yx0=(0, 0)),
261 Image(1 / variance, bands=bands, yx0=(0, 0)),
262 psfs,
263 small_model_psf[None],
264 bands=bands,
265 )
266 small_init = FactorizedInitialization(small_obs, [(2, 5)])
267 component = small_init.get_psf_component((2, 5))
268 self.assertBoxEqual(component.bbox, Box((10, 13), origin=(0, 0)))
269 assert_array_equal(component.morph, small_model_psf[5:15, 2:15])
271 # Case 3: PSF footprint does not overlap the observation at
272 # all -> raise an informative error rather than silently
273 # producing a degenerate component.
274 with self.assertRaises(ValueError):
275 small_init.get_psf_component((-100, -100))
277 def test_get_psf_component_zero_psf_spectrum(self):
278 """``get_psf_component`` must produce a finite spectrum even
279 when one of the per-band ``psf_spectrum`` values is zero.
281 Audit finding I-10: dividing by ``self.psf_spectrum`` produces
282 ``inf`` (or ``nan`` for 0/0) for any band with a degenerate
283 model PSF whose central value is zero. The trailing
284 ``spectrum[spectrum < 0] = 0`` mask does not catch either.
285 Same fix pattern as I-4.
286 """
287 init = FactorizedInitialization(self.observation, self.centers)
288 # Force one band's psf_spectrum to zero. Real model PSFs always
289 # have a positive central pixel, but a degenerate PSF could
290 # exhibit this — and the masked branch should still be finite.
291 init.psf_spectrum[0] = 0
292 center = (int(self.centers[0][0]), int(self.centers[0][1]))
293 component = init.get_psf_component(center)
294 np.testing.assert_array_equal(np.isfinite(component.spectrum), True)
295 np.testing.assert_array_equal(component.spectrum >= 0, True)
296 # The zero-psf_spectrum band must yield zero flux rather than
297 # inf or a saturating large value.
298 self.assertEqual(component.spectrum[0], 0)
300 def test_get_single_component_zero_convolved(self):
301 """``get_single_component`` must produce a finite spectrum
302 even when the convolved detection image is zero at the source
303 center.
305 Audit finding I-4: ``spectrum = images / convolved`` produces
306 ``inf`` (or ``nan`` for 0/0) at zero-convolved pixels, and
307 the subsequent ``spectrum[spectrum < 0] = 0`` does not catch
308 either, so the component is initialized with non-finite
309 flux.
310 """
311 init = FactorizedInitialization(self.observation, self.centers)
312 center = (int(self.centers[0][0]), int(self.centers[0][1]))
313 local_center = (
314 center[0] - init.observation.bbox.origin[0],
315 center[1] - init.observation.bbox.origin[1],
316 )
317 # Force the convolved detection image to zero at the source
318 # center across all bands. The detection image still has flux
319 # at this pixel, so ``init_monotonic_morph`` returns a valid
320 # morph and the spectrum branch is exercised.
321 init.convolved.data[:, local_center[0], local_center[1]] = 0
323 thresh = np.mean(self.observation.noise_rms) * init.initial_bg_thresh
324 component = init.get_single_component(center, init.detect.copy(), thresh, init.padding)
325 assert component is not None
326 np.testing.assert_array_equal(np.isfinite(component.spectrum), True)
327 np.testing.assert_array_equal(component.spectrum >= 0, True)
329 def test_factorized_chi2_init(self):
330 # Test default parameters
331 init = FactorizedInitialization(self.observation, self.centers)
332 self.assertEqual(init.observation, self.observation)
333 self.assertEqual(init.min_snr, 50)
334 self.assertIsNone(init.monotonicity)
335 self.assertEqual(init.disk_percentile, 25)
336 self.assertEqual(init.thresh, 0.5)
337 self.assertTupleEqual((init.py, init.px), (7, 7))
338 self.assertEqual(len(init.sources), 7)
339 for src in init.sources:
340 self.assertEqual(src.get_model().dtype, np.float32)
342 centers = tuple(tuple(center.astype(int)) for center in self.centers) + ((1000, 2004),)
343 init = FactorizedInitialization(self.observation, centers)
344 self.assertEqual(len(init.sources), 8)
345 for src in init.sources:
346 self.assertEqual(src.get_model().dtype, np.float32)
348 @deprecated(
349 version="v29.0",
350 reason="FactorizedWaveletInitialization is deprecated and will be removed after v29.0",
351 )
352 def test_wavelet_init_source_falls_back_to_psf(self):
353 """init_source must always return a Source with at least one
354 component, even when individual init paths fail.
356 Audit finding I-1: when get_single_component returned None or
357 the two-component path produced all-zero spectra, ``components``
358 was either left unbound (UnboundLocalError) or set to an empty
359 list. Both cases must now fall back to a PSF component.
360 """
361 init = FactorizedWaveletInitialization(self.observation, self.centers)
362 int_centers = [(int(round(c[0])), int(round(c[1]))) for c in self.centers]
364 # Failure mode 1: get_single_component always returns None.
365 # Centers hitting the single-component branch or the
366 # two-component fallback must fall back to PSF rather than
367 # raising or producing an empty Source.
368 with patch.object(FactorizedWaveletInitialization, "get_single_component", return_value=None):
369 for center in int_centers:
370 source = init.init_source(center)
371 self.assertGreater(len(source.components), 0)
373 # Failure mode 2: two-component path returns all-zero spectra
374 # for both bulge and disk -> empty components list, also a fall
375 # back case.
376 n_bands = len(self.observation.bands)
377 with patch(
378 "lsst.scarlet.lite.initialization.multifit_spectra",
379 return_value=np.zeros((2, n_bands), dtype=np.float32),
380 ):
381 for center in int_centers:
382 source = init.init_source(center)
383 self.assertGreater(len(source.components), 0)
385 @deprecated(
386 version="v29.0",
387 reason="FactorizedWaveletInitialization is deprecated and will be removed after v29.0",
388 )
389 def test_factorized_wavelet_init(self):
390 # Test default parameters
391 init = FactorizedWaveletInitialization(self.observation, self.centers)
392 self.assertEqual(init.observation, self.observation)
393 self.assertEqual(init.min_snr, 50)
394 self.assertIsNone(init.monotonicity)
395 self.assertTupleEqual((init.py, init.px), (7, 7))
396 self.assertEqual(len(init.sources), 7)
397 components = np.sum([len(src.components) for src in init.sources])
398 self.assertEqual(components, 8)
399 for src in init.sources:
400 self.assertEqual(src.get_model().dtype, np.float32)