Skip to content

Commit 3aff58b

Browse files
committed
realize compress in forward through OP: tabulate_fusion_se_t && add tabulate_fusion_se_t_tebd custom OP but not used and tested
1 parent cfcbb7b commit 3aff58b

5 files changed

Lines changed: 747 additions & 4 deletions

File tree

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -953,15 +953,16 @@ def forward(
953953
if self.compress:
954954
# Use tabulated computation for the geometric embedding
955955
ebd_env_ij = env_ij.view(-1, 1)
956-
gg_s = torch.ops.deepmd.tabulate_fusion_se_t(
956+
gg_s_compressed = torch.ops.deepmd.tabulate_fusion_se_t(
957957
self.compress_data[0].contiguous(),
958958
self.compress_info[0].cpu().contiguous(),
959959
ebd_env_ij.contiguous(),
960960
env_ij.contiguous(),
961961
self.filter_neuron[-1],
962962
)[0]
963-
# Reshape back to the expected format: nfnl x nt_i x nt_j x ng
964-
gg_s = gg_s.view(nfnl, nnei, nnei, self.filter_neuron[-1])
963+
# The compressed output is nfnl x ng, need to expand to nfnl x nt_i x nt_j x ng
964+
# by replicating across the neighbor dimensions
965+
gg_s = gg_s_compressed.view(nfnl, 1, 1, self.filter_neuron[-1]).expand(nfnl, nnei, nnei, self.filter_neuron[-1])
965966
else:
966967
# nfnl x nt_i x nt_j x ng
967968
gg_s = self.filter_layers.networks[0](ss)

source/lib/include/tabulate.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,41 @@ void tabulate_fusion_se_r_grad_grad_cpu(FPTYPE* dz_dy,
111111
const int nnei,
112112
const int last_layer_size);
113113

114+
template <typename FPTYPE>
115+
void tabulate_fusion_se_t_tebd_cpu(FPTYPE* out,
116+
const FPTYPE* table,
117+
const FPTYPE* table_info,
118+
const FPTYPE* em_x,
119+
const FPTYPE* em,
120+
const int nloc,
121+
const int nnei_i,
122+
const int nnei_j,
123+
const int last_layer_size);
124+
125+
template <typename FPTYPE>
126+
void tabulate_fusion_se_t_tebd_grad_cpu(FPTYPE* dy_dem_x,
127+
const FPTYPE* table,
128+
const FPTYPE* table_info,
129+
const FPTYPE* em_x,
130+
const FPTYPE* em,
131+
const FPTYPE* dy,
132+
const int nloc,
133+
const int nnei_i,
134+
const int nnei_j,
135+
const int last_layer_size);
136+
137+
template <typename FPTYPE>
138+
void tabulate_fusion_se_t_tebd_grad_grad_cpu(FPTYPE* dz_dy,
139+
const FPTYPE* table,
140+
const FPTYPE* table_info,
141+
const FPTYPE* em_x,
142+
const FPTYPE* em,
143+
const FPTYPE* dz_dy_dem_x,
144+
const int nloc,
145+
const int nnei_i,
146+
const int nnei_j,
147+
const int last_layer_size);
148+
114149
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
115150
template <typename FPTYPE>
116151
void tabulate_fusion_se_a_gpu(FPTYPE* out,
@@ -219,5 +254,40 @@ void tabulate_fusion_se_r_grad_grad_gpu(FPTYPE* dz_dy,
219254
const int nloc,
220255
const int nnei,
221256
const int last_layer_size);
257+
258+
template <typename FPTYPE>
259+
void tabulate_fusion_se_t_tebd_gpu(FPTYPE* out,
260+
const FPTYPE* table,
261+
const FPTYPE* table_info,
262+
const FPTYPE* em_x,
263+
const FPTYPE* em,
264+
const int nloc,
265+
const int nnei_i,
266+
const int nnei_j,
267+
const int last_layer_size);
268+
269+
template <typename FPTYPE>
270+
void tabulate_fusion_se_t_tebd_grad_gpu(FPTYPE* dy_dem_x,
271+
const FPTYPE* table,
272+
const FPTYPE* table_info,
273+
const FPTYPE* em_x,
274+
const FPTYPE* em,
275+
const FPTYPE* dy,
276+
const int nloc,
277+
const int nnei_i,
278+
const int nnei_j,
279+
const int last_layer_size);
280+
281+
template <typename FPTYPE>
282+
void tabulate_fusion_se_t_tebd_grad_grad_gpu(FPTYPE* dz_dy,
283+
const FPTYPE* table,
284+
const FPTYPE* table_info,
285+
const FPTYPE* em_x,
286+
const FPTYPE* em,
287+
const FPTYPE* dz_dy_dem_x,
288+
const int nloc,
289+
const int nnei_i,
290+
const int nnei_j,
291+
const int last_layer_size);
222292
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
223293
} // namespace deepmd

source/lib/src/gpu/tabulate.cu

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,129 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
630630
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
631631
}
632632

633+
template <typename FPTYPE, int MM, int KK>
634+
__global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
635+
FPTYPE* out,
636+
const FPTYPE* table,
637+
const FPTYPE* em_x,
638+
const FPTYPE* em,
639+
const FPTYPE lower,
640+
const FPTYPE upper,
641+
const FPTYPE max,
642+
const FPTYPE stride0,
643+
const FPTYPE stride1,
644+
const int nnei_i,
645+
const int nnei_j,
646+
const int last_layer_size) {
647+
const int_64 block_idx = blockIdx.x; // nloc
648+
const int thread_idx = threadIdx.x; // last_layer_size
649+
650+
for (int ii = 0; ii < nnei_i; ii++) {
651+
for (int jj = 0; jj < nnei_j; jj++) {
652+
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
653+
int table_idx = 0;
654+
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
655+
656+
FPTYPE var[6];
657+
load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size);
658+
659+
FPTYPE res = var[0] +
660+
(var[1] +
661+
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;
662+
663+
// Store result preserving the nt_i x nt_j structure
664+
out[block_idx * nnei_i * nnei_j * last_layer_size +
665+
ii * nnei_j * last_layer_size +
666+
jj * last_layer_size + thread_idx] = res;
667+
}
668+
}
669+
}
670+
671+
template <typename FPTYPE, int MM, int KK>
672+
__global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial(
673+
FPTYPE* dy_dem_x,
674+
const FPTYPE* table,
675+
const FPTYPE* em_x,
676+
const FPTYPE* em,
677+
const FPTYPE* dy,
678+
const FPTYPE lower,
679+
const FPTYPE upper,
680+
const FPTYPE max,
681+
const FPTYPE stride0,
682+
const FPTYPE stride1,
683+
const int nnei_i,
684+
const int nnei_j,
685+
const int last_layer_size) {
686+
const int_64 block_idx = blockIdx.x; // nloc
687+
const int thread_idx = threadIdx.x; // thread within block
688+
689+
for (int ii = 0; ii < nnei_i; ii++) {
690+
for (int jj = 0; jj < nnei_j; jj++) {
691+
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
692+
int table_idx = 0;
693+
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
694+
695+
FPTYPE grad_sum = 0.0;
696+
for (int mm = 0; mm < last_layer_size; mm++) {
697+
FPTYPE var[6];
698+
load_polynomial_params(var, table, table_idx, mm, last_layer_size);
699+
700+
FPTYPE dres_dxx = var[1] + 2.0 * var[2] * xx + 3.0 * var[3] * xx * xx +
701+
4.0 * var[4] * xx * xx * xx + 5.0 * var[5] * xx * xx * xx * xx;
702+
703+
FPTYPE dy_val = dy[block_idx * nnei_i * nnei_j * last_layer_size +
704+
ii * nnei_j * last_layer_size +
705+
jj * last_layer_size + mm];
706+
grad_sum += dy_val * dres_dxx;
707+
}
708+
709+
if (thread_idx == 0) { // Only one thread writes the gradient
710+
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum;
711+
}
712+
}
713+
}
714+
}
715+
716+
template <typename FPTYPE, int MTILE, int KTILE>
717+
__global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial(
718+
FPTYPE* dz_dy,
719+
const FPTYPE* table,
720+
const FPTYPE* em_x,
721+
const FPTYPE* em,
722+
const FPTYPE* dz_dy_dem_x,
723+
const FPTYPE lower,
724+
const FPTYPE upper,
725+
const FPTYPE max,
726+
const FPTYPE stride0,
727+
const FPTYPE stride1,
728+
const int nnei_i,
729+
const int nnei_j,
730+
const int last_layer_size) {
731+
const int_64 block_idx = blockIdx.x; // nloc
732+
const int thread_idx = threadIdx.x; // last_layer_size
733+
734+
for (int ii = 0; ii < nnei_i; ii++) {
735+
for (int jj = 0; jj < nnei_j; jj++) {
736+
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
737+
FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
738+
739+
int table_idx = 0;
740+
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
741+
742+
FPTYPE var[6];
743+
load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size);
744+
745+
FPTYPE dres_dxx = var[1] + 2.0 * var[2] * xx + 3.0 * var[3] * xx * xx +
746+
4.0 * var[4] * xx * xx * xx + 5.0 * var[5] * xx * xx * xx * xx;
747+
748+
// Store result preserving the nt_i x nt_j structure
749+
dz_dy[block_idx * nnei_i * nnei_j * last_layer_size +
750+
ii * nnei_j * last_layer_size +
751+
jj * last_layer_size + thread_idx] = dz_dy_dem_x_val * dres_dxx;
752+
}
753+
}
754+
}
755+
633756
template <typename FPTYPE, int MTILE, int KTILE>
634757
__global__ void tabulate_fusion_se_r_fifth_order_polynomial(
635758
FPTYPE* out,
@@ -923,6 +1046,82 @@ void tabulate_fusion_se_t_grad_grad_gpu(FPTYPE* dz_dy,
9231046
DPErrcheck(gpuDeviceSynchronize());
9241047
}
9251048

1049+
// SE_T_TEBD kernels - preserve full nt_i x nt_j structure unlike SE_T
1050+
template <typename FPTYPE>
1051+
void tabulate_fusion_se_t_tebd_gpu(FPTYPE* out,
1052+
const FPTYPE* table,
1053+
const FPTYPE* table_info,
1054+
const FPTYPE* em_x,
1055+
const FPTYPE* em,
1056+
const int nloc,
1057+
const int nnei_i,
1058+
const int nnei_j,
1059+
const int last_layer_size) {
1060+
if (nloc <= 0) {
1061+
return;
1062+
}
1063+
DPErrcheck(gpuGetLastError());
1064+
DPErrcheck(gpuDeviceSynchronize());
1065+
tabulate_fusion_se_t_tebd_fifth_order_polynomial<FPTYPE, MM, KK>
1066+
<<<nloc, last_layer_size>>>(
1067+
out, table, em_x, em, table_info[0], table_info[1], table_info[2],
1068+
table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size);
1069+
DPErrcheck(gpuGetLastError());
1070+
DPErrcheck(gpuDeviceSynchronize());
1071+
}
1072+
1073+
template <typename FPTYPE>
1074+
void tabulate_fusion_se_t_tebd_grad_gpu(FPTYPE* dy_dem_x,
1075+
const FPTYPE* table,
1076+
const FPTYPE* table_info,
1077+
const FPTYPE* em_x,
1078+
const FPTYPE* em,
1079+
const FPTYPE* dy,
1080+
const int nloc,
1081+
const int nnei_i,
1082+
const int nnei_j,
1083+
const int last_layer_size) {
1084+
if (nloc <= 0) {
1085+
return;
1086+
}
1087+
DPErrcheck(gpuGetLastError());
1088+
DPErrcheck(gpuDeviceSynchronize());
1089+
DPErrcheck(gpuMemset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j));
1090+
tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial<FPTYPE, MM, KK>
1091+
<<<nloc, KK * WARP_SIZE>>>(
1092+
dy_dem_x, table, em_x, em, dy, table_info[0], table_info[1],
1093+
table_info[2], table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size);
1094+
DPErrcheck(gpuGetLastError());
1095+
DPErrcheck(gpuDeviceSynchronize());
1096+
}
1097+
1098+
template <typename FPTYPE>
1099+
void tabulate_fusion_se_t_tebd_grad_grad_gpu(FPTYPE* dz_dy,
1100+
const FPTYPE* table,
1101+
const FPTYPE* table_info,
1102+
const FPTYPE* em_x,
1103+
const FPTYPE* em,
1104+
const FPTYPE* dz_dy_dem_x,
1105+
const int nloc,
1106+
const int nnei_i,
1107+
const int nnei_j,
1108+
const int last_layer_size) {
1109+
if (nloc <= 0) {
1110+
return;
1111+
}
1112+
DPErrcheck(gpuGetLastError());
1113+
DPErrcheck(gpuDeviceSynchronize());
1114+
DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j * last_layer_size));
1115+
1116+
tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
1117+
<<<nloc, last_layer_size>>>(
1118+
dz_dy, table, em_x, em, dz_dy_dem_x,
1119+
table_info[0], table_info[1], table_info[2], table_info[3], table_info[4],
1120+
nnei_i, nnei_j, last_layer_size);
1121+
DPErrcheck(gpuGetLastError());
1122+
DPErrcheck(gpuDeviceSynchronize());
1123+
}
1124+
9261125
template <typename FPTYPE>
9271126
void tabulate_fusion_se_r_gpu(FPTYPE* out,
9281127
const FPTYPE* table,
@@ -1181,4 +1380,75 @@ template void tabulate_fusion_se_r_grad_grad_gpu<double>(
11811380
const int nnei,
11821381
const int last_layer_size);
11831382

1383+
// Template instantiations for SE_T_TEBD GPU functions
1384+
template void tabulate_fusion_se_t_tebd_gpu<float>(
1385+
float* out,
1386+
const float* table,
1387+
const float* table_info,
1388+
const float* em_x,
1389+
const float* em,
1390+
const int nloc,
1391+
const int nnei_i,
1392+
const int nnei_j,
1393+
const int last_layer_size);
1394+
1395+
template void tabulate_fusion_se_t_tebd_gpu<double>(
1396+
double* out,
1397+
const double* table,
1398+
const double* table_info,
1399+
const double* em_x,
1400+
const double* em,
1401+
const int nloc,
1402+
const int nnei_i,
1403+
const int nnei_j,
1404+
const int last_layer_size);
1405+
1406+
template void tabulate_fusion_se_t_tebd_grad_gpu<float>(
1407+
float* dy_dem_x,
1408+
const float* table,
1409+
const float* table_info,
1410+
const float* em_x,
1411+
const float* em,
1412+
const float* dy,
1413+
const int nloc,
1414+
const int nnei_i,
1415+
const int nnei_j,
1416+
const int last_layer_size);
1417+
1418+
template void tabulate_fusion_se_t_tebd_grad_gpu<double>(
1419+
double* dy_dem_x,
1420+
const double* table,
1421+
const double* table_info,
1422+
const double* em_x,
1423+
const double* em,
1424+
const double* dy,
1425+
const int nloc,
1426+
const int nnei_i,
1427+
const int nnei_j,
1428+
const int last_layer_size);
1429+
1430+
template void tabulate_fusion_se_t_tebd_grad_grad_gpu<float>(
1431+
float* dz_dy,
1432+
const float* table,
1433+
const float* table_info,
1434+
const float* em_x,
1435+
const float* em,
1436+
const float* dz_dy_dem_x,
1437+
const int nloc,
1438+
const int nnei_i,
1439+
const int nnei_j,
1440+
const int last_layer_size);
1441+
1442+
template void tabulate_fusion_se_t_tebd_grad_grad_gpu<double>(
1443+
double* dz_dy,
1444+
const double* table,
1445+
const double* table_info,
1446+
const double* em_x,
1447+
const double* em,
1448+
const double* dz_dy_dem_x,
1449+
const int nloc,
1450+
const int nnei_i,
1451+
const int nnei_j,
1452+
const int last_layer_size);
1453+
11841454
} // namespace deepmd

0 commit comments

Comments
 (0)