Coverage for python / lsst / images / psfs / _piff.py: 35%

186 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-23 08:25 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14from lsst.images.serialization import ArchiveReadError 

15 

16__all__ = ("PiffSerializationModel", "PiffWrapper") 

17 

18import operator 

19from collections.abc import Iterator 

20from contextlib import contextmanager 

21from functools import cached_property 

22from logging import getLogger 

23from typing import TYPE_CHECKING, Annotated, Any, Literal 

24 

25import astropy.io.fits 

26import numpy as np 

27import pydantic 

28 

29from .. import serialization 

30from .._concrete_bounds import SerializableBounds 

31from .._geom import Bounds, Box 

32from .._image import Image 

33from ..utils import round_half_up 

34from ._base import PointSpreadFunction 

35 

36if TYPE_CHECKING: 

37 import galsim.wcs 

38 import piff 

39 import piff.config 

40 

41 try: 

42 from lsst.meas.extensions.piff.piffPsf import PiffPsf as LegacyPiffPsf 

43 except ImportError: 

44 type LegacyPiffPsf = Any # type: ignore[no-redef] 

45 

46 

47_LOG = getLogger(__name__) 

48 

49 

50class PiffWrapper(PointSpreadFunction): 

51 """A PSF model backed by the Piff library. 

52 

53 Parameters 

54 ---------- 

55 impl 

56 The Piff PSF object to wrap. 

57 bounds 

58 The pixel-coordinate region where the model can safely be evaluated. 

59 """ 

60 

61 def __init__(self, impl: piff.PSF, bounds: Bounds, stamp_size: int): 

62 self._impl = impl 

63 self._bounds = bounds 

64 self._stamp_size = stamp_size 

65 

66 @property 

67 def bounds(self) -> Bounds: 

68 return self._bounds 

69 

70 @cached_property 

71 def kernel_bbox(self) -> Box: 

72 r = self._stamp_size // 2 

73 return Box.factory[-r : r + 1, -r : r + 1] 

74 

75 def compute_kernel_image(self, *, x: float, y: float) -> Image: 

76 if "colorValue" in self._impl.interp_property_names: 

77 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

78 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=True) 

79 r = self._stamp_size // 2 

80 result = Image(gs_image.array.copy(), start=(-r, -r)) 

81 result.array /= np.sum(result.array) 

82 return result 

83 

84 def compute_stellar_image(self, *, x: float, y: float) -> Image: 

85 if "colorValue" in self._impl.interp_property_names: 

86 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

87 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=None) 

88 r = self._stamp_size // 2 

89 result = Image(gs_image.array.copy(), start=(round_half_up(y) - r, round_half_up(x) - r)) 

90 result.array /= np.sum(result.array) 

91 return result 

92 

93 def compute_stellar_bbox(self, *, x: float, y: float) -> Box: 

94 r = self._stamp_size // 2 

95 xi = round_half_up(x) 

96 yi = round_half_up(y) 

97 return Box.factory[yi - r : yi + r + 1, xi - r : xi + r + 1] 

98 

99 @property 

100 def piff_psf(self) -> piff.PSF: 

101 """The backing `piff.PSF` object. 

102 

103 This is an internal object that must not be modified in place. 

104 """ 

105 return self._impl 

106 

107 @classmethod 

108 def from_legacy(cls, legacy_psf: LegacyPiffPsf, bounds: Bounds) -> PiffWrapper: 

109 return cls(impl=legacy_psf._piffResult, bounds=bounds, stamp_size=int(legacy_psf.width)) 

110 

111 def to_legacy(self) -> LegacyPiffPsf: 

112 """Convert to a legacy `lsst.meas.extensions.piff.piffPsf`.""" 

113 from lsst.meas.extensions.piff.piffPsf import PiffPsf as LegacyPiffPsf 

114 

115 return LegacyPiffPsf(self._stamp_size, self._stamp_size, self._impl) 

