@@ -103,6 +103,38 @@ __forceinline__ __device__ void locate_xx_se_t(FPTYPE& xx,
103103 }
104104}
105105
106+ // same with locate_xx_se_t
107+ template <typename FPTYPE>
108+ __forceinline__ __device__ void locate_xx_se_t_tebd (FPTYPE& xx,
109+ int & table_idx,
110+ const FPTYPE& lower,
111+ const FPTYPE& upper,
112+ const FPTYPE& min,
113+ const FPTYPE& max,
114+ const FPTYPE& stride0,
115+ const FPTYPE& stride1) {
116+ if (xx < min) {
117+ table_idx = 0 ;
118+ xx = (FPTYPE)0 .;
119+ } else if (xx < lower) {
120+ table_idx = (int )((xx - min) / stride1);
121+ xx -= (table_idx * stride1 + min);
122+ } else if (xx < upper) {
123+ int first_stride = int ((lower - min) / stride1);
124+ table_idx = first_stride + (int )((xx - lower) / stride0);
125+ xx -= ((table_idx - first_stride) * stride0 + lower);
126+ } else if (xx < max) {
127+ int first_stride =
128+ int ((lower - min) / stride1) + int ((upper - lower) / stride0);
129+ table_idx = first_stride + (int )((xx - upper) / stride1);
130+ xx -= ((table_idx - first_stride) * stride1 + upper);
131+ } else {
132+ table_idx = int ((lower - min) / stride1) + int ((upper - lower) / stride0) +
133+ (int )((max - upper) / stride1) - 1 ;
134+ xx = (FPTYPE)0 .;
135+ }
136+ }
137+
106138template <typename FPTYPE>
107139__forceinline__ __device__ void locate_xx_se_r (FPTYPE& xx,
108140 int & table_idx,
@@ -644,30 +676,48 @@ __global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
644676 const int nnei_i,
645677 const int nnei_j,
646678 const int last_layer_size) {
647- const int_64 block_idx = blockIdx .x ; // nloc
648- const int thread_idx = threadIdx .x ; // last_layer_size
649-
650- for (int ii = 0 ; ii < nnei_i; ii++) {
651- for (int jj = 0 ; jj < nnei_j; jj++) {
652- FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
653- int table_idx = 0 ;
654- locate_xx_se_t (xx, table_idx, lower, upper, -max, max, stride0, stride1);
655-
656- FPTYPE var[6 ];
657- load_polynomial_params (var, table, table_idx, thread_idx,
658- last_layer_size);
659-
660- FPTYPE res =
661- var[0 ] +
662- (var[1 ] +
663- (var[2 ] + (var[3 ] + (var[4 ] + var[5 ] * xx) * xx) * xx) * xx) *
664- xx;
665-
666- // Store result preserving the nt_i x nt_j structure
667- out[block_idx * nnei_i * nnei_j * last_layer_size +
668- ii * nnei_j * last_layer_size + jj * last_layer_size + thread_idx] =
669- res;
670- }
679+ // NOT USED: em: (nfnl, nnei_i, nnei_j)
680+ // em_x: (nfnl * nnei_i * nnei_j, 1) flat version of em
681+ // blockDim.x -> total threads in a block: nnei_i * nnei_j
682+ // gridDim.x -> total blocks in a grid: nloc
683+
684+ // Identify which atom and neighbor pair this thread is responsible for.
685+ // block_idx corresponds to the atom index, given by the block index.
686+ const int_64 block_idx = blockIdx .x ;
687+
688+ // thread_idx is the flattened 1D index for the (ii, jj) neighbor pair.
689+ const int_64 thread_idx = threadIdx .x ;
690+
691+ // Recover the 2D (ii, jj) indices from the 1D thread index.
692+ // thread_idx = ii * nnei_j + jj
693+ const int_64 ii = thread_idx / nnei_j;
694+ const int_64 jj = thread_idx % nnei_j;
695+
696+ // Read the input value xx for this specific neighbor pair.
697+ const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
698+ FPTYPE xx = em_x[em_x_idx];
699+
700+ // Determine the table index based on the value of xx.
701+ int table_idx = 0 ;
702+ locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0, stride1);
703+
704+ // Serially loop through the 'last_layer_size' dimension to calculate all
705+ // features.
706+ for (int idx = 0 ; idx < last_layer_size; idx++) {
707+ FPTYPE var[6 ];
708+ load_polynomial_params (var, table, table_idx, idx, last_layer_size);
709+ FPTYPE res =
710+ var[0 ] +
711+ (var[1 ] + (var[2 ] + (var[3 ] + (var[4 ] + var[5 ] * xx) * xx) * xx) * xx) *
712+ xx;
713+ // Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
714+ // jj, idx).
715+ const int_64 out_idx =
716+ (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
717+ (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
718+ idx;
719+ // Write the result to the global output memory.
720+ out[out_idx] = res;
671721 }
672722}
673723
@@ -686,35 +736,49 @@ __global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial(
686736 const int nnei_i,
687737 const int nnei_j,
688738 const int last_layer_size) {
689- const int_64 block_idx = blockIdx .x ; // nloc
690- const int thread_idx = threadIdx .x ; // thread within block
691-
692- for (int ii = 0 ; ii < nnei_i; ii++) {
693- for (int jj = 0 ; jj < nnei_j; jj++) {
694- FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
695- int table_idx = 0 ;
696- locate_xx_se_t (xx, table_idx, lower, upper, -max, max, stride0, stride1);
697-
698- FPTYPE grad_sum = 0.0 ;
699- for (int mm = 0 ; mm < last_layer_size; mm++) {
700- FPTYPE var[6 ];
701- load_polynomial_params (var, table, table_idx, mm, last_layer_size);
739+ // blockDim.x -> nnei_i * nnei_j
740+ // gridDim.x -> nloc
741+
742+ // Identify which atom and neighbor pair this thread is responsible for.
743+ const int block_idx = blockIdx .x ;
744+ const int thread_idx = threadIdx .x ;
745+ const int ii = thread_idx / nnei_j;
746+ const int jj = thread_idx % nnei_j;
747+
748+ // Determine the table index based on the value of xx.
749+ const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
750+ FPTYPE xx = em_x[em_x_idx];
751+ int table_idx = 0 ;
752+ locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0, stride1);
753+
754+ // Accumulate the gradient contributions from all features.
755+ FPTYPE grad_sum = 0.0 ;
756+ for (int idx = 0 ; idx < last_layer_size; idx++) {
757+ FPTYPE var[6 ];
758+ load_polynomial_params (var, table, table_idx, idx, last_layer_size);
702759
703- FPTYPE dres_dxx = var[1 ] + 2.0 * var[2 ] * xx + 3.0 * var[3 ] * xx * xx +
704- 4.0 * var[4 ] * xx * xx * xx +
705- 5.0 * var[5 ] * xx * xx * xx * xx;
760+ // Calculate the derivative of the polynomial with respect to xx.
761+ FPTYPE dres_dxx =
762+ var[1 ] + ((FPTYPE)2 . * var[2 ] +
763+ ((FPTYPE)3 . * var[3 ] +
764+ ((FPTYPE)4 . * var[4 ] + (FPTYPE)5 . * var[5 ] * xx) * xx) *
765+ xx) *
766+ xx;
706767
707- FPTYPE dy_val =
708- dy[block_idx * nnei_i * nnei_j * last_layer_size +
709- ii * nnei_j * last_layer_size + jj * last_layer_size + mm];
710- grad_sum += dy_val * dres_dxx;
711- }
768+ // Read the incoming gradient from the previous layer.
769+ const int_64 dy_idx =
770+ (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
771+ (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
772+ idx;
773+ FPTYPE dy_val = dy[dy_idx];
712774
713- if (thread_idx == 0 ) { // Only one thread writes the gradient
714- dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum;
715- }
716- }
775+ // Apply the chain rule: dL/dxx = sum over idx [ (dL/d_res_mm) *
776+ // (d_res_mm/dxx) ]
777+ grad_sum += dy_val * dres_dxx;
717778 }
779+
780+ // Write the final summed gradient to the output array.
781+ dy_dem_x[em_x_idx] = grad_sum;
718782}
719783
720784template <typename FPTYPE, int MTILE, int KTILE>
@@ -732,31 +796,50 @@ __global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial(
732796 const int nnei_i,
733797 const int nnei_j,
734798 const int last_layer_size) {
735- const int_64 block_idx = blockIdx . x ; // nloc
736- const int thread_idx = threadIdx . x ; // last_layer_size
799+ // blockDim.x -> nnei_i * nnei_j
800+ // gridDim.x -> nloc
737801
738- for ( int ii = 0 ; ii < nnei_i; ii++) {
739- for ( int jj = 0 ; jj < nnei_j; jj++) {
740- FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] ;
741- FPTYPE dz_dy_dem_x_val =
742- dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] ;
802+ // Identify which atom and neighbor pair this thread is responsible for.
803+ const int block_idx = blockIdx . x ;
804+ const int thread_idx = threadIdx . x ;
805+ const int ii = thread_idx / nnei_j;
806+ const int jj = thread_idx % nnei_j;
743807
744- int table_idx = 0 ;
745- locate_xx_se_t (xx, table_idx, lower, upper, -max, max, stride0, stride1) ;
808+ const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx ;
809+ FPTYPE xx = em_x[em_x_idx] ;
746810
747- FPTYPE var[ 6 ];
748- load_polynomial_params (var, table, table_idx, thread_idx,
749- last_layer_size) ;
811+ // Read the incoming gradient for xx. This value is the same for all 'idx'
812+ // features.
813+ const FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[em_x_idx] ;
750814
751- FPTYPE dres_dxx = var[ 1 ] + 2.0 * var[ 2 ] * xx + 3.0 * var[ 3 ] * xx * xx +
752- 4.0 * var[ 4 ] * xx * xx * xx +
753- 5.0 * var[ 5 ] * xx * xx * xx * xx ;
815+ // Determine the table index based on the value of xx.
816+ int table_idx = 0 ;
817+ locate_xx_se_t_tebd (xx, table_idx, lower, upper, -max, max, stride0, stride1) ;
754818
755- // Store result preserving the nt_i x nt_j structure
756- dz_dy[block_idx * nnei_i * nnei_j * last_layer_size +
757- ii * nnei_j * last_layer_size + jj * last_layer_size + thread_idx] =
758- dz_dy_dem_x_val * dres_dxx;
759- }
819+ // Serially loop through the 'last_layer_size' dimension.
820+ for (int idx = 0 ; idx < last_layer_size; idx++) {
821+ FPTYPE var[6 ];
822+ load_polynomial_params (var, table, table_idx, idx, last_layer_size);
823+
824+ // Calculate the derivative of the polynomial with respect to xx.
825+ FPTYPE dres_dxx =
826+ var[1 ] + ((FPTYPE)2 . * var[2 ] +
827+ ((FPTYPE)3 . * var[3 ] +
828+ ((FPTYPE)4 . * var[4 ] + (FPTYPE)5 . * var[5 ] * xx) * xx) *
829+ xx) *
830+ xx;
831+
832+ // Apply the chain rule: dz/dy_idx = (dz/dxx) * (dxx/dy_idx)
833+ // which simplifies to dz_dy_dem_x_val * dres_dxx
834+ FPTYPE out_grad = dz_dy_dem_x_val * dres_dxx;
835+
836+ // Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
837+ // jj, idx).
838+ const int_64 out_idx =
839+ (int_64)block_idx * nnei_i * nnei_j * last_layer_size +
840+ (int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
841+ idx;
842+ dz_dy[out_idx] = out_grad;
760843 }
761844}
762845
@@ -1064,13 +1147,19 @@ void tabulate_fusion_se_t_tebd_gpu(FPTYPE* out,
10641147 const int nnei_i,
10651148 const int nnei_j,
10661149 const int last_layer_size) {
1067- if (nloc <= 0 ) {
1150+ if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0 ) {
10681151 return ;
10691152 }
1153+ // Grid dimension: One block for each atom: nloc
1154+ dim3 num_blocks (nloc);
1155+ // Block dimension: One thread for each (ii, jj) neighbor pair: nnei_i *
1156+ // nnei_j
1157+ dim3 num_threads (nnei_i * nnei_j);
1158+
10701159 DPErrcheck (gpuGetLastError ());
10711160 DPErrcheck (gpuDeviceSynchronize ());
10721161 tabulate_fusion_se_t_tebd_fifth_order_polynomial<FPTYPE, MM, KK>
1073- <<<nloc, last_layer_size >>> (
1162+ <<<num_blocks, num_threads >>> (
10741163 out, table, em_x, em, table_info[0 ], table_info[1 ], table_info[2 ],
10751164 table_info[3 ], table_info[4 ], nnei_i, nnei_j, last_layer_size);
10761165 DPErrcheck (gpuGetLastError ());
@@ -1088,18 +1177,22 @@ void tabulate_fusion_se_t_tebd_grad_gpu(FPTYPE* dy_dem_x,
10881177 const int nnei_i,
10891178 const int nnei_j,
10901179 const int last_layer_size) {
1091- if (nloc <= 0 ) {
1180+ if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0 ) {
10921181 return ;
10931182 }
1183+ // Define Grid and Block dimensions, matching the forward pass strategy.
1184+ dim3 num_blocks (nloc);
1185+ dim3 num_threads (nnei_i * nnei_j);
1186+
10941187 DPErrcheck (gpuGetLastError ());
10951188 DPErrcheck (gpuDeviceSynchronize ());
10961189 DPErrcheck (gpuMemset (dy_dem_x, 0 , sizeof (FPTYPE) * nloc * nnei_i * nnei_j));
10971190 // table_info should be on CPU side
10981191 tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial<FPTYPE, MM, KK>
1099- <<<nloc, KK * WARP_SIZE >>> (dy_dem_x, table, em_x, em, dy, table_info[ 0 ] ,
1100- table_info[1 ], table_info[2 ], table_info[3 ],
1101- table_info[4 ], nnei_i, nnei_j ,
1102- last_layer_size);
1192+ <<<num_blocks, num_threads >>> (dy_dem_x, table, em_x, em, dy,
1193+ table_info[0 ], table_info[1 ], table_info[2 ],
1194+ table_info[3 ], table_info[ 4 ], nnei_i ,
1195+ nnei_j, last_layer_size);
11031196 DPErrcheck (gpuGetLastError ());
11041197 DPErrcheck (gpuDeviceSynchronize ());
11051198}
@@ -1115,19 +1208,22 @@ void tabulate_fusion_se_t_tebd_grad_grad_gpu(FPTYPE* dz_dy,
11151208 const int nnei_i,
11161209 const int nnei_j,
11171210 const int last_layer_size) {
1118- if (nloc <= 0 ) {
1211+ if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0 ) {
11191212 return ;
11201213 }
1214+ dim3 num_blocks (nloc);
1215+ dim3 num_threads (nnei_i * nnei_j);
1216+
11211217 DPErrcheck (gpuGetLastError ());
11221218 DPErrcheck (gpuDeviceSynchronize ());
11231219 DPErrcheck (gpuMemset (
11241220 dz_dy, 0 , sizeof (FPTYPE) * nloc * nnei_i * nnei_j * last_layer_size));
11251221
11261222 tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
1127- <<<nloc, last_layer_size >>> (dz_dy, table, em_x, em, dz_dy_dem_x,
1128- table_info[0 ], table_info[1 ], table_info[2 ],
1129- table_info[3 ], table_info[4 ], nnei_i, nnei_j ,
1130- last_layer_size);
1223+ <<<num_blocks, num_threads >>> (dz_dy, table, em_x, em, dz_dy_dem_x,
1224+ table_info[0 ], table_info[1 ], table_info[2 ],
1225+ table_info[3 ], table_info[4 ], nnei_i,
1226+ nnei_j, last_layer_size);
11311227 DPErrcheck (gpuGetLastError ());
11321228 DPErrcheck (gpuDeviceSynchronize ());
11331229}
0 commit comments