Coverage for python / lsst / drp / tasks / metadetection_shear.py: 19%

265 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 17:11 +0000

1# This file is part of drp_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "MetadetectionProcessingError", 

26 "MetadetectionShearConfig", 

27 "MetadetectionShearTask", 

28) 

29 

30from collections.abc import Collection, Mapping, Sequence 

31from itertools import product 

32from typing import Any, ClassVar 

33 

34import esutil as eu 

35import numpy as np 

36import pyarrow as pa 

37from metadetect.lsst.masking import apply_apodized_bright_masks_mbexp, apply_apodized_edge_masks_mbexp 

38from metadetect.lsst.metacal_exposures import STEP as SHEAR_STEP 

39from metadetect.lsst.metadetect import MetadetectTask 

40from metadetect.lsst.util import extract_multiband_coadd_data 

41 

42import lsst.pipe.base.connectionTypes as cT 

43from lsst.afw.image import ExposureF 

44from lsst.afw.table import SimpleCatalog 

45from lsst.cell_coadds import MultipleCellCoadd, StitchedCoadd 

46from lsst.daf.butler import DataCoordinate, DatasetRef 

47from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader 

48from lsst.meas.base import FullIdGenerator, SkyMapIdGeneratorConfig 

49from lsst.pex.config import ConfigField, ConfigurableField, Field, FieldValidationError, ListField 

50from lsst.pipe.base import ( 

51 AlgorithmError, 

52 AnnotatedPartialOutputsError, 

53 InputQuantizedConnection, 

54 InvalidQuantumError, 

55 NoWorkFound, 

56 OutputQuantizedConnection, 

57 PipelineTask, 

58 PipelineTaskConfig, 

59 PipelineTaskConnections, 

60 QuantumContext, 

61 Struct, 

62) 

63from lsst.pipe.base.connectionTypes import BaseInput, Output 

64from lsst.skymap import BaseSkyMap, Index2D 

65 

66 

67class MetadetectionProcessingError(AlgorithmError): 

68 """Exception raised when metadetection processing fails.""" 

69 

70 @property 

71 def metadata(self) -> dict: 

72 return {} 

73 

74 

75class MetadetectionShearConnections(PipelineTaskConnections, dimensions={"patch"}): 

76 """Definitions of inputs and outputs for MetadetectionShearTask.""" 

77 

78 input_coadds = cT.Input( 

79 "deep_coadd_cell_predetection", 

80 storageClass="MultipleCellCoadd", 

81 doc="Per-band deep coadds.", 

82 multiple=True, 

83 dimensions={"patch", "band"}, 

84 ) 

85 

86 sky_map = cT.Input( 

87 doc="Cell-based skymap defining the patch structure.", 

88 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

89 storageClass="SkyMap", 

90 dimensions=("skymap",), 

91 ) 

92 

93 ref_cat = cT.PrerequisiteInput( 

94 doc="Reference catalog used to mask bright objects.", 

95 name="the_monster_20250219", 

96 storageClass="SimpleCatalog", 

97 dimensions=("skypix",), 

98 deferLoad=True, 

99 multiple=True, 

100 ) 

101 

102 metadetect_catalog = cT.Output( 

103 "object_shear_patch", 

104 storageClass="ArrowTable", 

105 doc="Output catalog with all quantities measured inside the metacalibration loop.", 

106 multiple=False, 

107 dimensions={"patch"}, 

108 ) 

109 

110 metadetect_schema = cT.InitOutput( 

111 "object_shear_schema", 

112 storageClass="ArrowSchema", 

113 doc="Schema of the output catalog.", 

114 ) 

115 

116 config: MetadetectionShearConfig 

117 

118 def __init__(self, *, config=None): 

119 super().__init__(config=config) 

120 

121 if not config: 

122 return None 

123 

124 if not config.do_mask_bright_objects: 

125 del self.ref_cat 

126 

127 def adjustQuantum( 

128 self, 

129 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

130 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

131 label: str, 

132 data_id: DataCoordinate, 

133 ) -> tuple[ 

134 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

135 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

136 ]: 

137 # Docstring inherited. 

138 # This is a hook for customizing what is input and output to each 

139 # invocation of the task as early as possible, which we override here 

140 # to make sure we have exactly the required bands, no more, no less. 

141 connection, original_input_coadds = inputs["input_coadds"] 

142 bands_missing = set(self.config.photometry_bands) 

143 adjusted_input_coadds = [] 

144 for ref in original_input_coadds: 

145 if ref.dataId["band"] in self.config.photometry_bands: 

146 adjusted_input_coadds.append(ref) 

147 bands_missing.remove(ref.dataId["band"]) 

148 if missing_shear_bands := bands_missing.intersection(self.config.metadetect.shear_bands): 

149 raise NoWorkFound(f"Required bands {missing_shear_bands} not present for {label}@{data_id}).") 

150 adjusted_inputs = {"input_coadds": (connection, adjusted_input_coadds)} 

151 inputs.update(adjusted_inputs) 

152 super().adjustQuantum(inputs, outputs, label, data_id) 

