@@ -330,19 +330,23 @@ def cached(
330330 )
331331 return self ._memory_cache [key ]
332332
333- if verbose :
334- self .logger .debug (f"Checking disk cache for kernel { get_prim_func_name (func , '<unknown>' )} " )
333+ if verbose :
334+ self .logger .debug (f"Checking disk cache for kernel { get_prim_func_name (func , '<unknown>' )} " )
335335
336- # Then check disk cache
337- kernel = self ._load_kernel_from_disk (
338- key , target , target_host , out_idx , execution_backend , pass_configs , compile_flags , func , verbose
339- )
340- if kernel is not None :
341- if verbose :
342- self .logger .debug (f"Found kernel in disk cache for { get_prim_func_name (func , '<unknown>' )} " )
343- # Populate memory cache with disk result
336+ # Disk loads can be expensive for large kernel sets; keep them outside
337+ # the global cache lock so independent cache hits can proceed in parallel.
338+ kernel = self ._load_kernel_from_disk (
339+ key , target , target_host , out_idx , execution_backend , pass_configs , compile_flags , func , verbose
340+ )
341+ if kernel is not None :
342+ if verbose :
343+ self .logger .debug (f"Found kernel in disk cache for { get_prim_func_name (func , '<unknown>' )} " )
344+ with self ._lock :
345+ existing = self ._memory_cache .get (key )
346+ if existing is not None :
347+ return existing
344348 self ._memory_cache [key ] = kernel
345- return kernel
349+ return kernel
346350
347351 if verbose :
348352 self .logger .debug (f"No cached kernel for { get_prim_func_name (func , '<unknown>' )} " )
@@ -518,9 +522,6 @@ def _load_kernel_from_disk(
518522 if not all ([os .path .exists (file ) for file in required_files ]):
519523 return None
520524
521- # Load the kernel source file (optional)
522- device_kernel_source , host_kernel_source = self ._load_kernel_source (device_kernel_path , host_kernel_path , verbose )
523-
524525 # Load kernel parameters
525526 kernel_params : list [KernelParam ] | None = None
526527 try :
@@ -533,8 +534,10 @@ def _load_kernel_from_disk(
533534
534535 return self ._build_kernel (
535536 func = func ,
536- host_kernel_source = host_kernel_source ,
537- device_kernel_source = device_kernel_source ,
537+ host_kernel_source = None ,
538+ device_kernel_source = None ,
539+ host_kernel_path = host_kernel_path ,
540+ device_kernel_path = device_kernel_path ,
538541 kernel_lib_path = kernel_lib_path ,
539542 kernel_params = kernel_params ,
540543 target = target ,
@@ -638,8 +641,10 @@ def _set_adapter_cache_path(self, kernel: JITKernel, cache_path: str):
638641 def _build_kernel (
639642 self ,
640643 func : Callable | None ,
641- host_kernel_source : str ,
642- device_kernel_source : str ,
644+ host_kernel_source : str | None ,
645+ device_kernel_source : str | None ,
646+ host_kernel_path : str | None ,
647+ device_kernel_path : str | None ,
643648 kernel_lib_path : str ,
644649 kernel_params : list [KernelParam ] | None ,
645650 target : str | Target ,
@@ -651,10 +656,6 @@ def _build_kernel(
651656 ) -> JITKernel | None :
652657 # Check all required components and report specific failures
653658 missing_components = []
654- if not host_kernel_source :
655- missing_components .append ("host_kernel_source" )
656- if not device_kernel_source :
657- missing_components .append ("device_kernel_source" )
658659 if not kernel_params :
659660 missing_components .append ("kernel_params" )
660661
@@ -674,4 +675,6 @@ def _build_kernel(
674675 execution_backend = execution_backend ,
675676 pass_configs = pass_configs ,
676677 compile_flags = compile_flags ,
678+ host_kernel_source_path = host_kernel_path ,
679+ device_kernel_source_path = device_kernel_path ,
677680 )
0 commit comments