Coverage for python/felis/datamodel.py: 30%

738 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 08:36 +0000

1"""Define Pydantic data models for Felis.""" 

2 

3# This file is part of felis. 

4# 

5# Developed for the LSST Data Management System. 

6# This product includes software developed by the LSST Project 

7# (https://www.lsst.org). 

8# See the COPYRIGHT file at the top-level directory of this distribution 

9# for details of code ownership. 

10# 

11# This program is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU General Public License as published by 

13# the Free Software Foundation, either version 3 of the License, or 

14# (at your option) any later version. 

15# 

16# This program is distributed in the hope that it will be useful, 

17# but WITHOUT ANY WARRANTY; without even the implied warranty of 

18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19# GNU General Public License for more details. 

20# 

21# You should have received a copy of the GNU General Public License 

22# along with this program. If not, see <https://www.gnu.org/licenses/>. 

23 

24from __future__ import annotations 

25 

26import json 

27import logging 

28import sys 

29from collections.abc import Sequence 

30from enum import StrEnum, auto 

31from operator import itemgetter 

32from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar 

33 

34import yaml 

35from astropy import units as units # type: ignore 

36from astropy.io.votable import ucd # type: ignore 

37from lsst.resources import ResourcePath, ResourcePathExpression 

38from pydantic import ( 

39 BaseModel, 

40 ConfigDict, 

41 Field, 

42 PrivateAttr, 

43 ValidationError, 

44 ValidationInfo, 

45 field_serializer, 

46 field_validator, 

47 model_validator, 

48) 

49from pydantic_core import InitErrorDetails 

50 

51from .db._dialects import get_supported_dialects, string_to_typeengine 

52from .db._sqltypes import get_type_func 

53from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode 

54 

55logger = logging.getLogger(__name__) 

56 

57__all__ = ( 

58 "BaseObject", 

59 "CheckConstraint", 

60 "Column", 

61 "ColumnOverrides", 

62 "ColumnResourceRef", 

63 "Constraint", 

64 "DataType", 

65 "ForeignKeyConstraint", 

66 "Index", 

67 "Resource", 

68 "Schema", 

69 "SchemaVersion", 

70 "Table", 

71 "UniqueConstraint", 

72) 

73 

74CONFIG = ConfigDict( 

75 populate_by_name=True, # Populate attributes by name. 

76 extra="forbid", # Do not allow extra fields. 

77 str_strip_whitespace=True, # Strip whitespace from string fields. 

78 use_enum_values=False, # Do not use enum values during serialization. 

79) 

80"""Pydantic model configuration as described in: 

81https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

82""" 

83 

84DESCR_MIN_LENGTH = 3 

85"""Minimum length for a description field.""" 

86 

87DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

88"""Type for a description, which must be three or more characters long.""" 

89 

90 

91class BaseObject(BaseModel): 

92 """Base model. 

93 

94 All classes representing objects in the Felis data model should inherit 

95 from this class. 

96 """ 

97 

98 model_config = CONFIG 

99 """Pydantic model configuration.""" 

100 

101 name: str 

102 """Name of the database object.""" 

103 

104 id: str = Field(alias="@id") 

105 """Unique identifier of the database object.""" 

106 

107 description: DescriptionStr | None = None 

108 """Description of the database object.""" 

109 

110 votable_utype: str | None = Field(None, alias="votable:utype") 

111 """VOTable utype (usage-specific or unique type) of the object.""" 

112 

113 @model_validator(mode="after") 

114 def check_description(self, info: ValidationInfo) -> BaseObject: 

115 """Check that the description is present if required. 

116 

117 Parameters 

118 ---------- 

119 info 

120 Validation context used to determine if the check is enabled. 

121 

122 Returns 

123 ------- 

124 `BaseObject` 

125 The object being validated. 

126 """ 

127 context = info.context 

128 if not context or not context.get("check_description", False): 

129 return self 

130 if self.description is None or self.description == "": 

131 raise ValueError("Description is required and must be non-empty") 

132 if len(self.description) < DESCR_MIN_LENGTH: 

133 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

134 return self 

135 

136 

137class DataType(StrEnum): 

138 """``Enum`` representing the data types supported by Felis.""" 

139 

140 boolean = auto() 

141 byte = auto() 

142 short = auto() 

143 int = auto() 

144 long = auto() 

145 float = auto() 

146 double = auto() 

147 char = auto() 

148 string = auto() 

149 unicode = auto() 

150 text = auto() 

151 binary = auto() 

152 timestamp = auto() 

153 

154 

155def validate_ivoa_ucd(ivoa_ucd: str) -> str: 

156 """Validate IVOA UCD values. 

157 

158 Parameters 

159 ---------- 

160 ivoa_ucd 

161 IVOA UCD value to check. 

162 

163 Returns 

164 ------- 

165 `str` 

166 The IVOA UCD value if it is valid. 

167 

168 Raises 

169 ------ 

170 ValueError 

171 If the IVOA UCD value is invalid. 

172 """ 

173 if ivoa_ucd is not None: 

174 try: 

175 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

176 except ValueError as e: 

177 raise ValueError(f"Invalid IVOA UCD: {e}") 

178 return ivoa_ucd 

179 

180 

181class Column(BaseObject): 

182 """Column model.""" 

183 

184 datatype: DataType 

185 """Datatype of the column.""" 

186 

187 length: int | None = Field(None, gt=0) 

188 """Length of the column.""" 

189 

190 precision: int | None = Field(None, ge=0) 

191 """The numerical precision of the column. 

192 

193 For timestamps, this is the number of fractional digits retained in the 

194 seconds field. 

195 """ 

196 

197 nullable: bool = True 

198 """Whether the column can be ``NULL``.""" 

199 

200 value: str | int | float | bool | None = None 

201 """Default value of the column.""" 

202 

203 autoincrement: bool | None = None 

204 """Whether the column is autoincremented.""" 

205 

206 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

207 """IVOA UCD of the column.""" 

208 

209 fits_tunit: str | None = Field(None, alias="fits:tunit") 

210 """FITS TUNIT of the column.""" 

211 

212 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

213 """IVOA unit of the column.""" 

214 

215 tap_column_index: int | None = Field(None, alias="tap:column_index") 

216 """TAP_SCHEMA column index of the column.""" 

217 

218 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

219 """Whether this is a TAP_SCHEMA principal column.""" 

220 

221 votable_arraysize: int | str | None = Field(None, alias="votable:arraysize") 

222 """VOTable arraysize of the column.""" 

223 

224 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

225 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

226 """ 

227 

228 votable_xtype: str | None = Field(None, alias="votable:xtype") 

229 """VOTable xtype (extended type) of the column.""" 

230 

231 votable_datatype: str | None = Field(None, alias="votable:datatype") 

232 """VOTable datatype of the column.""" 

233 

234 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

235 """MySQL datatype override on the column.""" 

236 

237 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

238 """PostgreSQL datatype override on the column.""" 

239 

240 _is_resource_ref: bool = PrivateAttr(False) 

241 """Whether this column is a resource reference column.""" 

242 

243 @model_validator(mode="after") 

244 def check_value(self) -> Column: 

245 """Check that the default value is valid. 

246 

247 Returns 

248 ------- 

249 `Column` 

250 The column being validated. 

251 """ 

252 if (value := self.value) is not None: 

253 if value is not None and self.autoincrement is True: 

254 raise ValueError("Column cannot have both a default value and be autoincremented") 

255 felis_type = FelisType.felis_type(self.datatype) 

256 if felis_type.is_numeric: 

257 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int): 

258 raise ValueError("Default value must be an int for integer type columns") 

259 elif felis_type in (Float, Double) and not isinstance(value, float): 

260 raise ValueError("Default value must be a decimal number for float and double columns") 

261 elif felis_type in (String, Char, Unicode, Text): 

262 if not isinstance(value, str): 

263 raise ValueError("Default value must be a string for string columns") 

264 if not len(value): 

265 raise ValueError("Default value must be a non-empty string for string columns") 

266 elif felis_type is Boolean and not isinstance(value, bool): 

267 raise ValueError("Default value must be a boolean for boolean columns") 

268 return self 

269 

270 @field_validator("ivoa_ucd") 

271 @classmethod 

272 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

273 """Check that IVOA UCD values are valid. 

