Coverage for python/lsst/images/psfs/_piff.py: 36%

189 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 08:08 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14from lsst.images.serialization import ArchiveReadError 

15 

16__all__ = ("PiffSerializationModel", "PiffWrapper") 

17 

18import operator 

19from collections.abc import Iterator 

20from contextlib import contextmanager 

21from functools import cached_property 

22from logging import getLogger 

23from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal 

24 

25import astropy.io.fits 

26import numpy as np 

27import pydantic 

28 

29from .. import serialization 

30from .._concrete_bounds import SerializableBounds 

31from .._geom import Bounds, Box 

32from .._image import Image 

33from ..utils import round_half_up 

34from ._base import PointSpreadFunction 

35 

36if TYPE_CHECKING: 

37 import galsim.wcs 

38 import piff 

39 import piff.config 

40 

41 try: 

42 from lsst.meas.extensions.piff.piffPsf import PiffPsf as LegacyPiffPsf 

43 except ImportError: 

44 type LegacyPiffPsf = Any # type: ignore[no-redef] 

45 

46 

47_LOG = getLogger(__name__) 

48 

49 

50class PiffWrapper(PointSpreadFunction): 

51 """A PSF model backed by the Piff library. 

52 

53 Parameters 

54 ---------- 

55 impl 

56 The Piff PSF object to wrap. 

57 bounds 

58 The pixel-coordinate region where the model can safely be evaluated. 

59 """ 

60 

61 def __init__(self, impl: piff.PSF, bounds: Bounds, stamp_size: int): 

62 self._impl = impl 

63 self._bounds = bounds 

64 self._stamp_size = stamp_size 

65 

66 @property 

67 def bounds(self) -> Bounds: 

68 return self._bounds 

69 

70 @cached_property 

71 def kernel_bbox(self) -> Box: 

72 r = self._stamp_size // 2 

73 return Box.factory[-r : r + 1, -r : r + 1] 

74 

75 def compute_kernel_image(self, *, x: float, y: float) -> Image: 

76 if "colorValue" in self._impl.interp_property_names: 

77 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

78 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=True) 

79 r = self._stamp_size // 2 

80 result = Image(gs_image.array.copy(), start=(-r, -r)) 

81 result.array /= np.sum(result.array) 

82 return result 

83 

84 def compute_stellar_image(self, *, x: float, y: float) -> Image: 

85 if "colorValue" in self._impl.interp_property_names: 

86 raise NotImplementedError("Chromatic PSFs are not yet supported.") 

87 gs_image = self._impl.draw(x, y, stamp_size=self._stamp_size, center=None) 

88 r = self._stamp_size // 2 

89 result = Image(gs_image.array.copy(), start=(round_half_up(y) - r, round_half_up(x) - r)) 

90 result.array /= np.sum(result.array) 

91 return result 

92 

93 def compute_stellar_bbox(self, *, x: float, y: float) -> Box: 

94 r = self._stamp_size // 2 

95 xi = round_half_up(x) 

96 yi = round_half_up(y) 

97 return Box.factory[yi - r : yi + r + 1, xi - r : xi + r + 1] 

98 

99 @property 

100 def piff_psf(self) -> piff.PSF: 

101 """The backing `piff.PSF` object. 

102 

103 This is an internal object that must not be modified in place. 

104 """ 

105 return self._impl 

106 

107 @classmethod 

108 def from_legacy(cls, legacy_psf: LegacyPiffPsf, bounds: Bounds) -> PiffWrapper: 

109 return cls(impl=legacy_psf._piffResult, bounds=bounds, stamp_size=int(legacy_psf.width)) 

110 

111 def to_legacy(self) -> LegacyPiffPsf: 

112 """Convert to a legacy `lsst.meas.extensions.piff.piffPsf`.""" 

113 from lsst.meas.extensions.piff.piffPsf import PiffPsf as LegacyPiffPsf 

114 

115 return LegacyPiffPsf(self._stamp_size, self._stamp_size, self._impl) 

116 

117 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffSerializationModel: 

