Coverage for python / lsst / dax / apdb / cassandra / apdbCassandra.py: 9%

811 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 08:28 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbCassandra"] 

25 

26import datetime 

27import json 

28import logging 

29import random 

30import uuid 

31import warnings 

32from collections import Counter, defaultdict 

33from collections.abc import Iterable, Mapping, Set 

34from typing import TYPE_CHECKING, Any, cast 

35 

36import numpy as np 

37import pandas 

38 

39# If cassandra-driver is not there the module can still be imported 

40# but ApdbCassandra cannot be instantiated. 

41try: 

42 import cassandra 

43 import cassandra.query 

44 from cassandra.query import UNSET_VALUE 

45 

46 CASSANDRA_IMPORTED = True 

47except ImportError: 

48 CASSANDRA_IMPORTED = False 

49 

50import astropy.time 

51import felis.datamodel 

52 

53from lsst import sphgeom 

54from lsst.utils.iteration import chunk_iterable 

55 

56from ..apdb import Apdb, ApdbConfig 

57from ..apdbConfigFreezer import ApdbConfigFreezer 

58from ..apdbReplica import ApdbTableData, ReplicaChunk 

59from ..apdbSchema import ApdbSchema, ApdbTables 

60from ..apdbUpdateRecord import ( 

61 ApdbCloseDiaObjectValidityRecord, 

62 ApdbReassignDiaSourceToDiaObjectRecord, 

63 ApdbUpdateNDiaSourcesRecord, 

64) 

65from ..monitor import MonAgent 

66from ..recordIds import DiaObjectId, DiaSourceId 

67from ..schema_model import Table 

68from ..timer import Timer 

69from ..versionTuple import VersionTuple 

70from .apdbCassandraAdmin import ApdbCassandraAdmin 

71from .apdbCassandraReplica import ApdbCassandraReplica 

72from .apdbCassandraSchema import ApdbCassandraSchema, CreateTableOptions, ExtraTables 

73from .apdbMetadataCassandra import ApdbMetadataCassandra 

74from .cassandra_utils import ( 

75 ApdbCassandraTableData, 

76 execute_concurrent, 

77 literal, 

78 select_concurrent, 

79) 

80from .config import ApdbCassandraConfig, ApdbCassandraConnectionConfig, ApdbCassandraTimePartitionRange 

81from .connectionContext import ConnectionContext, DbVersions 

82from .exceptions import CassandraMissingError 

83from .partitioner import Partitioner 

84from .queries import Column as C # noqa: N817 

85from .queries import ColumnExpr, Delete, Insert, QExpr, Select, Update 

86from .sessionFactory import SessionContext, SessionFactory 

87 

88if TYPE_CHECKING: 

89 from ..apdbMetadata import ApdbMetadata 

90 from ..apdbUpdateRecord import ApdbUpdateRecord 

91 

92_LOG = logging.getLogger(__name__) 

93 

94_MON = MonAgent(__name__) 

95 

96VERSION = VersionTuple(1, 3, 0) 

97"""Version for the code controlling non-replication tables. This needs to be 

98updated following compatibility rules when schema produced by this code 

99changes. 

100""" 

101 

102 

103class ApdbCassandra(Apdb): 

104 """Implementation of APDB database with Apache Cassandra backend. 

105 

106 Parameters 

107 ---------- 

108 config : `ApdbCassandraConfig` 

109 Configuration object. 

110 """ 

111 

112 def __init__(self, config: ApdbCassandraConfig): 

113 if not CASSANDRA_IMPORTED: 

114 raise CassandraMissingError() 

115 

116 self._config = config 

117 self._keyspace = config.keyspace 

118 self._schema = ApdbSchema(config.schema_file, config.ss_schema_file) 

119 

120 self._session_factory = SessionFactory(config) 

121 self._connection_context: ConnectionContext | None = None 

122 

123 @property 

124 def _context(self) -> ConnectionContext: 

125 """Establish connection if not established and return context.""" 

126 if self._connection_context is None: 

127 current_versions = DbVersions( 

128 schema_version=self.schema.schemaVersion(), 

129 code_version=self.apdbImplementationVersion(), 

130 replica_version=ApdbCassandraReplica.apdbReplicaImplementationVersion(), 

131 ) 

132 _LOG.debug("Current versions: %s", current_versions) 

133 

134 session = self._session_factory.session() 

135 self._connection_context = ConnectionContext( 

136 session, self._config, self.schema.tableSchemas, current_versions 

137 ) 

138 

139 if _LOG.isEnabledFor(logging.DEBUG): 

140 _LOG.debug("ApdbCassandra Configuration: %s", self._connection_context.config.model_dump()) 

141 

142 return self._connection_context 

143 

144 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer: 

145 """Create `Timer` instance given its name.""" 

146 return Timer(name, _MON, tags=tags) 

147 

148 @classmethod 

149 def apdbImplementationVersion(cls) -> VersionTuple: 

150 """Return version number for current APDB implementation. 

151 

152 Returns 

153 ------- 

154 version : `VersionTuple` 

155 Version of the code defined in implementation class. 

156 """ 

157 return VERSION 

158 

159 def getConfig(self) -> ApdbCassandraConfig: 

160 # docstring is inherited from a base class 

161 return self._context.config 

162 

163 def tableDef(self, table: ApdbTables) -> Table | None: 

164 # docstring is inherited from a base class 

165 return self.schema.tableSchemas.get(table) 

166 

167 @classmethod 

168 def init_database( 

169 cls, 

170 hosts: tuple[str, ...], 

171 keyspace: str, 

172 *, 

173 schema_file: str | None = None, 

174 ss_schema_file: str | None = None, 

175 read_sources_months: int | None = None, 

176 read_forced_sources_months: int | None = None, 

177 enable_replica: bool = False, 

178 replica_skips_diaobjects: bool = False, 

179 port: int | None = None, 

180 username: str | None = None, 

181 dbauth_alias: str | None = None, 

182 prefix: str | None = None, 

183 part_pixelization: str | None = None, 

184 part_pix_level: int | None = None, 

185 time_partition_tables: bool = True, 

186 time_partition_start: str | None = None, 

187 time_partition_end: str | None = None, 

188 read_consistency: str | None = None, 

189 write_consistency: str | None = None, 

190 read_timeout: int | None = None, 

191 write_timeout: int | None = None, 

192 ra_dec_columns: tuple[str, str] | None = None, 

193 replication_factor: int | None = None, 

194 drop: bool = False, 

195 table_options: CreateTableOptions | None = None, 

196 ) -> ApdbCassandraConfig: 

197 """Initialize new APDB instance and make configuration object for it. 

198 

199 Parameters 

200 ---------- 

201 hosts : `tuple` [`str`, ...] 

202 List of host names or IP addresses for Cassandra cluster. 

203 keyspace : `str` 

204 Name of the keyspace for APDB tables. 

205 schema_file : `str`, optional 

206 Location of (YAML) configuration file with APDB schema. If not 

207 specified then default location will be used. 

208 ss_schema_file : `str`, optional 

209 Location of (YAML) configuration file with SSO schema. If not 

210 specified then default location will be used. 

211 read_sources_months : `int`, optional 

212 Number of months of history to read from DiaSource. 

213 read_forced_sources_months : `int`, optional 

214 Number of months of history to read from DiaForcedSource. 

215 enable_replica : `bool`, optional 

216 If True, make additional tables used for replication to PPDB. 

217 replica_skips_diaobjects : `bool`, optional 

218 If `True` then do not fill regular ``DiaObject`` table when 

219 ``enable_replica`` is `True`. 

220 port : `int`, optional 

221 Port number to use for Cassandra connections. 

222 username : `str`, optional 

223 User name for Cassandra connections. 

224 dbauth_alias : `str`, optional 

225 If specified then this string will be used to as a host name when 

226 checking credentials in db-auth.yaml in addition to regular host 

227 names in contact_points. For example if 

228 dbauth_alias='pp_apdb_prod_cluster' then the entry 

229 'cassandra://pp_apdb_prod_cluster/' will match. Port number should 

230 not be used in that entry. Alias has higher priority than host 

231 names. 

232 prefix : `str`, optional 

233 Optional prefix for all table names. 

234 part_pixelization : `str`, optional 

235 Name of the MOC pixelization used for partitioning. 

236 part_pix_level : `int`, optional 

237 Pixelization level. 

238 time_partition_tables : `bool`, optional 

239 Create per-partition tables. 

240 time_partition_start : `str`, optional 

241 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

242 format, in TAI. 

243 time_partition_end : `str`, optional 

244 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss 

245 format, in TAI. 

246 read_consistency : `str`, optional 

247 Name of the consistency level for read operations. 

248 write_consistency : `str`, optional 

249 Name of the consistency level for write operations. 

250 read_timeout : `int`, optional 

251 Read timeout in seconds. 

252 write_timeout : `int`, optional 

253 Write timeout in seconds. 

254 ra_dec_columns : `tuple` [`str`, `str`], optional 

255 Names of ra/dec columns in DiaObject table. 

256 replication_factor : `int`, optional 

257 Replication factor used when creating new keyspace, if keyspace 

258 already exists its replication factor is not changed. 

259 drop : `bool`, optional 

260 If `True` then drop existing tables before re-creating the schema. 

261 table_options : `CreateTableOptions`, optional 

262 Options used when creating Cassandra tables. 

263 

264 Returns 

265 ------- 

266 config : `ApdbCassandraConfig` 

267 Resulting configuration object for a created APDB instance. 

268 """ 

269 # Some non-standard defaults for connection parameters, these can be 

