@@ -180,9 +180,7 @@ def _get_cdist_implementation(
180180 opsets : Dict [str , int ],
181181 ** kwargs : Any ,
182182 ) -> FunctionProto :
183- """
184- Returns the CDist implementation as a function.
185- """
183+ """Returns the CDist implementation as a function."""
186184 assert len (node_inputs ) == 2
187185 assert len (node_outputs ) == 1
188186 assert opsets
@@ -191,39 +189,39 @@ def _get_cdist_implementation(
191189 metric = kwargs ["metric" ]
192190 assert metric in ("euclidean" , "sqeuclidean" )
193191 # subgraph
194- nodes = [
195- oh .make_node ("Sub" , ["next" , "next_in" ], ["diff" ]),
196- oh .make_node ("Constant" , [], ["axis" ], value_ints = [1 ]),
197- oh .make_node ("ReduceSumSquare" , ["diff" , "axis" ], ["scan_out" ], keepdims = 0 ),
198- oh .make_node ("Identity" , ["next_in" ], ["next_out" ]),
199- ]
200192
201193 def make_value (name ):
202194 value = ValueInfoProto ()
203195 value .name = name
204196 return value
205197
206198 graph = oh .make_graph (
207- nodes ,
199+ [
200+ oh .make_node ("Sub" , ["next" , "next_in" ], ["diff" ]),
201+ oh .make_node ("Constant" , [], ["axis" ], value_ints = [1 ]),
202+ oh .make_node ("ReduceSumSquare" , ["diff" , "axis" ], ["scan_out" ], keepdims = 0 ),
203+ oh .make_node ("Identity" , ["next_in" ], ["next_out" ]),
204+ ],
208205 "loop" ,
209206 [make_value ("next_in" ), make_value ("next" )],
210207 [make_value ("next_out" ), make_value ("scan_out" )],
211208 )
212209
213- scan = oh .make_node (
214- "Scan" , ["xb" , "xa" ], ["next_out" , "zout" ], num_scan_inputs = 1 , body = graph
215- )
216- final = (
217- oh .make_node ("Sqrt" , ["zout" ], ["z" ])
218- if metric == "euclidean"
219- else oh .make_node ("Identity" , ["zout" ], ["z" ])
220- )
221210 return oh .make_function (
222211 "npx" ,
223212 f"CDist_{ metric } " ,
224213 ["xa" , "xb" ],
225214 ["z" ],
226- [scan , final ],
215+ [
216+ oh .make_node (
217+ "Scan" , ["xb" , "xa" ], ["next_out" , "zout" ], num_scan_inputs = 1 , body = graph
218+ ),
219+ (
220+ oh .make_node ("Sqrt" , ["zout" ], ["z" ])
221+ if metric == "euclidean"
222+ else oh .make_node ("Identity" , ["zout" ], ["z" ])
223+ ),
224+ ],
227225 [oh .make_opsetid ("" , opsets ["" ])],
228226 )
229227
@@ -234,9 +232,7 @@ def test_iterate_function(self):
234232 )
235233 model = oh .make_model (
236234 oh .make_graph (
237- [
238- oh .make_node (proto .name , ["X" , "Y" ], ["Z" ]),
239- ],
235+ [oh .make_node (proto .name , ["X" , "Y" ], ["Z" ])],
240236 "dummy" ,
241237 [
242238 oh .make_tensor_value_info ("X" , itype , [None , None ]),
0 commit comments