Coverage for python / lsst / daf / butler / arrow_utils.py: 66%

232 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-14 07:38 +0000

1# This file is part of butler4. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "DateTimeArrowScalar", 

32 "DateTimeArrowType", 

33 "RegionArrowScalar", 

34 "RegionArrowType", 

35 "TimespanArrowScalar", 

36 "TimespanArrowType", 

37 "ToArrow", 

38 "UUIDArrowScalar", 

39 "UUIDArrowType", 

40) 

41 

42import uuid 

43from abc import ABC, abstractmethod 

44from collections.abc import Callable 

45from typing import Any, ClassVar, final 

46 

47import astropy.time 

48import pyarrow as pa 

49 

50from lsst.sphgeom import Region 

51 

52from ._timespan import Timespan 

53from .time_utils import TimeConverter 

54 

55 

56class ToArrow(ABC): 

57 """An abstract interface for converting objects to an Arrow field of the 

58 appropriate type. 

59 """ 

60 

61 @staticmethod 

62 def for_primitive(name: str, data_type: pa.DataType, nullable: bool) -> ToArrow: 

63 """Return a converter for a primitive type already supported by Arrow. 

64 

65 Parameters 

66 ---------- 

67 name : `str` 

68 Name of the schema field. 

69 data_type : `pyarrow.DataType` 

70 Arrow data type object. 

71 nullable : `bool` 

72 Whether the field should permit null or `None` values. 

73 

74 Returns 

75 ------- 

76 to_arrow : `ToArrow` 

77 Converter instance. 

78 """ 

79 return _ToArrowPrimitive(name, data_type, nullable) 

80 

81 @staticmethod 

82 def for_uuid(name: str, nullable: bool = True) -> ToArrow: 

83 """Return a converter for `uuid.UUID`. 

84 

85 Parameters 

86 ---------- 

87 name : `str` 

88 Name of the schema field. 

89 nullable : `bool` 

90 Whether the field should permit null or `None` values. 

91 

92 Returns 

93 ------- 

94 to_arrow : `ToArrow` 

95 Converter instance. 

96 """ 

97 return _ToArrowUUID(name, nullable) 

98 

99 @staticmethod 

100 def for_region(name: str, nullable: bool = True) -> ToArrow: 

101 """Return a converter for `lsst.sphgeom.Region`. 

102 

103 Parameters 

104 ---------- 

105 name : `str` 

106 Name of the schema field. 

107 nullable : `bool` 

108 Whether the field should permit null or `None` values. 

109 

110 Returns 

111 ------- 

112 to_arrow : `ToArrow` 

113 Converter instance. 

114 """ 

115 return _ToArrowRegion(name, nullable) 

116 

117 @staticmethod 

118 def for_timespan(name: str, nullable: bool = True) -> ToArrow: 

119 """Return a converter for `lsst.daf.butler.Timespan`. 

120 

121 Parameters 

122 ---------- 

123 name : `str` 

124 Name of the schema field. 

125 nullable : `bool` 

126 Whether the field should permit null or `None` values. 

127 

128 Returns 

129 ------- 

130 to_arrow : `ToArrow` 

131 Converter instance. 

132 """ 

133 return _ToArrowTimespan(name, nullable) 

134 

135 @staticmethod 

136 def for_datetime(name: str, nullable: bool = True) -> ToArrow: 

137 """Return a converter for `astropy.time.Time`, stored as TAI 

138 nanoseconds since 1970-01-01. 

139 

140 Parameters 

141 ---------- 

142 name : `str` 

143 Name of the schema field. 

144 nullable : `bool` 

145 Whether the field should permit null or `None` values. 

146 

147 Returns 

148 ------- 

149 to_arrow : `ToArrow` 

150 Converter instance. 

151 """ 

152 return _ToArrowDateTime(name, nullable) 

153 

154 @property 

155 @abstractmethod 

156 def name(self) -> str: 

157 """Name of the field.""" 

