Skip to content

Commit d931e87

Browse files
Micky774assistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#6867 (commit 3cb0219)
Added custom FMHA codegen receipt for TransformerEngine (#6867) ## Motivation TE uses AITER to build static MHA libraries, which ultimately rely on CK kernels. We use the `600` receipt which generates more kernels than TE truly needs. This bespoke receipt allows us to minimize the kernel count, compile time, and memory footprint of our MHA library. ## Technical Details Extended the receipt mechanism to include a custom `700` receipt for TE's needs ## Test Plan Test by building TE using the same receipt profile ## Test Result Build validated in TE using a custom feature branches of AITER/CK to temporarily apply the patch ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 83566ed commit d931e87

6 files changed

Lines changed: 29 additions & 1 deletion

File tree

example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,8 @@ def get_fwd_blobs(
878878
cond &= pipeline.F_qscale == "no"
879879
if not cond:
880880
continue
881+
elif receipt == 700:
882+
continue # TE does not use this API
881883

882884
# fp32 only
883885
if receipt == 800 or receipt == 801:

example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,12 @@ def get_bwd_blobs(
11491149
cond = dtype in ["fp16", "bf16"]
11501150
if not cond:
11511151
continue
1152+
# TransformerEngine integration
1153+
elif receipt == 700:
1154+
cond = dtype in ["fp16", "bf16"]
1155+
cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"]
1156+
if not cond:
1157+
continue
11521158

11531159
# fp32 only, all variations
11541160
if receipt == 800:

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,20 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
14541454
return cond
14551455

14561456
return Product(name="aiter::mha_fwd C++ api integration", rule=fit)
1457+
# TransformerEngine integration
1458+
elif receipt == 700:
1459+
1460+
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
1461+
cond = problem_ctx.dtype in ["fp16", "bf16"]
1462+
cond &= kernel_ctx.pipeline.F_vlayout == "row"
1463+
cond &= kernel_ctx.pipeline.F_qscale == "no"
1464+
cond &= kernel_ctx.pipeline.F_lse == "t"
1465+
cond &= kernel_ctx.pipeline.F_skip == "f"
1466+
cond &= kernel_ctx.pipeline.F_sink == "f"
1467+
cond &= kernel_ctx.pipeline.F_logits == "f"
1468+
return cond
1469+
1470+
return Product(name="TransformerEngine integration", rule=fit)
14571471
elif receipt == 888:
14581472

14591473
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:

example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,8 @@ def get_fwd_splitkv_blobs(
970970
cond &= pipeline.F_squant == "f"
971971
if not cond:
972972
continue
973+
elif receipt == 700:
974+
continue # TE does not use this API
973975

974976
# fp32 only
975977
if receipt == 800 or receipt == 801:

example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,8 @@ def get_fwd_blobs(
745745
cond &= pipeline.F_squant == "f"
746746
if not cond:
747747
continue
748+
elif receipt == 700:
749+
continue # TE does not use this API
748750

749751
# fp32 only
750752
if receipt == 800 or receipt == 801:

example/ck_tile/01_fmha/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ def list_blobs(
139139
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
140140
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
141141
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
142-
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
142+
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration\n"
143+
+ " 700: Only generate instance for TransformerEngine integration (fwd + bwd, fp16/bf16 only,\n"
144+
+ " invariants: row vlayout, has_lse, no skip/sink/logits/qscale)",
143145
)
144146

145147
parser.add_argument(

0 commit comments

Comments
 (0)