From d63e0921c1c1eafe4159f53c4d89990b2b3c6a6f Mon Sep 17 00:00:00 2001 From: Sami Alfattany Date: Thu, 5 Feb 2026 19:04:33 +0300 Subject: [PATCH] Add default coerce to any Select2Field include all field types and cover all default end explicit coerce --- flask_admin/_types.py | 4 + flask_admin/contrib/mongoengine/form.py | 10 +- flask_admin/contrib/mongoengine/view.py | 6 +- flask_admin/contrib/peewee/view.py | 6 +- flask_admin/contrib/sqla/form.py | 28 ++++- flask_admin/contrib/sqla/view.py | 6 +- flask_admin/form/fields.py | 26 ++++ flask_admin/model/base.py | 18 ++- flask_admin/tests/sqla/test_form.py | 153 ++++++++++++++++++++++++ 9 files changed, 247 insertions(+), 10 deletions(-) diff --git a/flask_admin/_types.py b/flask_admin/_types.py index e3cdb3a11..166107e33 100644 --- a/flask_admin/_types.py +++ b/flask_admin/_types.py @@ -260,6 +260,10 @@ class T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK(T_FIELD_ARGS_VALIDATORS): allow_blank: NotRequired[bool] +class T_FIELD_ARGS_VALIDATORS_SELECTABLE(T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK): + coerce: NotRequired[t.Callable[[t.Any], t.Any]] + + class T_FIELD_ARGS_VALIDATORS_FILES(T_FIELD_ARGS_VALIDATORS): base_path: NotRequired[str] allow_overwrite: NotRequired[bool] diff --git a/flask_admin/contrib/mongoengine/form.py b/flask_admin/contrib/mongoengine/form.py index 9f4dbf914..761454aae 100644 --- a/flask_admin/contrib/mongoengine/form.py +++ b/flask_admin/contrib/mongoengine/form.py @@ -36,6 +36,7 @@ from flask_admin.model.form import FieldPlaceholder from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from ..._types import T_ITER_CHOICES from ..._types import T_MONGO_ENGINE_DOCUMENT from ..._types import T_VALIDATOR @@ -289,7 +290,9 @@ def convert( self, model: type[T_MONGO_ENGINE_DOCUMENT], field: BaseField, - field_args: T_FIELD_ARGS_VALIDATORS_FILES | None, + field_args: T_FIELD_ARGS_VALIDATORS_FILES + | T_FIELD_ARGS_VALIDATORS_SELECTABLE + | None = None, ) -> t.Any: # Check if it is overridden field if isinstance(field, FieldPlaceholder): @@ -624,7 +627,10 @@ def get_form( base_class: type[form.BaseForm] = form.BaseForm, only: t.Iterable[str] | None = None, exclude: t.Iterable[str] | None = None, - field_args: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + field_args: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, extra_fields: dict[str, "UnboundField[F]"] | None = None, ) -> type[form.BaseForm]: """ diff --git a/flask_admin/contrib/mongoengine/view.py b/flask_admin/contrib/mongoengine/view.py index f91170667..690cd0c44 100644 --- a/flask_admin/contrib/mongoengine/view.py +++ b/flask_admin/contrib/mongoengine/view.py @@ -29,6 +29,7 @@ from ..._types import T_AJAX_MODEL_LOADER from ..._types import T_COLUMN_TYPE_FORMATTERS from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from ..._types import T_MONGO_ENGINE_DOCUMENT from ..._types import T_WIDGET from .ajax import create_ajax_loader @@ -481,7 +482,10 @@ def scaffold_form(self) -> type: def scaffold_list_form( self, widget: type[T_WIDGET] | None = None, - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + validators: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, ) -> type[BaseListForm]: """ Create form for the `index_view` using only the columns from diff --git a/flask_admin/contrib/peewee/view.py b/flask_admin/contrib/peewee/view.py index 629d6365d..4ee2c5e69 100644 --- a/flask_admin/contrib/peewee/view.py +++ b/flask_admin/contrib/peewee/view.py @@ -27,6 +27,7 @@ from flask_admin.model.form import InlineFormAdmin from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from ..._types import T_FILTER from ..._types import T_PEEWEE_MODEL from ..._types import T_WIDGET @@ -339,7 +340,10 @@ def scaffold_form(self) -> type[Form]: def scaffold_list_form( self, widget: type[T_WIDGET] | None = None, - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + validators: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, ) -> type[Form]: """ Create form for the `index_view` using only the columns from diff --git a/flask_admin/contrib/sqla/form.py b/flask_admin/contrib/sqla/form.py index 50f2ba044..b3e8300b1 100644 --- a/flask_admin/contrib/sqla/form.py +++ b/flask_admin/contrib/sqla/form.py @@ -9,6 +9,8 @@ from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy.orm import ColumnProperty +from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy_utils import ChoiceType from wtforms import fields from wtforms import Form from wtforms import HiddenField @@ -37,6 +39,7 @@ from ..._types import T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK from ..._types import T_FIELD_ARGS_VALIDATORS_COERCE from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from ..._types import T_INSTRUMENTED_ATTRIBUTE from ..._types import T_MODEL_VIEW from ..._types import T_ORM_MODEL @@ -231,7 +234,7 @@ def convert( if isinstance(prop, FieldPlaceholder): return form.recreate_field(prop.field) - kwargs: T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK = {"validators": [], "filters": []} + kwargs: T_FIELD_ARGS_VALIDATORS_SELECTABLE = {"validators": [], "filters": []} if field_args: kwargs.update(field_args) # type: ignore[typeddict-item] @@ -361,7 +364,11 @@ def convert( form_choices = getattr(self.view, "form_choices", None) if mapper.class_ == self.view.model and form_choices: choices = form_choices.get(prop.key) + if choices: + if "coerce" not in kwargs: + kwargs["coerce"] = coerce_factory(column.type) + return form.Select2Field( # type: ignore[misc] choices=choices, allow_blank=column.nullable, # type: ignore[arg-type] @@ -667,6 +674,18 @@ def avoid_empty_strings(value: T) -> T | None: return value if value else None +def coerce_factory(type_: TypeEngine[t.Any]) -> t.Callable[[t.Any], t.Any]: + """ + Return a function to coerce a column, for use by Select2Field. + :param type_: Column type + """ + + if isinstance(type_, ChoiceType): + return choice_type_coerce_factory(type_) + else: + return type_.python_type + + def choice_type_coerce_factory(type_: T_CHOICE_TYPE) -> t.Callable[[t.Any], t.Any]: """ Return a function to coerce a ChoiceType column, for use by Select2Field. @@ -683,8 +702,10 @@ def choice_type_coerce_factory(type_: T_CHOICE_TYPE) -> t.Callable[[t.Any], t.An def choice_coerce(value: t.Any) -> t.Any: if value is None: return None + if isinstance(value, choice_cls): return getattr(value, key) + return type_.python_type(value) return choice_coerce @@ -711,7 +732,10 @@ def get_form( base_class: type[form.BaseForm] = form.BaseForm, only: t.Collection[str | T_INSTRUMENTED_ATTRIBUTE] | None = None, exclude: t.Collection[str | T_INSTRUMENTED_ATTRIBUTE] | None = None, - field_args: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + field_args: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, hidden_pk: bool = False, ignore_hidden: bool = True, extra_fields: dict[str | T_INSTRUMENTED_ATTRIBUTE, UnboundField[t.Any]] diff --git a/flask_admin/contrib/sqla/view.py b/flask_admin/contrib/sqla/view.py index 4fb25a6ef..7ed89c2bb 100644 --- a/flask_admin/contrib/sqla/view.py +++ b/flask_admin/contrib/sqla/view.py @@ -40,6 +40,7 @@ from ..._types import T_COLUMN_LIST from ..._types import T_COLUMN_TYPE_FORMATTERS from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from ..._types import T_FILTER from ..._types import T_INSTRUMENTED_ATTRIBUTE from ..._types import T_SQLALCHEMY_COLUMN @@ -899,7 +900,10 @@ def scaffold_form(self) -> type[Form]: def scaffold_list_form( self, widget: type[T_WIDGET] | None = None, - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + validators: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, ) -> type[Form]: """ Create form for the `index_view` using only the columns from diff --git a/flask_admin/form/fields.py b/flask_admin/form/fields.py index 0a96ab961..fa7d51e62 100644 --- a/flask_admin/form/fields.py +++ b/flask_admin/form/fields.py @@ -4,6 +4,7 @@ import time import typing as t from enum import Enum +from inspect import isclass from wtforms import fields from wtforms.form import BaseForm @@ -196,6 +197,9 @@ def process_formdata(self, valuelist: t.Sequence[str] | None) -> None: self.data = None else: try: + if isclass(self.coerce) and issubclass(self.coerce, Enum): + self.coerce = self._enum_coerce_factory(self.coerce) + self.data = self.coerce(valuelist[0]) except ValueError as err: raise ValueError( @@ -208,6 +212,28 @@ def pre_validate(self, form: BaseForm) -> None: super().pre_validate(form) + def _enum_coerce_factory(self, type_: type[Enum]) -> t.Callable[[t.Any], t.Any]: + """ + Return a function to coerce an Enum column, for use by Select2Field. + :param type_: Enum class + """ + + def enum_coerce(value: t.Any) -> t.Any: + if value is None: + return None + + if isinstance(value, type_): + return value + + ename = getattr(value, "name", value) + ename = str(value).replace(type_.__name__ + ".", "") + try: + return type_[ename] + except KeyError: + return type_(value) + + return enum_coerce + class Select2TagsField(fields.StringField): """`Select2Tags `_ styled text field. diff --git a/flask_admin/model/base.py b/flask_admin/model/base.py index e3ea7033d..4a41a2ecc 100644 --- a/flask_admin/model/base.py +++ b/flask_admin/model/base.py @@ -28,6 +28,7 @@ from .._types import T_COLUMN_LIST from .._types import T_COLUMN_TYPE_FORMATTERS from .._types import T_FIELD_ARGS_VALIDATORS_FILES +from .._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from .._types import T_FILTER from .._types import T_INSTRUMENTED_ATTRIBUTE from .._types import T_ORM_MODEL @@ -668,7 +669,10 @@ class MyModelView(BaseModelView): """ - form_args: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None + form_args: ( + dict[str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE] + | None + ) = None """ Dictionary of form field arguments. Refer to WTForms documentation for list of possible options. @@ -1403,7 +1407,10 @@ def scaffold_form(self) -> type[Form]: def scaffold_list_form( self, widget: type[T_WIDGET] | None = None, - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + validators: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, ) -> type[Form]: """ Create form for the `index_view` using only the columns from @@ -1456,7 +1463,12 @@ class MyModelView(BaseModelView): def get_list_form(self): return self.scaffold_list_form(widget=CustomWidget) """ - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None + validators: ( + dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None + ) = None if self.form_args: # get only validators, other form_args can break FieldList wrapper validators = dict( diff --git a/flask_admin/tests/sqla/test_form.py b/flask_admin/tests/sqla/test_form.py index c97187b40..2ec0624f5 100644 --- a/flask_admin/tests/sqla/test_form.py +++ b/flask_admin/tests/sqla/test_form.py @@ -1,17 +1,28 @@ +import enum import inspect import typing as t from unittest.mock import MagicMock import pytest +import sqlalchemy as sa +import sqlalchemy_utils as sau import wtforms +from flask import Flask from sqlalchemy import ARRAY from sqlalchemy import Column from sqlalchemy import Float from sqlalchemy import Integer from sqlalchemy import String from wtforms.fields.simple import StringField +from wtforms.validators import Length +from wtforms.validators import NumberRange +from flask_admin.base import Admin from flask_admin.contrib.sqla.form import AdminModelConverter +from flask_admin.tests.conftest import skip_or_return_session_or_db +from flask_admin.tests.conftest import T_ANY_SQLA_PROVIDER +from flask_admin.tests.conftest import T_LITERAL_SESSION_OR_DB +from flask_admin.tests.sqla.test_basic import CustomModelView sqla_admin_model_converters = [ method_name @@ -142,3 +153,145 @@ def test_conv_ARRAY_missing_item_type_falls_back_to_text(self) -> None: bound.process_formdata(["x,y"]) assert bound.data == ["x", "y"] + + +class EnumChoices(enum.Enum): + First = 101 + Second = 150 + + +class StrEnumChoices(enum.Enum): + First = "101" + Second = "150" + + +def create_models(sqla_db_ext: T_ANY_SQLA_PROVIDER) -> t.Any: + class Model1(sqla_db_ext.Base): # type: ignore[name-defined, misc] + __tablename__ = "model1" + + id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) + test1 = sa.Column(sa.String(20)) + int_field = sa.Column(sa.Integer) + float_field = sa.Column(sa.Float) + choice_field = sa.Column(sa.String, nullable=True) + enum_field = sa.Column(sa.Enum("101", "150"), nullable=True) # type: ignore[var-annotated] + enum_type_field = sa.Column(sa.Enum(EnumChoices), nullable=True) # type: ignore[var-annotated] + sau_choicetype = sa.Column( + sau.ChoiceType( + [("101", "First"), ("150", "Second")] + ) # default impl=sa.String() + ) + sau_choicetype_impl_int = sa.Column( + sau.ChoiceType([(101, "First"), (150, "Second")], impl=sa.Integer()) + ) + sau_choicetype_with_enum = sa.Column( + sau.ChoiceType(EnumChoices, impl=sa.Integer()) + ) + sau_choicetype_with_strenum = sa.Column( + sau.ChoiceType(StrEnumChoices, impl=sa.String()) + ) + + def __str__(self) -> str: + return str(self.test1.value) if self.test1 else "" + + sqla_db_ext.create_all() + + return Model1 + + +def prepare_kwargs( + expected_coerce: type[t.Any], + use_coerce_explicitly: bool, + field_name: str, +) -> dict[str, t.Any]: + validators: list[t.Any] = [] + kwargs: dict[str, t.Any] = dict() + + if expected_coerce in [int, float]: + f_choices = [(expected_coerce(101), "First"), (expected_coerce(150), "Second")] + validators = [NumberRange(min=100, max=199)] + kwargs["form_choices"] = {field_name: f_choices} + + elif expected_coerce in [ + str, + ]: + f_choices = [(expected_coerce(101), "First"), (expected_coerce(150), "Second")] + validators = [Length(min=1, max=3)] + kwargs["form_choices"] = {field_name: f_choices} + + elif expected_coerce in [EnumChoices, StrEnumChoices]: + pass + + kwargs["form_columns"] = [field_name] + + if use_coerce_explicitly: + kwargs["form_args"] = dict() + kwargs["form_args"][field_name] = {"validators": validators} + kwargs["form_args"][field_name]["coerce"] = expected_coerce + + return kwargs + + +@pytest.mark.parametrize("use_coerce_explicitly", [False, True]) +@pytest.mark.parametrize( + "field_name, expected_coerce, coerced_value, model_value", + [ + ("int_field", int, None, 101), + ("float_field", float, None, 101.0), + ("choice_field", str, None, "101"), + ("enum_field", str, None, "101"), + ("enum_type_field", EnumChoices, "First", EnumChoices.First), + ("sau_choicetype", str, None, sau.Choice("101", "First")), + ("sau_choicetype_impl_int", int, None, sau.Choice(101, "First")), + ("sau_choicetype_with_enum", EnumChoices, "101", EnumChoices.First), + ("sau_choicetype_with_strenum", StrEnumChoices, "101", StrEnumChoices.First), + ], +) +def test_coerce( + app: Flask, + admin: Admin, + sqla_db_ext: T_ANY_SQLA_PROVIDER, + session_or_db: T_LITERAL_SESSION_OR_DB, + field_name: str, + expected_coerce: type[t.Any], + use_coerce_explicitly: bool, + coerced_value: t.Any, + model_value: t.Any, +) -> None: + with app.app_context(): + Model1 = create_models(sqla_db_ext) + sqla_db_ext.db.session.add_all( + [ + Model1(test1="101", int_field=101), + Model1(test1="102", int_field=102), + ] + ) + sqla_db_ext.db.session.commit() + + kwargs = prepare_kwargs(expected_coerce, use_coerce_explicitly, field_name) + + param = skip_or_return_session_or_db(sqla_db_ext, session_or_db) + + view1 = CustomModelView(Model1, param, name="My Model1", **kwargs) + admin.add_view(view1) + + value = expected_coerce(101) if coerced_value is None else coerced_value + + client = app.test_client() + + rv = client.get("/admin/model1/new/") + data = rv.data.decode("utf-8") + assert f'value="{value}"' in data + assert ">First" in data + + rv = client.post( + "/admin/model1/new/", + data={field_name: f"{value}"}, + follow_redirects=True, + ) + data = rv.data.decode("utf-8") + assert "Record was successfully created" in data + + inserted = sqla_db_ext.db.session.query(Model1).order_by(Model1.id.desc()).first() + assert inserted is not None + assert getattr(inserted, field_name) == model_value