Coverage for python/lsst/scarlet/lite/wavelet.py: 9%

126 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 01:23 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "starlet_transform", 

24 "starlet_reconstruction", 

25 "multiband_starlet_transform", 

26 "multiband_starlet_reconstruction", 

27 "get_multiresolution_support", 

28] 

29 

30from dataclasses import dataclass 

31from typing import Callable, Sequence 

32 

33import numpy as np 

34 

35 

36def bspline_convolve(image: np.ndarray, scale: int) -> np.ndarray: 

37 """Convolve an image with a bspline at a given scale. 

38 

39 This uses the spline 

40 `h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])` 

41 from Starck et al. 2011. 

42 

43 Parameters 

44 ---------- 

45 image: 

46 The 2D image or wavelet coefficients to convolve. 

47 scale: 

48 The wavelet scale for the convolution. This sets the 

49 spacing between adjacent pixels with the spline. 

50 

51 Returns 

52 ------- 

53 result: 

54 The result of convolving the `image` with the spline. 

55 """ 

56 # Filter for the scarlet transform. Here bspline 

57 h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]).astype(image.dtype) 

58 j = scale 

59 

60 slice0 = slice(None, -(2 ** (j + 1))) 

61 slice1 = slice(None, -(2**j)) 

62 slice3 = slice(2**j, None) 

63 slice4 = slice(2 ** (j + 1), None) 

64 # row 

65 col = image * h1d[2] 

66 col[slice4] += image[slice0] * h1d[0] 

67 col[slice3] += image[slice1] * h1d[1] 

68 col[slice1] += image[slice3] * h1d[3] 

69 col[slice0] += image[slice4] * h1d[4] 

70 

71 # column 

72 result = col * h1d[2] 

73 result[:, slice4] += col[:, slice0] * h1d[0] 

74 result[:, slice3] += col[:, slice1] * h1d[1] 

75 result[:, slice1] += col[:, slice3] * h1d[3] 

76 result[:, slice0] += col[:, slice4] * h1d[4] 

77 return result 

78 

79 

80def get_starlet_scales(image_shape: Sequence[int], scales: int | None = None) -> int: 

81 """Get the number of scales to use in the starlet transform. 

82 

83 Parameters 

84 ---------- 

85 image_shape: 

86 The 2D shape of the image that is being transformed 

87 scales: 

88 The number of scales to transform with starlets. 

89 The total dimension of the starlet will have 

90 `scales+1` dimensions, since it will also hold 

91 the image at all scales higher than `scales`. 

92 

93 Returns 

94 ------- 

95 result: 

96 Number of scales, adjusted for the size of the image. 

97 """ 

98 # Number of levels for the Starlet decomposition 

99 max_scale = int(np.log2(np.min(image_shape[-2:]))) - 1 

100 if (scales is None) or scales > max_scale: 

101 scales = max_scale 

102 return int(scales) 

103 

104 

105def starlet_transform( 

106 image: np.ndarray, 

107 scales: int | None = None, 

108 generation: int = 2, 

109 convolve2d: Callable | None = None, 

110) -> np.ndarray: 

111 """Perform a starlet transform, or 2nd gen starlet transform. 

112 

113 Parameters 

114 ---------- 

115 image: 

116 The image to transform into starlet coefficients. 

117 scales: 

118 The number of scale to transform with starlets. 

119 The total dimension of the starlet will have 

120 `scales+1` dimensions, since it will also hold 

121 the image at all scales higher than `scales`. 

122 generation: 

123 The generation of the transform. 

124 This must be `1` or `2`. 

125 convolve2d: 

126 The filter function to use to convolve the image 

127 with starlets in 2D. 

128 

129 Returns 

130 ------- 

131 starlet: 

132 The starlet dictionary for the input `image`. 

133 """ 

134 if len(image.shape) != 2: 

135 raise ValueError(f"Image should be 2D, got {len(image.shape)}") 

136 if generation not in (1, 2): 

137 raise ValueError(f"generation should be 1 or 2, got {generation}") 

138 

139 scales = get_starlet_scales(image.shape, scales) 

140 c = image 

141 if convolve2d is None: 

142 convolve2d = bspline_convolve 

143 

144 # wavelet set of coefficients. 

145 starlet = np.zeros((scales + 1,) + image.shape, dtype=image.dtype) 

146 for j in range(scales): 

147 gen1 = convolve2d(c, j) 

148 

149 if generation == 2: 

150 gen2 = convolve2d(gen1, j) 

151 starlet[j] = c - gen2 

152 else: 

153 starlet[j] = c - gen1 

154 

155 c = gen1 

156 

157 starlet[-1] = c 

158 return starlet 

159 

160 

161def multiband_starlet_transform( 

162 image: np.ndarray, 

163 scales: int | None = None, 

164 generation: int = 2, 

165 convolve2d: Callable | None = None, 

166) -> np.ndarray: 

