Skip to content

Commit 6ada564

Browse files
committed
Finish switching to offset and shape
1 parent 799c88f commit 6ada564

5 files changed

Lines changed: 111 additions & 84 deletions

File tree

src/torchjd/sparse/_aten_function_overrides/backward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def threshold_backward_default(
1111
new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold)
1212

1313
return SparseLatticedTensor(
14-
new_physical, grad_output.basis, grad_output.offset, grad_output.size
14+
new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t
1515
)
1616

1717

@@ -27,7 +27,7 @@ def hardtanh_backward_default(
2727

2828
new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val)
2929
return SparseLatticedTensor(
30-
new_physical, grad_output.basis, grad_output.offset, grad_output.size
30+
new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t
3131
)
3232

3333

@@ -38,5 +38,5 @@ def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor):
3838

3939
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
4040
return SparseLatticedTensor(
41-
new_physical, grad_output.basis, grad_output.offset, grad_output.size
41+
new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t
4242
)

src/torchjd/sparse/_aten_function_overrides/einsum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
165165
@impl(aten.div.Tensor)
166166
def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
167167
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
168-
t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis, t2_.offset, t2_.size)
168+
t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis, t2_.offset, t2_.shape_t)
169169
all_dims = list(range(t1_.ndim))
170170
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
171171

@@ -177,7 +177,7 @@ def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor:
177177

178178
assert isinstance(t, SparseLatticedTensor)
179179
new_physical = aten.mul.Scalar(t.physical, scalar)
180-
return SparseLatticedTensor(new_physical, t.basis, t.offset, t.size)
180+
return SparseLatticedTensor(new_physical, t.basis, t.offset, t.shape_t)
181181

182182

183183
@impl(aten.add.Tensor)
@@ -189,10 +189,10 @@ def add_Tensor(
189189
if (
190190
torch.equal(t1_.basis, t2_.basis)
191191
and torch.equal(t1_.offset, t2_.offset)
192-
and torch.equal(t1_.size, t2_.size)
192+
and torch.equal(t1_.shape_t, t2_.shape_t)
193193
):
194194
new_physical = t1_.physical + t2_.physical * alpha
195-
return SparseLatticedTensor(new_physical, t1_.basis, t1_.offset, t1_.size)
195+
return SparseLatticedTensor(new_physical, t1_.basis, t1_.offset, t1_.shape_t)
196196
else:
197197
raise NotImplementedError()
198198

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 13 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import operator
22
from itertools import accumulate
33
from math import prod
4-
from typing import cast
54

65
import torch
76
from torch import Tensor, arange, cat, tensor
@@ -41,6 +40,9 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
4140

4241
assert isinstance(t, SparseLatticedTensor)
4342

43+
if not torch.equal(t.padding, torch.zeros_like(t.padding)):
44+
raise NotImplementedError()
45+
4446
shape = infer_shape(shape, t.numel())
4547

4648
if prod(shape) != t.numel():
@@ -51,7 +53,9 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
5153
c = _reverse_cumulative_product(vshape)
5254
c_prime = _reverse_cumulative_product(shape)
5355
new_basis = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1)
54-
return to_most_efficient_tensor(t.physical, new_basis)
56+
57+
new_offset = torch.zeros(len(shape), dtype=torch.int64)
58+
return to_most_efficient_tensor(t.physical, new_basis, new_offset, shape)
5559

5660

5761
def _reverse_cumulative_product(values: list[int]) -> Tensor:
@@ -87,7 +91,7 @@ def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor
8791
pdims = t.basis.shape[1]
8892
new_basis = cat([t.basis[:dim], torch.zeros(1, pdims, dtype=torch.int64), t.basis[dim:]])
8993
new_offset = cat([t.offset[:dim], torch.zeros(1, dtype=torch.int64), t.offset[dim:]])
90-
new_size = cat([t.size[:dim], torch.zeros(1, dtype=torch.int64), t.size[dim:]])
94+
new_size = cat([t.shape_t[:dim], torch.ones(1, dtype=torch.int64), t.shape_t[dim:]])
9195
return SparseLatticedTensor(t.physical, new_basis, new_offset, new_size)
9296

9397

@@ -106,15 +110,15 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso
106110
is_row_kept = [i not in excluded for i in range(t.ndim)]
107111
new_basis = t.basis[is_row_kept]
108112
new_offset = t.offset[is_row_kept]
109-
new_size = t.size[is_row_kept]
113+
new_size = t.shape_t[is_row_kept]
110114
return to_most_efficient_tensor(t.physical, new_basis, new_offset, new_size)
111115

112116

113117
@impl(aten.permute.default)
114118
def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedTensor:
115119
new_basis = t.basis[dims]
116120
new_offset = t.offset[dims]
117-
new_size = t.size[dims]
121+
new_size = t.shape_t[dims]
118122
return SparseLatticedTensor(t.physical, new_basis, new_offset, new_size)
119123

120124

