Skip to content

Commit efbf963

Browse files
authored
Linearize AOT variant enumeration (#154)
1 parent dc51d41 commit efbf963

1 file changed

Lines changed: 16 additions & 6 deletions

File tree

src/ninetoothed/aot.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import ast
22
import ctypes
3-
import itertools
43
import pathlib
54
import re
65
import shutil
@@ -310,11 +309,22 @@ def _spec_from_combo(combo):
310309
(name, dim) for name, dim in zip(launch_arg_names, combo) if dim is not None
311310
)
312311

313-
dim_specs = tuple(
314-
_spec_from_combo(combo)
315-
for combo in itertools.product(*per_tensor_dims)
316-
if any(dim is not None for dim in combo)
317-
) + ((),)
312+
base_combo = tuple(dims[0] for dims in per_tensor_dims)
313+
combos = [base_combo] if any(dim is not None for dim in base_combo) else []
314+
315+
for i, dims in enumerate(per_tensor_dims):
316+
if len(dims) <= 1:
317+
continue
318+
319+
for alternative_dim in dims[1:]:
320+
if alternative_dim is None:
321+
continue
322+
323+
combo = list(base_combo)
324+
combo[i] = alternative_dim
325+
combos.append(tuple(combo))
326+
327+
dim_specs = tuple(_spec_from_combo(combo) for combo in combos) + ((),)
318328

319329
specs = []
320330

0 commit comments

Comments
 (0)