Coverage for python/lsst/images/fields/_spline.py: 24%

146 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 09:00 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("SplineField", "SplineFieldSerializationModel") 

15 

16from typing import TYPE_CHECKING, Any, Literal, final 

17 

18import astropy.units 

19import numpy as np 

20import pydantic 

21from scipy.interpolate import Akima1DInterpolator 

22 

23from .._concrete_bounds import SerializableBounds 

24from .._geom import Bounds, Box 

25from .._image import Image 

26from ..serialization import ( 

27 ArchiveTree, 

28 ArrayReferenceModel, 

29 InlineArray, 

30 InputArchive, 

31 InvalidParameterError, 

32 OutputArchive, 

33 Unit, 

34) 

35from ._base import BaseField 

36 

37if TYPE_CHECKING: 

38 try: 

39 from lsst.afw.math import BackgroundMI as LegacyBackground 

40 except ImportError: 

41 type LegacyBackground = Any # type: ignore[no-redef] 

42 

43 

44@final 

45class SplineField(BaseField): 

46 """A 2-d Akima spline interpolation of data on a regular grid. 

47 

48 Parameters 

49 ---------- 

50 bounds 

51 The region where this field can be evaluated. 

52 data 

53 The data points to be interpolated. Missing values (indicated by NaN) 

54 are allowed. Will be set to read-only in place. 

55 y 

56 Coordinates for the first dimension of ``data``. Will be set to 

57 read-only in place. 

58 x 

59 Coordinates for the second dimension of ``data``. Will be set to 

60 read-only in place. 

61 unit 

62 Units of the field. 

63 

64 Notes 

65 ----- 

66 This field is much faster to evaluate on a grid via `render` than at 

67 arbitrary points via the function-call operator. 

68 """ 

69 

70 def __init__( 

71 self, 

72 bounds: Bounds, 

73 data: np.ndarray, 

74 *, 

75 y: np.ndarray, 

76 x: np.ndarray, 

77 unit: astropy.units.UnitBase | None = None, 

78 ): 

79 if isinstance(data, astropy.units.Quantity): 

80 if unit is not None: 

81 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.") 

82 unit = data.unit 

83 data = data.to_value() 

84 if data.ndim != 2: 

85 raise ValueError("'data' must be 2-d.") 

86 if y.ndim != 1: 

87 raise ValueError("'y' must be 1-d.") 

88 if x.ndim != 1: 

89 raise ValueError("'x' must be 1-d.") 

90 if data.shape != y.shape + x.shape: 

91 raise ValueError( 

92 f"Shape of 2-d 'data' {data.shape} does not match " 

93 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}." 

94 ) 

95 self._bounds = bounds 

96 self._data = data 

97 self._data.flags.writeable = False 

98 self._x = x 

99 self._x.flags.writeable = False 

100 self._y = y 

101 self._y.flags.writeable = False 

102 self._unit = unit 

103 

104 def __eq__(self, other: object) -> bool: 

105 if type(other) is not SplineField: 

106 return NotImplemented 

107 return ( 

108 self._bounds == other._bounds 

109 and self._unit == other._unit 

110 and np.array_equal(self._data, other._data, equal_nan=True) 

111 and np.array_equal(self._x, other._x, equal_nan=True) 

112 and np.array_equal(self._y, other._y, equal_nan=True) 

113 ) 

114 

115 __hash__ = None # type: ignore[assignment] 

116 

117 @property 

118 def bounds(self) -> Bounds: 

119 return self._bounds 

120 

121 @property 

122 def unit(self) -> astropy.units.UnitBase | None: 

123 return self._unit 

124 

125 @property 

126 def data(self) -> np.ndarray: 

127 """The data points to be interpolated (`numpy.ndarray`). 

128 

129 May have missing values indicated by NaNs. 

130 """ 

131 return self._data 

132 

133 @property 

134 def x(self) -> np.ndarray: 

135 """Coordinates for the second dimension of `data` (`numpy.ndarray`).""" 

136 return self._x 

137 

138 @property 

139 def y(self) -> np.ndarray: 

140 """Coordinates for the first dimension of `data` (`numpy.ndarray`).""" 

141 return self._y 

142 

143 @property 

144 def is_constant(self) -> bool: 

145 # We really do want an exact floating-point comparison here. 

146 return (self._data == self._data[0, 0]).all() 

147 

148 def evaluate( 

149 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

150 ) -> np.ndarray | astropy.units.Quantity: 

151 y, x = np.broadcast_arrays(y, x) 

152 xg = self._x 

153 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64) 

154 mask = np.zeros(xg.size, dtype=bool) 

155 for j in range(xg.size): 

156 if (y_interpolator := self._make_y_interpolator(j)) is not None: 

157 y_render[j, ...] = y_interpolator(y) 

158 mask[j] = True 

159 if not np.all(mask): 

160 y_render = y_render[mask, ...] 

161 xg = xg[mask] 

162 result = np.zeros(y.shape, dtype=np.float64) 

163 # There doesn't seem to be a way to avoid looping in Python here; 

164 # maybe someday we'll push this down to a compiled language. 

165 for i, xv in np.ndenumerate(x): 

166 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None: 

167 raise ValueError("No valid data points.") 

