22import json
33import os
44import uuid
5- from typing import Any , Dict
5+ from typing import Any , Callable , Dict , overload
66
77from llama_index .core .workflow import StopEvent , Workflow
88from llama_index .core .workflow .drawing import StepConfig # type: ignore
@@ -284,8 +284,14 @@ def get_event_style(event_type) -> str:
284284 return "defaultEventStyle"
285285
286286
287- async def llamaindex_init_middleware_async (entrypoint : str ) -> MiddlewareResult :
287+ async def llamaindex_init_middleware_async (
288+ entrypoint : str ,
289+ options : dict [str , Any ] | None = None ,
290+ write_config : Callable [[Any ], str ] | None = None ,
291+ ) -> MiddlewareResult :
288292 """Middleware to check for llama_index.json and create uipath.json with schemas"""
293+ options = options or {}
294+
289295 config = LlamaIndexConfig ()
290296 if not config .exists :
291297 return MiddlewareResult (
@@ -305,8 +311,9 @@ async def llamaindex_init_middleware_async(entrypoint: str) -> MiddlewareResult:
305311 loaded_workflow = await workflow .load_workflow ()
306312 schema = generate_schema_from_workflow (loaded_workflow )
307313 try :
314+ should_infer_bindings = options .get ("infer_bindings" , True )
308315 # Make sure the file path exists
309- if os .path .exists (workflow .file_path ):
316+ if os .path .exists (workflow .file_path ) and should_infer_bindings :
310317 file_bindings = generate_bindings_json (workflow .file_path )
311318 # Merge bindings
312319 if "resources" in file_bindings :
@@ -345,10 +352,13 @@ async def llamaindex_init_middleware_async(entrypoint: str) -> MiddlewareResult:
345352
346353 uipath_config = {"entryPoints" : entrypoints , "bindings" : all_bindings }
347354
348- # Save the uipath.json file
349- config_path = "uipath.json"
350- with open (config_path , "w" ) as f :
351- json .dump (uipath_config , f , indent = 2 )
355+ if write_config :
356+ config_path = write_config (uipath_config )
357+ else :
358+ # Save the uipath.json file
359+ config_path = "uipath.json"
360+ with open (config_path , "w" ) as f :
361+ json .dump (uipath_config , f , indent = 4 )
352362
353363 console .success (f" Created '{ config_path } ' file." )
354364 return MiddlewareResult (should_continue = False )
@@ -361,6 +371,24 @@ async def llamaindex_init_middleware_async(entrypoint: str) -> MiddlewareResult:
361371 )
362372
363373
364- def llamaindex_init_middleware (entrypoint : str ) -> MiddlewareResult :
374+ @overload
375+ def llamaindex_init_middleware (entrypoint : str ) -> MiddlewareResult : ...
376+
377+
378+ @overload
379+ def llamaindex_init_middleware (
380+ entrypoint : str ,
381+ options : dict [str , Any ],
382+ write_config : Callable [[Any ], str ],
383+ ) -> MiddlewareResult : ...
384+
385+
386+ def llamaindex_init_middleware (
387+ entrypoint : str ,
388+ options : dict [str , Any ] | None = None ,
389+ write_config : Callable [[Any ], str ] | None = None ,
390+ ) -> MiddlewareResult :
365391 """Middleware to check for llama_index.json and create uipath.json with schemas"""
366- return asyncio .run (llamaindex_init_middleware_async (entrypoint ))
392+ return asyncio .run (
393+ llamaindex_init_middleware_async (entrypoint , options , write_config )
394+ )
0 commit comments