2828 ) from _err
2929
3030from dash .fingerprint import check_fingerprint
31- from dash import _validate
31+ from dash import _validate , get_app
3232from dash .exceptions import PreventUpdate
3333from .base_server import BaseDashServer , RequestAdapter , ResponseAdapter
3434from ._utils import format_traceback_html
35+ import traceback
3536
3637if TYPE_CHECKING : # pragma: no cover - typing only
3738 from dash import Dash
@@ -122,8 +123,12 @@ async def _initialize_dev_tools(self) -> None:
122123 self .dash_app .enable_dev_tools (** config , first_run = False )
123124 self ._dev_tools_initialized = True
124125
125- def _setup_timing (self , request : Request ) -> None :
126+ async def _setup_timing (self , request : Request ) -> None :
126127 """Set up timing information for the request."""
128+ try :
129+ request .state .json_body = await request .json () if request .headers .get ("content-type" , "" ).startswith ("application/json" ) else None
130+ except :
131+ request .state .json_body = None
127132 if self .enable_timing :
128133 request .state .timing_information = {
129134 "__dash_server" : {"dur" : time .time (), "desc" : None }
@@ -179,6 +184,12 @@ async def _handle_error(
179184 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
180185 # Handle lifespan events (startup/shutdown)
181186 if scope ["type" ] == "lifespan" :
187+ try :
188+ dash_app = get_app ()
189+ dash_app .backend ._setup_catchall ()
190+ except :
191+ print ("Error during catch-all setup:" )
192+ print (traceback .format_exc ())
182193 await self ._initialize_dev_tools ()
183194 await self .app (scope , receive , send )
184195 return
@@ -193,7 +204,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
193204 token = set_current_request (request )
194205
195206 try :
196- self ._setup_timing (request )
207+ await self ._setup_timing (request )
197208 await self ._run_before_hooks ()
198209
199210 await self .app (scope , receive , send )
@@ -275,11 +286,24 @@ async def index(_request: Request):
275286 dash_app ._add_url ("" , index , methods = ["GET" ])
276287
277288 def setup_catchall (self , dash_app : Dash ):
278- async def catchall (_request : Request ):
279- return Response (content = dash_app .index (), media_type = "text/html" )
289+ '''This is needed to ensure that all routes are handled by FastAPI
290+ and passed through the middleware, which is necessary for features like authentication
291+ and timing to work correctly on all routes. FastAPI will match this catch-all route
292+ for any path that isn't matched by a more specific route, allowing the middleware to
293+ process the request and then return the appropriate response (e.g., 404 if no Dash route matches).'''
280294
281- # pylint: disable=protected-access
282- dash_app ._add_url ("{path:path}" , catchall , methods = ["GET" ])
295+
296+ def _setup_catchall (self ):
297+ try :
298+ print ("Setting up catch-all route for unmatched paths" )
299+ dash_app = get_app ()
300+ async def catchall (_request : Request ):
301+ return Response (content = dash_app .index (), media_type = "text/html" )
302+
303+ # pylint: disable=protected-access
304+ self .add_url_rule ("{path:path}" , catchall , methods = ["GET" ])
305+ except :
306+ print (traceback .format_exc ())
283307
284308 def add_url_rule (
285309 self ,
@@ -289,6 +313,7 @@ def add_url_rule(
289313 methods : list [str ] | None = None ,
290314 include_in_schema : bool = False ,
291315 ):
316+ print (f"Adding URL rule: { rule } -> { view_func } (endpoint: { endpoint } , methods: { methods } )" )
292317 if rule == "" :
293318 rule = "/"
294319 if isinstance (view_func , str ):
@@ -481,7 +506,7 @@ def add_redirect_rule(self, app, fullname, path):
481506 def serve_callback (self , dash_app : Dash ):
482507 async def _dispatch (request : Request ):
483508 # pylint: disable=protected-access
484- body = await request . json ()
509+ body = self . request_adapter (). get_json ()
485510 cb_ctx = dash_app ._initialize_context (
486511 body
487512 ) # pylint: disable=protected-access
@@ -641,5 +666,13 @@ def origin(self):
641666 def path (self ):
642667 return self ._request .url .path
643668
669+ async def _get_json (self , request : Request = None ):
670+ req = self ._request
671+ if not hasattr (req .state , "json_body" ):
672+ req .state .json_body = await request .json ()
673+ return req .state .json_body
674+
644675 def get_json (self ):
645- return asyncio .run (self ._request .json ())
676+ if not hasattr (self , "_request" ) or self ._request is None :
677+ self ._request = get_current_request ()
678+ return self ._request .state .json_body
0 commit comments