Support of gdn kernel from tpu-inference#4051
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
🤖 Hi @khatwanimohit, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces support for the Gated Delta Net (GDN) kernel from tpu-inference into MaxText, specifically for vLLM decoding. It includes integration of the GDN logic, sharding support, profiling capabilities, and necessary monkey-patches for hybrid Attention+GDN models.
🔍 General Feedback
- Configurability: The use of environment variables like
MAXTEXT_GDN_REPLICATE_EXPERTfor model behavior should be transitioned to the formalConfigsystem for better reproducibility. - Robustness: Global monkey-patching of library classes (e.g.,
ModelConfig.uses_mropeandKVCacheManager) is quite brittle and may lead to issues with future updates or different model types. - Performance: The dynamic padding logic in
Qwen3NextGatedDeltaNetcould trigger frequent JAX re-compilations if batch sizes vary, which should be addressed for production workloads.
|
🤖 Hi @khatwanimohit, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces comprehensive support for Gated Delta Net (GDN) kernels from tpu-inference into MaxText, including integration with vLLM and necessary sharding updates. The changes are logically sound and significantly improve the flexibility of the Qwen3/Qwen3.5 model implementations.
🔍 General Feedback
- Consistency: The simplification of
kv_cachehandling indecoders.pyis a great improvement, bringing Qwen3 into alignment with other models in the codebase. - Complexity: The monkey-patching and adapter logic are necessary evils for this level of integration, but should be monitored closely for breakages when upstream
vLLMortpu-inferenceAPIs evolve. - Performance: The use of
shard_mapand specialized kernels inqwen3.pydemonstrates a high level of optimization for TPU sharding.
|
🤖 I'm sorry @khatwanimohit, but I was unable to process your request. Please see the logs for more details. |
Description
You can also provide a comma-separated list. If you don't want to close a bug but
simply to reference it, use BUGS, e.g.:
BUGS: b/517158881
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Profile: https://xprof.corp.google.com/trace_viewer/mohitkhatwani-11546393038070241787?view_start=71184.720&view_end=71204.219
Logs: https://paste.googleplex.com/5464600022745088
Decode performance: 25ms
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.