@@ -53,6 +53,37 @@ describe("detectTorchImports", () => {
5353 expect ( result [ 1 ] . module ) . toBe ( "torch.nn" ) ;
5454 } ) ;
5555
56+ test ( "detects bare `import torch`" , async ( ) => {
57+ const result = await detectTorchImports ( pyodide , "import torch\nx = 1\n" ) ;
58+ expect ( result ) . toHaveLength ( 1 ) ;
59+ expect ( result [ 0 ] . type ) . toBe ( "import" ) ;
60+ expect ( result [ 0 ] . module ) . toBe ( "torch" ) ;
61+ expect ( result [ 0 ] . names ) . toEqual ( [ { name : "torch" , alias : null } ] ) ;
62+ } ) ;
63+
64+ test ( "detects `import torch as t`" , async ( ) => {
65+ const result = await detectTorchImports ( pyodide , "import torch as t\nx = 1\n" ) ;
66+ expect ( result ) . toHaveLength ( 1 ) ;
67+ expect ( result [ 0 ] . type ) . toBe ( "import" ) ;
68+ expect ( result [ 0 ] . module ) . toBe ( "torch" ) ;
69+ expect ( result [ 0 ] . names ) . toEqual ( [ { name : "torch" , alias : "t" } ] ) ;
70+ } ) ;
71+
72+ test ( "detects `import torch.nn`" , async ( ) => {
73+ const result = await detectTorchImports ( pyodide , "import torch.nn\nx = 1\n" ) ;
74+ expect ( result ) . toHaveLength ( 1 ) ;
75+ expect ( result [ 0 ] . type ) . toBe ( "import" ) ;
76+ expect ( result [ 0 ] . module ) . toBe ( "torch.nn" ) ;
77+ } ) ;
78+
79+ test ( "detects mix of bare import and from-import" , async ( ) => {
80+ const src = "import torch\nfrom torch.nn import Linear\nx = 1\n" ;
81+ const result = await detectTorchImports ( pyodide , src ) ;
82+ expect ( result ) . toHaveLength ( 2 ) ;
83+ const types = result . map ( r => r . type ) . sort ( ) ;
84+ expect ( types ) . toEqual ( [ "from" , "import" ] ) ;
85+ } ) ;
86+
5687 test ( "ignores non-torch imports" , async ( ) => {
5788 const result = await detectTorchImports ( pyodide , "from math import sqrt\nx = 1\n" ) ;
5889 expect ( result ) . toHaveLength ( 0 ) ;
@@ -97,6 +128,12 @@ describe("getNonTorchImportRoots", () => {
97128 const roots = await getNonTorchImportRoots ( pyodide , "from os.path import join\nx = 1\n" ) ;
98129 expect ( roots ) . toEqual ( new Set ( [ "os" ] ) ) ;
99130 } ) ;
131+
132+ test ( "returns non-torch roots for bare import statements" , async ( ) => {
133+ const src = "import numpy\nimport torch\nx = 1\n" ;
134+ const roots = await getNonTorchImportRoots ( pyodide , src ) ;
135+ expect ( roots ) . toEqual ( new Set ( [ "numpy" ] ) ) ;
136+ } ) ;
100137} ) ;
101138
102139// ---------------------------------------------------------------------------
@@ -157,6 +194,57 @@ describe("rewriteTorchImports", () => {
157194 expect ( code ) . toContain ( "relu = __sa_import_torch.nn.functional.relu" ) ;
158195 } ) ;
159196
197+ test ( "rewrites bare `import torch`" , async ( ) => {
198+ const { code, hasTorch } = await rewriteTorchImports (
199+ pyodide ,
200+ "import torch\nx = torch.tensor([1, 2, 3])\n" ,
201+ ) ;
202+ expect ( hasTorch ) . toBe ( true ) ;
203+ expect ( code ) . toContain ( "torch = __sa_import_torch" ) ;
204+ expect ( code ) . not . toContain ( "import torch" ) ;
205+ expect ( code ) . toContain ( "x = torch.tensor([1, 2, 3])" ) ;
206+ } ) ;
207+
208+ test ( "rewrites `import torch as t`" , async ( ) => {
209+ const { code, hasTorch } = await rewriteTorchImports (
210+ pyodide ,
211+ "import torch as t\nx = t.tensor([1])\n" ,
212+ ) ;
213+ expect ( hasTorch ) . toBe ( true ) ;
214+ expect ( code ) . toContain ( "t = __sa_import_torch" ) ;
215+ expect ( code ) . toContain ( "x = t.tensor([1])" ) ;
216+ } ) ;
217+
218+ test ( "rewrites `import torch.nn as nn`" , async ( ) => {
219+ const { code, hasTorch } = await rewriteTorchImports (
220+ pyodide ,
221+ "import torch.nn as nn\nx = nn.Linear(3, 2)\n" ,
222+ ) ;
223+ expect ( hasTorch ) . toBe ( true ) ;
224+ expect ( code ) . toContain ( "nn = __sa_import_torch.nn" ) ;
225+ expect ( code ) . toContain ( "x = nn.Linear(3, 2)" ) ;
226+ } ) ;
227+
228+ test ( "rewrites `import torch.nn` (no alias)" , async ( ) => {
229+ const { code, hasTorch } = await rewriteTorchImports (
230+ pyodide ,
231+ "import torch.nn\nx = torch.nn.Linear(3, 2)\n" ,
232+ ) ;
233+ expect ( hasTorch ) . toBe ( true ) ;
234+ expect ( code ) . toContain ( "torch = __sa_import_torch" ) ;
235+ expect ( code ) . toContain ( "x = torch.nn.Linear(3, 2)" ) ;
236+ } ) ;
237+
238+ test ( "rewrites mix of bare import and from-import" , async ( ) => {
239+ const src = "import torch\nfrom torch.nn import Linear\nx = torch.tensor(1)\ny = Linear(3, 2)\n" ;
240+ const { code, hasTorch } = await rewriteTorchImports ( pyodide , src ) ;
241+ expect ( hasTorch ) . toBe ( true ) ;
242+ expect ( code ) . toContain ( "torch = __sa_import_torch" ) ;
243+ expect ( code ) . toContain ( "Linear = __sa_import_torch.nn.Linear" ) ;
244+ expect ( code ) . not . toMatch ( / ^ i m p o r t t o r c h $ / m) ;
245+ expect ( code ) . not . toContain ( "from torch" ) ;
246+ } ) ;
247+
160248 test ( "handles full Python body that py-slang cannot parse" , async ( ) => {
161249 const src = "from torch import tensor\nx = tensor([1, 2, 3]).tolist()\nprint(x)\n" ;
162250 const { code, hasTorch } = await rewriteTorchImports ( pyodide , src ) ;
0 commit comments