11import itertools
22import operator
3- from collections .abc import Callable
3+ from collections .abc import Callable , Sequence
44from functools import wraps
55from itertools import accumulate
66from math import prod
@@ -14,7 +14,7 @@ class SparseLatticedTensor(Tensor):
1414 _HANDLED_FUNCTIONS = dict [Callable , Callable ]()
1515
1616 @staticmethod
17- def __new__ (cls , physical : Tensor , basis : Tensor ):
17+ def __new__ (cls , physical : Tensor , basis : Tensor , offset : Sequence [ int ], shape : Sequence [ int ] ):
1818 assert basis .dtype == torch .int64
1919
2020 # Note [Passing requires_grad=true tensors to subclasses]
@@ -27,13 +27,13 @@ def __new__(cls, physical: Tensor, basis: Tensor):
2727 # (which is bad!)
2828 assert not physical .requires_grad or not torch .is_grad_enabled ()
2929
30- pshape = tensor (physical .shape , dtype = torch .int64 )
31- vshape = basis @ (pshape - 1 ) + 1
3230 return Tensor ._make_wrapper_subclass (
33- cls , tuple ( vshape . tolist ()) , dtype = physical .dtype , device = physical .device
31+ cls , shape , dtype = physical .dtype , device = physical .device
3432 )
3533
36- def __init__ (self , physical : Tensor , basis : Tensor ):
34+ def __init__ (
35+ self , physical : Tensor , basis : Tensor , offset : Sequence [int ], shape : Sequence [int ]
36+ ):
3737 """
3838 This constructor is made for specifying physical and basis exactly. It should not modify
3939 it.
0 commit comments