Coverage for python / lsst / afw / geom / testUtils.py: 10%
447 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 00:45 -0700
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 00:45 -0700
1# This file is part of afw.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["BoxGrid", "makeSipIwcToPixel", "makeSipPixelToIwc"]
24import itertools
25import math
26import pickle
28import astshim as ast
29import numpy as np
30from numpy.testing import assert_allclose, assert_array_equal
31from astshim.test import makeForwardPolyMap, makeTwoWayPolyMap
32from ._geom import getCdMatrixFromMetadata
34import lsst.geom
35import lsst.afw.geom as afwGeom
36from lsst.pex.exceptions import InvalidParameterError
37import lsst.utils
38import lsst.utils.tests
41class BoxGrid:
42 """Divide a box into nx by ny sub-boxes that tile the region
44 The sub-boxes will be of the same type as `box` and will exactly tile `box`;
45 they will also all be the same size, to the extent possible (some variation
46 is inevitable for integer boxes that cannot be evenly divided.
48 Parameters
49 ----------
50 box : `lsst.geom.Box2I` or `lsst.geom.Box2D`
51 the box to subdivide; the boxes in the grid will be of the same type
52 numColRow : pair of `int`
53 number of columns and rows
54 """
56 def __init__(self, box, numColRow):
57 if len(numColRow) != 2:
58 raise RuntimeError(f"numColRow={numColRow!r}; must be a sequence of two integers")
59 self._numColRow = tuple(int(val) for val in numColRow)
61 if isinstance(box, lsst.geom.Box2I):
62 stopDelta = 1
63 elif isinstance(box, lsst.geom.Box2D):
64 stopDelta = 0
65 else:
66 raise RuntimeError(f"Unknown class {type(box)} of box {box}")
67 self.boxClass = type(box)
68 self.stopDelta = stopDelta
70 minPoint = box.getMin()
71 self.pointClass = type(minPoint)
72 dtype = np.array(minPoint).dtype
74 self._divList = [np.linspace(start=box.getMin()[i],
75 stop=box.getMax()[i] + self.stopDelta,
76 num=self._numColRow[i] + 1,
77 endpoint=True,
78 dtype=dtype) for i in range(2)]
80 @property
81 def numColRow(self):
82 return self._numColRow
84 def __getitem__(self, indXY):
85 """Return the box at the specified x,y index
87 Parameters
88 ----------
89 indXY : pair of `ints`
90 the x,y index to return
92 Returns
93 -------
94 subBox : `lsst.geom.Box2I` or `lsst.geom.Box2D`
95 """
96 beg = self.pointClass(*[self._divList[i][indXY[i]] for i in range(2)])
97 end = self.pointClass(
98 *[self._divList[i][indXY[i] + 1] - self.stopDelta for i in range(2)])
99 return self.boxClass(beg, end)
101 def __len__(self):
102 return self.shape[0]*self.shape[1]
104 def __iter__(self):
105 """Return an iterator over all boxes, where column varies most quickly
106 """
107 for row in range(self.numColRow[1]):
108 for col in range(self.numColRow[0]):
109 yield self[col, row]
112class FrameSetInfo:
113 """Information about a FrameSet
115 Parameters
116 ----------
117 frameSet : `ast.FrameSet`
118 The FrameSet about which you want information
120 Notes
121 -----
122 **Fields**
124 baseInd : `int`
125 Index of base frame
126 currInd : `int`
127 Index of current frame
128 isBaseSkyFrame : `bool`
129 Is the base frame an `ast.SkyFrame`?
130 isCurrSkyFrame : `bool`
131 Is the current frame an `ast.SkyFrame`?
132 """
133 def __init__(self, frameSet):
134 self.baseInd = frameSet.base
135 self.currInd = frameSet.current
136 self.isBaseSkyFrame = frameSet.getFrame(self.baseInd).className == "SkyFrame"
137 self.isCurrSkyFrame = frameSet.getFrame(self.currInd).className == "SkyFrame"
140def makeSipPolyMapCoeffs(metadata, name):
141 """Return a list of ast.PolyMap coefficients for the specified SIP matrix
143 The returned list of coefficients for an ast.PolyMap
144 that computes the following function:
146 f(dxy) = dxy + sipPolynomial(dxy))
147 where dxy = pixelPosition - pixelOrigin
148 and sipPolynomial is a polynomial with terms `<name>n_m for x^n y^m`
149 (e.g. A2_0 is the coefficient for x^2 y^0)
151 Parameters
152 ----------
153 metadata : lsst.daf.base.PropertySet
154 FITS metadata describing a WCS with the specified SIP coefficients
155 name : str
156 The desired SIP terms: one of A, B, AP, BP
158 Returns
159 -------
160 list
161 A list of coefficients for an ast.PolyMap that computes
162 the specified SIP polynomial, including a term for out = in
164 Note
165 ----
166 This is an internal function for use by makeSipIwcToPixel
167 and makeSipPixelToIwc
168 """
169 outAxisDict = dict(A=1, B=2, AP=1, BP=2)
170 outAxis = outAxisDict.get(name)
171 if outAxis is None:
172 raise RuntimeError(f"{name} not a supported SIP name")
173 width = metadata.getAsInt(f"{name}_ORDER") + 1
174 found = False
175 # start with a term for out = in
176 coeffs = []
177 if outAxis == 1:
178 coeffs.append([1.0, outAxis, 1, 0])
179 else:
180 coeffs.append([1.0, outAxis, 0, 1])
181 # add SIP distortion terms
182 for xPower in range(width):
183 for yPower in range(width):
184 coeffName = f"{name}_{xPower}_{yPower}"
185 if not metadata.exists(coeffName):
186 continue
187 found = True
188 coeff = metadata.getAsDouble(coeffName)
189 coeffs.append([coeff, outAxis, xPower, yPower])
190 if not found:
191 raise RuntimeError(f"No {name} coefficients found")
192 return coeffs
195def makeSipIwcToPixel(metadata):
196 """Make an IWC to pixel transform with SIP distortion from FITS-WCS metadata
198 This function is primarily intended for unit tests.
199 IWC is intermediate world coordinates, as described in the FITS papers.
201 Parameters
202 ----------
203 metadata : lsst.daf.base.PropertySet
204 FITS metadata describing a WCS with inverse SIP coefficients
206 Returns
207 -------
208 lsst.afw.geom.TransformPoint2ToPoint2
209 Transform from IWC position to pixel position (zero-based)
210 in the forward direction. The inverse direction is not defined.
212 Notes
213 -----
215 The inverse SIP terms APn_m, BPn_m are polynomial coefficients x^n y^m
216 for computing transformed x, y respectively. If we call the resulting
217 polynomial inverseSipPolynomial, the returned transformation is:
219 pixelPosition = pixel origin + uv + inverseSipPolynomial(uv)
220 where uv = inverseCdMatrix * iwcPosition
221 """
222 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1)
223 pixelRelativeToAbsoluteMap = ast.ShiftMap(crpix)
224 cdMatrix = getCdMatrixFromMetadata(metadata)
225 cdMatrixMap = ast.MatrixMap(cdMatrix.copy())
226 coeffList = makeSipPolyMapCoeffs(metadata, "AP") + makeSipPolyMapCoeffs(metadata, "BP")
227 coeffArr = np.array(coeffList, dtype=float)
228 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0")
230 iwcToPixelMap = cdMatrixMap.inverted().then(sipPolyMap).then(pixelRelativeToAbsoluteMap)
231 return afwGeom.TransformPoint2ToPoint2(iwcToPixelMap)
234def makeSipPixelToIwc(metadata):
235 """Make a pixel to IWC transform with SIP distortion from FITS-WCS metadata
237 This function is primarily intended for unit tests.
238 IWC is intermediate world coordinates, as described in the FITS papers.
240 Parameters
241 ----------
242 metadata : lsst.daf.base.PropertySet
243 FITS metadata describing a WCS with forward SIP coefficients
245 Returns
246 -------
247 lsst.afw.geom.TransformPoint2ToPoint2
248 Transform from pixel position (zero-based) to IWC position
249 in the forward direction. The inverse direction is not defined.
251 Notes
252 -----
254 The forward SIP terms An_m, Bn_m are polynomial coefficients x^n y^m
255 for computing transformed x, y respectively. If we call the resulting
256 polynomial sipPolynomial, the returned transformation is:
258 iwcPosition = cdMatrix * (dxy + sipPolynomial(dxy))
259 where dxy = pixelPosition - pixelOrigin
260 """
261 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1)
262 pixelAbsoluteToRelativeMap = ast.ShiftMap(crpix).inverted()
263 cdMatrix = getCdMatrixFromMetadata(metadata)
264 cdMatrixMap = ast.MatrixMap(cdMatrix.copy())
265 coeffList = makeSipPolyMapCoeffs(metadata, "A") + makeSipPolyMapCoeffs(metadata, "B")
266 coeffArr = np.array(coeffList, dtype=float)
267 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0")
268 pixelToIwcMap = pixelAbsoluteToRelativeMap.then(sipPolyMap).then(cdMatrixMap)
269 return afwGeom.TransformPoint2ToPoint2(pixelToIwcMap)
272class PermutedFrameSet:
273 """A FrameSet with base or current frame possibly permuted, with associated
274 information
276 Only two-axis frames will be permuted.
278 Parameters
279 ----------
280 frameSet : `ast.FrameSet`
281 The FrameSet you wish to permute. A deep copy is made.
282 permuteBase : `bool`
283 Permute the base frame's axes?
284 permuteCurr : `bool`
285 Permute the current frame's axes?
287 Raises
288 ------
289 RuntimeError
290 If you try to permute a frame that does not have 2 axes
292 Notes
293 -----
294 **Fields**
296 frameSet : `ast.FrameSet`
297 The FrameSet that may be permuted. A local copy is made.
298 isBaseSkyFrame : `bool`
299 Is the base frame an `ast.SkyFrame`?
300 isCurrSkyFrame : `bool`
301 Is the current frame an `ast.SkyFrame`?
302 isBasePermuted : `bool`
303 Are the base frame axes permuted?
304 isCurrPermuted : `bool`
305 Are the current frame axes permuted?
306 """
307 def __init__(self, frameSet, permuteBase, permuteCurr):
308 self.frameSet = frameSet.copy()
309 fsInfo = FrameSetInfo(self.frameSet)
310 self.isBaseSkyFrame = fsInfo.isBaseSkyFrame
311 self.isCurrSkyFrame = fsInfo.isCurrSkyFrame
312 if permuteBase:
313 baseNAxes = self.frameSet.getFrame(fsInfo.baseInd).nAxes
314 if baseNAxes != 2:
315 raise RuntimeError("Base frame has {} axes; 2 required to permute".format(baseNAxes))
316 self.frameSet.current = fsInfo.baseInd
317 self.frameSet.permAxes([2, 1])
318 self.frameSet.current = fsInfo.currInd
319 if permuteCurr:
320 currNAxes = self.frameSet.getFrame(fsInfo.currInd).nAxes
321 if currNAxes != 2:
322 raise RuntimeError("Current frame has {} axes; 2 required to permute".format(currNAxes))
323 assert self.frameSet.getFrame(fsInfo.currInd).nAxes == 2
324 self.frameSet.permAxes([2, 1])
325 self.isBasePermuted = permuteBase
326 self.isCurrPermuted = permuteCurr
329class TransformTestBaseClass(lsst.utils.tests.TestCase):
330 """Base class for unit tests of Transform<X>To<Y>
332 Subclasses must call `TransformTestBaseClass.setUp(self)`
333 if they provide their own version.
334 """
336 def setUp(self):
337 """Set up a test
339 Subclasses should call this method if they override setUp.
340 """
341 # tell unittest to use the msg argument of asserts as a supplement
342 # to the error message, rather than as the whole error message
343 self.longMessage = True
345 # list of endpoint class name prefixes; the full name is prefix + "Endpoint"
346 self.endpointPrefixes = ("Generic", "Point2", "SpherePoint")
348 # GoodNAxes is dict of endpoint class name prefix:
349 # tuple containing 0 or more valid numbers of axes
350 self.goodNAxes = {
351 "Generic": (1, 2, 3, 4), # all numbers of axes are valid for GenericEndpoint
352 "Point2": (2,),
353 "SpherePoint": (2,),
354 }
356 # BadAxes is dict of endpoint class name prefix:
357 # tuple containing 0 or more invalid numbers of axes
358 self.badNAxes = {
359 "Generic": (), # all numbers of axes are valid for GenericEndpoint
360 "Point2": (1, 3, 4),
361 "SpherePoint": (1, 3, 4),
362 }
364 # Dict of frame index: identity name for frames created by makeFrameSet
365 self.frameIdentDict = {
366 1: "baseFrame",
367 2: "frame2",
368 3: "frame3",
369 4: "currFrame",
370 }
372 @staticmethod
373 def makeRawArrayData(nPoints, nAxes, delta=0.123):
374 """Make an array of generic point data
376 The data will be suitable for spherical points
378 Parameters
379 ----------
380 nPoints : `int`
381 Number of points in the array
382 nAxes : `int`
383 Number of axes in the point
385 Returns
386 -------
387 np.array of floats with shape (nAxes, nPoints)
388 The values are as follows; if nAxes != 2:
389 The first point has values `[0, delta, 2*delta, ..., (nAxes-1)*delta]`
390 The Nth point has those values + N
391 if nAxes == 2 then the data is scaled so that the max value of axis 1
392 is a bit less than pi/2
393 """
394 delta = 0.123
395 # oneAxis = [0, 1, 2, ...nPoints-1]
396 oneAxis = np.arange(nPoints, dtype=float) # [0, 1, 2...]
397 # rawData = [oneAxis, oneAxis + delta, oneAxis + 2 delta, ...]
398 rawData = np.array([j * delta + oneAxis for j in range(nAxes)], dtype=float)
399 if nAxes == 2:
400 # scale rawData so that max value of 2nd axis is a bit less than pi/2,
401 # thus making the data safe for SpherePoint
402 maxLatitude = np.max(rawData[1])
403 rawData *= math.pi * 0.4999 / maxLatitude
404 return rawData
406 @staticmethod
407 def makeRawPointData(nAxes, delta=0.123):
408 """Make one generic point
410 Parameters
411 ----------
412 nAxes : `int`
413 Number of axes in the point
414 delta : `float`
415 Increment between axis values
417 Returns
418 -------
419 A list of `nAxes` floats with values `[0, delta, ..., (nAxes-1)*delta]
420 """
421 return [i*delta for i in range(nAxes)]
423 @staticmethod
424 def makeEndpoint(name, nAxes=None):
425 """Make an endpoint
427 Parameters
428 ----------
429 name : `str`
430 Endpoint class name prefix; the full class name is name + "Endpoint"
431 nAxes : `int` or `None`, optional
432 number of axes; an int is required if `name` == "Generic";
433 otherwise ignored
435 Returns
436 -------
437 subclass of `lsst.afw.geom.BaseEndpoint`
438 The constructed endpoint
440 Raises
441 ------
442 TypeError
443 If `name` == "Generic" and `nAxes` is None or <= 0
444 """
445 EndpointClassName = name + "Endpoint"
446 EndpointClass = getattr(afwGeom, EndpointClassName)
447 if name == "Generic":
448 if nAxes is None:
449 raise TypeError("nAxes must be an integer for GenericEndpoint")
450 return EndpointClass(nAxes)
451 return EndpointClass()
453 @classmethod
454 def makeGoodFrame(cls, name, nAxes=None):
455 """Return the appropriate frame for the given name and nAxes
457 Parameters
458 ----------
459 name : `str`
460 Endpoint class name prefix; the full class name is name + "Endpoint"
461 nAxes : `int` or `None`, optional
462 number of axes; an int is required if `name` == "Generic";
463 otherwise ignored
465 Returns
466 -------
467 `ast.Frame`
468 The constructed frame
470 Raises
471 ------
472 TypeError
473 If `name` == "Generic" and `nAxes` is `None` or <= 0
474 """
475 return cls.makeEndpoint(name, nAxes).makeFrame()
477 @staticmethod
478 def makeBadFrames(name):
479 """Return a list of 0 or more frames that are not a valid match for the
480 named endpoint
482 Parameters
483 ----------
484 name : `str`
485 Endpoint class name prefix; the full class name is name + "Endpoint"
487 Returns
488 -------
489 Collection of `ast.Frame`
490 A collection of 0 or more frames
491 """
492 return {
493 "Generic": [],
494 "Point2": [
495 ast.SkyFrame(),
496 ast.Frame(1),
497 ast.Frame(3),
498 ],
499 "SpherePoint": [
500 ast.Frame(1),
501 ast.Frame(2),
502 ast.Frame(3),
503 ],
504 }[name]
506 def makeFrameSet(self, baseFrame, currFrame):
507 """Make a FrameSet
509 The FrameSet will contain 4 frames and three transforms connecting them.
510 The idenity of each frame is provided by self.frameIdentDict
512 Frame Index Mapping from this frame to the next
513 `baseFrame` 1 `ast.UnitMap(nIn)`
514 Frame(nIn) 2 `polyMap`
515 Frame(nOut) 3 `ast.UnitMap(nOut)`
516 `currFrame` 4
518 where:
519 - `nIn` = `baseFrame.nAxes`
520 - `nOut` = `currFrame.nAxes`
521 - `polyMap` = `makeTwoWayPolyMap(nIn, nOut)`
523 Returns
524 ------
525 `ast.FrameSet`
526 The FrameSet as described above
528 Parameters
529 ----------
530 baseFrame : `ast.Frame`
531 base frame
532 currFrame : `ast.Frame`
533 current frame
534 """
535 nIn = baseFrame.nAxes
536 nOut = currFrame.nAxes
537 polyMap = makeTwoWayPolyMap(nIn, nOut)
539 # The only way to set the Ident of a frame in a FrameSet is to set it in advance,
540 # and I don't want to modify the inputs, so replace the input frames with copies
541 baseFrame = baseFrame.copy()
542 baseFrame.ident = self.frameIdentDict[1]
543 currFrame = currFrame.copy()
544 currFrame.ident = self.frameIdentDict[4]
546 frameSet = ast.FrameSet(baseFrame)
547 frame2 = ast.Frame(nIn)
548 frame2.ident = self.frameIdentDict[2]
549 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nIn), frame2)
550 frame3 = ast.Frame(nOut)
551 frame3.ident = self.frameIdentDict[3]
552 frameSet.addFrame(ast.FrameSet.CURRENT, polyMap, frame3)
553 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nOut), currFrame)
554 return frameSet
556 @staticmethod
557 def permuteFrameSetIter(frameSet):
558 """Iterator over 0 or more frameSets with SkyFrames axes permuted
560 Only base and current SkyFrames are permuted. If neither the base nor
561 the current frame is a SkyFrame then no frames are returned.
563 Returns
564 -------
565 iterator over `PermutedFrameSet`
566 """
568 fsInfo = FrameSetInfo(frameSet)
569 if not (fsInfo.isBaseSkyFrame or fsInfo.isCurrSkyFrame):
570 return
572 permuteBaseList = [False, True] if fsInfo.isBaseSkyFrame else [False]
573 permuteCurrList = [False, True] if fsInfo.isCurrSkyFrame else [False]
574 for permuteBase in permuteBaseList:
575 for permuteCurr in permuteCurrList:
576 yield PermutedFrameSet(frameSet, permuteBase, permuteCurr)
578 @staticmethod
579 def makeJacobian(nIn, nOut, inPoint):
580 """Make a Jacobian matrix for the equation described by
581 `makeTwoWayPolyMap`.
583 Parameters
584 ----------
585 nIn, nOut : `int`
586 the dimensions of the input and output data; see makeTwoWayPolyMap
587 inPoint : `numpy.ndarray`
588 an array of size `nIn` representing the point at which the Jacobian
589 is measured
591 Returns
592 -------
593 J : `numpy.ndarray`
594 an `nOut` x `nIn` array of first derivatives
595 """
596 basePolyMapCoeff = 0.001 # see makeTwoWayPolyMap
597 baseCoeff = 2.0 * basePolyMapCoeff
598 coeffs = np.empty((nOut, nIn))
599 for iOut in range(nOut):
600 coeffOffset = baseCoeff * iOut
601 for iIn in range(nIn):
602 coeffs[iOut, iIn] = baseCoeff * (iIn + 1) + coeffOffset
603 coeffs[iOut, iIn] *= inPoint[iIn]
604 assert coeffs.ndim == 2
605 # Avoid spurious errors when comparing to a simplified array
606 assert coeffs.shape == (nOut, nIn)
607 return coeffs
609 def checkTransformation(self, transform, mapping, msg=""):
610 """Check applyForward and applyInverse for a transform
612 Parameters
613 ----------
614 transform : `lsst.afw.geom.Transform`
615 The transform to check
616 mapping : `ast.Mapping`
617 The mapping the transform should use. This mapping
618 must contain valid forward or inverse transformations,
619 but they need not match if both present. Hence the
620 mappings returned by make*PolyMap are acceptable.
621 msg : `str`
622 Error message suffix describing test parameters
623 """
624 fromEndpoint = transform.fromEndpoint
625 toEndpoint = transform.toEndpoint
626 mappingFromTransform = transform.getMapping()
628 nIn = mapping.nIn
629 nOut = mapping.nOut
630 self.assertEqual(nIn, fromEndpoint.nAxes, msg=msg)
631 self.assertEqual(nOut, toEndpoint.nAxes, msg=msg)
633 # forward transformation of one point
634 rawInPoint = self.makeRawPointData(nIn)
635 inPoint = fromEndpoint.pointFromData(rawInPoint)
637 # forward transformation of an array of points
638 nPoints = 7 # arbitrary
639 rawInArray = self.makeRawArrayData(nPoints, nIn)
640 inArray = fromEndpoint.arrayFromData(rawInArray)
642 if mapping.hasForward:
643 self.assertTrue(transform.hasForward)
644 outPoint = transform.applyForward(inPoint)
645 rawOutPoint = toEndpoint.dataFromPoint(outPoint)
646 assert_allclose(rawOutPoint, mapping.applyForward(rawInPoint), err_msg=msg)
647 assert_allclose(rawOutPoint, mappingFromTransform.applyForward(rawInPoint), err_msg=msg)
649 outArray = transform.applyForward(inArray)
650 rawOutArray = toEndpoint.dataFromArray(outArray)
651 self.assertFloatsAlmostEqual(rawOutArray, mapping.applyForward(rawInArray), msg=msg)
652 self.assertFloatsAlmostEqual(rawOutArray, mappingFromTransform.applyForward(rawInArray), msg=msg)
653 else:
654 # Need outPoint, but don't need it to be consistent with inPoint
655 rawOutPoint = self.makeRawPointData(nOut)
656 outPoint = toEndpoint.pointFromData(rawOutPoint)
657 rawOutArray = self.makeRawArrayData(nPoints, nOut)
658 outArray = toEndpoint.arrayFromData(rawOutArray)
660 self.assertFalse(transform.hasForward)
662 if mapping.hasInverse:
663 self.assertTrue(transform.hasInverse)
664 # inverse transformation of one point;
665 # remember that the inverse need not give the original values
666 # (see the description of the `mapping` parameter)
667 inversePoint = transform.applyInverse(outPoint)
668 rawInversePoint = fromEndpoint.dataFromPoint(inversePoint)
669 assert_allclose(rawInversePoint, mapping.applyInverse(rawOutPoint), err_msg=msg)
670 assert_allclose(rawInversePoint, mappingFromTransform.applyInverse(rawOutPoint), err_msg=msg)
672 # inverse transformation of an array of points;
673 # remember that the inverse will not give the original values
674 # (see the description of the `mapping` parameter)
675 inverseArray = transform.applyInverse(outArray)
676 rawInverseArray = fromEndpoint.dataFromArray(inverseArray)
677 self.assertFloatsAlmostEqual(rawInverseArray, mapping.applyInverse(rawOutArray), msg=msg)
678 self.assertFloatsAlmostEqual(rawInverseArray, mappingFromTransform.applyInverse(rawOutArray),
679 msg=msg)
680 else:
681 self.assertFalse(transform.hasInverse)
683 def checkInverseTransformation(self, forward, inverse, msg=""):
684 """Check that two Transforms are each others' inverses.
686 Parameters
687 ----------
688 forward : `lsst.afw.geom.Transform`
689 the reference Transform to test
690 inverse : `lsst.afw.geom.Transform`
691 the transform that should be the inverse of `forward`
692 msg : `str`
693 error message suffix describing test parameters
694 """
695 fromEndpoint = forward.fromEndpoint
696 toEndpoint = forward.toEndpoint
697 forwardMapping = forward.getMapping()
698 inverseMapping = inverse.getMapping()
700 # properties
701 self.assertEqual(forward.fromEndpoint,
702 inverse.toEndpoint, msg=msg)
703 self.assertEqual(forward.toEndpoint,
704 inverse.fromEndpoint, msg=msg)
705 self.assertEqual(forward.hasForward, inverse.hasInverse, msg=msg)
706 self.assertEqual(forward.hasInverse, inverse.hasForward, msg=msg)
708 # transformations of one point
709 # we don't care about whether the transformation itself is correct
710 # (see checkTransformation), so inPoint/outPoint need not be related
711 rawInPoint = self.makeRawPointData(fromEndpoint.nAxes)
712 inPoint = fromEndpoint.pointFromData(rawInPoint)
713 rawOutPoint = self.makeRawPointData(toEndpoint.nAxes)
714 outPoint = toEndpoint.pointFromData(rawOutPoint)
716 # transformations of arrays of points
717 nPoints = 7 # arbitrary
718 rawInArray = self.makeRawArrayData(nPoints, fromEndpoint.nAxes)
719 inArray = fromEndpoint.arrayFromData(rawInArray)
720 rawOutArray = self.makeRawArrayData(nPoints, toEndpoint.nAxes)
721 outArray = toEndpoint.arrayFromData(rawOutArray)
723 if forward.hasForward:
724 self.assertEqual(forward.applyForward(inPoint),
725 inverse.applyInverse(inPoint), msg=msg)
726 self.assertEqual(forwardMapping.applyForward(rawInPoint),
727 inverseMapping.applyInverse(rawInPoint), msg=msg)
728 # Assertions must work with both lists and numpy arrays
729 assert_array_equal(forward.applyForward(inArray),
730 inverse.applyInverse(inArray),
731 err_msg=msg)
732 assert_array_equal(forwardMapping.applyForward(rawInArray),
733 inverseMapping.applyInverse(rawInArray),
734 err_msg=msg)
736 if forward.hasInverse:
737 self.assertEqual(forward.applyInverse(outPoint),
738 inverse.applyForward(outPoint), msg=msg)
739 self.assertEqual(forwardMapping.applyInverse(rawOutPoint),
740 inverseMapping.applyForward(rawOutPoint), msg=msg)
741 assert_array_equal(forward.applyInverse(outArray),
742 inverse.applyForward(outArray),
743 err_msg=msg)
744 assert_array_equal(forwardMapping.applyInverse(rawOutArray),
745 inverseMapping.applyForward(rawOutArray),
746 err_msg=msg)
748 def checkTransformFromMapping(self, fromName, toName):
749 """Check Transform_<fromName>_<toName> using the Mapping constructor
751 Parameters
752 ----------
753 fromName, toName : `str`
754 Endpoint name prefix for "from" and "to" endpoints, respectively,
755 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
756 fromAxes, toAxes : `int`
757 number of axes in fromFrame and toFrame, respectively
758 """
759 transformClassName = "Transform{}To{}".format(fromName, toName)
760 TransformClass = getattr(afwGeom, transformClassName)
761 baseMsg = "TransformClass={}".format(TransformClass.__name__)
763 # check valid numbers of inputs and outputs
764 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
765 self.goodNAxes[toName]):
766 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
767 polyMap = makeTwoWayPolyMap(nIn, nOut)
768 transform = TransformClass(polyMap)
770 # desired output from `str(transform)`
771 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut)
772 self.assertEqual("{}".format(transform), desStr)
773 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr)
775 self.checkTransformation(transform, polyMap, msg=msg)
777 # Forward transform but no inverse
778 polyMap = makeForwardPolyMap(nIn, nOut)
779 transform = TransformClass(polyMap)
780 self.checkTransformation(transform, polyMap, msg=msg)
782 # Inverse transform but no forward
783 polyMap = makeForwardPolyMap(nOut, nIn).inverted()
784 transform = TransformClass(polyMap)
785 self.checkTransformation(transform, polyMap, msg=msg)
787 # check invalid # of output against valid # of inputs
788 for nIn, badNOut in itertools.product(self.goodNAxes[fromName],
789 self.badNAxes[toName]):
790 badPolyMap = makeTwoWayPolyMap(nIn, badNOut)
791 msg = "{}, nIn={}, badNOut={}".format(baseMsg, nIn, badNOut)
792 with self.assertRaises(InvalidParameterError, msg=msg):
793 TransformClass(badPolyMap)
795 # check invalid # of inputs against valid and invalid # of outputs
796 for badNIn, nOut in itertools.product(self.badNAxes[fromName],
797 self.goodNAxes[toName] + self.badNAxes[toName]):
798 badPolyMap = makeTwoWayPolyMap(badNIn, nOut)
799 msg = "{}, badNIn={}, nOut={}".format(baseMsg, nIn, nOut)
800 with self.assertRaises(InvalidParameterError, msg=msg):
801 TransformClass(badPolyMap)
803 def checkTransformFromFrameSet(self, fromName, toName):
804 """Check Transform_<fromName>_<toName> using the FrameSet constructor
806 Parameters
807 ----------
808 fromName, toName : `str`
809 Endpoint name prefix for "from" and "to" endpoints, respectively,
810 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
811 """
812 transformClassName = "Transform{}To{}".format(fromName, toName)
813 TransformClass = getattr(afwGeom, transformClassName)
814 baseMsg = "TransformClass={}".format(TransformClass.__name__)
815 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
816 self.goodNAxes[toName]):
817 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
819 baseFrame = self.makeGoodFrame(fromName, nIn)
820 currFrame = self.makeGoodFrame(toName, nOut)
821 frameSet = self.makeFrameSet(baseFrame, currFrame)
822 self.assertEqual(frameSet.nFrame, 4)
824 # construct 0 or more frame sets that are invalid for this transform class
825 for badBaseFrame in self.makeBadFrames(fromName):
826 badFrameSet = self.makeFrameSet(badBaseFrame, currFrame)
827 with self.assertRaises(InvalidParameterError):
828 TransformClass(badFrameSet)
829 for badCurrFrame in self.makeBadFrames(toName):
830 reallyBadFrameSet = self.makeFrameSet(badBaseFrame, badCurrFrame)
831 with self.assertRaises(InvalidParameterError):
832 TransformClass(reallyBadFrameSet)
833 for badCurrFrame in self.makeBadFrames(toName):
834 badFrameSet = self.makeFrameSet(baseFrame, badCurrFrame)
835 with self.assertRaises(InvalidParameterError):
836 TransformClass(badFrameSet)
838 transform = TransformClass(frameSet)
840 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut)
841 self.assertEqual("{}".format(transform), desStr)
842 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr)
844 self.checkPersistence(transform)
846 mappingFromTransform = transform.getMapping()
847 transformCopy = TransformClass(mappingFromTransform)
848 self.assertEqual(type(transform), type(transformCopy))
849 self.assertEqual(transform.getMapping(), mappingFromTransform)
851 polyMap = makeTwoWayPolyMap(nIn, nOut)
853 self.checkTransformation(transform, mapping=polyMap, msg=msg)
855 # If the base and/or current frame of frameSet is a SkyFrame,
856 # try permuting that frame (in place, so the connected mappings are
857 # correctly updated). The Transform constructor should undo the permutation,
858 # (via SpherePointEndpoint.normalizeFrame) in its internal copy of frameSet,
859 # forcing the axes of the SkyFrame into standard (longitude, latitude) order
860 for permutedFS in self.permuteFrameSetIter(frameSet):
861 if permutedFS.isBaseSkyFrame:
862 baseFrame = permutedFS.frameSet.getFrame(ast.FrameSet.BASE)
863 # desired base longitude axis
864 desBaseLonAxis = 2 if permutedFS.isBasePermuted else 1
865 self.assertEqual(baseFrame.lonAxis, desBaseLonAxis)
866 if permutedFS.isCurrSkyFrame:
867 currFrame = permutedFS.frameSet.getFrame(ast.FrameSet.CURRENT)
868 # desired current base longitude axis
869 desCurrLonAxis = 2 if permutedFS.isCurrPermuted else 1
870 self.assertEqual(currFrame.lonAxis, desCurrLonAxis)
872 permTransform = TransformClass(permutedFS.frameSet)
873 self.checkTransformation(permTransform, mapping=polyMap, msg=msg)
875 def checkInverted(self, fromName, toName):
876 """Test Transform<fromName>To<toName>.inverted
878 Parameters
879 ----------
880 fromName, toName : `str`
881 Endpoint name prefix for "from" and "to" endpoints, respectively,
882 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
883 """
884 transformClassName = "Transform{}To{}".format(fromName, toName)
885 TransformClass = getattr(afwGeom, transformClassName)
886 baseMsg = "TransformClass={}".format(TransformClass.__name__)
887 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
888 self.goodNAxes[toName]):
889 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
890 self.checkInverseMapping(
891 TransformClass,
892 makeTwoWayPolyMap(nIn, nOut),
893 "{}, Map={}".format(msg, "TwoWay"))
894 self.checkInverseMapping(
895 TransformClass,
896 makeForwardPolyMap(nIn, nOut),
897 "{}, Map={}".format(msg, "Forward"))
898 self.checkInverseMapping(
899 TransformClass,
900 makeForwardPolyMap(nOut, nIn).inverted(),
901 "{}, Map={}".format(msg, "Inverse"))
903 def checkInverseMapping(self, TransformClass, mapping, msg):
904 """Test Transform<fromName>To<toName>.inverted for a specific
905 mapping.
907 Also check that inverted() and getInverted() return the same
908 transform.
910 Parameters
911 ----------
912 TransformClass : `type`
913 The class of transform to test, such as TransformPoint2ToPoint2
914 mapping : `ast.Mapping`
915 The mapping to use for the transform
916 msg : `str`
917 Error message suffix
918 """
919 transform = TransformClass(mapping)
920 inverse = transform.inverted()
921 inverseInverse = inverse.inverted()
923 self.checkInverseTransformation(transform, inverse, msg=msg)
924 self.checkInverseTransformation(inverse, inverseInverse, msg=msg)
925 self.checkTransformation(inverseInverse, mapping, msg=msg)
927 def checkGetJacobian(self, fromName, toName):
928 """Test Transform<fromName>To<toName>.getJacobian
930 Parameters
931 ----------
932 fromName, toName : `str`
933 Endpoint name prefix for "from" and "to" endpoints, respectively,
934 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
935 """
936 transformClassName = "Transform{}To{}".format(fromName, toName)
937 TransformClass = getattr(afwGeom, transformClassName)
938 baseMsg = "TransformClass={}".format(TransformClass.__name__)
939 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
940 self.goodNAxes[toName]):
941 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
942 polyMap = makeForwardPolyMap(nIn, nOut)
943 transform = TransformClass(polyMap)
944 fromEndpoint = transform.fromEndpoint
946 # Test multiple points to ensure correct functional form
947 rawInPoint = self.makeRawPointData(nIn)
948 inPoint = fromEndpoint.pointFromData(rawInPoint)
949 jacobian = transform.getJacobian(inPoint)
950 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint),
951 err_msg=msg)
953 rawInPoint = self.makeRawPointData(nIn, 0.111)
954 inPoint = fromEndpoint.pointFromData(rawInPoint)
955 jacobian = transform.getJacobian(inPoint)
956 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint),
957 err_msg=msg)
959 def checkThen(self, fromName, midName, toName):
960 """Test Transform<fromName>To<midName>.then(Transform<midName>To<toName>)
962 Parameters
963 ----------
964 fromName : `str`
965 the prefix of the starting endpoint (e.g., "Point2" for a
966 Point2Endpoint) for the final, concatenated Transform
967 midName : `str`
968 the prefix for the shared endpoint where two Transforms will be
969 concatenated
970 toName : `str`
971 the prefix of the ending endpoint for the final, concatenated
972 Transform
973 """
974 TransformClass1 = getattr(afwGeom,
975 "Transform{}To{}".format(fromName, midName))
976 TransformClass2 = getattr(afwGeom,
977 "Transform{}To{}".format(midName, toName))
978 baseMsg = "{}.then({})".format(TransformClass1.__name__,
979 TransformClass2.__name__)
980 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName],
981 self.goodNAxes[midName],
982 self.goodNAxes[toName]):
983 msg = "{}, nIn={}, nMid={}, nOut={}".format(
984 baseMsg, nIn, nMid, nOut)
985 polyMap1 = makeTwoWayPolyMap(nIn, nMid)
986 transform1 = TransformClass1(polyMap1)
987 polyMap2 = makeTwoWayPolyMap(nMid, nOut)
988 transform2 = TransformClass2(polyMap2)
989 transform = transform1.then(transform2)
991 fromEndpoint = transform1.fromEndpoint
992 toEndpoint = transform2.toEndpoint
994 inPoint = fromEndpoint.pointFromData(self.makeRawPointData(nIn))
995 outPointMerged = transform.applyForward(inPoint)
996 outPointSeparate = transform2.applyForward(
997 transform1.applyForward(inPoint))
998 assert_allclose(toEndpoint.dataFromPoint(outPointMerged),
999 toEndpoint.dataFromPoint(outPointSeparate),
1000 err_msg=msg)
1002 outPoint = toEndpoint.pointFromData(self.makeRawPointData(nOut))
1003 inPointMerged = transform.applyInverse(outPoint)
1004 inPointSeparate = transform1.applyInverse(
1005 transform2.applyInverse(outPoint))
1006 assert_allclose(
1007 fromEndpoint.dataFromPoint(inPointMerged),
1008 fromEndpoint.dataFromPoint(inPointSeparate),
1009 err_msg=msg)
1011 # Mismatched number of axes should fail
1012 if midName == "Generic":
1013 nIn = self.goodNAxes[fromName][0]
1014 nOut = self.goodNAxes[toName][0]
1015 polyMap = makeTwoWayPolyMap(nIn, 3)
1016 transform1 = TransformClass1(polyMap)
1017 polyMap = makeTwoWayPolyMap(2, nOut)
1018 transform2 = TransformClass2(polyMap)
1019 with self.assertRaises(InvalidParameterError):
1020 transform = transform1.then(transform2)
1022 # Mismatched types of endpoints should fail
1023 if fromName != midName:
1024 # Use TransformClass1 for both args to keep test logic simple
1025 outName = midName
1026 joinNAxes = set(self.goodNAxes[fromName]).intersection(
1027 self.goodNAxes[outName])
1029 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName],
1030 joinNAxes,
1031 self.goodNAxes[outName]):
1032 polyMap = makeTwoWayPolyMap(nIn, nMid)
1033 transform1 = TransformClass1(polyMap)
1034 polyMap = makeTwoWayPolyMap(nMid, nOut)
1035 transform2 = TransformClass1(polyMap)
1036 with self.assertRaises(InvalidParameterError):
1037 transform = transform1.then(transform2)
1039 def assertTransformsEqual(self, transform1, transform2):
1040 """Assert that two transforms are equal"""
1041 self.assertEqual(type(transform1), type(transform2))
1042 self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint)
1043 self.assertEqual(transform1.toEndpoint, transform2.toEndpoint)
1044 self.assertEqual(transform1.getMapping(), transform2.getMapping())
1046 fromEndpoint = transform1.fromEndpoint
1047 toEndpoint = transform1.toEndpoint
1048 mapping = transform1.getMapping()
1049 nIn = mapping.nIn
1050 nOut = mapping.nOut
1052 if mapping.hasForward:
1053 nPoints = 7 # arbitrary
1054 rawInArray = self.makeRawArrayData(nPoints, nIn)
1055 inArray = fromEndpoint.arrayFromData(rawInArray)
1056 outArray = transform1.applyForward(inArray)
1057 outData = toEndpoint.dataFromArray(outArray)
1058 outArrayRoundTrip = transform2.applyForward(inArray)
1059 outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip)
1060 assert_allclose(outData, outDataRoundTrip)
1062 if mapping.hasInverse:
1063 nPoints = 7 # arbitrary
1064 rawOutArray = self.makeRawArrayData(nPoints, nOut)
1065 outArray = toEndpoint.arrayFromData(rawOutArray)
1066 inArray = transform1.applyInverse(outArray)
1067 inData = fromEndpoint.dataFromArray(inArray)
1068 inArrayRoundTrip = transform2.applyInverse(outArray)
1069 inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip)
1070 assert_allclose(inData, inDataRoundTrip)
1072 def checkPersistence(self, transform):
1073 """Check persistence of a transform
1074 """
1075 className = type(transform).__name__
1077 # check writeString and readString
1078 transformStr = transform.writeString()
1079 serialVersion, serialClassName, serialRest = transformStr.split(" ", 2)
1080 self.assertEqual(int(serialVersion), 1)
1081 self.assertEqual(serialClassName, className)
1082 badStr1 = " ".join(["2", serialClassName, serialRest])
1083 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
1084 transform.readString(badStr1)
1085 badClassName = "x" + serialClassName
1086 badStr2 = " ".join(["1", badClassName, serialRest])
1087 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
1088 transform.readString(badStr2)
1089 transformFromStr1 = transform.readString(transformStr)
1090 self.assertTransformsEqual(transform, transformFromStr1)
1092 # check transformFromString
1093 transformFromStr2 = afwGeom.transformFromString(transformStr)
1094 self.assertTransformsEqual(transform, transformFromStr2)
1096 # Check pickling
1097 self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform)))
1099 # Check afw::table::io persistence round-trip
1100 with lsst.utils.tests.getTempFilePath(".fits") as filename:
1101 transform.writeFits(filename)
1102 self.assertTransformsEqual(transform, type(transform).readFits(filename))