158 raise NotImplementedError() 

159 

160 @property 

161 @abstractmethod 

162 def nullable(self) -> bool: 

163 """Whether the field permits null or `None` values.""" 

164 raise NotImplementedError() 

165 

166 @property 

167 @abstractmethod 

168 def data_type(self) -> pa.DataType: 

169 """Arrow data type for the field this converter produces.""" 

170 raise NotImplementedError() 

171 

172 @property 

173 def field(self) -> pa.Field: 

174 """Arrow field this converter produces.""" 

175 return pa.field(self.name, self.data_type, self.nullable) 

176 

177 def dictionary_encoded(self) -> ToArrow: 

178 """Return a new converter with the same name and type, but using 

179 dictionary encoding (to 32-bit integers) to compress duplicate values. 

180 """ 

181 return _ToArrowDictionary(self) 

182 

183 @abstractmethod 

184 def append(self, value: Any, column: list[Any]) -> None: 

185 """Append an object's arrow representation to a `list`. 

186 

187 Parameters 

188 ---------- 

189 value : `object` 

190 Original value to be converted to a row in an Arrow column. 

191 column : `list` 

192 List of values to append to. The type of value to append is 

193 implementation-defined; the only requirement is that `finish` be 

194 able to handle this `list` later. 

195 """ 

196 raise NotImplementedError() 

197 

198 @abstractmethod 

199 def finish(self, column: list[Any]) -> pa.Array: 

200 """Convert a list of values constructed via `append` into an Arrow 

201 array. 

202 

203 Parameters 

204 ---------- 

205 column : `list` 

206 List of column values populated by `append`. 

207 """ 

208 raise NotImplementedError() 

209 

210 

211class _ToArrowPrimitive(ToArrow): 

212 """`ToArrow` implementation for primitive types supported direct by Arrow. 

213 

214 Should be constructed via the `ToArrow.for_primitive` factory method. 

215 """ 

216 

217 def __init__(self, name: str, data_type: pa.DataType, nullable: bool): 

218 self._name = name 

219 self._data_type = data_type 

220 self._nullable = nullable 

221 

222 @property 

223 def name(self) -> str: 

224 # Docstring inherited. 

225 return self._name 

226 

227 @property 

228 def nullable(self) -> bool: 

229 # Docstring inherited. 

230 return self._nullable 

231 

232 @property 

233 def data_type(self) -> pa.DataType: 

234 # Docstring inherited. 

235 return self._data_type 

236 

237 def append(self, value: Any, column: list[Any]) -> None: 

238 # Docstring inherited. 

239 column.append(value) 

240 

241 def finish(self, column: list[Any]) -> pa.Array: 

242 # Docstring inherited. 

243 return pa.array(column, self._data_type) 

244 

245 

246class _ToArrowDictionary(ToArrow): 

247 """`ToArrow` implementation for Arrow dictionary fields. 

248 

249 Should be constructed via the `ToArrow.dictionary_encoded` factory method. 

250 """ 

251 

252 def __init__(self, to_arrow_value: ToArrow): 

253 self._to_arrow_value = to_arrow_value 

254 

255 @property 

256 def name(self) -> str: 

257 # Docstring inherited. 

258 return self._to_arrow_value.name 

259 

260 @property 

261 def nullable(self) -> bool: 

262 # Docstring inherited. 

263 return self._to_arrow_value.nullable 

264 

265 @property 

266 def data_type(self) -> pa.DataType: 

267 # Docstring inherited. 

268 # We hard-code int32 as the index type here because that's what 

269 # the pa.Arrow.dictionary_encode() method does. 

270 return pa.dictionary(pa.int32(), self._to_arrow_value.data_type) 

271 

272 def append(self, value: Any, column: list[Any]) -> None: 

273 # Docstring inherited. 

274 self._to_arrow_value.append(value, column) 

275 

276 def finish(self, column: list[Any]) -> pa.Array: 

