Coverage for python / lsst / images / fields / _spline.py: 24%
136 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 08:01 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 08:01 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("SplineField", "SplineFieldSerializationModel")
16from typing import TYPE_CHECKING, Any, Literal, final
18import astropy.units
19import numpy as np
20import pydantic
21from scipy.interpolate import Akima1DInterpolator
23from .._concrete_bounds import SerializableBounds
24from .._geom import Bounds, Box
25from .._image import Image
26from ..serialization import ArchiveTree, ArrayReferenceModel, InlineArray, InputArchive, OutputArchive, Unit
27from ._base import BaseField
29if TYPE_CHECKING:
30 try:
31 from lsst.afw.math import BackgroundMI as LegacyBackground
32 except ImportError:
33 type LegacyBackground = Any # type: ignore[no-redef]
36@final
37class SplineField(BaseField):
38 """A 2-d Akima spline interpolation of data on a regular grid.
40 Parameters
41 ----------
42 bounds
43 The region where this field can be evaluated.
44 data
45 The data points to be interpolated. Missing values (indicated by NaN)
46 are allowed. Will be set to read-only in place.
47 y
48 Coordinates for the first dimension of ``data``. Will be set to
49 read-only in place.
50 x
51 Coordinates for the second dimension of ``data``. Will be set to
52 read-only in place.
53 unit
54 Units of the field.
56 Notes
57 -----
58 This field is much faster to evaluate on a grid via `render` than at
59 arbitrary points via the function-call operator.
60 """
62 def __init__(
63 self,
64 bounds: Bounds,
65 data: np.ndarray,
66 *,
67 y: np.ndarray,
68 x: np.ndarray,
69 unit: astropy.units.UnitBase | None = None,
70 ):
71 if isinstance(data, astropy.units.Quantity):
72 if unit is not None:
73 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.")
74 unit = data.unit
75 data = data.to_value()
76 if data.ndim != 2:
77 raise ValueError("'data' must be 2-d.")
78 if y.ndim != 1:
79 raise ValueError("'y' must be 1-d.")
80 if x.ndim != 1:
81 raise ValueError("'x' must be 1-d.")
82 if data.shape != y.shape + x.shape:
83 raise ValueError(
84 f"Shape of 2-d 'data' {data.shape} does not match "
85 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}."
86 )
87 self._bounds = bounds
88 self._data = data
89 self._data.flags.writeable = False
90 self._x = x
91 self._x.flags.writeable = False
92 self._y = y
93 self._y.flags.writeable = False
94 self._unit = unit
96 @property
97 def bounds(self) -> Bounds:
98 return self._bounds
100 @property
101 def unit(self) -> astropy.units.UnitBase | None:
102 return self._unit
104 @property
105 def data(self) -> np.ndarray:
106 """The data points to be interpolated (`numpy.ndarray`).
108 May have missing values indicated by NaNs.
109 """
110 return self._data
112 @property
113 def x(self) -> np.ndarray:
114 """Coordinates for the second dimension of `data` (`numpy.ndarray`)."""
115 return self._x
117 @property
118 def y(self) -> np.ndarray:
119 """Coordinates for the first dimension of `data` (`numpy.ndarray`)."""
120 return self._y
122 def evaluate(
123 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False
124 ) -> np.ndarray | astropy.units.Quantity:
125 y, x = np.broadcast_arrays(y, x)
126 xg = self._x
127 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64)
128 mask = np.zeros(xg.size, dtype=bool)
129 for j in range(xg.size):
130 if (y_interpolator := self._make_y_interpolator(j)) is not None:
131 y_render[j, ...] = y_interpolator(y)
132 mask[j] = True
133 if not np.all(mask):
134 y_render = y_render[mask, ...]
135 xg = xg[mask]
136 result = np.zeros(y.shape, dtype=np.float64)
137 # There doesn't seem to be a way to avoid looping in Python here;
138 # maybe someday we'll push this down to a compiled language.
139 for i, xv in np.ndenumerate(x):
140 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None:
141 raise ValueError("No valid data points.")
142 v = x_interpolator(xv)
143 result[*i] = v
144 if quantity:
145 return astropy.units.Quantity(result, self._unit)
146 return result
148 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image:
149 if bbox is None:
150 bbox = self.bounds.bbox
151 xg = self._x
152 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype)
153 mask = np.zeros(xg.size, dtype=bool)
154 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points.
155 if (y_interpolator := self._make_y_interpolator(j)) is not None:
156 y_render[j, :] = y_interpolator(bbox.y.arange)
157 mask[j] = True
158 if not np.all(mask):
159 y_render = y_render[mask, :]
160 xg = xg[mask]
161 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None:
162 raise ValueError("No valid data points.")
163 rendered_array = x_interpolator(bbox.x.arange)
164 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype)
166 def multiply_constant(
167 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase
168 ) -> SplineField:
169 factor, unit = self._handle_factor_units(factor)
170 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit)
172 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel:
173 """Serialize the spline field to an output archive."""
174 return SplineFieldSerializationModel(
175 bounds=self.bounds.serialize(),
176 data=archive.add_array(self._data, name="data"),
177 y=self._y,
178 x=self._x,
179 unit=self._unit,
180 )
182 @staticmethod
183 def _get_archive_tree_type(
184 pointer_type: type[Any],
185 ) -> type[SplineFieldSerializationModel]:
186 """Return the serialization model type for this object for an archive
187 type that uses the given pointer type.
188 """
189 return SplineFieldSerializationModel
191 @staticmethod
192 def from_legacy_background(
193 legacy_background: LegacyBackground,
194 unit: astropy.units.UnitBase | None = None,
195 ) -> SplineField:
196 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance.
198 Notes
199 -----
200 `SplineField.render` and the `lsst.afw` background interpolator both
201 use Akima splines, but with slightly different boundary conditions.
202 They should produce equivalent to single-precision round-off error
203 when evaluated within the region enclosed by bin centers (i.e. where
204 no extrapolation is necessary) and when there are five or more
205 points to be interpolated in each row and column.
206 """
207 from lsst.afw.math import ApproximateControl, Interpolate
209 bg_control = legacy_background.getBackgroundControl()
210 approx_control = bg_control.getApproximateControl()
211 stats_image = legacy_background.getStatsImage()
212 if approx_control.getStyle() != ApproximateControl.UNKNOWN:
213 raise TypeError("Legacy background uses Chebyshev approximation, not splines.")
214 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE:
215 raise TypeError("Legacy background does not use Akima spline interpolation.")
216 x = legacy_background.getBinCentersX()
217 y = legacy_background.getBinCentersY()
218 return SplineField(
219 Box.from_legacy(legacy_background.getImageBBox()), stats_image.image.array, x=x, y=y, unit=unit
220 )
222 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None:
223 match len(loc):
224 case 0:
225 return None
226 case 1:
227 # SciPy can handle only two points by downgrading to linear
228 # interpolation, but it raises if given only one. Mock up
229 # two for the nearest-neighbor fallback.
230 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]]))
231 case _:
232 return Akima1DInterpolator(loc, val, extrapolate=True)
234 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None:
235 y = self._y
236 z = self._data[:, j]
237 mask = np.isfinite(z)
238 if not np.all(mask):
239 y = y[mask]
240 z = z[mask]
241 del mask
242 return self._make_1d_interpolator(y, z)
245class SplineFieldSerializationModel(ArchiveTree):
246 """Serialization model for `SplineField`."""
248 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated."))
250 data: ArrayReferenceModel = pydantic.Field(
251 description="2-d data to interpolate. NaNs indicate missing values."
252 )
254 y: InlineArray = pydantic.Field(description="Row positions of the data points.")
256 x: InlineArray = pydantic.Field(description="Column positions of the data points.")
258 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.")
260 field_type: Literal["SPLINE"] = "SPLINE"
262 def deserialize(self, archive: InputArchive) -> SplineField:
263 """Deserialize the spline field from an input archive."""
264 return SplineField(
265 self.bounds.deserialize(),
266 archive.get_array(self.data),
267 y=self.y,
268 x=self.x,
269 unit=self.unit,
270 )