Coverage for tests / test_initialization.py: 13%

176 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 00:46 -0700

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23from unittest.mock import patch 

24 

25import numpy as np 

26from deprecated.sphinx import deprecated 

27from lsst.scarlet.lite import Box, Image, Observation 

28from lsst.scarlet.lite.initialization import ( 

29 FactorizedInitialization, 

30 FactorizedWaveletInitialization, 

31 init_monotonic_morph, 

32 multifit_spectra, 

33 trim_morphology, 

34) 

35from lsst.scarlet.lite.operators import Monotonicity, prox_monotonic_mask 

36from lsst.scarlet.lite.utils import integrated_circular_gaussian 

37from numpy.testing import assert_almost_equal, assert_array_equal 

38from scipy.signal import convolve as scipy_convolve 

39from utils import ObservationData, ScarletTestCase 

40 

41 

42class TestInitialization(ScarletTestCase): 

43 def setUp(self) -> None: 

44 yx0 = (1000, 2000) 

45 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

46 filename = os.path.abspath(filename) 

47 data = np.load(filename) 

48 model_psf = integrated_circular_gaussian(sigma=0.8) 

49 self.detect = np.sum(data["images"], axis=0) 

50 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T + np.array(yx0) 

51 bands = data["filters"] 

52 self.observation = Observation( 

53 Image(data["images"], bands=bands, yx0=yx0), 

54 Image(data["variance"], bands=bands, yx0=yx0), 

55 Image(1 / data["variance"], bands=bands, yx0=yx0), 

56 data["psfs"], 

57 model_psf[None], 

58 bands=bands, 

59 ) 

60 

61 def test_trim_morphology(self): 

62 # Default parameters: returns a tight bbox around the non-zero 

63 # support of the input. 

64 morph = np.zeros((50, 50)).astype(np.float32) 

65 morph[10:15, 12:27] = 1 

66 trimmed, trimmed_box = trim_morphology(morph) 

67 assert_array_equal(trimmed, morph) 

68 self.assertTupleEqual(trimmed_box.origin, (10, 12)) 

69 self.assertTupleEqual(trimmed_box.shape, (5, 15)) 

70 self.assertEqual(trimmed.dtype, np.float32) 

71 

72 # With a threshold: pixels at or below the threshold are zeroed, 

73 # and the bbox is the tight box around what remains. The input 

74 # array must not be mutated (audit finding I-8). 

75 morph = np.full((50, 50), 0.1).astype(np.float32) 

76 morph[10:15, 12:27] = 1 

77 original = morph.copy() 

78 truth = np.zeros(morph.shape) 

79 truth[10:15, 12:27] = 1 

80 trimmed, trimmed_box = trim_morphology(morph, 0.5) 

81 assert_array_equal(trimmed, truth) 

82 assert_array_equal(morph, original) 

83 self.assertTupleEqual(trimmed_box.origin, (10, 12)) 

84 self.assertTupleEqual(trimmed_box.shape, (5, 15)) 

85 self.assertEqual(trimmed.dtype, np.float32) 

86 

87 def test_init_monotonic_mask(self): 

88 full_box = self.observation.bbox 

89 center = self.centers[0] 

90 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1]) 

91 

92 # Default parameters 

93 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box) 

94 self.assertBoxEqual(bbox, Box((38, 29), (1014, 2000))) 

95 _, masked_morph, _ = prox_monotonic_mask(self.detect.copy(), local_center, max_iter=0) 

96 assert_array_equal(morph, masked_morph / np.max(masked_morph)) 

97 self.assertEqual(morph.dtype, np.float32) 

98 

99 # Non-zero threshold AND non-zero padding. This combination 

100 # exercises the path-1 trim_morphology call AND the post-trim 

101 # padding step; if those ever double up again (audit I-2), the 

102 # bbox below would grow to (34, 28) at origin (1017, 2000) 

103 # rather than (30, 25) at (1019, 2001). 

104 bbox, morph = init_monotonic_morph( 

105 self.detect.copy(), 

106 center, 

107 full_box, 

108 2, # padding 

109 False, # normalize 

110 None, # monotonicity 

111 0.2, # threshold 

112 ) 

113 self.assertBoxEqual(bbox, Box((30, 25), (1019, 2001))) 

114 # Remove pixels below the threshold 

115 truth = masked_morph.copy() 

116 truth[truth < 0.2] = 0 

117 assert_array_equal(morph, truth) 

118 self.assertEqual(morph.dtype, np.float32) 

119 