277 # Docstring inherited. 

278 return self._to_arrow_value.finish(column).dictionary_encode() 

279 

280 

281class _ToArrowUUID(ToArrow): 

282 """`ToArrow` implementation for `uuid.UUID` fields. 

283 

284 Should be constructed via the `ToArrow.for_uuid` factory method. 

285 """ 

286 

287 def __init__(self, name: str, nullable: bool): 

288 self._name = name 

289 self._nullable = nullable 

290 

291 storage_type: ClassVar[pa.DataType] = pa.binary(16) 

292 

293 @property 

294 def name(self) -> str: 

295 # Docstring inherited. 

296 return self._name 

297 

298 @property 

299 def nullable(self) -> bool: 

300 # Docstring inherited. 

301 return self._nullable 

302 

303 @property 

304 def data_type(self) -> pa.DataType: 

305 # Docstring inherited. 

306 return UUIDArrowType() 

307 

308 def append(self, value: uuid.UUID | None, column: list[bytes | None]) -> None: 

309 # Docstring inherited. 

310 column.append(value.bytes if value is not None else None) 

311 

312 def finish(self, column: list[Any]) -> pa.Array: 

313 # Docstring inherited. 

314 storage_array = pa.array(column, self.storage_type) 

315 return pa.ExtensionArray.from_storage(UUIDArrowType(), storage_array) 

316 

317 

318class _ToArrowRegion(ToArrow): 

319 """`ToArrow` implementation for `lsst.sphgeom.Region` fields. 

320 

321 Should be constructed via the `ToArrow.for_region` factory method. 

322 """ 

323 

324 def __init__(self, name: str, nullable: bool): 

325 self._name = name 

326 self._nullable = nullable 

327 

328 storage_type: ClassVar[pa.DataType] = pa.binary() 

329 

330 @property 

331 def name(self) -> str: 

332 # Docstring inherited. 

333 return self._name 

334 

335 @property 

336 def nullable(self) -> bool: 

337 # Docstring inherited. 

338 return self._nullable 

339 

340 @property 

341 def data_type(self) -> pa.DataType: 

342 # Docstring inherited. 

343 return RegionArrowType() 

344 

345 def append(self, value: Region | None, column: list[bytes | None]) -> None: 

346 # Docstring inherited. 

347 column.append(value.encode() if value is not None else None) 

348 

349 def finish(self, column: list[Any]) -> pa.Array: 

350 # Docstring inherited. 

351 storage_array = pa.array(column, self.storage_type) 

352 return pa.ExtensionArray.from_storage(RegionArrowType(), storage_array) 

353 

354 

355class _ToArrowTimespan(ToArrow): 

356 """`ToArrow` implementation for `lsst.daf.butler.timespan` fields. 

357 

358 Should be constructed via the `ToArrow.for_timespan` factory method. 

359 """ 

360 

361 def __init__(self, name: str, nullable: bool): 

362 self._name = name 

363 self._nullable = nullable 

364 

365 storage_type: ClassVar[pa.DataType] = pa.struct( 

366 [ 

367 pa.field("begin_nsec", pa.int64(), nullable=False), 

368 pa.field("end_nsec", pa.int64(), nullable=False), 

369 ] 

370 ) 

371 

372 @property 

373 def name(self) -> str: 

374 # Docstring inherited. 

375 return self._name 

376 

377 @property 

378 def nullable(self) -> bool: 

379 # Docstring inherited. 

380 return self._nullable 

381 

382 @property 

383 def data_type(self) -> pa.DataType: 

384 # Docstring inherited. 

385 return TimespanArrowType() 

386 

387 def append(self, value: Timespan | None, column: list[dict[str, int] | None]) -> None: 

388 # Docstring inherited. 

389 column.append({"begin_nsec": value.nsec[0], "end_nsec": value.nsec[1]} if value is not None else None) 

390 

391 def finish(self, column: list[Any]) -> pa.Array: 

