Coverage for python/lsst/scarlet/lite/parameters.py: 21%
198 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:24 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 08:24 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "parameter",
26 "Parameter",
27 "FistaParameter",
28 "AdaproxParameter",
29 "FixedParameter",
30 "relative_step",
31 "phi_psi",
32 "DEFAULT_ADAPROX_FACTOR",
33]
35from copy import deepcopy
36from typing import Any, Callable, Sequence, cast
38import numpy as np
39import numpy.typing as npt
41from .bbox import Box
43# The default factor used for adaprox parameter steps
44DEFAULT_ADAPROX_FACTOR = 1e-2
47def step_function_wrapper(step: float) -> Callable:
48 """Wrapper to make a numerical step into a step function
50 Parameters
51 ----------
52 step:
53 The step to take for a given array.
55 Returns
56 -------
57 step_function:
58 The step function that takes an array and returns the
59 numerical step.
60 """
61 return lambda x: step
64class Parameter:
65 """A parameter in a `Component`
67 Parameters
68 ----------
69 x:
70 The array of values that is being fit.
71 helpers:
72 A dictionary of helper arrays that are used by an optimizer to
73 persist values like the gradient of `x`, the Hessian of `x`, etc.
74 step:
75 A numerical step value or function to calculate the step for a
76 given `x``.
77 grad:
78 A function to calculate the gradient of `x`.
79 prox:
80 A function to take the proximal operator of `x`.
81 """
83 def __init__(
84 self,
85 x: np.ndarray,
86 helpers: dict[str, np.ndarray],
87 step: Callable | float,
88 grad: Callable | None = None,
89 prox: Callable | None = None,
90 ):
91 self.x = x
92 self.helpers = helpers
94 if isinstance(step, float):
95 _step = step_function_wrapper(step)
96 else:
97 _step = step
99 self._step = _step
100 self.grad = grad
101 self.prox = prox
103 @property
104 def step(self) -> float:
105 """Calculate the step
107 Returns
108 -------
109 step:
110 The numerical step if no iteration is given.
111 """
112 return self._step(self.x)
114 @property
115 def shape(self) -> tuple[int, ...]:
116 """The shape of the array that is being fit."""
117 return self.x.shape
119 @property
120 def dtype(self) -> npt.DTypeLike:
121 """The numpy dtype of the array that is being fit."""
122 return self.x.dtype
124 def __copy__(self) -> Parameter:
125 """Create a shallow copy of this parameter.
127 Returns
128 -------
129 parameter:
130 A shallow copy of this parameter.
131 """
132 helpers = {k: v.copy() for k, v in self.helpers.items()}
133 copied = Parameter(self.x.copy(), helpers, 0)
134 copied._step = self._step
135 copied.grad = self.grad
136 copied.prox = self.prox
137 return copied
139 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Parameter:
140 """Create a deep copy of this parameter.
142 Parameters
143 ----------
144 memo:
145 A memoization dictionary used by `copy.deepcopy`.
146 Returns
147 -------
148 parameter:
149 A deep copy of this parameter.
150 """
151 helpers = {k: deepcopy(v, memo) for k, v in self.helpers.items()}
152 copied = Parameter(deepcopy(self.x, memo), helpers, 0)
153 copied._step = deepcopy(self._step, memo)
154 copied.grad = deepcopy(self.grad, memo)
155 copied.prox = deepcopy(self.prox, memo)
156 return copied
158 def copy(self, deep: bool = False) -> Parameter:
159 """Copy this parameter, including all of the helper arrays.
161 Parameters
162 ----------
163 deep:
164 If `True`, a deep copy is made.
165 If `False`, a shallow copy is made.
167 Returns
168 -------
169 parameter:
170 A copy of this parameter.
171 """
172 if deep:
173 return self.__deepcopy__({})
174 return self.__copy__()
176 def update(self, it: int, input_grad: np.ndarray, *args):
177 """Update the parameter in one iteration.
179 This includes the gradient update, proximal update,
180 and any meta parameters that are stored as class
181 attributes to update the parameter.
183 Parameters
184 ----------
185 it:
186 The current iteration
187 input_grad:
188 The gradient from the full model, passed to the parameter.
189 """
190 raise NotImplementedError("Base Parameters cannot be updated")
192 def resize(self, old_box: Box, new_box: Box):
193 """Grow the parameter and all of the helper parameters
195 Parameters
196 ----------
197 old_box:
198 The old bounding box for the parameter.
199 new_box:
200 The new bounding box for the parameter.
201 """
202 slices = new_box.overlapped_slices(old_box)
203 x = np.zeros(new_box.shape, dtype=self.dtype)
204 x[slices[0]] = self.x[slices[1]]
205 self.x = x
207 for name, value in self.helpers.items():
208 result = np.zeros(new_box.shape, dtype=self.dtype)
209 result[slices[0]] = value[slices[1]]
210 self.helpers[name] = result
213def parameter(x: np.ndarray | Parameter) -> Parameter:
214 """Convert a `np.ndarray` into a `Parameter`.
216 Parameters
217 ----------
218 x:
219 The array or parameter to convert into a `Parameter`.
221 Returns
222 -------
223 result:
224 `x`, converted into a `Parameter` if necessary.
225 """
226 if isinstance(x, Parameter):
227 return x
228 return Parameter(x, {}, 0)
231class FistaParameter(Parameter):
232 """A `Parameter` that updates itself using the Beck-Teboulle 2009
233 FISTA proximal gradient method.
235 See https://www.ceremade.dauphine.fr/~carlier/FISTA
237 Parameters
238 ----------
239 x:
240 The array of values that is being fit.
241 step:
242 A numerical step value or function to calculate the step for a
243 given `x`.
244 grad:
245 A function to calculate the gradient of `x`.
246 prox:
247 A function to take the proximal operator of `x`.
248 t0:
249 The initial value of the FISTA momentum term.
250 z0:
251 The initial value of the extrapolation array.
252 If `None` then a copy of `x` is used.
253 """
255 def __init__(
256 self,
257 x: np.ndarray,
258 step: float,
259 grad: Callable | None = None,
260 prox: Callable | None = None,
261 t0: float = 1,
262 z0: np.ndarray | None = None,
263 ):
264 if z0 is None:
265 z0 = x.copy()
267 super().__init__(
268 x,
269 {"z": z0},
270 step,
271 grad,
272 prox,
273 )
274 self.t = t0
276 def update(self, it: int, input_grad: np.ndarray, *args):
277 """Update the parameter and meta-parameters using the PGM
279 See `Parameter` for the full description.
280 """
281 if len(args) == 0:
282 step = self.step
283 else:
284 step = self.step / np.sum(args[0] * args[0])
285 _x = self.x
286 _z = self.helpers["z"]
288 y = _z - step * cast(Callable, self.grad)(input_grad, _x, *args)
289 if self.prox is not None:
290 x = self.prox(y)
291 else:
292 x = y
293 t = 0.5 * (1 + np.sqrt(1 + 4 * self.t**2))
294 omega = 1 + (self.t - 1) / t
295 self.helpers["z"] = _x + omega * (x - _x)
296 _x[:] = x
297 self.t = t
299 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FistaParameter:
300 """Create a deep copy of this parameter.
302 Parameters
303 ----------
304 memo:
305 A memoization dictionary used by `copy.deepcopy`.
306 Returns
307 -------
308 parameter:
309 A deep copy of this parameter.
310 """
311 return FistaParameter(
312 deepcopy(self.x, memo),
313 self.step,
314 self.grad,
315 self.prox,
316 self.t,
317 deepcopy(self.helpers["z"], memo),
318 )
320 def __copy__(self) -> FistaParameter:
321 """Create a shallow copy of this parameter.
323 Returns
324 -------
325 parameter:
326 A shallow copy of this parameter.
327 """
328 return FistaParameter(
329 self.x.copy(),
330 self.step,
331 self.grad,
332 self.prox,
333 self.t,
334 self.helpers["z"].copy(),
335 )
338# The following code block contains different update methods for
339# various implementations of ADAM.
340# We currently use the `amsgrad_phi_psi` update by default,
341# but it can easily be interchanged by passing a different
342# variant name to the `AdaproxParameter`.
345# noinspection PyUnusedLocal
346def _adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
347 # moving averages
348 m[:] = (1 - b1[it]) * g + b1[it] * m
349 v[:] = (1 - b2) * (g**2) + b2 * v
351 # bias correction
352 t = it + 1
353 phi = m / (1 - b1[it] ** t)
354 psi = np.sqrt(v / (1 - b2**t)) + eps
355 return phi, psi
358# noinspection PyUnusedLocal
359def _nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
360 # moving averages
361 m[:] = (1 - b1[it]) * g + b1[it] * m
362 v[:] = (1 - b2) * (g**2) + b2 * v
364 # bias correction
365 t = it + 1
366 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t)
367 psi = np.sqrt(v / (1 - b2**t)) + eps
368 return phi, psi
371# noinspection PyUnusedLocal
372def _amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
373 # moving averages
374 m[:] = (1 - b1[it]) * g + b1[it] * m
375 v[:] = (1 - b2) * (g**2) + b2 * v
377 phi = m
378 vhat[:] = np.maximum(vhat, v)
379 # sanitize zero-gradient elements
380 if eps > 0:
381 vhat = np.maximum(vhat, eps)
382 psi = np.sqrt(vhat)
383 return phi, psi
386def _padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
387 # moving averages
388 m[:] = (1 - b1[it]) * g + b1[it] * m
389 v[:] = (1 - b2) * (g**2) + b2 * v
391 phi = m
392 vhat[:] = np.maximum(vhat, v)
393 # sanitize zero-gradient elements
394 if eps > 0:
395 vhat = np.maximum(vhat, eps)
396 psi = vhat**p
397 return phi, psi
400# noinspection PyUnusedLocal
401def _adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
402 # moving averages
403 m[:] = (1 - b1[it]) * g + b1[it] * m
404 v[:] = (1 - b2) * (g**2) + b2 * v
406 phi = m
407 if it == 0:
408 factor = 1.0
409 else:
410 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2
411 vhat[:] = np.maximum(factor * vhat, v)
412 # sanitize zero-gradient elements
413 if eps > 0:
414 vhat = np.maximum(vhat, eps)
415 psi = np.sqrt(vhat)
416 return phi, psi
419# noinspection PyUnusedLocal
420def _radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
421 rho_inf = 2 / (1 - b2) - 1
423 # moving averages
424 m[:] = (1 - b1[it]) * g + b1[it] * m
425 v[:] = (1 - b2) * (g**2) + b2 * v
427 # bias correction
428 t = it + 1
429 phi = m / (1 - b1[it] ** t)
430 rho = rho_inf - 2 * t * b2**t / (1 - b2**t)
432 if rho > 4:
433 psi = np.sqrt(v / (1 - b2**t))
434 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho)
435 psi /= r
436 else:
437 psi = np.ones(g.shape, g.dtype)
438 # sanitize zero-gradient elements
439 if eps > 0:
440 psi = np.maximum(psi, np.sqrt(eps))
441 return phi, psi
444# Dictionary to link ADAM variation names to their functional algorithms.
445phi_psi = {
446 "adam": _adam_phi_psi,
447 "nadam": _nadam_phi_psi,
448 "amsgrad": _amsgrad_phi_psi,
449 "padam": _padam_phi_psi,
450 "adamx": _adamx_phi_psi,
451 "radam": _radam_phi_psi,
452}
455class SingleItemArray:
456 """Mock an array with only a single item
458 Parameters
459 ----------
460 value:
461 The single value returned for any index.
462 """
464 def __init__(self, value):
465 self.value = value
467 def __getitem__(self, item):
468 return self.value
471class AdaproxParameter(Parameter):
472 """Operator updated using the Proximal ADAM algorithm
474 Uses multiple variants of adaptive quasi-Newton gradient descent
476 * Adam (Kingma & Ba 2015)
477 * NAdam (Dozat 2016)
478 * AMSGrad (Reddi, Kale & Kumar 2018)
479 * PAdam (Chen & Gu 2018)
480 * AdamX (Phuong & Phong 2019)
481 * RAdam (Liu et al. 2019)
483 See details of the algorithms in the respective papers.
485 Parameters
486 ----------
487 x:
488 The array of values that is being fit.
489 step:
490 A numerical step value or function to calculate the step for a
491 given `x`.
492 grad:
493 A function to calculate the gradient of `x`.
494 prox:
495 A function to take the proximal operator of `x`.
496 b1:
497 The decay rate of the first moment (mean) of the gradient.
498 b2:
499 The decay rate of the second moment (variance) of the gradient.
500 eps:
501 A small constant added for numerical stability.
502 p:
503 The power used by the ``PAdam`` scheme.
504 m0:
505 The initial value of the first moment.
506 If `None` then an array of zeros is used.
507 v0:
508 The initial value of the second moment.
509 If `None` then an array of zeros is used.
510 vhat0:
511 The initial value of the maximum second moment.
512 If `None` then an array of ``-inf`` is used.
513 scheme:
514 The name of the ADAM variant to use to update the parameter.
515 prox_e_rel:
516 The relative error used by the proximal operator.
517 """
519 def __init__(
520 self,
521 x: np.ndarray,
522 step: Callable | float,
523 grad: Callable | None = None,
524 prox: Callable | None = None,
525 b1: float | SingleItemArray = 0.9,
526 b2: float = 0.999,
527 eps: float = 1e-8,
528 p: float = 0.25,
529 m0: np.ndarray | None = None,
530 v0: np.ndarray | None = None,
531 vhat0: np.ndarray | None = None,
532 scheme: str = "amsgrad",
533 prox_e_rel: float = 1e-6,
534 ):
535 shape = x.shape
536 dtype = x.dtype
537 if m0 is None:
538 m0 = np.zeros(shape, dtype=dtype)
540 if v0 is None:
541 v0 = np.zeros(shape, dtype=dtype)
543 if vhat0 is None:
544 vhat0 = np.ones(shape, dtype=dtype) * -np.inf
546 super().__init__(
547 x,
548 {
549 "m": m0,
550 "v": v0,
551 "vhat": vhat0,
552 },
553 step,
554 grad,
555 prox,
556 )
558 if isinstance(b1, float):
559 _b1 = SingleItemArray(b1)
560 else:
561 _b1 = b1
563 self.b1 = _b1
564 self.b2 = b2
565 self.eps = eps
566 self.p = p
568 self.scheme = scheme
569 self.phi_psi = phi_psi[scheme]
570 self.e_rel = prox_e_rel
572 def update(self, it: int, input_grad: np.ndarray, *args):
573 """Update the parameter and meta-parameters using the PGM
575 See `~Parameter` for more.
576 """
577 _x = self.x
578 # Calculate the gradient
579 grad = cast(Callable, self.grad)(input_grad, _x, *args)
580 # Get the update for the parameter
581 phi, psi = self.phi_psi(
582 it,
583 grad,
584 self.helpers["m"],
585 self.helpers["v"],
586 self.helpers["vhat"],
587 self.b1,
588 self.b2,
589 self.eps,
590 self.p,
591 )
592 # Calculate the step size
593 step = self.step
594 if it > 0:
595 _x += -step * phi / psi
596 else:
597 # This is a scheme that Peter Melchior and I came up with to
598 # dampen the known affect of ADAM, where the first iteration
599 # is often much larger than desired.
600 _x += -step * phi / psi / 10
602 if self.prox is not None:
603 self.x = self.prox(_x)
605 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> AdaproxParameter:
606 """Create a deep copy of this parameter.
608 Parameters
609 ----------
610 memo:
611 A memoization dictionary used by `copy.deepcopy`.
612 Returns
613 -------
614 parameter:
615 A deep copy of this parameter.
616 """
617 return AdaproxParameter(
618 deepcopy(self.x, memo),
619 self.step,
620 self.grad,
621 self.prox,
622 self.b1,
623 self.b2,
624 self.eps,
625 self.p,
626 deepcopy(self.helpers["m"], memo),
627 deepcopy(self.helpers["v"], memo),
628 deepcopy(self.helpers["vhat"], memo),
629 scheme=self.scheme,
630 prox_e_rel=self.e_rel,
631 )
633 def __copy__(self) -> AdaproxParameter:
634 """Create a shallow copy of this parameter.
636 Returns
637 -------
638 parameter:
639 A shallow copy of this parameter.
640 """
641 return AdaproxParameter(
642 self.x,
643 self.step,
644 self.grad,
645 self.prox,
646 self.b1,
647 self.b2,
648 self.eps,
649 self.p,
650 self.helpers["m"],
651 self.helpers["v"],
652 self.helpers["vhat"],
653 scheme=self.scheme,
654 prox_e_rel=self.e_rel,
655 )
658class FixedParameter(Parameter):
659 """A parameter that is not updated
661 Parameters
662 ----------
663 x:
664 The array of values that is held fixed.
665 """
667 def __init__(self, x: np.ndarray):
668 super().__init__(x, {}, 0)
670 def update(self, it: int, input_grad: np.ndarray, *args):
671 pass
673 def __copy__(self) -> FixedParameter:
674 """Create a shallow copy of this parameter.
676 Returns
677 -------
678 parameter:
679 A shallow copy of this parameter.
680 """
681 return FixedParameter(self.x)
683 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FixedParameter:
684 """Create a deep copy of this parameter.
686 Parameters
687 ----------
688 memo:
689 A memoization dictionary used by `copy.deepcopy`.
691 Returns
692 -------
693 parameter:
694 A deep copy of this parameter.
695 """
696 return FixedParameter(deepcopy(self.x, memo))
699def relative_step(
700 x: np.ndarray,
701 factor: float = 0.1,
702 minimum: float = 0,
703 axis: int | Sequence[int] | None = None,
704):
705 """Step size set at `factor` times the mean of `X` in direction `axis`"""
706 return np.maximum(minimum, factor * x.mean(axis=axis))