120 # Test an empty morphology 

121 bbox, morph = init_monotonic_morph(np.zeros(self.detect.shape), center, full_box) 

122 self.assertBoxEqual(bbox, Box((0, 0))) 

123 self.assertIsNone(morph) 

124 

125 def test_init_monotonic_weighted(self): 

126 full_box = self.observation.bbox 

127 center = self.centers[0] 

128 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1]) 

129 monotonicity = Monotonicity((101, 101)) 

130 

131 # Default parameters 

132 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box, monotonicity=monotonicity) 

133 truth = monotonicity(self.detect.copy(), local_center) 

134 truth[truth < 0] = 0 

135 truth = truth / np.max(truth) 

136 self.assertBoxEqual(bbox, Box((58, 48), origin=(1000, 2000))) 

137 assert_array_equal(morph, truth) 

138 self.assertEqual(morph.dtype, np.float32) 

139 

140 # Non-zero threshold AND non-zero padding. trim_morphology is 

141 # always called on path 2; pairing that with padding > 0 makes 

142 # the regression for audit I-2 (double padding) observable 

143 # rather than masked by clipping or padding=0. 

144 bbox, morph = init_monotonic_morph( 

145 self.detect.copy(), 

146 center, 

147 full_box, 

148 2, # padding 

149 False, # normalize 

150 monotonicity, # monotonicity 

151 0.2, # threshold 

152 ) 

153 truth = monotonicity(self.detect.copy(), local_center) 

154 truth[truth < 0.2] = 0 

155 self.assertBoxEqual(bbox, Box((49, 47), origin=(1008, 2001))) 

156 assert_array_equal(morph, truth) 

157 self.assertEqual(morph.dtype, np.float32) 

158 

159 # Test zero morphology 

160 zeros = np.zeros(self.detect.shape) 

161 bbox, morph = init_monotonic_morph(zeros, center, full_box, monotonicity=monotonicity) 

162 self.assertBoxEqual(bbox, Box((0, 0), (1000, 2000))) 

163 self.assertIsNone(morph) 

164 

165 def test_multifit_spectra(self): 

166 bands = ("g", "r", "i") 

167 variance = np.ones((3, 35, 35), dtype=np.float32) 

168 weights = 1 / variance 

169 psfs = np.array([integrated_circular_gaussian(sigma=sigma) for sigma in [1.05, 0.9, 1.2]]) 

170 psfs = psfs.astype(np.float32) 

171 model_psf = integrated_circular_gaussian(sigma=0.8).astype(np.float32) 

172 

173 # The spectrum of each source 

174 spectra = np.array( 

175 [ 

176 [31, 10, 0], 

177 [0, 5, 20], 

178 [15, 8, 3], 

179 [20, 3, 4], 

180 [0, 30, 60], 

181 ], 

182 dtype=np.float32, 

183 ) 

184 

185 # Use a point source for all of the sources 

186 morphs = [ 

187 integrated_circular_gaussian(sigma=sigma).astype(np.float32) 

188 for sigma in [0.8, 3.1, 1.1, 2.1, 1.5] 

189 ] 

190 # Make the second component a disk component 

191 morphs[1] = scipy_convolve(morphs[1], model_psf, mode="same") 

192 

193 # Give the first two components the same center, and unique centers 

194 # for the remaining sources 

195 centers = [ 

196 (10, 12), 

197 (10, 12), 

198 (20, 23), 

199 (20, 10), 

200 (25, 20), 

201 ] 

202 

203 # Create the Observation 

204 test_data = ObservationData(bands, psfs, spectra, morphs, centers, model_psf, dtype=np.float32) 

205 observation = Observation( 

206 test_data.convolved, 

207 variance, 

208 weights, 

209 psfs, 

210 model_psf[None], 

211 bands=bands, 

212 ) 

213 

214 fit_spectra = multifit_spectra(observation, test_data.morphs) 

215 self.assertEqual(fit_spectra.dtype, spectra.dtype) 

216 assert_almost_equal(fit_spectra, spectra, decimal=5) 

217 

218 def test_psf_component_at_boundary(self): 

219 """``get_psf_component`` must extract the surviving region of 

220 the model PSF when the source center is close enough to the 

221 observation boundary that the PSF box is clipped. 

222 

223 Audit finding I-3: the original code created the Image with 

224 ``yx0=bbox.origin`` (the *intersection's* origin) instead of 

225 the original PSF origin, so ``[bbox]`` returned the top-left 

226 of the PSF rather than the portion of the PSF that survived 

227 the clip. The PSF was therefore spatially misaligned with the 

228 actual source center. 

229 """ 

