Coverage for python / lsst / images / fields / _spline.py: 23%

138 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-20 08:26 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("SplineField", "SplineFieldSerializationModel") 

15 

16from typing import TYPE_CHECKING, Any, Literal, final 

17 

18import astropy.units 

19import numpy as np 

20import pydantic 

21from scipy.interpolate import Akima1DInterpolator 

22 

23from .._concrete_bounds import SerializableBounds 

24from .._geom import Bounds, Box 

25from .._image import Image 

26from ..serialization import ( 

27 ArchiveTree, 

28 ArrayReferenceModel, 

29 InlineArray, 

30 InputArchive, 

31 InvalidParameterError, 

32 OutputArchive, 

33 Unit, 

34) 

35from ._base import BaseField 

36 

37if TYPE_CHECKING: 

38 try: 

39 from lsst.afw.math import BackgroundMI as LegacyBackground 

40 except ImportError: 

41 type LegacyBackground = Any # type: ignore[no-redef] 

42 

43 

44@final 

45class SplineField(BaseField): 

46 """A 2-d Akima spline interpolation of data on a regular grid. 

47 

48 Parameters 

49 ---------- 

50 bounds 

51 The region where this field can be evaluated. 

52 data 

53 The data points to be interpolated. Missing values (indicated by NaN) 

54 are allowed. Will be set to read-only in place. 

55 y 

56 Coordinates for the first dimension of ``data``. Will be set to 

57 read-only in place. 

58 x 

59 Coordinates for the second dimension of ``data``. Will be set to 

60 read-only in place. 

61 unit 

62 Units of the field. 

63 

64 Notes 

65 ----- 

66 This field is much faster to evaluate on a grid via `render` than at 

67 arbitrary points via the function-call operator. 

68 """ 

69 

70 def __init__( 

71 self, 

72 bounds: Bounds, 

73 data: np.ndarray, 

74 *, 

75 y: np.ndarray, 

76 x: np.ndarray, 

77 unit: astropy.units.UnitBase | None = None, 

78 ): 

79 if isinstance(data, astropy.units.Quantity): 

80 if unit is not None: 

81 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.") 

82 unit = data.unit 

83 data = data.to_value() 

84 if data.ndim != 2: 

85 raise ValueError("'data' must be 2-d.") 

86 if y.ndim != 1: 

87 raise ValueError("'y' must be 1-d.") 

88 if x.ndim != 1: 

89 raise ValueError("'x' must be 1-d.") 

90 if data.shape != y.shape + x.shape: 

91 raise ValueError( 

92 f"Shape of 2-d 'data' {data.shape} does not match " 

93 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}." 

94 ) 

95 self._bounds = bounds 

96 self._data = data 

97 self._data.flags.writeable = False 

98 self._x = x 

99 self._x.flags.writeable = False 

100 self._y = y 

101 self._y.flags.writeable = False 

102 self._unit = unit 

103 

104 @property 

105 def bounds(self) -> Bounds: 

106 return self._bounds 

107 

108 @property 

109 def unit(self) -> astropy.units.UnitBase | None: 

110 return self._unit 

111 

112 @property 

113 def data(self) -> np.ndarray: 

114 """The data points to be interpolated (`numpy.ndarray`). 

115 

116 May have missing values indicated by NaNs. 

117 """ 

118 return self._data 

119 

120 @property 

121 def x(self) -> np.ndarray: 

122 """Coordinates for the second dimension of `data` (`numpy.ndarray`).""" 

123 return self._x 

124 

125 @property 

126 def y(self) -> np.ndarray: 

127 """Coordinates for the first dimension of `data` (`numpy.ndarray`).""" 

128 return self._y 

129 

130 def evaluate( 

131 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

132 ) -> np.ndarray | astropy.units.Quantity: 

133 y, x = np.broadcast_arrays(y, x) 

134 xg = self._x 

135 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64) 

136 mask = np.zeros(xg.size, dtype=bool) 

137 for j in range(xg.size): 

138 if (y_interpolator := self._make_y_interpolator(j)) is not None: 

139 y_render[j, ...] = y_interpolator(y) 

140 mask[j] = True 

141 if not np.all(mask): 

142 y_render = y_render[mask, ...] 

143 xg = xg[mask] 

144 result = np.zeros(y.shape, dtype=np.float64) 

145 # There doesn't seem to be a way to avoid looping in Python here; 

146 # maybe someday we'll push this down to a compiled language. 

147 for i, xv in np.ndenumerate(x): 

148 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None: 

149 raise ValueError("No valid data points.") 

150 v = x_interpolator(xv) 

151 result[*i] = v 

152 if quantity: 

153 return astropy.units.Quantity(result, self._unit) 

154 return result 

155 

156 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

157 if bbox is None: 

158 bbox = self.bounds.bbox 

159 xg = self._x 

160 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype) 

