Commit dc9af4a
Implement 4over6 NVFP4 recipe (#2972)
* Initial implementation
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Make 4over6 compile time for dequant
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Expand 1d fwd+bwd test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Refactor
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add gemm test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add more tests and fix offload
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Fix offload
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up arg
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add more test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add more tests
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Refactor cuh kernel impl
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Further extract
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add recipe_id
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Fix failing unit tests
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Refactor ref
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Update comments and docs
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Drop unnecessary test_sanity workaround
The following tests passed:
`NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
`
`NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
`
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Refactor `QuantizerRole`
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Allow separate recipe 4over6 config
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Support 2d
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Refactor 2d
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up anti pattern
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Enforce 4over6 consistency
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Update comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Update docs
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Fix test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Drop test_fusible_ops
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Revert "Drop test_fusible_ops"
This reverts commit 69f9ccc.
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Refactor test_fusible_ops
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Refactor ref and extend cpp test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Clean up cpp test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Minor comment
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Drop doc
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Explicit handle conditional smem buffer
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Further clean up
Signed-off-by: Ziang Li <ziangli@umich.edu>
* More templates
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Simplify cpp
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Drop write back lifting
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add MAE and dedicated fast math env var
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Harden cpp test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add warning and err fast math coverage
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Fold test case and clean up cpp test
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Initial 448 vs 256 implementation
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Use e4m3 max instead of boolean, more template
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add benchmark script and minor optimization
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Use standalone kernels
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Use cp async
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Add benchmark script
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Minor fix after rebase
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Naming consistency
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Remove 4over6 benchmark
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Refactor modes
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Relax tol for `test_layernorm_mlp` for `nvfp4_4over6`
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Minor fix recipe naming
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Remove gradient 4over6 quantization and partially allow SR/RHT
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Allow RHT in pytorch ref
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Update transformer_engine/pytorch/csrc/quantizer.cpp
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* Minor fix TODO lint
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Use standard nvfp4 for grad ref in test_fusible_ops.py since 4over6 is not applied to gradient quantizers
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Minor fix test-fusible_ops 4over6 helper
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Default to 256 for 4over6
Signed-off-by: Ziang Li <ziangli@umich.edu>
* Reset RNG state for each TE ops test
Adding tests affected RNG in unrelated tests.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Remove loosened NVFP4 tols in layernorm MLP test.
Make sure tensors are representable in quantized format.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent 9af70a8 commit dc9af4a
37 files changed
Lines changed: 2595 additions & 251 deletions
File tree
- docs
- tests
- cpp
- operator
- pytorch
- nvfp4
- transformer_engine
- common
- cast
- dispatch
- nvfp4
- comm_gemm_overlap
- include/transformer_engine
- recipe
- pytorch
- csrc
- extensions
- custom_recipes
- tensor
- storage
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
287 | 287 | | |
288 | 288 | | |
289 | 289 | | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
290 | 314 | | |
291 | 315 | | |
292 | 316 | | |
| |||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
49 | | - | |
50 | | - | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
51 | 52 | | |
52 | 53 | | |
53 | 54 | | |
| |||
86 | 87 | | |
87 | 88 | | |
88 | 89 | | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
89 | 95 | | |
90 | 96 | | |
91 | 97 | | |
92 | 98 | | |
93 | | - | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
94 | 102 | | |
95 | 103 | | |
96 | 104 | | |
| |||
105 | 113 | | |
106 | 114 | | |
107 | 115 | | |
| 116 | + | |
| 117 | + | |
108 | 118 | | |
109 | 119 | | |
110 | 120 | | |
| |||
116 | 126 | | |
117 | 127 | | |
118 | 128 | | |
119 | | - | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
120 | 132 | | |
121 | 133 | | |
122 | 134 | | |
| |||
146 | 158 | | |
147 | 159 | | |
148 | 160 | | |
149 | | - | |
| 161 | + | |
150 | 162 | | |
151 | 163 | | |
152 | 164 | | |
| |||
156 | 168 | | |
157 | 169 | | |
158 | 170 | | |
159 | | - | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
160 | 174 | | |
161 | 175 | | |
162 | 176 | | |
| |||
165 | 179 | | |
166 | 180 | | |
167 | 181 | | |
| 182 | + | |
| 183 | + | |
168 | 184 | | |
169 | 185 | | |
170 | 186 | | |
| |||
174 | 190 | | |
175 | 191 | | |
176 | 192 | | |
177 | | - | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
178 | 196 | | |
179 | 197 | | |
180 | 198 | | |
| |||
186 | 204 | | |
187 | 205 | | |
188 | 206 | | |
| 207 | + | |
| 208 | + | |
189 | 209 | | |
190 | 210 | | |
191 | 211 | | |
| |||
260 | 280 | | |
261 | 281 | | |
262 | 282 | | |
263 | | - | |
| 283 | + | |
| 284 | + | |
264 | 285 | | |
265 | 286 | | |
266 | 287 | | |
| |||
271 | 292 | | |
272 | 293 | | |
273 | 294 | | |
| 295 | + | |
274 | 296 | | |
275 | 297 | | |
276 | 298 | | |
277 | | - | |
| 299 | + | |
| 300 | + | |
278 | 301 | | |
279 | 302 | | |
280 | 303 | | |
| |||
284 | 307 | | |
285 | 308 | | |
286 | 309 | | |
287 | | - | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
288 | 314 | | |
289 | 315 | | |
| 316 | + | |
| 317 | + | |
290 | 318 | | |
291 | 319 | | |
292 | 320 | | |
293 | | - | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
294 | 324 | | |
295 | 325 | | |
296 | 326 | | |
297 | 327 | | |
298 | 328 | | |
299 | 329 | | |
300 | 330 | | |
301 | | - | |
| 331 | + | |
| 332 | + | |
302 | 333 | | |
303 | 334 | | |
304 | 335 | | |
| |||
309 | 340 | | |
310 | 341 | | |
311 | 342 | | |
| 343 | + | |
312 | 344 | | |
313 | 345 | | |
314 | 346 | | |
315 | | - | |
| 347 | + | |
| 348 | + | |
316 | 349 | | |
317 | 350 | | |
318 | 351 | | |
| |||
322 | 355 | | |
323 | 356 | | |
324 | 357 | | |
325 | | - | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
326 | 362 | | |
327 | 363 | | |
| 364 | + | |
| 365 | + | |
328 | 366 | | |
329 | 367 | | |
330 | 368 | | |
331 | 369 | | |
| 370 | + | |
| 371 | + | |
332 | 372 | | |
333 | 373 | | |
334 | 374 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
440 | 440 | | |
441 | 441 | | |
442 | 442 | | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
443 | 455 | | |
444 | 456 | | |
445 | 457 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
293 | 293 | | |
294 | 294 | | |
295 | 295 | | |
| 296 | + | |
| 297 | + | |
296 | 298 | | |
297 | 299 | | |
298 | 300 | | |
299 | 301 | | |
| 302 | + | |
300 | 303 | | |
301 | 304 | | |
302 | 305 | | |
| |||
0 commit comments