116 

117 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffSerializationModel: 

118 """Serialize the PSF to an archive. 

119 

120 This method is intended to be usable as the callback function passed to 

121 `.serialization.OutputArchive.serialize_direct` or 

122 `.serialization.OutputArchive.serialize_pointer`. 

123 """ 

124 from piff.config import PiffLogger 

125 

126 writer = _ArchivePiffWriter() 

127 with self._without_stars(): 

128 self._impl._write(writer, "piff", PiffLogger(_LOG)) 

129 piff_model = writer.serialize(archive) 

130 return PiffSerializationModel( 

131 piff=piff_model, 

132 stamp_size=self._stamp_size, 

133 bounds=self._bounds.serialize(), 

134 ) 

135 

136 @staticmethod 

137 def _get_archive_tree_type( 

138 pointer_type: type[pydantic.BaseModel], 

139 ) -> type[PiffSerializationModel]: 

140 """Return the serialization model type for this object for an archive 

141 type that uses the given pointer type. 

142 """ 

143 return PiffSerializationModel 

144 

145 @contextmanager 

146 def _without_stars(self) -> Iterator[None]: 

147 """Temporarily drop the embedded list of stars used to fit the PSF. 

148 

149 Notes 

150 ----- 

151 By default Piff saves the list of stars (including postage stamps) used 

152 to fit the PSF, which makes the serialized form much larger. But the 

153 upstream Piff serialization code recognizes the case where that 

154 ``stars`` attribute has been deleted and serializes everything else. 

155 

156 Unfortunately, to date, Rubin's pickle-based Piff serialization instead 

157 just deletes the postage stamp image attributes from inside the Piff 

158 ``stars`` list, which is not a state the Piff serialization code 

159 handles gracefully. So for now we have to drop the full stars list 

160 during serialization if it is present. 

161 """ 

162 if hasattr(self._impl, "stars"): 

163 stars = self._impl.stars 

164 try: 

165 del self._impl.stars 

166 yield 

167 finally: 

168 self._impl.stars = stars 

169 else: 

170 yield 

171 

172 

173# Conventions on public visibility of the serialization types: 

174# 

175# - We lift and document the outermost Pydantic model type, since that needs to 

176# be included directly in the Pydantic models of types that hold a PSF. This 

177# type needs to be very clearly documented and named as a *serialization* 

178# model, since there are many other kinds of models in play in this package. 

179# 

180# - We do not lift or document types used in that outermost model, but we do 

181# not give them leading underscores, since they aren't really private. 

182# 

183# - Other utility types do get leading underscores. 

184 

185 

186# Piff serialization uses a lot of dictionaries and lists restricted to these 

187# basic types. 

188type PiffScalar = int | float | str | bool | None 

189type PiffValue = PiffScalar | list[PiffValue] 

190type PiffDict = dict[str, PiffValue] 

191 

192 

193class GalSimPixelScaleModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

194 """Model used to serialize `galsim.wcs.PixelScale` instances.""" 

195 

196 scale: float 

197 wcs_type: Literal["pixel_scale"] = "pixel_scale" 

198 

199 

200# We expect this discriminated union to grow to include other trivial 

201# pixel-to-pixel transforms that get embedded in PSFs. If we someday have to 

202# store Piff objects that embed more sophisticated PSFs, we'll hook them into 

203# the AST-based coordinate transform system instead, but as long as we're just 

204# talking about simple offsets and scalings, that's a lot of extra complexity 

205# for very little gain. 

206type GalSimLocalWcsModel = Annotated[GalSimPixelScaleModel, pydantic.Field(discriminator="wcs_type")] 

207 

208 

209class PiffTableModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

210 """Serialization model used to embed a reference to a binary-data table in 

211 a Piff serialization's JSON-like data. 

212 """ 

213 

214 metadata: PiffDict 

215 table: serialization.TableModel 

216 

217 

218class PiffObjectModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

