1414from pydantic import BaseModel as PydanticBaseModel
1515from pydantic import ConfigDict
1616from pydantic import Field
17+ from pydantic import FieldSerializationInfo
1718from pydantic import GetCoreSchemaHandler
1819from pydantic import SerializationInfo
1920from pydantic import SerializerFunctionWrapHandler
@@ -115,7 +116,9 @@ def validate_attribute_urn(
115116 return f"{ schema } :{ attribute_base } "
116117
117118
118- def contains_attribute_or_subattributes (attribute_urns : list [str ], attribute_urn : str ):
119+ def contains_attribute_or_subattributes (
120+ attribute_urns : list [str ], attribute_urn : str
121+ ) -> bool :
119122 return attribute_urn in attribute_urns or any (
120123 item .startswith (f"{ attribute_urn } ." ) or item .startswith (f"{ attribute_urn } :" )
121124 for item in attribute_urns
@@ -412,7 +415,7 @@ class Required(Enum):
412415
413416 _default = false
414417
415- def __bool__ (self ):
418+ def __bool__ (self ) -> bool :
416419 return self .value
417420
418421
@@ -424,7 +427,7 @@ class CaseExact(Enum):
424427
425428 _default = false
426429
427- def __bool__ (self ):
430+ def __bool__ (self ) -> bool :
428431 return self .value
429432
430433
@@ -449,7 +452,7 @@ def get_field_annotation(cls, field_name: str, annotation_type: type) -> Any:
449452
450453 default_value = getattr (annotation_type , "_default" , None )
451454
452- def annotation_type_filter (item ) :
455+ def annotation_type_filter (item : Any ) -> bool :
453456 return isinstance (item , annotation_type )
454457
455458 field_annotation = next (
@@ -647,7 +650,9 @@ def check_replacement_request_mutability(
647650 return value
648651
649652 @classmethod
650- def check_mutability_issues (cls , original : "BaseModel" , replacement : "BaseModel" ):
653+ def check_mutability_issues (
654+ cls , original : "BaseModel" , replacement : "BaseModel"
655+ ) -> None :
651656 """Compare two instances, and check for differences of values on the fields marked as immutable."""
652657 model = replacement .__class__
653658 for field_name in model .model_fields :
@@ -662,15 +667,17 @@ def check_mutability_issues(cls, original: "BaseModel", replacement: "BaseModel"
662667 )
663668
664669 attr_type = model .get_field_root_type (field_name )
665- if is_complex_attribute (attr_type ) and not model .get_field_multiplicity (
666- field_name
670+ if (
671+ attr_type
672+ and is_complex_attribute (attr_type )
673+ and not model .get_field_multiplicity (field_name )
667674 ):
668675 original_val = getattr (original , field_name )
669676 replacement_value = getattr (replacement , field_name )
670677 if original_val is not None and replacement_value is not None :
671678 cls .check_mutability_issues (original_val , replacement_value )
672679
673- def mark_with_schema (self ):
680+ def mark_with_schema (self ) -> None :
674681 """Navigate through attributes and sub-attributes of type ComplexAttribute, and mark them with a '_schema' attribute.
675682
676683 '_schema' will later be used by 'get_attribute_urn'.
@@ -679,7 +686,7 @@ def mark_with_schema(self):
679686
680687 for field_name in self .__class__ .model_fields :
681688 attr_type = self .get_field_root_type (field_name )
682- if not is_complex_attribute (attr_type ):
689+ if not attr_type or not is_complex_attribute (attr_type ):
683690 continue
684691
685692 main_schema = (
@@ -702,7 +709,7 @@ def scim_serializer(
702709 self ,
703710 value : Any ,
704711 handler : SerializerFunctionWrapHandler ,
705- info : SerializationInfo ,
712+ info : FieldSerializationInfo ,
706713 ) -> Any :
707714 """Serialize the fields according to mutability indications passed in the serialization context."""
708715 value = handler (value )
@@ -716,7 +723,7 @@ def scim_serializer(
716723
717724 return value
718725
719- def scim_request_serializer (self , value : Any , info : SerializationInfo ) -> Any :
726+ def scim_request_serializer (self , value : Any , info : FieldSerializationInfo ) -> Any :
720727 """Serialize the fields according to mutability indications passed in the serialization context."""
721728 mutability = self .get_field_annotation (info .field_name , Mutability )
722729 scim_ctx = info .context .get ("scim" ) if info .context else None
@@ -740,7 +747,7 @@ def scim_request_serializer(self, value: Any, info: SerializationInfo) -> Any:
740747
741748 return value
742749
743- def scim_response_serializer (self , value : Any , info : SerializationInfo ) -> Any :
750+ def scim_response_serializer (self , value : Any , info : FieldSerializationInfo ) -> Any :
744751 """Serialize the fields according to returnability indications passed in the serialization context."""
745752 returnability = self .get_field_annotation (info .field_name , Returned )
746753 attribute_urn = self .get_attribute_urn (info .field_name )
@@ -774,7 +781,7 @@ def scim_response_serializer(self, value: Any, info: SerializationInfo) -> Any:
774781
775782 @model_serializer (mode = "wrap" )
776783 def model_serializer_exclude_none (
777- self , handler , info : SerializationInfo
784+ self , handler : SerializerFunctionWrapHandler , info : SerializationInfo
778785 ) -> dict [str , Any ]:
779786 """Remove `None` values inserted by the :meth:`~scim2_models.base.BaseModel.scim_serializer`."""
780787 self .mark_with_schema ()
@@ -787,7 +794,7 @@ def model_validate(
787794 * args ,
788795 scim_ctx : Optional [Context ] = Context .DEFAULT ,
789796 original : Optional ["BaseModel" ] = None ,
790- ** kwargs ,
797+ ** kwargs : Any ,
791798 ) -> Self :
792799 """Validate SCIM payloads and generate model representation by using Pydantic :code:`BaseModel.model_validate`.
793800
@@ -812,8 +819,8 @@ def _prepare_model_dump(
812819 scim_ctx : Optional [Context ] = Context .DEFAULT ,
813820 attributes : Optional [list [str ]] = None ,
814821 excluded_attributes : Optional [list [str ]] = None ,
815- ** kwargs ,
816- ):
822+ ** kwargs : Any ,
823+ ) -> dict [ str , Any ] :
817824 kwargs .setdefault ("context" , {}).setdefault ("scim" , scim_ctx )
818825 kwargs ["context" ]["scim_attributes" ] = [
819826 validate_attribute_urn (attribute , self .__class__ )
@@ -832,11 +839,11 @@ def _prepare_model_dump(
832839
833840 def model_dump (
834841 self ,
835- * args ,
842+ * args : Any ,
836843 scim_ctx : Optional [Context ] = Context .DEFAULT ,
837844 attributes : Optional [list [str ]] = None ,
838845 excluded_attributes : Optional [list [str ]] = None ,
839- ** kwargs ,
846+ ** kwargs : Any ,
840847 ) -> dict :
841848 """Create a model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump`.
842849
@@ -853,12 +860,12 @@ def model_dump(
853860
854861 def model_dump_json (
855862 self ,
856- * args ,
863+ * args : Any ,
857864 scim_ctx : Optional [Context ] = Context .DEFAULT ,
858865 attributes : Optional [list [str ]] = None ,
859866 excluded_attributes : Optional [list [str ]] = None ,
860- ** kwargs ,
861- ) -> dict :
867+ ** kwargs : Any ,
868+ ) -> str :
862869 """Create a JSON model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump_json`.
863870
864871 :param scim_ctx: If a SCIM context is passed, some default values of
@@ -920,12 +927,12 @@ class MultiValuedComplexAttribute(ComplexAttribute):
920927 reference."""
921928
922929
923- def is_complex_attribute (type ) -> bool :
930+ def is_complex_attribute (type_ : type ) -> bool :
924931 # issubclass raise a TypeError with 'Reference' on python < 3.11
925932 return (
926- get_origin (type ) != Reference
927- and isclass (type )
928- and issubclass (type , (ComplexAttribute , MultiValuedComplexAttribute ))
933+ get_origin (type_ ) != Reference
934+ and isclass (type_ )
935+ and issubclass (type_ , (ComplexAttribute , MultiValuedComplexAttribute ))
929936 )
930937
931938
0 commit comments