Coverage for python/lsst/meas/algorithms/accumulator_mean_stack.py: 9%

95 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 08:01 +0000

1# This file is part of meas_algorithms. 

2# 

3# LSST Data Management System 

4# This product includes software developed by the 

5# LSST Project (http://www.lsst.org/). 

6# See COPYRIGHT file at the top of the source tree. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <https://www.lsstcorp.org/LegalNotices/>. 

21# 

22import warnings 

23 

24import numpy as np 

25 

26 

27__all__ = ['AccumulatorMeanStack'] 

28 

29 

30class AccumulatorMeanStack: 

31 """Stack masked images. 

32 

33 Parameters 

34 ---------- 

35 shape : `tuple` 

36 Shape of the input and output images. 

37 bit_mask_value : `int` 

38 Bit mask to flag for "bad" inputs that should not be stacked. 

39 mask_threshold_dict : `dict` [`int`: `float`], optional 

40 Dictionary of mapping from bit number to threshold for flagging. 

41 Only bad bits (in bit_mask_value) which mask fractional weight 

42 greater than this threshold will be flagged in the output image. 

43 mask_map : `list` [`tuple`], optional 

44 Mapping from input image bits to aggregated coadd bits. 

45 no_good_pixels_mask : `int`, optional 

46 Bit mask to set when there are no good pixels in the stack. 

47 If not set then will set coadd masked image 'NO_DATA' bit. 

48 calc_error_from_input_variance : `bool`, optional 

49 Calculate the error from the input variance? 

50 compute_n_image : `bool`, optional 

51 Calculate the n_image map as well as stack? 

52 """ 

53 def __init__(self, shape, 

54 bit_mask_value, mask_threshold_dict={}, 

55 mask_map=[], no_good_pixels_mask=None, 

56 calc_error_from_input_variance=True, 

57 compute_n_image=False): 

58 self.shape = shape 

59 self.bit_mask_value = bit_mask_value 

60 self.mask_map = mask_map 

61 self.no_good_pixels_mask = no_good_pixels_mask 

62 self.calc_error_from_input_variance = calc_error_from_input_variance 

63 self.compute_n_image = compute_n_image 

64 

65 # Only track threshold bits that are in the bad bit_mask_value. 

66 self.mask_threshold_dict = {} 

67 for bit in mask_threshold_dict: 

68 if (self.bit_mask_value & 2**bit) > 0: 

69 self.mask_threshold_dict[bit] = mask_threshold_dict[bit] 

70 

71 # sum_weight holds the sum of weights for each pixel. 

72 self.sum_weight = np.zeros(shape, dtype=np.float64) 

73 # sum_wdata holds the sum of weight*data for each pixel. 

74 self.sum_wdata = np.zeros(shape, dtype=np.float64) 

75 

76 if calc_error_from_input_variance: 

77 # sum_w2var holds the sum of weight**2 * variance for each pixel. 

78 self.sum_w2var = np.zeros(shape, dtype=np.float64) 

79 else: 

80 # sum_weight2 holds the sum of weight**2 for each pixel. 

81 self.sum_weight2 = np.zeros(shape, dtype=np.float64) 

82 # sum_wdata2 holds the sum of weight * data**2 for each pixel. 

83 self.sum_wdata2 = np.zeros(shape, dtype=np.float64) 

84 

85 self.or_mask = np.zeros(shape, dtype=np.int64) 

86 self.rejected_weights_by_bit = {} 

87 for bit in self.mask_threshold_dict: 

88 self.rejected_weights_by_bit[bit] = np.zeros(shape, dtype=np.float64) 

89 

90 self.masked_pixels_mask = np.zeros(shape, dtype=np.int64) 

91 

92 if self.compute_n_image: 

93 self.n_image = np.zeros(shape, dtype=np.int32) 

94 

95 def reset(self): 

96 """Reset all accumulator arrays.""" 

97 self.sum_weight[...] = 0 

98 self.sum_wdata[...] = 0 

99 if self.calc_error_from_input_variance: 

100 self.sum_w2var[...] = 0 

101 else: 

102 self.sum_weight2[...] = 0 

103 self.sum_wdata2[...] = 0 

104 self.or_mask[...] = 0 

105 for bit in self.mask_threshold_dict: 

106 self.rejected_weights_by_bit[bit][...] = 0 

107 self.masked_pixels_mask[...] = 0 

108 if self.compute_n_image: 

109 self.n_image[...] = 0 

110 

111 def add_masked_image(self, masked_image, weight=1.0): 

112 """Add a masked image to the stack. 

113 

114 Parameters 

115 ---------- 

116 masked_image : `lsst.afw.image.MaskedImage` 

117 Masked image to add to the stack. 

118 weight : `float`, optional 

119 Weight to apply for weighted mean. 

120 """ 

121 good_pixels = np.where(((masked_image.mask.array & self.bit_mask_value) == 0) 

122 & np.isfinite(masked_image.mask.array)) 

123 

124 self.sum_weight[good_pixels] += weight 

125 self.sum_wdata[good_pixels] += weight*masked_image.image.array[good_pixels] 

126 

127 if self.compute_n_image: 

128 self.n_image[good_pixels] += 1 

129 

130 if self.calc_error_from_input_variance: 

131 self.sum_w2var[good_pixels] += (weight**2.)*masked_image.variance.array[good_pixels] 

132 else: 

133 self.sum_weight2[good_pixels] += weight**2. 

134 self.sum_wdata2[good_pixels] += weight*(masked_image.image.array[good_pixels]**2.) 

135 

136 # Mask bits are propagated for good pixels 

137 self.or_mask[good_pixels] |= masked_image.mask.array[good_pixels] 

138 

139 # Bad pixels are only tracked if they cross a threshold 

140 for bit in self.mask_threshold_dict: 

141 bad_pixels = ((masked_image.mask.array & 2**bit) > 0) 

142 self.rejected_weights_by_bit[bit][bad_pixels] += weight 

143 self.masked_pixels_mask[bad_pixels] |= 2**bit 

144 

145 def fill_stacked_masked_image(self, stacked_masked_image): 

146 """Fill the stacked mask image after accumulation. 

147 

148 Parameters 

149 ---------- 

150 stacked_masked_image : `lsst.afw.image.MaskedImage` 

151 Total masked image. 

152 """ 

153 with warnings.catch_warnings(): 

154 # Let the NaNs through and flag bad pixels below 

155 warnings.simplefilter("ignore") 

156 

157 # The image plane is sum(weight*data)/sum(weight) 

158 stacked_masked_image.image.array[:, :] = self.sum_wdata/self.sum_weight 

159 

160 if self.calc_error_from_input_variance: 

161 mean_var = self.sum_w2var/(self.sum_weight**2.) 

162 else: 

163 # Compute the biased estimator 

164 variance = self.sum_wdata2/self.sum_weight - stacked_masked_image.image.array[:, :]**2. 

165 # De-bias 

166 variance *= (self.sum_weight**2.)/(self.sum_weight**2. - self.sum_weight2) 

167 

168 # Compute the mean variance 

169 mean_var = variance*self.sum_weight2/(self.sum_weight**2.) 

170 

171 stacked_masked_image.variance.array[:, :] = mean_var 

172 

173 # Propagate bits when they cross the threshold 

174 for bit in self.mask_threshold_dict: 

175 hypothetical_total_weight = self.sum_weight + self.rejected_weights_by_bit[bit] 

176 self.rejected_weights_by_bit[bit] /= hypothetical_total_weight 

177 propagate = np.where(self.rejected_weights_by_bit[bit] > self.mask_threshold_dict[bit]) 

178 self.or_mask[propagate] |= 2**bit 

179 

180 # Map mask planes to new bits for pixels that had at least one 

181 # bad input rejected and are in the mask_map. 

182 for mask_tuple in self.mask_map: 

183 self.or_mask[(self.masked_pixels_mask & mask_tuple[0]) > 0] |= mask_tuple[1] 

184 

185 stacked_masked_image.mask.array[:, :] = self.or_mask 

186 

187 if self.no_good_pixels_mask is None: 

188 mask_dict = stacked_masked_image.mask.getMaskPlaneDict() 

189 no_good_pixels_mask = 2**(mask_dict['NO_DATA']) 

190 else: 

191 no_good_pixels_mask = self.no_good_pixels_mask 

192 

193 bad_pixels = (self.sum_weight <= 0.0) 

194 stacked_masked_image.mask.array[bad_pixels] |= no_good_pixels_mask 

195 

196 def add_image(self, image, weight=1.0): 

197 """Add an image to the stack. 

198 

199 No bit-filtering is performed when adding an image. 

200 

201 Parameters 

202 ---------- 

203 image : `lsst.afw.image.Image` 

204 Image to add to the stack. 

205 weight : `float`, optional 

206 Weight to apply for weighted mean. 

207 """ 

208 self.sum_weight[:, :] += weight 

209 self.sum_wdata[:, :] += weight*image.array[:] 

210 

211 if self.compute_n_image: 

212 self.n_image[:, :] += 1 

213 

214 def fill_stacked_image(self, stacked_image): 

215 """Fill the image after accumulation. 

216 

217 Parameters 

218 ---------- 

219 stacked_image : `lsst.afw.image.Image` 

220 Total image. 

221 """ 

222 with warnings.catch_warnings(): 

223 # Let the NaNs through, this should only happen 

224 # if we're stacking with no inputs. 

225 warnings.simplefilter("ignore") 

226 

227 # The image plane is sum(weight*data)/sum(weight) 

228 stacked_image.array[:, :] = self.sum_wdata/self.sum_weight 

229 

230 @staticmethod 

231 def stats_ctrl_to_threshold_dict(stats_ctrl): 

232 """Convert stats control to threshold dict. 

233 

234 Parameters 

235 ---------- 

236 stats_ctrl : `lsst.afw.math.StatisticsControl` 

237 

238 Returns 

239 ------- 

240 threshold_dict : `dict` 

241 Dict mapping from bit to propagation threshold. 

242 """ 

243 threshold_dict = {} 

244 for bit in range(64): 

245 threshold_dict[bit] = stats_ctrl.getMaskPropagationThreshold(bit) 

246 

247 return threshold_dict