Coverage for python/lsst/pipe/tasks/prettyPictureMaker/_task.py: 18%

534 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 08:13 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "ChannelRGBConfig", 

26 "PrettyPictureTask", 

27 "PrettyPictureConnections", 

28 "PrettyPictureConfig", 

29 "PrettyMosaicTask", 

30 "PrettyMosaicConnections", 

31 "PrettyMosaicConfig", 

32 "PrettyPictureBackgroundFixerConfig", 

33 "PrettyPictureBackgroundFixerTask", 

34 "PrettyPictureStarFixerConfig", 

35 "PrettyPictureStarFixerTask", 

36) 

37 

38import colour 

39import copy 

40import itertools 

41from collections.abc import Iterable, Mapping 

42from lsst.afw.image import ExposureF 

43import numpy as np 

44from typing import TYPE_CHECKING, cast, Any 

45from lsst.skymap import BaseSkyMap 

46 

47from scipy.stats import halfnorm, mode 

48from scipy.ndimage import binary_dilation 

49from scipy.interpolate import RBFInterpolator 

50from skimage.restoration import inpaint_biharmonic 

51 

52from lsst.daf.butler import Butler, DataCoordinate, DeferredDatasetHandle 

53from lsst.daf.butler import DatasetRef 

54from lsst.images import ColorImage, Projection, Box, TractFrame 

55from lsst.pex.config import Field, Config, ConfigDictField, ListField, ChoiceField 

56from lsst.pex.config.configurableActions import ConfigurableActionField 

57from lsst.pipe.base import ( 

58 PipelineTask, 

59 PipelineTaskConfig, 

60 PipelineTaskConnections, 

61 Struct, 

62 InMemoryDatasetHandle, 

63 NoWorkFound, 

64 QuantaAdjuster, 

65) 

66from lsst.rubinoxide import rbf_interpolator 

67import cv2 

68 

69from lsst.pipe.base.connectionTypes import Input, Output 

70from lsst.geom import Box2I, Point2I, Extent2I 

71from lsst.afw.image import Exposure, Mask 

72from lsst.skymap import Index2D 

73 

74from ._plugins import plugins 

75from ._colorMapper import lsstRGB 

76from ._utils import FeatheredMosaicCreator 

77from ._functors import ( 

78 BoundsRemapper, 

79 ColorScaler, 

80 LumCompressor, 

81 ExposureBracketer, 

82 GamutFixer, 

83 LocalContrastEnhancer, 

84) 

85 

86import logging 

87import tempfile 

88 

89logger = logging.getLogger(__name__) 

90 

91if TYPE_CHECKING: 

92 from numpy.typing import NDArray 

93 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection 

94 from lsst.skymap import TractInfo, PatchInfo 

95 

96 

97class PrettyPictureConnections( 

98 PipelineTaskConnections, 

99 dimensions={"tract", "patch", "skymap"}, 

100 defaultTemplates={"coaddTypeName": "deep"}, 

101): 

102 inputCoadds = Input( 

103 doc=( 

104 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, " 

105 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True" 

106 ), 

107 name="pretty_coadd", 

108 storageClass="ExposureF", 

109 dimensions=("tract", "patch", "skymap", "band"), 

110 multiple=True, 

111 ) 

112 

113 skyMap = Input( 

114 doc="The skymap which the data has been mapped onto", 

115 storageClass="SkyMap", 

116 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

117 dimensions=("skymap",), 

118 ) 

119 

120 outputRGB = Output( 

121 doc="A RGB image created from the input data stored as a 3d array", 

122 name="rgb_picture", 

123 storageClass="ColorImage", 

124 dimensions=("tract", "patch", "skymap"), 

125 ) 

126 

127 outputRGBMask = Output( 

128 doc="A Mask corresponding to the fused masks of the input channels", 

129 name="rgb_picture_mask", 

130 storageClass="Mask", 

131 dimensions=("tract", "patch", "skymap"), 

132 ) 

133 

134 

135class ChannelRGBConfig(Config): 

136 """This describes the rgb values of a given input channel. 

137 

138 For instance if this channel is red the values would be self.r = 1, 

139 self.g = 0, self.b = 0. If the channel was cyan the values would be 

140 self.r = 0, self.g = 1, self.b = 1. 

141 """ 

142 

143 r = Field[float](doc="The amount of red contained in this channel") 

144 g = Field[float](doc="The amount of green contained in this channel") 

145 b = Field[float](doc="The amount of blue contained in this channel") 

146 

147 

148class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections): 

149 channelConfig = ConfigDictField( 

150 doc="A dictionary that maps band names to their rgb channel configurations", 

151 keytype=str, 

152 itemtype=ChannelRGBConfig, 

153 default={}, 

154 ) 

155 cieWhitePoint = ListField[float]( 

156 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28] 

157 ) 

158 arrayType = ChoiceField[str]( 

159 doc="The dataset type for the output image array", 

160 default="uint8", 

161 allowed={ 

162 "uint8": "Use 8 bit arrays, 255 max", 

163 "uint16": "Use 16 bit arrays, 65535 max", 

164 "half": "Use 16 bit float arrays, 1 max", 

165 "float": "Use 32 bit float arrays, 1 max", 

166 }, 

167 ) 

168 recenterNoise = Field[float]( 

169 doc="Recenter the noise away from zero. Supplied value is in units of sigma", 

170 optional=True, 

171 default=None, 

172 ) 

173 noiseSearchThreshold = Field[float]( 

174 doc=( 

175 "Flux threshold below which most flux will be considered noise, used to estimate noise properties" 

176 ), 

177 default=2, 

178 ) 

179 maxNoiseImbalance = Field[float]( 

180 doc=( 

181 "When recentering noise, if the ratio of counts of positive pixels, to negative pixels passes " 

182 "this threshold, consider there to be extended low flux and only estimate noise below zero." 

183 ), 

184 default=1.5, 

185 ) 

186 doPsfDeconvolve = Field[bool]( 

187 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", default=False 

188 ) 

189 doPSFDeconcovlve = Field[bool]( 

190 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", 

191 default=False, 

192 deprecated="This field will be removed in v32. Use doPsfDeconvolve instead.", 

193 optional=True, 

194 ) 

195 doRemapGamut = Field[bool]( 

196 doc="Apply a color correction to unrepresentable colors; if False, clip them.", default=True 

197 ) 

198 doExposureBrackets = Field[bool]( 

199 doc="Apply exposure bracketing to aid in dynamic range compression", default=True 

200 ) 

201 doLocalContrast = Field[bool](doc="Apply local contrast optimizations to luminance.", default=True) 

202 

203 imageRemappingConfig = ConfigurableActionField[BoundsRemapper]( 

204 doc="Action controlling normalization process" 

205 ) 

206 luminanceConfig = ConfigurableActionField[LumCompressor]( 

207 doc="Action controlling luminance scaling when making an RGB image" 

208 ) 

209 localContrastConfig = ConfigurableActionField[LocalContrastEnhancer]( 

210 doc="Action controlling the local contrast correction in RGB image production" 

211 ) 

212 colorConfig = ConfigurableActionField[ColorScaler]( 

213 doc="Action to control the color scaling process in RGB image production" 

214 ) 

215 exposureBracketerConfig = ConfigurableActionField[ExposureBracketer]( 

216 doc=( 

217 "Exposure scaling action used in creating multiple exposures with different scalings which will " 

218 "then be fused into a final image" 

219 ), 

220 ) 

221 gamutMapperConfig = ConfigurableActionField[GamutFixer]( 

222 doc="Action to fix pixels which lay outside RGB color gamut" 

223 ) 

