@@ -662,6 +662,7 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
662662 dz_dy[block_idx * last_layer_size + thread_idx] = sum;
663663}
664664
665+ // Apply Grid-Stride Loop
665666template <typename FPTYPE, int MTILE, int KTILE>
666667__global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial (
667668 FPTYPE* out,
@@ -675,52 +676,51 @@ __global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
675676 const FPTYPE stride1,
676677 const int nnei_i,
677678 const int nnei_j,
678- const int last_layer_size) {
679+ const int last_layer_size,
680+ const int_64 total_work) {
679681 // NOT USED: em: (nfnl, nnei_i, nnei_j)
680682 // em_x: (nfnl * nnei_i * nnei_j, 1) flat version of em
681- // blockDim.x -> total threads in a block: nnei_i * nnei_j
682- // gridDim.x -> total blocks in a grid: nloc
683-
684- // Identify which atom and neighbor pair this thread is responsible for.
685- // block_idx corresponds to the atom index, given by the block index.
686- const int_64 block_idx = blockIdx .x ;
687-
688- // thread_idx is the flattened 1D index for the (ii, jj) neighbor pair.
689- const int_64 thread_idx = threadIdx .x ;
690-
691- // Recover the 2D (ii, jj) indices from the 1D thread index.
692- // thread_idx = ii * nnei_j + jj
693- const int_64 ii = thread_idx / nnei_j;
694- const int_64 jj = thread_idx % nnei_j;
695-
696- // Read the input value xx for this specific neighbor pair.
697- const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
698- FPTYPE xx = em_x[em_x_idx];
699-
700- // Determine the table index based on the value of xx.
701- int table_idx = 0 ;
702- locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0, stride1);
683+ // total_work = nloc * nnei_i * nnei_j
684+ // Grid-Stride Loop
685+ for (int_64 i = (int_64)blockIdx .x * blockDim .x + threadIdx .x ; i < total_work;
686+ i += (int_64)gridDim .x * blockDim .x ) {
687+ // Decompose the 1D index 'i' to get atom and neighbor indices
688+ const int_64 block_idx = i / (nnei_i * nnei_j);
689+ const int_64 local_idx = i % (nnei_i * nnei_j);
690+ const int_64 ii = local_idx / nnei_j;
691+ const int_64 jj = local_idx % nnei_j;
692+
693+ // Read the input value xx for this specific neighbor pair.
694+ FPTYPE xx = em_x[i];
695+
696+ // Determine the table index based on the value of xx.
697+ int table_idx = 0 ;
698+ locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0,
699+ stride1);
703700
704- // Serially loop through the 'last_layer_size' dimension to calculate all
705- // features.
706- for (int idx = 0 ; idx < last_layer_size; idx++) {
707- FPTYPE var[6 ];
708- load_polynomial_params (var, table, table_idx, idx, last_layer_size);
709- FPTYPE res =
710- var[0 ] +
711- (var[1 ] + (var[2 ] + (var[3 ] + (var[4 ] + var[5 ] * xx) * xx) * xx) * xx) *
712- xx;
713- // Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
714- // jj, idx).
715- const int_64 out_idx =
716- (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
717- (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
718- idx;
719- // Write the result to the global output memory.
720- out[out_idx] = res;
701+ // Serially loop through the 'last_layer_size' dimension to calculate all
702+ // features.
703+ for (int idx = 0 ; idx < last_layer_size; idx++) {
704+ FPTYPE var[6 ];
705+ load_polynomial_params (var, table, table_idx, idx, last_layer_size);
706+ FPTYPE res =
707+ var[0 ] +
708+ (var[1 ] +
709+ (var[2 ] + (var[3 ] + (var[4 ] + var[5 ] * xx) * xx) * xx) * xx) *
710+ xx;
711+ // Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
712+ // jj, idx).
713+ const int_64 out_idx =
714+ (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
715+ (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
716+ idx;
717+ // Write the result to the global output memory.
718+ out[out_idx] = res;
719+ }
721720 }
722721}
723722
723+ // Apply Grid-Stride Loop
724724template <typename FPTYPE, int MTILE, int KTILE>
725725__global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial (
726726 FPTYPE* dy_dem_x,
@@ -735,52 +735,56 @@ __global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial(
735735 const FPTYPE stride1,
736736 const int nnei_i,
737737 const int nnei_j,
738- const int last_layer_size) {
739- // blockDim.x -> nnei_i * nnei_j
740- // gridDim.x -> nloc
741-
742- // Identify which atom and neighbor pair this thread is responsible for.
743- const int_64 block_idx = blockIdx .x ;
744- const int_64 thread_idx = threadIdx .x ;
745- const int ii = thread_idx / nnei_j;
746- const int jj = thread_idx % nnei_j;
747-
748- // Determine the table index based on the value of xx.
749- const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
750- FPTYPE xx = em_x[em_x_idx];
751- int table_idx = 0 ;
752- locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0, stride1);
753-
754- // Accumulate the gradient contributions from all features.
755- FPTYPE grad_sum = 0.0 ;
756- for (int idx = 0 ; idx < last_layer_size; idx++) {
757- FPTYPE var[6 ];
758- load_polynomial_params (var, table, table_idx, idx, last_layer_size);
738+ const int last_layer_size,
739+ const int_64 total_work) {
740+ // total_work = nloc * nnei_i * nnei_j
741+ // Grid-Stride Loop
742+ for (int_64 i = (int_64)blockIdx .x * blockDim .x + threadIdx .x ; i < total_work;
743+ i += (int_64)gridDim .x * blockDim .x ) {
744+ // Decompose the 1D index 'i' to get atom and neighbor indices
745+ const int_64 block_idx = i / (nnei_i * nnei_j);
746+ const int_64 local_idx = i % (nnei_i * nnei_j);
747+ const int ii = local_idx / nnei_j;
748+ const int jj = local_idx % nnei_j;
749+
750+ // Determine the table index based on the value of xx.
751+ FPTYPE xx = em_x[i];
752+ int table_idx = 0 ;
753+ locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0,
754+ stride1);
759755
760- // Calculate the derivative of the polynomial with respect to xx.
761- FPTYPE dres_dxx =
762- var[1 ] + ((FPTYPE)2 . * var[2 ] +
763- ((FPTYPE)3 . * var[3 ] +
764- ((FPTYPE)4 . * var[4 ] + (FPTYPE)5 . * var[5 ] * xx) * xx) *
765- xx) *
766- xx;
756+ // Accumulate the gradient contributions from all features.
757+ FPTYPE grad_sum = 0.0 ;
758+ for (int idx = 0 ; idx < last_layer_size; idx++) {
759+ FPTYPE var[6 ];
760+ load_polynomial_params (var, table, table_idx, idx, last_layer_size);
767761
768- // Read the incoming gradient from the previous layer.
769- const int_64 dy_idx =
770- (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
771- (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
772- idx;
773- FPTYPE dy_val = dy[dy_idx];
762+ // Calculate the derivative of the polynomial with respect to xx.
763+ FPTYPE dres_dxx =
764+ var[1 ] + ((FPTYPE)2 . * var[2 ] +
765+ ((FPTYPE)3 . * var[3 ] +
766+ ((FPTYPE)4 . * var[4 ] + (FPTYPE)5 . * var[5 ] * xx) * xx) *
767+ xx) *
768+ xx;
774769
775- // Apply the chain rule: dL/dxx = sum over idx [ (dL/d_res_mm) *
776- // (d_res_mm/dxx) ]
777- grad_sum += dy_val * dres_dxx;
778- }
770+ // Read the incoming gradient from the previous layer.
771+ const int_64 dy_idx =
772+ (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
773+ (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
774+ idx;
775+ FPTYPE dy_val = dy[dy_idx];
779776
780- // Write the final summed gradient to the output array.
781- dy_dem_x[em_x_idx] = grad_sum;
777+ // Apply the chain rule: dL/dxx = sum over idx [ (dL/d_res_mm) *
778+ // (d_res_mm/dxx) ]
779+ grad_sum += dy_val * dres_dxx;
780+ }
781+
782+ // Write the final summed gradient to the output array.
783+ dy_dem_x[i] = grad_sum;
784+ }
782785}
783786
787+ // Apply Grid-Stride Loop
784788template <typename FPTYPE, int MTILE, int KTILE>
785789__global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial (
786790 FPTYPE* dz_dy,
@@ -795,51 +799,54 @@ __global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial(
795799 const FPTYPE stride1,
796800 const int nnei_i,
797801 const int nnei_j,
798- const int last_layer_size) {
799- // blockDim.x -> nnei_i * nnei_j
800- // gridDim.x -> nloc
801-
802- // Identify which atom and neighbor pair this thread is responsible for.
803- const int_64 block_idx = blockIdx .x ;
804- const int_64 thread_idx = threadIdx .x ;
805- const int ii = thread_idx / nnei_j;
806- const int jj = thread_idx % nnei_j;
807-
808- const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
809- FPTYPE xx = em_x[em_x_idx];
810-
811- // Read the incoming gradient for xx. This value is the same for all 'idx'
812- // features.
813- const FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[em_x_idx];
814-
815- // Determine the table index based on the value of xx.
816- int table_idx = 0 ;
817- locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0, stride1);
802+ const int last_layer_size,
803+ const int_64 total_work) {
804+ // total_work = nloc * nnei_i * nnei_j
805+ // Grid-Stride Loop
806+ for (int_64 i = (int_64)blockIdx .x * blockDim .x + threadIdx .x ; i < total_work;
807+ i += (int_64)gridDim .x * blockDim .x ) {
808+ // Decompose the 1D index 'i' to get atom and neighbor indices
809+ const int_64 block_idx = i / (nnei_i * nnei_j);
810+ const int_64 local_idx = i % (nnei_i * nnei_j);
811+ const int ii = local_idx / nnei_j;
812+ const int jj = local_idx % nnei_j;
813+
814+ FPTYPE xx = em_x[i];
815+
816+ // Read the incoming gradient for xx. This value is the same for all 'idx'
817+ // features.
818+ const FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[i];
819+
820+ // Determine the table index based on the value of xx.
821+ int table_idx = 0 ;
822+ locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0,
823+ stride1);
818824
819- // Serially loop through the 'last_layer_size' dimension.
820- for (int idx = 0 ; idx < last_layer_size; idx++) {
821- FPTYPE var[6 ];
822- load_polynomial_params (var, table, table_idx, idx, last_layer_size);
825+ // Serially loop through the 'last_layer_size' dimension.
826+ for (int idx = 0 ; idx < last_layer_size; idx++) {
827+ FPTYPE var[6 ];
828+ load_polynomial_params (var, table, table_idx, idx, last_layer_size);
823829
824- // Calculate the derivative of the polynomial with respect to xx.
825- FPTYPE dres_dxx =
826- var[1 ] + ((FPTYPE)2 . * var[2 ] +
827- ((FPTYPE)3 . * var[3 ] +
828- ((FPTYPE)4 . * var[4 ] + (FPTYPE)5 . * var[5 ] * xx) * xx) *
829- xx) *
830- xx;
830+ // Calculate the derivative of the polynomial with respect to xx.
831+ FPTYPE dres_dxx =
832+ var[1 ] + ((FPTYPE)2 . * var[2 ] +
833+ ((FPTYPE)3 . * var[3 ] +
834+ ((FPTYPE)4 . * var[4 ] + (FPTYPE)5 . * var[5 ] * xx) * xx) *
835+ xx) *
836+ xx;
831837
832- // Apply the chain rule: dz/dy_idx = (dz/dxx) * (dxx/dy_idx)
833- // which simplifies to dz_dy_dem_x_val * dres_dxx
834- FPTYPE out_grad = dz_dy_dem_x_val * dres_dxx;
835-
836- // Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
837- // jj, idx).
838- const int_64 out_idx =
839- (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
840- (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
841- idx;
842- dz_dy[out_idx] = out_grad;
838+ // Apply the chain rule: dz/dy_idx = (dz/dxx) * (dxx/dy_idx)
839+ // which simplifies to dz_dy_dem_x_val * dres_dxx
840+ FPTYPE out_grad = dz_dy_dem_x_val * dres_dxx;
841+
842+ // Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
843+ // jj, idx).
844+ const int_64 out_idx =
845+ (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
846+ (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
847+ idx;
848+ dz_dy[out_idx] = out_grad;
849+ }
843850 }
844851}
845852
@@ -1150,18 +1157,19 @@ void tabulate_fusion_se_t_tebd_gpu(FPTYPE* out,
11501157 if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0 ) {
11511158 return ;
11521159 }
1153- // Grid dimension: One block for each atom: nloc
1154- dim3 num_blocks (nloc);
1155- // Block dimension: One thread for each (ii, jj) neighbor pair: nnei_i *
1156- // nnei_j
1157- dim3 num_threads (nnei_i * nnei_j) ;
1160+ const int_64 total_work = (int_64)nloc * nnei_i * nnei_j;
1161+ // Use fixed number of threads per block
1162+ const int num_threads = TPB;
1163+ // Calculate number of blocks needed
1164+ const int num_blocks = (total_work + num_threads - 1 ) / num_threads ;
11581165
11591166 DPErrcheck (gpuGetLastError ());
11601167 DPErrcheck (gpuDeviceSynchronize ());
11611168 tabulate_fusion_se_t_tebd_fifth_order_polynomial<FPTYPE, MM, KK>
1162- <<<num_blocks, num_threads>>> (
1163- out, table, em_x, em, table_info[0 ], table_info[1 ], table_info[2 ],
1164- table_info[3 ], table_info[4 ], nnei_i, nnei_j, last_layer_size);
1169+ <<<num_blocks, num_threads>>> (out, table, em_x, em, table_info[0 ],
1170+ table_info[1 ], table_info[2 ], table_info[3 ],
1171+ table_info[4 ], nnei_i, nnei_j,
1172+ last_layer_size, total_work);
11651173 DPErrcheck (gpuGetLastError ());
11661174 DPErrcheck (gpuDeviceSynchronize ());
11671175}
@@ -1180,19 +1188,18 @@ void tabulate_fusion_se_t_tebd_grad_gpu(FPTYPE* dy_dem_x,
11801188 if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0 ) {
11811189 return ;
11821190 }
1183- // Define Grid and Block dimensions, matching the forward pass strategy.
1184- dim3 num_blocks (nloc) ;
1185- dim3 num_threads (nnei_i * nnei_j) ;
1191+ const int_64 total_work = (int_64)nloc * nnei_i * nnei_j;
1192+ const int num_threads = TPB ;
1193+ const int num_blocks = (total_work + num_threads - 1 ) / num_threads ;
11861194
11871195 DPErrcheck (gpuGetLastError ());
11881196 DPErrcheck (gpuDeviceSynchronize ());
1189- DPErrcheck (gpuMemset (dy_dem_x, 0 , sizeof (FPTYPE) * nloc * nnei_i * nnei_j));
1190- // table_info should be on CPU side
1197+ DPErrcheck (gpuMemset (dy_dem_x, 0 , sizeof (FPTYPE) * total_work));
11911198 tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial<FPTYPE, MM, KK>
11921199 <<<num_blocks, num_threads>>> (dy_dem_x, table, em_x, em, dy,
11931200 table_info[0 ], table_info[1 ], table_info[2 ],
11941201 table_info[3 ], table_info[4 ], nnei_i,
1195- nnei_j, last_layer_size);
1202+ nnei_j, last_layer_size, total_work );
11961203 DPErrcheck (gpuGetLastError ());
11971204 DPErrcheck (gpuDeviceSynchronize ());
11981205}
@@ -1211,19 +1218,20 @@ void tabulate_fusion_se_t_tebd_grad_grad_gpu(FPTYPE* dz_dy,
12111218 if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0 ) {
12121219 return ;
12131220 }
1214- dim3 num_blocks (nloc);
1215- dim3 num_threads (nnei_i * nnei_j);
1221+ const int_64 total_work = (int_64)nloc * nnei_i * nnei_j;
1222+ const int num_threads = TPB;
1223+ const int num_blocks = (total_work + num_threads - 1 ) / num_threads;
12161224
12171225 DPErrcheck (gpuGetLastError ());
12181226 DPErrcheck (gpuDeviceSynchronize ());
1219- DPErrcheck (gpuMemset (
1220- dz_dy, 0 , sizeof (FPTYPE) * nloc * nnei_i * nnei_j * last_layer_size));
1227+ DPErrcheck (
1228+ gpuMemset ( dz_dy, 0 , sizeof (FPTYPE) * total_work * last_layer_size));
12211229
12221230 tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
12231231 <<<num_blocks, num_threads>>> (dz_dy, table, em_x, em, dz_dy_dem_x,
12241232 table_info[0 ], table_info[1 ], table_info[2 ],
12251233 table_info[3 ], table_info[4 ], nnei_i,
1226- nnei_j, last_layer_size);
1234+ nnei_j, last_layer_size, total_work );
12271235 DPErrcheck (gpuGetLastError ());
12281236 DPErrcheck (gpuDeviceSynchronize ());
12291237}
0 commit comments