Skip to content

Commit 0da5dda

Browse files
committed
Document inference fp16 accumulator default
1 parent 748a5ad commit 0da5dda

2 files changed

Lines changed: 55 additions & 8 deletions

File tree

docs/user_guide/accumulator_precision.md

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Accumulator Precision
22

33
**Created**: 2026-04-15 14:00:00
4-
**Edited**: 2026-05-26 14:30:00
4+
**Edited**: 2026-05-29 21:08:44
55

66
WarpConvNet's mask_gemm kernels use tensor core MMA instructions that support two accumulator modes:
77

@@ -69,12 +69,59 @@ output = spatially_sparse_conv(
6969
)
7070
```
7171

72+
## Automatic FP16 Accumulator under `torch.inference_mode()`
73+
74+
When `use_fp16_accum` is left unresolved (per-module `None` **and** global
75+
setting `False`), the fp16 accumulator is **auto-enabled** if both of these
76+
hold at call time:
77+
78+
1. The convolution runs inside a `torch.inference_mode()` context, and
79+
2. The effective compute dtype is half precision (`torch.float16` or
80+
`torch.bfloat16`).
81+
82+
Rationale: the global default is fp32 accum for training-convergence safety,
83+
but that concern is moot during inference. fp16 accum gives ~2x tensor-core
84+
throughput at no training cost, so it is the better default for low-precision
85+
inference.
86+
87+
```python
88+
import torch
89+
from warpconvnet.nn.modules.sparse_conv import SparseConv3d
90+
91+
conv = SparseConv3d(64, 128, kernel_size=3) # use_fp16_accum=None
92+
93+
# Training / no_grad eval: stays fp32 accum (global default)
94+
with torch.no_grad():
95+
out = conv(x_fp16) # fp32 accumulator
96+
97+
# inference_mode + half precision: auto-enabled fp16 accum
98+
with torch.inference_mode():
99+
out = conv(x_fp16) # fp16 accumulator (auto)
100+
out = conv(x_bf16) # fp16 accumulator (auto)
101+
out = conv(x_fp32) # fp32 accumulator (not half precision)
102+
```
103+
104+
Notes:
105+
106+
- **Only `torch.inference_mode()` triggers it.** Plain `torch.no_grad()` or a
107+
bare `model.eval()` (which does *not* disable grad) will **not** auto-enable
108+
it. `model.eval()` only flips BN/dropout; grad stays on.
109+
- **Explicit settings still win.** A per-module or global `True`/`False` is
110+
respected — the auto-enable only fills in the unresolved (`None` + global
111+
`False`) case.
112+
- **Separate autotune cache entry.** The autotune cache key includes
113+
`use_fp16_accum`, so the inference path (`fa=1`) and the training path
114+
(`fa=0`) keep independent entries. The first half-precision
115+
`inference_mode` pass warms a fresh cache entry rather than reusing the
116+
training one.
117+
72118
## Precedence
73119

74120
1. **Per-module** `use_fp16_accum=True/False` -- highest priority
75121
2. **Global runtime** `warpconvnet.set_fp16_accum(True)` -- used when module is `None`
76122
3. **Environment variable** `WARPCONVNET_USE_FP16_ACCUM=true` -- sets initial global value
77-
4. **Default** `False` (fp32 accumulator)
123+
4. **Auto-enable** -- when unresolved (`None` + global `False`), fp16 accum is turned on inside `torch.inference_mode()` with half-precision (fp16/bf16) compute
124+
5. **Default** `False` (fp32 accumulator)
78125

79126
## Small-Channel F16-Accum Pcoff Allowance
80127

@@ -140,7 +187,7 @@ Wgrad always uses fp32 accumulator regardless of this setting, since weight grad
140187
## Recommendations
141188

142189
- **Training**: Use fp32 accumulator (default). Switch to fp16 only after verifying convergence is unaffected on your model. The F16Acc tiles' per-step relative difference (up to ~7.5e-4 at C=256) accumulates across epochs and has been observed to slow convergence on ScanNet MinkUNet-style models.
143-
- **Inference**: fp16 accumulator is safe and recommended for maximum throughput.
190+
- **Inference**: fp16 accumulator is safe and recommended for maximum throughput. Under `torch.inference_mode()` with half-precision (fp16/bf16) compute it is now enabled **automatically** — see [Automatic FP16 Accumulator under `torch.inference_mode()`](#automatic-fp16-accumulator-under-torchinference_mode).
144191
- **Large channels (C >= 128)**: Largest speedup from fp16 accumulator (~15%).
145192
- **Small channels (C \<= 32)**: Minimal benefit since the computation is memory-bound, not compute-bound.
146193
- **After switching**: clear the cache (`rm ~/.cache/warpconvnet/benchmark_cache_generic.*`) so the pool change triggers a fresh autotune pass.

warpconvnet/nn/functional/sparse_conv/helper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,13 @@ def spatially_sparse_conv(
261261

262262
# Inference default: enable fp16 accumulator for low-precision
263263
# inference even when the global setting is False. Training stability
264-
# (the reason fp32 accum is the global default) is moot with grad
265-
# disabled, and fp16 accum gives ~2x tensor-core throughput. Gated on
266-
# half-precision compute (float16 or bfloat16) and a grad-disabled
267-
# context covering both torch.no_grad() and torch.inference_mode().
264+
# (the reason fp32 accum is the global default) is moot here, and
265+
# fp16 accum gives ~2x tensor-core throughput. Gated on half-precision
266+
# compute (float16 or bfloat16) inside a torch.inference_mode()
267+
# context only (not plain torch.no_grad()).
268268
if (
269269
not use_fp16_accum
270-
and not torch.is_grad_enabled()
270+
and torch.is_inference_mode_enabled()
271271
and effective_compute_dtype in (torch.float16, torch.bfloat16)
272272
):
273273
use_fp16_accum = True

0 commit comments

Comments
 (0)