Skip to content

Commit 9bd92a5

Browse files
committed
Basic loop fusion added
1 parent 4b56c07 commit 9bd92a5

1 file changed

Lines changed: 83 additions & 3 deletions

File tree

src/ninetoothed/generation.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import astor
23
import collections
34
import copy
45
import functools
@@ -25,7 +26,6 @@
2526
CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
2627
CACHE_DIR.mkdir(exist_ok=True)
2728

28-
2929
class CodeGenerator(ast.NodeTransformer):
3030
def __init__(self):
3131
super().__init__()
@@ -59,9 +59,11 @@ def _get_tree(func):
5959

6060
inliner = _Inliner(func.__globals__)
6161
inliner.visit(func_def)
62-
62+
6363
func_def = ast.parse(ast.unparse(func_def))
64-
64+
name_mapping = type(self)._generate_name_mapping_from_tensors(self._args)
65+
loop_fuser = _LoopFuser(self._context, name_mapping)
66+
loop_fuser.visit(func_def)
6567
module = ast.Module(body=[func_def], type_ignores=[])
6668

6769
if inliner.libdevice_used:
@@ -1264,3 +1266,81 @@ def visit_FunctionDef(self, node):
12641266
self.result = node
12651267

12661268
self.generic_visit(node)
1269+
1270+
class _LoopFuser(ast.NodeVisitor):
1271+
def __init__(self, context, name_mapping):
1272+
self._context = context
1273+
self._name_mapping = name_mapping
1274+
self.result = None
1275+
1276+
def _same_loop(self, f1, f2):
1277+
return (
1278+
ast.dump(f1.target) == ast.dump(f2.target) and
1279+
ast.dump(f1.iter) == ast.dump(f2.iter)
1280+
)
1281+
1282+
# === 新增:变量分析 ===
1283+
class _VarRWAnalyzer(ast.NodeVisitor):
1284+
def __init__(self):
1285+
self.read = set()
1286+
self.write = set()
1287+
1288+
def visit_Name(self, node):
1289+
if isinstance(node.ctx, ast.Load):
1290+
self.read.add(node.id)
1291+
elif isinstance(node.ctx, ast.Store):
1292+
self.write.add(node.id)
1293+
1294+
def _loop_carried_vars(self, loop):
1295+
analyzer = self._VarRWAnalyzer()
1296+
for stmt in loop.body:
1297+
analyzer.visit(stmt)
1298+
return analyzer.read & analyzer.write
1299+
1300+
def _loop_reads_vars(self, loop, vars):
1301+
analyzer = self._VarRWAnalyzer()
1302+
for stmt in loop.body:
1303+
analyzer.visit(stmt)
1304+
return bool(analyzer.read & vars)
1305+
1306+
# === 新增:是否允许融合 ===
1307+
def _can_fuse(self, loop1, loop2):
1308+
carried = self._loop_carried_vars(loop1)
1309+
if not carried:
1310+
return True
1311+
if self._loop_reads_vars(loop2, carried):
1312+
return False
1313+
return True
1314+
1315+
def visit_FunctionDef(self, node):
1316+
self.generic_visit(node)
1317+
1318+
fused_body = []
1319+
body = node.body
1320+
i = 0
1321+
1322+
while i < len(body):
1323+
if not isinstance(body[i], ast.For):
1324+
fused_body.append(body[i])
1325+
i += 1
1326+
continue
1327+
1328+
fused_for = body[i]
1329+
j = i + 1
1330+
1331+
while (
1332+
j < len(body)
1333+
and isinstance(body[j], ast.For)
1334+
and self._same_loop(fused_for, body[j])
1335+
and self._can_fuse(fused_for, body[j]) # 👈 关键
1336+
):
1337+
fused_for.body.extend(body[j].body)
1338+
j += 1
1339+
1340+
fused_body.append(fused_for)
1341+
i = j
1342+
1343+
node.body = fused_body
1344+
self.result = node
1345+
return node
1346+

0 commit comments

Comments
 (0)