Coverage for tests / test_utils.py: 22%
138 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 08:05 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 08:05 +0000
1# This file is part of source_injection.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
23import os
24import unittest
25from contextlib import redirect_stdout
26from io import StringIO
28import numpy as np
30import lsst.utils.tests
31from lsst.daf.butler.tests import makeTestCollection, makeTestRepo
32from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
33from lsst.obs.base.instrument_tests import DummyCam
34from lsst.pipe.base import Pipeline
35from lsst.skymap.ringsSkyMap import RingsSkyMap, RingsSkyMapConfig
36from lsst.source.injection import (
37 ConsolidateInjectedCatalogsConfig,
38 ConsolidateInjectedCatalogsTask,
39 ExposureInjectTask,
40 ingest_injection_catalog,
41 make_injection_pipeline,
42 show_source_types,
43)
44from lsst.source.injection.utils.test_utils import (
45 make_test_exposure,
46 make_test_injection_catalog,
47 make_test_reference_pipeline,
48)
49from lsst.utils.tests import TestCase
51TEST_DIR = os.path.abspath(os.path.dirname(__file__))
54class SourceInjectionUtilsTestCase(TestCase):
55 """Test the utility functions in the source_injection package."""
57 @classmethod
58 def setUpClass(cls):
59 cls.root = makeTestTempDir(TEST_DIR)
60 cls.creator_butler = makeTestRepo(cls.root)
61 cls.writeable_butler = makeTestCollection(cls.creator_butler)
62 # Register an instrument so we can get some bands.
63 DummyCam().register(cls.writeable_butler.registry)
64 skyMapConfig = RingsSkyMapConfig()
65 skyMapConfig.numRings = 3
66 cls.skyMap = RingsSkyMap(config=skyMapConfig)
67 logging.disable(logging.CRITICAL) # Suppress logging output
69 @classmethod
70 def tearDownClass(cls):
71 del cls.writeable_butler
72 del cls.creator_butler
73 del cls.skyMap
74 removeTestTempDir(cls.root)
75 logging.disable(logging.NOTSET) # Re-enable logging output
77 def setUp(self):
78 self.exposure = make_test_exposure()
79 self.injection_catalog = make_test_injection_catalog(
80 self.exposure.getWcs(),
81 self.exposure.getBBox(),
82 )
83 n_rows = len(self.injection_catalog)
84 group_ids = np.arange(n_rows)
85 group_ids[int(n_rows / 4) : int((n_rows * 3) / 4) : 2] -= 1
86 self.injection_catalog["group_id"] = group_ids
87 self.reference_pipeline = make_test_reference_pipeline()
88 self.consolidate_injected_config = ConsolidateInjectedCatalogsConfig(
89 get_catalogs_from_butler=False,
90 )
91 self.injected_catalog = self.injection_catalog.copy()
92 self.injected_catalog.add_columns(cols=[0, 0], names=["injection_draw_size", "injection_flag"])
93 self.injected_catalog["injection_flag"][:5] = 1
95 def tearDown(self):
96 del self.exposure
97 del self.injection_catalog
98 del self.reference_pipeline
99 del self.injected_catalog
101 def test_generate_injection_catalog(self):
102 self.assertEqual(len(self.injection_catalog), 30)
103 expected_columns = {"injection_id", "ra", "dec", "source_type", "mag", "group_id"}
104 self.assertEqual(set(self.injection_catalog.columns), expected_columns)
106 def test_make_injection_pipeline(self):
107 injection_pipeline = Pipeline("injection_pipeline")
108 injection_pipeline.addTask(ExposureInjectTask, "inject_exposure")
110 additional_pipeline = Pipeline("additional_pipeline")
111 additional_pipeline.addTask(ConsolidateInjectedCatalogsTask, "additional_task")
113 # Explicitly set connection names to non-default values.
114 injection_pipeline.addConfigOverride("inject_exposure", "connections.input_exposure", "A")
115 injection_pipeline.addConfigOverride("inject_exposure", "connections.output_exposure", "B")
116 injection_pipeline.addConfigOverride("inject_exposure", "connections.output_catalog", "C")
118 # Merge the injection pipeline into the main reference pipeline.
119 merged_pipeline = make_injection_pipeline(
120 dataset_type_name="postISRCCD", # Unchanged to match task default
121 reference_pipeline=self.reference_pipeline,
122 injection_pipeline=injection_pipeline,
123 update_subsets=True,
124 excluded_tasks={"calibrate"},
125 prefix="injected_",
126 instrument="lsst.obs.subaru.HyperSuprimeCam",
127 additional_pipelines=[additional_pipeline],
128 additional_subset=["additional_subset:Additional subset description"],
129 log_level=logging.DEBUG,
130 )
132 # Test that only the expected tasks are present in the merged pipeline.
133 expected_task_labels = set(self.reference_pipeline.task_labels) - {"calibrate"}
134 surviving_task_labels = set(self.reference_pipeline.task_labels) & set(merged_pipeline.task_labels)
135 self.assertEqual(expected_task_labels, surviving_task_labels)
137 # Test that all surviving tasks are still in a subset.
138 surviving_task_subsets = [merged_pipeline.findSubsetsWithLabel(x) for x in surviving_task_labels]
139 self.assertEqual(sum(1 for s in surviving_task_subsets if s), len(surviving_task_labels))
140 self.assertIn("additional_subset", merged_pipeline.findSubsetsWithLabel("additional_task"))
141 self.assertNotIn("injected_test_subset", merged_pipeline.findSubsetsWithLabel("isr"))
142 self.assertIn("injected_test_subset", merged_pipeline.findSubsetsWithLabel("inject_exposure"))
143 self.assertIn("injected_test_subset", merged_pipeline.findSubsetsWithLabel("characterizeImage"))
145 # Test that connection names have been properly configured.
146 for t in merged_pipeline.to_graph().tasks.values():
147 if t.label == "isr":
148 self.assertEqual(t.outputs["outputExposure"].dataset_type_name, "postISRCCD")
149 elif t.label == "inject_exposure":
150 self.assertEqual(t.inputs["input_exposure"].dataset_type_name, "postISRCCD")
151 self.assertEqual(t.outputs["output_exposure"].dataset_type_name, "injected_postISRCCD")
152 self.assertEqual(t.outputs["output_catalog"].dataset_type_name, "injected_postISRCCD_catalog")
153 elif t.label == "characterizeImage":
154 self.assertEqual(t.inputs["exposure"].dataset_type_name, "injected_postISRCCD")
155 self.assertEqual(t.outputs["characterized"].dataset_type_name, "icExp")
156 self.assertEqual(t.outputs["backgroundModel"].dataset_type_name, "icExpBackground")
157 self.assertEqual(t.outputs["sourceCat"].dataset_type_name, "icSrc")
159 def test_ingest_injection_catalog(self):
160 input_dataset_refs = ingest_injection_catalog(
161 writeable_butler=self.writeable_butler,
162 table=self.injection_catalog,
163 band="g",
164 output_collection="test_collection",
165 dataset_type_name="injection_catalog",
166 log_level=logging.DEBUG,
167 )
168 output_dataset_refs = self.writeable_butler.registry.queryDatasets(
169 "injection_catalog",
170 collections="test_collection",
171 )
172 self.assertEqual(len(input_dataset_refs), output_dataset_refs.count())
173 input_ids = {x.id for x in input_dataset_refs}
174 output_ids = {x.id for x in output_dataset_refs}
175 self.assertEqual(input_ids, output_ids)
176 injected_catalog = self.writeable_butler.get(input_dataset_refs[0])
177 self.assertTrue(all(self.injection_catalog == injected_catalog))
179 def test_consolidate_injected_catalogs(self):
180 catalog_dict = {"g": self.injected_catalog, "r": self.injected_catalog}
181 output_catalog = self.consolidate_injected_config.consolidate_catalogs(
182 catalog_dict=catalog_dict,
183 skymap=self.skyMap,
184 tract=9,
185 copy_catalogs=True,
186 )
187 self.assertEqual(len(output_catalog), 30)
188 expected_columns = [
189 "injected_id",
190 "ra",
191 "dec",
192 "source_type",
193 "g_mag",
194 "r_mag",
195 "patch",
196 "injection_id",
197 "injection_draw_size",
198 "injection_flag",
199 "injected_isPatchInner",
200 "injected_isTractInner",
201 "injected_isPrimary",
202 "group_id",
203 "g_injection_flag",
204 "r_injection_flag",
205 ]
206 self.assertListEqual(output_catalog.colnames, expected_columns)
207 self.assertEqual(sum(output_catalog["injection_flag"]), 5)
208 self.assertEqual(sum(output_catalog["injected_isPatchInner"]), 30)
209 self.assertEqual(sum(output_catalog["injected_isTractInner"]), 30)
210 self.assertEqual(sum(output_catalog["injected_isPrimary"]), 25)
212 def test_consolidate_injected_catalog_task(self):
213 group_id_key = "group_id"
214 config = ConsolidateInjectedCatalogsConfig(
215 groupIdKey=group_id_key,
216 pixel_match_radius=-1,
217 columns_extra=[],
218 get_catalogs_from_butler=False,
219 )
220 task = ConsolidateInjectedCatalogsTask(config=config)
221 catalog_dict = {"g": self.injected_catalog, "r": self.injected_catalog}
222 output_catalog = task.run(
223 catalog_dict=catalog_dict,
224 skymap=self.skyMap,
225 tract=9,
226 ).output_catalog
227 groupIds, counts = np.unique(
228 self.injection_catalog[group_id_key],
229 return_counts=True,
230 )
231 n_comps = np.max(counts)
232 self.assertEqual(len(output_catalog), len(groupIds))
233 expected_columns = [
234 config.groupIdKey,
235 config.injectionKey,
236 config.col_ra,
237 config.col_dec,
238 ]
239 for band in catalog_dict.keys():
240 columns_band = [
241 f"{band}_{config.injectionKey}",
242 f"{band}_{config.col_mag}",
243 ]
244 for compnum in range(1, n_comps + 1):
245 columns_band.extend(
246 [
247 f"{band}_comp{compnum}_source_type",
248 f"{band}_comp{compnum}_{config.injectionKey}",
249 ]
250 )
251 expected_columns.extend(columns_band)
252 expected_columns.extend(
253 [
254 "patch",
255 "injected_isPatchInner",
256 "injected_isTractInner",
257 "injected_isPrimary",
258 "injected_id",
259 ]
260 )
261 self.assertEqual(set(output_catalog.colnames), set(expected_columns))
262 self.assertEqual(sum(output_catalog["injection_flag"]), 5)
263 self.assertEqual(sum(output_catalog["injected_isPatchInner"]), 22)
264 self.assertEqual(sum(output_catalog["injected_isTractInner"]), 22)
265 self.assertEqual(sum(output_catalog["injected_isPrimary"]), 17)
267 def test_show_source_types(self):
268 buffer = StringIO()
269 with redirect_stdout(buffer):
270 show_source_types(wrap_width=80)
271 output = buffer.getvalue()
272 self.assertIn(
273 "Sersic:\n"
274 " (n, half_light_radius=None, scale_radius=None, mag=None, trunc=0.0,\n"
275 " flux_untruncated=False)",
276 output,
277 )
280class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
281 """Test memory usage of functions in this script."""
283 pass
286def setup_module(module):
287 """Configure pytest."""
288 lsst.utils.tests.init()
291if __name__ == "__main__": 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 lsst.utils.tests.init()
293 unittest.main()