Coverage for python / lsst / afw / geom / testUtils.py: 10%

447 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 02:01 -0700

1# This file is part of afw. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["BoxGrid", "makeSipIwcToPixel", "makeSipPixelToIwc"] 

23 

24import itertools 

25import math 

26import pickle 

27 

28import astshim as ast 

29import numpy as np 

30from numpy.testing import assert_allclose, assert_array_equal 

31from astshim.test import makeForwardPolyMap, makeTwoWayPolyMap 

32from ._geom import getCdMatrixFromMetadata 

33 

34import lsst.geom 

35import lsst.afw.geom as afwGeom 

36from lsst.pex.exceptions import InvalidParameterError 

37import lsst.utils 

38import lsst.utils.tests 

39 

40 

41class BoxGrid: 

42 """Divide a box into nx by ny sub-boxes that tile the region 

43 

44 The sub-boxes will be of the same type as `box` and will exactly tile `box`; 

45 they will also all be the same size, to the extent possible (some variation 

46 is inevitable for integer boxes that cannot be evenly divided. 

47 

48 Parameters 

49 ---------- 

50 box : `lsst.geom.Box2I` or `lsst.geom.Box2D` 

51 the box to subdivide; the boxes in the grid will be of the same type 

52 numColRow : pair of `int` 

53 number of columns and rows 

54 """ 

55 

56 def __init__(self, box, numColRow): 

57 if len(numColRow) != 2: 

58 raise RuntimeError(f"numColRow={numColRow!r}; must be a sequence of two integers") 

59 self._numColRow = tuple(int(val) for val in numColRow) 

60 

61 if isinstance(box, lsst.geom.Box2I): 

62 stopDelta = 1 

63 elif isinstance(box, lsst.geom.Box2D): 

64 stopDelta = 0 

65 else: 

66 raise RuntimeError(f"Unknown class {type(box)} of box {box}") 

67 self.boxClass = type(box) 

68 self.stopDelta = stopDelta 

69 

70 minPoint = box.getMin() 

71 self.pointClass = type(minPoint) 

72 dtype = np.array(minPoint).dtype 

73 

74 self._divList = [np.linspace(start=box.getMin()[i], 

75 stop=box.getMax()[i] + self.stopDelta, 

76 num=self._numColRow[i] + 1, 

77 endpoint=True, 

78 dtype=dtype) for i in range(2)] 

79 

80 @property 

81 def numColRow(self): 

82 return self._numColRow 

83 

84 def __getitem__(self, indXY): 

85 """Return the box at the specified x,y index 

86 

87 Parameters 

88 ---------- 

89 indXY : pair of `ints` 

90 the x,y index to return 

91 

92 Returns 

93 ------- 

94 subBox : `lsst.geom.Box2I` or `lsst.geom.Box2D` 

95 """ 

96 beg = self.pointClass(*[self._divList[i][indXY[i]] for i in range(2)]) 

97 end = self.pointClass( 

98 *[self._divList[i][indXY[i] + 1] - self.stopDelta for i in range(2)]) 

99 return self.boxClass(beg, end) 

100 

101 def __len__(self): 

102 return self.shape[0]*self.shape[1] 

103 

104 def __iter__(self): 

105 """Return an iterator over all boxes, where column varies most quickly 

106 """ 

107 for row in range(self.numColRow[1]): 

108 for col in range(self.numColRow[0]): 

109 yield self[col, row] 

110 

111 

112class FrameSetInfo: 

113 """Information about a FrameSet 

114 

115 Parameters 

116 ---------- 

117 frameSet : `ast.FrameSet` 

118 The FrameSet about which you want information 

119 

120 Notes 

121 ----- 

122 **Fields** 

123 

124 baseInd : `int` 

125 Index of base frame 

126 currInd : `int` 

127 Index of current frame 

128 isBaseSkyFrame : `bool` 

129 Is the base frame an `ast.SkyFrame`? 

130 isCurrSkyFrame : `bool` 

131 Is the current frame an `ast.SkyFrame`? 

132 """ 

133 def __init__(self, frameSet): 

134 self.baseInd = frameSet.base 

135 self.currInd = frameSet.current 

136 self.isBaseSkyFrame = frameSet.getFrame(self.baseInd).className == "SkyFrame" 

137 self.isCurrSkyFrame = frameSet.getFrame(self.currInd).className == "SkyFrame" 

138 

139 

140def makeSipPolyMapCoeffs(metadata, name): 

141 """Return a list of ast.PolyMap coefficients for the specified SIP matrix 

142 

143 The returned list of coefficients for an ast.PolyMap 

144 that computes the following function: 

145 

146 f(dxy) = dxy + sipPolynomial(dxy)) 

147 where dxy = pixelPosition - pixelOrigin 

148 and sipPolynomial is a polynomial with terms `<name>n_m for x^n y^m` 

149 (e.g. A2_0 is the coefficient for x^2 y^0) 

150 

151 Parameters 

152 ---------- 

153 metadata : lsst.daf.base.PropertySet 

154 FITS metadata describing a WCS with the specified SIP coefficients 

155 name : str 

156 The desired SIP terms: one of A, B, AP, BP 

157 

158 Returns 

159 ------- 

160 list 

161 A list of coefficients for an ast.PolyMap that computes 

162 the specified SIP polynomial, including a term for out = in 

163 

164 Note 

165 ---- 

166 This is an internal function for use by makeSipIwcToPixel 

167 and makeSipPixelToIwc 

168 """ 

169 outAxisDict = dict(A=1, B=2, AP=1, BP=2) 

170 outAxis = outAxisDict.get(name) 

171 if outAxis is None: 

172 raise RuntimeError(f"{name} not a supported SIP name") 

173 width = metadata.getAsInt(f"{name}_ORDER") + 1 

174 found = False 

175 # start with a term for out = in 

176 coeffs = [] 