270 # changed later in generated config. Check Cassandra driver 

271 # documentation for what these parameters do. These parameters are not 

272 # used during database initialization, but they will be saved with 

273 # generated config. 

274 connection_config = ApdbCassandraConnectionConfig( 

275 extra_parameters={ 

276 "idle_heartbeat_interval": 0, 

277 "idle_heartbeat_timeout": 30, 

278 "control_connection_timeout": 100, 

279 }, 

280 ) 

281 config = ApdbCassandraConfig( 

282 contact_points=hosts, 

283 keyspace=keyspace, 

284 enable_replica=enable_replica, 

285 replica_skips_diaobjects=replica_skips_diaobjects, 

286 connection_config=connection_config, 

287 ) 

288 config.partitioning.time_partition_tables = time_partition_tables 

289 if schema_file is not None: 

290 config.schema_file = schema_file 

291 if ss_schema_file is not None: 

292 config.ss_schema_file = ss_schema_file 

293 if read_sources_months is not None: 

294 config.read_sources_months = read_sources_months 

295 if read_forced_sources_months is not None: 

296 config.read_forced_sources_months = read_forced_sources_months 

297 if port is not None: 

298 config.connection_config.port = port 

299 if username is not None: 

300 config.connection_config.username = username 

301 if dbauth_alias is not None: 

302 config.connection_config.dbauth_alias = dbauth_alias 

303 if prefix is not None: 

304 config.prefix = prefix 

305 if part_pixelization is not None: 

306 config.partitioning.part_pixelization = part_pixelization 

307 if part_pix_level is not None: 

308 config.partitioning.part_pix_level = part_pix_level 

309 if time_partition_start is not None: 

310 config.partitioning.time_partition_start = time_partition_start 

311 if time_partition_end is not None: 

312 config.partitioning.time_partition_end = time_partition_end 

313 if read_consistency is not None: 

314 config.connection_config.read_consistency = read_consistency 

315 if write_consistency is not None: 

316 config.connection_config.write_consistency = write_consistency 

317 if read_timeout is not None: 

318 config.connection_config.read_timeout = read_timeout 

319 if write_timeout is not None: 

320 config.connection_config.write_timeout = write_timeout 

321 if ra_dec_columns is not None: 

322 config.ra_dec_columns = ra_dec_columns 

323 

324 cls._makeSchema(config, drop=drop, replication_factor=replication_factor, table_options=table_options) 

325 

326 return config 

327 

328 def get_replica(self) -> ApdbCassandraReplica: 

329 """Return `ApdbReplica` instance for this database.""" 

330 # Note that this instance has to stay alive while replica exists, so 

331 # we pass reference to self. 

332 return ApdbCassandraReplica(self) 

333 

334 @classmethod 

335 def _makeSchema( 

336 cls, 

337 config: ApdbConfig, 

338 *, 

339 drop: bool = False, 

340 replication_factor: int | None = None, 

341 table_options: CreateTableOptions | None = None, 

342 ) -> None: 

343 # docstring is inherited from a base class 

344 

345 if not isinstance(config, ApdbCassandraConfig): 

346 raise TypeError(f"Unexpected type of configuration object: {type(config)}") 

347 

348 simple_schema = ApdbSchema(config.schema_file, config.ss_schema_file) 

349 

350 with SessionContext(config) as session: 

351 schema = ApdbCassandraSchema( 

352 session=session, 

353 keyspace=config.keyspace, 

354 table_schemas=simple_schema.tableSchemas, 

355 prefix=config.prefix, 

356 time_partition_tables=config.partitioning.time_partition_tables, 

357 enable_replica=config.enable_replica, 

358 replica_skips_diaobjects=config.replica_skips_diaobjects, 

359 ) 

360 

361 # Ask schema to create all tables. 

362 part_range_config: ApdbCassandraTimePartitionRange | None = None 

363 if config.partitioning.time_partition_tables: 

364 partitioner = Partitioner(config) 

365 time_partition_start = astropy.time.Time( 

366 config.partitioning.time_partition_start, format="isot", scale="tai" 

367 ) 

368 time_partition_end = astropy.time.Time( 

369 config.partitioning.time_partition_end, format="isot", scale="tai" 

370 ) 

371 part_range_config = ApdbCassandraTimePartitionRange( 

372 start=partitioner.time_partition(time_partition_start), 

373 end=partitioner.time_partition(time_partition_end), 

374 ) 

375 schema.makeSchema( 

376 drop=drop, 

377 part_range=part_range_config, 

378 replication_factor=replication_factor, 

379 table_options=table_options, 

380 ) 

381 else: 

382 schema.makeSchema( 

383 drop=drop, replication_factor=replication_factor, table_options=table_options 

384 ) 

385 

386 meta_table_name = ApdbTables.metadata.table_name(config.prefix) 

387 metadata = ApdbMetadataCassandra( 

388 session, meta_table_name, config.keyspace, "read_tuples", "write" 

389 ) 

390 

391 # Fill version numbers, overrides if they existed before. 

392 metadata.set( 

393 ConnectionContext.metadataSchemaVersionKey, str(simple_schema.schemaVersion()), force=True 

394 ) 

395 metadata.set( 

396 ConnectionContext.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True 

397 ) 

398 

399 if config.enable_replica: 

400 # Only store replica code version if replica is enabled. 

401 metadata.set( 

402 ConnectionContext.metadataReplicaVersionKey, 

403 str(ApdbCassandraReplica.apdbReplicaImplementationVersion()), 

404 force=True, 

405 ) 

406 

407 # Store frozen part of a configuration in metadata. 

408 freezer = ApdbConfigFreezer[ApdbCassandraConfig](ConnectionContext.frozen_parameters) 

409 metadata.set(ConnectionContext.metadataConfigKey, freezer.to_json(config), force=True) 

410 

411 # Store time partition range. 

412 if part_range_config: 

413 part_range_config.save_to_meta(metadata) 

414 

415 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame: 

416 # docstring is inherited from a base class 

417 context = self._context 

418 config = context.config 

419 

420 sp_where, num_sp_part = context.partitioner.spatial_where(region) 

421 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where)) 

422 

423 # We need to exclude extra partitioning columns from result. 

424 column_names = context.schema.apdbColumnNames(ApdbTables.DiaObjectLast) 

425 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

426 query = Select(self._keyspace, table_name, column_names) 

427 statements: list[tuple] = [] 

428 for where_clause in sp_where: 

429 full_query = query.where(where_clause) 

430 statements.append(context.stmt_factory.with_params(full_query, prepare=True)) 

431 _LOG.debug("getDiaObjects: #queries: %s", len(statements)) 

432 

433 with self._timer("select_time", tags={"table": "DiaObject", "method": "getDiaObjects"}) as timer: 

434 raw_objects = cast( 

435 ApdbCassandraTableData, 

436 select_concurrent( 

437 context.session, 

438 statements, 

439 "read_raw_multi", 

440 config.connection_config.read_concurrency, 

441 ), 

442 ) 

443 objects = raw_objects.to_pandas(context.schema._table_schema(ApdbTables.DiaObjectLast)) 

444 timer.add_values(row_count=len(objects), num_sp_part=num_sp_part, num_queries=len(statements)) 

445 

446 _LOG.debug("found %s DiaObjects", objects.shape[0]) 

447 return objects 

448 

449 def getDiaSources( 

450 self, 

451 region: sphgeom.Region, 

452 object_ids: Iterable[int] | None, 

453 visit_time: astropy.time.Time, 

454 start_time: astropy.time.Time | None = None, 

455 ) -> pandas.DataFrame | None: 

456 # docstring is inherited from a base class 

457 context = self._context 

458 config = context.config 

459 

460 months = config.read_sources_months 

461 if start_time is None and months == 0: 

462 return None 

463 

464 mjd_end = float(visit_time.tai.mjd) 

465 if start_time is None: 

466 mjd_start = mjd_end - months * 30 

467 else: 

468 mjd_start = float(start_time.tai.mjd) 

469 

470 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource) 

471 

472 def getDiaForcedSources( 

473 self, 

474 region: sphgeom.Region, 

475 object_ids: Iterable[int] | None, 

476 visit_time: astropy.time.Time, 

477 start_time: astropy.time.Time | None = None, 

478 ) -> pandas.DataFrame | None: 

479 # docstring is inherited from a base class 

480 context = self._context 

481 config = context.config 

482 

483 months = config.read_forced_sources_months 

484 if start_time is None and months == 0: 

485 return None 

486 

487 mjd_end = float(visit_time.tai.mjd) 

488 if start_time is None: 

489 mjd_start = mjd_end - months * 30 

490 else: 

491 mjd_start = float(start_time.tai.mjd) 

492 

493 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource) 

494 

495 def getDiaObjectsForDedup(self, since: astropy.time.Time | None = None) -> pandas.DataFrame: 

496 # docstring is inherited from a base class 

497 context = self._context 

498 config = context.config 

499 

500 if not context.has_dedup_table: 

501 raise TypeError("DiaObjectDedup table does not exist in this APDB instance.") 

502 

503 if since is None: 

504 # Read last deduplication time from metadata. 

505 dedup_str = context.metadata.get(context.metadataDedupKey) 

506 if dedup_str is not None: 

507 dedup_state = json.loads(dedup_str) 

508 dedup_time_str = dedup_state["dedup_time_iso_tai"] 

509 since = astropy.time.Time(dedup_time_str, format="iso", scale="tai") 

510 

511 column_names = context.schema.apdbColumnNames(ExtraTables.DiaObjectDedup) 

