Add W4A8 INT8 activation kernels for batched MoE prefill#19187
Add W4A8 INT8 activation kernels for batched MoE prefill#19187digantdesai wants to merge 5 commits intogh/digantdesai/50/basefrom
Conversation
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreply@anthropic.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19187
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 New Failures, 5 Cancelled Jobs, 3 Unrelated FailuresAs of commit 2b1e1eb with merge base cb4e5ae ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> [ghstack-poisoned]
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> [ghstack-poisoned]
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> [ghstack-poisoned]
Gasoonjia
left a comment
There was a problem hiding this comment.
Thanks your work! Can you also update the ci to use int8 activation type for moe prefill?
| help="Disable split-K (flash-decoding) SDPA for decode; use tiled SDPA instead.", | ||
| ) | ||
| parser.add_argument( | ||
| "--moe-activation-dtype", |
There was a problem hiding this comment.
maybe we call prefill-moe-activation-dtype would be better?
There was a problem hiding this comment.
I didn't do that because not doing int8 for decode is something we may revisit later.
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
INT8 tensor core variants of the batched MoE GEMM kernels that
dynamically quantize bf16 activations to INT8 per-row per-tile and
dequantize INT4 weights directly to INT8 (skipping bf16 conversion).
Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32
rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine
similarity vs bf16 baseline.
Co-authored-by: Claude noreply@anthropic.com