Coverage for python/lsst/scarlet/lite/wavelet.py: 9%
126 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:25 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:25 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = [
23 "starlet_transform",
24 "starlet_reconstruction",
25 "multiband_starlet_transform",
26 "multiband_starlet_reconstruction",
27 "get_multiresolution_support",
28]
30from dataclasses import dataclass
31from typing import Callable, Sequence
33import numpy as np
36def bspline_convolve(image: np.ndarray, scale: int) -> np.ndarray:
37 """Convolve an image with a bspline at a given scale.
39 This uses the spline
40 `h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])`
41 from Starck et al. 2011.
43 Parameters
44 ----------
45 image:
46 The 2D image or wavelet coefficients to convolve.
47 scale:
48 The wavelet scale for the convolution. This sets the
49 spacing between adjacent pixels with the spline.
51 Returns
52 -------
53 result:
54 The result of convolving the `image` with the spline.
55 """
56 # Filter for the scarlet transform. Here bspline
57 h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]).astype(image.dtype)
58 j = scale
60 slice0 = slice(None, -(2 ** (j + 1)))
61 slice1 = slice(None, -(2**j))
62 slice3 = slice(2**j, None)
63 slice4 = slice(2 ** (j + 1), None)
64 # row
65 col = image * h1d[2]
66 col[slice4] += image[slice0] * h1d[0]
67 col[slice3] += image[slice1] * h1d[1]
68 col[slice1] += image[slice3] * h1d[3]
69 col[slice0] += image[slice4] * h1d[4]
71 # column
72 result = col * h1d[2]
73 result[:, slice4] += col[:, slice0] * h1d[0]
74 result[:, slice3] += col[:, slice1] * h1d[1]
75 result[:, slice1] += col[:, slice3] * h1d[3]
76 result[:, slice0] += col[:, slice4] * h1d[4]
77 return result
80def get_starlet_scales(image_shape: Sequence[int], scales: int | None = None) -> int:
81 """Get the number of scales to use in the starlet transform.
83 Parameters
84 ----------
85 image_shape:
86 The 2D shape of the image that is being transformed
87 scales:
88 The number of scales to transform with starlets.
89 The total dimension of the starlet will have
90 `scales+1` dimensions, since it will also hold
91 the image at all scales higher than `scales`.
93 Returns
94 -------
95 result:
96 Number of scales, adjusted for the size of the image.
97 """
98 # Number of levels for the Starlet decomposition
99 max_scale = int(np.log2(np.min(image_shape[-2:]))) - 1
100 if (scales is None) or scales > max_scale:
101 scales = max_scale
102 return int(scales)
105def starlet_transform(
106 image: np.ndarray,
107 scales: int | None = None,
108 generation: int = 2,
109 convolve2d: Callable | None = None,
110) -> np.ndarray:
111 """Perform a starlet transform, or 2nd gen starlet transform.
113 Parameters
114 ----------
115 image:
116 The image to transform into starlet coefficients.
117 scales:
118 The number of scale to transform with starlets.
119 The total dimension of the starlet will have
120 `scales+1` dimensions, since it will also hold
121 the image at all scales higher than `scales`.
122 generation:
123 The generation of the transform.
124 This must be `1` or `2`.
125 convolve2d:
126 The filter function to use to convolve the image
127 with starlets in 2D.
129 Returns
130 -------
131 starlet:
132 The starlet dictionary for the input `image`.
133 """
134 if len(image.shape) != 2:
135 raise ValueError(f"Image should be 2D, got {len(image.shape)}")
136 if generation not in (1, 2):
137 raise ValueError(f"generation should be 1 or 2, got {generation}")
139 scales = get_starlet_scales(image.shape, scales)
140 c = image
141 if convolve2d is None:
142 convolve2d = bspline_convolve
144 # wavelet set of coefficients.
145 starlet = np.zeros((scales + 1,) + image.shape, dtype=image.dtype)
146 for j in range(scales):
147 gen1 = convolve2d(c, j)
149 if generation == 2:
150 gen2 = convolve2d(gen1, j)
151 starlet[j] = c - gen2
152 else:
153 starlet[j] = c - gen1
155 c = gen1
157 starlet[-1] = c
158 return starlet
161def multiband_starlet_transform(
162 image: np.ndarray,
163 scales: int | None = None,
164 generation: int = 2,
165 convolve2d: Callable | None = None,
166) -> np.ndarray:
167 """Perform a starlet transform of a multiband image.
169 See `starlet_transform` for a description of the parameters.
170 """
171 if len(image.shape) != 3:
172 raise ValueError(f"Image should be 3D (bands, height, width), got shape {len(image.shape)}")
173 if generation not in (1, 2):
174 raise ValueError(f"generation should be 1 or 2, got {generation}")
175 scales = get_starlet_scales(image.shape, scales)
177 wavelets = np.empty((scales + 1,) + image.shape, dtype=image.dtype)
178 for b, band_image in enumerate(image):
179 wavelets[:, b] = starlet_transform(
180 band_image, scales=scales, generation=generation, convolve2d=convolve2d
181 )
182 return wavelets
185def starlet_reconstruction(
186 starlets: np.ndarray,
187 generation: int = 2,
188 convolve2d: Callable | None = None,
189) -> np.ndarray:
190 """Reconstruct an image from a dictionary of starlets
192 Parameters
193 ----------
194 starlets:
195 The starlet dictionary used to reconstruct the image
196 with dimension (scales+1, Ny, Nx).
197 generation:
198 The generation of the starlet transform (either ``1`` or ``2``).
199 convolve2d:
200 The filter function to use to convolve the image
201 with starlets in 2D.
203 Returns
204 -------
205 image:
206 The 2D image reconstructed from the input `starlet`.
207 """
208 if generation == 1:
209 return np.sum(starlets, axis=0)
210 if convolve2d is None:
211 convolve2d = bspline_convolve
212 scales = len(starlets) - 1
214 c = starlets[-1]
215 for i in range(1, scales + 1):
216 j = scales - i
217 cj = convolve2d(c, j)
218 c = cj + starlets[j]
219 return c
222def multiband_starlet_reconstruction(
223 starlets: np.ndarray,
224 generation: int = 2,
225 convolve2d: Callable | None = None,
226) -> np.ndarray:
227 """Reconstruct a multiband image.
229 See `starlet_reconstruction` for a description of the
230 remainder of the parameters.
231 """
232 _, bands, height, width = starlets.shape
233 result = np.zeros((bands, height, width), dtype=starlets.dtype)
234 for band in range(bands):
235 result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d)
236 return result
239@dataclass
240class MultiResolutionSupport:
241 """The multi-resolution support of a set of starlet coefficients.
243 Attributes
244 ----------
245 support:
246 A per-scale mask, with the shape of the starlet coefficients,
247 that is non-zero where a coefficient is considered significant.
248 sigma:
249 The noise standard deviation estimated at each scale.
250 """
252 support: np.ndarray
253 sigma: np.ndarray
256def get_multiresolution_support(
257 image: np.ndarray,
258 starlets: np.ndarray,
259 sigma: np.floating,
260 sigma_scaling: float = 3,
261 epsilon: float = 1e-1,
262 max_iter: int = 20,
263 image_type: str = "ground",
264 rng: np.random.Generator | None = None,
265) -> MultiResolutionSupport:
266 """Calculate the multi-resolution support for a
267 dictionary of starlet coefficients.
269 This is different for ground and space based telescopes.
270 For space-based telescopes the procedure in Starck and Murtagh 1998
271 iteratively calculates the multi-resolution support.
272 For ground based images, where the PSF is much wider and there are no
273 pixels with no signal at all scales, we use a modified method that
274 estimates support at each scale independently.
276 Parameters
277 ----------
278 image:
279 The image to transform into starlet coefficients.
280 starlets:
281 The starlet dictionary used to reconstruct `image` with
282 dimension (scales+1, Ny, Nx).
283 sigma:
284 The standard deviation of the `image`.
285 sigma_scaling:
286 The multiple of `sigma` to use to calculate significance.
287 Coefficients `w` where `|w| > K*sigma_j`, where `sigma_j` is
288 standard deviation at the jth scale, are considered significant.
289 epsilon:
290 The convergence criteria of the algorithm.
291 Once `|new_sigma_j - sigma_j|/new_sigma_j < epsilon` the
292 algorithm has completed.
293 max_iter:
294 Maximum number of iterations to fit `sigma_j` at each scale.
295 image_type:
296 The type of image that is being used.
297 This should be "ground" for ground based images with wide PSFs or
298 "space" for images from space-based telescopes with a narrow PSF.
299 rng:
300 Random number generator used to draw the Gaussian noise
301 realization that calibrates ``sigma_je`` in the ``space``
302 branch. Defaults to ``np.random.default_rng(0)`` so repeated
303 calls with the same input return the same support.
305 Returns
306 -------
307 M:
308 Mask with significant coefficients in `starlets` set to `True`.
309 """
310 if image_type not in ("ground", "space"):
311 raise ValueError(f"image_type must be 'ground' or 'space', got {image_type}")
313 if image_type == "space":
314 # Calculate sigma_je, the standard deviation at
315 # each scale due to gaussian noise
316 if rng is None:
317 rng = np.random.default_rng(0)
318 noise_img = rng.normal(size=image.shape)
319 noise_starlet = starlet_transform(noise_img, generation=1, scales=len(starlets) - 1)
320 sigma_je = np.zeros((len(noise_starlet),))
321 for j, star in enumerate(noise_starlet):
322 sigma_je[j] = np.std(star)
323 noise = image - starlets[-1]
325 last_sigma_i = sigma
326 for it in range(max_iter):
327 m = np.abs(starlets) > sigma_scaling * last_sigma_i * sigma_je[:, None, None]
328 s = np.sum(m, axis=0) == 0
329 sigma_i = np.std(noise * s)
330 if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon:
331 break
332 last_sigma_i = sigma_i
333 sigma_j = sigma_je
334 else:
335 # Sigma to use for significance at each scale
336 # Initially we use the input `sigma`
337 sigma_j = np.full(len(starlets), sigma, dtype=image.dtype)
338 last_sigma_j = sigma_j
339 for it in range(max_iter):
340 m = np.abs(starlets) > sigma_scaling * sigma_j[:, None, None]
341 # Take the standard deviation of the current
342 # insignificant coeffs at each scale, excluding
343 # significant pixels entirely. Including them as zeros
344 # (the pre-fix behavior) biased ``sigma_j`` downward
345 # whenever a non-trivial fraction of pixels were
346 # significant. Scales where every pixel is significant
347 # get ``sigma_j[j] = 0``, treated downstream as "skip
348 # this scale" by the ``sigma_j > 0`` cut.
349 sigma_j = np.zeros(len(starlets), dtype=image.dtype)
350 for j in range(len(starlets)):
351 unmasked = starlets[j][~m[j]]
352 if unmasked.size > 0:
353 sigma_j[j] = np.std(unmasked)
354 # At lower scales all of the pixels may be significant,
355 # so sigma is effectively zero. To avoid infinities we
356 # only check the scales with non-zero sigma
357 cut = sigma_j > 0
358 if np.all(np.abs(sigma_j[cut] - last_sigma_j[cut]) / sigma_j[cut] < epsilon):
359 break
361 last_sigma_j = sigma_j
362 # noinspection PyUnboundLocalVariable
363 return MultiResolutionSupport(support=m.astype(int), sigma=sigma_j)
366def apply_wavelet_denoising(
367 image: np.ndarray,
368 sigma: np.floating | None = None,
369 sigma_scaling: float = 3,
370 epsilon: float = 1e-1,
371 max_iter: int = 20,
372 image_type: str = "ground",
373 positive: bool = True,
374) -> np.ndarray:
375 """Apply wavelet denoising
377 Uses the algorithm and notation from Starck et al. 2011, section 4.1
379 Parameters
380 ----------
381 image:
382 The image to denoise
383 sigma:
384 The standard deviation of the image
385 sigma_scaling:
386 The threshold in units of sigma to declare a coefficient significant
387 epsilon:
388 Convergence criteria for determining the support
389 max_iter:
390 The maximum number of iterations.
391 This applies to both finding the support and the denoising loop.
392 image_type:
393 The type of image that is being used.
394 This should be "ground" for ground based images with wide PSFs or
395 "space" for images from space-based telescopes with a narrow PSF.
396 positive:
397 Whether or not the expected result should be positive
399 Returns
400 -------
401 result:
402 The resulting denoised image after `max_iter` iterations.
403 """
404 image_coeffs = starlet_transform(image)
405 if sigma is None:
406 sigma = np.median(np.absolute(image - np.median(image)))
407 coeffs = image_coeffs.copy()
408 support = get_multiresolution_support(
409 image=image,
410 starlets=coeffs,
411 sigma=sigma,
412 sigma_scaling=sigma_scaling,
413 epsilon=epsilon,
414 max_iter=max_iter,
415 image_type=image_type,
416 )
417 x = starlet_reconstruction(coeffs)
419 for n in range(max_iter):
420 coeffs = starlet_transform(x)
421 x = x + starlet_reconstruction(support.support * (image_coeffs - coeffs))
422 if positive:
423 x[x < 0] = 0
424 return x