177 if outAxis == 1: 

178 coeffs.append([1.0, outAxis, 1, 0]) 

179 else: 

180 coeffs.append([1.0, outAxis, 0, 1]) 

181 # add SIP distortion terms 

182 for xPower in range(width): 

183 for yPower in range(width): 

184 coeffName = f"{name}_{xPower}_{yPower}" 

185 if not metadata.exists(coeffName): 

186 continue 

187 found = True 

188 coeff = metadata.getAsDouble(coeffName) 

189 coeffs.append([coeff, outAxis, xPower, yPower]) 

190 if not found: 

191 raise RuntimeError(f"No {name} coefficients found") 

192 return coeffs 

193 

194 

195def makeSipIwcToPixel(metadata): 

196 """Make an IWC to pixel transform with SIP distortion from FITS-WCS metadata 

197 

198 This function is primarily intended for unit tests. 

199 IWC is intermediate world coordinates, as described in the FITS papers. 

200 

201 Parameters 

202 ---------- 

203 metadata : lsst.daf.base.PropertySet 

204 FITS metadata describing a WCS with inverse SIP coefficients 

205 

206 Returns 

207 ------- 

208 lsst.afw.geom.TransformPoint2ToPoint2 

209 Transform from IWC position to pixel position (zero-based) 

210 in the forward direction. The inverse direction is not defined. 

211 

212 Notes 

213 ----- 

214 

215 The inverse SIP terms APn_m, BPn_m are polynomial coefficients x^n y^m 

216 for computing transformed x, y respectively. If we call the resulting 

217 polynomial inverseSipPolynomial, the returned transformation is: 

218 

219 pixelPosition = pixel origin + uv + inverseSipPolynomial(uv) 

220 where uv = inverseCdMatrix * iwcPosition 

221 """ 

222 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1) 

223 pixelRelativeToAbsoluteMap = ast.ShiftMap(crpix) 

224 cdMatrix = getCdMatrixFromMetadata(metadata) 

225 cdMatrixMap = ast.MatrixMap(cdMatrix.copy()) 

226 coeffList = makeSipPolyMapCoeffs(metadata, "AP") + makeSipPolyMapCoeffs(metadata, "BP") 

227 coeffArr = np.array(coeffList, dtype=float) 

228 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0") 

229 

230 iwcToPixelMap = cdMatrixMap.inverted().then(sipPolyMap).then(pixelRelativeToAbsoluteMap) 

231 return afwGeom.TransformPoint2ToPoint2(iwcToPixelMap) 

232 

233 

234def makeSipPixelToIwc(metadata): 

235 """Make a pixel to IWC transform with SIP distortion from FITS-WCS metadata 

236 

237 This function is primarily intended for unit tests. 

238 IWC is intermediate world coordinates, as described in the FITS papers. 

239 

240 Parameters 

241 ---------- 

242 metadata : lsst.daf.base.PropertySet 

243 FITS metadata describing a WCS with forward SIP coefficients 

244 

245 Returns 

246 ------- 

247 lsst.afw.geom.TransformPoint2ToPoint2 

248 Transform from pixel position (zero-based) to IWC position 

249 in the forward direction. The inverse direction is not defined. 

250 

251 Notes 

252 ----- 

253 

254 The forward SIP terms An_m, Bn_m are polynomial coefficients x^n y^m 

255 for computing transformed x, y respectively. If we call the resulting 

256 polynomial sipPolynomial, the returned transformation is: 

257 

258 iwcPosition = cdMatrix * (dxy + sipPolynomial(dxy)) 

259 where dxy = pixelPosition - pixelOrigin 

260 """ 

261 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1) 

262 pixelAbsoluteToRelativeMap = ast.ShiftMap(crpix).inverted() 

263 cdMatrix = getCdMatrixFromMetadata(metadata) 

264 cdMatrixMap = ast.MatrixMap(cdMatrix.copy()) 

265 coeffList = makeSipPolyMapCoeffs(metadata, "A") + makeSipPolyMapCoeffs(metadata, "B") 

266 coeffArr = np.array(coeffList, dtype=float) 

267 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0") 

268 pixelToIwcMap = pixelAbsoluteToRelativeMap.then(sipPolyMap).then(cdMatrixMap) 

269 return afwGeom.TransformPoint2ToPoint2(pixelToIwcMap) 

270 

271 

272class PermutedFrameSet: 

273 """A FrameSet with base or current frame possibly permuted, with associated 

274 information 

275 

276 Only two-axis frames will be permuted. 

277 

278 Parameters 

279 ---------- 

280 frameSet : `ast.FrameSet` 

281 The FrameSet you wish to permute. A deep copy is made. 

282 permuteBase : `bool` 

283 Permute the base frame's axes? 

284 permuteCurr : `bool` 

285 Permute the current frame's axes? 

286 

287 Raises 

288 ------ 

289 RuntimeError 

290 If you try to permute a frame that does not have 2 axes 

291 

292 Notes 

293 ----- 

294 **Fields** 

295 

296 frameSet : `ast.FrameSet` 

297 The FrameSet that may be permuted. A local copy is made. 

298 isBaseSkyFrame : `bool` 

299 Is the base frame an `ast.SkyFrame`? 

300 isCurrSkyFrame : `bool` 

301 Is the current frame an `ast.SkyFrame`? 

302 isBasePermuted : `bool` 

303 Are the base frame axes permuted? 

304 isCurrPermuted : `bool` 

305 Are the current frame axes permuted? 

306 """ 

307 def __init__(self, frameSet, permuteBase, permuteCurr): 

308 self.frameSet = frameSet.copy() 

309 fsInfo = FrameSetInfo(self.frameSet) 

310 self.isBaseSkyFrame = fsInfo.isBaseSkyFrame 

311 self.isCurrSkyFrame = fsInfo.isCurrSkyFrame 

