Coverage for python/lsst/images/cells/_psf.py: 28%
116 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 09:08 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 09:08 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("CellPointSpreadFunction", "CellPointSpreadFunctionSerializationModel")
16from functools import cached_property
17from typing import TYPE_CHECKING, Any, Literal, overload
19import numpy as np
20import pydantic
22from .._cell_grid import CellGrid, CellGridBounds, CellIJ
23from .._geom import YX, Bounds, BoundsError, Box
24from .._image import Image
25from ..psfs import PointSpreadFunction
26from ..serialization import (
27 ArchiveTree,
28 ArrayReferenceModel,
29 InlineArrayModel,
30 InputArchive,
31 InvalidParameterError,
32 OutputArchive,
33)
34from ..utils import round_half_up
36if TYPE_CHECKING:
37 try:
38 from lsst.cell_coadds import StitchedPsf
39 except ImportError:
40 type StitchedPsf = Any # type: ignore[no-redef]
43class CellPointSpreadFunction(PointSpreadFunction):
44 """A PSF model that is at least approximately constant over cells.
46 Parameters
47 ----------
48 array
49 A 4-d array of PSF kernel images with with shape
50 ``(n_cells_y, n_cells_x, psf_shape_y, psf_shape_x)``.
51 bounds
52 Description of the cell grid and any missing cells. Array entries for
53 missing cells should be NaN.
54 resampling_kernel
55 Name of the resampling kernel to use when shifting the kernel image
56 into the stellar image.
58 Notes
59 -----
60 Unlike most PSF model types, `CellPointSpreadFunction` can be subset via
61 slicing:
63 - a bounding `.Box` for a subimage, which returns a new PSF with only the
64 cells that cover that subimage;
65 - a `CellIJ` index, which returns the kernel image for that cell.
66 -
67 """
69 def __init__(
70 self,
71 array: np.ndarray,
72 bounds: CellGridBounds,
73 resampling_kernel: Literal["lanczos3", "lanczos5"] = "lanczos5",
74 ):
75 self._array = array
76 self._bounds: CellGridBounds = bounds
77 self._resampling_kernel = resampling_kernel
79 @property
80 def grid(self) -> CellGrid:
81 """The grid that defines the PSF's cells (`CellGrid`).
83 Notes
84 -----
85 This is usually (but is not guaranteed to be) the grid for a full
86 patch, even when the PSF only covers a subimage.
87 """
88 return self._bounds.grid
90 @property
91 def bounds(self) -> CellGridBounds:
92 """The bounds where the PSF can be evaluated (`CellGridBounds`)."""
93 return self._bounds
95 @cached_property
96 def kernel_bbox(self) -> Box:
97 sy, sx = self._array.shape[2:]
98 ry = sy // 2
99 rx = sx // 2
100 return Box.factory[-ry : ry + 1, -rx : rx + 1]
102 @overload
103 def __getitem__(self, bbox: Box) -> CellPointSpreadFunction: ... 103 ↛ exitline 103 didn't return from function '__getitem__' because
104 @overload
105 def __getitem__(self, index: CellIJ) -> Image: ... 105 ↛ exitline 105 didn't return from function '__getitem__' because
107 def __getitem__(self, key: Box | CellIJ) -> CellPointSpreadFunction | Image:
108 match key:
109 case CellIJ():
110 if key in self._bounds.missing:
111 raise BoundsError(f"Cell {key} is missing for this PSF.")
112 index = key - self._bounds.grid_start
113 try:
114 return Image(self._array[index.i, index.j], bbox=self.kernel_bbox)
115 except IndexError:
116 raise BoundsError(f"Cell {key} is out of bounds for this PSF.")
117 case Box():
118 bounds, slices = self._subset_impl(self._bounds, key)
119 return CellPointSpreadFunction(self._array[slices.y, slices.x, ...].copy(), bounds=bounds)
120 case _:
121 raise TypeError("Invalid argument for CellPointSpreadFunction.__getitem__.")
123 def compute_kernel_image(self, *, x: float, y: float) -> Image:
124 index = self.grid.index_of(x=round(x), y=round(y))
125 try:
126 return self[index]
127 except Exception as err:
128 err.add_note(f"Evaluating cell PSF at x={x}, y={y}.")
129 raise
131 def compute_stellar_image(self, *, x: float, y: float) -> Image:
132 try:
133 from lsst.afw.math import offsetImage
134 from lsst.geom import Point2I
135 except ImportError as err:
136 err.add_note("CellPointSpreadFunction.compute_stellar_image cannot be used without lsst.afw.")
137 raise
138 ix = round_half_up(x)
139 dx = x - ix
140 iy = round_half_up(y)
141 dy = y - iy
142 kernel_image = self.compute_kernel_image(x=x, y=y)
143 if dx != 0 or dy != 0:
144 legacy_result = offsetImage(kernel_image.to_legacy(), dx, dy, self._resampling_kernel, 5)
145 else:
146 # This branch is equal to the other up to round-off error, but it's
147 # convenient nonetheless because it maintains exact compatibility
148 # with the legacy implementation, where the caching mechanism
149 # causes the offsetImage call to be skipped.
150 legacy_result = kernel_image.to_legacy()
151 legacy_result.setXY0(Point2I(legacy_result.getX0() + ix, legacy_result.getY0() + iy))
152 return Image.from_legacy(legacy_result)
154 def compute_stellar_bbox(self, *, x: float, y: float) -> Box:
155 # This is obviously inefficient, but it's what afw does, and hence the
156 # only easy way we've got to replicate what afw does.
157 return self.compute_stellar_image(x=x, y=y).bbox
159 def serialize(self, archive: OutputArchive[Any]) -> CellPointSpreadFunctionSerializationModel:
160 array_model = archive.add_array(self._array)
161 return CellPointSpreadFunctionSerializationModel(array=array_model, bounds=self.bounds)
163 @classmethod
164 def from_legacy(cls, legacy_psf: Any, bounds: Bounds | None = None) -> CellPointSpreadFunction:
165 # 'bounds' is accepted as an argument only for base-class
166 # compatibility; we always generate our own bounds.
167 from lsst.geom import Box2I
169 grid = CellGrid.from_legacy(legacy_psf.grid)
170 # Start with bounds that cover the entire grid.
171 bounds = CellGridBounds(grid=grid, bbox=grid.bbox)
172 # Shrink bounds to just the bbox where we have data.
173 legacy_bbox = Box2I()
174 for legacy_index in legacy_psf.images.keys():
175 legacy_bbox.include(legacy_psf.grid.bbox_of(legacy_index))
176 bounds = bounds[Box.from_legacy(legacy_bbox)]
177 # Allocate and populate the array.
178 psf_image_size_y, psf_image_size_x = legacy_psf.images.arbitrary.array.shape
179 array = np.zeros(
180 (
181 bounds.bbox.y.size // grid.cell_shape.y,
182 bounds.bbox.x.size // grid.cell_shape.x,
183 psf_image_size_y,
184 psf_image_size_x,
185 ),
186 dtype=np.float64,
187 )
188 missing: set[CellIJ] = set()
189 for cell_index in bounds.cell_indices():
190 legacy_index = cell_index.to_legacy()
191 array_index = cell_index - bounds.grid_start
192 if legacy_index in legacy_psf.images:
193 array[array_index.i, array_index.j] = legacy_psf.images[legacy_index].array
194 else:
195 array[array_index.i, array_index.j] = np.nan
196 missing.add(cell_index)
197 # Modify the bounds one last time to account for missing cells.
198 bounds = CellGridBounds(grid=grid, bbox=bounds.bbox, missing=frozenset(missing))
199 return cls(array, bounds=bounds)
201 @staticmethod
202 def _subset_impl(bounds: CellGridBounds, bbox: Box) -> tuple[CellGridBounds, YX[slice]]:
203 subset_bounds = bounds[bbox]
204 start = subset_bounds.grid_start - bounds.grid_start
205 stop = subset_bounds.grid_stop - bounds.grid_start
206 return subset_bounds, YX(y=slice(start.i, stop.i), x=slice(start.j, stop.j))
209class CellPointSpreadFunctionSerializationModel(ArchiveTree):
210 """Model used to serialize CellPointSpreadFunction objects."""
212 array: ArrayReferenceModel | InlineArrayModel = pydantic.Field(
213 description=(
214 "A 4-d array of PSF kernel images with with shape "
215 "(n_cells_y, n_cells_x, psf_shape_y, psf_shape_x)."
216 )
217 )
218 bounds: CellGridBounds = pydantic.Field(
219 description=(
220 "Description of the cell grid and any missing cells. Array entries for "
221 "missing cells should be NaN."
222 )
223 )
225 def deserialize(
226 self, archive: InputArchive[Any], *, bbox: Box | None = None, **kwargs: Any
227 ) -> CellPointSpreadFunction:
228 if kwargs:
229 raise InvalidParameterError(
230 f"Unrecognized parameters for CellPointSpreadFunction: {set(kwargs.keys())}."
231 )
232 bounds = self.bounds
233 if bbox is not None:
234 bounds, slices = CellPointSpreadFunction._subset_impl(bounds, bbox)
235 array = archive.get_array(self.array, slices=slices)
236 else:
237 array = archive.get_array(self.array)
238 return CellPointSpreadFunction(array, bounds)