Skip to content

Commit 59bcf06

Browse files
committed
Rename DiagonalSparseTensor to StructuredSparseTensor
1 parent 2419c7e commit 59bcf06

8 files changed

Lines changed: 147 additions & 147 deletions

File tree

src/torchjd/autogram/_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor, nn, vmap
55
from torch.autograd.graph import get_gradient_edge
66

7-
from torchjd.sparse import make_dst
7+
from torchjd.sparse import make_sst
88

99
from ._edge_registry import EdgeRegistry
1010
from ._gramian_accumulator import GramianAccumulator
@@ -176,7 +176,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
176176

177177
output_dims = list(range(output.ndim))
178178
v_to_ps = [[dim] for dim in output_dims * 2]
179-
jac_output = make_dst(torch.ones_like(output), v_to_ps)
179+
jac_output = make_sst(torch.ones_like(output), v_to_ps)
180180

181181
vmapped_diff = differentiation
182182
for _ in output_dims:

src/torchjd/sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Need to import this to execute the code inside and thus to override the functions
22
from . import _aten_function_overrides
3-
from ._diagonal_sparse_tensor import DiagonalSparseTensor, make_dst
3+
from ._structured_sparse_tensor import StructuredSparseTensor, make_sst
Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,36 @@
11
from torch import Tensor
22
from torch.ops import aten # type: ignore
33

4-
from torchjd.sparse import DiagonalSparseTensor
4+
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor
55

66

7-
@DiagonalSparseTensor.implements(aten.threshold_backward.default)
7+
@StructuredSparseTensor.implements(aten.threshold_backward.default)
88
def threshold_backward_default(
9-
grad_output: DiagonalSparseTensor, self: Tensor, threshold
10-
) -> DiagonalSparseTensor:
9+
grad_output: StructuredSparseTensor, self: Tensor, threshold
10+
) -> StructuredSparseTensor:
1111
new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold)
1212

13-
return DiagonalSparseTensor(new_physical, grad_output.v_to_ps)
13+
return StructuredSparseTensor(new_physical, grad_output.v_to_ps)
1414

1515

16-
@DiagonalSparseTensor.implements(aten.hardtanh_backward.default)
16+
@StructuredSparseTensor.implements(aten.hardtanh_backward.default)
1717
def hardtanh_backward_default(
18-
grad_output: DiagonalSparseTensor,
18+
grad_output: StructuredSparseTensor,
1919
self: Tensor,
2020
min_val: Tensor | int | float,
2121
max_val: Tensor | int | float,
22-
) -> DiagonalSparseTensor:
23-
if isinstance(self, DiagonalSparseTensor):
22+
) -> StructuredSparseTensor:
23+
if isinstance(self, StructuredSparseTensor):
2424
raise NotImplementedError()
2525

2626
new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val)
27-
return DiagonalSparseTensor(new_physical, grad_output.v_to_ps)
27+
return StructuredSparseTensor(new_physical, grad_output.v_to_ps)
2828

2929

30-
@DiagonalSparseTensor.implements(aten.hardswish_backward.default)
31-
def hardswish_backward_default(grad_output: DiagonalSparseTensor, self: Tensor):
32-
if isinstance(self, DiagonalSparseTensor):
30+
@StructuredSparseTensor.implements(aten.hardswish_backward.default)
31+
def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor):
32+
if isinstance(self, StructuredSparseTensor):
3333
raise NotImplementedError()
3434

3535
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
36-
return DiagonalSparseTensor(new_physical, grad_output.v_to_ps)
36+
return StructuredSparseTensor(new_physical, grad_output.v_to_ps)

src/torchjd/sparse/_aten_function_overrides/einsum.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,23 @@
22
from torch import Tensor, tensor
33
from torch.ops import aten # type: ignore
44

