Coverage for python/lsst/scarlet/lite/parameters.py: 21%

198 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 01:23 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "parameter", 

26 "Parameter", 

27 "FistaParameter", 

28 "AdaproxParameter", 

29 "FixedParameter", 

30 "relative_step", 

31 "phi_psi", 

32 "DEFAULT_ADAPROX_FACTOR", 

33] 

34 

35from copy import deepcopy 

36from typing import Any, Callable, Sequence, cast 

37 

38import numpy as np 

39import numpy.typing as npt 

40 

41from .bbox import Box 

42 

43# The default factor used for adaprox parameter steps 

44DEFAULT_ADAPROX_FACTOR = 1e-2 

45 

46 

47def step_function_wrapper(step: float) -> Callable: 

48 """Wrapper to make a numerical step into a step function 

49 

50 Parameters 

51 ---------- 

52 step: 

53 The step to take for a given array. 

54 

55 Returns 

56 ------- 

57 step_function: 

58 The step function that takes an array and returns the 

59 numerical step. 

60 """ 

61 return lambda x: step 

62 

63 

64class Parameter: 

65 """A parameter in a `Component` 

66 

67 Parameters 

68 ---------- 

69 x: 

70 The array of values that is being fit. 

71 helpers: 

72 A dictionary of helper arrays that are used by an optimizer to 

73 persist values like the gradient of `x`, the Hessian of `x`, etc. 

74 step: 

75 A numerical step value or function to calculate the step for a 

76 given `x``. 

77 grad: 

78 A function to calculate the gradient of `x`. 

79 prox: 

80 A function to take the proximal operator of `x`. 

81 """ 

82 

83 def __init__( 

84 self, 

85 x: np.ndarray, 

86 helpers: dict[str, np.ndarray], 

87 step: Callable | float, 

88 grad: Callable | None = None, 

89 prox: Callable | None = None, 

90 ): 

91 self.x = x 

92 self.helpers = helpers 

93 

94 if isinstance(step, float): 

95 _step = step_function_wrapper(step) 

96 else: 

97 _step = step 

98 

99 self._step = _step 

100 self.grad = grad 

101 self.prox = prox 

102 

103 @property 

104 def step(self) -> float: 

105 """Calculate the step 

106 

107 Returns 

108 ------- 

109 step: 

110 The numerical step if no iteration is given. 

111 """ 

112 return self._step(self.x) 

113 

114 @property 

115 def shape(self) -> tuple[int, ...]: 

116 """The shape of the array that is being fit.""" 

117 return self.x.shape 

118 

119 @property 

120 def dtype(self) -> npt.DTypeLike: 

121 """The numpy dtype of the array that is being fit.""" 

122 return self.x.dtype 

123 

124 def __copy__(self) -> Parameter: 

125 """Create a shallow copy of this parameter. 

126 

127 Returns 

128 ------- 

129 parameter: 

130 A shallow copy of this parameter. 

131 """ 

132 helpers = {k: v.copy() for k, v in self.helpers.items()} 

133 copied = Parameter(self.x.copy(), helpers, 0) 

134 copied._step = self._step 

135 copied.grad = self.grad 

136 copied.prox = self.prox 

137 return copied 

138 

139 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Parameter: 

140 """Create a deep copy of this parameter. 

141 

142 Parameters 

143 ---------- 

144 memo: 

145 A memoization dictionary used by `copy.deepcopy`. 

146 Returns 

147 ------- 

148 parameter: 

149 A deep copy of this parameter. 

150 """ 

151 helpers = {k: deepcopy(v, memo) for k, v in self.helpers.items()} 

152 copied = Parameter(deepcopy(self.x, memo), helpers, 0) 

153 copied._step = deepcopy(self._step, memo) 

154 copied.grad = deepcopy(self.grad, memo) 

155 copied.prox = deepcopy(self.prox, memo) 

156 return copied 

157 

158 def copy(self, deep: bool = False) -> Parameter: 

159 """Copy this parameter, including all of the helper arrays. 

160 

161 Parameters 

162 ---------- 

163 deep: 

164 If `True`, a deep copy is made. 

165 If `False`, a shallow copy is made. 

166 

167 Returns 

168 ------- 

169 parameter: 

170 A copy of this parameter. 

171 """ 

