Skip to content

Commit f80f6cc

Browse files
authored
Merge pull request #945 from CEED/jeremy/more-quoted-kernels
Remove 'quoted' operator assembly kernels
2 parents d92ccc1 + 07b31e0 commit f80f6cc

6 files changed

Lines changed: 606 additions & 554 deletions

File tree

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 38 additions & 277 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <ceed/ceed.h>
99
#include <ceed/backend.h>
10+
#include <ceed/jit-tools.h>
1011
#include <assert.h>
1112
#include <cuda.h>
1213
#include <cuda_runtime.h>
@@ -719,149 +720,6 @@ static int CeedOperatorLinearAssembleQFunctionUpdate_Cuda(CeedOperator op,
719720
&rstr, request);
720721
}
721722

722-
//------------------------------------------------------------------------------
723-
// Diagonal assembly kernels
724-
//------------------------------------------------------------------------------
725-
// *INDENT-OFF*
726-
static const char *diagonalkernels = QUOTE(
727-
728-
typedef enum {
729-
/// Perform no evaluation (either because there is no data or it is already at
730-
/// quadrature points)
731-
CEED_EVAL_NONE = 0,
732-
/// Interpolate from nodes to quadrature points
733-
CEED_EVAL_INTERP = 1,
734-
/// Evaluate gradients at quadrature points from input in a nodal basis
735-
CEED_EVAL_GRAD = 2,
736-
/// Evaluate divergence at quadrature points from input in a nodal basis
737-
CEED_EVAL_DIV = 4,
738-
/// Evaluate curl at quadrature points from input in a nodal basis
739-
CEED_EVAL_CURL = 8,
740-
/// Using no input, evaluate quadrature weights on the reference element
741-
CEED_EVAL_WEIGHT = 16,
742-
} CeedEvalMode;
743-
744-
//------------------------------------------------------------------------------
745-
// Get Basis Emode Pointer
746-
//------------------------------------------------------------------------------
747-
extern "C" __device__ void CeedOperatorGetBasisPointer_Cuda(const CeedScalar **basisptr,
748-
CeedEvalMode emode, const CeedScalar *identity, const CeedScalar *interp,
749-
const CeedScalar *grad) {
750-
switch (emode) {
751-
case CEED_EVAL_NONE:
752-
*basisptr = identity;
753-
break;
754-
case CEED_EVAL_INTERP:
755-
*basisptr = interp;
756-
break;
757-
case CEED_EVAL_GRAD:
758-
*basisptr = grad;
759-
break;
760-
case CEED_EVAL_WEIGHT:
761-
case CEED_EVAL_DIV:
762-
case CEED_EVAL_CURL:
763-
break; // Caught by QF Assembly
764-
}
765-
}
766-
767-
//------------------------------------------------------------------------------
768-
// Core code for diagonal assembly
769-
//------------------------------------------------------------------------------
770-
__device__ void diagonalCore(const CeedInt nelem,
771-
const CeedScalar maxnorm, const bool pointBlock,
772-
const CeedScalar *identity,
773-
const CeedScalar *interpin, const CeedScalar *gradin,
774-
const CeedScalar *interpout, const CeedScalar *gradout,
775-
const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
776-
const CeedScalar *__restrict__ assembledqfarray,
777-
CeedScalar *__restrict__ elemdiagarray) {
778-
const int tid = threadIdx.x; // running with P threads, tid is evec node
779-
const CeedScalar qfvaluebound = maxnorm*1e-12;
780-
781-
// Compute the diagonal of B^T D B
782-
// Each element
783-
for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < nelem;
784-
e += gridDim.x*blockDim.z) {
785-
CeedInt dout = -1;
786-
// Each basis eval mode pair
787-
for (CeedInt eout = 0; eout < NUMEMODEOUT; eout++) {
788-
const CeedScalar *bt = NULL;
789-
if (emodeout[eout] == CEED_EVAL_GRAD)
790-
dout += 1;
791-
CeedOperatorGetBasisPointer_Cuda(&bt, emodeout[eout], identity, interpout,
792-
&gradout[dout*NQPTS*NNODES]);
793-
CeedInt din = -1;
794-
for (CeedInt ein = 0; ein < NUMEMODEIN; ein++) {
795-
const CeedScalar *b = NULL;
796-
if (emodein[ein] == CEED_EVAL_GRAD)
797-
din += 1;
798-
CeedOperatorGetBasisPointer_Cuda(&b, emodein[ein], identity, interpin,
799-
&gradin[din*NQPTS*NNODES]);
800-
// Each component
801-
for (CeedInt compOut = 0; compOut < NCOMP; compOut++) {
802-
// Each qpoint/node pair
803-
if (pointBlock) {
804-
// Point Block Diagonal
805-
for (CeedInt compIn = 0; compIn < NCOMP; compIn++) {
806-
CeedScalar evalue = 0.;
807-
for (CeedInt q = 0; q < NQPTS; q++) {
808-
const CeedScalar qfvalue =
809-
assembledqfarray[((((ein*NCOMP+compIn)*NUMEMODEOUT+eout)*
810-
NCOMP+compOut)*nelem+e)*NQPTS+q];
811-
if (abs(qfvalue) > qfvaluebound)
812-
evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
813-
}
814-
elemdiagarray[((compOut*NCOMP+compIn)*nelem+e)*NNODES+tid] += evalue;
815-
}
816-
} else {
817-
// Diagonal Only
818-
CeedScalar evalue = 0.;
819-
for (CeedInt q = 0; q < NQPTS; q++) {
820-
const CeedScalar qfvalue =
821-
assembledqfarray[((((ein*NCOMP+compOut)*NUMEMODEOUT+eout)*
822-
NCOMP+compOut)*nelem+e)*NQPTS+q];
823-
if (abs(qfvalue) > qfvaluebound)
824-
evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
825-
}
826-
elemdiagarray[(compOut*nelem+e)*NNODES+tid] += evalue;
827-
}
828-
}
829-
}
830-
}
831-
}
832-
}
833-
834-
//------------------------------------------------------------------------------
835-
// Linear diagonal
836-
//------------------------------------------------------------------------------
837-
extern "C" __global__ void linearDiagonal(const CeedInt nelem,
838-
const CeedScalar maxnorm, const CeedScalar *identity,
839-
const CeedScalar *interpin, const CeedScalar *gradin,
840-
const CeedScalar *interpout, const CeedScalar *gradout,
841-
const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
842-
const CeedScalar *__restrict__ assembledqfarray,
843-
CeedScalar *__restrict__ elemdiagarray) {
844-
diagonalCore(nelem, maxnorm, false, identity, interpin, gradin, interpout,
845-
gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
846-
}
847-
848-
//------------------------------------------------------------------------------
849-
// Linear point block diagonal
850-
//------------------------------------------------------------------------------
851-
extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem,
852-
const CeedScalar maxnorm, const CeedScalar *identity,
853-
const CeedScalar *interpin, const CeedScalar *gradin,
854-
const CeedScalar *interpout, const CeedScalar *gradout,
855-
const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
856-
const CeedScalar *__restrict__ assembledqfarray,
857-
CeedScalar *__restrict__ elemdiagarray) {
858-
diagonalCore(nelem, maxnorm, true, identity, interpin, gradin, interpout,
859-
gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
860-
}
861-
862-
);
863-
// *INDENT-ON*
864-
865723
//------------------------------------------------------------------------------
866724
// Create point block restriction
867725
//------------------------------------------------------------------------------
@@ -1027,11 +885,21 @@ static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op,
1027885
diag->numemodeout = numemodeout;
1028886

