Coverage for python / lsst / images / psfs / _piff.py: 36%

181 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-16 00:52 -0700

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14from lsst.images.serialization import ArchiveReadError 

15 

16__all__ = ("PiffSerializationModel", "PiffWrapper") 

17 

18import operator 

19from collections.abc import Iterator 

20from contextlib import contextmanager 

21from functools import cached_property 

22from logging import getLogger 

23from typing import TYPE_CHECKING, Annotated, Any, Literal 

24 

25import astropy.io.fits 

26import numpy as np 

27import pydantic 

28 

29from .. import serialization 

30from .._concrete_bounds import SerializableBounds 

31from .._geom import Bounds, Box 

32from .._image import Image 

33from ..utils import round_half_up 

34from ._base import PointSpreadFunction 

35 

36if TYPE_CHECKING: 

37 import galsim.wcs 

38 import piff.config 

39 

40_LOG = getLogger(__name__) 

41 

42 

43class PiffWrapper(PointSpreadFunction): 

44 """A PSF model backed by the Piff library. 

45 

46 Parameters 

47 ---------- 

48 impl 

49 The Piff PSF object to wrap. 

50 bounds 

51 The pixel-coordinate region where the model can safely be evaluated. 

52 """ 

53 

54 def __init__(self, impl: piff.PSF, bounds: Bounds, stamp_size: int): 

55 self._impl = impl 

56 self._bounds = bounds 

57 self._stamp_size = stamp_size 

58 

59 @property 

60 def bounds(self) -> Bounds: 

61 return self._bounds 

62 

63 @cached_property 

64 def kernel_bbox(self) -> Box: 

65 r = self._stamp_size // 2 

66 return Box.factory[-r : r + 1, -r : r + 1] 

67 

68 def compute_kernel_image(self, *, x: float, y: float) -> Image: 

69 if "colorValue" in self._impl.interp_property_names: 

70 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

71 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=True) 

72 r = self._stamp_size // 2 

73 result = Image(gs_image.array.copy(), start=(-r, -r)) 

74 result.array /= np.sum(result.array) 

75 return result 

76 

77 def compute_stellar_image(self, *, x: float, y: float) -> Image: 

78 if "colorValue" in self._impl.interp_property_names: 

79 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

80 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=None) 

81 r = self._stamp_size // 2 

82 result = Image(gs_image.array.copy(), start=(round_half_up(y) - r, round_half_up(x) - r)) 

83 result.array /= np.sum(result.array) 

84 return result 

85 

86 def compute_stellar_bbox(self, *, x: float, y: float) -> Box: 

87 r = self._stamp_size // 2 

88 xi = round_half_up(x) 

89 yi = round_half_up(y) 

90 return Box.factory[yi - r : yi + r + 1, xi - r : xi + r + 1] 

91 

92 @property 

93 def piff_psf(self) -> Any: 

94 """The backing `piff.PSF` object. 

95 

96 This is an internal object that must not be modified in place. 

97 """ 

98 return self._impl 

99 

100 @classmethod 

101 def from_legacy(cls, legacy_psf: Any, bounds: Bounds) -> PiffWrapper: 

102 return cls(impl=legacy_psf._piffResult, bounds=bounds, stamp_size=int(legacy_psf.width)) 

103 

104 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffSerializationModel: 

105 """Serialize the PSF to an archive. 

106 

107 This method is intended to be usable as the callback function passed to 

108 `.serialization.OutputArchive.serialize_direct` or 

109 `.serialization.OutputArchive.serialize_pointer`. 

110 """ 

111 from piff.config import PiffLogger 

112 

113 writer = _ArchivePiffWriter() 

114 with self._without_stars(): 

115 self._impl._write(writer, "piff", PiffLogger(_LOG)) 

116 piff_model = writer.serialize(archive) 

117 return PiffSerializationModel( 

118 piff=piff_model, 

119 stamp_size=self._stamp_size, 

120 bounds=self._bounds.serialize(), 

121 ) 

122 

123 @staticmethod 

124 def _get_archive_tree_type( 

125 pointer_type: type[pydantic.BaseModel], 

126 ) -> type[PiffSerializationModel]: 

127 """Return the serialization model type for this object for an archive 

128 type that uses the given pointer type. 

129 """ 

130 return PiffSerializationModel 

131 

132 @contextmanager 

133 def _without_stars(self) -> Iterator[None]: 

134 """Temporarily drop the embedded list of stars used to fit the PSF. 

135 

136 Notes 

137 ----- 

138 By default Piff saves the list of stars (including postage stamps) used 

139 to fit the PSF, which makes the serialized form much larger. But the 

140 upstream Piff serialization code recognizes the case where that 

141 ``stars`` attribute has been deleted and serializes everything else. 

142 

143 Unfortunately, to date, Rubin's pickle-based Piff serialization instead 

144 just deletes the postage stamp image attributes from inside the Piff 

145 ``stars`` list, which is not a state the Piff serialization code 

146 handles gracefully. So for now we have to drop the full stars list 

147 during serialization if it is present. 

148 """ 

