"""SQLAlchemy connection helpers."""
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Optional
# Import duckdb_engine to register the dialect with SQLAlchemy
try:
import duckdb_engine # noqa: F401
except ImportError:
pass
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Connection, Engine
from moltres_core.config import EngineConfig
from moltres_core.sql.dialects import DialectSpec, get_dialect
[docs]
class ConnectionManager:
"""Creates and caches SQLAlchemy engines for Moltres sessions."""
def __init__(self, config: EngineConfig):
self.config = config
self._engine: Engine | None = None
self._session: object | None = None # SQLAlchemy Session
self._active_transaction: Optional[Connection] = None
self._savepoint_stack: list[str] = []
self._transaction_metadata: Optional[dict[str, object]] = None
def _create_engine(self) -> Engine:
# If a session is provided, extract engine from it
if self.config.session is not None:
session = self.config.session
# Check if it's a SQLAlchemy Session or SQLModel Session
if hasattr(session, "get_bind"):
# SQLAlchemy 2.0 style
bind = session.get_bind()
elif hasattr(session, "bind"):
# SQLAlchemy 1.x style
bind = session.bind
else:
raise TypeError(
"session must be a SQLAlchemy Session or SQLModel Session instance. "
f"Got: {type(session).__name__}"
)
if not isinstance(bind, Engine):
raise TypeError(
"Session's bind must be a synchronous Engine, not AsyncEngine. "
"Use async_connect() for async sessions."
)
self._session = session
return bind
# If an engine is provided in config, use it directly
if self.config.engine is not None:
if not isinstance(self.config.engine, Engine):
raise TypeError("config.engine must be a synchronous Engine, not AsyncEngine")
return self.config.engine
# Otherwise, create a new engine from DSN
if self.config.dsn is None:
raise ValueError(
"Either 'dsn', 'engine', or 'session' must be provided in EngineConfig"
)
kwargs: dict[str, object] = {"echo": self.config.echo, "future": self.config.future}
if self.config.pool_size is not None:
kwargs["pool_size"] = self.config.pool_size
if self.config.max_overflow is not None:
kwargs["max_overflow"] = self.config.max_overflow
if self.config.pool_timeout is not None:
kwargs["pool_timeout"] = self.config.pool_timeout
if self.config.pool_recycle is not None:
kwargs["pool_recycle"] = self.config.pool_recycle
if self.config.pool_pre_ping:
kwargs["pool_pre_ping"] = self.config.pool_pre_ping
return create_engine(self.config.dsn, **kwargs)
@property
def engine(self) -> Engine:
if self._engine is None:
self._engine = self._create_engine()
return self._engine
[docs]
@contextmanager
def connect(self, transaction: Optional[Connection] = None) -> Iterator[Connection]:
"""Get a database connection.
Args:
transaction: If provided, use this transaction connection instead of creating a new one.
This allows operations to share a transaction.
If None and an active transaction exists, uses the active transaction.
Yields:
:class:`Database` connection
"""
if transaction is not None:
# Use the provided transaction connection
yield transaction
elif self._active_transaction is not None:
# Use the active transaction connection automatically
yield self._active_transaction
elif self._session is not None:
# Use the session's connection
# SQLAlchemy sessions have a connection() method
if hasattr(self._session, "connection"):
# Get connection from session
connection = self._session.connection()
yield connection
else:
# Fallback: use session's bind to create a connection
with self.engine.begin() as connection:
yield connection
else:
# Create a new connection with auto-commit (default behavior)
with self.engine.begin() as connection:
yield connection
[docs]
def begin_transaction(
self,
savepoint: bool = False,
readonly: bool = False,
isolation_level: Optional[str] = None,
timeout: Optional[float] = None,
) -> Connection:
"""Begin a new transaction and return the connection.
Args:
savepoint: If True and a transaction is already active, create a savepoint instead.
readonly: If True, set transaction to read-only mode.
isolation_level: Optional isolation level (READ UNCOMMITTED, READ COMMITTED,
REPEATABLE READ, SERIALIZABLE).
timeout: Optional transaction timeout in seconds.
Returns:
Connection that is part of a transaction (not auto-committed)
Raises:
RuntimeError: If savepoint=False and a transaction is already active.
ValueError: If isolation level or readonly is requested but not supported by dialect.
"""
if self._active_transaction is not None:
if savepoint:
# Create a savepoint instead of a new transaction
savepoint_name = self._generate_savepoint_name()
return self.create_savepoint(self._active_transaction, savepoint_name)
else:
raise RuntimeError(
"Transaction already active. Use savepoint=True for nested transactions."
)
# Get dialect for feature checking
dialect_name = self.engine.dialect.name
try:
dialect_spec = get_dialect(dialect_name)
except ValueError:
# Unknown dialect, use conservative defaults
dialect_spec = DialectSpec(name=dialect_name)
self._active_transaction = self.engine.connect()
self._savepoint_stack = []
self._transaction_metadata = {
"readonly": readonly,
"isolation_level": isolation_level,
"timeout": timeout,
}
# Set isolation level if specified
if isolation_level:
if not dialect_spec.supports_isolation_levels:
self._active_transaction.close()
self._active_transaction = None
raise ValueError(
f"Dialect '{dialect_name}' does not support isolation levels. "
"SQLite only supports SERIALIZABLE and READ UNCOMMITTED via PRAGMA."
)
self._set_isolation_level(self._active_transaction, isolation_level)
# Set read-only mode if specified
if readonly:
if not dialect_spec.supports_read_only_transactions:
self._active_transaction.close()
self._active_transaction = None
raise ValueError(
f"Dialect '{dialect_name}' does not support read-only transactions."
)
self._set_readonly(self._active_transaction, True)
# Set timeout if specified
if timeout:
self._set_timeout(self._active_transaction, timeout, dialect_name)
self._active_transaction.begin()
return self._active_transaction
def _generate_savepoint_name(self) -> str:
"""Generate a unique savepoint name."""
return f"sp_{len(self._savepoint_stack)}"
def _set_isolation_level(self, connection: Connection, isolation_level: str) -> None:
"""Set transaction isolation level."""
# Normalize isolation level names
level_map = {
"READ UNCOMMITTED": "READ UNCOMMITTED",
"READ COMMITTED": "READ COMMITTED",
"REPEATABLE READ": "REPEATABLE READ",
"SERIALIZABLE": "SERIALIZABLE",
}
normalized = level_map.get(isolation_level.upper())
if not normalized:
raise ValueError(
f"Invalid isolation level '{isolation_level}'. "
"Must be one of: READ UNCOMMITTED, READ COMMITTED, REPEATABLE READ, SERIALIZABLE"
)
stmt = text(f"SET TRANSACTION ISOLATION LEVEL {normalized}")
connection.execute(stmt)
def _set_readonly(self, connection: Connection, readonly: bool) -> None:
"""Set transaction to read-only mode."""
mode = "READ ONLY" if readonly else "READ WRITE"
stmt = text(f"SET TRANSACTION {mode}")
connection.execute(stmt)
def _set_timeout(self, connection: Connection, timeout: float, dialect_name: str) -> None:
"""Set transaction timeout (database-specific)."""
# PostgreSQL uses statement_timeout (in milliseconds)
if "postgresql" in dialect_name:
stmt = text(f"SET statement_timeout = {int(timeout * 1000)}")
connection.execute(stmt)
# MySQL uses innodb_lock_wait_timeout (in seconds)
elif "mysql" in dialect_name:
stmt = text(f"SET innodb_lock_wait_timeout = {int(timeout)}")
connection.execute(stmt)
# SQLite doesn't support transaction timeouts directly
# Other databases may need specific implementations
[docs]
def create_savepoint(self, connection: Connection, name: str) -> Connection:
"""Create a savepoint in the current transaction.
Args:
connection: The transaction connection
name: Savepoint name
Returns:
The same connection (for compatibility)
Raises:
RuntimeError: If no transaction is active or connection doesn't match active transaction.
"""
if connection is not self._active_transaction:
raise RuntimeError("Connection is not the active transaction")
if not self._active_transaction:
raise RuntimeError("No active transaction")
# Get dialect for feature checking
dialect_name = self.engine.dialect.name
try:
dialect_spec = get_dialect(dialect_name)
except ValueError:
dialect_spec = DialectSpec(name=dialect_name)
if not dialect_spec.supports_savepoints:
raise ValueError(f"Dialect '{dialect_name}' does not support savepoints.")
stmt = text(f"SAVEPOINT {name}")
connection.execute(stmt)
self._savepoint_stack.append(name)
return connection
[docs]
def rollback_to_savepoint(self, connection: Connection, name: str) -> None:
"""Rollback to a specific savepoint.
Args:
connection: The transaction connection
name: Savepoint name to rollback to
Raises:
RuntimeError: If no transaction is active, connection doesn't match, or savepoint not found.
"""
if connection is not self._active_transaction:
raise RuntimeError("Connection is not the active transaction")
if not self._active_transaction:
raise RuntimeError("No active transaction")
if name not in self._savepoint_stack:
raise RuntimeError(f"Savepoint '{name}' not found in current transaction")
stmt = text(f"ROLLBACK TO SAVEPOINT {name}")
connection.execute(stmt)
# Remove all savepoints after the one we're rolling back to
index = self._savepoint_stack.index(name)
self._savepoint_stack = self._savepoint_stack[: index + 1]
[docs]
def release_savepoint(self, connection: Connection, name: str) -> None:
"""Release a savepoint.
Args:
connection: The transaction connection
name: Savepoint name to release
Raises:
RuntimeError: If no transaction is active, connection doesn't match, or savepoint not found.
"""
if connection is not self._active_transaction:
raise RuntimeError("Connection is not the active transaction")
if not self._active_transaction:
raise RuntimeError("No active transaction")
if name not in self._savepoint_stack:
raise RuntimeError(f"Savepoint '{name}' not found in current transaction")
stmt = text(f"RELEASE SAVEPOINT {name}")
connection.execute(stmt)
self._savepoint_stack.remove(name)
[docs]
def commit_transaction(self, connection: Connection) -> None:
"""Commit a transaction.
Args:
connection: The transaction connection to commit
"""
if connection is not self._active_transaction:
raise RuntimeError("Connection is not the active transaction")
try:
connection.commit()
finally:
# Always close connection, even if commit fails
connection.close()
self._active_transaction = None
self._savepoint_stack = []
self._transaction_metadata = None
[docs]
def rollback_transaction(self, connection: Connection) -> None:
"""Rollback a transaction.
Args:
connection: The transaction connection to rollback
"""
if connection is not self._active_transaction:
raise RuntimeError("Connection is not the active transaction")
try:
connection.rollback()
finally:
# Always close connection, even if rollback fails
connection.close()
self._active_transaction = None
self._savepoint_stack = []
self._transaction_metadata = None
@property
def active_transaction(self) -> Optional[Connection]:
"""Get the active transaction connection if one exists."""
return self._active_transaction
@property
def transaction_metadata(self) -> Optional[dict[str, object]]:
"""Get transaction metadata if a transaction is active."""
return self._transaction_metadata
@property
def savepoint_stack(self) -> list[str]:
"""Get the current savepoint stack."""
return self._savepoint_stack.copy()