@@ -473,13 +473,17 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
473473 CeedInt comp_stride;
474474
475475 CeedCallBackend (CeedElemRestrictionGetLVectorSize (elem_rstr, &l_size));
476+ code << tab << " if (e < num_elem) {\n " ;
477+ tab.push ();
476478 code << tab << " const CeedInt l_size" << var_suffix << " = " << l_size << " ;\n " ;
477479 CeedCallBackend (CeedElemRestrictionGetCompStride (elem_rstr, &comp_stride));
478- code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n " ;
480+ code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n\n " ;
479481 data->indices .outputs [i] = (CeedInt *)rstr_data->d_offsets ;
480482 code << tab << " WriteLVecStandard" << (is_all_tensor ? max_dim : 1 ) << " d<num_comp" << var_suffix << " , comp_stride" << var_suffix << " , "
481483 << P_name << " >(data, l_size" << var_suffix << " , elem, indices.outputs[" << i << " ], r_e" << var_suffix << " , d" << var_suffix
482484 << " );\n " ;
485+ tab.pop ();
486+ code << tab << " }\n " ;
483487 break ;
484488 }
485489 case CEED_RESTRICTION_STRIDED: {
@@ -493,11 +497,15 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
493497 if (!has_backend_strides) {
494498 CeedCallBackend (CeedElemRestrictionGetStrides (elem_rstr, strides));
495499 }
500+ code << tab << " if (e < num_elem) {\n " ;
501+ tab.push ();
496502 code << tab << " const CeedInt strides" << var_suffix << " _0 = " << strides[0 ] << " , strides" << var_suffix << " _1 = " << strides[1 ]
497- << " , strides" << var_suffix << " _2 = " << strides[2 ] << " ;\n " ;
503+ << " , strides" << var_suffix << " _2 = " << strides[2 ] << " ;\n\n " ;
498504 code << tab << " WriteLVecStrided" << (is_all_tensor ? max_dim : 1 ) << " d<num_comp" << var_suffix << " , " << P_name << " , strides"
499505 << var_suffix << " _0, strides" << var_suffix << " _1, strides" << var_suffix << " _2>(data, elem, r_e" << var_suffix << " , d" << var_suffix
500506 << " );\n " ;
507+ tab.pop ();
508+ code << tab << " }\n " ;
501509 break ;
502510 }
503511 case CEED_RESTRICTION_POINTS:
@@ -1033,10 +1041,14 @@ static int CeedOperatorBuildKernelQFunction_Cuda_gen(std::ostringstream &code, C
10331041 CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields[i], &elem_rstr));
10341042 CeedCallBackend (CeedElemRestrictionGetCompStride (elem_rstr, &comp_stride));
10351043 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
1036- code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n " ;
1044+ code << tab << " if (e < num_elem) {\n " ;
1045+ tab.push ();
1046+ code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n\n " ;
10371047 code << tab << " WritePoint<num_comp" << var_suffix << " , comp_stride" << var_suffix
10381048 << " , max_num_points>(data, elem, i, points.num_per_elem[elem], indices.outputs[" << i << " ]"
10391049 << " , r_s" << var_suffix << " , d" << var_suffix << " );\n " ;
1050+ tab.pop ();
1051+ code << tab << " }\n " ;
10401052 break ;
10411053 }
10421054 case CEED_EVAL_INTERP:
@@ -1482,8 +1494,10 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
14821494 // Loop over all elements
14831495 code << " \n " << tab << " // Element loop\n " ;
14841496 code << tab << " __syncthreads();\n " ;
1485- code << tab << " for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {\n " ;
1497+ code << tab << " const CeedInt elem_loop_bound = num_elem * ceil(1.0*num_elem/(gridDim.x*blockDim.z));\n\n " ;
1498+ code << tab << " for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < elem_loop_bound; e += gridDim.x*blockDim.z) {\n " ;
14861499 tab.push ();
1500+ code << tab << " const CeedInt elem = e % num_elem;\n\n " ;
14871501
14881502 // -- Compute minimum buffer space needed
14891503 CeedInt max_rstr_buffer_size = 1 ;
@@ -1848,8 +1862,10 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
18481862 // Loop over all elements
18491863 code << " \n " << tab << " // Element loop\n " ;
18501864 code << tab << " __syncthreads();\n " ;
1851- code << tab << " for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {\n " ;
1865+ code << tab << " const CeedInt elem_loop_bound = num_elem * ceil(1.0*num_elem/(gridDim.x*blockDim.z));\n\n " ;
1866+ code << tab << " for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < elem_loop_bound; e += gridDim.x*blockDim.z) {\n " ;
18521867 tab.push ();
1868+ code << tab << " const CeedInt elem = e % num_elem;\n\n " ;
18531869
18541870 // -- Compute minimum buffer space needed
18551871 CeedInt max_rstr_buffer_size = 1 ;
@@ -2042,11 +2058,15 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20422058
20432059 CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields[i], &elem_rstr));
20442060 CeedCallBackend (CeedElemRestrictionGetLVectorSize (elem_rstr, &l_size));
2061+ code << tab << " if (e < num_elem) {\n " ;
2062+ tab.push ();
20452063 code << tab << " const CeedInt l_size" << var_suffix << " = " << l_size << " ;\n " ;
20462064 CeedCallBackend (CeedElemRestrictionGetCompStride (elem_rstr, &comp_stride));
2047- code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n " ;
2065+ code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n\n " ;
20482066 code << tab << " WriteLVecStandard" << max_dim << " d_Assembly<num_comp" << var_suffix << " , comp_stride" << var_suffix << " , P_1d" + var_suffix
20492067 << " >(data, l_size" << var_suffix << " , elem, n, r_e" << var_suffix << " , values_array);\n " ;
2068+ tab.pop ();
2069+ code << tab << " }\n " ;
20502070 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
20512071 } else {
20522072 std::string var_suffix = " _out_" + std::to_string (i);
@@ -2056,11 +2076,15 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20562076
20572077 CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields[i], &elem_rstr));
20582078 CeedCallBackend (CeedElemRestrictionGetLVectorSize (elem_rstr, &l_size));
2079+ code << tab << " if (e < num_elem) {\n " ;
2080+ tab.push ();
20592081 code << tab << " const CeedInt l_size" << var_suffix << " = " << l_size << " ;\n " ;
20602082 CeedCallBackend (CeedElemRestrictionGetCompStride (elem_rstr, &comp_stride));
2061- code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n " ;
2083+ code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n\n " ;
20622084 code << tab << " WriteLVecStandard" << max_dim << " d_Single<num_comp" << var_suffix << " , comp_stride" << var_suffix << " , P_1d" + var_suffix
20632085 << " >(data, l_size" << var_suffix << " , elem, n, indices.outputs[" << i << " ], r_e" << var_suffix << " , values_array);\n " ;
2086+ tab.pop ();
2087+ code << tab << " }\n " ;
20642088 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
20652089 }
20662090 }
@@ -2425,8 +2449,10 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
24252449 // Loop over all elements
24262450 code << " \n " << tab << " // Element loop\n " ;
24272451 code << tab << " __syncthreads();\n " ;
2428- code << tab << " for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {\n " ;
2452+ code << tab << " const CeedInt elem_loop_bound = num_elem * ceil(1.0*num_elem/(gridDim.x*blockDim.z));\n\n " ;
2453+ code << tab << " for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < elem_loop_bound; e += gridDim.x*blockDim.z) {\n " ;
24292454 tab.push ();
2455+ code << tab << " const CeedInt elem = e % num_elem;\n\n " ;
24302456
24312457 // -- Compute minimum buffer space needed
24322458 CeedInt max_rstr_buffer_size = 1 ;
@@ -2642,8 +2668,12 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
26422668 // ---- Restriction
26432669 CeedInt field_size;
26442670
2671+ code << tab << " if (e < num_elem) {\n " ;
2672+ tab.push ();
26452673 code << tab << " WriteLVecStandard" << (is_all_tensor ? max_dim : 1 ) << " d_QFAssembly<total_size_out, field_size_out_" << i << " , "
26462674 << (is_all_tensor ? " Q_1d" : " Q" ) << " >(data, num_elem, elem, input_offset + s, " << offset << " , r_q_out_" << i << " , values_array);\n " ;
2675+ tab.pop ();
2676+ code << tab << " }\n " ;
26472677 CeedCallBackend (CeedQFunctionFieldGetSize (qf_output_fields[i], &field_size));
26482678 offset += field_size;
26492679 }
0 commit comments