Coverage for python/lsst/images/cells/_psf.py: 28%
116 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-27 08:29 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-27 08:29 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("CellPointSpreadFunction", "CellPointSpreadFunctionSerializationModel")
16from functools import cached_property
17from typing import TYPE_CHECKING, Any, Literal, overload
19import numpy as np
20import pydantic
22from .._cell_grid import CellGrid, CellGridBounds, CellIJ
23from .._geom import YX, Bounds, BoundsError, Box
24from .._image import Image
25from ..psfs import PointSpreadFunction
26from ..serialization import (
27 ArchiveTree,
28 ArrayReferenceModel,
29 InputArchive,
30 InvalidParameterError,
31 OutputArchive,
32)
33from ..utils import round_half_up
35if TYPE_CHECKING:
36 try:
37 from lsst.cell_coadds import StitchedPsf
38 except ImportError:
39 type StitchedPsf = Any # type: ignore[no-redef]
42class CellPointSpreadFunction(PointSpreadFunction):
43 """A PSF model that is at least approximately constant over cells.
45 Parameters
46 ----------
47 array
48 A 4-d array of PSF kernel images with with shape
49 ``(n_cells_y, n_cells_x, psf_shape_y, psf_shape_x)``.
50 bounds
51 Description of the cell grid and any missing cells. Array entries for
52 missing cells should be NaN.
53 resampling_kernel
54 Name of the resampling kernel to use when shifting the kernel image
55 into the stellar image.
57 Notes
58 -----
59 Unlike most PSF model types, `CellPointSpreadFunction` can be subset via
60 slicing:
62 - a bounding `.Box` for a subimage, which returns a new PSF with only the
63 cells that cover that subimage;
64 - a `CellIJ` index, which returns the kernel image for that cell.
65 -
66 """
68 def __init__(
69 self,
70 array: np.ndarray,
71 bounds: CellGridBounds,
72 resampling_kernel: Literal["lanczos3", "lanczos5"] = "lanczos5",
73 ):
74 self._array = array
75 self._bounds: CellGridBounds = bounds
76 self._resampling_kernel = resampling_kernel
78 @property
79 def grid(self) -> CellGrid:
80 """The grid that defines the PSF's cells (`CellGrid`).
82 Notes
83 -----
84 This is usually (but is not guaranteed to be) the grid for a full
85 patch, even when the PSF only covers a subimage.
86 """
87 return self._bounds.grid
89 @property
90 def bounds(self) -> CellGridBounds:
91 """The bounds where the PSF can be evaluated (`CellGridBounds`)."""
92 return self._bounds
94 @cached_property
95 def kernel_bbox(self) -> Box:
96 sy, sx = self._array.shape[2:]
97 ry = sy // 2
98 rx = sx // 2
99 return Box.factory[-ry : ry + 1, -rx : rx + 1]
101 @overload
102 def __getitem__(self, bbox: Box) -> CellPointSpreadFunction: ... 102 ↛ exitline 102 didn't return from function '__getitem__' because
103 @overload
104 def __getitem__(self, index: CellIJ) -> Image: ... 104 ↛ exitline 104 didn't return from function '__getitem__' because
106 def __getitem__(self, key: Box | CellIJ) -> CellPointSpreadFunction | Image:
107 match key:
108 case CellIJ():
109 if key in self._bounds.missing:
110 raise BoundsError(f"Cell {key} is missing for this PSF.")
111 index = key - self._bounds.grid_start
112 try:
113 return Image(self._array[index.i, index.j], bbox=self.kernel_bbox)
114 except IndexError:
115 raise BoundsError(f"Cell {key} is out of bounds for this PSF.")
116 case Box():
117 bounds, slices = self._subset_impl(self._bounds, key)
118 return CellPointSpreadFunction(self._array[slices.y, slices.x, ...].copy(), bounds=bounds)
119 case _:
120 raise TypeError("Invalid argument for CellPointSpreadFunction.__getitem__.")
122 def compute_kernel_image(self, *, x: float, y: float) -> Image:
123 index = self.grid.index_of(x=round(x), y=round(y))
124 try:
125 return self[index]
126 except Exception as err:
127 err.add_note(f"Evaluating cell PSF at x={x}, y={y}.")
128 raise
130 def compute_stellar_image(self, *, x: float, y: float) -> Image:
131 try:
132 from lsst.afw.math import offsetImage
133 from lsst.geom import Point2I
134 except ImportError as err:
135 err.add_note("CellPointSpreadFunction.compute_stellar_image cannot be used without lsst.afw.")
136 raise
137 ix = round_half_up(x)
138 dx = x - ix
139 iy = round_half_up(y)
140 dy = y - iy
141 kernel_image = self.compute_kernel_image(x=x, y=y)
142 if dx != 0 or dy != 0:
143 legacy_result = offsetImage(kernel_image.to_legacy(), dx, dy, self._resampling_kernel, 5)
144 else:
145 # This branch is equal to the other up to round-off error, but it's
146 # convenient nonetheless because it maintains exact compatibility
147 # with the legacy implementation, where the caching mechanism
148 # causes the offsetImage call to be skipped.
149 legacy_result = kernel_image.to_legacy()
150 legacy_result.setXY0(Point2I(legacy_result.getX0() + ix, legacy_result.getY0() + iy))
151 return Image.from_legacy(legacy_result)
153 def compute_stellar_bbox(self, *, x: float, y: float) -> Box:
154 # This is obviously inefficient, but it's what afw does, and hence the
155 # only easy way we've got to replicate what afw does.
156 return self.compute_stellar_image(x=x, y=y).bbox
158 def serialize(self, archive: OutputArchive[Any]) -> CellPointSpreadFunctionSerializationModel:
159 array_model = archive.add_array(self._array)
160 return CellPointSpreadFunctionSerializationModel(array=array_model, bounds=self.bounds)
162 @classmethod
163 def from_legacy(cls, legacy_psf: Any, bounds: Bounds | None = None) -> CellPointSpreadFunction:
164 # 'bounds' is accepted as an argument only for base-class
165 # compatibility; we always generate our own bounds.
166 from lsst.geom import Box2I
168 grid = CellGrid.from_legacy(legacy_psf.grid)
169 # Start with bounds that cover the entire grid.
170 bounds = CellGridBounds(grid=grid, bbox=grid.bbox)
171 # Shrink bounds to just the bbox where we have data.
172 legacy_bbox = Box2I()
173 for legacy_index in legacy_psf.images.keys():
174 legacy_bbox.include(legacy_psf.grid.bbox_of(legacy_index))
175 bounds = bounds[Box.from_legacy(legacy_bbox)]
176 # Allocate and populate the array.
177 psf_image_size_y, psf_image_size_x = legacy_psf.images.arbitrary.array.shape
178 array = np.zeros(
179 (
180 bounds.bbox.y.size // grid.cell_shape.y,
181 bounds.bbox.x.size // grid.cell_shape.x,
182 psf_image_size_y,
183 psf_image_size_x,
184 ),
185 dtype=np.float64,
186 )
187 missing: set[CellIJ] = set()
188 for cell_index in bounds.cell_indices():
189 legacy_index = cell_index.to_legacy()
190 array_index = cell_index - bounds.grid_start
191 if legacy_index in legacy_psf.images:
192 array[array_index.i, array_index.j] = legacy_psf.images[legacy_index].array
193 else:
194 array[array_index.i, array_index.j] = np.nan
195 missing.add(cell_index)
196 # Modify the bounds one last time to account for missing cells.
197 bounds = CellGridBounds(grid=grid, bbox=bounds.bbox, missing=frozenset(missing))
198 return cls(array, bounds=bounds)
200 @staticmethod
201 def _subset_impl(bounds: CellGridBounds, bbox: Box) -> tuple[CellGridBounds, YX[slice]]:
202 subset_bounds = bounds[bbox]
203 start = subset_bounds.grid_start - bounds.grid_start
204 stop = subset_bounds.grid_stop - bounds.grid_start
205 return subset_bounds, YX(y=slice(start.i, stop.i), x=slice(start.j, stop.j))
208class CellPointSpreadFunctionSerializationModel(ArchiveTree):
209 """Model used to serialize CellPointSpreadFunction objects."""
211 array: ArrayReferenceModel = pydantic.Field(
212 description=(
213 "A 4-d array of PSF kernel images with with shape "
214 "(n_cells_y, n_cells_x, psf_shape_y, psf_shape_x)."
215 )
216 )
217 bounds: CellGridBounds = pydantic.Field(
218 description=(
219 "Description of the cell grid and any missing cells. Array entries for "
220 "missing cells should be NaN."
221 )
222 )
224 def deserialize(
225 self, archive: InputArchive[Any], *, bbox: Box | None = None, **kwargs: Any
226 ) -> CellPointSpreadFunction:
227 if kwargs:
228 raise InvalidParameterError(
229 f"Unrecognized parameters for CellPointSpreadFunction: {set(kwargs.keys())}."
230 )
231 bounds = self.bounds
232 if bbox is not None:
233 bounds, slices = CellPointSpreadFunction._subset_impl(bounds, bbox)
234 array = archive.get_array(self.array, slices=slices)
235 else:
236 array = archive.get_array(self.array)
237 return CellPointSpreadFunction(array, bounds)