|
| 1 | +# Benchmark: Batch-Level Decoding |
| 2 | + |
| 3 | +## Summary |
| 4 | + |
| 5 | +Replaces the per-item `for i in range(B)` decoding loop with a single set of |
| 6 | +batch-wide tensor operations. This reduces CUDA kernel launches from `B * 8` to |
| 7 | +`~8` regardless of batch size, eliminating the dominant overhead in GPU decoding. |
| 8 | + |
| 9 | +**GPU (bs>=8):** 63-95% decoder speedup (median 85%), all statistically significant. |
| 10 | +**GPU (bs=1):** Neutral (fast path delegates to per-item decoder). |
| 11 | +**CPU (bs>=8):** Mixed. Improvements at very_long inputs (24-42%); regressions at |
| 12 | +short/medium inputs where 4D `torch.where` has higher fixed overhead than B |
| 13 | +separate 3D calls. Absolute CPU regressions are 3-5ms. |
| 14 | + |
| 15 | +## Environment |
| 16 | + |
| 17 | +| | | |
| 18 | +|---|---| |
| 19 | +| Python | 3.13.7 | |
| 20 | +| PyTorch | 2.8.0+cu128 | |
| 21 | +| OS | Linux 6.6.87.2 (WSL2) | |
| 22 | +| GPU | NVIDIA GeForce RTX 5090 | |
| 23 | +| CPU | AMD (via WSL2) | |
| 24 | + |
| 25 | +## Methodology |
| 26 | + |
| 27 | +- **Interleaved A/B**: Old and new paths alternate within each iteration to |
| 28 | + avoid warm-cache bias. Execution order flips on odd/even iterations. |
| 29 | +- **Replications**: 10 warmup + 50 measured iterations per condition (n=50). |
| 30 | +- **Conditions**: 4 batch sizes (1, 8, 16, 32) x 4 input lengths (20, 80, 200, |
| 31 | + 500 tokens) x 2 devices (CPU, GPU) = 32 conditions. |
| 32 | +- **Statistical test**: Welch's t-test (two-sided, unequal variance), p<0.05. |
| 33 | +- **Input realism**: Logits biased to -3.0 so ~5-50 spans per item pass |
| 34 | + threshold=0.5, matching real GLiNER inference. max_width=12, num_classes=8. |
| 35 | +- **Correctness**: Bit-identical output verified for every condition before |
| 36 | + benchmarking. |
| 37 | + |
| 38 | +## GPU Results |
| 39 | + |
| 40 | +The batch-level approach eliminates per-item CUDA kernel launches. At bs=1, a |
| 41 | +fast path delegates to the original per-item decoder to avoid the overhead of |
| 42 | +4D `torch.where` without amortization. |
| 43 | + |
| 44 | +| Batch Size | Input Length | Spans | Old (ms) | New (ms) | Change | p-value | |
| 45 | +|-----------:|:------------|------:|---------:|---------:|-------:|--------:| |
| 46 | +| 1 | short (20) | 3 | 0.862 | 0.889 | -3.2% | 0.42 | |
| 47 | +| 8 | short (20) | 24 | 6.782 | 1.124 | **+83.4%** | <0.001 | |
| 48 | +| 16 | short (20) | 48 | 13.539 | 1.177 | **+91.3%** | <0.001 | |
| 49 | +| 32 | short (20) | 102 | 22.382 | 1.074 | **+95.2%** | <0.001 | |
| 50 | +| 1 | medium (80) | 9 | 0.829 | 0.790 | +4.8% | 0.24 | |
| 51 | +| 8 | medium (80) | 63 | 6.878 | 1.224 | **+82.2%** | <0.001 | |
| 52 | +| 16 | medium (80) | 121 | 13.711 | 1.240 | **+91.0%** | <0.001 | |
| 53 | +| 32 | medium (80) | 273 | 27.550 | 1.613 | **+94.1%** | <0.001 | |
| 54 | +| 1 | long (200) | 19 | 0.789 | 0.764 | +3.1% | 0.39 | |
| 55 | +| 8 | long (200) | 158 | 7.739 | 1.703 | **+78.0%** | <0.001 | |
| 56 | +| 16 | long (200) | 330 | 16.388 | 2.384 | **+85.5%** | <0.001 | |
| 57 | +| 32 | long (200) | 629 | 31.173 | 3.442 | **+89.0%** | <0.001 | |
| 58 | +| 1 | very_long (500) | 50 | 1.303 | 1.349 | -3.5% | 0.12 | |
| 59 | +| 8 | very_long (500) | 366 | 10.149 | 3.716 | **+63.4%** | <0.001 | |
| 60 | +| 16 | very_long (500) | 771 | 18.791 | 6.670 | **+64.5%** | <0.001 | |
| 61 | +| 32 | very_long (500) | 1591 | 38.142 | 12.907 | **+66.2%** | <0.001 | |
| 62 | + |
| 63 | +All values are median wall-clock time over 50 interleaved iterations. Bold |
| 64 | +entries are statistically significant (p<0.05). |
| 65 | + |
| 66 | +### GPU scaling characteristics |
| 67 | + |
| 68 | +The new decoder time is nearly constant across batch sizes for short/medium |
| 69 | +inputs (~1ms), confirming that the fixed overhead is paid once: |
| 70 | + |
| 71 | +``` |
| 72 | +GPU short input: bs=8 → 1.1ms, bs=16 → 1.2ms, bs=32 → 1.1ms |
| 73 | +GPU medium input: bs=8 → 1.2ms, bs=16 → 1.2ms, bs=32 → 1.6ms |
| 74 | +``` |
| 75 | + |
| 76 | +The old path scales linearly with batch size (B kernel launches each): |
| 77 | + |
| 78 | +``` |
| 79 | +Old short input: bs=8 → 6.8ms, bs=16 → 13.5ms, bs=32 → 22.4ms |
| 80 | +Old medium input: bs=8 → 6.9ms, bs=16 → 13.7ms, bs=32 → 27.6ms |
| 81 | +``` |
| 82 | + |
| 83 | +## CPU Results |
| 84 | + |
| 85 | +On CPU, `torch.where` on a 4D tensor has ~3-5ms fixed overhead that doesn't |
| 86 | +exist when calling it B times on 3D slices. This makes the batch path slower |
| 87 | +for short/medium inputs where the per-item cost is already low. At very_long |
| 88 | +inputs, the per-item cost is high enough that batching still wins. |
| 89 | + |
| 90 | +| Batch Size | Input Length | Spans | Old (ms) | New (ms) | Change | p-value | |
| 91 | +|-----------:|:------------|------:|---------:|---------:|-------:|--------:| |
| 92 | +| 1 | short (20) | 2 | 0.027 | 0.027 | -1.1% | 0.31 | |
| 93 | +| 8 | short (20) | 25 | 0.221 | 0.083 | **+62.4%** | <0.001 | |
| 94 | +| 16 | short (20) | 51 | 0.492 | 0.148 | **+69.8%** | <0.001 | |
| 95 | +| 32 | short (20) | 89 | 1.008 | 5.427 | **-438.5%** | <0.001 | |
| 96 | +| 1 | medium (80) | 5 | 0.040 | 0.040 | -0.9% | 0.95 | |
| 97 | +| 8 | medium (80) | 68 | 0.465 | 5.301 | **-1038.8%** | <0.001 | |
| 98 | +| 16 | medium (80) | 127 | 0.799 | 3.790 | **-374.2%** | <0.001 | |
| 99 | +| 32 | medium (80) | 259 | 1.607 | 4.311 | **-168.2%** | <0.001 | |
| 100 | +| 1 | long (200) | 20 | 0.129 | 0.129 | +0.4% | 0.85 | |
| 101 | +| 8 | long (200) | 154 | 0.998 | 6.177 | **-519.0%** | <0.001 | |
| 102 | +| 16 | long (200) | 323 | 2.065 | 1.714 | **+17.0%** | <0.001 | |
| 103 | +| 32 | long (200) | 638 | 4.193 | 3.760 | +10.3% | 0.08 | |
| 104 | +| 1 | very_long (500) | 56 | 0.455 | 0.447 | +1.6% | 0.88 | |
| 105 | +| 8 | very_long (500) | 384 | 5.389 | 3.115 | **+42.2%** | <0.001 | |
| 106 | +| 16 | very_long (500) | 767 | 8.085 | 6.133 | **+24.2%** | <0.001 | |
| 107 | +| 32 | very_long (500) | 1572 | 17.120 | 12.168 | **+28.9%** | <0.001 | |
| 108 | + |
| 109 | +### CPU regression analysis |
| 110 | + |
| 111 | +The CPU regressions share a pattern: the new path's absolute time clusters |
| 112 | +around 3-6ms regardless of batch size or input length, suggesting a fixed floor |
| 113 | +in PyTorch's 4D `torch.where` / `nonzero` implementation on CPU. The per-item |
| 114 | +path avoids this by calling 3D `torch.where` on small tensors (each <50K |
| 115 | +elements), which stays under 0.1ms per call. |
| 116 | + |
| 117 | +The regressions are limited to conditions where the old decoder was already |
| 118 | +fast (<2ms). In absolute terms the worst regression adds ~5ms. At very_long |
| 119 | +inputs where the decoder is the bottleneck (old path 5-17ms), batching |
| 120 | +delivers 24-42% improvement. |
| 121 | + |
| 122 | +Note: CPU benchmarks ran under WSL2, which adds scheduling variance. The |
| 123 | +high stdev on some CPU conditions (30-60% of mean) partly reflects this. |
| 124 | + |
| 125 | +## Why the improvement |
| 126 | + |
| 127 | +The old per-item loop calls `_decode_batch_item` B times, each paying: |
| 128 | +- 1 `torch.where` on (L, K, C) |
| 129 | +- 1 boolean mask + 3 indexing ops |
| 130 | +- 1 score extraction via advanced indexing |
| 131 | +- 5 `.tolist()` GPU→CPU transfers |
| 132 | + |
| 133 | +At bs=32, that's **256 CUDA kernel launches + 160 GPU→CPU transfers**. |
| 134 | + |
| 135 | +The batch-level approach does this once on the full (B, L, K, C) tensor: |
| 136 | +- 1 `torch.where` |
| 137 | +- 1 boolean mask + 3 indexing ops |
| 138 | +- 1 score extraction |
| 139 | +- 6 `.tolist()` transfers |
| 140 | + |
| 141 | +Total: **~8 CUDA ops** regardless of batch size. |
| 142 | + |
| 143 | +## Design decisions |
| 144 | + |
| 145 | +### bs=1 fast path |
| 146 | + |
| 147 | +At batch size 1, there's nothing to amortize — the 4D `torch.where` has |
| 148 | +strictly more overhead than the 3D version. `_decode_batch` detects bs=1 and |
| 149 | +delegates directly to `_decode_batch_item`, matching the old performance |
| 150 | +exactly. Benchmarks confirm bs=1 is neutral on both CPU and GPU (all p>0.1). |
| 151 | + |
| 152 | +### Batched `return_class_probs` |
| 153 | + |
| 154 | +When class probabilities are requested, the old path called `torch.topk` per |
| 155 | +span (N_total kernel launches). The batch path gathers all probability vectors |
| 156 | +in one advanced indexing op, then does one batched `topk` on the (N_total, C) |
| 157 | +matrix. |
| 158 | + |
| 159 | +### Correctness |
| 160 | + |
| 161 | +Output is verified bit-identical for all 32 benchmark conditions. The batch |
| 162 | +`torch.where` returns indices in row-major order (batch, start, width, class), |
| 163 | +so within each batch item the span ordering is identical to the per-item path. |
| 164 | +`greedy_search` sorts by score regardless, so final output is deterministic. |
0 commit comments