392 # Docstring inherited. 

393 storage_array = pa.array(column, self.storage_type) 

394 return pa.ExtensionArray.from_storage(TimespanArrowType(), storage_array) 

395 

396 

397class _ToArrowDateTime(ToArrow): 

398 """`ToArrow` implementation for `astropy.time.Time` fields. 

399 

400 Should be constructed via the `ToArrow.for_datetime` factory method. 

401 """ 

402 

403 def __init__(self, name: str, nullable: bool): 

404 self._name = name 

405 self._nullable = nullable 

406 

407 storage_type: ClassVar[pa.DataType] = pa.int64() 

408 

409 @property 

410 def name(self) -> str: 

411 # Docstring inherited. 

412 return self._name 

413 

414 @property 

415 def nullable(self) -> bool: 

416 # Docstring inherited. 

417 return self._nullable 

418 

419 @property 

420 def data_type(self) -> pa.DataType: 

421 # Docstring inherited. 

422 return DateTimeArrowType() 

423 

424 def append(self, value: astropy.time.Time | None, column: list[int | None]) -> None: 

425 # Docstring inherited. 

426 column.append(TimeConverter().astropy_to_nsec(value) if value is not None else None) 

427 

428 def finish(self, column: list[Any]) -> pa.Array: 

429 # Docstring inherited. 

430 storage_array = pa.array(column, self.storage_type) 

431 return pa.ExtensionArray.from_storage(DateTimeArrowType(), storage_array) 

432 

433 

434@final 

435class UUIDArrowType(pa.ExtensionType): 

436 """An Arrow extension type for `uuid.UUID`, stored as 16 bytes.""" 

437 

438 def __init__(self) -> None: 

439 super().__init__(_ToArrowUUID.storage_type, "uuid.UUID") 

440 

441 def __arrow_ext_serialize__(self) -> bytes: 

442 return b"" 

443 

444 @classmethod 

445 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> UUIDArrowType: 

446 return cls() 

447 

448 def __arrow_ext_scalar_class__(self) -> type[UUIDArrowScalar]: 

449 return UUIDArrowScalar 

450 

451 

452@final 

453class UUIDArrowScalar(pa.ExtensionScalar): 

454 """An Arrow scalar type for `uuid.UUID`. 

455 

456 Use the standard `as_py` method to convert to an actual `uuid.UUID` 

457 instance. 

458 """ 

459 

460 def as_py(self, **_unused: Any) -> uuid.UUID: 

461 return uuid.UUID(bytes=self.value.as_py()) 

462 

463 

464@final 

465class RegionArrowType(pa.ExtensionType): 

466 """An Arrow extension type for lsst.sphgeom.Region.""" 

467 

468 def __init__(self) -> None: 

469 super().__init__(_ToArrowRegion.storage_type, "lsst.sphgeom.Region") 

470 

471 def __arrow_ext_serialize__(self) -> bytes: 

472 return b"" 

473 

474 @classmethod 

475 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> RegionArrowType: 

476 return cls() 

477 

478 def __arrow_ext_scalar_class__(self) -> type[RegionArrowScalar]: 

479 return RegionArrowScalar 

480 

481 

482@final 

483class RegionArrowScalar(pa.ExtensionScalar): 

484 """An Arrow scalar type for `lsst.sphgeom.Region`. 

485 

486 Use the standard `as_py` method to convert to an actual region. 

487 """ 

488 

489 def as_py(self, **_unused: Any) -> Region: 

490 return Region.decode(self.value.as_py()) 

491 

492 

493@final 

494class TimespanArrowType(pa.ExtensionType): 

495 """An Arrow extension type for lsst.daf.butler.Timespan.""" 

496 

497 def __init__(self) -> None: 

498 super().__init__(_ToArrowTimespan.storage_type, "lsst.daf.butler.Timespan") 

499 

500 def __arrow_ext_serialize__(self) -> bytes: 

501 return b"" 

502 

