Coverage for python / lsst / images / _transforms / _transform.py: 29%

176 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 08:52 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ( 

15 "Transform", 

16 "TransformCompositionError", 

17 "TransformSerializationModel", 

18) 

19 

20import textwrap 

21from collections.abc import Iterable 

22from typing import TYPE_CHECKING, Any, TypeVar, final 

23 

24import astropy.io.fits.header 

25import astropy.units as u 

26import numpy as np 

27import pydantic 

28 

29from .._concrete_bounds import SerializableBounds 

30from .._geom import XY, Bounds, Box 

31from ..serialization import ArchiveReadError, ArchiveTree, InputArchive, InvalidParameterError, OutputArchive 

32from . import _ast as astshim 

33from ._frames import Frame, SerializableFrame, SkyFrame 

34 

35if TYPE_CHECKING: 

36 from ._projection import Projection 

37 

38 try: 

39 from lsst.afw.geom import TransformPoint2ToPoint2 as LegacyTransform 

40 except ImportError: 

41 type LegacyTransform = Any # type: ignore[no-redef] 

42 

43# These pre-python-3.12 declaration are needed by Sphinx (probably the 

44# autodoc-typehints plugin. 

45I = TypeVar("I", bound=Frame) # noqa: E741 

46O = TypeVar("O", bound=Frame) # noqa: E741 

47P = TypeVar("P", bound=pydantic.BaseModel) 

48 

49 

50class TransformCompositionError(RuntimeError): 

51 """Exception raised when two transforms cannot be composed.""" 

52 

53 

54@final 

55class Transform[I: Frame, O: Frame]: 

56 """A transform that maps two coordinate frames. 

57 

58 Notes 

59 ----- 

60 The `Transform` class constructor is considered a private implementation 

61 detail. Instead of using this, various factory methods are available: 

62 

63 - `from_fits_wcs` constructs a transform from a FITS WCS, as represented 

64 `astropy.wcs.WCS`; 

65 - `then` composes two transforms; 

66 - `identity` constructs a trivial transform that does nothing; 

67 - `inverted` returns the inverse of a transform; 

68 - `from_legacy` converts an `lsst.afw.geom.Transform` instance. 

69 

70 When applied to celestial coordinate systems, ``x=ra`` and ``y=dec``. 

71 `Projection` provides a more natural interface for pixel-to-sky transforms. 

72 

73 `Transform` is conceptually immutable (the internal AST Mapping should 

74 never be modified in-place after construction), and hence does not need to 

75 be copied when any object that holds it is copied. 

76 """ 

77 

78 def __init__( 

79 self, 

80 in_frame: I, 

81 out_frame: O, 

82 ast_mapping: astshim.Mapping, 

83 in_bounds: Bounds | None = None, 

84 out_bounds: Bounds | None = None, 

85 components: Iterable[Transform[Any, Any]] = (), 

86 ): 

87 self._in_frame = in_frame 

88 self._out_frame = out_frame 

89 self._ast_mapping = ast_mapping 

90 self._in_bounds = in_bounds or getattr(in_frame, "bbox", None) 

91 self._out_bounds = out_bounds or getattr(out_frame, "bbox", None) 

92 self._components = list(components) 

93 

94 @staticmethod 

95 def from_fits_wcs( 

96 fits_wcs: astropy.wcs.WCS, 

97 in_frame: I, 

98 out_frame: O, 

99 in_bounds: Bounds | None = None, 

100 out_bounds: Bounds | None = None, 

101 x0: int = 0, 

102 y0: int = 0, 

103 ) -> Transform[I, O]: 

