@@ -184,8 +184,7 @@ def transpose(self, dim0, dim1):
184184 return Tensor (self ._js .transpose (dim0 , dim1 ))
185185
186186 def flatten (self , start_dim = 0 , end_dim = - 1 ):
187- n = self .numel ()
188- return self .reshape ([n ])
187+ return Tensor (self ._js .flatten (start_dim , end_dim ))
189188
190189 # ------------------------------------------------------------------
191190 # Reductions — default (no dim) sums all elements, matching PyTorch
@@ -387,7 +386,7 @@ def __setattr__(self, name, value):
387386 object .__setattr__ (self , name , value )
388387 return
389388
390- if isinstance (value , Tensor ) and value . requires_grad :
389+ if isinstance (value , Parameter ) :
391390 params [name ] = value
392391 elif isinstance (value , (Module , _NNModule )):
393392 modules [name ] = value
@@ -674,17 +673,15 @@ def is_grad_enabled(self):
674673 return bool (js_torch .is_grad_enabled ())
675674
676675 def cat (self , tensors , dim = 0 ):
677- """Concatenate tensors along dim. NOTE: gradient is not tracked."""
678- if dim != 0 :
679- raise NotImplementedError ("torch.cat only supports dim=0 in this bridge" )
680- result = []
681- for t in tensors :
682- data = t .tolist ()
683- if isinstance (data , list ):
684- result .extend (data )
685- else :
686- result .append (data )
687- return Tensor (result )
676+ if isinstance (tensors , Tensor ):
677+ tensors = [tensors ]
678+ return Tensor (js_torch .cat (to_js ([t ._js for t in tensors ]), dim ))
679+
680+ def concatenate (self , tensors , dim = 0 ):
681+ return self .cat (tensors , dim )
682+
683+ def concat (self , tensors , dim = 0 ):
684+ return self .cat (tensors , dim )
688685
689686 def Size (self , shape ):
690687 return list (shape )
0 commit comments