Coverage for tests / test_detect.py: 13%

159 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 00:46 -0700

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23 

24import numpy as np 

25from lsst.scarlet.lite import Box, Image 

26from lsst.scarlet.lite.detect import ( 

27 bbox_to_bounds, 

28 bounds_to_bbox, 

29 detect_footprints, 

30 footprints_to_image, 

31 get_detect_wavelets, 

32 get_wavelets, 

33) 

34from lsst.scarlet.lite.detect_pybind11 import ( 

35 Footprint, 

36 Peak, 

37 get_connected_multipeak, 

38 get_connected_pixels, 

39 get_footprints, 

40) 

41from lsst.scarlet.lite.utils import integrated_circular_gaussian 

42from numpy.testing import assert_array_equal 

43from utils import ScarletTestCase 

44 

45 

46class TestDetect(ScarletTestCase): 

47 def setUp(self): 

48 centers = ( 

49 (17, 9), 

50 (27, 14), 

51 (41, 25), 

52 (10, 42), 

53 ) 

54 sigmas = (1.0, 0.95, 0.9, 1.5) 

55 

56 sources = [] 

57 for sigma, center in zip(sigmas, centers): 

58 yx0 = center[0] - 7, center[1] - 7 

59 source = Image(integrated_circular_gaussian(sigma=sigma).astype(np.float32), yx0=yx0) 

60 sources.append(source) 

61 

62 image = Image.from_box(Box((51, 51))) 

63 for source in sources: 

64 image += source 

65 image.data[30:32, 40] = 0.5 

66 

67 self.image = image 

68 self.centers = centers 

69 self.sources = sources 

70 

71 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

72 filename = os.path.abspath(filename) 

73 self.hsc_data = np.load(filename) 

74 

75 def tearDown(self): 

76 del self.hsc_data 

77 

78 def test_connected(self): 

79 image = self.image.copy() 

80 

81 # Check that the first 3 footprints are all connected 

82 # with thresholding at zero 

83 truth = self.sources[0] + self.sources[1] + self.sources[2] 

84 bbox = truth.bbox 

85 truth = truth.data > 0 

86 

87 unchecked = np.ones(self.image.shape, dtype=bool) 

88 footprint = np.zeros(self.image.shape, dtype=bool) 

89 y, x = self.centers[0] 

90 get_connected_pixels( 

91 y, 

92 x, 

93 image.data, 

94 unchecked, 

95 footprint, 

96 np.array([y, y, x, x]).astype(np.int32), 

97 0, 

98 ) 

99 assert_array_equal(footprint[bbox.slices], truth) 

100 

101 # Check that only the first 2 footprints are all connected 

102 # with thresholding at 1e-15 

103 truth = self.sources[0] + self.sources[1] 

104 bbox = truth.bbox 

105 truth = truth.data > 1e-15 

106 

107 unchecked = np.ones(self.image.shape, dtype=bool) 

108 footprint = np.zeros(self.image.shape, dtype=bool) 

109 y, x = self.centers[0] 

110 get_connected_pixels( 

111 y, 

112 x, 

113 image.data, 

114 unchecked, 

115 footprint, 

116 np.array([y, y, x, x]).astype(np.int32), 

117 1e-15, 

118 ) 

119 assert_array_equal(footprint[bbox.slices], truth) 

120 

121 # Test finding all peaks 

122 footprint = get_connected_multipeak(self.image.data, self.centers, 1e-15) 

123 truth = self.image.data > 1e-15 

124 truth[30:32, 40] = False 

125 assert_array_equal(footprint, truth) 

126 

127 def _check_footprints(self, footprints): 

128 self.assertEqual(len(footprints), 3) 

129 

130 # The first footprint has a single peak 

131 assert_array_equal(footprints[0].data, self.sources[3].data > 1e-15) 

132 self.assertEqual(len(footprints[0].peaks), 1) 

133 self.assertBoxEqual(footprints[0].bbox, self.sources[3].bbox) 

134 self.assertEqual(footprints[0].peaks[0].y, self.centers[3][0]) 

135 self.assertEqual(footprints[0].peaks[0].x, self.centers[3][1]) 

136 

137 # The second footprint has two peaks 

138 truth = self.sources[0] + self.sources[1] 

139 assert_array_equal(footprints[1].data, truth.data > 1e-15) 

140 self.assertEqual(len(footprints[1].peaks), 2) 

141 self.assertBoxEqual(footprints[1].bbox, truth.bbox) 

142 self.assertEqual(footprints[1].peaks[0].y, self.centers[1][0]) 

143 self.assertEqual(footprints[1].peaks[0].x, self.centers[1][1]) 

144 self.assertEqual(footprints[1].peaks[1].y, self.centers[0][0]) 

145 self.assertEqual(footprints[1].peaks[1].x, self.centers[0][1]) 

146 