104 """Construct a transform from a FITS WCS. 

105 

106 Parameters 

107 ---------- 

108 fits_wcs 

109 FITS WCS to convert. 

110 in_frame 

111 Coordinate frame for input points to the forward transform. 

112 out_frame 

113 Coordinate frame for output points from the forward transform. 

114 in_bounds 

115 The region that bounds valid input points. 

116 out_bounds 

117 The region that bounds valid output points. 

118 x0 

119 Logical coordinate of the first column in the array this WCS 

120 relates to world coordinates. 

121 y0 

122 Logical coordinate of the first column in the array this WCS 

123 relates to world coordinates. 

124 

125 Notes 

126 ----- 

127 The ``x0`` and ``y0`` parameters reflect the fact that for FITS, the 

128 first row and column are always labeled ``(1, 1)``, while in Astropy 

129 and most other Python libraries, they are ``(0, 0)``. The `types` in 

130 this package (e.g. `Image`, `Mask`) allow them to be any pair of 

131 integers. 

132 

133 See Also 

134 -------- 

135 Projection.from_fits_wcs 

136 """ 

137 ast_stream = astshim.StringStream(fits_wcs.to_header_string(relax=True)) 

138 ast_fits_chan = astshim.FitsChan(ast_stream, "Encoding=FITS-WCS, SipReplace=0, IWC=1") 

139 ast_frame_set = ast_fits_chan.read() 

140 _prepend_ast_shift(ast_frame_set, x=x0 - 1.0, y=y0 - 1.0, ast_domain="PIXEL") 

141 return Transform( 

142 in_frame, 

143 out_frame, 

144 ast_frame_set, 

145 in_bounds=in_bounds, 

146 out_bounds=out_bounds, 

147 ) 

148 

149 @staticmethod 

150 def identity(frame: I) -> Transform[I, I]: 

151 """Construct a trivial transform that maps a frame to itelf. 

152 

153 Parameters 

154 ---------- 

155 frame 

156 Frame used for both input and output points. 

157 """ 

158 return Transform(frame, frame, astshim.UnitMap(2)) 

159 

160 @property 

161 def in_frame(self) -> I: 

162 """Coordinate frame for input points.""" 

163 return self._in_frame 

164 

165 @property 

166 def out_frame(self) -> O: 

167 """Coordinate frame for output points.""" 

168 return self._out_frame 

169 

170 @property 

171 def in_bounds(self) -> Bounds | None: 

172 """The region that bounds valid input points (`Bounds` | `None`).""" 

173 return self._in_bounds 

174 

175 @property 

176 def out_bounds(self) -> Bounds | None: 

177 """The region that bounds valid output points (`Bounds` | `None`).""" 

178 return self._out_bounds 

179 

180 def show(self, simplified: bool = False, comments: bool = False) -> str: 

181 """Return the AST native representation of the transform. 

182 

183 Parameters 

184 ---------- 

185 simplified 

186 Whether to ask AST to simplify the mapping before showing it. 

187 This will make it much more likely that two equivalent transforms 

188 have the same `show` result. If the internal mapping is actually 

189 a frame set (as needed to round-trip legacy 

190 `lsst.afw.geom.SkyWcs` objects), this will also just show the 

191 mapping with no frame set information. 

192 comments 

193 Whether to include descriptive comments. 

194 """ 

195 ast_mapping = self._ast_mapping 

196 if simplified: 

197 if isinstance(ast_mapping, astshim.FrameSet): 

198 ast_mapping = ast_mapping.getMapping() 

199 ast_mapping = ast_mapping.simplified() 

200 return ast_mapping.show(comments) 

201 

202 def apply_forward[T: np.ndarray | float](self, *, x: T, y: T) -> XY[T]: 

203 """Apply the forward transform to one or more points. 

204 

205 Parameters 

206 ---------- 

207 x : `numpy.ndarray` | `float` 

208 ``x`` values of the points to transform. 

209 y : `numpy.ndarray` | `float` 

210 ``y`` values of the points to transform. 

211 

212 Returns 

213 ------- 

214 `XY` [`numpy.ndarray` | `float`] 

215 The transformed point or points. 

216 """ 

217 return _standardize_xy( 

218 _ast_apply( 

219 self._ast_mapping.applyForward, 

220 x=self._in_frame.standardize_x(x), 

221 y=self._in_frame.standardize_y(y), 

222 ), 

223 self._out_frame, 

224 ) 

