Skip to content

RL with vllm-native support (qwen3 converter)#3767

Merged
copybara-service[bot] merged 1 commit intomainfrom
hengtaoguo-rl-merge
May 1, 2026
Merged

RL with vllm-native support (qwen3 converter)#3767
copybara-service[bot] merged 1 commit intomainfrom
hengtaoguo-rl-merge

Conversation

@hengtaoguo
Copy link
Copy Markdown
Collaborator

@hengtaoguo hengtaoguo commented Apr 28, 2026

Description

Implement/refactor MaxText to vLLM weight conversion into a reusable converter style:

  • Introduce BaseMaxTextToVLLMConverter as the shared converter class, and isolate the Qwen3-MOE implementation into torchax_converter/qwen3_moe.py.
  • Add a standalone, config-driven validator module at maxtext.integration.vllm.torchax_converter.validate_converter for VM testing and generation-based weight-transfer checks.
  • Add a use_standalone_converter config flag so RL rollout can explicitly opt into the standalone MaxText to vLLM converter path.
  • Update rollout integration to use MaxTextVllmRollout with model-specific converter creation instead of importing converter logic from the bench script.
  • TP-aware padding to 128x is also added in qwen3 converter _make_fuse_all.

Tests

requirements (tpu-inference/vllm are critical)
logs

Standalone weight sync and decode on a v5p-8 VM:

python -m maxtext.integration.vllm.torchax_converter.validate_converter src/maxtext/configs/base.yml model_name=qwen3-30b-a3b tokenizer_type=huggingface tokenizer_path=Qwen/Qwen3-30B-A3B  load_parameters_path=gs://hengtaoguo-maxtext-logs/checkpoints/qwen3-30b-a3b/scanned/2026-01-23-14-00/0/items/0/items run_name=qwen3_converter_validation per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 attention=dot_product remat_policy=custom decoder_layer_input=offload query_proj=offload key_proj=offload value_proj=offload rollout_tensor_parallelism=4 hbm_utilization_vllm=0.6 async_scheduling=false prompt=Paris\ is hf_access_token=xxx
Assigning 435 weights to vLLM model took 0.1711 seconds

================================================================================
Generation test after weight transfer:
Rendering prompts: 100%|█| 1/1 [00:00<00:
Processed prompts:   0%| | 0/1 [00:00<?, /home/hengtaoguo_google_com/projects/venv2/lib/python3.12/site-packages/torchax/tensor.py:167: UserWarning: Explicitly requested dtype int64 requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  res = jax_function(self._elem, *args, **kwargs)
WARNING 04-30 04:23:44 [tuned_block_sizes.py:4368] Couldn`t find tuned sizes for the RPA v3 kernel with ('TPU v5', 16, 'q_bfloat16_kv_bfloat16', 'q_head-8_kv_head-1_head-128', 'max_model_len-16-sw-None')
INFO 04-30 04:23:44 [tuned_block_sizes.py:4389] RPA v3 kernel tuned block sizes for ('TPU v5', 16, 'q_bfloat16_kv_bfloat16', 'q_head-8_kv_head-1_head-128', 'max_model_len-16-sw-None'): bkv_p=1, bq=16
Processed prompts: 100%|█| 1/1 [00:06<00:
[RequestOutput(request_id=0, prompt='Paris is', prompt_token_ids=[59604, 374], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' the capital of France. It is also the city', token_ids=[279, 6722, 315, 9625, 13, 1084, 374, 1083, 279, 3283], routed_experts=None, cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0)]
Generation took 6.9421 seconds

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@hengtaoguo hengtaoguo force-pushed the hengtaoguo-rl-merge branch 4 times, most recently from 49611e5 to 150c252 Compare April 30, 2026 04:55
@hengtaoguo hengtaoguo marked this pull request as ready for review April 30, 2026 04:55
@hengtaoguo hengtaoguo changed the title [WIP] RL with vllm-native support RL with vllm-native support (qwen3 converter) Apr 30, 2026
@hengtaoguo hengtaoguo force-pushed the hengtaoguo-rl-merge branch from 91ca4c4 to a208ed1 Compare April 30, 2026 05:02
@github-actions
Copy link
Copy Markdown

🤖 Hi @hengtaoguo, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @hengtaoguo, but I was unable to process your request. Please see the logs for more details.

1 similar comment
@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @hengtaoguo, but I was unable to process your request. Please see the logs for more details.

Comment thread src/maxtext/integration/vllm/torchax_converter/validate_converter.py Outdated
Comment thread src/maxtext/integration/vllm/torchax_converter/validate_converter.py Outdated
@hengtaoguo hengtaoguo force-pushed the hengtaoguo-rl-merge branch 2 times, most recently from c6e8ee3 to 1ea9cf3 Compare April 30, 2026 21:19
@khatwanimohit khatwanimohit self-requested a review April 30, 2026 22:18
Copy link
Copy Markdown
Collaborator

@khatwanimohit khatwanimohit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@hengtaoguo hengtaoguo force-pushed the hengtaoguo-rl-merge branch 3 times, most recently from 714bd8f to ff4ec31 Compare April 30, 2026 23:09
Copy link
Copy Markdown
Collaborator

@NicoGrande NicoGrande left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment thread src/maxtext/integration/vllm/torchax_converter/qwen3_moe.py
Comment thread src/maxtext/integration/vllm/maxtext_vllm_rollout.py
@hengtaoguo hengtaoguo force-pushed the hengtaoguo-rl-merge branch from 4e68df7 to 8cd268b Compare April 30, 2026 23:46
@hengtaoguo hengtaoguo force-pushed the hengtaoguo-rl-merge branch from 3b8eb7b to c927864 Compare April 30, 2026 23:48
@copybara-service copybara-service Bot merged commit 64455b2 into main May 1, 2026
24 checks passed
@copybara-service copybara-service Bot deleted the hengtaoguo-rl-merge branch May 1, 2026 00:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants