Coverage for tests / test_wavelet.py: 19%

74 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 07:47 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23 

24import numpy as np 

25from lsst.scarlet.lite.wavelet import ( 

26 apply_wavelet_denoising, 

27 get_multiresolution_support, 

28 multiband_starlet_reconstruction, 

29 multiband_starlet_transform, 

30 starlet_reconstruction, 

31 starlet_transform, 

32) 

33from numpy.testing import assert_almost_equal 

34from utils import ScarletTestCase 

35 

36 

37class TestWavelet(ScarletTestCase): 

38 def setUp(self) -> None: 

39 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

40 filename = os.path.abspath(filename) 

41 self.data = np.load(filename) 

42 

43 def tearDown(self) -> None: 

44 del self.data 

45 

46 def test_transform_inverse(self): 

47 image = np.sum(self.data["images"], axis=0) 

48 starlets = starlet_transform(image, scales=3) 

49 self.assertEqual(starlets.dtype, np.float32) 

50 

51 # Test number of levels 

52 self.assertTupleEqual(starlets.shape, (4, 58, 48)) 

53 

54 # Test inverse 

55 inverse = starlet_reconstruction(starlets) 

56 assert_almost_equal(inverse, image, decimal=5) 

57 self.assertEqual(inverse.dtype, starlets.dtype) 

58 

59 # Test using gen1 starlets 

60 starlets = starlet_transform(image, scales=3, generation=1) 

61 

62 # Test number of levels 

63 self.assertTupleEqual(starlets.shape, (4, 58, 48)) 

64 

65 # Test inverse 

66 inverse = starlet_reconstruction(starlets, generation=1) 

67 assert_almost_equal(inverse, image, decimal=5) 

68 

69 def test_multiband_transform(self): 

70 image = self.data["images"] 

71 starlets = multiband_starlet_transform(image, scales=3) 

72 self.assertEqual(starlets.dtype, np.float32) 

73 

74 # Test number of levels 

75 self.assertTupleEqual(starlets.shape, (4, 5, 58, 48)) 

76 

77 # Test inverse 

78 inverse = multiband_starlet_reconstruction(starlets) 

79 assert_almost_equal(inverse, image, decimal=5) 

80 self.assertEqual(inverse.dtype, np.float32) 

81 

82 def test_extras(self): 

83 # This is code that is not used in production, 

84 # but that might be used in the future, 

85 # so we test to prevent bitrot 

86 image = np.sum(self.data["images"].astype(float), axis=0) 

87 starlets = starlet_transform(image, scales=3) 

88 

89 # Execute to ensure that the code runs 

90 get_multiresolution_support(image, starlets, 0.1) 

91 get_multiresolution_support(image, starlets, 0.1, image_type="space") 

92 apply_wavelet_denoising(image) 

93 

94 def test_ground_branch_unbiased_sigma(self): 

95 """Audit finding D-5: the per-scale noise estimate in the 

96 ``image_type='ground'`` branch must compute std over the 

97 insignificant pixels only, not over the full array with 

98 significant pixels zeroed (which pulls the variance down). 

99 

100 Run the algorithm on a synthetic starlet image where a 

101 large fraction of pixels are above the significance 

102 threshold. ``sigma_j`` is the noise-only std at each scale, 

103 so even though most pixels are masked, the returned value 

104 must match ``np.std`` of the underlying noise pixels — not 

105 ``np.std`` of those pixels mixed with zeros. 

106 """ 

107 rng = np.random.default_rng(0) 

108 # Build a single-scale "starlet" array where everything is 

109 # noise: half the pixels are unit-sigma noise, the other 

110 # half are very-large-amplitude pixels that the iterative 

111 # threshold will mask out. The unmasked-only std should 

112 # converge to ~1.0; the bug's zero-padded std would be 

113 # roughly sqrt(0.5) ~ 0.71. 

114 noise = rng.normal(scale=1.0, size=(64, 64)).astype(np.float32) 

115 starlets_per_scale = noise.copy() 

116 starlets_per_scale[:32] += 100.0 # half the array is "signal" 

117 # Stack one finest-scale band plus a coarse residual. 

118 starlets = np.stack([starlets_per_scale, np.zeros_like(noise)]) 

