Skip to content

Commit a238442

Browse files
committed
cuda: Add support for constexpr array compilation inputs
1 parent 3e7b5a6 commit a238442

2 files changed

Lines changed: 26 additions & 0 deletions

File tree

backends/cuda/ceed-cuda-compile.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <cstdlib>
2323
#include <fstream>
24+
#include <iomanip>
2425
#include <iostream>
2526
#include <sstream>
2627
#include <string>
@@ -421,6 +422,29 @@ static int CeedCompileCore_Cuda(Ceed ceed, const char *source, const bool throw_
421422
return CEED_ERROR_SUCCESS;
422423
}
423424

425+
template <typename ArrayT>
426+
struct CeedArrayView {
427+
const ArrayT *array;
428+
CeedInt size;
429+
430+
CeedArrayView(const ArrayT *array_, CeedInt size_) : array(array_), size(size_) {}
431+
};
432+
433+
template <typename OStream, typename ArrayT>
434+
OStream &operator<<(OStream &ostream, const CeedArrayView<ArrayT> &view) {
435+
ostream << "{";
436+
for (CeedInt i = 0; i < view.size; i++) ostream << std::setprecision(17) << view.array[i] << (i == view.size - 1 ? "}" : ", ");
437+
return ostream;
438+
}
439+
440+
int CeedBuildArrayConstantSize_Cuda(Ceed ceed, const char *name, CeedInt length, const CeedSize *array, char **line) {
441+
std::ostringstream code;
442+
443+
code << "constexpr CeedSize " << name << "[" << length << "] = " << CeedArrayView<CeedSize>(array, length) << ";";
444+
CeedCallBackend(CeedStringAllocCopy(code.str().c_str(), line));
445+
return CEED_ERROR_SUCCESS;
446+
}
447+
424448
int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const CeedInt num_defines, ...) {
425449
bool is_compile_good = true;
426450
va_list args;

backends/cuda/ceed-cuda-compile.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
static inline CeedInt CeedDivUpInt(CeedInt numerator, CeedInt denominator) { return (numerator + denominator - 1) / denominator; }
1414

15+
CEED_INTERN int CeedBuildArrayConstantSize_Cuda(Ceed ceed, const char *name, CeedInt length, const CeedSize *array, char **line);
16+
1517
CEED_INTERN int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const CeedInt num_defines, ...);
1618
CEED_INTERN int CeedTryCompile_Cuda(Ceed ceed, const char *source, bool *is_compile_good, CUmodule *module, const CeedInt num_defines, ...);
1719

0 commit comments

Comments
 (0)