Skip to content

Commit 579ffe0

Browse files
committed
Basic loop fusion added
1 parent 4b56c07 commit 579ffe0

1 file changed

Lines changed: 82 additions & 3 deletions

File tree

src/ninetoothed/generation.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
2626
CACHE_DIR.mkdir(exist_ok=True)
2727

28-
2928
class CodeGenerator(ast.NodeTransformer):
3029
def __init__(self):
3130
super().__init__()
@@ -59,9 +58,11 @@ def _get_tree(func):
5958

6059
inliner = _Inliner(func.__globals__)
6160
inliner.visit(func_def)
62-
61+
6362
func_def = ast.parse(ast.unparse(func_def))
64-
63+
name_mapping = type(self)._generate_name_mapping_from_tensors(self._args)
64+
loop_fuser = _LoopFuser(self._context, name_mapping)
65+
loop_fuser.visit(func_def)
6566
module = ast.Module(body=[func_def], type_ignores=[])
6667

6768
if inliner.libdevice_used:
@@ -1264,3 +1265,81 @@ def visit_FunctionDef(self, node):
12641265
self.result = node
12651266

12661267
self.generic_visit(node)
1268+
1269+
class _LoopFuser(ast.NodeVisitor):
1270+
def __init__(self, context, name_mapping):
1271+
self._context = context
1272+
self._name_mapping = name_mapping
1273+
self.result = None
1274+
1275+
def _same_loop(self, f1, f2):
1276+
return (
1277+
ast.dump(f1.target) == ast.dump(f2.target) and
1278+
ast.dump(f1.iter) == ast.dump(f2.iter)
1279+
)
1280+
1281+
# === 新增:变量分析 ===
1282+
class _VarRWAnalyzer(ast.NodeVisitor):
1283+
def __init__(self):
1284+
self.read = set()
1285+
self.write = set()
1286+
1287+
def visit_Name(self, node):
1288+
if isinstance(node.ctx, ast.Load):
1289+
self.read.add(node.id)
1290+
elif isinstance(node.ctx, ast.Store):
1291+
self.write.add(node.id)
1292+
1293+
def _loop_carried_vars(self, loop):
1294+
analyzer = self._VarRWAnalyzer()
1295+
for stmt in loop.body:
1296+
analyzer.visit(stmt)
1297+
return analyzer.read & analyzer.write
1298+
1299+
def _loop_reads_vars(self, loop, vars):
1300+
analyzer = self._VarRWAnalyzer()
1301+
for stmt in loop.body:
1302+
analyzer.visit(stmt)
1303+
return bool(analyzer.read & vars)
1304+
1305+
# === 新增:是否允许融合 ===
1306+
def _can_fuse(self, loop1, loop2):
1307+
carried = self._loop_carried_vars(loop1)
1308+
if not carried:
1309+
return True
1310+
if self._loop_reads_vars(loop2, carried):
1311+
return False
1312+
return True
1313+
1314+
def visit_FunctionDef(self, node):
1315+
self.generic_visit(node)
1316+
1317+
fused_body = []
1318+
body = node.body
1319+
i = 0
1320+
1321+
while i < len(body):
1322+
if not isinstance(body[i], ast.For):
1323+
fused_body.append(body[i])
1324+
i += 1
1325+
continue
1326+
1327+
fused_for = body[i]
1328+
j = i + 1
1329+
1330+
while (
1331+
j < len(body)
1332+
and isinstance(body[j], ast.For)
1333+
and self._same_loop(fused_for, body[j])
1334+
and self._can_fuse(fused_for, body[j]) # 👈 关键
1335+
):
1336+
fused_for.body.extend(body[j].body)
1337+
j += 1
1338+
1339+
fused_body.append(fused_for)
1340+
i = j
1341+
1342+
node.body = fused_body
1343+
self.result = node
1344+
return node
1345+

0 commit comments

Comments
 (0)