312 if permuteBase: 

313 baseNAxes = self.frameSet.getFrame(fsInfo.baseInd).nAxes 

314 if baseNAxes != 2: 

315 raise RuntimeError("Base frame has {} axes; 2 required to permute".format(baseNAxes)) 

316 self.frameSet.current = fsInfo.baseInd 

317 self.frameSet.permAxes([2, 1]) 

318 self.frameSet.current = fsInfo.currInd 

319 if permuteCurr: 

320 currNAxes = self.frameSet.getFrame(fsInfo.currInd).nAxes 

321 if currNAxes != 2: 

322 raise RuntimeError("Current frame has {} axes; 2 required to permute".format(currNAxes)) 

323 assert self.frameSet.getFrame(fsInfo.currInd).nAxes == 2 

324 self.frameSet.permAxes([2, 1]) 

325 self.isBasePermuted = permuteBase 

326 self.isCurrPermuted = permuteCurr 

327 

328 

329class TransformTestBaseClass(lsst.utils.tests.TestCase): 

330 """Base class for unit tests of Transform<X>To<Y> 

331 

332 Subclasses must call `TransformTestBaseClass.setUp(self)` 

333 if they provide their own version. 

334 """ 

335 

336 def setUp(self): 

337 """Set up a test 

338 

339 Subclasses should call this method if they override setUp. 

340 """ 

341 # tell unittest to use the msg argument of asserts as a supplement 

342 # to the error message, rather than as the whole error message 

343 self.longMessage = True 

344 

345 # list of endpoint class name prefixes; the full name is prefix + "Endpoint" 

346 self.endpointPrefixes = ("Generic", "Point2", "SpherePoint") 

347 

348 # GoodNAxes is dict of endpoint class name prefix: 

349 # tuple containing 0 or more valid numbers of axes 

350 self.goodNAxes = { 

351 "Generic": (1, 2, 3, 4), # all numbers of axes are valid for GenericEndpoint 

352 "Point2": (2,), 

353 "SpherePoint": (2,), 

354 } 

355 

356 # BadAxes is dict of endpoint class name prefix: 

357 # tuple containing 0 or more invalid numbers of axes 

358 self.badNAxes = { 

359 "Generic": (), # all numbers of axes are valid for GenericEndpoint 

360 "Point2": (1, 3, 4), 

361 "SpherePoint": (1, 3, 4), 

362 } 

363 

364 # Dict of frame index: identity name for frames created by makeFrameSet 

365 self.frameIdentDict = { 

366 1: "baseFrame", 

367 2: "frame2", 

368 3: "frame3", 

369 4: "currFrame", 

370 } 

371 

372 @staticmethod 

373 def makeRawArrayData(nPoints, nAxes, delta=0.123): 

374 """Make an array of generic point data 

375 

376 The data will be suitable for spherical points 

377 

378 Parameters 

379 ---------- 

380 nPoints : `int` 

381 Number of points in the array 

382 nAxes : `int` 

383 Number of axes in the point 

384 

385 Returns 

386 ------- 

387 np.array of floats with shape (nAxes, nPoints) 

388 The values are as follows; if nAxes != 2: 

389 The first point has values `[0, delta, 2*delta, ..., (nAxes-1)*delta]` 

390 The Nth point has those values + N 

391 if nAxes == 2 then the data is scaled so that the max value of axis 1 

392 is a bit less than pi/2 

393 """ 

394 delta = 0.123 

395 # oneAxis = [0, 1, 2, ...nPoints-1] 

396 oneAxis = np.arange(nPoints, dtype=float) # [0, 1, 2...] 

397 # rawData = [oneAxis, oneAxis + delta, oneAxis + 2 delta, ...] 

398 rawData = np.array([j * delta + oneAxis for j in range(nAxes)], dtype=float) 

399 if nAxes == 2: 

400 # scale rawData so that max value of 2nd axis is a bit less than pi/2, 

401 # thus making the data safe for SpherePoint 

402 maxLatitude = np.max(rawData[1]) 

403 rawData *= math.pi * 0.4999 / maxLatitude 

404 return rawData 

405 

406 @staticmethod 

407 def makeRawPointData(nAxes, delta=0.123): 

408 """Make one generic point 

409 

410 Parameters 

411 ---------- 

412 nAxes : `int` 

413 Number of axes in the point 

414 delta : `float` 

415 Increment between axis values 

416 

417 Returns 

418 ------- 

419 A list of `nAxes` floats with values `[0, delta, ..., (nAxes-1)*delta] 

420 """ 

421 return [i*delta for i in range(nAxes)] 

422 

423 @staticmethod 

424 def makeEndpoint(name, nAxes=None): 

425 """Make an endpoint 

426 

427 Parameters 

428 ---------- 

429 name : `str` 

430 Endpoint class name prefix; the full class name is name + "Endpoint" 

431 nAxes : `int` or `None`, optional 

432 number of axes; an int is required if `name` == "Generic"; 

433 otherwise ignored 

434 

435 Returns 

436 ------- 

437 subclass of `lsst.afw.geom.BaseEndpoint` 

438 The constructed endpoint 

439 

440 Raises 

441 ------ 

442 TypeError 

443 If `name` == "Generic" and `nAxes` is None or <= 0 

444 """ 

445 EndpointClassName = name + "Endpoint" 

446 EndpointClass = getattr(afwGeom, EndpointClassName) 

447 if name == "Generic": 

448 if nAxes is None: 

449 raise TypeError("nAxes must be an integer for GenericEndpoint") 

450 return EndpointClass(nAxes) 

451 return EndpointClass() 

452 

453 @classmethod 

454 def makeGoodFrame(cls, name, nAxes=None): 

