|
4 | 4 | #pragma once |
5 | 5 |
|
6 | 6 | #include "ck/utility/common_header.hpp" |
| 7 | +#include "ck/utility/env.hpp" |
7 | 8 | #include "ck/tensor_description/multi_index_transform_helper.hpp" |
8 | 9 | #include "ck/tensor_description/tensor_descriptor.hpp" |
9 | 10 | #include "ck/tensor_description/tensor_descriptor_helper.hpp" |
@@ -606,6 +607,203 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 |
606 | 607 | c_block_size * sizeof(CShuffleDataType)); |
607 | 608 | } |
608 | 609 |
|
| 610 | + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} |
| 611 | + __host__ static constexpr bool CheckValidity(const Argument& karg) |
| 612 | + { |
| 613 | + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && |
| 614 | + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, |
| 615 | + "Invalid tuning param!"); |
| 616 | + |
| 617 | + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || |
| 618 | + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || |
| 619 | + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || |
| 620 | + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && |
| 621 | + !(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)) |
| 622 | + { |
| 623 | + if(!(karg.M % MPerBlock == 0)) |
| 624 | + { |
| 625 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 626 | + { |
| 627 | + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " |
| 628 | + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ |
| 629 | + << std::endl; |
| 630 | + } |
| 631 | + return false; |
| 632 | + } |
| 633 | + } |
| 634 | + |
| 635 | + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || |
| 636 | + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || |
| 637 | + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || |
| 638 | + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && |
| 639 | + (is_same<tensor_layout::gemm::RowMajor, BLayout>::value)) |
| 640 | + { |
| 641 | + if(!(karg.N % NPerBlock == 0)) |
| 642 | + { |
| 643 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 644 | + { |
| 645 | + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " |
| 646 | + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ |
| 647 | + << std::endl; |
| 648 | + } |
| 649 | + return false; |
| 650 | + } |
| 651 | + } |
| 652 | + |
| 653 | + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || |
| 654 | + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || |
| 655 | + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || |
| 656 | + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) |
| 657 | + { |
| 658 | + |
| 659 | + auto K_t = karg.KBatch * KPerBlock; |
| 660 | + if(!(karg.K % K_t == 0)) |
| 661 | + { |
| 662 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 663 | + { |
| 664 | + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " |
| 665 | + << karg.K << " " << __FILE__ << ":" << __LINE__ |
| 666 | + << ", in function: " << __func__ << std::endl; |
| 667 | + } |
| 668 | + return false; |
| 669 | + } |
| 670 | + } |
| 671 | + else |
| 672 | + { |
| 673 | + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); |
| 674 | + auto K_t = karg.KBatch * KReadVec; |
| 675 | + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; |
| 676 | + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) |
| 677 | + { |
| 678 | + return false; |
| 679 | + } |
| 680 | + } |
| 681 | + |
| 682 | + if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) |
| 683 | + { |
| 684 | + if(karg.K % ABlockTransferSrcScalarPerVector != 0) |
| 685 | + { |
| 686 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 687 | + { |
| 688 | + std::cout << "Arg K (" << karg.K |
| 689 | + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" |
| 690 | + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" |
| 691 | + << __LINE__ << ", in function: " << __func__ << std::endl; |
| 692 | + } |
| 693 | + return false; |
| 694 | + } |
| 695 | + } |
| 696 | + else |
| 697 | + { |
| 698 | + if(karg.M % ABlockTransferSrcScalarPerVector != 0) |
| 699 | + { |
| 700 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 701 | + { |
| 702 | + std::cout << "Arg M (" << karg.M |
| 703 | + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" |
| 704 | + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" |
| 705 | + << __LINE__ << ", in function: " << __func__ << std::endl; |
| 706 | + } |
| 707 | + return false; |
| 708 | + } |
| 709 | + } |
| 710 | + |
| 711 | + if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) |
| 712 | + { |
| 713 | + if(karg.N % BBlockTransferSrcScalarPerVector != 0) |
| 714 | + { |
| 715 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 716 | + { |
| 717 | + std::cout << "Arg N (" << karg.N |
| 718 | + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" |
| 719 | + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" |
| 720 | + << __LINE__ << ", in function: " << __func__ << std::endl; |
| 721 | + } |
| 722 | + return false; |
| 723 | + } |
| 724 | + } |
| 725 | + else |
| 726 | + { |
| 727 | + if(karg.K % BBlockTransferSrcScalarPerVector != 0) |
| 728 | + { |
| 729 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 730 | + { |
| 731 | + std::cout << "Arg K (" << karg.K |
| 732 | + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" |
| 733 | + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" |
| 734 | + << __LINE__ << ", in function: " << __func__ << std::endl; |
| 735 | + } |
| 736 | + return false; |
| 737 | + } |
| 738 | + } |
| 739 | + |
| 740 | + if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) |
| 741 | + { |
| 742 | + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) |
| 743 | + { |
| 744 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 745 | + { |
| 746 | + std::cout << "Arg N (" << karg.N |
| 747 | + << ") value is not a multiple of " |
| 748 | + "CShuffleBlockTransferScalarPerVector_NPerBlock (" |
| 749 | + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " |
| 750 | + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ |
| 751 | + << std::endl; |
| 752 | + } |
| 753 | + return false; |
| 754 | + } |
| 755 | + } |
| 756 | + else |
| 757 | + { |
| 758 | + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) |
| 759 | + { |
| 760 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 761 | + { |
| 762 | + std::cout << "Arg M (" << karg.M |
| 763 | + << ") value is not a multiple of " |
| 764 | + "CShuffleBlockTransferScalarPerVector_NPerBlock (" |
| 765 | + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " |
| 766 | + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ |
| 767 | + << std::endl; |
| 768 | + } |
| 769 | + return false; |
| 770 | + } |
| 771 | + } |
| 772 | + |
| 773 | + if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value || |
| 774 | + is_same<remove_cvref_t<CDataType>, float>::value || |
| 775 | + is_same<remove_cvref_t<CDataType>, bhalf_t>::value || |
| 776 | + is_same<remove_cvref_t<CDataType>, int32_t>::value)) |
| 777 | + { |
| 778 | + if(!karg.IsReduceAdd()) |
| 779 | + { |
| 780 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 781 | + { |
| 782 | + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ |
| 783 | + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; |
| 784 | + } |
| 785 | + if(karg.KBatch > 1) |
| 786 | + { |
| 787 | + return false; |
| 788 | + } |
| 789 | + } |
| 790 | + } |
| 791 | + |
| 792 | + // check gridwise gemm pipeline |
| 793 | + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); |
| 794 | + |
| 795 | + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) |
| 796 | + { |
| 797 | + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) |
| 798 | + { |
| 799 | + return false; |
| 800 | + } |
| 801 | + } |
| 802 | + |
| 803 | + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) |
| 804 | + return true; |
| 805 | + } |
| 806 | + |
609 | 807 | __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) |
610 | 808 | { |
611 | 809 | const index_t num_loop = K / KPerBlock; |
|
0 commit comments