224 

225 exposureBrackets = ListField[float]( 

226 doc=( 

227 "Exposure scaling factors used in creating multiple exposures with different scalings which will " 

228 "then be fused into a final image" 

229 ), 

230 optional=True, 

231 default=[1.25, 1, 0.75], 

232 deprecated=( 

233 "This field will stop working in v31 and be removed in v32, " 

234 "please set exposureBracketerConfig.exposureBrackets" 

235 ), 

236 ) 

237 gamutMethod = ChoiceField[str]( 

238 doc="If doRemapGamut is True this determines the method", 

239 default="inpaint", 

240 allowed={ 

241 "mapping": "Use a mapping function", 

242 "inpaint": "Use surrounding pixels to determine likely value", 

243 }, 

244 deprecated="This field will stop working in v31 and be removed in v32, please set gamutMapperConfig", 

245 ) 

246 

247 def setDefaults(self): 

248 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0) 

249 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0) 

250 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1) 

251 return super().setDefaults() 

252 

253 def _handle_deprecated(self): 

254 """Handle deprecated configuration migration. 

255 

256 This method migrates deprecated configuration fields to their new 

257 locations in sub-configurations. It checks the configuration history 

258 to determine if deprecated fields were explicitly set and updates 

259 the new configuration locations accordingly. 

260 

261 Notes 

262 ----- 

263 The following deprecated fields are migrated: 

264 - ``gamutMethod`` -> ``gamutMapperConfig.gamutMethod`` 

265 - ``exposureBrackets`` -> ``exposureBracketerConfig.exposureBrackets`` 

266 - ``doLocalContrast`` -> ``localContrastConfig.doLocalContrast`` 

267 - ``doPSFDeconcovlve`` -> ``doPsfDeconvolve`` 

268 """ 

269 # check if gamutMethod is set 

270 if len(self._history["gamutMethod"]) > 1: 

271 # This has been set in config, update it in the new location 

272 self.gamutMapperConfig.gamutMethod = self.gamutMethod 

273 

274 if len(self._history["exposureBrackets"]) > 1: 

275 self.exposureBracketerConfig.exposureBrackets = self.exposureBrackets 

276 if self.exposureBrackets is None: 

277 self.doExposureBrackets = False 

278 

279 if len(self.localContrastConfig._history["doLocalContrast"]) > 1: 

280 self.doLocalContrast = self.localContrastConfig.doLocalContrast 

281 

282 # Handle doPsfDeconcovlve typo fix 

283 if len(self._history["doPSFDeconcovlve"]) > 1: 

284 self.doPsfDeconvolve = self.doPSFDeconcovlve 

285 

286 def freeze(self): 

287 # ensure this is not already frozen 

288 if self._frozen is not True: 

289 self._handle_deprecated() 

290 super().freeze() 

291 

292 

293class PrettyPictureTask(PipelineTask): 

294 """Turns inputs into an RGB image.""" 

295 

296 _DefaultName = "prettyPicture" 

297 ConfigClass = PrettyPictureConfig 

298 

299 config: ConfigClass 

300 

301 def _find_normal_stats(self, array): 

302 """Calculate standard deviation from negative values using half-normal distribution. 

303 

304 Raises 

305 ------ 

306 ValueError 

307 Array dimension validation fails. 

308 

309 Parameters 

310 ---------- 

311 array : `numpy.array` 

312 Input array of numerical values. 

313 

314 Returns 

315 ------- 

316 mean : `float` 

317 The central moment of the distribution 

318 sigma : `float` 

319 Estimated standard deviation from negative values. Returns np.inf if: 

320 - No negative values exist in the array 

321 - Half-normal fitting fails 

322 """ 

323 # Extract negative values efficiently 

324 values_noise = array[array < self.config.noiseSearchThreshold] 

325 

326 # find the mode 

327 center = mode(np.round(values_noise, 2)).mode 

328 

329 # extract the negative values 

330 values_neg = array[array < center] 

331 

332 # Return infinity if no negative values found 

333 if values_neg.size == 0: 

334 return 0, np.inf 

335 

336 try: 

337 # Fit half-normal distribution to absolute negative values 

338 _, sigma = halfnorm.fit(np.abs(values_neg - center), floc=0) 

339 mu = center 

340 except (ValueError, RuntimeError): 

341 # Handle fitting failures (e.g., constant data, optimization issues) 

342 return 0, np.inf 

343 

344 # examine for excess positive flux, this means there is contaminating signal 

345 new_cut = array[array < (mu + 3 * sigma)] 

346 positivity_ratio = np.sum(new_cut > mu) / np.sum(new_cut < mu) 

347 

348 if positivity_ratio > self.config.maxNoiseImbalance: 

349 # This means there is an excess flux, possibly diffuse source, 

350 # only estimate around zero. 

351 mu, sigma = halfnorm.fit(np.abs(values_noise[values_noise < 0]), floc=0) 

352 

353 return mu, sigma 

354 

355 def _match_sigmas_and_recenter(self, *arrays, factor=1): 

356 """Scale array values to match minimum standard deviation across arrays 

357 and recenter noise. 

358 

359 Adjusts values below each array's sigma by scaling and shifting them to 

360 align with the minimum sigma value across all input arrays. This operates 

361 in-place for efficiency. 

362 

363 Parameters 

364 ---------- 

365 *arrays : any number of `numpy.array` 

366 Variable number of input arrays to process. 

367 factor : float, optional 

368 Scaling factor for adjustments (default: 1). 

369 

370 """ 

371 # Calculate standard deviations for all arrays 

372 sigmas = [] 

373 mus = [] 

374 for arr in arrays: 

375 m, s = self._find_normal_stats(arr) 

376 mus.append(m) 

377 sigmas.append(s) 

378 mus = np.array(mus) 

379 sigmas = np.array(sigmas) 

380 

381 # If no sigmas could be determined, return the original 

382 # arrays. 

383 if not np.any(np.isfinite(sigmas)): 

384 return 

385 

386 min_sig = np.min(sigmas) 

387 

388 for mu, sigma, array in zip(mus, sigmas, arrays): 

389 # Identify values below the array's sigma threshold 

390 lower_pos = (array - mu) < sigma 

391 

392 # Skip processing if sigma is invalid 

393 if not np.isfinite(sigma): 

394 continue 

395 

396 # Calculate scaling ratio relative to minimum sigma 

397 sigma_ratio = min_sig / sigma 

398 

399 # Apply adjustment to qualifying values 

400 array[lower_pos] = (array[lower_pos] - mu) * sigma_ratio + min_sig * factor 

401 

402 def run( 

403 self, 

404 images: Mapping[str, Exposure], 

405 image_wcs: Projection[Any] | None = None, 

406 image_box: Box | None = None, 

407 ) -> Struct: 

408 """Turns the input arguments in arguments into an RGB array. 

409 

410 Parameters 

411 ---------- 

412 images : `Mapping` of `str` to `Exposure` 

413 A mapping of input images and the band they correspond to. 

414 image_wcs : `~lsst.images.Projection`, optional 

415 A projection describing the sky coordinate of each pixel. 

416 image_box : `~lsst.images.Box`, optional 

417 A box that defines this image as part of a larger region. 

418 

419 Returns 

420 ------- 

421 result : `Struct` 

422 A struct with the corresponding RGB image, and mask used in 

423 RGB image construction. The struct will have the attributes 

424 outputRGB and outputRGBMask. Each of the outputs will 

425 be a `~lsst.images.ColorImage` object. 

426 

427 Notes 

428 ----- 

429 Construction of input images are made easier by use of the 

430 makeInputsFrom* methods. 

431 """ 

