Skip to content

Commit 81d40e3

Browse files
committed
feat: add some missing types
1 parent 61268ab commit 81d40e3

File tree

12 files changed

+140
-100
lines changed

12 files changed

+140
-100
lines changed

conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
@pytest.fixture(autouse=True)
8-
def add_doctest_namespace(doctest_namespace):
8+
def add_doctest_namespace(doctest_namespace: dict) -> dict:
99
doctest_namespace["pydantic"] = pydantic
1010
imports = {item: getattr(scim2_models, item) for item in scim2_models.__all__}
1111
doctest_namespace.update(imports)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ exclude_lines = [
6868
"pragma: no cover",
6969
"raise NotImplementedError",
7070
"except ImportError",
71+
"if TYPE_CHECKING:",
7172
]
7273

7374
[tool.ruff.lint]

scim2_models/base.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic import BaseModel as PydanticBaseModel
1515
from pydantic import ConfigDict
1616
from pydantic import Field
17+
from pydantic import FieldSerializationInfo
1718
from pydantic import GetCoreSchemaHandler
1819
from pydantic import SerializationInfo
1920
from pydantic import SerializerFunctionWrapHandler
@@ -115,7 +116,9 @@ def validate_attribute_urn(
115116
return f"{schema}:{attribute_base}"
116117

117118

118-
def contains_attribute_or_subattributes(attribute_urns: list[str], attribute_urn: str):
119+
def contains_attribute_or_subattributes(
120+
attribute_urns: list[str], attribute_urn: str
121+
) -> bool:
119122
return attribute_urn in attribute_urns or any(
120123
item.startswith(f"{attribute_urn}.") or item.startswith(f"{attribute_urn}:")
121124
for item in attribute_urns
@@ -412,7 +415,7 @@ class Required(Enum):
412415

413416
_default = false
414417

415-
def __bool__(self):
418+
def __bool__(self) -> bool:
416419
return self.value
417420

418421

@@ -424,7 +427,7 @@ class CaseExact(Enum):
424427

425428
_default = false
426429

427-
def __bool__(self):
430+
def __bool__(self) -> bool:
428431
return self.value
429432

430433

@@ -449,7 +452,7 @@ def get_field_annotation(cls, field_name: str, annotation_type: type) -> Any:
449452

450453
default_value = getattr(annotation_type, "_default", None)
451454

452-
def annotation_type_filter(item):
455+
def annotation_type_filter(item: Any) -> bool:
453456
return isinstance(item, annotation_type)
454457

455458
field_annotation = next(
@@ -647,7 +650,9 @@ def check_replacement_request_mutability(
647650
return value
648651

649652
@classmethod
650-
def check_mutability_issues(cls, original: "BaseModel", replacement: "BaseModel"):
653+
def check_mutability_issues(
654+
cls, original: "BaseModel", replacement: "BaseModel"
655+
) -> None:
651656
"""Compare two instances, and check for differences of values on the fields marked as immutable."""
652657
model = replacement.__class__
653658
for field_name in model.model_fields:
@@ -662,15 +667,17 @@ def check_mutability_issues(cls, original: "BaseModel", replacement: "BaseModel"
662667
)
663668

664669
attr_type = model.get_field_root_type(field_name)
665-
if is_complex_attribute(attr_type) and not model.get_field_multiplicity(
666-
field_name
670+
if (
671+
attr_type
672+
and is_complex_attribute(attr_type)
673+
and not model.get_field_multiplicity(field_name)
667674
):
668675
original_val = getattr(original, field_name)
669676
replacement_value = getattr(replacement, field_name)
670677
if original_val is not None and replacement_value is not None:
671678
cls.check_mutability_issues(original_val, replacement_value)
672679

673-
def mark_with_schema(self):
680+
def mark_with_schema(self) -> None:
674681
"""Navigate through attributes and sub-attributes of type ComplexAttribute, and mark them with a '_schema' attribute.
675682
676683
'_schema' will later be used by 'get_attribute_urn'.
@@ -679,7 +686,7 @@ def mark_with_schema(self):
679686

680687
for field_name in self.__class__.model_fields:
681688
attr_type = self.get_field_root_type(field_name)
682-
if not is_complex_attribute(attr_type):
689+
if not attr_type or not is_complex_attribute(attr_type):
683690
continue
684691

685692
main_schema = (
@@ -702,7 +709,7 @@ def scim_serializer(
702709
self,
703710
value: Any,
704711
handler: SerializerFunctionWrapHandler,
705-
info: SerializationInfo,
712+
info: FieldSerializationInfo,
706713
) -> Any:
707714
"""Serialize the fields according to mutability indications passed in the serialization context."""
708715
value = handler(value)
@@ -716,7 +723,7 @@ def scim_serializer(
716723

717724
return value
718725

719-
def scim_request_serializer(self, value: Any, info: SerializationInfo) -> Any:
726+
def scim_request_serializer(self, value: Any, info: FieldSerializationInfo) -> Any:
720727
"""Serialize the fields according to mutability indications passed in the serialization context."""
721728
mutability = self.get_field_annotation(info.field_name, Mutability)
722729
scim_ctx = info.context.get("scim") if info.context else None
@@ -740,7 +747,7 @@ def scim_request_serializer(self, value: Any, info: SerializationInfo) -> Any:
740747

741748
return value
742749

743-
def scim_response_serializer(self, value: Any, info: SerializationInfo) -> Any:
750+
def scim_response_serializer(self, value: Any, info: FieldSerializationInfo) -> Any:
744751
"""Serialize the fields according to returnability indications passed in the serialization context."""
745752
returnability = self.get_field_annotation(info.field_name, Returned)
746753
attribute_urn = self.get_attribute_urn(info.field_name)
@@ -774,7 +781,7 @@ def scim_response_serializer(self, value: Any, info: SerializationInfo) -> Any:
774781

775782
@model_serializer(mode="wrap")
776783
def model_serializer_exclude_none(
777-
self, handler, info: SerializationInfo
784+
self, handler: SerializerFunctionWrapHandler, info: SerializationInfo
778785
) -> dict[str, Any]:
779786
"""Remove `None` values inserted by the :meth:`~scim2_models.base.BaseModel.scim_serializer`."""
780787
self.mark_with_schema()
@@ -787,7 +794,7 @@ def model_validate(
787794
*args,
788795
scim_ctx: Optional[Context] = Context.DEFAULT,
789796
original: Optional["BaseModel"] = None,
790-
**kwargs,
797+
**kwargs: Any,
791798
) -> Self:
792799
"""Validate SCIM payloads and generate model representation by using Pydantic :code:`BaseModel.model_validate`.
793800
@@ -812,8 +819,8 @@ def _prepare_model_dump(
812819
scim_ctx: Optional[Context] = Context.DEFAULT,
813820
attributes: Optional[list[str]] = None,
814821
excluded_attributes: Optional[list[str]] = None,
815-
**kwargs,
816-
):
822+
**kwargs: Any,
823+
) -> dict[str, Any]:
817824
kwargs.setdefault("context", {}).setdefault("scim", scim_ctx)
818825
kwargs["context"]["scim_attributes"] = [
819826
validate_attribute_urn(attribute, self.__class__)
@@ -832,11 +839,11 @@ def _prepare_model_dump(
832839

833840
def model_dump(
834841
self,
835-
*args,
842+
*args: Any,
836843
scim_ctx: Optional[Context] = Context.DEFAULT,
837844
attributes: Optional[list[str]] = None,
838845
excluded_attributes: Optional[list[str]] = None,
839-
**kwargs,
846+
**kwargs: Any,
840847
) -> dict:
841848
"""Create a model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump`.
842849
@@ -853,12 +860,12 @@ def model_dump(
853860

854861
def model_dump_json(
855862
self,
856-
*args,
863+
*args: Any,
857864
scim_ctx: Optional[Context] = Context.DEFAULT,
858865
attributes: Optional[list[str]] = None,
859866
excluded_attributes: Optional[list[str]] = None,
860-
**kwargs,
861-
) -> dict:
867+
**kwargs: Any,
868+
) -> str:
862869
"""Create a JSON model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump_json`.
863870
864871
:param scim_ctx: If a SCIM context is passed, some default values of
@@ -920,12 +927,12 @@ class MultiValuedComplexAttribute(ComplexAttribute):
920927
reference."""
921928

922929

923-
def is_complex_attribute(type) -> bool:
930+
def is_complex_attribute(type_: type) -> bool:
924931
# issubclass raise a TypeError with 'Reference' on python < 3.11
925932
return (
926-
get_origin(type) != Reference
927-
and isclass(type)
928-
and issubclass(type, (ComplexAttribute, MultiValuedComplexAttribute))
933+
get_origin(type_) != Reference
934+
and isclass(type_)
935+
and issubclass(type_, (ComplexAttribute, MultiValuedComplexAttribute))
929936
)
930937

931938

scim2_models/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
PYTHON_RESERVED_WORDS = [
1+
PYTHON_RESERVED_WORDS: list[str] = [
22
"False",
33
"def",
44
"if",
@@ -34,5 +34,5 @@
3434
"pass",
3535
]
3636

37-
PYDANTIC_RESERVED_WORDS = ["schema"]
38-
RESERVED_WORDS = PYTHON_RESERVED_WORDS + PYDANTIC_RESERVED_WORDS
37+
PYDANTIC_RESERVED_WORDS: list[str] = ["schema"]
38+
RESERVED_WORDS: list[str] = PYTHON_RESERVED_WORDS + PYDANTIC_RESERVED_WORDS

0 commit comments

Comments
 (0)