diff --git a/src/viur/core/modules/user.py b/src/viur/core/modules/user.py index 5d2e65c9d..56881a81b 100644 --- a/src/viur/core/modules/user.py +++ b/src/viur/core/modules/user.py @@ -27,6 +27,7 @@ from viur.core.prototypes.list import List from viur.core.ratelimit import RateLimit from viur.core.securityheaders import extendCsp +from viur.core.skeleton import SkeletonInstance @functools.total_ordering @@ -194,7 +195,7 @@ def __init__(self, moduleName, modulePath, userModule): super().__init__(moduleName, modulePath) self._user_module = userModule - def can_handle(self, skel: skeleton.SkeletonInstance) -> bool: + def can_handle(self, skel: skeleton.SkeletonInstance[UserSkel]) -> bool: return True @classmethod @@ -1309,7 +1310,7 @@ def editSkel(self, *args, **kwargs): def secondFactorProviderByClass(self, cls) -> UserSecondFactorAuthentication: return getattr(self, f"f2_{cls.__name__.lower()}") - def getCurrentUser(self): + def getCurrentUser(self) -> SkeletonInstance[UserSkel] | None: session = current.session.get() req = current.request.get() diff --git a/src/viur/core/skeleton.py b/src/viur/core/skeleton.py index 9ef951221..1722f4ba6 100644 --- a/src/viur/core/skeleton.py +++ b/src/viur/core/skeleton.py @@ -14,7 +14,6 @@ from itertools import chain from deprecated.sphinx import deprecated - from viur.core import conf, current, db, email, errors, translate, utils from viur.core.bones import ( BaseBone, @@ -41,6 +40,8 @@ ABSTRACT_SKEL_CLS_SUFFIX = "AbstractSkel" KeyType: t.TypeAlias = db.Key | str | int +Skeleton_Cls = t.TypeVar("Skeleton_Cls", bound="BaseSkeleton") + class MetaBaseSkel(type): """ @@ -129,7 +130,7 @@ def __setattr__(self, key, value): value.__set_name__(self, key) -class SkeletonInstance: +class SkeletonInstance(t.Generic[Skeleton_Cls]): """ The actual wrapper around a Skeleton-Class. An object of this class is what's actually returned when you call a Skeleton-Class. With ViUR3, you don't get an instance of a Skeleton-Class any more - it's always this @@ -148,7 +149,7 @@ class SkeletonInstance: def __init__( self, - skel_cls: t.Type[Skeleton], + skel_cls: t.Type[Skeleton_Cls], *, bones: t.Iterable[str] = (), bone_map: t.Optional[t.Dict[str, BaseBone]] = None, @@ -180,7 +181,7 @@ def __init__( bone_map = bone_map or {} if bones: - names = ("key", ) + tuple(bones) + names = ("key",) + tuple(bones) # generate full keys sequence based on definition; keeps order of patterns! keys = [] @@ -218,7 +219,7 @@ def __init__( self.is_cloned = clone self.renderAccessedValues = {} self.renderPreparation = None - self.skeletonCls = skel_cls + self.skeletonCls: t.Type[Skeleton_Cls] = skel_cls def items(self, yieldBoneValues: bool = False) -> t.Iterable[tuple[str, BaseBone]]: if yieldBoneValues: @@ -447,7 +448,7 @@ def __deepcopy__(self, memodict): return res -class BaseSkeleton(object, metaclass=MetaBaseSkel): +class BaseSkeleton(metaclass=MetaBaseSkel): """ This is a container-object holding information about one database entity. @@ -610,7 +611,7 @@ def setBoneValue( @classmethod def fromClient( cls, - skel: SkeletonInstance, + skel: SkeletonInstance[t.Self], data: dict[str, list[str] | str], *, amend: bool = False, @@ -690,7 +691,7 @@ def fromClient( return complete @classmethod - def refresh(cls, skel: SkeletonInstance): + def refresh(cls, skel: SkeletonInstance[t.Self]): """ Refresh the bones current content. @@ -706,7 +707,7 @@ def refresh(cls, skel: SkeletonInstance): _ = skel[key] # Ensure value gets loaded bone.refresh(skel, key) - def __new__(cls, *args, **kwargs) -> SkeletonInstance: + def __new__(cls, *args, **kwargs) -> SkeletonInstance[t.Self]: return SkeletonInstance(cls, *args, **kwargs) @@ -1157,7 +1158,7 @@ def read( if db_res := db.Get(db_key): skel.setEntity(db_res) return skel - elif create in (False, None): + elif create in (False, None): return None elif isinstance(create, dict): if create and not skel.fromClient(create, amend=True): @@ -1378,8 +1379,10 @@ def __txn_write(write_skel): skel.dbEntity["viur"].setdefault("viurActiveSeoKeys", []) for language, seo_key in last_set_seo_keys.items(): - if skel.dbEntity["viur"]["viurCurrentSeoKeys"][language] not in \ - skel.dbEntity["viur"]["viurActiveSeoKeys"]: + if ( + skel.dbEntity["viur"]["viurCurrentSeoKeys"][language] + not in skel.dbEntity["viur"]["viurActiveSeoKeys"] + ): # Ensure the current, active seo key is in the list of all seo keys skel.dbEntity["viur"]["viurActiveSeoKeys"].insert(0, seo_key) if str(skel.dbEntity.key.id_or_name) not in skel.dbEntity["viur"]["viurActiveSeoKeys"]: @@ -1803,7 +1806,7 @@ def read(self, key: t.Optional[db.Key | str | int] = None) -> SkeletonInstance: return skel -class SkelList(list): +class SkelList(list, t.Generic[Skeleton_Cls]): """ This class is used to hold multiple skeletons together with other, commonly used information. @@ -1822,12 +1825,12 @@ class SkelList(list): "renderPreparation", ) - def __init__(self, baseSkel=None): + def __init__(self, baseSkel: SkeletonInstance[Skeleton_Cls] = None): """ :param baseSkel: The baseclass for all entries in this list """ - super(SkelList, self).__init__() - self.baseSkel = baseSkel or {} + super().__init__() + self.baseSkel: SkeletonInstance[Skeleton_Cls] | dict = baseSkel or {} self.getCursor = lambda: None self.get_orders = lambda: None self.renderPreparation = None