Coverage for python / lsst / images / fields / _product.py: 36%

76 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-14 08:07 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("ProductField", "ProductFieldSerializationModel") 

15 

16from collections.abc import Iterable 

17from typing import TYPE_CHECKING, Any, Literal, final 

18 

19import astropy.units 

20import numpy as np 

21import pydantic 

22 

23from .._geom import Bounds, Box 

24from .._image import Image 

25from ..serialization import ArchiveTree, InputArchive, OutputArchive 

26from ._base import BaseField 

27 

28if TYPE_CHECKING: 

29 try: 

30 from lsst.afw.math import ProductBoundedField as LegacyProductBoundedField 

31 except ImportError: 

32 type LegacyProductBoundedField = Any # type: ignore[no-redef] 

33 

34 from ._concrete import Field, FieldSerializationModel 

35 

36 

37@final 

38class ProductField(BaseField): 

39 """A field that multiplies other fields lazily. 

40 

41 Parameters 

42 ---------- 

43 operands : `~collections.abc.Iterable` [ `BaseField` ] 

44 The fields to multiply together. 

45 """ 

46 

47 def __init__(self, operands: Iterable[Field]): 

48 self._operands = tuple(operands) 

49 if not self._operands: 

50 raise ValueError("At least one operand must be provided.") 

51 iterator = iter(self._operands) 

52 first = next(iterator) 

53 self._bounds = first.bounds 

54 self._unit = first.unit 

55 for operand in iterator: 

56 self._bounds = self._bounds.intersection(operand.bounds) 

57 if operand.unit is not None: 

58 if self._unit is None: 

59 self._unit = operand.unit 

60 else: 

61 self._unit *= operand.unit 

62 

63 @property 

64 def bounds(self) -> Bounds: 

65 return self._bounds 

66 

67 @property 

68 def unit(self) -> astropy.units.UnitBase | None: 

69 return self._unit 

70 

71 @property 

72 def operands(self) -> tuple[Field, ...]: 

73 """The fields that are multiplied together 

74 (`tuple` [`BaseField`, ...]). 

75 """ 

76 return self._operands 

77 

78 def evaluate( 

79 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

80 ) -> np.ndarray | astropy.units.Quantity: 

81 iterator = iter(self._operands) 

82 first = next(iterator) 

83 result = first(x=x, y=y, quantity=False) 

84 for operand in iterator: 

85 result *= operand(x=x, y=y, quantity=False) 

86 if quantity: 

87 return result * self.unit 

88 return result 

89 

90 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

91 if bbox is None: 

92 bbox = self.bounds.bbox 

93 result = Image(1.0, bbox=bbox, dtype=dtype, unit=self.unit) 

94 for operand in self._operands: 

95 result.array *= operand.render(bbox, dtype=dtype).array 

96 return result 

97 

98 def multiply_constant( 

99 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase 

100 ) -> ProductField: 

101 new_operands = list(self._operands[:-1]) 

102 new_operands.append(self._operands[-1] * factor) 

103 return ProductField(new_operands) 

104 

105 def serialize(self, archive: OutputArchive[Any]) -> ProductFieldSerializationModel: 

106 """Serialize the field to an output archive.""" 

107 return ProductFieldSerializationModel( 

108 operands=[operand.serialize(archive) for operand in self._operands] 

109 ) 

110 

111 @staticmethod 

112 def _get_archive_tree_type( 

113 pointer_type: type[Any], 

114 ) -> type[ProductFieldSerializationModel]: 

115 """Return the serialization model type for this object for an archive 

116 type that uses the given pointer type. 

117 """ 

118 return ProductFieldSerializationModel 

119 

120 @staticmethod 

121 def from_legacy( 

122 legacy: LegacyProductBoundedField, unit: astropy.units.UnitBase | None = None 

123 ) -> ProductField: 

124 """Convert from a legacy `lsst.afw.math.ProductBoundedField`.""" 

125 from ._concrete import field_from_legacy 

126 

127 legacy_factors = legacy.getFactors() 

128 operands = [field_from_legacy(f) for f in legacy_factors[:-1]] 

129 operands.append(field_from_legacy(legacy_factors[-1], unit=unit)) 

130 return ProductField(operands) 

131 

132 def to_legacy(self) -> LegacyProductBoundedField: 

133 """Convert to a legacy `lsst.afw.math.ProductBoundedField`.""" 

134 from lsst.afw.math import ProductBoundedField 

135 

136 # Not all Field types have a to_legacy, since they don't all have an 

137 # afw analog. But we just let that "no method" exception propagate. 

138 return ProductBoundedField( 

139 [operand.to_legacy() for operand in self._operands] # type: ignore[union-attr] 

140 ) 

141 

142 

143class ProductFieldSerializationModel(ArchiveTree): 

144 """Serialization model for `ProductField`.""" 

145 

146 operands: list[FieldSerializationModel] = pydantic.Field(default_factory=list) 

147 

148 field_type: Literal["PRODUCT"] = "PRODUCT" 

149 

150 def deserialize(self, archive: InputArchive) -> ProductField: 

151 """Deserialize the field from an input archive.""" 

152 return ProductField([operand.deserialize(archive) for operand in self.operands])