Skip to content

Commit dc5740b

Browse files
committed
Make hnf_decomposition return the reduced HNF rather than the HNF.
1 parent 3e9e7d4 commit dc5740b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/torchjd/sparse/_sparse_latticed_tensor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
import operator
3-
from collections.abc import Callable
3+
from collections.abc import Callable, Sequence
44
from functools import wraps
55
from itertools import accumulate
66
from math import prod
@@ -14,7 +14,7 @@ class SparseLatticedTensor(Tensor):
1414
_HANDLED_FUNCTIONS = dict[Callable, Callable]()
1515

1616
@staticmethod
17-
def __new__(cls, physical: Tensor, basis: Tensor):
17+
def __new__(cls, physical: Tensor, basis: Tensor, offset: Sequence[int], shape: Sequence[int]):
1818
assert basis.dtype == torch.int64
1919

2020
# Note [Passing requires_grad=true tensors to subclasses]
@@ -27,13 +27,13 @@ def __new__(cls, physical: Tensor, basis: Tensor):
2727
# (which is bad!)
2828
assert not physical.requires_grad or not torch.is_grad_enabled()
2929

30-
pshape = tensor(physical.shape, dtype=torch.int64)
31-
vshape = basis @ (pshape - 1) + 1
3230
return Tensor._make_wrapper_subclass(
33-
cls, tuple(vshape.tolist()), dtype=physical.dtype, device=physical.device
31+
cls, shape, dtype=physical.dtype, device=physical.device
3432
)
3533

36-
def __init__(self, physical: Tensor, basis: Tensor):
34+
def __init__(
35+
self, physical: Tensor, basis: Tensor, offset: Sequence[int], shape: Sequence[int]
36+
):
3737
"""
3838
This constructor is made for specifying physical and basis exactly. It should not modify
3939
it.

0 commit comments

Comments
 (0)