Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions alembic/autogenerate/compare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import TYPE_CHECKING

from . import check_constraints
from . import comments
from . import constraints
from . import schema
Expand Down Expand Up @@ -60,3 +61,6 @@ def _produce_net_changes(
server_defaults, "alembic.autogenerate.defaults"
)
Plugin.setup_plugin_from_module(comments, "alembic.autogenerate.comments")
Plugin.setup_plugin_from_module(
check_constraints, "alembic.ext.checkconstraint"
)
195 changes: 195 additions & 0 deletions alembic/autogenerate/compare/check_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# mypy: allow-untyped-defs, allow-untyped-calls, allow-incomplete-defs

from __future__ import annotations

import logging
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

from sqlalchemy import schema as sa_schema

from .util import _InspectorConv
from ...operations import ops
from ...util import PriorityDispatchResult
from ...util import sqla_compat

if TYPE_CHECKING:
from sqlalchemy.engine.interfaces import ReflectedCheckConstraint
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Table

from ...autogenerate.api import AutogenContext
from ...ddl.impl import DefaultImpl
from ...operations.ops import ModifyTableOps
from ...runtime.plugins import Plugin


log = logging.getLogger(__name__)


def _make_check_constraint(
impl: DefaultImpl,
params: ReflectedCheckConstraint,
conn_table: Table,
) -> CheckConstraint:
const = sa_schema.CheckConstraint(
params["sqltext"],
name=params["name"],
**impl.adjust_reflected_dialect_options(params, "check_constraint"),
)
conn_table.append_constraint(const)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Federico Caselli (CaselIT) wrote:

I don't think this is needed.

View this in Gerrit at https://gerrit.sqlalchemy.org/c/sqlalchemy/alembic/+/6672

return const


def _compare_check_constraints(
autogen_context: AutogenContext,
modify_table_ops: ModifyTableOps,
schema: Optional[str],
tname: Union[quoted_name, str],
conn_table: Optional[Table],
metadata_table: Optional[Table],
) -> PriorityDispatchResult:
if conn_table is None or metadata_table is None:
return PriorityDispatchResult.CONTINUE

inspector = autogen_context.inspector
impl = autogen_context.migration_context.impl

metadata_ck_constraints = {
ck
for ck in metadata_table.constraints
if isinstance(ck, sa_schema.CheckConstraint)
and not sqla_compat._is_type_bound(ck)
}

try:
conn_ck_list = _InspectorConv(inspector).get_check_constraints(
tname, schema=schema
)
except NotImplementedError:
return PriorityDispatchResult.CONTINUE

conn_ck_list = [
ck
for ck in conn_ck_list
if ck.get("name") is not None
and autogen_context.run_name_filters(
ck["name"],
"check_constraint",
{"table_name": tname, "schema_name": schema},
)
]

conn_ck_objs = {
_make_check_constraint(impl, ck_def, conn_table)
for ck_def in conn_ck_list
}

metadata_ck_sig = {
impl._create_metadata_constraint_sig(ck)
for ck in metadata_ck_constraints
if sqla_compat._constraint_is_named(ck, autogen_context.dialect)
}

conn_ck_sig = {
impl._create_reflected_constraint_sig(ck) for ck in conn_ck_objs
}

metadata_ck_by_name = {
c.name: c
for c in metadata_ck_sig
if sqla_compat.constraint_name_string(c.name)
}
conn_ck_by_name = {
c.name: c
for c in conn_ck_sig
if sqla_compat.constraint_name_string(c.name)
}

for removed_name in sorted(
set(conn_ck_by_name).difference(metadata_ck_by_name)
):
conn_obj = conn_ck_by_name[removed_name]
if autogen_context.run_object_filters(
conn_obj.const,
conn_obj.name,
"check_constraint",
True,
None,
):
modify_table_ops.ops.append(
ops.DropConstraintOp.from_constraint(conn_obj.const)
)
log.info(
"Detected removed check constraint %r on table %r",
conn_obj.name,
tname,
)

for existing_name in sorted(
set(metadata_ck_by_name).intersection(conn_ck_by_name)
):
metadata_obj = metadata_ck_by_name[existing_name]
conn_obj = conn_ck_by_name[existing_name]

comparison = metadata_obj.compare_to_reflected(conn_obj)

if comparison.is_different:
if autogen_context.run_object_filters(
metadata_obj.const,
metadata_obj.name,
"check_constraint",
False,
conn_obj.const,
):
log.info(
"Detected changed check constraint %r on table %r: %s",
existing_name,
tname,
comparison.message,
)
modify_table_ops.ops.append(
ops.DropConstraintOp.from_constraint(conn_obj.const)
)
modify_table_ops.ops.append(
ops.AddConstraintOp.from_constraint(metadata_obj.const)
)
elif comparison.is_skip:
log.info(
"Cannot compare check constraint %r, "
"assuming equal and skipping. %s",
existing_name,
comparison.message,
)

for added_name in sorted(
set(metadata_ck_by_name).difference(conn_ck_by_name)
):
metadata_obj = metadata_ck_by_name[added_name]
if autogen_context.run_object_filters(
metadata_obj.const,
metadata_obj.name,
"check_constraint",
False,
None,
):
modify_table_ops.ops.append(
ops.AddConstraintOp.from_constraint(metadata_obj.const)
)
log.info(
"Detected added check constraint %r on table %r",
metadata_obj.name,
tname,
)

return PriorityDispatchResult.CONTINUE