225 

226 def apply_inverse[T: np.ndarray | float](self, *, x: T, y: T) -> XY[T]: 

227 """Apply the inverse transform to one or more points. 

228 

229 Parameters 

230 ---------- 

231 x : `numpy.ndarray` | `float` 

232 ``x`` values of the points to transform. 

233 y : `numpy.ndarray` | `float` 

234 ``y`` values of the points to transform. 

235 

236 Returns 

237 ------- 

238 `XY` [`numpy.ndarray` | `float`] 

239 The transformed point or points. 

240 """ 

241 return _standardize_xy( 

242 _ast_apply( 

243 self._ast_mapping.applyInverse, 

244 x=self._out_frame.standardize_x(x), 

245 y=self._out_frame.standardize_y(y), 

246 ), 

247 self._in_frame, 

248 ) 

249 

250 def apply_forward_q(self, *, x: u.Quantity, y: u.Quantity) -> XY[u.Quantity]: 

251 """Apply the forward transform to one or more unit-aware points. 

252 

253 Parameters 

254 ---------- 

255 x 

256 ``x`` values of the points to transform. 

257 y 

258 ``y`` values of the points to transform. 

259 

260 Returns 

261 ------- 

262 `XY` [`astropy.units.Quantity`] 

263 The transformed point or points. 

264 """ 

265 xy = self.apply_forward(x=x.to_value(self._in_frame.unit), y=y.to_value(self._in_frame.unit)) 

266 return XY(xy.x * self._out_frame.unit, xy.y * self._out_frame.unit) 

267 

268 def apply_inverse_q(self, *, x: u.Quantity, y: u.Quantity) -> XY[u.Quantity]: 

269 """Apply the inverse transform to one or more unit-aware points. 

270 

271 Parameters 

272 ---------- 

273 x 

274 ``x`` values of the points to transform. 

275 y 

276 ``y`` values of the points to transform. 

277 

278 Returns 

279 ------- 

280 `XY` [`astropy.units.Quantity`] 

281 The transformed point or points. 

282 """ 

283 xy = self.apply_inverse(x=x.to_value(self._out_frame.unit), y=y.to_value(self._out_frame.unit)) 

284 return XY(xy.x * self._in_frame.unit, xy.y * self._in_frame.unit) 

285 

286 def decompose(self) -> list[Transform[Any, Any]]: 

287 """Deconstruct a composed transform into its constituent parts. 

288 

289 Notes 

290 ----- 

291 Most transforms will just return a single-element list holding 

292 ``self``. Identity transform will return an empty list, and 

293 transforms composed with `then` will return the original transforms. 

294 Transforms constructed by `FrameSet` may or may not be decomposable. 

295 """ 

296 if not self._components: 

297 if self.in_frame == self._out_frame: 

298 return [] 

299 else: 

300 return [self] 

301 else: 

302 return list(self._components) 

303 

304 def inverted(self) -> Transform[O, I]: 

305 """Return the inverse of this transform.""" 

306 return Transform[O, I]( 

307 self._out_frame, 

308 self._in_frame, 

309 self._ast_mapping.inverted(), 

310 in_bounds=self.out_bounds, 

311 out_bounds=self.in_bounds, 

312 components=[t.inverted() for t in reversed(self._components)], 

313 ) 

314 

315 def then[F: Frame](self, next: Transform[O, F], remember_components: bool = True) -> Transform[I, F]: 

316 """Compose two transforms into another. 

317 

318 Parameters 

319 ---------- 

320 next 

321 Another transform to apply after ``self``. 

322 remember_components 

323 If `True`, the returned composed transform will remember ``self`` 

324 and ``other`` so they can be returned by `decompose`. 

325 """ 

326 if self._out_frame != next._in_frame: 

327 raise TransformCompositionError( 

328 "Cannot compose transforms that do not share a common intermediate frame: " 

329 f"{self._out_frame} != {next._in_frame}." 

330 ) 

331 components = self.decompose() + next.decompose() if remember_components else () 

