Skip to content

Commit 2419c7e

Browse files
committed
Restructure sparse package
1 parent 26de009 commit 2419c7e

8 files changed

Lines changed: 738 additions & 704 deletions

File tree

src/torchjd/sparse/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
# Need to import this to execute the code inside and thus to override the functions
2+
from . import _aten_function_overrides
13
from ._diagonal_sparse_tensor import DiagonalSparseTensor, make_dst
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import backward, einsum, pointwise, shape
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from torch import Tensor
2+
from torch.ops import aten # type: ignore
3+
4+
from torchjd.sparse import DiagonalSparseTensor
5+
6+
7+
@DiagonalSparseTensor.implements(aten.threshold_backward.default)
8+
def threshold_backward_default(
9+
grad_output: DiagonalSparseTensor, self: Tensor, threshold
10+
) -> DiagonalSparseTensor:
11+
new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold)
12+
13+
return DiagonalSparseTensor(new_physical, grad_output.v_to_ps)
14+
15+
16+
@DiagonalSparseTensor.implements(aten.hardtanh_backward.default)
17+
def hardtanh_backward_default(
18+
grad_output: DiagonalSparseTensor,
19+
self: Tensor,
20+
min_val: Tensor | int | float,
21+
max_val: Tensor | int | float,
22+
) -> DiagonalSparseTensor:
23+
if isinstance(self, DiagonalSparseTensor):
24+
raise NotImplementedError()
25+
26+
new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val)
27+
return DiagonalSparseTensor(new_physical, grad_output.v_to_ps)
28+
29+
30+
@DiagonalSparseTensor.implements(aten.hardswish_backward.default)
31+
def hardswish_backward_default(grad_output: DiagonalSparseTensor, self: Tensor):
32+
if isinstance(self, DiagonalSparseTensor):
33+
raise NotImplementedError()
34+
35+
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
36+
return DiagonalSparseTensor(new_physical, grad_output.v_to_ps)
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import torch
2+
from torch import Tensor, tensor
3+
from torch.ops import aten # type: ignore
4+
5+
from torchjd.sparse import DiagonalSparseTensor
6+
from torchjd.sparse._diagonal_sparse_tensor import (
7+
p_to_vs_from_v_to_ps,
8+
to_diagonal_sparse_tensor,
9+
to_most_efficient_tensor,
10+
)
11+
12+
13+
def prepare_for_elementwise_op(
14+
t1: Tensor | int | float, t2: Tensor | int | float
15+
) -> tuple[DiagonalSparseTensor, DiagonalSparseTensor]:
16+
"""
17+
Prepares two DSTs of the same shape from two args, one of those being a DST, and the other being
18+
a DST, Tensor, int or float.
19+
"""
20+
21+
assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor)
22+
23+
if isinstance(t1, int) or isinstance(t1, float):
24+
t1_ = tensor(t1, device=t2.device)
25+
else:
26+
t1_ = t1
27+
28+
if isinstance(t2, int) or isinstance(t2, float):
29+
t2_ = tensor(t2, device=t1.device)
30+
else:
31+
t2_ = t2
32+
33+
t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_])
34+
t1_ = to_diagonal_sparse_tensor(t1_)
35+
t2_ = to_diagonal_sparse_tensor(t2_)
36+
37+
return t1_, t2_
38+
39+
40+
@DiagonalSparseTensor.implements(aten.mul.Tensor)
41+
def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
42+
# Element-wise multiplication with broadcasting
43+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
44+
all_dims = list(range(t1_.ndim))
45+
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
46+
47+
48+
@DiagonalSparseTensor.implements(aten.div.Tensor)
49+
def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
50+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
51+
t2_ = DiagonalSparseTensor(1.0 / t2_.physical, t2_.v_to_ps)
52+
all_dims = list(range(t1_.ndim))
53+
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
54+
55+
56+
@DiagonalSparseTensor.implements(aten.mul.Scalar)
57+
def mul_Scalar(t: DiagonalSparseTensor, scalar) -> DiagonalSparseTensor:
58+
# TODO: maybe it could be that scalar is a scalar DST and t is a normal tensor. Need to check
59+
# that
60+
61+
assert isinstance(t, DiagonalSparseTensor)
62+
new_physical = aten.mul.Scalar(t.physical, scalar)
63+
return DiagonalSparseTensor(new_physical, t.v_to_ps)
64+
65+
66+
@DiagonalSparseTensor.implements(aten.add.Tensor)
67+
def add_Tensor(
68+
t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0
69+
) -> DiagonalSparseTensor:
70+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
71+
72+
if t1_.v_to_ps == t2_.v_to_ps:
73+
new_physical = t1_.physical + t2_.physical * alpha
74+
return DiagonalSparseTensor(new_physical, t1_.v_to_ps)
75+
else:
76+
raise NotImplementedError()
77+
78+
79+
def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) -> Tensor:
80+
81+
# First part of the algorithm, determine how to cluster physical indices as well as the common
82+
# p_shapes corresponding to matching v_dims. Second part translates to physical einsum.
83+
84+
# get a map from einsum index to (tensor_idx, v_dims)
85+
# get a map from einsum index to merge of strides corresponding to v_dims with that index
86+
# use to_target_physical_strides on each physical and v_to_ps
87+
# cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding
88+
# p_to_vs
89+
# get unique indices
90+
# map output indices (there can be splits)
91+
# call physical einsum
92+
# build resulting dst
93+
94+
# OVER
95+
96+
# an index in the physical einsum is uniquely characterized by a virtual einsum index and a
97+
# stride corresponding to the physical stride in the virtual one (note that as the virtual shape
98+
# for two virtual index that match should match, then we want to match the strides and reshape
99+
# accordingly).
100+
# We want to cluster such indices whenever several appear in the same p_to_vs
101+
102+
# TODO: Handle ellipsis
103+
# If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list
104+
# [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension.
105+
# For this reason, an index is decomposed into sub-indices that are then independently
106+
# clustered.
107+
# So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l],
108+
# We will consider three indices (i, 0), (i, 1) and (i, 2).
109+
# If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then
110+
# (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in
111+
# the resulting einsum).
112+
# Note that this is a problem if two virtual dimensions (from possibly different
113+
# DiagonaSparseTensors) have the same size but not the same decomposition into physical
114+
# dimension sizes. For now lets leave the responsibility to care about that in the calling
115+
# functions, if we can factor code later on we will.
116+
117+
index_parents = dict[tuple[int, int], tuple[int, int]]()
118+
119+
def get_representative(index: tuple[int, int]) -> tuple[int, int]:
120+
if index not in index_parents:
121+
# If an index is not yet in a cluster, put it in its own.
122+
index_parents[index] = index
123+
current = index_parents[index]
124+
if current != index:
125+
# Compress path to representative
126+
index_parents[index] = get_representative(current)
127+
return index_parents[index]
128+
129+
def group_indices(indices: list[tuple[int, int]]) -> None:
130+
first_representative = get_representative(indices[0])
131+
for i in indices[1:]:
132+
curr_representative = get_representative(i)
133+
index_parents[curr_representative] = first_representative
134+
135+
new_indices_pair = list[list[tuple[int, int]]]()
136+
tensors = list[Tensor]()
137+
indices_to_n_pdims = dict[int, int]()
138+
for t, indices in args:
139+
assert isinstance(t, DiagonalSparseTensor)
140+
tensors.append(t.physical)
141+
for ps, index in zip(t.v_to_ps, indices):
142+
if index in indices_to_n_pdims:
143+
assert indices_to_n_pdims[index] == len(ps)
144+
else:
145+
indices_to_n_pdims[index] = len(ps)
146+
p_to_vs = p_to_vs_from_v_to_ps(t.v_to_ps)
147+
for indices_ in p_to_vs:
148+
# elements in indices[indices_] map to the same dimension, they should be clustered
149+
# together
150+
group_indices([(indices[i], sub_i) for i, sub_i in indices_])
151+
# record the physical dimensions, index[v] for v in vs will end-up mapping to the same
152+
# final dimension as they were just clustered, so we can take the first, which exists as
153+
# t is a valid DST.
154+
new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs])
155+
156+
current = 0
157+
pair_to_int = dict[tuple[int, int], int]()
158+
159+
def unique_int(pair: tuple[int, int]) -> int:
160+
nonlocal current
161+
if pair in pair_to_int:
162+
return pair_to_int[pair]
163+
pair_to_int[pair] = current
164+
current += 1
165+
return pair_to_int[pair]
166+
167+
new_indices = [
168+
[unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair
169+
]
170+
new_output = list[int]()
171+
v_to_ps = list[list[int]]()
172+
for i in output:
173+
current_v_to_ps = []
174+
for j in range(indices_to_n_pdims[i]):
175+
k = unique_int(get_representative((i, j)))
176+
if k in new_output:
177+
current_v_to_ps.append(new_output.index(k))
178+
else:
179+
current_v_to_ps.append(len(new_output))
180+
new_output.append(k)
181+
v_to_ps.append(current_v_to_ps)
182+
183+
physical = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output)
184+
# Need to use the safe constructor, otherwise the dimensions may not be maximally grouped.
185+
# Maybe there is a way to fix that though.
186+
return to_most_efficient_tensor(physical, v_to_ps)
187+
188+
189+
@DiagonalSparseTensor.implements(aten.bmm.default)
190+
def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
191+
assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor)
192+
assert (
193+
mat1.ndim == 3
194+
and mat2.ndim == 3
195+
and mat1.shape[0] == mat2.shape[0]
196+
and mat1.shape[2] == mat2.shape[1]
197+
)
198+
199+
mat1_ = to_diagonal_sparse_tensor(mat1)
200+
mat2_ = to_diagonal_sparse_tensor(mat2)
201+
202+
# TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes
203+
# decompositions. If not, can reshape to common decomposition?
204+
return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3])
205+
206+
207+
@DiagonalSparseTensor.implements(aten.mm.default)
208+
def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
209+
assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor)
210+
assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0]
211+
212+
mat1_ = to_diagonal_sparse_tensor(mat1)
213+
mat2_ = to_diagonal_sparse_tensor(mat2)
214+
215+
return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2])
216+
217+
218+
@DiagonalSparseTensor.implements(aten.mean.default)
219+
def mean_default(t: DiagonalSparseTensor) -> Tensor:
220+
assert isinstance(t, DiagonalSparseTensor)
221+
return aten.sum.default(t.physical) / t.numel()
222+
223+
224+
@DiagonalSparseTensor.implements(aten.sum.default)
225+
def sum_default(t: DiagonalSparseTensor) -> Tensor:
226+
assert isinstance(t, DiagonalSparseTensor)
227+
return aten.sum.default(t.physical)
228+
229+
230+
@DiagonalSparseTensor.implements(aten.sum.dim_IntList)
231+
def sum_dim_IntList(
232+
t: DiagonalSparseTensor, dim: list[int], keepdim: bool = False, dtype=None
233+
) -> Tensor:
234+
assert isinstance(t, DiagonalSparseTensor)
235+
236+
if dtype:
237+
raise NotImplementedError()
238+
239+
all_dims = list(range(t.ndim))
240+
result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim])
241+
242+
if keepdim:
243+
for d in dim:
244+
result = result.unsqueeze(d)
245+
246+
return result

0 commit comments

Comments
 (0)