Coverage for python / lsst / images / psfs / _piff.py: 35%

183 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-20 08:29 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14from lsst.images.serialization import ArchiveReadError 

15 

16__all__ = ("PiffSerializationModel", "PiffWrapper") 

17 

18import operator 

19from collections.abc import Iterator 

20from contextlib import contextmanager 

21from functools import cached_property 

22from logging import getLogger 

23from typing import TYPE_CHECKING, Annotated, Any, Literal 

24 

25import astropy.io.fits 

26import numpy as np 

27import pydantic 

28 

29from .. import serialization 

30from .._concrete_bounds import SerializableBounds 

31from .._geom import Bounds, Box 

32from .._image import Image 

33from ..utils import round_half_up 

34from ._base import PointSpreadFunction 

35 

36if TYPE_CHECKING: 

37 import galsim.wcs 

38 import piff.config 

39 

40_LOG = getLogger(__name__) 

41 

42 

43class PiffWrapper(PointSpreadFunction): 

44 """A PSF model backed by the Piff library. 

45 

46 Parameters 

47 ---------- 

48 impl 

49 The Piff PSF object to wrap. 

50 bounds 

51 The pixel-coordinate region where the model can safely be evaluated. 

52 """ 

53 

54 def __init__(self, impl: piff.PSF, bounds: Bounds, stamp_size: int): 

55 self._impl = impl 

56 self._bounds = bounds 

57 self._stamp_size = stamp_size 

58 

59 @property 

60 def bounds(self) -> Bounds: 

61 return self._bounds 

62 

63 @cached_property 

64 def kernel_bbox(self) -> Box: 

65 r = self._stamp_size // 2 

66 return Box.factory[-r : r + 1, -r : r + 1] 

67 

68 def compute_kernel_image(self, *, x: float, y: float) -> Image: 

69 if "colorValue" in self._impl.interp_property_names: 

70 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

71 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=True) 

72 r = self._stamp_size // 2 

73 result = Image(gs_image.array.copy(), start=(-r, -r)) 

74 result.array /= np.sum(result.array) 

75 return result 

76 

77 def compute_stellar_image(self, *, x: float, y: float) -> Image: 

78 if "colorValue" in self._impl.interp_property_names: 

79 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

80 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=None) 

81 r = self._stamp_size // 2 

82 result = Image(gs_image.array.copy(), start=(round_half_up(y) - r, round_half_up(x) - r)) 

83 result.array /= np.sum(result.array) 

84 return result 

85 

86 def compute_stellar_bbox(self, *, x: float, y: float) -> Box: 

87 r = self._stamp_size // 2 

88 xi = round_half_up(x) 

89 yi = round_half_up(y) 

90 return Box.factory[yi - r : yi + r + 1, xi - r : xi + r + 1] 

91 

92 @property 

93 def piff_psf(self) -> Any: 

94 """The backing `piff.PSF` object. 

95 

96 This is an internal object that must not be modified in place. 

97 """ 

98 return self._impl 

99 

100 @classmethod 

101 def from_legacy(cls, legacy_psf: Any, bounds: Bounds) -> PiffWrapper: 

102 return cls(impl=legacy_psf._piffResult, bounds=bounds, stamp_size=int(legacy_psf.width)) 

103 

104 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffSerializationModel: 

105 """Serialize the PSF to an archive. 

106 

107 This method is intended to be usable as the callback function passed to 

108 `.serialization.OutputArchive.serialize_direct` or 

109 `.serialization.OutputArchive.serialize_pointer`. 

110 """ 

111 from piff.config import PiffLogger 

112 

113 writer = _ArchivePiffWriter() 

114 with self._without_stars(): 

115 self._impl._write(writer, "piff", PiffLogger(_LOG)) 

116 piff_model = writer.serialize(archive) 

117 return PiffSerializationModel( 

118 piff=piff_model, 

119 stamp_size=self._stamp_size, 

120 bounds=self._bounds.serialize(), 

121 ) 

122 

123 @staticmethod 

124 def _get_archive_tree_type( 

125 pointer_type: type[pydantic.BaseModel], 

126 ) -> type[PiffSerializationModel]: 

127 """Return the serialization model type for this object for an archive 

128 type that uses the given pointer type. 

129 """ 

130 return PiffSerializationModel 

131 

132 @contextmanager 

133 def _without_stars(self) -> Iterator[None]: 

134 """Temporarily drop the embedded list of stars used to fit the PSF. 

135 

136 Notes 

137 ----- 

138 By default Piff saves the list of stars (including postage stamps) used 

139 to fit the PSF, which makes the serialized form much larger. But the 

140 upstream Piff serialization code recognizes the case where that 

141 ``stars`` attribute has been deleted and serializes everything else. 

142 

143 Unfortunately, to date, Rubin's pickle-based Piff serialization instead 

144 just deletes the postage stamp image attributes from inside the Piff 

145 ``stars`` list, which is not a state the Piff serialization code 

146 handles gracefully. So for now we have to drop the full stars list 

147 during serialization if it is present. 

148 """ 

