diff --git a/flask_admin/_types.py b/flask_admin/_types.py index 2cfbc44e6..6fa82a1bd 100644 --- a/flask_admin/_types.py +++ b/flask_admin/_types.py @@ -264,6 +264,10 @@ class T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK(T_FIELD_ARGS_VALIDATORS): allow_blank: NotRequired[bool] +class T_FIELD_ARGS_VALIDATORS_SELECTABLE(T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK): + coerce: NotRequired[t.Callable[[t.Any], t.Any]] + + class T_FIELD_ARGS_VALIDATORS_FILES(T_FIELD_ARGS_VALIDATORS): base_path: NotRequired[str] allow_overwrite: NotRequired[bool] diff --git a/flask_admin/contrib/peewee/view.py b/flask_admin/contrib/peewee/view.py index d7151f55d..8531dd837 100644 --- a/flask_admin/contrib/peewee/view.py +++ b/flask_admin/contrib/peewee/view.py @@ -26,6 +26,7 @@ from flask_admin.model.form import InlineFormAdmin from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from ..._types import T_FILTER from ..._types import T_PEEWEE_MODEL from ..._types import T_WIDGET @@ -338,7 +339,10 @@ def scaffold_form(self) -> type[Form]: def scaffold_list_form( self, widget: type[T_WIDGET] | None = None, - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + validators: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, ) -> type[Form]: """ Create form for the `index_view` using only the columns from diff --git a/flask_admin/contrib/sqla/form.py b/flask_admin/contrib/sqla/form.py index a22d021a0..850c817fd 100644 --- a/flask_admin/contrib/sqla/form.py +++ b/flask_admin/contrib/sqla/form.py @@ -9,6 +9,8 @@ from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy.orm import ColumnProperty +from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy_utils import ChoiceType from wtforms import fields from wtforms import Form from wtforms import HiddenField @@ -36,6 +38,7 @@ from ..._types import T_FIELD_ARGS_VALIDATORS from ..._types import T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from ..._types import T_INSTRUMENTED_ATTRIBUTE from ..._types import T_MODEL_VIEW from ..._types import T_ORM_MODEL @@ -230,7 +233,7 @@ def convert( if isinstance(prop, FieldPlaceholder): return form.recreate_field(prop.field) - kwargs: T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK = {"validators": [], "filters": []} + kwargs: T_FIELD_ARGS_VALIDATORS_SELECTABLE = {"validators": [], "filters": []} if field_args: kwargs.update(field_args) # type: ignore[typeddict-item] @@ -360,7 +363,11 @@ def convert( form_choices = getattr(self.view, "form_choices", None) if mapper.class_ == self.view.model and form_choices: choices = form_choices.get(prop.key) + if choices: + if "coerce" not in kwargs: + kwargs["coerce"] = coerce_factory(column.type) + return form.Select2Field( # type: ignore[misc] choices=choices, allow_blank=column.nullable, # type: ignore[arg-type] @@ -655,6 +662,18 @@ def avoid_empty_strings(value: T) -> T | None: return value if value else None +def coerce_factory(type_: TypeEngine[t.Any]) -> t.Callable[[t.Any], t.Any]: + """ + Return a function to coerce a column, for use by Select2Field. + :param type_: Column type + """ + + if isinstance(type_, ChoiceType): + return choice_type_coerce_factory(type_) + else: + return type_.python_type + + def choice_type_coerce_factory(type_: T_CHOICE_TYPE) -> t.Callable[[t.Any], t.Any]: """ Return a function to coerce a ChoiceType column, for use by Select2Field. @@ -671,8 +690,10 @@ def choice_type_coerce_factory(type_: T_CHOICE_TYPE) -> t.Callable[[t.Any], t.An def choice_coerce(value: t.Any) -> t.Any: if value is None: return None + if isinstance(value, choice_cls): return getattr(value, key) + return type_.python_type(value) return choice_coerce @@ -699,7 +720,10 @@ def get_form( base_class: type[form.BaseForm] = form.BaseForm, only: t.Collection[str | T_INSTRUMENTED_ATTRIBUTE] | None = None, exclude: t.Collection[str | T_INSTRUMENTED_ATTRIBUTE] | None = None, - field_args: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + field_args: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, hidden_pk: bool = False, ignore_hidden: bool = True, extra_fields: dict[str | T_INSTRUMENTED_ATTRIBUTE, UnboundField[t.Any]] diff --git a/flask_admin/contrib/sqla/view.py b/flask_admin/contrib/sqla/view.py index 8126c529d..6b2f402bd 100644 --- a/flask_admin/contrib/sqla/view.py +++ b/flask_admin/contrib/sqla/view.py @@ -39,6 +39,7 @@ from ..._types import T_COLUMN from ..._types import T_COLUMN_LIST from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from ..._types import T_FILTER from ..._types import T_INSTRUMENTED_ATTRIBUTE from ..._types import T_SQLALCHEMY_COLUMN @@ -892,7 +893,10 @@ def scaffold_form(self) -> type[Form]: def scaffold_list_form( self, widget: type[T_WIDGET] | None = None, - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + validators: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, ) -> type[Form]: """ Create form for the `index_view` using only the columns from diff --git a/flask_admin/form/fields.py b/flask_admin/form/fields.py index e7818a80d..b76b74f99 100644 --- a/flask_admin/form/fields.py +++ b/flask_admin/form/fields.py @@ -4,6 +4,7 @@ import time import typing as t from enum import Enum +from inspect import isclass from wtforms import fields from wtforms.form import BaseForm @@ -196,6 +197,9 @@ def process_formdata(self, valuelist: t.Sequence[str] | None) -> None: self.data = None else: try: + if isclass(self.coerce) and issubclass(self.coerce, Enum): + self.coerce = self._enum_coerce_factory(self.coerce) + self.data = self.coerce(valuelist[0]) except ValueError as err: raise ValueError( @@ -208,6 +212,28 @@ def pre_validate(self, form: BaseForm) -> None: super().pre_validate(form) + def _enum_coerce_factory(self, type_: type[Enum]) -> t.Callable[[t.Any], t.Any]: + """ + Return a function to coerce an Enum column, for use by Select2Field. + :param type_: Enum class + """ + + def enum_coerce(value: t.Any) -> t.Any: + if value is None: + return None + + if isinstance(value, type_): + return value + + ename = getattr(value, "name", value) + ename = str(value).replace(type_.__name__ + ".", "") + try: + return type_[ename] + except KeyError: + return type_(value) + + return enum_coerce + class Select2TagsField(fields.StringField): """`Select2Tags `_ styled text field. diff --git a/flask_admin/model/base.py b/flask_admin/model/base.py index f3c0a4117..4546df8ed 100644 --- a/flask_admin/model/base.py +++ b/flask_admin/model/base.py @@ -28,6 +28,7 @@ from .._types import T_COLUMN_LIST from .._types import T_COLUMN_TYPE_FORMATTERS from .._types import T_FIELD_ARGS_VALIDATORS_FILES +from .._types import T_FIELD_ARGS_VALIDATORS_SELECTABLE from .._types import T_FILTER from .._types import T_INSTRUMENTED_ATTRIBUTE from .._types import T_ORM_MODEL @@ -654,7 +655,10 @@ class MyModelView(BaseModelView): """ - form_args: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None + form_args: ( + dict[str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE] + | None + ) = None """ Dictionary of form field arguments. Refer to WTForms documentation for list of possible options. @@ -1389,7 +1393,10 @@ def scaffold_form(self) -> type[Form]: def scaffold_list_form( self, widget: type[T_WIDGET] | None = None, - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, + validators: dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None = None, ) -> type[Form]: """ Create form for the `index_view` using only the columns from @@ -1442,7 +1449,12 @@ class MyModelView(BaseModelView): def get_list_form(self): return self.scaffold_list_form(widget=CustomWidget) """ - validators: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None + validators: ( + dict[ + str, T_FIELD_ARGS_VALIDATORS_FILES | T_FIELD_ARGS_VALIDATORS_SELECTABLE + ] + | None + ) = None if self.form_args: # get only validators, other form_args can break FieldList wrapper validators = dict( diff --git a/flask_admin/tests/sqla/test_form.py b/flask_admin/tests/sqla/test_form.py index 48946ee8f..bf363b691 100644 --- a/flask_admin/tests/sqla/test_form.py +++ b/flask_admin/tests/sqla/test_form.py @@ -1,11 +1,19 @@ +import enum import inspect +import typing as t from unittest.mock import MagicMock import pytest +import sqlalchemy as sa +import sqlalchemy_utils as sa_utils import wtforms from wtforms.fields.simple import StringField +from wtforms.validators import Length +from wtforms.validators import NumberRange from flask_admin.contrib.sqla.form import AdminModelConverter +from flask_admin.tests.conftest import skip_or_return_session_or_db +from flask_admin.tests.sqla.test_basic import CustomModelView sqla_admin_model_converters = [ method_name @@ -52,3 +60,178 @@ class TestForm(wtforms.Form): pass assert field() == "

widget overridden

" + + +class EnumChoices(enum.Enum): + first = 101 + second = 150 + + +class StrEnumChoices(enum.Enum): + first = "101" + second = "150" + + +def create_models(sqla_db_ext): + class Model1(sqla_db_ext.Base): # type: ignore[name-defined, misc] + __tablename__ = "model1" + + def __init__(self, test1, int_field): + self.test1 = test1 + self.int_field = int_field + + id = sa.Column(sa.Integer, primary_key=True) + test1 = sa.Column(sa.String(20)) + int_field = sa.Column(sa.Integer) + float_field = sa.Column(sa.Float) + choice_field = sa.Column(sa.String, nullable=True) + enum_field = sa.Column(sa.Enum("101", "150"), nullable=True) # type: ignore[var-annotated] + enum_type_field = sa.Column(sa.Enum(EnumChoices), nullable=True) # type: ignore[var-annotated] + sa_utils_choicetype = sa.Column( + sa_utils.ChoiceType( + [("101", "First"), ("150", "Second")] + ) # default impl=sa.String() + ) + sa_utils_choicetype_impl_int = sa.Column( + sa_utils.ChoiceType([(101, "First"), (150, "Second")], impl=sa.Integer()) + ) + sa_utils_choicetype_with_enum = sa.Column( + sa_utils.ChoiceType(EnumChoices, impl=sa.Integer()) + ) + sa_utils_choicetype_with_strenum = sa.Column( + sa_utils.ChoiceType(StrEnumChoices, impl=sa.String()) + ) + + def __str__(self): + return self.test1 + + sqla_db_ext.create_all() + + return Model1 + + +@pytest.mark.parametrize("use_coerce_explicitly", [False, True]) +@pytest.mark.parametrize( + "field_name, expected_coerce", + [ + ("int_field", int), + ("float_field", float), + ("choice_field", str), + ("enum_field", str), + ("sa_utils_choicetype", str), + ("sa_utils_choicetype_impl_int", int), + ], +) +def test_coerce( + app, + admin, + sqla_db_ext, + session_or_db, + field_name, + expected_coerce, + use_coerce_explicitly, +): + with app.app_context(): + Model1 = create_models(sqla_db_ext) + sqla_db_ext.db.session.add_all( + [ + Model1("101", int_field=101), + Model1("102", int_field=102), + ] + ) + sqla_db_ext.db.session.commit() + + validators: list[t.Any] = [] + if expected_coerce in [int, float]: + validators = [NumberRange(min=100, max=199)] + elif expected_coerce in [str]: + validators = [Length(min=1, max=3)] + + f_choices = [(expected_coerce(101), "First"), (expected_coerce(150), "Second")] + + kwargs = { + "form_columns": [field_name], + "form_choices": {field_name: f_choices}, + } + + if use_coerce_explicitly: + kwargs["form_args"] = dict() + kwargs["form_args"][field_name] = {"validators": validators} # type: ignore[index] + kwargs["form_args"][field_name]["coerce"] = expected_coerce # type: ignore[index] + + param = skip_or_return_session_or_db(sqla_db_ext, session_or_db) + + view1 = CustomModelView(Model1, param, name="My Model1", **kwargs) + admin.add_view(view1) + + client = app.test_client() + rv = client.get("/admin/model1/new/") + data = rv.data.decode("utf-8") + assert f'value="{expected_coerce(101)}"' in data + assert ">First" in data + + rv = client.post( + "/admin/model1/new/", + data={field_name: f"{expected_coerce(150)}"}, + follow_redirects=True, + ) + data = rv.data.decode("utf-8") + assert "Record was successfully created" in data + + +@pytest.mark.parametrize("use_coerce_explicitly", [False, True]) +@pytest.mark.parametrize( + "field_name, expected_coerce, value", + [ + ("enum_type_field", EnumChoices, "first"), + ("sa_utils_choicetype_with_enum", EnumChoices, "101"), + ("sa_utils_choicetype_with_strenum", StrEnumChoices, "101"), + ], +) +def test_enum_coerce( + app, + admin, + sqla_db_ext, + session_or_db, + field_name, + value, + expected_coerce, + use_coerce_explicitly, +): + with app.app_context(): + Model1 = create_models(sqla_db_ext) + sqla_db_ext.db.session.add_all( + [ + Model1("101", int_field=101), + Model1("102", int_field=102), + ] + ) + sqla_db_ext.db.session.commit() + + kwargs = { + "form_columns": [field_name], + } + + if use_coerce_explicitly: + kwargs["form_args"] = dict() # type: ignore[assignment] + kwargs["form_args"][field_name] = dict() + kwargs["form_args"][field_name]["coerce"] = expected_coerce + + param = skip_or_return_session_or_db(sqla_db_ext, session_or_db) + + view1 = CustomModelView(Model1, param, name="My Model1", **kwargs) + admin.add_view(view1) + + client = app.test_client() + rv = client.get("/admin/model1/new/") + data = rv.data.decode("utf-8") + assert f'value="{value}"' in data + assert ">first" in data + + rv = client.post( + "/admin/model1/new/", + data={field_name: value}, + follow_redirects=True, + ) + data = rv.data.decode("utf-8") + assert "Record was successfully created" in data