lsst.afw g714e0ed6de+196fb0684f
Loading...
Searching...
No Matches
testUtils.py
Go to the documentation of this file.
1# This file is part of afw.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22__all__ = ["BoxGrid", "makeSipIwcToPixel", "makeSipPixelToIwc"]
23
24import itertools
25import math
26import pickle
27
28import astshim as ast
29import numpy as np
30from numpy.testing import assert_allclose, assert_array_equal
31from astshim.test import makeForwardPolyMap, makeTwoWayPolyMap
32from ._geom import getCdMatrixFromMetadata
33
34import lsst.geom
35import lsst.afw.geom as afwGeom
36from lsst.pex.exceptions import InvalidParameterError
37import lsst.utils
39
40
41class BoxGrid:
42 """Divide a box into nx by ny sub-boxes that tile the region
43
44 The sub-boxes will be of the same type as `box` and will exactly tile `box`;
45 they will also all be the same size, to the extent possible (some variation
46 is inevitable for integer boxes that cannot be evenly divided.
47
48 Parameters
49 ----------
50 box : `lsst.geom.Box2I` or `lsst.geom.Box2D`
51 the box to subdivide; the boxes in the grid will be of the same type
52 numColRow : pair of `int`
53 number of columns and rows
54 """
55
56 def __init__(self, box, numColRow):
57 if len(numColRow) != 2:
58 raise RuntimeError(f"numColRow={numColRow!r}; must be a sequence of two integers")
59 self._numColRow = tuple(int(val) for val in numColRow)
60
61 if isinstance(box, lsst.geom.Box2I):
62 stopDelta = 1
63 elif isinstance(box, lsst.geom.Box2D):
64 stopDelta = 0
65 else:
66 raise RuntimeError(f"Unknown class {type(box)} of box {box}")
67 self.boxClass = type(box)
68 self.stopDelta = stopDelta
69
70 minPoint = box.getMin()
71 self.pointClass = type(minPoint)
72 dtype = np.array(minPoint).dtype
73
74 self._divList = [np.linspace(start=box.getMin()[i],
75 stop=box.getMax()[i] + self.stopDelta,
76 num=self._numColRow[i] + 1,
77 endpoint=True,
78 dtype=dtype) for i in range(2)]
79
80 @property
81 def numColRow(self):
82 return self._numColRow
83
84 def __getitem__(self, indXY):
85 """Return the box at the specified x,y index
86
87 Parameters
88 ----------
89 indXY : pair of `ints`
90 the x,y index to return
91
92 Returns
93 -------
94 subBox : `lsst.geom.Box2I` or `lsst.geom.Box2D`
95 """
96 beg = self.pointClass(*[self._divList[i][indXY[i]] for i in range(2)])
97 end = self.pointClass(
98 *[self._divList[i][indXY[i] + 1] - self.stopDelta for i in range(2)])
99 return self.boxClass(beg, end)
100
101 def __len__(self):
102 return self.shape[0]*self.shape[1]
103
104 def __iter__(self):
105 """Return an iterator over all boxes, where column varies most quickly
106 """
107 for row in range(self.numColRow[1]):
108 for col in range(self.numColRow[0]):
109 yield self[col, row]
110
111
113 """Information about a FrameSet
114
115 Parameters
116 ----------
117 frameSet : `ast.FrameSet`
118 The FrameSet about which you want information
119
120 Notes
121 -----
122 **Fields**
123
124 baseInd : `int`
125 Index of base frame
126 currInd : `int`
127 Index of current frame
128 isBaseSkyFrame : `bool`
129 Is the base frame an `ast.SkyFrame`?
130 isCurrSkyFrame : `bool`
131 Is the current frame an `ast.SkyFrame`?
132 """
133 def __init__(self, frameSet):
134 self.baseInd = frameSet.base
135 self.currInd = frameSet.current
136 self.isBaseSkyFrame = frameSet.getFrame(self.baseInd).className == "SkyFrame"
137 self.isCurrSkyFrame = frameSet.getFrame(self.currInd).className == "SkyFrame"
138
139
140def makeSipPolyMapCoeffs(metadata, name):
141 """Return a list of ast.PolyMap coefficients for the specified SIP matrix
142
143 The returned list of coefficients for an ast.PolyMap
144 that computes the following function:
145
146 f(dxy) = dxy + sipPolynomial(dxy))
147 where dxy = pixelPosition - pixelOrigin
148 and sipPolynomial is a polynomial with terms `<name>n_m for x^n y^m`
149 (e.g. A2_0 is the coefficient for x^2 y^0)
150
151 Parameters
152 ----------
153 metadata : lsst.daf.base.PropertySet
154 FITS metadata describing a WCS with the specified SIP coefficients
155 name : str
156 The desired SIP terms: one of A, B, AP, BP
157
158 Returns
159 -------
160 list
161 A list of coefficients for an ast.PolyMap that computes
162 the specified SIP polynomial, including a term for out = in
163
164 Note
165 ----
166 This is an internal function for use by makeSipIwcToPixel
167 and makeSipPixelToIwc
168 """
169 outAxisDict = dict(A=1, B=2, AP=1, BP=2)
170 outAxis = outAxisDict.get(name)
171 if outAxis is None:
172 raise RuntimeError(f"{name} not a supported SIP name")
173 width = metadata.getAsInt(f"{name}_ORDER") + 1
174 found = False
175 # start with a term for out = in
176 coeffs = []
177 if outAxis == 1:
178 coeffs.append([1.0, outAxis, 1, 0])
179 else:
180 coeffs.append([1.0, outAxis, 0, 1])
181 # add SIP distortion terms
182 for xPower in range(width):
183 for yPower in range(width):
184 coeffName = f"{name}_{xPower}_{yPower}"
185 if not metadata.exists(coeffName):
186 continue
187 found = True
188 coeff = metadata.getAsDouble(coeffName)
189 coeffs.append([coeff, outAxis, xPower, yPower])
190 if not found:
191 raise RuntimeError(f"No {name} coefficients found")
192 return coeffs
193
194
195def makeSipIwcToPixel(metadata):
196 """Make an IWC to pixel transform with SIP distortion from FITS-WCS metadata
197
198 This function is primarily intended for unit tests.
199 IWC is intermediate world coordinates, as described in the FITS papers.
200
201 Parameters
202 ----------
203 metadata : lsst.daf.base.PropertySet
204 FITS metadata describing a WCS with inverse SIP coefficients
205
206 Returns
207 -------
208 lsst.afw.geom.TransformPoint2ToPoint2
209 Transform from IWC position to pixel position (zero-based)
210 in the forward direction. The inverse direction is not defined.
211
212 Notes
213 -----
214
215 The inverse SIP terms APn_m, BPn_m are polynomial coefficients x^n y^m
216 for computing transformed x, y respectively. If we call the resulting
217 polynomial inverseSipPolynomial, the returned transformation is:
218
219 pixelPosition = pixel origin + uv + inverseSipPolynomial(uv)
220 where uv = inverseCdMatrix * iwcPosition
221 """
222 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1)
223 pixelRelativeToAbsoluteMap = ast.ShiftMap(crpix)
224 cdMatrix = getCdMatrixFromMetadata(metadata)
225 cdMatrixMap = ast.MatrixMap(cdMatrix.copy())
226 coeffList = makeSipPolyMapCoeffs(metadata, "AP") + makeSipPolyMapCoeffs(metadata, "BP")
227 coeffArr = np.array(coeffList, dtype=float)
228 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0")
229
230 iwcToPixelMap = cdMatrixMap.inverted().then(sipPolyMap).then(pixelRelativeToAbsoluteMap)
231 return afwGeom.TransformPoint2ToPoint2(iwcToPixelMap)
232
233
234def makeSipPixelToIwc(metadata):
235 """Make a pixel to IWC transform with SIP distortion from FITS-WCS metadata
236
237 This function is primarily intended for unit tests.
238 IWC is intermediate world coordinates, as described in the FITS papers.
239
240 Parameters
241 ----------
242 metadata : lsst.daf.base.PropertySet
243 FITS metadata describing a WCS with forward SIP coefficients
244
245 Returns
246 -------
247 lsst.afw.geom.TransformPoint2ToPoint2
248 Transform from pixel position (zero-based) to IWC position
249 in the forward direction. The inverse direction is not defined.
250
251 Notes
252 -----
253
254 The forward SIP terms An_m, Bn_m are polynomial coefficients x^n y^m
255 for computing transformed x, y respectively. If we call the resulting
256 polynomial sipPolynomial, the returned transformation is:
257
258 iwcPosition = cdMatrix * (dxy + sipPolynomial(dxy))
259 where dxy = pixelPosition - pixelOrigin
260 """
261 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1)
262 pixelAbsoluteToRelativeMap = ast.ShiftMap(crpix).inverted()
263 cdMatrix = getCdMatrixFromMetadata(metadata)
264 cdMatrixMap = ast.MatrixMap(cdMatrix.copy())
265 coeffList = makeSipPolyMapCoeffs(metadata, "A") + makeSipPolyMapCoeffs(metadata, "B")
266 coeffArr = np.array(coeffList, dtype=float)
267 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0")
268 pixelToIwcMap = pixelAbsoluteToRelativeMap.then(sipPolyMap).then(cdMatrixMap)
269 return afwGeom.TransformPoint2ToPoint2(pixelToIwcMap)
270
271
273 """A FrameSet with base or current frame possibly permuted, with associated
274 information
275
276 Only two-axis frames will be permuted.
277
278 Parameters
279 ----------
280 frameSet : `ast.FrameSet`
281 The FrameSet you wish to permute. A deep copy is made.
282 permuteBase : `bool`
283 Permute the base frame's axes?
284 permuteCurr : `bool`
285 Permute the current frame's axes?
286
287 Raises
288 ------
289 RuntimeError
290 If you try to permute a frame that does not have 2 axes
291
292 Notes
293 -----
294 **Fields**
295
296 frameSet : `ast.FrameSet`
297 The FrameSet that may be permuted. A local copy is made.
298 isBaseSkyFrame : `bool`
299 Is the base frame an `ast.SkyFrame`?
300 isCurrSkyFrame : `bool`
301 Is the current frame an `ast.SkyFrame`?
302 isBasePermuted : `bool`
303 Are the base frame axes permuted?
304 isCurrPermuted : `bool`
305 Are the current frame axes permuted?
306 """
307 def __init__(self, frameSet, permuteBase, permuteCurr):
308 self.frameSet = frameSet.copy()
309 fsInfo = FrameSetInfo(self.frameSet)
310 self.isBaseSkyFrame = fsInfo.isBaseSkyFrame
311 self.isCurrSkyFrame = fsInfo.isCurrSkyFrame
312 if permuteBase:
313 baseNAxes = self.frameSet.getFrame(fsInfo.baseInd).nAxes
314 if baseNAxes != 2:
315 raise RuntimeError("Base frame has {} axes; 2 required to permute".format(baseNAxes))
316 self.frameSet.current = fsInfo.baseInd
317 self.frameSet.permAxes([2, 1])
318 self.frameSet.current = fsInfo.currInd
319 if permuteCurr:
320 currNAxes = self.frameSet.getFrame(fsInfo.currInd).nAxes
321 if currNAxes != 2:
322 raise RuntimeError("Current frame has {} axes; 2 required to permute".format(currNAxes))
323 assert self.frameSet.getFrame(fsInfo.currInd).nAxes == 2
324 self.frameSet.permAxes([2, 1])
325 self.isBasePermuted = permuteBase
326 self.isCurrPermuted = permuteCurr
327
328
329class TransformTestBaseClass(lsst.utils.tests.TestCase):
330 """Base class for unit tests of Transform<X>To<Y>
331
332 Subclasses must call `TransformTestBaseClass.setUp(self)`
333 if they provide their own version.
334 """
335
336 def setUp(self):
337 """Set up a test
338
339 Subclasses should call this method if they override setUp.
340 """
341 # tell unittest to use the msg argument of asserts as a supplement
342 # to the error message, rather than as the whole error message
343 self.longMessage = True
344
345 # list of endpoint class name prefixes; the full name is prefix + "Endpoint"
346 self.endpointPrefixes = ("Generic", "Point2", "SpherePoint")
347
348 # GoodNAxes is dict of endpoint class name prefix:
349 # tuple containing 0 or more valid numbers of axes
350 self.goodNAxes = {
351 "Generic": (1, 2, 3, 4), # all numbers of axes are valid for GenericEndpoint
352 "Point2": (2,),
353 "SpherePoint": (2,),
354 }
355
356 # BadAxes is dict of endpoint class name prefix:
357 # tuple containing 0 or more invalid numbers of axes
358 self.badNAxes = {
359 "Generic": (), # all numbers of axes are valid for GenericEndpoint
360 "Point2": (1, 3, 4),
361 "SpherePoint": (1, 3, 4),
362 }
363
364 # Dict of frame index: identity name for frames created by makeFrameSet
366 1: "baseFrame",
367 2: "frame2",
368 3: "frame3",
369 4: "currFrame",
370 }
371
372 @staticmethod
373 def makeRawArrayData(nPoints, nAxes, delta=0.123):
374 """Make an array of generic point data
375
376 The data will be suitable for spherical points
377
378 Parameters
379 ----------
380 nPoints : `int`
381 Number of points in the array
382 nAxes : `int`
383 Number of axes in the point
384
385 Returns
386 -------
387 np.array of floats with shape (nAxes, nPoints)
388 The values are as follows; if nAxes != 2:
389 The first point has values `[0, delta, 2*delta, ..., (nAxes-1)*delta]`
390 The Nth point has those values + N
391 if nAxes == 2 then the data is scaled so that the max value of axis 1
392 is a bit less than pi/2
393 """
394 delta = 0.123
395 # oneAxis = [0, 1, 2, ...nPoints-1]
396 oneAxis = np.arange(nPoints, dtype=float) # [0, 1, 2...]
397 # rawData = [oneAxis, oneAxis + delta, oneAxis + 2 delta, ...]
398 rawData = np.array([j * delta + oneAxis for j in range(nAxes)], dtype=float)
399 if nAxes == 2:
400 # scale rawData so that max value of 2nd axis is a bit less than pi/2,
401 # thus making the data safe for SpherePoint
402 maxLatitude = np.max(rawData[1])
403 rawData *= math.pi * 0.4999 / maxLatitude
404 return rawData
405
406 @staticmethod
407 def makeRawPointData(nAxes, delta=0.123):
408 """Make one generic point
409
410 Parameters
411 ----------
412 nAxes : `int`
413 Number of axes in the point
414 delta : `float`
415 Increment between axis values
416
417 Returns
418 -------
419 A list of `nAxes` floats with values `[0, delta, ..., (nAxes-1)*delta]
420 """
421 return [i*delta for i in range(nAxes)]
422
423 @staticmethod
424 def makeEndpoint(name, nAxes=None):
425 """Make an endpoint
426
427 Parameters
428 ----------
429 name : `str`
430 Endpoint class name prefix; the full class name is name + "Endpoint"
431 nAxes : `int` or `None`, optional
432 number of axes; an int is required if `name` == "Generic";
433 otherwise ignored
434
435 Returns
436 -------
437 subclass of `lsst.afw.geom.BaseEndpoint`
438 The constructed endpoint
439
440 Raises
441 ------
442 TypeError
443 If `name` == "Generic" and `nAxes` is None or <= 0
444 """
445 EndpointClassName = name + "Endpoint"
446 EndpointClass = getattr(afwGeom, EndpointClassName)
447 if name == "Generic":
448 if nAxes is None:
449 raise TypeError("nAxes must be an integer for GenericEndpoint")
450 return EndpointClass(nAxes)
451 return EndpointClass()
452
453 @classmethod
454 def makeGoodFrame(cls, name, nAxes=None):
455 """Return the appropriate frame for the given name and nAxes
456
457 Parameters
458 ----------
459 name : `str`
460 Endpoint class name prefix; the full class name is name + "Endpoint"
461 nAxes : `int` or `None`, optional
462 number of axes; an int is required if `name` == "Generic";
463 otherwise ignored
464
465 Returns
466 -------
467 `ast.Frame`
468 The constructed frame
469
470 Raises
471 ------
472 TypeError
473 If `name` == "Generic" and `nAxes` is `None` or <= 0
474 """
475 return cls.makeEndpoint(name, nAxes).makeFrame()
476
477 @staticmethod
478 def makeBadFrames(name):
479 """Return a list of 0 or more frames that are not a valid match for the
480 named endpoint
481
482 Parameters
483 ----------
484 name : `str`
485 Endpoint class name prefix; the full class name is name + "Endpoint"
486
487 Returns
488 -------
489 Collection of `ast.Frame`
490 A collection of 0 or more frames
491 """
492 return {
493 "Generic": [],
494 "Point2": [
495 ast.SkyFrame(),
496 ast.Frame(1),
497 ast.Frame(3),
498 ],
499 "SpherePoint": [
500 ast.Frame(1),
501 ast.Frame(2),
502 ast.Frame(3),
503 ],
504 }[name]
505
506 def makeFrameSet(self, baseFrame, currFrame):
507 """Make a FrameSet
508
509 The FrameSet will contain 4 frames and three transforms connecting them.
510 The idenity of each frame is provided by self.frameIdentDict
511
512 Frame Index Mapping from this frame to the next
513 `baseFrame` 1 `ast.UnitMap(nIn)`
514 Frame(nIn) 2 `polyMap`
515 Frame(nOut) 3 `ast.UnitMap(nOut)`
516 `currFrame` 4
517
518 where:
519 - `nIn` = `baseFrame.nAxes`
520 - `nOut` = `currFrame.nAxes`
521 - `polyMap` = `makeTwoWayPolyMap(nIn, nOut)`
522
523 Returns
524 ------
525 `ast.FrameSet`
526 The FrameSet as described above
527
528 Parameters
529 ----------
530 baseFrame : `ast.Frame`
531 base frame
532 currFrame : `ast.Frame`
533 current frame
534 """
535 nIn = baseFrame.nAxes
536 nOut = currFrame.nAxes
537 polyMap = makeTwoWayPolyMap(nIn, nOut)
538
539 # The only way to set the Ident of a frame in a FrameSet is to set it in advance,
540 # and I don't want to modify the inputs, so replace the input frames with copies
541 baseFrame = baseFrame.copy()
542 baseFrame.ident = self.frameIdentDict[1]
543 currFrame = currFrame.copy()
544 currFrame.ident = self.frameIdentDict[4]
545
546 frameSet = ast.FrameSet(baseFrame)
547 frame2 = ast.Frame(nIn)
548 frame2.ident = self.frameIdentDict[2]
549 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nIn), frame2)
550 frame3 = ast.Frame(nOut)
551 frame3.ident = self.frameIdentDict[3]
552 frameSet.addFrame(ast.FrameSet.CURRENT, polyMap, frame3)
553 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nOut), currFrame)
554 return frameSet
555
556 @staticmethod
557 def permuteFrameSetIter(frameSet):
558 """Iterator over 0 or more frameSets with SkyFrames axes permuted
559
560 Only base and current SkyFrames are permuted. If neither the base nor
561 the current frame is a SkyFrame then no frames are returned.
562
563 Returns
564 -------
565 iterator over `PermutedFrameSet`
566 """
567
568 fsInfo = FrameSetInfo(frameSet)
569 if not (fsInfo.isBaseSkyFrame or fsInfo.isCurrSkyFrame):
570 return
571
572 permuteBaseList = [False, True] if fsInfo.isBaseSkyFrame else [False]
573 permuteCurrList = [False, True] if fsInfo.isCurrSkyFrame else [False]
574 for permuteBase in permuteBaseList:
575 for permuteCurr in permuteCurrList:
576 yield PermutedFrameSet(frameSet, permuteBase, permuteCurr)
577
578 @staticmethod
579 def makeJacobian(nIn, nOut, inPoint):
580 """Make a Jacobian matrix for the equation described by
581 `makeTwoWayPolyMap`.
582
583 Parameters
584 ----------
585 nIn, nOut : `int`
586 the dimensions of the input and output data; see makeTwoWayPolyMap
587 inPoint : `numpy.ndarray`
588 an array of size `nIn` representing the point at which the Jacobian
589 is measured
590
591 Returns
592 -------
593 J : `numpy.ndarray`
594 an `nOut` x `nIn` array of first derivatives
595 """
596 basePolyMapCoeff = 0.001 # see makeTwoWayPolyMap
597 baseCoeff = 2.0 * basePolyMapCoeff
598 coeffs = np.empty((nOut, nIn))
599 for iOut in range(nOut):
600 coeffOffset = baseCoeff * iOut
601 for iIn in range(nIn):
602 coeffs[iOut, iIn] = baseCoeff * (iIn + 1) + coeffOffset
603 coeffs[iOut, iIn] *= inPoint[iIn]
604 assert coeffs.ndim == 2
605 # Avoid spurious errors when comparing to a simplified array
606 assert coeffs.shape == (nOut, nIn)
607 return coeffs
608
609 def checkTransformation(self, transform, mapping, msg=""):
610 """Check applyForward and applyInverse for a transform
611
612 Parameters
613 ----------
614 transform : `lsst.afw.geom.Transform`
615 The transform to check
616 mapping : `ast.Mapping`
617 The mapping the transform should use. This mapping
618 must contain valid forward or inverse transformations,
619 but they need not match if both present. Hence the
620 mappings returned by make*PolyMap are acceptable.
621 msg : `str`
622 Error message suffix describing test parameters
623 """
624 fromEndpoint = transform.fromEndpoint
625 toEndpoint = transform.toEndpoint
626 mappingFromTransform = transform.getMapping()
627
628 nIn = mapping.nIn
629 nOut = mapping.nOut
630 self.assertEqual(nIn, fromEndpoint.nAxes, msg=msg)
631 self.assertEqual(nOut, toEndpoint.nAxes, msg=msg)
632
633 # forward transformation of one point
634 rawInPoint = self.makeRawPointData(nIn)
635 inPoint = fromEndpoint.pointFromData(rawInPoint)
636
637 # forward transformation of an array of points
638 nPoints = 7 # arbitrary
639 rawInArray = self.makeRawArrayData(nPoints, nIn)
640 inArray = fromEndpoint.arrayFromData(rawInArray)
641
642 if mapping.hasForward:
643 self.assertTrue(transform.hasForward)
644 outPoint = transform.applyForward(inPoint)
645 rawOutPoint = toEndpoint.dataFromPoint(outPoint)
646 assert_allclose(rawOutPoint, mapping.applyForward(rawInPoint), err_msg=msg)
647 assert_allclose(rawOutPoint, mappingFromTransform.applyForward(rawInPoint), err_msg=msg)
648
649 outArray = transform.applyForward(inArray)
650 rawOutArray = toEndpoint.dataFromArray(outArray)
651 self.assertFloatsAlmostEqual(rawOutArray, mapping.applyForward(rawInArray), msg=msg)
652 self.assertFloatsAlmostEqual(rawOutArray, mappingFromTransform.applyForward(rawInArray), msg=msg)
653 else:
654 # Need outPoint, but don't need it to be consistent with inPoint
655 rawOutPoint = self.makeRawPointData(nOut)
656 outPoint = toEndpoint.pointFromData(rawOutPoint)
657 rawOutArray = self.makeRawArrayData(nPoints, nOut)
658 outArray = toEndpoint.arrayFromData(rawOutArray)
659
660 self.assertFalse(transform.hasForward)
661
662 if mapping.hasInverse:
663 self.assertTrue(transform.hasInverse)
664 # inverse transformation of one point;
665 # remember that the inverse need not give the original values
666 # (see the description of the `mapping` parameter)
667 inversePoint = transform.applyInverse(outPoint)
668 rawInversePoint = fromEndpoint.dataFromPoint(inversePoint)
669 assert_allclose(rawInversePoint, mapping.applyInverse(rawOutPoint), err_msg=msg)
670 assert_allclose(rawInversePoint, mappingFromTransform.applyInverse(rawOutPoint), err_msg=msg)
671
672 # inverse transformation of an array of points;
673 # remember that the inverse will not give the original values
674 # (see the description of the `mapping` parameter)
675 inverseArray = transform.applyInverse(outArray)
676 rawInverseArray = fromEndpoint.dataFromArray(inverseArray)
677 self.assertFloatsAlmostEqual(rawInverseArray, mapping.applyInverse(rawOutArray), msg=msg)
678 self.assertFloatsAlmostEqual(rawInverseArray, mappingFromTransform.applyInverse(rawOutArray),
679 msg=msg)
680 else:
681 self.assertFalse(transform.hasInverse)
682
683 def checkInverseTransformation(self, forward, inverse, msg=""):
684 """Check that two Transforms are each others' inverses.
685
686 Parameters
687 ----------
688 forward : `lsst.afw.geom.Transform`
689 the reference Transform to test
690 inverse : `lsst.afw.geom.Transform`
691 the transform that should be the inverse of `forward`
692 msg : `str`
693 error message suffix describing test parameters
694 """
695 fromEndpoint = forward.fromEndpoint
696 toEndpoint = forward.toEndpoint
697 forwardMapping = forward.getMapping()
698 inverseMapping = inverse.getMapping()
699
700 # properties
701 self.assertEqual(forward.fromEndpoint,
702 inverse.toEndpoint, msg=msg)
703 self.assertEqual(forward.toEndpoint,
704 inverse.fromEndpoint, msg=msg)
705 self.assertEqual(forward.hasForward, inverse.hasInverse, msg=msg)
706 self.assertEqual(forward.hasInverse, inverse.hasForward, msg=msg)
707
708 # transformations of one point
709 # we don't care about whether the transformation itself is correct
710 # (see checkTransformation), so inPoint/outPoint need not be related
711 rawInPoint = self.makeRawPointData(fromEndpoint.nAxes)
712 inPoint = fromEndpoint.pointFromData(rawInPoint)
713 rawOutPoint = self.makeRawPointData(toEndpoint.nAxes)
714 outPoint = toEndpoint.pointFromData(rawOutPoint)
715
716 # transformations of arrays of points
717 nPoints = 7 # arbitrary
718 rawInArray = self.makeRawArrayData(nPoints, fromEndpoint.nAxes)
719 inArray = fromEndpoint.arrayFromData(rawInArray)
720 rawOutArray = self.makeRawArrayData(nPoints, toEndpoint.nAxes)
721 outArray = toEndpoint.arrayFromData(rawOutArray)
722
723 if forward.hasForward:
724 self.assertEqual(forward.applyForward(inPoint),
725 inverse.applyInverse(inPoint), msg=msg)
726 self.assertEqual(forwardMapping.applyForward(rawInPoint),
727 inverseMapping.applyInverse(rawInPoint), msg=msg)
728 # Assertions must work with both lists and numpy arrays
729 assert_array_equal(forward.applyForward(inArray),
730 inverse.applyInverse(inArray),
731 err_msg=msg)
732 assert_array_equal(forwardMapping.applyForward(rawInArray),
733 inverseMapping.applyInverse(rawInArray),
734 err_msg=msg)
735
736 if forward.hasInverse:
737 self.assertEqual(forward.applyInverse(outPoint),
738 inverse.applyForward(outPoint), msg=msg)
739 self.assertEqual(forwardMapping.applyInverse(rawOutPoint),
740 inverseMapping.applyForward(rawOutPoint), msg=msg)
741 assert_array_equal(forward.applyInverse(outArray),
742 inverse.applyForward(outArray),
743 err_msg=msg)
744 assert_array_equal(forwardMapping.applyInverse(rawOutArray),
745 inverseMapping.applyForward(rawOutArray),
746 err_msg=msg)
747
748 def checkTransformFromMapping(self, fromName, toName):
749 """Check Transform_<fromName>_<toName> using the Mapping constructor
750
751 Parameters
752 ----------
753 fromName, toName : `str`
754 Endpoint name prefix for "from" and "to" endpoints, respectively,
755 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
756 fromAxes, toAxes : `int`
757 number of axes in fromFrame and toFrame, respectively
758 """
759 transformClassName = "Transform{}To{}".format(fromName, toName)
760 TransformClass = getattr(afwGeom, transformClassName)
761 baseMsg = "TransformClass={}".format(TransformClass.__name__)
762
763 # check valid numbers of inputs and outputs
764 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
765 self.goodNAxes[toName]):
766 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
767 polyMap = makeTwoWayPolyMap(nIn, nOut)
768 transform = TransformClass(polyMap)
769
770 # desired output from `str(transform)`
771 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut)
772 self.assertEqual("{}".format(transform), desStr)
773 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr)
774
775 self.checkTransformation(transform, polyMap, msg=msg)
776
777 # Forward transform but no inverse
778 polyMap = makeForwardPolyMap(nIn, nOut)
779 transform = TransformClass(polyMap)
780 self.checkTransformation(transform, polyMap, msg=msg)
781
782 # Inverse transform but no forward
783 polyMap = makeForwardPolyMap(nOut, nIn).inverted()
784 transform = TransformClass(polyMap)
785 self.checkTransformation(transform, polyMap, msg=msg)
786
787 # check invalid # of output against valid # of inputs
788 for nIn, badNOut in itertools.product(self.goodNAxes[fromName],
789 self.badNAxes[toName]):
790 badPolyMap = makeTwoWayPolyMap(nIn, badNOut)
791 msg = "{}, nIn={}, badNOut={}".format(baseMsg, nIn, badNOut)
792 with self.assertRaises(InvalidParameterError, msg=msg):
793 TransformClass(badPolyMap)
794
795 # check invalid # of inputs against valid and invalid # of outputs
796 for badNIn, nOut in itertools.product(self.badNAxes[fromName],
797 self.goodNAxes[toName] + self.badNAxes[toName]):
798 badPolyMap = makeTwoWayPolyMap(badNIn, nOut)
799 msg = "{}, badNIn={}, nOut={}".format(baseMsg, nIn, nOut)
800 with self.assertRaises(InvalidParameterError, msg=msg):
801 TransformClass(badPolyMap)
802
803 def checkTransformFromFrameSet(self, fromName, toName):
804 """Check Transform_<fromName>_<toName> using the FrameSet constructor
805
806 Parameters
807 ----------
808 fromName, toName : `str`
809 Endpoint name prefix for "from" and "to" endpoints, respectively,
810 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
811 """
812 transformClassName = "Transform{}To{}".format(fromName, toName)
813 TransformClass = getattr(afwGeom, transformClassName)
814 baseMsg = "TransformClass={}".format(TransformClass.__name__)
815 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
816 self.goodNAxes[toName]):
817 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
818
819 baseFrame = self.makeGoodFrame(fromName, nIn)
820 currFrame = self.makeGoodFrame(toName, nOut)
821 frameSet = self.makeFrameSet(baseFrame, currFrame)
822 self.assertEqual(frameSet.nFrame, 4)
823
824 # construct 0 or more frame sets that are invalid for this transform class
825 for badBaseFrame in self.makeBadFrames(fromName):
826 badFrameSet = self.makeFrameSet(badBaseFrame, currFrame)
827 with self.assertRaises(InvalidParameterError):
828 TransformClass(badFrameSet)
829 for badCurrFrame in self.makeBadFrames(toName):
830 reallyBadFrameSet = self.makeFrameSet(badBaseFrame, badCurrFrame)
831 with self.assertRaises(InvalidParameterError):
832 TransformClass(reallyBadFrameSet)
833 for badCurrFrame in self.makeBadFrames(toName):
834 badFrameSet = self.makeFrameSet(baseFrame, badCurrFrame)
835 with self.assertRaises(InvalidParameterError):
836 TransformClass(badFrameSet)
837
838 transform = TransformClass(frameSet)
839
840 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut)
841 self.assertEqual("{}".format(transform), desStr)
842 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr)
843
844 self.checkPersistence(transform)
845
846 mappingFromTransform = transform.getMapping()
847 transformCopy = TransformClass(mappingFromTransform)
848 self.assertEqual(type(transform), type(transformCopy))
849 self.assertEqual(transform.getMapping(), mappingFromTransform)
850
851 polyMap = makeTwoWayPolyMap(nIn, nOut)
852
853 self.checkTransformation(transform, mapping=polyMap, msg=msg)
854
855 # If the base and/or current frame of frameSet is a SkyFrame,
856 # try permuting that frame (in place, so the connected mappings are
857 # correctly updated). The Transform constructor should undo the permutation,
858 # (via SpherePointEndpoint.normalizeFrame) in its internal copy of frameSet,
859 # forcing the axes of the SkyFrame into standard (longitude, latitude) order
860 for permutedFS in self.permuteFrameSetIter(frameSet):
861 if permutedFS.isBaseSkyFrame:
862 baseFrame = permutedFS.frameSet.getFrame(ast.FrameSet.BASE)
863 # desired base longitude axis
864 desBaseLonAxis = 2 if permutedFS.isBasePermuted else 1
865 self.assertEqual(baseFrame.lonAxis, desBaseLonAxis)
866 if permutedFS.isCurrSkyFrame:
867 currFrame = permutedFS.frameSet.getFrame(ast.FrameSet.CURRENT)
868 # desired current base longitude axis
869 desCurrLonAxis = 2 if permutedFS.isCurrPermuted else 1
870 self.assertEqual(currFrame.lonAxis, desCurrLonAxis)
871
872 permTransform = TransformClass(permutedFS.frameSet)
873 self.checkTransformation(permTransform, mapping=polyMap, msg=msg)
874
875 def checkInverted(self, fromName, toName):
876 """Test Transform<fromName>To<toName>.inverted
877
878 Parameters
879 ----------
880 fromName, toName : `str`
881 Endpoint name prefix for "from" and "to" endpoints, respectively,
882 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
883 """
884 transformClassName = "Transform{}To{}".format(fromName, toName)
885 TransformClass = getattr(afwGeom, transformClassName)
886 baseMsg = "TransformClass={}".format(TransformClass.__name__)
887 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
888 self.goodNAxes[toName]):
889 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
891 TransformClass,
892 makeTwoWayPolyMap(nIn, nOut),
893 "{}, Map={}".format(msg, "TwoWay"))
895 TransformClass,
896 makeForwardPolyMap(nIn, nOut),
897 "{}, Map={}".format(msg, "Forward"))
899 TransformClass,
900 makeForwardPolyMap(nOut, nIn).inverted(),
901 "{}, Map={}".format(msg, "Inverse"))
902
903 def checkInverseMapping(self, TransformClass, mapping, msg):
904 """Test Transform<fromName>To<toName>.inverted for a specific
905 mapping.
906
907 Also check that inverted() and getInverted() return the same
908 transform.
909
910 Parameters
911 ----------
912 TransformClass : `type`
913 The class of transform to test, such as TransformPoint2ToPoint2
914 mapping : `ast.Mapping`
915 The mapping to use for the transform
916 msg : `str`
917 Error message suffix
918 """
919 transform = TransformClass(mapping)
920 inverse = transform.inverted()
921 inverseInverse = inverse.inverted()
922
923 self.checkInverseTransformation(transform, inverse, msg=msg)
924 self.checkInverseTransformation(inverse, inverseInverse, msg=msg)
925 self.checkTransformation(inverseInverse, mapping, msg=msg)
926
927 def checkGetJacobian(self, fromName, toName):
928 """Test Transform<fromName>To<toName>.getJacobian
929
930 Parameters
931 ----------
932 fromName, toName : `str`
933 Endpoint name prefix for "from" and "to" endpoints, respectively,
934 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
935 """
936 transformClassName = "Transform{}To{}".format(fromName, toName)
937 TransformClass = getattr(afwGeom, transformClassName)
938 baseMsg = "TransformClass={}".format(TransformClass.__name__)
939 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
940 self.goodNAxes[toName]):
941 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
942 polyMap = makeForwardPolyMap(nIn, nOut)
943 transform = TransformClass(polyMap)
944 fromEndpoint = transform.fromEndpoint
945
946 # Test multiple points to ensure correct functional form
947 rawInPoint = self.makeRawPointData(nIn)
948 inPoint = fromEndpoint.pointFromData(rawInPoint)
949 jacobian = transform.getJacobian(inPoint)
950 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint),
951 err_msg=msg)
952
953 rawInPoint = self.makeRawPointData(nIn, 0.111)
954 inPoint = fromEndpoint.pointFromData(rawInPoint)
955 jacobian = transform.getJacobian(inPoint)
956 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint),
957 err_msg=msg)
958
959 def checkThen(self, fromName, midName, toName):
960 """Test Transform<fromName>To<midName>.then(Transform<midName>To<toName>)
961
962 Parameters
963 ----------
964 fromName : `str`
965 the prefix of the starting endpoint (e.g., "Point2" for a
966 Point2Endpoint) for the final, concatenated Transform
967 midName : `str`
968 the prefix for the shared endpoint where two Transforms will be
969 concatenated
970 toName : `str`
971 the prefix of the ending endpoint for the final, concatenated
972 Transform
973 """
974 TransformClass1 = getattr(afwGeom,
975 "Transform{}To{}".format(fromName, midName))
976 TransformClass2 = getattr(afwGeom,
977 "Transform{}To{}".format(midName, toName))
978 baseMsg = "{}.then({})".format(TransformClass1.__name__,
979 TransformClass2.__name__)
980 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName],
981 self.goodNAxes[midName],
982 self.goodNAxes[toName]):
983 msg = "{}, nIn={}, nMid={}, nOut={}".format(
984 baseMsg, nIn, nMid, nOut)
985 polyMap1 = makeTwoWayPolyMap(nIn, nMid)
986 transform1 = TransformClass1(polyMap1)
987 polyMap2 = makeTwoWayPolyMap(nMid, nOut)
988 transform2 = TransformClass2(polyMap2)
989 transform = transform1.then(transform2)
990
991 fromEndpoint = transform1.fromEndpoint
992 toEndpoint = transform2.toEndpoint
993
994 inPoint = fromEndpoint.pointFromData(self.makeRawPointData(nIn))
995 outPointMerged = transform.applyForward(inPoint)
996 outPointSeparate = transform2.applyForward(
997 transform1.applyForward(inPoint))
998 assert_allclose(toEndpoint.dataFromPoint(outPointMerged),
999 toEndpoint.dataFromPoint(outPointSeparate),
1000 err_msg=msg)
1001
1002 outPoint = toEndpoint.pointFromData(self.makeRawPointData(nOut))
1003 inPointMerged = transform.applyInverse(outPoint)
1004 inPointSeparate = transform1.applyInverse(
1005 transform2.applyInverse(outPoint))
1006 assert_allclose(
1007 fromEndpoint.dataFromPoint(inPointMerged),
1008 fromEndpoint.dataFromPoint(inPointSeparate),
1009 err_msg=msg)
1010
1011 # Mismatched number of axes should fail
1012 if midName == "Generic":
1013 nIn = self.goodNAxes[fromName][0]
1014 nOut = self.goodNAxes[toName][0]
1015 polyMap = makeTwoWayPolyMap(nIn, 3)
1016 transform1 = TransformClass1(polyMap)
1017 polyMap = makeTwoWayPolyMap(2, nOut)
1018 transform2 = TransformClass2(polyMap)
1019 with self.assertRaises(InvalidParameterError):
1020 transform = transform1.then(transform2)
1021
1022 # Mismatched types of endpoints should fail
1023 if fromName != midName:
1024 # Use TransformClass1 for both args to keep test logic simple
1025 outName = midName
1026 joinNAxes = set(self.goodNAxes[fromName]).intersection(
1027 self.goodNAxes[outName])
1028
1029 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName],
1030 joinNAxes,
1031 self.goodNAxes[outName]):
1032 polyMap = makeTwoWayPolyMap(nIn, nMid)
1033 transform1 = TransformClass1(polyMap)
1034 polyMap = makeTwoWayPolyMap(nMid, nOut)
1035 transform2 = TransformClass1(polyMap)
1036 with self.assertRaises(InvalidParameterError):
1037 transform = transform1.then(transform2)
1038
1039 def assertTransformsEqual(self, transform1, transform2):
1040 """Assert that two transforms are equal"""
1041 self.assertEqual(type(transform1), type(transform2))
1042 self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint)
1043 self.assertEqual(transform1.toEndpoint, transform2.toEndpoint)
1044 self.assertEqual(transform1.getMapping(), transform2.getMapping())
1045
1046 fromEndpoint = transform1.fromEndpoint
1047 toEndpoint = transform1.toEndpoint
1048 mapping = transform1.getMapping()
1049 nIn = mapping.nIn
1050 nOut = mapping.nOut
1051
1052 if mapping.hasForward:
1053 nPoints = 7 # arbitrary
1054 rawInArray = self.makeRawArrayData(nPoints, nIn)
1055 inArray = fromEndpoint.arrayFromData(rawInArray)
1056 outArray = transform1.applyForward(inArray)
1057 outData = toEndpoint.dataFromArray(outArray)
1058 outArrayRoundTrip = transform2.applyForward(inArray)
1059 outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip)
1060 assert_allclose(outData, outDataRoundTrip)
1061
1062 if mapping.hasInverse:
1063 nPoints = 7 # arbitrary
1064 rawOutArray = self.makeRawArrayData(nPoints, nOut)
1065 outArray = toEndpoint.arrayFromData(rawOutArray)
1066 inArray = transform1.applyInverse(outArray)
1067 inData = fromEndpoint.dataFromArray(inArray)
1068 inArrayRoundTrip = transform2.applyInverse(outArray)
1069 inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip)
1070 assert_allclose(inData, inDataRoundTrip)
1071
1072 def checkPersistence(self, transform):
1073 """Check persistence of a transform
1074 """
1075 className = type(transform).__name__
1076
1077 # check writeString and readString
1078 transformStr = transform.writeString()
1079 serialVersion, serialClassName, serialRest = transformStr.split(" ", 2)
1080 self.assertEqual(int(serialVersion), 1)
1081 self.assertEqual(serialClassName, className)
1082 badStr1 = " ".join(["2", serialClassName, serialRest])
1083 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
1084 transform.readString(badStr1)
1085 badClassName = "x" + serialClassName
1086 badStr2 = " ".join(["1", badClassName, serialRest])
1087 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
1088 transform.readString(badStr2)
1089 transformFromStr1 = transform.readString(transformStr)
1090 self.assertTransformsEqual(transform, transformFromStr1)
1091
1092 # check transformFromString
1093 transformFromStr2 = afwGeom.transformFromString(transformStr)
1094 self.assertTransformsEqual(transform, transformFromStr2)
1095
1096 # Check pickling
1097 self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform)))
1098
1099 # Check afw::table::io persistence round-trip
1100 with lsst.utils.tests.getTempFilePath(".fits") as filename:
1101 transform.writeFits(filename)
1102 self.assertTransformsEqual(transform, type(transform).readFits(filename))
__init__(self, box, numColRow)
Definition testUtils.py:56
__init__(self, frameSet, permuteBase, permuteCurr)
Definition testUtils.py:307
checkTransformFromFrameSet(self, fromName, toName)
Definition testUtils.py:803
checkTransformation(self, transform, mapping, msg="")
Definition testUtils.py:609
checkInverseMapping(self, TransformClass, mapping, msg)
Definition testUtils.py:903
makeRawArrayData(nPoints, nAxes, delta=0.123)
Definition testUtils.py:373
assertTransformsEqual(self, transform1, transform2)
checkThen(self, fromName, midName, toName)
Definition testUtils.py:959
checkTransformFromMapping(self, fromName, toName)
Definition testUtils.py:748
checkInverseTransformation(self, forward, inverse, msg="")
Definition testUtils.py:683
makeSipPolyMapCoeffs(metadata, name)
Definition testUtils.py:140
Eigen::Matrix2d getCdMatrixFromMetadata(daf::base::PropertySet &metadata)
Read a CD matrix from FITS WCS metadata.
Definition wcsUtils.cc:75