Skip to content

Commit c4f7dfc

Browse files
committed
Add docstring to fix_dim_of_size_1 and fix_ungrouped_dims
1 parent 8e17a77 commit c4f7dfc

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/torchjd/sparse/_sparse_latticed_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,17 @@ def get_full_source(source: list[int], destination: list[int], ndim: int) -> lis
217217

218218

219219
def fix_dim_of_size_1(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]:
220+
"""
221+
Removes physical dimensions of size one and returns the corresponding new physical and new basis
222+
"""
223+
220224
is_of_size_1 = tensor([s == 1 for s in physical.shape], dtype=torch.bool)
221225
return physical.squeeze(), basis[:, ~is_of_size_1]
222226

223227

224228
def fix_ungrouped_dims(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]:
229+
"""Squash together physical dimensions that can be squashed."""
230+
225231
groups = get_groupings(list(physical.shape), basis)
226232
nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups])
227233
basis_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64)

0 commit comments

Comments
 (0)