149 if hasattr(self._impl, "stars"): 

150 stars = self._impl.stars 

151 try: 

152 del self._impl.stars 

153 yield 

154 finally: 

155 self._impl.stars = stars 

156 else: 

157 yield 

158 

159 

160# Conventions on public visibility of the serialization types: 

161# 

162# - We lift and document the outermost Pydantic model type, since that needs to 

163# be included directly in the Pydantic models of types that hold a PSF. This 

164# type needs to be very clearly documented and named as a *serialization* 

165# model, since there are many other kinds of models in play in this package. 

166# 

167# - We do not lift or document types used in that outermost model, but we do 

168# not give them leading underscores, since they aren't really private. 

169# 

170# - Other utility types do get leading underscores. 

171 

172 

173# Piff serialization uses a lot of dictionaries and lists restricted to these 

174# basic types. 

175type PiffScalar = int | float | str | bool | None 

176type PiffValue = PiffScalar | list[PiffValue] 

177type PiffDict = dict[str, PiffValue] 

178 

179 

180class GalSimPixelScaleModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

181 """Model used to serialize `galsim.wcs.PixelScale` instances.""" 

182 

183 scale: float 

184 wcs_type: Literal["pixel_scale"] = "pixel_scale" 

185 

186 

187# We expect this discriminated union to grow to include other trivial 

188# pixel-to-pixel transforms that get embedded in PSFs. If we someday have to 

189# store Piff objects that embed more sophisticated PSFs, we'll hook them into 

190# the AST-based coordinate transform system instead, but as long as we're just 

191# talking about simple offsets and scalings, that's a lot of extra complexity 

192# for very little gain. 

193type GalSimLocalWcsModel = Annotated[GalSimPixelScaleModel, pydantic.Field(discriminator="wcs_type")] 

194 

195 

196class PiffTableModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

197 """Serialization model used to embed a reference to a binary-data table in 

198 a Piff serialization's JSON-like data. 

199 """ 

200 

201 metadata: PiffDict 

202 table: serialization.TableModel 

203 

204 

205class PiffObjectModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

206 """General-purpose serialization model used for various Piff objects.""" 

207 

208 structs: dict[str, PiffDict] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

