diff --git a/voluptuous/schema_builder.py b/voluptuous/schema_builder.py index 895c6e9..b64eed8 100644 --- a/voluptuous/schema_builder.py +++ b/voluptuous/schema_builder.py @@ -2,6 +2,7 @@ from __future__ import annotations import collections +import enum import inspect import itertools import re @@ -226,6 +227,8 @@ def _compile(self, schema): return self._compile_tuple(schema) elif isinstance(schema, (frozenset, set)): return self._compile_set(schema) + elif isinstance(schema, enum.Enum): + return _compile_scalar(schema) type_ = type(schema) if inspect.isclass(schema): type_ = schema diff --git a/voluptuous/tests/tests.py b/voluptuous/tests/tests.py index a3dd219..befe4fb 100644 --- a/voluptuous/tests/tests.py +++ b/voluptuous/tests/tests.py @@ -1869,6 +1869,29 @@ class StringChoice(str, Enum): string_schema("hello") +def test_enum_as_schema(): + """Test enum members can be used as scalar schemas and mapping keys.""" + + class Choice(Enum): + Easy = 1 + Hard = 3 + + class StringChoice(str, Enum): + Easy = 'easy' + Hard = 'hard' + + # As a scalar schema, an enum member is matched by equality. + schema = Schema(Choice.Easy) + assert schema(Choice.Easy) == Choice.Easy + with raises(Invalid, 'not a valid value'): + schema(Choice.Hard) + + # As a mapping key, including default handling for missing keys. + dict_schema = Schema({Optional(StringChoice.Easy, default=True): bool}) + assert dict_schema({}) == {StringChoice.Easy: True} + assert dict_schema({'easy': False}) == {'easy': False} + + class MyValueClass(object): def __init__(self, value=None): self.value = value