455 """Return the appropriate frame for the given name and nAxes 

456 

457 Parameters 

458 ---------- 

459 name : `str` 

460 Endpoint class name prefix; the full class name is name + "Endpoint" 

461 nAxes : `int` or `None`, optional 

462 number of axes; an int is required if `name` == "Generic"; 

463 otherwise ignored 

464 

465 Returns 

466 ------- 

467 `ast.Frame` 

468 The constructed frame 

469 

470 Raises 

471 ------ 

472 TypeError 

473 If `name` == "Generic" and `nAxes` is `None` or <= 0 

474 """ 

475 return cls.makeEndpoint(name, nAxes).makeFrame() 

476 

477 @staticmethod 

478 def makeBadFrames(name): 

479 """Return a list of 0 or more frames that are not a valid match for the 

480 named endpoint 

481 

482 Parameters 

483 ---------- 

484 name : `str` 

485 Endpoint class name prefix; the full class name is name + "Endpoint" 

486 

487 Returns 

488 ------- 

489 Collection of `ast.Frame` 

490 A collection of 0 or more frames 

491 """ 

492 return { 

493 "Generic": [], 

494 "Point2": [ 

495 ast.SkyFrame(), 

496 ast.Frame(1), 

497 ast.Frame(3), 

498 ], 

499 "SpherePoint": [ 

500 ast.Frame(1), 

501 ast.Frame(2), 

502 ast.Frame(3), 

503 ], 

504 }[name] 

505 

506 def makeFrameSet(self, baseFrame, currFrame): 

507 """Make a FrameSet 

508 

509 The FrameSet will contain 4 frames and three transforms connecting them. 

510 The idenity of each frame is provided by self.frameIdentDict 

511 

512 Frame Index Mapping from this frame to the next 

513 `baseFrame` 1 `ast.UnitMap(nIn)` 

514 Frame(nIn) 2 `polyMap` 

515 Frame(nOut) 3 `ast.UnitMap(nOut)` 

516 `currFrame` 4 

517 

518 where: 

519 - `nIn` = `baseFrame.nAxes` 

520 - `nOut` = `currFrame.nAxes` 

521 - `polyMap` = `makeTwoWayPolyMap(nIn, nOut)` 

522 

523 Returns 

524 ------ 

525 `ast.FrameSet` 

526 The FrameSet as described above 

527 

528 Parameters 

529 ---------- 

530 baseFrame : `ast.Frame` 

531 base frame 

532 currFrame : `ast.Frame` 

533 current frame 

534 """ 

535 nIn = baseFrame.nAxes 

536 nOut = currFrame.nAxes 

537 polyMap = makeTwoWayPolyMap(nIn, nOut) 

538 

539 # The only way to set the Ident of a frame in a FrameSet is to set it in advance, 

540 # and I don't want to modify the inputs, so replace the input frames with copies 

541 baseFrame = baseFrame.copy() 

542 baseFrame.ident = self.frameIdentDict[1] 

543 currFrame = currFrame.copy() 

544 currFrame.ident = self.frameIdentDict[4] 

545 

546 frameSet = ast.FrameSet(baseFrame) 

547 frame2 = ast.Frame(nIn) 

548 frame2.ident = self.frameIdentDict[2] 

549 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nIn), frame2) 

550 frame3 = ast.Frame(nOut) 

551 frame3.ident = self.frameIdentDict[3] 

552 frameSet.addFrame(ast.FrameSet.CURRENT, polyMap, frame3) 

553 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nOut), currFrame) 

554 return frameSet 

555 

556 @staticmethod 

557 def permuteFrameSetIter(frameSet): 

558 """Iterator over 0 or more frameSets with SkyFrames axes permuted 

559 

560 Only base and current SkyFrames are permuted. If neither the base nor 

561 the current frame is a SkyFrame then no frames are returned. 

562 

563 Returns 

564 ------- 

565 iterator over `PermutedFrameSet` 

566 """ 

567 

568 fsInfo = FrameSetInfo(frameSet) 

569 if not (fsInfo.isBaseSkyFrame or fsInfo.isCurrSkyFrame): 

570 return 

571 

572 permuteBaseList = [False, True] if fsInfo.isBaseSkyFrame else [False] 

573 permuteCurrList = [False, True] if fsInfo.isCurrSkyFrame else [False] 

574 for permuteBase in permuteBaseList: 

575 for permuteCurr in permuteCurrList: 

576 yield PermutedFrameSet(frameSet, permuteBase, permuteCurr) 

577 

578 @staticmethod 

579 def makeJacobian(nIn, nOut, inPoint): 

580 """Make a Jacobian matrix for the equation described by 

581 `makeTwoWayPolyMap`. 

582 

583 Parameters 

584 ---------- 

585 nIn, nOut : `int` 

586 the dimensions of the input and output data; see makeTwoWayPolyMap 

587 inPoint : `numpy.ndarray` 

588 an array of size `nIn` representing the point at which the Jacobian 

589 is measured 

590 

591 Returns 

592 ------- 

593 J : `numpy.ndarray` 

594 an `nOut` x `nIn` array of first derivatives 

595 """ 

596 basePolyMapCoeff = 0.001 # see makeTwoWayPolyMap 

597 baseCoeff = 2.0 * basePolyMapCoeff 

598 coeffs = np.empty((nOut, nIn)) 

599 for iOut in range(nOut): 

600 coeffOffset = baseCoeff * iOut 

601 for iIn in range(nIn): 

602 coeffs[iOut, iIn] = baseCoeff * (iIn + 1) + coeffOffset 

603 coeffs[iOut, iIn] *= inPoint[iIn] 

604 assert coeffs.ndim == 2 

605 # Avoid spurious errors when comparing to a simplified array 

606 assert coeffs.shape == (nOut, nIn) 

607 return coeffs 

608 

609 def checkTransformation(self, transform, mapping, msg=""): 

