Coverage for python / lsst / images / fields / _spline.py: 24%

141 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-23 08:27 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("SplineField", "SplineFieldSerializationModel") 

15 

16from typing import TYPE_CHECKING, Any, Literal, final 

17 

18import astropy.units 

19import numpy as np 

20import pydantic 

21from scipy.interpolate import Akima1DInterpolator 

22 

23from .._concrete_bounds import SerializableBounds 

24from .._geom import Bounds, Box 

25from .._image import Image 

26from ..serialization import ( 

27 ArchiveTree, 

28 ArrayReferenceModel, 

29 InlineArray, 

30 InputArchive, 

31 InvalidParameterError, 

32 OutputArchive, 

33 Unit, 

34) 

35from ._base import BaseField 

36 

37if TYPE_CHECKING: 

38 try: 

39 from lsst.afw.math import BackgroundMI as LegacyBackground 

40 except ImportError: 

41 type LegacyBackground = Any # type: ignore[no-redef] 

42 

43 

44@final 

45class SplineField(BaseField): 

46 """A 2-d Akima spline interpolation of data on a regular grid. 

47 

48 Parameters 

49 ---------- 

50 bounds 

51 The region where this field can be evaluated. 

52 data 

53 The data points to be interpolated. Missing values (indicated by NaN) 

54 are allowed. Will be set to read-only in place. 

55 y 

56 Coordinates for the first dimension of ``data``. Will be set to 

57 read-only in place. 

58 x 

59 Coordinates for the second dimension of ``data``. Will be set to 

60 read-only in place. 

61 unit 

62 Units of the field. 

63 

64 Notes 

65 ----- 

66 This field is much faster to evaluate on a grid via `render` than at 

67 arbitrary points via the function-call operator. 

68 """ 

69 

70 def __init__( 

71 self, 

72 bounds: Bounds, 

73 data: np.ndarray, 

74 *, 

75 y: np.ndarray, 

76 x: np.ndarray, 

77 unit: astropy.units.UnitBase | None = None, 

78 ): 

79 if isinstance(data, astropy.units.Quantity): 

80 if unit is not None: 

81 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.") 

82 unit = data.unit 

83 data = data.to_value() 

84 if data.ndim != 2: 

85 raise ValueError("'data' must be 2-d.") 

86 if y.ndim != 1: 

87 raise ValueError("'y' must be 1-d.") 

88 if x.ndim != 1: 

89 raise ValueError("'x' must be 1-d.") 

90 if data.shape != y.shape + x.shape: 

91 raise ValueError( 

92 f"Shape of 2-d 'data' {data.shape} does not match " 

93 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}." 

94 ) 

95 self._bounds = bounds 

96 self._data = data 

97 self._data.flags.writeable = False 

98 self._x = x 

99 self._x.flags.writeable = False 

100 self._y = y 

101 self._y.flags.writeable = False 

102 self._unit = unit 

103 

104 @property 

105 def bounds(self) -> Bounds: 

106 return self._bounds 

107 

108 @property 

109 def unit(self) -> astropy.units.UnitBase | None: 

110 return self._unit 

111 

112 @property 

113 def data(self) -> np.ndarray: 

114 """The data points to be interpolated (`numpy.ndarray`). 

115 

116 May have missing values indicated by NaNs. 

117 """ 

118 return self._data 

119 

120 @property 

121 def x(self) -> np.ndarray: 

122 """Coordinates for the second dimension of `data` (`numpy.ndarray`).""" 

123 return self._x 

124 

125 @property 

126 def y(self) -> np.ndarray: 

127 """Coordinates for the first dimension of `data` (`numpy.ndarray`).""" 

128 return self._y 

129 

130 @property 

131 def is_constant(self) -> bool: 

132 # We really do want an exact floating-point comparison here. 

133 return (self._data == self._data[0, 0]).all() 

134 

135 def evaluate( 

136 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

137 ) -> np.ndarray | astropy.units.Quantity: 

138 y, x = np.broadcast_arrays(y, x) 

139 xg = self._x 

140 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64) 

141 mask = np.zeros(xg.size, dtype=bool) 

142 for j in range(xg.size): 

143 if (y_interpolator := self._make_y_interpolator(j)) is not None: 

144 y_render[j, ...] = y_interpolator(y) 

145 mask[j] = True 

146 if not np.all(mask): 

147 y_render = y_render[mask, ...] 

148 xg = xg[mask] 

149 result = np.zeros(y.shape, dtype=np.float64) 

150 # There doesn't seem to be a way to avoid looping in Python here; 

151 # maybe someday we'll push this down to a compiled language. 

152 for i, xv in np.ndenumerate(x): 

153 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None: 

154 raise ValueError("No valid data points.") 

155 v = x_interpolator(xv) 

156 result[*i] = v 

157 if quantity: 

158 return astropy.units.Quantity(result, self._unit) 

159 return result 

160 

161 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

162 if bbox is None: 

163 bbox = self.bounds.bbox 