230 init = FactorizedInitialization(self.observation, self.centers) 

231 model_psf = self.observation.model_psf[0] 

232 

233 # Case 1: positive psf_bbox.origin. Center at the top-left 

234 # corner of the observation (origin (1000, 2000)): the 15x15 

235 # model PSF (py=px=7) extends 7 rows above and 3 columns to 

236 # the left of the observation bbox, so 7 rows and 3 columns 

237 # are clipped. psf_bbox.origin = (993, 1997) — both positive. 

238 center = (1000, 2004) 

239 component = init.get_psf_component(center) 

240 self.assertBoxEqual(component.bbox, Box((8, 12), origin=(1000, 2000))) 

241 # The surviving region is psf[7:15, 3:15] — the bottom-right 

242 # of the PSF, not the top-left. 

243 assert_array_equal(component.morph, model_psf[7:15, 3:15]) 

244 

245 # Case 2: negative psf_bbox.origin. Build a synthetic 

246 # observation at origin (0, 0) and place a center near the 

247 # corner so psf_bbox.origin = (-5, -2) — both negative. The 

248 # negative-origin path must still produce the correct 

249 # surviving region. ``Box.slices`` rejects negative origins, 

250 # so this also guards against future refactors that would 

251 # call ``.slices`` on ``psf_bbox`` directly. 

252 bands = ("r",) 

253 shape = (30, 30) 

254 images = np.ones((1,) + shape, dtype=np.float32) 

255 variance = np.ones((1,) + shape, dtype=np.float32) 

256 psfs = np.array([integrated_circular_gaussian(sigma=1.0)], dtype=np.float32) 

257 small_model_psf = integrated_circular_gaussian(sigma=0.8).astype(np.float32) 

258 small_obs = Observation( 

259 Image(images, bands=bands, yx0=(0, 0)), 

260 Image(variance, bands=bands, yx0=(0, 0)), 

261 Image(1 / variance, bands=bands, yx0=(0, 0)), 

262 psfs, 

263 small_model_psf[None], 

264 bands=bands, 

265 ) 

266 small_init = FactorizedInitialization(small_obs, [(2, 5)]) 

267 component = small_init.get_psf_component((2, 5)) 

268 self.assertBoxEqual(component.bbox, Box((10, 13), origin=(0, 0))) 

269 assert_array_equal(component.morph, small_model_psf[5:15, 2:15]) 

270 

271 # Case 3: PSF footprint does not overlap the observation at 

272 # all -> raise an informative error rather than silently 

273 # producing a degenerate component. 

274 with self.assertRaises(ValueError): 

275 small_init.get_psf_component((-100, -100)) 

276 

277 def test_get_psf_component_zero_psf_spectrum(self): 

278 """``get_psf_component`` must produce a finite spectrum even 

279 when one of the per-band ``psf_spectrum`` values is zero. 

280 

281 Audit finding I-10: dividing by ``self.psf_spectrum`` produces 

282 ``inf`` (or ``nan`` for 0/0) for any band with a degenerate 

283 model PSF whose central value is zero. The trailing 

284 ``spectrum[spectrum < 0] = 0`` mask does not catch either. 

285 Same fix pattern as I-4. 

286 """ 

287 init = FactorizedInitialization(self.observation, self.centers) 

288 # Force one band's psf_spectrum to zero. Real model PSFs always 

289 # have a positive central pixel, but a degenerate PSF could 

290 # exhibit this — and the masked branch should still be finite. 

291 init.psf_spectrum[0] = 0 

292 center = (int(self.centers[0][0]), int(self.centers[0][1])) 

293 component = init.get_psf_component(center) 

294 np.testing.assert_array_equal(np.isfinite(component.spectrum), True) 

295 np.testing.assert_array_equal(component.spectrum >= 0, True) 

296 # The zero-psf_spectrum band must yield zero flux rather than 

297 # inf or a saturating large value. 

298 self.assertEqual(component.spectrum[0], 0) 

299 

300 def test_get_single_component_zero_convolved(self): 

301 """``get_single_component`` must produce a finite spectrum 

302 even when the convolved detection image is zero at the source 

303 center. 

304 

305 Audit finding I-4: ``spectrum = images / convolved`` produces 

306 ``inf`` (or ``nan`` for 0/0) at zero-convolved pixels, and 

307 the subsequent ``spectrum[spectrum < 0] = 0`` does not catch 

308 either, so the component is initialized with non-finite 

309 flux. 

310 """ 