172 if deep: 

173 return self.__deepcopy__({}) 

174 return self.__copy__() 

175 

176 def update(self, it: int, input_grad: np.ndarray, *args): 

177 """Update the parameter in one iteration. 

178 

179 This includes the gradient update, proximal update, 

180 and any meta parameters that are stored as class 

181 attributes to update the parameter. 

182 

183 Parameters 

184 ---------- 

185 it: 

186 The current iteration 

187 input_grad: 

188 The gradient from the full model, passed to the parameter. 

189 """ 

190 raise NotImplementedError("Base Parameters cannot be updated") 

191 

192 def resize(self, old_box: Box, new_box: Box): 

193 """Grow the parameter and all of the helper parameters 

194 

195 Parameters 

196 ---------- 

197 old_box: 

198 The old bounding box for the parameter. 

199 new_box: 

200 The new bounding box for the parameter. 

201 """ 

202 slices = new_box.overlapped_slices(old_box) 

203 x = np.zeros(new_box.shape, dtype=self.dtype) 

204 x[slices[0]] = self.x[slices[1]] 

205 self.x = x 

206 

207 for name, value in self.helpers.items(): 

208 result = np.zeros(new_box.shape, dtype=self.dtype) 

209 result[slices[0]] = value[slices[1]] 

210 self.helpers[name] = result 

211 

212 

213def parameter(x: np.ndarray | Parameter) -> Parameter: 

214 """Convert a `np.ndarray` into a `Parameter`. 

215 

216 Parameters 

217 ---------- 

218 x: 

219 The array or parameter to convert into a `Parameter`. 

220 

221 Returns 

222 ------- 

223 result: 

224 `x`, converted into a `Parameter` if necessary. 

225 """ 

226 if isinstance(x, Parameter): 

227 return x 

228 return Parameter(x, {}, 0) 

229 

230 

231class FistaParameter(Parameter): 

232 """A `Parameter` that updates itself using the Beck-Teboulle 2009 

233 FISTA proximal gradient method. 

234 

235 See https://www.ceremade.dauphine.fr/~carlier/FISTA 

236 

237 Parameters 

238 ---------- 

239 x: 

240 The array of values that is being fit. 

241 step: 

242 A numerical step value or function to calculate the step for a 

243 given `x`. 

244 grad: 

245 A function to calculate the gradient of `x`. 

246 prox: 

247 A function to take the proximal operator of `x`. 

248 t0: 

249 The initial value of the FISTA momentum term. 

250 z0: 

251 The initial value of the extrapolation array. 

252 If `None` then a copy of `x` is used. 

253 """ 

254 

255 def __init__( 

256 self, 

257 x: np.ndarray, 

258 step: float, 

259 grad: Callable | None = None, 

260 prox: Callable | None = None, 

261 t0: float = 1, 

262 z0: np.ndarray | None = None, 

263 ): 

264 if z0 is None: 

265 z0 = x.copy() 

266 

267 super().__init__( 

268 x, 

269 {"z": z0}, 

270 step, 

271 grad, 

272 prox, 

273 ) 

274 self.t = t0 

275 

276 def update(self, it: int, input_grad: np.ndarray, *args): 

277 """Update the parameter and meta-parameters using the PGM 

278 

279 See `Parameter` for the full description. 

280 """ 

281 if len(args) == 0: 

282 step = self.step 

283 else: 

284 step = self.step / np.sum(args[0] * args[0]) 

285 _x = self.x 

286 _z = self.helpers["z"] 

287 

288 y = _z - step * cast(Callable, self.grad)(input_grad, _x, *args) 

289 if self.prox is not None: 

290 x = self.prox(y) 

291 else: 

292 x = y 

293 t = 0.5 * (1 + np.sqrt(1 + 4 * self.t**2)) 

294 omega = 1 + (self.t - 1) / t 

295 self.helpers["z"] = _x + omega * (x - _x) 

296 _x[:] = x 

297 self.t = t 

298 

299 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FistaParameter: 

