11import operator
22from itertools import accumulate
33from math import prod
4- from typing import cast
54
65import torch
76from torch import Tensor , arange , cat , tensor
@@ -41,6 +40,9 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
4140
4241 assert isinstance (t , SparseLatticedTensor )
4342
43+ if not torch .equal (t .padding , torch .zeros_like (t .padding )):
44+ raise NotImplementedError ()
45+
4446 shape = infer_shape (shape , t .numel ())
4547
4648 if prod (shape ) != t .numel ():
@@ -51,7 +53,9 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
5153 c = _reverse_cumulative_product (vshape )
5254 c_prime = _reverse_cumulative_product (shape )
5355 new_basis = ((c @ S ).unsqueeze (0 ) // c_prime .unsqueeze (1 )) % tensor (shape ).unsqueeze (1 )
54- return to_most_efficient_tensor (t .physical , new_basis )
56+
57+ new_offset = torch .zeros (len (shape ), dtype = torch .int64 )
58+ return to_most_efficient_tensor (t .physical , new_basis , new_offset , shape )
5559
5660
5761def _reverse_cumulative_product (values : list [int ]) -> Tensor :
@@ -87,7 +91,7 @@ def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor
8791 pdims = t .basis .shape [1 ]
8892 new_basis = cat ([t .basis [:dim ], torch .zeros (1 , pdims , dtype = torch .int64 ), t .basis [dim :]])
8993 new_offset = cat ([t .offset [:dim ], torch .zeros (1 , dtype = torch .int64 ), t .offset [dim :]])
90- new_size = cat ([t .size [:dim ], torch .zeros (1 , dtype = torch .int64 ), t .size [dim :]])
94+ new_size = cat ([t .shape_t [:dim ], torch .ones (1 , dtype = torch .int64 ), t .shape_t [dim :]])
9195 return SparseLatticedTensor (t .physical , new_basis , new_offset , new_size )
9296
9397
@@ -106,15 +110,15 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso
106110 is_row_kept = [i not in excluded for i in range (t .ndim )]
107111 new_basis = t .basis [is_row_kept ]
108112 new_offset = t .offset [is_row_kept ]
109- new_size = t .size [is_row_kept ]
113+ new_size = t .shape_t [is_row_kept ]
110114 return to_most_efficient_tensor (t .physical , new_basis , new_offset , new_size )
111115
112116
113117@impl (aten .permute .default )
114118def permute_default (t : SparseLatticedTensor , dims : list [int ]) -> SparseLatticedTensor :
115119 new_basis = t .basis [dims ]
116120 new_offset = t .offset [dims ]
117- new_size = t .size [dims ]
121+ new_size = t .shape_t [dims ]
118122 return SparseLatticedTensor (t .physical , new_basis , new_offset , new_size )
119123
120124
@@ -124,56 +128,10 @@ def cat_default(tensors: list[Tensor], dim: int = 0) -> Tensor:
124128 print_fallback (aten .cat .default , (tensors , dim ), {})
125129 return aten .cat .default ([unwrap_to_dense (t ) for t in tensors ])
126130
127- tensors_ = [cast (SparseLatticedTensor , t ) for t in tensors ]
128- ref_tensor = tensors_ [0 ]
129- ref_basis = ref_tensor .basis
130- if any (not torch .equal (t .basis , ref_basis ) for t in tensors_ [1 :]):
131- raise NotImplementedError (
132- "Override for aten.cat.default does not support SLTs that do not all have the same "
133- f"basis. Found the following tensors:\n { [repr (t ) for t in tensors_ ]} and the following "
134- f"dim: { dim } ."
135- )
136- if any (t .physical .shape != ref_tensor .physical .shape for t in tensors_ [1 :]):
137- # This can happen in the following example:
138- # t1 = SLT([1 2 3], [[2]])
139- # t2 = SLT([4 5 6 7], [[2]])
140- # The expected result would be 1 0 2 0 3 4 0 5 0 6 0 7, but this is not representable
141- # efficiently as an SLT (because there is no 0 between 3 and 4, and both physicals have a
142- # different shape so we can't just stack them).
143-
144- # TODO: Maybe a partial densify is possible rather than a full densify.
145- print_fallback (aten .cat .default , (tensors , dim ), {})
146- return aten .cat .default ([unwrap_to_dense (t ) for t in tensors ])
147-
148- # We need to try to find the (pretty sure it either does not exist or is unique) physical
149- # dimension that makes us only move on virtual dimension dim. It also needs to be such that
150- # traversing it entirely brings us exactly to the end of virtual dimension dim.
151-
152- ref_virtual_dim_size = ref_tensor .shape [dim ]
153- indices = torch .argwhere (
154- torch .eq (ref_basis [dim ] * tensor (ref_tensor .physical .shape ), ref_virtual_dim_size )
155- & torch .eq (ref_basis .sum (dim = 0 ) * tensor (ref_tensor .physical .shape ), ref_virtual_dim_size )
156- )
157- assert len (indices ) <= 1
158-
159- if len (indices ) == 0 :
160- # Add a physical dimension pdim on which we can concatenate the physicals such that this
161- # translates into a concatenation of the virtuals on virtual dimension dim.
162-
163- pdim = ref_tensor .physical .ndim
164- physicals = [t .physical .unsqueeze (- 1 ) for t in tensors_ ]
165- new_basis_vector = torch .zeros (ref_tensor .ndim , 1 , dtype = torch .int64 )
166- new_basis_vector [dim , 0 ] = ref_virtual_dim_size
167- new_basis = torch .concatenate ([ref_tensor .basis , new_basis_vector ], dim = 1 )
168- else :
169- # Such a physical dimension already exists. Note that an alternative implementation would be
170- # to simply always add the physical dimension, and squash it if it ends up being not needed.
171- physicals = [t .physical for t in tensors_ ]
172- pdim = cast (int , indices [0 , 0 ].item ())
173- new_basis = ref_tensor .basis
131+ print_fallback (aten .cat .default , (tensors , dim ), {})
132+ return aten .cat .default ([unwrap_to_dense (t ) for t in tensors ])
174133
175- new_physical = aten .cat .default (physicals , dim = pdim )
176- return SparseLatticedTensor (new_physical , new_basis )
134+ # TODO: add implementation based on adding some margin to tensors and summing them
177135
178136
179137@impl (aten .expand .default )
@@ -190,7 +148,7 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT
190148 # Try to expand each dimension to its new size
191149 new_physical = t .physical
192150 new_basis = t .basis
193- new_sizes = t .size
151+ new_sizes = t .shape_t
194152 for d , (v , orig_size , new_size ) in enumerate (zip (t .basis , t .shape , sizes , strict = True )):
195153 if v .sum () > 0 and orig_size != new_size and new_size != - 1 :
196154 raise ValueError (
0 commit comments