168 v = x_interpolator(xv) 

169 result[*i] = v 

170 if quantity: 

171 return astropy.units.Quantity(result, self._unit) 

172 return result 

173 

174 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

175 if bbox is None: 

176 bbox = self.bounds.bbox 

177 xg = self._x 

178 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype) 

179 mask = np.zeros(xg.size, dtype=bool) 

180 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points. 

181 if (y_interpolator := self._make_y_interpolator(j)) is not None: 

182 y_render[j, :] = y_interpolator(bbox.y.arange) 

183 mask[j] = True 

184 if not np.all(mask): 

185 y_render = y_render[mask, :] 

186 xg = xg[mask] 

187 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None: 

188 raise ValueError("No valid data points.") 

189 rendered_array = x_interpolator(bbox.x.arange) 

190 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype) 

191 

192 def multiply_constant( 

193 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase 

194 ) -> SplineField: 

195 factor, unit = self._handle_factor_units(factor) 

196 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit) 

197 

198 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel: 

199 """Serialize the spline field to an output archive.""" 

200 return SplineFieldSerializationModel( 

201 bounds=self.bounds.serialize(), 

202 data=archive.add_array(self._data, name="data"), 

203 y=self._y, 

204 x=self._x, 

205 unit=self._unit, 

206 ) 

207 

208 @staticmethod 

209 def _get_archive_tree_type( 

210 pointer_type: type[Any], 

211 ) -> type[SplineFieldSerializationModel]: 

212 """Return the serialization model type for this object for an archive 

213 type that uses the given pointer type. 

214 """ 

215 return SplineFieldSerializationModel 

216 

217 @staticmethod 

218 def from_legacy_background( 

219 legacy_background: LegacyBackground, 

220 bounds: Bounds | None = None, 

221 unit: astropy.units.UnitBase | None = None, 

222 ) -> SplineField: 

223 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance. 

224 

225 Parameters 

226 ---------- 

227 legacy 

228 Legacy background object to convert. 

229 bounds 

230 The bounds of the returned field, if they should be different from 

231 the bounding box of ``legacy_background``. 

232 unit 

233 The units of the returned field (`lsst.afw.math.Background` 

234 objects do not know their units). 

235 

236 Notes 

237 ----- 

238 `SplineField.render` and the `lsst.afw` background interpolator both 

239 use Akima splines, but with slightly different boundary conditions. 

240 They should produce equivalent to single-precision round-off error 

241 when evaluated within the region enclosed by bin centers (i.e. where 

242 no extrapolation is necessary) and when there are five or more 

243 points to be interpolated in each row and column. 

244 """ 

245 from lsst.afw.math import ApproximateControl, Interpolate 

246 

247 bg_control = legacy_background.getBackgroundControl() 

248 approx_control = bg_control.getApproximateControl() 

249 stats_image = legacy_background.getStatsImage() 

250 if approx_control.getStyle() != ApproximateControl.UNKNOWN: 

251 raise TypeError("Legacy background uses Chebyshev approximation, not splines.") 

252 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE: 

253 raise TypeError("Legacy background does not use Akima spline interpolation.") 

254 x = legacy_background.getBinCentersX() 

255 y = legacy_background.getBinCentersY() 

256 return SplineField( 

257 Box.from_legacy(legacy_background.getImageBBox()) if bounds is None else bounds, 

258 stats_image.image.array, 

259 x=x, 

260 y=y, 

261 unit=unit, 

262 ) 

263 

264 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None: 

265 match len(loc): 

266 case 0: 

267 return None 

268 case 1: 

269 # SciPy can handle only two points by downgrading to linear 

270 # interpolation, but it raises if given only one. Mock up 

271 # two for the nearest-neighbor fallback. 

272 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]])) 

273 case _: 

274 return Akima1DInterpolator(loc, val, extrapolate=True) 

275 

276 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None: 

277 y = self._y 

278 z = self._data[:, j] 

279 mask = np.isfinite(z) 

280 if not np.all(mask): 

281 y = y[mask] 

282 z = z[mask] 

283 del mask 

284 return self._make_1d_interpolator(y, z) 

285 

286 

287class SplineFieldSerializationModel(ArchiveTree): 

288 """Serialization model for `SplineField`.""" 

289 

290 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated.")) 

291 

292 data: ArrayReferenceModel = pydantic.Field( 

293 description="2-d data to interpolate. NaNs indicate missing values." 

294 ) 

295 

296 y: InlineArray = pydantic.Field(description="Row positions of the data points.") 

297 

298 x: InlineArray = pydantic.Field(description="Column positions of the data points.") 

299 

300 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.") 

301 

302 field_type: Literal["SPLINE"] = "SPLINE" 

303 

304 def deserialize(self, archive: InputArchive, **kwargs: Any) -> SplineField: 

305 """Deserialize the spline field from an input archive.""" 

306 if kwargs: 

307 raise InvalidParameterError(f"Unrecognized parameters for SplineField: {set(kwargs.keys())}.") 

308 return SplineField( 

309 self.bounds.deserialize(), 

310 archive.get_array(self.data), 

311 y=self.y, 

312 x=self.x, 

313 unit=self.unit, 

314 )