Skip to content

Commit f87ecbb

Browse files
committed
Use new strides in to_dense
* The result is the same as before * Before that we only iterated on the pdims used by each virtual dim, and summed them if a pdim was present multiple times. * Now the new stride is already the sum of the old strides when a pdim is present multiple times in a vdim. We iterate over all dimensions, because for dimensions not present in the vdim, the stride is simply 0. * There's probably a more efficient implementation
1 parent 48387ab commit f87ecbb

1 file changed

Lines changed: 6 additions & 20 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,36 +76,22 @@ def to_dense(
7676
if self.physical.ndim == 0:
7777
return self.physical
7878

79-
# This is a list of strides whose shape matches that of v_to_ps except that each element
80-
# is the stride factor of the index to get the right element for the corresponding virtual
81-
# dimension. Stride is the jump necessary to go from one element to the next one in the
82-
# specified dimension. For instance if the i'th element of v_to_ps is [0, 1, 2], then the
83-
# i'th element of _strides is [physical.shape[1] * physical.shape[2], physical.shape[2], 1]
84-
# and so, if we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2]
85-
# which isa unique decomposition, then this corresponds to indexing dimensions v_to_ps[i] at
86-
# indices [j_0, j_1, j_2]
87-
s = list(self.physical.shape)
88-
strides = [strides_from_p_dims_and_p_shape(dims, s) for dims in self.v_to_ps]
79+
strides = [strides_v2(p_dims, list(self.physical.shape)) for p_dims in self.v_to_ps]
8980

9081
# TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk
9182
# what's faster
9283
p_index_ranges = [torch.arange(s) for s in self.physical.shape]
9384
p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij")
9485
v_indices_grid = list[Tensor]()
86+
all_pdims = list(range(self.physical.ndim))
9587
for stride, dims in zip(strides, self.v_to_ps):
9688
stride_ = torch.tensor(stride, dtype=torch.int)
9789

98-
if len(dims) > 0:
99-
v_indices_grid.append(
100-
torch.sum(
101-
torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1
102-
)
90+
v_indices_grid.append(
91+
torch.sum(
92+
torch.stack([p_indices_grid[d] for d in all_pdims], dim=-1) * stride_, dim=-1
10393
)
104-
else:
105-
v_indices_grid.append(torch.tensor(0, dtype=torch.int))
106-
# This is supposed to be a vector of shape d_1 * d_2 ...
107-
# whose elements are the coordinates 1 in p_indices_grad[d_1] times stride 1
108-
# plus coordinates 2 in p_indices_grad[d_2] times stride 2, etc...
94+
)
10995

11096
res = torch.zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype)
11197
res[tuple(v_indices_grid)] = self.physical

0 commit comments

Comments
 (0)