147 # The third footprint has a single peak 

148 assert_array_equal(footprints[2].data, self.sources[2].data > 1e-15) 

149 self.assertEqual(len(footprints[2].peaks), 1) 

150 self.assertBoxEqual(footprints[2].bbox, self.sources[2].bbox) 

151 self.assertEqual(footprints[2].peaks[0].y, self.centers[2][0]) 

152 self.assertEqual(footprints[2].peaks[0].x, self.centers[2][1]) 

153 

154 truth = 1 * self.sources[3] + 2 * (self.sources[0] + self.sources[1]) + 3 * self.sources[2] 

155 truth.data[truth.data < 1e-15] = 0 

156 fp_image = footprints_to_image(footprints, truth.bbox) 

157 assert_array_equal(fp_image, truth.data) 

158 

159 def test_get_footprints(self): 

160 footprints = get_footprints(self.image.data, 1, 4, 1e-15, 1e-15, True) 

161 self._check_footprints(footprints) 

162 

163 def test_get_footprints_min_area_boundary(self): 

164 """A footprint that exactly meets ``min_area`` in a tight 

165 bounding box of the same area must be kept. 

166 

167 Audit finding D-1: the C++ pre-filter on the bounding-box 

168 area used strict ``>`` while the actual pixel count check 

169 uses ``>=``. A 2x2 filled square with ``min_area=4`` was 

170 therefore rejected by the pre-filter (4 > 4 is false) before 

171 the true area check (4 >= 4) ever ran. 

172 """ 

173 img = np.zeros((10, 10), dtype=np.float32) 

174 img[3:5, 5:7] = 1.0 # 2x2 filled square: area=4, bbox=2x2=4 

175 

176 # ``find_peaks=False`` to keep the test focused on the 

177 # min_area logic; with ``find_peaks=True`` an ambiguous 

178 # plateau can fail the peak check for unrelated reasons. 

179 footprints = get_footprints(img, 1, 4, 1e-15, 1e-15, False) 

180 self.assertEqual(len(footprints), 1) 

181 

182 # Sanity: with ``min_area=5`` the same footprint must be 

183 # rejected, confirming the boundary is tight. 

184 footprints = get_footprints(img, 1, 5, 1e-15, 1e-15, False) 

185 self.assertEqual(len(footprints), 0) 

186 

187 def _check_peaks(self, peaks): 

188 matched_peaks = [] 

189 for center in self.centers: 

190 for peak in peaks: 

191 if peak.y == center[0] and peak.x == center[1]: 

192 matched_peaks.append(peak) 

193 break 

194 self.assertEqual(len(matched_peaks), len(self.centers)) 

195 

196 def test_detect_footprints(self): 

197 # This method doesn't test for accurracy, since 

198 # there is no variance, so we set it to ones. 

199 variance = np.ones(self.image.shape, dtype=self.image.dtype) 

200 

201 footprints = detect_footprints( 

202 self.image.data[None, :, :], 

203 variance[None, :, :], 

204 scales=1, 

205 generation=2, 

206 origin=(0, 0), 

207 min_separation=1, 

208 min_area=4, 

209 peak_thresh=1e-15, 

210 footprint_thresh=1e-15, 

211 find_peaks=True, 

212 remove_high_freq=False, 

213 min_pixel_detect=1, 

214 ) 

215 

216 self.assertEqual(len(footprints), 3) 

217 peaks = [peak for footprint in footprints for peak in footprint.peaks] 

218 self._check_peaks(peaks) 

219 

220 footprints = detect_footprints( 

221 self.image.data[None, :, :], 

222 variance[None, :, :], 

223 scales=1, 

224 generation=1, 

225 min_separation=1, 

226 min_area=4, 

227 peak_thresh=1e-15, 

228 footprint_thresh=1e-15, 

229 find_peaks=True, 

230 remove_high_freq=True, 

231 min_pixel_detect=1, 

232 ) 

233 

234 self.assertEqual(len(footprints), 2) 

235 peaks = [peak for footprint in footprints for peak in footprint.peaks] 

236 self._check_peaks(peaks) 

237 

238 def test_detect_footprints_min_pixel_detect(self): 

239 """``min_pixel_detect`` requires the detection pixel to be 

240 above zero in at least N bands. Verify both that single-band 

241 input is filtered out entirely when ``min_pixel_detect=2``, 

242 and that multi-band input filters selectively. 

243 """ 

244 variance = np.ones(self.image.shape, dtype=self.image.dtype) 

245 

246 # Single-band: with min_pixel_detect=2 every pixel fails the 

247 # "at least 2 bands above 0" check, so nothing survives. 

248 footprints = detect_footprints( 

249 self.image.data[None, :, :], 

250 variance[None, :, :], 

251 scales=1, 

252 generation=2, 

253 origin=(0, 0), 

254 min_separation=1, 

255 min_area=4, 

256 peak_thresh=1e-15, 

257 footprint_thresh=1e-15, 

258 find_peaks=True, 

259 remove_high_freq=False, 

260 min_pixel_detect=2, 

261 ) 