@@ -124,56 +128,10 @@ def cat_default(tensors: list[Tensor], dim: int = 0) -> Tensor:
124128
print_fallback(aten.cat.default, (tensors, dim), {})
125129
return aten.cat.default([unwrap_to_dense(t) for t in tensors])
126130

127-
tensors_ = [cast(SparseLatticedTensor, t) for t in tensors]
128-
ref_tensor = tensors_[0]
129-
ref_basis = ref_tensor.basis
130-
if any(not torch.equal(t.basis, ref_basis) for t in tensors_[1:]):
131-
raise NotImplementedError(
132-
"Override for aten.cat.default does not support SLTs that do not all have the same "
133-
f"basis. Found the following tensors:\n{[repr(t) for t in tensors_]} and the following "
134-
f"dim: {dim}."
135-
)
136-
if any(t.physical.shape != ref_tensor.physical.shape for t in tensors_[1:]):
137-
# This can happen in the following example:
138-
# t1 = SLT([1 2 3], [[2]])
139-
# t2 = SLT([4 5 6 7], [[2]])
140-
# The expected result would be 1 0 2 0 3 4 0 5 0 6 0 7, but this is not representable
141-
# efficiently as an SLT (because there is no 0 between 3 and 4, and both physicals have a
142-
# different shape so we can't just stack them).
143-
144-
# TODO: Maybe a partial densify is possible rather than a full densify.
145-
print_fallback(aten.cat.default, (tensors, dim), {})
146-
return aten.cat.default([unwrap_to_dense(t) for t in tensors])
147-
148-
# We need to try to find the (pretty sure it either does not exist or is unique) physical
149-
# dimension that makes us only move on virtual dimension dim. It also needs to be such that
150-
# traversing it entirely brings us exactly to the end of virtual dimension dim.
151-
152-
ref_virtual_dim_size = ref_tensor.shape[dim]
153-
indices = torch.argwhere(
154-
torch.eq(ref_basis[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size)
155-
& torch.eq(ref_basis.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size)
156-
)
157-
assert len(indices) <= 1
158-
159-
if len(indices) == 0:
160-
# Add a physical dimension pdim on which we can concatenate the physicals such that this
161-
# translates into a concatenation of the virtuals on virtual dimension dim.
162-
163-
pdim = ref_tensor.physical.ndim
164-
physicals = [t.physical.unsqueeze(-1) for t in tensors_]
165-
new_basis_vector = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64)
166-
new_basis_vector[dim, 0] = ref_virtual_dim_size
167-
new_basis = torch.concatenate([ref_tensor.basis, new_basis_vector], dim=1)
168-
else:
169-
# Such a physical dimension already exists. Note that an alternative implementation would be
170-
# to simply always add the physical dimension, and squash it if it ends up being not needed.
171-
physicals = [t.physical for t in tensors_]
172-
pdim = cast(int, indices[0, 0].item())
173-
new_basis = ref_tensor.basis
131+
print_fallback(aten.cat.default, (tensors, dim), {})
132+
return aten.cat.default([unwrap_to_dense(t) for t in tensors])
174133

175-
new_physical = aten.cat.default(physicals, dim=pdim)
176-
return SparseLatticedTensor(new_physical, new_basis)
134+
# TODO: add implementation based on adding some margin to tensors and summing them
177135

178136

