Skip to content

Commit 6d63a44

Browse files
authored
Merge pull request #1975 from bstatcomp/fix_kg_local_variables
make kernel generator declare local variables at top level of kernel function.
2 parents e5f00e2 + fb887d5 commit 6d63a44

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

stan/math/opencl/kernel_generator/colwise_reduction.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ class colwise_reduction
101101
const bool view_handled,
102102
const std::string& var_name_arg) const {
103103
kernel_parts res;
104-
res.initialization = type_str<Scalar>() + " " + var_name_ + " = " + init_
105-
+ ";\n" "__local " + type_str<Scalar>() + " "
106-
+ var_name_ + "_local[LOCAL_SIZE_];\n";
104+
res.declarations = "__local " + type_str<Scalar>() + " " + var_name_
105+
+ "_local[LOCAL_SIZE_];\n";
106+
res.initialization
107+
= type_str<Scalar>() + " " + var_name_ + " = " + init_ + ";\n";
107108
res.body = var_name_ + " = " + var_name_arg + ";\n";
108109
res.reduction =
109110
var_name_ + "_local[lid_i] = " + var_name_ + ";\n"

stan/math/opencl/kernel_generator/multi_result_kernel.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ class results_cl {
343343
"const int i0 = lsize_i * wg_id_i;\n"
344344
"const int i = i0 + lid_i;\n"
345345
"const int j0 = lsize_i * wg_id_j;\n"
346+
+ parts.declarations +
346347
"for(int lid_j = 0; lid_j < min(cols - j0, lsize_i); lid_j++){\n"
347348
"const int j = j0 + lid_j;\n"
348349
+ parts.initialization +
@@ -360,6 +361,7 @@ class results_cl {
360361
"){\n"
361362
"int i = get_global_id(0);\n"
362363
"int j = get_global_id(1);\n"
364+
+ parts.declarations
363365
+ parts.initialization
364366
+ parts.body
365367
+ parts.reduction +

stan/math/opencl/kernel_generator/operation_cl.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace math {
3333
struct kernel_parts {
3434
std::string includes; // any function definitions - as if they were includet
3535
// at the start of kernel source
36+
std::string declarations; // declarations of any local variables
3637
std::string initialization; // the code for initializations done by all
3738
// threads, even if they have no work
3839
std::string body_prefix; // the code that should be placed at the start of
@@ -43,14 +44,18 @@ struct kernel_parts {
4344
std::string args; // kernel arguments
4445

4546
kernel_parts operator+(const kernel_parts& other) {
46-
return {
47-
includes + other.includes, initialization + other.initialization,
48-
body_prefix + other.body_prefix, body + other.body,
49-
reduction + other.reduction, args + other.args};
47+
return {includes + other.includes,
48+
declarations += other.declarations,
49+
initialization + other.initialization,
50+
body_prefix + other.body_prefix,
51+
body + other.body,
52+
reduction + other.reduction,
53+
args + other.args};
5054
}
5155

5256
kernel_parts operator+=(const kernel_parts& other) {
5357
includes += other.includes;
58+
declarations += other.declarations;
5459
initialization += other.initialization;
5560
body_prefix += other.body_prefix;
5661
body += other.body;

test/unit/math/opencl/kernel_generator/reference_kernels/colwise_sum.cl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ const int blocks_cols = (cols + lsize_i - 1) / lsize_i;
1010
const int i0 = lsize_i * wg_id_i;
1111
const int i = i0 + lid_i;
1212
const int j0 = lsize_i * wg_id_j;
13+
__local double var1_local[LOCAL_SIZE_];
1314
for(int lid_j = 0; lid_j < min(cols - j0, lsize_i); lid_j++){
1415
const int j = j0 + lid_j;
1516
double var1 = 0;
16-
__local double var1_local[LOCAL_SIZE_];
1717
if(i < rows){
1818
double var2 = 0; if (!((!contains_nonzero(var2_view, LOWER) && j < i) || (!contains_nonzero(var2_view, UPPER) && j > i))) {var2 = var2_global[i + var2_rows * j];}
1919
var1 = var2;

0 commit comments

Comments
 (0)