11import datetime
22import json
3+ import time
34from abc import ABC , abstractmethod
45from datetime import timedelta
56from typing import Any
67from uuid import uuid4
78
9+ from cryptojwt .jwk .jwk import key_from_jwk_dict
810from jinja2 import Template
911from pydantic import ValidationError
1012from pymdoccbor .mdoc .issuer import MdocCborIssuer
13+
14+ from pyeudiw .jwt .utils import decode_jwt_header , decode_jwt_payload
1115from satosa .context import Context
1216from satosa .response import Response
1317
1822 POST_ACCEPTED_METHODS ,
1923 VCIBaseEndpoint ,
2024)
21- from pyeudiw .satosa .frontends .openid4vci .models .credential_endpoint_request import (
22- CredentialEndpointRequest ,
23- )
2425from pyeudiw .satosa .frontends .openid4vci .models .openid4vci_basemodel import (
25- OpenId4VciBaseModel ,
26+ OpenId4VciBaseModel , ENDPOINT_CTX , CONFIG_CTX ,
2627)
2728from pyeudiw .satosa .frontends .openid4vci .storage .engine import OpenId4VciDBEngineHandler
2829from pyeudiw .satosa .frontends .openid4vci .storage .entity import AuthorizationSession
3738 CredentialConfiguration ,
3839 CredentialConfigurationFormatEnum ,
3940)
40- from pyeudiw .satosa .utils .session import get_session_id
4141from pyeudiw .satosa .utils .validation import (
4242 validate_content_type ,
43- validate_request_method ,
43+ validate_request_method , DPOP_HEADER , AUTHORIZATION_HEADER ,
4444)
4545from pyeudiw .sd_jwt .issuer import SDJWTIssuer
4646from pyeudiw .sd_jwt .utils .yaml_specification import (
6363
6464class BaseCredentialEndpoint (ABC , VCIBaseEndpoint ):
6565
66+ _ENDPOINT_NAME = "credential"
67+
6668 def __init__ (
6769 self ,
6870 config : dict ,
@@ -111,15 +113,17 @@ def endpoint(self, context: Context) -> Response:
111113 if self .dpop_required :
112114 if (
113115 not context .http_headers
114- or ("DPoP" not in context .http_headers )
115- or ("Authorization" not in context .http_headers )
116+ or (DPOP_HEADER not in context .http_headers )
117+ or (AUTHORIZATION_HEADER not in context .http_headers )
116118 ):
117119 raise InvalidRequestException (
118120 "Missing DPoP and/or Authorization header"
119121 )
120122
121- dpop = context .http_headers .get ("DPoP" )
122- authz = context .http_headers .get ("Authorization" )
123+ dpop = context .http_headers .get (DPOP_HEADER )
124+ authz = context .http_headers .get (AUTHORIZATION_HEADER )
125+ if not dpop or not authz :
126+ raise InvalidRequestException ("Invalid headers" )
123127
124128 try :
125129 dpop_verifier = DPoPVerifier (
@@ -134,14 +138,54 @@ def endpoint(self, context: Context) -> Response:
134138 )
135139 return self ._handle_400 (context , str (e ), e )
136140
137- entity = self .db_engine .get_by_session_id (get_session_id (context ))
138- req = self .validate_request (context , entity )
139- credential_id = None
140- if isinstance (req , CredentialEndpointRequest ):
141- credential_id = (
142- req .credential_identifier or req .credential_configuration_id
143- )
144- return self .to_response (context , entity , credential_id )
141+ auth_token = decode_jwt_payload (dpop_verifier .dpop_authz_token )
142+ entity = self .db_engine .search_session_by_field ("access_token_jti" , auth_token .get ("jti" ))
143+ auth_session = AuthorizationSession .model_validate (entity , context = {
144+ ENDPOINT_CTX : self ._ENDPOINT_NAME ,
145+ CONFIG_CTX : self .config
146+ })
147+
148+ data = self ._get_body (context ) or {}
149+ credential_identifier = data .get ("credential_identifier" ) or ""
150+ # TODO: check/validate scope
151+ if data .get ("credential_configuration_id" ):
152+ credential_configuration_id = data
153+ else : # validate credential_identifier with authorization_details of token
154+ if auth_session .authorization_details :
155+ for auth_details in auth_session .authorization_details :
156+ if auth_details .credential_identifiers :
157+ if credential_identifier in auth_details .credential_identifiers :
158+ credential_configuration_id = "_" .join (credential_identifier .split ("_" )[:- 1 ])
159+ break
160+ else :
161+ raise InvalidRequestException (
162+ "credential_identifier not match with token authorization_details" )
163+ else :
164+ raise InvalidRequestException ("Invalid credential_configuration_id" )
165+
166+ self .validate_request (context , entity )
167+
168+ proof_jwt = data .get ("proof" , {}).get ("jwt" ) or ""
169+ request_header = decode_jwt_header (proof_jwt )
170+ request_payload = decode_jwt_payload (proof_jwt )
171+ client_id = request_payload .get ("iss" )
172+
173+ #validate nonce
174+ self ._consume_nonce (request_payload .get ("nonce" ))
175+
176+ #validate client --> todo: move to self.validate_request
177+ if not (key_attestation := request_header .get ("key_attestation" )):
178+ return self ._handle_400 (context , "invalid key_attestation" , InvalidRequestException ("invalid_proof" ))
179+
180+ k_payload = decode_jwt_payload (key_attestation )
181+ for _k in k_payload .get ("attested_keys" ) or []:
182+ t_print = key_from_jwk_dict (_k ).thumbprint ("SHA-256" ).decode ()
183+ if t_print == client_id :
184+ break
185+ else :
186+ return self ._handle_400 (context , "client_id mismatch" , InvalidRequestException ("invalid_proof" ))
187+
188+ return self .to_response (context , auth_session , credential_configuration_id )
145189
146190 except (
147191 InvalidRequestException ,
@@ -160,6 +204,17 @@ def endpoint(self, context: Context) -> Response:
160204 context , "error during invoke credential endpoint" , e
161205 )
162206
207+ def _consume_nonce (self , nonce ):
208+ if not (found_nonce := self .db_engine .get ("get_nonce" , nonce )):
209+ raise InvalidRequestException ("Invalid nonce" )
210+
211+ now = round (time .time () * 1000 )
212+ if found_nonce ["created_at" ] + found_nonce ["expires_in" ] <= now :
213+ raise InvalidRequestException ("Expired nonce" )
214+
215+ if self .db_engine .write ("consume_nonce" , nonce , now ) < 1 :
216+ raise Exception ("Unable to consume nonce, storage error" )
217+
163218 @abstractmethod
164219 def validate_request (self , context : Context , entity : dict ) -> OpenId4VciBaseModel :
165220 pass
@@ -171,20 +226,16 @@ def to_response(
171226 pass
172227
173228 def build_credential (
174- self , context : Context , credential_id : str | None
229+ self , vci_entity : AuthorizationSession , credential_id : str | None
175230 ) -> list [str ]:
176231 credential_list = []
177- entity = self .db_engine .get_by_session_id (get_session_id (context ))
178-
179- if not entity :
232+ if not vci_entity :
180233 self ._log_error (
181234 self .__class__ .__name__ , "No entity found for the current session."
182235 )
183236 return credential_list
184237
185- vci_entity = AuthorizationSession (** entity )
186-
187- user = self ._db_user_engine .get_by_fields (
238+ user = self ._db_user_engine .get ("get_by_fields" ,
188239 self ._extract_lookup_identifiers (vci_entity .attributes or {})
189240 )
190241 if credential_id :
@@ -315,11 +366,11 @@ def _loader(
315366 )
316367
317368 def _build_status_list_payload (self , user_id : str ):
318- credential = self ._db_credential_engine .get_credential_by_user_id ( user_id )
369+ # credential = self._db_credential_engine.get("get_credential_by_user_id", user_id) # todo: store credential
319370 return {
320371 "status_list" : {
321- "idx" : credential .incremental_id ,
322- "uri" : f"{ self .status_endpoint } /{ credential .incremental_id } " ,
372+ "idx" : " credential.incremental_id" ,
373+ "uri" : f"{ self .status_endpoint } /{ " credential.incremental_id" } " ,
323374 }
324375 }
325376
0 commit comments