332 return Transform( 

333 self._in_frame, 

334 next._out_frame, 

335 self._ast_mapping.then(next._ast_mapping), 

336 in_bounds=self.in_bounds, 

337 out_bounds=next.out_bounds, 

338 components=components, 

339 ) 

340 

341 def as_projection(self: Transform[I, SkyFrame]) -> Projection[I]: 

342 """Return a `Projection` view of this transform. 

343 

344 This is only valid when `out_frame` is `~SkyFrame.ICRS`. 

345 """ 

346 from ._projection import Projection 

347 

348 return Projection(self) 

349 

350 def as_fits_wcs(self, bbox: Box) -> astropy.wcs.WCS | None: 

351 """Return a FITS WCS representation of this transform, if possible. 

352 

353 Parameters 

354 ---------- 

355 bbox 

356 Bounding box of the array the FITS WCS will describe. This 

357 transform object is assumed to work on the same coordinate system 

358 in which ``bbox`` is defined, while the FITS WCS will consider the 

359 first row and column in that box to be ``(0, 0)`` (in Astropy 

360 interfaces) or ``(1, 1)`` (in the FITS representation itself). 

361 

362 Notes 

363 ----- 

364 This method assumes the transform maps pixel coordinates to world 

365 coordinates. 

366 

367 Not all transforms can be represented exactly; when a FITS 

368 represention is not possible, `None` is returned. When the returned 

369 WCS is not `None`, it will have the same functional form, but it may 

370 not evaluate identically due to small implementation differences in 

371 the order of floating-point operations. 

372 """ 

373 ast_frame_set = self._get_ast_frame_set() 

374 _prepend_ast_shift(ast_frame_set, x=1.0 - bbox.x.start, y=1.0 - bbox.y.start, ast_domain="GRID") 

375 ast_stream = astshim.StringStream() 

376 ast_fits_chan = astshim.FitsChan( 

377 ast_stream, "Encoding=FITS-WCS, CDMatrix=1, FitsAxisOrder=<copy>, FitsTol=0.0001" 

378 ) 

379 ast_fits_chan.setFitsI("NAXIS1", bbox.x.size) 

380 ast_fits_chan.setFitsI("NAXIS2", bbox.y.size) 

381 n_writes = ast_fits_chan.write(ast_frame_set) 

382 if not n_writes: 

383 return None 

384 header = astropy.io.fits.Header(astropy.io.fits.Card.fromstring(c) for c in ast_fits_chan) 

385 return astropy.wcs.WCS(header) 

386 

387 def serialize[P: pydantic.BaseModel]( 

388 self, archive: OutputArchive[P], *, use_frame_sets: bool = False 

389 ) -> TransformSerializationModel[P]: 

390 """Serialize a transform to an archive. 

391 

392 Parameters 

393 ---------- 

394 archive 

395 Archive to serialize to. 

396 use_frame_sets 

397 If `True`, decompose the transform and try to reference component 

398 mappings that were already serialized into a `FrameSet` in the 

399 archive. Note that if multiple transforms exist between a pair of 

400 frames (e.g. a `Projection` and its FITS approximation), this may 

401 cause the wrong one to be saved. When this option is used, the 

402 frame set must be saved before the transform, and it must be 

403 deserialized before the transform as well. 

404 

405 Returns 

406 ------- 

407 `TransformSerializationModel` 

408 Serialized form of the transform. 

409 """ 

410 model = TransformSerializationModel[P]() 

411 if use_frame_sets: 

412 for link in self.decompose(): 

413 model.frames.append(link.in_frame.serialize()) 

414 model.bounds.append(link.in_bounds.serialize() if link.in_bounds is not None else None) 

415 for frame_set, pointer in archive.iter_frame_sets(): 

416 if link.in_frame in frame_set and link.out_frame in frame_set: 

417 model.mappings.append(pointer) 

418 break 

419 else: 

420 model.mappings.append(MappingSerializationModel(ast=link._ast_mapping.show())) 

421 else: 