512 

513 validity_start_column = self._timestamp_column_name("validityStart") 

514 timestamp = None if since is None else self._timestamp_column_value(since) 

515 

516 table_name = context.schema.tableName(ExtraTables.DiaObjectDedup) 

517 query = Select(self._keyspace, table_name, column_names, extra_clause="ALLOW FILTERING") 

518 query = query.where(C("dedup_part") == 0) 

519 if since is not None: 

520 query = query.where(C(validity_start_column) >= 0) 

521 

522 statement = context.stmt_factory(query, prepare=False) 

523 

524 num_part = config.partitioning.num_part_dedup 

525 statements = [] 

526 for dedup_part in range(num_part): 

527 params = (dedup_part,) if timestamp is None else (dedup_part, timestamp) 

528 statements.append((statement, params)) 

529 

530 with self._timer( 

531 "select_time", tags={"table": "DiaObjectDedup", "method": "getDiaObjectsForDedup"} 

532 ) as timer: 

533 objects_raw = cast( 

534 ApdbCassandraTableData, 

535 select_concurrent( 

536 context.session, 

537 statements, 

538 "read_raw_multi_dedup", 

539 config.connection_config.read_concurrency, 

540 ), 

541 ) 

542 objects = objects_raw.to_pandas(context.schema._table_schema(ExtraTables.DiaObjectDedup)) 

543 timer.add_values(row_count=len(objects), num_queries=num_part) 

544 

545 _LOG.debug("found %s DiaObjectDedup records", objects.shape[0]) 

546 return objects 

547 

548 def getDiaSourcesForDiaObjects( 

549 self, objects: list[DiaObjectId], start_time: astropy.time.Time, max_dist_arcsec: float = 1.0 

550 ) -> pandas.DataFrame: 

551 # docstring is inherited from a base class 

552 context = self._context 

553 config = context.config 

554 

555 # Which tables to query and temporal constraints. 

556 end_time = self._current_time() 

557 tables, temporal_where = context.partitioner.temporal_where( 

558 ApdbTables.DiaSource, 

559 start_time, 

560 end_time, 

561 partitons_range=context.time_partitions_range, 

562 query_per_time_part=False, 

563 ) 

564 if not tables: 

565 warnings.warn( 

566 f"Query time range ({start_time.isot} - {end_time.isot}) does not overlap database " 

567 "time partitions." 

568 ) 

569 

570 # Group DiaObjects by partition. 

571 partitioned_object_ids = self._group_dia_objects_by_partition( 

572 context.partitioner, objects, max_dist_arcsec 

573 ) 

574 

575 # Columns to return. 

576 column_names = context.schema.apdbColumnNames(ApdbTables.DiaSource) 

577 

578 # Make a bunch of queries. 

579 statements = [] 

580 for apdb_part, diaObjectIds in partitioned_object_ids.items(): 

581 spatial_where = [C("apdb_part") == apdb_part] 

582 for table in tables: 

583 query = Select(self._keyspace, table, column_names, extra_clause="ALLOW FILTERING") 

584 for id_chunk in chunk_iterable(diaObjectIds, 10_000): 

585 id_where = C("diaObjectId").in_(id_chunk) 

586 for clause in QExpr.combine(spatial_where, temporal_where, extra=id_where): 

587 statements.append( 

588 context.stmt_factory.with_params(query.where(clause), prepare=False) 

589 ) 

590 

591 _LOG.debug("getDiaSourcesForDiaObjects #queries: %s", len(statements)) 

592 

593 with self._timer( 

594 "select_time", tags={"table": "DiaSource", "method": "getDiaSourcesForDiaObjects"} 

595 ) as timer: 

596 table_data_raw = cast( 

597 ApdbCassandraTableData, 

598 select_concurrent( 

599 context.session, 

600 statements, 

601 "read_raw_multi", 

602 config.connection_config.read_concurrency, 

603 ), 

604 ) 

605 catalog = table_data_raw.to_pandas(context.schema._table_schema(ApdbTables.DiaSource)) 

606 timer.add_values(row_count_from_db=len(catalog), num_queries=len(statements)) 

607 

608 # precise filtering on midpointMjdTai 

609 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] >= start_time.tai.mjd]) 

610 

611 timer.add_values(row_count=len(catalog)) 

612 

613 _LOG.debug("found %d DiaSources", len(catalog)) 

614 return catalog 

615 

616 def containsVisitDetector( 

617 self, 

618 visit: int, 

619 detector: int, 

620 region: sphgeom.Region, 

621 visit_time: astropy.time.Time, 

622 ) -> bool: 

623 # docstring is inherited from a base class 

624 context = self._context 

625 

626 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector) 

627 query = Select(self._keyspace, table_name, [ColumnExpr("count(*)")]) 

628 query = query.where((C("visit") == visit) & (C("detector") == detector)) 

629 stmt, params = context.stmt_factory.with_params(query, prepare=False) 

630 

631 with self._timer("contains_visit_detector_time", tags={"table": table_name}): 

632 result = context.session.execute(stmt, params) 

633 return bool(result.one()[0]) 

634 

635 def store( 

636 self, 

637 visit_time: astropy.time.Time, 

638 objects: pandas.DataFrame, 

639 sources: pandas.DataFrame | None = None, 

640 forced_sources: pandas.DataFrame | None = None, 

641 ) -> None: 

642 # docstring is inherited from a base class 

643 context = self._context 

644 config = context.config 

645 

646 # Store visit/detector in a special table, this has to be done 

647 # before all other writes so if there is a failure at any point 

648 # later we still have a record for attempted write. 

649 visit_detector: set[tuple[int, int]] = set() 

650 for df in sources, forced_sources: 

651 if df is not None and not df.empty: 

652 df = df[["visit", "detector"]] 

653 for visit, detector in df.itertuples(index=False): 

654 visit_detector.add((visit, detector)) 

655 

656 if visit_detector: 

657 # Typically there is only one entry, do not bother with 

658 # concurrency. 

659 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector) 

660 query = Insert(self._keyspace, table_name, ("visit", "detector")) 

661 stmt = context.stmt_factory(query) 

662 for item in visit_detector: 

663 context.session.execute(stmt, item, execution_profile="write") 

664 

665 objects = self._fix_input_timestamps(objects) 

666 if sources is not None: 

667 sources = self._fix_input_timestamps(sources) 

668 if forced_sources is not None: 

669 forced_sources = self._fix_input_timestamps(forced_sources) 

670 

671 replica_chunk: ReplicaChunk | None = None 

672 if context.schema.replication_enabled: 

673 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, config.replica_chunk_seconds) 

674 self._storeReplicaChunk(replica_chunk) 

675 

676 # fill region partition column for DiaObjects 

677 objects = self._add_apdb_part(objects) 

678 self._storeDiaObjects(objects, visit_time, replica_chunk) 

679 

680 if sources is not None and len(sources) > 0: 

681 # copy apdb_part column from DiaObjects to DiaSources 

682 sources = self._add_apdb_part(sources) 

683 subchunk = self._storeDiaSources(ApdbTables.DiaSource, sources, replica_chunk) 

684 self._storeDiaSourcesPartitions(sources, visit_time, replica_chunk, subchunk) 

685 

686 if forced_sources is not None and len(forced_sources) > 0: 

687 forced_sources = self._add_apdb_part(forced_sources) 

688 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, replica_chunk) 

689 

690 def reassignDiaSourcesToDiaObjects( 

691 self, 

692 idMap: Mapping[DiaSourceId, int], 

693 *, 

694 increment_nDiaSources: bool = True, 

695 decrement_nDiaSources: bool = True, 

696 ) -> None: 

697 # docstring is inherited from a base class 

698 context = self._context 

699 config = context.config 

700 

701 source_ids = {source_id.diaSourceId for source_id in idMap} 

702 

703 # Find all DiaSources. 

704 found_sources = self._get_diasource_data( 

705 idMap, "apdb_part", "diaObjectId", "ra", "dec", "midpointMjdTai" 

706 ) 

707 

708 if missing_ids := (source_ids - {row.diaSourceId for row in found_sources}): 

709 raise LookupError(f"Some source IDs were not found in DiaSource table: {missing_ids}") 

710 

711 found_sources_by_id = {row.diaSourceId: row for row in found_sources} 

712 

713 # Make sure that all DiaObjects exist, we also want to know 

714 # nDiaSources count for current and new records because we want to 

715 # send updated values to replica. 

716 current_object_ids = { 

717 DiaObjectId(diaObjectId=row.diaObjectId, ra=row.ra, dec=row.dec) for row in found_sources 

718 } 

719 # Assume that DiaSource ra/dec are very close to re-assigned objects. 

720 new_object_ids = { 

721 DiaObjectId(diaObjectId=diaObjectId, ra=source_id.ra, dec=source_id.dec) 

722 for source_id, diaObjectId in idMap.items() 

723 } 

724 all_object_ids = new_object_ids | current_object_ids 

725 found_objects = self._get_diaobject_data(all_object_ids, "apdb_part", "ra", "dec", "nDiaSources") 

726 

727 if missing_ids := ( 

728 {row.diaObjectId for row in all_object_ids} - {row.diaObjectId for row in found_objects} 

729 ): 

730 raise LookupError(f"Some object IDs were not found in DiaObjectLast table: {missing_ids}") 

731 

732 update_records: list[ApdbUpdateRecord] = [] 

733 update_order = 0 

734 current_time = self._current_time() 

735 current_time_ns = int(current_time.unix_tai * 1e9) 

736 