167 """Perform a starlet transform of a multiband image. 

168 

169 See `starlet_transform` for a description of the parameters. 

170 """ 

171 if len(image.shape) != 3: 

172 raise ValueError(f"Image should be 3D (bands, height, width), got shape {len(image.shape)}") 

173 if generation not in (1, 2): 

174 raise ValueError(f"generation should be 1 or 2, got {generation}") 

175 scales = get_starlet_scales(image.shape, scales) 

176 

177 wavelets = np.empty((scales + 1,) + image.shape, dtype=image.dtype) 

178 for b, band_image in enumerate(image): 

179 wavelets[:, b] = starlet_transform( 

180 band_image, scales=scales, generation=generation, convolve2d=convolve2d 

181 ) 

182 return wavelets 

183 

184 

185def starlet_reconstruction( 

186 starlets: np.ndarray, 

187 generation: int = 2, 

188 convolve2d: Callable | None = None, 

189) -> np.ndarray: 

190 """Reconstruct an image from a dictionary of starlets 

191 

192 Parameters 

193 ---------- 

194 starlets: 

195 The starlet dictionary used to reconstruct the image 

196 with dimension (scales+1, Ny, Nx). 

197 generation: 

198 The generation of the starlet transform (either ``1`` or ``2``). 

199 convolve2d: 

200 The filter function to use to convolve the image 

201 with starlets in 2D. 

202 

203 Returns 

204 ------- 

205 image: 

206 The 2D image reconstructed from the input `starlet`. 

207 """ 

208 if generation == 1: 

209 return np.sum(starlets, axis=0) 

210 if convolve2d is None: 

211 convolve2d = bspline_convolve 

212 scales = len(starlets) - 1 

213 

214 c = starlets[-1] 

215 for i in range(1, scales + 1): 

216 j = scales - i 

217 cj = convolve2d(c, j) 

218 c = cj + starlets[j] 

219 return c 

220 

221 

222def multiband_starlet_reconstruction( 

223 starlets: np.ndarray, 

224 generation: int = 2, 

225 convolve2d: Callable | None = None, 

226) -> np.ndarray: 

227 """Reconstruct a multiband image. 

228 

229 See `starlet_reconstruction` for a description of the 

230 remainder of the parameters. 

231 """ 

232 _, bands, height, width = starlets.shape 

233 result = np.zeros((bands, height, width), dtype=starlets.dtype) 

234 for band in range(bands): 

235 result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d) 

236 return result 

237 

238 

239@dataclass 

240class MultiResolutionSupport: 

241 """The multi-resolution support of a set of starlet coefficients. 

242 

243 Attributes 

244 ---------- 

245 support: 

246 A per-scale mask, with the shape of the starlet coefficients, 

247 that is non-zero where a coefficient is considered significant. 

248 sigma: 

249 The noise standard deviation estimated at each scale. 

250 """ 

251 

252 support: np.ndarray 

253 sigma: np.ndarray 

254 

255 

256def get_multiresolution_support( 

257 image: np.ndarray, 

258 starlets: np.ndarray, 

259 sigma: np.floating, 

260 sigma_scaling: float = 3, 

261 epsilon: float = 1e-1, 

262 max_iter: int = 20, 

263 image_type: str = "ground", 

264 rng: np.random.Generator | None = None, 

265) -> MultiResolutionSupport: 

266 """Calculate the multi-resolution support for a 

267 dictionary of starlet coefficients. 

268 

269 This is different for ground and space based telescopes. 

270 For space-based telescopes the procedure in Starck and Murtagh 1998 

271 iteratively calculates the multi-resolution support. 

272 For ground based images, where the PSF is much wider and there are no 

273 pixels with no signal at all scales, we use a modified method that 

274 estimates support at each scale independently. 

275 

276 Parameters 

277 ---------- 

278 image: 

279 The image to transform into starlet coefficients. 

280 starlets: 

281 The starlet dictionary used to reconstruct `image` with 

282 dimension (scales+1, Ny, Nx). 

283 sigma: 

284 The standard deviation of the `image`. 

285 sigma_scaling: 

286 The multiple of `sigma` to use to calculate significance. 

287 Coefficients `w` where `|w| > K*sigma_j`, where `sigma_j` is 

288 standard deviation at the jth scale, are considered significant. 

289 epsilon: 

290 The convergence criteria of the algorithm. 

291 Once `|new_sigma_j - sigma_j|/new_sigma_j < epsilon` the 

292 algorithm has completed. 

293 max_iter: 

294 Maximum number of iterations to fit `sigma_j` at each scale. 

295 image_type: 

296 The type of image that is being used. 

297 This should be "ground" for ground based images with wide PSFs or 

298 "space" for images from space-based telescopes with a narrow PSF. 

299 rng: 

300 Random number generator used to draw the Gaussian noise 

301 realization that calibrates ``sigma_je`` in the ``space`` 

302 branch. Defaults to ``np.random.default_rng(0)`` so repeated 

303 calls with the same input return the same support. 

304 

305 Returns 

306 ------- 

307 M: 

308 Mask with significant coefficients in `starlets` set to `True`. 

309 """ 

