Coverage for tests/test_component.py: 16%

258 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 08:24 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from abc import ABC 

25from copy import deepcopy 

26from typing import Any, Callable 

27 

28import numpy as np 

29from lsst.scarlet.lite import Box, Image, Parameter 

30from lsst.scarlet.lite.component import ( 

31 Component, 

32 CubeComponent, 

33 FactorizedComponent, 

34 default_adaprox_parameterization, 

35 default_fista_parameterization, 

36) 

37from lsst.scarlet.lite.operators import Monotonicity 

38from lsst.scarlet.lite.utils import integrated_circular_gaussian 

39from numpy.testing import assert_almost_equal, assert_array_equal 

40from utils import ScarletTestCase 

41 

42 

43class DummyComponent(Component): 

44 def resize(self) -> bool: 

45 pass 

46 

47 def update(self, it: int, input_grad: np.ndarray): 

48 pass 

49 

50 def get_model(self) -> Image: 

51 pass 

52 

53 def parameterize(self, parameterization: Callable) -> None: 

54 parameterization(self) 

55 

56 def to_data(self) -> DummyComponent: 

57 pass 

58 

59 def __getitem__(self, indices: Any) -> DummyComponent: 

60 pass 

61 

62 def __copy__(self) -> DummyComponent: 

63 pass 

64 

65 def __deepcopy__(self, memo: dict[int, Any]) -> DummyComponent: 

66 pass 

67 

68 

69class _ComponentTestBase(ABC): 

70 def test_slice(self): 

71 component = self.component 

72 component_sliced = component["g":"r"] 

73 self.assertTupleEqual(component_sliced.bands, ("g", "r")) 

74 np.testing.assert_array_equal(component_sliced.get_model(), component.get_model().data[0:2]) 

75 

76 def test_reorder(self): 

77 component = self.component 

78 indices = ("i", "g", "r") 

79 component_reordered = component["i", "g", "r"] 

80 self.assertTupleEqual(component_reordered.bands, indices) 

81 np.testing.assert_array_equal( 

82 component_reordered.get_model(), 

83 component.get_model().data[(2, 0, 1),], 

84 ) 

85 

86 component_reordered = component["igr"] 

87 self.assertTupleEqual(component_reordered.bands, indices) 

88 np.testing.assert_array_equal( 

89 component_reordered.get_model(), 

90 component.get_model().data[(2, 0, 1),], 

91 ) 

92 

93 def test_subset(self): 

94 component = self.component 

95 indices = ("r",) 

96 component_subset = component["r"] 

97 self.assertTupleEqual(component_subset.bands, indices) 

98 np.testing.assert_array_equal( 

99 component_subset.get_model(), 

100 component.get_model().data[1:2,], 

101 ) 

102 

103 component = self.component.copy(deep=True) 

104 component._bands = ("ab", "cd", "ef") 

105 indices = "ab" 

106 component_reordered = component["ab"] 

107 self.assertTupleEqual(component_reordered.bands, (indices,)) 

108 np.testing.assert_array_equal( 

109 component_reordered.get_model(), 

110 component.get_model().data[0:1,], 

111 ) 

112 

113 def test_indexing_errors(self): 

114 component = self.component 

115 print("bands", component.bands) 

116 with self.assertRaises(IndexError): 

117 component["z"] 

118 

119 with self.assertRaises(IndexError): 

120 component["r":"z"] 

121 

122 with self.assertRaises(IndexError): 

123 component["z":"i"] 

124 

125 with self.assertRaises(IndexError): 

126 component["g", "z", "i"] 

127 

128 with self.assertRaises(IndexError): 

129 component[Box((0, 0), (10, 10))] 

130 

131 with self.assertRaises(IndexError): 

132 component[:, 10:20, 10:20] 

133 

134 with self.assertRaises(IndexError): 

135 component[1:] 

136 

137 with self.assertRaises(IndexError): 

138 component[1] 

139 

140 with self.assertRaises(IndexError): 

141 component[0, 1] 

142 

143 

144class TestFactorizedComponent(_ComponentTestBase, ScarletTestCase): 

145 def setUp(self) -> None: 

146 spectrum = np.arange(3).astype(np.float32) 

147 morph = np.arange(20).reshape(4, 5).astype(np.float32) 

148 bands = ("g", "r", "i") 

149 bbox = Box((4, 5), (22, 31)) 

150 self.model_box = Box((100, 100)) 

151 center = (24, 33) 

152 

