Coverage for python/lsst/scarlet/lite/models/free_form.py: 27%

112 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 08:24 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["FactorizedFreeFormComponent"] 

24 

25from copy import deepcopy 

26from typing import TYPE_CHECKING, Any, Callable, cast 

27 

28import numpy as np 

29 

30from ..bbox import Box 

31from ..component import Component, FactorizedComponent 

32from ..detect import footprints_to_image 

33from ..detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore 

34from ..image import Image 

35from ..parameters import Parameter, parameter 

36 

37if TYPE_CHECKING: 

38 from ..io.component import ScarletComponentBaseData 

39 

40 

41class FactorizedFreeFormComponent(FactorizedComponent): 

42 """Implements a free-form component 

43 

44 With no constraints this component is typically either a garbage collector, 

45 or part of a set of components to deconvolve an image by separating out 

46 the different spectral components. 

47 

48 See `FactorizedComponent` for a list of parameters not shown here. 

49 

50 Parameters 

51 ---------- 

52 peaks: `list` of `tuple` 

53 A set of ``(cy, cx)`` peaks for detected sources. 

54 If peak is not ``None`` then only pixels in the same "footprint" 

55 as one of the peaks are included in the morphology. 

56 If `peaks` is ``None`` then there is no constraint applied. 

57 min_area: float 

58 The minimum area for a peak. 

59 If `min_area` is not `None` then all regions of the morphology 

60 with fewer than `min_area` connected pixels are removed. 

61 """ 

62 

63 def __init__( 

64 self, 

65 bands: tuple, 

66 spectrum: np.ndarray | Parameter, 

67 morph: np.ndarray | Parameter, 

68 model_bbox: Box, 

69 bg_thresh: float | None = None, 

70 bg_rms: np.ndarray | None = None, 

71 floor: float = 1e-20, 

72 peaks: list[tuple[int, int]] | None = None, 

73 min_area: float = 0, 

74 ): 

75 super().__init__( 

76 bands=bands, 

77 spectrum=spectrum, 

78 morph=morph, 

79 bbox=model_bbox, 

80 peak=None, 

81 bg_rms=bg_rms, 

82 bg_thresh=bg_thresh, 

83 floor=floor, 

84 ) 

85 

86 self.peaks = peaks 

87 self.min_area = min_area 

88 

89 def prox_spectrum(self, spectrum: np.ndarray) -> np.ndarray: 

90 """Apply a prox-like update to the spectrum 

91 

92 This differs from `FactorizedComponent` because an 

93 `SedComponent` has the spectrum normalized to unity. 

94 """ 

95 # prevent divergent spectrum 

96 spectrum[spectrum < self.floor] = self.floor 

97 # Normalize the spectrum 

98 spectrum = spectrum / np.sum(spectrum) 

99 return spectrum 

100 

101 def prox_morph(self, morph: np.ndarray) -> np.ndarray: 

102 """Apply a prox-like update to the morphology 

103 

104 This is the main difference between an `SedComponent` and a 

105 `FactorizedComponent`, since this component has fewer constraints. 

106 """ 

107 from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore 

108 

109 if self.bg_thresh is not None and isinstance(self.bg_rms, np.ndarray): 

110 bg_thresh = self.bg_rms * self.bg_thresh 

111 # Enforce background thresholding 

112 model = self.spectrum[:, None, None] * morph[None, :, :] 

113 morph[np.all(model < bg_thresh[:, None, None], axis=0)] = 0 

114 else: 

115 # enforce positivity 

116 morph[morph < 0] = 0 

117 

118 if self.peaks is not None: 

119 footprint = get_connected_multipeak(morph, self.peaks, 0) 

120 morph = morph * footprint 

121 

122 if self.min_area > 0: 

123 footprints = get_footprints(morph, 4.0, self.min_area, 0, 0, False) 

124 bbox = self.bbox.copy() 

125 bbox.origin = (0, 0) 