422 model.frames.append(self.in_frame.serialize()) 

423 model.bounds.append(self.in_bounds.serialize() if self.in_bounds is not None else None) 

424 model.mappings.append(MappingSerializationModel(ast=self._ast_mapping.show())) 

425 model.frames.append(self.out_frame.serialize()) 

426 model.bounds.append(self.out_bounds.serialize() if self.out_bounds is not None else None) 

427 return model 

428 

429 @staticmethod 

430 def _get_archive_tree_type[P: pydantic.BaseModel]( 

431 pointer_type: type[P], 

432 ) -> type[TransformSerializationModel[P]]: 

433 """Return the serialization model type for this object for an archive 

434 type that uses the given pointer type. 

435 """ 

436 return TransformSerializationModel[pointer_type] # type: ignore 

437 

438 @staticmethod 

439 def from_legacy( 

440 legacy: LegacyTransform, 

441 in_frame: I, 

442 out_frame: O, 

443 in_bounds: Bounds | None = None, 

444 out_bounds: Bounds | None = None, 

445 ) -> Transform[I, O]: 

446 """Construct a transform from a legacy `lsst.afw.geom.Transform`. 

447 

448 Parameters 

449 ---------- 

450 legacy : `lsst.afw.geom.Transform` 

451 Legacy transform object. 

452 in_frame 

453 Coordinate frame for input points to the forward transform. 

454 out_frame 

455 Coordinate frame for output points from the forward transform. 

456 in_bounds 

457 The region that bounds valid input points. 

458 out_bounds 

459 The region that bounds valid output points. 

460 """ 

461 return Transform( 

462 in_frame, 

463 out_frame, 

464 legacy.getMapping(), 

465 in_bounds=in_bounds, 

466 out_bounds=out_bounds, 

467 ) 

468 

469 def to_legacy(self) -> LegacyTransform: 

470 """Convert to a legacy `lsst.afw.geom.TransformPoint2ToPoint2` 

471 instance. 

472 """ 

473 from lsst.afw.geom import TransformPoint2ToPoint2 as LegacyTransform 

474 

475 return LegacyTransform(self._ast_mapping, False) 

476 

477 def _get_ast_frame_set(self) -> Any: 

478 ast_frame_set = astshim.FrameSet(_make_ast_frame(self._in_frame)) 

479 ast_frame_set.addFrame(astshim.FrameSet.BASE, self._ast_mapping, _make_ast_frame(self._out_frame)) 

480 return ast_frame_set 

481 

482 

483def _ast_apply[T: np.ndarray | float](method: Any, *, x: T, y: T) -> XY[T]: 

484 # TODO: add bounds argument and check inputs 

485 # TODO: broadcast arrays with different shapes. 

486 xy_in = np.vstack([x, y]).astype(np.float64) 

487 xy_out = method(xy_in) 

488 return XY(xy_out[0, :], xy_out[1, :]) 

489 

490 

491def _prepend_ast_shift(ast_frame_set: Any, x: float, y: float, ast_domain: str) -> None: 

492 ast_output_frame_id = ast_frame_set.current 

493 ast_frame_set.addFrame( 

494 astshim.FrameSet.BASE, 

495 astshim.ShiftMap([x, y]), 

496 astshim.Frame(2, f"Domain={ast_domain}"), 

497 ) 

498 ast_frame_set.base = ast_frame_set.current 

499 ast_frame_set.current = ast_output_frame_id 

500 

501 

502def _make_ast_frame(frame: Frame) -> Any: 

503 if frame is SkyFrame.ICRS: 

504 return astshim.SkyFrame("") 

505 ast_frame = astshim.Frame(2, f"Ident={frame._ast_ident}") 

506 if frame.unit is not None: 

507 fits_unit = frame.unit.to_string(format="fits") 

508 ast_frame.setUnit(1, fits_unit) 

509 ast_frame.setUnit(2, fits_unit) 

510 ast_frame.setLabel(1, "x") 

511 ast_frame.setLabel(2, "y") 