432 channels = {} 

433 shape = (0, 0) 

434 jointMask: None | NDArray = None 

435 maskDict: Mapping[str, int] = {} 

436 doJointMaskInit = False 

437 if jointMask is None: 

438 doJointMask = True 

439 doJointMaskInit = True 

440 for channel, imageExposure in images.items(): 

441 imageArray = imageExposure.image.array 

442 # run all the plugins designed for array based interaction 

443 for plug in plugins.channel(): 

444 imageArray = plug( 

445 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict(), self.config 

446 ).astype(np.float32) 

447 channels[channel] = imageArray 

448 # These operations are trivial look-ups and don't matter if they 

449 # happen in each loop. 

450 shape = imageArray.shape 

451 maskDict = imageExposure.mask.getMaskPlaneDict() 

452 if doJointMaskInit: 

453 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype) 

454 doJointMaskInit = False 

455 if doJointMask: 

456 jointMask |= imageExposure.mask.array 

457 

458 # mix the images to RGB 

459 imageRArray = np.zeros(shape, dtype=np.float32) 

460 imageGArray = np.zeros(shape, dtype=np.float32) 

461 imageBArray = np.zeros(shape, dtype=np.float32) 

462 

463 for band, image in channels.items(): 

464 if band not in self.config.channelConfig: 

465 logger.info(f"{band} image found but not requested in RGB image, skipping") 

466 continue 

467 mix = self.config.channelConfig[band] 

468 if mix.r: 

469 imageRArray += mix.r * image 

470 if mix.g: 

471 imageGArray += mix.g * image 

472 if mix.b: 

473 imageBArray += mix.b * image 

474 

475 exposure = next(iter(images.values())) 

476 box: Box2I = exposure.getBBox() 

477 boxCenter = box.getCenter() 

478 try: 

479 psf = exposure.psf.computeImage(boxCenter).array 

480 except Exception: 

481 psf = None 

482 

483 if self.config.recenterNoise: 

484 self._match_sigmas_and_recenter( 

485 imageRArray, imageGArray, imageBArray, factor=self.config.recenterNoise 

486 ) 

487 

488 # assert for typing reasons 

489 assert jointMask is not None 

490 # Run any image level correction plugins 

491 colorImage = np.zeros((*imageRArray.shape, 3)) 

492 colorImage[:, :, 0] = imageRArray 

493 colorImage[:, :, 1] = imageGArray 

494 colorImage[:, :, 2] = imageBArray 

495 for plug in plugins.partial(): 

496 colorImage = plug(colorImage, jointMask, maskDict, self.config) 

497 

498 # Filter the local contrast parameters for diffusion that are None 

499 # This is so we only apply key word overrides that are specifically set. 

500 local_contrast_config = self.config.localContrastConfig.toDict() 

501 to_remove = [] 

502 for k, v in local_contrast_config["diffusionFunction"].items(): 

503 if v is None: 

504 to_remove.append(k) 

505 for item in to_remove: 

506 local_contrast_config["diffusionControl"].pop(item) 

507 

508 colorImage = lsstRGB( 

509 colorImage[:, :, 0], 

510 colorImage[:, :, 1], 

511 colorImage[:, :, 2], 

512 local_contrast=self.config.localContrastConfig if self.config.doLocalContrast else None, 

513 scale_lum=self.config.luminanceConfig, 

514 scale_color=self.config.colorConfig, 

515 remap_bounds=self.config.imageRemappingConfig, 

516 bracketing_function=( 

517 self.config.exposureBracketerConfig if self.config.doExposureBrackets else None 

518 ), 

519 gamut_remapping_function=self.config.gamutMapperConfig if self.config.doRemapGamut else None, 

520 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore 

521 psf=psf if self.config.doPsfDeconvolve else None, 

522 ) 

523 

524 # Find the dataset type and thus the maximum values as well 

525 maxVal: int | float 

526 match self.config.arrayType: 

527 case "uint8": 

528 dtype = np.uint8 

529 maxVal = 255 

530 case "uint16": 

531 dtype = np.uint16 

532 maxVal = 65535 

533 case "half": 

534 dtype = np.half 

535 maxVal = 1.0 

536 case "float": 

537 dtype = np.float32 

538 maxVal = 1.0 

539 case _: 

540 assert True, "This code path should be unreachable" 

541 

542 # lsstRGB returns an image in 0-1 scale it to the maximum value 

543 colorImage *= maxVal # type: ignore 

544 

545 # pack the joint mask back into a mask object 

546 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict) 

547 lsstMask.array = jointMask # type: ignore 

548 return Struct( 

549 outputRGB=ColorImage(colorImage.astype(dtype), bbox=image_box, projection=image_wcs), 

550 outputRGBMask=lsstMask, 

551 ) # type: ignore 

552 

553 def runQuantum( 

554 self, 

555 butlerQC: QuantumContext, 

556 inputRefs: InputQuantizedConnection, 

557 outputRefs: OutputQuantizedConnection, 

558 ) -> None: 

559 imageRefs: list[DatasetRef] = inputRefs.inputCoadds 

560 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC) 

561 if not sortedImages: 

562 requested = ", ".join(self.config.channelConfig.keys()) 

563 raise NoWorkFound(f"No input images of band(s) {requested}") 

564 

565 # get the patch tract bounding box and wcs 

566 skymap = butlerQC.get(inputRefs.skyMap) 

567 quantumDataId = butlerQC.quantum.dataId 

568 tractInfo = skymap[quantumDataId["tract"]] 

569 patchInfo = tractInfo[quantumDataId["patch"]] 

570 outputs = self.run( 

571 images=sortedImages, 

572 image_wcs=Projection.from_legacy( 

573 patchInfo.wcs, 

574 TractFrame( 

575 skymap=quantumDataId["skymap"], 

576 tract=quantumDataId["tract"], 

577 bbox=Box.from_legacy(tractInfo.bbox), 

578 ), 

579 ), 

580 image_box=Box.from_legacy(patchInfo.getOuterBBox()), 

581 ) 

582 butlerQC.put(outputs, outputRefs) 

583 

584 def makeInputsFromRefs( 

585 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext 

586 ) -> dict[str, Exposure]: 

587 r"""Make valid inputs for the run method from butler references. 

588 

589 Parameters 

590 ---------- 

591 refs : `Iterable` of `DatasetRef` 

592 Some `Iterable` container of `Butler` `DatasetRef`\ s 

593 butler : `Butler` or `QuantumContext` 

594 This is the object that fetches the input data. 

595 

596 Returns 

597 ------- 

598 sortedImages : `dict` of `str` to `Exposure` 

599 A dictionary of `Exposure`\ s keyed by the band they 

600 correspond to. 

601 """ 

602 sortedImages: dict[str, Exposure] = {} 

603 for ref in refs: 

604 key: str = cast(str, ref.dataId["band"]) 

605 image = butler.get(ref) 

606 sortedImages[key] = image 

607 return sortedImages 

608 

609 def makeInputsFromArrays(self, **kwargs) -> dict[str, DeferredDatasetHandle]: 

610 r"""Make valid inputs for the run method from numpy arrays. 

611 

612 Parameters 

613 ---------- 

614 kwargs : `numpy.ndarray` 

615 This is standard python kwargs where the left side of the equals 

616 is the data band, and the right side is the corresponding `numpy.ndarray` 

617 array. 

618 

619 Returns 

620 ------- 

621 sortedImages : `dict` of `str` to \ 

622 `~lsst.daf.butler.DeferredDatasetHandle` 

623 A dictionary of `~lsst.daf.butlger.DeferredDatasetHandle`\ s keyed 

624 by the band they correspond to. 

625 """ 

