Commit c81126e
authored
Arm backend: Refactor and bug-fix RewriteIndexPutPass (#18197)
The patch should hopefully make the pass easier to understand. Make
explicit that we set N=1, handle explicit indexing by folding them in
the K dimension, and handle full indexing (select all values) by folding
them in the C dimension.
Note that TOSA and torch has switched terminology regarding what the
parameter 'values' means, instead, use a new naming: TOSA values_in ==
torch x/self tensor, call this 'destination'. TOSA input == torch
values, call this 'data'.
Additionally, the pass earlier didn't account for that 1) There are
fully indexed dimensions
2) Index tensors can be broadcast
3) The data tensor can be smaller than (N, W, C), and
require broadcasting first.
4) None index tensors were incorrectly handled.
Regarding 1-3):
Given destination of shape (N, K, C),
TOSA.SCATTER semantics require the shape (N, W) of the index tensor,
including possibly an implicit C dimension, to match the data shape (N,
W_d, C_d). Torch can however broadcast both these inputs. We need to
expand/reshape the data tensor correctly.
Example (ignoring N, it's always 1):
>>> destination = torch.ones(5, 2), K=5, C=2
>>> indices = (torch.tensor([0, 2]),) # Indexes K dim W=2 times,
C is implicitly assumed to be C=2.
>>> data = torch.tensor([10.0, 20.0]) # W_d = 1 !!, C_d=2
>>> torch.index_put(destination, indices, data)
tensor([[10., 20.],
[ 1., 1.],
[10., 20.],
[ 1., 1.],
[ 1., 1.]])
Or even
>>> [...]
>>> data = torch.tensor([10.0]) # W_d = 1, C_d=1 !!
>>> torch.index_put(destination, indices, data)
tensor([[10., 10.],
[ 1., 1.],
[10., 10.],
[ 1., 1.],
[10., 10.]])
The patch generalizes this to multiple dimensions. Refer to docstring in
patch for complete explaination.
4)
Is handled by adding a normalization pass
that rewrites None indice tensors to fully indexed tensors.
Signed-off-by: Erik Lundell <erik.lundell@arm.com>1 parent 22174fa commit c81126e
6 files changed
Lines changed: 374 additions & 167 deletions
File tree
- backends/arm
- _passes
- test
- modules
- ops
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
123 | 123 | | |
124 | 124 | | |
125 | 125 | | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
126 | 129 | | |
127 | 130 | | |
128 | 131 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
111 | 111 | | |
112 | 112 | | |
113 | 113 | | |
| 114 | + | |
114 | 115 | | |
115 | 116 | | |
116 | 117 | | |
| |||
444 | 445 | | |
445 | 446 | | |
446 | 447 | | |
| 448 | + | |
447 | 449 | | |
448 | 450 | | |
449 | 451 | | |
| |||
Lines changed: 143 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
0 commit comments