Coverage for python/lsst/images/tests/_roundtrip.py: 28%

163 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-29 08:40 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("RoundtripFits", "RoundtripJson", "RoundtripNdf", "TemporaryButler") 

15 

16import tempfile 

17import unittest 

18import uuid 

19from abc import ABC, abstractmethod 

20from contextlib import ExitStack 

21from typing import TYPE_CHECKING, Any, Self, TypeVar 

22 

23import astropy.io.fits 

24from pydantic_core import from_json 

25 

26if TYPE_CHECKING: 

27 import h5py 

28 

29try: 

30 from lsst.daf.butler import Butler, Config, DataCoordinate, DatasetProvenance, DatasetRef, DatasetType 

31 

32 HAVE_BUTLER = True 

33except ImportError: 

34 HAVE_BUTLER = False 

35 

36from .. import fits, json 

37from .._generalized_image import GeneralizedImage 

38from ..serialization import ArchiveTree, MetadataValue, ReadResult 

39 

40# We need an old-style TypeVar for Sphinx. 

41T = TypeVar("T") 

42 

43 

44class TemporaryButler: 

45 """Make a temporary butler repository. 

46 

47 Parameters 

48 ---------- 

49 run 

50 Name of a `~lsst.daf.butler.CollectionType.RUN` collection to 

51 register and use as the default run for the returned butler. 

52 format 

53 Optional on-disk format name (``fits``, ``json``, ``sdf``, 

54 ``zarr``, ...) to bind to every storage class registered by 

55 ``**kwargs``. When set, the datastore config is overlaid so that 

56 `~lsst.images.formatters.GenericFormatter` writes that format for 

57 those storage classes, overriding its ``.fits`` default. Leave as 

58 `None` to keep the default formatter behaviour. 

59 **kwargs 

60 A mapping from a dataset type name to its storage class. For each 

61 entry, a dataset type will be registered with empty dimensions, and a 

62 `~lsst.daf.butler.DatasetRef` will be created and added as an 

63 attribute of this class. 

64 

65 Raises 

66 ------ 

67 unittest.SkipTest 

68 Raised when the context manager is entered if `lsst.daf.butler` could 

69 not be imported. This is typically handled by using this context 

70 manager within a `unittest.TestCase.subTest` context, which will skip 

71 just the butler-required tests in that context while allowing the rest 

72 of the test to continue. 

73 """ 

74 

75 def __init__(self, run: str = "test_run", *, format: str | None = None, **kwargs: str): 

76 self.run = run 

77 self._format = format 

78 self._kwargs = kwargs 

79 self._exit_stack = ExitStack() 

80 

81 def __enter__(self) -> TemporaryButler: 

82 if not HAVE_BUTLER: 

83 raise unittest.SkipTest("lsst.daf.butler could not be imported.") 

84 self._exit_stack.__enter__() 

85 root = self._exit_stack.enter_context( 

86 tempfile.TemporaryDirectory(ignore_cleanup_errors=True, delete=True) 

87 ) 

88 if self._format is not None: 

89 # Overlay a per-storage-class formatter binding so the default 

90 # FITS-writing GenericFormatter writes the requested format 

91 # instead. Keyed by the storage class name (matched by the 

92 # daf_butler formatter factory). 

93 overlay = Config( 

94 { 

95 "datastore": { 

96 "formatters": { 

97 storage_class: { 

98 "formatter": "lsst.images.formatters.GenericFormatter", 

99 "parameters": {"format": self._format}, 

100 } 

101 for storage_class in self._kwargs.values() 

102 } 

103 } 

104 } 

105 ) 

106 butler_config = Butler.makeRepo(root, config=overlay) 

107 else: 

108 butler_config = Butler.makeRepo(root) 

109 self.butler = self._exit_stack.enter_context(Butler.from_config(butler_config, run=self.run)) 

110 empty_data_id = DataCoordinate.make_empty(self.butler.dimensions) 

111 for name, storage_class in self._kwargs.items(): 

112 dataset_type = DatasetType(name, self.butler.dimensions.empty, storage_class) 

113 try: 

114 self.butler.registry.registerDatasetType(dataset_type) 

115 except KeyError as err: 

116 err.add_note( 

117 "Storage class not configured in butler defaults. " 

118 "A newer version of daf_butler may be needed." 

119 ) 

120 raise 

121 setattr(self, name, DatasetRef(dataset_type, empty_data_id, self.run)) 

122 return self 

123 

124 def __exit__(self, *args: Any) -> bool | None: 

125 return self._exit_stack.__exit__(*args) 

126 

127 # Just for typing, since this class uses dynamic attributes. 