274 

275 Parameters 

276 ---------- 

277 ivoa_ucd 

278 IVOA UCD value to check. 

279 

280 Returns 

281 ------- 

282 `str` 

283 The IVOA UCD value if it is valid. 

284 """ 

285 return validate_ivoa_ucd(ivoa_ucd) 

286 

287 @model_validator(mode="after") 

288 def check_units(self) -> Column: 

289 """Check that the ``fits:tunit`` or ``ivoa:unit`` field has valid 

290 units according to astropy. Only one may be provided. 

291 

292 Returns 

293 ------- 

294 `Column` 

295 The column being validated. 

296 

297 Raises 

298 ------ 

299 ValueError 

300 Raised if both FITS and IVOA units are provided, or if the unit is 

301 invalid. 

302 """ 

303 fits_unit = self.fits_tunit 

304 ivoa_unit = self.ivoa_unit 

305 

306 if fits_unit and ivoa_unit: 

307 raise ValueError("Column cannot have both FITS and IVOA units") 

308 unit = fits_unit or ivoa_unit 

309 

310 if unit is not None: 

311 try: 

312 units.Unit(unit) 

313 except ValueError as e: 

314 raise ValueError(f"Invalid unit: {e}") 

315 

316 return self 

317 

318 @model_validator(mode="before") 

319 @classmethod 

320 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]: 

321 """Check that a valid length is provided for sized types. 

322 

323 Parameters 

324 ---------- 

325 values 

326 Values of the column. 

327 

328 Returns 

329 ------- 

330 `dict` [ `str`, `Any` ] 

331 The values of the column. 

332 

333 Raises 

334 ------ 

335 ValueError 

336 Raised if a length is not provided for a sized type. 

337 """ 

338 datatype = values.get("datatype") 

339 if datatype is None: 

340 # Skip this validation if datatype is not provided 

341 return values 

342 length = values.get("length") 

343 felis_type = FelisType.felis_type(datatype) 

344 if felis_type.is_sized and length is None: 

345 raise ValueError( 

346 f"Length must be provided for type '{datatype}'" 

347 + (f" in column '{values['@id']}'" if "@id" in values else "") 

348 ) 

349 elif not felis_type.is_sized and length is not None: 

350 msg = f"The datatype '{datatype}' does not support a specified length" 

351 if "@id" in values: 

352 msg += f" in column '{values['@id']}'" 

353 logger.warning("%s", msg) 

354 return values 

355 

356 @model_validator(mode="after") 

357 def check_redundant_datatypes(self, info: ValidationInfo) -> Column: 

358 """Check for redundant datatypes on columns. 

359 

360 Parameters 

361 ---------- 

362 info 

363 Validation context used to determine if the check is enabled. 

364 

365 Returns 

366 ------- 

367 `Column` 

368 The column being validated. 

369 

370 Raises 

371 ------ 

372 ValueError 

373 Raised if a datatype override is redundant. 

374 """ 

375 context = info.context 

376 if not context or not context.get("check_redundant_datatypes", False): 

377 return self 

378 if all( 

379 getattr(self, f"{dialect}:datatype", None) is not None 

380 for dialect in get_supported_dialects().keys() 

381 ): 

382 return self 

383 

384 datatype = self.datatype 

385 length: int | None = self.length or None 

386 

387 datatype_func = get_type_func(datatype) 

388 felis_type = FelisType.felis_type(datatype) 

389 if felis_type.is_sized: 

390 datatype_obj = datatype_func(length) 

391 else: 

392 datatype_obj = datatype_func() 

393 

394 for dialect_name, dialect in get_supported_dialects().items(): 

395 db_annotation = f"{dialect_name}_datatype" 

396 if datatype_string := self.model_dump().get(db_annotation): 

397 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

398 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

399 raise ValueError( 

400 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format( 

401 db_annotation, 

402 datatype_string, 

403 self.datatype, 

404 self.id, 

405 "" if length is None else f" with length {length}", 

406 ) 

407 ) 

408 else: 

409 logger.debug( 

410 "Type override of 'datatype: %s' with '%s: %s' in column '%s' " 

411 "compiled to '%s' and '%s'", 

412 self.datatype, 

413 db_annotation, 

414 datatype_string, 

415 self.id, 

416 datatype_obj.compile(dialect), 

417 db_datatype_obj.compile(dialect), 

418 ) 

419 return self 

420 

421 @model_validator(mode="after") 

422 def check_precision(self) -> Column: 

423 """Check that precision is only valid for timestamp columns. 

424 

425 Returns 

426 ------- 

427 `Column` 

428 The column being validated. 

429 """ 

430 if self.precision is not None and self.datatype != "timestamp": 

431 raise ValueError("Precision is only valid for timestamp columns") 

432 return self 

433 

434 @model_validator(mode="before") 

435 @classmethod 

436 def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]: 

437 """Set the default value for the ``votable_arraysize`` field, which 

438 corresponds to ``arraysize`` in the IVOA VOTable standard. 

439 

440 Parameters 

441 ---------- 

442 values 

443 Values of the column. 

444 info 

445 Validation context used to determine if the check is enabled. 

446 

447 Returns 

448 ------- 

449 `dict` [ `str`, `Any` ] 

450 The values of the column. 

451 

452 Notes 

453 ----- 

454 Following the IVOA VOTable standard, an ``arraysize`` of 1 should not 

455 be used. 

456 """ 

457 if values.get("name", None) is None or values.get("datatype", None) is None: 

458 # Skip bad column data that will not validate 

459 return values 

460 context = info.context if info.context else {} 

461 arraysize = values.get("votable:arraysize", None) 

462 if arraysize is None: 

463 length = values.get("length", None) 

464 datatype = values.get("datatype") 

465 if length is not None and length > 1: 

466 # Following the IVOA standard, arraysize of 1 is disallowed 

467 if datatype == "char": 

468 arraysize = str(length) 

469 elif datatype in ("string", "unicode", "binary"): 

470 if context.get("force_unbounded_arraysize", False): 

471 arraysize = "*" 

472 logger.debug( 

473 "Forced VOTable's 'arraysize' to '*' on column '%s' with datatype " 

474 "'%s' and length '%s'", 

475 values["name"], 

476 values["datatype"], 

477 length, 

478 ) 

479 else: 

480 arraysize = f"{length}*" 

481 elif datatype in ("timestamp", "text"): 

482 arraysize = "*" 

483 if arraysize is not None: 

484 values["votable:arraysize"] = arraysize 

485 logger.debug( 

486 "Set default 'votable:arraysize' to '%s' on column '%s'" 

487 " with datatype '%s' and length '%s'", 

488 arraysize, 

489 values["name"], 

490 values["datatype"], 

491 values.get("length", None), 

492 ) 

493 else: 

494 logger.debug( 

495 "Using existing 'votable:arraysize' of '%s' on column '%s'", arraysize, values["name"] 

496 ) 

497 if isinstance(values["votable:arraysize"], int): 

498 logger.warning( 

499 "Usage of an integer value for 'votable:arraysize' in column '%s' is deprecated", 

500 values["name"], 

501 ) 

502 values["votable:arraysize"] = str(arraysize) 

503 return values 

504 

505 @field_serializer("datatype") 

506 def serialize_datatype(self, value: DataType) -> str: 

507 """Convert `DataType` to string when serializing to JSON/YAML. 

508 

509 Parameters 

510 ---------- 

511 value 

512 The `DataType` value to serialize. 

513 

514 Returns 

515 ------- 

516 `str` 

517 The serialized `DataType` value. 

518 """ 

519 return str(value) 

520 

521 @field_validator("datatype", mode="before") 

522 @classmethod 

523 def deserialize_datatype(cls, value: str) -> DataType: 

524 """Convert string back into `DataType` when loading from JSON/YAML. 

525 

526 Parameters 

527 ---------- 

528 value 

529 The string value to deserialize. 

