|
37 | 37 |
|
38 | 38 | #include "kernels/chunk_prefill/chunk_prefill_runner.hpp" |
39 | 39 | #include "kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp" |
40 | | -#include "kernels/flash_attention_v2/xe_fmha_fwd_prefill_runner.hpp" |
| 40 | +#include "kernels/flash_attention_v2/xe_fmha_fwd_prefill_dispatch.hpp" |
41 | 41 |
|
42 | 42 | namespace decode { |
43 | 43 |
|
@@ -423,6 +423,294 @@ std::vector<at::Tensor> mha_fwd( |
423 | 423 |
|
424 | 424 | } // namespace decode |
425 | 425 |
|
| 426 | +namespace prefill { |
| 427 | + |
| 428 | +// Dispatch macro following the same pattern as decode. |
| 429 | +// Directly call struct operator() - no function pointers. |
| 430 | + |
| 431 | +#define DISPATCH_PREFILL_KERNEL(HD) FmhaPrefillRunner<HD>{}(params) |
| 432 | + |
| 433 | +std::vector<at::Tensor> mha_fwd( |
| 434 | + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q |
| 435 | + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, |
| 436 | + // h_k, d) if there is page_table. |
| 437 | + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, |
| 438 | + // page_size, h_k, dv) if there is page_table. |
| 439 | + std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q |
| 440 | + const at::Tensor& cu_seqlens_q, // b+1 |
| 441 | + const at::Tensor& cu_seqlens_k, // b+1 |
| 442 | + int max_seqlen_q, |
| 443 | + int max_seqlen_k, |
| 444 | + std::optional<const at::Tensor>& page_table, // (b_k, max_num_pages_per_seq) |
| 445 | + std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache |
| 446 | + std::optional<const at::Tensor>& leftpad_k_, // b |
| 447 | + std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2) |
| 448 | + std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2) |
| 449 | + std::optional<const at::Tensor>& seqlens_rotary_, // b |
| 450 | + std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h) |
| 451 | + std::optional<at::Tensor>& k_descale_, // (b, h_k) |
| 452 | + std::optional<at::Tensor>& v_descale_, // (b, h_k) |
| 453 | + const float softmax_scale_, |
| 454 | + std::optional<const at::Tensor>& sinks_, |
| 455 | + bool is_causal, |
| 456 | + int window_size_left, |
| 457 | + int window_size_right, |
| 458 | + float const softcap, |
| 459 | + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 |
| 460 | + std::optional<at::Tensor>& scheduler_metadata_, // (b + 1) |
| 461 | + int num_splits, |
| 462 | + std::optional<bool> pack_gqa_, |
| 463 | + int const sm_margin) { |
| 464 | + auto q_type = q.scalar_type(); |
| 465 | + TORCH_CHECK( |
| 466 | + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, |
| 467 | + "mha_fwd only supports Half and BFloat16, got", |
| 468 | + q_type); |
| 469 | + |
| 470 | + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); |
| 471 | + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); |
| 472 | + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q); |
| 473 | + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k); |
| 474 | + CHECK_LAST_DIM_CONTIGUOUS_INPUT(v); |
| 475 | + |
| 476 | + TORCH_CHECK(page_table.value().dtype() == torch::kInt32, "page_table must have dtype torch.int32"); |
| 477 | + TORCH_CHECK(page_table.value().stride(-1) == 1, "page_table must have contiguous last dimension"); |
| 478 | + |
| 479 | + TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); |
| 480 | + CHECK_INPUT(cu_seqlens_q); |
| 481 | + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); |
| 482 | + |
| 483 | + CHECK_INPUT(cu_seqlens_k); |
| 484 | + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); |
| 485 | + |
| 486 | + auto const sizes = q.sizes(); |
| 487 | + const int batch_size = cu_seqlens_q.size(0) - 1; |
| 488 | + int seqlen_q = max_seqlen_q; |
| 489 | + int total_q = q.size(0); |
| 490 | + int num_heads = q.size(-2); |
| 491 | + int const head_size = q.size(-1); |
| 492 | + int const head_size_v = v.size(-1); |
| 493 | + int const max_num_pages_per_seq = page_table.value().size(1); |
| 494 | + int const num_pages = k.size(0); |
| 495 | + int const page_size = k.size(1); |
| 496 | + int const seqlen_k = max_num_pages_per_seq * page_size; |
| 497 | + int const total_k = num_pages * page_size; |
| 498 | + int const num_heads_k = k.size(-2); |
| 499 | + |
| 500 | + int const batch_size_k = page_table.value().size(0); |
| 501 | + float softmax_scale = softmax_scale_; |
| 502 | + |
| 503 | + if (!kv_batch_idx_.has_value()) { |
| 504 | + TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); |
| 505 | + } |
| 506 | + |
| 507 | + // Currently only support head dims <= 512 |
| 508 | + static constexpr int max_headdim = 512; |
| 509 | + TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim); |
| 510 | + TORCH_CHECK(num_heads == num_heads_k, "Only support number of heads in key/value equals to number of heads in query"); |
| 511 | + |
| 512 | + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM |
| 513 | + // TODO: check this |
| 514 | + |
| 515 | + if (window_size_left >= seqlen_k - 1) { |
| 516 | + window_size_left = -1; |
| 517 | + } |
| 518 | + window_size_right = min(window_size_right, seqlen_q); |
| 519 | + // causal=true is the same as causal=false in this case |
| 520 | + if (is_causal) { |
| 521 | + window_size_right = 0; |
| 522 | + } |
| 523 | + |
| 524 | + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); |
| 525 | + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); |
| 526 | + CHECK_SHAPE(page_table.value(), batch_size_k, max_num_pages_per_seq); |
| 527 | + |
| 528 | + if (leftpad_k_.has_value()) { |
| 529 | + auto leftpad_k = leftpad_k_.value(); |
| 530 | + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); |
| 531 | + CHECK_INPUT(leftpad_k); |
| 532 | + CHECK_SHAPE(leftpad_k, batch_size); |
| 533 | + } |
| 534 | + |
| 535 | + static constexpr int alignment = 8; |
| 536 | + TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); |
| 537 | + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); |
| 538 | + |
| 539 | + auto opts = q.options(); |
| 540 | + at::Tensor out; |
| 541 | + out = torch::empty({total_q, num_heads, head_size_v}, opts); |
| 542 | + |
| 543 | + int const head_size_rounded = round_up_headdim(head_size); |
| 544 | + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); |
| 545 | + |
| 546 | + // Otherwise the kernel will be launched from cuda:0 device |
| 547 | + // Cast to char to avoid compiler warning about narrowing |
| 548 | + c10::DeviceGuard device_guard(q.device()); |
| 549 | + |
| 550 | + at::Tensor softmax_lse; |
| 551 | + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); |
| 552 | + |
| 553 | + // align with FA3 |
| 554 | + Arguments params; |
| 555 | + params.is_bf16 = q.dtype() == torch::kBFloat16; |
| 556 | + |
| 557 | + // Set the pointers and strides. |
| 558 | + params.q_ptr = q.data_ptr(); |
| 559 | + params.k_ptr = k.data_ptr(); |
| 560 | + params.v_ptr = v.data_ptr(); |
| 561 | + // All stride are in elements, not bytes. |
| 562 | + params.q_row_stride = q.stride(-3); |
| 563 | + params.k_row_stride = k.stride(-3); |
| 564 | + params.v_row_stride = v.stride(-3); |
| 565 | + params.q_head_stride = q.stride(-2); |
| 566 | + params.k_head_stride = k.stride(-2); |
| 567 | + params.v_head_stride = v.stride(-2); |
| 568 | + params.v_dim_stride = v.stride(-1); |
| 569 | + params.o_ptr = out.data_ptr(); |
| 570 | + params.o_row_stride = out.stride(-3); |
| 571 | + params.o_head_stride = out.stride(-2); |
| 572 | + |
| 573 | + params.cu_seqlens_q = cu_seqlens_q.data_ptr<int>(); |
| 574 | + params.cu_seqlens_k = cu_seqlens_k.data_ptr<int>(); |
| 575 | + |
| 576 | + // Softmax sum |
| 577 | + params.softmax_lse_ptr = softmax_lse.data_ptr(); |
| 578 | + |
| 579 | + // Set the dimensions. |
| 580 | + params.b = batch_size; |
| 581 | + params.h = num_heads; |
| 582 | + params.h_k = num_heads_k; |
| 583 | + params.q_group_size = 1; |
| 584 | + params.seqlen_q = seqlen_q; |
| 585 | + params.seqlen_k = seqlen_k; |
| 586 | + params.d = head_size; |
| 587 | + params.d_rounded = head_size_rounded; |
| 588 | + |
| 589 | + // Set the different scale values. |
| 590 | + params.softmax_scale = softmax_scale; |
| 591 | + params.softmax_sink_ptr = sinks_.has_value() ? sinks_.value().data_ptr() : nullptr; |
| 592 | + |
| 593 | + params.softcap = softcap; |
| 594 | + |
| 595 | + // Set this to probability of keeping an element to simplify things. |
| 596 | + params.p_dropout = 1.f; |
| 597 | + |
| 598 | + // Causal is the special case where window_size_right == 0 and window_size_left < 0. |
| 599 | + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. |
| 600 | + params.is_causal = window_size_left < 0 && window_size_right == 0; |
| 601 | + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; |
| 602 | + |
| 603 | + // TODO: check this |
| 604 | + if (window_size_left < 0) { |
| 605 | + window_size_left = seqlen_k - 1; |
| 606 | + } |
| 607 | + if (window_size_right < 0) { |
| 608 | + window_size_right = seqlen_q - 1; |
| 609 | + } |
| 610 | + params.window_size_left = window_size_left; |
| 611 | + params.window_size_right = window_size_right; |
| 612 | + params.total_q = total_q; |
| 613 | + params.total_k = total_k; |
| 614 | + params.b_k = batch_size_k; |
| 615 | + params.dv = head_size_v; |
| 616 | + params.page_table = page_table.value().data_ptr<int>(); |
| 617 | + params.page_table_batch_stride = page_table.value().stride(0); |
| 618 | + params.max_num_pages_per_seq = max_num_pages_per_seq; |
| 619 | + params.page_size = page_size; |
| 620 | + params.num_pages = num_pages; |
| 621 | + |
| 622 | + if (q_v_.has_value()) { |
| 623 | + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); |
| 624 | + TORCH_CHECK( |
| 625 | + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, |
| 626 | + "q_v is only supported for fp16 and bf16 data type"); |
| 627 | + TORCH_CHECK(false, "q_v is not supported yet"); |
| 628 | + at::Tensor q_v = q_v_.value(); |
| 629 | + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); |
| 630 | + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); |
| 631 | + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); |
| 632 | + params.qv_ptr = q_v.data_ptr(); |
| 633 | + // All stride are in elements, not bytes. |
| 634 | + params.qv_row_stride = q_v.stride(-3); |
| 635 | + params.qv_head_stride = q_v.stride(-2); |
| 636 | + } |
| 637 | + |
| 638 | + if (rotary_cos_.has_value()) { |
| 639 | + auto rotary_cos = rotary_cos_.value(); |
| 640 | + CHECK_INPUT(rotary_cos); |
| 641 | + params.rotary_dim = rotary_cos.size(1) * 2; |
| 642 | + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); |
| 643 | + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); |
| 644 | + const int seqlen_ro = rotary_cos.size(0); |
| 645 | + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); |
| 646 | + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); |
| 647 | + TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); |
| 648 | + |
| 649 | + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); |
| 650 | + auto rotary_sin = rotary_sin_.value(); |
| 651 | + CHECK_INPUT(rotary_sin); |
| 652 | + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); |
| 653 | + TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); |
| 654 | + params.rotary_cos_ptr = rotary_cos.data_ptr(); |
| 655 | + params.rotary_sin_ptr = rotary_sin.data_ptr(); |
| 656 | + params.is_rotary_interleaved = is_rotary_interleaved; |
| 657 | + if (seqlens_rotary_.has_value()) { |
| 658 | + at::Tensor seqlens_rotary = seqlens_rotary_.value(); |
| 659 | + CHECK_INPUT(seqlens_rotary); |
| 660 | + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); |
| 661 | + CHECK_SHAPE(seqlens_rotary, batch_size); |
| 662 | + params.seqlens_rotary = seqlens_rotary.data_ptr<int>(); |
| 663 | + } |
| 664 | + } else { |
| 665 | + params.rotary_dim = 0; |
| 666 | + } |
| 667 | + |
| 668 | + if (kv_batch_idx_.has_value()) { |
| 669 | + auto kv_batch_idx = kv_batch_idx_.value(); |
| 670 | + CHECK_INPUT(kv_batch_idx); |
| 671 | + TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); |
| 672 | + params.kv_batch_idx = reinterpret_cast<int*>(kv_batch_idx.data_ptr()); |
| 673 | + } |
| 674 | + |
| 675 | + params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); |
| 676 | + |
| 677 | + at::Tensor out_accum, softmax_lse_accum; |
| 678 | + |
| 679 | + TORCH_CHECK( |
| 680 | + params.d == 64 || params.d == 96 || params.d == 128 || params.d == 192 || params.d == 256 || params.d == 512, |
| 681 | + "Unsupported head size for prefill attention: ", |
| 682 | + params.d); |
| 683 | + |
| 684 | + switch (params.d) { |
| 685 | + case 64: |
| 686 | + DISPATCH_PREFILL_KERNEL(64); |
| 687 | + break; |
| 688 | + case 96: |
| 689 | + DISPATCH_PREFILL_KERNEL(96); |
| 690 | + break; |
| 691 | + case 128: |
| 692 | + DISPATCH_PREFILL_KERNEL(128); |
| 693 | + break; |
| 694 | + case 192: |
| 695 | + DISPATCH_PREFILL_KERNEL(192); |
| 696 | + break; |
| 697 | + case 256: |
| 698 | + DISPATCH_PREFILL_KERNEL(256); |
| 699 | + break; |
| 700 | + case 512: |
| 701 | + DISPATCH_PREFILL_KERNEL(512); |
| 702 | + break; |
| 703 | + default: |
| 704 | + TORCH_CHECK(false, "Unsupported head size for prefill attention: ", params.d); |
| 705 | + } |
| 706 | + |
| 707 | + return {out, softmax_lse, out_accum, softmax_lse_accum}; |
| 708 | +} |
| 709 | + |
| 710 | +#undef DISPATCH_PREFILL_KERNEL |
| 711 | + |
| 712 | +} // namespace prefill |
| 713 | + |
426 | 714 | std::vector<at::Tensor> mha_fwd( |
427 | 715 | const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q |
428 | 716 | const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, |
|
0 commit comments