Skip to content

Explicitly pass qwix config for deepseek batch split#3405

Closed
shuningjin wants to merge 1 commit intomainfrom
shuningjin-qwix1
Closed

Explicitly pass qwix config for deepseek batch split#3405
shuningjin wants to merge 1 commit intomainfrom
shuningjin-qwix1

Conversation

@shuningjin
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin commented Mar 13, 2026

Description

Fix: b/489513157

Overview

PR 3182 switched deepseek batch split config to use pure JAX, which broke the existing Qwix integration that relies on the Qwix interception feature. As a workaround, PR 3319 explicitly pass the QwixRule to gmm kernels.

This PR furthers the workaround. We explicitly pass the QwixQuantization to individual layers, in a similar vein as existing AQTQuantization.

# flow1 (existing), aqt plumb
DeepSeekMoE, non Batch-Split Version (MoE & MLA): flax, aqt plumb

# flow2 (existing), qwix intercept
DeepSeekMoE, non Batch-Split Version (MoE & MLA): flax, intercept

# flow3 (goal of this PR), qwix plumb
DeepSeekMoE, Batch-Split Version (MoE & MLA): jax, qwix plumb

In summary,

  • when deepseek + batch_split: we will use flow3
  • any other cases: remain the same as before, either flow1 or flow2

Main Changes

  • quantization.py

    • Add QwixQuantization, similar to AQTQuantization
      • implements dot_general_cls and einsum, with QwixDotGeneral and QwixEinsum
      • use qwix._src.core.dot_general_qt.dot_general_qt config, with "fp8_full" and calibration method
    • when use qwix and batch split, not use intercept (qwix.quantize_model) but return QwixQuantization
  • deepseek_batchsplit.py

    • Allow dot to use quantized version quant.dot_general_cls
    • pass quant to methods

Tests

locally on v5p-8 for deepseek3-test (trim down to 2 dense layer + 2 MoE layer)

Internal e2e tests in: b/489513157

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 13, 2026

Codecov Report

❌ Patch coverage is 44.00000% with 28 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/models/deepseek_batchsplit.py 11.76% 15 Missing ⚠️
src/maxtext/layers/quantizations.py 60.60% 12 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/layers/quantizations.py Outdated
Copy link
Copy Markdown
Collaborator

@BirdsOfAFthr BirdsOfAFthr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this on unquantized path?

@shuningjin
Copy link
Copy Markdown
Collaborator Author

shuningjin commented Mar 17, 2026

Have you tested this on unquantized path?

I tested the unquantized path for deepseek batch split (quantization unset). It is functional. From the profile, the HLO does not contain any "f8".

pip install -e .
git rev-parse HEAD
RUN_NAME=ds3-fp8-$(date +%Y-%m-%d-%H-%M-%S)
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=gs://runner-maxtext-logs run_name=$RUN_NAME \
model_name=deepseek3-test override_model_config=True first_num_dense_layers=2 base_num_decoder_layers=4 \
dtype=bfloat16 mu_dtype=bfloat16 grad_dtype=bfloat16 \
per_device_batch_size=1 max_target_length=4096 \
attention=flash sa_use_fused_bwd_kernel=true \
sparse_matmul=true megablox=true use_tokamax_gmm=true use_tokamax_splash=true \
moe_fsdp_use_two_stage_all_gather=false use_iota_embed=true \
opt_type=adamw float32_weight_sum=false \
dataset_type=synthetic enable_checkpointing=false async_checkpointing=false \
steps=10 profiler=xplane skip_first_n_steps_for_profiler=5 profiler_steps=2 \
cost_estimate_flops_fwd=5000000000000 cost_estimate_flops_bwd=5000000000000 use_max_logit_estimate=-1 \
use_batch_split_schedule=true batch_split_factor=1

log: https://paste.googleplex.com/5697881110609920

profile: http://shortn/_eFfObXlmZN

copybara-service Bot pushed a commit that referenced this pull request Mar 17, 2026
COPYBARA_INTEGRATE_REVIEW=#3405 from AI-Hypercomputer:shuningjin-qwix1 17800bf
PiperOrigin-RevId: 885270721
@shuningjin
Copy link
Copy Markdown
Collaborator Author

cf051eb is merged

@shuningjin shuningjin closed this Mar 18, 2026
Shuwen-Fang pushed a commit that referenced this pull request Mar 25, 2026
COPYBARA_INTEGRATE_REVIEW=#3405 from AI-Hypercomputer:shuningjin-qwix1 17800bf
PiperOrigin-RevId: 885270721
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants