Skip to content

Commit 57c477c

Browse files
committed
Detect Python imports using python ast module
1 parent c665bce commit 57c477c

4 files changed

Lines changed: 151 additions & 120 deletions

File tree

src/pyodide/PyodideEvaluator.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ export default class PyodideEvaluator extends BasicEvaluator {
2424
async evaluateChunk(chunk: string): Promise<void> {
2525
const pyodide = await this.pyodide;
2626

27-
// --- Use py-slang's parser to detect and rewrite torch imports ---
28-
const { code, hasTorch } = rewriteTorchImports(chunk);
27+
// --- Use Python's ast module (via Pyodide) to detect and rewrite torch imports ---
28+
const { code, hasTorch } = await rewriteTorchImports(pyodide, chunk);
2929

3030
if (hasTorch && !this.torchLoaded) {
3131
await loadTorch(pyodide);
@@ -34,7 +34,7 @@ export default class PyodideEvaluator extends BasicEvaluator {
3434
}
3535

3636
// --- Install any other imported modules via micropip ---
37-
const otherRoots = getNonTorchImportRoots(chunk);
37+
const otherRoots = await getNonTorchImportRoots(pyodide, chunk);
3838
if (otherRoots.size > 0) {
3939
const modulesArray = Array.from(otherRoots);
4040
const installerCode = `

src/pyodide/importAnalyzer.ts

Lines changed: 80 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
/**
2-
* AST-based import analysis for detecting and rewriting torch imports.
2+
* Import analysis for detecting and rewriting torch imports using Python's
3+
* built-in `ast` module via Pyodide.
34
*
4-
* Uses py-slang's parser to produce an AST from the import prefix of the
5-
* source, then walks FromImport nodes to find torch-related imports —
6-
* replacing the regex-based approach used in sa-conductor-py-torch.
7-
*
8-
* Because pyodide code uses full Python syntax that py-slang cannot parse,
9-
* we extract only the leading `from … import …` lines, append a dummy
10-
* statement so the grammar is satisfied, and parse that fragment.
5+
* This avoids the limitations of py-slang's parser (which only supports a
6+
* subset of Python) by delegating to CPython's own parser running inside
7+
* Pyodide.
118
*/
129

13-
import { parse } from "../parser/parser-adapter";
14-
import { StmtNS } from "../ast-types";
10+
import type { PyodideInterface } from "pyodide";
1511

1612
export interface TorchImportInfo {
1713
/** Full module path, e.g. "torch" or "torch.nn" */
@@ -23,85 +19,88 @@ export interface TorchImportInfo {
2319
}
2420

2521
/**
26-
* Extracts leading `from … import …` lines from the source and returns
27-
* them along with the line index where non-import code begins.
22+
* Python helper that uses the `ast` module to extract import info.
23+
* Returns a JSON string describing all FromImport statements.
2824
*/
29-
function extractImportPrefix(source: string): {
30-
importLines: string[];
31-
bodyStartIdx: number;
32-
} {
33-
const lines = source.split(/\r?\n/);
34-
let i = 0;
35-
for (; i < lines.length; i++) {
36-
const trimmed = lines[i].trim();
37-
if (trimmed === "" || trimmed.startsWith("#")) continue;
38-
if (trimmed.startsWith("from ") && trimmed.includes(" import ")) continue;
39-
break;
40-
}
41-
return { importLines: lines.slice(0, i), bodyStartIdx: i };
42-
}
25+
const ANALYZE_IMPORTS_PY = `
26+
import ast as _ast, json as _json
27+
28+
def _sa_analyze_imports(source):
29+
"""Parse source and return JSON array of from-import info."""
30+
try:
31+
tree = _ast.parse(source)
32+
except SyntaxError:
33+
return "[]"
34+
result = []
35+
for node in _ast.walk(tree):
36+
if isinstance(node, _ast.ImportFrom) and node.module:
37+
result.append({
38+
"module": node.module,
39+
"names": [
40+
{"name": a.name, "alias": a.asname}
41+
for a in node.names
42+
],
43+
"line": node.lineno,
44+
})
45+
return _json.dumps(result)
46+
`;
47+
48+
let helperLoaded = false;
4349

4450
/**
45-
* Parses only the import prefix of the source using py-slang's parser
46-
* and returns all FromImport nodes whose root module is "torch".
51+
* Ensure the Python-side `_sa_analyze_imports` function is defined.
52+
* Idempotent — only runs once.
4753
*/
48-
export function detectTorchImports(source: string): TorchImportInfo[] {
49-
const { importLines } = extractImportPrefix(source);
50-
if (importLines.length === 0) return [];
51-
52-
// Append a dummy statement so the grammar (import* statement*) is satisfied.
53-
const fragment = importLines.join("\n") + "\n_ = 0\n";
54-
55-
let ast: StmtNS.FileInput;
56-
try {
57-
ast = parse(fragment);
58-
} catch {
59-
return [];
60-
}
61-
62-
const torchImports: TorchImportInfo[] = [];
63-
64-
for (const stmt of ast.statements) {
65-
if (!(stmt instanceof StmtNS.FromImport)) continue;
66-
67-
const moduleName = stmt.module.lexeme;
68-
const root = moduleName.split(".")[0];
69-
if (root !== "torch") continue;
70-
71-
torchImports.push({
72-
module: moduleName,
73-
names: stmt.names.map(n => ({
74-
name: n.name.lexeme,
75-
alias: n.alias ? n.alias.lexeme : null,
76-
})),
77-
line: stmt.startToken.line,
78-
});
79-
}
54+
async function ensureHelper(pyodide: PyodideInterface): Promise<void> {
55+
if (helperLoaded) return;
56+
await pyodide.runPythonAsync(ANALYZE_IMPORTS_PY);
57+
helperLoaded = true;
58+
}
8059

81-
return torchImports;
60+
/**
61+
* Reset the helper loaded state. Useful for testing when pyodide
62+
* instances are recreated.
63+
*/
64+
export function resetHelperState(): void {
65+
helperLoaded = false;
8266
}
8367

8468
/**
85-
* Returns the set of top-level module roots for all non-torch imports.
86-
* These are modules that should be installed via micropip.
69+
* Parses the source code using Python's `ast` module (via Pyodide) and
70+
* returns all `from … import …` statements whose root module is "torch".
8771
*/
88-
export function getNonTorchImportRoots(source: string): Set<string> {
89-
const { importLines } = extractImportPrefix(source);
90-
if (importLines.length === 0) return new Set();
72+
export async function detectTorchImports(
73+
pyodide: PyodideInterface,
74+
source: string,
75+
): Promise<TorchImportInfo[]> {
76+
await ensureHelper(pyodide);
77+
78+
const json = pyodide.runPython(
79+
`_sa_analyze_imports(${JSON.stringify(source)})`,
80+
) as string;
81+
82+
const allImports: TorchImportInfo[] = JSON.parse(json);
83+
return allImports.filter(imp => imp.module.split(".")[0] === "torch");
84+
}
9185

92-
const fragment = importLines.join("\n") + "\n_ = 0\n";
86+
/**
87+
* Returns the set of top-level module roots for all non-torch
88+
* `from … import …` statements. These may need to be installed via micropip.
89+
*/
90+
export async function getNonTorchImportRoots(
91+
pyodide: PyodideInterface,
92+
source: string,
93+
): Promise<Set<string>> {
94+
await ensureHelper(pyodide);
9395

94-
let ast: StmtNS.FileInput;
95-
try {
96-
ast = parse(fragment);
97-
} catch {
98-
return new Set();
99-
}
96+
const json = pyodide.runPython(
97+
`_sa_analyze_imports(${JSON.stringify(source)})`,
98+
) as string;
10099

100+
const allImports: TorchImportInfo[] = JSON.parse(json);
101101
const roots = new Set<string>();
102-
for (const stmt of ast.statements) {
103-
if (!(stmt instanceof StmtNS.FromImport)) continue;
104-
const root = stmt.module.lexeme.split(".")[0];
102+
for (const imp of allImports) {
103+
const root = imp.module.split(".")[0];
105104
if (root !== "torch") {
106105
roots.add(root);
107106
}
@@ -136,11 +135,11 @@ function generateReplacement(imp: TorchImportInfo): string {
136135
*
137136
* Non-torch code is passed through unchanged.
138137
*/
139-
export function rewriteTorchImports(source: string): {
140-
code: string;
141-
hasTorch: boolean;
142-
} {
143-
const imports = detectTorchImports(source);
138+
export async function rewriteTorchImports(
139+
pyodide: PyodideInterface,
140+
source: string,
141+
): Promise<{ code: string; hasTorch: boolean }> {
142+
const imports = await detectTorchImports(pyodide, source);
144143

145144
if (imports.length === 0) {
146145
return { code: source, hasTorch: false };

0 commit comments

Comments
 (0)