149 if hasattr(self._impl, "stars"): 

150 stars = self._impl.stars 

151 try: 

152 del self._impl.stars 

153 yield 

154 finally: 

155 self._impl.stars = stars 

156 else: 

157 yield 

158 

159 

160# Conventions on public visibility of the serialization types: 

161# 

162# - We lift and document the outermost Pydantic model type, since that needs to 

163# be included directly in the Pydantic models of types that hold a PSF. This 

164# type needs to be very clearly documented and named as a *serialization* 

165# model, since there are many other kinds of models in play in this package. 

166# 

167# - We do not lift or document types used in that outermost model, but we do 

168# not give them leading underscores, since they aren't really private. 

169# 

170# - Other utility types do get leading underscores. 

171 

172 

173# Piff serialization uses a lot of dictionaries and lists restricted to these 

174# basic types. 

175type PiffScalar = int | float | str | bool | None 

176type PiffValue = PiffScalar | list[PiffValue] 

177type PiffDict = dict[str, PiffValue] 

178 

179 

180class GalSimPixelScaleModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

181 """Model used to serialize `galsim.wcs.PixelScale` instances.""" 

182 

183 scale: float 

184 wcs_type: Literal["pixel_scale"] = "pixel_scale" 

185 

186 

187# We expect this discriminated union to grow to include other trivial 

188# pixel-to-pixel transforms that get embedded in PSFs. If we someday have to 

189# store Piff objects that embed more sophisticated PSFs, we'll hook them into 

190# the AST-based coordinate transform system instead, but as long as we're just 

191# talking about simple offsets and scalings, that's a lot of extra complexity 

192# for very little gain. 

193type GalSimLocalWcsModel = Annotated[GalSimPixelScaleModel, pydantic.Field(discriminator="wcs_type")] 

194 

195 

196class PiffTableModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

197 """Serialization model used to embed a reference to a binary-data table in 

198 a Piff serialization's JSON-like data. 

199 """ 

200 

201 metadata: PiffDict 

202 table: serialization.TableModel 

203 

204 

205class PiffObjectModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

206 """General-purpose serialization model used for various Piff objects.""" 

207 

208 structs: dict[str, PiffDict] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