128 def __getattr__(self, name: str) -> DatasetRef: 

129 raise AttributeError(name) 

130 

131 

132class RoundtripBase[T](ABC): 

133 """A context manager for testing serialization. 

134 

135 Parameters 

136 ---------- 

137 tc 

138 A test case object to used for internal checks. 

139 original 

140 The object to serialize. 

141 storage_class 

142 A butler storage class name to use. If not provided (or 

143 `lsst.daf.butler` cannot be imported), the roundtrip will just use 

144 a direct write to a temporary file. 

145 format 

146 Archive/file format to use when not using a butler (ignored when 

147 using a butler). 

148 

149 Notes 

150 ----- 

151 When entered, this context manager writes the object and reads it back in 

152 to the ``result`` attribute. When exited, any temporary files or 

153 directories are deleted, but the ``result`` attribute is still usable. 

154 In between the `inspect` and `get` methods can be used to perform other 

155 tests. 

156 

157 This helper internally tests that butler provenance and metadata are saved 

158 with any `.GeneralizedImage` object. 

159 """ 

160 

161 def __init__( 

162 self, 

163 tc: unittest.TestCase, 

164 original: T, 

165 storage_class: str | None = None, 

166 ): 

167 self._original = original 

168 self._storage_class = storage_class 

169 self._serialized: Any = None 

170 self._exit_stack = ExitStack() 

171 self._filename: str | None = None 

172 self._tc = tc 

173 self.result: Any 

174 self.butler: Butler | None = None 

175 self.ref: DatasetRef | None = None 

176 self._test_metadata: dict[str, MetadataValue] = { 

177 "roundtrip_test_1": 1, 

178 "roundtrip_test_2": 2.5, 

179 "roundtrip_test_3": "three", 

180 "roundtrip_test_4": True, 

181 "roundtrip_test_5": None, 

182 } 

183 

184 def __enter__(self) -> Self: 

185 self._exit_stack.__enter__() 

186 if isinstance(self._original, GeneralizedImage): 

187 self._original.metadata.update(self._test_metadata) 

188 if HAVE_BUTLER and self._storage_class is not None: 

189 self._run_with_butler() 

190 else: 

191 self._run_without_butler() 

192 if isinstance(self._original, GeneralizedImage): 

193 assert isinstance(self.result, GeneralizedImage) 

194 for k in self._test_metadata: 

195 self._tc.assertEqual(self.result.metadata[k], self._test_metadata[k]) 

196 del self._original.metadata[k] 

197 del self.result.metadata[k] 

198 return self 

199 

200 def __exit__(self, *args: Any) -> bool | None: 

201 return self._exit_stack.__exit__(*args) 

202 

203 @property 

204 def filename(self) -> str: 

205 """The name of the file the object was written to.""" 

206 if self._filename is None: 

207 assert self.butler is not None and self.ref is not None 

208 self._filename = self.butler.getURI(self.ref).ospath 

209 return self._filename 

210 

211 @property 

212 def serialized(self) -> Any: 

213 """The serialization model for this object 

214 (`.serialization.ArchiveTree`). 

215 """ 

216 if self._serialized is None: 

217 # The butler code path doesn't give us a way to inspect the 

218 # serialized model, so we have to save it again directly to another 

219 # file (which we then discard). 

220 with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False, delete=True) as tmp: 

221 tmp.close() 

222 self._serialized = fits.write(self._original, tmp.name) 

223 return self._serialized 

224 

225 def get(self, component: str | None = None, storageClass: str | None = None, **kwargs: Any) -> Any: 

226 """Perform a partial read. 

227 

228 Parameters 

229 ---------- 

230 component 

231 Component to read instead of the main object. This requires the 

232 roundtrip to use a butler, raising `unittest.SkipTest` otherwise; 

233 this generally means these tests should be nested within a 

234 `~unittest.TestCase.subTest` context. 

235 storageClass 

236 Override storage class name to affect the type returned by 

237 the get. Only used if a butler is active. 

238 **kwargs 

239 Keyword arguments either passed directly to `.fits.read` or used 

240 as ``parameters`` for a `~lsst.daf.butler.Butler.get`. 

241 

242 Return 

243 ------ 

244 object 

245 Result of the partial read. 

246 """ 

247 if self.butler is None: 

248 if component is not None: 

249 raise unittest.SkipTest("Cannot test component reads without a butler.") 

250 if storageClass is not None: 

251 raise unittest.SkipTest("Cannot test storage class override without a butler") 

252 result = fits.read(type(self._original), self.filename, **kwargs).deserialized 

253 else: 

254 assert self.ref is not None, "butler and ref should be None or not together" 