153 return adjusted_inputs, {} 

154 

155 

156class MetadetectionShearConfig(PipelineTaskConfig, pipelineConnections=MetadetectionShearConnections): 

157 """Configuration definition for MetadetectionShearTask.""" 

158 

159 metadetect = ConfigurableField( 

160 target=MetadetectTask, 

161 doc="Configuration for metadetection.", 

162 ) 

163 

164 photometry_bands = ListField[str]( 

165 "Bands expected to be present. Cells with one or more of these bands " 

166 "missing will be skipped. Bands other than those listed here will " 

167 "not be processed.", 

168 default=["g", "r", "i", "z"], 

169 ) 

170 

171 do_mask_bright_objects = Field[bool]( 

172 doc="Mask bright objects in coadds?", 

173 default=False, 

174 ) 

175 

176 ref_loader = ConfigField( 

177 dtype=LoadReferenceObjectsConfig, 

178 doc="Reference object loader used for bright-object masking.", 

179 ) 

180 

181 ref_loader_filter_name = Field[str]( 

182 "Filter name from ref_loader used for bright-object masking.", 

183 default="monster_DES_r", 

184 ) 

185 

186 border = Field[int]( 

187 "Border to apply to single cell images, if skymap has no cell borders", 

188 default=50, 

189 ) 

190 

191 id_generator = SkyMapIdGeneratorConfig.make_field() 

192 

193 def setDefaults(self): 

194 super().setDefaults() 

195 self.metadetect.shear_bands = ["r", "i", "z"] 

196 self.metadetect.metacal.types = ["noshear", "1p", "1m", "2p", "2m"] 

197 

198 def validate(self): 

199 super().validate() 

200 if (shear_bands := self.metadetect.shear_bands) is not None and not set(shear_bands).issubset( 

201 self.photometry_bands 

202 ): 

203 raise FieldValidationError( 

204 self.__class__.metadetect, 

205 self, 

206 "photometry_bands must be a list of bands that is a superset of metadetect.shear_bands", 

207 ) 

208 

209 

210class MetadetectionShearTask(PipelineTask): 

211 """A PipelineTask that measures shear using metadetection.""" 

212 

213 _DefaultName: ClassVar[str] = "metadetectionShear" 

214 ConfigClass: ClassVar[type[MetadetectionShearConfig]] = MetadetectionShearConfig 

215 

216 config: MetadetectionShearConfig 

217 

218 def __init__(self, *, initInputs: dict[str, Any] | None = None, **kwargs: Any): 

219 super().__init__(initInputs=initInputs, **kwargs) 

220 self.metadetect_schema = self.make_metadetect_schema(self.config) 

221 self.makeSubtask("metadetect") 

222 

223 @classmethod 

224 def make_metadetect_schema(cls, config: MetadetectionShearConfig) -> pa.Schema: 

225 """Construct a PyArrow Schema for this task's main output catalog. 

226 

227 Parameters 

228 ---------- 

229 config : `MetadetectionShearConfig` 

230 Configuration that may be used to control details of the schema. 

231 

232 Returns 

233 ------- 

234 object_schema : `pyarrow.Schema` 

235 Schema for the object catalog produced by this task. Each field's 

236 metadata should include both a 'doc' entry and a 'unit' entry. 

237 """ 