503 @classmethod 

504 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> TimespanArrowType: 

505 return cls() 

506 

507 def __arrow_ext_scalar_class__(self) -> type[TimespanArrowScalar]: 

508 return TimespanArrowScalar 

509 

510 

511@final 

512class TimespanArrowScalar(pa.ExtensionScalar): 

513 """An Arrow scalar type for `lsst.daf.butler.Timespan`. 

514 

515 Use the standard `as_py` method to convert to an actual timespan. 

516 """ 

517 

518 def as_py(self, **_unused: Any) -> Timespan | None: 

519 if self.value is None: 

520 return None 

521 else: 

522 return Timespan( 

523 None, None, _nsec=(self.value["begin_nsec"].as_py(), self.value["end_nsec"].as_py()) 

524 ) 

525 

526 

527@final 

528class DateTimeArrowType(pa.ExtensionType): 

529 """An Arrow extension type for `astropy.time.Time`, stored as TAI 

530 nanoseconds since 1970-01-01. 

531 """ 

532 

533 def __init__(self) -> None: 

534 super().__init__(_ToArrowTimespan.storage_type, "astropy.time.Time") 

535 

536 def __arrow_ext_serialize__(self) -> bytes: 

537 return b"" 

538 

539 @classmethod 

540 def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> DateTimeArrowType: 

541 return cls() 

542 

543 def __arrow_ext_scalar_class__(self) -> type[DateTimeArrowScalar]: 

544 return DateTimeArrowScalar 

545 

546 

547@final 

548class DateTimeArrowScalar(pa.ExtensionScalar): 

549 """An Arrow scalar type for `astropy.time.Time`, stored as TAI 

550 nanoseconds since 1970-01-01. 

551 

552 Use the standard `as_py` method to convert to an actual `astropy.time.Time` 

553 instance. 

554 """ 

555 

556 def as_py(self, **_unused: Any) -> astropy.time.Time: 

557 return TimeConverter().nsec_to_astropy(self.value.as_py()) 

558 

559 

560class ArrowTableUtils: 

561 """Utility functions for manipulating `pyarrow.Table` instances.""" 

562 

563 @staticmethod 

564 def replace_column(table: pa.Table, column_name: str, new_column: pa.Array) -> pa.Table: 

565 """Return a new `pyarrow.Table` instance, replacing a given column in 

566 the table with a new one. 

567 

568 Parameters 

569 ---------- 

570 table 

571 Original arrow table. 

572 column_name 

573 Name of the column to be replaced. 

574 new_column 

575 Replacement arrow column. 

576 

577 Returns 

578 ------- 

579 table 

580 Copy of the given table with the column replaced. 

581 """ 

582 index = table.schema.get_field_index(column_name) 

583 if index < 0: 

584 raise ValueError( 

585 f"Column {column_name} not found in arrow table, or multiple columns have the same name." 

586 ) 

587 return table.set_column(index, column_name, new_column) 

588 

589 @staticmethod 

590 def modify_column( 

591 table: pa.Table, column_name: str, function: Callable[[pa.Array], pa.Array] 

592 ) -> pa.Table: 

593 """Return a new `pyarrow.Table` instance, applying a function to 

594 replace one of the columns with a new one. 

595 

596 Parameters 

597 ---------- 

598 table 

599 Original arrow table. 

600 column_name 

601 Name of the column to be replaced. 

602 function 

603 Function that takes an arrow array, and returns a modified version 

604 of that array. 

605 

606 Returns 

607 ------- 

608 table 

609 Copy of the given table with the column replaced with the value 

610 returned from the callback function. 

611 """ 

612 column = table.column(column_name) 

613 new_column = function(column) 

614 return ArrowTableUtils.replace_column(table, column_name, new_column) 

615 

616 

617pa.register_extension_type(RegionArrowType()) 

618pa.register_extension_type(TimespanArrowType()) 

619pa.register_extension_type(DateTimeArrowType())