530 

531 Returns 

532 ------- 

533 `DataType` 

534 The deserialized `DataType` value. 

535 """ 

536 return DataType(value) 

537 

538 @model_validator(mode="after") 

539 def check_votable_xtype(self) -> Column: 

540 """Set the default value for the ``votable_xtype`` field, which 

541 corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable 

542 standard. 

543 

544 Returns 

545 ------- 

546 `Column` 

547 The column being validated. 

548 

549 Notes 

550 ----- 

551 This is currently only set automatically for the Felis ``timestamp`` 

552 datatype. 

553 """ 

554 if self.datatype == DataType.timestamp and self.votable_xtype is None: 

555 self.votable_xtype = "timestamp" 

556 return self 

557 

558 def _update_from_overrides(self, overrides: ColumnOverrides) -> None: 

559 """Update the column attributes from the given overrides. 

560 

561 Parameters 

562 ---------- 

563 overrides 

564 The column overrides to apply or `None` to skip applying overrides. 

565 

566 Notes 

567 ----- 

568 Using ``model_fields_set`` allows updating only the fields that are 

569 explicitly set in the `overrides` object. This prevents overwriting 

570 existing column attributes which were not explicitly provided. 

571 """ 

572 if overrides.model_fields_set: 

573 logger.debug("Applying overrides to column '%s': %s", self.id, overrides.model_fields_set) 

574 for field in overrides.model_fields_set: 

575 setattr(self, field, getattr(overrides, field)) 

576 

577 

578class Constraint(BaseObject): 

579 """Table constraint model.""" 

580 

581 deferrable: bool = False 

582 """Whether this constraint will be declared as deferrable.""" 

583 

584 initially: Literal["IMMEDIATE", "DEFERRED"] | None = None 

585 """Value for ``INITIALLY`` clause; only used if `deferrable` is 

586 `True`.""" 

587 

588 @model_validator(mode="after") 

589 def check_deferrable(self) -> Constraint: 

590 """Check that the ``INITIALLY`` clause is only used if `deferrable` is 

591 `True`. 

592 

593 Returns 

594 ------- 

595 `Constraint` 

596 The constraint being validated. 

597 """ 

598 if self.initially is not None and not self.deferrable: 

599 raise ValueError("INITIALLY clause can only be used if deferrable is True") 

600 return self 

601 

602 

603class CheckConstraint(Constraint): 

604 """Table check constraint model.""" 

605 

606 type: Literal["Check"] = Field("Check", alias="@type") 

607 """Type of the constraint.""" 

608 

609 expression: str 

610 """Expression for the check constraint.""" 

611 

612 @field_serializer("type") 

613 def serialize_type(self, value: str) -> str: 

614 """Ensure '@type' is included in serialized output. 

615 

616 Parameters 

617 ---------- 

618 value 

619 The value to serialize. 

620 

621 Returns 

622 ------- 

623 `str` 

624 The serialized value. 

625 """ 

626 return value 

627 

628 

629class UniqueConstraint(Constraint): 

630 """Table unique constraint model.""" 

631 

632 type: Literal["Unique"] = Field("Unique", alias="@type") 

633 """Type of the constraint.""" 

634 

635 columns: list[str] 

636 """Columns in the unique constraint.""" 

637 

638 @field_serializer("type") 

639 def serialize_type(self, value: str) -> str: 

640 """Ensure '@type' is included in serialized output. 

641 

642 Parameters 

643 ---------- 

644 value 

645 The value to serialize. 

646 

647 Returns 

648 ------- 

649 `str` 

650 The serialized value. 

651 """ 

652 return value 

653 

654 

655class ForeignKeyConstraint(Constraint): 

656 """Table foreign key constraint model. 

657 

658 This constraint is used to define a foreign key relationship between two 

659 tables in the schema. There must be at least one column in the 

660 `columns` list, and at least one column in the `referenced_columns` list 

661 or a validation error will be raised. 

662 

663 Notes 

664 ----- 

665 These relationships will be reflected in the TAP_SCHEMA ``keys`` and 

666 ``key_columns`` data. 

667 """ 

668 

669 type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type") 

670 """Type of the constraint.""" 

671 

672 columns: list[str] = Field(min_length=1) 

673 """The columns comprising the foreign key.""" 

674 

675 referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1) 

676 """The columns referenced by the foreign key.""" 

677 

678 on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None 

679 """Action to take when the referenced row is deleted.""" 

680 

681 on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None 

682 """Action to take when the referenced row is updated.""" 

683 

684 @field_serializer("type") 

685 def serialize_type(self, value: str) -> str: 

686 """Ensure '@type' is included in serialized output. 

687 

688 Parameters 

689 ---------- 

690 value 

691 The value to serialize. 

692 

693 Returns 

694 ------- 

695 `str` 

696 The serialized value. 

697 """ 

698 return value 

699 

700 @model_validator(mode="after") 

701 def check_column_lengths(self) -> ForeignKeyConstraint: 

702 """Check that the `columns` and `referenced_columns` lists have the 

703 same length. 

704 

705 Returns 

706 ------- 

707 `ForeignKeyConstraint` 

708 The foreign key constraint being validated. 

709 

710 Raises 

711 ------ 

712 ValueError 

713 Raised if the `columns` and `referenced_columns` lists do not have 

714 the same length. 

715 """ 

716 if len(self.columns) != len(self.referenced_columns): 

717 raise ValueError( 

718 "Columns and referencedColumns must have the same length for a ForeignKey constraint" 

719 ) 

720 return self 

721 

722 

723_ConstraintType = Annotated[ 

724 CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type") 

725] 

726"""Type alias for a constraint type.""" 

727 

728 

729class Index(BaseObject): 

730 """Table index model. 

731 

732 An index can be defined on either columns or expressions, but not both. 

733 """ 

734 

735 columns: list[str] | None = None 

736 """Columns in the index.""" 

737 

738 expressions: list[str] | None = None 

739 """Expressions in the index.""" 

740 

741 @model_validator(mode="before") 

742 @classmethod 

743 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

744 """Check that columns or expressions are specified, but not both. 

745 

746 Parameters 

747 ---------- 

748 values 

749 Values of the index. 

750 

751 Returns 

752 ------- 

753 `dict` [ `str`, `Any` ] 

754 The values of the index. 

755 

756 Raises 

757 ------ 

758 ValueError 

759 Raised if both columns and expressions are specified, or if neither 

760 are specified. 

761 """ 

762 if "columns" in values and "expressions" in values: 

763 raise ValueError("Defining columns and expressions is not valid") 

764 elif "columns" not in values and "expressions" not in values: 

765 raise ValueError("Must define columns or expressions") 

766 return values 

767 

768 

769ColumnRef: TypeAlias = str 

770"""Type alias for a column reference.""" 

771 

772 

773class ColumnGroup(BaseObject): 

774 """Column group model.""" 

775 

776 columns: list[ColumnRef | Column] = Field(..., min_length=1) 

777 """Columns in the group.""" 

778 

779 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

780 """IVOA UCD of the column.""" 

781 

782 table: Table | None = Field(None, exclude=True) 

783 """Reference to the parent table.""" 

784 

785 @field_validator("ivoa_ucd") 

786 @classmethod 

787 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

788 """Check that IVOA UCD values are valid. 

789 

790 Parameters 

791 ---------- 

792 ivoa_ucd 

793 IVOA UCD value to check. 

794 

795 Returns 

796 ------- 

797 `str` 

798 The IVOA UCD value if it is valid. 

799 """ 

800 return validate_ivoa_ucd(ivoa_ucd) 

801 

802 @model_validator(mode="after") 

803 def check_unique_columns(self) -> ColumnGroup: 

804 """Check that the columns list contains unique items. 

805 

806 Returns 

807 ------- 

808 `ColumnGroup` 

809 The column group being validated. 