1029887
// Assemble kernel
888+
char *diagonal_kernel_path, *diagonal_kernel_source;
889+
ierr = CeedGetJitAbsolutePath(ceed,
890+
"ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h",
891+
&diagonal_kernel_path); CeedChkBackend(ierr);
892+
CeedDebug256(ceed, 2, "----- Loading Diagonal Assembly Kernel Source -----\n");
893+
ierr = CeedLoadSourceToBuffer(ceed, diagonal_kernel_path,
894+
&diagonal_kernel_source);
895+
CeedChkBackend(ierr);
896+
CeedDebug256(ceed, 2,
897+
"----- Loading Diagonal Assembly Source Complete! -----\n");
1030898
CeedInt nnodes, nqpts;
1031899
ierr = CeedBasisGetNumNodes(basisin, &nnodes); CeedChkBackend(ierr);
1032900
ierr = CeedBasisGetNumQuadraturePoints(basisin, &nqpts); CeedChkBackend(ierr);
1033901
diag->nnodes = nnodes;
1034-
ierr = CeedCompileCuda(ceed, diagonalkernels, &diag->module, 5,
902+
ierr = CeedCompileCuda(ceed, diagonal_kernel_source, &diag->module, 5,
1035903
"NUMEMODEIN", numemodein,
1036904
"NUMEMODEOUT", numemodeout,
1037905
"NNODES", nnodes,
@@ -1043,6 +911,8 @@ static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op,
1043911
ierr = CeedGetKernelCuda(ceed, diag->module, "linearPointBlockDiagonal",
1044912
&diag->linearPointBlock);
1045913
CeedChk_Cu(ceed, ierr);
914+
ierr = CeedFree(&diagonal_kernel_path); CeedChkBackend(ierr);
915+
ierr = CeedFree(&diagonal_kernel_source); CeedChkBackend(ierr);
1046916

1047917
// Basis matrices
1048918
const CeedInt qBytes = nqpts * sizeof(CeedScalar);
@@ -1246,119 +1116,6 @@ static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op,
12461116
}
12471117
}
12481118

