@@ -696,3 +696,111 @@ def clear_cache(prompt=False):
696696 shutil .rmtree (cachedir )
697697 else :
698698 print ("Not removing cached libraries" )
699+
700+
701+ @collective
702+ def get_prepared_cuda_function (comm , global_kernel ):
703+ from pycuda .compiler import SourceModule
704+
705+ # Determine cache key
706+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
707+ basename = hsh .hexdigest ()
708+ cachedir = configuration ["cache_dir" ]
709+ dirpart , basename = basename [:2 ], basename [2 :]
710+ cachedir = os .path .join (cachedir , dirpart )
711+ cname = os .path .join (cachedir , f"{ basename } _code.cu" )
712+
713+ nvcc_opts = ["-use_fast_math" , "-w" ]
714+
715+ if configuration ["check_src_hashes" ] or configuration ["debug" ]:
716+ matching = comm .allreduce (basename , op = _check_op )
717+ if matching != basename :
718+ # Dump all src code to disk for debugging
719+ output = os .path .join (cachedir , "mismatching-kernels" )
720+ srcfile = os .path .join (output , "src-rank%d.cu" % comm .rank )
721+ if comm .rank == 0 :
722+ os .makedirs (output , exist_ok = True )
723+ comm .barrier ()
724+ with open (srcfile , "w" ) as f :
725+ f .write (global_kernel .code_to_compile )
726+ comm .barrier ()
727+ raise CompilationError ("Generated code differs across ranks"
728+ f" (see output in { output } )" )
729+
730+ if os .path .isfile (cname ):
731+ # Are we in the cache?
732+ with open (cname , "r" ) as f :
733+ source_module = SourceModule (f .read (), options = nvcc_opts ,
734+ cache_dir = cachedir )
735+ else :
736+ # No, let"s go ahead and build
737+ if comm .rank == 0 :
738+ # No need to do this on all ranks
739+ os .makedirs (cachedir , exist_ok = True )
740+ with progress (INFO , "Compiling wrapper" ):
741+ # make sure that compiles successfully before writing to file
742+ source_module = SourceModule (global_kernel .code_to_compile ,
743+ options = nvcc_opts , cache_dir = cachedir )
744+ with open (cname , "w" ) as f :
745+ f .write (global_kernel .code_to_compile )
746+ comm .barrier ()
747+
748+ cu_func = source_module .get_function (global_kernel .name )
749+
750+ type_map = {ctypes .c_void_p : "P" , ctypes .c_int : "i" }
751+ argtypes = "" .join (type_map [t ] for t in global_kernel .argtypes )
752+ cu_func .prepare (argtypes )
753+
754+ return cu_func
755+
756+
757+ @collective
758+ def get_opencl_kernel (comm , global_kernel ):
759+ import pyopencl as cl
760+ from pyop2 .backends .opencl import opencl_backend
761+ cl_ctx = opencl_backend .context
762+
763+ # Determine cache key
764+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
765+ basename = hsh .hexdigest ()
766+ cachedir = configuration ["cache_dir" ]
767+ dirpart , basename = basename [:2 ], basename [2 :]
768+ cachedir = os .path .join (cachedir , dirpart )
769+ cname = os .path .join (cachedir , f"{ basename } _code.cl" )
770+
771+ if configuration ["check_src_hashes" ] or configuration ["debug" ]:
772+ matching = comm .allreduce (basename , op = _check_op )
773+ if matching != basename :
774+ # Dump all src code to disk for debugging
775+ output = os .path .join (cachedir , "mismatching-kernels" )
776+ srcfile = os .path .join (output , "src-rank%d.cl" % comm .rank )
777+ if comm .rank == 0 :
778+ os .makedirs (output , exist_ok = True )
779+ comm .barrier ()
780+ with open (srcfile , "w" ) as f :
781+ f .write (global_kernel .code_to_compile )
782+ comm .barrier ()
783+ raise CompilationError ("Generated code differs across ranks"
784+ f" (see output in { output } )" )
785+
786+ if os .path .isfile (cname ):
787+ # Are we in the cache?
788+ with open (cname , "r" ) as f :
789+ prg = cl .Program (cl_ctx , f .read ()).build (options = [],
790+ cache_dir = cachedir )
791+ else :
792+ # No, let"s go ahead and build
793+ if comm .rank == 0 :
794+ # No need to do this on all ranks
795+ os .makedirs (cachedir , exist_ok = True )
796+ with progress (INFO , "Compiling wrapper" ):
797+ # make sure that compiles successfully before writing to file\
798+ prg = (cl .Program (cl_ctx ,
799+ global_kernel .code_to_compile )
800+ .build (options = [], cache_dir = cachedir ))
801+ with open (cname , "w" ) as f :
802+ f .write (global_kernel .code_to_compile )
803+ comm .barrier ()
804+
805+ cl_knl = cl .Kernel (prg , global_kernel .name )
806+ return cl_knl
0 commit comments