626 # ignore type because there aren't proper stubs for afw 

627 temp = {} 

628 for key, array in kwargs.items(): 

629 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype) 

630 temp[key].image.array[:] = array 

631 

632 return self.makeInputsFromExposures(**temp) 

633 

634 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]: 

635 r"""Make valid inputs for the run method from `Exposure` objects. 

636 

637 Parameters 

638 ---------- 

639 kwargs : `Exposure` 

640 This is standard python kwargs where the left side of the equals 

641 is the data band, and the right side is the corresponding 

642 `Exposure`. 

643 

644 Returns 

645 ------- 

646 sortedImages : `dict` of `int` to \ 

647 `~lsst.daf.butler.DeferredDatasetHandle` 

648 A dictionary of `~lsst.daf.butler.DeferredDatasetHandle`\ s keyed 

649 by the band they correspond to. 

650 """ 

651 sortedImages = {} 

652 for key, value in kwargs.items(): 

653 sortedImages[key] = value 

654 return sortedImages 

655 

656 

657class PrettyPictureBackgroundFixerConnections( 

658 PipelineTaskConnections, 

659 dimensions=("tract", "patch", "skymap", "band"), 

660 defaultTemplates={"coaddTypeName": "deep"}, 

661): 

662 inputCoadd = Input( 

663 doc=("Input coadd for which the background is to be removed"), 

664 name="{coaddTypeName}CoaddPsfMatched", 

665 storageClass="ExposureF", 

666 dimensions=("tract", "patch", "skymap", "band"), 

667 multiple=True, 

668 ) 

669 outputCoadd = Output( 

670 doc="The coadd with the background fixed and subtracted", 

671 name="pretty_picture_coadd_bg_subtracted", 

672 storageClass="ExposureF", 

673 dimensions=("tract", "patch", "skymap", "band"), 

674 ) 

675 

676 def adjust_all_quanta(self, adjuster: QuantaAdjuster) -> None: 

677 # At this stage of a QG build, we won't necessarily have pruned 

678 # quanta that don't have inputs, so instead of a set of quantum data 

679 # IDs, we want to start from a set of quantum *input* data IDs. We'll 

680 # just ignore quanta that don't; they'll get prune later. 

681 flat_data_ids: set[DataCoordinate] = set() 

682 for quantum_data_id in adjuster.iter_data_ids(): 

683 flat_data_ids.update(adjuster.get_inputs(quantum_data_id)["inputCoadd"]) 

684 # Reorganize task data IDs into 

685 # 

686 # {skymap: {(band, tract): [patches]}}} 

687 # 

688 # It's unlikely we'll ever get more than one skymap in a single QG 

689 # build, of course, but not impossible. 

690 hierarchical_data_ids: dict[str, dict[tuple[str, int], list[int]]] = {} 

691 for quantum_data_id in flat_data_ids: 

692 hierarchical_data_ids.setdefault(quantum_data_id["skymap"], {}).setdefault( 

693 (quantum_data_id["band"], quantum_data_id["tract"]), [] 

694 ).append(quantum_data_id["patch"]) 

695 for skymap_name, data_ids_for_skymap in hierarchical_data_ids.items(): 

696 # We need to load the skyMap to turn single-int patch_id IDs into 

697 # grid (x, y) pairs, so we can offset those to find neighbors. 

698 skyMap = adjuster.butler.get("skyMap", skymap=skymap_name) 

699 for (band, tract_id), patches in data_ids_for_skymap.items(): 

700 tract_info = skyMap[tract_id] 

701 num_x, num_y = tract_info.num_patches.x, tract_info.num_patches.y 

702 base_data_id = DataCoordinate.standardize( 

703 skymap=skymap_name, 

704 tract=tract_id, 

705 band=band, 

706 universe=adjuster.butler.dimensions, 

707 ) 

708 for patch_id in patches: 

709 patch_index: Index2D = tract_info.getPatchIndexPair(patch_id) 

710 quantum_data_id = DataCoordinate.standardize(base_data_id, patch=patch_id) 

711 # Find all adjacent patches (including corner neighbors). 

712 for offset_x, offset_y in itertools.product((-1, 0, 1), (-1, 0, 1)): 

713 if not (offset_x or offset_y): 

714 # Skip the input that matches the quantum data ID; 

715 # it's already got an edge to this quantum. 

716 continue 

717 

718 proposed_patch_x = patch_index.x + offset_x 

719 proposed_patch_y = patch_index.y + offset_y 

720 if ( 

721 proposed_patch_x < 0 

722 or proposed_patch_x > num_x 

723 or proposed_patch_y < 0 

724 or proposed_patch_y > num_y 

725 ): 

726 continue 

727 neighbor_patch_id = tract_info.getSequentialPatchIndexFromPair( 

728 Index2D(x=proposed_patch_x, y=proposed_patch_y) 

729 ) 

730 neighbor_data_id = DataCoordinate.standardize(base_data_id, patch=neighbor_patch_id) 

731 if neighbor_data_id not in flat_data_ids: 

732 # Skip inputs that don't exist, either because 

733 # they didn't have data or because they're just a 

734 # totally invalid patch index. 

735 continue 

736 # Finally we can add an edge from the neighboring 

737 # input to this quantum. 

738 adjuster.add_input(quantum_data_id, "inputCoadd", neighbor_data_id) 

739 

740 

741class PrettyPictureBackgroundFixerConfig( 

742 PipelineTaskConfig, pipelineConnections=PrettyPictureBackgroundFixerConnections 

743): 

744 use_detection_mask = Field[bool]( 

745 doc="Use the detection mask to determine background instead of empirically finding it in this task", 

746 default=False, 

747 ) 

748 num_background_bins = Field[int]( 

749 doc="The number of bins along each axis when determining background", default=5 

750 ) 

751 min_bin_fraction = Field[float]( 

752 doc="Bins with fewer pixels than this fraction of the total will be ignored", default=0.1 

753 ) 

754 

755 pos_sigma_multiplier = Field[float]( 

756 doc="How many sigma to consider as background in the positive direction", default=2 

757 ) 

758 extra_pixel_rad = Field[int]( 

759 doc=( 

760 "If there are neighboring input images consider control points that are radius of input image" 

761 ", (x^2 + x^2)^(1/2), plus this many pixels away from the center of the input image as control" 

762 "points to use." 

763 ), 

764 default=2000, 

765 ) 

766 max_flux_imbalance = Field[float]( 

767 doc="When determining background, if the ratio of counts of positive pixels, to negative pixels" 

768 " passes this threhsold, consider there to be extened low flux and only estimate noise below zero.", 

769 default=1.7, 

770 ) 

771 max_search_flux = Field[float]( 

772 doc="Pixels above this value should never be considered background", default=3 

773 ) 

774 

775 

776class PrettyPictureBackgroundFixerTask(PipelineTask): 

777 """Empirically flatten an images background. 

778 

779 Many astrophysical images have backgrounds with imperfections in them. 

780 This Task attempts to determine control points which are considered 

781 background values, and fits a radial basis function model to those 

782 points. This model is then subtracted off the image. 

783 

784 """ 

785 

786 _DefaultName = "prettyPictureBackgroundFixer" 

787 ConfigClass = PrettyPictureBackgroundFixerConfig 

788 

789 config: ConfigClass 

790 

