Coverage for python/lsst/images/fields/_product.py: 35%

86 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-29 08:40 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("ProductField", "ProductFieldSerializationModel") 

15 

16from collections.abc import Iterable 

17from typing import TYPE_CHECKING, Any, Literal, final 

18 

19import astropy.units 

20import numpy as np 

21import pydantic 

22 

23from .._geom import Bounds, Box 

24from .._image import Image 

25from ..serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive 

26from ._base import BaseField 

27 

28if TYPE_CHECKING: 

29 try: 

30 from lsst.afw.math import ProductBoundedField as LegacyProductBoundedField 

31 except ImportError: 

32 type LegacyProductBoundedField = Any # type: ignore[no-redef] 

33 

34 from ._concrete import Field, FieldSerializationModel 

35 

36 

37@final 

38class ProductField(BaseField): 

39 """A field that multiplies other fields lazily. 

40 

41 Parameters 

42 ---------- 

43 operands : `~collections.abc.Iterable` [ `BaseField` ] 

44 The fields to multiply together. 

45 """ 

46 

47 def __init__(self, operands: Iterable[Field]): 

48 self._operands = tuple(operands) 

49 if not self._operands: 

50 raise ValueError("At least one operand must be provided.") 

51 iterator = iter(self._operands) 

52 first = next(iterator) 

53 self._bounds = first.bounds 

54 self._unit = first.unit 

55 for operand in iterator: 

56 self._bounds = self._bounds.intersection(operand.bounds) 

57 if operand.unit is not None: 

58 if self._unit is None: 

59 self._unit = operand.unit 

60 else: 

61 self._unit *= operand.unit 

62 

63 def __eq__(self, other: object) -> bool: 

64 if type(other) is not ProductField: 

65 return NotImplemented 

66 # ``_bounds`` and ``_unit`` are derived from the operands, so 

67 # comparing the operand tuple is sufficient. 

68 return self._operands == other._operands 

69 

70 __hash__ = None # type: ignore[assignment] 

71 

72 @property 

73 def bounds(self) -> Bounds: 

74 return self._bounds 

75 

76 @property 

77 def unit(self) -> astropy.units.UnitBase | None: 

78 return self._unit 

79 

80 @property 

81 def operands(self) -> tuple[Field, ...]: 

82 """The fields that are multiplied together 

83 (`tuple` [`BaseField`, ...]). 

84 """ 

85 return self._operands 

86 

87 @property 

88 def is_constant(self) -> bool: 

89 return all(operand.is_constant for operand in self._operands) 

90 

91 def evaluate( 

92 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

93 ) -> np.ndarray | astropy.units.Quantity: 

94 iterator = iter(self._operands) 

95 first = next(iterator) 

96 result = first(x=x, y=y, quantity=False) 

97 for operand in iterator: 

98 result *= operand(x=x, y=y, quantity=False) 

99 if quantity: 

100 return result * self.unit 

101 return result 

102 

103 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

104 if bbox is None: 

105 bbox = self.bounds.bbox 

106 result = Image(1.0, bbox=bbox, dtype=dtype, unit=self.unit) 

107 for operand in self._operands: 

108 result.array *= operand.render(bbox, dtype=dtype).array 

109 return result 

110 

111 def multiply_constant( 

112 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase 

113 ) -> ProductField: 

114 new_operands = list(self._operands[:-1]) 

115 new_operands.append(self._operands[-1] * factor) 

116 return ProductField(new_operands) 

117 

118 def serialize(self, archive: OutputArchive[Any]) -> ProductFieldSerializationModel: 

119 """Serialize the field to an output archive.""" 

120 return ProductFieldSerializationModel( 

121 operands=[operand.serialize(archive) for operand in self._operands] 

122 ) 

123 

124 @staticmethod 

125 def _get_archive_tree_type( 

126 pointer_type: type[Any], 

127 ) -> type[ProductFieldSerializationModel]: 

128 """Return the serialization model type for this object for an archive 

129 type that uses the given pointer type. 

130 """ 

131 return ProductFieldSerializationModel 

132 

133 @staticmethod 

134 def from_legacy( 

135 legacy: LegacyProductBoundedField, 

136 unit: astropy.units.UnitBase | None = None, 

137 bounds: Bounds | None = None, 

138 ) -> ProductField: 

139 """Convert from a legacy `lsst.afw.math.ProductBoundedField`. 

140 

141 Parameters 

142 ---------- 

143 legacy 

144 Legacy field to convert. 

145 unit 

146 The units of the returned field (`lsst.afw.math.BoundedField` 

147 objects do not know their units). 

148 bounds 

149 The bounds of the returned field, if they should be different from 

150 the bounding box of ``legacy``. 

151 """ 

152 from ._concrete import field_from_legacy 

153 

154 legacy_factors = legacy.getFactors() 

155 operands = [field_from_legacy(f, bounds=bounds) for f in legacy_factors[:-1]] 

156 operands.append(field_from_legacy(legacy_factors[-1], unit=unit, bounds=bounds)) 

157 return ProductField(operands) 

158 

159 def to_legacy(self) -> LegacyProductBoundedField: 

160 """Convert to a legacy `lsst.afw.math.ProductBoundedField`.""" 

161 from lsst.afw.math import ProductBoundedField 

162 

163 # Not all Field types have a to_legacy, since they don't all have an 

164 # afw analog. But we just let that "no method" exception propagate. 

165 return ProductBoundedField( 

166 [operand.to_legacy() for operand in self._operands] # type: ignore[union-attr] 

167 ) 

168 

169 

170class ProductFieldSerializationModel(ArchiveTree): 

171 """Serialization model for `ProductField`.""" 

172 

173 operands: list[FieldSerializationModel] = pydantic.Field(default_factory=list) 

174 

175 field_type: Literal["PRODUCT"] = "PRODUCT" 

176 

177 def deserialize(self, archive: InputArchive, **kwargs: Any) -> ProductField: 

178 """Deserialize the field from an input archive.""" 

179 if kwargs: 

180 raise InvalidParameterError(f"Unrecognized parameters for ProductField: {set(kwargs.keys())}.") 

181 return ProductField([operand.deserialize(archive) for operand in self.operands])