Skip to content

Commit c6e3fd9

Browse files
committed
Add new_implementation idea in einsum.
1 parent c7047e1 commit c6e3fd9

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,18 @@ def einsum(
692692
# First part of the algorithm, determine how to cluster physical indices as well as the common
693693
# p_shapes corresponding to matching v_dims. Second part translates to physical einsum.
694694

695-
# new plan for first part:
696695
# get a map from einsum index to (tensor_idx, v_dims)
696+
# get a map from einsum index to merge of strides corresponding to v_dims with that index
697+
# use to_target_physical_strides on each physical and v_to_ps
698+
# cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding
699+
# p_to_vs
700+
# get unique indices
701+
# map output indices (there can be splits)
702+
# call physical einsum
703+
# build resulting dst
704+
705+
# OVER
706+
697707
# an index in the physical einsum is uniquely characterized by a virtual einsum index and a
698708
# stride corresponding to the physical stride in the virtual one (note that as the virtual shape
699709
# for two virtual index that match should match, then we want to match the strides and reshape

0 commit comments

Comments
 (0)