Coverage for python / felis / db / database_context.py: 32%

353 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 08:23 +0000

1"""API for managing database operations across different dialects.""" 

2 

3# This file is part of felis. 

4# 

5# Developed for the LSST Data Management System. 

6# This product includes software developed by the LSST Project 

7# (https://www.lsst.org). 

8# See the COPYRIGHT file at the top-level directory of this distribution 

9# for details of code ownership. 

10# 

11# This program is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU General Public License as published by 

13# the Free Software Foundation, either version 3 of the License, or 

14# (at your option) any later version. 

15# 

16# This program is distributed in the hope that it will be useful, 

17# but WITHOUT ANY WARRANTY; without even the implied warranty of 

18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19# GNU General Public License for more details. 

20# 

21# You should have received a copy of the GNU General Public License 

22# along with this program. If not, see <https://www.gnu.org/licenses/>. 

23 

24from __future__ import annotations 

25 

26import logging 

27from abc import abstractmethod 

28from collections.abc import Callable, Iterator 

29from contextlib import AbstractContextManager, contextmanager 

30from typing import IO, Any, Literal, TypeAlias 

31 

32from sqlalchemy import ( 

33 Engine, 

34 MetaData, 

35 create_engine, 

36 inspect, 

37 make_url, 

38 quoted_name, 

39) 

40from sqlalchemy.engine import ( 

41 Connection, 

42 Dialect, 

43 Result, 

44) 

45from sqlalchemy.engine.mock import MockConnection, create_mock_engine 

46from sqlalchemy.engine.url import URL 

47from sqlalchemy.exc import SQLAlchemyError 

48from sqlalchemy.schema import ( 

49 CreateSchema, 

50 DropSchema, 

51) 

52from sqlalchemy.sql import ( 

53 Executable, 

54 text, 

55) 

56from sqlalchemy.sql.elements import TextClause 

57 

58__all__ = [ 

59 "DatabaseContext", 

60 "DatabaseContextError", 

61 "MockContext", 

62 "MySQLContext", 

63 "PostgreSQLContext", 

64 "SQLiteContext", 

65 "create_database_context", 

66] 

67 

68logger = logging.getLogger(__name__) 

69 

70SQLStatement = str | Executable | TextClause 

71 

72 

73def _normalize_statement(statement: SQLStatement) -> Executable | TextClause: 

74 if isinstance(statement, str): 

75 return text(statement) 

76 return statement 

77 

78 

