[Do Not Merge] Optimizations on Qwen3-Next GatedDeltaNet w/ Kernel & XProf Agent#3077
Open
Rohan-Bierneni wants to merge 5 commits intomainfrom
Open
[Do Not Merge] Optimizations on Qwen3-Next GatedDeltaNet w/ Kernel & XProf Agent#3077Rohan-Bierneni wants to merge 5 commits intomainfrom
Rohan-Bierneni wants to merge 5 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Add backward pass checks & memory checks Add backward pass & memory consumption checks Update memory calcs Optimizations made to GDN impl in qwen3.py (3x speedup) Update dummy configs to align with q3-next Update tflops calc to align with WY-optimized GDN remove mixed precision Update config for chunk size update dtype Add NaN test in backward pass Fix exploding gradient in gdn Reintroduce mixed precision typo in bloat16 typo fixed convert to float test pallas kernel for gdn wrong api name fix function positional args fix pallas code fix tensor indexing error only optimize forward pass update pallas code use float mask fix function returns add shardmap to kernel update with kernel agent suggestions fix matrix indexing fix matrix indexing mask before exp update benchmarking script
acc1d56 to
b03fe07
Compare
|
This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Using the suggestions from kernel & xprof agent, try to improve the Gated Delta Net implementation in qwen3.py. We test our changes using the script added as part of this pr. The script tests the forward pass, backward pass, overall train step, and memory consumption between the baseline implementation of the GDN versus our optimized version in qwen3.py. This allowed us to test out changes iteratively and quickly.
To test the script, please use the command:
Note: run this script on a TPU/GPU vm since on CPU it will take a while.
So far, total improvements on the Gated Delta Rule using Q3-Next configs & 4k Seq len are:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
Tests
Tested our changes using the benchmarking script and pr unit tests (train_compile test for qwen3 next)
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.