164 xg = self._x 

165 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype) 

166 mask = np.zeros(xg.size, dtype=bool) 

167 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points. 

168 if (y_interpolator := self._make_y_interpolator(j)) is not None: 

169 y_render[j, :] = y_interpolator(bbox.y.arange) 

170 mask[j] = True 

171 if not np.all(mask): 

172 y_render = y_render[mask, :] 

173 xg = xg[mask] 

174 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None: 

175 raise ValueError("No valid data points.") 

176 rendered_array = x_interpolator(bbox.x.arange) 

177 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype) 

178 

179 def multiply_constant( 

180 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase 

181 ) -> SplineField: 

182 factor, unit = self._handle_factor_units(factor) 

183 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit) 

184 

185 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel: 

186 """Serialize the spline field to an output archive.""" 

187 return SplineFieldSerializationModel( 

188 bounds=self.bounds.serialize(), 

189 data=archive.add_array(self._data, name="data"), 

190 y=self._y, 

191 x=self._x, 

192 unit=self._unit, 

193 ) 

194 

195 @staticmethod 

196 def _get_archive_tree_type( 

197 pointer_type: type[Any], 

198 ) -> type[SplineFieldSerializationModel]: 

199 """Return the serialization model type for this object for an archive 

200 type that uses the given pointer type. 

201 """ 

202 return SplineFieldSerializationModel 

203 

204 @staticmethod 

205 def from_legacy_background( 

206 legacy_background: LegacyBackground, 

207 bounds: Bounds | None = None, 

208 unit: astropy.units.UnitBase | None = None, 

209 ) -> SplineField: 

210 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance. 

211 

212 Parameters 

213 ---------- 

214 legacy 

215 Legacy background object to convert. 

216 bounds 

217 The bounds of the returned field, if they should be different from 

218 the bounding box of ``legacy_background``. 

219 unit 

220 The units of the returned field (`lsst.afw.math.Background` 

221 objects do not know their units). 

222 

223 Notes 

224 ----- 

225 `SplineField.render` and the `lsst.afw` background interpolator both 

226 use Akima splines, but with slightly different boundary conditions. 

227 They should produce equivalent to single-precision round-off error 

228 when evaluated within the region enclosed by bin centers (i.e. where 

229 no extrapolation is necessary) and when there are five or more 

230 points to be interpolated in each row and column. 

231 """ 

232 from lsst.afw.math import ApproximateControl, Interpolate 

233 

234 bg_control = legacy_background.getBackgroundControl() 

235 approx_control = bg_control.getApproximateControl() 

236 stats_image = legacy_background.getStatsImage() 

237 if approx_control.getStyle() != ApproximateControl.UNKNOWN: 

238 raise TypeError("Legacy background uses Chebyshev approximation, not splines.") 

239 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE: 

240 raise TypeError("Legacy background does not use Akima spline interpolation.") 

241 x = legacy_background.getBinCentersX() 

242 y = legacy_background.getBinCentersY() 

243 return SplineField( 

244 Box.from_legacy(legacy_background.getImageBBox()) if bounds is None else bounds, 

245 stats_image.image.array, 

246 x=x, 

247 y=y, 

248 unit=unit, 

249 ) 

250 

251 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None: 

252 match len(loc): 

253 case 0: 

254 return None 

255 case 1: 

256 # SciPy can handle only two points by downgrading to linear 

257 # interpolation, but it raises if given only one. Mock up 

258 # two for the nearest-neighbor fallback. 

259 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]])) 

260 case _: 

261 return Akima1DInterpolator(loc, val, extrapolate=True) 

262 

263 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None: 

264 y = self._y 

265 z = self._data[:, j] 

266 mask = np.isfinite(z) 

267 if not np.all(mask): 

268 y = y[mask] 

269 z = z[mask] 

270 del mask 

271 return self._make_1d_interpolator(y, z) 

272 

273 

274class SplineFieldSerializationModel(ArchiveTree): 

275 """Serialization model for `SplineField`.""" 

276 

277 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated.")) 

278 

279 data: ArrayReferenceModel = pydantic.Field( 

280 description="2-d data to interpolate. NaNs indicate missing values." 

281 ) 

282 

283 y: InlineArray = pydantic.Field(description="Row positions of the data points.") 

284 

285 x: InlineArray = pydantic.Field(description="Column positions of the data points.") 

286 

287 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.") 

288 

289 field_type: Literal["SPLINE"] = "SPLINE" 

290 

291 def deserialize(self, archive: InputArchive, **kwargs: Any) -> SplineField: 

292 """Deserialize the spline field from an input archive.""" 

293 if kwargs: 

294 raise InvalidParameterError(f"Unrecognized parameters for SplineField: {set(kwargs.keys())}.") 

295 return SplineField( 

296 self.bounds.deserialize(), 

297 archive.get_array(self.data), 

298 y=self.y, 

299 x=self.x, 

300 unit=self.unit, 

301 )