300 """Create a deep copy of this parameter. 

301 

302 Parameters 

303 ---------- 

304 memo: 

305 A memoization dictionary used by `copy.deepcopy`. 

306 Returns 

307 ------- 

308 parameter: 

309 A deep copy of this parameter. 

310 """ 

311 return FistaParameter( 

312 deepcopy(self.x, memo), 

313 self.step, 

314 self.grad, 

315 self.prox, 

316 self.t, 

317 deepcopy(self.helpers["z"], memo), 

318 ) 

319 

320 def __copy__(self) -> FistaParameter: 

321 """Create a shallow copy of this parameter. 

322 

323 Returns 

324 ------- 

325 parameter: 

326 A shallow copy of this parameter. 

327 """ 

328 return FistaParameter( 

329 self.x.copy(), 

330 self.step, 

331 self.grad, 

332 self.prox, 

333 self.t, 

334 self.helpers["z"].copy(), 

335 ) 

336 

337 

338# The following code block contains different update methods for 

339# various implementations of ADAM. 

340# We currently use the `amsgrad_phi_psi` update by default, 

341# but it can easily be interchanged by passing a different 

342# variant name to the `AdaproxParameter`. 

343 

344 

345# noinspection PyUnusedLocal 

346def _adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

347 # moving averages 

348 m[:] = (1 - b1[it]) * g + b1[it] * m 

349 v[:] = (1 - b2) * (g**2) + b2 * v 

350 

351 # bias correction 

352 t = it + 1 

353 phi = m / (1 - b1[it] ** t) 

354 psi = np.sqrt(v / (1 - b2**t)) + eps 

355 return phi, psi 

356 

357 

358# noinspection PyUnusedLocal 

359def _nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

360 # moving averages 

361 m[:] = (1 - b1[it]) * g + b1[it] * m 

362 v[:] = (1 - b2) * (g**2) + b2 * v 

363 

364 # bias correction 

365 t = it + 1 

366 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t) 

367 psi = np.sqrt(v / (1 - b2**t)) + eps 

368 return phi, psi 

369 

370 

371# noinspection PyUnusedLocal 

372def _amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

373 # moving averages 

374 m[:] = (1 - b1[it]) * g + b1[it] * m 

375 v[:] = (1 - b2) * (g**2) + b2 * v 

376 

377 phi = m 

378 vhat[:] = np.maximum(vhat, v) 

379 # sanitize zero-gradient elements 

380 if eps > 0: 

381 vhat = np.maximum(vhat, eps) 

382 psi = np.sqrt(vhat) 

383 return phi, psi 

384 

385 

386def _padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

387 # moving averages 

388 m[:] = (1 - b1[it]) * g + b1[it] * m 

389 v[:] = (1 - b2) * (g**2) + b2 * v 

390 

391 phi = m 

392 vhat[:] = np.maximum(vhat, v) 

393 # sanitize zero-gradient elements 

394 if eps > 0: 

395 vhat = np.maximum(vhat, eps) 

396 psi = vhat**p 

397 return phi, psi 

398 

399 

400# noinspection PyUnusedLocal 

401def _adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

402 # moving averages 

403 m[:] = (1 - b1[it]) * g + b1[it] * m 

404 v[:] = (1 - b2) * (g**2) + b2 * v 

405 

406 phi = m 

407 if it == 0: 

408 factor = 1.0 

409 else: 

410 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2 

411 vhat[:] = np.maximum(factor * vhat, v) 

412 # sanitize zero-gradient elements 

413 if eps > 0: 

414 vhat = np.maximum(vhat, eps) 

415 psi = np.sqrt(vhat) 

416 return phi, psi 

417 

418 

419# noinspection PyUnusedLocal 

420def _radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

421 rho_inf = 2 / (1 - b2) - 1 

422 

423 # moving averages 

424 m[:] = (1 - b1[it]) * g + b1[it] * m 

425 v[:] = (1 - b2) * (g**2) + b2 * v 

426 

427 # bias correction 

428 t = it + 1 

429 phi = m / (1 - b1[it] ** t) 

430 rho = rho_inf - 2 * t * b2**t / (1 - b2**t) 

431 

432 if rho > 4: 

