Coverage for python / lsst / daf / butler / arrow_utils.py: 66%
232 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-20 01:07 -0700
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-20 01:07 -0700
1# This file is part of butler4.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "DateTimeArrowScalar",
32 "DateTimeArrowType",
33 "RegionArrowScalar",
34 "RegionArrowType",
35 "TimespanArrowScalar",
36 "TimespanArrowType",
37 "ToArrow",
38 "UUIDArrowScalar",
39 "UUIDArrowType",
40)
42import uuid
43from abc import ABC, abstractmethod
44from collections.abc import Callable
45from typing import Any, ClassVar, final
47import astropy.time
48import pyarrow as pa
50from lsst.sphgeom import Region
52from ._timespan import Timespan
53from .time_utils import TimeConverter
56class ToArrow(ABC):
57 """An abstract interface for converting objects to an Arrow field of the
58 appropriate type.
59 """
61 @staticmethod
62 def for_primitive(name: str, data_type: pa.DataType, nullable: bool) -> ToArrow:
63 """Return a converter for a primitive type already supported by Arrow.
65 Parameters
66 ----------
67 name : `str`
68 Name of the schema field.
69 data_type : `pyarrow.DataType`
70 Arrow data type object.
71 nullable : `bool`
72 Whether the field should permit null or `None` values.
74 Returns
75 -------
76 to_arrow : `ToArrow`
77 Converter instance.
78 """
79 return _ToArrowPrimitive(name, data_type, nullable)
81 @staticmethod
82 def for_uuid(name: str, nullable: bool = True) -> ToArrow:
83 """Return a converter for `uuid.UUID`.
85 Parameters
86 ----------
87 name : `str`
88 Name of the schema field.
89 nullable : `bool`
90 Whether the field should permit null or `None` values.
92 Returns
93 -------
94 to_arrow : `ToArrow`
95 Converter instance.
96 """
97 return _ToArrowUUID(name, nullable)
99 @staticmethod
100 def for_region(name: str, nullable: bool = True) -> ToArrow:
101 """Return a converter for `lsst.sphgeom.Region`.
103 Parameters
104 ----------
105 name : `str`
106 Name of the schema field.
107 nullable : `bool`
108 Whether the field should permit null or `None` values.
110 Returns
111 -------
112 to_arrow : `ToArrow`
113 Converter instance.
114 """
115 return _ToArrowRegion(name, nullable)
117 @staticmethod
118 def for_timespan(name: str, nullable: bool = True) -> ToArrow:
119 """Return a converter for `lsst.daf.butler.Timespan`.
121 Parameters
122 ----------
123 name : `str`
124 Name of the schema field.
125 nullable : `bool`
126 Whether the field should permit null or `None` values.
128 Returns
129 -------
130 to_arrow : `ToArrow`
131 Converter instance.
132 """
133 return _ToArrowTimespan(name, nullable)
135 @staticmethod
136 def for_datetime(name: str, nullable: bool = True) -> ToArrow:
137 """Return a converter for `astropy.time.Time`, stored as TAI
138 nanoseconds since 1970-01-01.
140 Parameters
141 ----------
142 name : `str`
143 Name of the schema field.
144 nullable : `bool`
145 Whether the field should permit null or `None` values.
147 Returns
148 -------
149 to_arrow : `ToArrow`
150 Converter instance.
151 """
152 return _ToArrowDateTime(name, nullable)
154 @property
155 @abstractmethod
156 def name(self) -> str:
157 """Name of the field."""
158 raise NotImplementedError()
160 @property
161 @abstractmethod
162 def nullable(self) -> bool:
163 """Whether the field permits null or `None` values."""
164 raise NotImplementedError()
166 @property
167 @abstractmethod
168 def data_type(self) -> pa.DataType:
169 """Arrow data type for the field this converter produces."""
170 raise NotImplementedError()
172 @property
173 def field(self) -> pa.Field:
174 """Arrow field this converter produces."""
175 return pa.field(self.name, self.data_type, self.nullable)
177 def dictionary_encoded(self) -> ToArrow:
178 """Return a new converter with the same name and type, but using
179 dictionary encoding (to 32-bit integers) to compress duplicate values.
180 """
181 return _ToArrowDictionary(self)
183 @abstractmethod
184 def append(self, value: Any, column: list[Any]) -> None:
185 """Append an object's arrow representation to a `list`.
187 Parameters
188 ----------
189 value : `object`
190 Original value to be converted to a row in an Arrow column.
191 column : `list`
192 List of values to append to. The type of value to append is
193 implementation-defined; the only requirement is that `finish` be
194 able to handle this `list` later.
195 """
196 raise NotImplementedError()
198 @abstractmethod
199 def finish(self, column: list[Any]) -> pa.Array:
200 """Convert a list of values constructed via `append` into an Arrow
201 array.
203 Parameters
204 ----------
205 column : `list`
206 List of column values populated by `append`.
207 """
208 raise NotImplementedError()
211class _ToArrowPrimitive(ToArrow):
212 """`ToArrow` implementation for primitive types supported direct by Arrow.
214 Should be constructed via the `ToArrow.for_primitive` factory method.
215 """
217 def __init__(self, name: str, data_type: pa.DataType, nullable: bool):
218 self._name = name
219 self._data_type = data_type
220 self._nullable = nullable
222 @property
223 def name(self) -> str:
224 # Docstring inherited.
225 return self._name
227 @property
228 def nullable(self) -> bool:
229 # Docstring inherited.
230 return self._nullable
232 @property
233 def data_type(self) -> pa.DataType:
234 # Docstring inherited.
235 return self._data_type
237 def append(self, value: Any, column: list[Any]) -> None:
238 # Docstring inherited.
239 column.append(value)
241 def finish(self, column: list[Any]) -> pa.Array:
242 # Docstring inherited.
243 return pa.array(column, self._data_type)
246class _ToArrowDictionary(ToArrow):
247 """`ToArrow` implementation for Arrow dictionary fields.
249 Should be constructed via the `ToArrow.dictionary_encoded` factory method.
250 """
252 def __init__(self, to_arrow_value: ToArrow):
253 self._to_arrow_value = to_arrow_value
255 @property
256 def name(self) -> str:
257 # Docstring inherited.
258 return self._to_arrow_value.name
260 @property
261 def nullable(self) -> bool:
262 # Docstring inherited.
263 return self._to_arrow_value.nullable
265 @property
266 def data_type(self) -> pa.DataType:
267 # Docstring inherited.
268 # We hard-code int32 as the index type here because that's what
269 # the pa.Arrow.dictionary_encode() method does.
270 return pa.dictionary(pa.int32(), self._to_arrow_value.data_type)
272 def append(self, value: Any, column: list[Any]) -> None:
273 # Docstring inherited.
274 self._to_arrow_value.append(value, column)
276 def finish(self, column: list[Any]) -> pa.Array:
277 # Docstring inherited.
278 return self._to_arrow_value.finish(column).dictionary_encode()
281class _ToArrowUUID(ToArrow):
282 """`ToArrow` implementation for `uuid.UUID` fields.
284 Should be constructed via the `ToArrow.for_uuid` factory method.
285 """
287 def __init__(self, name: str, nullable: bool):
288 self._name = name
289 self._nullable = nullable
291 storage_type: ClassVar[pa.DataType] = pa.binary(16)
293 @property
294 def name(self) -> str:
295 # Docstring inherited.
296 return self._name
298 @property
299 def nullable(self) -> bool:
300 # Docstring inherited.
301 return self._nullable
303 @property
304 def data_type(self) -> pa.DataType:
305 # Docstring inherited.
306 return UUIDArrowType()
308 def append(self, value: uuid.UUID | None, column: list[bytes | None]) -> None:
309 # Docstring inherited.
310 column.append(value.bytes if value is not None else None)
312 def finish(self, column: list[Any]) -> pa.Array:
313 # Docstring inherited.
314 storage_array = pa.array(column, self.storage_type)
315 return pa.ExtensionArray.from_storage(UUIDArrowType(), storage_array)
318class _ToArrowRegion(ToArrow):
319 """`ToArrow` implementation for `lsst.sphgeom.Region` fields.
321 Should be constructed via the `ToArrow.for_region` factory method.
322 """
324 def __init__(self, name: str, nullable: bool):
325 self._name = name
326 self._nullable = nullable
328 storage_type: ClassVar[pa.DataType] = pa.binary()
330 @property
331 def name(self) -> str:
332 # Docstring inherited.
333 return self._name
335 @property
336 def nullable(self) -> bool:
337 # Docstring inherited.
338 return self._nullable
340 @property
341 def data_type(self) -> pa.DataType:
342 # Docstring inherited.
343 return RegionArrowType()
345 def append(self, value: Region | None, column: list[bytes | None]) -> None:
346 # Docstring inherited.
347 column.append(value.encode() if value is not None else None)
349 def finish(self, column: list[Any]) -> pa.Array:
350 # Docstring inherited.
351 storage_array = pa.array(column, self.storage_type)
352 return pa.ExtensionArray.from_storage(RegionArrowType(), storage_array)
355class _ToArrowTimespan(ToArrow):
356 """`ToArrow` implementation for `lsst.daf.butler.timespan` fields.
358 Should be constructed via the `ToArrow.for_timespan` factory method.
359 """
361 def __init__(self, name: str, nullable: bool):
362 self._name = name
363 self._nullable = nullable
365 storage_type: ClassVar[pa.DataType] = pa.struct(
366 [
367 pa.field("begin_nsec", pa.int64(), nullable=False),
368 pa.field("end_nsec", pa.int64(), nullable=False),
369 ]
370 )
372 @property
373 def name(self) -> str:
374 # Docstring inherited.
375 return self._name
377 @property
378 def nullable(self) -> bool:
379 # Docstring inherited.
380 return self._nullable
382 @property
383 def data_type(self) -> pa.DataType:
384 # Docstring inherited.
385 return TimespanArrowType()
387 def append(self, value: Timespan | None, column: list[dict[str, int] | None]) -> None:
388 # Docstring inherited.
389 column.append({"begin_nsec": value.nsec[0], "end_nsec": value.nsec[1]} if value is not None else None)
391 def finish(self, column: list[Any]) -> pa.Array:
392 # Docstring inherited.
393 storage_array = pa.array(column, self.storage_type)
394 return pa.ExtensionArray.from_storage(TimespanArrowType(), storage_array)
397class _ToArrowDateTime(ToArrow):
398 """`ToArrow` implementation for `astropy.time.Time` fields.
400 Should be constructed via the `ToArrow.for_datetime` factory method.
401 """
403 def __init__(self, name: str, nullable: bool):
404 self._name = name
405 self._nullable = nullable
407 storage_type: ClassVar[pa.DataType] = pa.int64()
409 @property
410 def name(self) -> str:
411 # Docstring inherited.
412 return self._name
414 @property
415 def nullable(self) -> bool:
416 # Docstring inherited.
417 return self._nullable
419 @property
420 def data_type(self) -> pa.DataType:
421 # Docstring inherited.
422 return DateTimeArrowType()
424 def append(self, value: astropy.time.Time | None, column: list[int | None]) -> None:
425 # Docstring inherited.
426 column.append(TimeConverter().astropy_to_nsec(value) if value is not None else None)
428 def finish(self, column: list[Any]) -> pa.Array:
429 # Docstring inherited.
430 storage_array = pa.array(column, self.storage_type)
431 return pa.ExtensionArray.from_storage(DateTimeArrowType(), storage_array)
434@final
435class UUIDArrowType(pa.ExtensionType):
436 """An Arrow extension type for `uuid.UUID`, stored as 16 bytes."""
438 def __init__(self) -> None:
439 super().__init__(_ToArrowUUID.storage_type, "uuid.UUID")
441 def __arrow_ext_serialize__(self) -> bytes:
442 return b""
444 @classmethod
445 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> UUIDArrowType:
446 return cls()
448 def __arrow_ext_scalar_class__(self) -> type[UUIDArrowScalar]:
449 return UUIDArrowScalar
452@final
453class UUIDArrowScalar(pa.ExtensionScalar):
454 """An Arrow scalar type for `uuid.UUID`.
456 Use the standard `as_py` method to convert to an actual `uuid.UUID`
457 instance.
458 """
460 def as_py(self, **_unused: Any) -> uuid.UUID:
461 return uuid.UUID(bytes=self.value.as_py())
464@final
465class RegionArrowType(pa.ExtensionType):
466 """An Arrow extension type for lsst.sphgeom.Region."""
468 def __init__(self) -> None:
469 super().__init__(_ToArrowRegion.storage_type, "lsst.sphgeom.Region")
471 def __arrow_ext_serialize__(self) -> bytes:
472 return b""
474 @classmethod
475 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> RegionArrowType:
476 return cls()
478 def __arrow_ext_scalar_class__(self) -> type[RegionArrowScalar]:
479 return RegionArrowScalar
482@final
483class RegionArrowScalar(pa.ExtensionScalar):
484 """An Arrow scalar type for `lsst.sphgeom.Region`.
486 Use the standard `as_py` method to convert to an actual region.
487 """
489 def as_py(self, **_unused: Any) -> Region:
490 return Region.decode(self.value.as_py())
493@final
494class TimespanArrowType(pa.ExtensionType):
495 """An Arrow extension type for lsst.daf.butler.Timespan."""
497 def __init__(self) -> None:
498 super().__init__(_ToArrowTimespan.storage_type, "lsst.daf.butler.Timespan")
500 def __arrow_ext_serialize__(self) -> bytes:
501 return b""
503 @classmethod
504 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> TimespanArrowType:
505 return cls()
507 def __arrow_ext_scalar_class__(self) -> type[TimespanArrowScalar]:
508 return TimespanArrowScalar
511@final
512class TimespanArrowScalar(pa.ExtensionScalar):
513 """An Arrow scalar type for `lsst.daf.butler.Timespan`.
515 Use the standard `as_py` method to convert to an actual timespan.
516 """
518 def as_py(self, **_unused: Any) -> Timespan | None:
519 if self.value is None:
520 return None
521 else:
522 return Timespan(
523 None, None, _nsec=(self.value["begin_nsec"].as_py(), self.value["end_nsec"].as_py())
524 )
527@final
528class DateTimeArrowType(pa.ExtensionType):
529 """An Arrow extension type for `astropy.time.Time`, stored as TAI
530 nanoseconds since 1970-01-01.
531 """
533 def __init__(self) -> None:
534 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time")
536 def __arrow_ext_serialize__(self) -> bytes:
537 return b""
539 @classmethod
540 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> DateTimeArrowType:
541 return cls()
543 def __arrow_ext_scalar_class__(self) -> type[DateTimeArrowScalar]:
544 return DateTimeArrowScalar
547@final
548class DateTimeArrowScalar(pa.ExtensionScalar):
549 """An Arrow scalar type for `astropy.time.Time`, stored as TAI
550 nanoseconds since 1970-01-01.
552 Use the standard `as_py` method to convert to an actual `astropy.time.Time`
553 instance.
554 """
556 def as_py(self, **_unused: Any) -> astropy.time.Time:
557 return TimeConverter().nsec_to_astropy(self.value.as_py())
560class ArrowTableUtils:
561 """Utility functions for manipulating `pyarrow.Table` instances."""
563 @staticmethod
564 def replace_column(table: pa.Table, column_name: str, new_column: pa.Array) -> pa.Table:
565 """Return a new `pyarrow.Table` instance, replacing a given column in
566 the table with a new one.
568 Parameters
569 ----------
570 table
571 Original arrow table.
572 column_name
573 Name of the column to be replaced.
574 new_column
575 Replacement arrow column.
577 Returns
578 -------
579 table
580 Copy of the given table with the column replaced.
581 """
582 index = table.schema.get_field_index(column_name)
583 if index < 0:
584 raise ValueError(
585 f"Column {column_name} not found in arrow table, or multiple columns have the same name."
586 )
587 return table.set_column(index, column_name, new_column)
589 @staticmethod
590 def modify_column(
591 table: pa.Table, column_name: str, function: Callable[[pa.Array], pa.Array]
592 ) -> pa.Table:
593 """Return a new `pyarrow.Table` instance, applying a function to
594 replace one of the columns with a new one.
596 Parameters
597 ----------
598 table
599 Original arrow table.
600 column_name
601 Name of the column to be replaced.
602 function
603 Function that takes an arrow array, and returns a modified version
604 of that array.
606 Returns
607 -------
608 table
609 Copy of the given table with the column replaced with the value
610 returned from the callback function.
611 """
612 column = table.column(column_name)
613 new_column = function(column)
614 return ArrowTableUtils.replace_column(table, column_name, new_column)
617pa.register_extension_type(RegionArrowType())
618pa.register_extension_type(TimespanArrowType())
619pa.register_extension_type(DateTimeArrowType())