Skip to content

Commit 7991ac1

Browse files
committed
Add alias impl for StructuedSparseDensity.implements
1 parent f693e99 commit 7991ac1

5 files changed

Lines changed: 34 additions & 29 deletions

File tree

src/torchjd/sparse/_aten_function_overrides/backward.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from torch import Tensor
22
from torch.ops import aten # type: ignore
33

4-
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor
4+
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl
55

66

7-
@StructuredSparseTensor.implements(aten.threshold_backward.default)
7+
@impl(aten.threshold_backward.default)
88
def threshold_backward_default(
99
grad_output: StructuredSparseTensor, self: Tensor, threshold
1010
) -> StructuredSparseTensor:
@@ -13,7 +13,7 @@ def threshold_backward_default(
1313
return StructuredSparseTensor(new_physical, grad_output.v_to_ps)
1414

1515

16-
@StructuredSparseTensor.implements(aten.hardtanh_backward.default)
16+
@impl(aten.hardtanh_backward.default)
1717
def hardtanh_backward_default(
1818
grad_output: StructuredSparseTensor,
1919
self: Tensor,
@@ -27,7 +27,7 @@ def hardtanh_backward_default(
2727
return StructuredSparseTensor(new_physical, grad_output.v_to_ps)
2828

2929

30-
@StructuredSparseTensor.implements(aten.hardswish_backward.default)
30+
@impl(aten.hardswish_backward.default)
3131
def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor):
3232
if isinstance(self, StructuredSparseTensor):
3333
raise NotImplementedError()

src/torchjd/sparse/_aten_function_overrides/einsum.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from torchjd.sparse._structured_sparse_tensor import (
66
StructuredSparseTensor,
7+
impl,
78
p_to_vs_from_v_to_ps,
89
to_most_efficient_tensor,
910
to_structured_sparse_tensor,
@@ -37,23 +38,23 @@ def prepare_for_elementwise_op(
3738
return t1_, t2_
3839

3940

40-
@StructuredSparseTensor.implements(aten.mul.Tensor)
41+
@impl(aten.mul.Tensor)
4142
def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
4243
# Element-wise multiplication with broadcasting
4344
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
4445
all_dims = list(range(t1_.ndim))
4546
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
4647

4748

48-
@StructuredSparseTensor.implements(aten.div.Tensor)
49+
@impl(aten.div.Tensor)
4950
def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
5051
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
5152
t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.v_to_ps)
5253
all_dims = list(range(t1_.ndim))
5354
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
5455

5556

56-
@StructuredSparseTensor.implements(aten.mul.Scalar)
57+
@impl(aten.mul.Scalar)
5758
def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor:
5859
# TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check
5960
# that
@@ -63,7 +64,7 @@ def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor:
6364
return StructuredSparseTensor(new_physical, t.v_to_ps)
6465

6566

66-
@StructuredSparseTensor.implements(aten.add.Tensor)
67+
@impl(aten.add.Tensor)
6768
def add_Tensor(
6869
t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0
6970
) -> StructuredSparseTensor:
@@ -186,7 +187,7 @@ def unique_int(pair: tuple[int, int]) -> int:
186187
return to_most_efficient_tensor(physical, v_to_ps)
187188

188189

189-
@StructuredSparseTensor.implements(aten.bmm.default)
190+
@impl(aten.bmm.default)
190191
def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
191192
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
192193
assert (
@@ -204,7 +205,7 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
204205
return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3])
205206

206207

207-
@StructuredSparseTensor.implements(aten.mm.default)
208+
@impl(aten.mm.default)
208209
def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
209210
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
210211
assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0]
@@ -215,19 +216,19 @@ def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
215216
return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2])
216217

217218

218-
@StructuredSparseTensor.implements(aten.mean.default)
219+
@impl(aten.mean.default)
219220
def mean_default(t: StructuredSparseTensor) -> Tensor:
220221
assert isinstance(t, StructuredSparseTensor)
221222
return aten.sum.default(t.physical) / t.numel()
222223

223224

224-
@StructuredSparseTensor.implements(aten.sum.default)
225+
@impl(aten.sum.default)
225226
def sum_default(t: StructuredSparseTensor) -> Tensor:
226227
assert isinstance(t, StructuredSparseTensor)
227228
return aten.sum.default(t.physical)
228229

229230

230-
@StructuredSparseTensor.implements(aten.sum.dim_IntList)
231+
@impl(aten.sum.dim_IntList)
231232
def sum_dim_IntList(
232233
t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None
233234
) -> Tensor:

src/torchjd/sparse/_aten_function_overrides/pointwise.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch.ops import aten # type: ignore
22

3-
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor
3+
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl
44

