Source code for moltres.engine.connection

"""SQLAlchemy connection helpers."""

from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from typing import Optional

# Import duckdb_engine to register the dialect with SQLAlchemy
try:
    import duckdb_engine  # noqa: F401
except ImportError:
    pass

from sqlalchemy import create_engine
from sqlalchemy.engine import Connection, Engine

from ..config import EngineConfig


[docs] class ConnectionManager: """Creates and caches SQLAlchemy engines for Moltres sessions.""" def __init__(self, config: EngineConfig): self.config = config self._engine: Engine | None = None self._session: object | None = None # SQLAlchemy Session self._active_transaction: Optional[Connection] = None def _create_engine(self) -> Engine: # If a session is provided, extract engine from it if self.config.session is not None: session = self.config.session # Check if it's a SQLAlchemy Session or SQLModel Session if hasattr(session, "get_bind"): # SQLAlchemy 2.0 style bind = session.get_bind() elif hasattr(session, "bind"): # SQLAlchemy 1.x style bind = session.bind else: raise TypeError( "session must be a SQLAlchemy Session or SQLModel Session instance. " f"Got: {type(session).__name__}" ) if not isinstance(bind, Engine): raise TypeError( "Session's bind must be a synchronous Engine, not AsyncEngine. " "Use async_connect() for async sessions." ) self._session = session return bind # If an engine is provided in config, use it directly if self.config.engine is not None: if not isinstance(self.config.engine, Engine): raise TypeError("config.engine must be a synchronous Engine, not AsyncEngine") return self.config.engine # Otherwise, create a new engine from DSN if self.config.dsn is None: raise ValueError( "Either 'dsn', 'engine', or 'session' must be provided in EngineConfig" ) kwargs: dict[str, object] = {"echo": self.config.echo, "future": self.config.future} if self.config.pool_size is not None: kwargs["pool_size"] = self.config.pool_size if self.config.max_overflow is not None: kwargs["max_overflow"] = self.config.max_overflow if self.config.pool_timeout is not None: kwargs["pool_timeout"] = self.config.pool_timeout if self.config.pool_recycle is not None: kwargs["pool_recycle"] = self.config.pool_recycle if self.config.pool_pre_ping: kwargs["pool_pre_ping"] = self.config.pool_pre_ping return create_engine(self.config.dsn, **kwargs) @property def engine(self) -> Engine: if self._engine is None: self._engine = self._create_engine() return self._engine
[docs] @contextmanager def connect(self, transaction: Optional[Connection] = None) -> Iterator[Connection]: """Get a database connection. Args: transaction: If provided, use this transaction connection instead of creating a new one. This allows operations to share a transaction. Yields: :class:`Database` connection """ if transaction is not None: # Use the provided transaction connection yield transaction elif self._session is not None: # Use the session's connection # SQLAlchemy sessions have a connection() method if hasattr(self._session, "connection"): # Get connection from session connection = self._session.connection() yield connection else: # Fallback: use session's bind to create a connection with self.engine.begin() as connection: yield connection else: # Create a new connection with auto-commit (default behavior) with self.engine.begin() as connection: yield connection
[docs] def begin_transaction(self) -> Connection: """Begin a new transaction and return the connection. Returns: Connection that is part of a transaction (not auto-committed) """ if self._active_transaction is not None: raise RuntimeError("Transaction already active. Nested transactions not yet supported.") self._active_transaction = self.engine.connect() self._active_transaction.begin() return self._active_transaction
[docs] def commit_transaction(self, connection: Connection) -> None: """Commit a transaction. Args: connection: The transaction connection to commit """ if connection is not self._active_transaction: raise RuntimeError("Connection is not the active transaction") try: connection.commit() finally: # Always close connection, even if commit fails connection.close() self._active_transaction = None
[docs] def rollback_transaction(self, connection: Connection) -> None: """Rollback a transaction. Args: connection: The transaction connection to rollback """ if connection is not self._active_transaction: raise RuntimeError("Connection is not the active transaction") try: connection.rollback() finally: # Always close connection, even if rollback fails connection.close() self._active_transaction = None
@property def active_transaction(self) -> Optional[Connection]: """Get the active transaction connection if one exists.""" return self._active_transaction