5-
from torchjd.sparse import DiagonalSparseTensor
6-
from torchjd.sparse._diagonal_sparse_tensor import (
5+
from torchjd.sparse._structured_sparse_tensor import (
6+
StructuredSparseTensor,
77
p_to_vs_from_v_to_ps,
8-
to_diagonal_sparse_tensor,
98
to_most_efficient_tensor,
9+
to_structured_sparse_tensor,
1010
)
1111

1212

1313
def prepare_for_elementwise_op(
1414
t1: Tensor | int | float, t2: Tensor | int | float
15-
) -> tuple[DiagonalSparseTensor, DiagonalSparseTensor]:
15+
) -> tuple[StructuredSparseTensor, StructuredSparseTensor]:
1616
"""
17-
Prepares two DSTs of the same shape from two args, one of those being a DST, and the other being
18-
a DST, Tensor, int or float.
17+
Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being
18+
a SST, Tensor, int or float.
1919
"""
2020

21-
assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor)
21+
assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor)
2222

2323
if isinstance(t1, int) or isinstance(t1, float):
2424
t1_ = tensor(t1, device=t2.device)
@@ -31,52 +31,52 @@ def prepare_for_elementwise_op(
3131
t2_ = t2
3232

3333
t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_])
34-
t1_ = to_diagonal_sparse_tensor(t1_)
35-
t2_ = to_diagonal_sparse_tensor(t2_)
34+
t1_ = to_structured_sparse_tensor(t1_)
35+
t2_ = to_structured_sparse_tensor(t2_)
3636

3737
return t1_, t2_
3838

3939

40-
@DiagonalSparseTensor.implements(aten.mul.Tensor)
40+
@StructuredSparseTensor.implements(aten.mul.Tensor)
4141
def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
4242
# Element-wise multiplication with broadcasting
4343
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
4444
all_dims = list(range(t1_.ndim))
4545
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
4646

4747

48-
@DiagonalSparseTensor.implements(aten.div.Tensor)
48+
@StructuredSparseTensor.implements(aten.div.Tensor)
4949
def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
5050
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
51-
t2_ = DiagonalSparseTensor(1.0 / t2_.physical, t2_.v_to_ps)
51+
t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.v_to_ps)
5252
all_dims = list(range(t1_.ndim))
5353
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
5454

5555

56-
@DiagonalSparseTensor.implements(aten.mul.Scalar)
57-
def mul_Scalar(t: DiagonalSparseTensor, scalar) -> DiagonalSparseTensor:
58-
# TODO: maybe it could be that scalar is a scalar DST and t is a normal tensor. Need to check
56+
@StructuredSparseTensor.implements(aten.mul.Scalar)
57+
def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor:
58+
# TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check
5959
# that
6060

61-
assert isinstance(t, DiagonalSparseTensor)
61+
assert isinstance(t, StructuredSparseTensor)
6262
new_physical = aten.mul.Scalar(t.physical, scalar)
63-
return DiagonalSparseTensor(new_physical, t.v_to_ps)
63+
return StructuredSparseTensor(new_physical, t.v_to_ps)
6464

6565

66-
@DiagonalSparseTensor.implements(aten.add.Tensor)
66+
@StructuredSparseTensor.implements(aten.add.Tensor)
6767
def add_Tensor(
6868
t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0
69-
) -> DiagonalSparseTensor:
69+
) -> StructuredSparseTensor:
7070
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
7171

7272
if t1_.v_to_ps == t2_.v_to_ps:
7373
new_physical = t1_.physical + t2_.physical * alpha
74-
return DiagonalSparseTensor(new_physical, t1_.v_to_ps)
74+
return StructuredSparseTensor(new_physical, t1_.v_to_ps)
7575
else:
7676
raise NotImplementedError()
7777

7878

79-
def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) -> Tensor:
79+
def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor:
8080

