|
5 | 5 | #extension GL_KHR_shader_subgroup_shuffle : enable |
6 | 6 | #include "types.glsl" |
7 | 7 |
|
8 | | -#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0) |
| 8 | +#if defined(SET_ROWS) && (defined(DATA_A_TURBO2_0) || defined(DATA_A_TURBO3_0) || defined(DATA_A_TURBO4_0)) |
9 | 9 | layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; |
10 | 10 | const uint BLOCK_SIZE = 128; |
11 | 11 | #elif defined(SET_ROWS) && QUANT_K == 1 |
@@ -469,6 +469,245 @@ void main() { |
469 | 469 | data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm); |
470 | 470 | } |
471 | 471 | } |
| 472 | +#elif defined(SET_ROWS) && defined(DATA_A_TURBO2_0) |
| 473 | +// Mirror of the TURBO3_0 block above, adapted for turbo2 (4 centroids, |
| 474 | +// 2-bit pack, no signs byte). WHT tables and reduction structure are |
| 475 | +// identical (QK = 128 for both). |
| 476 | +const float TS1_T2[128] = float[128]( |
| 477 | + -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1, |
| 478 | + 1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1, |
| 479 | + -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, |
| 480 | + 1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1, |
| 481 | + -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1, |
| 482 | + 1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1, |
| 483 | + -1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, |
| 484 | + 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1 |
| 485 | +); |
| 486 | +const float TS2_T2[128] = float[128]( |
| 487 | + 1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1, |
| 488 | + 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1, |
| 489 | + 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, |
| 490 | + 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, |
| 491 | + 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1, |
| 492 | + -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, |
| 493 | + 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, |
| 494 | + -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1 |
| 495 | +); |
| 496 | +const float TINV_T2 = 0.08838834764831845; // 1 / sqrt(128) |
| 497 | +// Lloyd-Max centroids for N(0, 1/128), 4 levels (matches CENTROIDS_2BIT in C ref) |
| 498 | +const float TC2[4] = float[4](-0.133462, -0.039994, 0.039994, 0.133462); |
| 499 | +// Midpoints between adjacent centroids |
| 500 | +const float TM2[3] = float[3](-0.086728, 0.0, 0.086728); |
| 501 | + |
| 502 | +shared float wht_t2[128]; |
| 503 | +shared float sg_acc_t2[16]; |
| 504 | +shared float gnrm_t2; |
| 505 | + |
| 506 | +void main() { |
| 507 | + const uint t = gl_LocalInvocationID.x; |
| 508 | + const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; |
| 509 | + const uint gpr = p.ne00 / 128; |
| 510 | + |
| 511 | + if (gpr == 0) return; |
| 512 | + if (g >= p.ne / 128) return; |
| 513 | + |
| 514 | + uint tmp = g; |
| 515 | + const uint ig = tmp % gpr; tmp /= gpr; |
| 516 | + const uint i01 = tmp % p.ne01; tmp /= p.ne01; |
| 517 | + const uint i02 = tmp % p.ne12; |
| 518 | + const uint i03 = tmp / p.ne12; |
| 519 | + |
| 520 | + const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset(); |
| 521 | + const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE; |
| 522 | + const uint db = dst_idx(ig, i1, i02, i03) + get_doffset(); |
| 523 | + |
| 524 | + wht_t2[t] = data_s[sb + t]; |
| 525 | + barrier(); |
| 526 | + |
| 527 | + float v2 = wht_t2[t] * wht_t2[t]; |
| 528 | + v2 = subgroupAdd(v2); |
| 529 | + if (gl_SubgroupInvocationID == 0) sg_acc_t2[gl_SubgroupID] = v2; |
| 530 | + barrier(); |
| 531 | + if (t == 0) { |
| 532 | + float total = 0.0; |
| 533 | + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t2[w]; |
| 534 | + gnrm_t2 = sqrt(total); |
| 535 | + } |
| 536 | + barrier(); |
| 537 | + |
| 538 | + wht_t2[t] *= (gnrm_t2 > 1e-10) ? (1.0 / gnrm_t2) : 0.0; |
| 539 | + barrier(); |
| 540 | + |
| 541 | + wht_t2[t] *= TS1_T2[t]; |
| 542 | + barrier(); |
| 543 | + |
| 544 | + [[unroll]] for (uint h = 1; h < 128; h *= 2) { |
| 545 | + if ((t % (2 * h)) < h) { |
| 546 | + float a = wht_t2[t]; |
| 547 | + float b = wht_t2[t + h]; |
| 548 | + wht_t2[t] = a + b; |
| 549 | + wht_t2[t + h] = a - b; |
| 550 | + } |
| 551 | + barrier(); |
| 552 | + } |
| 553 | + |
| 554 | + float rv = wht_t2[t] * TINV_T2 * TS2_T2[t]; |
| 555 | + |
| 556 | + // Quantize to nearest of 4 centroids (2-bit index, no signs byte) |
| 557 | + uint idx = rv < TM2[0] ? 0u : rv < TM2[1] ? 1u : rv < TM2[2] ? 2u : 3u; |
| 558 | + |
| 559 | + // Pack qs: 4 elements per byte (full 2-bit each, no high bit) |
| 560 | + uint sg_lane = gl_SubgroupInvocationID; |
| 561 | + uint qs_byte = 0u; |
| 562 | + [[unroll]] for (uint k = 0; k < 4; k++) { |
| 563 | + uint contrib = subgroupShuffle(idx & 0x3u, (sg_lane & ~3u) + k); |
| 564 | + qs_byte |= contrib << (k * 2u); |
| 565 | + } |
| 566 | + if (sg_lane % 4u == 0u) { |
| 567 | + data_q[db].qs[t / 4u] = uint8_t(qs_byte); |
| 568 | + } |
| 569 | + |
| 570 | + // Reconstruction norm via subgroup reduction |
| 571 | + float rc = TC2[idx] * TC2[idx]; |
| 572 | + rc = subgroupAdd(rc); |
| 573 | + if (sg_lane == 0u) sg_acc_t2[gl_SubgroupID] = rc; |
| 574 | + barrier(); |
| 575 | + if (t == 0u) { |
| 576 | + float total = 0.0; |
| 577 | + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t2[w]; |
| 578 | + float rn = sqrt(total); |
| 579 | + data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm_t2 / rn) : gnrm_t2); |
| 580 | + } |
| 581 | +} |
| 582 | + |
| 583 | +#elif defined(SET_ROWS) && defined(DATA_A_TURBO4_0) |
| 584 | +// Mirror of the TURBO3_0 block above, adapted for turbo4 (16 centroids, |
| 585 | +// 4-bit nibble pack, no signs byte). WHT tables and reduction structure |
| 586 | +// are identical (QK = 128 for both). The block struct keeps a reserved |
| 587 | +// rnorm field for ABI parity with the legacy 3-bit + QJL layout. |
| 588 | +const float TS1_T4[128] = float[128]( |
| 589 | + -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1, |
| 590 | + 1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1, |
| 591 | + -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, |
| 592 | + 1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1, |
| 593 | + -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1, |
| 594 | + 1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1, |
| 595 | + -1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, |
| 596 | + 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1 |
| 597 | +); |
| 598 | +const float TS2_T4[128] = float[128]( |
| 599 | + 1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1, |
| 600 | + 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1, |
| 601 | + 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, |
| 602 | + 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, |
| 603 | + 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1, |
| 604 | + -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, |
| 605 | + 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, |
| 606 | + -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1 |
| 607 | +); |
| 608 | +const float TINV_T4 = 0.08838834764831845; // 1 / sqrt(128) |
| 609 | +// Lloyd-Max centroids for N(0, 1/128), 16 levels (matches CENTROIDS_4BIT in C ref) |
| 610 | +const float TC4[16] = float[16]( |
| 611 | + -0.173926, -0.117195, -0.089527, -0.068756, |
| 612 | + -0.051262, -0.035597, -0.020989, -0.006938, |
| 613 | + 0.006938, 0.020989, 0.035597, 0.051262, |
| 614 | + 0.068756, 0.089527, 0.117195, 0.173926 |
| 615 | +); |
| 616 | +// 15 midpoints between adjacent centroids |
| 617 | +const float TM4[15] = float[15]( |
| 618 | + -0.145561, -0.103361, -0.079142, -0.060009, |
| 619 | + -0.043430, -0.028293, -0.013964, 0.0, |
| 620 | + 0.013964, 0.028293, 0.043430, 0.060009, |
| 621 | + 0.079142, 0.103361, 0.145561 |
| 622 | +); |
| 623 | + |
| 624 | +shared float wht_t4[128]; |
| 625 | +shared float sg_acc_t4[16]; |
| 626 | +shared float gnrm_t4; |
| 627 | + |
| 628 | +void main() { |
| 629 | + const uint t = gl_LocalInvocationID.x; |
| 630 | + const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; |
| 631 | + const uint gpr = p.ne00 / 128; |
| 632 | + |
| 633 | + if (gpr == 0) return; |
| 634 | + if (g >= p.ne / 128) return; |
| 635 | + |
| 636 | + uint tmp = g; |
| 637 | + const uint ig = tmp % gpr; tmp /= gpr; |
| 638 | + const uint i01 = tmp % p.ne01; tmp /= p.ne01; |
| 639 | + const uint i02 = tmp % p.ne12; |
| 640 | + const uint i03 = tmp / p.ne12; |
| 641 | + |
| 642 | + const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset(); |
| 643 | + const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE; |
| 644 | + const uint db = dst_idx(ig, i1, i02, i03) + get_doffset(); |
| 645 | + |
| 646 | + wht_t4[t] = data_s[sb + t]; |
| 647 | + barrier(); |
| 648 | + |
| 649 | + float v2 = wht_t4[t] * wht_t4[t]; |
| 650 | + v2 = subgroupAdd(v2); |
| 651 | + if (gl_SubgroupInvocationID == 0) sg_acc_t4[gl_SubgroupID] = v2; |
| 652 | + barrier(); |
| 653 | + if (t == 0) { |
| 654 | + float total = 0.0; |
| 655 | + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t4[w]; |
| 656 | + gnrm_t4 = sqrt(total); |
| 657 | + } |
| 658 | + barrier(); |
| 659 | + |
| 660 | + wht_t4[t] *= (gnrm_t4 > 1e-10) ? (1.0 / gnrm_t4) : 0.0; |
| 661 | + barrier(); |
| 662 | + |
| 663 | + wht_t4[t] *= TS1_T4[t]; |
| 664 | + barrier(); |
| 665 | + |
| 666 | + [[unroll]] for (uint h = 1; h < 128; h *= 2) { |
| 667 | + if ((t % (2 * h)) < h) { |
| 668 | + float a = wht_t4[t]; |
| 669 | + float b = wht_t4[t + h]; |
| 670 | + wht_t4[t] = a + b; |
| 671 | + wht_t4[t + h] = a - b; |
| 672 | + } |
| 673 | + barrier(); |
| 674 | + } |
| 675 | + |
| 676 | + float rv = wht_t4[t] * TINV_T4 * TS2_T4[t]; |
| 677 | + |
| 678 | + // Quantize to nearest of 16 centroids (4-bit index, no signs byte) |
| 679 | + uint idx = 0u; |
| 680 | + [[unroll]] for (uint i = 0; i < 15; i++) { |
| 681 | + if (rv >= TM4[i]) idx = i + 1u; |
| 682 | + } |
| 683 | + |
| 684 | + // Pack qs: 2 elements per byte (4-bit nibble each) |
| 685 | + uint sg_lane = gl_SubgroupInvocationID; |
| 686 | + uint pair_low = subgroupShuffle(idx & 0xFu, sg_lane & ~1u); |
| 687 | + uint pair_high = subgroupShuffle(idx & 0xFu, (sg_lane & ~1u) + 1u); |
| 688 | + uint qs_byte = pair_low | (pair_high << 4u); |
| 689 | + if (sg_lane % 2u == 0u) { |
| 690 | + data_q[db].qs[t / 2u] = uint8_t(qs_byte); |
| 691 | + } |
| 692 | + |
| 693 | + // Reset rnorm field (reserved in 4-bit mode) |
| 694 | + if (t == 0u) { |
| 695 | + data_q[db].rnorm = float16_t(0.0); |
| 696 | + } |
| 697 | + |
| 698 | + // Reconstruction norm via subgroup reduction |
| 699 | + float rc = TC4[idx] * TC4[idx]; |
| 700 | + rc = subgroupAdd(rc); |
| 701 | + if (sg_lane == 0u) sg_acc_t4[gl_SubgroupID] = rc; |
| 702 | + barrier(); |
| 703 | + if (t == 0u) { |
| 704 | + float total = 0.0; |
| 705 | + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc_t4[w]; |
| 706 | + float rn = sqrt(total); |
| 707 | + data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm_t4 / rn) : gnrm_t4); |
| 708 | + } |
| 709 | +} |
| 710 | + |
472 | 711 | #elif defined(SET_ROWS) && defined(DATA_A_TQ4_1S) |
473 | 712 |
|
474 | 713 | void main() { |
|
0 commit comments