Coverage for tests/test_parameters.py: 10%

158 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 08:24 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23from lsst.scarlet.lite import Box 

24from lsst.scarlet.lite.parameters import ( 

25 AdaproxParameter, 

26 FistaParameter, 

27 FixedParameter, 

28 Parameter, 

29 parameter, 

30 phi_psi, 

31) 

32from numpy.testing import assert_array_equal 

33from utils import ScarletTestCase 

34 

35 

36def prox_ceiling(x, thresh: float = 20): 

37 """Test prox for testing parameters""" 

38 x[x > thresh] = thresh 

39 return x 

40 

41 

42def grad(input_grad: np.ndarray, x: np.ndarray, *args): 

43 """Test gradient for testing parameters""" 

44 return 2 * x * input_grad 

45 

46 

47class TestParameters(ScarletTestCase): 

48 def test_parameter_class(self): 

49 x = np.arange(15, dtype=float).reshape(3, 5) 

50 param = parameter(x) 

51 self.assertIsInstance(param, Parameter) 

52 assert_array_equal(param.x, x) 

53 self.assertTupleEqual(param.shape, (3, 5)) 

54 self.assertEqual(param.dtype, float) 

55 

56 with self.assertRaises(NotImplementedError): 

57 param.update(1, np.zeros((3, 5))) 

58 

59 # Test copy method 

60 y = np.zeros((3, 5), dtype=float) 

61 y[1, 3] = 1 

62 param = Parameter(x, {"y": y}, 0) 

63 self.assertIsNot(param.copy().x, x) 

64 assert_array_equal(param.copy().x, x) 

65 self.assertIsNot(param.copy().helpers["y"], y) 

66 assert_array_equal(param.copy().helpers["y"], y) 

67 

68 param2 = parameter(param) 

69 self.assertIs(param2, param) 

70 

71 # Audit finding K-1: ``__copy__`` and ``__deepcopy__`` must 

72 # propagate ``step``, ``grad``, and ``prox``. Previously the 

73 # base class re-built the copy with ``step=0`` and no grad 

74 # or prox, leaving the copy non-functional for optimization. 

75 param = Parameter(x, {"y": y}, 0.25, grad=grad, prox=prox_ceiling) 

76 for copied in (param.copy(deep=False), param.copy(deep=True)): 

77 self.assertEqual(copied.step, 0.25) 

78 self.assertIs(copied.grad, grad) 

79 self.assertIs(copied.prox, prox_ceiling) 

80 

81 def test_growing(self): 

82 x = np.arange(15, dtype=float).reshape(3, 5) 

83 y = np.zeros((3, 5), dtype=float) 

84 y[1, 3] = 1 

85 param = Parameter(x, {"y": y}, 0) 

86 

87 # Test growing in all dimensions 

88 old_box = Box((3, 5), (21, 15)) 

89 new_box = Box((11, 20), (19, 10)) 

90 param.resize(old_box, new_box) 

91 truth = np.zeros((11, 20), dtype=float) 

92 truth[2:5, 5:10] = x 

93 assert_array_equal(param.x, truth) 

94 

95 # Test shrinking in all directions 

96 param = Parameter(x, {"y": y}, 0) 

97 old_box = Box((3, 5), (21, 15)) 

98 new_box = Box((1, 3), (22, 16)) 

99 param.resize(old_box, new_box) 

100 truth = x[1:2, 1:4] 

101 assert_array_equal(param.x, truth) 

102 

103 def test_fista_parameter(self): 

104 x = np.arange(10, dtype=float) 

105 x2 = x**2 

106 param = FistaParameter( 

107 x2, 

108 0.1, 

109 grad, 

110 prox_ceiling, 

111 ) 

112 

113 assert_array_equal(param.x, x2) 

114 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x) 

115 truth = x2.copy() 

116 truth[truth > 20] = 20 

117 assert_array_equal(param.prox(x2), truth) 

118 param.update(10, x, x2) 

119 

120 def test_adprox_parameter(self): 

121 x = np.arange(10, dtype=float) 

122 x2 = x**2 

123 param = AdaproxParameter( 

124 x2, 

125 0.1, 

126 grad, 

127 prox_ceiling, 

128 ) 

129 

130 assert_array_equal(param.x, x2) 

131 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x) 

132 truth = x2.copy() 

133 truth[truth > 20] = 20 

134 assert_array_equal(param.prox(x2), truth) 

135 param.update(10, x, x2) 

136 

137 schemes = tuple(phi_psi.keys()) 

138 for scheme in schemes: 

139 param = AdaproxParameter( 

140 x2, 

141 0.1, 

142 grad, 

143 prox_ceiling, 

144 scheme=scheme, 

145 ) 

146 param.update(10, x, x2) 

147 

148 # Audit finding O-1: ``update`` must work when ``prox`` is 

149 # None (the default), matching ``FistaParameter.update``. 

150 # Previously ``self.prox(_x)`` was called unconditionally and 

151 # raised ``TypeError`` on ``None``. 

152 param = AdaproxParameter(x2.copy(), 0.1, grad) 

153 param.update(10, x, x2) 

154 param.update(0, x, x2) 

155 

156 def test_adaprox_variants_converge(self): 

157 """Each ADAM variant must drive a simple quadratic loss to 

158 its optimum. 

159 

160 The loss is ``0.5 * sum((x - target)**2)`` with gradient 

161 ``x - target``. Every scheme should reach ``target`` within 

162 a small tolerance after a fixed iteration budget. This 

163 catches semantic regressions in any of the per-iteration 

164 update formulas (the kinds of bug in O-2). 

165 """ 

166 target = np.array([3.0, -2.0, 5.0]) 

167 

168 def quad_grad(input_grad, x): 

169 return x - target 

