Coverage for python/lsst/drp/tasks/assemble_cell_coadd.py: 15%
411 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 01:20 -0700
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 01:20 -0700
1# This file is part of drp_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "AssembleCellCoaddTask",
26 "AssembleCellCoaddConfig",
27 "ConvertMultipleCellCoaddToExposureTask",
28)
30import dataclasses
31import itertools
32import logging
34import numpy as np
36import lsst.afw.geom as afwGeom
37import lsst.afw.image as afwImage
38import lsst.afw.math as afwMath
39import lsst.geom as geom
40from lsst.afw.detection import InvalidPsfError
41from lsst.afw.geom import SinglePolygonException, makeWcsPairTransform
42from lsst.cell_coadds import (
43 CellIdentifiers,
44 CoaddApCorrMapStacker,
45 CoaddInputs,
46 CoaddUnits,
47 CommonComponents,
48 GridContainer,
49 MultipleCellCoadd,
50 ObservationIdentifiers,
51 OwnedImagePlanes,
52 PatchIdentifiers,
53 SingleCellCoadd,
54 UniformGrid,
55)
56from lsst.daf.butler import DataCoordinate, DeferredDatasetHandle
57from lsst.meas.algorithms import AccumulatorMeanStack
58from lsst.pex.config import ConfigField, ConfigurableField, DictField, Field, ListField, RangeField
59from lsst.pipe.base import (
60 InMemoryDatasetHandle,
61 NoWorkFound,
62 PipelineTask,
63 PipelineTaskConfig,
64 PipelineTaskConnections,
65 Struct,
66)
67from lsst.pipe.base.connectionTypes import Input, Output
68from lsst.pipe.tasks.coaddBase import makeSkyInfo, removeMaskPlanes, setRejectedMaskMapping
69from lsst.pipe.tasks.healSparseMapping import HealSparseInputMapTask
70from lsst.pipe.tasks.interpImage import InterpImageTask
71from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask
72from lsst.skymap import BaseSkyMap
75@dataclasses.dataclass
76class WarpInputs:
77 """Collection of associate inputs along with warps."""
79 warp: DeferredDatasetHandle | InMemoryDatasetHandle
80 """Handle for the warped exposure."""
82 masked_fraction: DeferredDatasetHandle | InMemoryDatasetHandle | None = None
83 """Handle for the masked fraction image."""
85 artifact_mask: DeferredDatasetHandle | InMemoryDatasetHandle | None = None
86 """Handle for the CompareWarp artifact mask."""
88 noise_warps: list[DeferredDatasetHandle | InMemoryDatasetHandle] = dataclasses.field(default_factory=list)
89 """List of handles for the noise warps"""
91 @property
92 def dataId(self) -> DataCoordinate:
93 """DataID corresponding to the warp.
95 Returns
96 -------
97 data_id : `~lsst.daf.butler.DataCoordinate`
98 DataID of the warp.
99 """
100 return self.warp.dataId
103class AssembleCellCoaddConnections(
104 PipelineTaskConnections,
105 dimensions=("tract", "patch", "band", "skymap"),
106 defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"},
107):
108 inputWarps = Input(
109 doc="Input warps",
110 name="{inputWarpName}Coadd_directWarp",
111 storageClass="ExposureF",
112 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
113 deferLoad=True,
114 multiple=True,
115 )
117 maskedFractionWarps = Input(
118 doc="Mask fraction warps",
119 name="{inputWarpName}Coadd_directWarp_maskedFraction",
120 storageClass="ImageF",
121 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
122 deferLoad=True,
123 multiple=True,
124 )
126 artifactMasks = Input(
127 doc="Artifact masks to be applied to the input warps",
128 name="compare_warp_artifact_mask",
129 storageClass="Mask",
130 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
131 deferLoad=True,
132 multiple=True,
133 )
135 visitSummaryList = Input(
136 doc="Input visit-summary catalogs with updated calibration objects. Mainly used for coadd weights.",
137 name="finalVisitSummary",
138 storageClass="ExposureCatalog",
139 dimensions=("instrument", "visit"),
140 deferLoad=True,
141 multiple=True,
142 )
144 skyMap = Input(
145 doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.",
146 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
147 storageClass="SkyMap",
148 dimensions=("skymap",),
149 )
151 multipleCellCoadd = Output(
152 doc="Output multiple cell coadd",
153 name="{inputWarpName}Coadd{outputCoaddSuffix}",
154 storageClass="MultipleCellCoadd",
155 dimensions=("tract", "patch", "band", "skymap"),
156 )
158 inputMap = Output(
159 doc="Output healsparse map of input images",
160 name="{inputWarpName}Coadd_inputMap",
161 storageClass="HealSparseMap",
162 dimensions=("tract", "patch", "band", "skymap"),
163 )
165 def __init__(self, *, config=None):
166 super().__init__(config=config)
168 if not config:
169 return
171 if config.do_calculate_weight_from_warp:
172 del self.visitSummaryList
174 if not config.do_use_artifact_mask:
175 del self.artifactMasks
177 if not config.do_input_map:
178 del self.inputMap
180 # Dynamically set input connections for noise images, depending on the
181 # number of noise realizations specified in the config.
182 for n in range(config.num_noise_realizations):
183 noise_warps = Input(
184 doc="Input noise warps",
185 name=f"direct_warp_noise{n}",
186 storageClass="MaskedImageF",
187 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
188 deferLoad=True,
189 multiple=True,
190 )
191 setattr(self, f"noise{n}_warps", noise_warps)
194class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections):
195 do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=True)
196 interpolate_coadd = ConfigurableField(
197 target=InterpImageTask,
198 doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds",
199 )
200 do_scale_zero_point = Field[bool](
201 doc="Scale warps to a common zero point? This is not needed if they have absolute flux calibration.",
202 default=False,
203 deprecated="Now that visits are scaled to nJy it is no longer necessary or "
204 "recommended to scale the zero point, so this will be removed "
205 "after v29.",
206 )
207 scale_zero_point = ConfigurableField(
208 target=ScaleZeroPointTask,
209 doc="Task to scale warps to a common zero point",
210 deprecated="Now that visits are scaled to nJy it is no longer necessary or "
211 "recommended to scale the zero point, so this will be removed "
212 "after v29.",
213 )
214 do_calculate_weight_from_warp = Field[bool](
215 doc="Calculate coadd weight from the input warp? Otherwise, the weight is obtained from the "
216 "visitSummaryList connection. This is meant as a fallback when run outside the pipeline.",
217 default=False,
218 )
219 do_use_artifact_mask = Field[bool](
220 doc="Substitute the mask planes input warp with an alternative artifact mask?",
221 default=True,
222 )
223 do_coadd_inverse_aperture_corrections = Field[bool](
224 doc="Coadd the inverse aperture corrections for each cell? This is formally the more accurate way "
225 "but may be turned off for parity with deepCoadd.",
226 default=False,
227 )
228 min_overlap_fraction = RangeField[float](
229 doc="The minimum overlap fraction required for a single (visit, detector) input to be included in a "
230 "cell.",
231 # A value of 1.0 corresponds to ideal, edge-free cells.
232 # A value of 0.0 corresponds to the deep_coadd style coadds.
233 # This has to be at least 0.5 to ensure that the an input overlaps the
234 # cell center. Inputs will overlap fraction less than 0.25 will
235 # definitely not overlap the cell center.
236 default=1.0,
237 min=0.0,
238 max=1.0,
239 inclusiveMin=True,
240 inclusiveMax=True,
241 )
242 bad_mask_planes = ListField[str](
243 doc="Mask planes that count towards the masked fraction within a cell.",
244 default=("BAD", "NO_DATA", "SAT", "CLIPPED"),
245 )
246 remove_mask_planes = ListField[str](
247 doc="Mask planes to remove before coadding",
248 default=["EDGE", "NOT_DEBLENDED"],
249 )
250 calc_error_from_input_variance = Field[bool](
251 doc="Calculate coadd variance from input variance by stacking "
252 "statistic. Passed to AccumulatorMeanStack.",
253 default=True,
254 )
255 mask_propagation_thresholds = DictField[str, float](
256 doc=(
257 "Threshold (in fractional weight) of rejection at which we "
258 "propagate a mask plane to the coadd; that is, we set the mask "
259 "bit on the coadd if the fraction the rejected frames "
260 "would have contributed exceeds this value."
261 ),
262 default={"SAT": 0.1},
263 )
264 max_maskfrac = RangeField[float](
265 doc="Maximum fraction of masked pixels in a cell for a given warp. "
266 "Warps exceeding this threshold are excluded from the science coadd, "
267 "PSF, aperture corrections, and input maps.",
268 default=0.5,
269 min=0.0,
270 max=1.0,
271 inclusiveMin=True,
272 inclusiveMax=False,
273 )
274 num_noise_realizations = Field[int](
275 default=0,
276 doc=(
277 "Number of noise planes to include in the coadd. "
278 "This should not exceed the corresponding config parameter "
279 "specified in `MakeDirectWarpConfig`. "
280 ),
281 check=lambda x: x >= 0,
282 )
283 psf_warper = ConfigField(
284 doc="Configuration for the warper that warps the PSFs. It must have the same configuration used to "
285 "warp the images.",
286 dtype=afwMath.Warper.ConfigClass,
287 )
288 psf_dimensions = Field[int](
289 default=35,
290 doc="Dimensions of the PSF image stamp size to be assigned to cells (must be odd).",
291 check=lambda x: (x > 0) and (x % 2 == 1),
292 )
293 require_artifact_mask = Field[bool](
294 default=True,
295 doc="Require presence of artifact mask for each warp? Use true if using artifact rejection outputs"
296 " from CompareWarpTask",
297 )
298 do_input_map = Field[bool](
299 default=False,
300 doc="Create a bitwise map of coadd inputs.",
301 )
302 input_mapper = ConfigurableField(
303 target=HealSparseInputMapTask,
304 doc="Input map creation subtask.",
305 )
308class AssembleCellCoaddTask(PipelineTask):
309 """Assemble a cell-based coadded image from a set of warps.
311 This task reads in the warp one at a time, and accumulates it in all the
312 cells that it completely overlaps with. This is the optimal I/O pattern but
313 this also implies that it is not possible to build one or only a few cells.
315 Each cell coadds is guaranteed to have a well-defined PSF. This is done by
316 1) excluding warps that only partially overlap a cell from that cell coadd;
317 2) interpolating bad pixels in the warps rather than excluding them;
318 3) by computing the coadd as a weighted mean of the warps without clipping;
319 4) by computing the coadd PSF as the weighted mean of the PSF of the warps
320 with the same weights.
322 The cells are (and must be) defined in the skymap, and cannot be configured
323 or redefined here. The cells are assumed to be small enough that the PSF is
324 assumed to be spatially constant within a cell.
326 Raises
327 ------
328 NoWorkFound
329 Raised if no input warps are provided, or no cells could be populated.
330 RuntimeError
331 Raised if the skymap is not cell-based.
333 Notes
334 -----
335 This is not yet a part of the standard DRP pipeline. As such, the Task and
336 especially its Config and Connections are experimental and subject to
337 change any time without a formal RFC or standard deprecation procedures
338 until it is included in the DRP pipeline.
339 """
341 ConfigClass = AssembleCellCoaddConfig
342 _DefaultName = "assembleCellCoadd"
344 def __init__(self, *args, **kwargs):
345 super().__init__(*args, **kwargs)
346 if self.config.do_interpolate_coadd:
347 self.makeSubtask("interpolate_coadd")
348 # Suppress the warning message about fallback.
349 self.interpolate_coadd.log.setLevel(logging.ERROR)
350 if self.config.do_scale_zero_point:
351 self.makeSubtask("scale_zero_point")
352 if self.config.do_input_map:
353 self.makeSubtask("input_mapper")
355 self.psf_warper = afwMath.Warper.fromConfig(self.config.psf_warper)
356 if (warping_kernel_name := self.config.psf_warper.warpingKernelName.lower()).startswith("lanczos"):
357 psf_padding = 2 * int(warping_kernel_name.lstrip("lanczos")) - 1
358 self.log.debug(
359 "Padding PSF image by %d pixels since the warping kernel is %s.",
360 psf_padding,
361 self.config.psf_warper.warpingKernelName,
362 )
363 else:
364 psf_padding = 10
365 self.log.info(
366 "Padding PSF image by %d pixels since the warping kernel is not Lanczos.",
367 psf_padding,
368 )
369 self.psf_padding = psf_padding
371 def runQuantum(self, butlerQC, inputRefs, outputRefs):
372 # Docstring inherited.
373 if not inputRefs.inputWarps:
374 raise NoWorkFound("No input warps provided for co-addition")
375 self.log.info("Found %d input warps", len(inputRefs.inputWarps))
377 # Construct skyInfo expected by run
378 # Do not remove skyMap from inputData in case _makeSupplementaryData
379 # needs it
380 skyMap = butlerQC.get(inputRefs.skyMap)
382 if not skyMap.config.tractBuilder.name == "cells":
383 raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.")
385 outputDataId = butlerQC.quantum.dataId
387 skyInfo = makeSkyInfo(skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"])
388 visitSummaryList = butlerQC.get(getattr(inputRefs, "visitSummaryList", []))
390 units = CoaddUnits.legacy if self.config.do_scale_zero_point else CoaddUnits.nJy
391 self.common = CommonComponents(
392 units=units,
393 wcs=skyInfo.patchInfo.wcs,
394 band=outputDataId.get("band", None),
395 identifiers=PatchIdentifiers.from_data_id(outputDataId),
396 )
398 inputs: dict[DataCoordinate, WarpInputs] = {}
399 for handle in butlerQC.get(inputRefs.inputWarps):
400 inputs[handle.dataId] = WarpInputs(warp=handle, noise_warps=[])
402 for ref in getattr(inputRefs, "artifactMasks", []):
403 inputs[ref.dataId].artifact_mask = butlerQC.get(ref)
404 for ref in getattr(inputRefs, "maskedFractionWarps", []):
405 inputs[ref.dataId].masked_fraction = butlerQC.get(ref)
406 for n in range(self.config.num_noise_realizations):
407 for ref in getattr(inputRefs, f"noise{n}_warps"):
408 inputs[ref.dataId].noise_warps.append(butlerQC.get(ref))
410 returnStruct = self.run(inputs=inputs, skyInfo=skyInfo, visitSummaryList=visitSummaryList)
411 butlerQC.put(returnStruct, outputRefs)
412 return returnStruct
414 @staticmethod
415 def _compute_weight(maskedImage, statsCtrl):
416 """Compute a weight for a masked image.
418 Parameters
419 ----------
420 maskedImage : `~lsst.afw.image.MaskedImage`
421 The masked image to compute the weight.
422 statsCtrl : `~lsst.afw.math.StatisticsControl`
423 A control (config-like) object for StatisticsStack.
425 Returns
426 -------
427 weight : `float`
428 Inverse of the clipped mean variance of the masked image.
429 """
430 statObj = afwMath.makeStatistics(
431 maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl
432 )
433 meanVar, _ = statObj.getResult(afwMath.MEANCLIP)
434 weight = 1.0 / float(meanVar)
435 return weight
437 @staticmethod
438 def _construct_grid(skyInfo):
439 """Construct a UniformGrid object from a SkyInfo struct.
441 Parameters
442 ----------
443 skyInfo : `~lsst.pipe.base.Struct`
444 A Struct object
446 Returns
447 -------
448 grid : `~lsst.cell_coadds.UniformGrid`
449 A UniformGrid object.
450 """
451 padding = skyInfo.patchInfo.getCellBorder()
452 grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(padding)
453 grid = UniformGrid.from_bbox_cell_size(
454 grid_bbox,
455 skyInfo.patchInfo.getCellInnerDimensions(),
456 padding=padding,
457 )
458 return grid
460 def _construct_grid_container(self, skyInfo, statsCtrl):
461 """Construct a grid of AccumulatorMeanStack instances.
463 Parameters
464 ----------
465 skyInfo : `~lsst.pipe.base.Struct`
466 A Struct object
467 statsCtrl : `~lsst.afw.math.StatisticsControl`
468 A control (config-like) object for StatisticsStack.
470 Returns
471 -------
472 gc : `~lsst.cell_coadds.GridContainer`
473 A GridContainer object container one AccumulatorMeanStack per cell.
474 """
475 grid = self._construct_grid(skyInfo)
477 maskMap = setRejectedMaskMapping(statsCtrl)
478 self.log.debug("Obtained maskMap = %s for %s", maskMap, skyInfo.patchInfo)
479 thresholdDict = AccumulatorMeanStack.stats_ctrl_to_threshold_dict(statsCtrl)
481 # Initialize the grid container with AccumulatorMeanStacks
482 gc = GridContainer[AccumulatorMeanStack](grid.shape)
483 for cellInfo in skyInfo.patchInfo:
484 stacker = AccumulatorMeanStack(
485 # The shape is for the numpy arrays, hence transposed.
486 shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width),
487 bit_mask_value=statsCtrl.getAndMask(),
488 mask_threshold_dict=thresholdDict,
489 calc_error_from_input_variance=self.config.calc_error_from_input_variance,
490 compute_n_image=False,
491 mask_map=maskMap,
492 no_good_pixels_mask=statsCtrl.getNoGoodPixelsMask(),
493 )
494 gc[cellInfo.index] = stacker
496 return gc
498 def _construct_stats_control(self):
499 """Construct a StatisticsControl object for coadd.
501 Unlike AssembleCoaddTask or CompareWarpAssembleCoaddTask, there is
502 very little to be configured apart from setting the mask planes and
503 optionally mask propagation thresholds.
505 Returns
506 -------
507 statsCtrl : `~lsst.afw.math.StatisticsControl`
508 A control object for StatisticsStack.
509 """
510 statsCtrl = afwMath.StatisticsControl()
511 # Hardcode the numIter parameter to the default config value set in
512 # CompareWarpAssembleCoaddTask to get consistent weights. This is NOT
513 # exposed as a config parameter, since this is only meant to be a
514 # fallback option that is not recommended for production.
515 statsCtrl.setNumIter(2)
516 statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes))
517 statsCtrl.setNanSafe(True)
518 for plane, threshold in self.config.mask_propagation_thresholds.items():
519 bit = afwImage.Mask.getMaskPlane(plane)
520 statsCtrl.setMaskPropagationThreshold(bit, threshold)
521 return statsCtrl
523 def _construct_ap_corr_grid_container(self, skyInfo):
524 """Construct a grid of CoaddApCorrMapStacker instances.
526 Parameters
527 ----------
528 skyInfo : `~lsst.pipe.base.Struct`
529 A Struct object
531 Returns
532 -------
533 gc : `~lsst.cell_coadds.GridContainer`
534 A GridContainer object container one CoaddApCorrMapStacker per
535 cell.
536 """
537 grid = self._construct_grid(skyInfo)
539 # Initialize the grid container with CoaddApCorrMapStacker.
540 gc = GridContainer[CoaddApCorrMapStacker](grid.shape)
541 for cellInfo in skyInfo.patchInfo:
542 stacker = CoaddApCorrMapStacker(
543 evaluation_point=cellInfo.inner_bbox.getCenter(),
544 do_coadd_inverse_ap_corr=self.config.do_coadd_inverse_aperture_corrections,
545 )
546 gc[cellInfo.index] = stacker
548 return gc
550 def run(
551 self,
552 *,
553 inputs: dict[DataCoordinate, WarpInputs],
554 skyInfo,
555 visitSummaryList: list | None = None,
556 ):
557 for mask_plane in self.config.bad_mask_planes:
558 afwImage.Mask.addMaskPlane(mask_plane)
559 for mask_plane in self.config.mask_propagation_thresholds:
560 afwImage.Mask.addMaskPlane(mask_plane)
562 statsCtrl = self._construct_stats_control()
564 warp_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl)
565 maskfrac_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl)
566 noise_stacker_gc_list = [
567 self._construct_grid_container(skyInfo, statsCtrl)
568 for n in range(self.config.num_noise_realizations)
569 ]
570 psf_stacker_gc = GridContainer[AccumulatorMeanStack](warp_stacker_gc.shape)
571 psf_bbox_gc = GridContainer[geom.Box2I](warp_stacker_gc.shape)
572 ap_corr_stacker_gc = self._construct_ap_corr_grid_container(skyInfo)
574 # A cell is in "fallback" mode if it does not yet have any warps that
575 # pass the per-detector cuts; in that mode, we accumulate warps
576 # regardless of that cut, but clear the accumulators and start over if
577 # we later see data that does pass the per-detector cuts.
578 is_fallback_gc = GridContainer[bool](warp_stacker_gc.shape)
580 # We accumulate the information to pass to the Healsparse input-map
581 # accumulator instead of calling it directly, so we can do that only
582 # after we've accumulated all warps and hence know which cells will
583 # stay in fallback mode.
584 input_map_data_gc = GridContainer[list](warp_stacker_gc.shape)
586 # Make a container to hold the cell centers in sky coordinates now,
587 # so we don't have to recompute them for each warp
588 # (they share a common WCS). These are needed to find the various
589 # warp + detector combinations that contributed to each cell, and later
590 # get the corresponding PSFs as well.
591 cell_centers_sky = GridContainer[geom.SpherePoint](warp_stacker_gc.shape)
592 # Make a container to hold the observation identifiers for each cell.
593 observation_identifiers_gc = GridContainer[dict](warp_stacker_gc.shape)
595 if self.config.do_input_map:
596 # We need to know all the visit + detector pairs in the inputs.
597 warp_input_list = [warp_ref.warp.get(component="coaddInputs") for warp_ref in inputs.values()]
598 visit_detectors = []
599 for warp_input in warp_input_list:
600 for row in warp_input.ccds:
601 visit_detectors.append((int(row["visit"]), int(row["ccd"])))
603 self.input_mapper.initialize_cell_input_map(
604 skyInfo.patchInfo.getOuterBBox(),
605 skyInfo.patchInfo.wcs,
606 visit_detectors,
607 )
609 # Populate them.
610 for cellInfo in skyInfo.patchInfo:
611 # Make a list to hold the observation identifiers for each cell.
612 observation_identifiers_gc[cellInfo.index] = {}
613 cell_center_pixel = geom.Point2D(geom.Point2I(cellInfo.inner_bbox.getCenter()))
614 cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cell_center_pixel)
615 psf_bbox_gc[cellInfo.index] = geom.Box2I.makeCenteredBox(
616 cell_center_pixel,
617 geom.Extent2I(self.config.psf_dimensions, self.config.psf_dimensions),
618 )
619 psf_stacker_gc[cellInfo.index] = AccumulatorMeanStack(
620 # The shape is for the numpy arrays, hence transposed.
621 shape=(self.config.psf_dimensions, self.config.psf_dimensions),
622 bit_mask_value=0,
623 calc_error_from_input_variance=self.config.calc_error_from_input_variance,
624 compute_n_image=False,
625 )
626 is_fallback_gc[cellInfo.index] = True
627 input_map_data_gc[cellInfo.index] = []
629 # visit_summary do not have (tract, patch, band, skymap) dimensions.
630 if not visitSummaryList:
631 visitSummaryList = []
632 visitSummaryRefDict = {
633 visitSummaryRef.dataId["visit"]: visitSummaryRef for visitSummaryRef in visitSummaryList
634 }
636 # Keep track of the polygons corresponding to each (visit, detector).
637 visit_polygons: dict[ObservationIdentifiers, afwGeom.Polygon] = {}
639 # Read in one warp at a time, and accumulate it in all the cells that
640 # it completely overlaps.
641 for warp_input in inputs.values():
642 # warps that have been excluded from CompareWarp via visit
643 # selection from SelectVisitsTasks will not have artifact masks.
644 # Exclude them from the cell coadds too.
645 if self.config.require_artifact_mask and warp_input.artifact_mask is None:
646 self.log.info(
647 "Excluding warp %s from cell coadds because it has no artifact mask",
648 warp_input.dataId["visit"],
649 )
650 continue
652 warp = warp_input.warp.get(parameters={"bbox": skyInfo.bbox})
653 masked_fraction_image = (
654 warp_input.masked_fraction.get(parameters={"bbox": skyInfo.bbox})
655 if warp_input.masked_fraction
656 else None
657 )
659 # Pre-process the warp before coadding.
660 # TODO: Can we get these mask names from artifactMask?
661 warp.mask.addMaskPlane("CLIPPED")
662 warp.mask.addMaskPlane("REJECTED")
663 warp.mask.addMaskPlane("SENSOR_EDGE")
664 warp.mask.addMaskPlane("INEXACT_PSF")
666 if artifact_mask_ref := warp_input.artifact_mask:
667 # Apply the artifact mask to the warp.
668 artifact_mask = artifact_mask_ref.get()
669 assert (
670 warp.mask.getMaskPlaneDict() == artifact_mask.getMaskPlaneDict()
671 ), "Mask dicts do not agree."
672 warp.mask.array = artifact_mask.array
673 del artifact_mask
675 if self.config.do_scale_zero_point:
676 # Each Warp that goes into a coadd will typically have an
677 # independent photometric zero-point. Therefore, we must scale
678 # each Warp to set it to a common photometric zeropoint.
679 imageScaler = self.scale_zero_point.run(exposure=warp, dataRef=warp_input.warp).imageScaler
680 zero_point_scale_factor = imageScaler.scale
681 self.log.debug(
682 "Scaled the warp %s by %f to match zero points",
683 warp_input.dataId,
684 zero_point_scale_factor,
685 )
686 else:
687 zero_point_scale_factor = 1.0
688 if "BUNIT" not in warp.metadata:
689 raise ValueError(f"Warp {warp_input.dataId} has no BUNIT metadata")
690 if warp.metadata["BUNIT"] != "nJy":
691 raise ValueError(
692 f"Warp {warp_input.dataId} has BUNIT {warp.metadata['BUNIT']}, expected nJy"
693 )
695 # Only try to remove maks planes that have been registered.
696 to_remove = []
697 for plane in self.config.remove_mask_planes:
698 if plane in warp.mask.getMaskPlaneDict():
699 to_remove.append(plane)
700 removeMaskPlanes(warp.mask, to_remove, self.log)
701 # Instead of using self.config.bad_mask_planes, we explicitly
702 # ask statsCtrl which pixels are going to be ignored/rejected.
703 rejected = afwImage.Mask.getPlaneBitMask(
704 ["CLIPPED", "REJECTED"] + afwImage.Mask.interpret(statsCtrl.getAndMask()).split(",")
705 )
707 # Compute the weight for each CCD in the warp from the visitSummary
708 # or from the warp itself, if not provided. Computing the weight
709 # from the warp is not recommended, and in that case we compute one
710 # weight per warp and not bother with per-detector weights.
711 full_ccd_table = warp.getInfo().getCoaddInputs().ccds
712 weights: dict[int, float] = dict.fromkeys(
713 full_ccd_table["ccd"].tolist(),
714 0.0,
715 ) # Mapping from detector to weight.
717 if visitSummaryRef := visitSummaryRefDict.get(warp_input.dataId["visit"]):
718 visitSummary = visitSummaryRef.get()
719 for detector in full_ccd_table["ccd"].tolist():
720 visitSummaryRow = visitSummary.find(detector)
721 mean_variance = visitSummaryRow["meanVar"]
722 mean_variance *= zero_point_scale_factor**2
723 if warp.metadata.get("BUNIT", None) == "nJy":
724 mean_variance *= visitSummaryRow.photoCalib.getCalibrationMean() ** 2
725 weights[detector] = 1.0 / mean_variance
726 del visitSummary
727 else:
728 self.log.debug("No visit summary found for %s; using warp-based weights", warp_input.dataId)
729 weight = self._compute_weight(warp, statsCtrl)
730 if not np.isfinite(weight):
731 self.log.warning("Non-finite weight for %s: skipping", warp_input.dataId)
732 continue
734 for detector in weights:
735 weights[detector] = weight
737 noise_warps = [ref.get(parameters={"bbox": skyInfo.bbox}) for ref in warp_input.noise_warps]
739 # Create an image where each pixel value corresponds to the
740 # detector ID that pixel comes from.
741 detector_map = afwImage.ImageI(bbox=warp.getBBox(), initialValue=-1)
742 for row in full_ccd_table:
743 transform = makeWcsPairTransform(row.wcs, warp.wcs)
744 if (src_polygon := row.validPolygon) is None:
745 src_polygon = afwGeom.Polygon(geom.Box2D(row.getBBox()))
746 try:
747 dest_polygon = src_polygon.transform(transform).intersectionSingle(
748 geom.Box2D(warp.getBBox())
749 )
750 except SinglePolygonException:
751 continue
753 observation_identifier = ObservationIdentifiers.from_data_id(
754 warp_input.dataId,
755 backup_detector=row["ccd"],
756 )
757 visit_polygons[observation_identifier] = dest_polygon
759 detector_map_slice = dest_polygon.createImage(detector_map.getBBox()).array > 0
760 if not (detector_map.array[detector_map_slice] < 0).all():
761 self.log.warning("Multiple detectors from visit %s are overlapping", warp_input.dataId)
762 detector_map.array[detector_map_slice] = row["ccd"]
764 if (detector_map.array < 0).all():
765 self.log.warning("Unable to split the warp %s into single-detector warps.", warp_input.dataId)
766 detector_map.array[:, :] = 0
768 for cellInfo, ccd_row in itertools.product(skyInfo.patchInfo, full_ccd_table):
769 bbox = cellInfo.outer_bbox
770 inner_bbox = cellInfo.inner_bbox
772 overlap_fraction = (detector_map[inner_bbox].array == ccd_row["ccd"]).mean()
773 assert -1e-4 < overlap_fraction < 1.0001, "Overlap fraction is not within [0, 1]."
774 if (overlap_fraction < self.config.min_overlap_fraction) or (overlap_fraction <= 0.0):
775 self.log.debug(
776 "Skipping %s in cell %s because it had only %.3f < %.3f fractional overlap.",
777 warp_input.dataId,
778 cellInfo.index,
779 overlap_fraction,
780 self.config.min_overlap_fraction,
781 )
782 continue
784 weight = weights[int(ccd_row["ccd"])]
785 if not np.isfinite(weight):
786 self.log.warning(
787 "Non-finite weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index
788 )
789 continue
791 if weight == 0:
792 self.log.info(
793 "Zero weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index
794 )
795 continue
797 # Compute the unmasked fraction for this detector in the inner
798 # cell. Used to gate on max_maskfrac.
799 inner_detector_pixels = detector_map[inner_bbox].array == ccd_row["ccd"]
800 inner_unmasked_pixels = (warp[inner_bbox].mask.array & rejected) == 0
801 unmasked_fraction = (
802 inner_detector_pixels & inner_unmasked_pixels
803 ).sum() / inner_detector_pixels.sum()
804 is_fallback = is_fallback_gc[cellInfo.index]
805 if unmasked_fraction <= max(1.0 - self.config.max_maskfrac, 0.0):
806 if not is_fallback:
807 # We already have good data in this cell, so we don't
808 # want this heavily masked warp - it will add too much
809 # INEXACT_PSF.
810 self.log.debug(
811 "Skipping %s in cell %s: masked fraction %.3f exceeds threshold %.3f",
812 warp_input.dataId,
813 cellInfo.index,
814 1.0 - unmasked_fraction,
815 self.config.max_maskfrac,
816 )
817 continue
818 else:
819 self.log.debug(
820 "Including %s in cell %s only as potential fallback: "
821 "masked fraction %.3f exceeds threshold %.3f",
822 warp_input.dataId,
823 cellInfo.index,
824 1.0 - unmasked_fraction,
825 self.config.max_maskfrac,
826 )
827 elif is_fallback:
828 # This is the first good data we've gotten for this cell;
829 # wipe out the fallback coadd we've been accumulating so
830 # far, so we can start fresh.
831 warp_stacker_gc[cellInfo.index].reset()
832 maskfrac_stacker_gc[cellInfo.index].reset()
833 for n in range(self.config.num_noise_realizations):
834 noise_stacker_gc_list[n][cellInfo.index].reset()
835 psf_stacker_gc[cellInfo.index].reset()
836 ap_corr_stacker_gc[cellInfo.index].reset()
837 observation_identifiers_gc[cellInfo.index].clear()
838 input_map_data_gc[cellInfo.index].clear()
839 is_fallback_gc[cellInfo.index] = False
841 overlaps_center = detector_map[geom.Point2I(bbox.getCenter())] == ccd_row["ccd"]
842 if not overlaps_center:
843 self.log.debug(
844 "%s does not overlap with the center of the cell %s",
845 warp_input.dataId,
846 cellInfo.index,
847 )
848 continue
850 # Decide if a deep copy is necessary to apply the single
851 # detector cuts since it involves modifying the image in-place.
852 # If within the inner cell, there are three or more different
853 # values that detector map takes, then there are definitely
854 # multiple detectors (one for chip gaps, two for two detectors)
855 deep_copy = len(set(detector_map[inner_bbox].array.ravel())) >= 3
856 if deep_copy:
857 single_detector_mask_array = detector_map[bbox].array != ccd_row["ccd"]
859 mi = afwImage.MaskedImageF(warp[bbox].maskedImage, deep=deep_copy)
860 if deep_copy:
861 mi.image.array[single_detector_mask_array] = 0.0
862 mi.variance.array[single_detector_mask_array] = np.inf
863 nodata_or_mask = (single_detector_mask_array) * afwImage.Mask.getPlaneBitMask("NO_DATA")
864 mi.mask[bbox].array |= nodata_or_mask
865 warp_stacker_gc[cellInfo.index].add_masked_image(mi, weight=weight)
867 if masked_fraction_image:
868 mi = afwImage.ImageF(masked_fraction_image[bbox], deep=True)
869 if deep_copy:
870 mi.array[single_detector_mask_array] = 0.0
871 mi.array[(warp[bbox].mask.array & rejected) != 0] = 1.0
872 maskfrac_stacker_gc[cellInfo.index].add_image(mi, weight=weight)
874 for n in range(self.config.num_noise_realizations):
875 mi = afwImage.MaskedImageF(noise_warps[n][bbox], deep=deep_copy)
876 if deep_copy:
877 mi.image.array[single_detector_mask_array] = 0.0
878 mi.variance.array[single_detector_mask_array] = np.inf
879 mi.mask[bbox].array |= nodata_or_mask
880 noise_stacker_gc_list[n][cellInfo.index].add_masked_image(mi, weight=weight)
882 # Set the defaults for PSF shape quantities.
883 psf_shape = afwGeom.Quadrupole()
884 psf_shape_flag = True
885 psf_eval_point = None
886 try:
887 # The `if` branch is buggy. `dest_polygon` is technically
888 # out of scope, but Python does not raise an error.
889 # TODO: Fix this properly in DM-53479, but sweep it under
890 # the rug for now.
891 if overlap_fraction < 0.5:
892 psf_eval_point = dest_polygon.intersectionSingle(
893 geom.Box2D(inner_bbox)
894 ).calculateCenter()
895 else:
896 psf_eval_point = geom.Point2D(geom.Point2I(inner_bbox.getCenter()))
897 psf_shape = warp.psf.computeShape(psf_eval_point)
898 psf_shape_flag = False
899 except SinglePolygonException:
900 self.log.info(
901 "Unable to find the overlapping polygon between %d detector in %s and cell %s",
902 ccd_row["ccd"],
903 warp_input.dataId,
904 cellInfo.index,
905 )
906 except InvalidPsfError:
907 self.log.info(
908 "Unable to compute PSF shape from %d detector in %s at %s",
909 ccd_row["ccd"],
910 warp_input.dataId,
911 psf_eval_point,
912 )
914 observation_identifier = ObservationIdentifiers.from_data_id(
915 warp_input.dataId,
916 backup_detector=int(ccd_row["ccd"]),
917 )
918 observation_identifiers_gc[cellInfo.index][observation_identifier] = CoaddInputs(
919 overlaps_center=overlaps_center,
920 overlap_fraction=overlap_fraction,
921 unmasked_overlap_fraction=unmasked_fraction,
922 weight=weight,
923 psf_shape=psf_shape,
924 psf_shape_flag=psf_shape_flag,
925 )
926 input_map_data_gc[cellInfo.index].append((ccd_row, weight))
928 # Everything below this has to do with the center of the cell
929 calexp_point = ccd_row.getWcs().skyToPixel(cell_centers_sky[cellInfo.index])
930 undistorted_psf_im = ccd_row.getPsf().computeImage(calexp_point)
932 assert undistorted_psf_im.getBBox() == geom.Box2I.makeCenteredBox(
933 calexp_point,
934 undistorted_psf_im.getDimensions(),
935 ), "PSF image does not share the coordinates of the 'calexp'"
937 # Convert the PSF image from Image to MaskedImage and
938 # zero-pad the image.
939 undistorted_psf_bbox = undistorted_psf_im.getBBox()
940 undistorted_psf_maskedImage = afwImage.MaskedImageD(
941 undistorted_psf_bbox.dilatedBy(self.psf_padding)
942 )
943 undistorted_psf_maskedImage.image[undistorted_psf_bbox].array[:, :] = undistorted_psf_im.array
944 # TODO: In DM-43585, use the variance plane value from noise.
945 undistorted_psf_maskedImage.variance += 1.0 # Set variance to 1
947 warped_psf_maskedImage = self.psf_warper.warpImage(
948 destWcs=skyInfo.wcs,
949 srcImage=undistorted_psf_maskedImage,
950 srcWcs=ccd_row.getWcs(),
951 destBBox=psf_bbox_gc[cellInfo.index],
952 )
954 # There may be NaNs in the PSF image. Set them to 0.0
955 warped_psf_maskedImage.variance.array[np.isnan(warped_psf_maskedImage.image.array)] = 1.0
956 warped_psf_maskedImage.image.array[np.isnan(warped_psf_maskedImage.image.array)] = 0.0
958 psf_stacker = psf_stacker_gc[cellInfo.index]
959 psf_stacker.add_masked_image(warped_psf_maskedImage, weight=weight)
961 if not (0.995 < (psf_normalization := warped_psf_maskedImage.image.array.sum()) < 1.005):
962 self.log.warning(
963 "PSF image for %s in %s is not normalized to 1.0, but instead %f",
964 warp_input.dataId,
965 cellInfo.index,
966 psf_normalization,
967 )
969 if (ap_corr_map := warp.getInfo().getApCorrMap()) is not None:
970 ap_corr_stacker_gc[cellInfo.index].add(ap_corr_map, weight=weight)
972 del warp
974 # Update common with the visit polygons.
975 self.common = dataclasses.replace(
976 self.common,
977 visit_polygons=visit_polygons,
978 )
980 cells: list[SingleCellCoadd] = []
981 for cellInfo in skyInfo.patchInfo:
982 if len(observation_identifiers_gc[cellInfo.index]) == 0:
983 self.log.debug("Skipping cell %s because it has no input warps", cellInfo.index)
984 continue
986 cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox)
987 cell_maskfrac_image = afwImage.ImageF(cellInfo.outer_bbox)
988 cell_noise_images = [
989 afwImage.MaskedImageF(cellInfo.outer_bbox) for n in range(self.config.num_noise_realizations)
990 ]
991 psf_masked_image = afwImage.MaskedImageF(psf_bbox_gc[cellInfo.index])
993 warp_stacker_gc[cellInfo.index].fill_stacked_masked_image(cell_masked_image)
994 maskfrac_stacker_gc[cellInfo.index].fill_stacked_image(cell_maskfrac_image)
995 for n in range(self.config.num_noise_realizations):
996 noise_stacker_gc_list[n][cellInfo.index].fill_stacked_masked_image(cell_noise_images[n])
997 psf_stacker_gc[cellInfo.index].fill_stacked_masked_image(psf_masked_image)
999 if ap_corr_stacker_gc[cellInfo.index].ap_corr_names:
1000 ap_corr_map = ap_corr_stacker_gc[cellInfo.index].final_ap_corr_map
1001 else:
1002 ap_corr_map = None
1004 # Post-process the coadd before converting to new data structures.
1005 if np.isnan(cell_masked_image.image.array).all():
1006 cell_masked_image.image.array[:, :] = 0.0
1007 cell_masked_image.variance.array[:, :] = np.inf
1008 elif self.config.do_interpolate_coadd:
1009 self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA")
1010 for noise_image in cell_noise_images:
1011 self.interpolate_coadd.run(noise_image, planeName="NO_DATA")
1012 # The variance must be positive; work around for DM-3201.
1013 varArray = cell_masked_image.variance.array
1014 with np.errstate(invalid="ignore"):
1015 varArray[:] = np.where(varArray > 0, varArray, np.inf)
1017 afwImage.Mask.addMaskPlane("INEXACT_PSF")
1018 cell_masked_image.mask.array[
1019 (cell_masked_image.mask.array & rejected) > 0
1020 ] |= cell_masked_image.mask.getPlaneBitMask("INEXACT_PSF")
1022 if self.config.do_input_map:
1023 self.input_mapper.build_cell_input_map(cellInfo)
1024 for ccd_row, weight in input_map_data_gc[cellInfo.index]:
1025 self.input_mapper.add_warp_to_cell_input_map(ccd_row, weight, cellInfo)
1027 image_planes = OwnedImagePlanes.from_masked_image(
1028 masked_image=cell_masked_image,
1029 mask_fractions=cell_maskfrac_image,
1030 noise_realizations=[noise_image.image for noise_image in cell_noise_images],
1031 )
1032 identifiers = CellIdentifiers(
1033 cell=cellInfo.index,
1034 skymap=self.common.identifiers.skymap,
1035 tract=self.common.identifiers.tract,
1036 patch=self.common.identifiers.patch,
1037 band=self.common.identifiers.band,
1038 )
1040 singleCellCoadd = SingleCellCoadd(
1041 outer=image_planes,
1042 psf=psf_masked_image.image,
1043 inner_bbox=cellInfo.inner_bbox,
1044 inputs=observation_identifiers_gc[cellInfo.index],
1045 common=self.common,
1046 identifiers=identifiers,
1047 aperture_correction_map=ap_corr_map,
1048 )
1049 # TODO: Attach transmission curve when they become available.
1050 cells.append(singleCellCoadd)
1052 if not cells:
1053 raise NoWorkFound("No cells could be populated for the cell coadd.")
1055 grid = self._construct_grid(skyInfo)
1056 multipleCellCoadd = MultipleCellCoadd(
1057 cells,
1058 grid=grid,
1059 outer_cell_size=cellInfo.outer_bbox.getDimensions(),
1060 inner_bbox=None,
1061 common=self.common,
1062 psf_image_size=cells[0].psf_image.getDimensions(),
1063 )
1065 if self.config.do_input_map:
1066 inputMap = self.input_mapper.cell_input_map
1067 else:
1068 inputMap = None
1070 return Struct(
1071 multipleCellCoadd=multipleCellCoadd,
1072 inputMap=inputMap,
1073 )
1076class ConvertMultipleCellCoaddToExposureConnections(
1077 PipelineTaskConnections,
1078 dimensions=("tract", "patch", "band", "skymap"),
1079 defaultTemplates={"inputCoaddName": "deep", "inputCoaddSuffix": "Cell"},
1080):
1081 cellCoaddExposure = Input(
1082 doc="Output coadded exposure, produced by stacking input warps",
1083 name="{inputCoaddName}Coadd{inputCoaddSuffix}",
1084 storageClass="MultipleCellCoadd",
1085 dimensions=("tract", "patch", "skymap", "band"),
1086 )
1088 stitchedCoaddExposure = Output(
1089 doc="Output stitched coadded exposure, produced by stacking input warps",
1090 name="{inputCoaddName}Coadd{inputCoaddSuffix}_stitched",
1091 storageClass="ExposureF",
1092 dimensions=("tract", "patch", "skymap", "band"),
1093 )
1096class ConvertMultipleCellCoaddToExposureConfig(
1097 PipelineTaskConfig, pipelineConnections=ConvertMultipleCellCoaddToExposureConnections
1098):
1099 """A trivial PipelineTaskConfig class for
1100 ConvertMultipleCellCoaddToExposureTask.
1101 """
1104class ConvertMultipleCellCoaddToExposureTask(PipelineTask):
1105 """An after burner PipelineTask that converts a cell-based coadd from
1106 `MultipleCellCoadd` format to `ExposureF` format.
1108 The run method stitches the cell-based coadd into contiguous exposure and
1109 returns it in as an `Exposure` object. This is lossy as it preserves only
1110 the pixels in the inner bounding box of the cells and discards the values
1111 in the buffer region.
1113 Notes
1114 -----
1115 This task has no configurable parameters.
1116 """
1118 ConfigClass = ConvertMultipleCellCoaddToExposureConfig
1119 _DefaultName = "convertMultipleCellCoaddToExposure"
1121 def run(self, cellCoaddExposure):
1122 return Struct(
1123 stitchedCoaddExposure=cellCoaddExposure.stitch().asExposure(),
1124 )