433 psi = np.sqrt(v / (1 - b2**t)) 

434 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho) 

435 psi /= r 

436 else: 

437 psi = np.ones(g.shape, g.dtype) 

438 # sanitize zero-gradient elements 

439 if eps > 0: 

440 psi = np.maximum(psi, np.sqrt(eps)) 

441 return phi, psi 

442 

443 

444# Dictionary to link ADAM variation names to their functional algorithms. 

445phi_psi = { 

446 "adam": _adam_phi_psi, 

447 "nadam": _nadam_phi_psi, 

448 "amsgrad": _amsgrad_phi_psi, 

449 "padam": _padam_phi_psi, 

450 "adamx": _adamx_phi_psi, 

451 "radam": _radam_phi_psi, 

452} 

453 

454 

455class SingleItemArray: 

456 """Mock an array with only a single item 

457 

458 Parameters 

459 ---------- 

460 value: 

461 The single value returned for any index. 

462 """ 

463 

464 def __init__(self, value): 

465 self.value = value 

466 

467 def __getitem__(self, item): 

468 return self.value 

469 

470 

471class AdaproxParameter(Parameter): 

472 """Operator updated using the Proximal ADAM algorithm 

473 

474 Uses multiple variants of adaptive quasi-Newton gradient descent 

475 

476 * Adam (Kingma & Ba 2015) 

477 * NAdam (Dozat 2016) 

478 * AMSGrad (Reddi, Kale & Kumar 2018) 

479 * PAdam (Chen & Gu 2018) 

480 * AdamX (Phuong & Phong 2019) 

481 * RAdam (Liu et al. 2019) 

482 

483 See details of the algorithms in the respective papers. 

484 

485 Parameters 

486 ---------- 

487 x: 

488 The array of values that is being fit. 

489 step: 

490 A numerical step value or function to calculate the step for a 

491 given `x`. 

492 grad: 

493 A function to calculate the gradient of `x`. 

494 prox: 

495 A function to take the proximal operator of `x`. 

496 b1: 

497 The decay rate of the first moment (mean) of the gradient. 

498 b2: 

499 The decay rate of the second moment (variance) of the gradient. 

500 eps: 

501 A small constant added for numerical stability. 

502 p: 

503 The power used by the ``PAdam`` scheme. 

504 m0: 

505 The initial value of the first moment. 

506 If `None` then an array of zeros is used. 

507 v0: 

508 The initial value of the second moment. 

509 If `None` then an array of zeros is used. 

510 vhat0: 

511 The initial value of the maximum second moment. 

512 If `None` then an array of ``-inf`` is used. 

513 scheme: 

514 The name of the ADAM variant to use to update the parameter. 

515 prox_e_rel: 

516 The relative error used by the proximal operator. 

517 """ 

518 

519 def __init__( 

520 self, 

521 x: np.ndarray, 

522 step: Callable | float, 

523 grad: Callable | None = None, 

524 prox: Callable | None = None, 

525 b1: float | SingleItemArray = 0.9, 

526 b2: float = 0.999, 

527 eps: float = 1e-8, 

528 p: float = 0.25, 

529 m0: np.ndarray | None = None, 

530 v0: np.ndarray | None = None, 

531 vhat0: np.ndarray | None = None, 

532 scheme: str = "amsgrad", 

533 prox_e_rel: float = 1e-6, 

534 ): 

535 shape = x.shape 

536 dtype = x.dtype 

537 if m0 is None: 

538 m0 = np.zeros(shape, dtype=dtype) 

539 

540 if v0 is None: 

541 v0 = np.zeros(shape, dtype=dtype) 

542 

543 if vhat0 is None: 

544 vhat0 = np.ones(shape, dtype=dtype) * -np.inf 

545 

546 super().__init__( 

547 x, 

548 { 

549 "m": m0, 

550 "v": v0, 

551 "vhat": vhat0, 

552 }, 

553 step, 

554 grad, 

555 prox, 

556 ) 

557 

558 if isinstance(b1, float): 

559 _b1 = SingleItemArray(b1) 

560 else: 

561 _b1 = b1 

562 

563 self.b1 = _b1 

564 self.b2 = b2 

