Coverage for python / lsst / images / psfs / _piff.py: 37%

176 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-15 01:53 -0700

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14from lsst.images.serialization import ArchiveReadError 

15 

16__all__ = ("PiffSerializationModel", "PiffWrapper") 

17 

18import operator 

19from collections.abc import Iterator 

20from contextlib import contextmanager 

21from functools import cached_property 

22from logging import getLogger 

23from typing import TYPE_CHECKING, Annotated, Any, Literal 

24 

25import astropy.io.fits 

26import numpy as np 

27import pydantic 

28 

29from .. import serialization 

30from .._concrete_bounds import SerializableBounds 

31from .._geom import Bounds, Box 

32from .._image import Image 

33from ..utils import round_half_up 

34from ._base import PointSpreadFunction 

35 

36if TYPE_CHECKING: 

37 import galsim.wcs 

38 import piff.config 

39 

40_LOG = getLogger(__name__) 

41 

42 

43class PiffWrapper(PointSpreadFunction): 

44 """A PSF model backed by the Piff library. 

45 

46 Parameters 

47 ---------- 

48 impl 

49 The Piff PSF object to wrap. 

50 bounds 

51 The pixel-coordinate region where the model can safely be evaluated. 

52 """ 

53 

54 def __init__(self, impl: piff.PSF, bounds: Bounds, stamp_size: int): 

55 self._impl = impl 

56 self._bounds = bounds 

57 self._stamp_size = stamp_size 

58 

59 @property 

60 def bounds(self) -> Bounds: 

61 return self._bounds 

62 

63 @cached_property 

64 def kernel_bbox(self) -> Box: 

65 r = self._stamp_size // 2 

66 return Box.factory[-r : r + 1, -r : r + 1] 

67 

68 def compute_kernel_image(self, *, x: float, y: float) -> Image: 

69 if "colorValue" in self._impl.interp_property_names: 

70 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

71 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=True) 

72 r = self._stamp_size // 2 

73 result = Image(gs_image.array.copy(), start=(-r, -r)) 

74 result.array /= np.sum(result.array) 

75 return result 

76 

77 def compute_stellar_image(self, *, x: float, y: float) -> Image: 

78 if "colorValue" in self._impl.interp_property_names: 

79 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

80 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=None) 

81 r = self._stamp_size // 2 

82 result = Image(gs_image.array.copy(), start=(round_half_up(y) - r, round_half_up(x) - r)) 

83 result.array /= np.sum(result.array) 

84 return result 

85 

86 def compute_stellar_bbox(self, *, x: float, y: float) -> Box: 

87 r = self._stamp_size // 2 

88 xi = round_half_up(x) 

89 yi = round_half_up(y) 

90 return Box.factory[yi - r : yi + r + 1, xi - r : xi + r + 1] 

91 

92 @property 

93 def piff_psf(self) -> Any: 

94 """The backing `piff.PSF` object. 

95 

96 This is an internal object that must not be modified in place. 

97 """ 

98 return self._impl 

99 

100 @classmethod 

101 def from_legacy(cls, legacy_psf: Any, bounds: Bounds) -> PiffWrapper: 

102 return cls(impl=legacy_psf._piffResult, bounds=bounds, stamp_size=int(legacy_psf.width)) 

103 

104 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffSerializationModel: 

105 """Serialize the PSF to an archive. 

106 

107 This method is intended to be usable as the callback function passed to 

108 `.serialization.OutputArchive.serialize_direct` or 

109 `.serialization.OutputArchive.serialize_pointer`. 

110 """ 

111 from piff.config import PiffLogger 

112 

113 writer = _ArchivePiffWriter() 

114 with self._without_stars(): 

115 self._impl._write(writer, "piff", PiffLogger(_LOG)) 

116 piff_model = writer.serialize(archive) 

117 return PiffSerializationModel( 

118 piff=piff_model, 

119 stamp_size=self._stamp_size, 

120 bounds=self._bounds.serialize(), 

121 ) 

122 

123 @staticmethod 

124 def _get_archive_tree_type( 

125 pointer_type: type[pydantic.BaseModel], 

126 ) -> type[PiffSerializationModel]: 

127 """Return the serialization model type for this object for an archive 

128 type that uses the given pointer type. 

129 """ 

130 return PiffSerializationModel 

131 

132 @contextmanager 

133 def _without_stars(self) -> Iterator[None]: 

134 """Temporarily drop the embedded list of stars used to fit the PSF. 

135 

136 Notes 

137 ----- 

138 By default Piff saves the list of stars (including postage stamps) used 

139 to fit the PSF, which makes the serialized form much larger. But the 

140 upstream Piff serialization code recognizes the case where that 

141 ``stars`` attribute has been deleted and serializes everything else. 

142 

143 Unfortunately, to date, Rubin's pickle-based Piff serialization instead 

144 just deletes the postage stamp image attributes from inside the Piff 

145 ``stars`` list, which is not a state the Piff serialization code 

146 handles gracefully. So for now we have to drop the full stars list 

147 during serialization if it is present. 

148 """ 

