11/**
2- * AST-based import analysis for detecting and rewriting torch imports.
2+ * Import analysis for detecting and rewriting torch imports using Python's
3+ * built-in `ast` module via Pyodide.
34 *
4- * Uses py-slang's parser to produce an AST from the import prefix of the
5- * source, then walks FromImport nodes to find torch-related imports —
6- * replacing the regex-based approach used in sa-conductor-py-torch.
7- *
8- * Because pyodide code uses full Python syntax that py-slang cannot parse,
9- * we extract only the leading `from … import …` lines, append a dummy
10- * statement so the grammar is satisfied, and parse that fragment.
5+ * This avoids the limitations of py-slang's parser (which only supports a
6+ * subset of Python) by delegating to CPython's own parser running inside
7+ * Pyodide.
118 */
129
13- import { parse } from "../parser/parser-adapter" ;
14- import { StmtNS } from "../ast-types" ;
10+ import type { PyodideInterface } from "pyodide" ;
1511
1612export interface TorchImportInfo {
1713 /** Full module path, e.g. "torch" or "torch.nn" */
@@ -23,85 +19,88 @@ export interface TorchImportInfo {
2319}
2420
2521/**
26- * Extracts leading `from … import …` lines from the source and returns
27- * them along with the line index where non-import code begins .
22+ * Python helper that uses the `ast` module to extract import info.
23+ * Returns a JSON string describing all FromImport statements .
2824 */
29- function extractImportPrefix ( source : string ) : {
30- importLines : string [ ] ;
31- bodyStartIdx : number ;
32- } {
33- const lines = source . split ( / \r ? \n / ) ;
34- let i = 0 ;
35- for ( ; i < lines . length ; i ++ ) {
36- const trimmed = lines [ i ] . trim ( ) ;
37- if ( trimmed === "" || trimmed . startsWith ( "#" ) ) continue ;
38- if ( trimmed . startsWith ( "from " ) && trimmed . includes ( " import " ) ) continue ;
39- break ;
40- }
41- return { importLines : lines . slice ( 0 , i ) , bodyStartIdx : i } ;
42- }
25+ const ANALYZE_IMPORTS_PY = `
26+ import ast as _ast, json as _json
27+
28+ def _sa_analyze_imports(source):
29+ """Parse source and return JSON array of from-import info."""
30+ try:
31+ tree = _ast.parse(source)
32+ except SyntaxError:
33+ return "[]"
34+ result = []
35+ for node in _ast.walk(tree):
36+ if isinstance(node, _ast.ImportFrom) and node.module:
37+ result.append({
38+ "module": node.module,
39+ "names": [
40+ {"name": a.name, "alias": a.asname}
41+ for a in node.names
42+ ],
43+ "line": node.lineno,
44+ })
45+ return _json.dumps(result)
46+ ` ;
47+
48+ let helperLoaded = false ;
4349
4450/**
45- * Parses only the import prefix of the source using py-slang's parser
46- * and returns all FromImport nodes whose root module is "torch" .
51+ * Ensure the Python-side `_sa_analyze_imports` function is defined.
52+ * Idempotent — only runs once .
4753 */
48- export function detectTorchImports ( source : string ) : TorchImportInfo [ ] {
49- const { importLines } = extractImportPrefix ( source ) ;
50- if ( importLines . length === 0 ) return [ ] ;
51-
52- // Append a dummy statement so the grammar (import* statement*) is satisfied.
53- const fragment = importLines . join ( "\n" ) + "\n_ = 0\n" ;
54-
55- let ast : StmtNS . FileInput ;
56- try {
57- ast = parse ( fragment ) ;
58- } catch {
59- return [ ] ;
60- }
61-
62- const torchImports : TorchImportInfo [ ] = [ ] ;
63-
64- for ( const stmt of ast . statements ) {
65- if ( ! ( stmt instanceof StmtNS . FromImport ) ) continue ;
66-
67- const moduleName = stmt . module . lexeme ;
68- const root = moduleName . split ( "." ) [ 0 ] ;
69- if ( root !== "torch" ) continue ;
70-
71- torchImports . push ( {
72- module : moduleName ,
73- names : stmt . names . map ( n => ( {
74- name : n . name . lexeme ,
75- alias : n . alias ? n . alias . lexeme : null ,
76- } ) ) ,
77- line : stmt . startToken . line ,
78- } ) ;
79- }
54+ async function ensureHelper ( pyodide : PyodideInterface ) : Promise < void > {
55+ if ( helperLoaded ) return ;
56+ await pyodide . runPythonAsync ( ANALYZE_IMPORTS_PY ) ;
57+ helperLoaded = true ;
58+ }
8059
81- return torchImports ;
60+ /**
61+ * Reset the helper loaded state. Useful for testing when pyodide
62+ * instances are recreated.
63+ */
64+ export function resetHelperState ( ) : void {
65+ helperLoaded = false ;
8266}
8367
8468/**
85- * Returns the set of top-level module roots for all non-torch imports.
86- * These are modules that should be installed via micropip .
69+ * Parses the source code using Python's `ast` module (via Pyodide) and
70+ * returns all `from … import …` statements whose root module is "torch" .
8771 */
88- export function getNonTorchImportRoots ( source : string ) : Set < string > {
89- const { importLines } = extractImportPrefix ( source ) ;
90- if ( importLines . length === 0 ) return new Set ( ) ;
72+ export async function detectTorchImports (
73+ pyodide : PyodideInterface ,
74+ source : string ,
75+ ) : Promise < TorchImportInfo [ ] > {
76+ await ensureHelper ( pyodide ) ;
77+
78+ const json = pyodide . runPython (
79+ `_sa_analyze_imports(${ JSON . stringify ( source ) } )` ,
80+ ) as string ;
81+
82+ const allImports : TorchImportInfo [ ] = JSON . parse ( json ) ;
83+ return allImports . filter ( imp => imp . module . split ( "." ) [ 0 ] === "torch" ) ;
84+ }
9185
92- const fragment = importLines . join ( "\n" ) + "\n_ = 0\n" ;
86+ /**
87+ * Returns the set of top-level module roots for all non-torch
88+ * `from … import …` statements. These may need to be installed via micropip.
89+ */
90+ export async function getNonTorchImportRoots (
91+ pyodide : PyodideInterface ,
92+ source : string ,
93+ ) : Promise < Set < string > > {
94+ await ensureHelper ( pyodide ) ;
9395
94- let ast : StmtNS . FileInput ;
95- try {
96- ast = parse ( fragment ) ;
97- } catch {
98- return new Set ( ) ;
99- }
96+ const json = pyodide . runPython (
97+ `_sa_analyze_imports(${ JSON . stringify ( source ) } )` ,
98+ ) as string ;
10099
100+ const allImports : TorchImportInfo [ ] = JSON . parse ( json ) ;
101101 const roots = new Set < string > ( ) ;
102- for ( const stmt of ast . statements ) {
103- if ( ! ( stmt instanceof StmtNS . FromImport ) ) continue ;
104- const root = stmt . module . lexeme . split ( "." ) [ 0 ] ;
102+ for ( const imp of allImports ) {
103+ const root = imp . module . split ( "." ) [ 0 ] ;
105104 if ( root !== "torch" ) {
106105 roots . add ( root ) ;
107106 }
@@ -136,11 +135,11 @@ function generateReplacement(imp: TorchImportInfo): string {
136135 *
137136 * Non-torch code is passed through unchanged.
138137 */
139- export function rewriteTorchImports ( source : string ) : {
140- code : string ;
141- hasTorch : boolean ;
142- } {
143- const imports = detectTorchImports ( source ) ;
138+ export async function rewriteTorchImports (
139+ pyodide : PyodideInterface ,
140+ source : string ,
141+ ) : Promise < { code : string ; hasTorch : boolean } > {
142+ const imports = await detectTorchImports ( pyodide , source ) ;
144143
145144 if ( imports . length === 0 ) {
146145 return { code : source , hasTorch : false } ;
0 commit comments