610 """Check applyForward and applyInverse for a transform 

611 

612 Parameters 

613 ---------- 

614 transform : `lsst.afw.geom.Transform` 

615 The transform to check 

616 mapping : `ast.Mapping` 

617 The mapping the transform should use. This mapping 

618 must contain valid forward or inverse transformations, 

619 but they need not match if both present. Hence the 

620 mappings returned by make*PolyMap are acceptable. 

621 msg : `str` 

622 Error message suffix describing test parameters 

623 """ 

624 fromEndpoint = transform.fromEndpoint 

625 toEndpoint = transform.toEndpoint 

626 mappingFromTransform = transform.getMapping() 

627 

628 nIn = mapping.nIn 

629 nOut = mapping.nOut 

630 self.assertEqual(nIn, fromEndpoint.nAxes, msg=msg) 

631 self.assertEqual(nOut, toEndpoint.nAxes, msg=msg) 

632 

633 # forward transformation of one point 

634 rawInPoint = self.makeRawPointData(nIn) 

635 inPoint = fromEndpoint.pointFromData(rawInPoint) 

636 

637 # forward transformation of an array of points 

638 nPoints = 7 # arbitrary 

639 rawInArray = self.makeRawArrayData(nPoints, nIn) 

640 inArray = fromEndpoint.arrayFromData(rawInArray) 

641 

642 if mapping.hasForward: 

643 self.assertTrue(transform.hasForward) 

644 outPoint = transform.applyForward(inPoint) 

645 rawOutPoint = toEndpoint.dataFromPoint(outPoint) 

646 assert_allclose(rawOutPoint, mapping.applyForward(rawInPoint), err_msg=msg) 

647 assert_allclose(rawOutPoint, mappingFromTransform.applyForward(rawInPoint), err_msg=msg) 

648 

649 outArray = transform.applyForward(inArray) 

650 rawOutArray = toEndpoint.dataFromArray(outArray) 

651 self.assertFloatsAlmostEqual(rawOutArray, mapping.applyForward(rawInArray), msg=msg) 

652 self.assertFloatsAlmostEqual(rawOutArray, mappingFromTransform.applyForward(rawInArray), msg=msg) 

653 else: 

654 # Need outPoint, but don't need it to be consistent with inPoint 

655 rawOutPoint = self.makeRawPointData(nOut) 

656 outPoint = toEndpoint.pointFromData(rawOutPoint) 

657 rawOutArray = self.makeRawArrayData(nPoints, nOut) 

658 outArray = toEndpoint.arrayFromData(rawOutArray) 

659 

660 self.assertFalse(transform.hasForward) 

661 

662 if mapping.hasInverse: 

663 self.assertTrue(transform.hasInverse) 

664 # inverse transformation of one point; 

665 # remember that the inverse need not give the original values 

666 # (see the description of the `mapping` parameter) 

667 inversePoint = transform.applyInverse(outPoint) 

668 rawInversePoint = fromEndpoint.dataFromPoint(inversePoint) 

669 assert_allclose(rawInversePoint, mapping.applyInverse(rawOutPoint), err_msg=msg) 

670 assert_allclose(rawInversePoint, mappingFromTransform.applyInverse(rawOutPoint), err_msg=msg) 

671 

672 # inverse transformation of an array of points; 

673 # remember that the inverse will not give the original values 

674 # (see the description of the `mapping` parameter) 

675 inverseArray = transform.applyInverse(outArray) 

676 rawInverseArray = fromEndpoint.dataFromArray(inverseArray) 

677 self.assertFloatsAlmostEqual(rawInverseArray, mapping.applyInverse(rawOutArray), msg=msg) 

678 self.assertFloatsAlmostEqual(rawInverseArray, mappingFromTransform.applyInverse(rawOutArray), 

679 msg=msg) 

680 else: 

681 self.assertFalse(transform.hasInverse) 

682 

683 def checkInverseTransformation(self, forward, inverse, msg=""): 

684 """Check that two Transforms are each others' inverses. 

685 

686 Parameters 

687 ---------- 

688 forward : `lsst.afw.geom.Transform` 

689 the reference Transform to test 

690 inverse : `lsst.afw.geom.Transform` 

691 the transform that should be the inverse of `forward` 

692 msg : `str` 

693 error message suffix describing test parameters 

694 """ 

695 fromEndpoint = forward.fromEndpoint 

696 toEndpoint = forward.toEndpoint 

697 forwardMapping = forward.getMapping() 

698 inverseMapping = inverse.getMapping() 

699 

700 # properties 

701 self.assertEqual(forward.fromEndpoint, 

702 inverse.toEndpoint, msg=msg) 

703 self.assertEqual(forward.toEndpoint, 

704 inverse.fromEndpoint, msg=msg) 

705 self.assertEqual(forward.hasForward, inverse.hasInverse, msg=msg) 

706 self.assertEqual(forward.hasInverse, inverse.hasForward, msg=msg) 

707 

708 # transformations of one point 

709 # we don't care about whether the transformation itself is correct 

710 # (see checkTransformation), so inPoint/outPoint need not be related 

711 rawInPoint = self.makeRawPointData(fromEndpoint.nAxes) 

712 inPoint = fromEndpoint.pointFromData(rawInPoint) 

713 rawOutPoint = self.makeRawPointData(toEndpoint.nAxes) 

714 outPoint = toEndpoint.pointFromData(rawOutPoint) 

715 

716 # transformations of arrays of points 

717 nPoints = 7 # arbitrary 

718 rawInArray = self.makeRawArrayData(nPoints, fromEndpoint.nAxes) 

719 inArray = fromEndpoint.arrayFromData(rawInArray) 

720 rawOutArray = self.makeRawArrayData(nPoints, toEndpoint.nAxes) 

721 outArray = toEndpoint.arrayFromData(rawOutArray) 

722 

723 if forward.hasForward: 

724 self.assertEqual(forward.applyForward(inPoint), 

725 inverse.applyInverse(inPoint), msg=msg) 