55
# pointwise functions applied to one Tensor with `0.0 → 0`
66
_POINTWISE_FUNCTIONS = [
@@ -68,7 +68,7 @@
6868

6969

7070
def _override_pointwise(op):
71-
@StructuredSparseTensor.implements(op)
71+
@impl(op)
7272
def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
7373
assert isinstance(t, StructuredSparseTensor)
7474
return StructuredSparseTensor(op(t.physical), t.v_to_ps)
@@ -77,7 +77,7 @@ def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
7777

7878

7979
def _override_inplace_pointwise(op):
80-
@StructuredSparseTensor.implements(op)
80+
@impl(op)
8181
def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
8282
assert isinstance(t, StructuredSparseTensor)
8383
op(t.physical)
@@ -91,7 +91,7 @@ def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
9191
_override_inplace_pointwise(pointwise_func)
9292

9393

94-
@StructuredSparseTensor.implements(aten.pow.Tensor_Scalar)
94+
@impl(aten.pow.Tensor_Scalar)
9595
def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor:
9696
assert isinstance(t, StructuredSparseTensor)
9797

@@ -104,7 +104,7 @@ def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredS
104104

105105

106106
# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar.
107-
@StructuredSparseTensor.implements(aten.pow_.Scalar)
107+
@impl(aten.pow_.Scalar)
108108
def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor:
109109
assert isinstance(t, StructuredSparseTensor)
110110

@@ -117,7 +117,7 @@ def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseT
117117
return t
118118

119119

120-
@StructuredSparseTensor.implements(aten.div.Scalar)
120+
@impl(aten.div.Scalar)
121121
def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTensor:
122122
assert isinstance(t, StructuredSparseTensor)
123123

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
StructuredSparseTensor,
1010
encode_v_to_ps,
1111
fix_dim_encoding,
12+
impl,
1213
print_fallback,
1314
to_most_efficient_tensor,
1415
unwrap_to_dense,
1516
)
1617

1718

18-
@StructuredSparseTensor.implements(aten.view.default)
19+
@impl(aten.view.default)
1920
def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor:
2021
assert isinstance(t, StructuredSparseTensor)
2122

@@ -120,14 +121,14 @@ def new_encoding_fn(d: int) -> list[int]:
120121
return new_physical, new_encoding
121122

122123

123-
@StructuredSparseTensor.implements(aten._unsafe_view.default)
124+
@impl(aten._unsafe_view.default)
124125
def _unsafe_view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor:
125126
return view_default(
126127
t, shape
127128
) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp
128129

129130

130-
@StructuredSparseTensor.implements(aten.unsqueeze.default)
131+
@impl(aten.unsqueeze.default)
131132
def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTensor:
132133
assert isinstance(t, StructuredSparseTensor)
133134
assert -t.ndim - 1 <= dim < t.ndim + 1
@@ -141,7 +142,7 @@ def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTe
141142
return StructuredSparseTensor(t.physical, new_v_to_ps)
142143

143144

144-
@StructuredSparseTensor.implements(aten.squeeze.dims)
145+
@impl(aten.squeeze.dims)
145146
def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Tensor:
146147
assert isinstance(t, StructuredSparseTensor)
147148

@@ -157,15 +158,15 @@ def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Ten
157158
return to_most_efficient_tensor(t.physical, new_v_to_ps)
158159

159160

160-
@StructuredSparseTensor.implements(aten.permute.default)
161+
@impl(aten.permute.default)
161162
def permute_default(t: StructuredSparseTensor, dims: list[int]) -> StructuredSparseTensor:
162163
new_v_to_ps = [t.v_to_ps[d] for d in dims]
163164

164165
new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps)
165166
return StructuredSparseTensor(new_physical, new_v_to_ps)
166167

167168

168-
@StructuredSparseTensor.implements(aten.cat.default)
169+
@impl(aten.cat.default)
169170
def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
170171
if any(not isinstance(t, StructuredSparseTensor) for t in tensors):
171172
print_fallback(aten.cat.default, (tensors, dim), {})
@@ -217,7 +218,7 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
217218
return StructuredSparseTensor(new_physical, new_v_to_ps)
218219

219220

220-
@StructuredSparseTensor.implements(aten.expand.default)
221+
@impl(aten.expand.default)
221222
def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSparseTensor:
222223
# note that sizes could also be just an int, or a torch.Size i think
223224
assert isinstance(t, StructuredSparseTensor)
@@ -252,7 +253,7 @@ def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSpa
252253
return StructuredSparseTensor(new_physical, new_v_to_ps)
253254

254255

255-
@StructuredSparseTensor.implements(aten.broadcast_tensors.default)
256+
@impl(aten.broadcast_tensors.default)
256257
def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
257258
if len(tensors) != 2:
258259
raise NotImplementedError()
@@ -279,7 +280,7 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
279280
return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape)
280281

281282

282-
@StructuredSparseTensor.implements(aten.slice.Tensor)
283+
@impl(aten.slice.Tensor)
283284
def slice_Tensor(
284285
t: StructuredSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1
285286
) -> StructuredSparseTensor:
@@ -315,7 +316,7 @@ def slice_Tensor(
315316
return StructuredSparseTensor(new_physical, t.v_to_ps)
316317

317318

318-
@StructuredSparseTensor.implements(aten.transpose.int)
319+
@impl(aten.transpose.int)
319320
def transpose_int(t: StructuredSparseTensor, dim0: int, dim1: int) -> StructuredSparseTensor:
320321
assert isinstance(t, StructuredSparseTensor)
321322

src/torchjd/sparse/_structured_sparse_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def decorator(func):
126126
return decorator
127127

128128

129+
impl = StructuredSparseTensor.implements
130+
131+
129132
def print_fallback(func, args, kwargs) -> None:
130133
def tensor_to_str(t: Tensor) -> str:
131134
result = f"{t.__class__.__name__} - shape: {t.shape}"

0 commit comments

Comments
 (0)