153 self.component = FactorizedComponent( 

154 bands, 

155 spectrum, 

156 morph, 

157 bbox, 

158 center, 

159 ) 

160 

161 self.bands = bands 

162 self.spectrum = spectrum 

163 self.morph = morph 

164 self.full_shape = (3, 100, 100) 

165 

166 def test_constructor(self): 

167 # Test with only required parameters 

168 component = FactorizedComponent( 

169 self.bands, 

170 self.spectrum, 

171 self.morph, 

172 self.component.bbox, 

173 ) 

174 

175 self.assertIsInstance(component._spectrum, Parameter) 

176 assert_array_equal(component.spectrum, self.spectrum) 

177 self.assertIsInstance(component._morph, Parameter) 

178 assert_array_equal(component.morph, self.morph) 

179 self.assertBoxEqual(component.bbox, self.component.bbox) 

180 self.assertIsNone(component.peak) 

181 self.assertIsNone(component.bg_rms) 

182 self.assertEqual(component.bg_thresh, 0.25) 

183 self.assertEqual(component.floor, 1e-20) 

184 self.assertTupleEqual(component.shape, (3, 4, 5)) 

185 

186 # Test that parameters are passed through 

187 center = self.component.peak 

188 bg_rms = np.arange(5) / 10 

189 bg_thresh = 0.9 

190 floor = 1e-10 

191 

192 component = FactorizedComponent( 

193 self.bands, 

194 self.spectrum, 

195 self.morph, 

196 self.component.bbox, 

197 center, 

198 bg_rms, 

199 bg_thresh, 

200 floor, 

201 ) 

202 

203 self.assertTupleEqual(component.peak, center) 

204 assert_array_equal(component.bg_rms, bg_rms) # type: ignore 

205 self.assertEqual(component.bg_thresh, bg_thresh) 

206 self.assertEqual(component.floor, floor) 

207 self.assertEqual(component.get_model().dtype, np.float32) 

208 

209 def test_get_model(self): 

210 component = self.component 

211 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :]) 

212 

213 # Insert component into a larger model 

214 full_model = np.zeros(self.full_shape) 

215 full_model[:, 22:26, 31:36] = self.spectrum[:, None, None] * self.morph[None, :, :] 

216 

217 test_model = Image(np.zeros(self.full_shape), bands=self.bands) 

218 test_model += component.get_model() 

219 

220 assert_array_equal(test_model.data, full_model) 

221 

222 def test_gradients(self): 

223 component = self.component 

224 morph = self.morph 

225 spectrum = self.spectrum 

226 

227 input_grad = np.array([morph, 2 * morph, 3 * morph]) 

228 true_spectrum_grad = np.array( 

229 [ 

230 np.sum(morph**2), 

231 np.sum(2 * morph**2), 

232 np.sum(3 * morph**2), 

233 ] 

234 ) 

235 assert_almost_equal(component.grad_spectrum(input_grad, spectrum, morph), true_spectrum_grad) 

236 

237 true_morph_grad = np.sum(input_grad * spectrum[:, None, None], axis=0) 

238 assert_almost_equal(component.grad_morph(input_grad, morph, spectrum), true_morph_grad) 

239 

240 def test_proximal_operators(self): 

241 # Test spectrum positivity, morph threshold, and monotonicity 

242 spectrum = np.array([-1, 2, 3], dtype=float) 

243 morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, -1]], dtype=float) 

244 bbox = Box((3, 3), (10, 10)) 

245 morph_bbox = Box((100, 100)) 

246 center = (11, 11) 

247 monotonicity = Monotonicity((101, 101), fit_radius=0) 

248 

249 component = FactorizedComponent( 

250 self.bands, 

251 spectrum.copy(), 

252 morph.copy(), 

253 bbox, 

254 center, 

255 bg_rms=np.array([1, 1, 1]), 

256 bg_thresh=0.5, 

257 monotonicity=monotonicity, 

258 ) 

259 

260 proxed_spectrum = np.array([1e-20, 2, 3]) 

261 proxed_morph = np.array([[2.6666666666666667, 2, 1], [1, 5, 3], [0, 4, 0]]) 

262 proxed_morph = proxed_morph / 5 

263 

264 component.prox_spectrum(component.spectrum) 

265 component.prox_morph(component.morph) 

266 

267 assert_array_equal(component.spectrum, proxed_spectrum) 

268 assert_array_equal(component.morph, proxed_morph) 

269 