791 def _tile_slices(self, arr, R, C): 

792 """Generate slices for tiling an array. 

793 

794 This function divides an array into a grid of tiles and returns a list of 

795 slice objects representing each tile. It handles cases where the array 

796 dimensions are not evenly divisible by the number of tiles in each 

797 dimension, distributing the remainder among the tiles. 

798 

799 Parameters 

800 ---------- 

801 arr : `numyp.ndarray` 

802 The input array to be tiled. Used only to determine the array's shape. 

803 R : `int` 

804 The number of tiles in the row dimension. 

805 C : `int` 

806 The number of tiles in the column dimension. 

807 

808 Returns 

809 ------- 

810 slices : `list` of `tuple` 

811 A list of tuples, where each tuple contains two `slice` objects 

812 representing the row and column slices for a single tile. 

813 """ 

814 M = arr.shape[0] 

815 N = arr.shape[1] 

816 

817 # Function to compute slices for a given dimension size and number of divisions 

818 def get_slices(total_size: int, num_divisions: int) -> list[tuple[int, int]]: 

819 """Generate slice ranges for dividing a size into equal parts. 

820 

821 Parameters 

822 ---------- 

823 total_size : `int` 

824 Total size to be divided into slices. 

825 num_divisions : `int` 

826 Number of divisions to create. 

827 

828 Returns 

829 ------- 

830 `list` of `tuple` of `int` 

831 List of (start, end) tuples representing each slice. 

832 

833 Notes 

834 ----- 

835 This function divides the total_size into num_divisions equal parts. 

836 If the division is not exact, the remainder is distributed by adding 

837 1 to the first 'remainder' slices, ensuring balanced distribution. 

838 """ 

839 base = total_size // num_divisions 

840 remainder = total_size % num_divisions 

841 slices = [] 

842 start = 0 

843 for i in range(num_divisions): 

844 end = start + base 

845 if i < remainder: 

846 end += 1 

847 slices.append((start, end)) 

848 start = end 

849 return slices 

850 

851 # Get row and column slices 

852 row_slices = get_slices(M, R) 

853 col_slices = get_slices(N, C) 

854 

855 # Generate all possible tile combinations of row and column slices 

856 tiles = [] 

857 for rs in row_slices: 

858 r_start, r_end = rs 

859 for cs in col_slices: 

860 c_start, c_end = cs 

861 tile_slice = (slice(r_start, r_end), slice(c_start, c_end)) 

862 tiles.append(tile_slice) 

863 

864 return tiles 

865 

866 @staticmethod 

867 def _findImageStatistics(image, pos_sigma_mult=1, max_flux_imbalance=1.7, max_search_flux=3): 

868 """Find pixels that are likely to be background based on image statistics. 

869 

870 This method estimates background pixels by analyzing the distribution of 

871 pixel values in the image. It uses the median as an estimate of the background 

872 level and fits a half-normal distribution to values below the median to 

873 determine the background sigma. Pixels below a threshold (mean + sigma) are 

874 classified as background. 

875 

876 Parameters 

877 ---------- 

878 image : `numpy.ndarray` 

879 Input image array for which to find background pixels. 

880 pos_sigma_mult : `float` 

881 How many sigma to consider as background in the positive direction 

882 max_flux_imbalance : `float` 

883 Limit on the ratio of the area below the determined average to that 

884 above. If the ratio is in excess of this value, then there is 

885 assumed to be diffuse background flux in the image, and zero is 

886 assumed to be the background level. Some excess flux is expected as 

887 there is a real distribution of faint flux. 

888 max_search_flux : `float` 

889 Pixels above this value should never be considered background 

890 """ 

891 # Find the median value in the image, which is likely to be 

892 # close to average background. Note this doesn't work well 

893 # in fields with high density or diffuse flux. 

894 maxLikely = np.median(image[image < max_search_flux], axis=None) 

895 

896 # find all the pixels that are fainter than this 

897 # and find the std. This is just used as an initialization 

898 # parameter and doesn't need to be accurate. 

899 mask = image < maxLikely 

900 initial_std = (image[mask] - maxLikely).std() 

901 

902 # new estimate 

903 sub_image = image[image < max(maxLikely + 3 * initial_std, max_search_flux)] 

904 

905 center = mode(np.round(sub_image, 2)).mode 

906 

907 # Don't do anything if there are no pixels to check 

908 if np.any(mask): 

909 # use a minimizer to determine best sigma for a Gaussian 

910 # given only samples below the mean of the Gaussian. # 

911 _, sigma_hat = halfnorm.fit(np.abs(image[image < center] - center), floc=0) 

912 mu_hat = center 

913 else: 

914 mu_hat, sigma_hat = (maxLikely, 3 * initial_std) 

915 

916 new_cut = image[image < (mu_hat + 3 * sigma_hat)] 

917 positivity_ratio = np.sum(new_cut > mu_hat) / np.sum(new_cut < mu_hat) 

918 

919 if positivity_ratio > max_flux_imbalance: 

920 # This means there is an excess flux, possibly diffuse source, 

921 # assume all flux is diffuse 

922 return np.inf, -np.inf 

923 

924 # create a new masking threshold that's the determined 

925 # mean plus std from the fit 

926 threshold_pos = min(mu_hat + pos_sigma_mult * sigma_hat, max_search_flux) 

927 

928 # The reason a lower threshold is needed is that occasionally for some 

929 # reason we have images with a few random pixels that are REALLY negative. 

930 # We generate control points for background fixing from bins of the pixels we 

931 # consider background. When there are a few pixels that are so negative, it brings 

932 # the estimate way way down, and the control point then drags the background 

933 # unrealistically high. By excluding them we get a more realistic estimate, and 

934 # the only consequence is that these pixels that were excluded will still end up 

935 # negative, and be clipped to black. In principle we could add a lower clipped 

936 # boundary too, but it would not change the results much at all as the estimation 

937 # of the sigma at the low side is robust and we do not expect contaminating flux 

938 # from the low side (by construction of the algorithm). 

939 threshold_neg = mu_hat - pos_sigma_mult * sigma_hat 

940 return threshold_pos, threshold_neg 

941 

942 def findBackgroundPixels(self, image, threshold_pair=None): 

943 """Find pixels that are likely to be background based on image statistics. 

944 

945 This method estimates background pixels by analyzing the distribution of 

946 pixel values in the image. It uses the median as an estimate of the background 

947 level and fits a half-normal distribution to values below the median to 

948 determine the background sigma. Pixels below a threshold (mean + sigma) are 

949 classified as background. 

950 

951 Parameters 

952 ---------- 

953 image : `numpy.ndarray` 

954 Input image array for which to find background pixels. 

955 threshold_pair : `tuple` of `float`, `float` or None 

956 A tuple representing the bottom and top values for which all pixels inside 

957 these bounds should be considered background. If `None` theses bounds are 

958 determined from the image statistics. 

959 

960 Returns 

961 ------- 

962 result : `numpy.ndarray` 

963 Boolean mask array where True indicates background pixels. 

964 

965 Notes 

966 ----- 

967 This method works best for images with relatively uniform background. It may 

968 not perform well in fields with high density or diffuse flux, as noted in 

969 the implementation comments. 

970 """ 

971 if threshold_pair is None: 

972 threshhold_pos, threshhold_neg = self._findImageStatistics( 

973 image, 

974 self.config.pos_sigma_multiplier, 

975 self.config.max_flux_imbalance, 

976 self.config.max_search_flux, 

977 ) 

978 else: 

979 threshhold_pos, threshhold_neg = threshold_pair 

