Skip to content

Commit 2e641c7

Browse files
committed
Rename SST to SLT
1 parent 63549ca commit 2e641c7

File tree

7 files changed

+16
-16
lines changed

7 files changed

+16
-16
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor, nn, vmap
55
from torch.autograd.graph import get_gradient_edge
66

7-
from torchjd.sparse import make_sst
7+
from torchjd.sparse import make_slt
88

99
from ._edge_registry import EdgeRegistry
1010
from ._gramian_accumulator import GramianAccumulator
@@ -177,7 +177,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
177177
output_dims = list(range(output.ndim))
178178
identity = torch.eye(output.ndim, dtype=torch.int64)
179179
basis = torch.concatenate([identity, identity], dim=0)
180-
jac_output = make_sst(torch.ones_like(output), basis)
180+
jac_output = make_slt(torch.ones_like(output), basis)
181181

182182
vmapped_diff = differentiation
183183
for _ in output_dims:

src/torchjd/sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Need to import this to execute the code inside and thus to override the functions
22
from . import _aten_function_overrides
3-
from ._sparse_latticed_tensor import SparseLatticedTensor, make_sst
3+
from ._sparse_latticed_tensor import SparseLatticedTensor, make_slt

src/torchjd/sparse/_aten_function_overrides/einsum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def prepare_for_elementwise_op(
131131
t1: Tensor | int | float, t2: Tensor | int | float
132132
) -> tuple[SparseLatticedTensor, SparseLatticedTensor]:
133133
"""
134-
Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being
135-
a SST, Tensor, int or float.
134+
Prepares two SLTs of the same shape from two args, one of those being a SLT, and the other being
135+
a SLT, Tensor, int or float.
136136
"""
137137

138138
assert isinstance(t1, SparseLatticedTensor) or isinstance(t2, SparseLatticedTensor)
@@ -172,7 +172,7 @@ def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
172172

173173
@impl(aten.mul.Scalar)
174174
def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor:
175-
# TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check
175+
# TODO: maybe it could be that scalar is a scalar SLT and t is a normal tensor. Need to check
176176
# that
177177

178178
assert isinstance(t, SparseLatticedTensor)

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
178178
ref_basis = ref_tensor.basis
179179
if any(not torch.equal(t.basis, ref_basis) for t in tensors_[1:]):
180180
raise NotImplementedError(
181-
"Override for aten.cat.default does not support SSTs that do not all have the same "
181+
"Override for aten.cat.default does not support SLTs that do not all have the same "
182182
f"basis. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the "
183183
f"following dim: {dim}."
184184
)

src/torchjd/sparse/_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]:
179179
# S1 = G @ K1
180180
# S2 = G @ K2
181181
#
182-
# SST(p1, S1) = SST(SST(p1, K1), G)
183-
# SST(p2, S2) = SST(SST(p2, K2), G)
182+
# SLT(p1, S1) = SLT(SLT(p1, K1), G)
183+
# SLT(p2, S2) = SLT(SLT(p2, K2), G)
184184

185185
col_magnitudes = torch.sum(torch.abs(H), dim=0)
186186
non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0]

src/torchjd/sparse/_sparse_latticed_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]:
160160
161161
Example:
162162
Imagine a vector of size 3, and of value [1, 2, 3].
163-
Imagine a SST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps.
163+
Imagine a SLT t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps.
164164
t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix
165165
[[1, 0, 0], [0, 2, 0], [0, 0, 3]]).
166166
When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e.
@@ -203,15 +203,15 @@ def to_sparse_latticed_tensor(t: Tensor) -> SparseLatticedTensor:
203203
if isinstance(t, SparseLatticedTensor):
204204
return t
205205
else:
206-
return make_sst(physical=t, basis=torch.eye(t.ndim, dtype=torch.int64))
206+
return make_slt(physical=t, basis=torch.eye(t.ndim, dtype=torch.int64))
207207

208208

209209
def to_most_efficient_tensor(physical: Tensor, basis: Tensor) -> Tensor:
210210
physical, basis = fix_dim_of_size_1(physical, basis)
211211
physical, basis = fix_ungrouped_dims(physical, basis)
212212

213213
if (basis.sum(dim=0) == 1).all():
214-
# TODO: this can be done more efficiently (without even creating the SST)
214+
# TODO: this can be done more efficiently (without even creating the SLT)
215215
return SparseLatticedTensor(physical, basis).to_dense()
216216
else:
217217
return SparseLatticedTensor(physical, basis)
@@ -264,7 +264,7 @@ def fix_ungrouped_dims(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]
264264
return nphysical, new_basis
265265

266266

267-
def make_sst(physical: Tensor, basis: Tensor) -> SparseLatticedTensor:
267+
def make_slt(physical: Tensor, basis: Tensor) -> SparseLatticedTensor:
268268
"""Fix physical and basis and create a SparseLatticedTensor with them."""
269269

270270
physical, basis = fix_dim_of_size_1(physical, basis)

tests/unit/sparse/test_sparse_latticed_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,17 +370,17 @@ def test_get_column_indices(source: list[int], destination: list[int], ndim: int
370370

371371

372372
@mark.parametrize(
373-
["sst_args", "dim"],
373+
["slt_args", "dim"],
374374
[
375375
([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1),
376376
([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1),
377377
],
378378
)
379379
def test_concatenate(
380-
sst_args: list[tuple[list[int], Tensor]],
380+
slt_args: list[tuple[list[int], Tensor]],
381381
dim: int,
382382
):
383-
tensors = [SparseLatticedTensor(randn_(pshape), basis) for pshape, basis in sst_args]
383+
tensors = [SparseLatticedTensor(randn_(pshape), basis) for pshape, basis in slt_args]
384384
res = aten.cat.default(tensors, dim)
385385
expected = aten.cat.default([t.to_dense() for t in tensors], dim)
386386

0 commit comments

Comments
 (0)