238 pa_schema = pa.schema( 

239 [ 

240 # Fields from pipeline bookkeeping. 

241 pa.field( 

242 "shearObjectId", 

243 pa.int64(), 

244 nullable=False, 

245 metadata={ 

246 "doc": ( 

247 "Unique identifier for a ShearObject, specific " 

248 "to a single metacalibration counterfactual image." 

249 ), 

250 "unit": "", 

251 }, 

252 ), 

253 pa.field( 

254 "tract", 

255 pa.int64(), 

256 nullable=False, 

257 metadata={ 

258 "doc": "ID of the tract on which this measurement was made.", 

259 "unit": "", 

260 }, 

261 ), 

262 pa.field( 

263 "patch", 

264 pa.int64(), 

265 nullable=False, 

266 metadata={ 

267 "doc": "ID of the patch within the tract on which this measurement was made.", 

268 "unit": "", 

269 }, 

270 ), 

271 pa.field( 

272 "cell_x", 

273 pa.int32(), 

274 nullable=False, 

275 metadata={ 

276 "doc": "Column of the cell within the patch on which this measurement was made.", 

277 "unit": "", 

278 }, 

279 ), 

280 pa.field( 

281 "cell_y", 

282 pa.int32(), 

283 nullable=False, 

284 metadata={ 

285 "doc": "Row of the cell within the patch on which this measurement was made.", 

286 "unit": "", 

287 }, 

288 ), 

289 # Fields from metadetection (generic). 

290 pa.field( 

291 "metaStep", 

292 pa.string(), 

293 nullable=False, 

294 metadata={ 

295 "doc": ( 

296 "Type of artificial shear applied to image. " 

297 "One of: 'ns', '1p', '1m', '2p', '2m'." 

298 ), 

299 "unit": "", 

300 }, 

301 ), 

302 pa.field( 

303 "image_flags", 

304 pa.int32(), 

305 nullable=False, 

306 metadata={ 

307 "doc": "Flags for the image on which this measurement was made.", 

308 "unit": "", 

309 }, 

310 ), 

311 pa.field( 

312 "x", 

313 pa.float32(), 

314 nullable=False, 

315 metadata={ 

316 "doc": "Centroid (tract, x-axis) of the detected ShearObject.", 

317 "unit": "", 

318 }, 

319 ), 

320 pa.field( 

321 "y", 

322 pa.float32(), 

323 nullable=False, 

324 metadata={ 

325 "doc": "Centroid (tract, y-axis) of the detected ShearObject.", 

326 "unit": "", 

327 }, 

328 ), 

329 pa.field( 

330 "ra", 

331 pa.float64(), 

332 nullable=False, 

333 metadata={ 

334 "doc": "Detected Right Ascension of the ShearObject.", 

335 "unit": "degrees", 

336 }, 

337 ), 

338 pa.field( 

339 "dec", 

340 pa.float64(), 

341 nullable=False, 

342 metadata={ 

343 "doc": "Detected Declination of the ShearObject.", 

344 "unit": "degrees", 

345 }, 

346 ), 

347 # Original PSF measurements 

348 pa.field( 

349 "psfOriginal_flags", 

350 pa.int32(), 

351 nullable=False, 

352 metadata={ 

353 "doc": "Flags for the original PSF measurement.", 

354 "unit": "", 

355 }, 

356 ), 

357 pa.field( 

358 "psfOriginal_e1", 

359 pa.float32(), 

360 nullable=False, 

361 metadata={ 

362 "doc": "Distortion-style e1 of the original PSF from adaptive moments.", 

363 "unit": "", 

364 }, 

365 ), 

366 pa.field( 

367 "psfOriginal_e2", 

368 pa.float32(), 

369 nullable=False, 

370 metadata={ 

371 "doc": "Distortion-style e2 of the original PSF from adaptive moments.", 

372 "unit": "", 

373 }, 

374 ), 

375 pa.field( 

376 "psfOriginal_T", 

377 pa.float32(), 

378 nullable=False, 

379 metadata={ 

380 "doc": "Trace (<x^2> + <y^2>) measurement of the original PSF from adaptive moments.", 

381 "unit": "arcseconds squared", 

382 }, 

383 ), 

384 pa.field( 

385 "bmask_flags", 

386 pa.int32(), 

387 nullable=False, 

388 metadata={ 

389 "doc": "`bmask` flags for the ShearObject", 

390 "unit": "", 

391 }, 

392 ), 

393 pa.field( 

394 "ormask_flags", 

395 pa.int32(), 

396 nullable=False, 

397 metadata={ 

398 "doc": "`ored` mask flags for the ShearObject", 

399 "unit": "", 

400 }, 

401 ), 

402 pa.field( 

403 "mfrac", 

404 pa.float32(), 

405 nullable=False, 

406 metadata={ 

407 "doc": "Gaussian-weighted masked fraction for the ShearObject.", 

408 "unit": "", 

409 }, 

410 ), 

411 # Fields that come only from gauss algorithm. 

412 # Reconvolved PSF measurements (gauss) 

413 pa.field( 

414 "gauss_psfReconvolved_flags", 

415 pa.int32(), 

416 nullable=False, 

417 metadata={ 

418 "doc": "Flags for reconvolved PSF (measured with gauss algorithm).", 

419 "unit": "", 

420 }, 

421 ), 

422 pa.field( 

423 "gauss_psfReconvolved_g1", 

424 pa.float32(), 

425 nullable=False, 

426 metadata={ 

427 "doc": "Reduced-shear g1 of the reconvolved PSF (measured with gauss algorithm).", 

428 "unit": "", 

429 }, 

430 ), 

431 pa.field( 

432 "gauss_psfReconvolved_g2", 

433 pa.float32(), 

434 nullable=False, 

435 metadata={ 

436 "doc": "Reduced-shear g2 of the reconvolved PSF (measured with gauss algorithm).", 

437 "unit": "", 

438 }, 

439 ), 

440 pa.field( 

441 "gauss_psfReconvolved_T", 

442 pa.float32(), 

443 nullable=False, 

444 metadata={ 

445 "doc": ( 

446 "Trace (<x^2> + <y^2>) of the reconvolved PSF (measured with gauss algorithm)." 

447 ), 

448 "unit": "arcseconds squared", 

449 }, 

450 ), 

451 # Object measurements (gauss algorithm). 

452 pa.field( 

453 "gauss_g1", 

454 pa.float32(), 

455 nullable=False, 

456 metadata={ 

457 "doc": ( 

458 "Reduced-shear g1 measurement of the ShearObject " 

459 "(measured with gauss algorithm)." 

460 ), 

461 "unit": "", 

462 }, 

463 ), 

464 pa.field( 

465 "gauss_g2", 

466 pa.float32(), 

467 nullable=False, 

468 metadata={ 

469 "doc": ( 

470 "Reduced-shear g2 measurement of the ShearObject " 

471 "(measured with gauss algorithm)." 

472 ), 

473 "unit": "", 

474 }, 

475 ), 

476 pa.field( 

477 "gauss_g1_g1_Cov", 

478 pa.float32(), 

479 nullable=False, 

480 metadata={ 

481 "doc": ( 

482 "Auto-covariance of g1 measurement of the ShearObject " 

483 "(measured with gauss algorithm)." 

484 ), 

485 "unit": "", 

486 }, 

487 ), 

488 pa.field( 

489 "gauss_g1_g2_Cov", 

490 pa.float32(), 

491 nullable=False, 

492 metadata={ 

493 "doc": ( 

494 "Cross-covariance of g1 and g2 measurement of the ShearObject " 

495 "(measured with gauss algorithm)." 

496 ), 

497 "unit": "", 

498 }, 

499 ), 

500 pa.field( 

501 "gauss_g2_g2_Cov", 

502 pa.float32(), 

503 nullable=False, 

504 metadata={ 

505 "doc": ( 

506 "Auto-covariance of g2 measurement of the ShearObject " 

507 "(measured with gauss algorithm)." 

508 ), 

509 "unit": "", 

510 }, 

511 ), 

512 ], 

513 metadata={ 

514 "shear_step": str(SHEAR_STEP), 

515 "shear_bands": "".join(sorted(config.metadetect.shear_bands)), 

516 }, 

517 ) 

