Use input_precision="ieee" in matmul_kernel_persistent for fp32 inputs#24
Open
chenwenxiaolive wants to merge 1 commit into
Open
Conversation
tl.dot defaults to TF32 for fp32 inputs on NVIDIA Tensor Cores (10-bit
mantissa; integers exact only to 2**11=2048). matmul_kernel_persistent backs
aten::{mm,addmm} under batch-invariant mode, so enabling the mode silently
downgrades every fp32 matmul from IEEE (torch.mm's default) to TF32, corrupting
values > 2048 -- e.g. the rotary-embedding phase (positions @ inv_freq) for long
context.
Add input_precision="ieee" so fp32 keeps full precision. No-op for bf16/fp16
(TF32 only applies to fp32 inputs), so the typical GEMM path is unchanged.
Adds test_fp32_matmul_precision() to test_batch_invariance.py.
Closes thinking-machines-lab#23
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Use
input_precision="ieee"inmatmul_kernel_persistentfor fp32 inputsCloses #23
Problem
matmul_kernel_persistentcallstl.dot(a, b, accumulator)withoutinput_precision. On NVIDIA Tensor-Core GPUs,tl.dotdefaults to TF32 for fp32 inputs (Triton docs) — a 10-bit mantissa, so integers are exact only to 2¹¹ = 2048.This kernel backs
aten::{mm, addmm}underset_batch_invariant_mode(True). PyTorch's owntorch.mmruns fp32 in IEEE by default (allow_tf32 == Falsesince 1.12), so enabling batch-invariant mode silently downgrades fp32 matmul precision — an accuracy change unrelated to batch invariance.The standard rotary-embedding phase is an fp32 matmul (
positions ⊗ inv_freq). Through this kernel with real Llama-3.1 params (head_dim=128,rope_theta=500000),cos()of the phase diverges from the IEEE reference by up to 1.35 over a 4096-token context (and exactly 0.0 with this fix).Change
input_precision="ieee"is a no-op for bf16/fp16 (TF32 only applies to fp32 inputs), so the typical GEMM path is unchanged.Test
Adds
test_fp32_matmul_precision()totest_batch_invariance.py— asserts fp32arange(8192)round-trips exactly throughtorch.mmunder batch-invariant mode. Fails before this change (4096 mismatches), passes after.