Commit 2a2ef8c
authored
[WebGPU] Support continuous decoding (RewindTo) with graph capture (microsoft#2083)
This pull request introduces improvements to the handling of attention
masks in both the CUDA and WebGPU backends, focusing on more efficient
and correct updates of mask buffers during decoding. The main changes
are the implementation of a CPU-side update for static attention masks
in CUDA and the addition of a reusable staging buffer for efficient mask
updates in WebGPU, with logic to avoid redundant work for single-beam
cases.
**CUDA backend improvements:**
* Replaced the previous (commented-out and incorrect) CUDA memory set
logic in `DefaultPositionInputs::RewindMask` with a CPU-side update that
correctly sets attended and non-attended positions in the attention mask
for each batch/beam, followed by a copy back to the device. This ensures
the mask is set with 1s for attended tokens and 0s for future tokens,
supporting both `int32_t` and `int64_t` types.
**WebGPU backend improvements:**
* Added a reusable CPU staging buffer (`mask_staging_buffer_`) to the
`InterfaceImpl` struct for efficient attention mask updates, avoiding
repeated allocations and redundant writes.
* Implemented the `UpdateAttentionMask` method to efficiently update the
mask for single-beam cases by only filling new positions with 1s and
copying the relevant portion to the device, falling back to CPU for
multi-beam cases. This method handles static update path and supports
both `int32_t` and `int64_t` mask types.1 parent 09e69e4 commit 2a2ef8c
3 files changed
Lines changed: 137 additions & 14 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
380 | 380 | | |
381 | 381 | | |
382 | 382 | | |
383 | | - | |
384 | | - | |
385 | | - | |
386 | | - | |
387 | | - | |
388 | | - | |
389 | | - | |
390 | | - | |
391 | | - | |
392 | | - | |
393 | | - | |
394 | | - | |
395 | | - | |
396 | | - | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
397 | 408 | | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
398 | 414 | | |
399 | 415 | | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
400 | 427 | | |
401 | 428 | | |
402 | 429 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
171 | 171 | | |
172 | 172 | | |
173 | 173 | | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
174 | 179 | | |
175 | 180 | | |
176 | 181 | | |
| |||
190 | 195 | | |
191 | 196 | | |
192 | 197 | | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
193 | 239 | | |
194 | 240 | | |
195 | 241 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1334 | 1334 | | |
1335 | 1335 | | |
1336 | 1336 | | |
| 1337 | + | |
| 1338 | + | |
| 1339 | + | |
| 1340 | + | |
| 1341 | + | |
| 1342 | + | |
| 1343 | + | |
| 1344 | + | |
| 1345 | + | |
| 1346 | + | |
| 1347 | + | |
| 1348 | + | |
| 1349 | + | |
| 1350 | + | |
| 1351 | + | |
| 1352 | + | |
| 1353 | + | |
| 1354 | + | |
| 1355 | + | |
| 1356 | + | |
| 1357 | + | |
| 1358 | + | |
| 1359 | + | |
| 1360 | + | |
| 1361 | + | |
| 1362 | + | |
| 1363 | + | |
| 1364 | + | |
| 1365 | + | |
| 1366 | + | |
| 1367 | + | |
| 1368 | + | |
| 1369 | + | |
| 1370 | + | |
| 1371 | + | |
| 1372 | + | |
| 1373 | + | |
| 1374 | + | |
| 1375 | + | |
| 1376 | + | |
| 1377 | + | |
| 1378 | + | |
| 1379 | + | |
| 1380 | + | |
| 1381 | + | |
| 1382 | + | |
| 1383 | + | |
| 1384 | + | |
| 1385 | + | |
| 1386 | + | |
1337 | 1387 | | |
1338 | 1388 | | |
1339 | 1389 | | |
| |||
0 commit comments