Coverage for python / felis / datamodel.py: 30%
738 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-16 07:52 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-16 07:52 +0000
1"""Define Pydantic data models for Felis."""
3# This file is part of felis.
4#
5# Developed for the LSST Data Management System.
6# This product includes software developed by the LSST Project
7# (https://www.lsst.org).
8# See the COPYRIGHT file at the top-level directory of this distribution
9# for details of code ownership.
10#
11# This program is free software: you can redistribute it and/or modify
12# it under the terms of the GNU General Public License as published by
13# the Free Software Foundation, either version 3 of the License, or
14# (at your option) any later version.
15#
16# This program is distributed in the hope that it will be useful,
17# but WITHOUT ANY WARRANTY; without even the implied warranty of
18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19# GNU General Public License for more details.
20#
21# You should have received a copy of the GNU General Public License
22# along with this program. If not, see <https://www.gnu.org/licenses/>.
24from __future__ import annotations
26import json
27import logging
28import sys
29from collections.abc import Sequence
30from enum import StrEnum, auto
31from operator import itemgetter
32from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar
34import yaml
35from astropy import units as units # type: ignore
36from astropy.io.votable import ucd # type: ignore
37from lsst.resources import ResourcePath, ResourcePathExpression
38from pydantic import (
39 BaseModel,
40 ConfigDict,
41 Field,
42 PrivateAttr,
43 ValidationError,
44 ValidationInfo,
45 field_serializer,
46 field_validator,
47 model_validator,
48)
49from pydantic_core import InitErrorDetails
51from .db._dialects import get_supported_dialects, string_to_typeengine
52from .db._sqltypes import get_type_func
53from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
55logger = logging.getLogger(__name__)
57__all__ = (
58 "BaseObject",
59 "CheckConstraint",
60 "Column",
61 "ColumnOverrides",
62 "ColumnResourceRef",
63 "Constraint",
64 "DataType",
65 "ForeignKeyConstraint",
66 "Index",
67 "Resource",
68 "Schema",
69 "SchemaVersion",
70 "Table",
71 "UniqueConstraint",
72)
74CONFIG = ConfigDict(
75 populate_by_name=True, # Populate attributes by name.
76 extra="forbid", # Do not allow extra fields.
77 str_strip_whitespace=True, # Strip whitespace from string fields.
78 use_enum_values=False, # Do not use enum values during serialization.
79)
80"""Pydantic model configuration as described in:
81https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict
82"""
84DESCR_MIN_LENGTH = 3
85"""Minimum length for a description field."""
87DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)]
88"""Type for a description, which must be three or more characters long."""
91class BaseObject(BaseModel):
92 """Base model.
94 All classes representing objects in the Felis data model should inherit
95 from this class.
96 """
98 model_config = CONFIG
99 """Pydantic model configuration."""
101 name: str
102 """Name of the database object."""
104 id: str = Field(alias="@id")
105 """Unique identifier of the database object."""
107 description: DescriptionStr | None = None
108 """Description of the database object."""
110 votable_utype: str | None = Field(None, alias="votable:utype")
111 """VOTable utype (usage-specific or unique type) of the object."""
113 @model_validator(mode="after")
114 def check_description(self, info: ValidationInfo) -> BaseObject:
115 """Check that the description is present if required.
117 Parameters
118 ----------
119 info
120 Validation context used to determine if the check is enabled.
122 Returns
123 -------
124 `BaseObject`
125 The object being validated.
126 """
127 context = info.context
128 if not context or not context.get("check_description", False):
129 return self
130 if self.description is None or self.description == "":
131 raise ValueError("Description is required and must be non-empty")
132 if len(self.description) < DESCR_MIN_LENGTH:
133 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long")
134 return self
137class DataType(StrEnum):
138 """``Enum`` representing the data types supported by Felis."""
140 boolean = auto()
141 byte = auto()
142 short = auto()
143 int = auto()
144 long = auto()
145 float = auto()
146 double = auto()
147 char = auto()
148 string = auto()
149 unicode = auto()
150 text = auto()
151 binary = auto()
152 timestamp = auto()
155def validate_ivoa_ucd(ivoa_ucd: str) -> str:
156 """Validate IVOA UCD values.
158 Parameters
159 ----------
160 ivoa_ucd
161 IVOA UCD value to check.
163 Returns
164 -------
165 `str`
166 The IVOA UCD value if it is valid.
168 Raises
169 ------
170 ValueError
171 If the IVOA UCD value is invalid.
172 """
173 if ivoa_ucd is not None:
174 try:
175 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
176 except ValueError as e:
177 raise ValueError(f"Invalid IVOA UCD: {e}")
178 return ivoa_ucd
181class Column(BaseObject):
182 """Column model."""
184 datatype: DataType
185 """Datatype of the column."""
187 length: int | None = Field(None, gt=0)
188 """Length of the column."""
190 precision: int | None = Field(None, ge=0)
191 """The numerical precision of the column.
193 For timestamps, this is the number of fractional digits retained in the
194 seconds field.
195 """
197 nullable: bool = True
198 """Whether the column can be ``NULL``."""
200 value: str | int | float | bool | None = None
201 """Default value of the column."""
203 autoincrement: bool | None = None
204 """Whether the column is autoincremented."""
206 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
207 """IVOA UCD of the column."""
209 fits_tunit: str | None = Field(None, alias="fits:tunit")
210 """FITS TUNIT of the column."""
212 ivoa_unit: str | None = Field(None, alias="ivoa:unit")
213 """IVOA unit of the column."""
215 tap_column_index: int | None = Field(None, alias="tap:column_index")
216 """TAP_SCHEMA column index of the column."""
218 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1)
219 """Whether this is a TAP_SCHEMA principal column."""
221 votable_arraysize: int | str | None = Field(None, alias="votable:arraysize")
222 """VOTable arraysize of the column."""
224 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1)
225 """TAP_SCHEMA indication that this column is defined by an IVOA standard.
226 """
228 votable_xtype: str | None = Field(None, alias="votable:xtype")
229 """VOTable xtype (extended type) of the column."""
231 votable_datatype: str | None = Field(None, alias="votable:datatype")
232 """VOTable datatype of the column."""
234 mysql_datatype: str | None = Field(None, alias="mysql:datatype")
235 """MySQL datatype override on the column."""
237 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
238 """PostgreSQL datatype override on the column."""
240 _is_resource_ref: bool = PrivateAttr(False)
241 """Whether this column is a resource reference column."""
243 @model_validator(mode="after")
244 def check_value(self) -> Column:
245 """Check that the default value is valid.
247 Returns
248 -------
249 `Column`
250 The column being validated.
251 """
252 if (value := self.value) is not None:
253 if value is not None and self.autoincrement is True:
254 raise ValueError("Column cannot have both a default value and be autoincremented")
255 felis_type = FelisType.felis_type(self.datatype)
256 if felis_type.is_numeric:
257 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int):
258 raise ValueError("Default value must be an int for integer type columns")
259 elif felis_type in (Float, Double) and not isinstance(value, float):
260 raise ValueError("Default value must be a decimal number for float and double columns")
261 elif felis_type in (String, Char, Unicode, Text):
262 if not isinstance(value, str):
263 raise ValueError("Default value must be a string for string columns")
264 if not len(value):
265 raise ValueError("Default value must be a non-empty string for string columns")
266 elif felis_type is Boolean and not isinstance(value, bool):
267 raise ValueError("Default value must be a boolean for boolean columns")
268 return self
270 @field_validator("ivoa_ucd")
271 @classmethod
272 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
273 """Check that IVOA UCD values are valid.
275 Parameters
276 ----------
277 ivoa_ucd
278 IVOA UCD value to check.
280 Returns
281 -------
282 `str`
283 The IVOA UCD value if it is valid.
284 """
285 return validate_ivoa_ucd(ivoa_ucd)
287 @model_validator(mode="after")
288 def check_units(self) -> Column:
289 """Check that the ``fits:tunit`` or ``ivoa:unit`` field has valid
290 units according to astropy. Only one may be provided.
292 Returns
293 -------
294 `Column`
295 The column being validated.
297 Raises
298 ------
299 ValueError
300 Raised if both FITS and IVOA units are provided, or if the unit is
301 invalid.
302 """
303 fits_unit = self.fits_tunit
304 ivoa_unit = self.ivoa_unit
306 if fits_unit and ivoa_unit:
307 raise ValueError("Column cannot have both FITS and IVOA units")
308 unit = fits_unit or ivoa_unit
310 if unit is not None:
311 try:
312 units.Unit(unit)
313 except ValueError as e:
314 raise ValueError(f"Invalid unit: {e}")
316 return self
318 @model_validator(mode="before")
319 @classmethod
320 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]:
321 """Check that a valid length is provided for sized types.
323 Parameters
324 ----------
325 values
326 Values of the column.
328 Returns
329 -------
330 `dict` [ `str`, `Any` ]
331 The values of the column.
333 Raises
334 ------
335 ValueError
336 Raised if a length is not provided for a sized type.
337 """
338 datatype = values.get("datatype")
339 if datatype is None:
340 # Skip this validation if datatype is not provided
341 return values
342 length = values.get("length")
343 felis_type = FelisType.felis_type(datatype)
344 if felis_type.is_sized and length is None:
345 raise ValueError(
346 f"Length must be provided for type '{datatype}'"
347 + (f" in column '{values['@id']}'" if "@id" in values else "")
348 )
349 elif not felis_type.is_sized and length is not None:
350 msg = f"The datatype '{datatype}' does not support a specified length"
351 if "@id" in values:
352 msg += f" in column '{values['@id']}'"
353 logger.warning("%s", msg)
354 return values
356 @model_validator(mode="after")
357 def check_redundant_datatypes(self, info: ValidationInfo) -> Column:
358 """Check for redundant datatypes on columns.
360 Parameters
361 ----------
362 info
363 Validation context used to determine if the check is enabled.
365 Returns
366 -------
367 `Column`
368 The column being validated.
370 Raises
371 ------
372 ValueError
373 Raised if a datatype override is redundant.
374 """
375 context = info.context
376 if not context or not context.get("check_redundant_datatypes", False):
377 return self
378 if all(
379 getattr(self, f"{dialect}:datatype", None) is not None
380 for dialect in get_supported_dialects().keys()
381 ):
382 return self
384 datatype = self.datatype
385 length: int | None = self.length or None
387 datatype_func = get_type_func(datatype)
388 felis_type = FelisType.felis_type(datatype)
389 if felis_type.is_sized:
390 datatype_obj = datatype_func(length)
391 else:
392 datatype_obj = datatype_func()
394 for dialect_name, dialect in get_supported_dialects().items():
395 db_annotation = f"{dialect_name}_datatype"
396 if datatype_string := self.model_dump().get(db_annotation):
397 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
398 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
399 raise ValueError(
400 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format(
401 db_annotation,
402 datatype_string,
403 self.datatype,
404 self.id,
405 "" if length is None else f" with length {length}",
406 )
407 )
408 else:
409 logger.debug(
410 "Type override of 'datatype: %s' with '%s: %s' in column '%s' "
411 "compiled to '%s' and '%s'",
412 self.datatype,
413 db_annotation,
414 datatype_string,
415 self.id,
416 datatype_obj.compile(dialect),
417 db_datatype_obj.compile(dialect),
418 )
419 return self
421 @model_validator(mode="after")
422 def check_precision(self) -> Column:
423 """Check that precision is only valid for timestamp columns.
425 Returns
426 -------
427 `Column`
428 The column being validated.
429 """
430 if self.precision is not None and self.datatype != "timestamp":
431 raise ValueError("Precision is only valid for timestamp columns")
432 return self
434 @model_validator(mode="before")
435 @classmethod
436 def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
437 """Set the default value for the ``votable_arraysize`` field, which
438 corresponds to ``arraysize`` in the IVOA VOTable standard.
440 Parameters
441 ----------
442 values
443 Values of the column.
444 info
445 Validation context used to determine if the check is enabled.
447 Returns
448 -------
449 `dict` [ `str`, `Any` ]
450 The values of the column.
452 Notes
453 -----
454 Following the IVOA VOTable standard, an ``arraysize`` of 1 should not
455 be used.
456 """
457 if values.get("name", None) is None or values.get("datatype", None) is None:
458 # Skip bad column data that will not validate
459 return values
460 context = info.context if info.context else {}
461 arraysize = values.get("votable:arraysize", None)
462 if arraysize is None:
463 length = values.get("length", None)
464 datatype = values.get("datatype")
465 if length is not None and length > 1:
466 # Following the IVOA standard, arraysize of 1 is disallowed
467 if datatype == "char":
468 arraysize = str(length)
469 elif datatype in ("string", "unicode", "binary"):
470 if context.get("force_unbounded_arraysize", False):
471 arraysize = "*"
472 logger.debug(
473 "Forced VOTable's 'arraysize' to '*' on column '%s' with datatype "
474 "'%s' and length '%s'",
475 values["name"],
476 values["datatype"],
477 length,
478 )
479 else:
480 arraysize = f"{length}*"
481 elif datatype in ("timestamp", "text"):
482 arraysize = "*"
483 if arraysize is not None:
484 values["votable:arraysize"] = arraysize
485 logger.debug(
486 "Set default 'votable:arraysize' to '%s' on column '%s'"
487 " with datatype '%s' and length '%s'",
488 arraysize,
489 values["name"],
490 values["datatype"],
491 values.get("length", None),
492 )
493 else:
494 logger.debug(
495 "Using existing 'votable:arraysize' of '%s' on column '%s'", arraysize, values["name"]
496 )
497 if isinstance(values["votable:arraysize"], int):
498 logger.warning(
499 "Usage of an integer value for 'votable:arraysize' in column '%s' is deprecated",
500 values["name"],
501 )
502 values["votable:arraysize"] = str(arraysize)
503 return values
505 @field_serializer("datatype")
506 def serialize_datatype(self, value: DataType) -> str:
507 """Convert `DataType` to string when serializing to JSON/YAML.
509 Parameters
510 ----------
511 value
512 The `DataType` value to serialize.
514 Returns
515 -------
516 `str`
517 The serialized `DataType` value.
518 """
519 return str(value)
521 @field_validator("datatype", mode="before")
522 @classmethod
523 def deserialize_datatype(cls, value: str) -> DataType:
524 """Convert string back into `DataType` when loading from JSON/YAML.
526 Parameters
527 ----------
528 value
529 The string value to deserialize.
531 Returns
532 -------
533 `DataType`
534 The deserialized `DataType` value.
535 """
536 return DataType(value)
538 @model_validator(mode="after")
539 def check_votable_xtype(self) -> Column:
540 """Set the default value for the ``votable_xtype`` field, which
541 corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable
542 standard.
544 Returns
545 -------
546 `Column`
547 The column being validated.
549 Notes
550 -----
551 This is currently only set automatically for the Felis ``timestamp``
552 datatype.
553 """
554 if self.datatype == DataType.timestamp and self.votable_xtype is None:
555 self.votable_xtype = "timestamp"
556 return self
558 def _update_from_overrides(self, overrides: ColumnOverrides) -> None:
559 """Update the column attributes from the given overrides.
561 Parameters
562 ----------
563 overrides
564 The column overrides to apply or `None` to skip applying overrides.
566 Notes
567 -----
568 Using ``model_fields_set`` allows updating only the fields that are
569 explicitly set in the `overrides` object. This prevents overwriting
570 existing column attributes which were not explicitly provided.
571 """
572 if overrides.model_fields_set:
573 logger.debug("Applying overrides to column '%s': %s", self.id, overrides.model_fields_set)
574 for field in overrides.model_fields_set:
575 setattr(self, field, getattr(overrides, field))
578class Constraint(BaseObject):
579 """Table constraint model."""
581 deferrable: bool = False
582 """Whether this constraint will be declared as deferrable."""
584 initially: Literal["IMMEDIATE", "DEFERRED"] | None = None
585 """Value for ``INITIALLY`` clause; only used if `deferrable` is
586 `True`."""
588 @model_validator(mode="after")
589 def check_deferrable(self) -> Constraint:
590 """Check that the ``INITIALLY`` clause is only used if `deferrable` is
591 `True`.
593 Returns
594 -------
595 `Constraint`
596 The constraint being validated.
597 """
598 if self.initially is not None and not self.deferrable:
599 raise ValueError("INITIALLY clause can only be used if deferrable is True")
600 return self
603class CheckConstraint(Constraint):
604 """Table check constraint model."""
606 type: Literal["Check"] = Field("Check", alias="@type")
607 """Type of the constraint."""
609 expression: str
610 """Expression for the check constraint."""
612 @field_serializer("type")
613 def serialize_type(self, value: str) -> str:
614 """Ensure '@type' is included in serialized output.
616 Parameters
617 ----------
618 value
619 The value to serialize.
621 Returns
622 -------
623 `str`
624 The serialized value.
625 """
626 return value
629class UniqueConstraint(Constraint):
630 """Table unique constraint model."""
632 type: Literal["Unique"] = Field("Unique", alias="@type")
633 """Type of the constraint."""
635 columns: list[str]
636 """Columns in the unique constraint."""
638 @field_serializer("type")
639 def serialize_type(self, value: str) -> str:
640 """Ensure '@type' is included in serialized output.
642 Parameters
643 ----------
644 value
645 The value to serialize.
647 Returns
648 -------
649 `str`
650 The serialized value.
651 """
652 return value
655class ForeignKeyConstraint(Constraint):
656 """Table foreign key constraint model.
658 This constraint is used to define a foreign key relationship between two
659 tables in the schema. There must be at least one column in the
660 `columns` list, and at least one column in the `referenced_columns` list
661 or a validation error will be raised.
663 Notes
664 -----
665 These relationships will be reflected in the TAP_SCHEMA ``keys`` and
666 ``key_columns`` data.
667 """
669 type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type")
670 """Type of the constraint."""
672 columns: list[str] = Field(min_length=1)
673 """The columns comprising the foreign key."""
675 referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1)
676 """The columns referenced by the foreign key."""
678 on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
679 """Action to take when the referenced row is deleted."""
681 on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
682 """Action to take when the referenced row is updated."""
684 @field_serializer("type")
685 def serialize_type(self, value: str) -> str:
686 """Ensure '@type' is included in serialized output.
688 Parameters
689 ----------
690 value
691 The value to serialize.
693 Returns
694 -------
695 `str`
696 The serialized value.
697 """
698 return value
700 @model_validator(mode="after")
701 def check_column_lengths(self) -> ForeignKeyConstraint:
702 """Check that the `columns` and `referenced_columns` lists have the
703 same length.
705 Returns
706 -------
707 `ForeignKeyConstraint`
708 The foreign key constraint being validated.
710 Raises
711 ------
712 ValueError
713 Raised if the `columns` and `referenced_columns` lists do not have
714 the same length.
715 """
716 if len(self.columns) != len(self.referenced_columns):
717 raise ValueError(
718 "Columns and referencedColumns must have the same length for a ForeignKey constraint"
719 )
720 return self
723_ConstraintType = Annotated[
724 CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type")
725]
726"""Type alias for a constraint type."""
729class Index(BaseObject):
730 """Table index model.
732 An index can be defined on either columns or expressions, but not both.
733 """
735 columns: list[str] | None = None
736 """Columns in the index."""
738 expressions: list[str] | None = None
739 """Expressions in the index."""
741 @model_validator(mode="before")
742 @classmethod
743 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]:
744 """Check that columns or expressions are specified, but not both.
746 Parameters
747 ----------
748 values
749 Values of the index.
751 Returns
752 -------
753 `dict` [ `str`, `Any` ]
754 The values of the index.
756 Raises
757 ------
758 ValueError
759 Raised if both columns and expressions are specified, or if neither
760 are specified.
761 """
762 if "columns" in values and "expressions" in values:
763 raise ValueError("Defining columns and expressions is not valid")
764 elif "columns" not in values and "expressions" not in values:
765 raise ValueError("Must define columns or expressions")
766 return values
769ColumnRef: TypeAlias = str
770"""Type alias for a column reference."""
773class ColumnGroup(BaseObject):
774 """Column group model."""
776 columns: list[ColumnRef | Column] = Field(..., min_length=1)
777 """Columns in the group."""
779 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
780 """IVOA UCD of the column."""
782 table: Table | None = Field(None, exclude=True)
783 """Reference to the parent table."""
785 @field_validator("ivoa_ucd")
786 @classmethod
787 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
788 """Check that IVOA UCD values are valid.
790 Parameters
791 ----------
792 ivoa_ucd
793 IVOA UCD value to check.
795 Returns
796 -------
797 `str`
798 The IVOA UCD value if it is valid.
799 """
800 return validate_ivoa_ucd(ivoa_ucd)
802 @model_validator(mode="after")
803 def check_unique_columns(self) -> ColumnGroup:
804 """Check that the columns list contains unique items.
806 Returns
807 -------
808 `ColumnGroup`
809 The column group being validated.
810 """
811 column_ids = [col if isinstance(col, str) else col.id for col in self.columns]
812 if len(column_ids) != len(set(column_ids)):
813 raise ValueError("Columns in the group must be unique")
814 return self
816 def _dereference_columns(self) -> None:
817 """Dereference ColumnRef to Column objects."""
818 if self.table is None:
819 raise ValueError("ColumnGroup must have a reference to its parent table")
821 dereferenced_columns: list[ColumnRef | Column] = []
822 for col in self.columns:
823 if isinstance(col, str):
824 # Dereference ColumnRef to Column object
825 try:
826 col_obj = self.table._find_column_by_id(col)
827 except KeyError as e:
828 raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e
829 dereferenced_columns.append(col_obj)
830 else:
831 dereferenced_columns.append(col)
833 self.columns = dereferenced_columns
835 @field_serializer("columns")
836 def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]:
837 """Serialize columns as their IDs.
839 Parameters
840 ----------
841 columns
842 The columns to serialize.
844 Returns
845 -------
846 `list` [ `str` ]
847 The serialized column IDs.
848 """
849 return [col if isinstance(col, str) else col.id for col in columns]
852class ColumnOverrides(BaseModel):
853 """Allowed overrides for a referenced column.
855 Notes
856 -----
857 All of these fields are optional. Values of None may be explicitly set to
858 override the corresponding attribute in the referenced column but only
859 for certain fields (see validation in `_check_non_nullable_overrides`).
860 """
862 model_config = CONFIG.copy()
864 datatype: DataType | None = None
865 """New datatype for the column."""
867 length: int | None = None
868 """New length for the column."""
870 description: str | None = None
871 """New description for the column."""
873 nullable: bool | None = None
874 """New nullable flag for the column."""
876 tap_principal: int | None = Field(default=None, alias="tap:principal")
877 """Override for the TAP_SCHEMA 'principal' flag."""
879 tap_column_index: int | None = Field(default=None, alias="tap:column_index")
880 """Override for the TAP_SCHEMA column index."""
882 @model_validator(mode="before")
883 @classmethod
884 def _check_non_nullable_overrides(cls, data: Any) -> Any:
885 """Check that certain fields are not overridden to null."""
886 if not isinstance(data, dict):
887 return data
888 non_nullable_fields = ("datatype", "length", "nullable", "tap_principal")
889 for name in non_nullable_fields:
890 if name in data and data[name] is None:
891 raise ValueError(f"The '{name}' field cannot be overridden to null")
892 return data
894 @field_serializer("datatype")
895 def serialize_datatype(self, value: DataType | None) -> str | None:
896 """Convert `DataType` to string when serializing to JSON/YAML.
898 Parameters
899 ----------
900 value
901 The `DataType` value to serialize, or None.
903 Returns
904 -------
905 `str` | None
906 The serialized `DataType` value, or None if the input was None.
907 """
908 if value is None:
909 return None
910 return str(value)
912 @field_validator("datatype", mode="before")
913 @classmethod
914 def deserialize_datatype(cls, value: str | None) -> DataType | None:
915 """Convert string back into `DataType` when loading from JSON/YAML.
917 Parameters
918 ----------
919 value
920 The string value to deserialize, or None.
922 Returns
923 -------
924 `DataType` | None
925 The deserialized `DataType` value, or None if the input was None.
926 """
927 if value is None:
928 return None
929 return DataType(value)
932class ColumnResourceRef(BaseModel):
933 """A column which is dervived from an external resource."""
935 ref_name: str | None = None
936 """Name of the referenced column in the resource
937 (if different from the key)."""
939 overrides: ColumnOverrides | None = None
940 """Optional overrides of the referenced column's attributes."""
943# Type aliases for the nested mapping structure of referenced columns
944ResourceColumnMap: TypeAlias = dict[str, ColumnResourceRef | None]
945ResourceTableMap: TypeAlias = dict[str, ResourceColumnMap]
946ResourceMap: TypeAlias = dict[str, ResourceTableMap]
949class Table(BaseObject):
950 """Table model."""
952 primary_key: str | list[str] | None = Field(None, alias="primaryKey")
953 """Primary key of the table."""
955 tap_table_index: int | None = Field(None, alias="tap:table_index")
956 """IVOA TAP_SCHEMA table index of the table."""
958 mysql_engine: str | None = Field("MyISAM", alias="mysql:engine")
959 """MySQL engine to use for the table."""
961 mysql_charset: str | None = Field(None, alias="mysql:charset")
962 """MySQL charset to use for the table."""
964 column_refs: ResourceMap = Field(default_factory=dict, alias="columnRefs")
965 """Referenced columns from external resources."""
967 columns: list[Column] = Field(default_factory=list)
968 """Columns in the table."""
970 column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups")
971 """Column groups in the table."""
973 constraints: list[_ConstraintType] = Field(default_factory=list)
974 """Constraints on the table."""
976 indexes: list[Index] = Field(default_factory=list)
977 """Indexes on the table."""
979 @field_validator("columns", mode="after")
980 @classmethod
981 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
982 """Check that column names are unique.
984 Parameters
985 ----------
986 columns
987 The columns to check.
989 Returns
990 -------
991 `list` [ `Column` ]
992 The columns if they are unique.
994 Raises
995 ------
996 ValueError
997 Raised if column names are not unique.
998 """
999 if len(columns) != len(set(column.name for column in columns)):
1000 raise ValueError("Column names must be unique")
1001 return columns
1003 @model_validator(mode="after")
1004 def check_tap_table_index(self, info: ValidationInfo) -> Table:
1005 """Check that the table has a TAP table index.
1007 Parameters
1008 ----------
1009 info
1010 Validation context used to determine if the check is enabled.
1012 Returns
1013 -------
1014 `Table`
1015 The table being validated.
1017 Raises
1018 ------
1019 ValueError
1020 Raised If the table is missing a TAP table index.
1021 """
1022 context = info.context
1023 if not context or not context.get("check_tap_table_indexes", False):
1024 return self
1025 if self.tap_table_index is None:
1026 raise ValueError("Table is missing a TAP table index")
1027 return self
1029 @model_validator(mode="after")
1030 def check_tap_principal(self, info: ValidationInfo) -> Table:
1031 """Check that at least one column is flagged as 'principal' for TAP
1032 purposes.
1034 Parameters
1035 ----------
1036 info
1037 Validation context used to determine if the check is enabled.
1039 Returns
1040 -------
1041 `Table`
1042 The table being validated.
1044 Raises
1045 ------
1046 ValueError
1047 Raised if the table is missing a column flagged as 'principal'.
1048 """
1049 context = info.context
1050 if not context or not context.get("check_tap_principal", False):
1051 return self
1052 for col in self.columns:
1053 if col.tap_principal == 1:
1054 return self
1055 raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
1057 def _find_column_by_id(self, id: str) -> Column:
1058 """Find a column by ID.
1060 Parameters
1061 ----------
1062 id
1063 The ID of the column to find.
1065 Returns
1066 -------
1067 `Column`
1068 The column with the given ID.
1070 Raises
1071 ------
1072 ValueError
1073 Raised if the column is not found.
1074 """
1075 for column in self.columns:
1076 if column.id == id:
1077 return column
1078 raise KeyError(f"Column '{id}' not found in table '{self.name}'")
1080 def _find_column_by_name(self, name: str) -> Column:
1081 for column in self.columns:
1082 if column.name == name:
1083 return column
1084 raise KeyError(f"Column '{name}' not found in table '{self.name}'")
1086 @model_validator(mode="after")
1087 def dereference_column_groups(self: Table) -> Table:
1088 """Dereference columns in column groups.
1090 Returns
1091 -------
1092 `Table`
1093 The table with dereferenced column groups.
1094 """
1095 for group in self.column_groups:
1096 group.table = self
1097 group._dereference_columns()
1098 return self
1100 @field_serializer("columns")
1101 def _serialize_columns(self, columns: list[Column]) -> list[dict[str, Any]]:
1102 """Serialize only non-resource columns."""
1103 return [
1104 col.model_dump(
1105 by_alias=True,
1106 exclude_none=True,
1107 exclude_defaults=True,
1108 )
1109 for col in columns
1110 if not col._is_resource_ref
1111 ]
1114class SchemaVersion(BaseModel):
1115 """Schema version model."""
1117 current: str
1118 """The current version of the schema."""
1120 compatible: list[str] = Field(default_factory=list)
1121 """The compatible versions of the schema."""
1123 read_compatible: list[str] = Field(default_factory=list)
1124 """The read compatible versions of the schema."""
1127class SchemaIdVisitor:
1128 """Visit a schema and build the map of IDs to objects.
1130 Notes
1131 -----
1132 Duplicates are added to a set when they are encountered, which can be
1133 accessed via the ``duplicates`` attribute. The presence of duplicates will
1134 not throw an error. Only the first object with a given ID will be added to
1135 the map, but this should not matter, since a ``ValidationError`` will be
1136 thrown by the ``model_validator`` method if any duplicates are found in the
1137 schema.
1138 """
1140 def __init__(self) -> None:
1141 """Create a new SchemaVisitor."""
1142 self.schema: Schema | None = None
1143 self.duplicates: set[str] = set()
1145 def add(self, obj: BaseObject) -> None:
1146 """Add an object to the ID map.
1148 Parameters
1149 ----------
1150 obj
1151 The object to add to the ID map.
1152 """
1153 if hasattr(obj, "id"):
1154 obj_id = getattr(obj, "id")
1155 if self.schema is not None:
1156 if obj_id in self.schema._id_map:
1157 self.duplicates.add(obj_id)
1158 else:
1159 self.schema._id_map[obj_id] = obj
1161 def visit_schema(self, schema: Schema) -> None:
1162 """Visit the objects in a schema and build the ID map.
1164 Parameters
1165 ----------
1166 schema
1167 The schema object to visit.
1169 Notes
1170 -----
1171 This will set an internal variable pointing to the schema object.
1172 """
1173 self.schema = schema
1174 self.duplicates.clear()
1175 self.add(self.schema)
1176 for table in self.schema.tables:
1177 self.visit_table(table)
1179 def visit_table(self, table: Table) -> None:
1180 """Visit a table object.
1182 Parameters
1183 ----------
1184 table
1185 The table object to visit.
1186 """
1187 self.add(table)
1188 for column in table.columns:
1189 self.visit_column(column)
1190 for constraint in table.constraints:
1191 self.visit_constraint(constraint)
1193 def visit_column(self, column: Column) -> None:
1194 """Visit a column object.
1196 Parameters
1197 ----------
1198 column
1199 The column object to visit.
1200 """
1201 self.add(column)
1203 def visit_constraint(self, constraint: Constraint) -> None:
1204 """Visit a constraint object.
1206 Parameters
1207 ----------
1208 constraint
1209 The constraint object to visit.
1210 """
1211 self.add(constraint)
1214T = TypeVar("T", bound=BaseObject)
1217def _strip_ids(data: Any) -> Any:
1218 """Recursively strip '@id' fields from a dictionary or list.
1220 Parameters
1221 ----------
1222 data
1223 The data to strip IDs from, which can be a dictionary, list, or any
1224 other type. Other types will be returned unchanged.
1225 """
1226 if isinstance(data, dict):
1227 data.pop("@id", None)
1228 for k, v in data.items():
1229 data[k] = _strip_ids(v)
1230 return data
1231 elif isinstance(data, list):
1232 return [_strip_ids(item) for item in data]
1233 else:
1234 return data
1237def _append_error(
1238 errors: list[InitErrorDetails],
1239 loc: tuple,
1240 input_value: Any,
1241 error_message: str,
1242 error_type: str = "value_error",
1243) -> None:
1244 """Append an error to the errors list.
1246 Parameters
1247 ----------
1248 errors : list[InitErrorDetails]
1249 The list of errors to append to.
1250 loc : tuple
1251 The location of the error in the schema.
1252 input_value : Any
1253 The input value that caused the error.
1254 error_message : str
1255 The error message to include in the context.
1256 """
1257 errors.append(
1258 {
1259 "type": error_type,
1260 "loc": loc,
1261 "input": input_value,
1262 "ctx": {"error": error_message},
1263 }
1264 )
1267class Resource(BaseModel):
1268 """A resource definition referencing an external schema."""
1270 uri: str = Field(..., description="Resource URI or path")
1271 """URI of the schema resource which may be a local path, ``resource://``,
1272 or remote URL."""
1275class Schema(BaseObject, Generic[T]):
1276 """Database schema model.
1278 This represents a database schema, which contains one or more tables.
1279 """
1281 version: SchemaVersion | str | None = None
1282 """The version of the schema."""
1284 resources: dict[str, Resource] = Field(default_factory=dict)
1285 """External resources referenced by this schema."""
1287 tables: Sequence[Table]
1288 """The tables in the schema."""
1290 _id_map: dict[str, Any] = PrivateAttr(default_factory=dict)
1291 """Map of IDs to objects."""
1293 _resource_map: dict[str, Schema] = PrivateAttr(default_factory=dict)
1294 """Map of resource names to loaded schemas."""
1296 @model_validator(mode="after")
1297 def _load_resources(self: Schema, info: ValidationInfo) -> Schema:
1298 """Load external resources referenced by this schema into an internal
1299 mapping of resource names to their `Schema` objects.
1301 Returns
1302 -------
1303 `Schema`
1304 The schema being validated.
1306 Raises
1307 ------
1308 ValueError
1309 Raised if a resource cannot be loaded.
1310 """
1311 if info.context:
1312 context = info.context.copy()
1313 # Ignore this flag for loading the resources themselves
1314 context.pop("dereference_resources", None)
1315 else:
1316 context = {}
1318 # Get the base URI for resolving relative resource paths from the
1319 # validation context, if available.
1320 resource_path = context.pop("resource_path", None)
1321 base_uri = None
1322 if resource_path is not None:
1323 base_uri = resource_path.parent()
1325 for resource_name, resource in self.resources.items():
1326 uri = resource.uri
1328 # Apply the base URI to the resource URI, if available.
1329 if base_uri is not None:
1330 orig_uri = uri
1331 uri = base_uri.join(uri, forceDirectory=False)
1332 if uri != orig_uri:
1333 logger.info(
1334 "Resolved relative URI '%s' for resource '%s' to '%s' using base URI '%s'",
1335 resource.uri,
1336 resource_name,
1337 uri,
1338 base_uri,
1339 )
1341 try:
1342 loaded_schema = Schema.from_uri(uri, context=context)
1343 self._resource_map[resource_name] = loaded_schema
1344 logger.debug("Loaded resource '%s' from URI '%s'", resource_name, uri)
1345 except Exception as e:
1346 raise ValueError(f"Failed to load resource '{resource_name}' from URI '{uri}': {e}") from e
1347 return self
1349 def _find_table_by_name(self, name: str) -> Table:
1350 """Find a table by name.
1352 Parameters
1353 ----------
1354 name
1355 The name of the table to find.
1357 Returns
1358 -------
1359 `Table`
1360 The table with the given name.
1362 Raises
1363 ------
1364 KeyError
1365 Raised if the table is not found.
1366 """
1367 for table in self.tables:
1368 if table.name == name:
1369 return table
1370 raise KeyError(f"Table '{name}' not found in schema '{self.name}'")
1372 @model_validator(mode="after")
1373 def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema:
1374 """Dereference columns from external resources and add them to the
1375 tables in this schema.
1376 """
1377 context = info.context
1378 column_ref_index_increment: int | None = None
1379 dereference_resources = False
1380 if context is not None:
1381 dereference_resources = context.get("dereference_resources", False)
1382 column_ref_index_increment = context.get("column_ref_index_increment", None)
1384 for table in self.tables:
1385 if column_refs := table.column_refs:
1386 for resource_name, tables in column_refs.items():
1387 resource_schema = self._resource_map.get(resource_name)
1388 if resource_schema is None:
1389 raise ValueError(f"Schema resource '{resource_name}' was not found in resources")
1390 self._process_column_refs(
1391 table,
1392 tables,
1393 resource_schema,
1394 dereference_resources,
1395 column_ref_index_increment,
1396 )
1397 if dereference_resources and len(table.column_refs) > 0:
1398 # Clear column refs in table if fully dereferencing
1399 logger.debug(
1400 "Clearing columnRefs in table '%s' after dereferencing resource columns",
1401 table.name,
1402 )
1403 table.column_refs = {}
1404 return self
1406 @classmethod
1407 def _process_column_refs(
1408 cls,
1409 table: Table,
1410 ref_tables: ResourceTableMap,
1411 resource_schema: Schema,
1412 dereference_resources: bool = False,
1413 column_ref_index_increment: int | None = None,
1414 ) -> None:
1415 """Process column references from an external resource and add them
1416 to the given table as columns.
1417 """
1418 current_column_index = column_ref_index_increment if column_ref_index_increment is not None else -1
1420 for table_name, columns in ref_tables.items():
1421 try:
1422 resource_table = resource_schema._find_table_by_name(table_name)
1423 except KeyError as e:
1424 raise ValueError(
1425 f"Table '{table_name}' not found in resource '{resource_schema.name}'"
1426 ) from e
1427 for local_column_name, column_ref in columns.items():
1428 if column_ref is not None and column_ref.ref_name is not None:
1429 # Use specified ref_name
1430 ref_column_name = column_ref.ref_name
1431 else:
1432 # Use the local column name if no ref_name
1433 # specified
1434 ref_column_name = local_column_name
1436 # Check if referenced column exists in resource
1437 try:
1438 base_column = resource_table._find_column_by_name(ref_column_name)
1439 except KeyError:
1440 # The ref_name is specified but column is not
1441 # found
1442 if column_ref is not None and column_ref.ref_name is not None:
1443 raise ValueError(
1444 f"Column '{ref_column_name}' not found in table '{table_name}' "
1445 f"from resource '{resource_schema.name}'"
1446 )
1447 # The ref_name is not specified and the local
1448 # column name is not found
1449 raise ValueError(
1450 f"Column '{local_column_name}' not found in table '{table_name}' "
1451 f"from resource '{resource_schema.name}' and no ref_name provided"
1452 )
1454 # Create a copy of the base column
1455 column_copy = base_column.model_copy()
1457 # Set the local name (key from the mapping)
1458 column_copy.name = local_column_name
1460 if not dereference_resources:
1461 # Flag the column as a resource reference so it will not be
1462 # written out during serialization
1463 column_copy._is_resource_ref = True
1465 # Apply overrides to the referenced column definition
1466 overrides = column_ref.overrides if column_ref is not None else None
1467 if overrides is not None:
1468 column_copy._update_from_overrides(overrides)
1470 # Manually set the ID of the copied column as ID generation has
1471 # already occurred by now
1472 column_copy.id = f"{table.id}.{local_column_name}"
1474 # Apply automatic assignment of 'tap:column_index', if enabled
1475 if column_ref_index_increment is not None:
1476 if (not overrides) or (overrides.tap_column_index is None):
1477 column_copy.tap_column_index = current_column_index
1478 current_column_index += column_ref_index_increment
1479 logger.debug(
1480 "Automatically assigned 'tap:column_index' %s to "
1481 "column '%s' in table '%s' from resource '%s'",
1482 column_copy.tap_column_index,
1483 local_column_name,
1484 table_name,
1485 resource_schema.name,
1486 )
1487 else:
1488 # Skip automatic assignment of 'tap:column_index' if it
1489 # is already overridden
1490 logger.debug(
1491 "Skipping automatic assignment of 'tap:column_index' for column "
1492 "'%s' in table '%s' from resource '%s' as it is already overridden to %s",
1493 local_column_name,
1494 table_name,
1495 resource_schema.name,
1496 column_copy.tap_column_index,
1497 )
1498 table.columns.append(column_copy)
1499 logger.debug(
1500 "Dereferenced column '%s' from table '%s' in resource '%s' into table '%s'",
1501 local_column_name,
1502 table_name,
1503 resource_schema.name,
1504 table.name,
1505 )
1507 @model_validator(mode="before")
1508 @classmethod
1509 def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
1510 """Generate IDs for objects that do not have them.
1512 Parameters
1513 ----------
1514 values
1515 The values of the schema.
1516 info
1517 Validation context used to determine if ID generation is enabled.
1519 Returns
1520 -------
1521 `dict` [ `str`, `Any` ]
1522 The values of the schema with generated IDs.
1523 """
1524 context = info.context
1525 if not context or not context.get("id_generation", False):
1526 logger.debug("Skipping ID generation")
1527 return values
1528 schema_name = values["name"]
1529 if "@id" not in values:
1530 values["@id"] = f"#{schema_name}"
1531 logger.debug("Generated ID '%s' for schema '%s'", values["@id"], schema_name)
1532 if "tables" in values:
1533 for table in values["tables"]:
1534 if "@id" not in table:
1535 table["@id"] = f"#{table['name']}"
1536 logger.debug("Generated ID '%s' for table '%s'", table["@id"], table["name"])
1537 if "columns" in table:
1538 for column in table["columns"]:
1539 if "@id" not in column:
1540 column["@id"] = f"#{table['name']}.{column['name']}"
1541 logger.debug("Generated ID '%s' for column '%s'", column["@id"], column["name"])
1542 if "columnGroups" in table:
1543 for column_group in table["columnGroups"]:
1544 if "@id" not in column_group:
1545 column_group["@id"] = f"#{table['name']}.{column_group['name']}"
1546 logger.debug(
1547 "Generated ID '%s' for column group '%s'",
1548 column_group["@id"],
1549 column_group["name"],
1550 )
1551 if "constraints" in table:
1552 for constraint in table["constraints"]:
1553 if "@id" not in constraint:
1554 constraint["@id"] = f"#{constraint['name']}"
1555 logger.debug(
1556 "Generated ID '%s' for constraint '%s'",
1557 constraint["@id"],
1558 constraint["name"],
1559 )
1560 if "indexes" in table:
1561 for index in table["indexes"]:
1562 if "@id" not in index:
1563 index["@id"] = f"#{index['name']}"
1564 logger.debug("Generated ID '%s' for index '%s'", index["@id"], index["name"])
1565 return values
1567 @field_validator("tables", mode="after")
1568 @classmethod
1569 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
1570 """Check that table names are unique.
1572 Parameters
1573 ----------
1574 tables
1575 The tables to check.
1577 Returns
1578 -------
1579 `list` [ `Table` ]
1580 The tables if they are unique.
1582 Raises
1583 ------
1584 ValueError
1585 Raised if table names are not unique.
1586 """
1587 if len(tables) != len(set(table.name for table in tables)):
1588 raise ValueError("Table names must be unique")
1589 return tables
1591 @model_validator(mode="after")
1592 def check_tap_table_indexes(self, info: ValidationInfo) -> Schema:
1593 """Check that the TAP table indexes are unique.
1595 Parameters
1596 ----------
1597 info
1598 The validation context used to determine if the check is enabled.
1600 Returns
1601 -------
1602 `Schema`
1603 The schema being validated.
1604 """
1605 context = info.context
1606 if not context or not context.get("check_tap_table_indexes", False):
1607 return self
1608 table_indicies = set()
1609 for table in self.tables:
1610 table_index = table.tap_table_index
1611 if table_index is not None:
1612 if table_index in table_indicies:
1613 raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema")
1614 table_indicies.add(table_index)
1615 return self
1617 @model_validator(mode="after")
1618 def check_unique_constraint_names(self: Schema) -> Schema:
1619 """Check for duplicate constraint names in the schema.
1621 Returns
1622 -------
1623 `Schema`
1624 The schema being validated.
1626 Raises
1627 ------
1628 ValueError
1629 Raised if duplicate constraint names are found in the schema.
1630 """
1631 constraint_names = set()
1632 duplicate_names = []
1634 for table in self.tables:
1635 for constraint in table.constraints:
1636 constraint_name = constraint.name
1637 if constraint_name in constraint_names:
1638 duplicate_names.append(constraint_name)
1639 else:
1640 constraint_names.add(constraint_name)
1642 if duplicate_names:
1643 raise ValueError(f"Duplicate constraint names found in schema: {duplicate_names}")
1645 return self
1647 @model_validator(mode="after")
1648 def check_unique_index_names(self: Schema) -> Schema:
1649 """Check for duplicate index names in the schema.
1651 Returns
1652 -------
1653 `Schema`
1654 The schema being validated.
1656 Raises
1657 ------
1658 ValueError
1659 Raised if duplicate index names are found in the schema.
1660 """
1661 index_names = set()
1662 duplicate_names = []
1664 for table in self.tables:
1665 for index in table.indexes:
1666 index_name = index.name
1667 if index_name in index_names:
1668 duplicate_names.append(index_name)
1669 else:
1670 index_names.add(index_name)
1672 if duplicate_names:
1673 raise ValueError(f"Duplicate index names found in schema: {duplicate_names}")
1675 return self
1677 @model_validator(mode="after")
1678 def create_id_map(self: Schema) -> Schema:
1679 """Create a map of IDs to objects.
1681 Returns
1682 -------
1683 `Schema`
1684 The schema with the ID map created.
1686 Raises
1687 ------
1688 ValueError
1689 Raised if duplicate identifiers are found in the schema.
1690 """
1691 if self._id_map:
1692 logger.debug("Ignoring call to create_id_map() - ID map was already populated")
1693 return self
1694 visitor: SchemaIdVisitor = SchemaIdVisitor()
1695 visitor.visit_schema(self)
1696 if len(visitor.duplicates):
1697 raise ValueError(
1698 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
1699 )
1700 logger.debug("Created ID map with %d entries", len(self._id_map))
1701 return self
1703 def _validate_column_id(
1704 self: Schema,
1705 column_id: str,
1706 loc: tuple,
1707 errors: list[InitErrorDetails],
1708 ) -> None:
1709 """Validate a column ID from a constraint and append errors if invalid.
1711 Parameters
1712 ----------
1713 schema : Schema
1714 The schema being validated.
1715 column_id : str
1716 The column ID to validate.
1717 loc : tuple
1718 The location of the error in the schema.
1719 errors : list[InitErrorDetails]
1720 The list of errors to append to.
1721 """
1722 if column_id not in self:
1723 _append_error(
1724 errors,
1725 loc,
1726 column_id,
1727 f"Column ID '{column_id}' not found in schema",
1728 )
1729 elif not isinstance(self[column_id], Column):
1730 _append_error(
1731 errors,
1732 loc,
1733 column_id,
1734 f"ID '{column_id}' does not refer to a Column object",
1735 )
1737 def _validate_foreign_key_column(
1738 self: Schema,
1739 column_id: str,
1740 table: Table,
1741 loc: tuple,
1742 errors: list[InitErrorDetails],
1743 ) -> None:
1744 """Validate a foreign key column ID from a constraint and append errors
1745 if invalid.
1747 Parameters
1748 ----------
1749 schema : Schema
1750 The schema being validated.
1751 column_id : str
1752 The foreign key column ID to validate.
1753 loc : tuple
1754 The location of the error in the schema.
1755 errors : list[InitErrorDetails]
1756 The list of errors to append to.
1757 """
1758 try:
1759 table._find_column_by_id(column_id)
1760 except KeyError:
1761 _append_error(
1762 errors,
1763 loc,
1764 column_id,
1765 f"Column '{column_id}' not found in table '{table.name}'",
1766 )
1768 @model_validator(mode="after")
1769 def check_constraints(self: Schema) -> Schema:
1770 """Check constraint objects for validity. This needs to be deferred
1771 until after the schema is fully loaded and the ID map is created.
1773 Raises
1774 ------
1775 pydantic.ValidationError
1776 Raised if any constraints are invalid.
1778 Returns
1779 -------
1780 `Schema`
1781 The schema being validated.
1782 """
1783 errors: list[InitErrorDetails] = []
1785 for table_index, table in enumerate(self.tables):
1786 for constraint_index, constraint in enumerate(table.constraints):
1787 column_ids: list[str] = []
1788 referenced_column_ids: list[str] = []
1790 if isinstance(constraint, ForeignKeyConstraint):
1791 column_ids += constraint.columns
1792 referenced_column_ids += constraint.referenced_columns
1793 elif isinstance(constraint, UniqueConstraint):
1794 column_ids += constraint.columns
1795 # No extra checks are required on CheckConstraint objects.
1797 # Validate the foreign key columns
1798 for column_id in column_ids:
1799 self._validate_column_id(
1800 column_id,
1801 (
1802 "tables",
1803 table_index,
1804 "constraints",
1805 constraint_index,
1806 "columns",
1807 column_id,
1808 ),
1809 errors,
1810 )
1811 # Check that the foreign key column is within the source
1812 # table.
1813 self._validate_foreign_key_column(
1814 column_id,
1815 table,
1816 (
1817 "tables",
1818 table_index,
1819 "constraints",
1820 constraint_index,
1821 "columns",
1822 column_id,
1823 ),
1824 errors,
1825 )
1827 # Validate the primary key (reference) columns
1828 for referenced_column_id in referenced_column_ids:
1829 self._validate_column_id(
1830 referenced_column_id,
1831 (
1832 "tables",
1833 table_index,
1834 "constraints",
1835 constraint_index,
1836 "referenced_columns",
1837 referenced_column_id,
1838 ),
1839 errors,
1840 )
1842 if errors:
1843 raise ValidationError.from_exception_data("Schema validation failed", errors)
1845 return self
1847 def __getitem__(self, id: str) -> BaseObject:
1848 """Get an object by its ID.
1850 Parameters
1851 ----------
1852 id
1853 The ID of the object to get.
1855 Raises
1856 ------
1857 KeyError
1858 Raised if the object with the given ID is not found in the schema.
1859 """
1860 if id not in self:
1861 raise KeyError(f"Object with ID '{id}' not found in schema")
1862 return self._id_map[id]
1864 def __contains__(self, id: str) -> bool:
1865 """Check if an object with the given ID is in the schema.
1867 Parameters
1868 ----------
1869 id
1870 The ID of the object to check.
1871 """
1872 return id in self._id_map
1874 def find_object_by_id(self, id: str, obj_type: type[T]) -> T:
1875 """Find an object with the given type by its ID.
1877 Parameters
1878 ----------
1879 id
1880 The ID of the object to find.
1881 obj_type
1882 The type of the object to find.
1884 Returns
1885 -------
1886 BaseObject
1887 The object with the given ID and type.
1889 Raises
1890 ------
1891 KeyError
1892 If the object with the given ID is not found in the schema.
1893 TypeError
1894 If the object that is found does not have the right type.
1896 Notes
1897 -----
1898 The actual return type is the user-specified argument ``T``, which is
1899 expected to be a subclass of `BaseObject`.
1900 """
1901 obj = self[id]
1902 if not isinstance(obj, obj_type):
1903 raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'")
1904 return obj
1906 def get_table_by_column(self, column: Column) -> Table:
1907 """Find the table that contains a column.
1909 Parameters
1910 ----------
1911 column
1912 The column to find.
1914 Returns
1915 -------
1916 `Table`
1917 The table that contains the column.
1919 Raises
1920 ------
1921 ValueError
1922 If the column is not found in any table.
1923 """
1924 for table in self.tables:
1925 if column in table.columns:
1926 return table
1927 raise ValueError(f"Column '{column.name}' not found in any table")
1929 @classmethod
1930 def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema:
1931 """Load a `Schema` from a string representing a ``ResourcePath``.
1933 Parameters
1934 ----------
1935 resource_path
1936 The ``ResourcePath`` pointing to a YAML file.
1937 context
1938 Pydantic context to be used in validation.
1940 Returns
1941 -------
1942 `str`
1943 The ID of the object.
1945 Raises
1946 ------
1947 yaml.YAMLError
1948 Raised if there is an error loading the YAML data.
1949 ValueError
1950 Raised if there is an error reading the resource.
1951 pydantic.ValidationError
1952 Raised if the schema fails validation.
1953 """
1954 try:
1955 rp = ResourcePath(resource_path, forceAbsolute=False, forceDirectory=False)
1956 rp_data = rp.read()
1957 except Exception as e:
1958 raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e
1959 yaml_data = yaml.safe_load(rp_data)
1960 context = dict(context)
1961 # Append the resource path to the context for resolving resource URLs.
1962 context["resource_path"] = rp
1963 return Schema.model_validate(yaml_data, context=context)
1965 @classmethod
1966 def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema:
1967 """Load a `Schema` from a file stream which should contain YAML data.
1969 Parameters
1970 ----------
1971 source
1972 The file stream to read from.
1973 context
1974 Pydantic context to be used in validation.
1976 Returns
1977 -------
1978 `Schema`
1979 The Felis schema loaded from the stream.
1981 Raises
1982 ------
1983 yaml.YAMLError
1984 Raised if there is an error loading the YAML file.
1985 pydantic.ValidationError
1986 Raised if the schema fails validation.
1987 """
1988 logger.debug("Loading schema from: '%s'", source)
1989 yaml_data = yaml.safe_load(source)
1990 return Schema.model_validate(yaml_data, context=context)
1992 def _model_dump(self, strip_ids: bool = False, sort_columns: bool = False) -> dict[str, Any]:
1993 """Dump the schema as a dictionary with some default arguments
1994 applied.
1996 Parameters
1997 ----------
1998 strip_ids
1999 Whether to strip the IDs from the dumped data. Defaults to `False`.
2000 sort_columns
2001 Whether to sort columns alphabetically by name. Defaults to
2002 `False`.
2004 Returns
2005 -------
2006 `dict` [ `str`, `Any` ]
2007 The dumped schema data as a dictionary.
2008 """
2009 data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True)
2010 if strip_ids:
2011 data = _strip_ids(data)
2012 if sort_columns:
2013 for table in data.get("tables", []):
2014 table["columns"] = sorted(table.get("columns", []), key=itemgetter("name"))
2015 return data
2017 def dump_yaml(
2018 self, stream: IO[str] = sys.stdout, strip_ids: bool = False, sort_columns: bool = False
2019 ) -> None:
2020 """Pretty print the schema as YAML.
2022 Parameters
2023 ----------
2024 stream
2025 The stream to write the YAML data to.
2026 strip_ids
2027 Whether to strip the IDs from the dumped data. Defaults to `False`.
2028 sort_columns
2029 Whether to sort columns alphabetically by name. Defaults to
2030 `False`.
2031 """
2032 data = self._model_dump(strip_ids=strip_ids, sort_columns=sort_columns)
2033 yaml.safe_dump(
2034 data,
2035 stream,
2036 default_flow_style=False,
2037 sort_keys=False,
2038 )
2040 def dump_json(
2041 self, stream: IO[str] = sys.stdout, strip_ids: bool = False, sort_columns: bool = False
2042 ) -> None:
2043 """Pretty print the schema as JSON.
2045 Parameters
2046 ----------
2047 stream
2048 The stream to write the JSON data to.
2049 strip_ids
2050 Whether to strip the IDs from the dumped data. Defaults to `False`.
2051 sort_columns
2052 Whether to sort columns alphabetically by name. Defaults to
2053 `False`.
2054 """
2055 data = self._model_dump(strip_ids=strip_ids, sort_columns=sort_columns)
2056 json.dump(
2057 data,
2058 stream,
2059 indent=4,
2060 sort_keys=False,
2061 )