Skip to content

Commit d9b3770

Browse files
committed
cuda: Add support for constexpr array compilation inputs
1 parent 171bc85 commit d9b3770

2 files changed

Lines changed: 53 additions & 3 deletions

File tree

backends/cuda/ceed-cuda-compile.cpp

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <cstdlib>
2323
#include <fstream>
24+
#include <iomanip>
2425
#include <iostream>
2526
#include <sstream>
2627
#include <string>
@@ -69,7 +70,7 @@ using std::ofstream;
6970
using std::ostringstream;
7071

7172
static int CeedCompileCore_Cuda(Ceed ceed, const char *source, const bool throw_error, bool *is_compile_good, CUmodule *module,
72-
const CeedInt num_defines, va_list args) {
73+
const CeedInt num_defines, const CeedInt num_extra, va_list args) {
7374
size_t ptx_size;
7475
char *ptx;
7576
const int num_opts = 4;
@@ -106,6 +107,18 @@ static int CeedCompileCore_Cuda(Ceed ceed, const char *source, const bool throw_
106107
// Standard libCEED definitions for CUDA backends
107108
code << "#include <ceed/jit-source/cuda/cuda-jit.h>\n\n";
108109

110+
// Insert kernel specific extra code, e.g. array constants
111+
if (num_extra > 0) {
112+
char *val;
113+
114+
code << "// Kernel defined extra code\n";
115+
for (int i = 0; i < num_extra; i++) {
116+
val = va_arg(args, char *);
117+
code << val << "\n";
118+
}
119+
code << "\n";
120+
}
121+
109122
// Non-macro options
110123
CeedCallBackend(CeedCalloc(num_opts, &opts));
111124
opts[0] = "-default-device";
@@ -404,12 +417,46 @@ static int CeedCompileCore_Cuda(Ceed ceed, const char *source, const bool throw_
404417
return CEED_ERROR_SUCCESS;
405418
}
406419

420+
template <typename ArrayT>
421+
struct CeedArrayView {
422+
const ArrayT *array;
423+
CeedInt size;
424+
425+
CeedArrayView<ArrayT>(const ArrayT *array_, CeedInt size_) : array(array_), size(size_) {}
426+
};
427+
428+
template <typename OStream, typename ArrayT>
429+
OStream &operator<<(OStream &ostream, const CeedArrayView<ArrayT> &view) {
430+
ostream << "{";
431+
for (CeedInt i = 0; i < view.size; i++) ostream << std::setprecision(17) << view.array[i] << (i == view.size - 1 ? "}" : ", ");
432+
return ostream;
433+
}
434+
435+
int CeedBuildArrayConstantSize_Cuda(Ceed ceed, const char *name, CeedInt length, const CeedSize *array, char **line) {
436+
std::ostringstream code;
437+
438+
code << "constexpr CeedSize " << name << "[" << length << "] = " << CeedArrayView<CeedSize>(array, length) << ";";
439+
CeedCallBackend(CeedStringAllocCopy(code.str().c_str(), line));
440+
return CEED_ERROR_SUCCESS;
441+
}
442+
443+
int CeedCompileExtra_Cuda(Ceed ceed, const char *source, CUmodule *module, const CeedInt num_defines, const CeedInt num_extra, ...) {
444+
bool is_compile_good = true;
445+
va_list args;
446+
447+
va_start(args, num_extra);
448+
const CeedInt ierr = CeedCompileCore_Cuda(ceed, source, true, &is_compile_good, module, num_defines, num_extra, args);
449+
va_end(args);
450+
CeedCallBackend(ierr);
451+
return CEED_ERROR_SUCCESS;
452+
}
453+
407454
int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const CeedInt num_defines, ...) {
408455
bool is_compile_good = true;
409456
va_list args;
410457

411458
va_start(args, num_defines);
412-
const CeedInt ierr = CeedCompileCore_Cuda(ceed, source, true, &is_compile_good, module, num_defines, args);
459+
const CeedInt ierr = CeedCompileCore_Cuda(ceed, source, true, &is_compile_good, module, num_defines, 0, args);
413460

414461
va_end(args);
415462
CeedCallBackend(ierr);
@@ -420,7 +467,7 @@ int CeedTryCompile_Cuda(Ceed ceed, const char *source, bool *is_compile_good, CU
420467
va_list args;
421468

422469
va_start(args, num_defines);
423-
const CeedInt ierr = CeedCompileCore_Cuda(ceed, source, false, is_compile_good, module, num_defines, args);
470+
const CeedInt ierr = CeedCompileCore_Cuda(ceed, source, false, is_compile_good, module, num_defines, 0, args);
424471

425472
va_end(args);
426473
CeedCallBackend(ierr);

backends/cuda/ceed-cuda-compile.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
static inline CeedInt CeedDivUpInt(CeedInt numerator, CeedInt denominator) { return (numerator + denominator - 1) / denominator; }
1414

15+
CEED_INTERN int CeedBuildArrayConstantSize_Cuda(Ceed ceed, const char *name, CeedInt length, const CeedSize *array, char **line);
16+
17+
CEED_INTERN int CeedCompileExtra_Cuda(Ceed ceed, const char *source, CUmodule *module, const CeedInt num_defines, const CeedInt num_extra, ...);
1518
CEED_INTERN int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const CeedInt num_defines, ...);
1619
CEED_INTERN int CeedTryCompile_Cuda(Ceed ceed, const char *source, bool *is_compile_good, CUmodule *module, const CeedInt num_defines, ...);
1720

0 commit comments

Comments
 (0)