170 

171 for scheme in tuple(phi_psi.keys()): 

172 param = AdaproxParameter( 

173 np.zeros_like(target), 

174 step=0.1, 

175 grad=quad_grad, 

176 scheme=scheme, 

177 ) 

178 for it in range(2000): 

179 param.update(it, np.zeros_like(target)) 

180 np.testing.assert_allclose( 

181 param.x, 

182 target, 

183 atol=1e-3, 

184 err_msg=f"AdaproxParameter scheme={scheme!r} failed to converge", 

185 ) 

186 

187 def test_adamx_first_iteration(self): 

188 """``_adamx_phi_psi`` must treat ``factor`` as 1 on the first 

189 iteration rather than indexing ``b1[it-1] = b1[-1]``. 

190 

191 Audit finding O-2: at ``it=0`` the formula 

192 ``(1 - b1[it])**2 / (1 - b1[it-1])**2`` accidentally reads 

193 the *last* element of a varying ``b1`` schedule. With the 

194 default ``SingleItemArray`` (constant ``b1``) this returns 

195 the right value by coincidence; with a real array of varying 

196 decay rates the factor is wrong on the very first step. 

197 """ 

198 adamx = phi_psi["adamx"] 

199 # b1[-1] differs sharply from b1[0], so the buggy and fixed 

200 # branches diverge. 

201 b1 = np.array([0.9, 0.5]) 

202 g = np.array([1.0]) 

203 m = np.array([0.0]) 

204 v = np.array([0.0]) 

205 # Non-default ``vhat`` so the factor multiplies a finite 

206 # value (the default ``-inf`` would absorb any positive 

207 # factor). 

208 vhat = np.array([1.0]) 

209 _, psi = adamx(0, g, m, v, vhat, b1, 0.999, 0, 0.5) 

210 # v after update: (1-0.999)*1 = 0.001 

211 # Fixed: vhat = max(1.0 * 1.0, 0.001) = 1.0; psi = sqrt(1.0) = 1.0 

212 # Buggy: factor = (0.1)**2/(0.5)**2 = 0.04 

213 # vhat = max(0.04 * 1.0, 0.001) = 0.04; psi = sqrt(0.04) = 0.2 

214 np.testing.assert_allclose(psi, 1.0) 

215 

216 def test_fixed_parameter(self): 

217 x = np.arange(10, dtype=float) 

218 param = FixedParameter(x) 

219 param.update(10, np.arange(10) * 2) 

220 assert_array_equal(param.x, x) 

221 

222 def test_shallow_copy(self): 

223 x = np.arange(10, dtype=float) 

224 

225 # FistaParameter 

226 param = FistaParameter(x, 0.1) 

227 param_copy = param.copy() 

228 self.assertIsInstance(param_copy, FistaParameter) 

229 

230 assert_array_equal(param.x, param_copy.x) 

231 assert_array_equal(param.helpers["z"], param_copy.helpers["z"]) 

232 

233 # AdaproxParameter 

234 param = AdaproxParameter(x, 0.1) 

235 param_copy = param.copy() 

236 self.assertIsInstance(param_copy, AdaproxParameter) 

237 

238 assert_array_equal(param.x, param_copy.x) 

239 assert_array_equal(param.helpers["m"], param_copy.helpers["m"]) 

240 assert_array_equal(param.helpers["v"], param_copy.helpers["v"]) 

241 assert_array_equal(param.helpers["vhat"], param_copy.helpers["vhat"]) 

242 

243 # FixedParameter 

244 param = FixedParameter(x) 

245 param_copy = param.copy() 

246 self.assertIsInstance(param_copy, FixedParameter) 

247 assert_array_equal(param.x, param_copy.x) 

248 

249 def test_deep_copy(self): 

250 x = np.arange(10, dtype=float) 

251 

252 # FistaParameter 

253 param = FistaParameter(x, 0.1) 

254 param_deepcopy = param.copy(deep=True) 

255 self.assertIsInstance(param_deepcopy, FistaParameter) 

256 

257 assert_array_equal(param.x, param_deepcopy.x) 

258 param_deepcopy.x += 1 

259 with self.assertRaises(AssertionError): 

260 assert_array_equal(param.x, param_deepcopy.x) 

261 

262 assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"]) 

263 param_deepcopy.helpers["z"] += 1 

264 with self.assertRaises(AssertionError): 

265 assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"]) 

266 

267 # AdaproxParameter 

268 param = AdaproxParameter(x, 0.1) 

269 param_deepcopy = param.copy(deep=True) 

270 self.assertIsInstance(param_deepcopy, AdaproxParameter) 

271 

272 assert_array_equal(param.x, param_deepcopy.x) 

273 param_deepcopy.x += 1 

274 with self.assertRaises(AssertionError): 

275 assert_array_equal(param.x, param_deepcopy.x) 

276 

277 assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"]) 

278 param_deepcopy.helpers["m"] = -1 

279 with self.assertRaises(AssertionError): 

280 assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"]) 

281 

282 assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"]) 

283 param_deepcopy.helpers["v"] = -1 

284 with self.assertRaises(AssertionError): 

285 assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"]) 

286 

287 assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"]) 

288 param_deepcopy.helpers["vhat"] = -1 

289 with self.assertRaises(AssertionError): 

290 assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"]) 

291 

292 # FixedParameter 

293 param = FixedParameter(x) 

294 param_deepcopy = param.copy(deep=True) 

295 self.assertIsInstance(param_deepcopy, FixedParameter) 

296 assert_array_equal(param.x, param_deepcopy.x) 

297 param_deepcopy.x += 1 

298 with self.assertRaises(AssertionError): 

299 assert_array_equal(param.x, param_deepcopy.x)