Commit bcedecd
Support MoE for GPTModelPipe (#373)
* MOE: Support MoE layers creation for GPTModelPipe
Signed-off-by: Moshe Island <misland@habana.ai>
* MOE: Support MoE aux loss for GPTModelPipe
Propagate aux loss along GPTModelPipe layers by forwarding the aggregated loss
from each transformer layer to the next transformer layer.
In addition, add a layer to GPTModelPipe, after the last transformer layer, to
catch the final aggregated aux loss and cache it for use in the loss function.
Signed-off-by: Moshe Island <misland@habana.ai>
* MOE: Support display of MoE loss for GPTModelPipe
Signed-off-by: Moshe Island <misland@habana.ai>
* MOE: Verify MoE with no pipe/grad partitioned
Currently PipelineEngine supports only a single tensor partitioning with grad.
MoE model requires to forward with grad both the activations and the aux_loss.
Therefore, until PilelineEngine limitation is removed, verify no partitioning
when using MoE.
Signed-off-by: Moshe Island <misland@habana.ai>
---------
Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>1 parent 3c5f475 commit bcedecd
3 files changed
Lines changed: 121 additions & 27 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
| |||
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
19 | | - | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
360 | 361 | | |
361 | 362 | | |
362 | 363 | | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
363 | 377 | | |
364 | 378 | | |
365 | 379 | | |
366 | | - | |
367 | | - | |
368 | | - | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
369 | 391 | | |
370 | 392 | | |
371 | 393 | | |
| |||
404 | 426 | | |
405 | 427 | | |
406 | 428 | | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
407 | 434 | | |
408 | 435 | | |
409 | 436 | | |
| |||
418 | 445 | | |
419 | 446 | | |
420 | 447 | | |
421 | | - | |
| 448 | + | |
422 | 449 | | |
423 | 450 | | |
424 | 451 | | |
425 | 452 | | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
426 | 469 | | |
427 | 470 | | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1229 | 1229 | | |
1230 | 1230 | | |
1231 | 1231 | | |
1232 | | - | |
| 1232 | + | |
| 1233 | + | |
1233 | 1234 | | |
1234 | 1235 | | |
1235 | 1236 | | |
| |||
1321 | 1322 | | |
1322 | 1323 | | |
1323 | 1324 | | |
| 1325 | + | |
| 1326 | + | |
| 1327 | + | |
| 1328 | + | |
1324 | 1329 | | |
1325 | 1330 | | |
1326 | 1331 | | |
| |||
1381 | 1386 | | |
1382 | 1387 | | |
1383 | 1388 | | |
| 1389 | + | |
| 1390 | + | |
| 1391 | + | |
| 1392 | + | |
| 1393 | + | |
| 1394 | + | |
| 1395 | + | |
| 1396 | + | |
| 1397 | + | |
1384 | 1398 | | |
1385 | 1399 | | |
1386 | 1400 | | |
1387 | 1401 | | |
1388 | 1402 | | |
1389 | 1403 | | |
| 1404 | + | |
1390 | 1405 | | |
1391 | 1406 | | |
1392 | | - | |
1393 | | - | |
1394 | | - | |
1395 | | - | |
1396 | | - | |
1397 | | - | |
1398 | | - | |
1399 | | - | |
1400 | | - | |
| 1407 | + | |
| 1408 | + | |
| 1409 | + | |
| 1410 | + | |
| 1411 | + | |
| 1412 | + | |
| 1413 | + | |
| 1414 | + | |
| 1415 | + | |
| 1416 | + | |
| 1417 | + | |
| 1418 | + | |
| 1419 | + | |
| 1420 | + | |
| 1421 | + | |
| 1422 | + | |
| 1423 | + | |
| 1424 | + | |
| 1425 | + | |
| 1426 | + | |
| 1427 | + | |
| 1428 | + | |
| 1429 | + | |
| 1430 | + | |
| 1431 | + | |
| 1432 | + | |
| 1433 | + | |
1401 | 1434 | | |
1402 | 1435 | | |
1403 | 1436 | | |
| |||
1499 | 1532 | | |
1500 | 1533 | | |
1501 | 1534 | | |
| 1535 | + | |
| 1536 | + | |
| 1537 | + | |
| 1538 | + | |
| 1539 | + | |
| 1540 | + | |
| 1541 | + | |
| 1542 | + | |
| 1543 | + | |
| 1544 | + | |
| 1545 | + | |
| 1546 | + | |
| 1547 | + | |
1502 | 1548 | | |
1503 | 1549 | | |
1504 | 1550 | | |
| |||
1682 | 1728 | | |
1683 | 1729 | | |
1684 | 1730 | | |
1685 | | - | |
1686 | | - | |
1687 | | - | |
1688 | | - | |
1689 | | - | |
1690 | | - | |
1691 | | - | |
1692 | 1731 | | |
1693 | 1732 | | |
| 1733 | + | |
1694 | 1734 | | |
1695 | 1735 | | |
1696 | | - | |
1697 | | - | |
1698 | | - | |
1699 | | - | |
| 1736 | + | |
1700 | 1737 | | |
1701 | 1738 | | |
1702 | 1739 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
| |||
667 | 668 | | |
668 | 669 | | |
669 | 670 | | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
670 | 676 | | |
671 | | - | |
| 677 | + | |
672 | 678 | | |
673 | 679 | | |
674 | 680 | | |
| |||
0 commit comments