518 

519 for alg_name in ("gauss", "pgauss"): 

520 pa_schema = pa_schema.append( 

521 pa.field( 

522 f"{alg_name}_snr", 

523 pa.float32(), 

524 nullable=False, 

525 metadata={ 

526 "doc": ( 

527 "Signal-to-noise ratio measure of the ShearObject " 

528 f"(measured with {alg_name} algorithm)." 

529 ), 

530 "unit": "", 

531 }, 

532 ), 

533 ) 

534 pa_schema = pa_schema.append( 

535 pa.field( 

536 f"{alg_name}_T", 

537 pa.float32(), 

538 nullable=False, 

539 metadata={ 

540 "doc": ( 

541 "Trace (<x^2> + <y^2>) measurement of the ShearObject " 

542 f"(measured with {alg_name} algorithm)." 

543 ), 

544 "unit": "arcseconds squared", 

545 }, 

546 ), 

547 ) 

548 pa_schema = pa_schema.append( 

549 pa.field( 

550 f"{alg_name}_TErr", 

551 pa.float32(), 

552 nullable=False, 

553 metadata={ 

554 "doc": ( 

555 "Uncertainty in the trace measurement of the ShearObject " 

556 f"(measured with {alg_name} algorithm)." 

557 ), 

558 "unit": "arcseconds squared", 

559 }, 

560 ), 

561 ) 

562 pa_schema = pa_schema.append( 

563 pa.field( 

564 f"{alg_name}_shape_flags", 

565 pa.int32(), 

566 nullable=False, 

567 metadata={ 

568 "doc": ( 

569 "Flags for the second order moments measurement of the ShearObject " 

570 f"(measured with {alg_name} algorithm)." 

571 ), 

572 "unit": "", 

573 }, 

574 ), 

575 ) 

576 pa_schema = pa_schema.append( 

577 pa.field( 

578 f"{alg_name}_object_flags", 

579 pa.int32(), 

580 nullable=False, 

581 metadata={ 

582 "doc": f"Flags for the ShearObject measurement (measured with {alg_name} algorithm).", 

583 "unit": "", 

584 }, 

585 ), 

586 ) 

587 pa_schema = pa_schema.append( 

588 pa.field( 

589 f"{alg_name}_flags", 

590 pa.int32(), 

591 nullable=False, 

592 metadata={ 

593 "doc": f"Overall flags for {alg_name} measurement algorithm.", 

594 "unit": "", 

595 }, 

596 ), 

597 ) 

598 

599 # Per-band quantities, typically fluxes and associated quantites. 

600 for b in config.photometry_bands: 

601 pa_schema = pa_schema.append( 

602 pa.field( 

603 f"{b}_{alg_name}Flux_flags", 

604 pa.int32(), 

605 nullable=False, 

606 metadata={ 

607 "doc": f"Flags set for flux in {b} band measured with {alg_name} algorithm.", 

608 "unit": "", 

609 }, 

610 ), 

611 ) 

612 pa_schema = pa_schema.append( 

613 pa.field( 

614 f"{b}_{alg_name}Flux", 

615 pa.float32(), 

616 nullable=b not in config.metadetect.shear_bands, 

617 metadata={ 

618 "doc": f"Flux in {b} band (measured with {alg_name} algorithm).", 

619 "unit": "", 

620 }, 

621 ), 

622 ) 