270 component = FactorizedComponent( 

271 self.bands, 

272 spectrum.copy(), 

273 morph.copy(), 

274 bbox, 

275 None, 

276 ) 

277 

278 proxed_spectrum = np.array([1e-20, 2, 3]) 

279 proxed_morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, 0]]) 

280 proxed_morph = proxed_morph / 10 

281 

282 component.prox_spectrum(component.spectrum) 

283 component.prox_morph(component.morph) 

284 

285 assert_array_equal(component.spectrum, proxed_spectrum) 

286 assert_array_equal(component.morph, proxed_morph) 

287 

288 self.assertFalse(component.resize(morph_bbox)) 

289 

290 def test_prox_morph_enforces_positivity_with_bg_thresh(self): 

291 """Audit finding K-2: ``prox_morph`` previously enforced 

292 positivity only on the ``else`` branch (no ``bg_thresh``). 

293 With ``bg_thresh`` active, a negative morph pixel survived 

294 when ``spectrum * morph >= bg_thresh`` in at least one band 

295 — possible whenever the spectrum has a negative band, since 

296 ``neg * neg`` is positive. Positivity must always be 

297 enforced. 

298 """ 

299 bands = ("g", "r", "i") 

300 # Mixed-sign spectrum: the negative band makes the model 

301 # positive at a negative-morph pixel for that band, so the 

302 # threshold check no longer catches it. 

303 spectrum = np.array([-1.0, 1.0, 1.0]) 

304 morph = np.full((3, 3), 0.5, dtype=float) 

305 morph[0, 0] = -0.3 

306 

307 component = FactorizedComponent( 

308 bands, 

309 spectrum, 

310 morph, 

311 Box((3, 3), (0, 0)), 

312 peak=None, 

313 bg_rms=np.array([0.1, 0.1, 0.1]), 

314 bg_thresh=0.5, 

315 ) 

316 

317 proxed = component.prox_morph(component.morph.copy()) 

318 # The negative pixel must be zeroed by the positivity guard, 

319 # not slip through because spectrum*morph >= bg_thresh in 

320 # the negative-spectrum band. 

321 self.assertEqual(proxed[0, 0], 0) 

322 

323 def test_resize(self): 

324 spectrum = np.array([1, 2, 3], dtype=float) 

325 morph = np.zeros((10, 10), dtype=float) 

326 morph[3:6, 5:8] = np.arange(9).reshape(3, 3) 

327 bbox = Box((10, 10), (3, 5)) 

328 

329 morph_bbox = Box((100, 100)) 

330 monotonicity = Monotonicity((101, 101), fit_radius=0) 

331 

332 component = FactorizedComponent( 

333 self.bands, 

334 spectrum.copy(), 

335 morph.copy(), 

336 bbox, 

337 None, 

338 bg_rms=np.array([1, 1, 1]), 

339 bg_thresh=0.5, 

340 monotonicity=monotonicity, 

341 padding=1, 

342 ) 

343 

344 self.assertTupleEqual(component.morph.shape, (10, 10)) 

345 self.assertIsNone(component.component_center) 

346 

347 component.resize(morph_bbox) 

348 self.assertTupleEqual(component.morph.shape, (5, 5)) 

349 self.assertTupleEqual(component.bbox.origin, (5, 9)) 

350 self.assertTupleEqual(component.bbox.shape, (5, 5)) 

351 self.assertIsNone(component.component_center) 

352 

353 def test_parameterization(self): 

354 component = self.component 

355 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :]) 

356 

357 component.parameterize(default_fista_parameterization) 

358 helpers = set(component._morph.helpers.keys()) 

359 self.assertSetEqual(helpers, {"z"}) 

360 component.parameterize(default_adaprox_parameterization) 

361 helpers = set(component._morph.helpers.keys()) 

362 self.assertSetEqual(helpers, {"m", "v", "vhat"}) 

363 

364 params = (tuple("grizy"), Box((5, 5))) 

365 with self.assertRaises(NotImplementedError): 

366 default_fista_parameterization(DummyComponent(*params)) 

367 

368 with self.assertRaises(NotImplementedError): 

369 default_adaprox_parameterization(DummyComponent(*params)) 

370 

371 def test_shallow_copy(self): 

372 component = self.component 

373 component.monotonicity = Monotonicity((11, 11), fit_radius=0) 

374 

375 component_copy = component.copy() 

376 

377 self.assertIsNot(component, component_copy) 