737 # Update DiaSources. 

738 statements: list[tuple] = [] 

739 for source_id, diaObjectId in idMap.items(): 

740 source_row = found_sources_by_id[source_id.diaSourceId] 

741 apdb_part = source_row.apdb_part 

742 time_part = context.partitioner.time_partition(source_row.midpointMjdTai) 

743 

744 if config.partitioning.time_partition_tables: 

745 table_name = context.schema.tableName(ApdbTables.DiaSource, time_part) 

746 update = ( 

747 Update(self._keyspace, table_name) 

748 .values(C("diaObjectId").update(diaObjectId)) 

749 .where(C("apdb_part") == apdb_part) 

750 .where(C("diaSourceId") == source_id.diaSourceId) 

751 ) 

752 else: 

753 table_name = context.schema.tableName(ApdbTables.DiaSource) 

754 update = ( 

755 Update(self._keyspace, table_name) 

756 .values(C("diaObjectId").update(diaObjectId)) 

757 .where(C("apdb_part") == apdb_part) 

758 .where(C("apdb_time_part") == time_part) 

759 .where(C("diaSourceId") == source_id.diaSourceId) 

760 ) 

761 statements.append(context.stmt_factory.with_params(update, prepare=True)) 

762 

763 if context.schema.replication_enabled: 

764 update_records.append( 

765 ApdbReassignDiaSourceToDiaObjectRecord( 

766 diaSourceId=source_id.diaSourceId, 

767 ra=source_id.ra, 

768 dec=source_id.dec, 

769 midpointMjdTai=source_id.midpointMjdTai, 

770 diaObjectId=diaObjectId, 

771 update_time_ns=current_time_ns, 

772 update_order=update_order, 

773 ) 

774 ) 

775 update_order += 1 

776 

777 with self._timer( 

778 "update_time", tags={"table": "DiaSource", "method": "reassignDiaSourcesToDiaObjects"} 

779 ) as timer: 

780 execute_concurrent(context.session, statements, execution_profile="write") 

781 timer.add_values(num_queries=len(statements)) 

782 

783 # Update nDiaSources in DiaObjectLast. We do not update DiaObject table 

784 # here because it may not even exist. PPDB updates DiaObject from 

785 # update records. 

786 if increment_nDiaSources or decrement_nDiaSources: 

787 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

788 update = ( 

789 Update(self._keyspace, table_name) 

790 .values(C("nDiaSources").update(-1)) 

791 .where(C("apdb_part") == -1) 

792 .where(C("diaObjectId") == -1) 

793 ) 

794 statement = context.stmt_factory(update, prepare=True) 

795 statements = [] 

796 

797 # Calculate increments/decrements for all affected DiaObjects. 

798 increments: Counter = Counter() 

799 if increment_nDiaSources: 

800 increments.update(idMap.values()) 

801 if decrement_nDiaSources: 

802 increments.subtract(row.diaObjectId for row in found_sources) 

803 

804 for row in found_objects: 

805 if increments.get(row.diaObjectId): 

806 nDiaSources = row.nDiaSources + increments[row.diaObjectId] 

807 statements.append((statement, (nDiaSources, row.apdb_part, row.diaObjectId))) 

808 

809 # Also send updated values to replica. 

810 if context.schema.replication_enabled: 

811 update_records.append( 

812 ApdbUpdateNDiaSourcesRecord( 

813 diaObjectId=row.diaObjectId, 

814 ra=row.ra, 

815 dec=row.dec, 

816 nDiaSources=nDiaSources, 

817 update_time_ns=current_time_ns, 

818 update_order=update_order, 

819 ) 

820 ) 

821 update_order += 1 

822 

823 if statements: 

824 with self._timer( 

825 "update_time", tags={"table": table_name, "method": "reassignDiaSourcesToDiaObjects"} 

826 ) as timer: 

827 execute_concurrent(context.session, statements, execution_profile="write") 

828 timer.add_values(num_queries=len(statements)) 

829 

830 if update_records: 

831 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds) 

832 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True) 

833 

834 def setValidityEnd( 

835 self, objects: list[DiaObjectId], validityEnd: astropy.time.Time, raise_on_missing_id: bool = False 

836 ) -> int: 

837 # docstring is inherited from a base class 

838 if not objects: 

839 return 0 

840 

841 context = self._context 

842 config = context.config 

843 

844 pad_arcsec = 1.0 

845 partitioned_object_ids = self._group_dia_objects_by_partition( 

846 context.partitioner, objects, pad_arcsec 

847 ) 

848 

849 # Check that all objects exist. 

850 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

851 statements: list[tuple] = [] 

852 for apdb_part, diaObjectIds in partitioned_object_ids.items(): 

853 query = Select(self._keyspace, table_name, ["apdb_part", "diaObjectId"]) 

854 query = query.where(C("apdb_part") == apdb_part) 

855 query = query.where(C("diaObjectId").in_(diaObjectIds)) 

856 statements.append(context.stmt_factory.with_params(query, prepare=False)) 

857 

858 with self._timer("select_time", tags={"table": table_name, "method": "setValidityEnd"}) as timer: 

859 records = cast( 

860 list[tuple[int, int]], 

861 select_concurrent( 

862 context.session, 

863 statements, 

864 "read_tuples", 

865 config.connection_config.read_concurrency, 

866 ), 

867 ) 

868 timer.add_values(row_count=len(objects)) 

869 

870 requested_ids = {obj.diaObjectId for obj in objects} 

871 found_ids = {rec[1] for rec in records} 

872 if extra_ids := (found_ids - requested_ids): 

873 raise RuntimeError(f"Consistency error - found duplicate records for object IDs: {extra_ids}") 

874 if raise_on_missing_id: 

875 if missing_ids := (requested_ids - found_ids): 

876 raise LookupError(f"Some object IDs are missing from DiaObjectLast table: {missing_ids}") 

877 

878 # Filter existing records. 

879 if len(objects) != len(found_ids): 

880 objects = [obj for obj in objects if obj.diaObjectId in found_ids] 

881 

882 if not objects: 

883 return 0 

884 

885 # Group by partitions again. 

886 grouped_object_ids: dict[int, list[int]] = defaultdict(list) 

887 for apdb_part, diaObjectId in records: 

888 grouped_object_ids[apdb_part].append(diaObjectId) 

889 

890 # Remove all matching rows from DiaObjectLast. 

891 statements = [] 

892 for apdb_part, diaObjectIds in grouped_object_ids.items(): 

893 delete = ( 

894 Delete(self._keyspace, table_name) 

895 .where(C("apdb_part") == apdb_part) 

896 .where(C("diaObjectId").in_(diaObjectIds)) 

897 ) 

898 statements.append(context.stmt_factory.with_params(delete)) 

899 

900 # Also remove from DiaObjectLastToPartition. 

901 reverse_table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition) 

902 delete = Delete(self._keyspace, reverse_table_name).where( 

903 C("diaObjectId").in_([rec[1] for rec in records]) 

904 ) 

905 statements.append(context.stmt_factory.with_params(delete)) 

906 

907 with self._timer("delete_time", tags={"table": table_name, "method": "setValidityEnd"}) as timer: 

908 execute_concurrent(context.session, statements, execution_profile="write") 

909 timer.add_values(row_count=len(records)) 

910 

911 # If repication is enabled then send all updates. 

912 if context.schema.replication_enabled: 

913 current_time = self._current_time() 

914 current_time_ns = int(current_time.unix_tai * 1e9) 

915 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds) 

916 

917 update_records = [ 

918 ApdbCloseDiaObjectValidityRecord( 

919 diaObjectId=obj.diaObjectId, 

920 ra=obj.ra, 

921 dec=obj.dec, 

922 update_time_ns=current_time_ns, 

923 update_order=index, 

924 validityEndMjdTai=float(validityEnd.tai.mjd), 

925 nDiaSources=None, 

926 ) 

927 for index, obj in enumerate(objects) 

928 ] 

929 

930 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True) 

931 

932 return len(objects) 

933 

934 def resetDedup(self, dedup_time: astropy.time.Time | None = None) -> None: 

935 # docstring is inherited from a base class 

936 context = self._context 

937 

938 if not context.has_dedup_table: 

939 raise TypeError("DiaObjectDedup table does not exist in this APDB instance.") 

940 

941 if dedup_time is None: 

942 dedup_time = self._current_time() 

943 

944 validity_start_column = self._timestamp_column_name("validityStart") 

945 

946 # Find latest timestamp in deduplication table. 

947 table_name = context.schema.tableName(ExtraTables.DiaObjectDedup) 

948 query = Select(self._keyspace, table_name, [ColumnExpr(f'MAX("{validity_start_column}")')]) 

949 stmt = context.stmt_factory(query, prepare=False) 

950 

951 result = context.session.execute(stmt, execution_profile="read_tuples") 

952 max_value = result.one()[0] 

953 if self._schema.has_mjd_timestamps: 

954 max_validity_start = astropy.time.Time(max_value, format="mjd", scale="tai") 

955 else: 

956 max_validity_start = astropy.time.Time(max_value, format="datetime", scale="tai") 

957 

958 # If max time is lower than dedup time we can do TRUNCATE. 

959 if dedup_time >= max_validity_start: 

960 query_str = f'TRUNCATE TABLE "{self._keyspace}"."{table_name}"' 

961 context.session.execute(query_str, execution_profile="write") 

962 else: 

963 dedup_time_value = self._timestamp_column_value(dedup_time) 

964 delete = Delete(self._keyspace, table_name).where(C(validity_start_column) < dedup_time_value) 