1249-
//------------------------------------------------------------------------------
1250-
// Matrix assembly kernel for low-order elements (2D thread block)
1251-
//------------------------------------------------------------------------------
1252-
// *INDENT-OFF*
1253-
static const char *assemblykernel = QUOTE(
1254-
extern "C" __launch_bounds__(BLOCK_SIZE)
1255-
__global__ void linearAssemble(const CeedScalar *B_in, const CeedScalar *B_out,
1256-
const CeedScalar *__restrict__ qf_array,
1257-
CeedScalar *__restrict__ values_array) {
1258-
1259-
// This kernel assumes B_in and B_out have the same number of quadrature points and
1260-
// basis points.
1261-
// TODO: expand to more general cases
1262-
const int i = threadIdx.x; // The output row index of each B^TDB operation
1263-
const int l = threadIdx.y; // The output column index of each B^TDB operation
1264-
// such that we have (Bout^T)_ij D_jk Bin_kl = C_il
1265-
1266-
// Strides for final output ordering, determined by the reference (interface) implementation of
1267-
// the symbolic assembly, slowest --> fastest: element, comp_in, comp_out, node_row, node_col
1268-
const CeedInt comp_out_stride = NNODES * NNODES;
1269-
const CeedInt comp_in_stride = comp_out_stride * NCOMP;
1270-
const CeedInt e_stride = comp_in_stride * NCOMP;
1271-
// Strides for QF array, slowest --> fastest: emode_in, comp_in, emode_out, comp_out, elem, qpt
1272-
const CeedInt qe_stride = NQPTS;
1273-
const CeedInt qcomp_out_stride = NELEM * qe_stride;
1274-
const CeedInt qemode_out_stride = qcomp_out_stride * NCOMP;
1275-
const CeedInt qcomp_in_stride = qemode_out_stride * NUMEMODEOUT;
1276-
const CeedInt qemode_in_stride = qcomp_in_stride * NCOMP;
1277-
1278-
// Loop over each element (if necessary)
1279-
for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < NELEM;
1280-
e += gridDim.x*blockDim.z) {
1281-
for (CeedInt comp_in = 0; comp_in < NCOMP; comp_in++) {
1282-
for (CeedInt comp_out = 0; comp_out < NCOMP; comp_out++) {
1283-
CeedScalar result = 0.0;
1284-
CeedInt qf_index_comp = qcomp_in_stride * comp_in + qcomp_out_stride * comp_out + qe_stride * e;
1285-
for (CeedInt emode_in = 0; emode_in < NUMEMODEIN; emode_in++) {
1286-
CeedInt b_in_index = emode_in * NQPTS * NNODES;
1287-
for (CeedInt emode_out = 0; emode_out < NUMEMODEOUT; emode_out++) {
1288-
CeedInt b_out_index = emode_out * NQPTS * NNODES;
1289-
CeedInt qf_index = qf_index_comp + qemode_out_stride * emode_out + qemode_in_stride * emode_in;
1290-
// Perform the B^T D B operation for this 'chunk' of D (the qf_array)
1291-
for (CeedInt j = 0; j < NQPTS; j++) {
1292-
result += B_out[b_out_index + j * NNODES + i] * qf_array[qf_index + j] * B_in[b_in_index + j * NNODES + l];
1293-
}
1294-
1295-
}// end of emode_out
1296-
} // end of emode_in
1297-
CeedInt val_index = comp_in_stride * comp_in + comp_out_stride * comp_out + e_stride * e + NNODES * i + l;
1298-
values_array[val_index] = result;
1299-
} // end of out component
1300-
} // end of in component
1301-
} // end of element loop
1302-
}
1303-
);
1304-
1305-
//------------------------------------------------------------------------------
1306-
// Fallback kernel for larger orders (1D thread block)
1307-
//------------------------------------------------------------------------------
1308-
static const char *assemblykernelbigelem = QUOTE(
1309-
extern "C" __launch_bounds__(BLOCK_SIZE)
1310-
__global__ void linearAssemble(const CeedScalar *B_in, const CeedScalar *B_out,
1311-
const CeedScalar *__restrict__ qf_array,
1312-
CeedScalar *__restrict__ values_array) {
1313-
1314-
// This kernel assumes B_in and B_out have the same number of quadrature points and
1315-
// basis points.
1316-
// TODO: expand to more general cases
1317-
const int l = threadIdx.x; // The output column index of each B^TDB operation
1318-
// such that we have (Bout^T)_ij D_jk Bin_kl = C_il
1319-
1320-
// Strides for final output ordering, determined by the reference (interface) implementation of
1321-
// the symbolic assembly, slowest --> fastest: element, comp_in, comp_out, node_row, node_col
1322-
const CeedInt comp_out_stride = NNODES * NNODES;
1323-
const CeedInt comp_in_stride = comp_out_stride * NCOMP;
1324-
const CeedInt e_stride = comp_in_stride * NCOMP;
1325-
// Strides for QF array, slowest --> fastest: emode_in, comp_in, emode_out, comp_out, elem, qpt
1326-
const CeedInt qe_stride = NQPTS;
1327-
const CeedInt qcomp_out_stride = NELEM * qe_stride;
1328-
const CeedInt qemode_out_stride = qcomp_out_stride * NCOMP;
1329-
const CeedInt qcomp_in_stride = qemode_out_stride * NUMEMODEOUT;
1330-
const CeedInt qemode_in_stride = qcomp_in_stride * NCOMP;
1331-
1332-
// Loop over each element (if necessary)
1333-
for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < NELEM;
1334-
e += gridDim.x*blockDim.z) {
1335-
for (CeedInt comp_in = 0; comp_in < NCOMP; comp_in++) {
1336-
for (CeedInt comp_out = 0; comp_out < NCOMP; comp_out++) {
1337-
for (CeedInt i = 0; i < NNODES; i++) {
1338-
CeedScalar result = 0.0;
1339-
CeedInt qf_index_comp = qcomp_in_stride * comp_in + qcomp_out_stride * comp_out + qe_stride * e;
1340-
for (CeedInt emode_in = 0; emode_in < NUMEMODEIN; emode_in++) {
1341-
CeedInt b_in_index = emode_in * NQPTS * NNODES;
1342-
for (CeedInt emode_out = 0; emode_out < NUMEMODEOUT; emode_out++) {
1343-
CeedInt b_out_index = emode_out * NQPTS * NNODES;
1344-
CeedInt qf_index = qf_index_comp + qemode_out_stride * emode_out + qemode_in_stride * emode_in;
1345-
// Perform the B^T D B operation for this 'chunk' of D (the qf_array)
1346-
for (CeedInt j = 0; j < NQPTS; j++) {
1347-
result += B_out[b_out_index + j * NNODES + i] * qf_array[qf_index + j] * B_in[b_in_index + j * NNODES + l];
1348-
}
1349-
1350-
}// end of emode_out
1351-
} // end of emode_in
1352-
CeedInt val_index = comp_in_stride * comp_in + comp_out_stride * comp_out + e_stride * e + NNODES * i + l;
1353-
values_array[val_index] = result;
1354-
} // end of loop over element node index, i
1355-
} // end of out component
1356-
} // end of in component
1357-
} // end of element loop
1358-
}
1359-
);
1360-
// *INDENT-ON*
1361-
13621119
//------------------------------------------------------------------------------
13631120
// Single operator assembly setup
13641121
//------------------------------------------------------------------------------
@@ -1482,35 +1239,39 @@ static int CeedSingleOperatorAssembleSetup_Cuda(CeedOperator op) {
14821239
CeedInt block_size = esize * esize * elemsPerBlock;
14831240
Ceed_Cuda *cuda_data;
14841241
ierr = CeedGetData(ceed, &cuda_data); CeedChkBackend(ierr);
1485-
if (block_size > cuda_data->device_prop.maxThreadsPerBlock) {
1242+
char *assembly_kernel_path, *assembly_kernel_source;
1243+
ierr = CeedGetJitAbsolutePath(ceed,
1244+
"ceed/jit-source/cuda/cuda-ref-operator-assemble.h",
1245+
&assembly_kernel_path); CeedChkBackend(ierr);
1246+
CeedDebug256(ceed, 2, "----- Loading Assembly Kernel Source -----\n");
1247+
ierr = CeedLoadSourceToBuffer(ceed, assembly_kernel_path,
1248+
&assembly_kernel_source);
1249+
CeedChkBackend(ierr);
1250+
CeedDebug256(ceed, 2, "----- Loading Assembly Source Complete! -----\n");
1251+
bool fallback = block_size > cuda_data->device_prop.maxThreadsPerBlock;
1252+
if (fallback) {
14861253
// Use fallback kernel with 1D threadblock
14871254
block_size = esize * elemsPerBlock;
14881255
asmb->block_size_x = esize;
14891256
asmb->block_size_y = 1;
1490-
ierr = CeedCompileCuda(ceed, assemblykernelbigelem, &asmb->module, 7,
1491-
"NELEM", nelem,
1492-
"NUMEMODEIN", num_emode_in,
1493-
"NUMEMODEOUT", num_emode_out,
1494-
"NQPTS", nqpts,
1495-
"NNODES", esize,
1496-
"BLOCK_SIZE", block_size,
1497-
"NCOMP", ncomp
1498-
); CeedChk_Cu(ceed, ierr);
14991257
} else { // Use kernel with 2D threadblock
15001258
asmb->block_size_x = esize;
15011259
asmb->block_size_y = esize;
1502-
ierr = CeedCompileCuda(ceed, assemblykernel, &asmb->module, 7,
1503-
"NELEM", nelem,
1504-
"NUMEMODEIN", num_emode_in,
1505-
"NUMEMODEOUT", num_emode_out,
1506-
"NQPTS", nqpts,
1507-
"NNODES", esize,
1508-
"BLOCK_SIZE", block_size,
1509-
"NCOMP", ncomp
1510-
); CeedChk_Cu(ceed, ierr);
15111260
}
1512-
ierr = CeedGetKernelCuda(ceed, asmb->module, "linearAssemble",
1261+
ierr = CeedCompileCuda(ceed, assembly_kernel_source, &asmb->module, 7,
1262+
"NELEM", nelem,
1263+
"NUMEMODEIN", num_emode_in,
1264+
"NUMEMODEOUT", num_emode_out,
1265+
"NQPTS", nqpts,
1266+
"NNODES", esize,
1267+
"BLOCK_SIZE", block_size,
1268+
"NCOMP", ncomp
1269+
); CeedChk_Cu(ceed, ierr);
1270+
ierr = CeedGetKernelCuda(ceed, asmb->module,
1271+
fallback ? "linearAssembleFallback" : "linearAssemble",
15131272
&asmb->linearAssemble); CeedChk_Cu(ceed, ierr);
1273+
ierr = CeedFree(&assembly_kernel_path); CeedChkBackend(ierr);
1274+
ierr = CeedFree(&assembly_kernel_source); CeedChkBackend(ierr);
15141275

15151276
// Build 'full' B matrices (not 1D arrays used for tensor-product matrices)
15161277
const CeedScalar *interp_in, *grad_in;

0 commit comments

Comments
 (0)