726 self.assertEqual(forwardMapping.applyForward(rawInPoint), 

727 inverseMapping.applyInverse(rawInPoint), msg=msg) 

728 # Assertions must work with both lists and numpy arrays 

729 assert_array_equal(forward.applyForward(inArray), 

730 inverse.applyInverse(inArray), 

731 err_msg=msg) 

732 assert_array_equal(forwardMapping.applyForward(rawInArray), 

733 inverseMapping.applyInverse(rawInArray), 

734 err_msg=msg) 

735 

736 if forward.hasInverse: 

737 self.assertEqual(forward.applyInverse(outPoint), 

738 inverse.applyForward(outPoint), msg=msg) 

739 self.assertEqual(forwardMapping.applyInverse(rawOutPoint), 

740 inverseMapping.applyForward(rawOutPoint), msg=msg) 

741 assert_array_equal(forward.applyInverse(outArray), 

742 inverse.applyForward(outArray), 

743 err_msg=msg) 

744 assert_array_equal(forwardMapping.applyInverse(rawOutArray), 

745 inverseMapping.applyForward(rawOutArray), 

746 err_msg=msg) 

747 

748 def checkTransformFromMapping(self, fromName, toName): 

749 """Check Transform_<fromName>_<toName> using the Mapping constructor 

750 

751 Parameters 

752 ---------- 

753 fromName, toName : `str` 

754 Endpoint name prefix for "from" and "to" endpoints, respectively, 

755 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

756 fromAxes, toAxes : `int` 

757 number of axes in fromFrame and toFrame, respectively 

758 """ 

759 transformClassName = "Transform{}To{}".format(fromName, toName) 

760 TransformClass = getattr(afwGeom, transformClassName) 

761 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

762 

763 # check valid numbers of inputs and outputs 

764 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

765 self.goodNAxes[toName]): 

766 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

767 polyMap = makeTwoWayPolyMap(nIn, nOut) 

768 transform = TransformClass(polyMap) 

769 

770 # desired output from `str(transform)` 

771 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut) 

772 self.assertEqual("{}".format(transform), desStr) 

773 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr) 

774 

775 self.checkTransformation(transform, polyMap, msg=msg) 

776 

777 # Forward transform but no inverse 

778 polyMap = makeForwardPolyMap(nIn, nOut) 

779 transform = TransformClass(polyMap) 

780 self.checkTransformation(transform, polyMap, msg=msg) 

781 

782 # Inverse transform but no forward 

783 polyMap = makeForwardPolyMap(nOut, nIn).inverted() 

784 transform = TransformClass(polyMap) 

785 self.checkTransformation(transform, polyMap, msg=msg) 

786 

787 # check invalid # of output against valid # of inputs 

788 for nIn, badNOut in itertools.product(self.goodNAxes[fromName], 

789 self.badNAxes[toName]): 

790 badPolyMap = makeTwoWayPolyMap(nIn, badNOut) 

791 msg = "{}, nIn={}, badNOut={}".format(baseMsg, nIn, badNOut) 

792 with self.assertRaises(InvalidParameterError, msg=msg): 

793 TransformClass(badPolyMap) 

794 

795 # check invalid # of inputs against valid and invalid # of outputs 

796 for badNIn, nOut in itertools.product(self.badNAxes[fromName], 

797 self.goodNAxes[toName] + self.badNAxes[toName]): 

798 badPolyMap = makeTwoWayPolyMap(badNIn, nOut) 

799 msg = "{}, badNIn={}, nOut={}".format(baseMsg, nIn, nOut) 

800 with self.assertRaises(InvalidParameterError, msg=msg): 

801 TransformClass(badPolyMap) 

802 

803 def checkTransformFromFrameSet(self, fromName, toName): 

804 """Check Transform_<fromName>_<toName> using the FrameSet constructor 

805 

806 Parameters 

807 ---------- 

808 fromName, toName : `str` 

809 Endpoint name prefix for "from" and "to" endpoints, respectively, 

810 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

811 """ 

812 transformClassName = "Transform{}To{}".format(fromName, toName) 

813 TransformClass = getattr(afwGeom, transformClassName) 

814 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

815 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

816 self.goodNAxes[toName]): 

817 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

818 

819 baseFrame = self.makeGoodFrame(fromName, nIn) 

820 currFrame = self.makeGoodFrame(toName, nOut) 

821 frameSet = self.makeFrameSet(baseFrame, currFrame) 

822 self.assertEqual(frameSet.nFrame, 4) 

823 

824 # construct 0 or more frame sets that are invalid for this transform class 

825 for badBaseFrame in self.makeBadFrames(fromName): 

826 badFrameSet = self.makeFrameSet(badBaseFrame, currFrame) 

827 with self.assertRaises(InvalidParameterError): 

828 TransformClass(badFrameSet) 

829 for badCurrFrame in self.makeBadFrames(toName): 

830 reallyBadFrameSet = self.makeFrameSet(badBaseFrame, badCurrFrame) 

831 with self.assertRaises(InvalidParameterError): 

832 TransformClass(reallyBadFrameSet) 

833 for badCurrFrame in self.makeBadFrames(toName): 

834 badFrameSet = self.makeFrameSet(baseFrame, badCurrFrame) 

835 with self.assertRaises(InvalidParameterError): 

836 TransformClass(badFrameSet) 

837 

838 transform = TransformClass(frameSet) 

839 

840 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut) 

841 self.assertEqual("{}".format(transform), desStr) 

842 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr) 

843 

844 self.checkPersistence(transform) 

845 

846 mappingFromTransform = transform.getMapping() 

847 transformCopy = TransformClass(mappingFromTransform) 

848 self.assertEqual(type(transform), type(transformCopy)) 

849 self.assertEqual(transform.getMapping(), mappingFromTransform) 

850 

