Skip to content

Use input_precision="ieee" in matmul_kernel_persistent for fp32 inputs#24

Open
chenwenxiaolive wants to merge 1 commit into
thinking-machines-lab:mainfrom
chenwenxiaolive:fix/fp32-tf32-input-precision
Open

Use input_precision="ieee" in matmul_kernel_persistent for fp32 inputs#24
chenwenxiaolive wants to merge 1 commit into
thinking-machines-lab:mainfrom
chenwenxiaolive:fix/fp32-tf32-input-precision

Conversation

@chenwenxiaolive

Copy link
Copy Markdown

Use input_precision="ieee" in matmul_kernel_persistent for fp32 inputs

Closes #23

Problem

matmul_kernel_persistent calls tl.dot(a, b, accumulator) without input_precision. On NVIDIA Tensor-Core GPUs, tl.dot defaults to TF32 for fp32 inputs (Triton docs) — a 10-bit mantissa, so integers are exact only to 2¹¹ = 2048.

This kernel backs aten::{mm, addmm} under set_batch_invariant_mode(True). PyTorch's own torch.mm runs fp32 in IEEE by default (allow_tf32 == False since 1.12), so enabling batch-invariant mode silently downgrades fp32 matmul precision — an accuracy change unrelated to batch invariance.

import torch
from batch_invariant_ops import set_batch_invariant_mode

pos  = torch.arange(8192, device="cuda", dtype=torch.float32)  # exact in fp32
ones = torch.ones(1, 1, device="cuda", dtype=torch.float32)
with set_batch_invariant_mode(True):
    out = torch.mm(ones, pos[None, :])
print(int((out[0] != pos).sum()), out[0, 2049].item())   # 4096   2048.0

The standard rotary-embedding phase is an fp32 matmul (positions ⊗ inv_freq). Through this kernel with real Llama-3.1 params (head_dim=128, rope_theta=500000), cos() of the phase diverges from the IEEE reference by up to 1.35 over a 4096-token context (and exactly 0.0 with this fix).

Change

-            accumulator = tl.dot(a, b, accumulator)
+            accumulator = tl.dot(a, b, accumulator, input_precision="ieee")

input_precision="ieee" is a no-op for bf16/fp16 (TF32 only applies to fp32 inputs), so the typical GEMM path is unchanged.

Test

Adds test_fp32_matmul_precision() to test_batch_invariance.py — asserts fp32 arange(8192) round-trips exactly through torch.mm under batch-invariant mode. Fails before this change (4096 mismatches), passes after.

Batch-Invariant Mode:
Batch Deterministic: True ... for torch.float32 in 10 iterations
Batch Deterministic: True ... for torch.bfloat16 in 10 iterations

fp32 precision (TF32 regression):
  fp32 matmul mismatches vs IEEE reference: 0 (expected 0)

tl.dot defaults to TF32 for fp32 inputs on NVIDIA Tensor Cores (10-bit
mantissa; integers exact only to 2**11=2048). matmul_kernel_persistent backs
aten::{mm,addmm} under batch-invariant mode, so enabling the mode silently
downgrades every fp32 matmul from IEEE (torch.mm's default) to TF32, corrupting
values > 2048 -- e.g. the rotary-embedding phase (positions @ inv_freq) for long
context.

Add input_precision="ieee" so fp32 keeps full precision. No-op for bf16/fp16
(TF32 only applies to fp32 inputs), so the typical GEMM path is unchanged.

Adds test_fp32_matmul_precision() to test_batch_invariance.py.

Closes thinking-machines-lab#23

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

matmul_kernel_persistent runs fp32 inputs through TF32 — silent precision loss vs torch.mm

1 participant