118 """Serialize the PSF to an archive. 

119 

120 This method is intended to be usable as the callback function passed to 

121 `.serialization.OutputArchive.serialize_direct` or 

122 `.serialization.OutputArchive.serialize_pointer`. 

123 """ 

124 from piff.config import PiffLogger 

125 

126 writer = _ArchivePiffWriter() 

127 with self._without_stars(): 

128 self._impl._write(writer, "piff", PiffLogger(_LOG)) 

129 piff_model = writer.serialize(archive) 

130 return PiffSerializationModel( 

131 piff=piff_model, 

132 stamp_size=self._stamp_size, 

133 bounds=self._bounds.serialize(), 

134 ) 

135 

136 @staticmethod 

137 def _get_archive_tree_type( 

138 pointer_type: type[pydantic.BaseModel], 

139 ) -> type[PiffSerializationModel]: 

140 """Return the serialization model type for this object for an archive 

141 type that uses the given pointer type. 

142 """ 

143 return PiffSerializationModel 

144 

145 @contextmanager 

146 def _without_stars(self) -> Iterator[None]: 

147 """Temporarily drop the embedded list of stars used to fit the PSF. 

148 

149 Notes 

150 ----- 

151 By default Piff saves the list of stars (including postage stamps) used 

152 to fit the PSF, which makes the serialized form much larger. But the 

153 upstream Piff serialization code recognizes the case where that 

154 ``stars`` attribute has been deleted and serializes everything else. 

155 

156 Unfortunately, to date, Rubin's pickle-based Piff serialization instead 

157 just deletes the postage stamp image attributes from inside the Piff 

158 ``stars`` list, which is not a state the Piff serialization code 

159 handles gracefully. So for now we have to drop the full stars list 

160 during serialization if it is present. 

161 """ 

162 if hasattr(self._impl, "stars"): 

163 stars = self._impl.stars 

164 try: 

165 del self._impl.stars 

166 yield 

167 finally: 

168 self._impl.stars = stars 

169 else: 

170 yield 

171 

172 

173# Conventions on public visibility of the serialization types: 

174# 

175# - We lift and document the outermost Pydantic model type, since that needs to 

176# be included directly in the Pydantic models of types that hold a PSF. This 

177# type needs to be very clearly documented and named as a *serialization* 

178# model, since there are many other kinds of models in play in this package. 

179# 

180# - We do not lift or document types used in that outermost model, but we do 

181# not give them leading underscores, since they aren't really private. 

182# 

183# - Other utility types do get leading underscores. 

184 

185 

186# Piff serialization uses a lot of dictionaries and lists restricted to these 

187# basic types. 

188type PiffScalar = int | float | str | bool | None 

189type PiffValue = PiffScalar | list[PiffValue] 

190type PiffDict = dict[str, PiffValue] 

191 

192 

193class GalSimPixelScaleModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

194 """Model used to serialize `galsim.wcs.PixelScale` instances.""" 

195 

196 scale: float 

197 wcs_type: Literal["pixel_scale"] = "pixel_scale" 

198 

199 

200# We expect this discriminated union to grow to include other trivial 

201# pixel-to-pixel transforms that get embedded in PSFs. If we someday have to 

202# store Piff objects that embed more sophisticated PSFs, we'll hook them into 

203# the AST-based coordinate transform system instead, but as long as we're just 

204# talking about simple offsets and scalings, that's a lot of extra complexity 

205# for very little gain. 

206type GalSimLocalWcsModel = Annotated[GalSimPixelScaleModel, pydantic.Field(discriminator="wcs_type")] 

207 

208 

209class PiffTableModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

210 """Serialization model used to embed a reference to a binary-data table in 

211 a Piff serialization's JSON-like data. 

212 """ 

213 

214 metadata: PiffDict 

215 table: serialization.TableModel 

216 

217 

218class PiffObjectModel(pydantic.BaseModel, ser_json_inf_nan="constants"): 

219 """General-purpose serialization model used for various Piff objects.""" 

220 

221 structs: dict[str, PiffDict] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