179137
@impl(aten.expand.default)
@@ -190,7 +148,7 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT
190148
# Try to expand each dimension to its new size
191149
new_physical = t.physical
192150
new_basis = t.basis
193-
new_sizes = t.size
151+
new_sizes = t.shape_t
194152
for d, (v, orig_size, new_size) in enumerate(zip(t.basis, t.shape, sizes, strict=True)):
195153
if v.sum() > 0 and orig_size != new_size and new_size != -1:
196154
raise ValueError(

src/torchjd/sparse/_sparse_latticed_tensor.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __new__(
1717
physical: Tensor,
1818
basis: Tensor,
1919
offset: Tensor | None = None,
20-
size: list[int] | tuple[int, ...] | torch.Size | Tensor | None = None,
20+
shape: list[int] | tuple[int, ...] | torch.Size | Tensor | None = None,
2121
):
2222
assert basis.dtype == torch.int64
2323

@@ -31,20 +31,20 @@ def __new__(
3131
# (which is bad!)
3232
assert not physical.requires_grad or not torch.is_grad_enabled()
3333

34-
if size is None:
34+
if shape is None:
3535
pshape = tensor(physical.shape, dtype=torch.int64)
36-
size = basis @ (pshape - 1) + 1
36+
shape = basis @ (pshape - 1) + 1
3737

3838
return Tensor._make_wrapper_subclass(
39-
cls, list(size), dtype=physical.dtype, device=physical.device
39+
cls, list(shape), dtype=physical.dtype, device=physical.device
4040
)
4141

4242
def __init__(
4343
self,
4444
physical: Tensor,
4545
basis: Tensor,
4646
offset: Tensor | None,
47-
size: list[int] | tuple[int, ...] | torch.Size | Tensor | None,
47+
shape: list[int] | tuple[int, ...] | torch.Size | Tensor | None,
4848
):
4949
"""
5050
This constructor is made for specifying physical and basis exactly. It should not modify
@@ -58,15 +58,15 @@ def __init__(
5858
the linear transformation between an index in the physical tensor and the corresponding
5959
index in the virtual tensor, i.e. v_index = basis @ p_index + offset.
6060
:param offset: Offset for the virtual index, i.e. v_index = basis @ p_index + offset.
61-
:param size: Size of the sparse tensor. If not provided, the size will be inferred as the
61+
:param shape: Size of the sparse tensor. If not provided, the size will be inferred as the
6262
minimum size big enough to hold all non-zero elements.
6363
6464
# TODO: make a nicer interface where it's possible to provide lists or sizes instead of
6565
always having to provide int tensors
6666
"""
6767

6868
if offset is None:
69-
offset = torch.zeros(len(self.shape))
69+
offset = torch.zeros(len(self.shape), dtype=torch.int64)
7070

7171
if any(s == 1 for s in physical.shape):
7272
raise ValueError(
@@ -95,7 +95,16 @@ def __init__(
9595
self.physical = physical
9696
self.basis = basis
9797
self.offset = offset
98-
self.size = tensor(size, dtype=torch.int64)
98+
99+
if shape is None:
100+
pshape = tensor(physical.shape, dtype=torch.int64)
101+
shape = basis @ (pshape - 1) + 1
102+
if isinstance(shape, torch.Tensor):
103+
self.shape_t = shape
104+
else:
105+
self.shape_t = tensor(shape, dtype=torch.int64)
106+
107+
self.pshape_t = tensor(physical.shape, dtype=torch.int64)
99108

100109
def to_dense(
101110
self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None
@@ -110,7 +119,9 @@ def to_dense(
110119
p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij"))
111120

112121
# addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu
113-
v_indices_grid = tensordot(self.basis, p_indices_grid, dims=1) + self.offset
122+
reshaped_offset = self.offset.reshape([-1] + [1] * self.physical.ndim)
123+
v_indices_grid = tensordot(self.basis, p_indices_grid, dims=1) + reshaped_offset
124+
# v_indices_grid is of shape [n_virtual_dims] + physical_shape
114125
res = zeros(self.shape, device=self.device, dtype=self.dtype)
115126
res[tuple(v_indices_grid)] = self.physical
116127
return res
@@ -128,7 +139,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
128139
return func(*unwrapped_args, **unwrapped_kwargs)
129140

130141
def __repr__(self, *, tensor_contents=None) -> str:
131-
return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, offset={self.offset}, size={self.size})"
142+
return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, offset={self.offset}, size={self.shape_t})"
132143

133144
@classmethod
134145
def implements(cls, torch_function):
@@ -141,6 +152,61 @@ def decorator(func):
141152

142153
return decorator
143154

155+
@property
156+
def start_padding(self) -> Tensor:
157+
"""
158+
Returns the number of zeros of padding at the start of each virtual dimension.
159+
160+
The result is an int tensor of shape [virtual_ndim].
161+
"""
162+
163+
return self.offset
164+
165+
@property
166+
def end_padding(self) -> Tensor:
167+
"""
168+
Returns the number of zeros of padding at the end of each virtual dimension.
169+
170+
The result is an int tensor of shape [virtual_ndim].
171+
"""
172+
173+
return self.shape_t - self.physical_image_size - self.offset
174+
175+
@property
176+
def padding(self) -> Tensor:
177+
"""
178+
Returns the number of zeros of padding at the start and end of each virtual dimension.
179+
180+
The result is an int tensor of shape [virtual_ndim, 2].
181+
"""
182+
183+
return torch.stack([self.start_padding, self.end_padding], dim=1)
184+
185+
@property
186+
def min_natural_virtual_indices(self) -> Tensor:
187+
# Basis where each positive element is replaced by 0
188+
non_positive_basis = torch.min(self.basis, torch.zeros_like(self.basis))
189+
max_physical_index = self.pshape_t - 1
190+
return (non_positive_basis * max_physical_index.unsqueeze(0)).sum(dim=1)
191+
192+
@property
193+
def max_natural_virtual_indices(self) -> Tensor:
194+
# Basis where each negative element is replaced by 0
195+
non_negative = torch.max(self.basis, torch.zeros_like(self.basis))
196+
max_physical_index = self.pshape_t - 1
197+
return (non_negative * max_physical_index.unsqueeze(0)).sum(dim=1)
198+
199+
@property
200+
def physical_image_size(self) -> Tensor:
201+
"""
202+
Returns the shape of the image of the physical through the basis transform.
203+
204+
The result is an int tensor of shape [virtual_ndim].
205+
"""
206+
207+
one = torch.ones(self.ndim, dtype=torch.int64)
208+
return self.max_natural_virtual_indices - self.min_natural_virtual_indices + one
209+
144210

145211
impl = SparseLatticedTensor.implements
146212

0 commit comments

Comments
 (0)