Coverage for python / lsst / images / _transforms / _transform.py: 29%
174 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-15 08:44 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-15 08:44 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = (
15 "Transform",
16 "TransformCompositionError",
17 "TransformSerializationModel",
18)
20import textwrap
21from collections.abc import Iterable
22from typing import TYPE_CHECKING, Any, TypeVar, final
24import astropy.io.fits.header
25import astropy.units as u
26import numpy as np
27import pydantic
29from .._concrete_bounds import SerializableBounds
30from .._geom import XY, Bounds, Box
31from ..serialization import ArchiveReadError, ArchiveTree, InputArchive, OutputArchive
32from . import _ast as astshim
33from ._frames import Frame, SerializableFrame, SkyFrame
35if TYPE_CHECKING:
36 from ._projection import Projection
38 try:
39 from lsst.afw.geom import TransformPoint2ToPoint2 as LegacyTransform
40 except ImportError:
41 type LegacyTransform = Any # type: ignore[no-redef]
43# These pre-python-3.12 declaration are needed by Sphinx (probably the
44# autodoc-typehints plugin.
45I = TypeVar("I", bound=Frame) # noqa: E741
46O = TypeVar("O", bound=Frame) # noqa: E741
47P = TypeVar("P", bound=pydantic.BaseModel)
50class TransformCompositionError(RuntimeError):
51 """Exception raised when two transforms cannot be composed."""
54@final
55class Transform[I: Frame, O: Frame]:
56 """A transform that maps two coordinate frames.
58 Notes
59 -----
60 The `Transform` class constructor is considered a private implementation
61 detail. Instead of using this, various factory methods are available:
63 - `from_fits_wcs` constructs a transform from a FITS WCS, as represented
64 `astropy.wcs.WCS`;
65 - `then` composes two transforms;
66 - `identity` constructs a trivial transform that does nothing;
67 - `inverted` returns the inverse of a transform;
68 - `from_legacy` converts an `lsst.afw.geom.Transform` instance.
70 When applied to celestial coordinate systems, ``x=ra`` and ``y=dec``.
71 `Projection` provides a more natural interface for pixel-to-sky transforms.
73 `Transform` is conceptually immutable (the internal AST Mapping should
74 never be modified in-place after construction), and hence does not need to
75 be copied when any object that holds it is copied.
76 """
78 def __init__(
79 self,
80 in_frame: I,
81 out_frame: O,
82 ast_mapping: astshim.Mapping,
83 in_bounds: Bounds | None = None,
84 out_bounds: Bounds | None = None,
85 components: Iterable[Transform[Any, Any]] = (),
86 ):
87 self._in_frame = in_frame
88 self._out_frame = out_frame
89 self._ast_mapping = ast_mapping
90 self._in_bounds = in_bounds or getattr(in_frame, "bbox", None)
91 self._out_bounds = out_bounds or getattr(out_frame, "bbox", None)
92 self._components = list(components)
94 @staticmethod
95 def from_fits_wcs(
96 fits_wcs: astropy.wcs.WCS,
97 in_frame: I,
98 out_frame: O,
99 in_bounds: Bounds | None = None,
100 out_bounds: Bounds | None = None,
101 x0: int = 0,
102 y0: int = 0,
103 ) -> Transform[I, O]:
104 """Construct a transform from a FITS WCS.
106 Parameters
107 ----------
108 fits_wcs
109 FITS WCS to convert.
110 in_frame
111 Coordinate frame for input points to the forward transform.
112 out_frame
113 Coordinate frame for output points from the forward transform.
114 in_bounds
115 The region that bounds valid input points.
116 out_bounds
117 The region that bounds valid output points.
118 x0
119 Logical coordinate of the first column in the array this WCS
120 relates to world coordinates.
121 y0
122 Logical coordinate of the first column in the array this WCS
123 relates to world coordinates.
125 Notes
126 -----
127 The ``x0`` and ``y0`` parameters reflect the fact that for FITS, the
128 first row and column are always labeled ``(1, 1)``, while in Astropy
129 and most other Python libraries, they are ``(0, 0)``. The `types` in
130 this package (e.g. `Image`, `Mask`) allow them to be any pair of
131 integers.
133 See Also
134 --------
135 Projection.from_fits_wcs
136 """
137 ast_stream = astshim.StringStream(fits_wcs.to_header_string(relax=True))
138 ast_fits_chan = astshim.FitsChan(ast_stream, "Encoding=FITS-WCS, SipReplace=0, IWC=1")
139 ast_frame_set = ast_fits_chan.read()
140 _prepend_ast_shift(ast_frame_set, x=x0 - 1.0, y=y0 - 1.0, ast_domain="PIXEL")
141 return Transform(
142 in_frame,
143 out_frame,
144 ast_frame_set,
145 in_bounds=in_bounds,
146 out_bounds=out_bounds,
147 )
149 @staticmethod
150 def identity(frame: I) -> Transform[I, I]:
151 """Construct a trivial transform that maps a frame to itelf.
153 Parameters
154 ----------
155 frame
156 Frame used for both input and output points.
157 """
158 return Transform(frame, frame, astshim.UnitMap(2))
160 @property
161 def in_frame(self) -> I:
162 """Coordinate frame for input points."""
163 return self._in_frame
165 @property
166 def out_frame(self) -> O:
167 """Coordinate frame for output points."""
168 return self._out_frame
170 @property
171 def in_bounds(self) -> Bounds | None:
172 """The region that bounds valid input points (`Bounds` | `None`)."""
173 return self._in_bounds
175 @property
176 def out_bounds(self) -> Bounds | None:
177 """The region that bounds valid output points (`Bounds` | `None`)."""
178 return self._out_bounds
180 def show(self, simplified: bool = False, comments: bool = False) -> str:
181 """Return the AST native representation of the transform.
183 Parameters
184 ----------
185 simplified
186 Whether to ask AST to simplify the mapping before showing it.
187 This will make it much more likely that two equivalent transforms
188 have the same `show` result. If the internal mapping is actually
189 a frame set (as needed to round-trip legacy
190 `lsst.afw.geom.SkyWcs` objects), this will also just show the
191 mapping with no frame set information.
192 comments
193 Whether to include descriptive comments.
194 """
195 ast_mapping = self._ast_mapping
196 if simplified:
197 if isinstance(ast_mapping, astshim.FrameSet):
198 ast_mapping = ast_mapping.getMapping()
199 ast_mapping = ast_mapping.simplified()
200 return ast_mapping.show(comments)
202 def apply_forward[T: np.ndarray | float](self, *, x: T, y: T) -> XY[T]:
203 """Apply the forward transform to one or more points.
205 Parameters
206 ----------
207 x : `numpy.ndarray` | `float`
208 ``x`` values of the points to transform.
209 y : `numpy.ndarray` | `float`
210 ``y`` values of the points to transform.
212 Returns
213 -------
214 `XY` [`numpy.ndarray` | `float`]
215 The transformed point or points.
216 """
217 return _standardize_xy(
218 _ast_apply(
219 self._ast_mapping.applyForward,
220 x=self._in_frame.standardize_x(x),
221 y=self._in_frame.standardize_y(y),
222 ),
223 self._out_frame,
224 )
226 def apply_inverse[T: np.ndarray | float](self, *, x: T, y: T) -> XY[T]:
227 """Apply the inverse transform to one or more points.
229 Parameters
230 ----------
231 x : `numpy.ndarray` | `float`
232 ``x`` values of the points to transform.
233 y : `numpy.ndarray` | `float`
234 ``y`` values of the points to transform.
236 Returns
237 -------
238 `XY` [`numpy.ndarray` | `float`]
239 The transformed point or points.
240 """
241 return _standardize_xy(
242 _ast_apply(
243 self._ast_mapping.applyInverse,
244 x=self._out_frame.standardize_x(x),
245 y=self._out_frame.standardize_y(y),
246 ),
247 self._in_frame,
248 )
250 def apply_forward_q(self, *, x: u.Quantity, y: u.Quantity) -> XY[u.Quantity]:
251 """Apply the forward transform to one or more unit-aware points.
253 Parameters
254 ----------
255 x
256 ``x`` values of the points to transform.
257 y
258 ``y`` values of the points to transform.
260 Returns
261 -------
262 `XY` [`astropy.units.Quantity`]
263 The transformed point or points.
264 """
265 xy = self.apply_forward(x=x.to_value(self._in_frame.unit), y=y.to_value(self._in_frame.unit))
266 return XY(xy.x * self._out_frame.unit, xy.y * self._out_frame.unit)
268 def apply_inverse_q(self, *, x: u.Quantity, y: u.Quantity) -> XY[u.Quantity]:
269 """Apply the inverse transform to one or more unit-aware points.
271 Parameters
272 ----------
273 x
274 ``x`` values of the points to transform.
275 y
276 ``y`` values of the points to transform.
278 Returns
279 -------
280 `XY` [`astropy.units.Quantity`]
281 The transformed point or points.
282 """
283 xy = self.apply_inverse(x=x.to_value(self._out_frame.unit), y=y.to_value(self._out_frame.unit))
284 return XY(xy.x * self._in_frame.unit, xy.y * self._in_frame.unit)
286 def decompose(self) -> list[Transform[Any, Any]]:
287 """Deconstruct a composed transform into its constituent parts.
289 Notes
290 -----
291 Most transforms will just return a single-element list holding
292 ``self``. Identity transform will return an empty list, and
293 transforms composed with `then` will return the original transforms.
294 Transforms constructed by `FrameSet` may or may not be decomposable.
295 """
296 if not self._components:
297 if self.in_frame == self._out_frame:
298 return []
299 else:
300 return [self]
301 else:
302 return list(self._components)
304 def inverted(self) -> Transform[O, I]:
305 """Return the inverse of this transform."""
306 return Transform[O, I](
307 self._out_frame,
308 self._in_frame,
309 self._ast_mapping.inverted(),
310 in_bounds=self.out_bounds,
311 out_bounds=self.in_bounds,
312 components=[t.inverted() for t in reversed(self._components)],
313 )
315 def then[F: Frame](self, next: Transform[O, F], remember_components: bool = True) -> Transform[I, F]:
316 """Compose two transforms into another.
318 Parameters
319 ----------
320 next
321 Another transform to apply after ``self``.
322 remember_components
323 If `True`, the returned composed transform will remember ``self``
324 and ``other`` so they can be returned by `decompose`.
325 """
326 if self._out_frame != next._in_frame:
327 raise TransformCompositionError(
328 "Cannot compose transforms that do not share a common intermediate frame: "
329 f"{self._out_frame} != {next._in_frame}."
330 )
331 components = self.decompose() + next.decompose() if remember_components else ()
332 return Transform(
333 self._in_frame,
334 next._out_frame,
335 self._ast_mapping.then(next._ast_mapping),
336 in_bounds=self.in_bounds,
337 out_bounds=next.out_bounds,
338 components=components,
339 )
341 def as_projection(self: Transform[I, SkyFrame]) -> Projection[I]:
342 """Return a `Projection` view of this transform.
344 This is only valid when `out_frame` is `~SkyFrame.ICRS`.
345 """
346 from ._projection import Projection
348 return Projection(self)
350 def as_fits_wcs(self, bbox: Box) -> astropy.wcs.WCS | None:
351 """Return a FITS WCS representation of this transform, if possible.
353 Parameters
354 ----------
355 bbox
356 Bounding box of the array the FITS WCS will describe. This
357 transform object is assumed to work on the same coordinate system
358 in which ``bbox`` is defined, while the FITS WCS will consider the
359 first row and column in that box to be ``(0, 0)`` (in Astropy
360 interfaces) or ``(1, 1)`` (in the FITS representation itself).
362 Notes
363 -----
364 This method assumes the transform maps pixel coordinates to world
365 coordinates.
367 Not all transforms can be represented exactly; when a FITS
368 represention is not possible, `None` is returned. When the returned
369 WCS is not `None`, it will have the same functional form, but it may
370 not evaluate identically due to small implementation differences in
371 the order of floating-point operations.
372 """
373 ast_frame_set = self._get_ast_frame_set()
374 _prepend_ast_shift(ast_frame_set, x=1.0 - bbox.x.start, y=1.0 - bbox.y.start, ast_domain="GRID")
375 ast_stream = astshim.StringStream()
376 ast_fits_chan = astshim.FitsChan(
377 ast_stream, "Encoding=FITS-WCS, CDMatrix=1, FitsAxisOrder=<copy>, FitsTol=0.0001"
378 )
379 ast_fits_chan.setFitsI("NAXIS1", bbox.x.size)
380 ast_fits_chan.setFitsI("NAXIS2", bbox.y.size)
381 n_writes = ast_fits_chan.write(ast_frame_set)
382 if not n_writes:
383 return None
384 header = astropy.io.fits.Header(astropy.io.fits.Card.fromstring(c) for c in ast_fits_chan)
385 return astropy.wcs.WCS(header)
387 def serialize[P: pydantic.BaseModel](
388 self, archive: OutputArchive[P], *, use_frame_sets: bool = False
389 ) -> TransformSerializationModel[P]:
390 """Serialize a transform to an archive.
392 Parameters
393 ----------
394 archive
395 Archive to serialize to.
396 use_frame_sets
397 If `True`, decompose the transform and try to reference component
398 mappings that were already serialized into a `FrameSet` in the
399 archive. Note that if multiple transforms exist between a pair of
400 frames (e.g. a `Projection` and its FITS approximation), this may
401 cause the wrong one to be saved. When this option is used, the
402 frame set must be saved before the transform, and it must be
403 deserialized before the transform as well.
405 Returns
406 -------
407 `TransformSerializationModel`
408 Serialized form of the transform.
409 """
410 model = TransformSerializationModel[P]()
411 if use_frame_sets:
412 for link in self.decompose():
413 model.frames.append(link.in_frame.serialize())
414 model.bounds.append(link.in_bounds.serialize() if link.in_bounds is not None else None)
415 for frame_set, pointer in archive.iter_frame_sets():
416 if link.in_frame in frame_set and link.out_frame in frame_set:
417 model.mappings.append(pointer)
418 break
419 else:
420 model.mappings.append(MappingSerializationModel(ast=link._ast_mapping.show()))
421 else:
422 model.frames.append(self.in_frame.serialize())
423 model.bounds.append(self.in_bounds.serialize() if self.in_bounds is not None else None)
424 model.mappings.append(MappingSerializationModel(ast=self._ast_mapping.show()))
425 model.frames.append(self.out_frame.serialize())
426 model.bounds.append(self.out_bounds.serialize() if self.out_bounds is not None else None)
427 return model
429 @staticmethod
430 def _get_archive_tree_type[P: pydantic.BaseModel](
431 pointer_type: type[P],
432 ) -> type[TransformSerializationModel[P]]:
433 """Return the serialization model type for this object for an archive
434 type that uses the given pointer type.
435 """
436 return TransformSerializationModel[pointer_type] # type: ignore
438 @staticmethod
439 def from_legacy(
440 legacy: LegacyTransform,
441 in_frame: I,
442 out_frame: O,
443 in_bounds: Bounds | None = None,
444 out_bounds: Bounds | None = None,
445 ) -> Transform[I, O]:
446 """Construct a transform from a legacy `lsst.afw.geom.Transform`.
448 Parameters
449 ----------
450 legacy : `lsst.afw.geom.Transform`
451 Legacy transform object.
452 in_frame
453 Coordinate frame for input points to the forward transform.
454 out_frame
455 Coordinate frame for output points from the forward transform.
456 in_bounds
457 The region that bounds valid input points.
458 out_bounds
459 The region that bounds valid output points.
460 """
461 return Transform(
462 in_frame,
463 out_frame,
464 legacy.getMapping(),
465 in_bounds=in_bounds,
466 out_bounds=out_bounds,
467 )
469 def to_legacy(self) -> LegacyTransform:
470 """Convert to a legacy `lsst.afw.geom.TransformPoint2ToPoint2`
471 instance.
472 """
473 from lsst.afw.geom import TransformPoint2ToPoint2 as LegacyTransform
475 return LegacyTransform(self._ast_mapping, False)
477 def _get_ast_frame_set(self) -> Any:
478 ast_frame_set = astshim.FrameSet(_make_ast_frame(self._in_frame))
479 ast_frame_set.addFrame(astshim.FrameSet.BASE, self._ast_mapping, _make_ast_frame(self._out_frame))
480 return ast_frame_set
483def _ast_apply[T: np.ndarray | float](method: Any, *, x: T, y: T) -> XY[T]:
484 # TODO: add bounds argument and check inputs
485 # TODO: broadcast arrays with different shapes.
486 xy_in = np.vstack([x, y]).astype(np.float64)
487 xy_out = method(xy_in)
488 return XY(xy_out[0, :], xy_out[1, :])
491def _prepend_ast_shift(ast_frame_set: Any, x: float, y: float, ast_domain: str) -> None:
492 ast_output_frame_id = ast_frame_set.current
493 ast_frame_set.addFrame(
494 astshim.FrameSet.BASE,
495 astshim.ShiftMap([x, y]),
496 astshim.Frame(2, f"Domain={ast_domain}"),
497 )
498 ast_frame_set.base = ast_frame_set.current
499 ast_frame_set.current = ast_output_frame_id
502def _make_ast_frame(frame: Frame) -> Any:
503 if frame is SkyFrame.ICRS:
504 return astshim.SkyFrame("")
505 ast_frame = astshim.Frame(2, f"Ident={frame._ast_ident}")
506 if frame.unit is not None:
507 fits_unit = frame.unit.to_string(format="fits")
508 ast_frame.setUnit(1, fits_unit)
509 ast_frame.setUnit(2, fits_unit)
510 ast_frame.setLabel(1, "x")
511 ast_frame.setLabel(2, "y")
512 return ast_frame
515def _standardize_xy[T: np.ndarray | float](xy: XY[T], frame: Frame) -> XY[T]:
516 return XY(x=frame.standardize_x(xy.x), y=frame.standardize_y(xy.y))
519class MappingSerializationModel(pydantic.BaseModel):
520 """Serialization model for an AST Mapping."""
522 ast: str = pydantic.Field(description="A serialized Starlink AST Mapping, using the AST native encoding.")
525class TransformSerializationModel[P: pydantic.BaseModel](ArchiveTree):
526 """Serialization model for coordinate transforms."""
528 frames: list[SerializableFrame] = pydantic.Field(
529 default_factory=list,
530 description=textwrap.dedent(
531 """
532 List of frames that this transform passes through.
534 All transforms include at least two frames (the endpoints). Others
535 intermediate frames may be included to facilitate data-sharing
536 between transforms.
537 """
538 ),
539 )
541 bounds: list[SerializableBounds | None] = pydantic.Field(
542 default_factory=list,
543 description=textwrap.dedent(
544 """
545 List of the bounds of the ``frames`` for this transform.
547 This always has the same number of elements as ``frames``.
548 """
549 ),
550 )
552 mappings: list[P | MappingSerializationModel] = pydantic.Field(
553 default_factory=list,
554 description=textwrap.dedent(
555 """
556 The actual mappings between frames, or archive pointers to
557 serialized FrameSet objects from which they can be obtained.
559 This always has one fewer element than ``frames``.
560 """
561 ),
562 )
564 def deserialize(self, archive: InputArchive[P]) -> Transform[Any, Any]:
565 """Deserialize a transform from an archive.
567 Parameters
568 ----------
569 archive
570 Archive to read from.
571 """
572 if len(self.frames) != len(self.bounds):
573 raise ArchiveReadError(
574 f"Inconsistent lengths for 'frames' ({len(self.frames)}) and 'bounds' ({len(self.bounds)})."
575 )
576 if len(self.frames) != len(self.mappings) + 1:
577 raise ArchiveReadError(
578 f"Inconsistent lengths for 'frames' ({len(self.frames)}) and "
579 f"'mappings' ({len(self.mappings)}; should be one less)."
580 )
581 # We can't just compose onto an identity Transform if we want to
582 # preserve the FrameSet-ness of any of these mappings.
583 transform: Transform | None = None
584 for n, mapping in enumerate(self.mappings):
585 match mapping:
586 case MappingSerializationModel(ast=serialized_mapping):
587 ast_mapping = astshim.Mapping.fromString(serialized_mapping)
588 in_bounds = self.bounds[n]
589 out_bounds = self.bounds[n + 1]
590 new_transform = Transform(
591 self.frames[n].deserialize(),
592 self.frames[n + 1].deserialize(),
593 ast_mapping,
594 in_bounds.deserialize() if in_bounds is not None else None,
595 out_bounds.deserialize() if out_bounds is not None else None,
596 )
597 case reference:
598 frame_set = archive.get_frame_set(reference)
599 new_transform = frame_set[self.frames[n].deserialize(), self.frames[n + 1].deserialize()]
600 if transform is None:
601 transform = new_transform
602 else:
603 transform = transform.then(new_transform)
604 if transform is None:
605 transform = Transform.identity(self.frames[0].deserialize())
606 return transform