126 footprint_image = footprints_to_image(footprints, bbox) 

127 morph = morph * (footprint_image > 0).data 

128 

129 if np.all(morph == 0): 

130 morph[0, 0] = self.floor 

131 

132 return morph 

133 

134 def resize(self, model_box: Box) -> bool: 

135 return False 

136 

137 def __str__(self): 

138 return ( 

139 f"FactorizedFreeFormComponent(\n bands={self.bands}\n " 

140 f"spectrum={self.spectrum})\n center={self.peak}\n " 

141 f"morph_shape={self.morph.shape}" 

142 ) 

143 

144 def __repr__(self): 

145 return self.__str__() 

146 

147 

148class FreeFormComponent(Component): 

149 """Implements a component with no spectral or monotonicty constraints 

150 

151 This is a FreeFormComponent that is not factorized into a 

152 spectrum and morphology with no monotonicity constraint. 

153 

154 Parameters 

155 ---------- 

156 bands: 

157 The bands covered by the component. 

158 model: 

159 The 3D (bands, y, x) model of the component. 

160 model_bbox: 

161 The bounding box of the model. 

162 bg_thresh: 

163 The background threshold, in units of `bg_rms`, below which 

164 pixels are set to zero. If `None` then only positivity is enforced. 

165 bg_rms: 

166 The background RMS in each band. 

167 floor: 

168 The minimum value to use for the model when it is otherwise empty. 

169 peaks: 

170 The `(y, x)` peaks of the component, used to keep only the pixels 

171 connected to a peak. If `None` then no peak connectivity is enforced. 

172 min_area: 

173 The minimum area (in pixels) of a connected footprint to keep. 

174 """ 

175 

176 def __init__( 

177 self, 

178 bands: tuple, 

179 model: np.ndarray | Parameter, 

180 model_bbox: Box, 

181 bg_thresh: float | None = None, 

182 bg_rms: np.ndarray | None = None, 

183 floor: float = 1e-20, 

184 peaks: list[tuple[int, int]] | None = None, 

185 min_area: float = 0, 

186 ): 

187 super().__init__(bands=bands, bbox=model_bbox) 

188 self._model = parameter(model) 

189 self.bg_rms = bg_rms 

190 self.bg_thresh = bg_thresh 

191 self.floor = floor 

192 self.peaks = peaks 

193 self.min_area = min_area 

194 

195 @property 

196 def model(self) -> np.ndarray: 

197 return self._model.x 

198 

199 def get_model(self) -> Image: 

200 return Image(self.model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin)) 

201 

202 @property 

203 def shape(self) -> tuple: 

204 return self.model.shape 

205 

206 def grad_model(self, input_grad: np.ndarray, model: np.ndarray) -> np.ndarray: 

207 return input_grad 

208 

209 def prox_model(self, model: np.ndarray) -> np.ndarray: 

210 if self.bg_thresh is not None and isinstance(self.bg_rms, np.ndarray): 

211 bg_thresh = self.bg_rms * self.bg_thresh 

212 # Enforce background thresholding 

213 model[model < bg_thresh[:, None, None]] = 0 

214 else: 

215 # enforce positivity 

216 model[model < 0] = 0 

217 

218 if self.peaks is not None: 

219 # Remove pixels not connected to one of the peaks 

220 model2d = np.sum(model, axis=0) 

221 footprint = get_connected_multipeak(model2d, self.peaks, 0) 

222 model = model * footprint[None, :, :] 

223 

224 if self.min_area > 0: 

225 # Remove regions with fewer than min_area connected pixels 

226 model2d = np.sum(model, axis=0) 

227 footprints = get_footprints(model2d, 4.0, self.min_area, 0, 0, False) 

228 bbox = self.bbox.copy() 

229 bbox.origin = (0, 0) 

230 footprint_image = footprints_to_image(footprints, bbox) 

231 model = model * (footprint_image > 0).data[None, :, :] 

