File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11import ast
22import ctypes
3- import itertools
43import pathlib
54import re
65import shutil
@@ -310,11 +309,22 @@ def _spec_from_combo(combo):
310309 (name , dim ) for name , dim in zip (launch_arg_names , combo ) if dim is not None
311310 )
312311
313- dim_specs = tuple (
314- _spec_from_combo (combo )
315- for combo in itertools .product (* per_tensor_dims )
316- if any (dim is not None for dim in combo )
317- ) + ((),)
312+ inner_combo = tuple (dims [0 ] for dims in per_tensor_dims )
313+ combos = [inner_combo ] if any (d is not None for d in inner_combo ) else []
314+
315+ for i , dims in enumerate (per_tensor_dims ):
316+ if len (dims ) <= 1 :
317+ continue
318+
319+ for alt in dims [1 :]:
320+ if alt is None :
321+ continue
322+
323+ combo = list (inner_combo )
324+ combo [i ] = alt
325+ combos .append (tuple (combo ))
326+
327+ dim_specs = tuple (_spec_from_combo (combo ) for combo in combos ) + ((),)
318328
319329 specs = []
320330
You can’t perform that action at this time.
0 commit comments