Skip to content

Commit c5138d8

Browse files
committed
feat: use Grid-Stride loop to reduce threads num
1 parent 7da6874 commit c5138d8

1 file changed

Lines changed: 149 additions & 141 deletions

File tree

source/lib/src/gpu/tabulate.cu

Lines changed: 149 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
662662
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
663663
}
664664

665+
// Apply Grid-Stride Loop
665666
template <typename FPTYPE, int MTILE, int KTILE>
666667
__global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
667668
FPTYPE* out,
@@ -675,52 +676,51 @@ __global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
675676
const FPTYPE stride1,
676677
const int nnei_i,
677678
const int nnei_j,
678-
const int last_layer_size) {
679+
const int last_layer_size,
680+
const int_64 total_work) {
679681
// NOT USED: em: (nfnl, nnei_i, nnei_j)
680682
// em_x: (nfnl * nnei_i * nnei_j, 1) flat version of em
681-
// blockDim.x -> total threads in a block: nnei_i * nnei_j
682-
// gridDim.x -> total blocks in a grid: nloc
683-
684-
// Identify which atom and neighbor pair this thread is responsible for.
685-
// block_idx corresponds to the atom index, given by the block index.
686-
const int_64 block_idx = blockIdx.x;
687-
688-
// thread_idx is the flattened 1D index for the (ii, jj) neighbor pair.
689-
const int_64 thread_idx = threadIdx.x;
690-
691-
// Recover the 2D (ii, jj) indices from the 1D thread index.
692-
// thread_idx = ii * nnei_j + jj
693-
const int_64 ii = thread_idx / nnei_j;
694-
const int_64 jj = thread_idx % nnei_j;
695-
696-
// Read the input value xx for this specific neighbor pair.
697-
const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
698-
FPTYPE xx = em_x[em_x_idx];
699-
700-
// Determine the table index based on the value of xx.
701-
int table_idx = 0;
702-
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, stride1);
683+
// total_work = nloc * nnei_i * nnei_j
684+
// Grid-Stride Loop
685+
for (int_64 i = (int_64)blockIdx.x * blockDim.x + threadIdx.x; i < total_work;
686+
i += (int_64)gridDim.x * blockDim.x) {
687+
// Decompose the 1D index 'i' to get atom and neighbor indices
688+
const int_64 block_idx = i / (nnei_i * nnei_j);
689+
const int_64 local_idx = i % (nnei_i * nnei_j);
690+
const int_64 ii = local_idx / nnei_j;
691+
const int_64 jj = local_idx % nnei_j;
692+
693+
// Read the input value xx for this specific neighbor pair.
694+
FPTYPE xx = em_x[i];
695+
696+
// Determine the table index based on the value of xx.
697+
int table_idx = 0;
698+
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0,
699+
stride1);
703700