810 """ 

811 column_ids = [col if isinstance(col, str) else col.id for col in self.columns] 

812 if len(column_ids) != len(set(column_ids)): 

813 raise ValueError("Columns in the group must be unique") 

814 return self 

815 

816 def _dereference_columns(self) -> None: 

817 """Dereference ColumnRef to Column objects.""" 

818 if self.table is None: 

819 raise ValueError("ColumnGroup must have a reference to its parent table") 

820 

821 dereferenced_columns: list[ColumnRef | Column] = [] 

822 for col in self.columns: 

823 if isinstance(col, str): 

824 # Dereference ColumnRef to Column object 

825 try: 

826 col_obj = self.table._find_column_by_id(col) 

827 except KeyError as e: 

828 raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e 

829 dereferenced_columns.append(col_obj) 

830 else: 

831 dereferenced_columns.append(col) 

832 

833 self.columns = dereferenced_columns 

834 

835 @field_serializer("columns") 

836 def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]: 

837 """Serialize columns as their IDs. 

838 

839 Parameters 

840 ---------- 

841 columns 

842 The columns to serialize. 

843 

844 Returns 

845 ------- 

846 `list` [ `str` ] 

847 The serialized column IDs. 

848 """ 

849 return [col if isinstance(col, str) else col.id for col in columns] 

850 

851 

852class ColumnOverrides(BaseModel): 

853 """Allowed overrides for a referenced column. 

854 

855 Notes 

856 ----- 

857 All of these fields are optional. Values of None may be explicitly set to 

858 override the corresponding attribute in the referenced column but only 

859 for certain fields (see validation in `_check_non_nullable_overrides`). 

860 """ 

861 

862 model_config = CONFIG.copy() 

863 

864 datatype: DataType | None = None 

865 """New datatype for the column.""" 

866 

867 length: int | None = None 

868 """New length for the column.""" 

869 

870 description: str | None = None 

871 """New description for the column.""" 

872 

873 nullable: bool | None = None 

874 """New nullable flag for the column.""" 

875 

876 tap_principal: int | None = Field(default=None, alias="tap:principal") 

877 """Override for the TAP_SCHEMA 'principal' flag.""" 

878 

879 tap_column_index: int | None = Field(default=None, alias="tap:column_index") 

880 """Override for the TAP_SCHEMA column index.""" 

881 

882 @model_validator(mode="before") 

883 @classmethod 

884 def _check_non_nullable_overrides(cls, data: Any) -> Any: 

885 """Check that certain fields are not overridden to null.""" 

886 if not isinstance(data, dict): 

887 return data 

888 non_nullable_fields = ("datatype", "length", "nullable", "tap_principal") 

889 for name in non_nullable_fields: 

890 if name in data and data[name] is None: 

891 raise ValueError(f"The '{name}' field cannot be overridden to null") 

892 return data 

893 

894 @field_serializer("datatype") 

895 def serialize_datatype(self, value: DataType | None) -> str | None: 

896 """Convert `DataType` to string when serializing to JSON/YAML. 

897 

898 Parameters 

899 ---------- 

900 value 

901 The `DataType` value to serialize, or None. 

902 

903 Returns 

904 ------- 

905 `str` | None 

906 The serialized `DataType` value, or None if the input was None. 

907 """ 

908 if value is None: 

909 return None 

910 return str(value) 

911 

912 @field_validator("datatype", mode="before") 

913 @classmethod 

914 def deserialize_datatype(cls, value: str | None) -> DataType | None: 

915 """Convert string back into `DataType` when loading from JSON/YAML. 

916 

917 Parameters 

918 ---------- 

919 value 

920 The string value to deserialize, or None. 

921 

922 Returns 

923 ------- 

924 `DataType` | None 

925 The deserialized `DataType` value, or None if the input was None. 

926 """ 

927 if value is None: 

928 return None 

929 return DataType(value) 

930 

931 

932class ColumnResourceRef(BaseModel): 

933 """A column which is dervived from an external resource.""" 

934 

935 ref_name: str | None = None 

936 """Name of the referenced column in the resource 

937 (if different from the key).""" 

938 

939 overrides: ColumnOverrides | None = None 

940 """Optional overrides of the referenced column's attributes.""" 

941 

942 

943# Type aliases for the nested mapping structure of referenced columns 

944ResourceColumnMap: TypeAlias = dict[str, ColumnResourceRef | None] 

945ResourceTableMap: TypeAlias = dict[str, ResourceColumnMap] 

946ResourceMap: TypeAlias = dict[str, ResourceTableMap] 

947 

948 

949class Table(BaseObject): 

950 """Table model.""" 

951 

952 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

953 """Primary key of the table.""" 

954 

955 tap_table_index: int | None = Field(None, alias="tap:table_index") 

956 """IVOA TAP_SCHEMA table index of the table.""" 

957 

958 mysql_engine: str | None = Field("MyISAM", alias="mysql:engine") 

959 """MySQL engine to use for the table.""" 

960 

961 mysql_charset: str | None = Field(None, alias="mysql:charset") 

962 """MySQL charset to use for the table.""" 

963 

964 column_refs: ResourceMap = Field(default_factory=dict, alias="columnRefs") 

965 """Referenced columns from external resources.""" 

966 

967 columns: list[Column] = Field(default_factory=list) 

968 """Columns in the table.""" 

969 

970 column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups") 

971 """Column groups in the table.""" 

972 

973 constraints: list[_ConstraintType] = Field(default_factory=list) 

974 """Constraints on the table.""" 

975 

976 indexes: list[Index] = Field(default_factory=list) 

977 """Indexes on the table.""" 

978 

979 @field_validator("columns", mode="after") 

980 @classmethod 

981 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

982 """Check that column names are unique. 

983 

984 Parameters 

985 ---------- 

986 columns 

987 The columns to check. 

988 

989 Returns 

990 ------- 

991 `list` [ `Column` ] 

992 The columns if they are unique. 

993 

994 Raises 

995 ------ 

996 ValueError 

997 Raised if column names are not unique. 

998 """ 

999 if len(columns) != len(set(column.name for column in columns)): 

1000 raise ValueError("Column names must be unique") 

1001 return columns 

1002 

1003 @model_validator(mode="after") 

1004 def check_tap_table_index(self, info: ValidationInfo) -> Table: 

1005 """Check that the table has a TAP table index. 

1006 

1007 Parameters 

1008 ---------- 

1009 info 

1010 Validation context used to determine if the check is enabled. 

1011 

1012 Returns 

1013 ------- 

1014 `Table` 

1015 The table being validated. 

1016 

1017 Raises 

1018 ------ 

1019 ValueError 

1020 Raised If the table is missing a TAP table index. 

1021 """ 

1022 context = info.context 

1023 if not context or not context.get("check_tap_table_indexes", False): 

1024 return self 

1025 if self.tap_table_index is None: 

1026 raise ValueError("Table is missing a TAP table index") 

1027 return self 

1028 

1029 @model_validator(mode="after") 

1030 def check_tap_principal(self, info: ValidationInfo) -> Table: 

1031 """Check that at least one column is flagged as 'principal' for TAP 

1032 purposes. 

1033 

1034 Parameters 

1035 ---------- 

1036 info 

1037 Validation context used to determine if the check is enabled. 

1038 

1039 Returns 

1040 ------- 

1041 `Table` 

1042 The table being validated. 

1043 

1044 Raises 

1045 ------ 

1046 ValueError 

1047 Raised if the table is missing a column flagged as 'principal'. 

1048 """ 

1049 context = info.context 

1050 if not context or not context.get("check_tap_principal", False): 

1051 return self 

1052 for col in self.columns: 

1053 if col.tap_principal == 1: 

1054 return self 

1055 raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'") 

1056 

1057 def _find_column_by_id(self, id: str) -> Column: 

1058 """Find a column by ID. 

1059 

1060 Parameters 

1061 ---------- 

1062 id 

1063 The ID of the column to find. 

1064 

1065 Returns 

1066 ------- 

1067 `Column` 

1068 The column with the given ID. 

1069 

1070 Raises 

1071 ------ 

1072 ValueError 

1073 Raised if the column is not found. 

1074 """ 

1075 for column in self.columns: 

1076 if column.id == id: 

1077 return column 

1078 raise KeyError(f"Column '{id}' not found in table '{self.name}'") 