209 tables: dict[str, PiffTableModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

210 wcs: dict[str, GalSimLocalWcsModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

211 objects: dict[str, PiffObjectModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

212 

213 

214class PiffSerializationModel(serialization.ArchiveTree): 

215 """Serialization model for a Piff PSF.""" 

216 

217 piff: PiffObjectModel = pydantic.Field(description="The Piff PSF object itself.") 

218 

219 stamp_size: int = pydantic.Field( 

220 description="Width of the (square) images returned by this PSF's methods." 

221 ) 

222 

223 bounds: SerializableBounds = pydantic.Field( 

224 description="The bounds object that represents the PSF's validity region." 

225 ) 

226 

227 def deserialize(self, archive: serialization.InputArchive[Any]) -> PiffWrapper: 

228 """Deserialize the PSF from an archive. 

229 

230 This method is intended to be usable as the callback function passed to 

231 `.serialization.InputArchive.deserialize_pointer`. 

232 """ 

233 try: 

234 from piff import PSF 

235 from piff.config import PiffLogger 

236 except ImportError: 

237 raise ArchiveReadError("Failed to import piff.") from None 

238 

239 reader = _ArchivePiffReader(self.piff, archive) 

240 impl = PSF._read(reader, "piff", PiffLogger(_LOG)) 

241 return PiffWrapper(impl, bounds=self.bounds.deserialize(), stamp_size=self.stamp_size) 

242 

243 

244class _ArchivePiffWriter: 

245 """An adapter from the Piff serialization interface to the 

246 `.serialization.OutputArchive` class. 

247 

248 Notes 

249 ----- 

250 Piff has its own simple serialization framework (contributed upstream by 

251 Rubin DM) that maps everything to dictionaries, structured numpy arrays, 

252 and a library of GalSim WCS objects, with the native implementation writing 

253 standalone FITS files. That mostly maps nicely to the `lsst.images` 

254 archive system, but we don't get to leverage any Pydantic validation or 

255 JSON schema functionality since we only get opaque dictionaries from Piff. 

256 

257 See `piff.FitsWriter` for most method documentation; this class is designed 

258 to mimic it exactly (the Piff authors prefer to just use duck-typing rather 

259 than ABCs or protocols for interface definition). 

260 """ 

261 

262 def __init__(self, base_name: str = ""): 

263 self._base_name = base_name 

264 self.structs: dict[str, PiffDict] = {} 

265 self.tables: dict[str, tuple[np.ndarray, PiffDict]] = {} 

266 self.wcs_models: dict[str, GalSimLocalWcsModel] = {} 

267 self.writers: dict[str, _ArchivePiffWriter] = {} 

268 

269 def write_struct(self, name: str, struct: PiffDict) -> None: 

270 self.structs[name] = {k: self._to_builtin(v) for k, v in struct.items()} 

271 

272 def write_table(self, name: str, array: np.ndarray, metadata: PiffDict | None = None) -> None: 

273 self.tables[name] = ( 

274 array, 

275 {k: self._to_builtin(v) for k, v in (metadata or {}).items()}, 

276 ) 

277 

278 def write_wcs_map( 

279 self, name: str, wcs_map: dict[int, galsim.wcs.BaseWCS], pointing: galsim.CelestialCoord | None 

280 ) -> None: 

281 import galsim.wcs 

282 

283 match wcs_map: 

284 case {0: galsim.wcs.PixelScale() as wcs} if pointing is None: 

285 self.wcs_models[name] = GalSimPixelScaleModel(scale=wcs.scale) 

286 case _: 

287 raise NotImplementedError("PSFs with complex embedded WCSs are not supported.") 

288 

289 @contextmanager 

290 def nested(self, name: str) -> Iterator[_ArchivePiffWriter]: 

291 nested = _ArchivePiffWriter(self.get_full_name(name)) 

292 yield nested 

293 self.writers[name] = nested 

294 

295 def get_full_name(self, name: str) -> str: 

296 return f"{self._base_name}/{name}" 

297 

298 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffObjectModel: 

299 """Serialize to an archive. 

300 

301 This method is intended to be used as the callable passed to 

302 `.serialization.OutputArchive.serialize_direct` and 

303 `.serialization.OutputArchive.serialize_pointer`, after first passing 

304 this writer to a Piff object's ``write`` or ``_write`` method. 

305 """ 

306 model = PiffObjectModel() 

307 for name, struct in self.structs.items(): 

308 model.structs[name] = struct 

309 for name, (array, metadata) in self.tables.items(): 

310 model.tables[name] = PiffTableModel( 

311 metadata=metadata, 

312 table=archive.add_structured_array( 

313 array, name=name, update_header=lambda header: header.update(metadata) 

314 ), 

315 ) 

316 for name, wcs_model in self.wcs_models.items(): 

317 model.wcs[name] = wcs_model 

318 for name, writer in self.writers.items(): 

319 model.objects[name] = archive.serialize_direct(name, writer.serialize) 

320 return model 

321 

322 @staticmethod 

323 def _to_builtin(val: Any) -> PiffValue: 

324 match val: 

325 case np.integer(): 

326 return int(val) 

327 case np.floating(): 

328 return float(val) 

329 case np.bool_(): 

330 return bool(val) 

331 case np.str_(): 

332 return str(val) 

333 case tuple() | list(): 

334 return [_ArchivePiffWriter._to_builtin(item) for item in val] 

335 return val 

336 

337 

338class _ArchivePiffReader: 

339 """An adapter from the Piff serialization interface to the 

340 `.serialization.InputArchive` class. 

341 

342 See `ArchivePiffWriter` for additional notes. 

343 """ 

344 

345 def __init__( 

346 self, object_model: PiffObjectModel, archive: serialization.InputArchive[Any], base_name: str = "" 

347 ): 

348 self._model = object_model 

349 self._archive = archive 

350 self._base_name = base_name 

351 

352 def read_struct(self, name: str) -> PiffDict | None: 

353 return self._model.structs.get(name) 

354 

355 def read_table(self, name: str, metadata: PiffDict | None = None) -> np.ndarray | None: 

356 table_model = self._model.tables.get(name) 

357 if table_model is None: 

358 return None 

359 if metadata is not None: 

360 metadata.update(table_model.metadata) 

361 return self._archive.get_structured_array( 

362 table_model.table, strip_header=astropy.io.fits.Header.clear 

363 ) 

364 

365 def read_wcs_map( 

366 self, name: str, logger: piff.config.LoggerWrapper 

367 ) -> tuple[dict[int, galsim.wcs.BaseWCS] | None, galsim.CelestialCoord | None]: 

368 import galsim.wcs 

369 

370 match self._model.wcs.get(name): 

371 case GalSimPixelScaleModel(scale=scale): 

372 return {0: galsim.wcs.PixelScale(scale)}, None 

373 case None: 

374 return None, None 

375 case unexpected: 

376 raise serialization.ArchiveReadError( 

377 f"{self.get_full_name(name)} should be a WCS or WCS map, not {unexpected!r}." 

378 ) 

379 

380 @contextmanager 

381 def nested(self, name: str) -> Iterator[_ArchivePiffReader]: 

382 nested_model = self._model.objects[name] 

383 yield _ArchivePiffReader(nested_model, self._archive, self.get_full_name(name)) 

384 

385 def get_full_name(self, name: str) -> str: 

386 return f"{self._base_name}/{name}"