Coverage for python / lsst / meas / extensions / trailedSources / NaivePlugin.py: 13%
226 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 08:43 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 08:43 +0000
1#
2# This file is part of meas_extensions_trailedSources.
3#
4# Developed for the LSST Data Management System.
5# This product includes software developed by the LSST Project
6# (http://www.lsst.org).
7# See the COPYRIGHT file at the top-level directory of this distribution
8# for details of code ownership.
9#
10# This program is free software: you can redistribute it and/or modify
11# it under the terms of the GNU General Public License as published by
12# the Free Software Foundation, either version 3 of the License, or
13# (at your option) any later version.
14#
15# This program is distributed in the hope that it will be useful,
16# but WITHOUT ANY WARRANTY; without even the implied warranty of
17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18# GNU General Public License for more details.
19#
20# You should have received a copy of the GNU General Public License
21# along with this program. If not, see <http://www.gnu.org/licenses/>.
22#
24import logging
25import numpy as np
26import scipy.optimize as sciOpt
27from scipy.special import erf
28from math import sqrt
30from lsst.geom import Point2D, Point2I
31from lsst.meas.base.pluginRegistry import register
32from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig
33from lsst.meas.base import FlagHandler, FlagDefinitionList
34import lsst.pex.config
36from ._trailedSources import VeresModel
37from .utils import getMeasurementCutout
39__all__ = ("SingleFrameNaiveTrailConfig", "SingleFrameNaiveTrailPlugin")
42class SingleFrameNaiveTrailConfig(SingleFramePluginConfig):
43 """Config class for SingleFrameNaiveTrailPlugin.
44 """
45 maxFlux = lsst.pex.config.Field(
46 dtype=float,
47 default=1e10,
48 doc="Maximum calculated model flux before falling back on aperture flux."
49 )
52@register("ext_trailedSources_Naive")
53class SingleFrameNaiveTrailPlugin(SingleFramePlugin):
54 """Naive trailed source measurement plugin
56 Measures the length, angle from +x-axis, and end points of an extended
57 source using the second moments.
59 Parameters
60 ----------
61 config: `SingleFrameNaiveTrailConfig`
62 Plugin configuration.
63 name: `str`
64 Plugin name.
65 schema: `lsst.afw.table.Schema`
66 Schema for the output catalog.
67 metadata: `lsst.daf.base.PropertySet`
68 Metadata to be attached to output catalog.
70 Notes
71 -----
72 This measurement plugin aims to utilize the already measured adaptive
73 second moments to naively estimate the length and angle, and thus
74 end-points, of a fast-moving, trailed source. The length is solved for via
75 finding the root of the difference between the numerical (stack computed)
76 and the analytic adaptive second moments. The angle, theta, from the x-axis
77 is also computed via adaptive moments: theta = arctan(2*Ixy/(Ixx - Iyy))/2.
78 The end points of the trail are then given by (xc +/- (length/2)*cos(theta)
79 and yc +/- (length/2)*sin(theta)), with xc and yc being the centroid
80 coordinates.
82 See also
83 --------
84 lsst.meas.base.SingleFramePlugin
85 """
87 ConfigClass = SingleFrameNaiveTrailConfig
89 @classmethod
90 def getExecutionOrder(cls):
91 # Needs centroids, shape, and flux measurements.
92 # VeresPlugin is run after, which requires image data.
93 return cls.APCORR_ORDER + 0.1
95 def __init__(self, config, name, schema, metadata, logName=None):
96 if logName is None:
97 logName = __name__
98 super().__init__(config, name, schema, metadata, logName=logName)
100 # Measurement Keys
101 self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.")
102 self.keyDec = schema.addField(name + "_dec", type="D", doc="Trail centroid declination.")
103 self.keyX0 = schema.addField(name + "_x0", type="D", doc="Trail head X coordinate.", units="pixel")
104 self.keyY0 = schema.addField(name + "_y0", type="D", doc="Trail head Y coordinate.", units="pixel")
105 self.keyX1 = schema.addField(name + "_x1", type="D", doc="Trail tail X coordinate.", units="pixel")
106 self.keyY1 = schema.addField(name + "_y1", type="D", doc="Trail tail Y coordinate.", units="pixel")
107 self.keyFlux = schema.addField(name + "_flux", type="D", doc="Trailed source flux.", units="count")
108 self.keyLength = schema.addField(name + "_length", type="D", doc="Trail length.", units="pixel")
109 self.keyAngle = schema.addField(name + "_angle", type="D", doc="Angle measured from +x-axis.")
111 # Measurement Error Keys
112 self.keyX0Err = schema.addField(name + "_x0Err", type="D",
113 doc="Trail head X coordinate error.", units="pixel")
114 self.keyY0Err = schema.addField(name + "_y0Err", type="D",
115 doc="Trail head Y coordinate error.", units="pixel")
116 self.keyX1Err = schema.addField(name + "_x1Err", type="D",
117 doc="Trail tail X coordinate error.", units="pixel")
118 self.keyY1Err = schema.addField(name + "_y1Err", type="D",
119 doc="Trail tail Y coordinate error.", units="pixel")
120 self.keyFluxErr = schema.addField(name + "_fluxErr", type="D",
121 doc="Trail flux error.", units="count")
122 self.keyLengthErr = schema.addField(name + "_lengthErr", type="D",
123 doc="Trail length error.", units="pixel")
124 self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.")
125 self.keyAlgorithm = schema.addField(name + "_algorithmKey", type="I",
126 doc="Algorithm key indicating which algorithm is "
127 "used to measure trailed source.")
129 flagDefs = FlagDefinitionList()
130 self.FAILURE = flagDefs.addFailureFlag("No trailed-source measured")
131 self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement")
132 self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge")
133 self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)")
134 self.EDGE = flagDefs.add("flag_edge", "Trail contains edge pixels")
135 self.OFFIMAGE = flagDefs.add("flag_off_image", "Trail extends off image")
136 self.NAN = flagDefs.add("flag_nan", "One or more trail coordinates are missing")
137 self.SUSPECT_LONG_TRAIL = flagDefs.add("flag_suspect_long_trail",
138 "Trail length is greater than three times the psf radius")
139 self.SHAPE = flagDefs.add("flag_shape", "Shape flag is set, trail length not calculated")
140 self.SAFE_CENTROID = flagDefs.add("flag_centroid",
141 "Centroid flag is set, trail length not calculated")
142 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs)
144 self.log = logging.getLogger(self.logName)
146 def measure(self, measRecord, exposure):
147 """Run the Naive trailed source measurement algorithm.
149 Parameters
150 ----------
151 measRecord : `lsst.afw.table.SourceRecord`
152 Record describing the object being measured.
153 exposure : `lsst.afw.image.Exposure`
154 Pixel data to be measured.
156 See also
157 --------
158 lsst.meas.base.SingleFramePlugin.measure
159 """
160 self.flagHandler.setValue(measRecord, self.FAILURE.number, False)
161 use_sdss_shape = True
162 if measRecord['base_SdssShape_flag']:
163 if measRecord.getShapeFlag():
164 self.log.debug("HSM shape flag is also set for measRecord: %s. Trail measurement "
165 "will not be made. All trail values will be set to nan.", measRecord.getId())
166 self.flagHandler.setValue(measRecord, self.FAILURE.number, True)
167 self.flagHandler.setValue(measRecord, self.SHAPE.number, True)
168 return
169 else:
170 use_sdss_shape = False
171 measRecord.set(self.keyAlgorithm, 2)
172 self.log.debug(
173 "SDSS Shape flag is set for measRecord: %s. Falling back"
174 "to HSMshape. No error measurements will be made.", measRecord.getId())
175 xc = measRecord["slot_Shape_x"]
176 yc = measRecord["slot_Shape_y"]
177 Ixx, Iyy, Ixy = measRecord.getShape().getParameterVector()
179 else:
180 measRecord.set(self.keyAlgorithm, 1)
181 xc = measRecord["base_SdssShape_x"]
182 yc = measRecord["base_SdssShape_y"]
183 Ixx = measRecord["base_SdssShape_xx"]
184 Iyy = measRecord["base_SdssShape_yy"]
185 Ixy = measRecord["base_SdssShape_xy"]
187 if not np.isfinite(xc) or not np.isfinite(yc):
188 self.flagHandler.setValue(measRecord, self.SAFE_CENTROID.number, True)
189 self.flagHandler.setValue(measRecord, self.FAILURE.number, True)
190 return
191 ra, dec = self.computeRaDec(exposure, xc, yc)
193 # Transform the second-moments to semi-major and minor axes
194 xmy = Ixx - Iyy
195 xpy = Ixx + Iyy
196 xmy2 = xmy*xmy
197 xy2 = Ixy*Ixy
198 a2 = 0.5 * (xpy + sqrt(xmy2 + 4.0*xy2))
199 b2 = 0.5 * (xpy - sqrt(xmy2 + 4.0*xy2))
201 # Measure the trail length
202 length, gradLength, results = self.findLength(a2, b2)
203 if not results.converged:
204 self.log.info("Results not converged: %s", results.flag)
205 self.flagHandler.setValue(measRecord, self.NO_CONVERGE.number, True)
206 self.flagHandler.setValue(measRecord, self.FAILURE.number, True)
207 return
209 # Compute the angle of the trail from the x-axis
210 theta = 0.5 * np.arctan2(2.0 * Ixy, xmy)
212 # Get end-points of the trail (there is a degeneracy here)
213 radius = length/2.0 # Trail 'radius'
214 dydtheta = radius*np.cos(theta)
215 dxdtheta = radius*np.sin(theta)
216 x0 = xc - dydtheta
217 y0 = yc - dxdtheta
218 x1 = xc + dydtheta
219 y1 = yc + dxdtheta
221 self.check_trail(measRecord, exposure, x0, y0, x1, y1, length)
223 # Get a cutout of the object from the exposure
224 cutout = getMeasurementCutout(measRecord, exposure)
226 # Compute flux assuming fixed parameters for VeresModel
227 params = np.array([xc, yc, 1.0, length, theta]) # Flux = 1.0
228 model = VeresModel(cutout)
229 flux, gradFlux = model.computeFluxWithGradient(params)
231 # Fall back to aperture flux
232 if (not np.isfinite(flux)) | (np.abs(flux) > self.config.maxFlux):
233 if np.isfinite(measRecord.getApInstFlux()):
234 flux = measRecord.getApInstFlux()
235 else:
236 self.flagHandler.setValue(measRecord, self.NO_FLUX.number, True)
237 self.flagHandler.setValue(measRecord, self.FAILURE.number, True)
238 return
240 # Errors can only be calculated when using SDSS shape. Otherwise,
241 # the errors are set to nan.
242 if use_sdss_shape:
243 # Propogate errors from second moments and centroid.
244 # Retrieved error is the standard of deviation, not
245 # covariance and must be squared.
246 IxxErr2 = measRecord["base_SdssShape_xxErr"]**2
247 IyyErr2 = measRecord["base_SdssShape_yyErr"]**2
248 IxyErr2 = measRecord["base_SdssShape_xyErr"]**2
250 # Centroid Errors
251 xcErr2 = measRecord["base_SdssCentroid_xErr"]**2
252 ycErr2 = measRecord["base_SdssCentroid_yErr"]**2
254 # Error in length
255 desc = sqrt(xmy2 + 4.0*xy2) # Descriminant^1/2 of EV equation
256 da2dIxx = 0.5*(1.0 + (xmy/desc))
257 da2dIyy = 0.5*(1.0 - (xmy/desc))
258 da2dIxy = 2.0*Ixy / desc
259 a2Err2 = IxxErr2*da2dIxx*da2dIxx + IyyErr2*da2dIyy*da2dIyy + IxyErr2*da2dIxy*da2dIxy
260 b2Err2 = IxxErr2*da2dIyy*da2dIyy + IyyErr2*da2dIxx*da2dIxx + IxyErr2*da2dIxy*da2dIxy
261 dLda2, dLdb2 = gradLength
262 lengthErr = np.sqrt(dLda2*dLda2*a2Err2 + dLdb2*dLdb2*b2Err2)
264 # Error in theta
265 dThetadIxx = -Ixy / (xmy2 + 4.0*xy2) # dThetadIxx = -dThetadIyy
266 dThetadIxy = xmy / (xmy2 + 4.0*xy2)
267 thetaErr = sqrt(dThetadIxx*dThetadIxx*(IxxErr2 + IyyErr2) + dThetadIxy*dThetadIxy*IxyErr2)
269 # Error in flux
270 dFdxc, dFdyc, _, dFdL, dFdTheta = gradFlux
271 fluxErr = sqrt(dFdL*dFdL*lengthErr*lengthErr + dFdTheta*dFdTheta*thetaErr*thetaErr
272 + dFdxc*dFdxc*xcErr2 + dFdyc*dFdyc*ycErr2)
274 # Errors in end-points
275 dxdradius = np.cos(theta)
276 dydradius = np.sin(theta)
277 radiusErr2 = lengthErr*lengthErr/4.0
278 xErr2 = sqrt(xcErr2 + radiusErr2*dxdradius*dxdradius + thetaErr*thetaErr*dxdtheta*dxdtheta)
279 yErr2 = sqrt(ycErr2 + radiusErr2*dydradius*dydradius + thetaErr*thetaErr*dydtheta*dydtheta)
280 x0Err = sqrt(xErr2) # Same for x1
281 y0Err = sqrt(yErr2) # Same for y1
283 # Set error values
284 measRecord.set(self.keyX0Err, x0Err)
285 measRecord.set(self.keyY0Err, y0Err)
286 measRecord.set(self.keyX1Err, x0Err)
287 measRecord.set(self.keyY1Err, y0Err)
288 measRecord.set(self.keyFluxErr, fluxErr)
289 measRecord.set(self.keyLengthErr, lengthErr)
290 measRecord.set(self.keyAngleErr, thetaErr)
292 # Set values
293 measRecord.set(self.keyRa, ra)
294 measRecord.set(self.keyDec, dec)
295 measRecord.set(self.keyX0, x0)
296 measRecord.set(self.keyY0, y0)
297 measRecord.set(self.keyX1, x1)
298 measRecord.set(self.keyY1, y1)
299 measRecord.set(self.keyFlux, flux)
300 measRecord.set(self.keyLength, length)
301 measRecord.set(self.keyAngle, theta)
303 def check_trail(self, measRecord, exposure, x0, y0, x1, y1, length):
304 """ Set flags for edge pixels, off chip, and nan trail coordinates and
305 flag if trail length is three times larger than psf.
307 Check if the coordinates of the beginning and ending of the trail fall
308 inside the exposures bounding box. If not, set the off_chip flag.
309 If the beginning or ending falls within a pixel marked as edge, set the
310 edge flag. If any of the coordinates happens to fall on a nan, then
311 set the nan flag.
312 Additionally, check if the trail is three times larger than the psf. If
313 so, set the suspect trail flag.
315 Parameters
316 ----------
317 measRecord: `lsst.afw.MeasurementRecord`
318 Record describing the object being measured.
319 exposure: `lsst.afw.Exposure`
320 Pixel data to be measured.
322 x0: `float`
323 x coordinate of the beginning of the trail.
324 y0: `float`
325 y coordinate of the beginning of the trail.
326 x1: `float`
327 x coordinate of the end of the trail.
328 y1: `float`
329 y coordinate of the end of the trail.
330 """
331 x_coords = [x0, x1]
332 y_coords = [y0, y1]
334 # Check if one of the end points of the trail sources is nan. If so,
335 # set the trailed source nan flag.
336 if np.isnan(x_coords).any() or np.isnan(y_coords).any():
337 self.flagHandler.setValue(measRecord, self.NAN.number, True)
338 x_coords = [x for x in x_coords if not np.isnan(x)]
339 y_coords = [y for y in y_coords if not np.isnan(y)]
341 # Check if the non-nan coordinates are within the bounding box
342 if not (all(exposure.getBBox().beginX <= x <= exposure.getBBox().endX for x in x_coords)
343 and all(exposure.getBBox().beginY <= y <= exposure.getBBox().endY for y in y_coords)):
344 self.flagHandler.setValue(measRecord, self.EDGE.number, True)
345 self.flagHandler.setValue(measRecord, self.OFFIMAGE.number, True)
346 else:
347 # Check if edge is set for any of the pixel pairs. Do not
348 # check any that have a nan.
349 for (x_val, y_val) in zip(x_coords, y_coords):
350 if x_val is not np.nan and y_val is not np.nan:
351 if exposure.mask[Point2I(int(x_val),
352 int(y_val))] & exposure.mask.getPlaneBitMask('EDGE') != 0:
353 self.flagHandler.setValue(measRecord, self.EDGE.number, True)
354 # Check whether trail extends off the edge of the exposure. Allows nans
355 # as their location
356 elif not (all(exposure.getBBox().beginX <= x <= exposure.getBBox().endX for x in x_coords)
357 and all(exposure.getBBox().beginY <= y <= exposure.getBBox().endY for y in y_coords)):
358 self.flagHandler.setValue(measRecord, self.EDGE.number, True)
359 self.flagHandler.setValue(measRecord, self.OFFIMAGE.number, True)
360 else:
361 # Check whether the beginning or end point of the trail has the
362 # edge flag set. The end points are not whole pixel values, so
363 # the pixel value must be rounded.
364 if exposure.mask[Point2I(int(x0), int(y0))] and exposure.mask[Point2I(int(x1), int(y1))]:
365 if ((exposure.mask[Point2I(int(x0), int(y0))] & exposure.mask.getPlaneBitMask('EDGE') != 0)
366 or (exposure.mask[Point2I(int(x1), int(y1))]
367 & exposure.mask.getPlaneBitMask('EDGE') != 0)):
368 self.flagHandler.setValue(measRecord, self.EDGE.number, True)
370 psfShape = exposure.psf.computeShape(exposure.getBBox().getCenter())
371 psfRadius = psfShape.getDeterminantRadius()
373 if length > psfRadius*3.0:
374 self.flagHandler.setValue(measRecord, self.SUSPECT_LONG_TRAIL.number, True)
376 def fail(self, measRecord, error=None):
377 """Record failure
379 See also
380 --------
381 lsst.meas.base.SingleFramePlugin.fail
382 """
383 if error is None:
384 self.flagHandler.handleFailure(measRecord)
385 else:
386 self.flagHandler.handleFailure(measRecord, error.cpp)
388 @staticmethod
389 def _computeSecondMomentDiff(z, c):
390 """Compute difference of the numerical and analytic second moments.
392 Parameters
393 ----------
394 z : `float`
395 Proportional to the length of the trail. (see notes)
396 c : `float`
397 Constant (see notes)
399 Returns
400 -------
401 diff : `float`
402 Difference in numerical and analytic second moments.
404 Notes
405 -----
406 This is a simplified expression for the difference between the stack
407 computed adaptive second-moment and the analytic solution. The variable
408 z is proportional to the length such that length=2*z*sqrt(2*(Ixx+Iyy)),
409 and c is a constant (c = 4*Ixx/((Ixx+Iyy)*sqrt(pi))). Both have been
410 defined to avoid unnecessary floating-point operations in the root
411 finder.
412 """
414 diff = erf(z) - c*z*np.exp(-z*z)
415 return diff
417 @classmethod
418 def findLength(cls, Ixx, Iyy):
419 """Find the length of a trail, given adaptive second-moments.
421 Uses a root finder to compute the length of a trail corresponding to
422 the adaptive second-moments computed by previous measurements
423 (ie. SdssShape).
425 Parameters
426 ----------
427 Ixx : `float`
428 Adaptive second-moment along x-axis.
429 Iyy : `float`
430 Adaptive second-moment along y-axis.
432 Returns
433 -------
434 length : `float`
435 Length of the trail.
436 results : `scipy.optimize.RootResults`
437 Contains messages about convergence from the root finder.
438 """
440 xpy = Ixx + Iyy
441 c = 4.0*Ixx/(xpy*np.sqrt(np.pi))
443 # Given a 'c' in (c_min, c_max], the root is contained in (0,1].
444 # c_min is given by the case: Ixx == Iyy, ie. a point source.
445 # c_max is given by the limit Ixx >> Iyy.
446 # Empirically, 0.001 is a suitable lower bound, assuming Ixx > Iyy.
447 z, results = sciOpt.brentq(lambda z: cls._computeSecondMomentDiff(z, c),
448 0.001, 1.0, full_output=True)
450 length = 2.0*z*np.sqrt(2.0*xpy)
451 gradLength = cls._gradFindLength(Ixx, Iyy, z, c)
452 return length, gradLength, results
454 @staticmethod
455 def _gradFindLength(Ixx, Iyy, z, c):
456 """Compute the gradient of the findLength function.
457 """
458 spi = np.sqrt(np.pi)
459 xpy = Ixx+Iyy
460 xpy2 = xpy*xpy
461 enz2 = np.exp(-z*z)
462 sxpy = np.sqrt(xpy)
464 fac = 4.0 / (spi*xpy2)
465 dcdIxx = Iyy*fac
466 dcdIyy = -Ixx*fac
468 # Derivatives of the _computeMomentsDiff function
469 dfdc = z*enz2
470 dzdf = spi / (enz2*(spi*c*(2.0*z*z - 1.0) + 2.0)) # inverse of dfdz
472 dLdz = 2.0*np.sqrt(2.0)*sxpy
473 pLpIxx = np.sqrt(2.0)*z / sxpy # Same as pLpIyy
475 dLdc = dLdz*dzdf*dfdc
476 dLdIxx = dLdc*dcdIxx + pLpIxx
477 dLdIyy = dLdc*dcdIyy + pLpIxx
478 return dLdIxx, dLdIyy
480 @staticmethod
481 def computeLength(Ixx, Iyy):
482 """Compute the length of a trail, given unweighted second-moments.
483 """
484 denom = np.sqrt(Ixx - 2.0*Iyy)
486 length = np.sqrt(6.0)*denom
488 dLdIxx = np.sqrt(1.5) / denom
489 dLdIyy = -np.sqrt(6.0) / denom
490 return length, (dLdIxx, dLdIyy)
492 @staticmethod
493 def computeRaDec(exposure, x, y):
494 """Convert pixel coordinates to RA and Dec.
496 Parameters
497 ----------
498 exposure : `lsst.afw.image.ExposureF`
499 Exposure object containing the WCS.
500 x : `float`
501 x coordinate of the trail centroid
502 y : `float`
503 y coodinate of the trail centroid
505 Returns
506 -------
507 ra : `float`
508 Right ascension.
509 dec : `float`
510 Declination.
511 """
513 wcs = exposure.getWcs()
514 center = wcs.pixelToSky(Point2D(x, y))
515 ra = center.getRa().asDegrees()
516 dec = center.getDec().asDegrees()
517 return ra, dec