8181
# First part of the algorithm, determine how to cluster physical indices as well as the common
8282
# p_shapes corresponding to matching v_dims. Second part translates to physical einsum.
@@ -89,7 +89,7 @@ def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) ->
8989
# get unique indices
9090
# map output indices (there can be splits)
9191
# call physical einsum
92-
# build resulting dst
92+
# build resulting sst
9393

9494
# OVER
9595

@@ -104,7 +104,7 @@ def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) ->
104104
# [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension.
105105
# For this reason, an index is decomposed into sub-indices that are then independently
106106
# clustered.
107-
# So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l],
107+
# So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l],
108108
# We will consider three indices (i, 0), (i, 1) and (i, 2).
109109
# If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then
110110
# (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in
@@ -136,7 +136,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None:
136136
tensors = list[Tensor]()
137137
indices_to_n_pdims = dict[int, int]()
138138
for t, indices in args:
139-
assert isinstance(t, DiagonalSparseTensor)
139+
assert isinstance(t, StructuredSparseTensor)
140140
tensors.append(t.physical)
141141
for ps, index in zip(t.v_to_ps, indices):
142142
if index in indices_to_n_pdims:
@@ -150,7 +150,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None:
150150
group_indices([(indices[i], sub_i) for i, sub_i in indices_])
151151
# record the physical dimensions, index[v] for v in vs will end-up mapping to the same
152152
# final dimension as they were just clustered, so we can take the first, which exists as
153-
# t is a valid DST.
153+
# t is a valid SST.
154154
new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs])
155155

156156
current = 0
@@ -186,52 +186,52 @@ def unique_int(pair: tuple[int, int]) -> int:
186186
return to_most_efficient_tensor(physical, v_to_ps)
187187

188188

189-
@DiagonalSparseTensor.implements(aten.bmm.default)
189+
@StructuredSparseTensor.implements(aten.bmm.default)
190190
def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
191-
assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor)
191+
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
192192
assert (
193193
mat1.ndim == 3
194194
and mat2.ndim == 3
195195
and mat1.shape[0] == mat2.shape[0]
196196
and mat1.shape[2] == mat2.shape[1]
197197
)
198198

199-
mat1_ = to_diagonal_sparse_tensor(mat1)
200-
mat2_ = to_diagonal_sparse_tensor(mat2)
199+
mat1_ = to_structured_sparse_tensor(mat1)
200+
mat2_ = to_structured_sparse_tensor(mat2)
201201

202202
# TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes
203203
# decompositions. If not, can reshape to common decomposition?
204204
return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3])
205205

206206

207-
@DiagonalSparseTensor.implements(aten.mm.default)
207+
@StructuredSparseTensor.implements(aten.mm.default)
208208
def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
209-
assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor)
209+
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
210210
assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0]
211211

212-
mat1_ = to_diagonal_sparse_tensor(mat1)
213-
mat2_ = to_diagonal_sparse_tensor(mat2)
212+
mat1_ = to_structured_sparse_tensor(mat1)
213+
mat2_ = to_structured_sparse_tensor(mat2)
214214

215215
return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2])
216216

217217

218-
@DiagonalSparseTensor.implements(aten.mean.default)
219-
def mean_default(t: DiagonalSparseTensor) -> Tensor:
220-
assert isinstance(t, DiagonalSparseTensor)
218+
@StructuredSparseTensor.implements(aten.mean.default)
219+
def mean_default(t: StructuredSparseTensor) -> Tensor:
220+
assert isinstance(t, StructuredSparseTensor)
221221
return aten.sum.default(t.physical) / t.numel()
222222

223223

224-
@DiagonalSparseTensor.implements(aten.sum.default)
225-
def sum_default(t: DiagonalSparseTensor) -> Tensor:
226-
assert isinstance(t, DiagonalSparseTensor)
224+
@StructuredSparseTensor.implements(aten.sum.default)
225+
def sum_default(t: StructuredSparseTensor) -> Tensor:
226+
assert isinstance(t, StructuredSparseTensor)
227227
return aten.sum.default(t.physical)
228228