623 pa_schema = pa_schema.append( 

624 pa.field( 

625 f"{b}_{alg_name}FluxErr", 

626 pa.float32(), 

627 nullable=b not in config.metadetect.shear_bands, 

628 metadata={ 

629 "doc": f"Flux uncertainty in {b} band (measured with {alg_name} algorithm).", 

630 "unit": "", 

631 }, 

632 ), 

633 ) 

634 

635 return pa_schema 

636 

637 def validate_skymap_config(self, skymap_config: BaseSkyMap.ConfigClass) -> None: 

638 if not skymap_config.tractBuilder.name == "cells": 

639 raise InvalidQuantumError("MetadetectionShearTask requires a cell-based skymap.") 

640 

641 cell_config = skymap_config.tractBuilder.active 

642 if (self.config.border == 0 and cell_config.cellBorder == 0) or ( 

643 self.config.border > 0 and cell_config.cellBorder > 0 

644 ): 

645 raise InvalidQuantumError( 

646 "MetadetectionShearTask requires a positive border to be set either in the skymap config " 

647 "or in the task config (but not in both)." 

648 ) 

649 

650 if self.config.border: 

651 # In case cellInnerDimensions are different in different directions 

652 # (which is rare in practice), take the min of it 

653 cell_inner_dimensions = min(cell_config.cellInnerDimensions) 

654 if self.config.border > ( 

655 max_border_value := cell_config.numCellsInPatchBorder * cell_inner_dimensions 

656 ): 

657 raise InvalidQuantumError( 

658 "The border value is too large for the skymap configuration. " 

659 f"Maximum border value is {max_border_value}." 

660 ) 

661 

662 # Ensure that the amount by which we withdraw inwards does not 

663 # create gaps at tract boundaries. 

664 # tractOverlap is specified in degrees. 

665 if ( 

666 skymap_config.tractOverlap * 3600 / skymap_config.pixelScale 

667 < self.count_cells_along_edges(skymap_config) * cell_inner_dimensions + cell_config.cellBorder 

668 ): 

669 raise InvalidQuantumError( 

670 "The tract overlap is insufficient given the borders. " 

671 "This will result in missed regions between adjacent tracts." 

672 ) 

673 

674 @staticmethod 

675 def count_cells_along_edges(skymap_config: BaseSkyMap.ConfigClass) -> int: 

676 """Count the number of cells along the edges to skip processing.""" 

677 if skymap_config.tractBuilder["cells"].cellBorder > 0: 

678 return 0 

679 

680 return skymap_config.tractBuilder["cells"].numCellsInPatchBorder 

681 

682 def runQuantum( 

683 self, 

684 qc: QuantumContext, 

685 inputRefs: InputQuantizedConnection, 

686 outputRefs: OutputQuantizedConnection, 

687 ) -> None: 

688 # Docstring inherited. 

689 

690 # Get the skyMap for this quantum 

691 sky_map = qc.get(inputRefs.sky_map) 

692 

693 self.validate_skymap_config(sky_map.config) 

694 

695 id_generator = self.config.id_generator.apply(qc.quantum.dataId) 

696 

697 if self.config.do_mask_bright_objects: 

698 ref_loader = ReferenceObjectLoader( 

699 dataIds=[ref.datasetRef.dataId for ref in inputRefs.ref_cat], 

700 refCats=[qc.get(ref) for ref in inputRefs.ref_cat], 

701 name=self.config.connections.ref_cat, 

702 config=self.config.ref_loader, 

703 log=self.log, 

704 ) 

705 ref_cat = ref_loader.loadRegion( 

706 qc.quantum.dataId.region, filterName=self.config.ref_loader_filter_name 

707 ) 

708 else: 

709 ref_cat = None 

710 

711 # Read the coadds and put them in the order defined by 

712 # config.photometry_bands (note that each MultipleCellCoadd object also 

713 # knows its own band, if that's needed). 

714 

715 coadds_by_band = { 

716 ref.dataId["band"]: qc.get(ref) 

717 for ref in inputRefs.input_coadds 

718 if ref.dataId["band"] in self.config.photometry_bands 

719 } 

720 

721 try: 

722 outputs = self.run( 

723 patch_coadds=coadds_by_band, 

724 id_generator=id_generator, 

725 sky_map=sky_map, 

726 ref_cat=ref_cat, 

727 ) 

728 except AlgorithmError as err: 

729 # We know there are no actual outputs in this case, but this is 

730 # still the right exception to raise (it's just badly named). 

731 raise AnnotatedPartialOutputsError.annotate(err, self, log=self.log) from err 

732 qc.put(outputs, outputRefs) 

733 

734 def run( 

735 self, 

736 *, 

737 patch_coadds: Mapping[str, MultipleCellCoadd], 

738 id_generator: FullIdGenerator, 

739 sky_map: BaseSkyMap, 

740 ref_cat: SimpleCatalog | None, 

741 ) -> Struct: 

