Coverage for python/lsst/images/tests/_roundtrip.py: 28%
163 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 09:00 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 09:00 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("RoundtripFits", "RoundtripJson", "RoundtripNdf", "TemporaryButler")
16import tempfile
17import unittest
18import uuid
19from abc import ABC, abstractmethod
20from contextlib import ExitStack
21from typing import TYPE_CHECKING, Any, Self, TypeVar
23import astropy.io.fits
24from pydantic_core import from_json
26if TYPE_CHECKING:
27 import h5py
29try:
30 from lsst.daf.butler import Butler, Config, DataCoordinate, DatasetProvenance, DatasetRef, DatasetType
32 HAVE_BUTLER = True
33except ImportError:
34 HAVE_BUTLER = False
36from .. import fits, json
37from .._generalized_image import GeneralizedImage
38from ..serialization import ArchiveTree, MetadataValue, ReadResult
40# We need an old-style TypeVar for Sphinx.
41T = TypeVar("T")
44class TemporaryButler:
45 """Make a temporary butler repository.
47 Parameters
48 ----------
49 run
50 Name of a `~lsst.daf.butler.CollectionType.RUN` collection to
51 register and use as the default run for the returned butler.
52 format
53 Optional on-disk format name (``fits``, ``json``, ``sdf``,
54 ``zarr``, ...) to bind to every storage class registered by
55 ``**kwargs``. When set, the datastore config is overlaid so that
56 `~lsst.images.formatters.GenericFormatter` writes that format for
57 those storage classes, overriding its ``.fits`` default. Leave as
58 `None` to keep the default formatter behaviour.
59 **kwargs
60 A mapping from a dataset type name to its storage class. For each
61 entry, a dataset type will be registered with empty dimensions, and a
62 `~lsst.daf.butler.DatasetRef` will be created and added as an
63 attribute of this class.
65 Raises
66 ------
67 unittest.SkipTest
68 Raised when the context manager is entered if `lsst.daf.butler` could
69 not be imported. This is typically handled by using this context
70 manager within a `unittest.TestCase.subTest` context, which will skip
71 just the butler-required tests in that context while allowing the rest
72 of the test to continue.
73 """
75 def __init__(self, run: str = "test_run", *, format: str | None = None, **kwargs: str):
76 self.run = run
77 self._format = format
78 self._kwargs = kwargs
79 self._exit_stack = ExitStack()
81 def __enter__(self) -> TemporaryButler:
82 if not HAVE_BUTLER:
83 raise unittest.SkipTest("lsst.daf.butler could not be imported.")
84 self._exit_stack.__enter__()
85 root = self._exit_stack.enter_context(
86 tempfile.TemporaryDirectory(ignore_cleanup_errors=True, delete=True)
87 )
88 if self._format is not None:
89 # Overlay a per-storage-class formatter binding so the default
90 # FITS-writing GenericFormatter writes the requested format
91 # instead. Keyed by the storage class name (matched by the
92 # daf_butler formatter factory).
93 overlay = Config(
94 {
95 "datastore": {
96 "formatters": {
97 storage_class: {
98 "formatter": "lsst.images.formatters.GenericFormatter",
99 "parameters": {"format": self._format},
100 }
101 for storage_class in self._kwargs.values()
102 }
103 }
104 }
105 )
106 butler_config = Butler.makeRepo(root, config=overlay)
107 else:
108 butler_config = Butler.makeRepo(root)
109 self.butler = self._exit_stack.enter_context(Butler.from_config(butler_config, run=self.run))
110 empty_data_id = DataCoordinate.make_empty(self.butler.dimensions)
111 for name, storage_class in self._kwargs.items():
112 dataset_type = DatasetType(name, self.butler.dimensions.empty, storage_class)
113 try:
114 self.butler.registry.registerDatasetType(dataset_type)
115 except KeyError as err:
116 err.add_note(
117 "Storage class not configured in butler defaults. "
118 "A newer version of daf_butler may be needed."
119 )
120 raise
121 setattr(self, name, DatasetRef(dataset_type, empty_data_id, self.run))
122 return self
124 def __exit__(self, *args: Any) -> bool | None:
125 return self._exit_stack.__exit__(*args)
127 # Just for typing, since this class uses dynamic attributes.
128 def __getattr__(self, name: str) -> DatasetRef:
129 raise AttributeError(name)
132class RoundtripBase[T](ABC):
133 """A context manager for testing serialization.
135 Parameters
136 ----------
137 tc
138 A test case object to used for internal checks.
139 original
140 The object to serialize.
141 storage_class
142 A butler storage class name to use. If not provided (or
143 `lsst.daf.butler` cannot be imported), the roundtrip will just use
144 a direct write to a temporary file.
145 format
146 Archive/file format to use when not using a butler (ignored when
147 using a butler).
149 Notes
150 -----
151 When entered, this context manager writes the object and reads it back in
152 to the ``result`` attribute. When exited, any temporary files or
153 directories are deleted, but the ``result`` attribute is still usable.
154 In between the `inspect` and `get` methods can be used to perform other
155 tests.
157 This helper internally tests that butler provenance and metadata are saved
158 with any `.GeneralizedImage` object.
159 """
161 def __init__(
162 self,
163 tc: unittest.TestCase,
164 original: T,
165 storage_class: str | None = None,
166 ):
167 self._original = original
168 self._storage_class = storage_class
169 self._serialized: Any = None
170 self._exit_stack = ExitStack()
171 self._filename: str | None = None
172 self._tc = tc
173 self.result: Any
174 self.butler: Butler | None = None
175 self.ref: DatasetRef | None = None
176 self._test_metadata: dict[str, MetadataValue] = {
177 "roundtrip_test_1": 1,
178 "roundtrip_test_2": 2.5,
179 "roundtrip_test_3": "three",
180 "roundtrip_test_4": True,
181 "roundtrip_test_5": None,
182 }
184 def __enter__(self) -> Self:
185 self._exit_stack.__enter__()
186 if isinstance(self._original, GeneralizedImage):
187 self._original.metadata.update(self._test_metadata)
188 if HAVE_BUTLER and self._storage_class is not None:
189 self._run_with_butler()
190 else:
191 self._run_without_butler()
192 if isinstance(self._original, GeneralizedImage):
193 assert isinstance(self.result, GeneralizedImage)
194 for k in self._test_metadata:
195 self._tc.assertEqual(self.result.metadata[k], self._test_metadata[k])
196 del self._original.metadata[k]
197 del self.result.metadata[k]
198 return self
200 def __exit__(self, *args: Any) -> bool | None:
201 return self._exit_stack.__exit__(*args)
203 @property
204 def filename(self) -> str:
205 """The name of the file the object was written to."""
206 if self._filename is None:
207 assert self.butler is not None and self.ref is not None
208 self._filename = self.butler.getURI(self.ref).ospath
209 return self._filename
211 @property
212 def serialized(self) -> Any:
213 """The serialization model for this object
214 (`.serialization.ArchiveTree`).
215 """
216 if self._serialized is None:
217 # The butler code path doesn't give us a way to inspect the
218 # serialized model, so we have to save it again directly to another
219 # file (which we then discard).
220 with tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False, delete=True) as tmp:
221 tmp.close()
222 self._serialized = fits.write(self._original, tmp.name)
223 return self._serialized
225 def get(self, component: str | None = None, storageClass: str | None = None, **kwargs: Any) -> Any:
226 """Perform a partial read.
228 Parameters
229 ----------
230 component
231 Component to read instead of the main object. This requires the
232 roundtrip to use a butler, raising `unittest.SkipTest` otherwise;
233 this generally means these tests should be nested within a
234 `~unittest.TestCase.subTest` context.
235 storageClass
236 Override storage class name to affect the type returned by
237 the get. Only used if a butler is active.
238 **kwargs
239 Keyword arguments either passed directly to `.fits.read` or used
240 as ``parameters`` for a `~lsst.daf.butler.Butler.get`.
242 Return
243 ------
244 object
245 Result of the partial read.
246 """
247 if self.butler is None:
248 if component is not None:
249 raise unittest.SkipTest("Cannot test component reads without a butler.")
250 if storageClass is not None:
251 raise unittest.SkipTest("Cannot test storage class override without a butler")
252 result = fits.read(type(self._original), self.filename, **kwargs).deserialized
253 else:
254 assert self.ref is not None, "butler and ref should be None or not together"
255 ref = self.ref
256 if component is not None:
257 ref = ref.makeComponentRef(component)
258 result = self.butler.get(ref, parameters=kwargs, storageClass=storageClass)
259 if isinstance(result, GeneralizedImage):
260 # The metadata the RoundtripFits object added for the test may or
261 # may not be present; strip it if it does so comparisons to the
262 # original are not messed up.
263 for k in self._test_metadata:
264 result.metadata.pop(k, None)
265 return result
267 def _run_with_butler(self) -> None:
268 assert self._storage_class is not None, "Should not use butler if no storage class"
269 # ``GenericFormatter`` defaults to FITS; tell the temporary butler
270 # which format this Roundtrip variant wants so the on-disk file
271 # matches ``_get_extension()`` on the round-trip check below.
272 fmt = self._get_extension().lstrip(".")
273 butler_helper = self._exit_stack.enter_context(
274 TemporaryButler(test_dataset=self._storage_class, format=fmt)
275 )
276 self.butler = butler_helper.butler
277 quantum_id = uuid.uuid4()
278 self.ref = self.butler.put(
279 self._original, butler_helper.test_dataset, provenance=DatasetProvenance(quantum_id=quantum_id)
280 )
281 self.result = self.butler.get(self.ref)
282 if isinstance(self._original, GeneralizedImage):
283 self._tc.assertEqual(
284 DatasetRef.from_simple(self.result.butler_dataset, universe=self.butler.dimensions), self.ref
285 )
286 self._tc.assertEqual(self.result.butler_provenance.quantum_id, quantum_id)
287 self._tc.assertTrue(
288 self.filename.endswith(self._get_extension()),
289 f"{self.filename} did not end with {self._get_extension()}",
290 )
292 def _run_without_butler(self) -> None:
293 tmp = self._exit_stack.enter_context(
294 tempfile.NamedTemporaryFile(suffix=".fits", delete_on_close=False, delete=True)
295 )
296 tmp.close()
297 self._filename = tmp.name
298 self._serialized = self._write(self._original, tmp.name)
299 read_result = self._read(type(self._original), tmp.name)
300 self._tc.assertIsNone(read_result.butler_info)
301 self.result = read_result.deserialized
303 @abstractmethod
304 def _get_extension(self) -> str:
305 raise NotImplementedError()
307 @abstractmethod
308 def _write(self, obj: Any, filename: str) -> ArchiveTree:
309 raise NotImplementedError()
311 @abstractmethod
312 def _read(self, obj_type: Any, filename: str) -> ReadResult:
313 raise NotImplementedError()
316class RoundtripFits[T](RoundtripBase[T]):
317 def inspect(self) -> astropy.io.fits.HDUList:
318 """Open the FITS file with Astropy."""
319 return self._exit_stack.enter_context(
320 astropy.io.fits.open(self.filename, disable_image_compression=True)
321 )
323 def _get_extension(self) -> str:
324 return ".fits"
326 def _write(self, obj: Any, filename: str) -> ArchiveTree:
327 return fits.write(obj, filename)
329 def _read(self, obj_type: Any, filename: str) -> ReadResult:
330 return fits.read(obj_type, filename)
333class RoundtripJson[T](RoundtripBase[T]):
334 def inspect(self) -> dict[str, Any]:
335 """Read the JSON file as a dictionary."""
336 with open(self.filename, "rb") as stream:
337 return from_json(stream.read())
339 def _get_extension(self) -> str:
340 return ".json"
342 def _write(self, obj: Any, filename: str) -> ArchiveTree:
343 return json.write(obj, filename)
345 def _read(self, obj_type: Any, filename: str) -> ReadResult:
346 return json.read(obj_type, filename)
349class RoundtripNdf[T](RoundtripBase[T]):
350 def inspect(self) -> h5py.File:
351 """Open the NDF file with h5py."""
352 import h5py
354 return self._exit_stack.enter_context(h5py.File(self.filename, "r"))
356 def _get_extension(self) -> str:
357 return ".sdf"
359 def _write(self, obj: Any, filename: str) -> ArchiveTree:
360 from .. import ndf
362 return ndf.write(obj, filename)
364 def _read(self, obj_type: Any, filename: str) -> ReadResult:
365 from .. import ndf
367 return ndf.read(obj_type, filename)