209 tables: dict[str, PiffTableModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

210 wcs: dict[str, GalSimLocalWcsModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

211 objects: dict[str, PiffObjectModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

212 

213 

214class PiffSerializationModel(serialization.ArchiveTree): 

215 """Serialization model for a Piff PSF.""" 

216 

217 piff: PiffObjectModel = pydantic.Field(description="The Piff PSF object itself.") 

218 

219 stamp_size: int = pydantic.Field( 

220 description="Width of the (square) images returned by this PSF's methods." 

221 ) 

222 

223 bounds: SerializableBounds = pydantic.Field( 

224 description="The bounds object that represents the PSF's validity region." 

225 ) 

226 

227 def deserialize(self, archive: serialization.InputArchive[Any], **kwargs: Any) -> PiffWrapper: 

228 """Deserialize the PSF from an archive. 

229 

230 This method is intended to be usable as the callback function passed to 

231 `.serialization.InputArchive.deserialize_pointer`. 

232 """ 

233 if kwargs: 

234 raise serialization.InvalidParameterError( 

235 f"Unrecognized parameters for PiffWrapper: {set(kwargs.keys())}." 

236 ) 

237 try: 

238 from piff import PSF 

239 from piff.config import PiffLogger 

240 except ImportError: 

241 raise ArchiveReadError("Failed to import piff.") from None 

242 

243 reader = _ArchivePiffReader(self.piff, archive) 

244 impl = PSF._read(reader, "piff", PiffLogger(_LOG)) 

245 return PiffWrapper(impl, bounds=self.bounds.deserialize(), stamp_size=self.stamp_size) 

246 

247 

248class _ArchivePiffWriter: 

249 """An adapter from the Piff serialization interface to the 

250 `.serialization.OutputArchive` class. 

251 

252 Notes 

253 ----- 

254 Piff has its own simple serialization framework (contributed upstream by 

255 Rubin DM) that maps everything to dictionaries, structured numpy arrays, 

256 and a library of GalSim WCS objects, with the native implementation writing 

257 standalone FITS files. That mostly maps nicely to the `lsst.images` 

258 archive system, but we don't get to leverage any Pydantic validation or 

259 JSON schema functionality since we only get opaque dictionaries from Piff. 

260 

261 See `piff.FitsWriter` for most method documentation; this class is designed 

262 to mimic it exactly (the Piff authors prefer to just use duck-typing rather 

263 than ABCs or protocols for interface definition). 

264 """ 

265 

266 def __init__(self, base_name: str = ""): 

267 self._base_name = base_name 

268 self.structs: dict[str, PiffDict] = {} 

269 self.tables: dict[str, tuple[np.ndarray, PiffDict]] = {} 

270 self.wcs_models: dict[str, GalSimLocalWcsModel] = {} 

271 self.writers: dict[str, _ArchivePiffWriter] = {} 

272 

273 def write_struct(self, name: str, struct: PiffDict) -> None: 

274 self.structs[name] = {k: self._to_builtin(v) for k, v in struct.items()} 

275 

276 def write_table(self, name: str, array: np.ndarray, metadata: PiffDict | None = None) -> None: 

277 self.tables[name] = ( 

278 array, 

279 {k: self._to_builtin(v) for k, v in (metadata or {}).items()}, 

280 ) 

281 

282 def write_wcs_map( 

283 self, name: str, wcs_map: dict[int, galsim.wcs.BaseWCS], pointing: galsim.CelestialCoord | None 

284 ) -> None: 

285 import galsim.wcs 

286 

287 match wcs_map: 

288 case {0: galsim.wcs.PixelScale() as wcs} if pointing is None: 

289 self.wcs_models[name] = GalSimPixelScaleModel(scale=wcs.scale) 

290 case _: 

291 raise NotImplementedError("PSFs with complex embedded WCSs are not supported.") 

292 

293 @contextmanager 

294 def nested(self, name: str) -> Iterator[_ArchivePiffWriter]: 

295 nested = _ArchivePiffWriter(self.get_full_name(name)) 

296 yield nested 

297 self.writers[name] = nested 

298 

299 def get_full_name(self, name: str) -> str: 

300 return f"{self._base_name}/{name}" 

301 

302 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffObjectModel: 

303 """Serialize to an archive. 

304 

305 This method is intended to be used as the callable passed to 

306 `.serialization.OutputArchive.serialize_direct` and 

307 `.serialization.OutputArchive.serialize_pointer`, after first passing 

308 this writer to a Piff object's ``write`` or ``_write`` method. 

309 """ 

310 model = PiffObjectModel() 

311 for name, struct in self.structs.items(): 

312 model.structs[name] = struct 

313 for name, (array, metadata) in self.tables.items(): 

314 model.tables[name] = PiffTableModel( 

315 metadata=metadata, 

316 table=archive.add_structured_array( 

317 array, name=name, update_header=lambda header: header.update(metadata) 

318 ), 

319 ) 

320 for name, wcs_model in self.wcs_models.items(): 

321 model.wcs[name] = wcs_model 

322 for name, writer in self.writers.items(): 

323 model.objects[name] = archive.serialize_direct(name, writer.serialize) 

324 return model 

325 

326 @staticmethod 

327 def _to_builtin(val: Any) -> PiffValue: 

328 match val: 

329 case np.integer(): 

330 return int(val) 

331 case np.floating(): 

332 return float(val) 

333 case np.bool_(): 

334 return bool(val) 

335 case np.str_(): 

336 return str(val) 

337 case tuple() | list(): 

338 return [_ArchivePiffWriter._to_builtin(item) for item in val] 

339 return val 

340 

341 

342class _ArchivePiffReader: 

343 """An adapter from the Piff serialization interface to the 

344 `.serialization.InputArchive` class. 

345 

346 See `ArchivePiffWriter` for additional notes. 

347 """ 

348 

349 def __init__( 

350 self, object_model: PiffObjectModel, archive: serialization.InputArchive[Any], base_name: str = "" 

351 ): 

352 self._model = object_model 

353 self._archive = archive 

354 self._base_name = base_name 

355 

356 def read_struct(self, name: str) -> PiffDict | None: 

357 return self._model.structs.get(name) 

358 

359 def read_table(self, name: str, metadata: PiffDict | None = None) -> np.ndarray | None: 

360 table_model = self._model.tables.get(name) 

361 if table_model is None: 

362 return None 

363 if metadata is not None: 

364 metadata.update(table_model.metadata) 

365 return self._archive.get_structured_array( 

366 table_model.table, strip_header=astropy.io.fits.Header.clear 

367 ) 

368 

369 def read_wcs_map( 

370 self, name: str, logger: piff.config.LoggerWrapper 

371 ) -> tuple[dict[int, galsim.wcs.BaseWCS] | None, galsim.CelestialCoord | None]: 

372 import galsim.wcs 

373 

374 match self._model.wcs.get(name): 

375 case GalSimPixelScaleModel(scale=scale): 

376 return {0: galsim.wcs.PixelScale(scale)}, None 

377 case None: 

378 return None, None 

379 case unexpected: 

380 raise serialization.ArchiveReadError( 

381 f"{self.get_full_name(name)} should be a WCS or WCS map, not {unexpected!r}." 

382 ) 

383 

384 @contextmanager 

385 def nested(self, name: str) -> Iterator[_ArchivePiffReader]: 

386 nested_model = self._model.objects[name] 

387 yield _ArchivePiffReader(nested_model, self._archive, self.get_full_name(name)) 

388 

389 def get_full_name(self, name: str) -> str: 

390 return f"{self._base_name}/{name}"