-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptimised asm for iterative treesum() example
More file actions
343 lines (302 loc) · 10.3 KB
/
optimised asm for iterative treesum() example
File metadata and controls
343 lines (302 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
struct node {
int val;
int nchildren;
node* left;
node* right;
}
int treesum(node* root)
{
node* wa[root->n_children + 2];
wa[0] = root;
int waRead = 0;
int waWrite = 1;
int sum = 0;
while(waRead < waWrite)
{
root = wa[waRead++];
sum += root->val;
int t = !!(root->left);
wa[waWrite] = root->left;
wa[waWrite + t] = root->right;
waWrite += !!(root->right) + t;
}
return sum;
}
; root node* in rdi
treesum:
mov ecx, [rdi + 4] ; get n_children
add ecx, 2 ; headroom
neg rcx
lea rcx, [rsp + rcx * 8] ; node* array on stack :: waRead
mov [rcx], rdi
lea rdx, [rcx + 8] ; waWrite
xor eax, eax ; sum
mov r8, 8 ; amount for conditional increment
.loop:
mov rdi, [rcx] ; read node*
add rcx, r8 ; ++
add eax, [rdi]
mov rsi, [rdi + 8] ; left node*
mov [rdx], rsi ; written
xor r9, r9
test rsi, rsi
cmovnz r9, r8 ; maybe increment
mov rsi, [rdi + 16] ; right node*
mov [rdx + r9], rsi ; written
xor r10, r10
test rsi, rsi
cmovnz r10, r8 ; maybe increment
add r9, r10
add rdx, r9 ; both conditional increments in one op
cmp rcx, rdx
js loop ; while read < write
; 16µops, 4 read 2 write, 3.2c/l
ret ; result is already in eax
; only used permitted registers
; this code is obviously longer and more complex than the recursive
; version, BUT every step is certain, so although the operations within an
; iteration are dependent on earlier operations within that iteration,
; once the writepointer is far enough ahead of the readpointer there's
; nothing stopping the core from initiating multiple iterations and
; running them all in paralell, meaning the execution time simplifies to
; just (µops)/(issue_width).
; the only thing I'm not quite happy with is that the loop condition is
; dependent on the loop result, meaning the core cannot know that the loop
; has ended before it gets there. Changing so that n_children elements from
; the start are always counted would improve this, but this cannot be done
; for the avx version, so I didn't bother.
; ********* avx512 version **********
treesum:
mov ecx, [rdi + 4] ; get n_children
add ecx, 20 ; headroom
neg rcx
lea r8, [rsp + rcx * 8] ; node* array on stack :: waRead
mov [r8], rdi
xor ecx, ecx
mov edx, 1 ; waWrite
mov r9d, 0xffff ; full mask
xor r10d, r10d ; a handy zero
vpxor ymm0, ymm0 ; sum: 8 dwords
.loop:
vmovupq zmm1, [r8 + rcx * 8] ; read nodes
add ecx, 16
mov eax, ecx
sub eax, edx ; if (read_index + 16) - write_index
cmovns ecx, edx ; >=0, read_index = write_index
cmovs eax, r10d ; <0, =0
shrx eax, r9d, eax
kmovw k1, eax
vpgatherqd ymm4{k1}{z}, [r10 + zmm1] ; sums
vpgatherqq zmm2{k1}{z}, [r10 + zmm1 + 8] ; left node*s
vpgatherqq zmm3{k1}{z}, [r10 + zmm1 + 16] ; right node*s
vpaddd ymm0{k1}, ymm4
vmovapq zmm4, zmm20 ; permute table for first half merge
vpermi2q zmm4, zmm2, zmm3 ; (optional)
vmovapq zmm5, zmm21 ; permute table for second half merge
vpermi2q zmm5, zmm2, zmm3 ; (optional)
vptestmq k1, zmm4, zmm4 ; test for null pointers
kmovw eax, k1
vptestmq k2, zmm5, zmm5
vpcompressq zmm4{k1}, zmm4 ; compress out the nulls
vmovudq [r8 + rdx * 8], zmm4
popcnt eax, eax ; how many was that?
add rdx, eax
vpcompressq zmm5{k2}, zmm5 ; compress out the nulls
vmovudq [r8 + rdx * 8], zmm5
kmovw eax, k2
popcnt eax, eax ; how many was that?
add rdx, eax
cmp rcx, rdx
js .loop ; while read < write
; 36 µops, 25 reads 2 writes 12.5c/l
; (12p5,5p0,1p05,3p1,3p06,5p0156)
valignd ymm1, ymm0, ymm0, 4
vpaddd ymm0, ymm1
valignd ymm1, ymm0, ymm0, 2
vpaddd ymm0, ymm1
valignd ymm1, ymm0, ymm0, 1
vpaddd ymm0, ymm1
movd eax, xmm0
ret
; this does effectively the same as the scalar version:
; - read from the list
; - work out how many of the entries were valid and increment read index
; appropriatly while generating a mask of valid entries in the register
; - gather vals,lefts & rights from the valid entries
; - rearrange the lefts and rights into sequence (optional)
; - check for nulls and compress before writing to the list
; - increment the write index by number of valid entries written
; - check read index has not caught up with write index and repeat
; then it has to sum the sums laterally to get the final total
; However, this is not the *right* way to vectorise this operation, merely
; an attempt to make it work from a faulty starting point.
; The right way is to completely redesign the memory layout of the data,
; probably a struct-of-arrays approach, with dword indexes instead of
; pointers for left and right, and an occupancy bitfield element, updated
; during mutations. Then all that is needed is to scan the val array,
; 16-at-a-time, masking the adds using the occupancy data.
;; ************ V2, using carry for index increment
struct node {
node* left;
node* right;
int val;
int childcount;
}
; node* in rdi
treesum:
lea rcx, [rsp - 8] ; growing node* array downward on stack :: waRead
mov [rcx], rdi ; first node written ... but never actually read
xor eax, eax ; sum = 0
xor edx, edx ; read index
mov r8, -1 ; write index
.loop:
sub rdx, 1 ; growing down
add eax, [rdi + 16]
mov rsi, [rdi] ; left node*
mov [rcx + r8 * 8], rsi ; written
cmp rsi, 1 ; carry (borrow) set only if rsi == 0
adc r8, -1 ; subtract 1 only if carry not set
mov rsi, [rdi + 8] ; right node*
mov [rcx + r8 * 8], rsi ; written
mov rdi, [rcx + rdx * 8] ; read node* for next iteration
cmp rsi, 1 ; carry (borrow) set only if rsi == 0
adc r8, -1 ; subtract 1 only if carry not set
cmp r8, rdx
js loop ; while write < read (more negative == ahead)
; 14µops, 4 read 2 write, 2.8c/l
ret ; result is already in eax
; only used permitted registers
;; ************ V2.5, same but simpler **********************
treesum: ;; root node pointer in rdi
xor eax, eax ;; for value accumulation
mov rdx, -1 ;; write index
mov rcx, -1 ;; read index
.loop:
add rax, [rdi + 16] ;; get the node value and accumulate
mov r8, [rdi + 8] ;; get right node pointer
mov [rsp + rdx * 8], r8 ;; write right node pointer to stack
cmp r8, 1 ;; check right node pointer for null
adc rdx, -1 ;; decrement of write index cancelled by carry (null pointer discard)
mov r8, [rdi] ;; read left node pointer
mov [rsp + rdx * 8], r8 ;; write left node pointer to stack
cmp r8, 1 ;; check left node pointer for null
adc rdx, -1 ;; decrement of write index cancelled by carry (null pointer discard)
mov rdi, [rsp + rcx * 4] ;; read next node pointer
sub rcx, 1 ;; decrement read index
cmp rcx, rdx ;; if read index is still more or equal to write index (no carry) the read node is valid
jnc .loop ;; so loop again
ret
;; ************ V3, using childcount for loop end
struct node {
node* left;
node* right;
int val;
int childcount;
}
; node* in rdi
treesum:
mov ecx, [rdi + 20] ; get n-1 child-count of the tree root
mov edx, ecx ; will be our read index
not rcx ; (-rcx)-1
lea rcx, [rsp + rcx * 8] ; space for node* array on stack :: wa
mov eax, [rdi + 16] ; sum = root->val
mov r8d, edx ; write index. starting at top of array.
.loop:
mov rsi, [rdi] ; left node* read,
mov [rcx + r8 * 8], rsi ; and written.
cmp rsi, 1 ; carry set only if rsi == 0
adc r8, -1 ; subtract 1 only if carry not set
mov rsi, [rdi + 8] ; right node* read,
mov [rcx + r8 * 8], rsi ; and written.
cmp rsi, 1 ; carry set only if rsi == 0
adc r8, -1 ; subtract 1 only if carry not set
mov rdi, [rcx + rdx * 8] ; read node* for next iteration
add eax, [rdi + 16] ; and accumulate its value
sub edx, 1 ; growing down
jnz loop ; while read index is positive
; 13µops, 4 read 2 write, 2.6cc/l
; last element at rdx==0
ret ; result is already in eax
; only used permitted registers
; ********* Vector V2: vpinserti32x4 for node reads
treesum:
mov ecx, [rdi + 20] ; get n_children
add ecx, 20 ; headroom
neg rcx
lea r8, [rsp + rcx * 8] ; node* array on stack :: wa
mov [r8], rdi ; first node*
xor ecx, ecx ; waRead
mov edx, 1 ; waWrite
mov r9d, 0xff ; full mask
xor r10d, r10d ; a handy zero
vpxor ymm0, ymm0 ; sum: 8 dwords
.loop:
vmovupq zmm1, [r8 + rcx * 8] ; read 8 node*s from array
add ecx, 8
mov eax, ecx
sub eax, edx ; if (read_index += 8) - write_index
cmovns ecx, edx ; >=0, read_index = write_index
cmovs eax, r10d ; <0, =0
shrx eax, r9d, eax
kmovw k1, eax
vpgatherqd ymm4{k1}{z}, [r10 + zmm1 + 16] ; sums
vpaddd ymm0{k1}, ymm4
popcnt eax, eax ;p1
movq rdi, xmm1 ;p0
vmovdqu xmm4, [rdi] ;p23 fetch l/r pointers in one read
sub eax, 1 ;p06
jz .nomore ;-↑
vpextrq rdi, xmm1, 1 ;p0,p15
vpinserti64x2 ymm4, ymm4, [rdi], 1 ;p05,p23
sub eax, 1 ;p06
jz .nomore ;-↑
vpextracti64x2 xmm2, zmm1, 1 ;p5 ;;9
movq rdi, xmm2
vpinserti64x2 zmm4, zmm4, [rdi], 2 ; fetch l/r pointers in one read
sub eax, 1
jz .nomore
vpextrq rdi, xmm2, 1
vpinserti64x2 zmm4, zmm4, [rdi], 3
sub eax, 1
jz .nomore
vpextracti64x2 xmm2, zmm1, 2 ;;19
movq rdi, xmm2
vmovdqu xmm5, [rdi] ; fetch l/r pointers in one read
sub eax, 1
jz .nomore
vpextrq rdi, xmm2, 1
vpinserti64x2 ymm5, ymm5, [rdi], 1
sub eax, 1
jz .nomore
vpextracti64x2 xmm2, zmm1, 3 ;;29
movq rdi, xmm2
vpinserti64x2 zmm5, zmm5, [rdi], 2 ; fetch l/r pointers in one read
sub eax, 1
jz .nomore
vpextrq rdi, xmm2, 1
vpinserti64x2 zmm5, zmm5, [rdi], 3 (8); 37µop (-12 removed) to do 8 reads instead of 16
.nomore:
vptestmq k1, zmm4, zmm4 ; test for null pointers
kmovw eax, k1
vptestmq k2, zmm5, zmm5
vpcompressq zmm4{k1}, zmm4 ; compress out the nulls
vmovudq [r8 + rdx * 8], zmm4
popcnt eax, eax ; how many was that?
add rdx, eax
vpcompressq zmm5{k2}, zmm5 ; compress out the nulls
vmovudq [r8 + rdx * 8], zmm5
kmovw eax, k2
popcnt eax, eax ; how many was that?
add rdx, eax
cmp rcx, rdx
js .loop ; while read < write
; 61 µops, 17 reads 2 writes 8.5c/l
; (9p5,11p0,1p05,4p1,10p06,5p0156)
valignd ymm1, ymm0, ymm0, 4
vpaddd ymm0, ymm1
valignd ymm1, ymm0, ymm0, 2
vpaddd ymm0, ymm1
valignd ymm1, ymm0, ymm0, 1
vpaddd ymm0, ymm1
movd eax, xmm0
ret