@@ -507,7 +507,7 @@ class TestTabulateSeTTebd : public ::testing::Test {
507507 4.0511515406019899e-01 };
508508
509509 // Environment matrix data (em) - same as em_x reshaped to 4x4x4
510- std::vector<double > em = em_x; // Will be reshaped in tests
510+ std::vector<double > em = em_x;
511511
512512 // Expected outputs
513513 std::vector<double > expected_xyz_scatter = {
@@ -689,8 +689,7 @@ TEST_F(TestTabulateSeTTebd, tabulate_fusion_se_t_tebd_cpu) {
689689
690690TEST_F (TestTabulateSeTTebd, tabulate_fusion_se_t_tebd_grad_cpu) {
691691 std::vector<double > dy_dem_x (em_x.size ());
692- std::vector<double > dy (expected_xyz_scatter.begin (),
693- expected_xyz_scatter.end ());
692+ std::vector<double > dy (nloc * nnei_i * nnei_j * last_layer_size, 1.0 );
694693
695694 deepmd::tabulate_fusion_se_t_tebd_grad_cpu<double >(
696695 &dy_dem_x[0 ], &table[0 ], &table_info[0 ], &em_x[0 ], &em[0 ], &dy[0 ], nloc,
@@ -738,8 +737,7 @@ TEST_F(TestTabulateSeTTebd, tabulate_fusion_se_t_tebd_gpu) {
738737
739738TEST_F (TestTabulateSeTTebd, tabulate_fusion_se_t_tebd_grad_gpu) {
740739 std::vector<double > dy_dem_x (em_x.size (), 0.0 );
741- std::vector<double > dy (expected_xyz_scatter.begin (),
742- expected_xyz_scatter.end ());
740+ std::vector<double > dy (nloc * nnei_i * nnei_j * last_layer_size, 1.0 );
743741
744742 double *dy_dem_x_dev = NULL , *table_dev = NULL , *table_info_dev = NULL ,
745743 *em_x_dev = NULL , *em_dev = NULL , *dy_dev = NULL ;
0 commit comments