Coverage for tests / test_mask.py: 12%
136 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-23 08:25 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-23 08:25 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14import os
15import unittest
17import numpy as np
19import lsst.utils.tests
20from lsst.images import Box, Mask, MaskPlane, MaskSchema, get_legacy_visit_image_mask_planes
21from lsst.images.tests import RoundtripFits, assert_masks_equal, compare_mask_to_legacy
23DATA_DIR = os.environ.get("TESTDATA_IMAGES_DIR", None)
26class MaskTestCase(unittest.TestCase):
27 """Tests for Mask and its helper classes."""
29 def setUp(self) -> None:
30 self.maxDiff = None
31 self.rng = np.random.default_rng(500)
33 def make_mask_planes(self, n_planes: int, n_placeholders: int) -> list[MaskPlane | None]:
34 planes: list[MaskPlane | None] = []
35 for i in range(n_planes):
36 planes.append(MaskPlane(f"M{i}", f"D{i}"))
37 planes.extend([None] * n_placeholders)
38 self.rng.shuffle(planes)
39 return planes
41 def test_schema(self) -> None:
42 """Test MaskSchema."""
43 self.assertEqual(MaskSchema.bits_per_element(np.uint8), 8)
44 planes = self.make_mask_planes(17, 5)
45 with self.assertRaises(TypeError):
46 MaskSchema.bits_per_element(np.float32)
47 schema = MaskSchema(planes, dtype=np.uint8)
48 self.assertEqual(list(schema), planes)
49 self.assertEqual(len(schema), len(planes))
50 self.assertEqual(schema[5], planes[5])
51 self.assertEqual(
52 eval(repr(schema), {"dtype": np.dtype, "MaskSchema": MaskSchema, "MaskPlane": MaskPlane}), schema
53 )
54 string = str(schema)
55 self.assertEqual(len(string.split("\n")), 17)
56 bit5 = schema.bit("M5")
57 self.assertIn(f"M5 [{bit5.index}@{hex(bit5.mask)}]: D5", string)
58 self.assertEqual(schema, MaskSchema(planes, np.uint8))
59 self.assertNotEqual(schema, MaskSchema(planes, np.int16))
60 self.assertNotEqual(schema, MaskSchema(planes[:-1], np.uint8))
61 self.assertEqual(schema.dtype, np.dtype(np.uint8))
62 self.assertEqual(schema.mask_size, 3)
63 self.assertEqual(schema.names, {f"M{i}" for i in range(17)})
64 self.assertEqual(schema.descriptions, {f"M{i}": f"D{i}" for i in range(17)})
65 bit7 = schema.bit("M7")
66 bitmask57 = schema.bitmask("M5", "M7")
67 self.assertTrue(bitmask57[bit5.index] & bit5.mask)
68 self.assertTrue(bitmask57[bit7.index] & bit7.mask)
69 bitmask57[bit5.index] &= ~bit5.mask
70 bitmask57[bit7.index] &= ~bit7.mask
71 self.assertFalse(bitmask57.any())
72 splits = schema.split(np.int16)
73 self.assertEqual(len(splits), 2)
74 self.assertEqual(splits[0].mask_size, 1)
75 self.assertEqual(splits[1].mask_size, 1)
76 self.assertEqual(list(splits[0]) + list(splits[1]), [p for p in planes if p is not None])
77 self.assertEqual(len(splits[0]), 15)
78 self.assertEqual(len(splits[1]), 2)
80 def test_basics(self) -> None:
81 """Test some basic Mask functionality."""
82 planes = self.make_mask_planes(35, n_placeholders=5)
83 schema = MaskSchema(planes, dtype=np.uint8)
84 bbox = Box.factory[5:50, 6:60]
85 mask = Mask(
86 0,
87 schema=schema,
88 bbox=bbox,
89 metadata={"four_and_a_half": 4.5},
90 )
92 self.assertIs(mask[...], mask)
93 self.assertEqual(mask.__eq__(42), NotImplemented)
94 self.assertEqual(mask, mask)
95 self.maxDiff = None
96 self.assertEqual(
97 str(mask),
98 "Mask([y=5:50, x=6:60], ['M34', 'M15', 'M29', 'M1', 'M20', 'M11', 'M13', 'M7', 'M17', 'M12', "
99 "'M31', 'M16', 'M2', 'M3', 'M8', 'M26', 'M22', 'M5', 'M18', 'M19', 'M24', 'M21', 'M27', 'M6', "
100 "'M28', 'M10', 'M4', 'M23', 'M0', 'M25', 'M9', 'M14', 'M33', 'M32', 'M30'])",
101 )
102 self.assertTrue(
103 repr(mask).startswith(
104 "Mask(..., bbox=Box(y=Interval(start=5, stop=50), x=Interval(start=6, stop=60)), "
105 "schema=MaskSchema([MaskPlane(name='M34', description='D34')"
106 ),
107 f"Repr: {mask!r}",
108 )
110 with self.assertRaises(TypeError):
111 # No bbox, size or array.
112 Mask(0, schema=schema)
114 with self.assertRaises(ValueError):
115 # Box mismatch.
116 Mask(mask.array, schema=schema, bbox=Box.factory[0:20, -5:45])
118 with self.assertRaises(ValueError):
119 # Shape mismatch.
120 Mask(mask.array, schema=schema, shape=(5, 10, 5))
122 with self.assertRaises(ValueError):
123 # Cannot be 2-D.
124 Mask(mask.array.reshape((2430, 5)), schema=schema, bbox=Box.factory[0:20, -5:45])
126 def test_read_write(self) -> None:
127 """Explicit calls to read and write fits."""
128 planes = self.make_mask_planes(35, n_placeholders=5)
129 schema = MaskSchema(planes, dtype=np.uint8)
130 bbox = Box.factory[5:50, 6:60]
131 mask = Mask(
132 0,
133 schema=schema,
134 bbox=bbox,
135 metadata={"four_and_a_half": 4.5},
136 )
137 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile:
138 mask.write_fits(tmpFile)
140 new = Mask.read_fits(tmpFile)
141 self.assertEqual(new, mask)
142 # __eq__ ignores metadata.
143 self.assertEqual(new.metadata["four_and_a_half"], 4.5)
144 self.assertEqual(new.metadata, mask.metadata)
146 def test_serialize_multi(self) -> None:
147 """Test serializing a mask with more than 31 mask planes, requiring
148 more than one HDU and EXTVER.
150 Note that serialization for simpler cases is covered by
151 test_masked_image.py.
152 """
153 planes = self.make_mask_planes(35, n_placeholders=5)
154 schema = MaskSchema(planes, dtype=np.uint8)
155 bbox = Box.factory[5:50, 6:60]
156 mask = Mask(0, schema=schema, bbox=bbox, metadata={"four_and_a_half": 4.5})
157 shape = bbox.shape
158 for plane in schema:
159 if plane is not None:
160 mask.set(plane.name, self.rng.random(shape) > 0.5)
161 with RoundtripFits(self, mask) as roundtrip:
162 fits = roundtrip.inspect()
163 self.assertEqual(fits[1].header["EXTNAME"], "MASK")
164 self.assertEqual(fits[1].header.get("EXTVER", 1), 1)
165 self.assertEqual(fits[1].header["ZCMPTYPE"], "GZIP_2")
166 self.assertEqual(fits[2].header["EXTNAME"], "MASK")
167 self.assertEqual(fits[2].header["EXTVER"], 2)
168 self.assertEqual(fits[2].header["ZCMPTYPE"], "GZIP_2")
169 n = 0
170 for plane in planes:
171 if plane is not None:
172 hdu = fits[1] if n < 31 else fits[2]
173 self.assertEqual(hdu.header[f"MSKN{(n % 31):04d}"], plane.name)
174 self.assertEqual(hdu.header[f"MSKM{(n % 31):04d}"], 1 << (n % 31))
175 self.assertEqual(hdu.header[f"MSKD{(n % 31):04d}"], plane.description)
176 n += 1
177 assert_masks_equal(self, mask, roundtrip.result)
179 @unittest.skipUnless(DATA_DIR is not None, "TESTDATA_IMAGES_DIR is not in the environment.")
180 def test_legacy(self) -> None:
181 """Test Mask.read_legacy, Mask.to_legacy, and Mask.from_legacy."""
182 assert DATA_DIR is not None, "Guaranteed by decorator."
183 filename = os.path.join(DATA_DIR, "dp2", "legacy", "visit_image.fits")
184 plane_map = get_legacy_visit_image_mask_planes()
185 mask = Mask.read_legacy(filename, ext=2, plane_map=plane_map)
186 try:
187 from lsst.afw.image import MaskedImageFitsReader
188 except ImportError:
189 raise unittest.SkipTest("'lsst.afw.image' could not be imported.") from None
190 reader = MaskedImageFitsReader(filename)
191 self.assertEqual(mask.bbox, Box.from_legacy(reader.readBBox()))
192 legacy_mask = reader.readMask()
193 compare_mask_to_legacy(self, mask, legacy_mask, plane_map)
194 compare_mask_to_legacy(self, mask, mask.to_legacy(plane_map), plane_map)
195 assert_masks_equal(self, mask, Mask.from_legacy(legacy_mask, plane_map=plane_map))
196 # Write the mask out in the new format, and test that we can read it
197 # back either way.
198 with RoundtripFits(self, mask, storage_class="MaskV2") as roundtrip:
199 with self.subTest():
200 try:
201 import lsst.afw.image
202 except ImportError:
203 raise unittest.SkipTest("afw could not be imported") from None
204 legacy_mask = roundtrip.get(storageClass="Mask")
205 self.assertIsInstance(legacy_mask, lsst.afw.image.Mask)
206 compare_mask_to_legacy(self, mask, legacy_mask)
207 assert_masks_equal(self, roundtrip.result, mask)
210if __name__ == "__main__":
211 unittest.main()