2121
2222#include < cstdlib>
2323#include < fstream>
24+ #include < iomanip>
2425#include < iostream>
2526#include < sstream>
2627#include < string>
@@ -69,7 +70,7 @@ using std::ofstream;
6970using std::ostringstream;
7071
7172static int CeedCompileCore_Cuda (Ceed ceed, const char *source, const bool throw_error, bool *is_compile_good, CUmodule *module ,
72- const CeedInt num_defines, va_list args) {
73+ const CeedInt num_defines, const CeedInt num_extra, va_list args) {
7374 size_t ptx_size;
7475 char *ptx;
7576 const int num_opts = 4 ;
@@ -106,6 +107,18 @@ static int CeedCompileCore_Cuda(Ceed ceed, const char *source, const bool throw_
106107 // Standard libCEED definitions for CUDA backends
107108 code << " #include <ceed/jit-source/cuda/cuda-jit.h>\n\n " ;
108109
110+ // Insert kernel specific extra code, e.g. array constants
111+ if (num_extra > 0 ) {
112+ char *val;
113+
114+ code << " // Kernel defined extra code\n " ;
115+ for (int i = 0 ; i < num_extra; i++) {
116+ val = va_arg (args, char *);
117+ code << val << " \n " ;
118+ }
119+ code << " \n " ;
120+ }
121+
109122 // Non-macro options
110123 CeedCallBackend (CeedCalloc (num_opts, &opts));
111124 opts[0 ] = " -default-device" ;
@@ -404,12 +417,46 @@ static int CeedCompileCore_Cuda(Ceed ceed, const char *source, const bool throw_
404417 return CEED_ERROR_SUCCESS;
405418}
406419
420+ template <typename ArrayT>
421+ struct CeedArrayView {
422+ const ArrayT *array;
423+ CeedInt size;
424+
425+ CeedArrayView<ArrayT>(const ArrayT *array_, CeedInt size_) : array(array_), size(size_) {}
426+ };
427+
428+ template <typename OStream, typename ArrayT>
429+ OStream &operator <<(OStream &ostream, const CeedArrayView<ArrayT> &view) {
430+ ostream << " {" ;
431+ for (CeedInt i = 0 ; i < view.size ; i++) ostream << std::setprecision (17 ) << view.array [i] << (i == view.size - 1 ? " }" : " , " );
432+ return ostream;
433+ }
434+
435+ int CeedBuildArrayConstantSize_Cuda (Ceed ceed, const char *name, CeedInt length, const CeedSize *array, char **line) {
436+ std::ostringstream code;
437+
438+ code << " constexpr CeedSize " << name << " [" << length << " ] = " << CeedArrayView<CeedSize>(array, length) << " ;" ;
439+ CeedCallBackend (CeedStringAllocCopy (code.str ().c_str (), line));
440+ return CEED_ERROR_SUCCESS;
441+ }
442+
443+ int CeedCompileExtra_Cuda (Ceed ceed, const char *source, CUmodule *module , const CeedInt num_defines, const CeedInt num_extra, ...) {
444+ bool is_compile_good = true ;
445+ va_list args;
446+
447+ va_start (args, num_extra);
448+ const CeedInt ierr = CeedCompileCore_Cuda (ceed, source, true , &is_compile_good, module , num_defines, num_extra, args);
449+ va_end (args);
450+ CeedCallBackend (ierr);
451+ return CEED_ERROR_SUCCESS;
452+ }
453+
407454int CeedCompile_Cuda (Ceed ceed, const char *source, CUmodule *module , const CeedInt num_defines, ...) {
408455 bool is_compile_good = true ;
409456 va_list args;
410457
411458 va_start (args, num_defines);
412- const CeedInt ierr = CeedCompileCore_Cuda (ceed, source, true , &is_compile_good, module , num_defines, args);
459+ const CeedInt ierr = CeedCompileCore_Cuda (ceed, source, true , &is_compile_good, module , num_defines, 0 , args);
413460
414461 va_end (args);
415462 CeedCallBackend (ierr);
@@ -420,7 +467,7 @@ int CeedTryCompile_Cuda(Ceed ceed, const char *source, bool *is_compile_good, CU
420467 va_list args;
421468
422469 va_start (args, num_defines);
423- const CeedInt ierr = CeedCompileCore_Cuda (ceed, source, false , is_compile_good, module , num_defines, args);
470+ const CeedInt ierr = CeedCompileCore_Cuda (ceed, source, false , is_compile_good, module , num_defines, 0 , args);
424471
425472 va_end (args);
426473 CeedCallBackend (ierr);
0 commit comments