|
1 | 1 | import ast |
| 2 | +import astor |
2 | 3 | import collections |
3 | 4 | import copy |
4 | 5 | import functools |
|
25 | 26 | CACHE_DIR = pathlib.Path.home() / ".ninetoothed" |
26 | 27 | CACHE_DIR.mkdir(exist_ok=True) |
27 | 28 |
|
28 | | - |
29 | 29 | class CodeGenerator(ast.NodeTransformer): |
30 | 30 | def __init__(self): |
31 | 31 | super().__init__() |
@@ -59,9 +59,11 @@ def _get_tree(func): |
59 | 59 |
|
60 | 60 | inliner = _Inliner(func.__globals__) |
61 | 61 | inliner.visit(func_def) |
62 | | - |
| 62 | + |
63 | 63 | func_def = ast.parse(ast.unparse(func_def)) |
64 | | - |
| 64 | + name_mapping = type(self)._generate_name_mapping_from_tensors(self._args) |
| 65 | + loop_fuser = _LoopFuser(self._context, name_mapping) |
| 66 | + loop_fuser.visit(func_def) |
65 | 67 | module = ast.Module(body=[func_def], type_ignores=[]) |
66 | 68 |
|
67 | 69 | if inliner.libdevice_used: |
@@ -1264,3 +1266,81 @@ def visit_FunctionDef(self, node): |
1264 | 1266 | self.result = node |
1265 | 1267 |
|
1266 | 1268 | self.generic_visit(node) |
| 1269 | + |
| 1270 | +class _LoopFuser(ast.NodeVisitor): |
| 1271 | + def __init__(self, context, name_mapping): |
| 1272 | + self._context = context |
| 1273 | + self._name_mapping = name_mapping |
| 1274 | + self.result = None |
| 1275 | + |
| 1276 | + def _same_loop(self, f1, f2): |
| 1277 | + return ( |
| 1278 | + ast.dump(f1.target) == ast.dump(f2.target) and |
| 1279 | + ast.dump(f1.iter) == ast.dump(f2.iter) |
| 1280 | + ) |
| 1281 | + |
| 1282 | + # === 新增:变量分析 === |
| 1283 | + class _VarRWAnalyzer(ast.NodeVisitor): |
| 1284 | + def __init__(self): |
| 1285 | + self.read = set() |
| 1286 | + self.write = set() |
| 1287 | + |
| 1288 | + def visit_Name(self, node): |
| 1289 | + if isinstance(node.ctx, ast.Load): |
| 1290 | + self.read.add(node.id) |
| 1291 | + elif isinstance(node.ctx, ast.Store): |
| 1292 | + self.write.add(node.id) |
| 1293 | + |
| 1294 | + def _loop_carried_vars(self, loop): |
| 1295 | + analyzer = self._VarRWAnalyzer() |
| 1296 | + for stmt in loop.body: |
| 1297 | + analyzer.visit(stmt) |
| 1298 | + return analyzer.read & analyzer.write |
| 1299 | + |
| 1300 | + def _loop_reads_vars(self, loop, vars): |
| 1301 | + analyzer = self._VarRWAnalyzer() |
| 1302 | + for stmt in loop.body: |
| 1303 | + analyzer.visit(stmt) |
| 1304 | + return bool(analyzer.read & vars) |
| 1305 | + |
| 1306 | + # === 新增:是否允许融合 === |
| 1307 | + def _can_fuse(self, loop1, loop2): |
| 1308 | + carried = self._loop_carried_vars(loop1) |
| 1309 | + if not carried: |
| 1310 | + return True |
| 1311 | + if self._loop_reads_vars(loop2, carried): |
| 1312 | + return False |
| 1313 | + return True |
| 1314 | + |
| 1315 | + def visit_FunctionDef(self, node): |
| 1316 | + self.generic_visit(node) |
| 1317 | + |
| 1318 | + fused_body = [] |
| 1319 | + body = node.body |
| 1320 | + i = 0 |
| 1321 | + |
| 1322 | + while i < len(body): |
| 1323 | + if not isinstance(body[i], ast.For): |
| 1324 | + fused_body.append(body[i]) |
| 1325 | + i += 1 |
| 1326 | + continue |
| 1327 | + |
| 1328 | + fused_for = body[i] |
| 1329 | + j = i + 1 |
| 1330 | + |
| 1331 | + while ( |
| 1332 | + j < len(body) |
| 1333 | + and isinstance(body[j], ast.For) |
| 1334 | + and self._same_loop(fused_for, body[j]) |
| 1335 | + and self._can_fuse(fused_for, body[j]) # 👈 关键 |
| 1336 | + ): |
| 1337 | + fused_for.body.extend(body[j].body) |
| 1338 | + j += 1 |
| 1339 | + |
| 1340 | + fused_body.append(fused_for) |
| 1341 | + i = j |
| 1342 | + |
| 1343 | + node.body = fused_body |
| 1344 | + self.result = node |
| 1345 | + return node |
| 1346 | + |
0 commit comments