1079 

1080 def _find_column_by_name(self, name: str) -> Column: 

1081 for column in self.columns: 

1082 if column.name == name: 

1083 return column 

1084 raise KeyError(f"Column '{name}' not found in table '{self.name}'") 

1085 

1086 @model_validator(mode="after") 

1087 def dereference_column_groups(self: Table) -> Table: 

1088 """Dereference columns in column groups. 

1089 

1090 Returns 

1091 ------- 

1092 `Table` 

1093 The table with dereferenced column groups. 

1094 """ 

1095 for group in self.column_groups: 

1096 group.table = self 

1097 group._dereference_columns() 

1098 return self 

1099 

1100 @field_serializer("columns") 

1101 def _serialize_columns(self, columns: list[Column]) -> list[dict[str, Any]]: 

1102 """Serialize only non-resource columns.""" 

1103 return [ 

1104 col.model_dump( 

1105 by_alias=True, 

1106 exclude_none=True, 

1107 exclude_defaults=True, 

1108 ) 

1109 for col in columns 

1110 if not col._is_resource_ref 

1111 ] 

1112 

1113 

1114class SchemaVersion(BaseModel): 

1115 """Schema version model.""" 

1116 

1117 current: str 

1118 """The current version of the schema.""" 

1119 

1120 compatible: list[str] = Field(default_factory=list) 

1121 """The compatible versions of the schema.""" 

1122 

1123 read_compatible: list[str] = Field(default_factory=list) 

1124 """The read compatible versions of the schema.""" 

1125 

1126 

1127class SchemaIdVisitor: 

1128 """Visit a schema and build the map of IDs to objects. 

1129 

1130 Notes 

1131 ----- 

1132 Duplicates are added to a set when they are encountered, which can be 

1133 accessed via the ``duplicates`` attribute. The presence of duplicates will 

1134 not throw an error. Only the first object with a given ID will be added to 

1135 the map, but this should not matter, since a ``ValidationError`` will be 

1136 thrown by the ``model_validator`` method if any duplicates are found in the 

1137 schema. 

1138 """ 

1139 

1140 def __init__(self) -> None: 

1141 """Create a new SchemaVisitor.""" 

1142 self.schema: Schema | None = None 

1143 self.duplicates: set[str] = set() 

1144 

1145 def add(self, obj: BaseObject) -> None: 

1146 """Add an object to the ID map. 

1147 

1148 Parameters 

1149 ---------- 

1150 obj 

1151 The object to add to the ID map. 

1152 """ 

1153 if hasattr(obj, "id"): 

1154 obj_id = getattr(obj, "id") 

1155 if self.schema is not None: 

1156 if obj_id in self.schema._id_map: 

1157 self.duplicates.add(obj_id) 

1158 else: 

1159 self.schema._id_map[obj_id] = obj 

1160 

1161 def visit_schema(self, schema: Schema) -> None: 

1162 """Visit the objects in a schema and build the ID map. 

1163 

1164 Parameters 

1165 ---------- 

1166 schema 

1167 The schema object to visit. 

1168 

1169 Notes 

1170 ----- 

1171 This will set an internal variable pointing to the schema object. 

1172 """ 

1173 self.schema = schema 

1174 self.duplicates.clear() 

1175 self.add(self.schema) 

1176 for table in self.schema.tables: 

1177 self.visit_table(table) 

1178 

1179 def visit_table(self, table: Table) -> None: 

1180 """Visit a table object. 

1181 

1182 Parameters 

1183 ---------- 

1184 table 

1185 The table object to visit. 

1186 """ 

1187 self.add(table) 

1188 for column in table.columns: 

1189 self.visit_column(column) 

1190 for constraint in table.constraints: 

1191 self.visit_constraint(constraint) 

1192 

1193 def visit_column(self, column: Column) -> None: 

1194 """Visit a column object. 

1195 

1196 Parameters 

1197 ---------- 

1198 column 

1199 The column object to visit. 

1200 """ 

1201 self.add(column) 

1202 

1203 def visit_constraint(self, constraint: Constraint) -> None: 

1204 """Visit a constraint object. 

1205 

1206 Parameters 

1207 ---------- 

1208 constraint 

1209 The constraint object to visit. 

1210 """ 

1211 self.add(constraint) 

1212 

1213 

1214T = TypeVar("T", bound=BaseObject) 

1215 

1216 

1217def _strip_ids(data: Any) -> Any: 

1218 """Recursively strip '@id' fields from a dictionary or list. 

1219 

1220 Parameters 

1221 ---------- 

1222 data 

1223 The data to strip IDs from, which can be a dictionary, list, or any 

1224 other type. Other types will be returned unchanged. 

1225 """ 

1226 if isinstance(data, dict): 

1227 data.pop("@id", None) 

1228 for k, v in data.items(): 

1229 data[k] = _strip_ids(v) 

1230 return data 

1231 elif isinstance(data, list): 

1232 return [_strip_ids(item) for item in data] 

1233 else: 

1234 return data 

1235 

1236 

1237def _append_error( 

1238 errors: list[InitErrorDetails], 

1239 loc: tuple, 

1240 input_value: Any, 

1241 error_message: str, 

1242 error_type: str = "value_error", 

1243) -> None: 

1244 """Append an error to the errors list. 

1245 

1246 Parameters 

1247 ---------- 

1248 errors : list[InitErrorDetails] 

1249 The list of errors to append to. 

1250 loc : tuple 

1251 The location of the error in the schema. 

1252 input_value : Any 

1253 The input value that caused the error. 

1254 error_message : str 

1255 The error message to include in the context. 

1256 """ 

1257 errors.append( 

1258 { 

1259 "type": error_type, 

1260 "loc": loc, 

1261 "input": input_value, 

1262 "ctx": {"error": error_message}, 

1263 } 

1264 ) 

1265 

1266 

1267class Resource(BaseModel): 

1268 """A resource definition referencing an external schema.""" 

1269 

1270 uri: str = Field(..., description="Resource URI or path") 

1271 """URI of the schema resource which may be a local path, ``resource://``, 

1272 or remote URL.""" 

1273 

1274 

1275class Schema(BaseObject, Generic[T]): 

1276 """Database schema model. 

1277 

1278 This represents a database schema, which contains one or more tables. 

1279 """ 

1280 

1281 version: SchemaVersion | str | None = None 

1282 """The version of the schema.""" 

1283 

1284 resources: dict[str, Resource] = Field(default_factory=dict) 

1285 """External resources referenced by this schema.""" 

1286 

1287 tables: Sequence[Table] 

1288 """The tables in the schema.""" 

1289 

1290 _id_map: dict[str, Any] = PrivateAttr(default_factory=dict) 

1291 """Map of IDs to objects.""" 

1292 

1293 _resource_map: dict[str, Schema] = PrivateAttr(default_factory=dict) 

1294 """Map of resource names to loaded schemas.""" 

1295 

1296 @model_validator(mode="after") 

1297 def _load_resources(self: Schema, info: ValidationInfo) -> Schema: 

1298 """Load external resources referenced by this schema into an internal 

1299 mapping of resource names to their `Schema` objects. 

1300 

1301 Returns 

1302 ------- 

1303 `Schema` 

1304 The schema being validated. 

1305 

1306 Raises 

1307 ------ 

1308 ValueError 

1309 Raised if a resource cannot be loaded. 

1310 """ 

1311 if info.context: 

1312 context = info.context.copy() 

1313 # Ignore this flag for loading the resources themselves 

1314 context.pop("dereference_resources", None) 

1315 else: 

1316 context = {} 

1317 

1318 # Get the base URI for resolving relative resource paths from the 

1319 # validation context, if available. 

1320 resource_path = context.pop("resource_path", None) 

1321 base_uri = None 

1322 if resource_path is not None: 

1323 base_uri = resource_path.parent() 

1324 

1325 for resource_name, resource in self.resources.items(): 

1326 uri = resource.uri 

1327 

1328 # Apply the base URI to the resource URI, if available. 

1329 if base_uri is not None: 

1330 orig_uri = uri 