262 self.assertEqual(len(footprints), 0) 

263 

264 # Two-band: band 0 only contains sources 0+1, band 1 only 

265 # contains sources 2+3. With min_pixel_detect=2 no pixel is 

266 # above zero in *both* bands, so nothing survives. 

267 band0 = self.sources[0] + self.sources[1] 

268 band1 = self.sources[2] + self.sources[3] 

269 full = Image.from_box(Box((51, 51))) 

270 b0 = (full + band0).data 

271 b1 = (full + band1).data 

272 images = np.stack([b0, b1]) 

273 variance2 = np.ones(images.shape, dtype=images.dtype) 

274 footprints = detect_footprints( 

275 images, 

276 variance2, 

277 scales=1, 

278 generation=2, 

279 origin=(0, 0), 

280 min_separation=1, 

281 min_area=4, 

282 peak_thresh=1e-15, 

283 footprint_thresh=1e-15, 

284 find_peaks=True, 

285 remove_high_freq=False, 

286 min_pixel_detect=2, 

287 ) 

288 self.assertEqual(len(footprints), 0) 

289 

290 # Sanity: with min_pixel_detect=1 the same multi-band input 

291 # produces the union of both bands' footprints. 

292 footprints = detect_footprints( 

293 images, 

294 variance2, 

295 scales=1, 

296 generation=2, 

297 origin=(0, 0), 

298 min_separation=1, 

299 min_area=4, 

300 peak_thresh=1e-15, 

301 footprint_thresh=1e-15, 

302 find_peaks=True, 

303 remove_high_freq=False, 

304 min_pixel_detect=1, 

305 ) 

306 self.assertGreater(len(footprints), 0) 

307 

308 def test_bounds_to_bbox(self): 

309 bounds = (3, 27, 11, 52) 

310 truth = Box((25, 42), (3, 11)) 

311 bbox = bounds_to_bbox(bounds) 

312 self.assertBoxEqual(bbox, truth) 

313 

314 # Check that the reverse operation also works 

315 new_bounds = bbox_to_bounds(bbox) 

316 self.assertTupleEqual(new_bounds, bounds) 

317 

318 def test_footprint(self): 

319 footprint = self.sources[0].data 

320 footprint[footprint < 1e-15] = 0 

321 bounds = [ 

322 self.sources[0].bbox.start[0], 

323 self.sources[0].bbox.stop[0] - 1, 

324 self.sources[0].bbox.start[1], 

325 self.sources[0].bbox.stop[1] - 1, 

326 ] 

327 print(bounds) 

328 peaks = [Peak(self.centers[0][0], self.centers[0][1], self.image.data[self.centers[0]])] 

329 footprint1 = Footprint(footprint, peaks, bounds) 

330 footprint = self.sources[1].data 

331 footprint[footprint < 1e-15] = 0 

332 bounds = [ 

333 self.sources[1].bbox.start[0], 

334 self.sources[1].bbox.stop[0] - 1, 

335 self.sources[1].bbox.start[1], 

336 self.sources[1].bbox.stop[1] - 1, 

337 ] 

338 print(bounds) 

339 peaks = [Peak(self.centers[1][0], self.centers[1][1], self.image.data[self.centers[1]])] 

340 footprint2 = Footprint(footprint, peaks, bounds) 

341 

342 truth = self.sources[0] + self.sources[1] 

343 truth.data[truth.data < 1e-15] = 0 

344 image = footprints_to_image([footprint1, footprint2], truth.bbox) 

345 assert_array_equal(image, truth.data) 

346 

347 # Test intersection 

348 truth = (self.sources[0] > 1e-15) & (self.sources[1] > 1e-15) 

349 intersection = footprint1.intersection(footprint2) 

350 self.assertImageEqual(intersection, truth) 

351 

352 # Test union 

353 truth = (self.sources[0] > 1e-15) | (self.sources[1] > 1e-15) 

354 union = footprint1.union(footprint2) 

355 self.assertImageEqual(union, truth) 

356 

357 def test_get_wavelets(self): 

358 images = self.hsc_data["images"] 

359 variance = self.hsc_data["variance"] 

360 wavelets = get_wavelets(images, variance) 

361 

362 self.assertTupleEqual(wavelets.shape, (5, 5, 58, 48)) 

363 self.assertEqual(wavelets.dtype, np.float32) 

364 

365 def test_get_detect_wavelets(self): 

366 images = self.hsc_data["images"] 

367 variance = self.hsc_data["variance"] 

368 wavelets = get_detect_wavelets(images, variance) 

369 

370 self.assertTupleEqual(wavelets.shape, (4, 58, 48))