@@ -630,6 +630,129 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
630630 dz_dy[block_idx * last_layer_size + thread_idx] = sum;
631631}
632632
633+ template <typename FPTYPE , int MM , int KK >
634+ __global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial (
635+ FPTYPE * out,
636+ const FPTYPE * table,
637+ const FPTYPE * em_x,
638+ const FPTYPE * em,
639+ const FPTYPE lower,
640+ const FPTYPE upper,
641+ const FPTYPE max,
642+ const FPTYPE stride0,
643+ const FPTYPE stride1,
644+ const int nnei_i,
645+ const int nnei_j,
646+ const int last_layer_size) {
647+ const int_64 block_idx = blockIdx .x ; // nloc
648+ const int thread_idx = threadIdx .x ; // last_layer_size
649+
650+ for (int ii = 0 ; ii < nnei_i; ii++) {
651+ for (int jj = 0 ; jj < nnei_j; jj++) {
652+ FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
653+ int table_idx = 0 ;
654+ locate_xx_se_t (xx, table_idx, lower, upper, -max, max, stride0, stride1);
655+
656+ FPTYPE var[6 ];
657+ load_polynomial_params (var, table, table_idx, thread_idx, last_layer_size);
658+
659+ FPTYPE res = var[0 ] +
660+ (var[1 ] +
661+ (var[2 ] + (var[3 ] + (var[4 ] + var[5 ] * xx) * xx) * xx) * xx) * xx;
662+
663+ // Store result preserving the nt_i x nt_j structure
664+ out[block_idx * nnei_i * nnei_j * last_layer_size +
665+ ii * nnei_j * last_layer_size +
666+ jj * last_layer_size + thread_idx] = res;
667+ }
668+ }
669+ }
670+
671+ template <typename FPTYPE , int MM , int KK >
672+ __global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial (
673+ FPTYPE * dy_dem_x,
674+ const FPTYPE * table,
675+ const FPTYPE * em_x,
676+ const FPTYPE * em,
677+ const FPTYPE * dy,
678+ const FPTYPE lower,
679+ const FPTYPE upper,
680+ const FPTYPE max,
681+ const FPTYPE stride0,
682+ const FPTYPE stride1,
683+ const int nnei_i,
684+ const int nnei_j,
685+ const int last_layer_size) {
686+ const int_64 block_idx = blockIdx .x ; // nloc
687+ const int thread_idx = threadIdx .x ; // thread within block
688+
689+ for (int ii = 0 ; ii < nnei_i; ii++) {
690+ for (int jj = 0 ; jj < nnei_j; jj++) {
691+ FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
692+ int table_idx = 0 ;
693+ locate_xx_se_t (xx, table_idx, lower, upper, -max, max, stride0, stride1);
694+
695+ FPTYPE grad_sum = 0.0 ;
696+ for (int mm = 0 ; mm < last_layer_size; mm++) {
697+ FPTYPE var[6 ];
698+ load_polynomial_params (var, table, table_idx, mm, last_layer_size);
699+
700+ FPTYPE dres_dxx = var[1 ] + 2.0 * var[2 ] * xx + 3.0 * var[3 ] * xx * xx +
701+ 4.0 * var[4 ] * xx * xx * xx + 5.0 * var[5 ] * xx * xx * xx * xx;
702+
703+ FPTYPE dy_val = dy[block_idx * nnei_i * nnei_j * last_layer_size +
704+ ii * nnei_j * last_layer_size +
705+ jj * last_layer_size + mm];
706+ grad_sum += dy_val * dres_dxx;
707+ }
708+
709+ if (thread_idx == 0 ) { // Only one thread writes the gradient
710+ dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum;
711+ }
712+ }
713+ }
714+ }
715+
716+ template <typename FPTYPE , int MTILE , int KTILE >
717+ __global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial (
718+ FPTYPE * dz_dy,
719+ const FPTYPE * table,
720+ const FPTYPE * em_x,
721+ const FPTYPE * em,
722+ const FPTYPE * dz_dy_dem_x,
723+ const FPTYPE lower,
724+ const FPTYPE upper,
725+ const FPTYPE max,
726+ const FPTYPE stride0,
727+ const FPTYPE stride1,
728+ const int nnei_i,
729+ const int nnei_j,
730+ const int last_layer_size) {
731+ const int_64 block_idx = blockIdx .x ; // nloc
732+ const int thread_idx = threadIdx .x ; // last_layer_size
733+
734+ for (int ii = 0 ; ii < nnei_i; ii++) {
735+ for (int jj = 0 ; jj < nnei_j; jj++) {
736+ FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
737+ FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
738+
739+ int table_idx = 0 ;
740+ locate_xx_se_t (xx, table_idx, lower, upper, -max, max, stride0, stride1);
741+
742+ FPTYPE var[6 ];
743+ load_polynomial_params (var, table, table_idx, thread_idx, last_layer_size);
744+
745+ FPTYPE dres_dxx = var[1 ] + 2.0 * var[2 ] * xx + 3.0 * var[3 ] * xx * xx +
746+ 4.0 * var[4 ] * xx * xx * xx + 5.0 * var[5 ] * xx * xx * xx * xx;
747+
748+ // Store result preserving the nt_i x nt_j structure
749+ dz_dy[block_idx * nnei_i * nnei_j * last_layer_size +
750+ ii * nnei_j * last_layer_size +
751+ jj * last_layer_size + thread_idx] = dz_dy_dem_x_val * dres_dxx;
752+ }
753+ }
754+ }
755+
633756template <typename FPTYPE , int MTILE , int KTILE >
634757__global__ void tabulate_fusion_se_r_fifth_order_polynomial (
635758 FPTYPE * out,
@@ -923,6 +1046,82 @@ void tabulate_fusion_se_t_grad_grad_gpu(FPTYPE* dz_dy,
9231046 DPErrcheck (gpuDeviceSynchronize ());
9241047}
9251048
1049+ // SE_T_TEBD kernels - preserve full nt_i x nt_j structure unlike SE_T
1050+ template <typename FPTYPE >
1051+ void tabulate_fusion_se_t_tebd_gpu (FPTYPE * out,
1052+ const FPTYPE * table,
1053+ const FPTYPE * table_info,
1054+ const FPTYPE * em_x,
1055+ const FPTYPE * em,
1056+ const int nloc,
1057+ const int nnei_i,
1058+ const int nnei_j,
1059+ const int last_layer_size) {
1060+ if (nloc <= 0 ) {
1061+ return ;
1062+ }
1063+ DPErrcheck (gpuGetLastError ());
1064+ DPErrcheck (gpuDeviceSynchronize ());
1065+ tabulate_fusion_se_t_tebd_fifth_order_polynomial<FPTYPE , MM , KK >
1066+ <<<nloc, last_layer_size>>> (
1067+ out, table, em_x, em, table_info[0 ], table_info[1 ], table_info[2 ],
1068+ table_info[3 ], table_info[4 ], nnei_i, nnei_j, last_layer_size);
1069+ DPErrcheck (gpuGetLastError ());
1070+ DPErrcheck (gpuDeviceSynchronize ());
1071+ }
1072+
1073+ template <typename FPTYPE >
1074+ void tabulate_fusion_se_t_tebd_grad_gpu (FPTYPE * dy_dem_x,
1075+ const FPTYPE * table,
1076+ const FPTYPE * table_info,
1077+ const FPTYPE * em_x,
1078+ const FPTYPE * em,
1079+ const FPTYPE * dy,
1080+ const int nloc,
1081+ const int nnei_i,
1082+ const int nnei_j,
1083+ const int last_layer_size) {
1084+ if (nloc <= 0 ) {
1085+ return ;
1086+ }
1087+ DPErrcheck (gpuGetLastError ());
1088+ DPErrcheck (gpuDeviceSynchronize ());
1089+ DPErrcheck (gpuMemset (dy_dem_x, 0 , sizeof (FPTYPE ) * nloc * nnei_i * nnei_j));
1090+ tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial<FPTYPE , MM , KK >
1091+ <<<nloc, KK * WARP_SIZE >>> (
1092+ dy_dem_x, table, em_x, em, dy, table_info[0 ], table_info[1 ],
1093+ table_info[2 ], table_info[3 ], table_info[4 ], nnei_i, nnei_j, last_layer_size);
1094+ DPErrcheck (gpuGetLastError ());
1095+ DPErrcheck (gpuDeviceSynchronize ());
1096+ }
1097+
1098+ template <typename FPTYPE >
1099+ void tabulate_fusion_se_t_tebd_grad_grad_gpu (FPTYPE * dz_dy,
1100+ const FPTYPE * table,
1101+ const FPTYPE * table_info,
1102+ const FPTYPE * em_x,
1103+ const FPTYPE * em,
1104+ const FPTYPE * dz_dy_dem_x,
1105+ const int nloc,
1106+ const int nnei_i,
1107+ const int nnei_j,
1108+ const int last_layer_size) {
1109+ if (nloc <= 0 ) {
1110+ return ;
1111+ }
1112+ DPErrcheck (gpuGetLastError ());
1113+ DPErrcheck (gpuDeviceSynchronize ());
1114+ DPErrcheck (gpuMemset (dz_dy, 0 , sizeof (FPTYPE ) * nloc * nnei_i * nnei_j * last_layer_size));
1115+
1116+ tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial<FPTYPE , MM , KK >
1117+ <<<nloc, last_layer_size>>> (
1118+ dz_dy, table, em_x, em, dz_dy_dem_x,
1119+ table_info[0 ], table_info[1 ], table_info[2 ], table_info[3 ], table_info[4 ],
1120+ nnei_i, nnei_j, last_layer_size);
1121+ DPErrcheck (gpuGetLastError ());
1122+ DPErrcheck (gpuDeviceSynchronize ());
1123+ }
1124+
9261125template <typename FPTYPE >
9271126void tabulate_fusion_se_r_gpu (FPTYPE * out,
9281127 const FPTYPE * table,
@@ -1181,4 +1380,75 @@ template void tabulate_fusion_se_r_grad_grad_gpu<double>(
11811380 const int nnei,
11821381 const int last_layer_size);
11831382
1383+ // Template instantiations for SE_T_TEBD GPU functions
1384+ template void tabulate_fusion_se_t_tebd_gpu<float >(
1385+ float * out,
1386+ const float * table,
1387+ const float * table_info,
1388+ const float * em_x,
1389+ const float * em,
1390+ const int nloc,
1391+ const int nnei_i,
1392+ const int nnei_j,
1393+ const int last_layer_size);
1394+
1395+ template void tabulate_fusion_se_t_tebd_gpu<double >(
1396+ double * out,
1397+ const double * table,
1398+ const double * table_info,
1399+ const double * em_x,
1400+ const double * em,
1401+ const int nloc,
1402+ const int nnei_i,
1403+ const int nnei_j,
1404+ const int last_layer_size);
1405+
1406+ template void tabulate_fusion_se_t_tebd_grad_gpu<float >(
1407+ float * dy_dem_x,
1408+ const float * table,
1409+ const float * table_info,
1410+ const float * em_x,
1411+ const float * em,
1412+ const float * dy,
1413+ const int nloc,
1414+ const int nnei_i,
1415+ const int nnei_j,
1416+ const int last_layer_size);
1417+
1418+ template void tabulate_fusion_se_t_tebd_grad_gpu<double >(
1419+ double * dy_dem_x,
1420+ const double * table,
1421+ const double * table_info,
1422+ const double * em_x,
1423+ const double * em,
1424+ const double * dy,
1425+ const int nloc,
1426+ const int nnei_i,
1427+ const int nnei_j,
1428+ const int last_layer_size);
1429+
1430+ template void tabulate_fusion_se_t_tebd_grad_grad_gpu<float >(
1431+ float * dz_dy,
1432+ const float * table,
1433+ const float * table_info,
1434+ const float * em_x,
1435+ const float * em,
1436+ const float * dz_dy_dem_x,
1437+ const int nloc,
1438+ const int nnei_i,
1439+ const int nnei_j,
1440+ const int last_layer_size);
1441+
1442+ template void tabulate_fusion_se_t_tebd_grad_grad_gpu<double >(
1443+ double * dz_dy,
1444+ const double * table,
1445+ const double * table_info,
1446+ const double * em_x,
1447+ const double * em,
1448+ const double * dz_dy_dem_x,
1449+ const int nloc,
1450+ const int nnei_i,
1451+ const int nnei_j,
1452+ const int last_layer_size);
1453+
11841454} // namespace deepmd
0 commit comments