149 if hasattr(self._impl, "stars"): 

150 stars = self._impl.stars 

151 try: 

152 del self._impl.stars 

153 yield 

154 finally: 

155 self._impl.stars = stars 

156 else: 

157 yield 

158 

159 

160# Conventions on public visibility of the serialization types: 

161# 

162# - We lift and document the outermost Pydantic model type, since that needs to 

163# be included directly in the Pydantic models of types that hold a PSF. This 

164# type needs to be very clearly documented and named as a *serialization* 

165# model, since there are many other kinds of models in play in this package. 

166# 

167# - We do not lift or document types used in that outermost model, but we do 

168# not give them leading underscores, since they aren't really private. 

169# 

170# - Other utility types do get leading underscores. 

171 

172 

173# Piff serialization uses a lot of dictionaries and lists restricted to these 

174# basic types. 

175type PiffScalar = int | float | str | bool | None 

176type PiffDict = dict[str, PiffScalar | list[PiffScalar]] 

177 

178 

179class GalSimPixelScaleModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

180 """Model used to serialize `galsim.wcs.PixelScale` instances.""" 

181 

182 scale: float 

183 wcs_type: Literal["pixel_scale"] = "pixel_scale" 

184 

185 

186# We expect this discriminated union to grow to include other trivial 

187# pixel-to-pixel transforms that get embedded in PSFs. If we someday have to 

188# store Piff objects that embed more sophisticated PSFs, we'll hook them into 

189# the AST-based coordinate transform system instead, but as long as we're just 

190# talking about simple offsets and scalings, that's a lot of extra complexity 

191# for very little gain. 

192type GalSimLocalWcsModel = Annotated[GalSimPixelScaleModel, pydantic.Field(discriminator="wcs_type")] 

193 

194 

195class PiffTableModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

196 """Serialization model used to embed a reference to a binary-data table in 

197 a Piff serialization's JSON-like data. 

198 """ 

199 

200 metadata: PiffDict 

201 table: serialization.TableModel 

202 

203 

204class PiffObjectModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

205 """General-purpose serialization model used for various Piff objects.""" 

206 

207 structs: dict[str, PiffDict] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