742 """Run metadetection on a patch. 

743 

744 Parameters 

745 ---------- 

746 patch_coadds : `~collections.abc.Mapping` [ \ 

747 `~lsst.cell_coadds.MultipleCellCoadd` ] 

748 Per-band, per-patch coadds, in the order specified by 

749 `MetadetectionShearConfig.photometry_bands`. 

750 id_generator : `~lsst.meas.base.FullIdGenerator` 

751 Generator for object IDs and to seed the random number generator. 

752 sky_map : `~lsst.skymap.BaseSkyMap` 

753 Sky map to use for determining the patch boundaries. 

754 ref_cat : `lsst.afw.table.SimpleCatalog`, optional 

755 Reference catalog to use when masking bright stars. 

756 

757 Returns 

758 ------- 

759 results : `lsst.pipe.base.Struct` 

760 Structure with the following attributes: 

761 

762 - ``metadetect_catalog`` [ `pyarrow.Table` ]: the output object 

763 catalog for the patch, with schema equal to `metadetect_schema`. 

764 """ 

765 seed = id_generator.catalog_id 

766 self.rng = np.random.RandomState(seed) 

767 idstart = 0 

768 

769 match sky_map.config.tractBuilder.active.cellBorder: 

770 case 0: 

771 # If cells have no borders, we cannot apply the metacal 

772 # procedure to the edge cells in the same way as inner cells. 

773 # We can do so for all that cells that are marked as borders, 

774 # but it has to be at least 1. 

775 if (num_edge_cells_skip := sky_map.config.tractBuilder.active.numCellsInPatchBorder) < 1: 

776 raise InvalidQuantumError("No border cells found in the skymap configuration.") 

777 case _: 

778 num_edge_cells_skip = 0 

779 

780 dilate_by = self.config.border or 0 

781 

782 grid = patch_coadds[self.config.metadetect.shear_bands[0]].grid 

783 nx_cells, ny_cells = grid.shape 

784 single_cell_tables: list[pa.Table] = [] 

785 for nx, ny in product( 

786 range(num_edge_cells_skip, nx_cells - num_edge_cells_skip), 

787 range(num_edge_cells_skip, ny_cells - num_edge_cells_skip), 

788 ): 

789 cell_id = Index2D(nx, ny) 

790 bbox = grid.bbox_of(cell_id).dilatedBy(dilate_by) 

791 cell_coadds = [patch_coadd.stitch(bbox) for patch_coadd in patch_coadds.values()] 

792 self.log.debug("Processing cell %s %s", nx, ny) 

793 

794 try: 

795 res = self.process_cell(cell_coadds, cell_id=cell_id) 

796 except Exception as e: 

797 self.log.error("Failed to process cell %s %s: %s", nx, ny, e) 

798 continue 

799 

800 if len(res) > 0: 

801 res["id"] = id_generator.arange(idstart, idstart + len(res)) 

802 # TODO: Avoid back and forth conversion between array and dict. 

803 da = self._dictify( 

804 res, 

805 tract=id_generator.data_id.tract.id, 

806 patch=id_generator.data_id.patch.id, 

807 ) 

808 table = pa.Table.from_pydict(da, self.metadetect_schema) 

809 

810 single_cell_tables.append(table) 

811 idstart += len(res) 

812 

813 if not single_cell_tables: 

814 raise MetadetectionProcessingError("No objects found in any cell") 

815 

816 # TODO: DM-53796 De-duplicate objects before concatenation. 

817 return Struct( 

818 metadetect_catalog=pa.concat_tables(single_cell_tables), 

819 ) 

820 

821 def process_cell( 

822 self, 

823 cell_coadds: Sequence[StitchedCoadd], 

824 cell_id: Index2D, 

825 ) -> pa.Table: 

826 """Run metadetection on a single cell. 

827 

828 Parameters 

829 ---------- 

830 cell_coadds : `~collections.abc.Sequence` [ \ 

831 `~lsst.cell_coadds.StitchedCoadd` ] 

832 Per-band, per-cell coadds, in the order specified by 

833 `MetadetectionShearConfig.photometry_bands`. 

834 cell_id : `~lsst.skymap.Index2D` 

835 The cell ID for the cell being processed. 

836 

837 Returns 

838 ------- 

839 metadetect_catalog : `pyarrow.Table` 

840 Output object catalog for the cell, with schema equal to 

841 `metadetect_schema`. 

842 """ 

843 

844 coadd_data = self._cell_to_coadd_data(cell_coadds) 

845 # TODO get bright star etc. info as input 

846 bright_info = [] 

847 

848 apply_apodized_edge_masks_mbexp(**coadd_data) 

849 

850 if len(bright_info) > 0: 

851 apply_apodized_bright_masks_mbexp(bright_info=bright_info, **coadd_data) 

852 

853 mask_frac = _get_mask_frac( 

854 coadd_data["mfrac_mbexp"], 

855 trim_pixels=0, 

856 ) 

857 

858 res = self.metadetect.run(rng=self.rng, **coadd_data) 

859 

860 comb_res = _make_comb_data( 

861 cell_coadd=cell_coadds[0], 

862 res=res, 

863 mask_frac=mask_frac, 

864 bands=[cell_coadd.band for cell_coadd in cell_coadds], 

865 cell_id=cell_id, 

866 ) 

867 

