impr: Use f16 and better subgroup shader in MNIST example#2412
Conversation
|
pkg.pr.new packages benchmark commit |
📊 Bundle Size Comparison
👀 Notable resultsStatic test results:No major changes. Dynamic test results:No major changes. 📋 All resultsClick to reveal the results table (354 entries).
If you wish to run a comparison for other, slower bundlers, run the 'Tree-shake test' from the GitHub Actions menu. |
Resolution Time Benchmark---
config:
themeVariables:
xyChart:
plotColorPalette: "#E63946, #3B82F6, #059669"
---
xychart
title "Random Branching (🔴 PR | 🔵 main | 🟢 release)"
x-axis "max depth" [1, 2, 3, 4, 5, 6, 7, 8]
y-axis "time (ms)"
line [0.95, 1.93, 4.42, 6.49, 7.69, 11.01, 22.11, 22.52]
line [0.99, 2.06, 4.40, 6.69, 7.60, 11.03, 21.56, 24.80]
line [0.91, 1.89, 4.39, 6.48, 7.36, 10.73, 21.20, 22.31]
---
config:
themeVariables:
xyChart:
plotColorPalette: "#E63946, #3B82F6, #059669"
---
xychart
title "Linear Recursion (🔴 PR | 🔵 main | 🟢 release)"
x-axis "max depth" [1, 2, 3, 4, 5, 6, 7, 8]
y-axis "time (ms)"
line [0.34, 0.53, 0.64, 0.81, 1.13, 1.18, 1.40, 1.61]
line [0.29, 0.54, 0.71, 0.87, 1.18, 1.25, 1.47, 1.58]
line [0.31, 0.52, 0.71, 0.87, 1.12, 1.14, 1.38, 1.53]
---
config:
themeVariables:
xyChart:
plotColorPalette: "#E63946, #3B82F6, #059669"
---
xychart
title "Full Tree (🔴 PR | 🔵 main | 🟢 release)"
x-axis "max depth" [1, 2, 3, 4, 5, 6, 7, 8]
y-axis "time (ms)"
line [0.89, 2.07, 4.26, 6.69, 12.73, 26.66, 56.97, 119.53]
line [0.82, 2.17, 3.63, 6.59, 12.54, 26.94, 55.85, 114.15]
line [0.83, 2.14, 4.31, 6.51, 13.08, 27.06, 55.75, 113.60]
|
79b1935 to
4328747
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| export function downloadLayers( | ||
| root: TgpuRoot, | ||
| floatShcema: d.F32 | d.F16, | ||
| ): Promise<[LayerData, LayerData][]> { |
There was a problem hiding this comment.
Typo in parameter name floatShcema (should be floatSchema). Keeping the misspelling makes the API harder to read/search and increases the chance of propagating the typo to call sites.
| const outputCount = buffers[i].biases.dataType.elementCount; | ||
| boundPipeline.dispatchWorkgroups( | ||
| subgroupPipeline ? outputCount : Math.ceil(outputCount / WORKGROUP_SIZE), | ||
| ); |
There was a problem hiding this comment.
dispatchWorkgroups uses outputCount when the subgroup pipeline is selected, but subgroupCompute computes num_subgroups outputs per workgroup (neuronIndex = wid.x * nsg + sgid). This over-dispatches workgroups by a factor of nsg (e.g., 2x for 64 threads with 32-wide subgroups), doing unnecessary work for larger layers. Consider either dispatching ceil(outputCount / outputsPerWorkgroup) (if you can determine outputsPerWorkgroup) or adjusting the shader/work mapping so each workgroup corresponds to exactly one output when dispatch count must be outputCount.
| const context = canvas.getContext('2d') as CanvasRenderingContext2D; | ||
|
|
||
| const bars = Array.from(document.querySelectorAll('.bar')) as HTMLDivElement[]; | ||
| const subgroupsEl = document.getElementById('subgroups-status') as HTMLSpanElement; |
There was a problem hiding this comment.
Could we also include status for f16? I don't think there is a way to know whether the shader runs on f16 or f32 at this moment
4328747 to
82d84e8
Compare
No description provided.