219 """General-purpose serialization model used for various Piff objects.""" 

220 

221 structs: dict[str, PiffDict] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

222 tables: dict[str, PiffTableModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

223 wcs: dict[str, GalSimLocalWcsModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

224 objects: dict[str, PiffObjectModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

225 

226 

227class PiffSerializationModel(serialization.ArchiveTree): 

228 """Serialization model for a Piff PSF.""" 

229 

230 piff: PiffObjectModel = pydantic.Field(description="The Piff PSF object itself.") 

231 

232 stamp_size: int = pydantic.Field( 

233 description="Width of the (square) images returned by this PSF's methods." 

234 ) 

235 

236 bounds: SerializableBounds = pydantic.Field( 

237 description="The bounds object that represents the PSF's validity region." 

238 ) 

239 

240 def deserialize(self, archive: serialization.InputArchive[Any], **kwargs: Any) -> PiffWrapper: 

241 """Deserialize the PSF from an archive. 

242 

243 This method is intended to be usable as the callback function passed to 

244 `.serialization.InputArchive.deserialize_pointer`. 

245 """ 

246 if kwargs: 

247 raise serialization.InvalidParameterError( 

248 f"Unrecognized parameters for PiffWrapper: {set(kwargs.keys())}." 

249 ) 

250 try: 

251 from piff import PSF 

252 from piff.config import PiffLogger 

253 except ImportError: 

254 raise ArchiveReadError("Failed to import piff.") from None 

255 

256 reader = _ArchivePiffReader(self.piff, archive) 

257 impl = PSF._read(reader, "piff", PiffLogger(_LOG)) 

258 return PiffWrapper(impl, bounds=self.bounds.deserialize(), stamp_size=self.stamp_size) 

259 

260 

261class _ArchivePiffWriter: 

262 """An adapter from the Piff serialization interface to the 

263 `.serialization.OutputArchive` class. 

264 

265 Notes 

266 ----- 

267 Piff has its own simple serialization framework (contributed upstream by 

268 Rubin DM) that maps everything to dictionaries, structured numpy arrays, 

269 and a library of GalSim WCS objects, with the native implementation writing 

270 standalone FITS files. That mostly maps nicely to the `lsst.images` 

271 archive system, but we don't get to leverage any Pydantic validation or 

272 JSON schema functionality since we only get opaque dictionaries from Piff. 

273 

274 See `piff.FitsWriter` for most method documentation; this class is designed 

275 to mimic it exactly (the Piff authors prefer to just use duck-typing rather 

276 than ABCs or protocols for interface definition). 

277 """ 

278 

279 def __init__(self, base_name: str = ""): 

280 self._base_name = base_name 

281 self.structs: dict[str, PiffDict] = {} 

282 self.tables: dict[str, tuple[np.ndarray, PiffDict]] = {} 

283 self.wcs_models: dict[str, GalSimLocalWcsModel] = {} 

284 self.writers: dict[str, _ArchivePiffWriter] = {} 

285 

286 def write_struct(self, name: str, struct: PiffDict) -> None: 

287 self.structs[name] = {k: self._to_builtin(v) for k, v in struct.items()} 

288 

289 def write_table(self, name: str, array: np.ndarray, metadata: PiffDict | None = None) -> None: 

290 self.tables[name] = ( 

291 array, 

292 {k: self._to_builtin(v) for k, v in (metadata or {}).items()}, 

293 ) 

294 

295 def write_wcs_map( 

296 self, name: str, wcs_map: dict[int, galsim.wcs.BaseWCS], pointing: galsim.CelestialCoord | None 

297 ) -> None: 

298 import galsim.wcs 

299 

300 match wcs_map: 

301 case {0: galsim.wcs.PixelScale() as wcs} if pointing is None: 

302 self.wcs_models[name] = GalSimPixelScaleModel(scale=wcs.scale) 

303 case _: 

304 raise NotImplementedError("PSFs with complex embedded WCSs are not supported.") 

305 

306 @contextmanager 

307 def nested(self, name: str) -> Iterator[_ArchivePiffWriter]: 

308 nested = _ArchivePiffWriter(self.get_full_name(name)) 

309 yield nested 

310 self.writers[name] = nested 

311 

312 def get_full_name(self, name: str) -> str: 

313 return f"{self._base_name}/{name}" 

314 

315 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffObjectModel: 

316 """Serialize to an archive. 

317 

318 This method is intended to be used as the callable passed to 

319 `.serialization.OutputArchive.serialize_direct` and 

320 `.serialization.OutputArchive.serialize_pointer`, after first passing 

321 this writer to a Piff object's ``write`` or ``_write`` method. 

322 """ 

323 model = PiffObjectModel() 

324 for name, struct in self.structs.items(): 

325 model.structs[name] = struct 

326 for name, (array, metadata) in self.tables.items(): 

327 model.tables[name] = PiffTableModel( 

328 metadata=metadata, 

329 table=archive.add_structured_array( 

330 array, name=name, update_header=lambda header: header.update(metadata) 

331 ), 

332 ) 

333 for name, wcs_model in self.wcs_models.items(): 

334 model.wcs[name] = wcs_model 

335 for name, writer in self.writers.items(): 

336 model.objects[name] = archive.serialize_direct(name, writer.serialize) 

337 return model 

338 

339 @staticmethod 

340 def _to_builtin(val: Any) -> PiffValue: 

341 match val: 

342 case np.integer(): 

343 return int(val) 

344 case np.floating(): 

345 return float(val) 

346 case np.bool_(): 

347 return bool(val) 

348 case np.str_(): 

349 return str(val) 

350 case tuple() | list(): 

351 return [_ArchivePiffWriter._to_builtin(item) for item in val] 

352 return val 

353 

354 

355class _ArchivePiffReader: 

356 """An adapter from the Piff serialization interface to the 

357 `.serialization.InputArchive` class. 

358 

359 See `ArchivePiffWriter` for additional notes. 

360 """ 

361 

362 def __init__( 

363 self, object_model: PiffObjectModel, archive: serialization.InputArchive[Any], base_name: str = "" 

364 ): 

365 self._model = object_model 

366 self._archive = archive 

367 self._base_name = base_name 

368 

369 def read_struct(self, name: str) -> PiffDict | None: 

370 return self._model.structs.get(name) 

371 

372 def read_table(self, name: str, metadata: PiffDict | None = None) -> np.ndarray | None: 

373 table_model = self._model.tables.get(name) 

374 if table_model is None: 

375 return None 

376 if metadata is not None: 

377 metadata.update(table_model.metadata) 

378 return self._archive.get_structured_array( 

379 table_model.table, strip_header=astropy.io.fits.Header.clear 

380 ) 

381 

382 def read_wcs_map( 

383 self, name: str, logger: piff.config.LoggerWrapper 

384 ) -> tuple[dict[int, galsim.wcs.BaseWCS] | None, galsim.CelestialCoord | None]: 

385 import galsim.wcs 

386 

387 match self._model.wcs.get(name): 

388 case GalSimPixelScaleModel(scale=scale): 

389 return {0: galsim.wcs.PixelScale(scale)}, None 

390 case None: 

391 return None, None 

392 case unexpected: 

393 raise serialization.ArchiveReadError( 

394 f"{self.get_full_name(name)} should be a WCS or WCS map, not {unexpected!r}." 

395 ) 

396 

397 @contextmanager 

398 def nested(self, name: str) -> Iterator[_ArchivePiffReader]: 

399 nested_model = self._model.objects[name] 

400 yield _ArchivePiffReader(nested_model, self._archive, self.get_full_name(name)) 

401 

402 def get_full_name(self, name: str) -> str: 

403 return f"{self._base_name}/{name}"