Coverage for python/lsst/pipe/base/quantum_graph/aggregator/_writer.py: 25%

85 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 01:48 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("Writer",) 

31 

32import dataclasses 

33from collections.abc import ByteString 

34 

35import zstandard 

36 

37from ...log_on_close import LogOnClose 

38from ...pipeline_graph import TaskImportMode 

39from .._predicted import PredictedQuantumGraphComponents, PredictedQuantumGraphReader 

40from .._provenance import ProvenanceQuantumGraphWriter, ProvenanceQuantumScanData 

41from ._communicators import WriterCommunicator 

42 

43 

44@dataclasses.dataclass 

45class Writer: 

46 """A helper class for the provenance aggregator actually writes the 

47 provenance quantum graph file. 

48 """ 

49 

50 predicted_path: str 

51 """Path to the predicted quantum graph.""" 

52 

53 comms: WriterCommunicator 

54 """Communicator object for this worker.""" 

55 

56 predicted: PredictedQuantumGraphComponents = dataclasses.field(init=False) 

57 """Components of the predicted quantum graph.""" 

58 

59 pending_compression_training: list[ProvenanceQuantumScanData] = dataclasses.field(default_factory=list) 

60 """Unprocessed quantum scans that are being accumulated in order to 

61 build a compression dictionary. 

62 """ 

63 

64 def __post_init__(self) -> None: 

65 assert self.comms.config.is_writing_provenance, "Writer should not be used if writing is disabled." 

66 self.comms.log.info("Reading predicted quantum graph.") 

67 with PredictedQuantumGraphReader.open( 

68 self.predicted_path, import_mode=TaskImportMode.DO_NOT_IMPORT 

69 ) as reader: 

70 self.comms.check_for_cancel() 

71 reader.read_init_quanta() 

72 self.comms.check_for_cancel() 

73 reader.read_quantum_datasets() 

74 self.predicted = reader.components 

75 

76 @staticmethod 

77 def run(predicted_path: str, comms: WriterCommunicator) -> None: 

78 """Run the writer. 

79 

80 Parameters 

81 ---------- 

82 predicted_path : `str` 

83 Path to the predicted quantum graph. 

84 comms : `WriterCommunicator` 

85 Communicator for the writer. 

86 

87 Notes 

88 ----- 

89 This method is designed to run as the ``target`` in 

90 `WorkerFactory.make_worker`. 

91 """ 

92 with comms: 

93 writer = Writer(predicted_path, comms) 

94 writer.loop() 

95 

96 def loop(self) -> None: 

97 """Run the main loop for the writer.""" 

98 qg_writer: ProvenanceQuantumGraphWriter | None = None 

99 if not self.comms.config.zstd_dict_size: 

100 qg_writer = self.make_qg_writer() 

101 self.comms.log.info("Polling for write requests from scanners.") 

102 for request in self.comms.poll(): 

103 if qg_writer is None: 

104 self.pending_compression_training.append(request) 

105 if len(self.pending_compression_training) >= self.comms.config.zstd_dict_n_inputs: 

106 qg_writer = self.make_qg_writer() 

107 else: 

108 qg_writer.write_scan_data(request) 

109 self.comms.report_write() 

110 if qg_writer is None: 

111 qg_writer = self.make_qg_writer() 

112 self.comms.log.info("Writing init outputs.") 

113 qg_writer.write_init_outputs(assume_existence=False) 

114 

115 def make_qg_writer(self) -> ProvenanceQuantumGraphWriter: 

116 """Make a compression dictionary, open the low-level writers, and 

117 write any accumulated scans that were needed to make the compression 

118 dictionary. 

119 

120 Returns 

121 ------- 

122 qg_writer : `ProvenanceQuantumGraphWriter` 

123 Low-level writers struct. 

124 """ 

125 cdict = self.make_compression_dictionary() 

126 self.comms.send_compression_dict(cdict.as_bytes()) 

127 assert self.comms.config.is_writing_provenance and self.comms.config.output_path is not None 

128 self.comms.log.info("Opening output files and processing predicted graph.") 

129 qg_writer = ProvenanceQuantumGraphWriter( 

130 self.comms.config.output_path, 

131 exit_stack=self.comms.exit_stack, 

132 log_on_close=LogOnClose(self.comms.log_progress), 

133 predicted=self.predicted, 

134 zstd_level=self.comms.config.zstd_level, 

135 cdict_data=cdict.as_bytes(), 

136 loop_wrapper=self.comms.periodically_check_for_cancel, 

137 log=self.comms.log, 

138 ) 

139 self.comms.check_for_cancel() 

140 self.comms.log.info("Compressing and writing queued scan requests.") 

141 for request in self.pending_compression_training: 

142 qg_writer.write_scan_data(request) 

143 self.comms.report_write() 

144 del self.pending_compression_training 

145 self.comms.check_for_cancel() 

146 self.comms.log.info("Writing overall inputs.") 

147 qg_writer.write_overall_inputs(self.comms.periodically_check_for_cancel) 

148 qg_writer.write_packages() 

149 self.comms.log.info("Returning to write request loop.") 

150 return qg_writer 

151 

152 def make_compression_dictionary(self) -> zstandard.ZstdCompressionDict: 

153 """Make the compression dictionary. 

154 

155 Returns 

156 ------- 

157 cdict : `zstandard.ZstdCompressionDict` 

158 The compression dictionary. 

159 """ 

160 if ( 

161 not self.comms.config.zstd_dict_size 

162 or len(self.pending_compression_training) < self.comms.config.zstd_dict_n_inputs 

163 ): 

164 self.comms.log.info("Making compressor with no dictionary.") 

165 return zstandard.ZstdCompressionDict(b"") 

166 self.comms.log.info("Training compression dictionary.") 

167 training_inputs: list[ByteString] = [] 

168 # We start the dictionary training with *predicted* quantum dataset 

169 # models, since those have almost all of the same attributes as the 

170 # provenance quantum and dataset models, and we can get a nice random 

171 # sample from just the first N, since they're ordered by UUID. We 

172 # chop out the datastore records since those don't appear in the 

173 # provenance graph. 

174 for predicted_quantum in self.predicted.quantum_datasets.values(): 

175 if len(training_inputs) == self.comms.config.zstd_dict_n_inputs: 

176 break 

177 predicted_quantum.datastore_records.clear() 

178 training_inputs.append(predicted_quantum.model_dump_json().encode()) 

179 # Add the provenance quanta, metadata, and logs we've accumulated. 

180 for write_request in self.pending_compression_training: 

181 assert not write_request.is_compressed, "We can't compress without the compression dictionary." 

182 training_inputs.append(write_request.quantum) 

183 training_inputs.append(write_request.metadata) 

184 training_inputs.append(write_request.logs) 

185 return zstandard.train_dictionary(self.comms.config.zstd_dict_size, training_inputs)