Coverage for python / lsst / images / fields / _product.py: 35%

81 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-23 08:25 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("ProductField", "ProductFieldSerializationModel") 

15 

16from collections.abc import Iterable 

17from typing import TYPE_CHECKING, Any, Literal, final 

18 

19import astropy.units 

20import numpy as np 

21import pydantic 

22 

23from .._geom import Bounds, Box 

24from .._image import Image 

25from ..serialization import ArchiveTree, InputArchive, InvalidParameterError, OutputArchive 

26from ._base import BaseField 

27 

28if TYPE_CHECKING: 

29 try: 

30 from lsst.afw.math import ProductBoundedField as LegacyProductBoundedField 

31 except ImportError: 

32 type LegacyProductBoundedField = Any # type: ignore[no-redef] 

33 

34 from ._concrete import Field, FieldSerializationModel 

35 

36 

37@final 

38class ProductField(BaseField): 

39 """A field that multiplies other fields lazily. 

40 

41 Parameters 

42 ---------- 

43 operands : `~collections.abc.Iterable` [ `BaseField` ] 

44 The fields to multiply together. 

45 """ 

46 

47 def __init__(self, operands: Iterable[Field]): 

48 self._operands = tuple(operands) 

49 if not self._operands: 

50 raise ValueError("At least one operand must be provided.") 

51 iterator = iter(self._operands) 

52 first = next(iterator) 

53 self._bounds = first.bounds 

54 self._unit = first.unit 

55 for operand in iterator: 

56 self._bounds = self._bounds.intersection(operand.bounds) 

57 if operand.unit is not None: 

58 if self._unit is None: 

59 self._unit = operand.unit 

60 else: 

61 self._unit *= operand.unit 

62 

63 @property 

64 def bounds(self) -> Bounds: 

65 return self._bounds 

66 

67 @property 

68 def unit(self) -> astropy.units.UnitBase | None: 

69 return self._unit 

70 

71 @property 

72 def operands(self) -> tuple[Field, ...]: 

73 """The fields that are multiplied together 

74 (`tuple` [`BaseField`, ...]). 

75 """ 

76 return self._operands 

77 

78 @property 

79 def is_constant(self) -> bool: 

80 return all(operand.is_constant for operand in self._operands) 

81 

82 def evaluate( 

83 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

84 ) -> np.ndarray | astropy.units.Quantity: 

85 iterator = iter(self._operands) 

86 first = next(iterator) 

87 result = first(x=x, y=y, quantity=False) 

88 for operand in iterator: 

89 result *= operand(x=x, y=y, quantity=False) 

90 if quantity: 

91 return result * self.unit 

92 return result 

93 

94 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

95 if bbox is None: 

96 bbox = self.bounds.bbox 

97 result = Image(1.0, bbox=bbox, dtype=dtype, unit=self.unit) 

98 for operand in self._operands: 

99 result.array *= operand.render(bbox, dtype=dtype).array 

100 return result 

101 

102 def multiply_constant( 

103 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase 

104 ) -> ProductField: 

105 new_operands = list(self._operands[:-1]) 

106 new_operands.append(self._operands[-1] * factor) 

107 return ProductField(new_operands) 

108 

109 def serialize(self, archive: OutputArchive[Any]) -> ProductFieldSerializationModel: 

110 """Serialize the field to an output archive.""" 

111 return ProductFieldSerializationModel( 

112 operands=[operand.serialize(archive) for operand in self._operands] 

113 ) 

114 

115 @staticmethod 

116 def _get_archive_tree_type( 

117 pointer_type: type[Any], 

118 ) -> type[ProductFieldSerializationModel]: 

119 """Return the serialization model type for this object for an archive 

120 type that uses the given pointer type. 

121 """ 

122 return ProductFieldSerializationModel 

123 

124 @staticmethod 

125 def from_legacy( 

126 legacy: LegacyProductBoundedField, 

127 unit: astropy.units.UnitBase | None = None, 

128 bounds: Bounds | None = None, 

129 ) -> ProductField: 

130 """Convert from a legacy `lsst.afw.math.ProductBoundedField`. 

131 

132 Parameters 

133 ---------- 

134 legacy 

135 Legacy field to convert. 

136 unit 

137 The units of the returned field (`lsst.afw.math.BoundedField` 

138 objects do not know their units). 

139 bounds 

140 The bounds of the returned field, if they should be different from 

141 the bounding box of ``legacy``. 

142 """ 

143 from ._concrete import field_from_legacy 

144 

145 legacy_factors = legacy.getFactors() 

146 operands = [field_from_legacy(f, bounds=bounds) for f in legacy_factors[:-1]] 

147 operands.append(field_from_legacy(legacy_factors[-1], unit=unit, bounds=bounds)) 

148 return ProductField(operands) 

149 

150 def to_legacy(self) -> LegacyProductBoundedField: 

151 """Convert to a legacy `lsst.afw.math.ProductBoundedField`.""" 

152 from lsst.afw.math import ProductBoundedField 

153 

154 # Not all Field types have a to_legacy, since they don't all have an 

155 # afw analog. But we just let that "no method" exception propagate. 

156 return ProductBoundedField( 

157 [operand.to_legacy() for operand in self._operands] # type: ignore[union-attr] 

158 ) 

159 

160 

161class ProductFieldSerializationModel(ArchiveTree): 

162 """Serialization model for `ProductField`.""" 

163 

164 operands: list[FieldSerializationModel] = pydantic.Field(default_factory=list) 

165 

166 field_type: Literal["PRODUCT"] = "PRODUCT" 

167 

168 def deserialize(self, archive: InputArchive, **kwargs: Any) -> ProductField: 

169 """Deserialize the field from an input archive.""" 

170 if kwargs: 

171 raise InvalidParameterError(f"Unrecognized parameters for ProductField: {set(kwargs.keys())}.") 

172 return ProductField([operand.deserialize(archive) for operand in self.operands])