161 mask = np.zeros(xg.size, dtype=bool) 

162 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points. 

163 if (y_interpolator := self._make_y_interpolator(j)) is not None: 

164 y_render[j, :] = y_interpolator(bbox.y.arange) 

165 mask[j] = True 

166 if not np.all(mask): 

167 y_render = y_render[mask, :] 

168 xg = xg[mask] 

169 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None: 

170 raise ValueError("No valid data points.") 

171 rendered_array = x_interpolator(bbox.x.arange) 

172 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype) 

173 

174 def multiply_constant( 

175 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase 

176 ) -> SplineField: 

177 factor, unit = self._handle_factor_units(factor) 

178 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit) 

179 

180 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel: 

181 """Serialize the spline field to an output archive.""" 

182 return SplineFieldSerializationModel( 

183 bounds=self.bounds.serialize(), 

184 data=archive.add_array(self._data, name="data"), 

185 y=self._y, 

186 x=self._x, 

187 unit=self._unit, 

188 ) 

189 

190 @staticmethod 

191 def _get_archive_tree_type( 

192 pointer_type: type[Any], 

193 ) -> type[SplineFieldSerializationModel]: 

194 """Return the serialization model type for this object for an archive 

195 type that uses the given pointer type. 

196 """ 

197 return SplineFieldSerializationModel 

198 

199 @staticmethod 

200 def from_legacy_background( 

201 legacy_background: LegacyBackground, 

202 bounds: Bounds | None = None, 

203 unit: astropy.units.UnitBase | None = None, 

204 ) -> SplineField: 

205 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance. 

206 

207 Parameters 

208 ---------- 

209 legacy 

210 Legacy background object to convert. 

211 bounds 

212 The bounds of the returned field, if they should be different from 

213 the bounding box of ``legacy_background``. 

214 unit 

215 The units of the returned field (`lsst.afw.math.Background` 

216 objects do not know their units). 

217 

218 Notes 

219 ----- 

220 `SplineField.render` and the `lsst.afw` background interpolator both 

221 use Akima splines, but with slightly different boundary conditions. 

222 They should produce equivalent to single-precision round-off error 

223 when evaluated within the region enclosed by bin centers (i.e. where 

224 no extrapolation is necessary) and when there are five or more 

225 points to be interpolated in each row and column. 

226 """ 

227 from lsst.afw.math import ApproximateControl, Interpolate 

228 

229 bg_control = legacy_background.getBackgroundControl() 

230 approx_control = bg_control.getApproximateControl() 

231 stats_image = legacy_background.getStatsImage() 

232 if approx_control.getStyle() != ApproximateControl.UNKNOWN: 

233 raise TypeError("Legacy background uses Chebyshev approximation, not splines.") 

234 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE: 

235 raise TypeError("Legacy background does not use Akima spline interpolation.") 

236 x = legacy_background.getBinCentersX() 

237 y = legacy_background.getBinCentersY() 

238 return SplineField( 

239 Box.from_legacy(legacy_background.getImageBBox()) if bounds is None else bounds, 

240 stats_image.image.array, 

241 x=x, 

242 y=y, 

243 unit=unit, 

244 ) 

245 

246 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None: 

247 match len(loc): 

248 case 0: 

249 return None 

250 case 1: 

251 # SciPy can handle only two points by downgrading to linear 

252 # interpolation, but it raises if given only one. Mock up 

253 # two for the nearest-neighbor fallback. 

254 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]])) 

255 case _: 

256 return Akima1DInterpolator(loc, val, extrapolate=True) 

257 

258 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None: 

259 y = self._y 

260 z = self._data[:, j] 

261 mask = np.isfinite(z) 

262 if not np.all(mask): 

263 y = y[mask] 

264 z = z[mask] 

265 del mask 

266 return self._make_1d_interpolator(y, z) 

267 

268 

269class SplineFieldSerializationModel(ArchiveTree): 

270 """Serialization model for `SplineField`.""" 

271 

272 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated.")) 

273 

274 data: ArrayReferenceModel = pydantic.Field( 

275 description="2-d data to interpolate. NaNs indicate missing values." 

276 ) 

277 

278 y: InlineArray = pydantic.Field(description="Row positions of the data points.") 

279 

280 x: InlineArray = pydantic.Field(description="Column positions of the data points.") 

281 

282 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.") 

283 

284 field_type: Literal["SPLINE"] = "SPLINE" 

285 

286 def deserialize(self, archive: InputArchive, **kwargs: Any) -> SplineField: 

287 """Deserialize the spline field from an input archive.""" 

288 if kwargs: 

289 raise InvalidParameterError(f"Unrecognized parameters for SplineField: {set(kwargs.keys())}.") 

290 return SplineField( 

291 self.bounds.deserialize(), 

292 archive.get_array(self.data), 

293 y=self.y, 

294 x=self.x, 

295 unit=self.unit, 

296 )