965 stmt, params = context.stmt_factory.with_params(delete) 

966 context.session.execute(stmt, params, execution_profile="write") 

967 

968 # Store dedup time. 

969 data = {"dedup_time_iso_tai": dedup_time.tai.to_value("iso")} 

970 data_json = json.dumps(data) 

971 context.metadata.set(context.metadataDedupKey, data_json, force=True) 

972 

973 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None: 

974 # docstring is inherited from a base class 

975 context = self._context 

976 config = context.config 

977 

978 now = self._current_time() 

979 reassign_time_column = self._timestamp_column_name("ssObjectReassocTime") 

980 reassignTime = self._timestamp_column_value(now) 

981 

982 # To update a record we need to know its exact primary key (including 

983 # partition key) so we start by querying for diaSourceId to find the 

984 # primary keys. 

985 

986 table_name = context.schema.tableName(ExtraTables.DiaSourceToPartition) 

987 # split it into 1k IDs per query 

988 selects: list[tuple] = [] 

989 columns = ["diaSourceId", "apdb_part", "apdb_time_part", "apdb_replica_chunk"] 

990 query = Select(self._keyspace, table_name, columns) 

991 for ids in chunk_iterable(idMap.keys(), 1_000): 

992 full_query = query.where(C("diaSourceId").in_(ids)) 

993 selects.append(context.stmt_factory.with_params(full_query, prepare=False)) 

994 

995 # No need for DataFrame here, read data as tuples. 

996 result = cast( 

997 list[tuple[int, int, int, int | None]], 

998 select_concurrent( 

999 context.session, selects, "read_tuples", config.connection_config.read_concurrency 

1000 ), 

1001 ) 

1002 

1003 # Make mapping from source ID to its partition. 

1004 id2partitions: dict[int, tuple[int, int]] = {} 

1005 id2chunk_id: dict[int, int] = {} 

1006 for row in result: 

1007 id2partitions[row[0]] = row[1:3] 

1008 if row[3] is not None: 

1009 id2chunk_id[row[0]] = row[3] 

1010 

1011 # make sure we know partitions for each ID 

1012 if set(id2partitions) != set(idMap): 

1013 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions)) 

1014 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}") 

1015 

1016 # Reassign in standard tables 

1017 queries: list[tuple[cassandra.query.PreparedStatement, tuple]] = [] 

1018 for diaSourceId, ssObjectId in idMap.items(): 

1019 apdb_part, apdb_time_part = id2partitions[diaSourceId] 

1020 if config.partitioning.time_partition_tables: 

1021 table_name = context.schema.tableName(ApdbTables.DiaSource, apdb_time_part) 

1022 update = ( 

1023 Update(self._keyspace, table_name) 

1024 .values( 

1025 C("ssObjectId").update(ssObjectId), 

1026 C("diaObjectId").update(None), 

1027 C(reassign_time_column).update(reassignTime), 

1028 ) 

1029 .where(C("apdb_part") == apdb_part) 

1030 .where(C("diaSourceId") == diaSourceId) 

1031 ) 

1032 else: 

1033 table_name = context.schema.tableName(ApdbTables.DiaSource) 

1034 update = ( 

1035 Update(self._keyspace, table_name) 

1036 .values( 

1037 C("ssObjectId").update(ssObjectId), 

1038 C("diaObjectId").update(None), 

1039 C(reassign_time_column).update(reassignTime), 

1040 ) 

1041 .where(C("apdb_part") == apdb_part) 

1042 .where(C("apdb_time_part") == apdb_time_part) 

1043 .where(C("diaSourceId") == diaSourceId) 

1044 ) 

1045 queries.append(context.stmt_factory.with_params(update, prepare=True)) 

1046 

1047 # TODO: (DM-50190) Replication for updated records is not implemented. 

1048 if id2chunk_id: 

1049 warnings.warn("Replication of reassigned DiaSource records is not implemented.", stacklevel=2) 

1050 

1051 _LOG.debug("%s: will update %d records", table_name, len(idMap)) 

1052 with self._timer("source_reassign_time") as timer: 

1053 execute_concurrent(context.session, queries, execution_profile="write") 

1054 timer.add_values(source_count=len(idMap)) 

1055 

1056 def countUnassociatedObjects(self) -> int: 

1057 # docstring is inherited from a base class 

1058 

1059 # It's too inefficient to implement it for Cassandra in current schema. 

1060 raise NotImplementedError() 

1061 

1062 @property 

1063 def schema(self) -> ApdbSchema: 

1064 # docstring is inherited from a base class 

1065 return self._schema 

1066 

1067 @property 

1068 def metadata(self) -> ApdbMetadata: 

1069 # docstring is inherited from a base class 

1070 context = self._context 

1071 return context.metadata 

1072 

1073 @property 

1074 def admin(self) -> ApdbCassandraAdmin: 

1075 # docstring is inherited from a base class 

1076 return ApdbCassandraAdmin(self) 

1077 

1078 def _getSources( 

1079 self, 

1080 region: sphgeom.Region, 

1081 object_ids: Iterable[int] | None, 

1082 mjd_start: float, 

1083 mjd_end: float, 

1084 table_name: ApdbTables, 

1085 ) -> pandas.DataFrame: 

1086 """Return catalog of DiaSource instances given set of DiaObject IDs. 

1087 

1088 Parameters 

1089 ---------- 

1090 region : `lsst.sphgeom.Region` 

1091 Spherical region. 

1092 object_ids : 

1093 Collection of DiaObject IDs 

1094 mjd_start : `float` 

1095 Lower bound of time interval. 

1096 mjd_end : `float` 

1097 Upper bound of time interval. 

1098 table_name : `ApdbTables` 

1099 Name of the table. 

1100 

1101 Returns 

1102 ------- 

1103 catalog : `pandas.DataFrame`, or `None` 

1104 Catalog containing DiaSource records. Empty catalog is returned if 

1105 ``object_ids`` is empty. 

1106 """ 

1107 context = self._context 

1108 config = context.config 

1109 

1110 object_id_set: Set[int] = set() 

1111 if object_ids is not None: 

1112 object_id_set = set(object_ids) 

1113 if len(object_id_set) == 0: 

1114 return self._make_empty_catalog(table_name) 

1115 

1116 sp_where, num_sp_part = context.partitioner.spatial_where(region) 

1117 tables, temporal_where = context.partitioner.temporal_where( 

1118 table_name, mjd_start, mjd_end, partitons_range=context.time_partitions_range 

1119 ) 

1120 if not tables: 

1121 start = astropy.time.Time(mjd_start, format="mjd", scale="tai") 

1122 end = astropy.time.Time(mjd_end, format="mjd", scale="tai") 

1123 warnings.warn( 

1124 f"Query time range ({start.isot} - {end.isot}) does not overlap database time partitions." 

1125 ) 

1126 

1127 # We need to exclude extra partitioning columns from result. 

1128 column_names = context.schema.apdbColumnNames(table_name) 

1129 

1130 # Build all queries 

1131 statements: list[tuple] = [] 

1132 for table in tables: 

1133 query = Select(self._keyspace, table, column_names) 

1134 for clause in QExpr.combine(sp_where, temporal_where): 

1135 statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True)) 

1136 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements)) 

1137 

1138 with self._timer("select_time", tags={"table": table_name.name, "method": "_getSources"}) as timer: 

1139 table_data_raw = cast( 

1140 ApdbCassandraTableData, 

1141 select_concurrent( 

1142 context.session, 

1143 statements, 

1144 "read_raw_multi", 

1145 config.connection_config.read_concurrency, 

1146 ), 

1147 ) 

1148 catalog = table_data_raw.to_pandas(context.schema._table_schema(table_name)) 

1149 timer.add_values( 

1150 row_count_from_db=len(catalog), num_sp_part=num_sp_part, num_queries=len(statements) 

1151 ) 

1152 

1153 # filter by given object IDs 

1154 if len(object_id_set) > 0: 

1155 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)]) 

1156 

1157 # precise filtering on midpointMjdTai 

1158 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start]) 

1159 

1160 timer.add_values(row_count=len(catalog)) 

1161 

1162 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name) 

1163 return catalog 

1164 

1165 def _storeReplicaChunk(self, replica_chunk: ReplicaChunk) -> None: 

1166 context = self._context 

1167 config = context.config 

1168 

1169 # Cassandra timestamp uses milliseconds since epoch 

1170 timestamp = int(replica_chunk.last_update_time.unix_tai * 1000) 

1171 

1172 # everything goes into a single partition 

1173 partition = 0 

1174 

1175 table_name = context.schema.tableName(ExtraTables.ApdbReplicaChunks) 

1176 

1177 columns = ["partition", "apdb_replica_chunk", "last_update_time", "unique_id"] 

1178 values = [partition, replica_chunk.id, timestamp, replica_chunk.unique_id] 

1179 if context.has_chunk_sub_partitions: 

1180 columns.append("has_subchunks") 

1181 values.append(True) 

1182 

1183 query = Insert(self._keyspace, table_name, columns) 

1184 stmt = context.stmt_factory(query) 

1185 

1186 context.session.execute( 

1187 stmt, 

1188 values, 

1189 timeout=config.connection_config.write_timeout, 

1190 execution_profile="write", 

1191 ) 

1192 

1193 def _queryDiaObjectLastPartitions(self, ids: Iterable[int]) -> Mapping[int, int]: 

1194 """Return existing mapping of diaObjectId to its last partition.""" 

1195 context = self._context 

1196 config = context.config 

1197 

1198 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition) 