119 # The image just needs a matching shape for the API. 

120 image = starlets.sum(axis=0) 

121 

122 result = get_multiresolution_support(image, starlets, 1.0, image_type="ground") 

123 # The finest scale's converged sigma must match the std of 

124 # the unmasked noise pixels (~1.0 to within iteration 

125 # tolerance), not the bug's zero-padded ~0.71. 

126 self.assertGreater(result.sigma[0], 0.9) 

127 self.assertLess(result.sigma[0], 1.1) 

128 

129 def test_space_branch_reproducible(self): 

130 """Audit finding D-6: the ``space`` branch draws a Gaussian 

131 noise realization to calibrate ``sigma_je``. Pre-fix it used 

132 the global ``np.random`` state, so two identical calls 

133 produced different supports unless the caller had seeded the 

134 global RNG. The default behavior must now be reproducible. 

135 """ 

136 rng = np.random.default_rng(42) 

137 image = rng.normal(scale=1.0, size=(64, 64)) 

138 starlets = starlet_transform(image, generation=1, scales=3) 

139 

140 r1 = get_multiresolution_support(image, starlets, 1.0, image_type="space") 

141 r2 = get_multiresolution_support(image, starlets, 1.0, image_type="space") 

142 np.testing.assert_array_equal(r1.support, r2.support) 

143 np.testing.assert_array_equal(r1.sigma, r2.sigma) 

144 

145 # Caller-supplied generator overrides the default seed. 

146 # Two calls each given a *fresh* seed-123 generator must 

147 # produce identical results. 

148 r3 = get_multiresolution_support( 

149 image, 

150 starlets, 

151 1.0, 

152 image_type="space", 

153 rng=np.random.default_rng(123), 

154 ) 

155 r4 = get_multiresolution_support( 

156 image, 

157 starlets, 

158 1.0, 

159 image_type="space", 

160 rng=np.random.default_rng(123), 

161 ) 

162 np.testing.assert_array_equal(r3.support, r4.support) 

163 

164 # And conversely: re-using the *same* generator instance 

165 # across two calls advances its state between them, so the 

166 # second call sees a different noise draw and may produce a 

167 # different support. (This is the standard ``np.random. 

168 # Generator`` contract — included to make the difference 

169 # between "fresh seed each call" and "shared mutable 

170 # generator" explicit.) 

171 shared = np.random.default_rng(123) 

172 r5 = get_multiresolution_support(image, starlets, 1.0, image_type="space", rng=shared) 

173 r6 = get_multiresolution_support(image, starlets, 1.0, image_type="space", rng=shared) 

174 with self.assertRaises(AssertionError): 

175 np.testing.assert_array_equal(r5.support, r6.support) 

176 

177 def test_space_branch_iterates_sigma(self): 

178 """Audit finding D-2: the ``image_type='space'`` branch of 

179 ``get_multiresolution_support`` implements the Starck & 

180 Murtagh 1998 multi-resolution support algorithm, which 

181 iteratively refines the global noise ``sigma_e`` from pixels 

182 that are insignificant at every scale. The iteration is 

183 meaningful only if each step's threshold uses the *previous* 

184 iteration's ``sigma``, otherwise the support never changes 

185 after iteration 0 and the loop is a no-op. 

186 

187 With a deliberately wrong input ``sigma`` (3x the true noise 

188 level), the algorithm must still converge to a support 

189 close to what the correct-sigma run produces. 

190 """ 

191 rng = np.random.default_rng(0) 

192 image = rng.normal(scale=1.0, size=(64, 64)) 

193 starlets = starlet_transform(image, generation=1, scales=3) 

194 

195 result_correct = get_multiresolution_support(image, starlets, 1.0, image_type="space") 

196 result_overestimate = get_multiresolution_support(image, starlets, 3.0, image_type="space") 

197 # With the bug, the overestimate run never re-thresholds the 

198 # mask and produces an essentially empty support (count = 0); 

199 # with the fix the iteration adapts and the support count is 

200 # within a small factor of the correct-sigma run. 

201 correct_count = result_correct.support.sum() 

202 overestimate_count = result_overestimate.support.sum() 

203 self.assertGreater(overestimate_count, 0) 

204 self.assertLess(abs(overestimate_count - correct_count), correct_count)