99 StructuredSparseTensor ,
1010 encode_v_to_ps ,
1111 fix_dim_encoding ,
12+ impl ,
1213 print_fallback ,
1314 to_most_efficient_tensor ,
1415 unwrap_to_dense ,
1516)
1617
1718
18- @StructuredSparseTensor . implements (aten .view .default )
19+ @impl (aten .view .default )
1920def view_default (t : StructuredSparseTensor , shape : list [int ]) -> Tensor :
2021 assert isinstance (t , StructuredSparseTensor )
2122
@@ -120,14 +121,14 @@ def new_encoding_fn(d: int) -> list[int]:
120121 return new_physical , new_encoding
121122
122123
123- @StructuredSparseTensor . implements (aten ._unsafe_view .default )
124+ @impl (aten ._unsafe_view .default )
124125def _unsafe_view_default (t : StructuredSparseTensor , shape : list [int ]) -> Tensor :
125126 return view_default (
126127 t , shape
127128 ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp
128129
129130
130- @StructuredSparseTensor . implements (aten .unsqueeze .default )
131+ @impl (aten .unsqueeze .default )
131132def unsqueeze_default (t : StructuredSparseTensor , dim : int ) -> StructuredSparseTensor :
132133 assert isinstance (t , StructuredSparseTensor )
133134 assert - t .ndim - 1 <= dim < t .ndim + 1
@@ -141,7 +142,7 @@ def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTe
141142 return StructuredSparseTensor (t .physical , new_v_to_ps )
142143
143144
144- @StructuredSparseTensor . implements (aten .squeeze .dims )
145+ @impl (aten .squeeze .dims )
145146def squeeze_dims (t : StructuredSparseTensor , dims : list [int ] | int | None ) -> Tensor :
146147 assert isinstance (t , StructuredSparseTensor )
147148
@@ -157,15 +158,15 @@ def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Ten
157158 return to_most_efficient_tensor (t .physical , new_v_to_ps )
158159
159160
160- @StructuredSparseTensor . implements (aten .permute .default )
161+ @impl (aten .permute .default )
161162def permute_default (t : StructuredSparseTensor , dims : list [int ]) -> StructuredSparseTensor :
162163 new_v_to_ps = [t .v_to_ps [d ] for d in dims ]
163164
164165 new_physical , new_v_to_ps = fix_dim_encoding (t .physical , new_v_to_ps )
165166 return StructuredSparseTensor (new_physical , new_v_to_ps )
166167
167168
168- @StructuredSparseTensor . implements (aten .cat .default )
169+ @impl (aten .cat .default )
169170def cat_default (tensors : list [Tensor ], dim : int ) -> Tensor :
170171 if any (not isinstance (t , StructuredSparseTensor ) for t in tensors ):
171172 print_fallback (aten .cat .default , (tensors , dim ), {})
@@ -217,7 +218,7 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
217218 return StructuredSparseTensor (new_physical , new_v_to_ps )
218219
219220
220- @StructuredSparseTensor . implements (aten .expand .default )
221+ @impl (aten .expand .default )
221222def expand_default (t : StructuredSparseTensor , sizes : list [int ]) -> StructuredSparseTensor :
222223 # note that sizes could also be just an int, or a torch.Size i think
223224 assert isinstance (t , StructuredSparseTensor )
@@ -252,7 +253,7 @@ def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSpa
252253 return StructuredSparseTensor (new_physical , new_v_to_ps )
253254
254255
255- @StructuredSparseTensor . implements (aten .broadcast_tensors .default )
256+ @impl (aten .broadcast_tensors .default )
256257def broadcast_tensors_default (tensors : list [Tensor ]) -> tuple [Tensor , Tensor ]:
257258 if len (tensors ) != 2 :
258259 raise NotImplementedError ()
@@ -279,7 +280,7 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
279280 return aten .expand .default (t1 , new_shape ), aten .expand .default (t2 , new_shape )
280281
281282
282- @StructuredSparseTensor . implements (aten .slice .Tensor )
283+ @impl (aten .slice .Tensor )
283284def slice_Tensor (
284285 t : StructuredSparseTensor , dim : int , start : int | None , end : int | None , step : int = 1
285286) -> StructuredSparseTensor :
@@ -315,7 +316,7 @@ def slice_Tensor(
315316 return StructuredSparseTensor (new_physical , t .v_to_ps )
316317
317318
318- @StructuredSparseTensor . implements (aten .transpose .int )
319+ @impl (aten .transpose .int )
319320def transpose_int (t : StructuredSparseTensor , dim0 : int , dim1 : int ) -> StructuredSparseTensor :
320321 assert isinstance (t , StructuredSparseTensor )
321322
0 commit comments