@@ -76,5 +76,24 @@ TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLarge) {
7676 test.Run (OpTester::ExpectResult::kExpectFailure , " out of range" );
7777}
7878
79+ TEST (CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLargeWithWeights) {
80+ OpTester test (" SoftmaxCrossEntropyLossGrad" , 1 , onnxruntime::kMSDomain );
81+ test.AddAttribute (" reduction" , std::string (" mean" ));
82+ test.AddAttribute (" ignore_index" , static_cast <int64_t >(-1 ));
83+
84+ std::vector<float > dY_data = {1 .0f };
85+ std::vector<float > log_prob_data (3 * 5 , -1 .6094f );
86+ std::vector<int64_t > index_data = {0 , 5 , 2 }; // 5 is out of range [0, 5)
87+ std::vector<float > weight_data = {1 .0f , 1 .0f , 1 .0f , 1 .0f , 1 .0f };
88+
89+ test.AddInput <float >(" dY" , {}, dY_data);
90+ test.AddInput <float >(" log_prob" , {3 , 5 }, log_prob_data);
91+ test.AddInput <int64_t >(" index" , {3 }, index_data);
92+ test.AddInput <float >(" weight" , {5 }, weight_data);
93+ test.AddOutput <float >(" dX" , {3 , 5 }, std::vector<float >(15 , 0 .0f ));
94+
95+ test.Run (OpTester::ExpectResult::kExpectFailure , " out of range" );
96+ }
97+
7998} // namespace test
8099} // namespace onnxruntime
0 commit comments