Coverage for python/lsst/images/fields/_spline.py: 24%
146 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 02:13 -0700
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 02:13 -0700
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("SplineField", "SplineFieldSerializationModel")
16from typing import TYPE_CHECKING, Any, Literal, final
18import astropy.units
19import numpy as np
20import pydantic
21from scipy.interpolate import Akima1DInterpolator
23from .._concrete_bounds import SerializableBounds
24from .._geom import Bounds, Box
25from .._image import Image
26from ..serialization import (
27 ArchiveTree,
28 ArrayReferenceModel,
29 InlineArray,
30 InputArchive,
31 InvalidParameterError,
32 OutputArchive,
33 Unit,
34)
35from ._base import BaseField
37if TYPE_CHECKING:
38 try:
39 from lsst.afw.math import BackgroundMI as LegacyBackground
40 except ImportError:
41 type LegacyBackground = Any # type: ignore[no-redef]
44@final
45class SplineField(BaseField):
46 """A 2-d Akima spline interpolation of data on a regular grid.
48 Parameters
49 ----------
50 bounds
51 The region where this field can be evaluated.
52 data
53 The data points to be interpolated. Missing values (indicated by NaN)
54 are allowed. Will be set to read-only in place.
55 y
56 Coordinates for the first dimension of ``data``. Will be set to
57 read-only in place.
58 x
59 Coordinates for the second dimension of ``data``. Will be set to
60 read-only in place.
61 unit
62 Units of the field.
64 Notes
65 -----
66 This field is much faster to evaluate on a grid via `render` than at
67 arbitrary points via the function-call operator.
68 """
70 def __init__(
71 self,
72 bounds: Bounds,
73 data: np.ndarray,
74 *,
75 y: np.ndarray,
76 x: np.ndarray,
77 unit: astropy.units.UnitBase | None = None,
78 ):
79 if isinstance(data, astropy.units.Quantity):
80 if unit is not None:
81 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.")
82 unit = data.unit
83 data = data.to_value()
84 if data.ndim != 2:
85 raise ValueError("'data' must be 2-d.")
86 if y.ndim != 1:
87 raise ValueError("'y' must be 1-d.")
88 if x.ndim != 1:
89 raise ValueError("'x' must be 1-d.")
90 if data.shape != y.shape + x.shape:
91 raise ValueError(
92 f"Shape of 2-d 'data' {data.shape} does not match "
93 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}."
94 )
95 self._bounds = bounds
96 self._data = data
97 self._data.flags.writeable = False
98 self._x = x
99 self._x.flags.writeable = False
100 self._y = y
101 self._y.flags.writeable = False
102 self._unit = unit
104 def __eq__(self, other: object) -> bool:
105 if type(other) is not SplineField:
106 return NotImplemented
107 return (
108 self._bounds == other._bounds
109 and self._unit == other._unit
110 and np.array_equal(self._data, other._data, equal_nan=True)
111 and np.array_equal(self._x, other._x, equal_nan=True)
112 and np.array_equal(self._y, other._y, equal_nan=True)
113 )
115 __hash__ = None # type: ignore[assignment]
117 @property
118 def bounds(self) -> Bounds:
119 return self._bounds
121 @property
122 def unit(self) -> astropy.units.UnitBase | None:
123 return self._unit
125 @property
126 def data(self) -> np.ndarray:
127 """The data points to be interpolated (`numpy.ndarray`).
129 May have missing values indicated by NaNs.
130 """
131 return self._data
133 @property
134 def x(self) -> np.ndarray:
135 """Coordinates for the second dimension of `data` (`numpy.ndarray`)."""
136 return self._x
138 @property
139 def y(self) -> np.ndarray:
140 """Coordinates for the first dimension of `data` (`numpy.ndarray`)."""
141 return self._y
143 @property
144 def is_constant(self) -> bool:
145 # We really do want an exact floating-point comparison here.
146 return (self._data == self._data[0, 0]).all()
148 def evaluate(
149 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False
150 ) -> np.ndarray | astropy.units.Quantity:
151 y, x = np.broadcast_arrays(y, x)
152 xg = self._x
153 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64)
154 mask = np.zeros(xg.size, dtype=bool)
155 for j in range(xg.size):
156 if (y_interpolator := self._make_y_interpolator(j)) is not None:
157 y_render[j, ...] = y_interpolator(y)
158 mask[j] = True
159 if not np.all(mask):
160 y_render = y_render[mask, ...]
161 xg = xg[mask]
162 result = np.zeros(y.shape, dtype=np.float64)
163 # There doesn't seem to be a way to avoid looping in Python here;
164 # maybe someday we'll push this down to a compiled language.
165 for i, xv in np.ndenumerate(x):
166 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None:
167 raise ValueError("No valid data points.")
168 v = x_interpolator(xv)
169 result[*i] = v
170 if quantity:
171 return astropy.units.Quantity(result, self._unit)
172 return result
174 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image:
175 if bbox is None:
176 bbox = self.bounds.bbox
177 xg = self._x
178 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype)
179 mask = np.zeros(xg.size, dtype=bool)
180 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points.
181 if (y_interpolator := self._make_y_interpolator(j)) is not None:
182 y_render[j, :] = y_interpolator(bbox.y.arange)
183 mask[j] = True
184 if not np.all(mask):
185 y_render = y_render[mask, :]
186 xg = xg[mask]
187 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None:
188 raise ValueError("No valid data points.")
189 rendered_array = x_interpolator(bbox.x.arange)
190 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype)
192 def multiply_constant(
193 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase
194 ) -> SplineField:
195 factor, unit = self._handle_factor_units(factor)
196 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit)
198 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel:
199 """Serialize the spline field to an output archive."""
200 return SplineFieldSerializationModel(
201 bounds=self.bounds.serialize(),
202 data=archive.add_array(self._data, name="data"),
203 y=self._y,
204 x=self._x,
205 unit=self._unit,
206 )
208 @staticmethod
209 def _get_archive_tree_type(
210 pointer_type: type[Any],
211 ) -> type[SplineFieldSerializationModel]:
212 """Return the serialization model type for this object for an archive
213 type that uses the given pointer type.
214 """
215 return SplineFieldSerializationModel
217 @staticmethod
218 def from_legacy_background(
219 legacy_background: LegacyBackground,
220 bounds: Bounds | None = None,
221 unit: astropy.units.UnitBase | None = None,
222 ) -> SplineField:
223 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance.
225 Parameters
226 ----------
227 legacy
228 Legacy background object to convert.
229 bounds
230 The bounds of the returned field, if they should be different from
231 the bounding box of ``legacy_background``.
232 unit
233 The units of the returned field (`lsst.afw.math.Background`
234 objects do not know their units).
236 Notes
237 -----
238 `SplineField.render` and the `lsst.afw` background interpolator both
239 use Akima splines, but with slightly different boundary conditions.
240 They should produce equivalent to single-precision round-off error
241 when evaluated within the region enclosed by bin centers (i.e. where
242 no extrapolation is necessary) and when there are five or more
243 points to be interpolated in each row and column.
244 """
245 from lsst.afw.math import ApproximateControl, Interpolate
247 bg_control = legacy_background.getBackgroundControl()
248 approx_control = bg_control.getApproximateControl()
249 stats_image = legacy_background.getStatsImage()
250 if approx_control.getStyle() != ApproximateControl.UNKNOWN:
251 raise TypeError("Legacy background uses Chebyshev approximation, not splines.")
252 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE:
253 raise TypeError("Legacy background does not use Akima spline interpolation.")
254 x = legacy_background.getBinCentersX()
255 y = legacy_background.getBinCentersY()
256 return SplineField(
257 Box.from_legacy(legacy_background.getImageBBox()) if bounds is None else bounds,
258 stats_image.image.array,
259 x=x,
260 y=y,
261 unit=unit,
262 )
264 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None:
265 match len(loc):
266 case 0:
267 return None
268 case 1:
269 # SciPy can handle only two points by downgrading to linear
270 # interpolation, but it raises if given only one. Mock up
271 # two for the nearest-neighbor fallback.
272 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]]))
273 case _:
274 return Akima1DInterpolator(loc, val, extrapolate=True)
276 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None:
277 y = self._y
278 z = self._data[:, j]
279 mask = np.isfinite(z)
280 if not np.all(mask):
281 y = y[mask]
282 z = z[mask]
283 del mask
284 return self._make_1d_interpolator(y, z)
287class SplineFieldSerializationModel(ArchiveTree):
288 """Serialization model for `SplineField`."""
290 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated."))
292 data: ArrayReferenceModel = pydantic.Field(
293 description="2-d data to interpolate. NaNs indicate missing values."
294 )
296 y: InlineArray = pydantic.Field(description="Row positions of the data points.")
298 x: InlineArray = pydantic.Field(description="Column positions of the data points.")
300 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.")
302 field_type: Literal["SPLINE"] = "SPLINE"
304 def deserialize(self, archive: InputArchive, **kwargs: Any) -> SplineField:
305 """Deserialize the spline field from an input archive."""
306 if kwargs:
307 raise InvalidParameterError(f"Unrecognized parameters for SplineField: {set(kwargs.keys())}.")
308 return SplineField(
309 self.bounds.deserialize(),
310 archive.get_array(self.data),
311 y=self.y,
312 x=self.x,
313 unit=self.unit,
314 )