Skip to content

Commit 95f9490

Browse files
PierreQuintonValerianRey
authored andcommitted
Add StructuredSparseTensor
1 parent 131df9a commit 95f9490

File tree

12 files changed

+1434
-13
lines changed

12 files changed

+1434
-13
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Jupyter notebooks
2+
*.ipynb
3+
14
# uv
25
uv.lock
36

src/torchjd/autogram/_engine.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from torch import Tensor, nn, vmap
55
from torch.autograd.graph import get_gradient_edge
66

7+
from torchjd.sparse import make_sst
8+
79
from ._edge_registry import EdgeRegistry
810
from ._gramian_accumulator import GramianAccumulator
911
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
@@ -173,7 +175,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
173175
)
174176

175177
output_dims = list(range(output.ndim))
176-
jac_output = _make_initial_jac_output(output)
178+
identity = torch.eye(output.ndim, dtype=torch.int64)
179+
strides = torch.concatenate([identity, identity], dim=0)
180+
jac_output = make_sst(torch.ones_like(output), strides)
177181

178182
vmapped_diff = differentiation
179183
for _ in output_dims:
@@ -193,15 +197,3 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
193197
gramian_computer.reset()
194198

195199
return gramian
196-
197-
198-
def _make_initial_jac_output(output: Tensor) -> Tensor:
199-
if output.ndim == 0:
200-
return torch.ones_like(output)
201-
p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape]
202-
p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij")
203-
v_indices_grid = p_indices_grid + p_indices_grid
204-
205-
res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype)
206-
res[v_indices_grid] = 1.0
207-
return res