1331 uri = base_uri.join(uri, forceDirectory=False) 

1332 if uri != orig_uri: 

1333 logger.info( 

1334 "Resolved relative URI '%s' for resource '%s' to '%s' using base URI '%s'", 

1335 resource.uri, 

1336 resource_name, 

1337 uri, 

1338 base_uri, 

1339 ) 

1340 

1341 try: 

1342 loaded_schema = Schema.from_uri(uri, context=context) 

1343 self._resource_map[resource_name] = loaded_schema 

1344 logger.debug("Loaded resource '%s' from URI '%s'", resource_name, uri) 

1345 except Exception as e: 

1346 raise ValueError(f"Failed to load resource '{resource_name}' from URI '{uri}': {e}") from e 

1347 return self 

1348 

1349 def _find_table_by_name(self, name: str) -> Table: 

1350 """Find a table by name. 

1351 

1352 Parameters 

1353 ---------- 

1354 name 

1355 The name of the table to find. 

1356 

1357 Returns 

1358 ------- 

1359 `Table` 

1360 The table with the given name. 

1361 

1362 Raises 

1363 ------ 

1364 KeyError 

1365 Raised if the table is not found. 

1366 """ 

1367 for table in self.tables: 

1368 if table.name == name: 

1369 return table 

1370 raise KeyError(f"Table '{name}' not found in schema '{self.name}'") 

1371 

1372 @model_validator(mode="after") 

1373 def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema: 

1374 """Dereference columns from external resources and add them to the 

1375 tables in this schema. 

1376 """ 

1377 context = info.context 

1378 column_ref_index_increment: int | None = None 

1379 dereference_resources = False 

1380 if context is not None: 

1381 dereference_resources = context.get("dereference_resources", False) 

1382 column_ref_index_increment = context.get("column_ref_index_increment", None) 

1383 

1384 for table in self.tables: 

1385 if column_refs := table.column_refs: 

1386 for resource_name, tables in column_refs.items(): 

1387 resource_schema = self._resource_map.get(resource_name) 

1388 if resource_schema is None: 

1389 raise ValueError(f"Schema resource '{resource_name}' was not found in resources") 

1390 self._process_column_refs( 

1391 table, 

1392 tables, 

1393 resource_schema, 

1394 dereference_resources, 

1395 column_ref_index_increment, 

1396 ) 

1397 if dereference_resources and len(table.column_refs) > 0: 

1398 # Clear column refs in table if fully dereferencing 

1399 logger.debug( 

1400 "Clearing columnRefs in table '%s' after dereferencing resource columns", 

1401 table.name, 

1402 ) 

1403 table.column_refs = {} 

1404 return self 

1405 

1406 @classmethod 

1407 def _process_column_refs( 

1408 cls, 

1409 table: Table, 

1410 ref_tables: ResourceTableMap, 

1411 resource_schema: Schema, 

1412 dereference_resources: bool = False, 

1413 column_ref_index_increment: int | None = None, 

1414 ) -> None: 

1415 """Process column references from an external resource and add them 

1416 to the given table as columns. 

1417 """ 

1418 current_column_index = column_ref_index_increment if column_ref_index_increment is not None else -1 

1419 

1420 for table_name, columns in ref_tables.items(): 

1421 try: 

1422 resource_table = resource_schema._find_table_by_name(table_name) 

1423 except KeyError as e: 

1424 raise ValueError( 

1425 f"Table '{table_name}' not found in resource '{resource_schema.name}'" 

1426 ) from e 

1427 for local_column_name, column_ref in columns.items(): 

1428 if column_ref is not None and column_ref.ref_name is not None: 

1429 # Use specified ref_name 

1430 ref_column_name = column_ref.ref_name 

1431 else: 

1432 # Use the local column name if no ref_name 

1433 # specified 

1434 ref_column_name = local_column_name 

1435 

1436 # Check if referenced column exists in resource 

1437 try: 

1438 base_column = resource_table._find_column_by_name(ref_column_name) 

1439 except KeyError: 

1440 # The ref_name is specified but column is not 

1441 # found 

1442 if column_ref is not None and column_ref.ref_name is not None: 

1443 raise ValueError( 

1444 f"Column '{ref_column_name}' not found in table '{table_name}' " 

1445 f"from resource '{resource_schema.name}'" 

1446 ) 

1447 # The ref_name is not specified and the local 

1448 # column name is not found 

1449 raise ValueError( 

1450 f"Column '{local_column_name}' not found in table '{table_name}' " 

1451 f"from resource '{resource_schema.name}' and no ref_name provided" 

1452 ) 

1453 

1454 # Create a copy of the base column 

1455 column_copy = base_column.model_copy() 

1456 

1457 # Set the local name (key from the mapping) 

1458 column_copy.name = local_column_name 

1459 

1460 if not dereference_resources: 

1461 # Flag the column as a resource reference so it will not be 

1462 # written out during serialization 

1463 column_copy._is_resource_ref = True 

1464 

1465 # Apply overrides to the referenced column definition 

1466 overrides = column_ref.overrides if column_ref is not None else None 

1467 if overrides is not None: 

1468 column_copy._update_from_overrides(overrides) 

1469 

1470 # Manually set the ID of the copied column as ID generation has 

1471 # already occurred by now 

1472 column_copy.id = f"{table.id}.{local_column_name}" 

1473 

1474 # Apply automatic assignment of 'tap:column_index', if enabled 

1475 if column_ref_index_increment is not None: 

1476 if (not overrides) or (overrides.tap_column_index is None): 

1477 column_copy.tap_column_index = current_column_index 

1478 current_column_index += column_ref_index_increment 

1479 logger.debug( 

1480 "Automatically assigned 'tap:column_index' %s to " 

1481 "column '%s' in table '%s' from resource '%s'", 

1482 column_copy.tap_column_index, 

1483 local_column_name, 

1484 table_name, 

1485 resource_schema.name, 

1486 ) 

1487 else: 

1488 # Skip automatic assignment of 'tap:column_index' if it 

1489 # is already overridden 

1490 logger.debug( 

1491 "Skipping automatic assignment of 'tap:column_index' for column " 

1492 "'%s' in table '%s' from resource '%s' as it is already overridden to %s", 

1493 local_column_name, 

1494 table_name, 

1495 resource_schema.name, 

1496 column_copy.tap_column_index, 

1497 ) 

1498 table.columns.append(column_copy) 

1499 logger.debug( 

1500 "Dereferenced column '%s' from table '%s' in resource '%s' into table '%s'", 

1501 local_column_name, 

1502 table_name, 

1503 resource_schema.name, 

1504 table.name, 

1505 ) 

1506 

1507 @model_validator(mode="before") 

1508 @classmethod 

1509 def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]: 

1510 """Generate IDs for objects that do not have them. 

1511 

1512 Parameters 

1513 ---------- 

1514 values 

1515 The values of the schema. 

1516 info 

1517 Validation context used to determine if ID generation is enabled. 

1518 

1519 Returns 

1520 ------- 

1521 `dict` [ `str`, `Any` ] 

1522 The values of the schema with generated IDs. 

1523 """ 

1524 context = info.context 

1525 if not context or not context.get("id_generation", False): 

1526 logger.debug("Skipping ID generation") 

1527 return values 

1528 schema_name = values["name"] 

1529 if "@id" not in values: 

1530 values["@id"] = f"#{schema_name}" 

1531 logger.debug("Generated ID '%s' for schema '%s'", values["@id"], schema_name) 

1532 if "tables" in values: 

1533 for table in values["tables"]: 

1534 if "@id" not in table: 

1535 table["@id"] = f"#{table['name']}" 

1536 logger.debug("Generated ID '%s' for table '%s'", table["@id"], table["name"]) 

1537 if "columns" in table: 

1538 for column in table["columns"]: 

1539 if "@id" not in column: 

1540 column["@id"] = f"#{table['name']}.{column['name']}" 

1541 logger.debug("Generated ID '%s' for column '%s'", column["@id"], column["name"]) 

1542 if "columnGroups" in table: 

1543 for column_group in table["columnGroups"]: 