222 tables: dict[str, PiffTableModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

223 wcs: dict[str, GalSimLocalWcsModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

224 objects: dict[str, PiffObjectModel] = pydantic.Field(default_factory=dict, exclude_if=operator.not_) 

225 

226 

227class PiffSerializationModel(serialization.ArchiveTree): 

228 """Serialization model for a Piff PSF.""" 

229 

230 SCHEMA_NAME: ClassVar[str] = "piff_psf" 

231 SCHEMA_VERSION: ClassVar[str] = "1.0.0" 

232 MIN_READ_VERSION: ClassVar[int] = 1 

233 

234 piff: PiffObjectModel = pydantic.Field(description="The Piff PSF object itself.") 

235 

236 stamp_size: int = pydantic.Field( 

237 description="Width of the (square) images returned by this PSF's methods." 

238 ) 

239 

240 bounds: SerializableBounds = pydantic.Field( 

241 description="The bounds object that represents the PSF's validity region." 

242 ) 

243 

244 def deserialize(self, archive: serialization.InputArchive[Any], **kwargs: Any) -> PiffWrapper: 

245 """Deserialize the PSF from an archive. 

246 

247 This method is intended to be usable as the callback function passed to 

248 `.serialization.InputArchive.deserialize_pointer`. 

249 """ 

250 if kwargs: 

251 raise serialization.InvalidParameterError( 

252 f"Unrecognized parameters for PiffWrapper: {set(kwargs.keys())}." 

253 ) 

254 try: 

255 from piff import PSF 

256 from piff.config import PiffLogger 

257 except ImportError: 

258 raise ArchiveReadError("Failed to import piff.") from None 

259 

260 reader = _ArchivePiffReader(self.piff, archive) 

261 impl = PSF._read(reader, "piff", PiffLogger(_LOG)) 

262 return PiffWrapper(impl, bounds=self.bounds.deserialize(), stamp_size=self.stamp_size) 

263 

264 

265class _ArchivePiffWriter: 

266 """An adapter from the Piff serialization interface to the 

267 `.serialization.OutputArchive` class. 

268 

269 Notes 

270 ----- 

271 Piff has its own simple serialization framework (contributed upstream by 

272 Rubin DM) that maps everything to dictionaries, structured numpy arrays, 

273 and a library of GalSim WCS objects, with the native implementation writing 

274 standalone FITS files. That mostly maps nicely to the `lsst.images` 

275 archive system, but we don't get to leverage any Pydantic validation or 

276 JSON schema functionality since we only get opaque dictionaries from Piff. 

277 

278 See `piff.FitsWriter` for most method documentation; this class is designed 

279 to mimic it exactly (the Piff authors prefer to just use duck-typing rather 

280 than ABCs or protocols for interface definition). 

281 """ 

282 

283 def __init__(self, base_name: str = ""): 

284 self._base_name = base_name 

285 self.structs: dict[str, PiffDict] = {} 

286 self.tables: dict[str, tuple[np.ndarray, PiffDict]] = {} 

287 self.wcs_models: dict[str, GalSimLocalWcsModel] = {} 

288 self.writers: dict[str, _ArchivePiffWriter] = {} 

289 

290 def write_struct(self, name: str, struct: PiffDict) -> None: 

291 self.structs[name] = {k: self._to_builtin(v) for k, v in struct.items()} 

292 

293 def write_table(self, name: str, array: np.ndarray, metadata: PiffDict | None = None) -> None: 

294 self.tables[name] = ( 

295 array, 

296 {k: self._to_builtin(v) for k, v in (metadata or {}).items()}, 

297 ) 

298 

299 def write_wcs_map( 

300 self, name: str, wcs_map: dict[int, galsim.wcs.BaseWCS], pointing: galsim.CelestialCoord | None 

301 ) -> None: 

302 import galsim.wcs 

303 

304 match wcs_map: 

305 case {0: galsim.wcs.PixelScale() as wcs} if pointing is None: 

306 self.wcs_models[name] = GalSimPixelScaleModel(scale=wcs.scale) 

307 case _: 

308 raise NotImplementedError("PSFs with complex embedded WCSs are not supported.") 

309 

310 @contextmanager 

311 def nested(self, name: str) -> Iterator[_ArchivePiffWriter]: 

312 nested = _ArchivePiffWriter(self.get_full_name(name)) 

313 yield nested 

314 self.writers[name] = nested 

315 

316 def get_full_name(self, name: str) -> str: 

317 return f"{self._base_name}/{name}" 

318 

319 def serialize(self, archive: serialization.OutputArchive[Any]) -> PiffObjectModel: 

320 """Serialize to an archive. 

321 

322 This method is intended to be used as the callable passed to 

323 `.serialization.OutputArchive.serialize_direct` and 

324 `.serialization.OutputArchive.serialize_pointer`, after first passing 

325 this writer to a Piff object's ``write`` or ``_write`` method. 

326 """ 

327 model = PiffObjectModel() 

328 for name, struct in self.structs.items(): 

329 model.structs[name] = struct 

330 for name, (array, metadata) in self.tables.items(): 

331 model.tables[name] = PiffTableModel( 

332 metadata=metadata, 

333 table=archive.add_structured_array( 

334 array, name=name, update_header=lambda header: header.update(metadata) 

335 ), 

336 ) 

337 for name, wcs_model in self.wcs_models.items(): 

338 model.wcs[name] = wcs_model 

339 for name, writer in self.writers.items(): 

340 model.objects[name] = archive.serialize_direct(name, writer.serialize) 

341 return model 

342 

343 @staticmethod 

344 def _to_builtin(val: Any) -> PiffValue: 

345 match val: 

346 case np.integer(): 

347 return int(val) 

348 case np.floating(): 

349 return float(val) 

350 case np.bool_(): 

351 return bool(val) 

352 case np.str_(): 

353 return str(val) 

354 case tuple() | list(): 

355 return [_ArchivePiffWriter._to_builtin(item) for item in val] 

356 return val 

357 

358 

359class _ArchivePiffReader: 

360 """An adapter from the Piff serialization interface to the 

361 `.serialization.InputArchive` class. 

362 

363 See `ArchivePiffWriter` for additional notes. 

364 """ 

365 

366 def __init__( 

367 self, object_model: PiffObjectModel, archive: serialization.InputArchive[Any], base_name: str = "" 

368 ): 

369 self._model = object_model 

370 self._archive = archive 

371 self._base_name = base_name 

372 

373 def read_struct(self, name: str) -> PiffDict | None: 

374 return self._model.structs.get(name) 

375 

376 def read_table(self, name: str, metadata: PiffDict | None = None) -> np.ndarray | None: 

377 table_model = self._model.tables.get(name) 

378 if table_model is None: 

379 return None 

380 if metadata is not None: 

381 metadata.update(table_model.metadata) 

382 return self._archive.get_structured_array( 

383 table_model.table, strip_header=astropy.io.fits.Header.clear 

384 ) 

385 

386 def read_wcs_map( 

387 self, name: str, logger: piff.config.LoggerWrapper 

388 ) -> tuple[dict[int, galsim.wcs.BaseWCS] | None, galsim.CelestialCoord | None]: 

389 import galsim.wcs 

390 

391 match self._model.wcs.get(name): 

392 case GalSimPixelScaleModel(scale=scale): 

393 return {0: galsim.wcs.PixelScale(scale)}, None 

394 case None: 

395 return None, None 

396 case unexpected: 

397 raise serialization.ArchiveReadError( 

398 f"{self.get_full_name(name)} should be a WCS or WCS map, not {unexpected!r}." 

399 ) 

400 

401 @contextmanager 

402 def nested(self, name: str) -> Iterator[_ArchivePiffReader]: 

403 nested_model = self._model.objects[name] 

404 yield _ArchivePiffReader(nested_model, self._archive, self.get_full_name(name)) 

405 

406 def get_full_name(self, name: str) -> str: 

407 return f"{self._base_name}/{name}"