704-
// Serially loop through the 'last_layer_size' dimension to calculate all
705-
// features.
706-
for (int idx = 0; idx < last_layer_size; idx++) {
707-
FPTYPE var[6];
708-
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
709-
FPTYPE res =
710-
var[0] +
711-
(var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
712-
xx;
713-
// Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
714-
// jj, idx).
715-
const int_64 out_idx =
716-
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
717-
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
718-
idx;
719-
// Write the result to the global output memory.
720-
out[out_idx] = res;
701+
// Serially loop through the 'last_layer_size' dimension to calculate all
702+
// features.
703+
for (int idx = 0; idx < last_layer_size; idx++) {
704+
FPTYPE var[6];
705+
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
706+
FPTYPE res =
707+
var[0] +
708+
(var[1] +
709+
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
710+
xx;
711+
// Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
712+
// jj, idx).
713+
const int_64 out_idx =
714+
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
715+
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
716+
idx;
717+
// Write the result to the global output memory.
718+
out[out_idx] = res;
719+
}
721720
}
722721
}
723722

723+
// Apply Grid-Stride Loop
724724
template <typename FPTYPE, int MTILE, int KTILE>
725725
__global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial(
726726
FPTYPE* dy_dem_x,
@@ -735,52 +735,56 @@ __global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial(
735735
const FPTYPE stride1,
736736
const int nnei_i,
737737
const int nnei_j,
738-
const int last_layer_size) {
739-
// blockDim.x -> nnei_i * nnei_j
740-
// gridDim.x -> nloc
741-
742-
// Identify which atom and neighbor pair this thread is responsible for.
743-
const int_64 block_idx = blockIdx.x;
744-
const int_64 thread_idx = threadIdx.x;
745-
const int ii = thread_idx / nnei_j;
746-
const int jj = thread_idx % nnei_j;
747-
748-
// Determine the table index based on the value of xx.
749-
const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
750-
FPTYPE xx = em_x[em_x_idx];
751-
int table_idx = 0;
752-
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, stride1);
753-
754-
// Accumulate the gradient contributions from all features.
755-
FPTYPE grad_sum = 0.0;
756-
for (int idx = 0; idx < last_layer_size; idx++) {
757-
FPTYPE var[6];
758-
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
738+
const int last_layer_size,
739+
const int_64 total_work) {
740+
// total_work = nloc * nnei_i * nnei_j
741+
// Grid-Stride Loop
742+
for (int_64 i = (int_64)blockIdx.x * blockDim.x + threadIdx.x; i < total_work;
743+
i += (int_64)gridDim.x * blockDim.x) {
744+
// Decompose the 1D index 'i' to get atom and neighbor indices
745+
const int_64 block_idx = i / (nnei_i * nnei_j);
746+
const int_64 local_idx = i % (nnei_i * nnei_j);
747+
const int ii = local_idx / nnei_j;
748+
const int jj = local_idx % nnei_j;
749+
750+
// Determine the table index based on the value of xx.
751+
FPTYPE xx = em_x[i];
752+
int table_idx = 0;
753+
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0,
754+
stride1);
759755

760-
// Calculate the derivative of the polynomial with respect to xx.
761-
FPTYPE dres_dxx =
762-
var[1] + ((FPTYPE)2. * var[2] +
763-
((FPTYPE)3. * var[3] +
764-
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
765-
xx) *
766-
xx;
756+
// Accumulate the gradient contributions from all features.
757+
FPTYPE grad_sum = 0.0;
758+
for (int idx = 0; idx < last_layer_size; idx++) {
759+
FPTYPE var[6];
760+
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
767761

768-
// Read the incoming gradient from the previous layer.
769-
const int_64 dy_idx =
770-
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
771-
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
772-
idx;
773-
FPTYPE dy_val = dy[dy_idx];
762+
// Calculate the derivative of the polynomial with respect to xx.
763+
FPTYPE dres_dxx =
764+
var[1] + ((FPTYPE)2. * var[2] +
765+
((FPTYPE)3. * var[3] +
766+
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
767+
xx) *
768+
xx;
774769

775-
// Apply the chain rule: dL/dxx = sum over idx [ (dL/d_res_mm) *
776-
// (d_res_mm/dxx) ]
777-
grad_sum += dy_val * dres_dxx;
778-
}
770+
// Read the incoming gradient from the previous layer.
771+
const int_64 dy_idx =
772+
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
773+
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
774+
idx;
775+
FPTYPE dy_val = dy[dy_idx];
779776

780-
// Write the final summed gradient to the output array.
781-
dy_dem_x[em_x_idx] = grad_sum;
777+
// Apply the chain rule: dL/dxx = sum over idx [ (dL/d_res_mm) *
778+
// (d_res_mm/dxx) ]
779+
grad_sum += dy_val * dres_dxx;
780+
}
781+
782+
// Write the final summed gradient to the output array.
783+
dy_dem_x[i] = grad_sum;
784+
}
782785
}
783786

787+
// Apply Grid-Stride Loop
784788
template <typename FPTYPE, int MTILE, int KTILE>
785789
__global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial(
786790
FPTYPE* dz_dy,
@@ -795,51 +799,54 @@ __global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial(
795799
const FPTYPE stride1,
796800
const int nnei_i,
797801
const int nnei_j,
798-
const int last_layer_size) {
799-
// blockDim.x -> nnei_i * nnei_j
800-
// gridDim.x -> nloc
801-
802-
// Identify which atom and neighbor pair this thread is responsible for.
803-
const int_64 block_idx = blockIdx.x;
804-
const int_64 thread_idx = threadIdx.x;
805-
const int ii = thread_idx / nnei_j;
806-
const int jj = thread_idx % nnei_j;
807-
808-
const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
809-
FPTYPE xx = em_x[em_x_idx];
810-
811-
// Read the incoming gradient for xx. This value is the same for all 'idx'
812-
// features.
813-
const FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[em_x_idx];
814-
815-
// Determine the table index based on the value of xx.
816-
int table_idx = 0;
817-
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, stride1);
802+
const int last_layer_size,
803+
const int_64 total_work) {
804+
// total_work = nloc * nnei_i * nnei_j
805+
// Grid-Stride Loop
806+
for (int_64 i = (int_64)blockIdx.x * blockDim.x + threadIdx.x; i < total_work;
807+
i += (int_64)gridDim.x * blockDim.x) {
808+
// Decompose the 1D index 'i' to get atom and neighbor indices
809+
const int_64 block_idx = i / (nnei_i * nnei_j);
810+
const int_64 local_idx = i % (nnei_i * nnei_j);
811+
const int ii = local_idx / nnei_j;
812+
const int jj = local_idx % nnei_j;
813+
814+
FPTYPE xx = em_x[i];
815+
816+
// Read the incoming gradient for xx. This value is the same for all 'idx'
817+
// features.
818+
const FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[i];
819+
820+
// Determine the table index based on the value of xx.
821+
int table_idx = 0;
822+
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0,
823+
stride1);
818824

819-
// Serially loop through the 'last_layer_size' dimension.
820-
for (int idx = 0; idx < last_layer_size; idx++) {
821-
FPTYPE var[6];
822-
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
825+
// Serially loop through the 'last_layer_size' dimension.
826+
for (int idx = 0; idx < last_layer_size; idx++) {
827+
FPTYPE var[6];
828+
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
823829

824-
// Calculate the derivative of the polynomial with respect to xx.
825-
FPTYPE dres_dxx =
826-
var[1] + ((FPTYPE)2. * var[2] +
827-
((FPTYPE)3. * var[3] +
828-
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
829-
xx) *
830-
xx;
830+
// Calculate the derivative of the polynomial with respect to xx.
831+
FPTYPE dres_dxx =
832+
var[1] + ((FPTYPE)2. * var[2] +
833+
((FPTYPE)3. * var[3] +
834+
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
835+
xx) *
836+
xx;
831837

832-
// Apply the chain rule: dz/dy_idx = (dz/dxx) * (dxx/dy_idx)
833-
// which simplifies to dz_dy_dem_x_val * dres_dxx
834-
FPTYPE out_grad = dz_dy_dem_x_val * dres_dxx;
835-
836-
// Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
837-
// jj, idx).
838-
const int_64 out_idx =
839-
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
840-
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
841-
idx;
842-
dz_dy[out_idx] = out_grad;
838+
// Apply the chain rule: dz/dy_idx = (dz/dxx) * (dxx/dy_idx)
839+
// which simplifies to dz_dy_dem_x_val * dres_dxx
840+
FPTYPE out_grad = dz_dy_dem_x_val * dres_dxx;
841+
842+
// Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
843+
// jj, idx).
844+
const int_64 out_idx =
845+
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
846+
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
847+
idx;
848+
dz_dy[out_idx] = out_grad;
849+
}
843850
}
844851
}
845852

@@ -1150,18 +1157,19 @@ void tabulate_fusion_se_t_tebd_gpu(FPTYPE* out,
11501157
if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0) {
11511158
return;
11521159
}
1153-
// Grid dimension: One block for each atom: nloc
1154-
dim3 num_blocks(nloc);
1155-
// Block dimension: One thread for each (ii, jj) neighbor pair: nnei_i *
1156-
// nnei_j
1157-
dim3 num_threads(nnei_i * nnei_j);
1160+
const int_64 total_work = (int_64)nloc * nnei_i * nnei_j;
1161+
// Use fixed number of threads per block
1162+
const int num_threads = TPB;
1163+
// Calculate number of blocks needed
1164+
const int num_blocks = (total_work + num_threads - 1) / num_threads;
11581165

11591166
DPErrcheck(gpuGetLastError());
11601167
DPErrcheck(gpuDeviceSynchronize());
11611168
tabulate_fusion_se_t_tebd_fifth_order_polynomial<FPTYPE, MM, KK>
1162-
<<<num_blocks, num_threads>>>(
1163-
out, table, em_x, em, table_info[0], table_info[1], table_info[2],
1164-
table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size);
1169+
<<<num_blocks, num_threads>>>(out, table, em_x, em, table_info[0],
1170+
table_info[1], table_info[2], table_info[3],
1171+
table_info[4], nnei_i, nnei_j,
1172+
last_layer_size, total_work);
11651173
DPErrcheck(gpuGetLastError());
11661174
DPErrcheck(gpuDeviceSynchronize());
11671175
}
@@ -1180,19 +1188,18 @@ void tabulate_fusion_se_t_tebd_grad_gpu(FPTYPE* dy_dem_x,
11801188
if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0) {
11811189
return;
11821190
}
1183-
// Define Grid and Block dimensions, matching the forward pass strategy.
1184-
dim3 num_blocks(nloc);
1185-
dim3 num_threads(nnei_i * nnei_j);
1191+
const int_64 total_work = (int_64)nloc * nnei_i * nnei_j;
1192+
const int num_threads = TPB;
1193+
const int num_blocks = (total_work + num_threads - 1) / num_threads;
11861194

11871195
DPErrcheck(gpuGetLastError());
11881196
DPErrcheck(gpuDeviceSynchronize());
1189-
DPErrcheck(gpuMemset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j));
1190-
// table_info should be on CPU side
1197+
DPErrcheck(gpuMemset(dy_dem_x, 0, sizeof(FPTYPE) * total_work));
11911198
tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial<FPTYPE, MM, KK>
11921199
<<<num_blocks, num_threads>>>(dy_dem_x, table, em_x, em, dy,
11931200
table_info[0], table_info[1], table_info[2],
11941201
table_info[3], table_info[4], nnei_i,
1195-
nnei_j, last_layer_size);
1202+
nnei_j, last_layer_size, total_work);
11961203
DPErrcheck(gpuGetLastError());
11971204
DPErrcheck(gpuDeviceSynchronize());
11981205
}
@@ -1211,19 +1218,20 @@ void tabulate_fusion_se_t_tebd_grad_grad_gpu(FPTYPE* dz_dy,
12111218
if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0) {
12121219
return;
12131220
}
1214-
dim3 num_blocks(nloc);
1215-
dim3 num_threads(nnei_i * nnei_j);
1221+
const int_64 total_work = (int_64)nloc * nnei_i * nnei_j;
1222+
const int num_threads = TPB;
1223+
const int num_blocks = (total_work + num_threads - 1) / num_threads;
12161224

12171225
DPErrcheck(gpuGetLastError());
12181226
DPErrcheck(gpuDeviceSynchronize());
1219-
DPErrcheck(gpuMemset(
1220-
dz_dy, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j * last_layer_size));
1227+
DPErrcheck(
1228+
gpuMemset(dz_dy, 0, sizeof(FPTYPE) * total_work * last_layer_size));
12211229

12221230
tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
12231231
<<<num_blocks, num_threads>>>(dz_dy, table, em_x, em, dz_dy_dem_x,
12241232
table_info[0], table_info[1], table_info[2],
12251233
table_info[3], table_info[4], nnei_i,
1226-
nnei_j, last_layer_size);
1234+
nnei_j, last_layer_size, total_work);
12271235
DPErrcheck(gpuGetLastError());
12281236
DPErrcheck(gpuDeviceSynchronize());
12291237
}

0 commit comments

Comments
 (0)