255 ref = self.ref 

256 if component is not None: 

257 ref = ref.makeComponentRef(component) 

258 result = self.butler.get(ref, parameters=kwargs, storageClass=storageClass) 

259 if isinstance(result, GeneralizedImage): 

260 # The metadata the RoundtripFits object added for the test may or 

261 # may not be present; strip it if it does so comparisons to the 

262 # original are not messed up. 

263 for k in self._test_metadata: 

264 result.metadata.pop(k, None) 

265 return result 

266 

267 def _run_with_butler(self) -> None: 

268 assert self._storage_class is not None, "Should not use butler if no storage class" 

269 # ``GenericFormatter`` defaults to FITS; tell the temporary butler 

270 # which format this Roundtrip variant wants so the on-disk file 

271 # matches ``_get_extension()`` on the round-trip check below. 

272 fmt = self._get_extension().lstrip(".") 

273 butler_helper = self._exit_stack.enter_context( 

274 TemporaryButler(test_dataset=self._storage_class, format=fmt) 

275 ) 

276 self.butler = butler_helper.butler 

277 quantum_id = uuid.uuid4() 

278 self.ref = self.butler.put( 

279 self._original, butler_helper.test_dataset, provenance=DatasetProvenance(quantum_id=quantum_id) 

280 ) 

281 self.result = self.butler.get(self.ref) 

282 if isinstance(self._original, GeneralizedImage): 

283 self._tc.assertEqual( 

284 DatasetRef.from_simple(self.result.butler_dataset, universe=self.butler.dimensions), self.ref 

285 ) 

286 self._tc.assertEqual(self.result.butler_provenance.quantum_id, quantum_id) 

287 self._tc.assertTrue( 

288 self.filename.endswith(self._get_extension()), 

289 f"{self.filename} did not end with {self._get_extension()}", 

290 ) 

291 

292 def _run_without_butler(self) -> None: 

293 tmp = self._exit_stack.enter_context( 

294 tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False, delete=True) 

295 ) 

296 tmp.close() 

297 self._filename = tmp.name 

298 self._serialized = self._write(self._original, tmp.name) 

299 read_result = self._read(type(self._original), tmp.name) 

300 self._tc.assertIsNone(read_result.butler_info) 

301 self.result = read_result.deserialized 

302 

303 @abstractmethod 

304 def _get_extension(self) -> str: 

305 raise NotImplementedError() 

306 

307 @abstractmethod 

308 def _write(self, obj: Any, filename: str) -> ArchiveTree: 

309 raise NotImplementedError() 

310 

311 @abstractmethod 

312 def _read(self, obj_type: Any, filename: str) -> ReadResult: 

313 raise NotImplementedError() 

314 

315 

316class RoundtripFits[T](RoundtripBase[T]): 

317 def inspect(self) -> astropy.io.fits.HDUList: 

318 """Open the FITS file with Astropy.""" 

319 return self._exit_stack.enter_context( 

320 astropy.io.fits.open(self.filename, disable_image_compression=True) 

321 ) 

322 

323 def _get_extension(self) -> str: 

324 return ".fits" 

325 

326 def _write(self, obj: Any, filename: str) -> ArchiveTree: 

327 return fits.write(obj, filename) 

328 

329 def _read(self, obj_type: Any, filename: str) -> ReadResult: 

330 return fits.read(obj_type, filename) 

331 

332 

333class RoundtripJson[T](RoundtripBase[T]): 

334 def inspect(self) -> dict[str, Any]: 

335 """Read the JSON file as a dictionary.""" 

336 with open(self.filename, "rb") as stream: 

337 return from_json(stream.read()) 

338 

339 def _get_extension(self) -> str: 

340 return ".json" 

341 

342 def _write(self, obj: Any, filename: str) -> ArchiveTree: 

343 return json.write(obj, filename) 

344 

345 def _read(self, obj_type: Any, filename: str) -> ReadResult: 

346 return json.read(obj_type, filename) 

347 

348 

349class RoundtripNdf[T](RoundtripBase[T]): 

350 def inspect(self) -> h5py.File: 

351 """Open the NDF file with h5py.""" 

352 import h5py 

353 

354 return self._exit_stack.enter_context(h5py.File(self.filename, "r")) 

355 

356 def _get_extension(self) -> str: 

357 return ".sdf" 

358 

359 def _write(self, obj: Any, filename: str) -> ArchiveTree: 

360 from .. import ndf 

361 

362 return ndf.write(obj, filename) 

363 

364 def _read(self, obj_type: Any, filename: str) -> ReadResult: 

365 from .. import ndf 

366 

367 return ndf.read(obj_type, filename)