Coverage for python / lsst / images / fields / _spline.py: 24%
141 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-23 01:30 -0700
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-23 01:30 -0700
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("SplineField", "SplineFieldSerializationModel")
16from typing import TYPE_CHECKING, Any, Literal, final
18import astropy.units
19import numpy as np
20import pydantic
21from scipy.interpolate import Akima1DInterpolator
23from .._concrete_bounds import SerializableBounds
24from .._geom import Bounds, Box
25from .._image import Image
26from ..serialization import (
27 ArchiveTree,
28 ArrayReferenceModel,
29 InlineArray,
30 InputArchive,
31 InvalidParameterError,
32 OutputArchive,
33 Unit,
34)
35from ._base import BaseField
37if TYPE_CHECKING:
38 try:
39 from lsst.afw.math import BackgroundMI as LegacyBackground
40 except ImportError:
41 type LegacyBackground = Any # type: ignore[no-redef]
44@final
45class SplineField(BaseField):
46 """A 2-d Akima spline interpolation of data on a regular grid.
48 Parameters
49 ----------
50 bounds
51 The region where this field can be evaluated.
52 data
53 The data points to be interpolated. Missing values (indicated by NaN)
54 are allowed. Will be set to read-only in place.
55 y
56 Coordinates for the first dimension of ``data``. Will be set to
57 read-only in place.
58 x
59 Coordinates for the second dimension of ``data``. Will be set to
60 read-only in place.
61 unit
62 Units of the field.
64 Notes
65 -----
66 This field is much faster to evaluate on a grid via `render` than at
67 arbitrary points via the function-call operator.
68 """
70 def __init__(
71 self,
72 bounds: Bounds,
73 data: np.ndarray,
74 *,
75 y: np.ndarray,
76 x: np.ndarray,
77 unit: astropy.units.UnitBase | None = None,
78 ):
79 if isinstance(data, astropy.units.Quantity):
80 if unit is not None:
81 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.")
82 unit = data.unit
83 data = data.to_value()
84 if data.ndim != 2:
85 raise ValueError("'data' must be 2-d.")
86 if y.ndim != 1:
87 raise ValueError("'y' must be 1-d.")
88 if x.ndim != 1:
89 raise ValueError("'x' must be 1-d.")
90 if data.shape != y.shape + x.shape:
91 raise ValueError(
92 f"Shape of 2-d 'data' {data.shape} does not match "
93 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}."
94 )
95 self._bounds = bounds
96 self._data = data
97 self._data.flags.writeable = False
98 self._x = x
99 self._x.flags.writeable = False
100 self._y = y
101 self._y.flags.writeable = False
102 self._unit = unit
104 @property
105 def bounds(self) -> Bounds:
106 return self._bounds
108 @property
109 def unit(self) -> astropy.units.UnitBase | None:
110 return self._unit
112 @property
113 def data(self) -> np.ndarray:
114 """The data points to be interpolated (`numpy.ndarray`).
116 May have missing values indicated by NaNs.
117 """
118 return self._data
120 @property
121 def x(self) -> np.ndarray:
122 """Coordinates for the second dimension of `data` (`numpy.ndarray`)."""
123 return self._x
125 @property
126 def y(self) -> np.ndarray:
127 """Coordinates for the first dimension of `data` (`numpy.ndarray`)."""
128 return self._y
130 @property
131 def is_constant(self) -> bool:
132 # We really do want an exact floating-point comparison here.
133 return (self._data == self._data[0, 0]).all()
135 def evaluate(
136 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False
137 ) -> np.ndarray | astropy.units.Quantity:
138 y, x = np.broadcast_arrays(y, x)
139 xg = self._x
140 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64)
141 mask = np.zeros(xg.size, dtype=bool)
142 for j in range(xg.size):
143 if (y_interpolator := self._make_y_interpolator(j)) is not None:
144 y_render[j, ...] = y_interpolator(y)
145 mask[j] = True
146 if not np.all(mask):
147 y_render = y_render[mask, ...]
148 xg = xg[mask]
149 result = np.zeros(y.shape, dtype=np.float64)
150 # There doesn't seem to be a way to avoid looping in Python here;
151 # maybe someday we'll push this down to a compiled language.
152 for i, xv in np.ndenumerate(x):
153 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None:
154 raise ValueError("No valid data points.")
155 v = x_interpolator(xv)
156 result[*i] = v
157 if quantity:
158 return astropy.units.Quantity(result, self._unit)
159 return result
161 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image:
162 if bbox is None:
163 bbox = self.bounds.bbox
164 xg = self._x
165 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype)
166 mask = np.zeros(xg.size, dtype=bool)
167 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points.
168 if (y_interpolator := self._make_y_interpolator(j)) is not None:
169 y_render[j, :] = y_interpolator(bbox.y.arange)
170 mask[j] = True
171 if not np.all(mask):
172 y_render = y_render[mask, :]
173 xg = xg[mask]
174 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None:
175 raise ValueError("No valid data points.")
176 rendered_array = x_interpolator(bbox.x.arange)
177 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype)
179 def multiply_constant(
180 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase
181 ) -> SplineField:
182 factor, unit = self._handle_factor_units(factor)
183 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit)
185 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel:
186 """Serialize the spline field to an output archive."""
187 return SplineFieldSerializationModel(
188 bounds=self.bounds.serialize(),
189 data=archive.add_array(self._data, name="data"),
190 y=self._y,
191 x=self._x,
192 unit=self._unit,
193 )
195 @staticmethod
196 def _get_archive_tree_type(
197 pointer_type: type[Any],
198 ) -> type[SplineFieldSerializationModel]:
199 """Return the serialization model type for this object for an archive
200 type that uses the given pointer type.
201 """
202 return SplineFieldSerializationModel
204 @staticmethod
205 def from_legacy_background(
206 legacy_background: LegacyBackground,
207 bounds: Bounds | None = None,
208 unit: astropy.units.UnitBase | None = None,
209 ) -> SplineField:
210 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance.
212 Parameters
213 ----------
214 legacy
215 Legacy background object to convert.
216 bounds
217 The bounds of the returned field, if they should be different from
218 the bounding box of ``legacy_background``.
219 unit
220 The units of the returned field (`lsst.afw.math.Background`
221 objects do not know their units).
223 Notes
224 -----
225 `SplineField.render` and the `lsst.afw` background interpolator both
226 use Akima splines, but with slightly different boundary conditions.
227 They should produce equivalent to single-precision round-off error
228 when evaluated within the region enclosed by bin centers (i.e. where
229 no extrapolation is necessary) and when there are five or more
230 points to be interpolated in each row and column.
231 """
232 from lsst.afw.math import ApproximateControl, Interpolate
234 bg_control = legacy_background.getBackgroundControl()
235 approx_control = bg_control.getApproximateControl()
236 stats_image = legacy_background.getStatsImage()
237 if approx_control.getStyle() != ApproximateControl.UNKNOWN:
238 raise TypeError("Legacy background uses Chebyshev approximation, not splines.")
239 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE:
240 raise TypeError("Legacy background does not use Akima spline interpolation.")
241 x = legacy_background.getBinCentersX()
242 y = legacy_background.getBinCentersY()
243 return SplineField(
244 Box.from_legacy(legacy_background.getImageBBox()) if bounds is None else bounds,
245 stats_image.image.array,
246 x=x,
247 y=y,
248 unit=unit,
249 )
251 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None:
252 match len(loc):
253 case 0:
254 return None
255 case 1:
256 # SciPy can handle only two points by downgrading to linear
257 # interpolation, but it raises if given only one. Mock up
258 # two for the nearest-neighbor fallback.
259 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]]))
260 case _:
261 return Akima1DInterpolator(loc, val, extrapolate=True)
263 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None:
264 y = self._y
265 z = self._data[:, j]
266 mask = np.isfinite(z)
267 if not np.all(mask):
268 y = y[mask]
269 z = z[mask]
270 del mask
271 return self._make_1d_interpolator(y, z)
274class SplineFieldSerializationModel(ArchiveTree):
275 """Serialization model for `SplineField`."""
277 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated."))
279 data: ArrayReferenceModel = pydantic.Field(
280 description="2-d data to interpolate. NaNs indicate missing values."
281 )
283 y: InlineArray = pydantic.Field(description="Row positions of the data points.")
285 x: InlineArray = pydantic.Field(description="Column positions of the data points.")
287 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.")
289 field_type: Literal["SPLINE"] = "SPLINE"
291 def deserialize(self, archive: InputArchive, **kwargs: Any) -> SplineField:
292 """Deserialize the spline field from an input archive."""
293 if kwargs:
294 raise InvalidParameterError(f"Unrecognized parameters for SplineField: {set(kwargs.keys())}.")
295 return SplineField(
296 self.bounds.deserialize(),
297 archive.get_array(self.data),
298 y=self.y,
299 x=self.x,
300 unit=self.unit,
301 )