Coverage for python/lsst/drp/tasks/metadetection_shear.py: 19%
265 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 01:20 -0700
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 01:20 -0700
1# This file is part of drp_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "MetadetectionProcessingError",
26 "MetadetectionShearConfig",
27 "MetadetectionShearTask",
28)
30from collections.abc import Collection, Mapping, Sequence
31from itertools import product
32from typing import Any, ClassVar
34import esutil as eu
35import numpy as np
36import pyarrow as pa
37from metadetect.lsst.masking import apply_apodized_bright_masks_mbexp, apply_apodized_edge_masks_mbexp
38from metadetect.lsst.metacal_exposures import STEP as SHEAR_STEP
39from metadetect.lsst.metadetect import MetadetectTask
40from metadetect.lsst.util import extract_multiband_coadd_data
42import lsst.pipe.base.connectionTypes as cT
43from lsst.afw.image import ExposureF
44from lsst.afw.table import SimpleCatalog
45from lsst.cell_coadds import MultipleCellCoadd, StitchedCoadd
46from lsst.daf.butler import DataCoordinate, DatasetRef
47from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader
48from lsst.meas.base import FullIdGenerator, SkyMapIdGeneratorConfig
49from lsst.pex.config import ConfigField, ConfigurableField, Field, FieldValidationError, ListField
50from lsst.pipe.base import (
51 AlgorithmError,
52 AnnotatedPartialOutputsError,
53 InputQuantizedConnection,
54 InvalidQuantumError,
55 NoWorkFound,
56 OutputQuantizedConnection,
57 PipelineTask,
58 PipelineTaskConfig,
59 PipelineTaskConnections,
60 QuantumContext,
61 Struct,
62)
63from lsst.pipe.base.connectionTypes import BaseInput, Output
64from lsst.skymap import BaseSkyMap, Index2D
67class MetadetectionProcessingError(AlgorithmError):
68 """Exception raised when metadetection processing fails."""
70 @property
71 def metadata(self) -> dict:
72 return {}
75class MetadetectionShearConnections(PipelineTaskConnections, dimensions={"patch"}):
76 """Definitions of inputs and outputs for MetadetectionShearTask."""
78 input_coadds = cT.Input(
79 "deep_coadd_cell_predetection",
80 storageClass="MultipleCellCoadd",
81 doc="Per-band deep coadds.",
82 multiple=True,
83 dimensions={"patch", "band"},
84 )
86 sky_map = cT.Input(
87 doc="Cell-based skymap defining the patch structure.",
88 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
89 storageClass="SkyMap",
90 dimensions=("skymap",),
91 )
93 ref_cat = cT.PrerequisiteInput(
94 doc="Reference catalog used to mask bright objects.",
95 name="the_monster_20250219",
96 storageClass="SimpleCatalog",
97 dimensions=("skypix",),
98 deferLoad=True,
99 multiple=True,
100 )
102 metadetect_catalog = cT.Output(
103 "object_shear_patch",
104 storageClass="ArrowTable",
105 doc="Output catalog with all quantities measured inside the metacalibration loop.",
106 multiple=False,
107 dimensions={"patch"},
108 )
110 metadetect_schema = cT.InitOutput(
111 "object_shear_schema",
112 storageClass="ArrowSchema",
113 doc="Schema of the output catalog.",
114 )
116 config: MetadetectionShearConfig
118 def __init__(self, *, config=None):
119 super().__init__(config=config)
121 if not config:
122 return None
124 if not config.do_mask_bright_objects:
125 del self.ref_cat
127 def adjustQuantum(
128 self,
129 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
130 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
131 label: str,
132 data_id: DataCoordinate,
133 ) -> tuple[
134 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
135 Mapping[str, tuple[Output, Collection[DatasetRef]]],
136 ]:
137 # Docstring inherited.
138 # This is a hook for customizing what is input and output to each
139 # invocation of the task as early as possible, which we override here
140 # to make sure we have exactly the required bands, no more, no less.
141 connection, original_input_coadds = inputs["input_coadds"]
142 bands_missing = set(self.config.photometry_bands)
143 adjusted_input_coadds = []
144 for ref in original_input_coadds:
145 if ref.dataId["band"] in self.config.photometry_bands:
146 adjusted_input_coadds.append(ref)
147 bands_missing.remove(ref.dataId["band"])
148 if missing_shear_bands := bands_missing.intersection(self.config.metadetect.shear_bands):
149 raise NoWorkFound(f"Required bands {missing_shear_bands} not present for {label}@{data_id}).")
150 adjusted_inputs = {"input_coadds": (connection, adjusted_input_coadds)}
151 inputs.update(adjusted_inputs)
152 super().adjustQuantum(inputs, outputs, label, data_id)
153 return adjusted_inputs, {}
156class MetadetectionShearConfig(PipelineTaskConfig, pipelineConnections=MetadetectionShearConnections):
157 """Configuration definition for MetadetectionShearTask."""
159 metadetect = ConfigurableField(
160 target=MetadetectTask,
161 doc="Configuration for metadetection.",
162 )
164 photometry_bands = ListField[str](
165 "Bands expected to be present. Cells with one or more of these bands "
166 "missing will be skipped. Bands other than those listed here will "
167 "not be processed.",
168 default=["g", "r", "i", "z"],
169 )
171 do_mask_bright_objects = Field[bool](
172 doc="Mask bright objects in coadds?",
173 default=False,
174 )
176 ref_loader = ConfigField(
177 dtype=LoadReferenceObjectsConfig,
178 doc="Reference object loader used for bright-object masking.",
179 )
181 ref_loader_filter_name = Field[str](
182 "Filter name from ref_loader used for bright-object masking.",
183 default="monster_DES_r",
184 )
186 border = Field[int](
187 "Border to apply to single cell images, if skymap has no cell borders",
188 default=50,
189 )
191 id_generator = SkyMapIdGeneratorConfig.make_field()
193 def setDefaults(self):
194 super().setDefaults()
195 self.metadetect.shear_bands = ["r", "i", "z"]
196 self.metadetect.metacal.types = ["noshear", "1p", "1m", "2p", "2m"]
198 def validate(self):
199 super().validate()
200 if (shear_bands := self.metadetect.shear_bands) is not None and not set(shear_bands).issubset(
201 self.photometry_bands
202 ):
203 raise FieldValidationError(
204 self.__class__.metadetect,
205 self,
206 "photometry_bands must be a list of bands that is a superset of metadetect.shear_bands",
207 )
210class MetadetectionShearTask(PipelineTask):
211 """A PipelineTask that measures shear using metadetection."""
213 _DefaultName: ClassVar[str] = "metadetectionShear"
214 ConfigClass: ClassVar[type[MetadetectionShearConfig]] = MetadetectionShearConfig
216 config: MetadetectionShearConfig
218 def __init__(self, *, initInputs: dict[str, Any] | None = None, **kwargs: Any):
219 super().__init__(initInputs=initInputs, **kwargs)
220 self.metadetect_schema = self.make_metadetect_schema(self.config)
221 self.makeSubtask("metadetect")
223 @classmethod
224 def make_metadetect_schema(cls, config: MetadetectionShearConfig) -> pa.Schema:
225 """Construct a PyArrow Schema for this task's main output catalog.
227 Parameters
228 ----------
229 config : `MetadetectionShearConfig`
230 Configuration that may be used to control details of the schema.
232 Returns
233 -------
234 object_schema : `pyarrow.Schema`
235 Schema for the object catalog produced by this task. Each field's
236 metadata should include both a 'doc' entry and a 'unit' entry.
237 """
238 pa_schema = pa.schema(
239 [
240 # Fields from pipeline bookkeeping.
241 pa.field(
242 "shearObjectId",
243 pa.int64(),
244 nullable=False,
245 metadata={
246 "doc": (
247 "Unique identifier for a ShearObject, specific "
248 "to a single metacalibration counterfactual image."
249 ),
250 "unit": "",
251 },
252 ),
253 pa.field(
254 "tract",
255 pa.int64(),
256 nullable=False,
257 metadata={
258 "doc": "ID of the tract on which this measurement was made.",
259 "unit": "",
260 },
261 ),
262 pa.field(
263 "patch",
264 pa.int64(),
265 nullable=False,
266 metadata={
267 "doc": "ID of the patch within the tract on which this measurement was made.",
268 "unit": "",
269 },
270 ),
271 pa.field(
272 "cell_x",
273 pa.int32(),
274 nullable=False,
275 metadata={
276 "doc": "Column of the cell within the patch on which this measurement was made.",
277 "unit": "",
278 },
279 ),
280 pa.field(
281 "cell_y",
282 pa.int32(),
283 nullable=False,
284 metadata={
285 "doc": "Row of the cell within the patch on which this measurement was made.",
286 "unit": "",
287 },
288 ),
289 # Fields from metadetection (generic).
290 pa.field(
291 "metaStep",
292 pa.string(),
293 nullable=False,
294 metadata={
295 "doc": (
296 "Type of artificial shear applied to image. "
297 "One of: 'ns', '1p', '1m', '2p', '2m'."
298 ),
299 "unit": "",
300 },
301 ),
302 pa.field(
303 "image_flags",
304 pa.int32(),
305 nullable=False,
306 metadata={
307 "doc": "Flags for the image on which this measurement was made.",
308 "unit": "",
309 },
310 ),
311 pa.field(
312 "x",
313 pa.float32(),
314 nullable=False,
315 metadata={
316 "doc": "Centroid (tract, x-axis) of the detected ShearObject.",
317 "unit": "",
318 },
319 ),
320 pa.field(
321 "y",
322 pa.float32(),
323 nullable=False,
324 metadata={
325 "doc": "Centroid (tract, y-axis) of the detected ShearObject.",
326 "unit": "",
327 },
328 ),
329 pa.field(
330 "ra",
331 pa.float64(),
332 nullable=False,
333 metadata={
334 "doc": "Detected Right Ascension of the ShearObject.",
335 "unit": "degrees",
336 },
337 ),
338 pa.field(
339 "dec",
340 pa.float64(),
341 nullable=False,
342 metadata={
343 "doc": "Detected Declination of the ShearObject.",
344 "unit": "degrees",
345 },
346 ),
347 # Original PSF measurements
348 pa.field(
349 "psfOriginal_flags",
350 pa.int32(),
351 nullable=False,
352 metadata={
353 "doc": "Flags for the original PSF measurement.",
354 "unit": "",
355 },
356 ),
357 pa.field(
358 "psfOriginal_e1",
359 pa.float32(),
360 nullable=False,
361 metadata={
362 "doc": "Distortion-style e1 of the original PSF from adaptive moments.",
363 "unit": "",
364 },
365 ),
366 pa.field(
367 "psfOriginal_e2",
368 pa.float32(),
369 nullable=False,
370 metadata={
371 "doc": "Distortion-style e2 of the original PSF from adaptive moments.",
372 "unit": "",
373 },
374 ),
375 pa.field(
376 "psfOriginal_T",
377 pa.float32(),
378 nullable=False,
379 metadata={
380 "doc": "Trace (<x^2> + <y^2>) measurement of the original PSF from adaptive moments.",
381 "unit": "arcseconds squared",
382 },
383 ),
384 pa.field(
385 "bmask_flags",
386 pa.int32(),
387 nullable=False,
388 metadata={
389 "doc": "`bmask` flags for the ShearObject",
390 "unit": "",
391 },
392 ),
393 pa.field(
394 "ormask_flags",
395 pa.int32(),
396 nullable=False,
397 metadata={
398 "doc": "`ored` mask flags for the ShearObject",
399 "unit": "",
400 },
401 ),
402 pa.field(
403 "mfrac",
404 pa.float32(),
405 nullable=False,
406 metadata={
407 "doc": "Gaussian-weighted masked fraction for the ShearObject.",
408 "unit": "",
409 },
410 ),
411 # Fields that come only from gauss algorithm.
412 # Reconvolved PSF measurements (gauss)
413 pa.field(
414 "gauss_psfReconvolved_flags",
415 pa.int32(),
416 nullable=False,
417 metadata={
418 "doc": "Flags for reconvolved PSF (measured with gauss algorithm).",
419 "unit": "",
420 },
421 ),
422 pa.field(
423 "gauss_psfReconvolved_g1",
424 pa.float32(),
425 nullable=False,
426 metadata={
427 "doc": "Reduced-shear g1 of the reconvolved PSF (measured with gauss algorithm).",
428 "unit": "",
429 },
430 ),
431 pa.field(
432 "gauss_psfReconvolved_g2",
433 pa.float32(),
434 nullable=False,
435 metadata={
436 "doc": "Reduced-shear g2 of the reconvolved PSF (measured with gauss algorithm).",
437 "unit": "",
438 },
439 ),
440 pa.field(
441 "gauss_psfReconvolved_T",
442 pa.float32(),
443 nullable=False,
444 metadata={
445 "doc": (
446 "Trace (<x^2> + <y^2>) of the reconvolved PSF (measured with gauss algorithm)."
447 ),
448 "unit": "arcseconds squared",
449 },
450 ),
451 # Object measurements (gauss algorithm).
452 pa.field(
453 "gauss_g1",
454 pa.float32(),
455 nullable=False,
456 metadata={
457 "doc": (
458 "Reduced-shear g1 measurement of the ShearObject "
459 "(measured with gauss algorithm)."
460 ),
461 "unit": "",
462 },
463 ),
464 pa.field(
465 "gauss_g2",
466 pa.float32(),
467 nullable=False,
468 metadata={
469 "doc": (
470 "Reduced-shear g2 measurement of the ShearObject "
471 "(measured with gauss algorithm)."
472 ),
473 "unit": "",
474 },
475 ),
476 pa.field(
477 "gauss_g1_g1_Cov",
478 pa.float32(),
479 nullable=False,
480 metadata={
481 "doc": (
482 "Auto-covariance of g1 measurement of the ShearObject "
483 "(measured with gauss algorithm)."
484 ),
485 "unit": "",
486 },
487 ),
488 pa.field(
489 "gauss_g1_g2_Cov",
490 pa.float32(),
491 nullable=False,
492 metadata={
493 "doc": (
494 "Cross-covariance of g1 and g2 measurement of the ShearObject "
495 "(measured with gauss algorithm)."
496 ),
497 "unit": "",
498 },
499 ),
500 pa.field(
501 "gauss_g2_g2_Cov",
502 pa.float32(),
503 nullable=False,
504 metadata={
505 "doc": (
506 "Auto-covariance of g2 measurement of the ShearObject "
507 "(measured with gauss algorithm)."
508 ),
509 "unit": "",
510 },
511 ),
512 ],
513 metadata={
514 "shear_step": str(SHEAR_STEP),
515 "shear_bands": "".join(sorted(config.metadetect.shear_bands)),
516 },
517 )
519 for alg_name in ("gauss", "pgauss"):
520 pa_schema = pa_schema.append(
521 pa.field(
522 f"{alg_name}_snr",
523 pa.float32(),
524 nullable=False,
525 metadata={
526 "doc": (
527 "Signal-to-noise ratio measure of the ShearObject "
528 f"(measured with {alg_name} algorithm)."
529 ),
530 "unit": "",
531 },
532 ),
533 )
534 pa_schema = pa_schema.append(
535 pa.field(
536 f"{alg_name}_T",
537 pa.float32(),
538 nullable=False,
539 metadata={
540 "doc": (
541 "Trace (<x^2> + <y^2>) measurement of the ShearObject "
542 f"(measured with {alg_name} algorithm)."
543 ),
544 "unit": "arcseconds squared",
545 },
546 ),
547 )
548 pa_schema = pa_schema.append(
549 pa.field(
550 f"{alg_name}_TErr",
551 pa.float32(),
552 nullable=False,
553 metadata={
554 "doc": (
555 "Uncertainty in the trace measurement of the ShearObject "
556 f"(measured with {alg_name} algorithm)."
557 ),
558 "unit": "arcseconds squared",
559 },
560 ),
561 )
562 pa_schema = pa_schema.append(
563 pa.field(
564 f"{alg_name}_shape_flags",
565 pa.int32(),
566 nullable=False,
567 metadata={
568 "doc": (
569 "Flags for the second order moments measurement of the ShearObject "
570 f"(measured with {alg_name} algorithm)."
571 ),
572 "unit": "",
573 },
574 ),
575 )
576 pa_schema = pa_schema.append(
577 pa.field(
578 f"{alg_name}_object_flags",
579 pa.int32(),
580 nullable=False,
581 metadata={
582 "doc": f"Flags for the ShearObject measurement (measured with {alg_name} algorithm).",
583 "unit": "",
584 },
585 ),
586 )
587 pa_schema = pa_schema.append(
588 pa.field(
589 f"{alg_name}_flags",
590 pa.int32(),
591 nullable=False,
592 metadata={
593 "doc": f"Overall flags for {alg_name} measurement algorithm.",
594 "unit": "",
595 },
596 ),
597 )
599 # Per-band quantities, typically fluxes and associated quantites.
600 for b in config.photometry_bands:
601 pa_schema = pa_schema.append(
602 pa.field(
603 f"{b}_{alg_name}Flux_flags",
604 pa.int32(),
605 nullable=False,
606 metadata={
607 "doc": f"Flags set for flux in {b} band measured with {alg_name} algorithm.",
608 "unit": "",
609 },
610 ),
611 )
612 pa_schema = pa_schema.append(
613 pa.field(
614 f"{b}_{alg_name}Flux",
615 pa.float32(),
616 nullable=b not in config.metadetect.shear_bands,
617 metadata={
618 "doc": f"Flux in {b} band (measured with {alg_name} algorithm).",
619 "unit": "",
620 },
621 ),
622 )
623 pa_schema = pa_schema.append(
624 pa.field(
625 f"{b}_{alg_name}FluxErr",
626 pa.float32(),
627 nullable=b not in config.metadetect.shear_bands,
628 metadata={
629 "doc": f"Flux uncertainty in {b} band (measured with {alg_name} algorithm).",
630 "unit": "",
631 },
632 ),
633 )
635 return pa_schema
637 def validate_skymap_config(self, skymap_config: BaseSkyMap.ConfigClass) -> None:
638 if not skymap_config.tractBuilder.name == "cells":
639 raise InvalidQuantumError("MetadetectionShearTask requires a cell-based skymap.")
641 cell_config = skymap_config.tractBuilder.active
642 if (self.config.border == 0 and cell_config.cellBorder == 0) or (
643 self.config.border > 0 and cell_config.cellBorder > 0
644 ):
645 raise InvalidQuantumError(
646 "MetadetectionShearTask requires a positive border to be set either in the skymap config "
647 "or in the task config (but not in both)."
648 )
650 if self.config.border:
651 # In case cellInnerDimensions are different in different directions
652 # (which is rare in practice), take the min of it
653 cell_inner_dimensions = min(cell_config.cellInnerDimensions)
654 if self.config.border > (
655 max_border_value := cell_config.numCellsInPatchBorder * cell_inner_dimensions
656 ):
657 raise InvalidQuantumError(
658 "The border value is too large for the skymap configuration. "
659 f"Maximum border value is {max_border_value}."
660 )
662 # Ensure that the amount by which we withdraw inwards does not
663 # create gaps at tract boundaries.
664 # tractOverlap is specified in degrees.
665 if (
666 skymap_config.tractOverlap * 3600 / skymap_config.pixelScale
667 < self.count_cells_along_edges(skymap_config) * cell_inner_dimensions + cell_config.cellBorder
668 ):
669 raise InvalidQuantumError(
670 "The tract overlap is insufficient given the borders. "
671 "This will result in missed regions between adjacent tracts."
672 )
674 @staticmethod
675 def count_cells_along_edges(skymap_config: BaseSkyMap.ConfigClass) -> int:
676 """Count the number of cells along the edges to skip processing."""
677 if skymap_config.tractBuilder["cells"].cellBorder > 0:
678 return 0
680 return skymap_config.tractBuilder["cells"].numCellsInPatchBorder
682 def runQuantum(
683 self,
684 qc: QuantumContext,
685 inputRefs: InputQuantizedConnection,
686 outputRefs: OutputQuantizedConnection,
687 ) -> None:
688 # Docstring inherited.
690 # Get the skyMap for this quantum
691 sky_map = qc.get(inputRefs.sky_map)
693 self.validate_skymap_config(sky_map.config)
695 id_generator = self.config.id_generator.apply(qc.quantum.dataId)
697 if self.config.do_mask_bright_objects:
698 ref_loader = ReferenceObjectLoader(
699 dataIds=[ref.datasetRef.dataId for ref in inputRefs.ref_cat],
700 refCats=[qc.get(ref) for ref in inputRefs.ref_cat],
701 name=self.config.connections.ref_cat,
702 config=self.config.ref_loader,
703 log=self.log,
704 )
705 ref_cat = ref_loader.loadRegion(
706 qc.quantum.dataId.region, filterName=self.config.ref_loader_filter_name
707 )
708 else:
709 ref_cat = None
711 # Read the coadds and put them in the order defined by
712 # config.photometry_bands (note that each MultipleCellCoadd object also
713 # knows its own band, if that's needed).
715 coadds_by_band = {
716 ref.dataId["band"]: qc.get(ref)
717 for ref in inputRefs.input_coadds
718 if ref.dataId["band"] in self.config.photometry_bands
719 }
721 try:
722 outputs = self.run(
723 patch_coadds=coadds_by_band,
724 id_generator=id_generator,
725 sky_map=sky_map,
726 ref_cat=ref_cat,
727 )
728 except AlgorithmError as err:
729 # We know there are no actual outputs in this case, but this is
730 # still the right exception to raise (it's just badly named).
731 raise AnnotatedPartialOutputsError.annotate(err, self, log=self.log) from err
732 qc.put(outputs, outputRefs)
734 def run(
735 self,
736 *,
737 patch_coadds: Mapping[str, MultipleCellCoadd],
738 id_generator: FullIdGenerator,
739 sky_map: BaseSkyMap,
740 ref_cat: SimpleCatalog | None,
741 ) -> Struct:
742 """Run metadetection on a patch.
744 Parameters
745 ----------
746 patch_coadds : `~collections.abc.Mapping` [ \
747 `~lsst.cell_coadds.MultipleCellCoadd` ]
748 Per-band, per-patch coadds, in the order specified by
749 `MetadetectionShearConfig.photometry_bands`.
750 id_generator : `~lsst.meas.base.FullIdGenerator`
751 Generator for object IDs and to seed the random number generator.
752 sky_map : `~lsst.skymap.BaseSkyMap`
753 Sky map to use for determining the patch boundaries.
754 ref_cat : `lsst.afw.table.SimpleCatalog`, optional
755 Reference catalog to use when masking bright stars.
757 Returns
758 -------
759 results : `lsst.pipe.base.Struct`
760 Structure with the following attributes:
762 - ``metadetect_catalog`` [ `pyarrow.Table` ]: the output object
763 catalog for the patch, with schema equal to `metadetect_schema`.
764 """
765 seed = id_generator.catalog_id
766 self.rng = np.random.RandomState(seed)
767 idstart = 0
769 match sky_map.config.tractBuilder.active.cellBorder:
770 case 0:
771 # If cells have no borders, we cannot apply the metacal
772 # procedure to the edge cells in the same way as inner cells.
773 # We can do so for all that cells that are marked as borders,
774 # but it has to be at least 1.
775 if (num_edge_cells_skip := sky_map.config.tractBuilder.active.numCellsInPatchBorder) < 1:
776 raise InvalidQuantumError("No border cells found in the skymap configuration.")
777 case _:
778 num_edge_cells_skip = 0
780 dilate_by = self.config.border or 0
782 grid = patch_coadds[self.config.metadetect.shear_bands[0]].grid
783 nx_cells, ny_cells = grid.shape
784 single_cell_tables: list[pa.Table] = []
785 for nx, ny in product(
786 range(num_edge_cells_skip, nx_cells - num_edge_cells_skip),
787 range(num_edge_cells_skip, ny_cells - num_edge_cells_skip),
788 ):
789 cell_id = Index2D(nx, ny)
790 bbox = grid.bbox_of(cell_id).dilatedBy(dilate_by)
791 cell_coadds = [patch_coadd.stitch(bbox) for patch_coadd in patch_coadds.values()]
792 self.log.debug("Processing cell %s %s", nx, ny)
794 try:
795 res = self.process_cell(cell_coadds, cell_id=cell_id)
796 except Exception as e:
797 self.log.error("Failed to process cell %s %s: %s", nx, ny, e)
798 continue
800 if len(res) > 0:
801 res["id"] = id_generator.arange(idstart, idstart + len(res))
802 # TODO: Avoid back and forth conversion between array and dict.
803 da = self._dictify(
804 res,
805 tract=id_generator.data_id.tract.id,
806 patch=id_generator.data_id.patch.id,
807 )
808 table = pa.Table.from_pydict(da, self.metadetect_schema)
810 single_cell_tables.append(table)
811 idstart += len(res)
813 if not single_cell_tables:
814 raise MetadetectionProcessingError("No objects found in any cell")
816 # TODO: DM-53796 De-duplicate objects before concatenation.
817 return Struct(
818 metadetect_catalog=pa.concat_tables(single_cell_tables),
819 )
821 def process_cell(
822 self,
823 cell_coadds: Sequence[StitchedCoadd],
824 cell_id: Index2D,
825 ) -> pa.Table:
826 """Run metadetection on a single cell.
828 Parameters
829 ----------
830 cell_coadds : `~collections.abc.Sequence` [ \
831 `~lsst.cell_coadds.StitchedCoadd` ]
832 Per-band, per-cell coadds, in the order specified by
833 `MetadetectionShearConfig.photometry_bands`.
834 cell_id : `~lsst.skymap.Index2D`
835 The cell ID for the cell being processed.
837 Returns
838 -------
839 metadetect_catalog : `pyarrow.Table`
840 Output object catalog for the cell, with schema equal to
841 `metadetect_schema`.
842 """
844 coadd_data = self._cell_to_coadd_data(cell_coadds)
845 # TODO get bright star etc. info as input
846 bright_info = []
848 apply_apodized_edge_masks_mbexp(**coadd_data)
850 if len(bright_info) > 0:
851 apply_apodized_bright_masks_mbexp(bright_info=bright_info, **coadd_data)
853 mask_frac = _get_mask_frac(
854 coadd_data["mfrac_mbexp"],
855 trim_pixels=0,
856 )
858 res = self.metadetect.run(rng=self.rng, **coadd_data)
860 comb_res = _make_comb_data(
861 cell_coadd=cell_coadds[0],
862 res=res,
863 mask_frac=mask_frac,
864 bands=[cell_coadd.band for cell_coadd in cell_coadds],
865 cell_id=cell_id,
866 )
868 return comb_res
870 @staticmethod
871 def _cell_to_coadd_data(cell_coadds: Sequence[StitchedCoadd]):
872 coadd_data_list = []
873 for cell_coadd in cell_coadds:
874 coadd_data = {}
875 coadd_data["coadd_exp"] = cell_coadd.asExposure()
876 coadd_data["coadd_noise_exp"] = cell_coadd.asExposure(noise_index=0)
877 coadd_data["coadd_mfrac_exp"] = ExposureF(coadd_data["coadd_exp"], deep=True)
878 coadd_data["coadd_mfrac_exp"].image = cell_coadd.mask_fractions
879 coadd_data_list.append(coadd_data)
881 return extract_multiband_coadd_data(coadd_data_list)
883 def _dictify(self, data, tract: int, patch: int):
884 output = {}
885 # TODO: Move this to a better location after DP2.
886 mapping = {
887 "bmask": "bmask_flags",
888 "cell_x": "cell_x",
889 "cell_y": "cell_y",
890 "col": "x",
891 "col_diff": "x_offset", # dropped.
892 "dec": "dec",
893 "gauss_flags": "gauss_flags",
894 "gauss_g_1": "gauss_g1",
895 "gauss_g_2": "gauss_g2",
896 "gauss_g_cov_11": "gauss_g1_g1_Cov",
897 "gauss_g_cov_12": "gauss_g1_g2_Cov", # same as 21.
898 "gauss_g_cov_22": "gauss_g2_g2_Cov",
899 "gauss_obj_flags": "gauss_object_flags",
900 "gauss_psf_flags": "gauss_psfReconvolved_flags",
901 "gauss_psf_g_1": "gauss_psfReconvolved_g1",
902 "gauss_psf_g_2": "gauss_psfReconvolved_g2",
903 "gauss_psf_T": "gauss_psfReconvolved_T",
904 "gauss_s2n": "gauss_snr",
905 "gauss_T": "gauss_T",
906 "gauss_T_err": "gauss_TErr",
907 "gauss_T_flags": "gauss_shape_flags",
908 "gauss_T_ratio": "gauss_T_ratio", # dropped.
909 "id": "shearObjectId",
910 "mfrac": "mfrac",
911 "ormask": "ormask_flags",
912 "pgauss_flags": "pgauss_flags",
913 "pgauss_obj_flags": "pgauss_object_flags",
914 "pgauss_s2n": "pgauss_snr",
915 "pgauss_T": "pgauss_T",
916 "pgauss_T_err": "pgauss_TErr",
917 "pgauss_T_flags": "pgauss_shape_flags",
918 "pgauss_T_ratio": "pgauss_T_ratio", # dropped.
919 "psfrec_flags": "psfOriginal_flags",
920 "psfrec_g_1": "psfOriginal_e1",
921 "psfrec_g_2": "psfOriginal_e2",
922 "psfrec_T": "psfOriginal_T",
923 "ra": "ra",
924 "row": "y",
925 "row_diff": "y_offset", # dropped.
926 "shear_type": "metaStep",
927 "stamp_flags": "image_flags",
928 }
930 for b, alg_name in product(self.config.photometry_bands, ("gauss", "pgauss")):
931 mapping[f"{alg_name}_band_flux_{b}"] = f"{b}_{alg_name}Flux"
932 mapping[f"{alg_name}_band_flux_err_{b}"] = f"{b}_{alg_name}FluxErr"
933 mapping[f"{alg_name}_band_flux_flags_{b}"] = f"{b}_{alg_name}Flux_flags"
935 for name in mapping:
936 if name in data.dtype.names:
937 output[mapping.get(name, name)] = data[name]
938 else:
939 if "flags" in name.lower():
940 output[mapping.get(name, name)] = np.ones_like(data["id"], dtype=np.int32)
941 else:
942 output[mapping.get(name, name)] = np.ones_like(data["id"], dtype=np.float32)
943 output[mapping.get(name, name)] *= np.nan
945 output["tract"] = tract * np.ones_like(data["id"], dtype=np.int64)
946 output["patch"] = patch * np.ones_like(data["id"], dtype=np.int32)
948 return output
951def _make_comb_data(
952 cell_coadd,
953 res,
954 mask_frac,
955 bands,
956 cell_id,
957):
958 idinfo = cell_coadd.identifiers
960 copy_dt = [
961 # we will copy out of arrays to these
962 ("psfrec_g_1", "f4"),
963 ("psfrec_g_2", "f4"),
964 ("gauss_psf_g_1", "f4"),
965 ("gauss_psf_g_2", "f4"),
966 ("gauss_g_1", "f4"),
967 ("gauss_g_2", "f4"),
968 ("gauss_g_cov_11", "f4"),
969 ("gauss_g_cov_12", "f4"),
970 ("gauss_g_cov_21", "f4"),
971 ("gauss_g_cov_22", "f4"),
972 ]
974 for b in bands:
975 copy_dt.append(("gauss_band_flux_flags_%s" % b, "i4"))
976 copy_dt.append(("gauss_band_flux_%s" % b, "f4"))
977 copy_dt.append(("gauss_band_flux_err_%s" % b, "f4"))
978 copy_dt.append(("pgauss_band_flux_flags_%s" % b, "i4"))
979 copy_dt.append(("pgauss_band_flux_%s" % b, "f4"))
980 copy_dt.append(("pgauss_band_flux_err_%s" % b, "f4"))
982 add_dt = [
983 ("id", "u8"),
984 ("tract", "u4"),
985 ("patch_x", "u1"),
986 ("patch_y", "u1"),
987 ("cell_x", "u1"),
988 ("cell_y", "u1"),
989 ("shear_type", "U2"),
990 ("mask_frac", "f4"),
991 ("primary", bool),
992 ] + copy_dt
994 if not hasattr(res, "keys"):
995 res = {"noshear": res}
997 dlist = []
998 for stype in res.keys():
999 data = res[stype]
1000 if data is not None:
1001 newdata = eu.numpy_util.add_fields(data, add_dt)
1002 newdata["psfrec_g_1"] = newdata["psfrec_g"][:, 0]
1003 newdata["psfrec_g_2"] = newdata["psfrec_g"][:, 1]
1005 newdata["gauss_psf_g_1"] = newdata["gauss_psf_g"][:, 0]
1006 newdata["gauss_psf_g_2"] = newdata["gauss_psf_g"][:, 1]
1007 newdata["gauss_g_1"] = newdata["gauss_g"][:, 0]
1008 newdata["gauss_g_2"] = newdata["gauss_g"][:, 1]
1010 newdata["gauss_g_cov_11"] = newdata["gauss_g_cov"][:, 0, 0]
1011 newdata["gauss_g_cov_12"] = newdata["gauss_g_cov"][:, 0, 1]
1012 newdata["gauss_g_cov_21"] = newdata["gauss_g_cov"][:, 1, 0]
1013 newdata["gauss_g_cov_22"] = newdata["gauss_g_cov"][:, 1, 1]
1015 # To-do make compatible with a single band better than this.
1016 if len(bands) > 1:
1017 for i, b in enumerate(bands):
1018 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"][:, i]
1019 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"][:, i]
1020 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"][:, i]
1021 newdata["pgauss_band_flux_flags_%s" % b] = newdata["pgauss_band_flux_flags"][:, i]
1022 newdata["pgauss_band_flux_%s" % b] = newdata["pgauss_band_flux"][:, i]
1023 newdata["pgauss_band_flux_err_%s" % b] = newdata["pgauss_band_flux_err"][:, i]
1024 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"][:, i]
1025 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"][:, i]
1026 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"][:, i]
1027 else:
1028 b = bands[0]
1029 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"]
1030 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"]
1031 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"]
1032 newdata["pgauss_band_flux_flags_%s" % b] = newdata["pgauss_band_flux_flags"]
1033 newdata["pgauss_band_flux_%s" % b] = newdata["pgauss_band_flux"]
1034 newdata["pgauss_band_flux_err_%s" % b] = newdata["pgauss_band_flux_err"]
1035 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"]
1036 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"]
1037 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"]
1039 newdata["tract"] = idinfo.tract
1040 newdata["patch_x"] = idinfo.patch.x
1041 newdata["patch_y"] = idinfo.patch.y
1042 newdata["cell_x"] = cell_id.x
1043 newdata["cell_y"] = cell_id.y
1045 if stype == "noshear":
1046 newdata["shear_type"] = "ns"
1047 else:
1048 newdata["shear_type"] = stype
1050 dlist.append(newdata)
1052 if len(dlist) > 0:
1053 output = eu.numpy_util.combine_arrlist(dlist)
1054 else:
1055 output = []
1057 return output
1060def _get_mask_frac(mfrac_mbexp, trim_pixels=0):
1061 """
1062 get the average mask frac for each band and then return the max of those
1063 """
1065 mask_fracs = []
1066 for mfrac_exp in mfrac_mbexp:
1067 mfrac = mfrac_exp.image.array
1068 dim = mfrac.shape[0]
1069 mfrac = mfrac[
1070 trim_pixels : dim - trim_pixels - 1,
1071 trim_pixels : dim - trim_pixels - 1,
1072 ]
1073 mask_fracs.append(mfrac.mean())
1075 return max(mask_fracs)