Skip to content

Commit d687c49

Browse files
committed
dy should be all ones for sum
1 parent 7002e33 commit d687c49

1 file changed

Lines changed: 3 additions & 5 deletions

File tree

source/lib/tests/test_tabulate_se_t_tebd.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ class TestTabulateSeTTebd : public ::testing::Test {
507507
4.0511515406019899e-01};
508508

509509
// Environment matrix data (em) - same as em_x reshaped to 4x4x4
510-
std::vector<double> em = em_x; // Will be reshaped in tests
510+
std::vector<double> em = em_x;
511511

512512
// Expected outputs
513513
std::vector<double> expected_xyz_scatter = {
@@ -689,8 +689,7 @@ TEST_F(TestTabulateSeTTebd, tabulate_fusion_se_t_tebd_cpu) {
689689

690690
TEST_F(TestTabulateSeTTebd, tabulate_fusion_se_t_tebd_grad_cpu) {
691691
std::vector<double> dy_dem_x(em_x.size());
692-
std::vector<double> dy(expected_xyz_scatter.begin(),
693-
expected_xyz_scatter.end());
692+
std::vector<double> dy(nloc * nnei_i * nnei_j * last_layer_size, 1.0);
694693

695694
deepmd::tabulate_fusion_se_t_tebd_grad_cpu<double>(
696695
&dy_dem_x[0], &table[0], &table_info[0], &em_x[0], &em[0], &dy[0], nloc,
@@ -738,8 +737,7 @@ TEST_F(TestTabulateSeTTebd, tabulate_fusion_se_t_tebd_gpu) {
738737

739738
TEST_F(TestTabulateSeTTebd, tabulate_fusion_se_t_tebd_grad_gpu) {
740739
std::vector<double> dy_dem_x(em_x.size(), 0.0);
741-
std::vector<double> dy(expected_xyz_scatter.begin(),
742-
expected_xyz_scatter.end());
740+
std::vector<double> dy(nloc * nnei_i * nnei_j * last_layer_size, 1.0);
743741

744742
double *dy_dem_x_dev = NULL, *table_dev = NULL, *table_info_dev = NULL,
745743
*em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL;

0 commit comments

Comments
 (0)