Skip to content

Commit e58d032

Browse files
committed
feat: improve CUDA realization
1 parent 27a4e30 commit e58d032

2 files changed

Lines changed: 179 additions & 83 deletions

File tree

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,8 +960,8 @@ def forward(
960960
gg_s = torch.ops.deepmd.tabulate_fusion_se_t_tebd(
961961
self.compress_data[0].contiguous(),
962962
self.compress_info[0].cpu().contiguous(),
963-
ebd_env_ij.contiguous(),
964-
env_ij.contiguous(),
963+
ebd_env_ij.contiguous(), # em_x: (nfnl * nt_i * nt_j, 1)
964+
env_ij.contiguous(), # em: (nfnl, nt_i, nt_j)
965965
self.filter_neuron[-1],
966966
)[0]
967967
# SE_T_TEBD tabulation preserves the full neighbor structure

source/lib/src/gpu/tabulate.cu

Lines changed: 177 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,38 @@ __forceinline__ __device__ void locate_xx_se_t(FPTYPE& xx,
103103
}
104104
}
105105

106+
// same with locate_xx_se_t
107+
template <typename FPTYPE>
108+
__forceinline__ __device__ void locate_xx_se_t_tebd(FPTYPE& xx,
109+
int& table_idx,
110+
const FPTYPE& lower,
111+
const FPTYPE& upper,
112+
const FPTYPE& min,
113+
const FPTYPE& max,
114+
const FPTYPE& stride0,
115+
const FPTYPE& stride1) {
116+
if (xx < min) {
117+
table_idx = 0;
118+
xx = (FPTYPE)0.;
119+
} else if (xx < lower) {
120+
table_idx = (int)((xx - min) / stride1);
121+
xx -= (table_idx * stride1 + min);
122+
} else if (xx < upper) {
123+
int first_stride = int((lower - min) / stride1);
124+
table_idx = first_stride + (int)((xx - lower) / stride0);
125+
xx -= ((table_idx - first_stride) * stride0 + lower);
126+
} else if (xx < max) {
127+
int first_stride =
128+
int((lower - min) / stride1) + int((upper - lower) / stride0);
129+
table_idx = first_stride + (int)((xx - upper) / stride1);
130+
xx -= ((table_idx - first_stride) * stride1 + upper);
131+
} else {
132+
table_idx = int((lower - min) / stride1) + int((upper - lower) / stride0) +
133+
(int)((max - upper) / stride1) - 1;
134+
xx = (FPTYPE)0.;
135+
}
136+
}
137+
106138
template <typename FPTYPE>
107139
__forceinline__ __device__ void locate_xx_se_r(FPTYPE& xx,
108140
int& table_idx,
@@ -644,30 +676,48 @@ __global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
644676
const int nnei_i,
645677
const int nnei_j,
646678
const int last_layer_size) {
647-
const int_64 block_idx = blockIdx.x; // nloc
648-
const int thread_idx = threadIdx.x; // last_layer_size
649-
650-
for (int ii = 0; ii < nnei_i; ii++) {
651-
for (int jj = 0; jj < nnei_j; jj++) {
652-
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
653-
int table_idx = 0;
654-
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
655-
656-
FPTYPE var[6];
657-
load_polynomial_params(var, table, table_idx, thread_idx,
658-
last_layer_size);
659-
660-
FPTYPE res =
661-
var[0] +
662-
(var[1] +
663-
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
664-
xx;
665-
666-
// Store result preserving the nt_i x nt_j structure
667-
out[block_idx * nnei_i * nnei_j * last_layer_size +
668-
ii * nnei_j * last_layer_size + jj * last_layer_size + thread_idx] =
669-
res;
670-
}
679+
// NOT USED: em: (nfnl, nnei_i, nnei_j)
680+
// em_x: (nfnl * nnei_i * nnei_j, 1) flat version of em
681+
// blockDim.x -> total threads in a block: nnei_i * nnei_j
682+
// gridDim.x -> total blocks in a grid: nloc
683+
684+
// Identify which atom and neighbor pair this thread is responsible for.
685+
// block_idx corresponds to the atom index, given by the block index.
686+
const int_64 block_idx = blockIdx.x;
687+
688+
// thread_idx is the flattened 1D index for the (ii, jj) neighbor pair.
689+
const int_64 thread_idx = threadIdx.x;
690+
691+
// Recover the 2D (ii, jj) indices from the 1D thread index.
692+
// thread_idx = ii * nnei_j + jj
693+
const int_64 ii = thread_idx / nnei_j;
694+
const int_64 jj = thread_idx % nnei_j;
695+
696+
// Read the input value xx for this specific neighbor pair.
697+
const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
698+
FPTYPE xx = em_x[em_x_idx];
699+
700+
// Determine the table index based on the value of xx.
701+
int table_idx = 0;
702+
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, stride1);
703+
704+
// Serially loop through the 'last_layer_size' dimension to calculate all
705+
// features.
706+
for (int idx = 0; idx < last_layer_size; idx++) {
707+
FPTYPE var[6];
708+
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
709+
FPTYPE res =
710+
var[0] +
711+
(var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
712+
xx;
713+
// Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
714+
// jj, idx).
715+
const int_64 out_idx =
716+
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
717+
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
718+
idx;
719+
// Write the result to the global output memory.
720+
out[out_idx] = res;
671721
}
672722
}
673723

@@ -686,35 +736,49 @@ __global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial(
686736
const int nnei_i,
687737
const int nnei_j,
688738
const int last_layer_size) {
689-
const int_64 block_idx = blockIdx.x; // nloc
690-
const int thread_idx = threadIdx.x; // thread within block
691-
692-
for (int ii = 0; ii < nnei_i; ii++) {
693-
for (int jj = 0; jj < nnei_j; jj++) {
694-
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
695-
int table_idx = 0;
696-
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
697-
698-
FPTYPE grad_sum = 0.0;
699-
for (int mm = 0; mm < last_layer_size; mm++) {
700-
FPTYPE var[6];
701-
load_polynomial_params(var, table, table_idx, mm, last_layer_size);
739+
// blockDim.x -> nnei_i * nnei_j
740+
// gridDim.x -> nloc
741+
742+
// Identify which atom and neighbor pair this thread is responsible for.
743+
const int block_idx = blockIdx.x;
744+
const int thread_idx = threadIdx.x;
745+
const int ii = thread_idx / nnei_j;
746+
const int jj = thread_idx % nnei_j;
747+
748+
// Determine the table index based on the value of xx.
749+
const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
750+
FPTYPE xx = em_x[em_x_idx];
751+
int table_idx = 0;
752+
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, stride1);
753+
754+
// Accumulate the gradient contributions from all features.
755+
FPTYPE grad_sum = 0.0;
756+
for (int idx = 0; idx < last_layer_size; idx++) {
757+
FPTYPE var[6];
758+
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
702759

703-
FPTYPE dres_dxx = var[1] + 2.0 * var[2] * xx + 3.0 * var[3] * xx * xx +
704-
4.0 * var[4] * xx * xx * xx +
705-
5.0 * var[5] * xx * xx * xx * xx;
760+
// Calculate the derivative of the polynomial with respect to xx.
761+
FPTYPE dres_dxx =
762+
var[1] + ((FPTYPE)2. * var[2] +
763+
((FPTYPE)3. * var[3] +
764+
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
765+
xx) *
766+
xx;
706767

707-
FPTYPE dy_val =
708-
dy[block_idx * nnei_i * nnei_j * last_layer_size +
709-
ii * nnei_j * last_layer_size + jj * last_layer_size + mm];
710-
grad_sum += dy_val * dres_dxx;
711-
}
768+
// Read the incoming gradient from the previous layer.
769+
const int_64 dy_idx =
770+
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
771+
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
772+
idx;
773+
FPTYPE dy_val = dy[dy_idx];
712774

713-
if (thread_idx == 0) { // Only one thread writes the gradient
714-
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum;
715-
}
716-
}
775+
// Apply the chain rule: dL/dxx = sum over idx [ (dL/d_res_mm) *
776+
// (d_res_mm/dxx) ]
777+
grad_sum += dy_val * dres_dxx;
717778
}
779+
780+
// Write the final summed gradient to the output array.
781+
dy_dem_x[em_x_idx] = grad_sum;
718782
}
719783

720784
template <typename FPTYPE, int MTILE, int KTILE>
@@ -732,31 +796,50 @@ __global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial(
732796
const int nnei_i,
733797
const int nnei_j,
734798
const int last_layer_size) {
735-
const int_64 block_idx = blockIdx.x; // nloc
736-
const int thread_idx = threadIdx.x; // last_layer_size
799+
// blockDim.x -> nnei_i * nnei_j
800+
// gridDim.x -> nloc
737801

738-
for (int ii = 0; ii < nnei_i; ii++) {
739-
for (int jj = 0; jj < nnei_j; jj++) {
740-
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
741-
FPTYPE dz_dy_dem_x_val =
742-
dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
802+
// Identify which atom and neighbor pair this thread is responsible for.
803+
const int block_idx = blockIdx.x;
804+
const int thread_idx = threadIdx.x;
805+
const int ii = thread_idx / nnei_j;
806+
const int jj = thread_idx % nnei_j;
743807

744-
int table_idx = 0;
745-
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
808+
const int_64 em_x_idx = (int_64)block_idx * nnei_i * nnei_j + thread_idx;
809+
FPTYPE xx = em_x[em_x_idx];
746810

747-
FPTYPE var[6];
748-
load_polynomial_params(var, table, table_idx, thread_idx,
749-
last_layer_size);
811+
// Read the incoming gradient for xx. This value is the same for all 'idx'
812+
// features.
813+
const FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[em_x_idx];
750814

751-
FPTYPE dres_dxx = var[1] + 2.0 * var[2] * xx + 3.0 * var[3] * xx * xx +
752-
4.0 * var[4] * xx * xx * xx +
753-
5.0 * var[5] * xx * xx * xx * xx;
815+
// Determine the table index based on the value of xx.
816+
int table_idx = 0;
817+
locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, stride1);
754818

755-
// Store result preserving the nt_i x nt_j structure
756-
dz_dy[block_idx * nnei_i * nnei_j * last_layer_size +
757-
ii * nnei_j * last_layer_size + jj * last_layer_size + thread_idx] =
758-
dz_dy_dem_x_val * dres_dxx;
759-
}
819+
// Serially loop through the 'last_layer_size' dimension.
820+
for (int idx = 0; idx < last_layer_size; idx++) {
821+
FPTYPE var[6];
822+
load_polynomial_params(var, table, table_idx, idx, last_layer_size);
823+
824+
// Calculate the derivative of the polynomial with respect to xx.
825+
FPTYPE dres_dxx =
826+
var[1] + ((FPTYPE)2. * var[2] +
827+
((FPTYPE)3. * var[3] +
828+
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
829+
xx) *
830+
xx;
831+
832+
// Apply the chain rule: dz/dy_idx = (dz/dxx) * (dxx/dy_idx)
833+
// which simplifies to dz_dy_dem_x_val * dres_dxx
834+
FPTYPE out_grad = dz_dy_dem_x_val * dres_dxx;
835+
836+
// Calculate the unique 1D output index for the 4D tensor (block_idx, ii,
837+
// jj, idx).
838+
const int_64 out_idx =
839+
(int_64)block_idx * nnei_i * nnei_j * last_layer_size +
840+
(int_64)ii * nnei_j * last_layer_size + (int_64)jj * last_layer_size +
841+
idx;
842+
dz_dy[out_idx] = out_grad;
760843
}
761844
}
762845

@@ -1064,13 +1147,19 @@ void tabulate_fusion_se_t_tebd_gpu(FPTYPE* out,
10641147
const int nnei_i,
10651148
const int nnei_j,
10661149
const int last_layer_size) {
1067-
if (nloc <= 0) {
1150+
if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0) {
10681151
return;
10691152
}
1153+
// Grid dimension: One block for each atom: nloc
1154+
dim3 num_blocks(nloc);
1155+
// Block dimension: One thread for each (ii, jj) neighbor pair: nnei_i *
1156+
// nnei_j
1157+
dim3 num_threads(nnei_i * nnei_j);
1158+
10701159
DPErrcheck(gpuGetLastError());
10711160
DPErrcheck(gpuDeviceSynchronize());
10721161
tabulate_fusion_se_t_tebd_fifth_order_polynomial<FPTYPE, MM, KK>
1073-
<<<nloc, last_layer_size>>>(
1162+
<<<num_blocks, num_threads>>>(
10741163
out, table, em_x, em, table_info[0], table_info[1], table_info[2],
10751164
table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size);
10761165
DPErrcheck(gpuGetLastError());
@@ -1088,18 +1177,22 @@ void tabulate_fusion_se_t_tebd_grad_gpu(FPTYPE* dy_dem_x,
10881177
const int nnei_i,
10891178
const int nnei_j,
10901179
const int last_layer_size) {
1091-
if (nloc <= 0) {
1180+
if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0) {
10921181
return;
10931182
}
1183+
// Define Grid and Block dimensions, matching the forward pass strategy.
1184+
dim3 num_blocks(nloc);
1185+
dim3 num_threads(nnei_i * nnei_j);
1186+
10941187
DPErrcheck(gpuGetLastError());
10951188
DPErrcheck(gpuDeviceSynchronize());
10961189
DPErrcheck(gpuMemset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j));
10971190
// table_info should be on CPU side
10981191
tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial<FPTYPE, MM, KK>
1099-
<<<nloc, KK * WARP_SIZE>>>(dy_dem_x, table, em_x, em, dy, table_info[0],
1100-
table_info[1], table_info[2], table_info[3],
1101-
table_info[4], nnei_i, nnei_j,
1102-
last_layer_size);
1192+
<<<num_blocks, num_threads>>>(dy_dem_x, table, em_x, em, dy,
1193+
table_info[0], table_info[1], table_info[2],
1194+
table_info[3], table_info[4], nnei_i,
1195+
nnei_j, last_layer_size);
11031196
DPErrcheck(gpuGetLastError());
11041197
DPErrcheck(gpuDeviceSynchronize());
11051198
}
@@ -1115,19 +1208,22 @@ void tabulate_fusion_se_t_tebd_grad_grad_gpu(FPTYPE* dz_dy,
11151208
const int nnei_i,
11161209
const int nnei_j,
11171210
const int last_layer_size) {
1118-
if (nloc <= 0) {
1211+
if (nloc <= 0 || nnei_i <= 0 || nnei_j <= 0) {
11191212
return;
11201213
}
1214+
dim3 num_blocks(nloc);
1215+
dim3 num_threads(nnei_i * nnei_j);
1216+
11211217
DPErrcheck(gpuGetLastError());
11221218
DPErrcheck(gpuDeviceSynchronize());
11231219
DPErrcheck(gpuMemset(
11241220
dz_dy, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j * last_layer_size));
11251221

11261222
tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
1127-
<<<nloc, last_layer_size>>>(dz_dy, table, em_x, em, dz_dy_dem_x,
1128-
table_info[0], table_info[1], table_info[2],
1129-
table_info[3], table_info[4], nnei_i, nnei_j,
1130-
last_layer_size);
1223+
<<<num_blocks, num_threads>>>(dz_dy, table, em_x, em, dz_dy_dem_x,
1224+
table_info[0], table_info[1], table_info[2],
1225+
table_info[3], table_info[4], nnei_i,
1226+
nnei_j, last_layer_size);
11311227
DPErrcheck(gpuGetLastError());
11321228
DPErrcheck(gpuDeviceSynchronize());
11331229
}

0 commit comments

Comments
 (0)