980 if np.isinf(threshhold_pos): 

981 return None 

982 # mean plus std from the fit 

983 image_mask = (image < threshhold_pos) * (image > threshhold_neg) 

984 return image_mask 

985 

986 def _findControlPoints( 

987 self, 

988 exposure: Exposure, 

989 origin: Point2I, 

990 use_detection_mask: bool = False, 

991 threshhold_pair: tuple[float, float] | None = None, 

992 ): 

993 if use_detection_mask: 

994 mask_plane_dict = exposure.mask.getMaskPlaneDict() 

995 image_mask = ~(exposure.mask.array & 2 ** mask_plane_dict["DETECTED"]) 

996 else: 

997 image_mask = self.findBackgroundPixels(exposure.image.array, threshold_pair=threshhold_pair) 

998 

999 yloc = [] 

1000 xloc = [] 

1001 values = [] 

1002 

1003 if image_mask is None: 

1004 logger.debug("returning early from _findControlPoints") 

1005 return values, yloc, xloc 

1006 

1007 tiles = self._tile_slices( 

1008 exposure.image.array, self.config.num_background_bins, self.config.num_background_bins 

1009 ) 

1010 

1011 # adjust for the offset of the origin 

1012 this_origin: Point2I = exposure.getBBox().getBegin() 

1013 offset: Extent2I = this_origin - origin 

1014 x_offset = offset.getX() 

1015 y_offset = offset.getY() 

1016 

1017 # for each box find the middle position and the median background 

1018 # value in the window. 

1019 for i, (xslice, yslice) in enumerate(tiles): 

1020 ypos = (yslice.stop - yslice.start) / 2 + yslice.start 

1021 xpos = (xslice.stop - xslice.start) / 2 + xslice.start 

1022 window = exposure.image.array[yslice, xslice][image_mask[yslice, xslice]] 

1023 # make sure each bin is at least 1% filled 

1024 min_fill = int((yslice.stop - yslice.start) ** 2 * self.config.min_bin_fraction) 

1025 if window.size > min_fill: 

1026 value = np.mean(window) 

1027 else: 

1028 continue 

1029 values.append(value) 

1030 yloc.append(ypos + y_offset) 

1031 xloc.append(xpos + x_offset) 

1032 

1033 return values, yloc, xloc 

1034 

1035 def fixBackground(self, image, yloc: list[float], xloc: list[float], values: list[float]): 

1036 """Estimate and subtract the background from an image. 

1037 

1038 This function estimates the background level in an image using supplied control 

1039 values using Gaussian fitting and radial basis function interpolation. 

1040 It aims to provide a more accurate background estimation than a simple median 

1041 filter, especially in images with varying background levels. 

1042 

1043 Parameters 

1044 ---------- 

1045 image : `numpy.ndarray` 

1046 The input image as a NumPy array. 

1047 yloc : `list` of `float` 

1048 The list of y control points 

1049 xloc : `list` of `float` 

1050 The list of x control points 

1051 values : `list` of `float` 

1052 The list of the values at the control points 

1053 

1054 Returns 

1055 ------- 

1056 numpy.ndarray 

1057 An array representing the estimated background level across the image. 

1058 """ 

1059 

1060 # create an interpolant for the background and interpolate over the image. 

1061 inter = RBFInterpolator( 

1062 np.vstack((yloc, xloc)).T, 

1063 values, 

1064 kernel="thin_plate_spline", 

1065 degree=4, 

1066 smoothing=0.05, 

1067 neighbors=None, 

1068 ) 

1069 

1070 backgrounds = rbf_interpolator.fast_rbf_interpolation_on_grid(inter, image.shape) 

1071 

1072 return backgrounds 

1073 

1074 def run(self, inputCoadd: Exposure, neighbors: list[Exposure] | None = None): 

1075 """Estimate a background for an input Exposure and remove it. 

1076 

1077 Parameters 

1078 ---------- 

1079 inputCoadd : `Exposure` 

1080 The exposure the background will be removed from. 

1081 neighbors : `list` of `Exposure` 

1082 Neighboring `Exposure` objects that can be used to constrain 

1083 backgrounds across boundaries. 

1084 

1085 Returns 

1086 ------- 

1087 result : `Struct` 

1088 A `Struct` that contains the exposure with the background removed. 

1089 This `Struct` will have an attribute named ``outputCoadd``. 

1090 

1091 """ 

1092 origin = inputCoadd.getBBox().getBegin() 

1093 input_y_dim, input_x_dim = inputCoadd.image.array.shape 

1094 input_rad = np.sqrt((input_y_dim / 2) ** 2 + (input_x_dim / 2) ** 2) 

1095 inside_rad = input_rad + self.config.extra_pixel_rad 

1096 center_x = input_x_dim / 2 

1097 center_y = input_y_dim / 2 

1098 

1099 # calculate joint statistics 

1100 if self.config.use_detection_mask is False: 

1101 background_pixels = inputCoadd.image.array[inputCoadd.image.array < self.config.max_search_flux] 

1102 if neighbors: 

1103 for n_exp in neighbors: 

1104 background_pixels = np.append( 

1105 background_pixels, n_exp.image.array[n_exp.image.array < self.config.max_search_flux] 

1106 ) 

1107 joint_thresh = self._findImageStatistics( 

1108 background_pixels, 

1109 self.config.pos_sigma_multiplier, 

1110 self.config.max_flux_imbalance, 

1111 self.config.max_search_flux, 

1112 ) 

1113 # There is no background to be found, return early 

1114 if np.isinf(joint_thresh[0]): 

1115 joint_thresh = None 

1116 

1117 else: 

1118 joint_thresh = None 

1119 

1120 values, yloc, xloc = self._findControlPoints( 

1121 inputCoadd, origin, self.config.use_detection_mask, joint_thresh 

1122 ) 

1123 

1124 if len(values) == 0: 

1125 output = ExposureF(inputCoadd, deep=True) 

1126 logger.warning( 

1127 "No control points could be determined, likely due to extended flux, leaving background" 

1128 " unmodified." 

1129 ) 

1130 return Struct(outputCoadd=output) 

1131 

1132 if neighbors: 

1133 for n_exp in neighbors: 

1134 tmp_values, tmp_yloc, tmp_xloc = self._findControlPoints( 

1135 n_exp, origin, self.config.use_detection_mask, joint_thresh 

1136 ) 

1137 

1138 for value, y_pos, x_pos in zip(tmp_values, tmp_yloc, tmp_xloc): 

1139 if np.sqrt((y_pos - center_y) ** 2 + (x_pos - center_x) ** 2) < inside_rad: 

1140 values.append(value) 

1141 yloc.append(y_pos) 

1142 xloc.append(x_pos) 

1143 

1144 # At least 15 points are requred for TPS with 4th order polynomial 

1145 if len(yloc) < 15: 

1146 logger.warning( 

1147 "Not enough control points could be determined, likely due to extended flux, leaving" 

1148 " background unmodified." 

1149 ) 

1150 return Struct(outputCoadd=inputCoadd) 

1151 

1152 background = self.fixBackground(inputCoadd.image.array, yloc, xloc, values) 

1153 

1154 # create a copy to mutate 

1155 output = ExposureF(inputCoadd, deep=True) 

1156 output.image.array -= background 

1157 return Struct(outputCoadd=output) 

1158 

1159 def runQuantum( 

1160 self, 

1161 butlerQC: QuantumContext, 

1162 inputRefs: InputQuantizedConnection, 

1163 outputRefs: OutputQuantizedConnection, 

1164 ) -> None: 

