@@ -997,24 +997,6 @@ TestResult NNModuleTest::testModuleRoPE() {
997997 infinicore::nn::RoPE rope1 (128 , 2048 );
998998
999999 auto state1 = rope1.state_dict ();
1000- if (state1.find (" sin_cache" ) == state1.end ()) {
1001- spdlog::error (" RoPE sin_cache not found in state dict" );
1002- return false ;
1003- }
1004- if (state1.find (" cos_cache" ) == state1.end ()) {
1005- spdlog::error (" RoPE cos_cache not found in state dict" );
1006- return false ;
1007- }
1008-
1009- if (rope1.sin_cache ()->shape () != std::vector<size_t >({2048 , 64 })) {
1010- spdlog::error (" RoPE sin_cache shape mismatch. Expected {{2048, 64}}" );
1011- return false ;
1012- }
1013-
1014- if (rope1.cos_cache ()->shape () != std::vector<size_t >({2048 , 64 })) {
1015- spdlog::error (" RoPE cos_cache shape mismatch. Expected {{2048, 64}}" );
1016- return false ;
1017- }
10181000
10191001 if (rope1.head_dim () != 128 ) {
10201002 spdlog::error (" head_dim mismatch. Expected 128, got {}" , rope1.head_dim ());
@@ -1050,15 +1032,15 @@ TestResult NNModuleTest::testModuleRoPE() {
10501032
10511033 // Test 3: Different algorithms
10521034 spdlog::info (" Test 3: Testing different algorithms" );
1053- infinicore::nn::RoPE rope_gptj (64 , 1024 , 10000.0 , INFINIOP_ROPE_ALGO_GPT_J );
1054- infinicore::nn::RoPE rope_gptneox (64 , 1024 , 10000.0 , INFINIOP_ROPE_ALGO_GPT_NEOX );
1035+ infinicore::nn::RoPE rope_gptj (64 , 1024 , 10000.0 , infinicore::nn::RoPEAlgo::GPT_J );
1036+ infinicore::nn::RoPE rope_gptneox (64 , 1024 , 10000.0 , infinicore::nn::RoPEAlgo::GPT_NEOX );
10551037
1056- if (rope_gptj.algo () != INFINIOP_ROPE_ALGO_GPT_J ) {
1038+ if (rope_gptj.algo () != infinicore::nn::RoPEAlgo::GPT_J ) {
10571039 spdlog::error (" GPT_J algorithm not set correctly" );
10581040 return false ;
10591041 }
10601042
1061- if (rope_gptneox.algo () != INFINIOP_ROPE_ALGO_GPT_NEOX ) {
1043+ if (rope_gptneox.algo () != infinicore::nn::RoPEAlgo::GPT_NEOX ) {
10621044 spdlog::error (" GPT_NEOX algorithm not set correctly" );
10631045 return false ;
10641046 }
@@ -1104,27 +1086,9 @@ TestResult NNModuleTest::testModuleRoPE() {
11041086 spdlog::debug (" Different theta values test passed" );
11051087
11061088 // Test 5: load_state_dict
1107- spdlog::info (" Test 5: Testing load_state_dict for RoPE" );
1108- auto new_sin_cache = infinicore::Tensor::ones ({2048 , 64 }, infinicore::DataType::F32, infinicore::Device ());
1109- auto new_cos_cache = infinicore::Tensor::ones ({2048 , 64 }, infinicore::DataType::F32, infinicore::Device ());
1110-
11111089 std::unordered_map<std::string, infinicore::Tensor> new_state;
1112- new_state.emplace (" sin_cache" , new_sin_cache);
1113- new_state.emplace (" cos_cache" , new_cos_cache);
1114-
11151090 rope1.load_state_dict (new_state);
1116-
1117- if (!tensorsAllClose (rope1.sin_cache (), new_sin_cache, 1e-7 , 1e-7 )) {
1118- spdlog::error (" RoPE sin_cache not loaded correctly" );
1119- return false ;
1120- }
1121-
1122- if (!tensorsAllClose (rope1.cos_cache (), new_cos_cache, 1e-7 , 1e-7 )) {
1123- spdlog::error (" RoPE cos_cache not loaded correctly" );
1124- return false ;
1125- }
1126-
1127- spdlog::debug (" load_state_dict for RoPE passed" );
1091+ spdlog::debug (" load_state_dict for RoPE passed (no parameters to load)" );
11281092
11291093 // Test 6: extra_repr
11301094 spdlog::info (" Test 6: Testing extra_repr" );
0 commit comments