Skip to content

Commit 5196d22

Browse files
authored
fix Enum validation (#16)
1 parent d905f80 commit 5196d22

5 files changed

Lines changed: 26 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ docs = [
7575
"sphinx-rtd-theme",
7676
]
7777
testing = [
78-
"packaging", # A test case uses packaging.version.Version
7978
"pydantic>=2",
8079
"pytest",
8180
"pytest-cov",

src/py_avro_schema/_schemas.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,9 @@
4646
get_type_hints,
4747
)
4848

49-
import avro.name
5049
import more_itertools
5150
import orjson
5251
import typeguard
53-
from avro.errors import InvalidName
5452

5553
import py_avro_schema._typing
5654
from py_avro_schema._alias import get_aliases, get_field_aliases_and_actual_type
@@ -73,6 +71,7 @@
7371
NamesType = List[str]
7472

7573
RUNTIME_TYPE_KEY = "_runtime_type"
74+
SYMBOL_REGEX = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")
7675

7776

7877
class TypeNotSupportedError(TypeError):
@@ -92,6 +91,7 @@ class Option(enum.Flag):
9291
JSON_INDENT_2 = orjson.OPT_INDENT_2
9392

9493
#: Sort keys in JSON data
94+
9595
JSON_SORT_KEYS = orjson.OPT_SORT_KEYS
9696

9797
#: Append a newline character at the end of the JSON data
@@ -930,12 +930,14 @@ def __init__(self, py_type: Type[enum.Enum], namespace: Optional[str] = None, op
930930
raise TypeError(f"Avro enum schema members must be strings. {py_type} uses {symbol_types} values.")
931931

932932
def _is_valid_enum(self) -> bool:
933-
"""Checks if all the symbols of the enum are valid Avro names."""
934-
try:
935-
for _symbol in self.symbols:
936-
avro.name.validate_basename(_symbol)
937-
except InvalidName:
938-
return False
933+
"""
934+
Checks if all the symbols of the enum are valid Avro names.
935+
Based on `fastavro._schema_py._validate_enum_symbols`
936+
"""
937+
for _symbol in self.symbols:
938+
if not isinstance(_symbol, str) or not SYMBOL_REGEX.fullmatch(_symbol):
939+
return False
940+
939941
return True
940942

941943
def data(self, names: NamesType) -> JSONType:

src/py_avro_schema/_testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414
Test functions
1515
"""
16+
1617
import dataclasses
1718
import difflib
1819
from typing import Dict, Type, Union

src/py_avro_schema/_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414
Additional type hint classes etc
1515
"""
16+
1617
import dataclasses
1718
import decimal
1819
from typing import _GenericAlias # type: ignore

tests/test_primitives.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ class PyType(str): ...
6565

6666

6767
def test_str_subclass_other_classes():
68-
import packaging.version
68+
class MyTestClass:
69+
def capitalize(self):
70+
return self
6971

70-
class PyType(packaging.version.Version, str): ...
72+
class PyType(MyTestClass, str): ...
7173

7274
expected = {
7375
"type": "string",
@@ -545,6 +547,16 @@ class OriginProtocolPolicy(str, enum.Enum):
545547
assert_schema(OriginProtocolPolicy, expected)
546548

547549

550+
def test_str_enum_invalid_name_with_dot():
551+
class StateReasonCode(enum.StrEnum):
552+
FunctionError_ExtensionInitError = "FunctionError.ExtensionInitError"
553+
FunctionError_InvalidEntryPoint = "FunctionError.InvalidEntryPoint"
554+
555+
expected = {"namedString": "StateReasonCode", "type": "string"}
556+
557+
assert_schema(StateReasonCode, expected)
558+
559+
548560
def test_duplicated_invalid_enum():
549561
class OriginProtocolPolicy(str, enum.Enum):
550562
http_only = "http-only"

0 commit comments

Comments
 (0)