11import json
22from http import HTTPStatus
33from typing import Annotated
4+ from typing import Any
45
56from fastapi import APIRouter
67from fastapi import Depends
78from fastapi import FastAPI
89from fastapi import HTTPException
10+ from fastapi import Query
911from fastapi import Request
1012from fastapi import Response
1113from pydantic import ValidationError
1214
1315from scim2_models import Context
16+ from scim2_models import CreationRequestContext
1417from scim2_models import Error
1518from scim2_models import ListResponse
1619from scim2_models import PatchOp
20+ from scim2_models import PatchRequestContext
21+ from scim2_models import QueryResponseContext
22+ from scim2_models import ReplacementRequestContext
1723from scim2_models import ResourceType
1824from scim2_models import ResponseParameters
1925from scim2_models import Schema
2026from scim2_models import SCIMException
21- from scim2_models import SCIMSerializer
2227from scim2_models import ServiceProviderConfig
23- from scim2_models import SCIMValidator
2428from scim2_models import SearchRequest
2529from scim2_models import User
2630
@@ -46,16 +50,14 @@ class SCIMResponse(Response):
4650
4751 media_type = "application/scim+json"
4852
49- def __init__ (self , content = None , ** kwargs ):
50- if isinstance (content , (dict , list )):
51- content = json .dumps (content , ensure_ascii = False )
52- super ().__init__ (content = content , ** kwargs )
53- try :
54- meta = json .loads (content ).get ("meta" , {})
55- if version := meta .get ("version" ):
56- self .headers ["ETag" ] = version
57- except (json .JSONDecodeError , AttributeError , TypeError ):
58- pass
53+ def render (self , content : Any ) -> bytes :
54+ self ._etag = (content or {}).get ("meta" , {}).get ("version" )
55+ return json .dumps (content , ensure_ascii = False ).encode ("utf-8" )
56+
57+ def __init__ (self , content : Any = None , ** kwargs : Any ) -> None :
58+ super ().__init__ (content , ** kwargs )
59+ if self ._etag :
60+ self .headers ["ETag" ] = self ._etag
5961
6062
6163router = APIRouter (prefix = "/scim/v2" , default_response_class = SCIMResponse )
@@ -103,21 +105,21 @@ def resolve_user(user_id: str):
103105async def handle_validation_error (request , error ):
104106 """Turn Pydantic validation errors into SCIM error responses."""
105107 scim_error = Error .from_validation_error (error .errors ()[0 ])
106- return SCIMResponse (scim_error .model_dump_json (), status_code = scim_error .status )
108+ return SCIMResponse (scim_error .model_dump (), status_code = scim_error .status )
107109
108110
109111@app .exception_handler (HTTPException )
110112async def handle_http_exception (request , error ):
111113 """Turn HTTP exceptions into SCIM error responses."""
112114 scim_error = Error (status = error .status_code , detail = error .detail or "" )
113- return SCIMResponse (scim_error .model_dump_json (), status_code = error .status_code )
115+ return SCIMResponse (scim_error .model_dump (), status_code = error .status_code )
114116
115117
116118@app .exception_handler (SCIMException )
117119async def handle_scim_error (request , error ):
118120 """Turn SCIM exceptions into SCIM error responses."""
119121 scim_error = error .to_error ()
120- return SCIMResponse (scim_error .model_dump_json (), status_code = scim_error .status )
122+ return SCIMResponse (scim_error .model_dump (), status_code = scim_error .status )
121123# -- error-handlers-end --
122124# -- refinements-end --
123125
@@ -126,16 +128,19 @@ async def handle_scim_error(request, error):
126128# -- single-resource-start --
127129# -- get-user-start --
128130@router .get ("/Users/{user_id}" )
129- async def get_user (request : Request , app_record : dict = Depends (resolve_user )):
131+ async def get_user (
132+ request : Request ,
133+ req : Annotated [ResponseParameters , Query ()],
134+ app_record : dict = Depends (resolve_user ),
135+ ):
130136 """Return one SCIM user."""
131- req = ResponseParameters .model_validate (dict (request .query_params ))
132137 scim_user = to_scim_user (app_record , resource_location (request , app_record ))
133138 etag = make_etag (app_record )
134139 if_none_match = request .headers .get ("If-None-Match" )
135140 if if_none_match and etag in [t .strip () for t in if_none_match .split ("," )]:
136141 return Response (status_code = HTTPStatus .NOT_MODIFIED )
137142 return SCIMResponse (
138- scim_user .model_dump_json (
143+ scim_user .model_dump (
139144 scim_ctx = Context .RESOURCE_QUERY_RESPONSE ,
140145 attributes = req .attributes ,
141146 excluded_attributes = req .excluded_attributes ,
@@ -148,11 +153,10 @@ async def get_user(request: Request, app_record: dict = Depends(resolve_user)):
148153@router .patch ("/Users/{user_id}" )
149154async def patch_user (
150155 request : Request ,
151- patch : Annotated [
152- PatchOp [User ], SCIMValidator (Context .RESOURCE_PATCH_REQUEST )
153- ],
156+ patch : PatchRequestContext [PatchOp [User ]],
157+ req : Annotated [ResponseParameters , Query ()],
154158 app_record : dict = Depends (resolve_user ),
155- ) -> Annotated [ User , SCIMSerializer ( Context . RESOURCE_PATCH_RESPONSE )] :
159+ ):
156160 """Apply a SCIM PatchOp to an existing user."""
157161 check_etag (app_record , request )
158162 scim_user = to_scim_user (app_record , resource_location (request , app_record ))
@@ -161,19 +165,25 @@ async def patch_user(
161165 updated_record = from_scim_user (scim_user )
162166 save_record (updated_record )
163167
164- return to_scim_user (updated_record , resource_location (request , updated_record ))
168+ response_user = to_scim_user (updated_record , resource_location (request , updated_record ))
169+ return SCIMResponse (
170+ response_user .model_dump (
171+ scim_ctx = Context .RESOURCE_PATCH_RESPONSE ,
172+ attributes = req .attributes ,
173+ excluded_attributes = req .excluded_attributes ,
174+ ),
175+ )
165176# -- patch-user-end --
166177
167178
168179# -- put-user-start --
169180@router .put ("/Users/{user_id}" )
170181async def replace_user (
171182 request : Request ,
172- replacement : Annotated [
173- User , SCIMValidator (Context .RESOURCE_REPLACEMENT_REQUEST )
174- ],
183+ replacement : ReplacementRequestContext [User ],
184+ req : Annotated [ResponseParameters , Query ()],
175185 app_record : dict = Depends (resolve_user ),
176- ) -> Annotated [ User , SCIMSerializer ( Context . RESOURCE_REPLACEMENT_RESPONSE )] :
186+ ):
177187 """Replace an existing user with a full SCIM resource."""
178188 check_etag (app_record , request )
179189 existing_user = to_scim_user (app_record , resource_location (request , app_record ))
@@ -183,7 +193,14 @@ async def replace_user(
183193 updated_record = from_scim_user (replacement )
184194 save_record (updated_record )
185195
186- return to_scim_user (updated_record , resource_location (request , updated_record ))
196+ response_user = to_scim_user (updated_record , resource_location (request , updated_record ))
197+ return SCIMResponse (
198+ response_user .model_dump (
199+ scim_ctx = Context .RESOURCE_REPLACEMENT_RESPONSE ,
200+ attributes = req .attributes ,
201+ excluded_attributes = req .excluded_attributes ,
202+ ),
203+ )
187204# -- put-user-end --
188205
189206
@@ -201,9 +218,10 @@ async def delete_user(request: Request, app_record: dict = Depends(resolve_user)
201218# -- collection-start --
202219# -- list-users-start --
203220@router .get ("/Users" )
204- async def list_users (request : Request ):
221+ async def list_users (
222+ request : Request , req : Annotated [SearchRequest , Query ()]
223+ ):
205224 """Return one page of users as a SCIM ListResponse."""
206- req = SearchRequest .model_validate (dict (request .query_params ))
207225 total , page = list_records (req .start_index_0 , req .stop_index_0 )
208226 resources = [
209227 to_scim_user (record , resource_location (request , record )) for record in page
@@ -215,7 +233,7 @@ async def list_users(request: Request):
215233 resources = resources ,
216234 )
217235 return SCIMResponse (
218- response .model_dump_json (
236+ response .model_dump (
219237 scim_ctx = Context .RESOURCE_QUERY_RESPONSE ,
220238 attributes = req .attributes ,
221239 excluded_attributes = req .excluded_attributes ,
@@ -228,25 +246,31 @@ async def list_users(request: Request):
228246@router .post ("/Users" , status_code = HTTPStatus .CREATED )
229247async def create_user (
230248 request : Request ,
231- request_user : Annotated [
232- User , SCIMValidator (Context .RESOURCE_CREATION_REQUEST )
233- ],
234- ) -> Annotated [User , SCIMSerializer (Context .RESOURCE_CREATION_RESPONSE )]:
249+ request_user : CreationRequestContext [User ],
250+ req : Annotated [ResponseParameters , Query ()],
251+ ):
235252 """Validate a SCIM creation payload and store the new user."""
236253 app_record = from_scim_user (request_user )
237254 save_record (app_record )
238255
239- return to_scim_user (app_record , resource_location (request , app_record ))
256+ response_user = to_scim_user (app_record , resource_location (request , app_record ))
257+ return SCIMResponse (
258+ response_user .model_dump (
259+ scim_ctx = Context .RESOURCE_CREATION_RESPONSE ,
260+ attributes = req .attributes ,
261+ excluded_attributes = req .excluded_attributes ,
262+ ),
263+ status_code = HTTPStatus .CREATED ,
264+ )
240265# -- create-user-end --
241266# -- collection-end --
242267
243268
244269# -- discovery-start --
245270# -- schemas-start --
246271@router .get ("/Schemas" )
247- async def list_schemas (request : Request ):
272+ async def list_schemas (req : Annotated [ SearchRequest , Query ()] ):
248273 """Return one page of SCIM schemas the server exposes."""
249- req = SearchRequest .model_validate (dict (request .query_params ))
250274 total , page = get_schemas (req .start_index_0 , req .stop_index_0 )
251275 response = ListResponse [Schema ](
252276 total_results = total ,
@@ -255,7 +279,7 @@ async def list_schemas(request: Request):
255279 resources = page ,
256280 )
257281 return SCIMResponse (
258- response .model_dump_json (scim_ctx = Context .RESOURCE_QUERY_RESPONSE ),
282+ response .model_dump (scim_ctx = Context .RESOURCE_QUERY_RESPONSE ),
259283 )
260284
261285
@@ -266,18 +290,17 @@ async def get_schema_by_id(schema_id: str):
266290 schema = get_schema (schema_id )
267291 except KeyError :
268292 scim_error = Error (status = 404 , detail = f"Schema { schema_id !r} not found" )
269- return SCIMResponse (scim_error .model_dump_json (), status_code = HTTPStatus .NOT_FOUND )
293+ return SCIMResponse (scim_error .model_dump (), status_code = HTTPStatus .NOT_FOUND )
270294 return SCIMResponse (
271- schema .model_dump_json (scim_ctx = Context .RESOURCE_QUERY_RESPONSE ),
295+ schema .model_dump (scim_ctx = Context .RESOURCE_QUERY_RESPONSE ),
272296 )
273297# -- schemas-end --
274298
275299
276300# -- resource-types-start --
277301@router .get ("/ResourceTypes" )
278- async def list_resource_types (request : Request ):
302+ async def list_resource_types (req : Annotated [ SearchRequest , Query ()] ):
279303 """Return one page of SCIM resource types the server exposes."""
280- req = SearchRequest .model_validate (dict (request .query_params ))
281304 total , page = get_resource_types (req .start_index_0 , req .stop_index_0 )
282305 response = ListResponse [ResourceType ](
283306 total_results = total ,
@@ -286,7 +309,7 @@ async def list_resource_types(request: Request):
286309 resources = page ,
287310 )
288311 return SCIMResponse (
289- response .model_dump_json (scim_ctx = Context .RESOURCE_QUERY_RESPONSE ),
312+ response .model_dump (scim_ctx = Context .RESOURCE_QUERY_RESPONSE ),
290313 )
291314
292315
@@ -299,17 +322,17 @@ async def get_resource_type_by_id(resource_type_id: str):
299322 scim_error = Error (
300323 status = 404 , detail = f"ResourceType { resource_type_id !r} not found"
301324 )
302- return SCIMResponse (scim_error .model_dump_json (), status_code = HTTPStatus .NOT_FOUND )
325+ return SCIMResponse (scim_error .model_dump (), status_code = HTTPStatus .NOT_FOUND )
303326 return SCIMResponse (
304- rt .model_dump_json (scim_ctx = Context .RESOURCE_QUERY_RESPONSE ),
327+ rt .model_dump (scim_ctx = Context .RESOURCE_QUERY_RESPONSE ),
305328 )
306329# -- resource-types-end --
307330
308331
309332# -- service-provider-config-start --
310333@router .get ("/ServiceProviderConfig" )
311- async def get_service_provider_config () -> Annotated [
312- ServiceProviderConfig , SCIMSerializer ( Context . RESOURCE_QUERY_RESPONSE )
334+ async def get_service_provider_config () -> QueryResponseContext [
335+ ServiceProviderConfig
313336]:
314337 """Return the SCIM service provider configuration."""
315338 return service_provider_config
0 commit comments