Coverage for python / lsst / dax / apdb / cassandra / apdbCassandra.py: 9%
811 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 01:30 -0700
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 01:30 -0700
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ApdbCassandra"]
26import datetime
27import json
28import logging
29import random
30import uuid
31import warnings
32from collections import Counter, defaultdict
33from collections.abc import Iterable, Mapping, Set
34from typing import TYPE_CHECKING, Any, cast
36import numpy as np
37import pandas
39# If cassandra-driver is not there the module can still be imported
40# but ApdbCassandra cannot be instantiated.
41try:
42 import cassandra
43 import cassandra.query
44 from cassandra.query import UNSET_VALUE
46 CASSANDRA_IMPORTED = True
47except ImportError:
48 CASSANDRA_IMPORTED = False
50import astropy.time
51import felis.datamodel
53from lsst import sphgeom
54from lsst.utils.iteration import chunk_iterable
56from ..apdb import Apdb, ApdbConfig
57from ..apdbConfigFreezer import ApdbConfigFreezer
58from ..apdbReplica import ApdbTableData, ReplicaChunk
59from ..apdbSchema import ApdbSchema, ApdbTables
60from ..apdbUpdateRecord import (
61 ApdbCloseDiaObjectValidityRecord,
62 ApdbReassignDiaSourceToDiaObjectRecord,
63 ApdbUpdateNDiaSourcesRecord,
64)
65from ..monitor import MonAgent
66from ..recordIds import DiaObjectId, DiaSourceId
67from ..schema_model import Table
68from ..timer import Timer
69from ..versionTuple import VersionTuple
70from .apdbCassandraAdmin import ApdbCassandraAdmin
71from .apdbCassandraReplica import ApdbCassandraReplica
72from .apdbCassandraSchema import ApdbCassandraSchema, CreateTableOptions, ExtraTables
73from .apdbMetadataCassandra import ApdbMetadataCassandra
74from .cassandra_utils import (
75 ApdbCassandraTableData,
76 execute_concurrent,
77 literal,
78 select_concurrent,
79)
80from .config import ApdbCassandraConfig, ApdbCassandraConnectionConfig, ApdbCassandraTimePartitionRange
81from .connectionContext import ConnectionContext, DbVersions
82from .exceptions import CassandraMissingError
83from .partitioner import Partitioner
84from .queries import Column as C # noqa: N817
85from .queries import ColumnExpr, Delete, Insert, QExpr, Select, Update
86from .sessionFactory import SessionContext, SessionFactory
88if TYPE_CHECKING:
89 from ..apdbMetadata import ApdbMetadata
90 from ..apdbUpdateRecord import ApdbUpdateRecord
92_LOG = logging.getLogger(__name__)
94_MON = MonAgent(__name__)
96VERSION = VersionTuple(1, 3, 0)
97"""Version for the code controlling non-replication tables. This needs to be
98updated following compatibility rules when schema produced by this code
99changes.
100"""
103class ApdbCassandra(Apdb):
104 """Implementation of APDB database with Apache Cassandra backend.
106 Parameters
107 ----------
108 config : `ApdbCassandraConfig`
109 Configuration object.
110 """
112 def __init__(self, config: ApdbCassandraConfig):
113 if not CASSANDRA_IMPORTED:
114 raise CassandraMissingError()
116 self._config = config
117 self._keyspace = config.keyspace
118 self._schema = ApdbSchema(config.schema_file, config.ss_schema_file)
120 self._session_factory = SessionFactory(config)
121 self._connection_context: ConnectionContext | None = None
123 @property
124 def _context(self) -> ConnectionContext:
125 """Establish connection if not established and return context."""
126 if self._connection_context is None:
127 current_versions = DbVersions(
128 schema_version=self.schema.schemaVersion(),
129 code_version=self.apdbImplementationVersion(),
130 replica_version=ApdbCassandraReplica.apdbReplicaImplementationVersion(),
131 )
132 _LOG.debug("Current versions: %s", current_versions)
134 session = self._session_factory.session()
135 self._connection_context = ConnectionContext(
136 session, self._config, self.schema.tableSchemas, current_versions
137 )
139 if _LOG.isEnabledFor(logging.DEBUG):
140 _LOG.debug("ApdbCassandra Configuration: %s", self._connection_context.config.model_dump())
142 return self._connection_context
144 def _timer(self, name: str, *, tags: Mapping[str, str | int] | None = None) -> Timer:
145 """Create `Timer` instance given its name."""
146 return Timer(name, _MON, tags=tags)
148 @classmethod
149 def apdbImplementationVersion(cls) -> VersionTuple:
150 """Return version number for current APDB implementation.
152 Returns
153 -------
154 version : `VersionTuple`
155 Version of the code defined in implementation class.
156 """
157 return VERSION
159 def getConfig(self) -> ApdbCassandraConfig:
160 # docstring is inherited from a base class
161 return self._context.config
163 def tableDef(self, table: ApdbTables) -> Table | None:
164 # docstring is inherited from a base class
165 return self.schema.tableSchemas.get(table)
167 @classmethod
168 def init_database(
169 cls,
170 hosts: tuple[str, ...],
171 keyspace: str,
172 *,
173 schema_file: str | None = None,
174 ss_schema_file: str | None = None,
175 read_sources_months: int | None = None,
176 read_forced_sources_months: int | None = None,
177 enable_replica: bool = False,
178 replica_skips_diaobjects: bool = False,
179 port: int | None = None,
180 username: str | None = None,
181 dbauth_alias: str | None = None,
182 prefix: str | None = None,
183 part_pixelization: str | None = None,
184 part_pix_level: int | None = None,
185 time_partition_tables: bool = True,
186 time_partition_start: str | None = None,
187 time_partition_end: str | None = None,
188 read_consistency: str | None = None,
189 write_consistency: str | None = None,
190 read_timeout: int | None = None,
191 write_timeout: int | None = None,
192 ra_dec_columns: tuple[str, str] | None = None,
193 replication_factor: int | None = None,
194 drop: bool = False,
195 table_options: CreateTableOptions | None = None,
196 ) -> ApdbCassandraConfig:
197 """Initialize new APDB instance and make configuration object for it.
199 Parameters
200 ----------
201 hosts : `tuple` [`str`, ...]
202 List of host names or IP addresses for Cassandra cluster.
203 keyspace : `str`
204 Name of the keyspace for APDB tables.
205 schema_file : `str`, optional
206 Location of (YAML) configuration file with APDB schema. If not
207 specified then default location will be used.
208 ss_schema_file : `str`, optional
209 Location of (YAML) configuration file with SSO schema. If not
210 specified then default location will be used.
211 read_sources_months : `int`, optional
212 Number of months of history to read from DiaSource.
213 read_forced_sources_months : `int`, optional
214 Number of months of history to read from DiaForcedSource.
215 enable_replica : `bool`, optional
216 If True, make additional tables used for replication to PPDB.
217 replica_skips_diaobjects : `bool`, optional
218 If `True` then do not fill regular ``DiaObject`` table when
219 ``enable_replica`` is `True`.
220 port : `int`, optional
221 Port number to use for Cassandra connections.
222 username : `str`, optional
223 User name for Cassandra connections.
224 dbauth_alias : `str`, optional
225 If specified then this string will be used to as a host name when
226 checking credentials in db-auth.yaml in addition to regular host
227 names in contact_points. For example if
228 dbauth_alias='pp_apdb_prod_cluster' then the entry
229 'cassandra://pp_apdb_prod_cluster/' will match. Port number should
230 not be used in that entry. Alias has higher priority than host
231 names.
232 prefix : `str`, optional
233 Optional prefix for all table names.
234 part_pixelization : `str`, optional
235 Name of the MOC pixelization used for partitioning.
236 part_pix_level : `int`, optional
237 Pixelization level.
238 time_partition_tables : `bool`, optional
239 Create per-partition tables.
240 time_partition_start : `str`, optional
241 Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss
242 format, in TAI.
243 time_partition_end : `str`, optional
244 Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss
245 format, in TAI.
246 read_consistency : `str`, optional
247 Name of the consistency level for read operations.
248 write_consistency : `str`, optional
249 Name of the consistency level for write operations.
250 read_timeout : `int`, optional
251 Read timeout in seconds.
252 write_timeout : `int`, optional
253 Write timeout in seconds.
254 ra_dec_columns : `tuple` [`str`, `str`], optional
255 Names of ra/dec columns in DiaObject table.
256 replication_factor : `int`, optional
257 Replication factor used when creating new keyspace, if keyspace
258 already exists its replication factor is not changed.
259 drop : `bool`, optional
260 If `True` then drop existing tables before re-creating the schema.
261 table_options : `CreateTableOptions`, optional
262 Options used when creating Cassandra tables.
264 Returns
265 -------
266 config : `ApdbCassandraConfig`
267 Resulting configuration object for a created APDB instance.
268 """
269 # Some non-standard defaults for connection parameters, these can be
270 # changed later in generated config. Check Cassandra driver
271 # documentation for what these parameters do. These parameters are not
272 # used during database initialization, but they will be saved with
273 # generated config.
274 connection_config = ApdbCassandraConnectionConfig(
275 extra_parameters={
276 "idle_heartbeat_interval": 0,
277 "idle_heartbeat_timeout": 30,
278 "control_connection_timeout": 100,
279 },
280 )
281 config = ApdbCassandraConfig(
282 contact_points=hosts,
283 keyspace=keyspace,
284 enable_replica=enable_replica,
285 replica_skips_diaobjects=replica_skips_diaobjects,
286 connection_config=connection_config,
287 )
288 config.partitioning.time_partition_tables = time_partition_tables
289 if schema_file is not None:
290 config.schema_file = schema_file
291 if ss_schema_file is not None:
292 config.ss_schema_file = ss_schema_file
293 if read_sources_months is not None:
294 config.read_sources_months = read_sources_months
295 if read_forced_sources_months is not None:
296 config.read_forced_sources_months = read_forced_sources_months
297 if port is not None:
298 config.connection_config.port = port
299 if username is not None:
300 config.connection_config.username = username
301 if dbauth_alias is not None:
302 config.connection_config.dbauth_alias = dbauth_alias
303 if prefix is not None:
304 config.prefix = prefix
305 if part_pixelization is not None:
306 config.partitioning.part_pixelization = part_pixelization
307 if part_pix_level is not None:
308 config.partitioning.part_pix_level = part_pix_level
309 if time_partition_start is not None:
310 config.partitioning.time_partition_start = time_partition_start
311 if time_partition_end is not None:
312 config.partitioning.time_partition_end = time_partition_end
313 if read_consistency is not None:
314 config.connection_config.read_consistency = read_consistency
315 if write_consistency is not None:
316 config.connection_config.write_consistency = write_consistency
317 if read_timeout is not None:
318 config.connection_config.read_timeout = read_timeout
319 if write_timeout is not None:
320 config.connection_config.write_timeout = write_timeout
321 if ra_dec_columns is not None:
322 config.ra_dec_columns = ra_dec_columns
324 cls._makeSchema(config, drop=drop, replication_factor=replication_factor, table_options=table_options)
326 return config
328 def get_replica(self) -> ApdbCassandraReplica:
329 """Return `ApdbReplica` instance for this database."""
330 # Note that this instance has to stay alive while replica exists, so
331 # we pass reference to self.
332 return ApdbCassandraReplica(self)
334 @classmethod
335 def _makeSchema(
336 cls,
337 config: ApdbConfig,
338 *,
339 drop: bool = False,
340 replication_factor: int | None = None,
341 table_options: CreateTableOptions | None = None,
342 ) -> None:
343 # docstring is inherited from a base class
345 if not isinstance(config, ApdbCassandraConfig):
346 raise TypeError(f"Unexpected type of configuration object: {type(config)}")
348 simple_schema = ApdbSchema(config.schema_file, config.ss_schema_file)
350 with SessionContext(config) as session:
351 schema = ApdbCassandraSchema(
352 session=session,
353 keyspace=config.keyspace,
354 table_schemas=simple_schema.tableSchemas,
355 prefix=config.prefix,
356 time_partition_tables=config.partitioning.time_partition_tables,
357 enable_replica=config.enable_replica,
358 replica_skips_diaobjects=config.replica_skips_diaobjects,
359 )
361 # Ask schema to create all tables.
362 part_range_config: ApdbCassandraTimePartitionRange | None = None
363 if config.partitioning.time_partition_tables:
364 partitioner = Partitioner(config)
365 time_partition_start = astropy.time.Time(
366 config.partitioning.time_partition_start, format="isot", scale="tai"
367 )
368 time_partition_end = astropy.time.Time(
369 config.partitioning.time_partition_end, format="isot", scale="tai"
370 )
371 part_range_config = ApdbCassandraTimePartitionRange(
372 start=partitioner.time_partition(time_partition_start),
373 end=partitioner.time_partition(time_partition_end),
374 )
375 schema.makeSchema(
376 drop=drop,
377 part_range=part_range_config,
378 replication_factor=replication_factor,
379 table_options=table_options,
380 )
381 else:
382 schema.makeSchema(
383 drop=drop, replication_factor=replication_factor, table_options=table_options
384 )
386 meta_table_name = ApdbTables.metadata.table_name(config.prefix)
387 metadata = ApdbMetadataCassandra(
388 session, meta_table_name, config.keyspace, "read_tuples", "write"
389 )
391 # Fill version numbers, overrides if they existed before.
392 metadata.set(
393 ConnectionContext.metadataSchemaVersionKey, str(simple_schema.schemaVersion()), force=True
394 )
395 metadata.set(
396 ConnectionContext.metadataCodeVersionKey, str(cls.apdbImplementationVersion()), force=True
397 )
399 if config.enable_replica:
400 # Only store replica code version if replica is enabled.
401 metadata.set(
402 ConnectionContext.metadataReplicaVersionKey,
403 str(ApdbCassandraReplica.apdbReplicaImplementationVersion()),
404 force=True,
405 )
407 # Store frozen part of a configuration in metadata.
408 freezer = ApdbConfigFreezer[ApdbCassandraConfig](ConnectionContext.frozen_parameters)
409 metadata.set(ConnectionContext.metadataConfigKey, freezer.to_json(config), force=True)
411 # Store time partition range.
412 if part_range_config:
413 part_range_config.save_to_meta(metadata)
415 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
416 # docstring is inherited from a base class
417 context = self._context
418 config = context.config
420 sp_where, num_sp_part = context.partitioner.spatial_where(region)
421 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
423 # We need to exclude extra partitioning columns from result.
424 column_names = context.schema.apdbColumnNames(ApdbTables.DiaObjectLast)
425 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
426 query = Select(self._keyspace, table_name, column_names)
427 statements: list[tuple] = []
428 for where_clause in sp_where:
429 full_query = query.where(where_clause)
430 statements.append(context.stmt_factory.with_params(full_query, prepare=True))
431 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
433 with self._timer("select_time", tags={"table": "DiaObject", "method": "getDiaObjects"}) as timer:
434 raw_objects = cast(
435 ApdbCassandraTableData,
436 select_concurrent(
437 context.session,
438 statements,
439 "read_raw_multi",
440 config.connection_config.read_concurrency,
441 ),
442 )
443 objects = raw_objects.to_pandas(context.schema._table_schema(ApdbTables.DiaObjectLast))
444 timer.add_values(row_count=len(objects), num_sp_part=num_sp_part, num_queries=len(statements))
446 _LOG.debug("found %s DiaObjects", objects.shape[0])
447 return objects
449 def getDiaSources(
450 self,
451 region: sphgeom.Region,
452 object_ids: Iterable[int] | None,
453 visit_time: astropy.time.Time,
454 start_time: astropy.time.Time | None = None,
455 ) -> pandas.DataFrame | None:
456 # docstring is inherited from a base class
457 context = self._context
458 config = context.config
460 months = config.read_sources_months
461 if start_time is None and months == 0:
462 return None
464 mjd_end = float(visit_time.tai.mjd)
465 if start_time is None:
466 mjd_start = mjd_end - months * 30
467 else:
468 mjd_start = float(start_time.tai.mjd)
470 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
472 def getDiaForcedSources(
473 self,
474 region: sphgeom.Region,
475 object_ids: Iterable[int] | None,
476 visit_time: astropy.time.Time,
477 start_time: astropy.time.Time | None = None,
478 ) -> pandas.DataFrame | None:
479 # docstring is inherited from a base class
480 context = self._context
481 config = context.config
483 months = config.read_forced_sources_months
484 if start_time is None and months == 0:
485 return None
487 mjd_end = float(visit_time.tai.mjd)
488 if start_time is None:
489 mjd_start = mjd_end - months * 30
490 else:
491 mjd_start = float(start_time.tai.mjd)
493 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
495 def getDiaObjectsForDedup(self, since: astropy.time.Time | None = None) -> pandas.DataFrame:
496 # docstring is inherited from a base class
497 context = self._context
498 config = context.config
500 if not context.has_dedup_table:
501 raise TypeError("DiaObjectDedup table does not exist in this APDB instance.")
503 if since is None:
504 # Read last deduplication time from metadata.
505 dedup_str = context.metadata.get(context.metadataDedupKey)
506 if dedup_str is not None:
507 dedup_state = json.loads(dedup_str)
508 dedup_time_str = dedup_state["dedup_time_iso_tai"]
509 since = astropy.time.Time(dedup_time_str, format="iso", scale="tai")
511 column_names = context.schema.apdbColumnNames(ExtraTables.DiaObjectDedup)
513 validity_start_column = self._timestamp_column_name("validityStart")
514 timestamp = None if since is None else self._timestamp_column_value(since)
516 table_name = context.schema.tableName(ExtraTables.DiaObjectDedup)
517 query = Select(self._keyspace, table_name, column_names, extra_clause="ALLOW FILTERING")
518 query = query.where(C("dedup_part") == 0)
519 if since is not None:
520 query = query.where(C(validity_start_column) >= 0)
522 statement = context.stmt_factory(query, prepare=False)
524 num_part = config.partitioning.num_part_dedup
525 statements = []
526 for dedup_part in range(num_part):
527 params = (dedup_part,) if timestamp is None else (dedup_part, timestamp)
528 statements.append((statement, params))
530 with self._timer(
531 "select_time", tags={"table": "DiaObjectDedup", "method": "getDiaObjectsForDedup"}
532 ) as timer:
533 objects_raw = cast(
534 ApdbCassandraTableData,
535 select_concurrent(
536 context.session,
537 statements,
538 "read_raw_multi_dedup",
539 config.connection_config.read_concurrency,
540 ),
541 )
542 objects = objects_raw.to_pandas(context.schema._table_schema(ExtraTables.DiaObjectDedup))
543 timer.add_values(row_count=len(objects), num_queries=num_part)
545 _LOG.debug("found %s DiaObjectDedup records", objects.shape[0])
546 return objects
548 def getDiaSourcesForDiaObjects(
549 self, objects: list[DiaObjectId], start_time: astropy.time.Time, max_dist_arcsec: float = 1.0
550 ) -> pandas.DataFrame:
551 # docstring is inherited from a base class
552 context = self._context
553 config = context.config
555 # Which tables to query and temporal constraints.
556 end_time = self._current_time()
557 tables, temporal_where = context.partitioner.temporal_where(
558 ApdbTables.DiaSource,
559 start_time,
560 end_time,
561 partitons_range=context.time_partitions_range,
562 query_per_time_part=False,
563 )
564 if not tables:
565 warnings.warn(
566 f"Query time range ({start_time.isot} - {end_time.isot}) does not overlap database "
567 "time partitions."
568 )
570 # Group DiaObjects by partition.
571 partitioned_object_ids = self._group_dia_objects_by_partition(
572 context.partitioner, objects, max_dist_arcsec
573 )
575 # Columns to return.
576 column_names = context.schema.apdbColumnNames(ApdbTables.DiaSource)
578 # Make a bunch of queries.
579 statements = []
580 for apdb_part, diaObjectIds in partitioned_object_ids.items():
581 spatial_where = [C("apdb_part") == apdb_part]
582 for table in tables:
583 query = Select(self._keyspace, table, column_names, extra_clause="ALLOW FILTERING")
584 for id_chunk in chunk_iterable(diaObjectIds, 10_000):
585 id_where = C("diaObjectId").in_(id_chunk)
586 for clause in QExpr.combine(spatial_where, temporal_where, extra=id_where):
587 statements.append(
588 context.stmt_factory.with_params(query.where(clause), prepare=False)
589 )
591 _LOG.debug("getDiaSourcesForDiaObjects #queries: %s", len(statements))
593 with self._timer(
594 "select_time", tags={"table": "DiaSource", "method": "getDiaSourcesForDiaObjects"}
595 ) as timer:
596 table_data_raw = cast(
597 ApdbCassandraTableData,
598 select_concurrent(
599 context.session,
600 statements,
601 "read_raw_multi",
602 config.connection_config.read_concurrency,
603 ),
604 )
605 catalog = table_data_raw.to_pandas(context.schema._table_schema(ApdbTables.DiaSource))
606 timer.add_values(row_count_from_db=len(catalog), num_queries=len(statements))
608 # precise filtering on midpointMjdTai
609 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] >= start_time.tai.mjd])
611 timer.add_values(row_count=len(catalog))
613 _LOG.debug("found %d DiaSources", len(catalog))
614 return catalog
616 def containsVisitDetector(
617 self,
618 visit: int,
619 detector: int,
620 region: sphgeom.Region,
621 visit_time: astropy.time.Time,
622 ) -> bool:
623 # docstring is inherited from a base class
624 context = self._context
626 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector)
627 query = Select(self._keyspace, table_name, [ColumnExpr("count(*)")])
628 query = query.where((C("visit") == visit) & (C("detector") == detector))
629 stmt, params = context.stmt_factory.with_params(query, prepare=False)
631 with self._timer("contains_visit_detector_time", tags={"table": table_name}):
632 result = context.session.execute(stmt, params)
633 return bool(result.one()[0])
635 def store(
636 self,
637 visit_time: astropy.time.Time,
638 objects: pandas.DataFrame,
639 sources: pandas.DataFrame | None = None,
640 forced_sources: pandas.DataFrame | None = None,
641 ) -> None:
642 # docstring is inherited from a base class
643 context = self._context
644 config = context.config
646 # Store visit/detector in a special table, this has to be done
647 # before all other writes so if there is a failure at any point
648 # later we still have a record for attempted write.
649 visit_detector: set[tuple[int, int]] = set()
650 for df in sources, forced_sources:
651 if df is not None and not df.empty:
652 df = df[["visit", "detector"]]
653 for visit, detector in df.itertuples(index=False):
654 visit_detector.add((visit, detector))
656 if visit_detector:
657 # Typically there is only one entry, do not bother with
658 # concurrency.
659 table_name = context.schema.tableName(ExtraTables.ApdbVisitDetector)
660 query = Insert(self._keyspace, table_name, ("visit", "detector"))
661 stmt = context.stmt_factory(query)
662 for item in visit_detector:
663 context.session.execute(stmt, item, execution_profile="write")
665 objects = self._fix_input_timestamps(objects)
666 if sources is not None:
667 sources = self._fix_input_timestamps(sources)
668 if forced_sources is not None:
669 forced_sources = self._fix_input_timestamps(forced_sources)
671 replica_chunk: ReplicaChunk | None = None
672 if context.schema.replication_enabled:
673 replica_chunk = ReplicaChunk.make_replica_chunk(visit_time, config.replica_chunk_seconds)
674 self._storeReplicaChunk(replica_chunk)
676 # fill region partition column for DiaObjects
677 objects = self._add_apdb_part(objects)
678 self._storeDiaObjects(objects, visit_time, replica_chunk)
680 if sources is not None and len(sources) > 0:
681 # copy apdb_part column from DiaObjects to DiaSources
682 sources = self._add_apdb_part(sources)
683 subchunk = self._storeDiaSources(ApdbTables.DiaSource, sources, replica_chunk)
684 self._storeDiaSourcesPartitions(sources, visit_time, replica_chunk, subchunk)
686 if forced_sources is not None and len(forced_sources) > 0:
687 forced_sources = self._add_apdb_part(forced_sources)
688 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, replica_chunk)
690 def reassignDiaSourcesToDiaObjects(
691 self,
692 idMap: Mapping[DiaSourceId, int],
693 *,
694 increment_nDiaSources: bool = True,
695 decrement_nDiaSources: bool = True,
696 ) -> None:
697 # docstring is inherited from a base class
698 context = self._context
699 config = context.config
701 source_ids = {source_id.diaSourceId for source_id in idMap}
703 # Find all DiaSources.
704 found_sources = self._get_diasource_data(
705 idMap, "apdb_part", "diaObjectId", "ra", "dec", "midpointMjdTai"
706 )
708 if missing_ids := (source_ids - {row.diaSourceId for row in found_sources}):
709 raise LookupError(f"Some source IDs were not found in DiaSource table: {missing_ids}")
711 found_sources_by_id = {row.diaSourceId: row for row in found_sources}
713 # Make sure that all DiaObjects exist, we also want to know
714 # nDiaSources count for current and new records because we want to
715 # send updated values to replica.
716 current_object_ids = {
717 DiaObjectId(diaObjectId=row.diaObjectId, ra=row.ra, dec=row.dec) for row in found_sources
718 }
719 # Assume that DiaSource ra/dec are very close to re-assigned objects.
720 new_object_ids = {
721 DiaObjectId(diaObjectId=diaObjectId, ra=source_id.ra, dec=source_id.dec)
722 for source_id, diaObjectId in idMap.items()
723 }
724 all_object_ids = new_object_ids | current_object_ids
725 found_objects = self._get_diaobject_data(all_object_ids, "apdb_part", "ra", "dec", "nDiaSources")
727 if missing_ids := (
728 {row.diaObjectId for row in all_object_ids} - {row.diaObjectId for row in found_objects}
729 ):
730 raise LookupError(f"Some object IDs were not found in DiaObjectLast table: {missing_ids}")
732 update_records: list[ApdbUpdateRecord] = []
733 update_order = 0
734 current_time = self._current_time()
735 current_time_ns = int(current_time.unix_tai * 1e9)
737 # Update DiaSources.
738 statements: list[tuple] = []
739 for source_id, diaObjectId in idMap.items():
740 source_row = found_sources_by_id[source_id.diaSourceId]
741 apdb_part = source_row.apdb_part
742 time_part = context.partitioner.time_partition(source_row.midpointMjdTai)
744 if config.partitioning.time_partition_tables:
745 table_name = context.schema.tableName(ApdbTables.DiaSource, time_part)
746 update = (
747 Update(self._keyspace, table_name)
748 .values(C("diaObjectId").update(diaObjectId))
749 .where(C("apdb_part") == apdb_part)
750 .where(C("diaSourceId") == source_id.diaSourceId)
751 )
752 else:
753 table_name = context.schema.tableName(ApdbTables.DiaSource)
754 update = (
755 Update(self._keyspace, table_name)
756 .values(C("diaObjectId").update(diaObjectId))
757 .where(C("apdb_part") == apdb_part)
758 .where(C("apdb_time_part") == time_part)
759 .where(C("diaSourceId") == source_id.diaSourceId)
760 )
761 statements.append(context.stmt_factory.with_params(update, prepare=True))
763 if context.schema.replication_enabled:
764 update_records.append(
765 ApdbReassignDiaSourceToDiaObjectRecord(
766 diaSourceId=source_id.diaSourceId,
767 ra=source_id.ra,
768 dec=source_id.dec,
769 midpointMjdTai=source_id.midpointMjdTai,
770 diaObjectId=diaObjectId,
771 update_time_ns=current_time_ns,
772 update_order=update_order,
773 )
774 )
775 update_order += 1
777 with self._timer(
778 "update_time", tags={"table": "DiaSource", "method": "reassignDiaSourcesToDiaObjects"}
779 ) as timer:
780 execute_concurrent(context.session, statements, execution_profile="write")
781 timer.add_values(num_queries=len(statements))
783 # Update nDiaSources in DiaObjectLast. We do not update DiaObject table
784 # here because it may not even exist. PPDB updates DiaObject from
785 # update records.
786 if increment_nDiaSources or decrement_nDiaSources:
787 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
788 update = (
789 Update(self._keyspace, table_name)
790 .values(C("nDiaSources").update(-1))
791 .where(C("apdb_part") == -1)
792 .where(C("diaObjectId") == -1)
793 )
794 statement = context.stmt_factory(update, prepare=True)
795 statements = []
797 # Calculate increments/decrements for all affected DiaObjects.
798 increments: Counter = Counter()
799 if increment_nDiaSources:
800 increments.update(idMap.values())
801 if decrement_nDiaSources:
802 increments.subtract(row.diaObjectId for row in found_sources)
804 for row in found_objects:
805 if increments.get(row.diaObjectId):
806 nDiaSources = row.nDiaSources + increments[row.diaObjectId]
807 statements.append((statement, (nDiaSources, row.apdb_part, row.diaObjectId)))
809 # Also send updated values to replica.
810 if context.schema.replication_enabled:
811 update_records.append(
812 ApdbUpdateNDiaSourcesRecord(
813 diaObjectId=row.diaObjectId,
814 ra=row.ra,
815 dec=row.dec,
816 nDiaSources=nDiaSources,
817 update_time_ns=current_time_ns,
818 update_order=update_order,
819 )
820 )
821 update_order += 1
823 if statements:
824 with self._timer(
825 "update_time", tags={"table": table_name, "method": "reassignDiaSourcesToDiaObjects"}
826 ) as timer:
827 execute_concurrent(context.session, statements, execution_profile="write")
828 timer.add_values(num_queries=len(statements))
830 if update_records:
831 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds)
832 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True)
834 def setValidityEnd(
835 self, objects: list[DiaObjectId], validityEnd: astropy.time.Time, raise_on_missing_id: bool = False
836 ) -> int:
837 # docstring is inherited from a base class
838 if not objects:
839 return 0
841 context = self._context
842 config = context.config
844 pad_arcsec = 1.0
845 partitioned_object_ids = self._group_dia_objects_by_partition(
846 context.partitioner, objects, pad_arcsec
847 )
849 # Check that all objects exist.
850 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
851 statements: list[tuple] = []
852 for apdb_part, diaObjectIds in partitioned_object_ids.items():
853 query = Select(self._keyspace, table_name, ["apdb_part", "diaObjectId"])
854 query = query.where(C("apdb_part") == apdb_part)
855 query = query.where(C("diaObjectId").in_(diaObjectIds))
856 statements.append(context.stmt_factory.with_params(query, prepare=False))
858 with self._timer("select_time", tags={"table": table_name, "method": "setValidityEnd"}) as timer:
859 records = cast(
860 list[tuple[int, int]],
861 select_concurrent(
862 context.session,
863 statements,
864 "read_tuples",
865 config.connection_config.read_concurrency,
866 ),
867 )
868 timer.add_values(row_count=len(objects))
870 requested_ids = {obj.diaObjectId for obj in objects}
871 found_ids = {rec[1] for rec in records}
872 if extra_ids := (found_ids - requested_ids):
873 raise RuntimeError(f"Consistency error - found duplicate records for object IDs: {extra_ids}")
874 if raise_on_missing_id:
875 if missing_ids := (requested_ids - found_ids):
876 raise LookupError(f"Some object IDs are missing from DiaObjectLast table: {missing_ids}")
878 # Filter existing records.
879 if len(objects) != len(found_ids):
880 objects = [obj for obj in objects if obj.diaObjectId in found_ids]
882 if not objects:
883 return 0
885 # Group by partitions again.
886 grouped_object_ids: dict[int, list[int]] = defaultdict(list)
887 for apdb_part, diaObjectId in records:
888 grouped_object_ids[apdb_part].append(diaObjectId)
890 # Remove all matching rows from DiaObjectLast.
891 statements = []
892 for apdb_part, diaObjectIds in grouped_object_ids.items():
893 delete = (
894 Delete(self._keyspace, table_name)
895 .where(C("apdb_part") == apdb_part)
896 .where(C("diaObjectId").in_(diaObjectIds))
897 )
898 statements.append(context.stmt_factory.with_params(delete))
900 # Also remove from DiaObjectLastToPartition.
901 reverse_table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition)
902 delete = Delete(self._keyspace, reverse_table_name).where(
903 C("diaObjectId").in_([rec[1] for rec in records])
904 )
905 statements.append(context.stmt_factory.with_params(delete))
907 with self._timer("delete_time", tags={"table": table_name, "method": "setValidityEnd"}) as timer:
908 execute_concurrent(context.session, statements, execution_profile="write")
909 timer.add_values(row_count=len(records))
911 # If repication is enabled then send all updates.
912 if context.schema.replication_enabled:
913 current_time = self._current_time()
914 current_time_ns = int(current_time.unix_tai * 1e9)
915 replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds)
917 update_records = [
918 ApdbCloseDiaObjectValidityRecord(
919 diaObjectId=obj.diaObjectId,
920 ra=obj.ra,
921 dec=obj.dec,
922 update_time_ns=current_time_ns,
923 update_order=index,
924 validityEndMjdTai=float(validityEnd.tai.mjd),
925 nDiaSources=None,
926 )
927 for index, obj in enumerate(objects)
928 ]
930 self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True)
932 return len(objects)
934 def resetDedup(self, dedup_time: astropy.time.Time | None = None) -> None:
935 # docstring is inherited from a base class
936 context = self._context
938 if not context.has_dedup_table:
939 raise TypeError("DiaObjectDedup table does not exist in this APDB instance.")
941 if dedup_time is None:
942 dedup_time = self._current_time()
944 validity_start_column = self._timestamp_column_name("validityStart")
946 # Find latest timestamp in deduplication table.
947 table_name = context.schema.tableName(ExtraTables.DiaObjectDedup)
948 query = Select(self._keyspace, table_name, [ColumnExpr(f'MAX("{validity_start_column}")')])
949 stmt = context.stmt_factory(query, prepare=False)
951 result = context.session.execute(stmt, execution_profile="read_tuples")
952 max_value = result.one()[0]
953 if self._schema.has_mjd_timestamps:
954 max_validity_start = astropy.time.Time(max_value, format="mjd", scale="tai")
955 else:
956 max_validity_start = astropy.time.Time(max_value, format="datetime", scale="tai")
958 # If max time is lower than dedup time we can do TRUNCATE.
959 if dedup_time >= max_validity_start:
960 query_str = f'TRUNCATE TABLE "{self._keyspace}"."{table_name}"'
961 context.session.execute(query_str, execution_profile="write")
962 else:
963 dedup_time_value = self._timestamp_column_value(dedup_time)
964 delete = Delete(self._keyspace, table_name).where(C(validity_start_column) < dedup_time_value)
965 stmt, params = context.stmt_factory.with_params(delete)
966 context.session.execute(stmt, params, execution_profile="write")
968 # Store dedup time.
969 data = {"dedup_time_iso_tai": dedup_time.tai.to_value("iso")}
970 data_json = json.dumps(data)
971 context.metadata.set(context.metadataDedupKey, data_json, force=True)
973 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
974 # docstring is inherited from a base class
975 context = self._context
976 config = context.config
978 now = self._current_time()
979 reassign_time_column = self._timestamp_column_name("ssObjectReassocTime")
980 reassignTime = self._timestamp_column_value(now)
982 # To update a record we need to know its exact primary key (including
983 # partition key) so we start by querying for diaSourceId to find the
984 # primary keys.
986 table_name = context.schema.tableName(ExtraTables.DiaSourceToPartition)
987 # split it into 1k IDs per query
988 selects: list[tuple] = []
989 columns = ["diaSourceId", "apdb_part", "apdb_time_part", "apdb_replica_chunk"]
990 query = Select(self._keyspace, table_name, columns)
991 for ids in chunk_iterable(idMap.keys(), 1_000):
992 full_query = query.where(C("diaSourceId").in_(ids))
993 selects.append(context.stmt_factory.with_params(full_query, prepare=False))
995 # No need for DataFrame here, read data as tuples.
996 result = cast(
997 list[tuple[int, int, int, int | None]],
998 select_concurrent(
999 context.session, selects, "read_tuples", config.connection_config.read_concurrency
1000 ),
1001 )
1003 # Make mapping from source ID to its partition.
1004 id2partitions: dict[int, tuple[int, int]] = {}
1005 id2chunk_id: dict[int, int] = {}
1006 for row in result:
1007 id2partitions[row[0]] = row[1:3]
1008 if row[3] is not None:
1009 id2chunk_id[row[0]] = row[3]
1011 # make sure we know partitions for each ID
1012 if set(id2partitions) != set(idMap):
1013 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
1014 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
1016 # Reassign in standard tables
1017 queries: list[tuple[cassandra.query.PreparedStatement, tuple]] = []
1018 for diaSourceId, ssObjectId in idMap.items():
1019 apdb_part, apdb_time_part = id2partitions[diaSourceId]
1020 if config.partitioning.time_partition_tables:
1021 table_name = context.schema.tableName(ApdbTables.DiaSource, apdb_time_part)
1022 update = (
1023 Update(self._keyspace, table_name)
1024 .values(
1025 C("ssObjectId").update(ssObjectId),
1026 C("diaObjectId").update(None),
1027 C(reassign_time_column).update(reassignTime),
1028 )
1029 .where(C("apdb_part") == apdb_part)
1030 .where(C("diaSourceId") == diaSourceId)
1031 )
1032 else:
1033 table_name = context.schema.tableName(ApdbTables.DiaSource)
1034 update = (
1035 Update(self._keyspace, table_name)
1036 .values(
1037 C("ssObjectId").update(ssObjectId),
1038 C("diaObjectId").update(None),
1039 C(reassign_time_column).update(reassignTime),
1040 )
1041 .where(C("apdb_part") == apdb_part)
1042 .where(C("apdb_time_part") == apdb_time_part)
1043 .where(C("diaSourceId") == diaSourceId)
1044 )
1045 queries.append(context.stmt_factory.with_params(update, prepare=True))
1047 # TODO: (DM-50190) Replication for updated records is not implemented.
1048 if id2chunk_id:
1049 warnings.warn("Replication of reassigned DiaSource records is not implemented.", stacklevel=2)
1051 _LOG.debug("%s: will update %d records", table_name, len(idMap))
1052 with self._timer("source_reassign_time") as timer:
1053 execute_concurrent(context.session, queries, execution_profile="write")
1054 timer.add_values(source_count=len(idMap))
1056 def countUnassociatedObjects(self) -> int:
1057 # docstring is inherited from a base class
1059 # It's too inefficient to implement it for Cassandra in current schema.
1060 raise NotImplementedError()
1062 @property
1063 def schema(self) -> ApdbSchema:
1064 # docstring is inherited from a base class
1065 return self._schema
1067 @property
1068 def metadata(self) -> ApdbMetadata:
1069 # docstring is inherited from a base class
1070 context = self._context
1071 return context.metadata
1073 @property
1074 def admin(self) -> ApdbCassandraAdmin:
1075 # docstring is inherited from a base class
1076 return ApdbCassandraAdmin(self)
1078 def _getSources(
1079 self,
1080 region: sphgeom.Region,
1081 object_ids: Iterable[int] | None,
1082 mjd_start: float,
1083 mjd_end: float,
1084 table_name: ApdbTables,
1085 ) -> pandas.DataFrame:
1086 """Return catalog of DiaSource instances given set of DiaObject IDs.
1088 Parameters
1089 ----------
1090 region : `lsst.sphgeom.Region`
1091 Spherical region.
1092 object_ids :
1093 Collection of DiaObject IDs
1094 mjd_start : `float`
1095 Lower bound of time interval.
1096 mjd_end : `float`
1097 Upper bound of time interval.
1098 table_name : `ApdbTables`
1099 Name of the table.
1101 Returns
1102 -------
1103 catalog : `pandas.DataFrame`, or `None`
1104 Catalog containing DiaSource records. Empty catalog is returned if
1105 ``object_ids`` is empty.
1106 """
1107 context = self._context
1108 config = context.config
1110 object_id_set: Set[int] = set()
1111 if object_ids is not None:
1112 object_id_set = set(object_ids)
1113 if len(object_id_set) == 0:
1114 return self._make_empty_catalog(table_name)
1116 sp_where, num_sp_part = context.partitioner.spatial_where(region)
1117 tables, temporal_where = context.partitioner.temporal_where(
1118 table_name, mjd_start, mjd_end, partitons_range=context.time_partitions_range
1119 )
1120 if not tables:
1121 start = astropy.time.Time(mjd_start, format="mjd", scale="tai")
1122 end = astropy.time.Time(mjd_end, format="mjd", scale="tai")
1123 warnings.warn(
1124 f"Query time range ({start.isot} - {end.isot}) does not overlap database time partitions."
1125 )
1127 # We need to exclude extra partitioning columns from result.
1128 column_names = context.schema.apdbColumnNames(table_name)
1130 # Build all queries
1131 statements: list[tuple] = []
1132 for table in tables:
1133 query = Select(self._keyspace, table, column_names)
1134 for clause in QExpr.combine(sp_where, temporal_where):
1135 statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True))
1136 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
1138 with self._timer("select_time", tags={"table": table_name.name, "method": "_getSources"}) as timer:
1139 table_data_raw = cast(
1140 ApdbCassandraTableData,
1141 select_concurrent(
1142 context.session,
1143 statements,
1144 "read_raw_multi",
1145 config.connection_config.read_concurrency,
1146 ),
1147 )
1148 catalog = table_data_raw.to_pandas(context.schema._table_schema(table_name))
1149 timer.add_values(
1150 row_count_from_db=len(catalog), num_sp_part=num_sp_part, num_queries=len(statements)
1151 )
1153 # filter by given object IDs
1154 if len(object_id_set) > 0:
1155 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
1157 # precise filtering on midpointMjdTai
1158 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start])
1160 timer.add_values(row_count=len(catalog))
1162 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
1163 return catalog
1165 def _storeReplicaChunk(self, replica_chunk: ReplicaChunk) -> None:
1166 context = self._context
1167 config = context.config
1169 # Cassandra timestamp uses milliseconds since epoch
1170 timestamp = int(replica_chunk.last_update_time.unix_tai * 1000)
1172 # everything goes into a single partition
1173 partition = 0
1175 table_name = context.schema.tableName(ExtraTables.ApdbReplicaChunks)
1177 columns = ["partition", "apdb_replica_chunk", "last_update_time", "unique_id"]
1178 values = [partition, replica_chunk.id, timestamp, replica_chunk.unique_id]
1179 if context.has_chunk_sub_partitions:
1180 columns.append("has_subchunks")
1181 values.append(True)
1183 query = Insert(self._keyspace, table_name, columns)
1184 stmt = context.stmt_factory(query)
1186 context.session.execute(
1187 stmt,
1188 values,
1189 timeout=config.connection_config.write_timeout,
1190 execution_profile="write",
1191 )
1193 def _queryDiaObjectLastPartitions(self, ids: Iterable[int]) -> Mapping[int, int]:
1194 """Return existing mapping of diaObjectId to its last partition."""
1195 context = self._context
1196 config = context.config
1198 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition)
1199 queries = []
1200 object_count = 0
1201 for id_chunk in chunk_iterable(ids, 10_000):
1202 id_chunk_list = tuple(id_chunk)
1203 query = Select(self._keyspace, table_name, ("diaObjectId", "apdb_part"))
1204 query = query.where(C("diaObjectId").in_(id_chunk_list))
1205 queries.append(context.stmt_factory.with_params(query, prepare=False))
1206 object_count += len(id_chunk_list)
1208 with self._timer("query_object_last_partitions", tags={"table": table_name}) as timer:
1209 data = cast(
1210 ApdbTableData,
1211 select_concurrent(
1212 context.session,
1213 queries,
1214 "read_raw_multi",
1215 config.connection_config.read_concurrency,
1216 ),
1217 )
1218 timer.add_values(object_count=object_count, row_count=len(data.rows()))
1220 if data.column_names() != ["diaObjectId", "apdb_part"]:
1221 raise RuntimeError(f"Unexpected column names in query result: {data.column_names()}")
1223 return {row[0]: row[1] for row in data.rows()}
1225 def _deleteMovingObjects(self, objs: pandas.DataFrame) -> None:
1226 """Objects in DiaObjectsLast can move from one spatial partition to
1227 another. For those objects inserting new version does not replace old
1228 one, so we need to explicitly remove old versions before inserting new
1229 ones.
1230 """
1231 context = self._context
1233 # Extract all object IDs.
1234 new_partitions = dict(zip(objs["diaObjectId"], objs["apdb_part"]))
1235 old_partitions = self._queryDiaObjectLastPartitions(objs["diaObjectId"])
1237 moved_oids: dict[int, tuple[int, int]] = {}
1238 for oid, old_part in old_partitions.items():
1239 new_part = new_partitions.get(oid, old_part)
1240 if new_part != old_part:
1241 moved_oids[oid] = (old_part, new_part)
1242 _LOG.debug("DiaObject IDs that moved to new partition: %s", moved_oids)
1244 if moved_oids:
1245 # Delete old records from DiaObjectLast.
1246 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
1247 query = Delete(self._keyspace, table_name)
1248 query = query.where('apdb_part = {} AND "diaObjectId" = {}', (-1, -1))
1249 statement = context.stmt_factory(query, prepare=True)
1250 queries = []
1251 for oid, (old_part, _) in moved_oids.items():
1252 queries.append((statement, (old_part, oid)))
1253 with self._timer("delete_object_last", tags={"table": table_name}) as timer:
1254 execute_concurrent(context.session, queries, execution_profile="write")
1255 timer.add_values(row_count=len(moved_oids))
1257 # Add all new records to the map.
1258 table_name = context.schema.tableName(ExtraTables.DiaObjectLastToPartition)
1259 insert = Insert(self._keyspace, table_name, ("diaObjectId", "apdb_part"))
1260 statement = context.stmt_factory(insert, prepare=True)
1262 queries = []
1263 for oid, new_part in new_partitions.items():
1264 queries.append((statement, (oid, new_part)))
1266 with self._timer("update_object_last_partition", tags={"table": table_name}) as timer:
1267 execute_concurrent(context.session, queries, execution_profile="write")
1268 timer.add_values(row_count=len(queries))
1270 def _storeDiaObjects(
1271 self, objs: pandas.DataFrame, visit_time: astropy.time.Time, replica_chunk: ReplicaChunk | None
1272 ) -> None:
1273 """Store catalog of DiaObjects from current visit.
1275 Parameters
1276 ----------
1277 objs : `pandas.DataFrame`
1278 Catalog with DiaObject records
1279 visit_time : `astropy.time.Time`
1280 Time of the current visit.
1281 replica_chunk : `ReplicaChunk` or `None`
1282 Replica chunk identifier if replication is configured.
1283 """
1284 if len(objs) == 0:
1285 _LOG.debug("No objects to write to database.")
1286 return
1288 context = self._context
1289 config = context.config
1291 self._deleteMovingObjects(objs)
1293 validity_start_column = self._timestamp_column_name("validityStart")
1294 timestamp = self._timestamp_column_value(visit_time)
1296 # DiaObjectLast did not have this column in the past.
1297 extra_columns: dict[str, Any] = {}
1298 if context.schema.check_column(ApdbTables.DiaObjectLast, validity_start_column):
1299 extra_columns[validity_start_column] = timestamp
1301 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
1303 extra_columns[validity_start_column] = timestamp
1304 visit_time_part = context.partitioner.time_partition(visit_time)
1305 time_part: int | None = visit_time_part
1306 if (time_partitions_range := context.time_partitions_range) is not None:
1307 self._check_time_partitions([visit_time_part], time_partitions_range)
1308 if not config.partitioning.time_partition_tables:
1309 extra_columns["apdb_time_part"] = time_part
1310 time_part = None
1312 # Only store DiaObects if not doing replication or explicitly
1313 # configured to always store them.
1314 if replica_chunk is None or not config.replica_skips_diaobjects:
1315 self._storeObjectsPandas(
1316 objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part
1317 )
1319 if replica_chunk is not None:
1320 extra_columns = {"apdb_replica_chunk": replica_chunk.id, validity_start_column: timestamp}
1321 table = ExtraTables.DiaObjectChunks
1322 if context.has_chunk_sub_partitions:
1323 table = ExtraTables.DiaObjectChunks2
1324 # Use a random number for a second part of partitioning key so
1325 # that different clients could wrtite to different partitions.
1326 # This makes it not exactly reproducible.
1327 extra_columns["apdb_replica_subchunk"] = random.randrange(config.replica_sub_chunk_count)
1328 self._storeObjectsPandas(objs, table, extra_columns=extra_columns)
1330 # Store copy of the records in dedup table.
1331 if context.has_dedup_table:
1332 table = ExtraTables.DiaObjectDedup
1333 extra_columns = {
1334 "dedup_part": random.randrange(config.partitioning.num_part_dedup),
1335 validity_start_column: timestamp,
1336 }
1337 self._storeObjectsPandas(objs, table, extra_columns=extra_columns)
1339 def _storeDiaSources(
1340 self,
1341 table_name: ApdbTables,
1342 sources: pandas.DataFrame,
1343 replica_chunk: ReplicaChunk | None,
1344 ) -> int | None:
1345 """Store catalog of DIASources or DIAForcedSources from current visit.
1347 Parameters
1348 ----------
1349 table_name : `ApdbTables`
1350 Table where to store the data.
1351 sources : `pandas.DataFrame`
1352 Catalog containing DiaSource records
1353 visit_time : `astropy.time.Time`
1354 Time of the current visit.
1355 replica_chunk : `ReplicaChunk` or `None`
1356 Replica chunk identifier if replication is configured.
1358 Returns
1359 -------
1360 subchunk : `int` or `None`
1361 Subchunk number for resulting replica data, `None` if relication is
1362 not enabled ot subchunking is not enabled.
1363 """
1364 context = self._context
1365 config = context.config
1367 # Time partitioning has to be based on midpointMjdTai, not visit_time
1368 # as visit_time is not really a visit time.
1369 tp_sources = sources.copy(deep=False)
1370 tp_sources["apdb_time_part"] = tp_sources["midpointMjdTai"].apply(context.partitioner.time_partition)
1371 if (time_partitions_range := context.time_partitions_range) is not None:
1372 self._check_time_partitions(tp_sources["apdb_time_part"], time_partitions_range)
1373 extra_columns: dict[str, Any] = {}
1374 if not config.partitioning.time_partition_tables:
1375 self._storeObjectsPandas(tp_sources, table_name)
1376 else:
1377 # Group by time partition
1378 partitions = set(tp_sources["apdb_time_part"])
1379 if len(partitions) == 1:
1380 # Single partition - just save the whole thing.
1381 time_part = partitions.pop()
1382 self._storeObjectsPandas(sources, table_name, time_part=time_part)
1383 else:
1384 # group by time partition.
1385 for time_part, sub_frame in tp_sources.groupby(by="apdb_time_part"):
1386 sub_frame.drop(columns="apdb_time_part", inplace=True)
1387 self._storeObjectsPandas(sub_frame, table_name, time_part=time_part)
1389 subchunk: int | None = None
1390 if replica_chunk is not None:
1391 extra_columns = {"apdb_replica_chunk": replica_chunk.id}
1392 if context.has_chunk_sub_partitions:
1393 subchunk = random.randrange(config.replica_sub_chunk_count)
1394 extra_columns["apdb_replica_subchunk"] = subchunk
1395 if table_name is ApdbTables.DiaSource:
1396 extra_table = ExtraTables.DiaSourceChunks2
1397 else:
1398 extra_table = ExtraTables.DiaForcedSourceChunks2
1399 else:
1400 if table_name is ApdbTables.DiaSource:
1401 extra_table = ExtraTables.DiaSourceChunks
1402 else:
1403 extra_table = ExtraTables.DiaForcedSourceChunks
1404 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns)
1406 return subchunk
1408 def _check_time_partitions(
1409 self, partitions: Iterable[int], time_partitions_range: ApdbCassandraTimePartitionRange
1410 ) -> None:
1411 """Check that time partitons for new data actually exist.
1413 Parameters
1414 ----------
1415 partitions : `~collections.abc.Iterable` [`int`]
1416 Time partitions for new data.
1417 time_partitions_range : `ApdbCassandraTimePartitionRange`
1418 Currrent time partition range.
1419 """
1420 partitions = set(partitions)
1421 min_part = min(partitions)
1422 max_part = max(partitions)
1423 if min_part < time_partitions_range.start or max_part > time_partitions_range.end:
1424 raise ValueError(
1425 "Attempt to store data for time partitions that do not yet exist. "
1426 f"Partitons for new records: {min_part}-{max_part}. "
1427 f"Database partitons: {time_partitions_range.start}-{time_partitions_range.end}."
1428 )
1429 # Make a noise when writing to the last partition.
1430 if max_part == time_partitions_range.end:
1431 warnings.warn(
1432 "Writing into the last temporal partition. Partition range needs to be extended soon.",
1433 stacklevel=3,
1434 )
1436 def _storeDiaSourcesPartitions(
1437 self,
1438 sources: pandas.DataFrame,
1439 visit_time: astropy.time.Time,
1440 replica_chunk: ReplicaChunk | None,
1441 subchunk: int | None,
1442 ) -> None:
1443 """Store mapping of diaSourceId to its partitioning values.
1445 Parameters
1446 ----------
1447 sources : `pandas.DataFrame`
1448 Catalog containing DiaSource records
1449 visit_time : `astropy.time.Time`
1450 Time of the current visit.
1451 replica_chunk : `ReplicaChunk` or `None`
1452 Replication chunk, or `None` when replication is disabled.
1453 subchunk : `int` or `None`
1454 Replication sub-chunk, or `None` when replication is disabled or
1455 sub-chunking is not used.
1456 """
1457 context = self._context
1459 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
1460 extra_columns = {
1461 "apdb_time_part": context.partitioner.time_partition(visit_time),
1462 "apdb_replica_chunk": replica_chunk.id if replica_chunk is not None else None,
1463 }
1464 if context.has_chunk_sub_partitions:
1465 extra_columns["apdb_replica_subchunk"] = subchunk
1467 self._storeObjectsPandas(
1468 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
1469 )
1471 def _storeObjectsPandas(
1472 self,
1473 records: pandas.DataFrame,
1474 table_name: ApdbTables | ExtraTables,
1475 extra_columns: Mapping | None = None,
1476 time_part: int | None = None,
1477 ) -> None:
1478 """Store generic objects.
1480 Takes Pandas catalog and stores a bunch of records in a table.
1482 Parameters
1483 ----------
1484 records : `pandas.DataFrame`
1485 Catalog containing object records
1486 table_name : `ApdbTables`
1487 Name of the table as defined in APDB schema.
1488 extra_columns : `dict`, optional
1489 Mapping (column_name, column_value) which gives fixed values for
1490 columns in each row, overrides values in ``records`` if matching
1491 columns exist there.
1492 time_part : `int`, optional
1493 If not `None` then insert into a per-partition table.
1495 Notes
1496 -----
1497 If Pandas catalog contains additional columns not defined in table
1498 schema they are ignored. Catalog does not have to contain all columns
1499 defined in a table, but partition and clustering keys must be present
1500 in a catalog or ``extra_columns``.
1501 """
1502 context = self._context
1504 # use extra columns if specified
1505 if extra_columns is None:
1506 extra_columns = {}
1507 extra_fields = list(extra_columns.keys())
1509 # Fields that will come from dataframe.
1510 df_fields = [column for column in records.columns if column not in extra_fields]
1512 column_map = context.schema.getColumnMap(table_name)
1513 # list of columns (as in felis schema)
1514 fields = [column_map[field].name for field in df_fields if field in column_map]
1515 fields += extra_fields
1517 # check that all partitioning and clustering columns are defined
1518 partition_columns = context.schema.partitionColumns(table_name)
1519 required_columns = partition_columns + context.schema.clusteringColumns(table_name)
1520 missing_columns = [column for column in required_columns if column not in fields]
1521 if missing_columns:
1522 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
1524 batch_size = self._batch_size(table_name)
1526 with self._timer("insert_build_time", tags={"table": table_name.name}):
1527 # Multi-partition batches are problematic in general, so we want to
1528 # group records in a batch by their partition key.
1529 values_by_key: dict[tuple, list[list]] = defaultdict(list)
1530 for rec in records.itertuples(index=False):
1531 values = []
1532 partitioning_values: dict[str, Any] = {}
1533 for field in df_fields:
1534 if field not in column_map:
1535 continue
1536 value = getattr(rec, field)
1537 if column_map[field].datatype is felis.datamodel.DataType.timestamp:
1538 if isinstance(value, pandas.Timestamp):
1539 value = value.to_pydatetime()
1540 elif value is pandas.NaT:
1541 value = None
1542 else:
1543 # Assume it's seconds since epoch, Cassandra
1544 # datetime is in milliseconds
1545 value = int(value * 1000)
1546 value = literal(value)
1547 values.append(UNSET_VALUE if value is None else value)
1548 if field in partition_columns:
1549 partitioning_values[field] = value
1550 for field in extra_fields:
1551 value = literal(extra_columns[field])
1552 values.append(UNSET_VALUE if value is None else value)
1553 if field in partition_columns:
1554 partitioning_values[field] = value
1556 key = tuple(partitioning_values[field] for field in partition_columns)
1557 values_by_key[key].append(values)
1559 table = context.schema.tableName(table_name, time_part)
1561 query = Insert(self._keyspace, table, fields)
1562 statement = context.stmt_factory(query, prepare=True)
1563 # Cassandra has 64k limit on batch size, normally that should be
1564 # enough but some tests generate too many forced sources.
1565 queries = []
1566 for key_values in values_by_key.values():
1567 for values_chunk in chunk_iterable(key_values, batch_size):
1568 batch = cassandra.query.BatchStatement()
1569 for row_values in values_chunk:
1570 batch.add(statement, row_values)
1571 queries.append((batch, None))
1572 assert batch.routing_key is not None and batch.keyspace is not None
1574 _LOG.debug("%s: will store %d records", context.schema.tableName(table_name), records.shape[0])
1575 with self._timer(
1576 "insert_time", tags={"table": table_name.name, "method": "_storeObjectsPandas"}
1577 ) as timer:
1578 execute_concurrent(context.session, queries, execution_profile="write")
1579 timer.add_values(row_count=len(records), num_batches=len(queries))
1581 def _storeUpdateRecords(
1582 self, records: Iterable[ApdbUpdateRecord], chunk: ReplicaChunk, *, store_chunk: bool = False
1583 ) -> None:
1584 """Store ApdbUpdateRecords in the replica table for those records.
1586 Parameters
1587 ----------
1588 records : `list` [`ApdbUpdateRecord`]
1589 Records to store.
1590 chunk : `ReplicaChunk`
1591 Replica chunk for these records.
1592 store_chunk : `bool`
1593 If True then also store replica chunk.
1595 Raises
1596 ------
1597 TypeError
1598 Raised if replication is not enabled for this instance.
1599 """
1600 context = self._context
1601 config = context.config
1603 if not context.schema.replication_enabled:
1604 raise TypeError("Replication is not enabled for this APDB instance.")
1606 if store_chunk:
1607 self._storeReplicaChunk(chunk)
1609 apdb_replica_chunk = chunk.id
1610 # Do not use unique_if from ReplicaChunk as it could be reused in
1611 # multiple calls to this method.
1612 update_unique_id = uuid.uuid4()
1614 rows = []
1615 for record in records:
1616 rows.append(
1617 [
1618 apdb_replica_chunk,
1619 record.update_time_ns,
1620 record.update_order,
1621 update_unique_id,
1622 record.to_json(),
1623 ]
1624 )
1625 columns = [
1626 "apdb_replica_chunk",
1627 "update_time_ns",
1628 "update_order",
1629 "update_unique_id",
1630 "update_payload",
1631 ]
1632 if context.has_chunk_sub_partitions:
1633 subchunk = random.randrange(config.replica_sub_chunk_count)
1634 for row in rows:
1635 row.append(subchunk)
1636 columns.append("apdb_replica_subchunk")
1638 table_name = context.schema.tableName(ExtraTables.ApdbUpdateRecordChunks)
1639 query = Insert(self._keyspace, table_name, columns)
1640 stmt = context.stmt_factory(query)
1641 queries = [(stmt, row) for row in rows]
1643 with self._timer("store_update_record", tags={"table": table_name}) as timer:
1644 execute_concurrent(context.session, queries, execution_profile="write")
1645 timer.add_values(row_count=len(queries))
1647 def _add_apdb_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
1648 """Calculate spatial partition for each record and add it to a
1649 DataFrame.
1651 Parameters
1652 ----------
1653 df : `pandas.DataFrame`
1654 DataFrame which has to contain ra/dec columns, names of these
1655 columns are defined by configuration ``ra_dec_columns`` field.
1657 Returns
1658 -------
1659 df : `pandas.DataFrame`
1660 DataFrame with ``apdb_part`` column which contains pixel index
1661 for ra/dec coordinates.
1663 Notes
1664 -----
1665 This overrides any existing column in a DataFrame with the same name
1666 (``apdb_part``). Original DataFrame is not changed, copy of a DataFrame
1667 is returned.
1668 """
1669 context = self._context
1670 config = context.config
1672 # Calculate pixelization index for every record.
1673 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
1674 ra_col, dec_col = config.ra_dec_columns
1675 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
1676 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
1677 idx = context.partitioner.pixel(uv3d)
1678 apdb_part[i] = idx
1679 df = df.copy()
1680 df["apdb_part"] = apdb_part
1681 return df
1683 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
1684 """Make an empty catalog for a table with a given name.
1686 Parameters
1687 ----------
1688 table_name : `ApdbTables`
1689 Name of the table.
1691 Returns
1692 -------
1693 catalog : `pandas.DataFrame`
1694 An empty catalog.
1695 """
1696 table = self.schema.tableSchemas[table_name]
1698 data = {columnDef.name: pandas.Series(dtype=columnDef.pandas_type) for columnDef in table.columns}
1699 return pandas.DataFrame(data)
1701 def _fix_input_timestamps(self, df: pandas.DataFrame) -> pandas.DataFrame:
1702 """Update timestamp columns in input DataFrame to be naive datetime
1703 type.
1705 Clients may or may not generate aware timestamps, code in this class
1706 assumes that timestamps are naive, so we convert them to UTC and
1707 drop timezone.
1708 """
1709 # Find all columns with aware timestamps.
1710 columns = [column for column, dtype in df.dtypes.items() if isinstance(dtype, pandas.DatetimeTZDtype)]
1711 for column in columns:
1712 # tz_convert(None) will convert to UTC and drop timezone.
1713 df[column] = df[column].dt.tz_convert(None)
1714 return df
1716 def _batch_size(self, table: ApdbTables | ExtraTables) -> int:
1717 """Calculate batch size based on config parameters."""
1718 context = self._context
1719 config = context.config
1721 # Cassandra limit on number of statements in a batch is 64k.
1722 batch_size = 65_535
1723 if 0 < config.batch_statement_limit < batch_size:
1724 batch_size = config.batch_statement_limit
1725 if config.batch_size_limit > 0:
1726 # The purpose of this limit is to try not to exceed batch size
1727 # threshold which is set on server side. Cassandra wire protocol
1728 # for prepared queries (and batches) only sends column values with
1729 # with an additional 4 bytes per value specifying size. Value is
1730 # not included for NULL or NOT_SET values, but the size is always
1731 # there. There is additional small per-query overhead, which we
1732 # ignore.
1733 row_size = context.schema.table_row_size(table)
1734 row_size += 4 * len(context.schema.getColumnMap(table))
1735 batch_size = min(batch_size, (config.batch_size_limit // row_size) + 1)
1736 return batch_size
1738 def _group_dia_objects_by_partition(
1739 self, partitioner: Partitioner, objects: list[DiaObjectId], pad_arcsec: float
1740 ) -> Mapping[int, list[int]]:
1741 """Group DiaObjects by partition.
1743 Parameters
1744 ----------
1745 partitioner : `Partitioner`
1746 Objects which knows how to partition things.
1747 objects : `list` [`DiaObjectId`]
1748 Collection of objects to partition.
1749 pad_arcsec : `float`
1750 Additional padding around object position.
1752 Returns
1753 -------
1754 grouped_objects
1755 Mapping of spatial patition ID to list ob object IDs that it
1756 contains. Some objects may belong to more than one partition.
1757 """
1758 partitioned_object_ids: dict[int, list[int]] = defaultdict(list)
1759 for obj_id in objects:
1760 partitions = partitioner.pixelization.circle_pixels(obj_id.ra, obj_id.dec, pad_arcsec)
1761 for pixel in partitions:
1762 partitioned_object_ids[pixel].append(obj_id.diaObjectId)
1763 return partitioned_object_ids
1765 def _timestamp_column_name(self, column: str) -> str:
1766 """Return column name before/after schema migration to MJD TAI."""
1767 return self._schema.timestamp_column_name(column)
1769 def _timestamp_column_value(self, time: astropy.time.Time) -> float | int:
1770 """Return column value before/after schema migration to MJD TAI."""
1771 if self._schema.has_mjd_timestamps:
1772 return float(time.tai.mjd)
1773 else:
1774 return int(time.datetime.astimezone(tz=datetime.UTC).timestamp() * 1000)
1776 def _get_diasource_data(self, source_ids: Iterable[DiaSourceId], *columns: str) -> list:
1777 """Select records from DiaSource table by diaSourceId and return all
1778 records as a list of named tuples.
1779 """
1780 context = self._context
1781 config = context.config
1782 partitioner = context.partitioner
1784 columns = ("diaSourceId",) + columns
1786 # Allow some uncertainty for coordinates and time when calculating
1787 # partitions.
1788 statements: list[tuple] = []
1789 pad_arcsec = 1.0
1790 pad_time_day = 10 / (24 * 3600)
1791 for source_id in source_ids:
1792 center = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(source_id.ra, source_id.dec))
1793 region = sphgeom.Circle(center, sphgeom.Angle.fromDegrees(pad_arcsec / 3600.0))
1794 spatial_where, _ = partitioner.spatial_where(region)
1796 tables, temporal_where = partitioner.temporal_where(
1797 ApdbTables.DiaSource,
1798 source_id.midpointMjdTai - pad_time_day,
1799 source_id.midpointMjdTai + pad_time_day,
1800 partitons_range=context.time_partitions_range,
1801 query_per_time_part=True,
1802 )
1804 id_where = QExpr('"diaSourceId" = {}', (source_id.diaSourceId,))
1806 for table in tables:
1807 query = Select(self._keyspace, table, columns)
1808 for clause in QExpr.combine(spatial_where, temporal_where, extra=id_where):
1809 statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True))
1811 with self._timer(
1812 "select_time", tags={"table": "DiaSource", "method": "_get_diasource_data"}
1813 ) as timer:
1814 result = cast(
1815 list[tuple],
1816 select_concurrent(
1817 context.session,
1818 statements,
1819 "read_named_tuples",
1820 config.connection_config.read_concurrency,
1821 ),
1822 )
1823 timer.add_values(row_count=len(result), num_queries=len(statements))
1825 return result
1827 def _get_diaobject_data(self, object_ids: Iterable[DiaObjectId], *columns: str) -> list:
1828 """Select records from DiaObjectLast table by diaObjectId and return
1829 all records as a list of named tuples.
1830 """
1831 context = self._context
1832 config = context.config
1833 partitioner = context.partitioner
1835 table_name = context.schema.tableName(ApdbTables.DiaObjectLast)
1836 columns = ("diaObjectId",) + columns
1838 # Allow some uncertainty for coordinates when calculating partitions.
1839 pad_arcsec = 1.0
1840 ids_by_partition = defaultdict(list)
1841 for object_id in object_ids:
1842 pixels = partitioner.pixelization.circle_pixels(object_id.ra, object_id.dec, pad_arcsec)
1843 for pixel in pixels:
1844 ids_by_partition[pixel].append(object_id.diaObjectId)
1846 statements: list[tuple] = []
1847 for apdb_part, diaObjectIds in ids_by_partition.items():
1848 query = Select(self._keyspace, table_name, columns)
1849 query = query.where(C("apdb_part") == apdb_part)
1850 query = query.where(C("diaObjectId").in_(diaObjectIds))
1851 statements.append(context.stmt_factory.with_params(query, prepare=False))
1853 with self._timer(
1854 "select_time", tags={"table": "DiaObjectLast", "method": "_get_diaobject_data"}
1855 ) as timer:
1856 result = cast(
1857 list[tuple],
1858 select_concurrent(
1859 context.session,
1860 statements,
1861 "read_named_tuples",
1862 config.connection_config.read_concurrency,
1863 ),
1864 )
1865 timer.add_values(row_count=len(result), num_queries=len(statements))
1867 return result