868 return comb_res 

869 

870 @staticmethod 

871 def _cell_to_coadd_data(cell_coadds: Sequence[StitchedCoadd]): 

872 coadd_data_list = [] 

873 for cell_coadd in cell_coadds: 

874 coadd_data = {} 

875 coadd_data["coadd_exp"] = cell_coadd.asExposure() 

876 coadd_data["coadd_noise_exp"] = cell_coadd.asExposure(noise_index=0) 

877 coadd_data["coadd_mfrac_exp"] = ExposureF(coadd_data["coadd_exp"], deep=True) 

878 coadd_data["coadd_mfrac_exp"].image = cell_coadd.mask_fractions 

879 coadd_data_list.append(coadd_data) 

880 

881 return extract_multiband_coadd_data(coadd_data_list) 

882 

883 def _dictify(self, data, tract: int, patch: int): 

884 output = {} 

885 # TODO: Move this to a better location after DP2. 

886 mapping = { 

887 "bmask": "bmask_flags", 

888 "cell_x": "cell_x", 

889 "cell_y": "cell_y", 

890 "col": "x", 

891 "col_diff": "x_offset", # dropped. 

892 "dec": "dec", 

893 "gauss_flags": "gauss_flags", 

894 "gauss_g_1": "gauss_g1", 

895 "gauss_g_2": "gauss_g2", 

896 "gauss_g_cov_11": "gauss_g1_g1_Cov", 

897 "gauss_g_cov_12": "gauss_g1_g2_Cov", # same as 21. 

898 "gauss_g_cov_22": "gauss_g2_g2_Cov", 

899 "gauss_obj_flags": "gauss_object_flags", 

900 "gauss_psf_flags": "gauss_psfReconvolved_flags", 

901 "gauss_psf_g_1": "gauss_psfReconvolved_g1", 

902 "gauss_psf_g_2": "gauss_psfReconvolved_g2", 

903 "gauss_psf_T": "gauss_psfReconvolved_T", 

904 "gauss_s2n": "gauss_snr", 

905 "gauss_T": "gauss_T", 

906 "gauss_T_err": "gauss_TErr", 

907 "gauss_T_flags": "gauss_shape_flags", 

908 "gauss_T_ratio": "gauss_T_ratio", # dropped. 

909 "id": "shearObjectId", 

910 "mfrac": "mfrac", 

911 "ormask": "ormask_flags", 

912 "pgauss_flags": "pgauss_flags", 

913 "pgauss_obj_flags": "pgauss_object_flags", 

914 "pgauss_s2n": "pgauss_snr", 

915 "pgauss_T": "pgauss_T", 

916 "pgauss_T_err": "pgauss_TErr", 

917 "pgauss_T_flags": "pgauss_shape_flags", 

918 "pgauss_T_ratio": "pgauss_T_ratio", # dropped. 

919 "psfrec_flags": "psfOriginal_flags", 

920 "psfrec_g_1": "psfOriginal_e1", 

921 "psfrec_g_2": "psfOriginal_e2", 

922 "psfrec_T": "psfOriginal_T", 

923 "ra": "ra", 

924 "row": "y", 

925 "row_diff": "y_offset", # dropped. 

926 "shear_type": "metaStep", 

927 "stamp_flags": "image_flags", 

928 } 

929 

930 for b, alg_name in product(self.config.photometry_bands, ("gauss", "pgauss")): 

931 mapping[f"{alg_name}_band_flux_{b}"] = f"{b}_{alg_name}Flux" 

932 mapping[f"{alg_name}_band_flux_err_{b}"] = f"{b}_{alg_name}FluxErr" 

933 mapping[f"{alg_name}_band_flux_flags_{b}"] = f"{b}_{alg_name}Flux_flags" 

934 

935 for name in mapping: 

936 if name in data.dtype.names: 

937 output[mapping.get(name, name)] = data[name] 

938 else: 

939 if "flags" in name.lower(): 

940 output[mapping.get(name, name)] = np.ones_like(data["id"], dtype=np.int32) 

941 else: 

942 output[mapping.get(name, name)] = np.ones_like(data["id"], dtype=np.float32) 

943 output[mapping.get(name, name)] *= np.nan 

944 

945 output["tract"] = tract * np.ones_like(data["id"], dtype=np.int64) 

946 output["patch"] = patch * np.ones_like(data["id"], dtype=np.int32) 

947 

948 return output 

949 

950 

951def _make_comb_data( 

952 cell_coadd, 

953 res, 

954 mask_frac, 

955 bands, 

956 cell_id, 

957): 

958 idinfo = cell_coadd.identifiers 

959 

960 copy_dt = [ 

961 # we will copy out of arrays to these 

962 ("psfrec_g_1", "f4"), 

963 ("psfrec_g_2", "f4"), 

964 ("gauss_psf_g_1", "f4"), 

965 ("gauss_psf_g_2", "f4"), 

966 ("gauss_g_1", "f4"), 

967 ("gauss_g_2", "f4"), 

968 ("gauss_g_cov_11", "f4"), 

969 ("gauss_g_cov_12", "f4"), 

970 ("gauss_g_cov_21", "f4"), 

971 ("gauss_g_cov_22", "f4"), 

972 ] 