def setup(plugin: Plugin) -> None:
plugin.add_autogenerate_comparator(
_compare_check_constraints,
"table",
"checkconstraints",
)
25 changes: 25 additions & 0 deletions alembic/autogenerate/compare/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if TYPE_CHECKING:
from sqlalchemy import Table
from sqlalchemy.engine import Inspector
from sqlalchemy.engine.interfaces import ReflectedCheckConstraint
from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
from sqlalchemy.engine.interfaces import ReflectedIndex
from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
Expand Down Expand Up @@ -78,6 +79,11 @@ def get_foreign_keys(
) -> list[ReflectedForeignKeyConstraint]:
raise NotImplementedError()

def get_check_constraints(
self, tname: str, schema: str | None
) -> list[ReflectedCheckConstraint]:
raise NotImplementedError()

def reflect_table(self, table: Table) -> None:
raise NotImplementedError()

Expand Down Expand Up @@ -123,6 +129,13 @@ def get_foreign_keys(
self.inspector.get_foreign_keys(tname, schema=schema)
)

def get_check_constraints(
self, tname: str, schema: str | None
) -> list[ReflectedCheckConstraint]:
return self._apply_reflectinfo_conv(
self.inspector.get_check_constraints(tname, schema=schema)
)

def reflect_table(self, table: Table) -> None:
self.inspector.reflect_table(table, include_columns=None)

Expand Down Expand Up @@ -252,6 +265,18 @@ def get_foreign_keys(
apply_constraint_conv=True,
)

def get_check_constraints(
self, tname: str, schema: str | None
) -> list[ReflectedCheckConstraint]:
return self._return_from_cache(
tname,
schema,
"alembic_check_constraints",
self.inspector.get_check_constraints,
apply_constraint_conv=True,
optional=False,
)

def _apply_reflectinfo_conv(self, consts):
if not consts:
return consts
Expand Down
20 changes: 18 additions & 2 deletions alembic/autogenerate/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,24 @@ def _add_pk_constraint(constraint, autogen_context):


@renderers.dispatch_for(ops.CreateCheckConstraintOp)
def _add_check_constraint(constraint, autogen_context):
raise NotImplementedError()
def _add_check_constraint(
autogen_context: AutogenContext, op: ops.CreateCheckConstraintOp
) -> str:
constraint = op.to_constraint()
args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
if not autogen_context._has_batch:
args.append(repr(_ident(op.table_name)))
args.append(
_render_potential_expr(
constraint.sqltext, autogen_context, wrap_in_element=False
)
)
if not autogen_context._has_batch and op.schema:
args.append("schema=%r" % _ident(op.schema))
return "%(prefix)screate_check_constraint(%(args)s)" % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}


@renderers.dispatch_for(ops.DropConstraintOp)
Expand Down
35 changes: 35 additions & 0 deletions alembic/ddl/_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TypeVar
from typing import Union

from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
Expand Down Expand Up @@ -86,6 +87,7 @@ class _constraint_sig(Generic[_C]):
_is_index: ClassVar[bool] = False
_is_fk: ClassVar[bool] = False
_is_uq: ClassVar[bool] = False
_is_ck: ClassVar[bool] = False

_is_metadata: bool

Expand Down Expand Up @@ -325,5 +327,38 @@ def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
return sig._is_uq


class _ck_constraint_sig(_constraint_sig[CheckConstraint]):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Federico Caselli (CaselIT) wrote:

nit: let's move this before the other typeguards

View this in Gerrit at https://gerrit.sqlalchemy.org/c/sqlalchemy/alembic/+/6672

_is_ck = True

@classmethod
def _register(cls) -> None:
_clsreg["check_constraint"] = cls
_clsreg["table_or_column_check_constraint"] = cls
_clsreg["column_check_constraint"] = cls

def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: CheckConstraint,
) -> None:
self._is_metadata = is_metadata
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
self._sig = (self.name,)

def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
assert is_ck_sig(other)
return self.impl.compare_check_constraint(self.const, other.const)


def is_ck_sig(sig: _constraint_sig) -> TypeGuard[_ck_constraint_sig]:
return sig._is_ck


def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
return sig._is_fk
12 changes: 11 additions & 1 deletion alembic/ddl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.interfaces import ReflectedCheckConstraint
from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
from sqlalchemy.engine.interfaces import ReflectedIndex
from sqlalchemy.engine.interfaces import ReflectedPrimaryKeyConstraint
Expand All @@ -51,6 +52,7 @@
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
Expand All @@ -64,7 +66,8 @@
from ..operations.batch import BatchOperationsImpl

_ReflectedConstraint = (
ReflectedForeignKeyConstraint
ReflectedCheckConstraint
| ReflectedForeignKeyConstraint
| ReflectedPrimaryKeyConstraint
| ReflectedIndex
| ReflectedUniqueConstraint
Expand Down Expand Up @@ -840,6 +843,13 @@ def compare_unique_constraint(
else:
return ComparisonResult.Equal()

def compare_check_constraint(
self,
metadata_constraint: CheckConstraint,
reflected_constraint: CheckConstraint,
) -> ComparisonResult:
return ComparisonResult.Equal()

def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
conn_indexes_by_name = {c.name: c for c in conn_indexes}

Expand Down
1 change: 1 addition & 0 deletions alembic/runtime/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"index",
"unique_constraint",
"foreign_key_constraint",
"check_constraint",
]
NameFilterParentNames = MutableMapping[
Literal["schema_name", "table_name", "schema_qualified_table_name"],
Expand Down
Loading
Loading