311 init = FactorizedInitialization(self.observation, self.centers) 

312 center = (int(self.centers[0][0]), int(self.centers[0][1])) 

313 local_center = ( 

314 center[0] - init.observation.bbox.origin[0], 

315 center[1] - init.observation.bbox.origin[1], 

316 ) 

317 # Force the convolved detection image to zero at the source 

318 # center across all bands. The detection image still has flux 

319 # at this pixel, so ``init_monotonic_morph`` returns a valid 

320 # morph and the spectrum branch is exercised. 

321 init.convolved.data[:, local_center[0], local_center[1]] = 0 

322 

323 thresh = np.mean(self.observation.noise_rms) * init.initial_bg_thresh 

324 component = init.get_single_component(center, init.detect.copy(), thresh, init.padding) 

325 assert component is not None 

326 np.testing.assert_array_equal(np.isfinite(component.spectrum), True) 

327 np.testing.assert_array_equal(component.spectrum >= 0, True) 

328 

329 def test_factorized_chi2_init(self): 

330 # Test default parameters 

331 init = FactorizedInitialization(self.observation, self.centers) 

332 self.assertEqual(init.observation, self.observation) 

333 self.assertEqual(init.min_snr, 50) 

334 self.assertIsNone(init.monotonicity) 

335 self.assertEqual(init.disk_percentile, 25) 

336 self.assertEqual(init.thresh, 0.5) 

337 self.assertTupleEqual((init.py, init.px), (7, 7)) 

338 self.assertEqual(len(init.sources), 7) 

339 for src in init.sources: 

340 self.assertEqual(src.get_model().dtype, np.float32) 

341 

342 centers = tuple(tuple(center.astype(int)) for center in self.centers) + ((1000, 2004),) 

343 init = FactorizedInitialization(self.observation, centers) 

344 self.assertEqual(len(init.sources), 8) 

345 for src in init.sources: 

346 self.assertEqual(src.get_model().dtype, np.float32) 

347 

348 @deprecated( 

349 version="v29.0", 

350 reason="FactorizedWaveletInitialization is deprecated and will be removed after v29.0", 

351 ) 

352 def test_wavelet_init_source_falls_back_to_psf(self): 

353 """init_source must always return a Source with at least one 

354 component, even when individual init paths fail. 

355 

356 Audit finding I-1: when get_single_component returned None or 

357 the two-component path produced all-zero spectra, ``components`` 

358 was either left unbound (UnboundLocalError) or set to an empty 

359 list. Both cases must now fall back to a PSF component. 

360 """ 

361 init = FactorizedWaveletInitialization(self.observation, self.centers) 

362 int_centers = [(int(round(c[0])), int(round(c[1]))) for c in self.centers] 

363 

364 # Failure mode 1: get_single_component always returns None. 

365 # Centers hitting the single-component branch or the 

366 # two-component fallback must fall back to PSF rather than 

367 # raising or producing an empty Source. 

368 with patch.object(FactorizedWaveletInitialization, "get_single_component", return_value=None): 

369 for center in int_centers: 

370 source = init.init_source(center) 

371 self.assertGreater(len(source.components), 0) 

372 

373 # Failure mode 2: two-component path returns all-zero spectra 

374 # for both bulge and disk -> empty components list, also a fall 

375 # back case. 

376 n_bands = len(self.observation.bands) 

377 with patch( 

378 "lsst.scarlet.lite.initialization.multifit_spectra", 

379 return_value=np.zeros((2, n_bands), dtype=np.float32), 

380 ): 

381 for center in int_centers: 

382 source = init.init_source(center) 

383 self.assertGreater(len(source.components), 0) 

384 

385 @deprecated( 

386 version="v29.0", 

387 reason="FactorizedWaveletInitialization is deprecated and will be removed after v29.0", 

388 ) 

389 def test_factorized_wavelet_init(self): 

390 # Test default parameters 

391 init = FactorizedWaveletInitialization(self.observation, self.centers) 

392 self.assertEqual(init.observation, self.observation) 

393 self.assertEqual(init.min_snr, 50) 

394 self.assertIsNone(init.monotonicity) 

395 self.assertTupleEqual((init.py, init.px), (7, 7)) 

396 self.assertEqual(len(init.sources), 7) 

397 components = np.sum([len(src.components) for src in init.sources]) 

398 self.assertEqual(components, 8) 

399 for src in init.sources: 

400 self.assertEqual(src.get_model().dtype, np.float32)