1199 queries = [] 

1200 object_count = 0 

1201 for id_chunk in chunk_iterable(ids, 10_000): 

1202 id_chunk_list = tuple(id_chunk) 

1203 query = Select(self._keyspace, table_name, ("diaObjectId", "apdb_part")) 

1204 query = query.where(C("diaObjectId").in_(id_chunk_list)) 

1205 queries.append(context.stmt_factory.with_params(query, prepare=False)) 

1206 object_count += len(id_chunk_list) 

1207 

1208 with self._timer("query_object_last_partitions", tags={"table": table_name}) as timer: 

1209 data = cast( 

1210 ApdbTableData, 

1211 select_concurrent( 

1212 context.session, 

1213 queries, 

1214 "read_raw_multi", 

1215 config.connection_config.read_concurrency, 

1216 ), 

1217 ) 

1218 timer.add_values(object_count=object_count, row_count=len(data.rows())) 

1219 

1220 if data.column_names() != ["diaObjectId", "apdb_part"]: 

1221 raise RuntimeError(f"Unexpected column names in query result: {data.column_names()}") 

1222 

1223 return {row[0]: row[1] for row in data.rows()} 

1224 

1225 def _deleteMovingObjects(self, objs: pandas.DataFrame) -> None: 

1226 """Objects in DiaObjectsLast can move from one spatial partition to 

1227 another. For those objects inserting new version does not replace old 

1228 one, so we need to explicitly remove old versions before inserting new 

1229 ones. 

1230 """ 

1231 context = self._context 

1232 

1233 # Extract all object IDs. 

1234 new_partitions = dict(zip(objs["diaObjectId"], objs["apdb_part"])) 

1235 old_partitions = self._queryDiaObjectLastPartitions(objs["diaObjectId"]) 

1236 

1237 moved_oids: dict[int, tuple[int, int]] = {} 

1238 for oid, old_part in old_partitions.items(): 

1239 new_part = new_partitions.get(oid, old_part) 

1240 if new_part != old_part: 

1241 moved_oids[oid] = (old_part, new_part) 

1242 _LOG.debug("DiaObject IDs that moved to new partition: %s", moved_oids) 

1243 

1244 if moved_oids: 

1245 # Delete old records from DiaObjectLast. 

1246 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

1247 query = Delete(self._keyspace, table_name) 

1248 query = query.where('apdb_part = {} AND "diaObjectId" = {}', (-1, -1)) 

1249 statement = context.stmt_factory(query, prepare=True) 

1250 queries = [] 

1251 for oid, (old_part, _) in moved_oids.items(): 

1252 queries.append((statement, (old_part, oid))) 

1253 with self._timer("delete_object_last", tags={"table": table_name}) as timer: 

1254 execute_concurrent(context.session, queries, execution_profile="write") 

1255 timer.add_values(row_count=len(moved_oids)) 

1256 

1257 # Add all new records to the map. 

1258 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition) 

1259 insert = Insert(self._keyspace, table_name, ("diaObjectId", "apdb_part")) 

1260 statement = context.stmt_factory(insert, prepare=True) 

1261 

1262 queries = [] 

1263 for oid, new_part in new_partitions.items(): 

1264 queries.append((statement, (oid, new_part))) 

1265 

1266 with self._timer("update_object_last_partition", tags={"table": table_name}) as timer: 

1267 execute_concurrent(context.session, queries, execution_profile="write") 

1268 timer.add_values(row_count=len(queries)) 

1269 

1270 def _storeDiaObjects( 

1271 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None 

1272 ) -> None: 

1273 """Store catalog of DiaObjects from current visit. 

1274 

1275 Parameters 

1276 ---------- 

1277 objs : `pandas.DataFrame` 

1278 Catalog with DiaObject records 

1279 visit_time : `astropy.time.Time` 

1280 Time of the current visit. 

1281 replica_chunk : `ReplicaChunk` or `None` 

1282 Replica chunk identifier if replication is configured. 

1283 """ 

1284 if len(objs) == 0: 

1285 _LOG.debug("No objects to write to database.") 

1286 return 

1287 

1288 context = self._context 

1289 config = context.config 

1290 

1291 self._deleteMovingObjects(objs) 

1292 

1293 validity_start_column = self._timestamp_column_name("validityStart") 

1294 timestamp = self._timestamp_column_value(visit_time) 

1295 

1296 # DiaObjectLast did not have this column in the past. 

1297 extra_columns: dict[str, Any] = {} 

1298 if context.schema.check_column(ApdbTables.DiaObjectLast, validity_start_column): 

1299 extra_columns[validity_start_column] = timestamp 

1300 

1301 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns) 

1302 

1303 extra_columns[validity_start_column] = timestamp 

1304 visit_time_part = context.partitioner.time_partition(visit_time) 

1305 time_part: int | None = visit_time_part 

1306 if (time_partitions_range := context.time_partitions_range) is not None: 

1307 self._check_time_partitions([visit_time_part], time_partitions_range) 

1308 if not config.partitioning.time_partition_tables: 

1309 extra_columns["apdb_time_part"] = time_part 

1310 time_part = None 

1311 

1312 # Only store DiaObects if not doing replication or explicitly 

1313 # configured to always store them. 

1314 if replica_chunk is None or not config.replica_skips_diaobjects: 

1315 self._storeObjectsPandas( 

1316 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part 

1317 ) 

1318 

1319 if replica_chunk is not None: 

1320 extra_columns = {"apdb_replica_chunk": replica_chunk.id, validity_start_column: timestamp} 

1321 table = ExtraTables.DiaObjectChunks 

1322 if context.has_chunk_sub_partitions: 

1323 table = ExtraTables.DiaObjectChunks2 

1324 # Use a random number for a second part of partitioning key so 

1325 # that different clients could wrtite to different partitions. 

1326 # This makes it not exactly reproducible. 

1327 extra_columns["apdb_replica_subchunk"] = random.randrange(config.replica_sub_chunk_count) 

1328 self._storeObjectsPandas(objs, table, extra_columns=extra_columns) 

1329 

1330 # Store copy of the records in dedup table. 

1331 if context.has_dedup_table: 

1332 table = ExtraTables.DiaObjectDedup 

1333 extra_columns = { 

1334 "dedup_part": random.randrange(config.partitioning.num_part_dedup), 

1335 validity_start_column: timestamp, 

1336 } 

1337 self._storeObjectsPandas(objs, table, extra_columns=extra_columns) 

1338 

1339 def _storeDiaSources( 

1340 self, 

1341 table_name: ApdbTables, 

1342 sources: pandas.DataFrame, 

1343 replica_chunk: ReplicaChunk | None, 

1344 ) -> int | None: 

1345 """Store catalog of DIASources or DIAForcedSources from current visit. 

1346 

1347 Parameters 

1348 ---------- 

1349 table_name : `ApdbTables` 

1350 Table where to store the data. 

1351 sources : `pandas.DataFrame` 

1352 Catalog containing DiaSource records 

1353 visit_time : `astropy.time.Time` 

1354 Time of the current visit. 

1355 replica_chunk : `ReplicaChunk` or `None` 

1356 Replica chunk identifier if replication is configured. 

1357 

1358 Returns 

1359 ------- 

1360 subchunk : `int` or `None` 

1361 Subchunk number for resulting replica data, `None` if relication is 

1362 not enabled ot subchunking is not enabled. 

1363 """ 

1364 context = self._context 

1365 config = context.config 

1366 

1367 # Time partitioning has to be based on midpointMjdTai, not visit_time 

1368 # as visit_time is not really a visit time. 

1369 tp_sources = sources.copy(deep=False) 

1370 tp_sources["apdb_time_part"] = tp_sources["midpointMjdTai"].apply(context.partitioner.time_partition) 

1371 if (time_partitions_range := context.time_partitions_range) is not None: 

1372 self._check_time_partitions(tp_sources["apdb_time_part"], time_partitions_range) 

1373 extra_columns: dict[str, Any] = {} 

1374 if not config.partitioning.time_partition_tables: 

1375 self._storeObjectsPandas(tp_sources, table_name) 

1376 else: 

1377 # Group by time partition 

1378 partitions = set(tp_sources["apdb_time_part"]) 

1379 if len(partitions) == 1: 

1380 # Single partition - just save the whole thing. 

1381 time_part = partitions.pop() 

1382 self._storeObjectsPandas(sources, table_name, time_part=time_part) 

1383 else: 

1384 # group by time partition. 

1385 for time_part, sub_frame in tp_sources.groupby(by="apdb_time_part"): 

1386 sub_frame.drop(columns="apdb_time_part", inplace=True) 

1387 self._storeObjectsPandas(sub_frame, table_name, time_part=time_part) 

1388 

1389 subchunk: int | None = None 

1390 if replica_chunk is not None: 

1391 extra_columns = {"apdb_replica_chunk": replica_chunk.id} 

1392 if context.has_chunk_sub_partitions: 

1393 subchunk = random.randrange(config.replica_sub_chunk_count) 

1394 extra_columns["apdb_replica_subchunk"] = subchunk 

1395 if table_name is ApdbTables.DiaSource: 

1396 extra_table = ExtraTables.DiaSourceChunks2 

1397 else: 

1398 extra_table = ExtraTables.DiaForcedSourceChunks2 

1399 else: 

1400 if table_name is ApdbTables.DiaSource: 

1401 extra_table = ExtraTables.DiaSourceChunks 

1402 else: 

1403 extra_table = ExtraTables.DiaForcedSourceChunks 

1404 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns) 

1405 