1165 quantum_patch = butlerQC.quantum.dataId["patch"] 

1166 primary_ref = None 

1167 neighbor_refs = [] 

1168 for ref in inputRefs.inputCoadd: 

1169 if quantum_patch == ref.dataId["patch"]: 

1170 primary_ref = ref 

1171 else: 

1172 neighbor_refs.append(ref) 

1173 if primary_ref is None: 

1174 # This should be unreachable 

1175 raise RuntimeError( 

1176 "There is a major problem, the input ref associated with this quantum can't be found." 

1177 ) 

1178 inputCoadd = butlerQC.get(primary_ref) 

1179 neighbors = [] 

1180 for n_ref in neighbor_refs: 

1181 neighbors.append(butlerQC.get(n_ref)) 

1182 results = self.run(inputCoadd, neighbors=neighbors if neighbors else None) 

1183 butlerQC.put(results, outputRefs) 

1184 

1185 

1186class PrettyPictureStarFixerConnections( 

1187 PipelineTaskConnections, 

1188 dimensions=("tract", "patch", "skymap"), 

1189): 

1190 inputCoadd = Input( 

1191 doc=("Input coadd for which the background is to be removed"), 

1192 name="pretty_picture_coadd_bg_subtracted", 

1193 storageClass="ExposureF", 

1194 dimensions=("tract", "patch", "skymap", "band"), 

1195 multiple=True, 

1196 ) 

1197 outputCoadd = Output( 

1198 doc="The coadd with the background fixed and subtracted", 

1199 name="pretty_picture_coadd_fixed_stars", 

1200 storageClass="ExposureF", 

1201 dimensions=("tract", "patch", "skymap", "band"), 

1202 multiple=True, 

1203 ) 

1204 

1205 

1206class PrettyPictureStarFixerConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureStarFixerConnections): 

1207 brightnessThresh = Field[float]( 

1208 doc="The flux value below which pixels with SAT or NO_DATA bits will be ignored" 

1209 ) 

1210 

1211 

1212class PrettyPictureStarFixerTask(PipelineTask): 

1213 """This class fixes up regions in an image where there is no, or bad data. 

1214 

1215 The fixes done by this task are overwhelmingly comprised of the cores of 

1216 bright stars for which there is no data. 

1217 """ 

1218 

1219 _DefaultName = "prettyPictureStarFixer" 

1220 ConfigClass = PrettyPictureStarFixerConfig 

1221 

1222 config: ConfigClass 

1223 

1224 def run(self, inputs: Mapping[str, ExposureF]) -> Struct: 

1225 """Fix areas in an image where this is no data, most likely to be 

1226 the cores of bright stars. 

1227 

1228 Because we want to have consistent fixes accross bands, this method 

1229 relies on supplying all bands and fixing pixels that are marked 

1230 as having a defect in any band even if within one band there is 

1231 no issue. 

1232 

1233 Parameters 

1234 ---------- 

1235 inputs : `Mapping` of `str` to `ExposureF` 

1236 This mapping has keys of band as a `str` and the corresponding 

1237 ExposureF as a value. 

1238 

1239 Returns 

1240 ------- 

1241 results : `Struct` of `Mapping` of `str` to `ExposureF` 

1242 A `Struct` that has a mapping of band to `ExposureF`. The `Struct` 

1243 has an attribute named ``results``. 

1244 

1245 """ 

1246 # make the joint mask of all the channels 

1247 doJointMaskInit = True 

1248 for imageExposure in inputs.values(): 

1249 maskDict = imageExposure.mask.getMaskPlaneDict() 

1250 if doJointMaskInit: 

1251 jointMask = np.zeros(imageExposure.mask.array.shape, dtype=imageExposure.mask.array.dtype) 

1252 doJointMaskInit = False 

1253 jointMask |= imageExposure.mask.array 

1254 

1255 sat_bit = maskDict["SAT"] 

1256 no_data_bit = maskDict["NO_DATA"] 

1257 together = (jointMask & 2**sat_bit).astype(bool) | (jointMask & 2**no_data_bit).astype(bool) 

1258 

1259 # use the last imageExposure as it is likely close enough across all bands 

1260 bright_mask = imageExposure.image.array > self.config.brightnessThresh 

1261 

1262 # dilate the mask a bit, this helps get a bit fainter mask without starting 

1263 # to include pixels in an irregular shape, as only the star cores should be 

1264 # fixed. 

1265 both = together & bright_mask 

1266 struct = np.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=bool) 

1267 both = binary_dilation(both, struct, iterations=4).astype(bool) 

1268 

1269 # do the actual fixing of values 

1270 results = {} 

1271 for band, imageExposure in inputs.items(): 

1272 if np.sum(both) > 0: 

1273 inpainted = inpaint_biharmonic(imageExposure.image.array, both, split_into_regions=True) 

1274 imageExposure.image.array[both] = inpainted[both] 

1275 results[band] = imageExposure 

1276 return Struct(results=results) 

1277 

1278 def runQuantum( 

1279 self, 

1280 butlerQC: QuantumContext, 

1281 inputRefs: InputQuantizedConnection, 

1282 outputRefs: OutputQuantizedConnection, 

1283 ) -> None: 

1284 refs = inputRefs.inputCoadd 

1285 sortedImages: dict[str, Exposure] = {} 

1286 for ref in refs: 

1287 key: str = cast(str, ref.dataId["band"]) 

1288 image = butlerQC.get(ref) 

1289 sortedImages[key] = image 

1290 

1291 outputs = self.run(sortedImages).results 

1292 sortedOutputs = {} 

1293 for ref in outputRefs.outputCoadd: 

1294 sortedOutputs[ref.dataId["band"]] = ref 

1295 

1296 for band, data in outputs.items(): 

1297 butlerQC.put(data, sortedOutputs[band]) 

1298 

1299 

1300class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")): 

1301 inputRGB = Input( 

1302 doc="Individual RGB images that are to go into the mosaic", 

1303 name="rgb_picture", 

1304 storageClass="ColorImage", 

1305 dimensions=("tract", "patch", "skymap"), 

1306 multiple=True, 

1307 deferLoad=True, 

1308 ) 

1309 

1310 skyMap = Input( 

1311 doc="The skymap which the data has been mapped onto", 

1312 storageClass="SkyMap", 

1313 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

1314 dimensions=("skymap",), 

1315 ) 

1316 

1317 inputRGBMask = Input( 

1318 doc="Individual RGB images that are to go into the mosaic", 

1319 name="rgb_picture_mask", 

1320 storageClass="Mask", 

1321 dimensions=("tract", "patch", "skymap"), 

1322 multiple=True, 

1323 deferLoad=True, 

1324 ) 

1325 

1326 outputRGBMosaic = Output( 

1327 doc="A RGB mosaic created from the input data stored as a 3d array", 

1328 name="rgb_mosaic", 

1329 storageClass="ColorImage", 

1330 dimensions=("tract", "skymap"), 

1331 ) 

1332 

1333 

1334class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections): 

1335 binFactor = Field[int](doc="The factor to bin by when producing the mosaic", default=1) 

1336 doDCID65Convert = Field[bool]( 

1337 "Force the output to be converted from display p3 to DCI-D65 colorspace.", default=False 

1338 ) 

1339 useLocalTemp = Field[bool](doc="Use the current directory when creating local temp files.", default=False) 

1340 

1341 

1342class PrettyMosaicTask(PipelineTask): 

1343 """Combines multiple RGB arrays into one mosaic.""" 

1344 

1345 _DefaultName = "prettyMosaic" 

1346 ConfigClass = PrettyMosaicConfig 