851 polyMap = makeTwoWayPolyMap(nIn, nOut) 

852 

853 self.checkTransformation(transform, mapping=polyMap, msg=msg) 

854 

855 # If the base and/or current frame of frameSet is a SkyFrame, 

856 # try permuting that frame (in place, so the connected mappings are 

857 # correctly updated). The Transform constructor should undo the permutation, 

858 # (via SpherePointEndpoint.normalizeFrame) in its internal copy of frameSet, 

859 # forcing the axes of the SkyFrame into standard (longitude, latitude) order 

860 for permutedFS in self.permuteFrameSetIter(frameSet): 

861 if permutedFS.isBaseSkyFrame: 

862 baseFrame = permutedFS.frameSet.getFrame(ast.FrameSet.BASE) 

863 # desired base longitude axis 

864 desBaseLonAxis = 2 if permutedFS.isBasePermuted else 1 

865 self.assertEqual(baseFrame.lonAxis, desBaseLonAxis) 

866 if permutedFS.isCurrSkyFrame: 

867 currFrame = permutedFS.frameSet.getFrame(ast.FrameSet.CURRENT) 

868 # desired current base longitude axis 

869 desCurrLonAxis = 2 if permutedFS.isCurrPermuted else 1 

870 self.assertEqual(currFrame.lonAxis, desCurrLonAxis) 

871 

872 permTransform = TransformClass(permutedFS.frameSet) 

873 self.checkTransformation(permTransform, mapping=polyMap, msg=msg) 

874 

875 def checkInverted(self, fromName, toName): 

876 """Test Transform<fromName>To<toName>.inverted 

877 

878 Parameters 

879 ---------- 

880 fromName, toName : `str` 

881 Endpoint name prefix for "from" and "to" endpoints, respectively, 

882 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

883 """ 

884 transformClassName = "Transform{}To{}".format(fromName, toName) 

885 TransformClass = getattr(afwGeom, transformClassName) 

886 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

887 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

888 self.goodNAxes[toName]): 

889 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

890 self.checkInverseMapping( 

891 TransformClass, 

892 makeTwoWayPolyMap(nIn, nOut), 

893 "{}, Map={}".format(msg, "TwoWay")) 

894 self.checkInverseMapping( 

895 TransformClass, 

896 makeForwardPolyMap(nIn, nOut), 

897 "{}, Map={}".format(msg, "Forward")) 

898 self.checkInverseMapping( 

899 TransformClass, 

900 makeForwardPolyMap(nOut, nIn).inverted(), 

901 "{}, Map={}".format(msg, "Inverse")) 

902 

903 def checkInverseMapping(self, TransformClass, mapping, msg): 

904 """Test Transform<fromName>To<toName>.inverted for a specific 

905 mapping. 

906 

907 Also check that inverted() and getInverted() return the same 

908 transform. 

909 

910 Parameters 

911 ---------- 

912 TransformClass : `type` 

913 The class of transform to test, such as TransformPoint2ToPoint2 

914 mapping : `ast.Mapping` 

915 The mapping to use for the transform 

916 msg : `str` 

917 Error message suffix 

918 """ 

919 transform = TransformClass(mapping) 

920 inverse = transform.inverted() 

921 inverseInverse = inverse.inverted() 

922 

923 self.checkInverseTransformation(transform, inverse, msg=msg) 

924 self.checkInverseTransformation(inverse, inverseInverse, msg=msg) 

925 self.checkTransformation(inverseInverse, mapping, msg=msg) 

926 

927 def checkGetJacobian(self, fromName, toName): 

928 """Test Transform<fromName>To<toName>.getJacobian 

929 

930 Parameters 

931 ---------- 

932 fromName, toName : `str` 

933 Endpoint name prefix for "from" and "to" endpoints, respectively, 

934 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

935 """ 

936 transformClassName = "Transform{}To{}".format(fromName, toName) 

937 TransformClass = getattr(afwGeom, transformClassName) 

938 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

939 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

940 self.goodNAxes[toName]): 

941 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

942 polyMap = makeForwardPolyMap(nIn, nOut) 

943 transform = TransformClass(polyMap) 

944 fromEndpoint = transform.fromEndpoint 

945 

946 # Test multiple points to ensure correct functional form 

947 rawInPoint = self.makeRawPointData(nIn) 

948 inPoint = fromEndpoint.pointFromData(rawInPoint) 

949 jacobian = transform.getJacobian(inPoint) 

950 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint), 

951 err_msg=msg) 

952 

953 rawInPoint = self.makeRawPointData(nIn, 0.111) 

954 inPoint = fromEndpoint.pointFromData(rawInPoint) 

955 jacobian = transform.getJacobian(inPoint) 

956 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint), 

957 err_msg=msg) 

958 

959 def checkThen(self, fromName, midName, toName): 

960 """Test Transform<fromName>To<midName>.then(Transform<midName>To<toName>) 

961 

962 Parameters 

963 ---------- 

964 fromName : `str` 

965 the prefix of the starting endpoint (e.g., "Point2" for a 

966 Point2Endpoint) for the final, concatenated Transform 

967 midName : `str` 

968 the prefix for the shared endpoint where two Transforms will be 

969 concatenated 

970 toName : `str` 

971 the prefix of the ending endpoint for the final, concatenated 

972 Transform 

973 """ 

974 TransformClass1 = getattr(afwGeom, 

975 "Transform{}To{}".format(fromName, midName)) 

976 TransformClass2 = getattr(afwGeom, 

977 "Transform{}To{}".format(midName, toName)) 

978 baseMsg = "{}.then({})".format(TransformClass1.__name__, 

979 TransformClass2.__name__) 

980 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName], 

981 self.goodNAxes[midName], 

982 self.goodNAxes[toName]): 

983 msg = "{}, nIn={}, nMid={}, nOut={}".format( 

984 baseMsg, nIn, nMid, nOut) 