1544 if "@id" not in column_group: 

1545 column_group["@id"] = f"#{table['name']}.{column_group['name']}" 

1546 logger.debug( 

1547 "Generated ID '%s' for column group '%s'", 

1548 column_group["@id"], 

1549 column_group["name"], 

1550 ) 

1551 if "constraints" in table: 

1552 for constraint in table["constraints"]: 

1553 if "@id" not in constraint: 

1554 constraint["@id"] = f"#{constraint['name']}" 

1555 logger.debug( 

1556 "Generated ID '%s' for constraint '%s'", 

1557 constraint["@id"], 

1558 constraint["name"], 

1559 ) 

1560 if "indexes" in table: 

1561 for index in table["indexes"]: 

1562 if "@id" not in index: 

1563 index["@id"] = f"#{index['name']}" 

1564 logger.debug("Generated ID '%s' for index '%s'", index["@id"], index["name"]) 

1565 return values 

1566 

1567 @field_validator("tables", mode="after") 

1568 @classmethod 

1569 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

1570 """Check that table names are unique. 

1571 

1572 Parameters 

1573 ---------- 

1574 tables 

1575 The tables to check. 

1576 

1577 Returns 

1578 ------- 

1579 `list` [ `Table` ] 

1580 The tables if they are unique. 

1581 

1582 Raises 

1583 ------ 

1584 ValueError 

1585 Raised if table names are not unique. 

1586 """ 

1587 if len(tables) != len(set(table.name for table in tables)): 

1588 raise ValueError("Table names must be unique") 

1589 return tables 

1590 

1591 @model_validator(mode="after") 

1592 def check_tap_table_indexes(self, info: ValidationInfo) -> Schema: 

1593 """Check that the TAP table indexes are unique. 

1594 

1595 Parameters 

1596 ---------- 

1597 info 

1598 The validation context used to determine if the check is enabled. 

1599 

1600 Returns 

1601 ------- 

1602 `Schema` 

1603 The schema being validated. 

1604 """ 

1605 context = info.context 

1606 if not context or not context.get("check_tap_table_indexes", False): 

1607 return self 

1608 table_indicies = set() 

1609 for table in self.tables: 

1610 table_index = table.tap_table_index 

1611 if table_index is not None: 

1612 if table_index in table_indicies: 

1613 raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema") 

1614 table_indicies.add(table_index) 

1615 return self 

1616 

1617 @model_validator(mode="after") 

1618 def check_unique_constraint_names(self: Schema) -> Schema: 

1619 """Check for duplicate constraint names in the schema. 

1620 

1621 Returns 

1622 ------- 

1623 `Schema` 

1624 The schema being validated. 

1625 

1626 Raises 

1627 ------ 

1628 ValueError 

1629 Raised if duplicate constraint names are found in the schema. 

1630 """ 

1631 constraint_names = set() 

1632 duplicate_names = [] 

1633 

1634 for table in self.tables: 

1635 for constraint in table.constraints: 

1636 constraint_name = constraint.name 

1637 if constraint_name in constraint_names: 

1638 duplicate_names.append(constraint_name) 

1639 else: 

1640 constraint_names.add(constraint_name) 

1641 

1642 if duplicate_names: 

1643 raise ValueError(f"Duplicate constraint names found in schema: {duplicate_names}") 

1644 

1645 return self 

1646 

1647 @model_validator(mode="after") 

1648 def check_unique_index_names(self: Schema) -> Schema: 

1649 """Check for duplicate index names in the schema. 

1650 

1651 Returns 

1652 ------- 

1653 `Schema` 

1654 The schema being validated. 

1655 

1656 Raises 

1657 ------ 

1658 ValueError 

1659 Raised if duplicate index names are found in the schema. 

1660 """ 

1661 index_names = set() 

1662 duplicate_names = [] 

1663 

1664 for table in self.tables: 

1665 for index in table.indexes: 

1666 index_name = index.name 

1667 if index_name in index_names: 

1668 duplicate_names.append(index_name) 

1669 else: 

1670 index_names.add(index_name) 

1671 

1672 if duplicate_names: 

1673 raise ValueError(f"Duplicate index names found in schema: {duplicate_names}") 

1674 

1675 return self 

1676 

1677 @model_validator(mode="after") 

1678 def create_id_map(self: Schema) -> Schema: 

1679 """Create a map of IDs to objects. 

1680 

1681 Returns 

1682 ------- 

1683 `Schema` 

1684 The schema with the ID map created. 

1685 

1686 Raises 

1687 ------ 

1688 ValueError 

1689 Raised if duplicate identifiers are found in the schema. 

1690 """ 

1691 if self._id_map: 

1692 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

1693 return self 

1694 visitor: SchemaIdVisitor = SchemaIdVisitor() 

1695 visitor.visit_schema(self) 

1696 if len(visitor.duplicates): 

1697 raise ValueError( 

1698 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

1699 ) 

1700 logger.debug("Created ID map with %d entries", len(self._id_map)) 

1701 return self 

1702 

1703 def _validate_column_id( 

1704 self: Schema, 

1705 column_id: str, 

1706 loc: tuple, 

1707 errors: list[InitErrorDetails], 

1708 ) -> None: 

1709 """Validate a column ID from a constraint and append errors if invalid. 

1710 

1711 Parameters 

1712 ---------- 

1713 schema : Schema 

1714 The schema being validated. 

1715 column_id : str 

1716 The column ID to validate. 

1717 loc : tuple 

1718 The location of the error in the schema. 

1719 errors : list[InitErrorDetails] 

1720 The list of errors to append to. 

1721 """ 

1722 if column_id not in self: 

1723 _append_error( 

1724 errors, 

1725 loc, 

1726 column_id, 

1727 f"Column ID '{column_id}' not found in schema", 

1728 ) 

1729 elif not isinstance(self[column_id], Column): 

1730 _append_error( 

1731 errors, 

1732 loc, 

1733 column_id, 

1734 f"ID '{column_id}' does not refer to a Column object", 

1735 ) 

1736 

1737 def _validate_foreign_key_column( 

1738 self: Schema, 

1739 column_id: str, 

1740 table: Table, 

1741 loc: tuple, 

1742 errors: list[InitErrorDetails], 

1743 ) -> None: 

1744 """Validate a foreign key column ID from a constraint and append errors 

1745 if invalid. 

1746 

1747 Parameters 

1748 ---------- 

1749 schema : Schema 

1750 The schema being validated. 

1751 column_id : str 

1752 The foreign key column ID to validate. 

1753 loc : tuple 

1754 The location of the error in the schema. 

1755 errors : list[InitErrorDetails] 

1756 The list of errors to append to. 

1757 """ 

1758 try: 

1759 table._find_column_by_id(column_id) 

1760 except KeyError: 

1761 _append_error( 

1762 errors, 

1763 loc, 

1764 column_id, 

1765 f"Column '{column_id}' not found in table '{table.name}'", 

1766 ) 

1767 

1768 @model_validator(mode="after") 

1769 def check_constraints(self: Schema) -> Schema: 

1770 """Check constraint objects for validity. This needs to be deferred 

1771 until after the schema is fully loaded and the ID map is created. 

1772 

1773 Raises 

1774 ------ 

1775 pydantic.ValidationError 

1776 Raised if any constraints are invalid. 

1777 

1778 Returns 

1779 ------- 

1780 `Schema` 

1781 The schema being validated. 

1782 """ 

1783 errors: list[InitErrorDetails] = [] 

1784 

1785 for table_index, table in enumerate(self.tables): 

1786 for constraint_index, constraint in enumerate(table.constraints): 

1787 column_ids: list[str] = [] 

1788 referenced_column_ids: list[str] = [] 

1789 

1790 if isinstance(constraint, ForeignKeyConstraint): 

1791 column_ids += constraint.columns 

1792 referenced_column_ids += constraint.referenced_columns 

1793 elif isinstance(constraint, UniqueConstraint): 

1794 column_ids += constraint.columns 

1795 # No extra checks are required on CheckConstraint objects. 

1796 

1797 # Validate the foreign key columns 