378 np.testing.assert_array_equal(component._spectrum.x, component_copy._spectrum.x) 

379 np.testing.assert_array_equal(component._morph.x, component_copy._morph.x) 

380 self.assertIs(component.bbox, component_copy.bbox) 

381 self.assertIs(component.peak, component_copy.peak) 

382 self.assertIs(component.bg_thresh, component_copy.bg_thresh) 

383 self.assertIs(component.monotonicity, component_copy.monotonicity) 

384 

385 def test_deep_copy(self): 

386 component = self.component 

387 component.monotonicity = Monotonicity((11, 11), fit_radius=0) 

388 component_deepcopy = component.copy(deep=True) 

389 

390 self.assertIsNot(component, component_deepcopy) 

391 

392 np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x) 

393 component_deepcopy._spectrum.x += 1 

394 with self.assertRaises(AssertionError): 

395 np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x) 

396 

397 np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x) 

398 component_deepcopy._morph.x += 1 

399 with self.assertRaises(AssertionError): 

400 np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x) 

401 

402 self.assertIsNot(component.bbox, component_deepcopy.bbox) 

403 self.assertBoxEqual(component.bbox, component_deepcopy.bbox) 

404 

405 self.assertTupleEqual(component.peak, component_deepcopy.peak) 

406 self.assertEqual(component.bg_thresh, component_deepcopy.bg_thresh) 

407 self.assertIsNot(component.monotonicity, component_deepcopy.monotonicity) 

408 

409 

410class TestCubeComponent(_ComponentTestBase, ScarletTestCase): 

411 def setUp(self) -> None: 

412 super().setUp() 

413 self.bands = tuple("gri") 

414 peak = (27, 32) 

415 bbox = Box((15, 15), (20, 25)) 

416 morph = integrated_circular_gaussian(sigma=0.8).astype(np.float32) 

417 spectrum = np.arange(3, dtype=np.float32) 

418 model = morph[None, :, :] * spectrum[:, None, None] 

419 model_image = Image(model, yx0=bbox.origin, bands=self.bands) 

420 self.component = CubeComponent(model=model_image, peak=peak) 

421 

422 def test_constructor(self): 

423 component = self.component 

424 self.assertIsInstance(component._model, Image) 

425 np.testing.assert_array_equal(component._model.data, self.component._model.data) 

426 self.assertTupleEqual(component.bands, self.bands) 

427 self.assertBoxEqual(component.bbox, Box((15, 15), (20, 25))) 

428 self.assertTupleEqual(component.peak, (27, 32)) 

429 

430 def test_shallow_copy(self): 

431 component = self.component 

432 component_copy = component.copy() 

433 

434 self.assertIsNot(component_copy, component) 

435 self.assertTupleEqual(component_copy.peak, component.peak) 

436 self.assertImageEqual(component_copy._model, component._model) 

437 

438 def test_deep_copy(self): 

439 component = self.component 

440 component_copy = component.copy(deep=True) 

441 

442 self.assertIsNot(component, component_copy) 

443 

444 self.assertTupleEqual(component_copy.peak, component.peak) 

445 self.assertImageEqual(component_copy._model, component._model) 

446 with self.assertRaises(AssertionError): 

447 component_copy._model._data -= 1 

448 self.assertImageEqual(component_copy._model, component._model) 

449 

450 def test_deep_copy_preserves_memo(self): 

451 """Audit finding K-6: ``__deepcopy__`` must call ``deepcopy`` 

452 with the memo dict on its sub-objects so shared references 

453 in the input graph remain shared in the copy. Previously 

454 ``self._model.copy()`` allocated a fresh image regardless of 

455 whether ``self._model`` already appeared in the memo. 

456 """ 

457 c1 = CubeComponent(model=self.component._model, peak=self.component.peak) 

458 c2 = CubeComponent(model=self.component._model, peak=self.component.peak) 

459 # Two CubeComponents sharing the same Image instance. 

460 self.assertIs(c1._model, c2._model) 

461 

462 c1_copy, c2_copy = deepcopy([c1, c2]) 

463 

464 # The model should be different in the deepcopy, 

465 # since CubeComponent's deepcopy creates a new Image instance. 

466 self.assertIsNot(c1_copy._model, c1._model) 

467 

468 # The shared Image must still be shared in the deepcopy, 

469 # otherwise larger object graphs that rely on identity 

470 # (e.g. observation sharing) silently fork. 

471 self.assertIs(c1_copy._model, c2_copy._model)