Coverage for python/lsst/pipe/base/quantum_graph/aggregator/_writer.py: 25%
85 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-29 08:23 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-29 08:23 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("Writer",)
32import dataclasses
33from collections.abc import ByteString
35import zstandard
37from ...log_on_close import LogOnClose
38from ...pipeline_graph import TaskImportMode
39from .._predicted import PredictedQuantumGraphComponents, PredictedQuantumGraphReader
40from .._provenance import ProvenanceQuantumGraphWriter, ProvenanceQuantumScanData
41from ._communicators import WriterCommunicator
44@dataclasses.dataclass
45class Writer:
46 """A helper class for the provenance aggregator actually writes the
47 provenance quantum graph file.
48 """
50 predicted_path: str
51 """Path to the predicted quantum graph."""
53 comms: WriterCommunicator
54 """Communicator object for this worker."""
56 predicted: PredictedQuantumGraphComponents = dataclasses.field(init=False)
57 """Components of the predicted quantum graph."""
59 pending_compression_training: list[ProvenanceQuantumScanData] = dataclasses.field(default_factory=list)
60 """Unprocessed quantum scans that are being accumulated in order to
61 build a compression dictionary.
62 """
64 def __post_init__(self) -> None:
65 assert self.comms.config.is_writing_provenance, "Writer should not be used if writing is disabled."
66 self.comms.log.info("Reading predicted quantum graph.")
67 with PredictedQuantumGraphReader.open(
68 self.predicted_path, import_mode=TaskImportMode.DO_NOT_IMPORT
69 ) as reader:
70 self.comms.check_for_cancel()
71 reader.read_init_quanta()
72 self.comms.check_for_cancel()
73 reader.read_quantum_datasets()
74 self.predicted = reader.components
76 @staticmethod
77 def run(predicted_path: str, comms: WriterCommunicator) -> None:
78 """Run the writer.
80 Parameters
81 ----------
82 predicted_path : `str`
83 Path to the predicted quantum graph.
84 comms : `WriterCommunicator`
85 Communicator for the writer.
87 Notes
88 -----
89 This method is designed to run as the ``target`` in
90 `WorkerFactory.make_worker`.
91 """
92 with comms:
93 writer = Writer(predicted_path, comms)
94 writer.loop()
96 def loop(self) -> None:
97 """Run the main loop for the writer."""
98 qg_writer: ProvenanceQuantumGraphWriter | None = None
99 if not self.comms.config.zstd_dict_size:
100 qg_writer = self.make_qg_writer()
101 self.comms.log.info("Polling for write requests from scanners.")
102 for request in self.comms.poll():
103 if qg_writer is None:
104 self.pending_compression_training.append(request)
105 if len(self.pending_compression_training) >= self.comms.config.zstd_dict_n_inputs:
106 qg_writer = self.make_qg_writer()
107 else:
108 qg_writer.write_scan_data(request)
109 self.comms.report_write()
110 if qg_writer is None:
111 qg_writer = self.make_qg_writer()
112 self.comms.log.info("Writing init outputs.")
113 qg_writer.write_init_outputs(assume_existence=False)
115 def make_qg_writer(self) -> ProvenanceQuantumGraphWriter:
116 """Make a compression dictionary, open the low-level writers, and
117 write any accumulated scans that were needed to make the compression
118 dictionary.
120 Returns
121 -------
122 qg_writer : `ProvenanceQuantumGraphWriter`
123 Low-level writers struct.
124 """
125 cdict = self.make_compression_dictionary()
126 self.comms.send_compression_dict(cdict.as_bytes())
127 assert self.comms.config.is_writing_provenance and self.comms.config.output_path is not None
128 self.comms.log.info("Opening output files and processing predicted graph.")
129 qg_writer = ProvenanceQuantumGraphWriter(
130 self.comms.config.output_path,
131 exit_stack=self.comms.exit_stack,
132 log_on_close=LogOnClose(self.comms.log_progress),
133 predicted=self.predicted,
134 zstd_level=self.comms.config.zstd_level,
135 cdict_data=cdict.as_bytes(),
136 loop_wrapper=self.comms.periodically_check_for_cancel,
137 log=self.comms.log,
138 )
139 self.comms.check_for_cancel()
140 self.comms.log.info("Compressing and writing queued scan requests.")
141 for request in self.pending_compression_training:
142 qg_writer.write_scan_data(request)
143 self.comms.report_write()
144 del self.pending_compression_training
145 self.comms.check_for_cancel()
146 self.comms.log.info("Writing overall inputs.")
147 qg_writer.write_overall_inputs(self.comms.periodically_check_for_cancel)
148 qg_writer.write_packages()
149 self.comms.log.info("Returning to write request loop.")
150 return qg_writer
152 def make_compression_dictionary(self) -> zstandard.ZstdCompressionDict:
153 """Make the compression dictionary.
155 Returns
156 -------
157 cdict : `zstandard.ZstdCompressionDict`
158 The compression dictionary.
159 """
160 if (
161 not self.comms.config.zstd_dict_size
162 or len(self.pending_compression_training) < self.comms.config.zstd_dict_n_inputs
163 ):
164 self.comms.log.info("Making compressor with no dictionary.")
165 return zstandard.ZstdCompressionDict(b"")
166 self.comms.log.info("Training compression dictionary.")
167 training_inputs: list[ByteString] = []
168 # We start the dictionary training with *predicted* quantum dataset
169 # models, since those have almost all of the same attributes as the
170 # provenance quantum and dataset models, and we can get a nice random
171 # sample from just the first N, since they're ordered by UUID. We
172 # chop out the datastore records since those don't appear in the
173 # provenance graph.
174 for predicted_quantum in self.predicted.quantum_datasets.values():
175 if len(training_inputs) == self.comms.config.zstd_dict_n_inputs:
176 break
177 predicted_quantum.datastore_records.clear()
178 training_inputs.append(predicted_quantum.model_dump_json().encode())
179 # Add the provenance quanta, metadata, and logs we've accumulated.
180 for write_request in self.pending_compression_training:
181 assert not write_request.is_compressed, "We can't compress without the compression dictionary."
182 training_inputs.append(write_request.quantum)
183 training_inputs.append(write_request.metadata)
184 training_inputs.append(write_request.logs)
185 return zstandard.train_dictionary(self.comms.config.zstd_dict_size, training_inputs)