1406 return subchunk 

1407 

1408 def _check_time_partitions( 

1409 self, partitions: Iterable[int], time_partitions_range: ApdbCassandraTimePartitionRange 

1410 ) -> None: 

1411 """Check that time partitons for new data actually exist. 

1412 

1413 Parameters 

1414 ---------- 

1415 partitions : `~collections.abc.Iterable` [`int`] 

1416 Time partitions for new data. 

1417 time_partitions_range : `ApdbCassandraTimePartitionRange` 

1418 Currrent time partition range. 

1419 """ 

1420 partitions = set(partitions) 

1421 min_part = min(partitions) 

1422 max_part = max(partitions) 

1423 if min_part < time_partitions_range.start or max_part > time_partitions_range.end: 

1424 raise ValueError( 

1425 "Attempt to store data for time partitions that do not yet exist. " 

1426 f"Partitons for new records: {min_part}-{max_part}. " 

1427 f"Database partitons: {time_partitions_range.start}-{time_partitions_range.end}." 

1428 ) 

1429 # Make a noise when writing to the last partition. 

1430 if max_part == time_partitions_range.end: 

1431 warnings.warn( 

1432 "Writing into the last temporal partition. Partition range needs to be extended soon.", 

1433 stacklevel=3, 

1434 ) 

1435 

1436 def _storeDiaSourcesPartitions( 

1437 self, 

1438 sources: pandas.DataFrame, 

1439 visit_time: astropy.time.Time, 

1440 replica_chunk: ReplicaChunk | None, 

1441 subchunk: int | None, 

1442 ) -> None: 

1443 """Store mapping of diaSourceId to its partitioning values. 

1444 

1445 Parameters 

1446 ---------- 

1447 sources : `pandas.DataFrame` 

1448 Catalog containing DiaSource records 

1449 visit_time : `astropy.time.Time` 

1450 Time of the current visit. 

1451 replica_chunk : `ReplicaChunk` or `None` 

1452 Replication chunk, or `None` when replication is disabled. 

1453 subchunk : `int` or `None` 

1454 Replication sub-chunk, or `None` when replication is disabled or 

1455 sub-chunking is not used. 

1456 """ 

1457 context = self._context 

1458 

1459 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]]) 

1460 extra_columns = { 

1461 "apdb_time_part": context.partitioner.time_partition(visit_time), 

1462 "apdb_replica_chunk": replica_chunk.id if replica_chunk is not None else None, 

1463 } 

1464 if context.has_chunk_sub_partitions: 

1465 extra_columns["apdb_replica_subchunk"] = subchunk 

1466 

1467 self._storeObjectsPandas( 

1468 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None 

1469 ) 

1470 

1471 def _storeObjectsPandas( 

1472 self, 

1473 records: pandas.DataFrame, 

1474 table_name: ApdbTables | ExtraTables, 

1475 extra_columns: Mapping | None = None, 

1476 time_part: int | None = None, 

1477 ) -> None: 

1478 """Store generic objects. 

1479 

1480 Takes Pandas catalog and stores a bunch of records in a table. 

1481 

1482 Parameters 

1483 ---------- 

1484 records : `pandas.DataFrame` 

1485 Catalog containing object records 

1486 table_name : `ApdbTables` 

1487 Name of the table as defined in APDB schema. 

1488 extra_columns : `dict`, optional 

1489 Mapping (column_name, column_value) which gives fixed values for 

1490 columns in each row, overrides values in ``records`` if matching 

1491 columns exist there. 

1492 time_part : `int`, optional 

1493 If not `None` then insert into a per-partition table. 

1494 

1495 Notes 

1496 ----- 

1497 If Pandas catalog contains additional columns not defined in table 

1498 schema they are ignored. Catalog does not have to contain all columns 

1499 defined in a table, but partition and clustering keys must be present 

1500 in a catalog or ``extra_columns``. 

1501 """ 

1502 context = self._context 

1503 

1504 # use extra columns if specified 

1505 if extra_columns is None: 

1506 extra_columns = {} 

1507 extra_fields = list(extra_columns.keys()) 

1508 

1509 # Fields that will come from dataframe. 

1510 df_fields = [column for column in records.columns if column not in extra_fields] 

1511 

1512 column_map = context.schema.getColumnMap(table_name) 

1513 # list of columns (as in felis schema) 

1514 fields = [column_map[field].name for field in df_fields if field in column_map] 

1515 fields += extra_fields 

1516 

1517 # check that all partitioning and clustering columns are defined 

1518 partition_columns = context.schema.partitionColumns(table_name) 

1519 required_columns = partition_columns + context.schema.clusteringColumns(table_name) 

1520 missing_columns = [column for column in required_columns if column not in fields] 

1521 if missing_columns: 

1522 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}") 

1523 

1524 batch_size = self._batch_size(table_name) 

1525 

1526 with self._timer("insert_build_time", tags={"table": table_name.name}): 

1527 # Multi-partition batches are problematic in general, so we want to 

1528 # group records in a batch by their partition key. 

1529 values_by_key: dict[tuple, list[list]] = defaultdict(list) 

1530 for rec in records.itertuples(index=False): 

1531 values = [] 

1532 partitioning_values: dict[str, Any] = {} 

1533 for field in df_fields: 

1534 if field not in column_map: 

1535 continue 

1536 value = getattr(rec, field) 

1537 if column_map[field].datatype is felis.datamodel.DataType.timestamp: 

1538 if isinstance(value, pandas.Timestamp): 

1539 value = value.to_pydatetime() 

1540 elif value is pandas.NaT: 

1541 value = None 

1542 else: 

1543 # Assume it's seconds since epoch, Cassandra 

1544 # datetime is in milliseconds 

1545 value = int(value * 1000) 

1546 value = literal(value) 

1547 values.append(UNSET_VALUE if value is None else value) 

1548 if field in partition_columns: 

1549 partitioning_values[field] = value 

1550 for field in extra_fields: 

1551 value = literal(extra_columns[field]) 

1552 values.append(UNSET_VALUE if value is None else value) 

1553 if field in partition_columns: 

1554 partitioning_values[field] = value 

1555 

1556 key = tuple(partitioning_values[field] for field in partition_columns) 

1557 values_by_key[key].append(values) 

1558 

1559 table = context.schema.tableName(table_name, time_part) 

1560 

1561 query = Insert(self._keyspace, table, fields) 

1562 statement = context.stmt_factory(query, prepare=True) 

1563 # Cassandra has 64k limit on batch size, normally that should be 

1564 # enough but some tests generate too many forced sources. 

1565 queries = [] 

1566 for key_values in values_by_key.values(): 

1567 for values_chunk in chunk_iterable(key_values, batch_size): 

1568 batch = cassandra.query.BatchStatement() 

1569 for row_values in values_chunk: 

1570 batch.add(statement, row_values) 

1571 queries.append((batch, None)) 

1572 assert batch.routing_key is not None and batch.keyspace is not None 

1573 

1574 _LOG.debug("%s: will store %d records", context.schema.tableName(table_name), records.shape[0]) 

1575 with self._timer( 

1576 "insert_time", tags={"table": table_name.name, "method": "_storeObjectsPandas"} 

1577 ) as timer: 

1578 execute_concurrent(context.session, queries, execution_profile="write") 

1579 timer.add_values(row_count=len(records), num_batches=len(queries)) 

1580 

1581 def _storeUpdateRecords( 

1582 self, records: Iterable[ApdbUpdateRecord], chunk: ReplicaChunk, *, store_chunk: bool = False 

1583 ) -> None: 

1584 """Store ApdbUpdateRecords in the replica table for those records. 

1585 

1586 Parameters 

1587 ---------- 

1588 records : `list` [`ApdbUpdateRecord`] 

1589 Records to store. 

1590 chunk : `ReplicaChunk` 

1591 Replica chunk for these records. 

1592 store_chunk : `bool` 

1593 If True then also store replica chunk. 

1594 

1595 Raises 

1596 ------ 

1597 TypeError 

1598 Raised if replication is not enabled for this instance. 

1599 """ 

1600 context = self._context 

1601 config = context.config 

1602 

1603 if not context.schema.replication_enabled: 

1604 raise TypeError("Replication is not enabled for this APDB instance.") 

1605 

1606 if store_chunk: 

1607 self._storeReplicaChunk(chunk) 

1608 

1609 apdb_replica_chunk = chunk.id 

1610 # Do not use unique_if from ReplicaChunk as it could be reused in 

1611 # multiple calls to this method. 

1612 update_unique_id = uuid.uuid4() 

1613 

1614 rows = [] 

1615 for record in records: 

1616 rows.append( 

1617 [ 

1618 apdb_replica_chunk, 

1619 record.update_time_ns, 

1620 record.update_order, 

1621 update_unique_id, 

1622 record.to_json(), 

1623 ] 

1624 ) 

1625 columns = [ 

1626 "apdb_replica_chunk", 

1627 "update_time_ns", 

1628 "update_order", 

1629 "update_unique_id", 

1630 "update_payload", 

1631 ] 

1632 if context.has_chunk_sub_partitions: 

1633 subchunk = random.randrange(config.replica_sub_chunk_count) 

1634 for row in rows: 

1635 row.append(subchunk) 

1636 columns.append("apdb_replica_subchunk") 

1637 

1638 table_name = context.schema.tableName(ExtraTables.ApdbUpdateRecordChunks) 

1639 query = Insert(self._keyspace, table_name, columns) 

1640 stmt = context.stmt_factory(query) 

1641 queries = [(stmt, row) for row in rows] 

1642 

