Coverage for python / lsst / source / injection / utils / _make_injection_pipeline.py: 6%
201 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-23 01:57 -0700
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-23 01:57 -0700
1# This file is part of source_injection.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["make_injection_pipeline"]
26import logging
27import warnings
29from lsst.pipe.base import LabelSpecifier, Pipeline, PipelineGraph
30from lsst.pipe.base.pipelineIR import ContractError
33def _infer_injection_pipeline(dataset_type_name: str, logger: logging.Logger) -> str | None:
34 """Infer the injection pipeline from the dataset type name.
36 Parameters
37 ----------
38 dataset_type_name : `str`
39 Name of the dataset type being injected into.
40 logger : `~logging.Logger`
41 Logger for warning and info messages.
43 Returns
44 -------
45 injection_pipeline : `str` | `None`
46 Location of an injection pipeline definition YAML file stub, or None if
47 no suitable injection pipeline could be inferred.
48 """
49 injection_pipeline = None
50 match dataset_type_name:
51 case "postISRCCD" | "post_isr_image":
52 injection_pipeline = "$SOURCE_INJECTION_DIR/pipelines/inject_exposure.yaml"
53 case "icExp" | "calexp" | "initial_pvi" | "pvi" | "preliminary_visit_image" | "visit_image":
54 injection_pipeline = "$SOURCE_INJECTION_DIR/pipelines/inject_visit.yaml"
55 case (
56 "deepCoadd"
57 | "deepCoadd_calexp"
58 | "goodSeeingCoadd"
59 | "deep_coadd_predetection"
60 | "deep_coadd"
61 | "deep_coadd_cell_predetection"
62 | "template_coadd"
63 ):
64 injection_pipeline = "$SOURCE_INJECTION_DIR/pipelines/inject_coadd.yaml"
65 case _:
66 # Print a warning rather than a raise, as the user may wish to
67 # edit connection names without merging an injection pipeline.
68 logger.warning(
69 "Unable to infer injection pipeline stub from dataset type name '%s' and none was "
70 "provided. No injection pipeline will be merged into the output pipeline.",
71 dataset_type_name,
72 )
73 if injection_pipeline:
74 logger.info(
75 "Injected dataset type '%s' used to infer injection pipeline: %s",
76 dataset_type_name,
77 injection_pipeline,
78 )
79 return injection_pipeline
82def _merge_injection_pipeline(
83 pipeline: Pipeline,
84 injection_pipeline: Pipeline | str | None,
85 dataset_type_name: str,
86 prefix: str,
87) -> None | str:
88 """Merge an injection pipeline into an existing pipeline.
90 Parameters
91 ----------
92 pipeline : `~lsst.pipe.base.Pipeline`
93 Pipeline to merge the injection pipeline into.
94 injection_pipeline : `~lsst.pipe.base.Pipeline` | `str` | `None`
95 Injection pipeline to merge, or location of an injection pipeline
96 definition YAML file stub. If None, no injection pipeline is merged.
97 dataset_type_name : `str`
98 Name of the dataset type being injected into.
99 prefix : `str`
100 Prefix to prepend to each affected post-injection dataset type name.
102 Returns
103 -------
104 injection_task_label : `str` | `None`
105 Label of the injection task, if an injection pipeline was merged, or
106 None if no injection pipeline was merged.
108 Notes
109 -----
110 This function modifies the input pipeline in place.
111 """
112 if injection_pipeline is None:
113 return None
114 if isinstance(injection_pipeline, str):
115 injection_pipeline = Pipeline.fromFile(injection_pipeline)
116 if len(injection_pipeline) != 1:
117 raise RuntimeError(
118 f"The injection pipeline contains {len(injection_pipeline)} tasks; only 1 task is allowed."
119 )
120 pipeline.mergePipeline(injection_pipeline)
122 injection_task_label = next(iter(injection_pipeline.task_labels))
123 pipeline.addConfigOverride(injection_task_label, "connections.input_exposure", dataset_type_name)
124 pipeline.addConfigOverride(
125 injection_task_label, "connections.output_exposure", prefix + dataset_type_name
126 )
127 pipeline.addConfigOverride(
128 injection_task_label, "connections.output_catalog", prefix + dataset_type_name + "_catalog"
129 )
130 return injection_task_label
133def _parse_config_override(config_override: str) -> tuple[str, str, str]:
134 """Parse a config override string into a label, a key and a value.
136 Parameters
137 ----------
138 config_override : `str`
139 Config override string to parse.
141 Returns
142 -------
143 label : `str`
144 Label to override.
145 key : `str`
146 Key to override.
147 value : `str`
148 Value to override.
150 Raises
151 ------
152 TypeError
153 If the config override string cannot be parsed.
154 """
155 try:
156 label, keyvalue = config_override.split(":", 1)
157 except ValueError:
158 raise TypeError(
159 f"Unrecognized syntax for option 'config': '{config_override}' (does not match pattern "
160 "(?P<label>.+):(?P<value>.+=.+))"
161 ) from None
162 try:
163 key, value = keyvalue.split("=", 1)
164 except ValueError as e:
165 raise TypeError(
166 f"Could not parse key-value pair '{config_override}' using separator '=', with multiple values "
167 f"not allowed: {e}"
168 ) from None
169 return label, key, value
172def _configure_injection_pipeline(
173 pipeline: Pipeline,
174 config: str | list[str],
175 logger: logging.Logger,
176) -> None:
177 """Apply user-supplied config overrides to the pipeline.
179 Parameters
180 ----------
181 pipeline : `~lsst.pipe.base.Pipeline`
182 Pipeline to apply config overrides to. Pipeline is modified in place.
183 config : `str` | `list` [`str`]
184 Config override(s) to apply, in the format 'label:key=value'.
185 logger : `~logging.Logger`
186 Logger for warning and info messages.
188 Notes
189 -----
190 This function modifies the input pipeline in place.
191 """
192 if isinstance(config, str):
193 config = [config]
194 for conf in config:
195 config_label, config_key, config_value = _parse_config_override(conf)
196 try:
197 pipeline.addConfigOverride(config_label, config_key, config_value)
198 except LookupError:
199 logger.debug(
200 "Config override '%s' for label '%s' not found in the reference "
201 "pipeline, either due to a typo or the label not existing in "
202 "the reference pipeline.",
203 conf,
204 config_label,
205 )
208def _remove_excluded_tasks(
209 pipeline: Pipeline,
210 excluded_tasks: set[str] | str,
211 logger: logging.Logger,
212) -> Pipeline:
213 """Remove excluded tasks from the pipeline and any subsets,
214 and remove any empty subsets.
216 Parameters
217 ----------
218 pipeline : `~lsst.pipe.base.Pipeline`
219 Pipeline to remove tasks from. This pipeline is modified in place.
220 excluded_tasks : `set` [`str`] | `str`
221 Task labels to exclude from the injection pipeline.
222 logger : `~logging.Logger`
223 Logger for warning and info messages.
225 Returns
226 -------
227 pipeline : `~lsst.pipe.base.Pipeline`
228 The input pipeline with excluded tasks and empty subsets removed.
229 """
230 if isinstance(excluded_tasks, str):
231 excluded_tasks = set(excluded_tasks.split(","))
232 all_tasks = set(pipeline.task_labels)
233 preserved_tasks = all_tasks - excluded_tasks
235 preserved_task_labels = LabelSpecifier(labels=preserved_tasks)
236 # EDIT mode removes tasks from parent subsets but keeps the subset itself.
237 pipeline = pipeline.subsetFromLabels(preserved_task_labels, pipeline.PipelineSubsetCtrl.EDIT)
239 if len(found_tasks := excluded_tasks & all_tasks) > 0:
240 grammar = "task" if len(found_tasks) == 1 else "tasks"
241 logger.info(
242 "%d %s excluded from the output pipeline: %s",
243 len(found_tasks),
244 grammar,
245 ", ".join(sorted(found_tasks)),
246 )
248 removed_subsets = set()
249 for subset_label, subset_tasks in pipeline.subsets.items():
250 if not subset_tasks:
251 removed_subsets.add(subset_label)
252 pipeline.removeLabeledSubset(subset_label)
253 if (removed_subsets_count := len(removed_subsets)) > 0:
254 grammar = "subset" if removed_subsets_count == 1 else "subsets"
255 logger.warning(
256 "Removed %d empty %s from the pipeline: %s.",
257 removed_subsets_count,
258 grammar,
259 ", ".join(sorted(removed_subsets)),
260 )
262 return pipeline
265def _get_pipeline_graph(pipeline: Pipeline, logger: logging.Logger) -> PipelineGraph:
266 """Get the pipeline graph, handling any contract errors.
268 Pipeline contracts that are violated by any modifications made to the
269 pipeline will be removed, with a warning logged for each contract that's
270 removed.
272 Parameters
273 ----------
274 pipeline : `~lsst.pipe.base.Pipeline`
275 Pipeline to validate contracts for.
276 logger : `~logging.Logger`
277 Logger for warning and info messages.
279 Returns
280 -------
281 pipeline_graph : `~lsst.pipe.base.PipelineGraph`
282 The pipeline graph for the input pipeline, with any violated contracts
283 removed from the input pipeline.
285 Notes
286 -----
287 This function modifies the input pipeline in place, removing any violated
288 contracts.
289 """
290 try:
291 with warnings.catch_warnings():
292 warnings.filterwarnings(
293 "ignore",
294 message=r".*formatted like a Pipeline parameter but was not found within the Pipeline.*",
295 category=UserWarning,
296 )
297 pipeline_graph = pipeline.to_graph()
298 except ContractError:
299 contracts_initial = pipeline._pipelineIR.contracts
300 pipeline._pipelineIR.contracts = []
301 contracts_passed = []
302 contracts_failed = []
304 for contract in contracts_initial:
305 pipeline._pipelineIR.contracts = [contract]
306 try:
307 _ = pipeline.to_graph()
308 except ContractError:
309 contracts_failed.append(contract)
310 continue
311 contracts_passed.append(contract)
313 pipeline._pipelineIR.contracts = contracts_passed
314 pipeline_graph = pipeline.to_graph()
316 if contracts_failed:
317 logger.warning(
318 "The following contracts were violated and have been removed: \n%s",
319 "\n".join([str(contract) for contract in contracts_failed]),
320 )
321 return pipeline_graph
324def _collect_injected_task_labels(
325 pipeline_graph: PipelineGraph,
326 dataset_type_name: str,
327) -> set[str]:
328 """Collect tasks downstream of the injection point.
330 Parameters
331 ----------
332 pipeline_graph : `~lsst.pipe.base.PipelineGraph`
333 Pipeline graph to inspect.
334 dataset_type_name : `str`
335 Name of the dataset type being injected into.
337 Returns
338 -------
339 injected_task_labels : `set` [`str`]
340 Labels of all tasks that consume the injected dataset type directly or
341 indirectly, including the injection task itself if present.
342 """
343 injected_task_labels = set()
345 dataset_type_frontier = {dataset_type_name}
346 seen_dataset_types = set(dataset_type_frontier)
348 # Note: here we opt to walk the pipeline graph instead of using
349 # `pipeline_graph._xgraph.successors`. The `_xgraph` attribute is a private
350 # implementation detail and therefore not a guaranteed interface.
352 while dataset_type_frontier:
353 next_frontier = set()
354 for current_dataset_type in dataset_type_frontier:
355 for task_node in pipeline_graph.consumers_of(current_dataset_type):
356 if task_node.label in injected_task_labels:
357 continue
359 injected_task_labels.add(task_node.label)
361 output_edges = task_node.iter_all_outputs()
363 for edge in output_edges:
364 output_dataset_type = edge.parent_dataset_type_name
365 if output_dataset_type not in seen_dataset_types:
366 seen_dataset_types.add(output_dataset_type)
367 next_frontier.add(output_dataset_type)
368 dataset_type_frontier = next_frontier
370 return injected_task_labels
373def _add_injected_subsets(
374 pipeline: Pipeline,
375 injected_task_labels: set[str],
376 prefix: str,
377 logger: logging.Logger,
378) -> int:
379 """Create injected variants of existing subsets.
381 Parameters
382 ----------
383 pipeline : `~lsst.pipe.base.Pipeline`
384 Pipeline to modify in place.
385 injected_task_labels : `set` [`str`]
386 Labels of tasks downstream of the injection point.
387 prefix : `str`
388 Prefix to prepend to the subset names.
389 logger : `~logging.Logger`
390 Logger for warning and info messages.
392 Returns
393 -------
394 subset_count : `int`
395 Number of injected subsets created.
396 """
397 if not injected_task_labels:
398 return 0
400 injected_label_specifier = LabelSpecifier(labels=injected_task_labels)
401 injected_pipeline = pipeline.subsetFromLabels(injected_label_specifier, pipeline.PipelineSubsetCtrl.EDIT)
403 injected_subset_labels = set()
404 for subset_label, subset_tasks in injected_pipeline.subsets.items():
405 if not subset_tasks:
406 continue
407 injected_subset_label = prefix + subset_label
408 injected_subset_description = (
409 f"All tasks from the '{subset_label}' subset impacted by source injection."
410 )
411 pipeline.addLabeledSubset(injected_subset_label, injected_subset_description, subset_tasks)
412 injected_subset_labels.add(injected_subset_label)
414 return len(injected_subset_labels)
417def _reconfigure_injection_pipeline(
418 pipeline: Pipeline,
419 dataset_type_name: str,
420 prefix: str,
421 injection_task_label: str | None,
422 update_subsets: bool,
423 logger: logging.Logger,
424) -> None:
425 """Reconfigure the injection pipeline by prefixing post-injection dataset
426 type names and updating subsets.
428 Parameters
429 ----------
430 pipeline : `~lsst.pipe.base.Pipeline`
431 Pipeline to configure. This pipeline is modified in place.
432 dataset_type_name : `str`
433 Name of the dataset type being injected into.
434 prefix : `str`
435 Prefix to prepend to each affected post-injection dataset type name.
436 injection_task_label : `str` | `None`
437 Label of the injection task.
438 update_subsets : `bool`
439 If True, update pipeline subsets to include the injection task.
440 logger : `~logging.Logger`
441 Logger for warning and info messages.
443 Notes
444 -----
445 This function modifies the input pipeline in place.
446 """
447 # Use pipeline graph to determine tasks with connections to be modified
448 pipeline_graph = _get_pipeline_graph(pipeline, logger)
449 injected_task_labels = _collect_injected_task_labels(pipeline_graph, dataset_type_name)
450 post_injection_tasks = pipeline_graph.consumers_of(dataset_type_name)
451 if len(post_injection_tasks) == 0:
452 logger.warning(
453 "Dataset type '%s' not found in the reference pipeline; no input connection edits to be made.",
454 dataset_type_name,
455 )
456 if post_injection_tasks:
457 post_injection_tasks = [task for task in post_injection_tasks if task.label != injection_task_label]
458 else:
459 post_injection_tasks = []
461 # Loop over each post injection task; prefix input connections only
462 for task_node in post_injection_tasks:
463 input_edges = task_node.iter_all_inputs()
465 for edge in input_edges:
466 if hasattr(task_node.config.connections.ConnectionsClass, edge.connection_name):
467 if edge.parent_dataset_type_name == dataset_type_name:
468 pipeline.addConfigOverride(
469 task_node.label,
470 "connections." + edge.connection_name,
471 prefix + edge.dataset_type_name,
472 )
474 # Update subsets to include the injection task
475 if (
476 update_subsets
477 and injection_task_label is not None
478 and (pre_injection_task := pipeline_graph.producer_of(dataset_type_name)) is not None
479 ):
480 precursor_subsets = pipeline.findSubsetsWithLabel(pre_injection_task.label)
481 for subset in precursor_subsets:
482 pipeline.addLabelToSubset(subset, injection_task_label)
484 injected_subset_count = 0
485 if update_subsets:
486 injected_subset_count = _add_injected_subsets(pipeline, injected_task_labels, prefix, logger)
488 logger.info(
489 "Made an injection pipeline containing %d task%s and %d injected subset%s.",
490 len(pipeline),
491 "" if len(pipeline) == 1 else "s",
492 injected_subset_count,
493 "" if injected_subset_count == 1 else "s",
494 )
497def _add_additional_pipelines(
498 pipeline: Pipeline,
499 additional_pipelines: list[Pipeline] | list[str],
500 additional_subset: list[str] | str | None,
501 logger: logging.Logger,
502) -> None:
503 """Add additional pipelines to the injection pipeline, and optionally add
504 all additional tasks to existing or new subsets.
506 Parameters
507 ----------
508 pipeline : `~lsst.pipe.base.Pipeline`
509 Pipeline to add additional pipelines to. Pipeline is modified in place.
510 additional_pipelines : `list` [`~lsst.pipe.base.Pipeline`] | `list` [`str`]
511 Additional pipelines to merge, or locations of additional pipeline
512 definition YAML file stubs.
513 additional_subset : `list` [`str`] | `str` | `None`
514 A list of subset definitions in the form
515 "subset_name[:subset_description]".
516 These subsets will be created if they don't already exist. All tasks
517 from the additional pipelines will be added to these subsets.
518 If None, additional tasks will not be added to any subsets.
519 logger : `~logging.Logger`
520 Logger for warning and info messages.
522 Notes
523 -----
524 This function modifies the input pipeline in place.
525 """
526 # Merge all additional pipelines into the main pipeline
527 additional_tasks: set[str] = set()
528 for additional_pipeline in additional_pipelines:
529 if isinstance(additional_pipeline, str):
530 additional_pipeline = Pipeline.fromFile(additional_pipeline)
531 additional_tasks.update(additional_pipeline.task_labels)
532 pipeline.mergePipeline(additional_pipeline)
534 # Add all tasks to subset_name; create the subset if it does not exist
535 subset_text = ""
536 if additional_subset is not None:
537 if not isinstance(additional_subset, list):
538 additional_subset = [additional_subset]
539 subset_names_old = []
540 subset_names_new = []
541 for subset in additional_subset:
542 # Parse the subset definition
543 if ":" in subset:
544 subset_name, subset_description = subset.split(":", 1)
545 else:
546 subset_name = subset
547 subset_description = ""
548 # Add or create the subset with all additional tasks
549 if subset_name in pipeline.subsets:
550 subset_names_old.append(subset_name)
551 for additional_task in additional_tasks:
552 pipeline.addLabelToSubset(subset_name, additional_task)
553 else:
554 subset_names_new.append(subset_name)
555 pipeline.addLabeledSubset(subset_name, subset_description, additional_tasks)
556 if subset_names_old:
557 subset_text += f", and to existing subset{'s' if len(subset_names_old) > 1 else ''} "
558 subset_text += f"{', '.join(sorted(subset_names_old))}"
559 if subset_names_new:
560 subset_text += f", and to new subset{'s' if len(subset_names_new) > 1 else ''} "
561 subset_text += f"{', '.join(sorted(subset_names_new))}"
563 # Revalidate the pipeline graph
564 _ = _get_pipeline_graph(pipeline, logger)
566 grammar = "task" if len(additional_tasks) == 1 else "tasks"
567 logger.info(
568 "Added %d %s to the pipeline%s: %s",
569 len(additional_tasks),
570 grammar,
571 subset_text,
572 ", ".join(sorted(additional_tasks)),
573 )
576def make_injection_pipeline(
577 dataset_type_name: str,
578 reference_pipeline: Pipeline | str,
579 injection_pipeline: Pipeline | str | None = None,
580 update_subsets: bool = True,
581 excluded_tasks: set[str] | str = {
582 "jointcal",
583 "gbdesAstrometricFit",
584 "fgcmBuildFromIsolatedStars",
585 "fgcmFitCycle",
586 "fgcmOutputProducts",
587 },
588 prefix: str = "injected_",
589 instrument: str | None = None,
590 config: str | list[str] | None = None,
591 additional_pipelines: list[Pipeline] | list[str] | None = None,
592 additional_subset: list[str] | str | None = None,
593 log_level: int = logging.INFO,
594) -> Pipeline:
595 """Make an expanded source injection pipeline.
597 This function takes a reference pipeline definition file and prefixes all
598 immediately post-injection dataset type names with the injected prefix. If
599 an optional injection pipeline definition YAML file is also provided, the
600 injection task will be merged into the pipeline.
602 Unless subset updates are explicitly disabled, all subsets from the
603 reference pipeline containing the task which generates the injection
604 dataset type will also be updated to include the injection task.
606 When the injection pipeline is constructed, a check on all existing
607 pipeline contracts is performed. If any contracts are violated, they're
608 removed from the pipeline. A warning is logged for each contract that is
609 removed.
611 Parameters
612 ----------
613 dataset_type_name : `str`
614 Name of the dataset type being injected into.
615 reference_pipeline : Pipeline | `str`
616 Location of a reference pipeline definition YAML file.
617 injection_pipeline : Pipeline | `str`, optional
618 Location of an injection pipeline definition YAML file stub. If not
619 provided, an attempt to infer the injection pipeline will be made based
620 on the injected dataset type name.
621 update_subsets : `bool`, optional
622 If True, update pipeline subsets to include the injection task.
623 excluded_tasks : `set` [`str`] | `str`
624 Set of task labels to exclude, or a comma-separated string of labels.
625 prefix : `str`, optional
626 Prefix to prepend to each affected post-injection dataset type name.
627 instrument : `str`, optional
628 Add instrument overrides. Must be a fully qualified class name.
629 config : `str` | `list` [`str`], optional
630 Config override for a task, in the format 'label:key=value'.
631 additional_pipelines: `list`[Pipeline] | `list`[`str`], optional
632 Additional pipelines to merge into the output pipeline, or their YAML
633 file locations. Tasks from these additional pipelines will be added to
634 the output injection pipeline.
635 additional_subset: `list`[`str`] | `str`, optional
636 A list of subset definitions in the form
637 "subset_name[:subset_description]".
638 These subsets will be created if they don't already exist.
639 All tasks from additional_pipelines will be added to these subsets.
640 log_level : `int`, optional
641 The log level to use for logging.
643 Returns
644 -------
645 pipeline : `lsst.pipe.base.Pipeline`
646 An expanded source injection pipeline.
647 """
648 logger = logging.getLogger(__name__)
649 logger.setLevel(log_level)
651 # Get the main reference pipeline
652 if isinstance(reference_pipeline, str):
653 pipeline = Pipeline.fromFile(reference_pipeline)
654 else:
655 pipeline = reference_pipeline
657 # Add an instrument override
658 if instrument:
659 pipeline.addInstrument(instrument)
661 # Infer the injection pipeline if not provided, and where possible
662 if not injection_pipeline:
663 injection_pipeline = _infer_injection_pipeline(
664 dataset_type_name,
665 logger,
666 )
668 # Merge the injection pipeline into the main pipeline
669 injection_task_label = _merge_injection_pipeline(pipeline, injection_pipeline, dataset_type_name, prefix)
671 # Apply all user-supplied config overrides
672 if config is not None:
673 _configure_injection_pipeline(pipeline, config, logger)
675 # Remove excluded tasks from the pipeline, and remove any empty subsets
676 pipeline = _remove_excluded_tasks(pipeline, excluded_tasks, logger)
678 # Prefix post-injection dataset type name connections and update subsets
679 _reconfigure_injection_pipeline(
680 pipeline, dataset_type_name, prefix, injection_task_label, update_subsets, logger
681 )
683 # Optionally include additional tasks in the injection pipeline.
684 if additional_pipelines:
685 _add_additional_pipelines(pipeline, additional_pipelines, additional_subset, logger)
687 return pipeline