1798 for column_id in column_ids: 

1799 self._validate_column_id( 

1800 column_id, 

1801 ( 

1802 "tables", 

1803 table_index, 

1804 "constraints", 

1805 constraint_index, 

1806 "columns", 

1807 column_id, 

1808 ), 

1809 errors, 

1810 ) 

1811 # Check that the foreign key column is within the source 

1812 # table. 

1813 self._validate_foreign_key_column( 

1814 column_id, 

1815 table, 

1816 ( 

1817 "tables", 

1818 table_index, 

1819 "constraints", 

1820 constraint_index, 

1821 "columns", 

1822 column_id, 

1823 ), 

1824 errors, 

1825 ) 

1826 

1827 # Validate the primary key (reference) columns 

1828 for referenced_column_id in referenced_column_ids: 

1829 self._validate_column_id( 

1830 referenced_column_id, 

1831 ( 

1832 "tables", 

1833 table_index, 

1834 "constraints", 

1835 constraint_index, 

1836 "referenced_columns", 

1837 referenced_column_id, 

1838 ), 

1839 errors, 

1840 ) 

1841 

1842 if errors: 

1843 raise ValidationError.from_exception_data("Schema validation failed", errors) 

1844 

1845 return self 

1846 

1847 def __getitem__(self, id: str) -> BaseObject: 

1848 """Get an object by its ID. 

1849 

1850 Parameters 

1851 ---------- 

1852 id 

1853 The ID of the object to get. 

1854 

1855 Raises 

1856 ------ 

1857 KeyError 

1858 Raised if the object with the given ID is not found in the schema. 

1859 """ 

1860 if id not in self: 

1861 raise KeyError(f"Object with ID '{id}' not found in schema") 

1862 return self._id_map[id] 

1863 

1864 def __contains__(self, id: str) -> bool: 

1865 """Check if an object with the given ID is in the schema. 

1866 

1867 Parameters 

1868 ---------- 

1869 id 

1870 The ID of the object to check. 

1871 """ 

1872 return id in self._id_map 

1873 

1874 def find_object_by_id(self, id: str, obj_type: type[T]) -> T: 

1875 """Find an object with the given type by its ID. 

1876 

1877 Parameters 

1878 ---------- 

1879 id 

1880 The ID of the object to find. 

1881 obj_type 

1882 The type of the object to find. 

1883 

1884 Returns 

1885 ------- 

1886 BaseObject 

1887 The object with the given ID and type. 

1888 

1889 Raises 

1890 ------ 

1891 KeyError 

1892 If the object with the given ID is not found in the schema. 

1893 TypeError 

1894 If the object that is found does not have the right type. 

1895 

1896 Notes 

1897 ----- 

1898 The actual return type is the user-specified argument ``T``, which is 

1899 expected to be a subclass of `BaseObject`. 

1900 """ 

1901 obj = self[id] 

1902 if not isinstance(obj, obj_type): 

1903 raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'") 

1904 return obj 

1905 

1906 def get_table_by_column(self, column: Column) -> Table: 

1907 """Find the table that contains a column. 

1908 

1909 Parameters 

1910 ---------- 

1911 column 

1912 The column to find. 

1913 

1914 Returns 

1915 ------- 

1916 `Table` 

1917 The table that contains the column. 

1918 

1919 Raises 

1920 ------ 

1921 ValueError 

1922 If the column is not found in any table. 

1923 """ 

1924 for table in self.tables: 

1925 if column in table.columns: 

1926 return table 

1927 raise ValueError(f"Column '{column.name}' not found in any table") 

1928 

1929 @classmethod 

1930 def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema: 

1931 """Load a `Schema` from a string representing a ``ResourcePath``. 

1932 

1933 Parameters 

1934 ---------- 

1935 resource_path 

1936 The ``ResourcePath`` pointing to a YAML file. 

1937 context 

1938 Pydantic context to be used in validation. 

1939 

1940 Returns 

1941 ------- 

1942 `str` 

1943 The ID of the object. 

1944 

1945 Raises 

1946 ------ 

1947 yaml.YAMLError 

1948 Raised if there is an error loading the YAML data. 

1949 ValueError 

1950 Raised if there is an error reading the resource. 

1951 pydantic.ValidationError 

1952 Raised if the schema fails validation. 

1953 """ 

1954 try: 

1955 rp = ResourcePath(resource_path, forceAbsolute=False, forceDirectory=False) 

1956 rp_data = rp.read() 

1957 except Exception as e: 

1958 raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e 

1959 yaml_data = yaml.safe_load(rp_data) 

1960 context = dict(context) 

1961 # Append the resource path to the context for resolving resource URLs. 

1962 context["resource_path"] = rp 

1963 return Schema.model_validate(yaml_data, context=context) 

1964 

1965 @classmethod 

1966 def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema: 

1967 """Load a `Schema` from a file stream which should contain YAML data. 

1968 

1969 Parameters 

1970 ---------- 

1971 source 

1972 The file stream to read from. 

1973 context 

1974 Pydantic context to be used in validation. 

1975 

1976 Returns 

1977 ------- 

1978 `Schema` 

1979 The Felis schema loaded from the stream. 

1980 

1981 Raises 

1982 ------ 

1983 yaml.YAMLError 

1984 Raised if there is an error loading the YAML file. 

1985 pydantic.ValidationError 

1986 Raised if the schema fails validation. 

1987 """ 

1988 logger.debug("Loading schema from: '%s'", source) 

1989 yaml_data = yaml.safe_load(source) 

1990 return Schema.model_validate(yaml_data, context=context) 

1991 

1992 def _model_dump(self, strip_ids: bool = False, sort_columns: bool = False) -> dict[str, Any]: 

1993 """Dump the schema as a dictionary with some default arguments 

1994 applied. 

1995 

1996 Parameters 

1997 ---------- 

1998 strip_ids 

1999 Whether to strip the IDs from the dumped data. Defaults to `False`. 

2000 sort_columns 

2001 Whether to sort columns alphabetically by name. Defaults to 

2002 `False`. 

2003 

2004 Returns 

2005 ------- 

2006 `dict` [ `str`, `Any` ] 

2007 The dumped schema data as a dictionary. 

2008 """ 

2009 data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True) 

2010 if strip_ids: 

2011 data = _strip_ids(data) 

2012 if sort_columns: 

2013 for table in data.get("tables", []): 

2014 table["columns"] = sorted(table.get("columns", []), key=itemgetter("name")) 

2015 return data 

2016 

2017 def dump_yaml( 

2018 self, stream: IO[str] = sys.stdout, strip_ids: bool = False, sort_columns: bool = False 

2019 ) -> None: 

2020 """Pretty print the schema as YAML. 

2021 

2022 Parameters 

2023 ---------- 

2024 stream 

2025 The stream to write the YAML data to. 

2026 strip_ids 

2027 Whether to strip the IDs from the dumped data. Defaults to `False`. 

2028 sort_columns 

2029 Whether to sort columns alphabetically by name. Defaults to 

2030 `False`. 

2031 """ 

2032 data = self._model_dump(strip_ids=strip_ids, sort_columns=sort_columns) 

2033 yaml.safe_dump( 

2034 data, 

2035 stream, 

2036 default_flow_style=False, 

2037 sort_keys=False, 

2038 ) 

2039 

2040 def dump_json( 

2041 self, stream: IO[str] = sys.stdout, strip_ids: bool = False, sort_columns: bool = False 

2042 ) -> None: 

2043 """Pretty print the schema as JSON. 

2044 

2045 Parameters 

2046 ---------- 

2047 stream 

2048 The stream to write the JSON data to. 

2049 strip_ids 

2050 Whether to strip the IDs from the dumped data. Defaults to `False`. 

2051 sort_columns 

2052 Whether to sort columns alphabetically by name. Defaults to 

2053 `False`. 

2054 """ 

2055 data = self._model_dump(strip_ids=strip_ids, sort_columns=sort_columns) 

2056 json.dump( 

2057 data, 

2058 stream, 

2059 indent=4, 

2060 sort_keys=False, 

2061 )