Coverage for tests / test_parameters.py: 10%
158 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 00:46 -0700
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 00:46 -0700
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23from lsst.scarlet.lite import Box
24from lsst.scarlet.lite.parameters import (
25 AdaproxParameter,
26 FistaParameter,
27 FixedParameter,
28 Parameter,
29 parameter,
30 phi_psi,
31)
32from numpy.testing import assert_array_equal
33from utils import ScarletTestCase
36def prox_ceiling(x, thresh: float = 20):
37 """Test prox for testing parameters"""
38 x[x > thresh] = thresh
39 return x
42def grad(input_grad: np.ndarray, x: np.ndarray, *args):
43 """Test gradient for testing parameters"""
44 return 2 * x * input_grad
47class TestParameters(ScarletTestCase):
48 def test_parameter_class(self):
49 x = np.arange(15, dtype=float).reshape(3, 5)
50 param = parameter(x)
51 self.assertIsInstance(param, Parameter)
52 assert_array_equal(param.x, x)
53 self.assertTupleEqual(param.shape, (3, 5))
54 self.assertEqual(param.dtype, float)
56 with self.assertRaises(NotImplementedError):
57 param.update(1, np.zeros((3, 5)))
59 # Test copy method
60 y = np.zeros((3, 5), dtype=float)
61 y[1, 3] = 1
62 param = Parameter(x, {"y": y}, 0)
63 self.assertIsNot(param.copy().x, x)
64 assert_array_equal(param.copy().x, x)
65 self.assertIsNot(param.copy().helpers["y"], y)
66 assert_array_equal(param.copy().helpers["y"], y)
68 param2 = parameter(param)
69 self.assertIs(param2, param)
71 # Audit finding K-1: ``__copy__`` and ``__deepcopy__`` must
72 # propagate ``step``, ``grad``, and ``prox``. Previously the
73 # base class re-built the copy with ``step=0`` and no grad
74 # or prox, leaving the copy non-functional for optimization.
75 param = Parameter(x, {"y": y}, 0.25, grad=grad, prox=prox_ceiling)
76 for copied in (param.copy(deep=False), param.copy(deep=True)):
77 self.assertEqual(copied.step, 0.25)
78 self.assertIs(copied.grad, grad)
79 self.assertIs(copied.prox, prox_ceiling)
81 def test_growing(self):
82 x = np.arange(15, dtype=float).reshape(3, 5)
83 y = np.zeros((3, 5), dtype=float)
84 y[1, 3] = 1
85 param = Parameter(x, {"y": y}, 0)
87 # Test growing in all dimensions
88 old_box = Box((3, 5), (21, 15))
89 new_box = Box((11, 20), (19, 10))
90 param.resize(old_box, new_box)
91 truth = np.zeros((11, 20), dtype=float)
92 truth[2:5, 5:10] = x
93 assert_array_equal(param.x, truth)
95 # Test shrinking in all directions
96 param = Parameter(x, {"y": y}, 0)
97 old_box = Box((3, 5), (21, 15))
98 new_box = Box((1, 3), (22, 16))
99 param.resize(old_box, new_box)
100 truth = x[1:2, 1:4]
101 assert_array_equal(param.x, truth)
103 def test_fista_parameter(self):
104 x = np.arange(10, dtype=float)
105 x2 = x**2
106 param = FistaParameter(
107 x2,
108 0.1,
109 grad,
110 prox_ceiling,
111 )
113 assert_array_equal(param.x, x2)
114 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x)
115 truth = x2.copy()
116 truth[truth > 20] = 20
117 assert_array_equal(param.prox(x2), truth)
118 param.update(10, x, x2)
120 def test_adprox_parameter(self):
121 x = np.arange(10, dtype=float)
122 x2 = x**2
123 param = AdaproxParameter(
124 x2,
125 0.1,
126 grad,
127 prox_ceiling,
128 )
130 assert_array_equal(param.x, x2)
131 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x)
132 truth = x2.copy()
133 truth[truth > 20] = 20
134 assert_array_equal(param.prox(x2), truth)
135 param.update(10, x, x2)
137 schemes = tuple(phi_psi.keys())
138 for scheme in schemes:
139 param = AdaproxParameter(
140 x2,
141 0.1,
142 grad,
143 prox_ceiling,
144 scheme=scheme,
145 )
146 param.update(10, x, x2)
148 # Audit finding O-1: ``update`` must work when ``prox`` is
149 # None (the default), matching ``FistaParameter.update``.
150 # Previously ``self.prox(_x)`` was called unconditionally and
151 # raised ``TypeError`` on ``None``.
152 param = AdaproxParameter(x2.copy(), 0.1, grad)
153 param.update(10, x, x2)
154 param.update(0, x, x2)
156 def test_adaprox_variants_converge(self):
157 """Each ADAM variant must drive a simple quadratic loss to
158 its optimum.
160 The loss is ``0.5 * sum((x - target)**2)`` with gradient
161 ``x - target``. Every scheme should reach ``target`` within
162 a small tolerance after a fixed iteration budget. This
163 catches semantic regressions in any of the per-iteration
164 update formulas (the kinds of bug in O-2).
165 """
166 target = np.array([3.0, -2.0, 5.0])
168 def quad_grad(input_grad, x):
169 return x - target
171 for scheme in tuple(phi_psi.keys()):
172 param = AdaproxParameter(
173 np.zeros_like(target),
174 step=0.1,
175 grad=quad_grad,
176 scheme=scheme,
177 )
178 for it in range(2000):
179 param.update(it, np.zeros_like(target))
180 np.testing.assert_allclose(
181 param.x,
182 target,
183 atol=1e-3,
184 err_msg=f"AdaproxParameter scheme={scheme!r} failed to converge",
185 )
187 def test_adamx_first_iteration(self):
188 """``_adamx_phi_psi`` must treat ``factor`` as 1 on the first
189 iteration rather than indexing ``b1[it-1] = b1[-1]``.
191 Audit finding O-2: at ``it=0`` the formula
192 ``(1 - b1[it])**2 / (1 - b1[it-1])**2`` accidentally reads
193 the *last* element of a varying ``b1`` schedule. With the
194 default ``SingleItemArray`` (constant ``b1``) this returns
195 the right value by coincidence; with a real array of varying
196 decay rates the factor is wrong on the very first step.
197 """
198 adamx = phi_psi["adamx"]
199 # b1[-1] differs sharply from b1[0], so the buggy and fixed
200 # branches diverge.
201 b1 = np.array([0.9, 0.5])
202 g = np.array([1.0])
203 m = np.array([0.0])
204 v = np.array([0.0])
205 # Non-default ``vhat`` so the factor multiplies a finite
206 # value (the default ``-inf`` would absorb any positive
207 # factor).
208 vhat = np.array([1.0])
209 _, psi = adamx(0, g, m, v, vhat, b1, 0.999, 0, 0.5)
210 # v after update: (1-0.999)*1 = 0.001
211 # Fixed: vhat = max(1.0 * 1.0, 0.001) = 1.0; psi = sqrt(1.0) = 1.0
212 # Buggy: factor = (0.1)**2/(0.5)**2 = 0.04
213 # vhat = max(0.04 * 1.0, 0.001) = 0.04; psi = sqrt(0.04) = 0.2
214 np.testing.assert_allclose(psi, 1.0)
216 def test_fixed_parameter(self):
217 x = np.arange(10, dtype=float)
218 param = FixedParameter(x)
219 param.update(10, np.arange(10) * 2)
220 assert_array_equal(param.x, x)
222 def test_shallow_copy(self):
223 x = np.arange(10, dtype=float)
225 # FistaParameter
226 param = FistaParameter(x, 0.1)
227 param_copy = param.copy()
228 self.assertIsInstance(param_copy, FistaParameter)
230 assert_array_equal(param.x, param_copy.x)
231 assert_array_equal(param.helpers["z"], param_copy.helpers["z"])
233 # AdaproxParameter
234 param = AdaproxParameter(x, 0.1)
235 param_copy = param.copy()
236 self.assertIsInstance(param_copy, AdaproxParameter)
238 assert_array_equal(param.x, param_copy.x)
239 assert_array_equal(param.helpers["m"], param_copy.helpers["m"])
240 assert_array_equal(param.helpers["v"], param_copy.helpers["v"])
241 assert_array_equal(param.helpers["vhat"], param_copy.helpers["vhat"])
243 # FixedParameter
244 param = FixedParameter(x)
245 param_copy = param.copy()
246 self.assertIsInstance(param_copy, FixedParameter)
247 assert_array_equal(param.x, param_copy.x)
249 def test_deep_copy(self):
250 x = np.arange(10, dtype=float)
252 # FistaParameter
253 param = FistaParameter(x, 0.1)
254 param_deepcopy = param.copy(deep=True)
255 self.assertIsInstance(param_deepcopy, FistaParameter)
257 assert_array_equal(param.x, param_deepcopy.x)
258 param_deepcopy.x += 1
259 with self.assertRaises(AssertionError):
260 assert_array_equal(param.x, param_deepcopy.x)
262 assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"])
263 param_deepcopy.helpers["z"] += 1
264 with self.assertRaises(AssertionError):
265 assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"])
267 # AdaproxParameter
268 param = AdaproxParameter(x, 0.1)
269 param_deepcopy = param.copy(deep=True)
270 self.assertIsInstance(param_deepcopy, AdaproxParameter)
272 assert_array_equal(param.x, param_deepcopy.x)
273 param_deepcopy.x += 1
274 with self.assertRaises(AssertionError):
275 assert_array_equal(param.x, param_deepcopy.x)
277 assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"])
278 param_deepcopy.helpers["m"] = -1
279 with self.assertRaises(AssertionError):
280 assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"])
282 assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"])
283 param_deepcopy.helpers["v"] = -1
284 with self.assertRaises(AssertionError):
285 assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"])
287 assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"])
288 param_deepcopy.helpers["vhat"] = -1
289 with self.assertRaises(AssertionError):
290 assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"])
292 # FixedParameter
293 param = FixedParameter(x)
294 param_deepcopy = param.copy(deep=True)
295 self.assertIsInstance(param_deepcopy, FixedParameter)
296 assert_array_equal(param.x, param_deepcopy.x)
297 param_deepcopy.x += 1
298 with self.assertRaises(AssertionError):
299 assert_array_equal(param.x, param_deepcopy.x)