Coverage for python / lsst / images / fields / _spline.py: 23%
138 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 08:52 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 08:52 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("SplineField", "SplineFieldSerializationModel")
16from typing import TYPE_CHECKING, Any, Literal, final
18import astropy.units
19import numpy as np
20import pydantic
21from scipy.interpolate import Akima1DInterpolator
23from .._concrete_bounds import SerializableBounds
24from .._geom import Bounds, Box
25from .._image import Image
26from ..serialization import (
27 ArchiveTree,
28 ArrayReferenceModel,
29 InlineArray,
30 InputArchive,
31 InvalidParameterError,
32 OutputArchive,
33 Unit,
34)
35from ._base import BaseField
37if TYPE_CHECKING:
38 try:
39 from lsst.afw.math import BackgroundMI as LegacyBackground
40 except ImportError:
41 type LegacyBackground = Any # type: ignore[no-redef]
44@final
45class SplineField(BaseField):
46 """A 2-d Akima spline interpolation of data on a regular grid.
48 Parameters
49 ----------
50 bounds
51 The region where this field can be evaluated.
52 data
53 The data points to be interpolated. Missing values (indicated by NaN)
54 are allowed. Will be set to read-only in place.
55 y
56 Coordinates for the first dimension of ``data``. Will be set to
57 read-only in place.
58 x
59 Coordinates for the second dimension of ``data``. Will be set to
60 read-only in place.
61 unit
62 Units of the field.
64 Notes
65 -----
66 This field is much faster to evaluate on a grid via `render` than at
67 arbitrary points via the function-call operator.
68 """
70 def __init__(
71 self,
72 bounds: Bounds,
73 data: np.ndarray,
74 *,
75 y: np.ndarray,
76 x: np.ndarray,
77 unit: astropy.units.UnitBase | None = None,
78 ):
79 if isinstance(data, astropy.units.Quantity):
80 if unit is not None:
81 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.")
82 unit = data.unit
83 data = data.to_value()
84 if data.ndim != 2:
85 raise ValueError("'data' must be 2-d.")
86 if y.ndim != 1:
87 raise ValueError("'y' must be 1-d.")
88 if x.ndim != 1:
89 raise ValueError("'x' must be 1-d.")
90 if data.shape != y.shape + x.shape:
91 raise ValueError(
92 f"Shape of 2-d 'data' {data.shape} does not match "
93 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}."
94 )
95 self._bounds = bounds
96 self._data = data
97 self._data.flags.writeable = False
98 self._x = x
99 self._x.flags.writeable = False
100 self._y = y
101 self._y.flags.writeable = False
102 self._unit = unit
104 @property
105 def bounds(self) -> Bounds:
106 return self._bounds
108 @property
109 def unit(self) -> astropy.units.UnitBase | None:
110 return self._unit
112 @property
113 def data(self) -> np.ndarray:
114 """The data points to be interpolated (`numpy.ndarray`).
116 May have missing values indicated by NaNs.
117 """
118 return self._data
120 @property
121 def x(self) -> np.ndarray:
122 """Coordinates for the second dimension of `data` (`numpy.ndarray`)."""
123 return self._x
125 @property
126 def y(self) -> np.ndarray:
127 """Coordinates for the first dimension of `data` (`numpy.ndarray`)."""
128 return self._y
130 def evaluate(
131 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False
132 ) -> np.ndarray | astropy.units.Quantity:
133 y, x = np.broadcast_arrays(y, x)
134 xg = self._x
135 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64)
136 mask = np.zeros(xg.size, dtype=bool)
137 for j in range(xg.size):
138 if (y_interpolator := self._make_y_interpolator(j)) is not None:
139 y_render[j, ...] = y_interpolator(y)
140 mask[j] = True
141 if not np.all(mask):
142 y_render = y_render[mask, ...]
143 xg = xg[mask]
144 result = np.zeros(y.shape, dtype=np.float64)
145 # There doesn't seem to be a way to avoid looping in Python here;
146 # maybe someday we'll push this down to a compiled language.
147 for i, xv in np.ndenumerate(x):
148 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None:
149 raise ValueError("No valid data points.")
150 v = x_interpolator(xv)
151 result[*i] = v
152 if quantity:
153 return astropy.units.Quantity(result, self._unit)
154 return result
156 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image:
157 if bbox is None:
158 bbox = self.bounds.bbox
159 xg = self._x
160 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype)
161 mask = np.zeros(xg.size, dtype=bool)
162 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points.
163 if (y_interpolator := self._make_y_interpolator(j)) is not None:
164 y_render[j, :] = y_interpolator(bbox.y.arange)
165 mask[j] = True
166 if not np.all(mask):
167 y_render = y_render[mask, :]
168 xg = xg[mask]
169 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None:
170 raise ValueError("No valid data points.")
171 rendered_array = x_interpolator(bbox.x.arange)
172 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype)
174 def multiply_constant(
175 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase
176 ) -> SplineField:
177 factor, unit = self._handle_factor_units(factor)
178 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit)
180 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel:
181 """Serialize the spline field to an output archive."""
182 return SplineFieldSerializationModel(
183 bounds=self.bounds.serialize(),
184 data=archive.add_array(self._data, name="data"),
185 y=self._y,
186 x=self._x,
187 unit=self._unit,
188 )
190 @staticmethod
191 def _get_archive_tree_type(
192 pointer_type: type[Any],
193 ) -> type[SplineFieldSerializationModel]:
194 """Return the serialization model type for this object for an archive
195 type that uses the given pointer type.
196 """
197 return SplineFieldSerializationModel
199 @staticmethod
200 def from_legacy_background(
201 legacy_background: LegacyBackground,
202 bounds: Bounds | None = None,
203 unit: astropy.units.UnitBase | None = None,
204 ) -> SplineField:
205 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance.
207 Parameters
208 ----------
209 legacy
210 Legacy background object to convert.
211 bounds
212 The bounds of the returned field, if they should be different from
213 the bounding box of ``legacy_background``.
214 unit
215 The units of the returned field (`lsst.afw.math.Background`
216 objects do not know their units).
218 Notes
219 -----
220 `SplineField.render` and the `lsst.afw` background interpolator both
221 use Akima splines, but with slightly different boundary conditions.
222 They should produce equivalent to single-precision round-off error
223 when evaluated within the region enclosed by bin centers (i.e. where
224 no extrapolation is necessary) and when there are five or more
225 points to be interpolated in each row and column.
226 """
227 from lsst.afw.math import ApproximateControl, Interpolate
229 bg_control = legacy_background.getBackgroundControl()
230 approx_control = bg_control.getApproximateControl()
231 stats_image = legacy_background.getStatsImage()
232 if approx_control.getStyle() != ApproximateControl.UNKNOWN:
233 raise TypeError("Legacy background uses Chebyshev approximation, not splines.")
234 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE:
235 raise TypeError("Legacy background does not use Akima spline interpolation.")
236 x = legacy_background.getBinCentersX()
237 y = legacy_background.getBinCentersY()
238 return SplineField(
239 Box.from_legacy(legacy_background.getImageBBox()) if bounds is None else bounds,
240 stats_image.image.array,
241 x=x,
242 y=y,
243 unit=unit,
244 )
246 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None:
247 match len(loc):
248 case 0:
249 return None
250 case 1:
251 # SciPy can handle only two points by downgrading to linear
252 # interpolation, but it raises if given only one. Mock up
253 # two for the nearest-neighbor fallback.
254 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]]))
255 case _:
256 return Akima1DInterpolator(loc, val, extrapolate=True)
258 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None:
259 y = self._y
260 z = self._data[:, j]
261 mask = np.isfinite(z)
262 if not np.all(mask):
263 y = y[mask]
264 z = z[mask]
265 del mask
266 return self._make_1d_interpolator(y, z)
269class SplineFieldSerializationModel(ArchiveTree):
270 """Serialization model for `SplineField`."""
272 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated."))
274 data: ArrayReferenceModel = pydantic.Field(
275 description="2-d data to interpolate. NaNs indicate missing values."
276 )
278 y: InlineArray = pydantic.Field(description="Row positions of the data points.")
280 x: InlineArray = pydantic.Field(description="Column positions of the data points.")
282 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.")
284 field_type: Literal["SPLINE"] = "SPLINE"
286 def deserialize(self, archive: InputArchive, **kwargs: Any) -> SplineField:
287 """Deserialize the spline field from an input archive."""
288 if kwargs:
289 raise InvalidParameterError(f"Unrecognized parameters for SplineField: {set(kwargs.keys())}.")
290 return SplineField(
291 self.bounds.deserialize(),
292 archive.get_array(self.data),
293 y=self.y,
294 x=self.x,
295 unit=self.unit,
296 )