Skip to content

Commit 731de15

Browse files
committed
Remove unsquash_pdim
* It was unused and I think it will be replaced by functions that find divisors of the basis
1 parent c4f7dfc commit 731de15

2 files changed

Lines changed: 0 additions & 93 deletions

File tree

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -69,61 +69,6 @@ def infer_shape(shape: list[int], numel: int) -> list[int]:
6969
return [inferred if s == -1 else s for s in shape]
7070

7171

72-
def unsquash_pdim(
73-
physical: Tensor, basis: Tensor, pdim: int, new_pdim_shape: list[int]
74-
) -> tuple[Tensor, Tensor]:
75-
"""
76-
EXAMPLE:
77-
78-
physical = [
79-
[1, 2, 3, 4, 5, 6],
80-
[7, 8, 9, 10, 11, 12],
81-
[13, 14, 15, 16, 17, 18],
82-
]
83-
basis = [
84-
[1, 1],
85-
[0, 2],
86-
]
87-
88-
dim = 1
89-
shape = [2, 3]
90-
91-
new_physical = [[
92-
[1, 2, 3],
93-
[4, 5, 6],
94-
], [
95-
[7, 8, 9],
96-
[10, 11, 12],
97-
], [
98-
[13, 14, 15],
99-
[16, 17, 18],
100-
]]
101-
102-
new_basis = [
103-
[1, 3, 1],
104-
[0, 6, 2]
105-
"""
106-
107-
# TODO: handle working with multiple dimensions at once
108-
109-
old_shape = list(physical.shape)
110-
new_shape = old_shape[:pdim] + new_pdim_shape + old_shape[pdim + 1 :]
111-
new_physical = physical.reshape(new_shape)
112-
113-
multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))])
114-
115-
new_basis = torch.concat(
116-
[
117-
basis[:, :pdim],
118-
torch.outer(basis[:, pdim], multipliers),
119-
basis[:, pdim + 1 :],
120-
],
121-
dim=1,
122-
)
123-
124-
return new_physical, new_basis
125-
126-
12772
@impl(aten._unsafe_view.default)
12873
def _unsafe_view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
12974
return view_default(

tests/unit/sparse/test_sparse_latticed_tensor.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
_IN_PLACE_POINTWISE_FUNCTIONS,
1111
_POINTWISE_FUNCTIONS,
1212
)
13-
from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim
1413
from torchjd.sparse._coalesce import fix_zero_basis_vectors
1514
from torchjd.sparse._sparse_latticed_tensor import (
1615
SparseLatticedTensor,
@@ -312,43 +311,6 @@ def test_fix_ungrouped_dims(
312311
assert torch.equal(fixed_basis, expected_basis)
313312

314313

315-
@mark.parametrize(
316-
[
317-
"physical_shape",
318-
"basis",
319-
"pdim",
320-
"new_pdim_shape",
321-
"expected_physical_shape",
322-
"expected_basis",
323-
],
324-
[
325-
([4], tensor([[1], [2]]), 0, [4], [4], tensor([[1], [2]])), # trivial
326-
([4], tensor([[1], [2]]), 0, [2, 2], [2, 2], tensor([[2, 1], [4, 2]])),
327-
(
328-
[3, 4, 5],
329-
tensor([[1, 2, 0], [1, 0, 1], [0, 1, 1]]),
330-
1,
331-
[2, 1, 1, 2],
332-
[3, 2, 1, 1, 2, 5],
333-
tensor([[1, 4, 4, 4, 2, 0], [1, 0, 0, 0, 0, 1], [0, 2, 2, 2, 1, 1]]),
334-
),
335-
],
336-
)
337-
def test_unsquash_pdim(
338-
physical_shape: list[int],
339-
basis: Tensor,
340-
pdim: int,
341-
new_pdim_shape: list[int],
342-
expected_physical_shape: list[int],
343-
expected_basis: Tensor,
344-
):
345-
physical = randn_(physical_shape)
346-
new_physical, new_basis = unsquash_pdim(physical, basis, pdim, new_pdim_shape)
347-
348-
assert list(new_physical.shape) == expected_physical_shape
349-
assert torch.equal(new_basis, expected_basis)
350-
351-
352314
@mark.parametrize(
353315
[
354316
"source",

0 commit comments

Comments
 (0)