985 polyMap1 = makeTwoWayPolyMap(nIn, nMid) 

986 transform1 = TransformClass1(polyMap1) 

987 polyMap2 = makeTwoWayPolyMap(nMid, nOut) 

988 transform2 = TransformClass2(polyMap2) 

989 transform = transform1.then(transform2) 

990 

991 fromEndpoint = transform1.fromEndpoint 

992 toEndpoint = transform2.toEndpoint 

993 

994 inPoint = fromEndpoint.pointFromData(self.makeRawPointData(nIn)) 

995 outPointMerged = transform.applyForward(inPoint) 

996 outPointSeparate = transform2.applyForward( 

997 transform1.applyForward(inPoint)) 

998 assert_allclose(toEndpoint.dataFromPoint(outPointMerged), 

999 toEndpoint.dataFromPoint(outPointSeparate), 

1000 err_msg=msg) 

1001 

1002 outPoint = toEndpoint.pointFromData(self.makeRawPointData(nOut)) 

1003 inPointMerged = transform.applyInverse(outPoint) 

1004 inPointSeparate = transform1.applyInverse( 

1005 transform2.applyInverse(outPoint)) 

1006 assert_allclose( 

1007 fromEndpoint.dataFromPoint(inPointMerged), 

1008 fromEndpoint.dataFromPoint(inPointSeparate), 

1009 err_msg=msg) 

1010 

1011 # Mismatched number of axes should fail 

1012 if midName == "Generic": 

1013 nIn = self.goodNAxes[fromName][0] 

1014 nOut = self.goodNAxes[toName][0] 

1015 polyMap = makeTwoWayPolyMap(nIn, 3) 

1016 transform1 = TransformClass1(polyMap) 

1017 polyMap = makeTwoWayPolyMap(2, nOut) 

1018 transform2 = TransformClass2(polyMap) 

1019 with self.assertRaises(InvalidParameterError): 

1020 transform = transform1.then(transform2) 

1021 

1022 # Mismatched types of endpoints should fail 

1023 if fromName != midName: 

1024 # Use TransformClass1 for both args to keep test logic simple 

1025 outName = midName 

1026 joinNAxes = set(self.goodNAxes[fromName]).intersection( 

1027 self.goodNAxes[outName]) 

1028 

1029 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName], 

1030 joinNAxes, 

1031 self.goodNAxes[outName]): 

1032 polyMap = makeTwoWayPolyMap(nIn, nMid) 

1033 transform1 = TransformClass1(polyMap) 

1034 polyMap = makeTwoWayPolyMap(nMid, nOut) 

1035 transform2 = TransformClass1(polyMap) 

1036 with self.assertRaises(InvalidParameterError): 

1037 transform = transform1.then(transform2) 

1038 

1039 def assertTransformsEqual(self, transform1, transform2): 

1040 """Assert that two transforms are equal""" 

1041 self.assertEqual(type(transform1), type(transform2)) 

1042 self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint) 

1043 self.assertEqual(transform1.toEndpoint, transform2.toEndpoint) 

1044 self.assertEqual(transform1.getMapping(), transform2.getMapping()) 

1045 

1046 fromEndpoint = transform1.fromEndpoint 

1047 toEndpoint = transform1.toEndpoint 

1048 mapping = transform1.getMapping() 

1049 nIn = mapping.nIn 

1050 nOut = mapping.nOut 

1051 

1052 if mapping.hasForward: 

1053 nPoints = 7 # arbitrary 

1054 rawInArray = self.makeRawArrayData(nPoints, nIn) 

1055 inArray = fromEndpoint.arrayFromData(rawInArray) 

1056 outArray = transform1.applyForward(inArray) 

1057 outData = toEndpoint.dataFromArray(outArray) 

1058 outArrayRoundTrip = transform2.applyForward(inArray) 

1059 outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip) 

1060 assert_allclose(outData, outDataRoundTrip) 

1061 

1062 if mapping.hasInverse: 

1063 nPoints = 7 # arbitrary 

1064 rawOutArray = self.makeRawArrayData(nPoints, nOut) 

1065 outArray = toEndpoint.arrayFromData(rawOutArray) 

1066 inArray = transform1.applyInverse(outArray) 

1067 inData = fromEndpoint.dataFromArray(inArray) 

1068 inArrayRoundTrip = transform2.applyInverse(outArray) 

1069 inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip) 

1070 assert_allclose(inData, inDataRoundTrip) 

1071 

1072 def checkPersistence(self, transform): 

1073 """Check persistence of a transform 

1074 """ 

1075 className = type(transform).__name__ 

1076 

1077 # check writeString and readString 

1078 transformStr = transform.writeString() 

1079 serialVersion, serialClassName, serialRest = transformStr.split(" ", 2) 

1080 self.assertEqual(int(serialVersion), 1) 

1081 self.assertEqual(serialClassName, className) 

1082 badStr1 = " ".join(["2", serialClassName, serialRest]) 

1083 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError): 

1084 transform.readString(badStr1) 

1085 badClassName = "x" + serialClassName 

1086 badStr2 = " ".join(["1", badClassName, serialRest]) 

1087 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError): 

1088 transform.readString(badStr2) 

1089 transformFromStr1 = transform.readString(transformStr) 

1090 self.assertTransformsEqual(transform, transformFromStr1) 

1091 

1092 # check transformFromString 

1093 transformFromStr2 = afwGeom.transformFromString(transformStr) 

1094 self.assertTransformsEqual(transform, transformFromStr2) 

1095 

1096 # Check pickling 

1097 self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform))) 

1098 

1099 # Check afw::table::io persistence round-trip 

1100 with lsst.utils.tests.getTempFilePath(".fits") as filename: 

1101 transform.writeFits(filename) 

1102 self.assertTransformsEqual(transform, type(transform).readFits(filename))