lsst.pipe.tasks gef5401d743+4408856ac0
Loading...
Searching...
No Matches
_task.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = (
25 "ChannelRGBConfig",
26 "PrettyPictureTask",
27 "PrettyPictureConnections",
28 "PrettyPictureConfig",
29 "PrettyMosaicTask",
30 "PrettyMosaicConnections",
31 "PrettyMosaicConfig",
32 "PrettyPictureBackgroundFixerConfig",
33 "PrettyPictureBackgroundFixerTask",
34 "PrettyPictureStarFixerConfig",
35 "PrettyPictureStarFixerTask",
36)
37
38import colour
39import copy
40import itertools
41from collections.abc import Iterable, Mapping
42from lsst.afw.image import ExposureF
43import numpy as np
44from typing import TYPE_CHECKING, cast, Any
45from lsst.skymap import BaseSkyMap
46
47from scipy.stats import halfnorm, mode
48from scipy.ndimage import binary_dilation
49from scipy.interpolate import RBFInterpolator
50from skimage.restoration import inpaint_biharmonic
51
52from lsst.daf.butler import Butler, DataCoordinate, DeferredDatasetHandle
53from lsst.daf.butler import DatasetRef
54from lsst.images import ColorImage, Projection, Box, TractFrame
55from lsst.pex.config import Field, Config, ConfigDictField, ListField, ChoiceField
56from lsst.pex.config.configurableActions import ConfigurableActionField
57from lsst.pipe.base import (
58 PipelineTask,
59 PipelineTaskConfig,
60 PipelineTaskConnections,
61 Struct,
62 InMemoryDatasetHandle,
63 NoWorkFound,
64 QuantaAdjuster,
65)
66from lsst.rubinoxide import rbf_interpolator
67import cv2
68
69from lsst.pipe.base.connectionTypes import Input, Output
70from lsst.geom import Box2I, Point2I, Extent2I
71from lsst.afw.image import Exposure, Mask
72from lsst.skymap import Index2D
73
74from ._plugins import plugins
75from ._colorMapper import lsstRGB
76from ._utils import FeatheredMosaicCreator
77from ._functors import (
78 BoundsRemapper,
79 ColorScaler,
80 LumCompressor,
81 ExposureBracketer,
82 GamutFixer,
83 LocalContrastEnhancer,
84)
85
86import logging
87import tempfile
88
89logger = logging.getLogger(__name__)
90
91if TYPE_CHECKING:
92 from numpy.typing import NDArray
93 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
94 from lsst.skymap import TractInfo, PatchInfo
95
96
98 PipelineTaskConnections,
99 dimensions={"tract", "patch", "skymap"},
100 defaultTemplates={"coaddTypeName": "deep"},
101):
102 inputCoadds = Input(
103 doc=(
104 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
105 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
106 ),
107 name="pretty_coadd",
108 storageClass="ExposureF",
109 dimensions=("tract", "patch", "skymap", "band"),
110 multiple=True,
111 )
112
113 skyMap = Input(
114 doc="The skymap which the data has been mapped onto",
115 storageClass="SkyMap",
116 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
117 dimensions=("skymap",),
118 )
119
120 outputRGB = Output(
121 doc="A RGB image created from the input data stored as a 3d array",
122 name="rgb_picture",
123 storageClass="ColorImage",
124 dimensions=("tract", "patch", "skymap"),
125 )
126
127 outputRGBMask = Output(
128 doc="A Mask corresponding to the fused masks of the input channels",
129 name="rgb_picture_mask",
130 storageClass="Mask",
131 dimensions=("tract", "patch", "skymap"),
132 )
133
134
135class ChannelRGBConfig(Config):
136 """This describes the rgb values of a given input channel.
137
138 For instance if this channel is red the values would be self.r = 1,
139 self.g = 0, self.b = 0. If the channel was cyan the values would be
140 self.r = 0, self.g = 1, self.b = 1.
141 """
142
143 r = Field[float](doc="The amount of red contained in this channel")
144 g = Field[float](doc="The amount of green contained in this channel")
145 b = Field[float](doc="The amount of blue contained in this channel")
146
147
148class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
149 channelConfig = ConfigDictField(
150 doc="A dictionary that maps band names to their rgb channel configurations",
151 keytype=str,
152 itemtype=ChannelRGBConfig,
153 default={},
154 )
155 cieWhitePoint = ListField[float](
156 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
157 )
158 arrayType = ChoiceField[str](
159 doc="The dataset type for the output image array",
160 default="uint8",
161 allowed={
162 "uint8": "Use 8 bit arrays, 255 max",
163 "uint16": "Use 16 bit arrays, 65535 max",
164 "half": "Use 16 bit float arrays, 1 max",
165 "float": "Use 32 bit float arrays, 1 max",
166 },
167 )
168 recenterNoise = Field[float](
169 doc="Recenter the noise away from zero. Supplied value is in units of sigma",
170 optional=True,
171 default=None,
172 )
173 noiseSearchThreshold = Field[float](
174 doc=(
175 "Flux threshold below which most flux will be considered noise, used to estimate noise properties"
176 ),
177 default=2,
178 )
179 maxNoiseImbalance = Field[float](
180 doc=(
181 "When recentering noise, if the ratio of counts of positive pixels, to negative pixels passes "
182 "this threshold, consider there to be extended low flux and only estimate noise below zero."
183 ),
184 default=1.5,
185 )
186 doPsfDeconvolve = Field[bool](
187 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", default=False
188 )
189 doPSFDeconcovlve = Field[bool](
190 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.",
191 default=False,
192 deprecated="This field will be removed in v32. Use doPsfDeconvolve instead.",
193 optional=True,
194 )
195 doRemapGamut = Field[bool](
196 doc="Apply a color correction to unrepresentable colors; if False, clip them.", default=True
197 )
198 doExposureBrackets = Field[bool](
199 doc="Apply exposure bracketing to aid in dynamic range compression", default=True
200 )
201 doLocalContrast = Field[bool](doc="Apply local contrast optimizations to luminance.", default=True)
202
203 imageRemappingConfig = ConfigurableActionField[BoundsRemapper](
204 doc="Action controlling normalization process"
205 )
206 luminanceConfig = ConfigurableActionField[LumCompressor](
207 doc="Action controlling luminance scaling when making an RGB image"
208 )
209 localContrastConfig = ConfigurableActionField[LocalContrastEnhancer](
210 doc="Action controlling the local contrast correction in RGB image production"
211 )
212 colorConfig = ConfigurableActionField[ColorScaler](
213 doc="Action to control the color scaling process in RGB image production"
214 )
215 exposureBracketerConfig = ConfigurableActionField[ExposureBracketer](
216 doc=(
217 "Exposure scaling action used in creating multiple exposures with different scalings which will "
218 "then be fused into a final image"
219 ),
220 )
221 gamutMapperConfig = ConfigurableActionField[GamutFixer](
222 doc="Action to fix pixels which lay outside RGB color gamut"
223 )
224
225 exposureBrackets = ListField[float](
226 doc=(
227 "Exposure scaling factors used in creating multiple exposures with different scalings which will "
228 "then be fused into a final image"
229 ),
230 optional=True,
231 default=[1.25, 1, 0.75],
232 deprecated=(
233 "This field will stop working in v31 and be removed in v32, "
234 "please set exposureBracketerConfig.exposureBrackets"
235 ),
236 )
237 gamutMethod = ChoiceField[str](
238 doc="If doRemapGamut is True this determines the method",
239 default="inpaint",
240 allowed={
241 "mapping": "Use a mapping function",
242 "inpaint": "Use surrounding pixels to determine likely value",
243 },
244 deprecated="This field will stop working in v31 and be removed in v32, please set gamutMapperConfig",
245 )
246
247 def setDefaults(self):
248 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0)
249 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0)
250 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1)
251 return super().setDefaults()
252
253 def _handle_deprecated(self):
254 """Handle deprecated configuration migration.
255
256 This method migrates deprecated configuration fields to their new
257 locations in sub-configurations. It checks the configuration history
258 to determine if deprecated fields were explicitly set and updates
259 the new configuration locations accordingly.
260
261 Notes
262 -----
263 The following deprecated fields are migrated:
264 - ``gamutMethod`` -> ``gamutMapperConfig.gamutMethod``
265 - ``exposureBrackets`` -> ``exposureBracketerConfig.exposureBrackets``
266 - ``doLocalContrast`` -> ``localContrastConfig.doLocalContrast``
267 - ``doPSFDeconcovlve`` -> ``doPsfDeconvolve``
268 """
269 # check if gamutMethod is set
270 if len(self._history["gamutMethod"]) > 1:
271 # This has been set in config, update it in the new location
272 self.gamutMapperConfig.gamutMethod = self.gamutMethod
273
274 if len(self._history["exposureBrackets"]) > 1:
275 self.exposureBracketerConfig.exposureBrackets = self.exposureBrackets
276 if self.exposureBrackets is None:
277 self.doExposureBrackets = False
278
279 if len(self.localContrastConfig._history["doLocalContrast"]) > 1:
280 self.doLocalContrast = self.localContrastConfig.doLocalContrast
281
282 # Handle doPsfDeconcovlve typo fix
283 if len(self._history["doPSFDeconcovlve"]) > 1:
284 self.doPsfDeconvolve = self.doPSFDeconcovlve
285
286 def freeze(self):
287 # ensure this is not already frozen
288 if self._frozen is not True:
289 self._handle_deprecated()
290 super().freeze()
291
292
293class PrettyPictureTask(PipelineTask):
294 """Turns inputs into an RGB image."""
295
296 _DefaultName = "prettyPicture"
297 ConfigClass = PrettyPictureConfig
298
299 config: ConfigClass
300
301 def _find_normal_stats(self, array):
302 """Calculate standard deviation from negative values using half-normal distribution.
303
304 Raises
305 ------
306 ValueError
307 Array dimension validation fails.
308
309 Parameters
310 ----------
311 array : `numpy.array`
312 Input array of numerical values.
313
314 Returns
315 -------
316 mean : `float`
317 The central moment of the distribution
318 sigma : `float`
319 Estimated standard deviation from negative values. Returns np.inf if:
320 - No negative values exist in the array
321 - Half-normal fitting fails
322 """
323 # Extract negative values efficiently
324 values_noise = array[array < self.config.noiseSearchThreshold]
325
326 # find the mode
327 center = mode(np.round(values_noise, 2)).mode
328
329 # extract the negative values
330 values_neg = array[array < center]
331
332 # Return infinity if no negative values found
333 if values_neg.size == 0:
334 return 0, np.inf
335
336 try:
337 # Fit half-normal distribution to absolute negative values
338 _, sigma = halfnorm.fit(np.abs(values_neg - center), floc=0)
339 mu = center
340 except (ValueError, RuntimeError):
341 # Handle fitting failures (e.g., constant data, optimization issues)
342 return 0, np.inf
343
344 # examine for excess positive flux, this means there is contaminating signal
345 new_cut = array[array < (mu + 3 * sigma)]
346 positivity_ratio = np.sum(new_cut > mu) / np.sum(new_cut < mu)
347
348 if positivity_ratio > self.config.maxNoiseImbalance:
349 # This means there is an excess flux, possibly diffuse source,
350 # only estimate around zero.
351 mu, sigma = halfnorm.fit(np.abs(values_noise[values_noise < 0]), floc=0)
352
353 return mu, sigma
354
355 def _match_sigmas_and_recenter(self, *arrays, factor=1):
356 """Scale array values to match minimum standard deviation across arrays
357 and recenter noise.
358
359 Adjusts values below each array's sigma by scaling and shifting them to
360 align with the minimum sigma value across all input arrays. This operates
361 in-place for efficiency.
362
363 Parameters
364 ----------
365 *arrays : any number of `numpy.array`
366 Variable number of input arrays to process.
367 factor : float, optional
368 Scaling factor for adjustments (default: 1).
369
370 """
371 # Calculate standard deviations for all arrays
372 sigmas = []
373 mus = []
374 for arr in arrays:
375 m, s = self._find_normal_stats(arr)
376 mus.append(m)
377 sigmas.append(s)
378 mus = np.array(mus)
379 sigmas = np.array(sigmas)
380
381 # If no sigmas could be determined, return the original
382 # arrays.
383 if not np.any(np.isfinite(sigmas)):
384 return
385
386 min_sig = np.min(sigmas)
387
388 for mu, sigma, array in zip(mus, sigmas, arrays):
389 # Identify values below the array's sigma threshold
390 lower_pos = (array - mu) < sigma
391
392 # Skip processing if sigma is invalid
393 if not np.isfinite(sigma):
394 continue
395
396 # Calculate scaling ratio relative to minimum sigma
397 sigma_ratio = min_sig / sigma
398
399 # Apply adjustment to qualifying values
400 array[lower_pos] = (array[lower_pos] - mu) * sigma_ratio + min_sig * factor
401
402 def run(
403 self,
404 images: Mapping[str, Exposure],
405 image_wcs: Projection[Any] | None = None,
406 image_box: Box | None = None,
407 ) -> Struct:
408 """Turns the input arguments in arguments into an RGB array.
409
410 Parameters
411 ----------
412 images : `Mapping` of `str` to `Exposure`
413 A mapping of input images and the band they correspond to.
414 image_wcs : `~lsst.images.Projection`, optional
415 A projection describing the sky coordinate of each pixel.
416 image_box : `~lsst.images.Box`, optional
417 A box that defines this image as part of a larger region.
418
419 Returns
420 -------
421 result : `Struct`
422 A struct with the corresponding RGB image, and mask used in
423 RGB image construction. The struct will have the attributes
424 outputRGB and outputRGBMask. Each of the outputs will
425 be a `~lsst.images.ColorImage` object.
426
427 Notes
428 -----
429 Construction of input images are made easier by use of the
430 makeInputsFrom* methods.
431 """
432 channels = {}
433 shape = (0, 0)
434 jointMask: None | NDArray = None
435 maskDict: Mapping[str, int] = {}
436 doJointMaskInit = False
437 if jointMask is None:
438 doJointMask = True
439 doJointMaskInit = True
440 for channel, imageExposure in images.items():
441 imageArray = imageExposure.image.array
442 # run all the plugins designed for array based interaction
443 for plug in plugins.channel():
444 imageArray = plug(
445 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict(), self.config
446 ).astype(np.float32)
447 channels[channel] = imageArray
448 # These operations are trivial look-ups and don't matter if they
449 # happen in each loop.
450 shape = imageArray.shape
451 maskDict = imageExposure.mask.getMaskPlaneDict()
452 if doJointMaskInit:
453 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
454 doJointMaskInit = False
455 if doJointMask:
456 jointMask |= imageExposure.mask.array
457
458 # mix the images to RGB
459 imageRArray = np.zeros(shape, dtype=np.float32)
460 imageGArray = np.zeros(shape, dtype=np.float32)
461 imageBArray = np.zeros(shape, dtype=np.float32)
462
463 for band, image in channels.items():
464 if band not in self.config.channelConfig:
465 logger.info(f"{band} image found but not requested in RGB image, skipping")
466 continue
467 mix = self.config.channelConfig[band]
468 if mix.r:
469 imageRArray += mix.r * image
470 if mix.g:
471 imageGArray += mix.g * image
472 if mix.b:
473 imageBArray += mix.b * image
474
475 exposure = next(iter(images.values()))
476 box: Box2I = exposure.getBBox()
477 boxCenter = box.getCenter()
478 try:
479 psf = exposure.psf.computeImage(boxCenter).array
480 except Exception:
481 psf = None
482
483 if self.config.recenterNoise:
484 self._match_sigmas_and_recenter(
485 imageRArray, imageGArray, imageBArray, factor=self.config.recenterNoise
486 )
487
488 # assert for typing reasons
489 assert jointMask is not None
490 # Run any image level correction plugins
491 colorImage = np.zeros((*imageRArray.shape, 3))
492 colorImage[:, :, 0] = imageRArray
493 colorImage[:, :, 1] = imageGArray
494 colorImage[:, :, 2] = imageBArray
495 for plug in plugins.partial():
496 colorImage = plug(colorImage, jointMask, maskDict, self.config)
497
498 # Filter the local contrast parameters for diffusion that are None
499 # This is so we only apply key word overrides that are specifically set.
500 local_contrast_config = self.config.localContrastConfig.toDict()
501 to_remove = []
502 for k, v in local_contrast_config["diffusionFunction"].items():
503 if v is None:
504 to_remove.append(k)
505 for item in to_remove:
506 local_contrast_config["diffusionControl"].pop(item)
507
508 colorImage = lsstRGB(
509 colorImage[:, :, 0],
510 colorImage[:, :, 1],
511 colorImage[:, :, 2],
512 local_contrast=self.config.localContrastConfig if self.config.doLocalContrast else None,
513 scale_lum=self.config.luminanceConfig,
514 scale_color=self.config.colorConfig,
515 remap_bounds=self.config.imageRemappingConfig,
516 bracketing_function=(
517 self.config.exposureBracketerConfig if self.config.doExposureBrackets else None
518 ),
519 gamut_remapping_function=self.config.gamutMapperConfig if self.config.doRemapGamut else None,
520 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore
521 psf=psf if self.config.doPsfDeconvolve else None,
522 )
523
524 # Find the dataset type and thus the maximum values as well
525 maxVal: int | float
526 match self.config.arrayType:
527 case "uint8":
528 dtype = np.uint8
529 maxVal = 255
530 case "uint16":
531 dtype = np.uint16
532 maxVal = 65535
533 case "half":
534 dtype = np.half
535 maxVal = 1.0
536 case "float":
537 dtype = np.float32
538 maxVal = 1.0
539 case _:
540 assert True, "This code path should be unreachable"
541
542 # lsstRGB returns an image in 0-1 scale it to the maximum value
543 colorImage *= maxVal # type: ignore
544
545 # pack the joint mask back into a mask object
546 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
547 lsstMask.array = jointMask # type: ignore
548 return Struct(
549 outputRGB=ColorImage(colorImage.astype(dtype), bbox=image_box, projection=image_wcs),
550 outputRGBMask=lsstMask,
551 ) # type: ignore
552
554 self,
555 butlerQC: QuantumContext,
556 inputRefs: InputQuantizedConnection,
557 outputRefs: OutputQuantizedConnection,
558 ) -> None:
559 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
560 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
561 if not sortedImages:
562 requested = ", ".join(self.config.channelConfig.keys())
563 raise NoWorkFound(f"No input images of band(s) {requested}")
564
565 # get the patch tract bounding box and wcs
566 skymap = butlerQC.get(inputRefs.skyMap)
567 quantumDataId = butlerQC.quantum.dataId
568 tractInfo = skymap[quantumDataId["tract"]]
569 patchInfo = tractInfo[quantumDataId["patch"]]
570 outputs = self.run(
571 images=sortedImages,
572 image_wcs=Projection.from_legacy(
573 patchInfo.wcs,
574 TractFrame(
575 skymap=quantumDataId["skymap"],
576 tract=quantumDataId["tract"],
577 bbox=Box.from_legacy(tractInfo.bbox),
578 ),
579 ),
580 image_box=Box.from_legacy(patchInfo.getOuterBBox()),
581 )
582 butlerQC.put(outputs, outputRefs)
583
585 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
586 ) -> dict[str, Exposure]:
587 r"""Make valid inputs for the run method from butler references.
588
589 Parameters
590 ----------
591 refs : `Iterable` of `DatasetRef`
592 Some `Iterable` container of `Butler` `DatasetRef`\ s
593 butler : `Butler` or `QuantumContext`
594 This is the object that fetches the input data.
595
596 Returns
597 -------
598 sortedImages : `dict` of `str` to `Exposure`
599 A dictionary of `Exposure`\ s keyed by the band they
600 correspond to.
601 """
602 sortedImages: dict[str, Exposure] = {}
603 for ref in refs:
604 key: str = cast(str, ref.dataId["band"])
605 image = butler.get(ref)
606 sortedImages[key] = image
607 return sortedImages
608
609 def makeInputsFromArrays(self, **kwargs) -> dict[str, DeferredDatasetHandle]:
610 r"""Make valid inputs for the run method from numpy arrays.
611
612 Parameters
613 ----------
614 kwargs : `numpy.ndarray`
615 This is standard python kwargs where the left side of the equals
616 is the data band, and the right side is the corresponding `numpy.ndarray`
617 array.
618
619 Returns
620 -------
621 sortedImages : `dict` of `str` to \
622 `~lsst.daf.butler.DeferredDatasetHandle`
623 A dictionary of `~lsst.daf.butlger.DeferredDatasetHandle`\ s keyed
624 by the band they correspond to.
625 """
626 # ignore type because there aren't proper stubs for afw
627 temp = {}
628 for key, array in kwargs.items():
629 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
630 temp[key].image.array[:] = array
631
632 return self.makeInputsFromExposures(**temp)
633
634 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
635 r"""Make valid inputs for the run method from `Exposure` objects.
636
637 Parameters
638 ----------
639 kwargs : `Exposure`
640 This is standard python kwargs where the left side of the equals
641 is the data band, and the right side is the corresponding
642 `Exposure`.
643
644 Returns
645 -------
646 sortedImages : `dict` of `int` to \
647 `~lsst.daf.butler.DeferredDatasetHandle`
648 A dictionary of `~lsst.daf.butler.DeferredDatasetHandle`\ s keyed
649 by the band they correspond to.
650 """
651 sortedImages = {}
652 for key, value in kwargs.items():
653 sortedImages[key] = value
654 return sortedImages
655
656
658 PipelineTaskConnections,
659 dimensions=("tract", "patch", "skymap", "band"),
660 defaultTemplates={"coaddTypeName": "deep"},
661):
662 inputCoadd = Input(
663 doc=("Input coadd for which the background is to be removed"),
664 name="{coaddTypeName}CoaddPsfMatched",
665 storageClass="ExposureF",
666 dimensions=("tract", "patch", "skymap", "band"),
667 multiple=True,
668 )
669 outputCoadd = Output(
670 doc="The coadd with the background fixed and subtracted",
671 name="pretty_picture_coadd_bg_subtracted",
672 storageClass="ExposureF",
673 dimensions=("tract", "patch", "skymap", "band"),
674 )
675
676 def adjust_all_quanta(self, adjuster: QuantaAdjuster) -> None:
677 # At this stage of a QG build, we won't necessarily have pruned
678 # quanta that don't have inputs, so instead of a set of quantum data
679 # IDs, we want to start from a set of quantum *input* data IDs. We'll
680 # just ignore quanta that don't; they'll get prune later.
681 flat_data_ids: set[DataCoordinate] = set()
682 for quantum_data_id in adjuster.iter_data_ids():
683 flat_data_ids.update(adjuster.get_inputs(quantum_data_id)["inputCoadd"])
684 # Reorganize task data IDs into
685 #
686 # {skymap: {(band, tract): [patches]}}}
687 #
688 # It's unlikely we'll ever get more than one skymap in a single QG
689 # build, of course, but not impossible.
690 hierarchical_data_ids: dict[str, dict[tuple[str, int], list[int]]] = {}
691 for quantum_data_id in flat_data_ids:
692 hierarchical_data_ids.setdefault(quantum_data_id["skymap"], {}).setdefault(
693 (quantum_data_id["band"], quantum_data_id["tract"]), []
694 ).append(quantum_data_id["patch"])
695 for skymap_name, data_ids_for_skymap in hierarchical_data_ids.items():
696 # We need to load the skyMap to turn single-int patch_id IDs into
697 # grid (x, y) pairs, so we can offset those to find neighbors.
698 skyMap = adjuster.butler.get("skyMap", skymap=skymap_name)
699 for (band, tract_id), patches in data_ids_for_skymap.items():
700 tract_info = skyMap[tract_id]
701 num_x, num_y = tract_info.num_patches.x, tract_info.num_patches.y
702 base_data_id = DataCoordinate.standardize(
703 skymap=skymap_name,
704 tract=tract_id,
705 band=band,
706 universe=adjuster.butler.dimensions,
707 )
708 for patch_id in patches:
709 patch_index: Index2D = tract_info.getPatchIndexPair(patch_id)
710 quantum_data_id = DataCoordinate.standardize(base_data_id, patch=patch_id)
711 # Find all adjacent patches (including corner neighbors).
712 for offset_x, offset_y in itertools.product((-1, 0, 1), (-1, 0, 1)):
713 if not (offset_x or offset_y):
714 # Skip the input that matches the quantum data ID;
715 # it's already got an edge to this quantum.
716 continue
717
718 proposed_patch_x = patch_index.x + offset_x
719 proposed_patch_y = patch_index.y + offset_y
720 if (
721 proposed_patch_x < 0
722 or proposed_patch_x > num_x
723 or proposed_patch_y < 0
724 or proposed_patch_y > num_y
725 ):
726 continue
727 neighbor_patch_id = tract_info.getSequentialPatchIndexFromPair(
728 Index2D(x=proposed_patch_x, y=proposed_patch_y)
729 )
730 neighbor_data_id = DataCoordinate.standardize(base_data_id, patch=neighbor_patch_id)
731 if neighbor_data_id not in flat_data_ids:
732 # Skip inputs that don't exist, either because
733 # they didn't have data or because they're just a
734 # totally invalid patch index.
735 continue
736 # Finally we can add an edge from the neighboring
737 # input to this quantum.
738 adjuster.add_input(quantum_data_id, "inputCoadd", neighbor_data_id)
739
740
741class PrettyPictureBackgroundFixerConfig(
742 PipelineTaskConfig, pipelineConnections=PrettyPictureBackgroundFixerConnections
743):
744 use_detection_mask = Field[bool](
745 doc="Use the detection mask to determine background instead of empirically finding it in this task",
746 default=False,
747 )
748 num_background_bins = Field[int](
749 doc="The number of bins along each axis when determining background", default=5
750 )
751 min_bin_fraction = Field[float](
752 doc="Bins with fewer pixels than this fraction of the total will be ignored", default=0.1
753 )
754
755 pos_sigma_multiplier = Field[float](
756 doc="How many sigma to consider as background in the positive direction", default=2
757 )
758 extra_pixel_rad = Field[int](
759 doc=(
760 "If there are neighboring input images consider control points that are radius of input image"
761 ", (x^2 + x^2)^(1/2), plus this many pixels away from the center of the input image as control"
762 "points to use."
763 ),
764 default=2000,
765 )
766 max_flux_imbalance = Field[float](
767 doc="When determining background, if the ratio of counts of positive pixels, to negative pixels"
768 " passes this threhsold, consider there to be extened low flux and only estimate noise below zero.",
769 default=1.7,
770 )
771 max_search_flux = Field[float](
772 doc="Pixels above this value should never be considered background", default=3
773 )
774
775
776class PrettyPictureBackgroundFixerTask(PipelineTask):
777 """Empirically flatten an images background.
778
779 Many astrophysical images have backgrounds with imperfections in them.
780 This Task attempts to determine control points which are considered
781 background values, and fits a radial basis function model to those
782 points. This model is then subtracted off the image.
783
784 """
785
786 _DefaultName = "prettyPictureBackgroundFixer"
787 ConfigClass = PrettyPictureBackgroundFixerConfig
788
789 config: ConfigClass
790
791 def _tile_slices(self, arr, R, C):
792 """Generate slices for tiling an array.
793
794 This function divides an array into a grid of tiles and returns a list of
795 slice objects representing each tile. It handles cases where the array
796 dimensions are not evenly divisible by the number of tiles in each
797 dimension, distributing the remainder among the tiles.
798
799 Parameters
800 ----------
801 arr : `numyp.ndarray`
802 The input array to be tiled. Used only to determine the array's shape.
803 R : `int`
804 The number of tiles in the row dimension.
805 C : `int`
806 The number of tiles in the column dimension.
807
808 Returns
809 -------
810 slices : `list` of `tuple`
811 A list of tuples, where each tuple contains two `slice` objects
812 representing the row and column slices for a single tile.
813 """
814 M = arr.shape[0]
815 N = arr.shape[1]
816
817 # Function to compute slices for a given dimension size and number of divisions
818 def get_slices(total_size: int, num_divisions: int) -> list[tuple[int, int]]:
819 """Generate slice ranges for dividing a size into equal parts.
820
821 Parameters
822 ----------
823 total_size : `int`
824 Total size to be divided into slices.
825 num_divisions : `int`
826 Number of divisions to create.
827
828 Returns
829 -------
830 `list` of `tuple` of `int`
831 List of (start, end) tuples representing each slice.
832
833 Notes
834 -----
835 This function divides the total_size into num_divisions equal parts.
836 If the division is not exact, the remainder is distributed by adding
837 1 to the first 'remainder' slices, ensuring balanced distribution.
838 """
839 base = total_size // num_divisions
840 remainder = total_size % num_divisions
841 slices = []
842 start = 0
843 for i in range(num_divisions):
844 end = start + base
845 if i < remainder:
846 end += 1
847 slices.append((start, end))
848 start = end
849 return slices
850
851 # Get row and column slices
852 row_slices = get_slices(M, R)
853 col_slices = get_slices(N, C)
854
855 # Generate all possible tile combinations of row and column slices
856 tiles = []
857 for rs in row_slices:
858 r_start, r_end = rs
859 for cs in col_slices:
860 c_start, c_end = cs
861 tile_slice = (slice(r_start, r_end), slice(c_start, c_end))
862 tiles.append(tile_slice)
863
864 return tiles
865
866 @staticmethod
867 def _findImageStatistics(image, pos_sigma_mult=1, max_flux_imbalance=1.7, max_search_flux=3):
868 """Find pixels that are likely to be background based on image statistics.
869
870 This method estimates background pixels by analyzing the distribution of
871 pixel values in the image. It uses the median as an estimate of the background
872 level and fits a half-normal distribution to values below the median to
873 determine the background sigma. Pixels below a threshold (mean + sigma) are
874 classified as background.
875
876 Parameters
877 ----------
878 image : `numpy.ndarray`
879 Input image array for which to find background pixels.
880 pos_sigma_mult : `float`
881 How many sigma to consider as background in the positive direction
882 max_flux_imbalance : `float`
883 Limit on the ratio of the area below the determined average to that
884 above. If the ratio is in excess of this value, then there is
885 assumed to be diffuse background flux in the image, and zero is
886 assumed to be the background level. Some excess flux is expected as
887 there is a real distribution of faint flux.
888 max_search_flux : `float`
889 Pixels above this value should never be considered background
890 """
891 # Find the median value in the image, which is likely to be
892 # close to average background. Note this doesn't work well
893 # in fields with high density or diffuse flux.
894 maxLikely = np.median(image[image < max_search_flux], axis=None)
895
896 # find all the pixels that are fainter than this
897 # and find the std. This is just used as an initialization
898 # parameter and doesn't need to be accurate.
899 mask = image < maxLikely
900 initial_std = (image[mask] - maxLikely).std()
901
902 # new estimate
903 sub_image = image[image < max(maxLikely + 3 * initial_std, max_search_flux)]
904
905 center = mode(np.round(sub_image, 2)).mode
906
907 # Don't do anything if there are no pixels to check
908 if np.any(mask):
909 # use a minimizer to determine best sigma for a Gaussian
910 # given only samples below the mean of the Gaussian. #
911 _, sigma_hat = halfnorm.fit(np.abs(image[image < center] - center), floc=0)
912 mu_hat = center
913 else:
914 mu_hat, sigma_hat = (maxLikely, 3 * initial_std)
915
916 new_cut = image[image < (mu_hat + 3 * sigma_hat)]
917 positivity_ratio = np.sum(new_cut > mu_hat) / np.sum(new_cut < mu_hat)
918
919 if positivity_ratio > max_flux_imbalance:
920 # This means there is an excess flux, possibly diffuse source,
921 # assume all flux is diffuse
922 return np.inf, -np.inf
923
924 # create a new masking threshold that's the determined
925 # mean plus std from the fit
926 threshold_pos = min(mu_hat + pos_sigma_mult * sigma_hat, max_search_flux)
927
928 # The reason a lower threshold is needed is that occasionally for some
929 # reason we have images with a few random pixels that are REALLY negative.
930 # We generate control points for background fixing from bins of the pixels we
931 # consider background. When there are a few pixels that are so negative, it brings
932 # the estimate way way down, and the control point then drags the background
933 # unrealistically high. By excluding them we get a more realistic estimate, and
934 # the only consequence is that these pixels that were excluded will still end up
935 # negative, and be clipped to black. In principle we could add a lower clipped
936 # boundary too, but it would not change the results much at all as the estimation
937 # of the sigma at the low side is robust and we do not expect contaminating flux
938 # from the low side (by construction of the algorithm).
939 threshold_neg = mu_hat - pos_sigma_mult * sigma_hat
940 return threshold_pos, threshold_neg
941
942 def findBackgroundPixels(self, image, threshold_pair=None):
943 """Find pixels that are likely to be background based on image statistics.
944
945 This method estimates background pixels by analyzing the distribution of
946 pixel values in the image. It uses the median as an estimate of the background
947 level and fits a half-normal distribution to values below the median to
948 determine the background sigma. Pixels below a threshold (mean + sigma) are
949 classified as background.
950
951 Parameters
952 ----------
953 image : `numpy.ndarray`
954 Input image array for which to find background pixels.
955 threshold_pair : `tuple` of `float`, `float` or None
956 A tuple representing the bottom and top values for which all pixels inside
957 these bounds should be considered background. If `None` theses bounds are
958 determined from the image statistics.
959
960 Returns
961 -------
962 result : `numpy.ndarray`
963 Boolean mask array where True indicates background pixels.
964
965 Notes
966 -----
967 This method works best for images with relatively uniform background. It may
968 not perform well in fields with high density or diffuse flux, as noted in
969 the implementation comments.
970 """
971 if threshold_pair is None:
972 threshhold_pos, threshhold_neg = self._findImageStatistics(
973 image,
974 self.config.pos_sigma_multiplier,
975 self.config.max_flux_imbalance,
976 self.config.max_search_flux,
977 )
978 else:
979 threshhold_pos, threshhold_neg = threshold_pair
980 if np.isinf(threshhold_pos):
981 return None
982 # mean plus std from the fit
983 image_mask = (image < threshhold_pos) * (image > threshhold_neg)
984 return image_mask
985
987 self,
988 exposure: Exposure,
989 origin: Point2I,
990 use_detection_mask: bool = False,
991 threshhold_pair: tuple[float, float] | None = None,
992 ):
993 if use_detection_mask:
994 mask_plane_dict = exposure.mask.getMaskPlaneDict()
995 image_mask = ~(exposure.mask.array & 2 ** mask_plane_dict["DETECTED"])
996 else:
997 image_mask = self.findBackgroundPixels(exposure.image.array, threshold_pair=threshhold_pair)
998
999 yloc = []
1000 xloc = []
1001 values = []
1002
1003 if image_mask is None:
1004 logger.debug("returning early from _findControlPoints")
1005 return values, yloc, xloc
1006
1007 tiles = self._tile_slices(
1008 exposure.image.array, self.config.num_background_bins, self.config.num_background_bins
1009 )
1010
1011 # adjust for the offset of the origin
1012 this_origin: Point2I = exposure.getBBox().getBegin()
1013 offset: Extent2I = this_origin - origin
1014 x_offset = offset.getX()
1015 y_offset = offset.getY()
1016
1017 # for each box find the middle position and the median background
1018 # value in the window.
1019 for i, (xslice, yslice) in enumerate(tiles):
1020 ypos = (yslice.stop - yslice.start) / 2 + yslice.start
1021 xpos = (xslice.stop - xslice.start) / 2 + xslice.start
1022 window = exposure.image.array[yslice, xslice][image_mask[yslice, xslice]]
1023 # make sure each bin is at least 1% filled
1024 min_fill = int((yslice.stop - yslice.start) ** 2 * self.config.min_bin_fraction)
1025 if window.size > min_fill:
1026 value = np.mean(window)
1027 else:
1028 continue
1029 values.append(value)
1030 yloc.append(ypos + y_offset)
1031 xloc.append(xpos + x_offset)
1032
1033 return values, yloc, xloc
1034
1035 def fixBackground(self, image, yloc: list[float], xloc: list[float], values: list[float]):
1036 """Estimate and subtract the background from an image.
1037
1038 This function estimates the background level in an image using supplied control
1039 values using Gaussian fitting and radial basis function interpolation.
1040 It aims to provide a more accurate background estimation than a simple median
1041 filter, especially in images with varying background levels.
1042
1043 Parameters
1044 ----------
1045 image : `numpy.ndarray`
1046 The input image as a NumPy array.
1047 yloc : `list` of `float`
1048 The list of y control points
1049 xloc : `list` of `float`
1050 The list of x control points
1051 values : `list` of `float`
1052 The list of the values at the control points
1053
1054 Returns
1055 -------
1056 numpy.ndarray
1057 An array representing the estimated background level across the image.
1058 """
1059
1060 # create an interpolant for the background and interpolate over the image.
1061 inter = RBFInterpolator(
1062 np.vstack((yloc, xloc)).T,
1063 values,
1064 kernel="thin_plate_spline",
1065 degree=4,
1066 smoothing=0.05,
1067 neighbors=None,
1068 )
1069
1070 backgrounds = rbf_interpolator.fast_rbf_interpolation_on_grid(inter, image.shape)
1071
1072 return backgrounds
1073
1074 def run(self, inputCoadd: Exposure, neighbors: list[Exposure] | None = None):
1075 """Estimate a background for an input Exposure and remove it.
1076
1077 Parameters
1078 ----------
1079 inputCoadd : `Exposure`
1080 The exposure the background will be removed from.
1081 neighbors : `list` of `Exposure`
1082 Neighboring `Exposure` objects that can be used to constrain
1083 backgrounds across boundaries.
1084
1085 Returns
1086 -------
1087 result : `Struct`
1088 A `Struct` that contains the exposure with the background removed.
1089 This `Struct` will have an attribute named ``outputCoadd``.
1090
1091 """
1092 origin = inputCoadd.getBBox().getBegin()
1093 input_y_dim, input_x_dim = inputCoadd.image.array.shape
1094 input_rad = np.sqrt((input_y_dim / 2) ** 2 + (input_x_dim / 2) ** 2)
1095 inside_rad = input_rad + self.config.extra_pixel_rad
1096 center_x = input_x_dim / 2
1097 center_y = input_y_dim / 2
1098
1099 # calculate joint statistics
1100 if self.config.use_detection_mask is False:
1101 background_pixels = inputCoadd.image.array[inputCoadd.image.array < self.config.max_search_flux]
1102 if neighbors:
1103 for n_exp in neighbors:
1104 background_pixels = np.append(
1105 background_pixels, n_exp.image.array[n_exp.image.array < self.config.max_search_flux]
1106 )
1107 joint_thresh = self._findImageStatistics(
1108 background_pixels,
1109 self.config.pos_sigma_multiplier,
1110 self.config.max_flux_imbalance,
1111 self.config.max_search_flux,
1112 )
1113 # There is no background to be found, return early
1114 if np.isinf(joint_thresh[0]):
1115 joint_thresh = None
1116
1117 else:
1118 joint_thresh = None
1119
1120 values, yloc, xloc = self._findControlPoints(
1121 inputCoadd, origin, self.config.use_detection_mask, joint_thresh
1122 )
1123
1124 if len(values) == 0:
1125 output = ExposureF(inputCoadd, deep=True)
1126 logger.warning(
1127 "No control points could be determined, likely due to extended flux, leaving background"
1128 " unmodified."
1129 )
1130 return Struct(outputCoadd=output)
1131
1132 if neighbors:
1133 for n_exp in neighbors:
1134 tmp_values, tmp_yloc, tmp_xloc = self._findControlPoints(
1135 n_exp, origin, self.config.use_detection_mask, joint_thresh
1136 )
1137
1138 for value, y_pos, x_pos in zip(tmp_values, tmp_yloc, tmp_xloc):
1139 if np.sqrt((y_pos - center_y) ** 2 + (x_pos - center_x) ** 2) < inside_rad:
1140 values.append(value)
1141 yloc.append(y_pos)
1142 xloc.append(x_pos)
1143
1144 # At least 15 points are requred for TPS with 4th order polynomial
1145 if len(yloc) < 15:
1146 logger.warning(
1147 "Not enough control points could be determined, likely due to extended flux, leaving"
1148 " background unmodified."
1149 )
1150 return Struct(outputCoadd=inputCoadd)
1151
1152 background = self.fixBackground(inputCoadd.image.array, yloc, xloc, values)
1153
1154 # create a copy to mutate
1155 output = ExposureF(inputCoadd, deep=True)
1156 output.image.array -= background
1157 return Struct(outputCoadd=output)
1158
1159 def runQuantum(
1160 self,
1161 butlerQC: QuantumContext,
1162 inputRefs: InputQuantizedConnection,
1163 outputRefs: OutputQuantizedConnection,
1164 ) -> None:
1165 quantum_patch = butlerQC.quantum.dataId["patch"]
1166 primary_ref = None
1167 neighbor_refs = []
1168 for ref in inputRefs.inputCoadd:
1169 if quantum_patch == ref.dataId["patch"]:
1170 primary_ref = ref
1171 else:
1172 neighbor_refs.append(ref)
1173 if primary_ref is None:
1174 # This should be unreachable
1175 raise RuntimeError(
1176 "There is a major problem, the input ref associated with this quantum can't be found."
1177 )
1178 inputCoadd = butlerQC.get(primary_ref)
1179 neighbors = []
1180 for n_ref in neighbor_refs:
1181 neighbors.append(butlerQC.get(n_ref))
1182 results = self.run(inputCoadd, neighbors=neighbors if neighbors else None)
1183 butlerQC.put(results, outputRefs)
1184
1185
1187 PipelineTaskConnections,
1188 dimensions=("tract", "patch", "skymap"),
1189):
1190 inputCoadd = Input(
1191 doc=("Input coadd for which the background is to be removed"),
1192 name="pretty_picture_coadd_bg_subtracted",
1193 storageClass="ExposureF",
1194 dimensions=("tract", "patch", "skymap", "band"),
1195 multiple=True,
1196 )
1197 outputCoadd = Output(
1198 doc="The coadd with the background fixed and subtracted",
1199 name="pretty_picture_coadd_fixed_stars",
1200 storageClass="ExposureF",
1201 dimensions=("tract", "patch", "skymap", "band"),
1202 multiple=True,
1203 )
1204
1205
1206class PrettyPictureStarFixerConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureStarFixerConnections):
1207 brightnessThresh = Field[float](
1208 doc="The flux value below which pixels with SAT or NO_DATA bits will be ignored"
1209 )
1210
1211
1212class PrettyPictureStarFixerTask(PipelineTask):
1213 """This class fixes up regions in an image where there is no, or bad data.
1214
1215 The fixes done by this task are overwhelmingly comprised of the cores of
1216 bright stars for which there is no data.
1217 """
1218
1219 _DefaultName = "prettyPictureStarFixer"
1220 ConfigClass = PrettyPictureStarFixerConfig
1221
1222 config: ConfigClass
1223
1224 def run(self, inputs: Mapping[str, ExposureF]) -> Struct:
1225 """Fix areas in an image where this is no data, most likely to be
1226 the cores of bright stars.
1227
1228 Because we want to have consistent fixes accross bands, this method
1229 relies on supplying all bands and fixing pixels that are marked
1230 as having a defect in any band even if within one band there is
1231 no issue.
1232
1233 Parameters
1234 ----------
1235 inputs : `Mapping` of `str` to `ExposureF`
1236 This mapping has keys of band as a `str` and the corresponding
1237 ExposureF as a value.
1238
1239 Returns
1240 -------
1241 results : `Struct` of `Mapping` of `str` to `ExposureF`
1242 A `Struct` that has a mapping of band to `ExposureF`. The `Struct`
1243 has an attribute named ``results``.
1244
1245 """
1246 # make the joint mask of all the channels
1247 doJointMaskInit = True
1248 for imageExposure in inputs.values():
1249 maskDict = imageExposure.mask.getMaskPlaneDict()
1250 if doJointMaskInit:
1251 jointMask = np.zeros(imageExposure.mask.array.shape, dtype=imageExposure.mask.array.dtype)
1252 doJointMaskInit = False
1253 jointMask |= imageExposure.mask.array
1254
1255 sat_bit = maskDict["SAT"]
1256 no_data_bit = maskDict["NO_DATA"]
1257 together = (jointMask & 2**sat_bit).astype(bool) | (jointMask & 2**no_data_bit).astype(bool)
1258
1259 # use the last imageExposure as it is likely close enough across all bands
1260 bright_mask = imageExposure.image.array > self.config.brightnessThresh
1261
1262 # dilate the mask a bit, this helps get a bit fainter mask without starting
1263 # to include pixels in an irregular shape, as only the star cores should be
1264 # fixed.
1265 both = together & bright_mask
1266 struct = np.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=bool)
1267 both = binary_dilation(both, struct, iterations=4).astype(bool)
1268
1269 # do the actual fixing of values
1270 results = {}
1271 for band, imageExposure in inputs.items():
1272 if np.sum(both) > 0:
1273 inpainted = inpaint_biharmonic(imageExposure.image.array, both, split_into_regions=True)
1274 imageExposure.image.array[both] = inpainted[both]
1275 results[band] = imageExposure
1276 return Struct(results=results)
1277
1279 self,
1280 butlerQC: QuantumContext,
1281 inputRefs: InputQuantizedConnection,
1282 outputRefs: OutputQuantizedConnection,
1283 ) -> None:
1284 refs = inputRefs.inputCoadd
1285 sortedImages: dict[str, Exposure] = {}
1286 for ref in refs:
1287 key: str = cast(str, ref.dataId["band"])
1288 image = butlerQC.get(ref)
1289 sortedImages[key] = image
1290
1291 outputs = self.run(sortedImages).results
1292 sortedOutputs = {}
1293 for ref in outputRefs.outputCoadd:
1294 sortedOutputs[ref.dataId["band"]] = ref
1295
1296 for band, data in outputs.items():
1297 butlerQC.put(data, sortedOutputs[band])
1298
1299
1300class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")):
1301 inputRGB = Input(
1302 doc="Individual RGB images that are to go into the mosaic",
1303 name="rgb_picture",
1304 storageClass="ColorImage",
1305 dimensions=("tract", "patch", "skymap"),
1306 multiple=True,
1307 deferLoad=True,
1308 )
1309
1310 skyMap = Input(
1311 doc="The skymap which the data has been mapped onto",
1312 storageClass="SkyMap",
1313 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
1314 dimensions=("skymap",),
1315 )
1316
1317 inputRGBMask = Input(
1318 doc="Individual RGB images that are to go into the mosaic",
1319 name="rgb_picture_mask",
1320 storageClass="Mask",
1321 dimensions=("tract", "patch", "skymap"),
1322 multiple=True,
1323 deferLoad=True,
1324 )
1325
1326 outputRGBMosaic = Output(
1327 doc="A RGB mosaic created from the input data stored as a 3d array",
1328 name="rgb_mosaic",
1329 storageClass="ColorImage",
1330 dimensions=("tract", "skymap"),
1331 )
1332
1333
1334class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
1335 binFactor = Field[int](doc="The factor to bin by when producing the mosaic", default=1)
1336 doDCID65Convert = Field[bool](
1337 "Force the output to be converted from display p3 to DCI-D65 colorspace.", default=False
1338 )
1339 useLocalTemp = Field[bool](doc="Use the current directory when creating local temp files.", default=False)
1340
1341
1342class PrettyMosaicTask(PipelineTask):
1343 """Combines multiple RGB arrays into one mosaic."""
1344
1345 _DefaultName = "prettyMosaic"
1346 ConfigClass = PrettyMosaicConfig
1347
1348 config: ConfigClass
1349
1350 def run(
1351 self,
1352 inputRGB: Iterable[DeferredDatasetHandle],
1353 skyMap: BaseSkyMap,
1354 inputRGBMask: Iterable[DeferredDatasetHandle],
1355 ) -> Struct:
1356 r"""Assemble individual `numpy.ndarrays` into a mosaic.
1357
1358 Each input is a `~lsst.daf.butler.DeferredDatasetHandle` because
1359 they're loaded in one at a time to be placed into the mosaic to save
1360 memory.
1361
1362 Parameters
1363 ----------
1364 inputRGB : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1365 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to RGB
1366 `numpy.ndarrays`.
1367 skyMap : `BaseSkyMap`
1368 The skymap that defines the relative position of each of the input
1369 images.
1370 inputRGBMask : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1371 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to masks for
1372 each of the corresponding images.
1373
1374 Returns
1375 -------
1376 result : `Struct`
1377 The `Struct` containing the combined mosaic. The `Struct` has
1378 and attribute named ``outputRGBMosaic``.
1379 """
1380 # create the bounding region
1381 newBox = Box2I()
1382 # store the bounds as they are retrieved from the skymap
1383 boxes = []
1384 tractMaps = []
1385 for handle in inputRGB:
1386 dataId = handle.dataId
1387 tractInfo: TractInfo = skyMap[dataId["tract"]]
1388 patchInfo: PatchInfo = tractInfo[dataId["patch"]]
1389 bbox = patchInfo.getOuterBBox()
1390 boxes.append(bbox)
1391 newBox.include(bbox)
1392 tractMaps.append(tractInfo)
1393 # This will be overwritten in the loop, but that is ok, because
1394 # it is the same for each patch.
1395 patch_grow: int = patchInfo.getCellInnerDimensions().getX()
1396
1397 # fixup the boxes to be smaller if needed, and put the origin at zero,
1398 # this must be done after constructing the complete outer box
1399 modifiedBoxes = []
1400 origin = newBox.getBegin()
1401 for iterBox in boxes:
1402 localOrigin = iterBox.getBegin() - origin
1403 localOrigin = Point2I(
1404 x=int(np.floor(localOrigin.x / self.config.binFactor)),
1405 y=int(np.floor(localOrigin.y / self.config.binFactor)),
1406 )
1407 localExtent = Extent2I(
1408 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
1409 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
1410 )
1411 tmpBox = Box2I(localOrigin, localExtent)
1412 modifiedBoxes.append(tmpBox)
1413 boxes = modifiedBoxes
1414
1415 # scale the container box
1416 newBoxOrigin = Point2I(0, 0)
1417 newBoxExtent = Extent2I(
1418 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
1419 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
1420 )
1421 newBox = Box2I(newBoxOrigin, newBoxExtent)
1422
1423 # Allocate storage for the mosaic
1424 self.imageHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1425 self.maskHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1426 consolidatedImage = None
1427 consolidatedMask = None
1428
1429 # Setup color space conversion in case they are used.
1430 d65 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DCI_P3)
1431 dp3 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DISPLAY_P3)
1432 d65.whitepoint = dp3.whitepoint
1433 d65.whitepoint_name = dp3.whitepoint_name
1434
1435 # Actually assemble the mosaic
1436 maskDict = {}
1437 mosaic_maker = FeatheredMosaicCreator(patch_grow, self.config.binFactor)
1438 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps):
1439 rgb = handle.get().array
1440 # convert to the dci-d65 colorspace
1441 if self.config.doDCID65Convert:
1442 rgb = colour.RGB_to_RGB(np.clip(rgb, 0, 1), dp3, d65)
1443 rgbMask = handleMask.get()
1444 maskDict = rgbMask.getMaskPlaneDict()
1445 # allocate the memory for the mosaic
1446 if consolidatedImage is None:
1447 consolidatedImage = np.memmap(
1448 self.imageHandle.name,
1449 mode="w+",
1450 shape=(newBox.getHeight(), newBox.getWidth(), 3),
1451 dtype=rgb.dtype,
1452 )
1453 if consolidatedMask is None:
1454 consolidatedMask = np.memmap(
1455 self.maskHandle.name,
1456 mode="w+",
1457 shape=(newBox.getHeight(), newBox.getWidth()),
1458 dtype=rgbMask.array.dtype,
1459 )
1460
1461 if self.config.binFactor > 1:
1462 # opencv wants things in x, y dimensions
1463 shape = tuple(box.getDimensions())[::-1]
1464 rgb = cv2.resize(
1465 rgb,
1466 dst=None,
1467 dsize=shape,
1468 fx=shape[0] / self.config.binFactor,
1469 fy=shape[1] / self.config.binFactor,
1470 )
1471 mask_array = rgbMask.array[:: self.config.binFactor, :: self.config.binFactor]
1472 rgbMask = Mask(*(mask_array.shape[::-1]))
1473 mosaic_maker.add_to_image(consolidatedImage, rgb, newBox, box, reverse=False)
1474
1475 consolidatedMask[*box.slices] = np.bitwise_or(consolidatedMask[*box.slices], rgbMask.array)
1476
1477 for plugin in plugins.full():
1478 if consolidatedImage is not None and consolidatedMask is not None:
1479 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
1480 # If consolidated image still None, that means there was no work to do.
1481 # Return an empty image instead of letting this task fail.
1482 if consolidatedImage is None:
1483 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
1484
1485 return Struct(outputRGBMosaic=ColorImage(consolidatedImage))
1486
1488 self,
1489 butlerQC: QuantumContext,
1490 inputRefs: InputQuantizedConnection,
1491 outputRefs: OutputQuantizedConnection,
1492 ) -> None:
1493 inputs = butlerQC.get(inputRefs)
1494 outputs = self.run(**inputs)
1495 butlerQC.put(outputs, outputRefs)
1496 if hasattr(self, "imageHandle"):
1497 self.imageHandle.close()
1498 if hasattr(self, "maskHandle"):
1499 self.maskHandle.close()
1500
1502 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
1503 ) -> Iterable[DeferredDatasetHandle]:
1504 r"""Make valid inputs for the run method from numpy arrays.
1505
1506 Parameters
1507 ----------
1508 inputs : `Iterable` of `tuple` of `Mapping` and `numpy.ndarray`
1509 An iterable where each element is a tuple with the first
1510 element is a mapping that corresponds to an arrays dataId,
1511 and the second is an `numpy.ndarray`.
1512
1513 Returns
1514 -------
1515 sortedImages : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1516 An iterable of `~lsst.daf.butler.DeferredDatasetHandle`\ s
1517 containing the input data.
1518 """
1519 structuredInputs = []
1520 for dataId, array in inputs:
1521 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
1522
1523 return structuredInputs
Iterable[DeferredDatasetHandle] makeInputsFromArrays(self, Iterable[tuple[Mapping[str, Any], NDArray]] inputs)
Definition _task.py:1503
Struct run(self, Iterable[DeferredDatasetHandle] inputRGB, BaseSkyMap skyMap, Iterable[DeferredDatasetHandle] inputRGBMask)
Definition _task.py:1355
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:1492
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:1283
Struct run(self, Mapping[str, ExposureF] inputs)
Definition _task.py:1224
RGBImage lsstRGB(FloatImagePlane rArray, FloatImagePlane gArray, FloatImagePlane bArray, LocalContrastFunction|None|_SentinalDefault local_contrast=DEFAULT_FUNCTION, ScaleLumFunction|None|_SentinalDefault scale_lum=DEFAULT_FUNCTION, ScaleColorFunction|None|_SentinalDefault scale_color=DEFAULT_FUNCTION, RemapBoundsFunction|None|_SentinalDefault remap_bounds=DEFAULT_FUNCTION, BracketingFunction|None|_SentinalDefault bracketing_function=DEFAULT_FUNCTION, GamutRemappingFunction|None|_SentinalDefault gamut_remapping_function=DEFAULT_FUNCTION, FloatImagePlane|None psf=None, tuple[float, float] cieWhitePoint=(0.28, 0.28))
dict[str, Exposure] makeInputsFromRefs(self, Iterable[DatasetRef] refs, Butler|QuantumContext butler)
Definition _task.py:586
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:558
Struct run(self, Mapping[str, Exposure] images, Projection[Any]|None image_wcs=None, Box|None image_box=None)
Definition _task.py:407
fixBackground(self, image, list[float] yloc, list[float] xloc, list[float] values)
Definition _task.py:1035
_findControlPoints(self, Exposure exposure, Point2I origin, bool use_detection_mask=False, tuple[float, float]|None threshhold_pair=None)
Definition _task.py:992
dict[int, DeferredDatasetHandle] makeInputsFromExposures(self, **kwargs)
Definition _task.py:634
findBackgroundPixels(self, image, threshold_pair=None)
Definition _task.py:942
dict[str, DeferredDatasetHandle] makeInputsFromArrays(self, **kwargs)
Definition _task.py:609
STL namespace.