Coverage for python/lsst/images/fields/_sum.py: 34%

83 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-29 08:43 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("SumField", "SumFieldSerializationModel") 

15 

16from collections.abc import Iterable 

17from typing import TYPE_CHECKING, Any, Literal, final 

18 

19import astropy.units 

20import numpy as np 

21import pydantic 

22 

23from .._geom import Bounds, Box 

24from .._image import Image 

25from ..serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive 

26from ._base import BaseField 

27 

28if TYPE_CHECKING: 

29 try: 

30 from lsst.afw.math import BackgroundList as LegacyBackgroundList 

31 except ImportError: 

32 type LegacyBackgroundList = Any # type: ignore[no-redef] 

33 

34 from ._concrete import Field, FieldSerializationModel 

35 

36 

37@final 

38class SumField(BaseField): 

39 """A field that sums other fields lazily. 

40 

41 Parameters 

42 ---------- 

43 operands : `~collections.abc.Iterable` [ `BaseField` ] 

44 The fields to sum together. 

45 """ 

46 

47 def __init__(self, operands: Iterable[Field]): 

48 self._operands = tuple(operands) 

49 if not self._operands: 

50 raise ValueError("At least one operand must be provided.") 

51 iterator = iter(self._operands) 

52 first = next(iterator) 

53 self._bounds = first.bounds 

54 self._unit = first.unit 

55 for operand in iterator: 

56 self._bounds = self._bounds.intersection(operand.bounds) 

57 if operand.unit is None: 

58 if self._unit is not None: 

59 raise astropy.units.UnitConversionError( 

60 "Cannot add a field with no units to a field with units." 

61 ) 

62 elif self._unit is None: 

63 raise astropy.units.UnitConversionError( 

64 "Cannot add a field with units to a field with no units." 

65 ) 

66 else: 

67 # Raise if these units are not sum-compatible. 

68 self._unit.to(operand.unit) 

69 

70 def __eq__(self, other: object) -> bool: 

71 if type(other) is not SumField: 

72 return NotImplemented 

73 # ``_bounds`` and ``_unit`` are derived from the operands, so 

74 # comparing the operand tuple is sufficient. 

75 return self._operands == other._operands 

76 

77 __hash__ = None # type: ignore[assignment] 

78 

79 @property 

80 def bounds(self) -> Bounds: 

81 return self._bounds 

82 

83 @property 

84 def unit(self) -> astropy.units.UnitBase | None: 

85 return self._unit 

86 

87 @property 

88 def operands(self) -> tuple[Field, ...]: 

89 """The fields that are summed together (`tuple` [`BaseField`, ...]).""" 

90 return self._operands 

91 

92 @property 

93 def is_constant(self) -> bool: 

94 return all(operand.is_constant for operand in self._operands) 

95 

96 def evaluate( 

97 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

98 ) -> np.ndarray | astropy.units.Quantity: 

99 iterator = iter(self._operands) 

100 first = next(iterator) 

101 # We have to add quantities if this is a unit-aware field, as the 

102 # terms in the sum might have different-but-compatible units. 

103 result = first(x=x, y=y, quantity=(self.unit is not None)) 

104 for operand in iterator: 

105 result += operand(x=x, y=y, quantity=(self.unit is not None)) 

106 if self.unit is not None and not quantity: 

107 # Caller doesn't want a Quantity back. 

108 assert isinstance(result, astropy.units.Quantity) 

109 return result.to_value(self.unit) 

110 if self.unit is None and quantity: 

111 # Caller wants a Quantity back even though there's no units. 

112 return astropy.units.Quantity(result) 

113 return result 

114 

115 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

116 if bbox is None: 

117 bbox = self.bounds.bbox 

118 result = Image(0.0, bbox=bbox, dtype=dtype, unit=self.unit) 

119 for operand in self._operands: 

120 result.quantity += operand.render(bbox, dtype=dtype).quantity 

121 return result 

122 

123 def multiply_constant(self, factor: float | astropy.units.Quantity | astropy.units.UnitBase) -> SumField: 

124 return SumField([operand * factor for operand in self._operands]) 

125 

126 def serialize(self, archive: OutputArchive[Any]) -> SumFieldSerializationModel: 

127 """Serialize the field to an output archive.""" 

128 return SumFieldSerializationModel(operands=[operand.serialize(archive) for operand in self._operands]) 

129 

130 @staticmethod 

131 def _get_archive_tree_type( 

132 pointer_type: type[Any], 

133 ) -> type[SumFieldSerializationModel]: 

134 """Return the serialization model type for this object for an archive 

135 type that uses the given pointer type. 

136 """ 

137 return SumFieldSerializationModel 

138 

139 @staticmethod 

140 def from_legacy_background( 

141 legacy_background: LegacyBackgroundList, 

142 bounds: Bounds | None = None, 

143 unit: astropy.units.UnitBase | None = None, 

144 ) -> SumField: 

145 """Convert from a legacy `lsst.afw.math.BackgroundList` instance. 

146 

147 Parameters 

148 ---------- 

149 legacy 

150 Legacy background object to convert. 

151 bounds 

152 The bounds of the returned field, if they should be different from 

153 the bounding box of ``legacy_background``. 

154 unit 

155 The units of the returned field (`lsst.afw.math.BackgroundList` 

156 objects do not know their units). 

157 """ 

158 from ._concrete import field_from_legacy_background 

159 

160 return SumField( 

161 [field_from_legacy_background(b, bounds=bounds, unit=unit) for b, *_ in legacy_background] 

162 ) 

163 

164 

165class SumFieldSerializationModel(ArchiveTree): 

166 """Serialization model for `SumField`.""" 

167 

168 operands: list[FieldSerializationModel] = pydantic.Field(default_factory=list) 

169 

170 field_type: Literal["SUM"] = "SUM" 

171 

172 def deserialize(self, archive: InputArchive, **kwargs: Any) -> SumField: 

173 """Deserialize the field from an input archive.""" 

174 if kwargs: 

175 raise InvalidParameterError(f"Unrecognized parameters for SumField: {set(kwargs.keys())}.") 

176 return SumField([operand.deserialize(archive) for operand in self.operands])