1347 

1348 config: ConfigClass 

1349 

1350 def run( 

1351 self, 

1352 inputRGB: Iterable[DeferredDatasetHandle], 

1353 skyMap: BaseSkyMap, 

1354 inputRGBMask: Iterable[DeferredDatasetHandle], 

1355 ) -> Struct: 

1356 r"""Assemble individual `numpy.ndarrays` into a mosaic. 

1357 

1358 Each input is a `~lsst.daf.butler.DeferredDatasetHandle` because 

1359 they're loaded in one at a time to be placed into the mosaic to save 

1360 memory. 

1361 

1362 Parameters 

1363 ---------- 

1364 inputRGB : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle` 

1365 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to RGB 

1366 `numpy.ndarrays`. 

1367 skyMap : `BaseSkyMap` 

1368 The skymap that defines the relative position of each of the input 

1369 images. 

1370 inputRGBMask : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle` 

1371 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to masks for 

1372 each of the corresponding images. 

1373 

1374 Returns 

1375 ------- 

1376 result : `Struct` 

1377 The `Struct` containing the combined mosaic. The `Struct` has 

1378 and attribute named ``outputRGBMosaic``. 

1379 """ 

1380 # create the bounding region 

1381 newBox = Box2I() 

1382 # store the bounds as they are retrieved from the skymap 

1383 boxes = [] 

1384 tractMaps = [] 

1385 for handle in inputRGB: 

1386 dataId = handle.dataId 

1387 tractInfo: TractInfo = skyMap[dataId["tract"]] 

1388 patchInfo: PatchInfo = tractInfo[dataId["patch"]] 

1389 bbox = patchInfo.getOuterBBox() 

1390 boxes.append(bbox) 

1391 newBox.include(bbox) 

1392 tractMaps.append(tractInfo) 

1393 # This will be overwritten in the loop, but that is ok, because 

1394 # it is the same for each patch. 

1395 patch_grow: int = patchInfo.getCellInnerDimensions().getX() 

1396 

1397 # fixup the boxes to be smaller if needed, and put the origin at zero, 

1398 # this must be done after constructing the complete outer box 

1399 modifiedBoxes = [] 

1400 origin = newBox.getBegin() 

1401 for iterBox in boxes: 

1402 localOrigin = iterBox.getBegin() - origin 

1403 localOrigin = Point2I( 

1404 x=int(np.floor(localOrigin.x / self.config.binFactor)), 

1405 y=int(np.floor(localOrigin.y / self.config.binFactor)), 

1406 ) 

1407 localExtent = Extent2I( 

1408 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)), 

1409 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)), 

1410 ) 

1411 tmpBox = Box2I(localOrigin, localExtent) 

1412 modifiedBoxes.append(tmpBox) 

1413 boxes = modifiedBoxes 

1414 

1415 # scale the container box 

1416 newBoxOrigin = Point2I(0, 0) 

1417 newBoxExtent = Extent2I( 

1418 x=int(np.floor(newBox.getWidth() / self.config.binFactor)), 

1419 y=int(np.floor(newBox.getHeight() / self.config.binFactor)), 

1420 ) 

1421 newBox = Box2I(newBoxOrigin, newBoxExtent) 

1422 

1423 # Allocate storage for the mosaic 

1424 self.imageHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None) 

1425 self.maskHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None) 

1426 consolidatedImage = None 

1427 consolidatedMask = None 

1428 

1429 # Setup color space conversion in case they are used. 

1430 d65 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DCI_P3) 

1431 dp3 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DISPLAY_P3) 

1432 d65.whitepoint = dp3.whitepoint 

1433 d65.whitepoint_name = dp3.whitepoint_name 

1434 

1435 # Actually assemble the mosaic 

1436 maskDict = {} 

1437 mosaic_maker = FeatheredMosaicCreator(patch_grow, self.config.binFactor) 

1438 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps): 

1439 rgb = handle.get().array 

1440 # convert to the dci-d65 colorspace 

1441 if self.config.doDCID65Convert: 

1442 rgb = colour.RGB_to_RGB(np.clip(rgb, 0, 1), dp3, d65) 

1443 rgbMask = handleMask.get() 

1444 maskDict = rgbMask.getMaskPlaneDict() 

1445 # allocate the memory for the mosaic 

1446 if consolidatedImage is None: 

1447 consolidatedImage = np.memmap( 

1448 self.imageHandle.name, 

1449 mode="w+", 

1450 shape=(newBox.getHeight(), newBox.getWidth(), 3), 

1451 dtype=rgb.dtype, 

1452 ) 

1453 if consolidatedMask is None: 

1454 consolidatedMask = np.memmap( 

1455 self.maskHandle.name, 

1456 mode="w+", 

1457 shape=(newBox.getHeight(), newBox.getWidth()), 

1458 dtype=rgbMask.array.dtype, 

1459 ) 

1460 

1461 if self.config.binFactor > 1: 

1462 # opencv wants things in x, y dimensions 

1463 shape = tuple(box.getDimensions())[::-1] 

1464 rgb = cv2.resize( 

1465 rgb, 

1466 dst=None, 

1467 dsize=shape, 

1468 fx=shape[0] / self.config.binFactor, 

1469 fy=shape[1] / self.config.binFactor, 

1470 ) 

1471 mask_array = rgbMask.array[:: self.config.binFactor, :: self.config.binFactor] 

1472 rgbMask = Mask(*(mask_array.shape[::-1])) 

1473 mosaic_maker.add_to_image(consolidatedImage, rgb, newBox, box, reverse=False) 

1474 

1475 consolidatedMask[*box.slices] = np.bitwise_or(consolidatedMask[*box.slices], rgbMask.array) 

1476 

1477 for plugin in plugins.full(): 

1478 if consolidatedImage is not None and consolidatedMask is not None: 

1479 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict) 

1480 # If consolidated image still None, that means there was no work to do. 

1481 # Return an empty image instead of letting this task fail. 

1482 if consolidatedImage is None: 

1483 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8) 

1484 

1485 return Struct(outputRGBMosaic=ColorImage(consolidatedImage)) 

1486 

1487 def runQuantum( 

1488 self, 

1489 butlerQC: QuantumContext, 

1490 inputRefs: InputQuantizedConnection, 

1491 outputRefs: OutputQuantizedConnection, 

1492 ) -> None: 

1493 inputs = butlerQC.get(inputRefs) 

1494 outputs = self.run(**inputs) 

1495 butlerQC.put(outputs, outputRefs) 

1496 if hasattr(self, "imageHandle"): 

1497 self.imageHandle.close() 

1498 if hasattr(self, "maskHandle"): 

1499 self.maskHandle.close() 

1500 

1501 def makeInputsFromArrays( 

1502 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]] 

1503 ) -> Iterable[DeferredDatasetHandle]: 

1504 r"""Make valid inputs for the run method from numpy arrays. 

1505 

1506 Parameters 

1507 ---------- 

1508 inputs : `Iterable` of `tuple` of `Mapping` and `numpy.ndarray` 

1509 An iterable where each element is a tuple with the first 

1510 element is a mapping that corresponds to an arrays dataId, 

1511 and the second is an `numpy.ndarray`. 

1512 

1513 Returns 

1514 ------- 

1515 sortedImages : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle` 

1516 An iterable of `~lsst.daf.butler.DeferredDatasetHandle`\ s 

1517 containing the input data. 

1518 """ 

1519 structuredInputs = [] 

1520 for dataId, array in inputs: 

1521 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId)) 

1522 

1523 return structuredInputs