diff --git a/jest.config.js b/jest.config.js index 92da7561..0c0f7816 100644 --- a/jest.config.js +++ b/jest.config.js @@ -46,6 +46,7 @@ module.exports = { }, }, ], + "\\.py$": "/src/tests/raw-text-transformer.js", }, transformIgnorePatterns: [ "/node_modules/(?!(@sourceacademy/wasm-util|@sourceacademy/conductor)/).+\\.js$", diff --git a/package.json b/package.json index 33c7789f..6896f809 100644 --- a/package.json +++ b/package.json @@ -12,8 +12,8 @@ "jsdoc": "./scripts/jsdoc.sh prepare", "jsdoc:run": "./scripts/jsdoc.sh run", "jsdoc:clean": "./scripts/jsdoc.sh clean", - "test": "jest", - "test-coverage": "jest --coverage", + "test": "node --experimental-vm-modules node_modules/.bin/jest", + "test-coverage": "node --experimental-vm-modules node_modules/.bin/jest --coverage", "lint": "eslint --concurrency=auto src", "format": "prettier --write \"**/*.{ts,tsx,json,js,mjs}\"", "format:ci": "prettier --list-different \"**/*.{ts,tsx,json,js,mjs}\"", @@ -60,11 +60,13 @@ }, "dependencies": { "@sourceacademy/conductor": "^0.3.0", + "@sourceacademy/torch": "^0.1.0", "@sourceacademy/wasm-util": "^1.0.6", "fast-levenshtein": "^3.0.0", "mathjs": "^14.9.1", "moo": "^0.5.2", "nearley": "^2.20.1", + "pyodide": "^0.29.3", "wabt": "^1.0.37" } } diff --git a/rollup.config.mjs b/rollup.config.mjs index 5f1f44b0..e4288605 100644 --- a/rollup.config.mjs +++ b/rollup.config.mjs @@ -6,6 +6,7 @@ import commonjs from "@rollup/plugin-commonjs"; import nodePolyfills from "rollup-plugin-polyfill-node"; import replace from "@rollup/plugin-replace"; import wasm from "@rollup/plugin-wasm"; +import { readFileSync } from "fs"; // Env EVALUATOR is set by scripts/build.ts. const EVALUATOR = process.env.EVALUATOR; @@ -13,8 +14,22 @@ if (!EVALUATOR) { throw new Error("EVALUATOR env var must be set. Use scripts/build.ts."); } +/** Plugin: import .py files as strings. */ +function rawPy() { + return { + name: "raw-py", + load(id) { + if (id.endsWith(".py")) { + const text = readFileSync(id, "utf-8"); + return `export default ${JSON.stringify(text)};`; + } + }, + }; +} + function plugins() { return [ + rawPy(), replace({ preventAssignment: true, values: { __EVALUATOR__: EVALUATOR }, @@ -45,6 +60,7 @@ const config = [ format: "iife", name: "PySlangWorker", sourcemap: true, + inlineDynamicImports: true, }, plugins: plugins(), }, @@ -58,6 +74,7 @@ const config = [ format: "cjs", exports: "default", sourcemap: true, + inlineDynamicImports: true, }, plugins: plugins(), }, diff --git a/scripts/build.ts b/scripts/build.ts index c73017b8..bd2cdd27 100644 --- a/scripts/build.ts +++ b/scripts/build.ts @@ -12,6 +12,11 @@ const allTargets = [ "PyWasmEvaluator", "PySvmlEvaluator", "PySvmlSinterEvaluator", + "PyodideEvaluator1", + "PyodideEvaluator2", + "PyodideEvaluator3", + "PyodideEvaluator4", + "PyodideEvaluatorFull", ] as const; type EvaluatorName = (typeof allTargets)[number]; diff --git a/src/conductor/index.ts b/src/conductor/index.ts index 0e39bcac..fd0e0316 100644 --- a/src/conductor/index.ts +++ b/src/conductor/index.ts @@ -1,3 +1,10 @@ +export { + PyodideEvaluator1, + PyodideEvaluator2, + PyodideEvaluator3, + PyodideEvaluator4, + PyodideEvaluatorFull, +} from "../pyodide/PyodideEvaluator"; export { PyCseEvaluator1, PyCseEvaluator2, @@ -5,5 +12,5 @@ export { PyCseEvaluator4, } from "./PyCseEvaluator"; export { PySvmlEvaluator } from "./PySvmlEvaluator"; -export { PyWasmEvaluator } from "./PyWasmEvaluator"; export { PySvmlSinterEvaluator } from "./PySvmlSinterEvaluator"; +export { PyWasmEvaluator } from "./PyWasmEvaluator"; diff --git a/src/pyodide/PyodideEvaluator.ts b/src/pyodide/PyodideEvaluator.ts new file mode 100644 index 00000000..17b784be --- /dev/null +++ b/src/pyodide/PyodideEvaluator.ts @@ -0,0 +1,120 @@ +import { ConductorError } from "@sourceacademy/conductor/common"; +import { BasicEvaluator, IRunnerPlugin } from "@sourceacademy/conductor/runner"; +import type { PyodideInterface } from "pyodide"; +import { parse } from "../parser/parser-adapter"; +import { analyze } from "../resolver/analysis"; +import { getNonTorchImportRoots, rewriteTorchImports } from "./importAnalyzer"; +import { loadPyodideGeneric } from "./loadPyodide"; +import { loadTorch } from "./loadTorch"; + +export default abstract class PyodideEvaluator extends BasicEvaluator { + protected pyodide: Promise; + private torchLoaded = false; + + constructor(conductor: IRunnerPlugin) { + super(conductor); + this.pyodide = loadPyodideGeneric().then(async pyodide => { + await pyodide.loadPackage("micropip"); + await pyodide.setStdout({ + batched: (output: string) => { + this.conductor.sendOutput(output); + }, + }); + return pyodide; + }); + } + + protected abstract validateChunk(_chunk: string): void; + + async evaluateChunk(chunk: string): Promise { + this.validateChunk(chunk); + + const pyodide = await this.pyodide; + + // --- Use Python's ast module (via Pyodide) to detect and rewrite torch imports --- + const { code, hasTorch } = await rewriteTorchImports(pyodide, chunk); + + if (hasTorch && !this.torchLoaded) { + await loadTorch(pyodide); + pyodide.globals.set("__sa_import_torch", pyodide.globals.get("torch")); + this.torchLoaded = true; + } + + // --- Install any other imported modules via micropip --- + const otherRoots = await getNonTorchImportRoots(pyodide, chunk); + if (otherRoots.size > 0) { + const modulesArray = Array.from(otherRoots); + const installerCode = ` +import importlib, micropip +mods = ${JSON.stringify(modulesArray)} +missing = [] +for m in mods: + try: + importlib.import_module(m) + except Exception: + missing.append(m) +if missing: + await micropip.install(missing) +`; + await pyodide.runPythonAsync(installerCode); + } + + // --- Execute the (possibly rewritten) code --- + try { + const output = await pyodide.runPythonAsync(code); + this.conductor.sendResult(output); + } catch (err: unknown) { + const message = err instanceof Error ? err.message : String(err); + this.conductor.sendError(new ConductorError(message)); + } + } +} + +export class ChapterPyodideEvaluator extends PyodideEvaluator { + private chapter: number; + + constructor(conductor: IRunnerPlugin, chapter: number) { + super(conductor); + this.chapter = chapter; + } + + protected validateChunk(chunk: string): void { + const script = chunk + "\n"; + const ast = parse(script); + analyze(ast, script, this.chapter); + } +} + +export class PyodideEvaluator1 extends ChapterPyodideEvaluator { + constructor(conductor: IRunnerPlugin) { + super(conductor, 1); + } +} + +export class PyodideEvaluator2 extends ChapterPyodideEvaluator { + constructor(conductor: IRunnerPlugin) { + super(conductor, 2); + } +} + +export class PyodideEvaluator3 extends ChapterPyodideEvaluator { + constructor(conductor: IRunnerPlugin) { + super(conductor, 3); + } +} + +export class PyodideEvaluator4 extends ChapterPyodideEvaluator { + constructor(conductor: IRunnerPlugin) { + super(conductor, 4); + } +} + +export class PyodideEvaluatorFull extends PyodideEvaluator { + constructor(conductor: IRunnerPlugin) { + super(conductor); + } + + protected validateChunk(_chunk: string): void { + // No-op validation + } +} diff --git a/src/pyodide/bridge.py b/src/pyodide/bridge.py new file mode 100644 index 00000000..eea4ae73 --- /dev/null +++ b/src/pyodide/bridge.py @@ -0,0 +1,697 @@ +# bridge.py +# Provides a PyTorch-compatible Python API over js_torch (the TypeScript torch library). +# +# Before loading this file, set the following globals in Pyodide: +# js_torch - the torch module (window.torch from the UMD build) + +from pyodide.ffi import JsProxy, to_js + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _wrap_result(result): + """ + Wrap a JS return value: + - JsProxy (JS object/Tensor) -> Python Tensor + - Python primitive (int, float, bool) -> return as-is + JS primitives are automatically converted to Python by Pyodide, + so they will NOT be JsProxy instances. + """ + if isinstance(result, JsProxy): + return Tensor(result) + return result + + +def _transform(obj): + """Convert Python objects to JS-compatible types before passing to JS.""" + if isinstance(obj, Tensor): + return obj._js + if isinstance(obj, (list, tuple)): + return to_js([_transform(item) for item in obj]) + return obj + + +def _transform_args(args): + return [_transform(a) for a in args] + + +# --------------------------------------------------------------------------- +# Tensor +# --------------------------------------------------------------------------- + +class Tensor: + """Python wrapper around a JS Tensor, mirroring the PyTorch Tensor API.""" + + # ------------------------------------------------------------------ + # Construction + # ------------------------------------------------------------------ + + def __new__(cls, data, requires_grad=False): + # Return None for missing tensors so e.g. `tensor.grad` returns None + # when there is no gradient — matching PyTorch behaviour. + # Pyodide may represent JS null as a special JsNull type (not JsProxy, not None). + if data is None or type(data).__name__ in ('JsNull', 'JsUndefined'): + return None + return super().__new__(cls) + + def __init__(self, data, requires_grad=False): + if isinstance(data, JsProxy): + self._js = data + else: + js_data = to_js(data) if isinstance(data, (list, tuple)) else data + self._js = js_torch.tensor(js_data, requires_grad) + + # ------------------------------------------------------------------ + # Representation + # ------------------------------------------------------------------ + + def __repr__(self): + extra = ", requires_grad=True" if self.requires_grad else "" + return f"tensor({self.tolist()}{extra})" + + # ------------------------------------------------------------------ + # Data access + # ------------------------------------------------------------------ + + def tolist(self): + """Return tensor data as a (nested) Python list, or a Python scalar for 0-d tensors.""" + result = self._js.toArray() + if isinstance(result, JsProxy): + return result.to_py() + return result # scalar + + def item(self): + return self._js.item() + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def shape(self): + return tuple(self._js.shape.to_py()) + + @property + def data(self): + """Detached view of the tensor data (no gradient).""" + return self.detach() + + @property + def requires_grad(self): + return bool(self._js.requires_grad) + + @requires_grad.setter + def requires_grad(self, value): + self._js.requires_grad = value + + @property + def grad(self): + raw = self._js.grad + if raw is None or type(raw).__name__ in ('JsNull', 'JsUndefined'): + return None + return Tensor(raw) + + @grad.setter + def grad(self, value): + self._js.grad = value._js if isinstance(value, Tensor) else None + + @property + def T(self): + if len(self.shape) < 2: + return self + return Tensor(self._js.transpose(0, 1)) + + # ------------------------------------------------------------------ + # Grad utilities + # ------------------------------------------------------------------ + + def backward(self, gradient=None): + if gradient is None: + self._js.backward() + else: + self._js.backward(gradient._js) + + def detach(self): + return Tensor(self._js.detach()) + + def zero_(self): + self._js.zero_() + return self + + def retain_grad(self): + self._js.retain_grad() + + # ------------------------------------------------------------------ + # Shape utilities + # ------------------------------------------------------------------ + + def size(self, dim=None): + s = self.shape + return s if dim is None else s[dim] + + def dim(self): + return len(self.shape) + + def numel(self): + n = 1 + for s in self.shape: + n *= s + return n + + def reshape(self, *args): + shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) + return Tensor(self._js.reshape(to_js(shape))) + + def view(self, *args): + return self.reshape(*args) + + def squeeze(self, dim=None): + if dim is None: + new_shape = [s for s in self.shape if s != 1] + return Tensor(self._js.reshape(to_js(new_shape or [1]))) + return Tensor(self._js.squeeze(dim)) + + def unsqueeze(self, dim): + return Tensor(self._js.unsqueeze(dim)) + + def expand(self, *args): + shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) + return Tensor(self._js.expand(to_js(shape))) + + def transpose(self, dim0, dim1): + return Tensor(self._js.transpose(dim0, dim1)) + + def flatten(self, start_dim=0, end_dim=-1): + return Tensor(self._js.flatten(start_dim, end_dim)) + + # ------------------------------------------------------------------ + # Reductions — default (no dim) sums all elements, matching PyTorch + # ------------------------------------------------------------------ + + def sum(self, dim=None, keepdim=False): + return Tensor(self._js.sum() if dim is None else self._js.sum(dim, keepdim)) + + def mean(self, dim=None, keepdim=False): + return Tensor(self._js.mean() if dim is None else self._js.mean(dim, keepdim)) + + def max(self, dim=None, keepdim=False): + return Tensor(self._js.max() if dim is None else self._js.max(dim, keepdim)) + + def min(self, dim=None, keepdim=False): + return Tensor(self._js.min() if dim is None else self._js.min(dim, keepdim)) + + # ------------------------------------------------------------------ + # Arithmetic — explicit methods + # ------------------------------------------------------------------ + + def _to_js(self, other): + return other._js if isinstance(other, Tensor) else other + + def add(self, other): return Tensor(self._js.add(self._to_js(other))) + def sub(self, other): return Tensor(self._js.sub(self._to_js(other))) + def mul(self, other): return Tensor(self._js.mul(self._to_js(other))) + def div(self, other): return Tensor(self._js.div(self._to_js(other))) + def pow(self, other): return Tensor(self._js.pow(self._to_js(other))) + def matmul(self, other): return Tensor(self._js.matmul(self._to_js(other))) + + # ------------------------------------------------------------------ + # Arithmetic operators + # ------------------------------------------------------------------ + + def __add__(self, other): return self.add(other) + def __radd__(self, other): return self.add(other) # add is commutative + def __sub__(self, other): return self.sub(other) + def __rsub__(self, other): + o = other if isinstance(other, Tensor) else Tensor(other) + return o.sub(self) + def __mul__(self, other): return self.mul(other) + def __rmul__(self, other): return self.mul(other) # mul is commutative + def __truediv__(self, other): return self.div(other) + def __rtruediv__(self, other): + o = other if isinstance(other, Tensor) else Tensor(other) + return o.div(self) + def __pow__(self, other): return self.pow(other) + def __rpow__(self, other): + o = other if isinstance(other, Tensor) else Tensor(other) + return o.pow(self) + def __matmul__(self, other): return self.matmul(other) + def __neg__(self): return Tensor(self._js.neg()) + def __abs__(self): return Tensor(self._js.abs()) + + # ------------------------------------------------------------------ + # Unary operations + # ------------------------------------------------------------------ + + def neg(self): return Tensor(self._js.neg()) + def abs(self): return Tensor(self._js.abs()) + def log(self): return Tensor(self._js.log()) + def exp(self): return Tensor(self._js.exp()) + def sqrt(self): return Tensor(self._js.sqrt()) + def square(self): return Tensor(self._js.square()) + def sin(self): return Tensor(self._js.sin()) + def cos(self): return Tensor(self._js.cos()) + def tan(self): return Tensor(self._js.tan()) + def sigmoid(self): return Tensor(self._js.sigmoid()) + def relu(self): return Tensor(js_torch.nn.functional.relu(self._js)) + def sign(self): return Tensor(self._js.sign()) + def reciprocal(self): return Tensor(self._js.reciprocal()) + def nan_to_num(self): return Tensor(self._js.nan_to_num()) + + # ------------------------------------------------------------------ + # Comparison + # ------------------------------------------------------------------ + + def lt(self, other): return Tensor(self._js.lt(self._to_js(other))) + def gt(self, other): return Tensor(self._js.gt(self._to_js(other))) + def le(self, other): return Tensor(self._js.le(self._to_js(other))) + def ge(self, other): return Tensor(self._js.ge(self._to_js(other))) + def eq(self, other): return Tensor(self._js.eq(self._to_js(other))) + def ne(self, other): return Tensor(self._js.ne(self._to_js(other))) + + def allclose(self, other, rtol=1e-5, atol=1e-8, equal_nan=False): + return bool(js_torch.allclose(self._js, other._js, rtol, atol, equal_nan)) + + # ------------------------------------------------------------------ + # Type conversions + # ------------------------------------------------------------------ + + def __float__(self): return float(self.item()) + def __int__(self): return int(self.item()) + def __bool__(self): return bool(self.item()) + def __format__(self, fmt): return format(self.item(), fmt) + + # ------------------------------------------------------------------ + # Indexing + # ------------------------------------------------------------------ + + def __getitem__(self, key): + if isinstance(key, int): + return Tensor(self._js.index(key)) + if isinstance(key, tuple): + result = self._js + for k in key: + if isinstance(k, int): + result = result.index(k) + else: + raise NotImplementedError( + "Only integer indexing is supported in multi-dimensional indexing" + ) + return Tensor(result) + if isinstance(key, slice): + start, stop, step = key.indices(self.shape[0]) + data = [Tensor(self._js.index(i)).tolist() for i in range(start, stop, step)] + return Tensor(data) + raise TypeError(f"Invalid index type: {type(key).__name__}") + + # ------------------------------------------------------------------ + # Iteration and length + # ------------------------------------------------------------------ + + def __len__(self): + return self.shape[0] + + def __iter__(self): + data = self.tolist() + if not isinstance(data, list): + raise TypeError("iteration over a 0-d tensor") + for item in data: + yield Tensor(item) + + # ------------------------------------------------------------------ + # Catch-all: delegate unknown attribute accesses to the JS tensor. + # Returned JsProxy objects are wrapped in Tensor; primitives pass through. + # ------------------------------------------------------------------ + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError(name) + def method(*args, **kwargs): + js_args = _transform_args(args) + return _wrap_result(self._js.__getattribute__(name)(*js_args)) + return method + + +# --------------------------------------------------------------------------- +# no_grad context manager — actually disables grad in the JS engine +# --------------------------------------------------------------------------- + +class _NoGrad: + def __enter__(self): + self._prev = js_torch.enable_no_grad() + return self + + def __exit__(self, *args): + js_torch.disable_no_grad(self._prev) + + +# --------------------------------------------------------------------------- +# Parameter +# --------------------------------------------------------------------------- + +class Parameter(Tensor): + """A Tensor that is automatically registered as a parameter.""" + def __init__(self, data, requires_grad=True): + if isinstance(data, Tensor): + self._js = js_torch.nn.Parameter.new(data._js) + elif isinstance(data, JsProxy): + self._js = js_torch.nn.Parameter.new(data) + else: + self._js = js_torch.nn.Parameter.new(js_torch.tensor(data)) + if not requires_grad: + self._js.requires_grad = False + + +# --------------------------------------------------------------------------- +# Module — pure-Python base class for user-defined models +# --------------------------------------------------------------------------- + +class Module: + """ + Pure-Python nn.Module. Subclass this to build models using bridge Tensors. + Assign `Parameter` or `_NNModule` instances as attributes and they are + automatically tracked by `parameters()`. + """ + + def __init__(self): + object.__setattr__(self, '_parameters', {}) + object.__setattr__(self, '_modules', {}) + + def __setattr__(self, name, value): + try: + params = object.__getattribute__(self, '_parameters') + modules = object.__getattribute__(self, '_modules') + except AttributeError: + object.__setattr__(self, name, value) + return + + if isinstance(value, Parameter): + params[name] = value + elif isinstance(value, (Module, _NNModule)): + modules[name] = value + object.__setattr__(self, name, value) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def parameters(self): + params = list(object.__getattribute__(self, '_parameters').values()) + for mod in object.__getattribute__(self, '_modules').values(): + params.extend(mod.parameters()) + return params + + def named_parameters(self, prefix=''): + result = [] + for name, p in object.__getattribute__(self, '_parameters').items(): + full = f"{prefix}.{name}" if prefix else name + result.append((full, p)) + for mod_name, mod in object.__getattribute__(self, '_modules').items(): + full_mod = f"{prefix}.{mod_name}" if prefix else mod_name + result.extend(mod.named_parameters(full_mod)) + return result + + def zero_grad(self): + for p in self.parameters(): + p.grad = None + + +# --------------------------------------------------------------------------- +# _NNModule — wraps a JS nn.Module instance +# --------------------------------------------------------------------------- + +class _NNModule: + """Wraps a JS nn.Module returned by the nn factory functions.""" + + def __init__(self, js_module): + self._module = js_module + + def __call__(self, *args): + js_args = [a._js if isinstance(a, Tensor) else a for a in args] + return Tensor(self._module.forward(*js_args)) + + def forward(self, *args): + return self(*args) + + def parameters(self): + return [Tensor(p) for p in self._module.parameters().to_py()] + + def named_parameters(self, prefix=''): + raw = self._module.named_parameters(prefix).to_py() + return [(pair[0], Tensor(pair[1])) for pair in raw] + + def zero_grad(self): + for p in self.parameters(): + p.grad = None + + +# --------------------------------------------------------------------------- +# nn.functional +# --------------------------------------------------------------------------- + +class _NNFunctional: + def relu(self, input): + return Tensor(js_torch.nn.functional.relu(input._js)) + + def sigmoid(self, input): + return Tensor(js_torch.nn.functional.sigmoid(input._js)) + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError(name) + def fn(*args, **kwargs): + return _wrap_result(js_torch.nn.functional.__getattribute__(name)(*_transform_args(args))) + return fn + + +# --------------------------------------------------------------------------- +# nn.parameter namespace +# --------------------------------------------------------------------------- + +class _NNParameterNamespace: + def __init__(self): + self.Parameter = Parameter + + +# --------------------------------------------------------------------------- +# nn namespace +# --------------------------------------------------------------------------- + +class _NNNamespace: + def __init__(self): + self.functional = _NNFunctional() + self.parameter = _NNParameterNamespace() + self.Module = Module + self.Parameter = Parameter + + def Linear(self, in_features, out_features, bias=True): + return _NNModule(js_torch.nn.Linear.new(in_features, out_features, bias)) + + def ReLU(self): + return _NNModule(js_torch.nn.ReLU.new()) + + def Sigmoid(self): + return _NNModule(js_torch.nn.Sigmoid.new()) + + def Sequential(self, *modules): + js_mods = [m._module for m in modules] + return _NNModule(js_torch.nn.Sequential.new(*js_mods)) + + def MSELoss(self, reduction='mean'): + return _NNModule(js_torch.nn.MSELoss.new(reduction)) + + def L1Loss(self, reduction='mean'): + return _NNModule(js_torch.nn.L1Loss.new(reduction)) + + def BCELoss(self, weight=None, reduction='mean'): + js_weight = weight._js if isinstance(weight, Tensor) else None + return _NNModule(js_torch.nn.BCELoss.new(js_weight, reduction)) + + def CrossEntropyLoss(self, reduction='mean'): + return _NNModule(js_torch.nn.CrossEntropyLoss.new(reduction)) + + def Conv1d(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + return _NNModule(js_torch.nn.Conv1d.new( + in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias + )) + + def Conv2d(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + return _NNModule(js_torch.nn.Conv2d.new( + in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias + )) + + def Conv3d(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + return _NNModule(js_torch.nn.Conv3d.new( + in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias + )) + + +# --------------------------------------------------------------------------- +# optim wrappers +# --------------------------------------------------------------------------- + +class _Optimizer: + def __init__(self, js_optim): + self._optim = js_optim + + def step(self): + self._optim.step() + + def zero_grad(self): + self._optim.zero_grad() + + +class _OptimNamespace: + def SGD(self, params, lr=0.001, momentum=0.0, dampening=0.0, + weight_decay=0.0, nesterov=False, maximize=False): + js_params = to_js([p._js for p in params]) + return _Optimizer(js_torch.optim.SGD.new( + js_params, lr, momentum, dampening, weight_decay, nesterov, maximize + )) + + def Adam(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0.0, amsgrad=False, maximize=False): + js_params = to_js([p._js for p in params]) + js_betas = to_js(list(betas)) + return _Optimizer(js_torch.optim.Adam.new( + js_params, lr, js_betas, eps, weight_decay, amsgrad, maximize + )) + + +# --------------------------------------------------------------------------- +# torch namespace +# --------------------------------------------------------------------------- + +class _Torch: + def __init__(self): + self.nn = _NNNamespace() + self.optim = _OptimNamespace() + self.no_grad = _NoGrad + + @property + def tensor(self): + return Tensor + + # --- creation functions --- + + def _shape_from_args(self, args): + return list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) + + def zeros(self, *args, **kwargs): + return Tensor(js_torch.zeros(to_js(self._shape_from_args(args)))) + + def ones(self, *args, **kwargs): + return Tensor(js_torch.ones(to_js(self._shape_from_args(args)))) + + def zeros_like(self, input): + return Tensor(js_torch.zeros_like(input._js)) + + def ones_like(self, input): + return Tensor(js_torch.ones_like(input._js)) + + def randn(self, *args, **kwargs): + return Tensor(js_torch.randn(to_js(self._shape_from_args(args)))) + + def rand(self, *args, **kwargs): + return Tensor(js_torch.rand(to_js(self._shape_from_args(args)))) + + def arange(self, start, end=None, step=1): + if end is None: + end = start + start = 0 + return Tensor(js_torch.arange(start, end, step)) + + def linspace(self, start, end, steps): + return Tensor(js_torch.linspace(start, end, steps)) + + def empty(self, *args, **kwargs): + return Tensor(js_torch.empty(to_js(self._shape_from_args(args)))) + + def empty_like(self, input): + return Tensor(js_torch.empty_like(input._js)) + + def full(self, shape, fill_value): + return Tensor(js_torch.full(to_js(list(shape)), fill_value)) + + def full_like(self, input, fill_value): + return Tensor(js_torch.full_like(input._js, fill_value)) + + def rand_like(self, input): + return Tensor(js_torch.rand_like(input._js)) + + def randn_like(self, input): + return Tensor(js_torch.randn_like(input._js)) + + def randint_like(self, input, low, high): + return Tensor(js_torch.randint_like(input._js, low, high)) + + # --- utility functions --- + + def is_tensor(self, obj): + return isinstance(obj, Tensor) + + def is_nonzero(self, input): + if input.numel() != 1: + raise RuntimeError( + "Boolean value of Tensor with more than one element is ambiguous" + ) + return bool(input.item() != 0) + + def numel(self, input): + return input.numel() + + # --- functional wrappers --- + + def sum(self, input, dim=None, keepdim=False): + return input.sum(dim, keepdim) + + def mean(self, input, dim=None, keepdim=False): + return input.mean(dim, keepdim) + + def sigmoid(self, input): + return input.sigmoid() + + def relu(self, input): + return input.relu() + + def flatten(self, input, start_dim=0, end_dim=-1): + return input.flatten(start_dim, end_dim) + + def allclose(self, a, b, rtol=1e-5, atol=1e-8, equal_nan=False): + return a.allclose(b, rtol, atol, equal_nan) + + def is_grad_enabled(self): + return bool(js_torch.is_grad_enabled()) + + def cat(self, tensors, dim=0): + if isinstance(tensors, Tensor): + tensors = [tensors] + return Tensor(js_torch.cat(to_js([t._js for t in tensors]), dim)) + + def concatenate(self, tensors, dim=0): + return self.cat(tensors, dim) + + def concat(self, tensors, dim=0): + return self.cat(tensors, dim) + + def Size(self, shape): + return list(shape) + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError(name) + def fn(*args, **kwargs): + return _wrap_result(js_torch.__getattribute__(name)(*_transform_args(args))) + return fn + + +torch = _Torch() diff --git a/src/pyodide/bridge.py.d.ts b/src/pyodide/bridge.py.d.ts new file mode 100644 index 00000000..c4ef32bd --- /dev/null +++ b/src/pyodide/bridge.py.d.ts @@ -0,0 +1,2 @@ +declare const bridgeCode: string; +export default bridgeCode; diff --git a/src/pyodide/importAnalyzer.ts b/src/pyodide/importAnalyzer.ts new file mode 100644 index 00000000..905a4e9b --- /dev/null +++ b/src/pyodide/importAnalyzer.ts @@ -0,0 +1,186 @@ +/** + * Import analysis for detecting and rewriting torch imports using Python's + * built-in `ast` module via Pyodide. + * + * This avoids the limitations of py-slang's parser (which only supports a + * subset of Python) by delegating to CPython's own parser running inside + * Pyodide. + */ + +import type { PyodideInterface } from "pyodide"; + +export interface TorchImportInfo { + /** "import" for bare `import torch`, "from" for `from torch import ...` */ + type: "import" | "from"; + /** Full module path, e.g. "torch" or "torch.nn" */ + module: string; + /** Imported names with optional aliases */ + names: { name: string; alias: string | null }[]; + /** 1-based line number in the original source */ + line: number; +} + +/** + * Python helper that uses the `ast` module to extract import info. + * Returns a JSON string describing all FromImport statements. + */ +const ANALYZE_IMPORTS_PY = ` +import ast as _ast, json as _json + +def _sa_analyze_imports(source): + """Parse source and return JSON array of import info (both 'import' and 'from ... import').""" + try: + tree = _ast.parse(source) + except SyntaxError: + return "[]" + result = [] + for node in _ast.walk(tree): + if isinstance(node, _ast.ImportFrom) and node.module: + result.append({ + "type": "from", + "module": node.module, + "names": [ + {"name": a.name, "alias": a.asname} + for a in node.names + ], + "line": node.lineno, + }) + elif isinstance(node, _ast.Import): + for a in node.names: + result.append({ + "type": "import", + "module": a.name, + "names": [{"name": a.name, "alias": a.asname}], + "line": node.lineno, + }) + return _json.dumps(result) +`; + +let helperLoaded = false; + +/** + * Ensure the Python-side `_sa_analyze_imports` function is defined. + * Idempotent — only runs once. + */ +async function ensureHelper(pyodide: PyodideInterface): Promise { + if (helperLoaded) return; + await pyodide.runPythonAsync(ANALYZE_IMPORTS_PY); + helperLoaded = true; +} + +/** + * Reset the helper loaded state. Useful for testing when pyodide + * instances are recreated. + */ +export function resetHelperState(): void { + helperLoaded = false; +} + +/** + * Parses the source code using Python's `ast` module (via Pyodide) and + * returns all `from … import …` statements whose root module is "torch". + */ +export async function detectTorchImports( + pyodide: PyodideInterface, + source: string, +): Promise { + await ensureHelper(pyodide); + + const json = pyodide.runPython(`_sa_analyze_imports(${JSON.stringify(source)})`) as string; + + const allImports: TorchImportInfo[] = JSON.parse(json); + return allImports.filter(imp => imp.module.split(".")[0] === "torch"); +} + +/** + * Returns the set of top-level module roots for all non-torch + * `from … import …` statements. These may need to be installed via micropip. + */ +export async function getNonTorchImportRoots( + pyodide: PyodideInterface, + source: string, +): Promise> { + await ensureHelper(pyodide); + + const json = pyodide.runPython(`_sa_analyze_imports(${JSON.stringify(source)})`) as string; + + const allImports: TorchImportInfo[] = JSON.parse(json); + const roots = new Set(); + for (const imp of allImports) { + const root = imp.module.split(".")[0]; + if (root !== "torch") { + roots.add(root); + } + } + return roots; +} + +/** + * Generates Python assignment code that replaces a torch import statement. + * + * Examples: + * import torch → torch = __sa_import_torch + * import torch as t → t = __sa_import_torch + * import torch.nn → torch = __sa_import_torch + * from torch.nn import Linear as L, Conv2d + * → L = __sa_import_torch.nn.Linear + * Conv2d = __sa_import_torch.nn.Conv2d + */ +function generateReplacement(imp: TorchImportInfo): string { + const injected = "__sa_import_torch"; + + if (imp.type === "import") { + // `import torch` or `import torch as t` or `import torch.nn` + const alias = imp.names[0].alias; + if (alias) { + // import torch as t → t = __sa_import_torch + // import torch.nn as nn → nn = __sa_import_torch.nn + const subparts = imp.module.split(".").slice(1); + const rhs = subparts.length > 0 ? `${injected}.${subparts.join(".")}` : injected; + return `${alias} = ${rhs}`; + } + // import torch → torch = __sa_import_torch + // import torch.nn → torch = __sa_import_torch (Python binds the top-level name) + return `torch = ${injected}`; + } + + // from torch.nn import Linear as L, Conv2d + const subparts = imp.module.split(".").slice(1); + const base = subparts.length > 0 ? `${injected}.${subparts.join(".")}` : injected; + + return imp.names + .map(({ name, alias }) => { + const binding = alias ?? name; + return `${binding} = ${base}.${name}`; + }) + .join("\n"); +} + +/** + * Rewrites the source code by replacing torch import lines with + * variable assignments that reference the injected `__sa_import_torch` global. + * + * Non-torch code is passed through unchanged. + */ +export async function rewriteTorchImports( + pyodide: PyodideInterface, + source: string, +): Promise<{ code: string; hasTorch: boolean }> { + const imports = await detectTorchImports(pyodide, source); + + if (imports.length === 0) { + return { code: source, hasTorch: false }; + } + + const lines = source.split(/\r?\n/); + + // Process in reverse order so earlier line indices stay valid. + for (let i = imports.length - 1; i >= 0; i--) { + const imp = imports[i]; + const replacement = generateReplacement(imp); + const idx = imp.line - 1; + lines.splice(idx, 1, replacement); + } + + return { code: lines.join("\n"), hasTorch: true }; +} diff --git a/src/pyodide/loadPyodide.ts b/src/pyodide/loadPyodide.ts new file mode 100644 index 00000000..d8b02f90 --- /dev/null +++ b/src/pyodide/loadPyodide.ts @@ -0,0 +1,47 @@ +import { version, loadPyodide } from "pyodide"; +import type { PyodideInterface } from "pyodide"; + +const IN_NODE = + typeof process !== "undefined" && process.versions != null && process.versions.node != null; + +async function ensureLocalPyodideAssets(baseUrl: string): Promise { + const path = await import("node:path"); + const fs = await import("node:fs/promises"); + const os = await import("node:os"); + + const dir = path.join(os.tmpdir(), `pyodide-${version}`); + await fs.mkdir(dir, { recursive: true }); + + const assets = [ + { name: "pyodide.asm.js", mode: "text" as const }, + { name: "pyodide.asm.wasm", mode: "binary" as const }, + { name: "python_stdlib.zip", mode: "binary" as const }, + { name: "pyodide-lock.json", mode: "text" as const }, + ]; + + for (const asset of assets) { + const url = baseUrl + asset.name; + const dest = path.join(dir, asset.name); + try { + await fs.access(dest); + continue; + } catch { + // File doesn't exist yet — download it. + } + const res = await fetch(url); + if (!res.ok) throw new Error(`Failed to fetch ${url}: ${res.status} ${res.statusText}`); + const data = + asset.mode === "text" + ? Buffer.from(await res.text(), "utf8") + : Buffer.from(await res.arrayBuffer()); + await fs.writeFile(dest, data); + } + + return dir + path.sep; +} + +export async function loadPyodideGeneric(): Promise { + const cdnBase = `https://cdn.jsdelivr.net/pyodide/v${version}/full/`; + const indexURL = IN_NODE ? await ensureLocalPyodideAssets(cdnBase) : cdnBase; + return loadPyodide({ indexURL, fullStdLib: true }); +} diff --git a/src/pyodide/loadTorch.ts b/src/pyodide/loadTorch.ts new file mode 100644 index 00000000..f7ba47f4 --- /dev/null +++ b/src/pyodide/loadTorch.ts @@ -0,0 +1,20 @@ +import type { PyodideInterface } from "pyodide"; +import * as torch from "@sourceacademy/torch"; +import bridgeCode from "./bridge.py"; + +/** + * Loads the torch library into Pyodide by exposing the JS torch object + * and running bridge.py to set up the Python-side `torch` module. + * + * After this call, `pyodide.globals.get("torch")` is the usable torch module. + */ +export async function loadTorch(pyodide: PyodideInterface): Promise { + pyodide.globals.set("js_torch", torch); + + await pyodide.runPythonAsync(bridgeCode); + + const hasTorch = pyodide.runPython("'torch' in globals()"); + if (!hasTorch) { + throw new Error("torch not found in globals after running bridge.py"); + } +} diff --git a/src/tests/import-analyzer.test.ts b/src/tests/import-analyzer.test.ts new file mode 100644 index 00000000..220b88ec --- /dev/null +++ b/src/tests/import-analyzer.test.ts @@ -0,0 +1,256 @@ +import { loadPyodide } from "pyodide"; +import type { PyodideInterface } from "pyodide"; +import { + detectTorchImports, + getNonTorchImportRoots, + rewriteTorchImports, + resetHelperState, +} from "../pyodide/importAnalyzer"; + +let pyodide: PyodideInterface; + +beforeAll(async () => { + resetHelperState(); + pyodide = await loadPyodide(); +}, 60_000); + +// --------------------------------------------------------------------------- +// detectTorchImports +// --------------------------------------------------------------------------- +describe("detectTorchImports", () => { + test("detects `from torch import tensor`", async () => { + const result = await detectTorchImports(pyodide, "from torch import tensor\nx = 1\n"); + expect(result).toHaveLength(1); + expect(result[0].module).toBe("torch"); + expect(result[0].names).toEqual([{ name: "tensor", alias: null }]); + }); + + test("detects `from torch.nn import Linear as L`", async () => { + const result = await detectTorchImports(pyodide, "from torch.nn import Linear as L\nx = 1\n"); + expect(result).toHaveLength(1); + expect(result[0].module).toBe("torch.nn"); + expect(result[0].names).toEqual([{ name: "Linear", alias: "L" }]); + }); + + test("detects multiple names", async () => { + const result = await detectTorchImports( + pyodide, + "from torch import tensor, zeros, ones\nx = 1\n", + ); + expect(result).toHaveLength(1); + expect(result[0].names).toEqual([ + { name: "tensor", alias: null }, + { name: "zeros", alias: null }, + { name: "ones", alias: null }, + ]); + }); + + test("detects multiple torch import statements", async () => { + const src = "from torch import tensor\nfrom torch.nn import Linear\nx = 1\n"; + const result = await detectTorchImports(pyodide, src); + expect(result).toHaveLength(2); + expect(result[0].module).toBe("torch"); + expect(result[1].module).toBe("torch.nn"); + }); + + test("detects bare `import torch`", async () => { + const result = await detectTorchImports(pyodide, "import torch\nx = 1\n"); + expect(result).toHaveLength(1); + expect(result[0].type).toBe("import"); + expect(result[0].module).toBe("torch"); + expect(result[0].names).toEqual([{ name: "torch", alias: null }]); + }); + + test("detects `import torch as t`", async () => { + const result = await detectTorchImports(pyodide, "import torch as t\nx = 1\n"); + expect(result).toHaveLength(1); + expect(result[0].type).toBe("import"); + expect(result[0].module).toBe("torch"); + expect(result[0].names).toEqual([{ name: "torch", alias: "t" }]); + }); + + test("detects `import torch.nn`", async () => { + const result = await detectTorchImports(pyodide, "import torch.nn\nx = 1\n"); + expect(result).toHaveLength(1); + expect(result[0].type).toBe("import"); + expect(result[0].module).toBe("torch.nn"); + }); + + test("detects mix of bare import and from-import", async () => { + const src = "import torch\nfrom torch.nn import Linear\nx = 1\n"; + const result = await detectTorchImports(pyodide, src); + expect(result).toHaveLength(2); + const types = result.map(r => r.type).sort(); + expect(types).toEqual(["from", "import"]); + }); + + test("ignores non-torch imports", async () => { + const result = await detectTorchImports(pyodide, "from math import sqrt\nx = 1\n"); + expect(result).toHaveLength(0); + }); + + test("returns empty array for syntax errors", async () => { + const result = await detectTorchImports(pyodide, "def (broken\n"); + expect(result).toHaveLength(0); + }); + + test("returns empty for code with no imports", async () => { + const result = await detectTorchImports(pyodide, "x = 1\n"); + expect(result).toHaveLength(0); + }); + + test("works with full Python syntax (method calls, list literals)", async () => { + const result = await detectTorchImports( + pyodide, + "from torch import tensor\nx = tensor([1,2]).tolist()\n", + ); + expect(result).toHaveLength(1); + expect(result[0].module).toBe("torch"); + }); +}); + +// --------------------------------------------------------------------------- +// getNonTorchImportRoots +// --------------------------------------------------------------------------- +describe("getNonTorchImportRoots", () => { + test("returns non-torch module roots", async () => { + const src = "from math import sqrt\nfrom torch import tensor\nfrom numpy import array\nx = 1\n"; + const roots = await getNonTorchImportRoots(pyodide, src); + expect(roots).toEqual(new Set(["math", "numpy"])); + }); + + test("returns empty set when only torch imports exist", async () => { + const roots = await getNonTorchImportRoots(pyodide, "from torch import tensor\nx = 1\n"); + expect(roots).toEqual(new Set()); + }); + + test("extracts root from dotted module name", async () => { + const roots = await getNonTorchImportRoots(pyodide, "from os.path import join\nx = 1\n"); + expect(roots).toEqual(new Set(["os"])); + }); + + test("returns non-torch roots for bare import statements", async () => { + const src = "import numpy\nimport torch\nx = 1\n"; + const roots = await getNonTorchImportRoots(pyodide, src); + expect(roots).toEqual(new Set(["numpy"])); + }); +}); + +// --------------------------------------------------------------------------- +// rewriteTorchImports +// --------------------------------------------------------------------------- +describe("rewriteTorchImports", () => { + test("rewrites `from torch import tensor`", async () => { + const { code, hasTorch } = await rewriteTorchImports( + pyodide, + "from torch import tensor\nx = tensor(1)\n", + ); + expect(hasTorch).toBe(true); + expect(code).toContain("tensor = __sa_import_torch.tensor"); + expect(code).not.toContain("from torch"); + expect(code).toContain("x = tensor(1)"); + }); + + test("rewrites `from torch.nn import Linear as L`", async () => { + const { code, hasTorch } = await rewriteTorchImports( + pyodide, + "from torch.nn import Linear as L\nx = L(3, 2)\n", + ); + expect(hasTorch).toBe(true); + expect(code).toContain("L = __sa_import_torch.nn.Linear"); + expect(code).toContain("x = L(3, 2)"); + }); + + test("rewrites multiple names", async () => { + const { code, hasTorch } = await rewriteTorchImports( + pyodide, + "from torch import tensor, zeros\nx = 1\n", + ); + expect(hasTorch).toBe(true); + expect(code).toContain("tensor = __sa_import_torch.tensor"); + expect(code).toContain("zeros = __sa_import_torch.zeros"); + }); + + test("leaves non-torch code unchanged", async () => { + const src = "x = 1\n"; + const { code, hasTorch } = await rewriteTorchImports(pyodide, src); + expect(hasTorch).toBe(false); + expect(code).toBe(src); + }); + + test("preserves non-torch imports", async () => { + const src = "from math import sqrt\nfrom torch import tensor\nx = sqrt(tensor(4))\n"; + const { code, hasTorch } = await rewriteTorchImports(pyodide, src); + expect(hasTorch).toBe(true); + expect(code).toContain("from math import sqrt"); + expect(code).toContain("tensor = __sa_import_torch.tensor"); + }); + + test("rewrites deeply nested module path", async () => { + const { code } = await rewriteTorchImports( + pyodide, + "from torch.nn.functional import relu\nx = relu(1)\n", + ); + expect(code).toContain("relu = __sa_import_torch.nn.functional.relu"); + }); + + test("rewrites bare `import torch`", async () => { + const { code, hasTorch } = await rewriteTorchImports( + pyodide, + "import torch\nx = torch.tensor([1, 2, 3])\n", + ); + expect(hasTorch).toBe(true); + expect(code).toContain("torch = __sa_import_torch"); + expect(code).not.toContain("import torch"); + expect(code).toContain("x = torch.tensor([1, 2, 3])"); + }); + + test("rewrites `import torch as t`", async () => { + const { code, hasTorch } = await rewriteTorchImports( + pyodide, + "import torch as t\nx = t.tensor([1])\n", + ); + expect(hasTorch).toBe(true); + expect(code).toContain("t = __sa_import_torch"); + expect(code).toContain("x = t.tensor([1])"); + }); + + test("rewrites `import torch.nn as nn`", async () => { + const { code, hasTorch } = await rewriteTorchImports( + pyodide, + "import torch.nn as nn\nx = nn.Linear(3, 2)\n", + ); + expect(hasTorch).toBe(true); + expect(code).toContain("nn = __sa_import_torch.nn"); + expect(code).toContain("x = nn.Linear(3, 2)"); + }); + + test("rewrites `import torch.nn` (no alias)", async () => { + const { code, hasTorch } = await rewriteTorchImports( + pyodide, + "import torch.nn\nx = torch.nn.Linear(3, 2)\n", + ); + expect(hasTorch).toBe(true); + expect(code).toContain("torch = __sa_import_torch"); + expect(code).toContain("x = torch.nn.Linear(3, 2)"); + }); + + test("rewrites mix of bare import and from-import", async () => { + const src = + "import torch\nfrom torch.nn import Linear\nx = torch.tensor(1)\ny = Linear(3, 2)\n"; + const { code, hasTorch } = await rewriteTorchImports(pyodide, src); + expect(hasTorch).toBe(true); + expect(code).toContain("torch = __sa_import_torch"); + expect(code).toContain("Linear = __sa_import_torch.nn.Linear"); + expect(code).not.toMatch(/^import torch$/m); + expect(code).not.toContain("from torch"); + }); + + test("handles full Python body that py-slang cannot parse", async () => { + const src = "from torch import tensor\nx = tensor([1, 2, 3]).tolist()\nprint(x)\n"; + const { code, hasTorch } = await rewriteTorchImports(pyodide, src); + expect(hasTorch).toBe(true); + expect(code).toContain("tensor = __sa_import_torch.tensor"); + expect(code).toContain("x = tensor([1, 2, 3]).tolist()"); + }); +}); diff --git a/src/tests/pyodide-torch.test.ts b/src/tests/pyodide-torch.test.ts new file mode 100644 index 00000000..1819a065 --- /dev/null +++ b/src/tests/pyodide-torch.test.ts @@ -0,0 +1,87 @@ +/** + * Integration tests for the pyodide+torch pipeline. + * + * These tests load real pyodide and torch, so they are slow on first run + * while pyodide downloads assets. They verify the full flow: parse → rewrite + * imports → load torch → execute in pyodide. + */ + +import { loadPyodide } from "pyodide"; +import type { PyodideInterface } from "pyodide"; +import * as torch from "@sourceacademy/torch"; +import bridgeCode from "../pyodide/bridge.py"; +import { rewriteTorchImports, resetHelperState } from "../pyodide/importAnalyzer"; + +let pyodide: PyodideInterface; + +beforeAll(async () => { + resetHelperState(); + pyodide = await loadPyodide({ fullStdLib: true }); + + // Set up torch in pyodide (mirrors what loadTorch does) + pyodide.globals.set("js_torch", torch); + await pyodide.runPythonAsync(bridgeCode); + pyodide.globals.set("__sa_import_torch", pyodide.globals.get("torch")); +}, 60_000); + +async function runTorchCode(source: string): Promise { + const { code } = await rewriteTorchImports(pyodide, source); + return pyodide.runPythonAsync(code); +} + +/** Convert a pyodide result to a plain JS value. */ +function toJS(result: unknown): unknown { + if (result != null && typeof result === "object" && "toJs" in result) { + return (result as { toJs: () => unknown }).toJs(); + } + if (result != null && typeof result === "object" && "to_py" in result) { + return (result as { to_py: () => unknown }).to_py(); + } + return result; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("pyodide + torch integration", () => { + test("from torch import tensor — create and read a tensor", async () => { + const result = await runTorchCode( + "from torch import tensor\nx = tensor([1, 2, 3])\nx.tolist()\n", + ); + const val = toJS(result); + expect(val).toEqual([1, 2, 3]); + }); + + test("from torch import zeros — creation function", async () => { + const result = await runTorchCode("from torch import zeros\nx = zeros(3)\nx.tolist()\n"); + const val = toJS(result); + expect(val).toEqual([0, 0, 0]); + }); + + test("from torch.nn import Linear as L — submodule with alias", async () => { + // Just verify it doesn't throw + await runTorchCode( + "from torch.nn import Linear as L\nfrom torch import zeros\nlayer = L(2, 3)\nresult = layer(zeros(2))\nresult.shape\n", + ); + }); + + test("tensor arithmetic", async () => { + const result = await runTorchCode( + "from torch import tensor\na = tensor([1, 2, 3])\nb = tensor([4, 5, 6])\n(a + b).tolist()\n", + ); + const val = toJS(result); + expect(val).toEqual([5, 7, 9]); + }); + + test("autograd — backward pass", async () => { + const result = await runTorchCode(`from torch import tensor +x = tensor([2.0], True) +y = x * x +y.backward() +x.grad.tolist() +`); + const val = toJS(result); + expect(val).toEqual([4.0]); + }); +}); diff --git a/src/tests/raw-text-transformer.js b/src/tests/raw-text-transformer.js new file mode 100644 index 00000000..e6e0a1c9 --- /dev/null +++ b/src/tests/raw-text-transformer.js @@ -0,0 +1,7 @@ +module.exports = { + process(sourceText) { + return { + code: `module.exports = ${JSON.stringify(sourceText)};`, + }; + }, +}; diff --git a/src/tests/utils.ts b/src/tests/utils.ts index 2078bc0e..7a727caa 100644 --- a/src/tests/utils.ts +++ b/src/tests/utils.ts @@ -1,3 +1,4 @@ +import { jest } from "@jest/globals"; import { ConductorError, ErrorType } from "@sourceacademy/conductor/common"; import { StmtNS } from "../ast-types"; import { Context } from "../engines/cse/context"; diff --git a/yarn.lock b/yarn.lock index bb2ef638..3f3fed56 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1643,6 +1643,13 @@ __metadata: languageName: node linkType: hard +"@sourceacademy/torch@npm:^0.1.0": + version: 0.1.0 + resolution: "@sourceacademy/torch@npm:0.1.0" + checksum: 10c0/bc6408d374a23080d3581ac06edbfd02f94215a9d2c27c38913a7b2ef99d5eadcec5f8b496dba4fac44341a216bded3412622f5b66152253373d969cdbbf7431 + languageName: node + linkType: hard + "@sourceacademy/wasm-util@npm:^1.0.6": version: 1.0.6 resolution: "@sourceacademy/wasm-util@npm:1.0.6" @@ -1691,6 +1698,13 @@ __metadata: languageName: node linkType: hard +"@types/emscripten@npm:^1.41.4": + version: 1.41.5 + resolution: "@types/emscripten@npm:1.41.5" + checksum: 10c0/ae816da716f896434e59df7a71b67c71ae7e85ca067a32aef1616572fc4757459515d42ade6f5b8fd8d69733a9dbd0cf23010fec5b2f41ce52c09501aa350e45 + languageName: node + linkType: hard + "@types/estree@npm:*, @types/estree@npm:1.0.8, @types/estree@npm:^1.0.0, @types/estree@npm:^1.0.6": version: 1.0.8 resolution: "@types/estree@npm:1.0.8" @@ -4768,6 +4782,7 @@ __metadata: "@rollup/plugin-typescript": "npm:^12.1.2" "@rollup/plugin-wasm": "npm:^6.2.2" "@sourceacademy/conductor": "npm:^0.3.0" + "@sourceacademy/torch": "npm:^0.1.0" "@sourceacademy/wasm-util": "npm:^1.0.6" "@types/fast-levenshtein": "npm:^0.0.4" "@types/jest": "npm:^29.5.14" @@ -4784,6 +4799,7 @@ __metadata: moo: "npm:^0.5.2" nearley: "npm:^2.20.1" prettier: "npm:^3.8.1" + pyodide: "npm:^0.29.3" rollup: "npm:^4.59.0" rollup-plugin-polyfill-node: "npm:^0.13.0" ts-jest: "npm:^29.0.5" @@ -4795,6 +4811,16 @@ __metadata: languageName: unknown linkType: soft +"pyodide@npm:^0.29.3": + version: 0.29.3 + resolution: "pyodide@npm:0.29.3" + dependencies: + "@types/emscripten": "npm:^1.41.4" + ws: "npm:^8.5.0" + checksum: 10c0/4c8108e9af7cd8997812507a01c3dd48789ab58973bdef3ac5a336c38837e7495146195be7fb8b0798bcc3c4f79e98877efaf0672d18e088488512eefbf1d3ca + languageName: node + linkType: hard + "railroad-diagrams@npm:^1.0.0": version: 1.0.0 resolution: "railroad-diagrams@npm:1.0.0" @@ -5628,6 +5654,21 @@ __metadata: languageName: node linkType: hard +"ws@npm:^8.5.0": + version: 8.20.0 + resolution: "ws@npm:8.20.0" + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: ">=5.0.2" + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + checksum: 10c0/956ac5f11738c914089b65878b9223692ace77337ba55379ae68e1ecbeae9b47a0c6eb9403688f609999a58c80d83d99865fe0029b229d308b08c1ef93d4ea14 + languageName: node + linkType: hard + "xmlcreate@npm:^2.0.4": version: 2.0.4 resolution: "xmlcreate@npm:2.0.4"