@@ -1023,6 +1023,106 @@ def test_scatter(axis: int, name: str, opset: int):
10231023 check_correctness (model , inputs = {"indices" : indices }, opset = opset )
10241024
10251025
1026+ @pytest .mark .parametrize (
1027+ "reduction, opset, data, indices, updates" ,
1028+ [
1029+ (
1030+ None ,
1031+ 11 ,
1032+ np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = "float32" ),
1033+ np .array ([[2 , 0 , 1 ], [1 , 2 , 0 ]], dtype = "int64" ),
1034+ np .array ([[30 , 10 , 20 ], [50 , 60 , 40 ]], dtype = "float32" ),
1035+ ),
1036+ (
1037+ "none" ,
1038+ 18 ,
1039+ np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = "float32" ),
1040+ np .array ([[2 , 0 , 1 ], [1 , 2 , 0 ]], dtype = "int64" ),
1041+ np .array ([[30 , 10 , 20 ], [50 , 60 , 40 ]], dtype = "float32" ),
1042+ ),
1043+ (
1044+ "add" ,
1045+ 16 ,
1046+ np .full ((2 , 3 ), 10 , dtype = "float32" ),
1047+ np .array ([[0 , 0 , 2 ], [1 , 1 , 2 ]], dtype = "int64" ),
1048+ np .array ([[2 , 5 , 7 ], [20 , 3 , 4 ]], dtype = "float32" ),
1049+ ),
1050+ (
1051+ "mul" ,
1052+ 16 ,
1053+ np .full ((2 , 3 ), 10 , dtype = "float32" ),
1054+ np .array ([[0 , 0 , 2 ], [1 , 1 , 2 ]], dtype = "int64" ),
1055+ np .array ([[2 , 5 , 7 ], [20 , 3 , 4 ]], dtype = "float32" ),
1056+ ),
1057+ (
1058+ "min" ,
1059+ 18 ,
1060+ np .full ((2 , 3 ), 10 , dtype = "float32" ),
1061+ np .array ([[0 , 0 , 2 ], [1 , 1 , 2 ]], dtype = "int64" ),
1062+ np .array ([[2 , 5 , 7 ], [20 , 3 , 4 ]], dtype = "float32" ),
1063+ ),
1064+ (
1065+ "max" ,
1066+ 18 ,
1067+ np .full ((2 , 3 ), 10 , dtype = "float32" ),
1068+ np .array ([[0 , 0 , 2 ], [1 , 1 , 2 ]], dtype = "int64" ),
1069+ np .array ([[2 , 5 , 7 ], [20 , 3 , 4 ]], dtype = "float32" ),
1070+ ),
1071+ ],
1072+ )
1073+ def test_scatter_elements_reduction (reduction , opset , data , indices , updates ):
1074+ attrs = {"axis" : 1 }
1075+ if reduction is not None :
1076+ attrs ["reduction" ] = reduction
1077+ scatter_elements_node = helper .make_node (
1078+ "ScatterElements" , ["data" , "indices" , "updates" ], ["output" ], ** attrs
1079+ )
1080+
1081+ graph = helper .make_graph (
1082+ [scatter_elements_node ],
1083+ "scatter_elements_reduction_test" ,
1084+ inputs = [
1085+ helper .make_tensor_value_info ("data" , TensorProto .FLOAT , list (data .shape )),
1086+ helper .make_tensor_value_info ("indices" , TensorProto .INT64 , list (indices .shape )),
1087+ helper .make_tensor_value_info ("updates" , TensorProto .FLOAT , list (updates .shape )),
1088+ ],
1089+ outputs = [helper .make_tensor_value_info ("output" , TensorProto .FLOAT , list (data .shape ))],
1090+ )
1091+ model = helper .make_model (graph , producer_name = "scatter_elements_reduction_test" )
1092+
1093+ check_correctness (
1094+ model ,
1095+ inputs = {"data" : data , "indices" : indices , "updates" : updates },
1096+ opset = opset ,
1097+ )
1098+
1099+
1100+ def test_scatter_elements_invalid_reduction ():
1101+ data_shape = [2 , 3 ]
1102+ scatter_elements_node = helper .make_node (
1103+ "ScatterElements" ,
1104+ ["data" , "indices" , "updates" ],
1105+ ["output" ],
1106+ axis = 1 ,
1107+ reduction = "unsupported" ,
1108+ )
1109+
1110+ graph = helper .make_graph (
1111+ [scatter_elements_node ],
1112+ "scatter_elements_invalid_reduction_test" ,
1113+ inputs = [
1114+ helper .make_tensor_value_info ("data" , TensorProto .FLOAT , data_shape ),
1115+ helper .make_tensor_value_info ("indices" , TensorProto .INT64 , data_shape ),
1116+ helper .make_tensor_value_info ("updates" , TensorProto .FLOAT , data_shape ),
1117+ ],
1118+ outputs = [helper .make_tensor_value_info ("output" , TensorProto .FLOAT , data_shape )],
1119+ )
1120+ model = helper .make_model (graph , producer_name = "scatter_elements_invalid_reduction_test" )
1121+
1122+ with pytest .raises (ValueError , match = "Only .* reductions are supported, but got unsupported" ):
1123+ from_onnx (model , opset = 18 , keep_params_in_input = True )
1124+
1125+
10261126@pytest .mark .parametrize ("reduction" , ["none" , "add" , "mul" ])
10271127def test_scatter_nd (reduction ):
10281128 def verify_scatter_nd (data_shape , indices_shape , updates_shape ):
0 commit comments