232 

233 if np.all(model == 0): 

234 # If the model is all zeros, set a single pixel to the floor 

235 model[0, 0] = self.floor 

236 

237 return model 

238 

239 def resize(self, model_box: Box) -> bool: 

240 return False 

241 

242 def update(self, it: int, grad_log_likelihood: np.ndarray): 

243 self._model.update(it, grad_log_likelihood) 

244 

245 def parameterize(self, parameterization: Callable) -> None: 

246 """Convert the component parameter arrays into Parameter instances 

247 

248 Parameters 

249 ---------- 

250 parameterization: Callable 

251 A function to use to convert parameters of a given type into 

252 a `Parameter` in place. It should take a single argument that 

253 is the `Component` or `Source` that is to be parameterized. 

254 """ 

255 # Update the spectrum and morph in place 

256 parameterization(self) 

257 # update the parameters 

258 self._model.grad = self.grad_model 

259 self._model.prox = self.prox_model 

260 

261 def __str__(self): 

262 result = f"FreeFormComponent<bands={self.bands}, shape={self.shape}>" 

263 return result 

264 

265 def __repr__(self): 

266 return self.__str__() 

267 

268 def to_data(self) -> ScarletComponentBaseData: 

269 raise NotImplementedError("Serialization not implemented for FreeFormComponent") 

270 

271 def __getitem__(self, indices: Any) -> FreeFormComponent: 

272 """Get a sub-component corresponding to the given indices. 

273 

274 Parameters 

275 ---------- 

276 indices: Any 

277 The indices to use to slice the component model. 

278 

279 Returns 

280 ------- 

281 component: FreeFormComponent 

282 A new component that is a sub-component of this one. 

283 

284 Raises 

285 ------ 

286 IndexError : 

287 If the index includes a ``Box`` or spatial indices. 

288 """ 

289 if indices in self.bands: 

290 bands = (indices,) 

291 else: 

292 bands = tuple(indices) 

293 

294 return FreeFormComponent( 

295 bands=bands, 

296 model=self.model[indices], 

297 model_bbox=self.bbox, 

298 bg_thresh=self.bg_thresh, 

299 bg_rms=self.bg_rms, 

300 floor=self.floor, 

301 peaks=self.peaks, 

302 min_area=self.min_area, 

303 ) 

304 

305 def __deepcopy__(self, memo: dict[int, Any]) -> FreeFormComponent: 

306 """Create a deep copy of this component. 

307 

308 Parameters 

309 ---------- 

310 memo: dict[int, Any] 

311 A dictionary to keep track of already copied objects. 

312 

313 Returns 

314 ------- 

315 component : FreeFormComponent 

316 A new component that is a deep copy of this one. 

317 """ 

318 if id(self) in memo: 

319 return memo[id(self)] 

320 

321 component = FreeFormComponent.__new__(FreeFormComponent) 

322 memo[id(self)] = component 

323 

324 component.__init__( # type: ignore[misc] 

325 bands=deepcopy(self.bands), 

326 model=deepcopy(self.model), 

327 model_bbox=deepcopy(self.bbox), 

328 bg_thresh=self.bg_thresh, 

329 bg_rms=deepcopy(self.bg_rms), 

330 floor=self.floor, 

331 peaks=deepcopy(self.peaks), 

332 min_area=self.min_area, 

333 ) 

334 return component 

335 

336 def __copy__(self) -> FreeFormComponent: 

337 """Create a copy of this component. 

338 

339 Returns 

340 ------- 

341 component : FreeFormComponent 

342 A new component that is a copy of this one. 

343 """ 

344 return FreeFormComponent( 

345 bands=self.bands, 

346 model=self.model, 

347 model_bbox=self.bbox, 

348 bg_thresh=self.bg_thresh, 

349 bg_rms=self.bg_rms, 

350 floor=self.floor, 

351 peaks=self.peaks, 

352 min_area=self.min_area, 

353 )