208 tables: dict[str, PiffTableModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

209 wcs: dict[str, GalSimLocalWcsModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

210 objects: dict[str, PiffObjectModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

211 

212 

213class PiffSerializationModel(serialization.ArchiveTree): 

214 """Serialization model for a Piff PSF.""" 

215 

216 piff: PiffObjectModel = pydantic.Field(description="The Piff PSF object itself.") 

217 

218 stamp_size: int = pydantic.Field( 

219 description="Width of the (square) images returned by this PSF's methods." 

220 ) 

221 

222 bounds: SerializableBounds = pydantic.Field( 

223 description="The bounds object that represents the PSF's validity region." 

224 ) 

225 

226 def deserialize(self, archive: serialization.InputArchive[Any]) -> PiffWrapper: 

227 """Deserialize the PSF from an archive. 

228 

229 This method is intended to be usable as the callback function passed to 

230 `.serialization.InputArchive.deserialize_pointer`. 

231 """ 

232 try: 

233 from piff import PSF 

234 from piff.config import PiffLogger 

235 except ImportError: 

236 raise ArchiveReadError("Failed to import piff.") from None 

237 

238 reader = _ArchivePiffReader(self.piff, archive) 

239 impl = PSF._read(reader, "piff", PiffLogger(_LOG)) 

240 return PiffWrapper(impl, bounds=self.bounds.deserialize(), stamp_size=self.stamp_size) 

241 

242 

243class _ArchivePiffWriter: 

244 """An adapter from the Piff serialization interface to the 

245 `.serialization.OutputArchive` class. 

246 

247 Notes 

248 ----- 

249 Piff has its own simple serialization framework (contributed upstream by 

250 Rubin DM) that maps everything to dictionaries, structured numpy arrays, 

251 and a library of GalSim WCS objects, with the native implementation writing 

252 standalone FITS files. That mostly maps nicely to the `lsst.images` 

253 archive system, but we don't get to leverage any Pydantic validation or 

254 JSON schema functionality since we only get opaque dictionaries from Piff. 

255 

256 See `piff.FitsWriter` for most method documentation; this class is designed 

257 to mimic it exactly (the Piff authors prefer to just use duck-typing rather 

258 than ABCs or protocols for interface definition). 

259 """ 

260 

261 def __init__(self, base_name: str = ""): 

262 self._base_name = base_name 

263 self.structs: dict[str, PiffDict] = {} 

264 self.tables: dict[str, tuple[np.ndarray, PiffDict]] = {} 

265 self.wcs_models: dict[str, GalSimLocalWcsModel] = {} 

266 self.writers: dict[str, _ArchivePiffWriter] = {} 

267 

268 def write_struct(self, name: str, struct: PiffDict) -> None: 

269 self.structs[name] = {k: self._to_builtin(v) for k, v in struct.items()} 

270 

271 def write_table(self, name: str, array: np.ndarray, metadata: PiffDict | None = None) -> None: 

272 self.tables[name] = (array, metadata or {}) 

273 

274 def write_wcs_map( 

275 self, name: str, wcs_map: dict[int, galsim.wcs.BaseWCS], pointing: galsim.CelestialCoord | None 

276 ) -> None: 

277 import galsim.wcs 

278 

279 match wcs_map: 

280 case {0: galsim.wcs.PixelScale() as wcs} if pointing is None: 

281 self.wcs_models[name] = GalSimPixelScaleModel(scale=wcs.scale) 

282 case _: 

283 raise NotImplementedError("PSFs with complex embedded WCSs are not supported.") 

284 

285 @contextmanager 

286 def nested(self, name: str) -> Iterator[_ArchivePiffWriter]: 

287 nested = _ArchivePiffWriter(self.get_full_name(name)) 

288 yield nested 

289 self.writers[name] = nested 

290 

291 def get_full_name(self, name: str) -> str: 

292 return f"{self._base_name}/{name}" 

293 

294 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffObjectModel: 

295 """Serialize to an archive. 

296 

297 This method is intended to be used as the callable passed to 

298 `.serialization.OutputArchive.serialize_direct` and 

299 `.serialization.OutputArchive.serialize_pointer`, after first passing 

300 this writer to a Piff object's ``write`` or ``_write`` method. 

301 """ 

302 model = PiffObjectModel() 

303 for name, struct in self.structs.items(): 

304 model.structs[name] = struct 

305 for name, (array, metadata) in self.tables.items(): 

306 model.tables[name] = PiffTableModel( 

307 metadata=metadata, 

308 table=archive.add_structured_array( 

309 array, name=name, update_header=lambda header: header.update(metadata) 

310 ), 

311 ) 

312 for name, wcs_model in self.wcs_models.items(): 

313 model.wcs[name] = wcs_model 

314 for name, writer in self.writers.items(): 

315 model.objects[name] = archive.serialize_direct(name, writer.serialize) 

316 return model 

317 

318 @staticmethod 

319 def _to_builtin(val: Any) -> PiffScalar: 

320 match val: 

321 case np.integer(): 

322 return int(val) 

323 case np.floating(): 

324 return float(val) 

325 case np.str_(): 

326 return str(val) 

327 return val 

328 

329 

330class _ArchivePiffReader: 

331 """An adapter from the Piff serialization interface to the 

332 `.serialization.InputArchive` class. 

333 

334 See `ArchivePiffWriter` for additional notes. 

335 """ 

336 

337 def __init__( 

338 self, object_model: PiffObjectModel, archive: serialization.InputArchive[Any], base_name: str = "" 

339 ): 

340 self._model = object_model 

341 self._archive = archive 

342 self._base_name = base_name 

343 

344 def read_struct(self, name: str) -> PiffDict | None: 

345 return self._model.structs.get(name) 

346 

347 def read_table(self, name: str, metadata: PiffDict | None = None) -> np.ndarray | None: 

348 table_model = self._model.tables.get(name) 

349 if table_model is None: 

350 return None 

351 if metadata is not None: 

352 metadata.update(table_model.metadata) 

353 return self._archive.get_structured_array( 

354 table_model.table, strip_header=astropy.io.fits.Header.clear 

355 ) 

356 

357 def read_wcs_map( 

358 self, name: str, logger: piff.config.LoggerWrapper 

359 ) -> tuple[dict[int, galsim.wcs.BaseWCS] | None, galsim.CelestialCoord | None]: 

360 import galsim.wcs 

361 

362 match self._model.wcs.get(name): 

363 case GalSimPixelScaleModel(scale=scale): 

364 return {0: galsim.wcs.PixelScale(scale)}, None 

365 case None: 

366 return None, None 

367 case unexpected: 

368 raise serialization.ArchiveReadError( 

369 f"{self.get_full_name(name)} should be a WCS or WCS map, not {unexpected!r}." 

370 ) 

371 

372 @contextmanager 

373 def nested(self, name: str) -> Iterator[_ArchivePiffReader]: 

374 nested_model = self._model.objects[name] 

375 yield _ArchivePiffReader(nested_model, self._archive, self.get_full_name(name)) 

376 

377 def get_full_name(self, name: str) -> str: 

378 return f"{self._base_name}/{name}"