diff --git a/ninja_jwt/__init__.py b/ninja_jwt/__init__.py index 90a7e3764..00d32e620 100644 --- a/ninja_jwt/__init__.py +++ b/ninja_jwt/__init__.py @@ -1,3 +1,3 @@ """Django Ninja JWT - JSON Web Token for Django-Ninja""" -__version__ = "5.3.7" +__version__ = "5.3.9" diff --git a/ninja_jwt/schema.py b/ninja_jwt/schema.py index 56ddac83f..e548f83c1 100644 --- a/ninja_jwt/schema.py +++ b/ninja_jwt/schema.py @@ -11,7 +11,8 @@ from ninja.schema import DjangoGetter from ninja_extra import service_resolver from ninja_extra.context import RouteContext -from pydantic import ConfigDict, model_validator +from pydantic import ConfigDict, ValidationInfo, model_validator +from pydantic.main import BaseModel import ninja_jwt.exceptions as exceptions from ninja_jwt.utils import token_error @@ -28,14 +29,21 @@ class SchemaInputService: - def __init__(self, values: SCHEMA_INPUT, model_config: ConfigDict) -> None: + def __init__( + self, + values: SCHEMA_INPUT, + model_config: ConfigDict, + request: Optional[HttpRequest] = None, + ) -> None: self.model_config = model_config self.values = values + self._request: Optional[HttpRequest] = request + def get_request(self) -> HttpRequest: - if self.model_config.get("extra") == "forbid": + if self.model_config.get("extra") == "forbid" and self._request is None: return service_resolver(RouteContext).request - return self.values._context.get("request") + return self._request def get_values(self) -> Dict: if self.model_config.get("extra") == "forbid": @@ -75,7 +83,7 @@ def check_user_authentication_rule(self) -> None: ) @classmethod - def validate_values(cls, request: HttpRequest, values: Dict) -> Dict: + def validate_values(cls, values: Dict) -> Dict: if user_name_field not in values and "password" not in values: raise exceptions.ValidationError( { @@ -92,16 +100,16 @@ def validate_values(cls, request: HttpRequest, values: Dict) -> Dict: if not values.get("password"): raise exceptions.ValidationError({"password": "password is required"}) - _user = authenticate(request, **values) - cls._user = _user + return values + + def authenticate(self, request: HttpRequest, credentials: Dict) -> None: + self._user = authenticate(request, **credentials) - if not (_user is not None and _user.is_active): + if not (self._user is not None and self._user.is_active): raise exceptions.AuthenticationFailed( - cls._default_error_messages["no_active_account"] + self._default_error_messages["no_active_account"] ) - return values - def output_schema(self) -> Schema: warnings.warn( "output_schema() is deprecated in favor of to_response_schema()", @@ -119,28 +127,37 @@ def get_token(cls, user: AbstractUser) -> Dict: class TokenObtainInputSchemaBase(ModelSchema, TokenInputSchemaMixin): class Config: - # extra = "allow" + # extra = "forbid" model = get_user_model() model_fields = ["password", user_name_field] - extra = "forbid" @model_validator(mode="before") def validate_inputs(cls, values: SCHEMA_INPUT) -> dict: schema_input = SchemaInputService(values, cls.model_config) input_values = schema_input.get_values() - request = schema_input.get_request() if isinstance(input_values, dict): - values.update(cls.validate_values(request=request, values=input_values)) - return values + cls.validate_values(values=input_values) return values @model_validator(mode="after") - def post_validate(cls, values: Dict) -> dict: - return cls.post_validate_schema(values) + def post_validate( + cls, values: "TokenObtainInputSchemaBase", info: ValidationInfo + ) -> BaseModel: + schema_input = SchemaInputService( + values.model_dump(), cls.model_config, info.context.get("request") + ) + + credentials = schema_input.get_values() + request = schema_input.get_request() + + values.authenticate(request, credentials) + cls.post_validate_schema(values) + + return values @classmethod - def post_validate_schema(cls, values: Dict) -> dict: + def post_validate_schema(cls, values: "TokenObtainInputSchemaBase") -> None: """ This is a post validate process which is common for any token generating schema. :param values: @@ -148,7 +165,7 @@ def post_validate_schema(cls, values: Dict) -> dict: """ # get_token can return values that wants to apply to `OutputSchema` - data = cls.get_token(cls._user) + data = cls.get_token(values._user) if not isinstance(data, dict): raise Exception("`get_token` must return a `typing.Dict` type.") @@ -158,9 +175,7 @@ def post_validate_schema(cls, values: Dict) -> dict: values.__dict__.update(token_data=data) if api_settings.UPDATE_LAST_LOGIN: - update_last_login(None, cls._user) - - return values + update_last_login(None, values._user) def get_response_schema_init_kwargs(self) -> dict: return dict(