79def _create_mock_connection(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection: 

80 writer = _SQLWriter(output_file) 

81 engine = create_mock_engine(engine_url, executor=writer.write, paramstyle="pyformat") 

82 writer.dialect = engine.dialect 

83 return engine 

84 

85 

86def _dialect_name(url: URL) -> str: 

87 dialect_name = url.drivername 

88 # Normalize dialect name (e.g., "postgresql+psycopg2" -> "postgresql") 

89 if "+" in dialect_name: 

90 dialect_name = dialect_name.split("+")[0] 

91 return dialect_name 

92 

93 

94def _clear_schema(metadata: MetaData) -> None: 

95 if metadata.schema: 

96 metadata.schema = None 

97 for table in metadata.tables.values(): 

98 table.schema = None 

99 

100 

101def _get_existing_indexes(inspector: Any, table_name: str, schema: str | None) -> set[str]: 

102 return { 

103 ix["name"] 

104 for ix in inspector.get_indexes(table_name, schema=schema) 

105 if "name" in ix and ix["name"] is not None 

106 } 

107 

108 

109def is_mock_url(url: URL) -> bool: 

110 """Check if the engine URL points to a mock connection. 

111 

112 Parameters 

113 ---------- 

114 url 

115 The SQLAlchemy engine URL. 

116 

117 Returns 

118 ------- 

119 bool 

120 True if the URL is a mock URL, False otherwise. 

121 """ 

122 return (url.drivername == "sqlite" and url.database is None) or ( 

123 url.drivername != "sqlite" and url.host is None 

124 ) 

125 

126 

127def is_sqlite_url(url: URL | str) -> bool: 

128 """Check if the engine URL points to a SQLite database. 

129 

130 Parameters 

131 ---------- 

132 url 

133 The SQLAlchemy engine URL or string. 

134 

135 Returns 

136 ------- 

137 bool 

138 True if the URL is a SQLite URL, False otherwise. 

139 """ 

140 if isinstance(url, str): 

141 url = make_url(url) 

142 return url.drivername.startswith("sqlite") 

143 

144 

145class DatabaseContextError(Exception): 

146 """Exception raised for errors in the DatabaseContext operations.""" 

147 

148 

149class DatabaseContext(AbstractContextManager): 

150 """Interface for managing database operations across different 

151 SQL dialects. 

152 """ 

153 

154 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: 

155 """Exit the context manager and clean up resources.""" 

156 try: 

157 self.close() 

158 except Exception: 

159 logger.exception("Error during cleanup of database context") 

160 return False 

161 

162 @abstractmethod 

163 def close(self) -> None: 

164 """Close and clean up database resources.""" 

165 ... 

166 

167 @property 

168 @abstractmethod 

169 def metadata(self) -> MetaData: 

170 """The SQLAlchemy metadata representing the database for the context 

171 (`~sqlalchemy.sql.schema.MetaData`). 

172 """ 

173 ... 

174 

175 @property 

176 @abstractmethod 

177 def engine(self) -> Engine: 

178 """The SQAlchemy engine for the context 

179 (`~sqlalchemy.engine.Engine`). 

180 """ 

181 ... 

182 

183 @property 

184 @abstractmethod 

185 def dialect(self) -> Dialect: 

186 """The SQLAlchemy dialect for the context 

187 (`~sqlalchemy.engine.Dialect`). 

188 """ 

189 ... 

190 

191 @property 

192 @abstractmethod 

193 def dialect_name(self) -> str: 

194 """Get the dialect name for this database context (``str``).""" 

195 ... 

196 

197 @abstractmethod 

198 def initialize(self) -> None: 

199 """Create the target schema in the database if it does not exist 

200 already. 

201 

202 Sub-classes should implement idempotent behavior so that calling this 

203 method multiple times has no adverse effects. If the schema already 

204 exists, the method should simply return without raising an error. (A 

205 warning message may be logged in this case.) 

206 

207 Raises 

208 ------ 

209 DatabaseContextError 

210 If there is an error instantiating the schema. 

211 """ 

212 ... 

213 

214 @abstractmethod 

215 def drop(self) -> None: 

216 """Drop the schema in the database if it exists. 

217 

218 Implementations should use ``IF EXISTS`` semantics to avoid raising 

219 an error if the schema does not exist. 

220 

221 Raises 

222 ------ 

223 DatabaseContextError 

224 If there is an error dropping the schema. 

225 """ 

226 ... 

227 

228 @abstractmethod 

229 def create_all(self) -> None: 

230 """Create all database objects in the schema using the metadata 

231 object. 

232 

233 Raises 

234 ------ 

235 DatabaseContextError 

236 If there is an error creating the schema objects in the database. 

237 """ 

238 ... 

239 

240 @abstractmethod 

241 def create_indexes(self) -> None: 

242 """Create all indexes in the schema using the metadata object. 

243 

244 Raises 

245 ------ 

246 DatabaseContextError 

247 If there is an error creating the indexes in the database. 

248 """ 

249 ... 

250 

251 @abstractmethod 

252 def drop_indexes(self) -> None: 

253 """Drop all indexes in the schema using the metadata object. 

254 

255 Raises 

256 ------ 

257 DatabaseContextError 

258 If there is an error dropping the indexes in the database. 

259 """ 

260 ... 

261 

262 @abstractmethod 

263 def execute(self, statement: SQLStatement, params: dict[str, Any] | None = None) -> Result: 

264 """Execute a SQL statement and return the result. 

265 

266 Parameters 

267 ---------- 

268 statement 

269 The SQL statement to execute. 

270 params 

271 Optional parameters to use for the SQL statement. 

272 

273 Returns 

274 ------- 

275 `~sqlalchemy.engine.Result` 

276 The result of the statement execution. 

277 

278 Raises 

279 ------ 

280 DatabaseContextError 

281 If there is an error executing the SQL statement. 

282 """ 

283 ... 

284 

285 

286class _BaseContext(DatabaseContext): 

287 """Base database context providing common behavior. 

288 

289 Parameters 

290 ---------- 

291 engine_url 

292 The SQLAlchemy engine for connecting to the database. 

293 metadata 

294 The SQLAlchemy metadata representing the database objects. 

295 require_schema 

296 True if a valid schema name is required on the MetaData, False if not. 

297 """ 

298 

299 # Subclasses should set this to the dialect name. 

300 DIALECT: str 

301 

302 def __init__(self, engine_url: URL, metadata: MetaData, require_schema: bool = False) -> None: 

303 self._engine_url = engine_url 

304 self._metadata = metadata 

305 self._schema_name: str | None = metadata.schema 

306 self._engine: Engine | None = None 

307 self._echo: bool = False 

308 

309 # Check that the URL dialect matches this context's expected dialect 

310 self._validate_dialect(engine_url) 

311 

312 # Ensure the schema name is set for dialects that require it 

313 if require_schema and self._schema_name is None: 

314 raise DatabaseContextError(f"Schema name must be set for context: {self.dialect_name}") 

315 

316 @property 

317 def echo(self) -> bool: 

318 """Whether to log all SQL statements executed by the engine 

319 (``bool``). 

320 """ 

321 return self._echo 

322 

323 @echo.setter 

324 def echo(self, value: bool) -> None: 

325 self._echo = value 

326 if self.engine is not None: 

327 self.engine.echo = value 

328 

329 @classmethod 

330 def _validate_dialect(cls, engine_url: URL) -> None: 

331 """Validate that the engine dialect matches this context's expected 

332 dialect. 

333 

334 Parameters 

335 ---------- 

336 engine_url 

337 The SQLAlchemy database URL to validate. 

338 

339 Raises 

340 ------ 

341 DatabaseContextError 

342 If the engine dialect doesn't match the context's expected dialect. 

343 """ 

344 # Normalize both the engine dialect and expected dialect for comparison 

345 engine_dialect = _dialect_name(engine_url) 

346 expected_dialect = cls.DIALECT.lower() 

347 

348 if engine_dialect != expected_dialect: 

349 raise DatabaseContextError( 

350 f"Engine dialect '{engine_dialect}' does not match the context's expected dialect: " 

351 f"{expected_dialect}" 

352 ) 

353 

354 @property 

355 def engine(self) -> Engine: 

356 if self._engine is None: 

357 self._engine = create_engine(self._engine_url) 

358 return self._engine 

359 

360 @property 

361 def metadata(self) -> MetaData: 

362 return self._metadata 

363 

364 @property 

365 def dialect(self) -> Dialect: 

366 return self.engine.dialect 

367 

368 @property 

369 def dialect_name(self) -> str: 

370 """Get the dialect name for this database context. 

371 

372 Returns 

373 ------- 

374 str 

375 The normalized dialect name. 

376 """ 

377 return self.DIALECT 

378 

379 @property 

380 def schema_name(self) -> str | None: 

381 """Effective schema name for this context (may be None). 

382 

383 Returns 

384 ------- 

385 str | None 

386 The schema name, or None if no schema is set. 

387 """ 

388 return self._schema_name 

389 

390 @contextmanager 

391 def connect(self) -> Iterator[Connection]: 

392 """Context manager for database connection.""" 

393 with self.engine.connect() as connection: 

394 yield connection 

395 

396 def execute(self, statement: SQLStatement, params: dict[str, Any] | None = None) -> Result: 

397 statement = _normalize_statement(statement) 

398 try: 

399 with self.connect() as conn: 

400 with conn.begin(): 

401 if params: 

402 result = conn.execute(statement, params) 

403 else: 

404 result = conn.execute(statement) 

405 return result 

406 except SQLAlchemyError as e: 

407 raise DatabaseContextError(f"Error executing statement: {e}") from e 

408 

409 def create_all(self) -> None: 

410 with self.connect() as conn: 

411 with conn.begin(): 

412 try: 

413 self.metadata.create_all(bind=conn) 

414 except SQLAlchemyError as e: 

415 raise DatabaseContextError(f"Error creating database: {e}") from e 

416 

417 def _manage_indexes(self, action: str) -> None: 

418 """Manage indexes by creating or dropping them. 

419 

420 Parameters 

421 ---------- 

422 action 

423 The action to perform, either "create" or "drop". 

424 

425 Raises 

426 ------ 

427 DatabaseContextError 

428 If there is an error managing the indexes in the database. 

429 """ 

430 with self.connect() as conn: 

431 with conn.begin(): 

432 try: 

433 inspector = inspect(conn) 

434 for table in self.metadata.tables.values(): 

435 # Fetch all existing indexes for this table once 

436 existing_indexes = _get_existing_indexes(inspector, table.name, self.schema_name) 

437 

438 for index in table.indexes: 

439 if index.name is None: 

440 # Anonymous indexes can't be checked by name 

441 logger.warning("Skipping anonymous index on table '%s'", table.name) 

442 continue 

443 

444 if action == "create": 

445 if index.name in existing_indexes: 

446 logger.warning( 

447 "Skipping creation of index '%s' which already exists", 

448 index.name, 

449 ) 

450 continue 

451 index.create(bind=conn, checkfirst=False) # We already checked 

452 logger.info("Created index '%s'", index.name) 

453 elif action == "drop": 

454 if index.name not in existing_indexes: 

455 logger.warning("Skipping index '%s' which does not exist", index.name) 

456 continue 

457 index.drop(bind=conn, checkfirst=False) # We already checked 

458 logger.info("Dropped index '%s'", index.name) 

459 else: 

460 raise ValueError(f"Invalid action '{action}'. Must be 'create' or 'drop'.") 

461 except SQLAlchemyError as e: 

462 raise DatabaseContextError(f"Error {action}ing indexes: {e}") from e 

463 

464 def create_indexes(self) -> None: 

465 """Create all indexes in the schema using the metadata object. 

466 

467 Raises 

468 ------ 

469 DatabaseContextError 

470 If there is an error creating the indexes in the database. 

471 """ 

472 self._manage_indexes("create") 

473 

474 def drop_indexes(self) -> None: 

475 """Drop all indexes in the schema using the metadata object. 

476 

477 Raises 

478 ------ 

479 DatabaseContextError 

480 If there is an error dropping the indexes in the database. 

481 """ 

482 self._manage_indexes("drop") 

483 

484 def _required_schema_name(self) -> str: 

485 """Return the schema name, ensuring that it is set. 

486 

487 This is mainly here for typing purposes, because the schema_name 

488 property may be None, and mypy doesn't understand that we already 

489 checked it during initialization. 

490 """ 

491 if self.schema_name is None: 

492 raise DatabaseContextError("Schema name is required but not set.") 

493 return self.schema_name 

494 

495 def close(self) -> None: 

496 """Close and dispose of the database engine.""" 

497 if self._engine is not None: 

498 self._engine.dispose() 

499 self._engine = None 

500 

501 

502_ContextClass: TypeAlias = type[_BaseContext] 

503_ContextDecorator: TypeAlias = Callable[[_ContextClass], _ContextClass] 

504 

505 

506class DatabaseContextFactory: 

507 """Factory for creating DatabaseContext instances based on dialect type.""" 

508 

509 _registry: dict[str, _ContextClass] = {} 

510 

511 @classmethod 

512 def register(cls) -> _ContextDecorator: 

513 """Register a context class for its dialect. 

514 

515 The dialect is determined by reading the DIALECT attribute from the 

516 decorated class. 

517 

518 Returns 

519 ------- 

520 Callable 

521 The decorator function that registers the context class. 

522 

523 Examples 

524 -------- 

525 >>> @DatabaseContextFactory.register() 

526 ... class PostgreSQLContext(_BaseContext): 

527 ... DIALECT = "postgresql" 

528 ... pass 

529 

530 Notes 

531 ----- 

532 The registry is populated at module import time and afterwards should 

533 be treated as read-only. 

534 """ 

535 

536 def decorator(context_class: type[_BaseContext]) -> type[_BaseContext]: 

537 # Get the dialect from the class's DIALECT attribute 

538 if not hasattr(context_class, "DIALECT"): 538 ↛ 539line 538 didn't jump to line 539 because the condition on line 538 was never true

539 raise ValueError(f"Context class {context_class.__name__} must define a DIALECT attribute") 

540 cls._registry[context_class.DIALECT] = context_class 

541 return context_class 

542 

543 return decorator 

544 

545 @classmethod 

546 def register_class(cls, dialect: str, context_class: type[_BaseContext]) -> None: 

547 """Register a context class for a specific dialect programmatically. 

548 

549 Parameters 

550 ---------- 

551 dialect 

552 The dialect name to register. 

553 context_class 

554 The context class to use for this dialect. 

555 """ 

556 dialect_name = dialect.lower() 

557 if "+" in dialect_name: 

558 dialect_name = dialect_name.split("+")[0] 

559 cls._registry[dialect_name] = context_class 

560 

561 @classmethod 

562 def create_context(cls, dialect: str, engine_url: URL, metadata: MetaData) -> DatabaseContext: 

563 """Create a context instance for the given dialect. 

564 

565 Parameters 

566 ---------- 

567 dialect 

568 The database dialect name. 

569 engine_url 

570 The SQLAlchemy database URL. 

571 metadata 

572 The SQLAlchemy metadata. 

573 

574 Returns 

575 ------- 

576 DatabaseContext 

577 The appropriate context instance. 

578 

579 Raises 

580 ------ 

581 ValueError 

582 If no context class is registered for the dialect. 

583 """ 

584 dialect_name = dialect.lower() 

585 if "+" in dialect_name: 

586 dialect_name = dialect_name.split("+")[0] 

587 

588 if dialect_name not in cls._registry: 

589 supported = cls.get_supported_dialects() 

590 raise ValueError( 

591 f"No context class registered for dialect: {dialect_name}. " 

592 f"Supported dialects: {', '.join(supported)}" 

593 ) 

594 

595 context_class = cls._registry[dialect_name] 

596 return context_class(engine_url, metadata) 

597 

598 @classmethod 

599 def get_supported_dialects(cls) -> list[str]: 

600 """Get a list of supported dialect names. 

601 

602 Returns 

603 ------- 

604 list[str] 

605 List of supported dialect names. 

606 """ 

607 return list(cls._registry.keys()) 

608 

609 

610class _SQLWriter: 

611 """Write SQL statements to stdout or a file. 

612 

613 Parameters 

614 ---------- 

615 file 

616 The file to write the SQL statements to. If None, the statements 

617 will be written to stdout. 

618 """ 

619 

620 def __init__(self, file: IO[str] | None = None) -> None: 

621 """Initialize the SQL writer.""" 

622 self.file = file 

623 self.dialect: Dialect | None = None 

624 

625 def write(self, sql: Any, *multiparams: Any, **params: Any) -> None: 

626 """Write the SQL statement to a file or stdout. 

627 

628 Statements with parameters will be formatted with the values 

629 inserted into the resultant SQL output. 

630 

631 Parameters 

632 ---------- 

633 sql 

634 The SQL statement to write. 

635 *multiparams 

636 The multiparams to use for the SQL statement. 

637 **params 

638 The params to use for the SQL statement. 

639 

640 Notes 

641 ----- 

642 The functions arguments are typed very loosely because this method in 

643 SQLAlchemy is untyped, amd we do not call it directly. 

644 """ 

645 compiled = sql.compile(dialect=self.dialect) 

646 sql_str = str(compiled) + ";" 

647 params_list = [compiled.params] 

648 for params in params_list: 

649 if not params: 

650 print(sql_str, file=self.file) 

651 continue 

652 new_params = {} 

653 for key, value in params.items(): 

654 if isinstance(value, str): 

655 new_params[key] = f"'{value}'" 

656 elif value is None: 

657 new_params[key] = "null" 

658 else: 

659 new_params[key] = value 

660 print(sql_str % new_params, file=self.file) 

661 

662 

663@DatabaseContextFactory.register() 

664class PostgreSQLContext(_BaseContext): 

665 """Database context for Postgres. 

666 

667 Parameters 

668 ---------- 

669 engine_url 

670 The SQLAlchemy database URL for connecting to the database. 

671 metadata 

672 The SQLAlchemy metadata representing the database objects. 

673 """ 

674 

675 DIALECT = "postgresql" 

676 

677 def __init__(self, engine_url: URL, metadata: MetaData): 

678 super().__init__(engine_url, metadata, require_schema=True) 

679 

680 def initialize(self) -> None: 

681 schema_name = self._required_schema_name() 

682 try: 

683 logger.debug("Checking if PG schema exists: %s", schema_name) 

684 result = self.execute( 

685 """ 

686 SELECT schema_name 

687 FROM information_schema.schemata 

688 WHERE schema_name = :schema_name 

689 """, 

690 {"schema_name": schema_name}, 

691 ) 

692 if result.fetchone(): 

693 return 

694 logger.debug("Creating PG schema: %s", schema_name) 

695 self.execute(CreateSchema(schema_name)) 

696 except SQLAlchemyError as e: 

697 raise DatabaseContextError(f"Error initializing Postgres schema: {e}") from e 

698 

699 def drop(self) -> None: 

700 schema_name = self._required_schema_name() 

701 try: 

702 logger.debug("Dropping PostgreSQL schema if exists: %s", schema_name) 

703 self.execute(DropSchema(schema_name, if_exists=True, cascade=True)) 

704 except SQLAlchemyError as e: 

705 raise DatabaseContextError(f"Error dropping Postgres database: {e}") from e 

706 

707 

708@DatabaseContextFactory.register() 

709class MySQLContext(_BaseContext): 

710 """Database context for MySQL. 

711 

712 Parameters 

713 ---------- 

714 engine_url 

715 The SQLAlchemy database URL for connecting to the database. 

716 metadata 

717 The SQLAlchemy metadata representing the database objects. 

718 """ 

719 

720 DIALECT = "mysql" 

721 

722 def __init__(self, engine_url: URL, metadata: MetaData): 

723 super().__init__(engine_url, metadata, require_schema=True) 

724 

725 def initialize(self) -> None: 

726 # The schema is instantiated as a database, as MySQL does not have a 

727 # distinct schema concept, unlike Postgres. 

728 schema_name = self._required_schema_name() 

729 try: 

730 logger.debug("Checking if MySQL database exists: %s", schema_name) 

731 result = self.execute("SHOW DATABASES LIKE :schema_name", {"schema_name": schema_name}) 

732 if result.fetchone(): 

733 return 

734 logger.debug("Creating MySQL database: %s", schema_name) 

735 from sqlalchemy import DDL 

736 

737 create_stmt = DDL(f"CREATE DATABASE {quoted_name(schema_name, quote=True)}") 

738 self.execute(create_stmt) 

739 except SQLAlchemyError as e: 

740 raise DatabaseContextError(f"Error initializing MySQL database: {e}") from e 

741 

742 def drop(self) -> None: 

743 schema_name = self._required_schema_name() 

744 try: 

745 logger.debug("Dropping MySQL database if exists: %s", schema_name) 

746 from sqlalchemy import DDL 

747 

748 drop_stmt = DDL(f"DROP DATABASE IF EXISTS {quoted_name(schema_name, quote=True)}") 

749 self.execute(drop_stmt) 

750 except SQLAlchemyError as e: 

751 raise DatabaseContextError(f"Error dropping MySQL database: {e}") from e 

752 

753 

754@DatabaseContextFactory.register() 

755class SQLiteContext(_BaseContext): 

756 """Database context for SQLite. 

757 

758 Parameters 

759 ---------- 

760 engine_url 

761 The SQLAlchemy database URL for connecting to the database. 

762 metadata 

763 The SQLAlchemy metadata representing the database objects. 

764 """ 

765 

766 DIALECT = "sqlite" 

767 

768 def __init__(self, engine_url: URL, metadata: MetaData): 

769 # Schema name needs to be cleared, if set. 

770 _clear_schema(metadata) 

771 # Schema name is not required. 

772 super().__init__(engine_url, metadata) 

773 

774 def initialize(self) -> None: 

775 # Nothing needs to be done for SQLite initialization. 

776 return 

777 

778 def drop(self) -> None: 

779 try: 

780 logger.debug("Dropping tables in SQLite schema") 

781 # Drop all the tables in the database file. 

782 self.metadata.drop_all(bind=self.engine) 

783 except SQLAlchemyError as e: 

784 raise DatabaseContextError(f"Error dropping SQLite database: {e}") from e 

785 

786 

787class MockContext(DatabaseContext): 

788 """Database context for a mock connection. 

789 

790 Parameters 

791 ---------- 

792 metadata 

793 The SQLAlchemy metadata defining the database objects. 

794 connection 

795 The SQLAlchemy mock connection. 

796 """ 

797 

798 def __init__(self, metadata: MetaData, connection: MockConnection): 

799 self._metadata = metadata 

800 self._connection = connection 

801 self._dialect = connection.dialect 

802 

803 @property 

804 def dialect(self) -> Dialect: 

805 return self._dialect 

806 

807 @property 

808 def dialect_name(self) -> str: 

809 return self.dialect.name 

810 

811 @property 

812 def metadata(self) -> MetaData: 

813 return self._metadata 

814 

815 @property 

816 def engine(self) -> Engine: 

817 raise DatabaseContextError("MockContext does not provide an engine.") 

818 

819 def initialize(self) -> None: 

820 # Mock connection doesn't do any initialization. 

821 pass 

822 

823 def drop(self) -> None: 

824 # Mock connection doesn't drop. 

825 pass 

826 

827 def create_all(self) -> None: 

828 self._metadata.create_all(self._connection) 

829 

830 def create_indexes(self) -> None: 

831 # Mock connection can't create indexes. 

832 pass 

833 

834 def drop_indexes(self) -> None: 

835 # Mock connection can't drop indexes. 

836 pass 

837 

838 def execute(self, statement: SQLStatement, params: dict[str, Any] | None = None) -> Result: 

839 statement = _normalize_statement(statement) 

840 if params: 

841 return self._connection.connect().execute(statement, params) 

842 else: 

843 return self._connection.connect().execute(statement) 

844 

845 def close(self) -> None: 

846 """Close the mock connection (no-op).""" 

847 pass 

848 

849 

850def create_database_context( 

851 engine_url: str | URL, 

852 metadata: MetaData, 

853 output_file: IO[str] | None = None, 

854 dry_run: bool = False, 

855 echo: bool | None = None, 

856) -> DatabaseContext: 

857 """Create a DatabaseContext object based on the engine URL. 

858 

859 Parameters 

860 ---------- 

861 engine_url 

862 The database URL for the database connection. 

863 metadata 

864 The SQLAlchemy MetaData representing the database objects. 

865 output_file 

866 Output file for writing generated SQL commands. 

867 dry_run 

868 If True, configure the context to perform a dry run, where operations 

869 will not be executed. 

870 If False, use a normal context where operations are executed. 

871 echo 

872 If True, the SQLAlchemy engine will log all statements to the console. 

873 

874 Returns 

875 ------- 

876 DatabaseContext 

877 A database context appropriate for the given engine URL. This will be 

878 a `MockContext` if the URL appears like a mock URL or if ``dry_run`` is 

879 True, otherwise it will be a context based on the dialect using the 

880 factory pattern. 

881 

882 Raises 

883 ------ 

884 DatabaseContextError 

885 If the dialect is not supported or if there's an issue creating 

886 the context. 

887 """ 

888 if isinstance(engine_url, str): 

889 engine_url = make_url(engine_url) 

890 

891 if is_mock_url(engine_url) or dry_run: 

892 # Use a mock context for mock URLs or dry run mode. 

893 dialect_name = _dialect_name(engine_url) 

894 if dialect_name == "sqlite": 

895 _clear_schema(metadata) 

896 mock_connection = _create_mock_connection(engine_url, output_file) 

897 return MockContext(metadata, mock_connection) 

898 else: 

899 # Create a real engine and context for the given dialect. 

900 try: 

901 dialect_name = _dialect_name(engine_url) 

902 

903 # Use the factory to create the appropriate context 

904 try: 

905 db_ctx = DatabaseContextFactory.create_context(dialect_name, engine_url, metadata) 

906 if echo is not None: 

907 # This is settable for real contexts only. 

908 if hasattr(db_ctx, "echo"): 

909 db_ctx.echo = echo 

910 return db_ctx 

911 except ValueError as e: 

912 supported = DatabaseContextFactory.get_supported_dialects() 

913 raise DatabaseContextError( 

914 f"Unsupported dialect: {dialect_name}. Supported dialects are: {', '.join(supported)}" 

915 ) from e 

916 

917 except Exception as e: 

918 if isinstance(e, DatabaseContextError): 

919 raise 

920 raise DatabaseContextError(f"Failed to create database context: {e}") from e