973 

974 for b in bands: 

975 copy_dt.append(("gauss_band_flux_flags_%s" % b, "i4")) 

976 copy_dt.append(("gauss_band_flux_%s" % b, "f4")) 

977 copy_dt.append(("gauss_band_flux_err_%s" % b, "f4")) 

978 copy_dt.append(("pgauss_band_flux_flags_%s" % b, "i4")) 

979 copy_dt.append(("pgauss_band_flux_%s" % b, "f4")) 

980 copy_dt.append(("pgauss_band_flux_err_%s" % b, "f4")) 

981 

982 add_dt = [ 

983 ("id", "u8"), 

984 ("tract", "u4"), 

985 ("patch_x", "u1"), 

986 ("patch_y", "u1"), 

987 ("cell_x", "u1"), 

988 ("cell_y", "u1"), 

989 ("shear_type", "U2"), 

990 ("mask_frac", "f4"), 

991 ("primary", bool), 

992 ] + copy_dt 

993 

994 if not hasattr(res, "keys"): 

995 res = {"noshear": res} 

996 

997 dlist = [] 

998 for stype in res.keys(): 

999 data = res[stype] 

1000 if data is not None: 

1001 newdata = eu.numpy_util.add_fields(data, add_dt) 

1002 newdata["psfrec_g_1"] = newdata["psfrec_g"][:, 0] 

1003 newdata["psfrec_g_2"] = newdata["psfrec_g"][:, 1] 

1004 

1005 newdata["gauss_psf_g_1"] = newdata["gauss_psf_g"][:, 0] 

1006 newdata["gauss_psf_g_2"] = newdata["gauss_psf_g"][:, 1] 

1007 newdata["gauss_g_1"] = newdata["gauss_g"][:, 0] 

1008 newdata["gauss_g_2"] = newdata["gauss_g"][:, 1] 

1009 

1010 newdata["gauss_g_cov_11"] = newdata["gauss_g_cov"][:, 0, 0] 

1011 newdata["gauss_g_cov_12"] = newdata["gauss_g_cov"][:, 0, 1] 

1012 newdata["gauss_g_cov_21"] = newdata["gauss_g_cov"][:, 1, 0] 

1013 newdata["gauss_g_cov_22"] = newdata["gauss_g_cov"][:, 1, 1] 

1014 

1015 # To-do make compatible with a single band better than this. 

1016 if len(bands) > 1: 

1017 for i, b in enumerate(bands): 

1018 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"][:, i] 

1019 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"][:, i] 

1020 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"][:, i] 

1021 newdata["pgauss_band_flux_flags_%s" % b] = newdata["pgauss_band_flux_flags"][:, i] 

1022 newdata["pgauss_band_flux_%s" % b] = newdata["pgauss_band_flux"][:, i] 

1023 newdata["pgauss_band_flux_err_%s" % b] = newdata["pgauss_band_flux_err"][:, i] 

1024 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"][:, i] 

1025 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"][:, i] 

1026 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"][:, i] 

1027 else: 

1028 b = bands[0] 

1029 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"] 

1030 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"] 

1031 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"] 

1032 newdata["pgauss_band_flux_flags_%s" % b] = newdata["pgauss_band_flux_flags"] 

1033 newdata["pgauss_band_flux_%s" % b] = newdata["pgauss_band_flux"] 

1034 newdata["pgauss_band_flux_err_%s" % b] = newdata["pgauss_band_flux_err"] 

1035 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"] 

1036 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"] 

1037 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"] 

1038 

1039 newdata["tract"] = idinfo.tract 

1040 newdata["patch_x"] = idinfo.patch.x 

1041 newdata["patch_y"] = idinfo.patch.y 

1042 newdata["cell_x"] = cell_id.x 

1043 newdata["cell_y"] = cell_id.y 

1044 

1045 if stype == "noshear": 

1046 newdata["shear_type"] = "ns" 

1047 else: 

1048 newdata["shear_type"] = stype 

1049 

1050 dlist.append(newdata) 

1051 

1052 if len(dlist) > 0: 

1053 output = eu.numpy_util.combine_arrlist(dlist) 

1054 else: 

1055 output = [] 

1056 

1057 return output 

1058 

1059 

1060def _get_mask_frac(mfrac_mbexp, trim_pixels=0): 

1061 """ 

1062 get the average mask frac for each band and then return the max of those 

1063 """ 

1064 

1065 mask_fracs = [] 

1066 for mfrac_exp in mfrac_mbexp: 

1067 mfrac = mfrac_exp.image.array 

1068 dim = mfrac.shape[0] 

1069 mfrac = mfrac[ 

1070 trim_pixels : dim - trim_pixels - 1, 

1071 trim_pixels : dim - trim_pixels - 1, 

1072 ] 

1073 mask_fracs.append(mfrac.mean()) 

1074 

1075 return max(mask_fracs)