565 self.eps = eps 

566 self.p = p 

567 

568 self.scheme = scheme 

569 self.phi_psi = phi_psi[scheme] 

570 self.e_rel = prox_e_rel 

571 

572 def update(self, it: int, input_grad: np.ndarray, *args): 

573 """Update the parameter and meta-parameters using the PGM 

574 

575 See `~Parameter` for more. 

576 """ 

577 _x = self.x 

578 # Calculate the gradient 

579 grad = cast(Callable, self.grad)(input_grad, _x, *args) 

580 # Get the update for the parameter 

581 phi, psi = self.phi_psi( 

582 it, 

583 grad, 

584 self.helpers["m"], 

585 self.helpers["v"], 

586 self.helpers["vhat"], 

587 self.b1, 

588 self.b2, 

589 self.eps, 

590 self.p, 

591 ) 

592 # Calculate the step size 

593 step = self.step 

594 if it > 0: 

595 _x += -step * phi / psi 

596 else: 

597 # This is a scheme that Peter Melchior and I came up with to 

598 # dampen the known affect of ADAM, where the first iteration 

599 # is often much larger than desired. 

600 _x += -step * phi / psi / 10 

601 

602 if self.prox is not None: 

603 self.x = self.prox(_x) 

604 

605 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> AdaproxParameter: 

606 """Create a deep copy of this parameter. 

607 

608 Parameters 

609 ---------- 

610 memo: 

611 A memoization dictionary used by `copy.deepcopy`. 

612 Returns 

613 ------- 

614 parameter: 

615 A deep copy of this parameter. 

616 """ 

617 return AdaproxParameter( 

618 deepcopy(self.x, memo), 

619 self.step, 

620 self.grad, 

621 self.prox, 

622 self.b1, 

623 self.b2, 

624 self.eps, 

625 self.p, 

626 deepcopy(self.helpers["m"], memo), 

627 deepcopy(self.helpers["v"], memo), 

628 deepcopy(self.helpers["vhat"], memo), 

629 scheme=self.scheme, 

630 prox_e_rel=self.e_rel, 

631 ) 

632 

633 def __copy__(self) -> AdaproxParameter: 

634 """Create a shallow copy of this parameter. 

635 

636 Returns 

637 ------- 

638 parameter: 

639 A shallow copy of this parameter. 

640 """ 

641 return AdaproxParameter( 

642 self.x, 

643 self.step, 

644 self.grad, 

645 self.prox, 

646 self.b1, 

647 self.b2, 

648 self.eps, 

649 self.p, 

650 self.helpers["m"], 

651 self.helpers["v"], 

652 self.helpers["vhat"], 

653 scheme=self.scheme, 

654 prox_e_rel=self.e_rel, 

655 ) 

656 

657 

658class FixedParameter(Parameter): 

659 """A parameter that is not updated 

660 

661 Parameters 

662 ---------- 

663 x: 

664 The array of values that is held fixed. 

665 """ 

666 

667 def __init__(self, x: np.ndarray): 

668 super().__init__(x, {}, 0) 

669 

670 def update(self, it: int, input_grad: np.ndarray, *args): 

671 pass 

672 

673 def __copy__(self) -> FixedParameter: 

674 """Create a shallow copy of this parameter. 

675 

676 Returns 

677 ------- 

678 parameter: 

679 A shallow copy of this parameter. 

680 """ 

681 return FixedParameter(self.x) 

682 

683 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FixedParameter: 

684 """Create a deep copy of this parameter. 

685 

686 Parameters 

687 ---------- 

688 memo: 

689 A memoization dictionary used by `copy.deepcopy`. 

690 

691 Returns 

692 ------- 

693 parameter: 

694 A deep copy of this parameter. 

695 """ 

696 return FixedParameter(deepcopy(self.x, memo)) 

697 

698 

699def relative_step( 

700 x: np.ndarray, 

701 factor: float = 0.1, 

702 minimum: float = 0, 

703 axis: int | Sequence[int] | None = None, 

704): 

705 """Step size set at `factor` times the mean of `X` in direction `axis`""" 

706 return np.maximum(minimum, factor * x.mean(axis=axis))