|
1 | 1 | # Accumulator Precision |
2 | 2 |
|
3 | 3 | **Created**: 2026-04-15 14:00:00 |
4 | | -**Edited**: 2026-05-26 14:30:00 |
| 4 | +**Edited**: 2026-05-29 21:08:44 |
5 | 5 |
|
6 | 6 | WarpConvNet's mask_gemm kernels use tensor core MMA instructions that support two accumulator modes: |
7 | 7 |
|
@@ -69,12 +69,59 @@ output = spatially_sparse_conv( |
69 | 69 | ) |
70 | 70 | ``` |
71 | 71 |
|
| 72 | +## Automatic FP16 Accumulator under `torch.inference_mode()` |
| 73 | + |
| 74 | +When `use_fp16_accum` is left unresolved (per-module `None` **and** global |
| 75 | +setting `False`), the fp16 accumulator is **auto-enabled** if both of these |
| 76 | +hold at call time: |
| 77 | + |
| 78 | +1. The convolution runs inside a `torch.inference_mode()` context, and |
| 79 | +2. The effective compute dtype is half precision (`torch.float16` or |
| 80 | + `torch.bfloat16`). |
| 81 | + |
| 82 | +Rationale: the global default is fp32 accum for training-convergence safety, |
| 83 | +but that concern is moot during inference. fp16 accum gives ~2x tensor-core |
| 84 | +throughput at no training cost, so it is the better default for low-precision |
| 85 | +inference. |
| 86 | + |
| 87 | +```python |
| 88 | +import torch |
| 89 | +from warpconvnet.nn.modules.sparse_conv import SparseConv3d |
| 90 | + |
| 91 | +conv = SparseConv3d(64, 128, kernel_size=3) # use_fp16_accum=None |
| 92 | + |
| 93 | +# Training / no_grad eval: stays fp32 accum (global default) |
| 94 | +with torch.no_grad(): |
| 95 | + out = conv(x_fp16) # fp32 accumulator |
| 96 | + |
| 97 | +# inference_mode + half precision: auto-enabled fp16 accum |
| 98 | +with torch.inference_mode(): |
| 99 | + out = conv(x_fp16) # fp16 accumulator (auto) |
| 100 | + out = conv(x_bf16) # fp16 accumulator (auto) |
| 101 | + out = conv(x_fp32) # fp32 accumulator (not half precision) |
| 102 | +``` |
| 103 | + |
| 104 | +Notes: |
| 105 | + |
| 106 | +- **Only `torch.inference_mode()` triggers it.** Plain `torch.no_grad()` or a |
| 107 | + bare `model.eval()` (which does *not* disable grad) will **not** auto-enable |
| 108 | + it. `model.eval()` only flips BN/dropout; grad stays on. |
| 109 | +- **Explicit settings still win.** A per-module or global `True`/`False` is |
| 110 | + respected — the auto-enable only fills in the unresolved (`None` + global |
| 111 | + `False`) case. |
| 112 | +- **Separate autotune cache entry.** The autotune cache key includes |
| 113 | + `use_fp16_accum`, so the inference path (`fa=1`) and the training path |
| 114 | + (`fa=0`) keep independent entries. The first half-precision |
| 115 | + `inference_mode` pass warms a fresh cache entry rather than reusing the |
| 116 | + training one. |
| 117 | + |
72 | 118 | ## Precedence |
73 | 119 |
|
74 | 120 | 1. **Per-module** `use_fp16_accum=True/False` -- highest priority |
75 | 121 | 2. **Global runtime** `warpconvnet.set_fp16_accum(True)` -- used when module is `None` |
76 | 122 | 3. **Environment variable** `WARPCONVNET_USE_FP16_ACCUM=true` -- sets initial global value |
77 | | -4. **Default** `False` (fp32 accumulator) |
| 123 | +4. **Auto-enable** -- when unresolved (`None` + global `False`), fp16 accum is turned on inside `torch.inference_mode()` with half-precision (fp16/bf16) compute |
| 124 | +5. **Default** `False` (fp32 accumulator) |
78 | 125 |
|
79 | 126 | ## Small-Channel F16-Accum Pcoff Allowance |
80 | 127 |
|
@@ -140,7 +187,7 @@ Wgrad always uses fp32 accumulator regardless of this setting, since weight grad |
140 | 187 | ## Recommendations |
141 | 188 |
|
142 | 189 | - **Training**: Use fp32 accumulator (default). Switch to fp16 only after verifying convergence is unaffected on your model. The F16Acc tiles' per-step relative difference (up to ~7.5e-4 at C=256) accumulates across epochs and has been observed to slow convergence on ScanNet MinkUNet-style models. |
143 | | -- **Inference**: fp16 accumulator is safe and recommended for maximum throughput. |
| 190 | +- **Inference**: fp16 accumulator is safe and recommended for maximum throughput. Under `torch.inference_mode()` with half-precision (fp16/bf16) compute it is now enabled **automatically** — see [Automatic FP16 Accumulator under `torch.inference_mode()`](#automatic-fp16-accumulator-under-torchinference_mode). |
144 | 191 | - **Large channels (C >= 128)**: Largest speedup from fp16 accumulator (~15%). |
145 | 192 | - **Small channels (C \<= 32)**: Minimal benefit since the computation is memory-bound, not compute-bound. |
146 | 193 | - **After switching**: clear the cache (`rm ~/.cache/warpconvnet/benchmark_cache_generic.*`) so the pool change triggers a fresh autotune pass. |
0 commit comments