1643 with self._timer("store_update_record", tags={"table": table_name}) as timer: 

1644 execute_concurrent(context.session, queries, execution_profile="write") 

1645 timer.add_values(row_count=len(queries)) 

1646 

1647 def _add_apdb_part(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1648 """Calculate spatial partition for each record and add it to a 

1649 DataFrame. 

1650 

1651 Parameters 

1652 ---------- 

1653 df : `pandas.DataFrame` 

1654 DataFrame which has to contain ra/dec columns, names of these 

1655 columns are defined by configuration ``ra_dec_columns`` field. 

1656 

1657 Returns 

1658 ------- 

1659 df : `pandas.DataFrame` 

1660 DataFrame with ``apdb_part`` column which contains pixel index 

1661 for ra/dec coordinates. 

1662 

1663 Notes 

1664 ----- 

1665 This overrides any existing column in a DataFrame with the same name 

1666 (``apdb_part``). Original DataFrame is not changed, copy of a DataFrame 

1667 is returned. 

1668 """ 

1669 context = self._context 

1670 config = context.config 

1671 

1672 # Calculate pixelization index for every record. 

1673 apdb_part = np.zeros(df.shape[0], dtype=np.int64) 

1674 ra_col, dec_col = config.ra_dec_columns 

1675 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

1676 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec)) 

1677 idx = context.partitioner.pixel(uv3d) 

1678 apdb_part[i] = idx 

1679 df = df.copy() 

1680 df["apdb_part"] = apdb_part 

1681 return df 

1682 

1683 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame: 

1684 """Make an empty catalog for a table with a given name. 

1685 

1686 Parameters 

1687 ---------- 

1688 table_name : `ApdbTables` 

1689 Name of the table. 

1690 

1691 Returns 

1692 ------- 

1693 catalog : `pandas.DataFrame` 

1694 An empty catalog. 

1695 """ 

1696 table = self.schema.tableSchemas[table_name] 

1697 

1698 data = {columnDef.name: pandas.Series(dtype=columnDef.pandas_type) for columnDef in table.columns} 

1699 return pandas.DataFrame(data) 

1700 

1701 def _fix_input_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame: 

1702 """Update timestamp columns in input DataFrame to be naive datetime 

1703 type. 

1704 

1705 Clients may or may not generate aware timestamps, code in this class 

1706 assumes that timestamps are naive, so we convert them to UTC and 

1707 drop timezone. 

1708 """ 

1709 # Find all columns with aware timestamps. 

1710 columns = [column for column, dtype in df.dtypes.items() if isinstance(dtype, pandas.DatetimeTZDtype)] 

1711 for column in columns: 

1712 # tz_convert(None) will convert to UTC and drop timezone. 

1713 df[column] = df[column].dt.tz_convert(None) 

1714 return df 

1715 

1716 def _batch_size(self, table: ApdbTables | ExtraTables) -> int: 

1717 """Calculate batch size based on config parameters.""" 

1718 context = self._context 

1719 config = context.config 

1720 

1721 # Cassandra limit on number of statements in a batch is 64k. 

1722 batch_size = 65_535 

1723 if 0 < config.batch_statement_limit < batch_size: 

1724 batch_size = config.batch_statement_limit 

1725 if config.batch_size_limit > 0: 

1726 # The purpose of this limit is to try not to exceed batch size 

1727 # threshold which is set on server side. Cassandra wire protocol 

1728 # for prepared queries (and batches) only sends column values with 

1729 # with an additional 4 bytes per value specifying size. Value is 

1730 # not included for NULL or NOT_SET values, but the size is always 

1731 # there. There is additional small per-query overhead, which we 

1732 # ignore. 

1733 row_size = context.schema.table_row_size(table) 

1734 row_size += 4 * len(context.schema.getColumnMap(table)) 

1735 batch_size = min(batch_size, (config.batch_size_limit // row_size) + 1) 

1736 return batch_size 

1737 

1738 def _group_dia_objects_by_partition( 

1739 self, partitioner: Partitioner, objects: list[DiaObjectId], pad_arcsec: float 

1740 ) -> Mapping[int, list[int]]: 

1741 """Group DiaObjects by partition. 

1742 

1743 Parameters 

1744 ---------- 

1745 partitioner : `Partitioner` 

1746 Objects which knows how to partition things. 

1747 objects : `list` [`DiaObjectId`] 

1748 Collection of objects to partition. 

1749 pad_arcsec : `float` 

1750 Additional padding around object position. 

1751 

1752 Returns 

1753 ------- 

1754 grouped_objects 

1755 Mapping of spatial patition ID to list ob object IDs that it 

1756 contains. Some objects may belong to more than one partition. 

1757 """ 

1758 partitioned_object_ids: dict[int, list[int]] = defaultdict(list) 

1759 for obj_id in objects: 

1760 partitions = partitioner.pixelization.circle_pixels(obj_id.ra, obj_id.dec, pad_arcsec) 

1761 for pixel in partitions: 

1762 partitioned_object_ids[pixel].append(obj_id.diaObjectId) 

1763 return partitioned_object_ids 

1764 

1765 def _timestamp_column_name(self, column: str) -> str: 

1766 """Return column name before/after schema migration to MJD TAI.""" 

1767 return self._schema.timestamp_column_name(column) 

1768 

1769 def _timestamp_column_value(self, time: astropy.time.Time) -> float | int: 

1770 """Return column value before/after schema migration to MJD TAI.""" 

1771 if self._schema.has_mjd_timestamps: 

1772 return float(time.tai.mjd) 

1773 else: 

1774 return int(time.datetime.astimezone(tz=datetime.UTC).timestamp() * 1000) 

1775 

1776 def _get_diasource_data(self, source_ids: Iterable[DiaSourceId], *columns: str) -> list: 

1777 """Select records from DiaSource table by diaSourceId and return all 

1778 records as a list of named tuples. 

1779 """ 

1780 context = self._context 

1781 config = context.config 

1782 partitioner = context.partitioner 

1783 

1784 columns = ("diaSourceId",) + columns 

1785 

1786 # Allow some uncertainty for coordinates and time when calculating 

1787 # partitions. 

1788 statements: list[tuple] = [] 

1789 pad_arcsec = 1.0 

1790 pad_time_day = 10 / (24 * 3600) 

1791 for source_id in source_ids: 

1792 center = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(source_id.ra, source_id.dec)) 

1793 region = sphgeom.Circle(center, sphgeom.Angle.fromDegrees(pad_arcsec / 3600.0)) 

1794 spatial_where, _ = partitioner.spatial_where(region) 

1795 

1796 tables, temporal_where = partitioner.temporal_where( 

1797 ApdbTables.DiaSource, 

1798 source_id.midpointMjdTai - pad_time_day, 

1799 source_id.midpointMjdTai + pad_time_day, 

1800 partitons_range=context.time_partitions_range, 

1801 query_per_time_part=True, 

1802 ) 

1803 

1804 id_where = QExpr('"diaSourceId" = {}', (source_id.diaSourceId,)) 

1805 

1806 for table in tables: 

1807 query = Select(self._keyspace, table, columns) 

1808 for clause in QExpr.combine(spatial_where, temporal_where, extra=id_where): 

1809 statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True)) 

1810 

1811 with self._timer( 

1812 "select_time", tags={"table": "DiaSource", "method": "_get_diasource_data"} 

1813 ) as timer: 

1814 result = cast( 

1815 list[tuple], 

1816 select_concurrent( 

1817 context.session, 

1818 statements, 

1819 "read_named_tuples", 

1820 config.connection_config.read_concurrency, 

1821 ), 

1822 ) 

1823 timer.add_values(row_count=len(result), num_queries=len(statements)) 

1824 

1825 return result 

1826 

1827 def _get_diaobject_data(self, object_ids: Iterable[DiaObjectId], *columns: str) -> list: 

1828 """Select records from DiaObjectLast table by diaObjectId and return 

1829 all records as a list of named tuples. 

1830 """ 

1831 context = self._context 

1832 config = context.config 

1833 partitioner = context.partitioner 

1834 

1835 table_name = context.schema.tableName(ApdbTables.DiaObjectLast) 

1836 columns = ("diaObjectId",) + columns 

1837 

1838 # Allow some uncertainty for coordinates when calculating partitions. 

1839 pad_arcsec = 1.0 

1840 ids_by_partition = defaultdict(list) 

1841 for object_id in object_ids: 

1842 pixels = partitioner.pixelization.circle_pixels(object_id.ra, object_id.dec, pad_arcsec) 

1843 for pixel in pixels: 

1844 ids_by_partition[pixel].append(object_id.diaObjectId) 

1845 

1846 statements: list[tuple] = [] 

1847 for apdb_part, diaObjectIds in ids_by_partition.items(): 

1848 query = Select(self._keyspace, table_name, columns) 

1849 query = query.where(C("apdb_part") == apdb_part) 

1850 query = query.where(C("diaObjectId").in_(diaObjectIds)) 

1851 statements.append(context.stmt_factory.with_params(query, prepare=False)) 

1852 

1853 with self._timer( 

1854 "select_time", tags={"table": "DiaObjectLast", "method": "_get_diaobject_data"} 

1855 ) as timer: 

1856 result = cast( 

1857 list[tuple], 

1858 select_concurrent( 

1859 context.session, 

1860 statements, 

1861 "read_named_tuples", 

1862 config.connection_config.read_concurrency, 

1863 ), 

1864 ) 

1865 timer.add_values(row_count=len(result), num_queries=len(statements)) 

1866 

1867 return result