512 return ast_frame 

513 

514 

515def _standardize_xy[T: np.ndarray | float](xy: XY[T], frame: Frame) -> XY[T]: 

516 return XY(x=frame.standardize_x(xy.x), y=frame.standardize_y(xy.y)) 

517 

518 

519class MappingSerializationModel(pydantic.BaseModel): 

520 """Serialization model for an AST Mapping.""" 

521 

522 ast: str = pydantic.Field(description="A serialized Starlink AST Mapping, using the AST native encoding.") 

523 

524 

525class TransformSerializationModel[P: pydantic.BaseModel](ArchiveTree): 

526 """Serialization model for coordinate transforms.""" 

527 

528 frames: list[SerializableFrame] = pydantic.Field( 

529 default_factory=list, 

530 description=textwrap.dedent( 

531 """ 

532 List of frames that this transform passes through. 

533 

534 All transforms include at least two frames (the endpoints). Others 

535 intermediate frames may be included to facilitate data-sharing 

536 between transforms. 

537 """ 

538 ), 

539 ) 

540 

541 bounds: list[SerializableBounds | None] = pydantic.Field( 

542 default_factory=list, 

543 description=textwrap.dedent( 

544 """ 

545 List of the bounds of the ``frames`` for this transform. 

546 

547 This always has the same number of elements as ``frames``. 

548 """ 

549 ), 

550 ) 

551 

552 mappings: list[P | MappingSerializationModel] = pydantic.Field( 

553 default_factory=list, 

554 description=textwrap.dedent( 

555 """ 

556 The actual mappings between frames, or archive pointers to 

557 serialized FrameSet objects from which they can be obtained. 

558 

559 This always has one fewer element than ``frames``. 

560 """ 

561 ), 

562 ) 

563 

564 def deserialize(self, archive: InputArchive[P], **kwargs: Any) -> Transform[Any, Any]: 

565 """Deserialize a transform from an archive. 

566 

567 Parameters 

568 ---------- 

569 archive 

570 Archive to read from. 

571 **kwargs 

572 Unsupported keyword arguments are accepted only to provide better 

573 error messages (raising `serialization.InvalidParameterError`). 

574 """ 

575 if kwargs: 

576 raise InvalidParameterError(f"Unrecognized parameters for Transform: {set(kwargs.keys())}.") 

577 if len(self.frames) != len(self.bounds): 

578 raise ArchiveReadError( 

579 f"Inconsistent lengths for 'frames' ({len(self.frames)}) and 'bounds' ({len(self.bounds)})." 

580 ) 

581 if len(self.frames) != len(self.mappings) + 1: 

582 raise ArchiveReadError( 

583 f"Inconsistent lengths for 'frames' ({len(self.frames)}) and " 

584 f"'mappings' ({len(self.mappings)}; should be one less)." 

585 ) 

586 # We can't just compose onto an identity Transform if we want to 

587 # preserve the FrameSet-ness of any of these mappings. 

588 transform: Transform | None = None 

589 for n, mapping in enumerate(self.mappings): 

590 match mapping: 

591 case MappingSerializationModel(ast=serialized_mapping): 

592 ast_mapping = astshim.Mapping.fromString(serialized_mapping) 

593 in_bounds = self.bounds[n] 

594 out_bounds = self.bounds[n + 1] 

595 new_transform = Transform( 

596 self.frames[n].deserialize(), 

597 self.frames[n + 1].deserialize(), 

598 ast_mapping, 

599 in_bounds.deserialize() if in_bounds is not None else None, 

600 out_bounds.deserialize() if out_bounds is not None else None, 

601 ) 

602 case reference: 

603 frame_set = archive.get_frame_set(reference) 

604 new_transform = frame_set[self.frames[n].deserialize(), self.frames[n + 1].deserialize()] 

605 if transform is None: 

606 transform = new_transform 

607 else: 

608 transform = transform.then(new_transform) 

609 if transform is None: 

610 transform = Transform.identity(self.frames[0].deserialize()) 

611 return transform