@@ -147,6 +147,29 @@ def sniff_compiler(exe):
147147 return compiler
148148
149149
150+ def _check_src_hashes (comm , global_kernel ):
151+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
152+ basename = hsh .hexdigest ()
153+ dirpart , basename = basename [:2 ], basename [2 :]
154+ cachedir = configuration ["cache_dir" ]
155+ cachedir = os .path .join (cachedir , dirpart )
156+
157+ if configuration ["check_src_hashes" ] or configuration ["debug" ]:
158+ matching = comm .allreduce (basename , op = _check_op )
159+ if matching != basename :
160+ # Dump all src code to disk for debugging
161+ output = os .path .join (cachedir , "mismatching-kernels" )
162+ srcfile = os .path .join (output , "src-rank%d.c" % comm .rank )
163+ if comm .rank == 0 :
164+ os .makedirs (output , exist_ok = True )
165+ comm .barrier ()
166+ with open (srcfile , "w" ) as f :
167+ f .write (global_kernel .code_to_compile )
168+ comm .barrier ()
169+ raise CompilationError ("Generated code differs across ranks"
170+ f" (see output in { output } )" )
171+
172+
150173class Compiler (ABC ):
151174 """A compiler for shared libraries.
152175
@@ -317,19 +340,8 @@ def get_so(self, jitmodule, extension):
317340 # atomically (avoiding races).
318341 tmpname = os .path .join (cachedir , "%s_p%d.so.tmp" % (basename , pid ))
319342
320- if configuration ['check_src_hashes' ] or configuration ['debug' ]:
321- matching = self .comm .allreduce (basename , op = _check_op )
322- if matching != basename :
323- # Dump all src code to disk for debugging
324- output = os .path .join (configuration ["cache_dir" ], "mismatching-kernels" )
325- srcfile = os .path .join (output , "src-rank%d.c" % self .comm .rank )
326- if self .comm .rank == 0 :
327- os .makedirs (output , exist_ok = True )
328- self .comm .barrier ()
329- with open (srcfile , "w" ) as f :
330- f .write (jitmodule .code_to_compile )
331- self .comm .barrier ()
332- raise CompilationError ("Generated code differs across ranks (see output in %s)" % output )
343+ _check_src_hashes (self .comm , jitmodule )
344+
333345 try :
334346 # Are we in the cache?
335347 return ctypes .CDLL (soname )
@@ -662,3 +674,81 @@ def clear_cache(prompt=False):
662674 shutil .rmtree (cachedir )
663675 else :
664676 print ("Not removing cached libraries" )
677+
678+
679+ def _get_code_to_compile (comm , global_kernel ):
680+ # Determine cache key
681+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
682+ basename = hsh .hexdigest ()
683+ cachedir = configuration ["cache_dir" ]
684+ dirpart , basename = basename [:2 ], basename [2 :]
685+ cachedir = os .path .join (cachedir , dirpart )
686+ cname = os .path .join (cachedir , f"{ basename } _code.cu" )
687+
688+ _check_src_hashes (comm , global_kernel )
689+
690+ if os .path .isfile (cname ):
691+ # Are we in the cache?
692+ with open (cname , "r" ) as f :
693+ code_to_compile = f .read ()
694+ else :
695+ # No, let"s go ahead and build
696+ if comm .rank == 0 :
697+ # No need to do this on all ranks
698+ os .makedirs (cachedir , exist_ok = True )
699+ with progress (INFO , "Compiling wrapper" ):
700+ # make sure that compiles successfully before writing to file
701+ code_to_compile = global_kernel .code_to_compile
702+ with open (cname , "w" ) as f :
703+ f .write (code_to_compile )
704+ comm .barrier ()
705+
706+ return code_to_compile
707+
708+
709+ @mpi .collective
710+ def get_prepared_cuda_function (comm , global_kernel ):
711+ from pycuda .compiler import SourceModule
712+
713+ # Determine cache key
714+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
715+ basename = hsh .hexdigest ()
716+ cachedir = configuration ["cache_dir" ]
717+ dirpart , basename = basename [:2 ], basename [2 :]
718+ cachedir = os .path .join (cachedir , dirpart )
719+
720+ nvcc_opts = ["-use_fast_math" , "-w" ]
721+
722+ code_to_compile = _get_code_to_compile (comm , global_kernel )
723+ source_module = SourceModule (code_to_compile , options = nvcc_opts ,
724+ cache_dir = cachedir )
725+
726+ cu_func = source_module .get_function (global_kernel .name )
727+
728+ type_map = {ctypes .c_void_p : "P" , ctypes .c_int : "i" }
729+ argtypes = "" .join (type_map [t ] for t in global_kernel .argtypes )
730+ cu_func .prepare (argtypes )
731+
732+ return cu_func
733+
734+
735+ @mpi .collective
736+ def get_opencl_kernel (comm , global_kernel ):
737+ import pyopencl as cl
738+ from pyop2 .backends .opencl import opencl_backend
739+ cl_ctx = opencl_backend .context
740+
741+ # Determine cache key
742+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
743+ basename = hsh .hexdigest ()
744+ cachedir = configuration ["cache_dir" ]
745+ dirpart , basename = basename [:2 ], basename [2 :]
746+ cachedir = os .path .join (cachedir , dirpart )
747+
748+ code_to_compile = _get_code_to_compile (comm , global_kernel )
749+
750+ prg = cl .Program (cl_ctx , code_to_compile ).build (options = [],
751+ cache_dir = cachedir )
752+
753+ cl_knl = cl .Kernel (prg , global_kernel .name )
754+ return cl_knl
0 commit comments