229229

230-
@DiagonalSparseTensor.implements(aten.sum.dim_IntList)
230+
@StructuredSparseTensor.implements(aten.sum.dim_IntList)
231231
def sum_dim_IntList(
232-
t: DiagonalSparseTensor, dim: list[int], keepdim: bool = False, dtype=None
232+
t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None
233233
) -> Tensor:
234-
assert isinstance(t, DiagonalSparseTensor)
234+
assert isinstance(t, StructuredSparseTensor)
235235

236236
if dtype:
237237
raise NotImplementedError()

src/torchjd/sparse/_aten_function_overrides/pointwise.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch.ops import aten # type: ignore
22

3-
from torchjd.sparse import DiagonalSparseTensor
3+
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor
44

55
# pointwise functions applied to one Tensor with `0.0 → 0`
66
_POINTWISE_FUNCTIONS = [
@@ -68,18 +68,18 @@
6868

6969

7070
def _override_pointwise(op):
71-
@DiagonalSparseTensor.implements(op)
72-
def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor:
73-
assert isinstance(t, DiagonalSparseTensor)
74-
return DiagonalSparseTensor(op(t.physical), t.v_to_ps)
71+
@StructuredSparseTensor.implements(op)
72+
def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
73+
assert isinstance(t, StructuredSparseTensor)
74+
return StructuredSparseTensor(op(t.physical), t.v_to_ps)
7575

7676
return func_
7777

7878

7979
def _override_inplace_pointwise(op):
80-
@DiagonalSparseTensor.implements(op)
81-
def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor:
82-
assert isinstance(t, DiagonalSparseTensor)
80+
@StructuredSparseTensor.implements(op)
81+
def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
82+
assert isinstance(t, StructuredSparseTensor)
8383
op(t.physical)
8484
return t
8585

@@ -91,22 +91,22 @@ def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor:
9191
_override_inplace_pointwise(pointwise_func)
9292

9393

94-
@DiagonalSparseTensor.implements(aten.pow.Tensor_Scalar)
95-
def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor:
96-
assert isinstance(t, DiagonalSparseTensor)
94+
@StructuredSparseTensor.implements(aten.pow.Tensor_Scalar)
95+
def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor:
96+
assert isinstance(t, StructuredSparseTensor)
9797

9898
if exponent <= 0.0:
9999
# Need to densify because we don't have pow(0.0, exponent) = 0.0
100100
return aten.pow.Tensor_Scalar(t.to_dense(), exponent)
101101

102102
new_physical = aten.pow.Tensor_Scalar(t.physical, exponent)
103-
return DiagonalSparseTensor(new_physical, t.v_to_ps)
103+
return StructuredSparseTensor(new_physical, t.v_to_ps)
104104

105105

106106
# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar.
107-
@DiagonalSparseTensor.implements(aten.pow_.Scalar)
108-
def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor:
109-
assert isinstance(t, DiagonalSparseTensor)
107+
@StructuredSparseTensor.implements(aten.pow_.Scalar)
108+
def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor:
109+
assert isinstance(t, StructuredSparseTensor)
110110

111111
if exponent <= 0.0:
112112
# Need to densify because we don't have pow(0.0, exponent) = 0.0
@@ -117,9 +117,9 @@ def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTenso
117117
return t
118118

119119

120-
@DiagonalSparseTensor.implements(aten.div.Scalar)
121-
def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor:
122-
assert isinstance(t, DiagonalSparseTensor)
120+
@StructuredSparseTensor.implements(aten.div.Scalar)
121+
def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTensor:
122+
assert isinstance(t, StructuredSparseTensor)
123123

124124
new_physical = aten.div.Scalar(t.physical, divisor)
125-
return DiagonalSparseTensor(new_physical, t.v_to_ps)
125+
return StructuredSparseTensor(new_physical, t.v_to_ps)

0 commit comments

Comments
 (0)