src/torchjd/sparse/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Need to import this to execute the code inside and thus to override the functions
2+
from . import _aten_function_overrides
3+
from ._structured_sparse_tensor import StructuredSparseTensor, make_sst
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import backward, einsum, pointwise, shape
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from torch import Tensor
2+
from torch.ops import aten # type: ignore
3+
4+
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl
5+
6+
7+
@impl(aten.threshold_backward.default)
8+
def threshold_backward_default(
9+
grad_output: StructuredSparseTensor, self: Tensor, threshold
10+
) -> StructuredSparseTensor:
11+
new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold)
12+
13+
return StructuredSparseTensor(new_physical, grad_output.strides)
14+
15+
16+
@impl(aten.hardtanh_backward.default)
17+
def hardtanh_backward_default(
18+
grad_output: StructuredSparseTensor,
19+
self: Tensor,
20+
min_val: Tensor | int | float,
21+
max_val: Tensor | int | float,
22+
) -> StructuredSparseTensor:
23+
if isinstance(self, StructuredSparseTensor):
24+
raise NotImplementedError()
25+
26+
new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val)
27+
return StructuredSparseTensor(new_physical, grad_output.strides)
28+
29+
30+
@impl(aten.hardswish_backward.default)
31+
def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor):
32+
if isinstance(self, StructuredSparseTensor):
33+
raise NotImplementedError()
34+
35+
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
36+
return StructuredSparseTensor(new_physical, grad_output.strides)
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import torch
2+
from torch import Tensor, tensor
3+
from torch.ops import aten # type: ignore
4+
5+
from torchjd.sparse._structured_sparse_tensor import (
6+
StructuredSparseTensor,
7+
impl,
8+
to_most_efficient_tensor,
9+
to_structured_sparse_tensor,
10+
)
11+
12+
13+
def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor:
14+
raise NotImplementedError()
15+
16+
# First part of the algorithm, determine how to cluster physical indices as well as the common
17+
# p_shapes corresponding to matching v_dims. Second part translates to physical einsum.
18+
19+
# get a map from einsum index to (tensor_idx, v_dims)
20+
# get a map from einsum index to merge of strides corresponding to v_dims with that index
21+
# use to_target_physical_strides on each physical and v_to_ps
22+
# cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding
23+
# p_to_vs
24+
# get unique indices
25+
# map output indices (there can be splits)
26+
# call physical einsum
27+
# build resulting sst
28+
29+
# OVER
30+
31+
# an index in the physical einsum is uniquely characterized by a virtual einsum index and a
32+
# stride corresponding to the physical stride in the virtual one (note that as the virtual shape
33+
# for two virtual index that match should match, then we want to match the strides and reshape
34+
# accordingly).
35+
# We want to cluster such indices whenever several appear in the same p_to_vs
36+
37+
# TODO: Handle ellipsis
38+
# If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list
39+
# [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension.
40+
# For this reason, an index is decomposed into sub-indices that are then independently
41+
# clustered.
42+
# So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l],
43+
# We will consider three indices (i, 0), (i, 1) and (i, 2).
44+
# If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then
45+
# (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in
46+
# the resulting einsum).
47+
# Note that this is a problem if two virtual dimensions (from possibly different
48+
# StructuredSparseTensors) have the same size but not the same decomposition into physical
49+
# dimension sizes. For now lets leave the responsibility to care about that in the calling
50+
# functions, if we can factor code later on we will.
51+
52+
index_parents = dict[tuple[int, int], tuple[int, int]]()
53+
54+
def get_representative(index: tuple[int, int]) -> tuple[int, int]:
55+
if index not in index_parents:
56+
# If an index is not yet in a cluster, put it in its own.
57+
index_parents[index] = index
58+
current = index_parents[index]
59+
if current != index:
60+
# Compress path to representative
61+
index_parents[index] = get_representative(current)
62+
return index_parents[index]
63+
64+
def group_indices(indices: list[tuple[int, int]]) -> None:
65+
first_representative = get_representative(indices[0])
66+
for i in indices[1:]:
67+
curr_representative = get_representative(i)
68+
index_parents[curr_representative] = first_representative
69+
70+
new_indices_pair = list[list[tuple[int, int]]]()
71+
physicals = list[Tensor]()
72+
indices_to_n_pdims = dict[int, int]()
73+
for t, indices in args:
74+
assert isinstance(t, StructuredSparseTensor)
75+
physicals.append(t.physical)
76+
for pdims, index in zip(t.v_to_ps, indices):
77+
if index in indices_to_n_pdims:
78+
if indices_to_n_pdims[index] != len(pdims):
79+
raise NotImplementedError(
80+
"einsum currently does not support having a different number of physical "
81+
"dimensions corresponding to matching virtual dimensions of different "
82+
f"tensors. Found {[(t.debug_info(), indices) for t, indices in args]}, "
83+
f"output_indices={output}."
84+
)
85+
else:
86+
indices_to_n_pdims[index] = len(pdims)
87+
p_to_vs = ... # p_to_vs_from_v_to_ps(t.v_to_ps)
88+
for indices_ in p_to_vs:
89+
# elements in indices[indices_] map to the same dimension, they should be clustered
90+
# together
91+
group_indices([(indices[i], sub_i) for i, sub_i in indices_])
92+
# record the physical dimensions, index[v] for v in vs will end-up mapping to the same
93+
# final dimension as they were just clustered, so we can take the first, which exists as
94+
# t is a valid SST.
95+
new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs])
96+
97+
current = 0
98+
pair_to_int = dict[tuple[int, int], int]()
99+
100+
def unique_int(pair: tuple[int, int]) -> int:
101+
nonlocal current
102+
if pair in pair_to_int:
103+
return pair_to_int[pair]
104+
pair_to_int[pair] = current
105+
current += 1
106+
return pair_to_int[pair]
107+
108+
new_indices = [
109+
[unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair
110+
]
111+
new_output = list[int]()
112+
v_to_ps = list[list[int]]()
113+
for i in output:
114+
current_v_to_ps = []
115+
for j in range(indices_to_n_pdims[i]):
116+
k = unique_int(get_representative((i, j)))
117+
if k in new_output:
118+
current_v_to_ps.append(new_output.index(k))
119+
else:
120+
current_v_to_ps.append(len(new_output))
121+
new_output.append(k)
122+
v_to_ps.append(current_v_to_ps)
123+
124+
physical = torch.einsum(*[x for y in zip(physicals, new_indices) for x in y], new_output)
125+
# Need to use the safe constructor, otherwise the dimensions may not be maximally grouped.
126+
# Maybe there is a way to fix that though.
127+
return to_most_efficient_tensor(physical, v_to_ps)
128+
129+
130+
def prepare_for_elementwise_op(
131+
t1: Tensor | int | float, t2: Tensor | int | float
132+
) -> tuple[StructuredSparseTensor, StructuredSparseTensor]:
133+
"""
134+
Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being
135+
a SST, Tensor, int or float.
136+
"""
137+
138+
assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor)
139+
140+
if isinstance(t1, int) or isinstance(t1, float):
141+
t1_ = tensor(t1, device=t2.device)
142+
else:
143+
t1_ = t1
144+
145+
if isinstance(t2, int) or isinstance(t2, float):
146+
t2_ = tensor(t2, device=t1.device)
147+
else:
148+
t2_ = t2
149+
150+
t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_])
151+
t1_ = to_structured_sparse_tensor(t1_)
152+
t2_ = to_structured_sparse_tensor(t2_)
153+
154+
return t1_, t2_
155+
156+
157+
@impl(aten.mul.Tensor)
158+
def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
159+
# Element-wise multiplication with broadcasting
160+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
161+
all_dims = list(range(t1_.ndim))
162+
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
163+
164+
165+
@impl(aten.div.Tensor)
166+
def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
167+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
168+
t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides)
169+
all_dims = list(range(t1_.ndim))
170+
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
171+
172+
173+
@impl(aten.mul.Scalar)
174+
def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor:
175+
# TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check
176+
# that
177+
178+
assert isinstance(t, StructuredSparseTensor)
179+
new_physical = aten.mul.Scalar(t.physical, scalar)
180+
return StructuredSparseTensor(new_physical, t.strides)
181+
182+
183+
@impl(aten.add.Tensor)
184+
def add_Tensor(
185+
t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0
186+
) -> StructuredSparseTensor:
187+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
188+
189+
if torch.equal(t1_.strides, t2_.strides):
190+
new_physical = t1_.physical + t2_.physical * alpha
191+
return StructuredSparseTensor(new_physical, t1_.strides)
192+
else:
193+
raise NotImplementedError()
194+
195+
196+
@impl(aten.bmm.default)
197+
def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
198+
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
199+
assert (
200+
mat1.ndim == 3
201+
and mat2.ndim == 3
202+
and mat1.shape[0] == mat2.shape[0]
203+
and mat1.shape[2] == mat2.shape[1]
204+
)
205+
206+
mat1_ = to_structured_sparse_tensor(mat1)
207+
mat2_ = to_structured_sparse_tensor(mat2)
208+
209+
# TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes
210+
# decompositions. If not, can reshape to common decomposition?
211+
return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3])
212+
213+
214+
@impl(aten.mm.default)
215+
def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
216+
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
217+
assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0]
218+
219+
mat1_ = to_structured_sparse_tensor(mat1)
220+
mat2_ = to_structured_sparse_tensor(mat2)
221+
222+
return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2])
223+
224+
225+
@impl(aten.mean.default)
226+
def mean_default(t: StructuredSparseTensor) -> Tensor:
227+
assert isinstance(t, StructuredSparseTensor)
228+
return aten.sum.default(t.physical) / t.numel()
229+
230+
231+
@impl(aten.sum.default)
232+
def sum_default(t: StructuredSparseTensor) -> Tensor:
233+
assert isinstance(t, StructuredSparseTensor)
234+
return aten.sum.default(t.physical)
235+
236+
237+
@impl(aten.sum.dim_IntList)
238+
def sum_dim_IntList(
239+
t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None
240+
) -> Tensor:
241+
assert isinstance(t, StructuredSparseTensor)
242+
243+
if dtype:
244+
raise NotImplementedError()
245+
246+
all_dims = list(range(t.ndim))
247+
result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim])
248+
249+
if keepdim:
250+
for d in dim:
251+
result = result.unsqueeze(d)
252+
253+
return result

0 commit comments

Comments
 (0)