310 if image_type not in ("ground", "space"): 

311 raise ValueError(f"image_type must be 'ground' or 'space', got {image_type}") 

312 

313 if image_type == "space": 

314 # Calculate sigma_je, the standard deviation at 

315 # each scale due to gaussian noise 

316 if rng is None: 

317 rng = np.random.default_rng(0) 

318 noise_img = rng.normal(size=image.shape) 

319 noise_starlet = starlet_transform(noise_img, generation=1, scales=len(starlets) - 1) 

320 sigma_je = np.zeros((len(noise_starlet),)) 

321 for j, star in enumerate(noise_starlet): 

322 sigma_je[j] = np.std(star) 

323 noise = image - starlets[-1] 

324 

325 last_sigma_i = sigma 

326 for it in range(max_iter): 

327 m = np.abs(starlets) > sigma_scaling * last_sigma_i * sigma_je[:, None, None] 

328 s = np.sum(m, axis=0) == 0 

329 sigma_i = np.std(noise * s) 

330 if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon: 

331 break 

332 last_sigma_i = sigma_i 

333 sigma_j = sigma_je 

334 else: 

335 # Sigma to use for significance at each scale 

336 # Initially we use the input `sigma` 

337 sigma_j = np.full(len(starlets), sigma, dtype=image.dtype) 

338 last_sigma_j = sigma_j 

339 for it in range(max_iter): 

340 m = np.abs(starlets) > sigma_scaling * sigma_j[:, None, None] 

341 # Take the standard deviation of the current 

342 # insignificant coeffs at each scale, excluding 

343 # significant pixels entirely. Including them as zeros 

344 # (the pre-fix behavior) biased ``sigma_j`` downward 

345 # whenever a non-trivial fraction of pixels were 

346 # significant. Scales where every pixel is significant 

347 # get ``sigma_j[j] = 0``, treated downstream as "skip 

348 # this scale" by the ``sigma_j > 0`` cut. 

349 sigma_j = np.zeros(len(starlets), dtype=image.dtype) 

350 for j in range(len(starlets)): 

351 unmasked = starlets[j][~m[j]] 

352 if unmasked.size > 0: 

353 sigma_j[j] = np.std(unmasked) 

354 # At lower scales all of the pixels may be significant, 

355 # so sigma is effectively zero. To avoid infinities we 

356 # only check the scales with non-zero sigma 

357 cut = sigma_j > 0 

358 if np.all(np.abs(sigma_j[cut] - last_sigma_j[cut]) / sigma_j[cut] < epsilon): 

359 break 

360 

361 last_sigma_j = sigma_j 

362 # noinspection PyUnboundLocalVariable 

363 return MultiResolutionSupport(support=m.astype(int), sigma=sigma_j) 

364 

365 

366def apply_wavelet_denoising( 

367 image: np.ndarray, 

368 sigma: np.floating | None = None, 

369 sigma_scaling: float = 3, 

370 epsilon: float = 1e-1, 

371 max_iter: int = 20, 

372 image_type: str = "ground", 

373 positive: bool = True, 

374) -> np.ndarray: 

375 """Apply wavelet denoising 

376 

377 Uses the algorithm and notation from Starck et al. 2011, section 4.1 

378 

379 Parameters 

380 ---------- 

381 image: 

382 The image to denoise 

383 sigma: 

384 The standard deviation of the image 

385 sigma_scaling: 

386 The threshold in units of sigma to declare a coefficient significant 

387 epsilon: 

388 Convergence criteria for determining the support 

389 max_iter: 

390 The maximum number of iterations. 

391 This applies to both finding the support and the denoising loop. 

392 image_type: 

393 The type of image that is being used. 

394 This should be "ground" for ground based images with wide PSFs or 

395 "space" for images from space-based telescopes with a narrow PSF. 

396 positive: 

397 Whether or not the expected result should be positive 

398 

399 Returns 

400 ------- 

401 result: 

402 The resulting denoised image after `max_iter` iterations. 

403 """ 

404 image_coeffs = starlet_transform(image) 

405 if sigma is None: 

406 sigma = np.median(np.absolute(image - np.median(image))) 

407 coeffs = image_coeffs.copy() 

408 support = get_multiresolution_support( 

409 image=image, 

410 starlets=coeffs, 

411 sigma=sigma, 

412 sigma_scaling=sigma_scaling, 

413 epsilon=epsilon, 

414 max_iter=max_iter, 

415 image_type=image_type, 

416 ) 

417 x = starlet_reconstruction(coeffs) 

418 

419 for n in range(max_iter): 

420 coeffs = starlet_transform(x) 

421 x = x + starlet_reconstruction(support.support * (image_coeffs - coeffs)) 

422 if positive: 

423 x[x < 0] = 0 

424 return x