Skip to content

Commit e8dadce

Browse files
committed
Update torch and bridge
1 parent 9d86260 commit e8dadce

3 files changed

Lines changed: 16 additions & 19 deletions

File tree

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
"moo": "^0.5.2",
6363
"nearley": "^2.20.1",
6464
"pyodide": "^0.29.3",
65-
"torch": "https://pkg.pr.new/veehz/torch@297abf6",
65+
"torch": "https://pkg.pr.new/veehz/torch@687fc81",
6666
"wabt": "^1.0.37"
6767
}
6868
}

src/pyodide/bridge.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ def transpose(self, dim0, dim1):
184184
return Tensor(self._js.transpose(dim0, dim1))
185185

186186
def flatten(self, start_dim=0, end_dim=-1):
187-
n = self.numel()
188-
return self.reshape([n])
187+
return Tensor(self._js.flatten(start_dim, end_dim))
189188

190189
# ------------------------------------------------------------------
191190
# Reductions — default (no dim) sums all elements, matching PyTorch
@@ -387,7 +386,7 @@ def __setattr__(self, name, value):
387386
object.__setattr__(self, name, value)
388387
return
389388

390-
if isinstance(value, Tensor) and value.requires_grad:
389+
if isinstance(value, Parameter):
391390
params[name] = value
392391
elif isinstance(value, (Module, _NNModule)):
393392
modules[name] = value
@@ -674,17 +673,15 @@ def is_grad_enabled(self):
674673
return bool(js_torch.is_grad_enabled())
675674

676675
def cat(self, tensors, dim=0):
677-
"""Concatenate tensors along dim. NOTE: gradient is not tracked."""
678-
if dim != 0:
679-
raise NotImplementedError("torch.cat only supports dim=0 in this bridge")
680-
result = []
681-
for t in tensors:
682-
data = t.tolist()
683-
if isinstance(data, list):
684-
result.extend(data)
685-
else:
686-
result.append(data)
687-
return Tensor(result)
676+
if isinstance(tensors, Tensor):
677+
tensors = [tensors]
678+
return Tensor(js_torch.cat(to_js([t._js for t in tensors]), dim))
679+
680+
def concatenate(self, tensors, dim=0):
681+
return self.cat(tensors, dim)
682+
683+
def concat(self, tensors, dim=0):
684+
return self.cat(tensors, dim)
688685

689686
def Size(self, shape):
690687
return list(shape)

yarn.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4306,7 +4306,7 @@ __metadata:
43064306
prettier: "npm:^3.8.1"
43074307
pyodide: "npm:^0.29.3"
43084308
rollup: "npm:^4.59.0"
4309-
torch: "https://pkg.pr.new/veehz/torch@297abf6"
4309+
torch: "https://pkg.pr.new/veehz/torch@687fc81"
43104310
ts-jest: "npm:^29.0.5"
43114311
ts-node: "npm:^10.9.1"
43124312
tslib: "npm:^2.8.1"
@@ -4818,10 +4818,10 @@ __metadata:
48184818
languageName: node
48194819
linkType: hard
48204820

4821-
"torch@https://pkg.pr.new/veehz/torch@297abf6":
4821+
"torch@https://pkg.pr.new/veehz/torch@687fc81":
48224822
version: 0.1.0
4823-
resolution: "torch@https://pkg.pr.new/veehz/torch@297abf6"
4824-
checksum: 10c0/bde381a0d266c845970fd34b3d27048709b60fd9e64a2d96f3080eb9c3a4bb1433bcc0ea1bc1123fbb31883ad2623c031ede94319e514f54398b9752c290c6d4
4823+
resolution: "torch@https://pkg.pr.new/veehz/torch@687fc81"
4824+
checksum: 10c0/84fd0d0fae8f5e67dccfa01d429755ef5db147340544f697a14ce32332a31b5515b84aa47ac849c327f1026c8c4ff29a7af3c60bd5c1db541f49ddfa9f529185
48254825
languageName: node
48264826
linkType: hard
48274827

0 commit comments

Comments
 (0)